From c66588447440b4c1d32c75dd307dc752c83550e4 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 6 Nov 2020 03:06:08 +0100 Subject: Wrap non-block expressions in closures with a block --- .../src/handlers/infer_function_return_type.rs | 43 ++++++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs index 81217378a..b4944a6b0 100644 --- a/crates/assists/src/handlers/infer_function_return_type.rs +++ b/crates/assists/src/handlers/infer_function_return_type.rs @@ -18,38 +18,43 @@ use crate::{AssistContext, AssistId, AssistKind, Assists}; // ``` pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { let expr = ctx.find_node_at_offset::()?; - let (tail_expr, insert_pos) = extract_tail(expr)?; + let (tail_expr, insert_pos, wrap_expr) = extract_tail(expr)?; let module = ctx.sema.scope(tail_expr.syntax()).module()?; let ty = ctx.sema.type_of_expr(&tail_expr).filter(|ty| !ty.is_unit())?; let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; acc.add( - AssistId("change_return_type_to_result", AssistKind::RefactorRewrite), - "Wrap return type in Result", + AssistId("infer_function_return_type", AssistKind::RefactorRewrite), + "Add this function's return type", tail_expr.syntax().text_range(), |builder| { let insert_pos = insert_pos.text_range().end() + TextSize::from(1); builder.insert(insert_pos, &format!("-> {} ", ty)); + if wrap_expr { + mark::hit!(wrap_closure_non_block_expr); + // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block + builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr)); + } }, ) } -fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken)> { - let (ret_ty, tail_expr, insert_pos) = +fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken, bool)> { + let (ret_ty, tail_expr, insert_pos, wrap_expr) = if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) { - let tail_expr = match closure.body()? { - ast::Expr::BlockExpr(block) => block.expr()?, - body => body, + let (tail_expr, wrap_expr) = match closure.body()? { + ast::Expr::BlockExpr(block) => (block.expr()?, false), + body => (body, true), }; let ret_ty = closure.ret_type(); let rpipe = closure.param_list()?.syntax().last_token()?; - (ret_ty, tail_expr, rpipe) + (ret_ty, tail_expr, rpipe, wrap_expr) } else { let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; let tail_expr = func.body()?.expr()?; let ret_ty = func.ret_type(); let rparen = func.param_list()?.r_paren_token()?; - (ret_ty, tail_expr, rparen) + (ret_ty, tail_expr, rparen, false) }; if ret_ty.is_some() { mark::hit!(existing_ret_type); @@ -61,7 +66,7 @@ fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken)> { mark::hit!(not_tail_expr); return None; } - Some((tail_expr, insert_pos)) + Some((tail_expr, insert_pos, wrap_expr)) } #[cfg(test)] @@ -156,13 +161,27 @@ mod tests { #[test] fn infer_return_type_closure() { + check_assist( + infer_function_return_type, + r#"fn foo() { + |x: i32| { x<|> }; + }"#, + r#"fn foo() { + |x: i32| -> i32 { x }; + }"#, + ); + } + + #[test] + fn infer_return_type_closure_wrap() { + mark::check!(wrap_closure_non_block_expr); check_assist( infer_function_return_type, r#"fn foo() { |x: i32| x<|>; }"#, r#"fn foo() { - |x: i32| -> i32 x; + |x: i32| -> i32 {x}; }"#, ); } -- cgit v1.2.3