diff options
-rw-r--r-- | crates/assists/src/handlers/infer_function_return_type.rs | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs index 81217378a..b4944a6b0 100644 --- a/crates/assists/src/handlers/infer_function_return_type.rs +++ b/crates/assists/src/handlers/infer_function_return_type.rs | |||
@@ -18,38 +18,43 @@ use crate::{AssistContext, AssistId, AssistKind, Assists}; | |||
18 | // ``` | 18 | // ``` |
19 | pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { | 19 | pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { |
20 | let expr = ctx.find_node_at_offset::<ast::Expr>()?; | 20 | let expr = ctx.find_node_at_offset::<ast::Expr>()?; |
21 | let (tail_expr, insert_pos) = extract_tail(expr)?; | 21 | let (tail_expr, insert_pos, wrap_expr) = extract_tail(expr)?; |
22 | let module = ctx.sema.scope(tail_expr.syntax()).module()?; | 22 | let module = ctx.sema.scope(tail_expr.syntax()).module()?; |
23 | let ty = ctx.sema.type_of_expr(&tail_expr).filter(|ty| !ty.is_unit())?; | 23 | let ty = ctx.sema.type_of_expr(&tail_expr).filter(|ty| !ty.is_unit())?; |
24 | let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; | 24 | let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; |
25 | 25 | ||
26 | acc.add( | 26 | acc.add( |
27 | AssistId("change_return_type_to_result", AssistKind::RefactorRewrite), | 27 | AssistId("infer_function_return_type", AssistKind::RefactorRewrite), |
28 | "Wrap return type in Result", | 28 | "Add this function's return type", |
29 | tail_expr.syntax().text_range(), | 29 | tail_expr.syntax().text_range(), |
30 | |builder| { | 30 | |builder| { |
31 | let insert_pos = insert_pos.text_range().end() + TextSize::from(1); | 31 | let insert_pos = insert_pos.text_range().end() + TextSize::from(1); |
32 | builder.insert(insert_pos, &format!("-> {} ", ty)); | 32 | builder.insert(insert_pos, &format!("-> {} ", ty)); |
33 | if wrap_expr { | ||
34 | mark::hit!(wrap_closure_non_block_expr); | ||
35 | // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block | ||
36 | builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr)); | ||
37 | } | ||
33 | }, | 38 | }, |
34 | ) | 39 | ) |
35 | } | 40 | } |
36 | 41 | ||
37 | fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken)> { | 42 | fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken, bool)> { |
38 | let (ret_ty, tail_expr, insert_pos) = | 43 | let (ret_ty, tail_expr, insert_pos, wrap_expr) = |
39 | if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) { | 44 | if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) { |
40 | let tail_expr = match closure.body()? { | 45 | let (tail_expr, wrap_expr) = match closure.body()? { |
41 | ast::Expr::BlockExpr(block) => block.expr()?, | 46 | ast::Expr::BlockExpr(block) => (block.expr()?, false), |
42 | body => body, | 47 | body => (body, true), |
43 | }; | 48 | }; |
44 | let ret_ty = closure.ret_type(); | 49 | let ret_ty = closure.ret_type(); |
45 | let rpipe = closure.param_list()?.syntax().last_token()?; | 50 | let rpipe = closure.param_list()?.syntax().last_token()?; |
46 | (ret_ty, tail_expr, rpipe) | 51 | (ret_ty, tail_expr, rpipe, wrap_expr) |
47 | } else { | 52 | } else { |
48 | let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; | 53 | let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; |
49 | let tail_expr = func.body()?.expr()?; | 54 | let tail_expr = func.body()?.expr()?; |
50 | let ret_ty = func.ret_type(); | 55 | let ret_ty = func.ret_type(); |
51 | let rparen = func.param_list()?.r_paren_token()?; | 56 | let rparen = func.param_list()?.r_paren_token()?; |
52 | (ret_ty, tail_expr, rparen) | 57 | (ret_ty, tail_expr, rparen, false) |
53 | }; | 58 | }; |
54 | if ret_ty.is_some() { | 59 | if ret_ty.is_some() { |
55 | mark::hit!(existing_ret_type); | 60 | mark::hit!(existing_ret_type); |
@@ -61,7 +66,7 @@ fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken)> { | |||
61 | mark::hit!(not_tail_expr); | 66 | mark::hit!(not_tail_expr); |
62 | return None; | 67 | return None; |
63 | } | 68 | } |
64 | Some((tail_expr, insert_pos)) | 69 | Some((tail_expr, insert_pos, wrap_expr)) |
65 | } | 70 | } |
66 | 71 | ||
67 | #[cfg(test)] | 72 | #[cfg(test)] |
@@ -159,10 +164,24 @@ mod tests { | |||
159 | check_assist( | 164 | check_assist( |
160 | infer_function_return_type, | 165 | infer_function_return_type, |
161 | r#"fn foo() { | 166 | r#"fn foo() { |
167 | |x: i32| { x<|> }; | ||
168 | }"#, | ||
169 | r#"fn foo() { | ||
170 | |x: i32| -> i32 { x }; | ||
171 | }"#, | ||
172 | ); | ||
173 | } | ||
174 | |||
175 | #[test] | ||
176 | fn infer_return_type_closure_wrap() { | ||
177 | mark::check!(wrap_closure_non_block_expr); | ||
178 | check_assist( | ||
179 | infer_function_return_type, | ||
180 | r#"fn foo() { | ||
162 | |x: i32| x<|>; | 181 | |x: i32| x<|>; |
163 | }"#, | 182 | }"#, |
164 | r#"fn foo() { | 183 | r#"fn foo() { |
165 | |x: i32| -> i32 x; | 184 | |x: i32| -> i32 {x}; |
166 | }"#, | 185 | }"#, |
167 | ); | 186 | ); |
168 | } | 187 | } |