aboutsummaryrefslogtreecommitdiff
path: root/crates
diff options
context:
space:
mode:
Diffstat (limited to 'crates')
-rw-r--r--crates/assists/src/handlers/infer_function_return_type.rs43
1 files changed, 31 insertions, 12 deletions
diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs
index 81217378a..b4944a6b0 100644
--- a/crates/assists/src/handlers/infer_function_return_type.rs
+++ b/crates/assists/src/handlers/infer_function_return_type.rs
@@ -18,38 +18,43 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
18// ``` 18// ```
19pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 19pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
20 let expr = ctx.find_node_at_offset::<ast::Expr>()?; 20 let expr = ctx.find_node_at_offset::<ast::Expr>()?;
21 let (tail_expr, insert_pos) = extract_tail(expr)?; 21 let (tail_expr, insert_pos, wrap_expr) = extract_tail(expr)?;
22 let module = ctx.sema.scope(tail_expr.syntax()).module()?; 22 let module = ctx.sema.scope(tail_expr.syntax()).module()?;
23 let ty = ctx.sema.type_of_expr(&tail_expr).filter(|ty| !ty.is_unit())?; 23 let ty = ctx.sema.type_of_expr(&tail_expr).filter(|ty| !ty.is_unit())?;
24 let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; 24 let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
25 25
26 acc.add( 26 acc.add(
27 AssistId("change_return_type_to_result", AssistKind::RefactorRewrite), 27 AssistId("infer_function_return_type", AssistKind::RefactorRewrite),
28 "Wrap return type in Result", 28 "Add this function's return type",
29 tail_expr.syntax().text_range(), 29 tail_expr.syntax().text_range(),
30 |builder| { 30 |builder| {
31 let insert_pos = insert_pos.text_range().end() + TextSize::from(1); 31 let insert_pos = insert_pos.text_range().end() + TextSize::from(1);
32 builder.insert(insert_pos, &format!("-> {} ", ty)); 32 builder.insert(insert_pos, &format!("-> {} ", ty));
33 if wrap_expr {
34 mark::hit!(wrap_closure_non_block_expr);
35 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
36 builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr));
37 }
33 }, 38 },
34 ) 39 )
35} 40}
36 41
37fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken)> { 42fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken, bool)> {
38 let (ret_ty, tail_expr, insert_pos) = 43 let (ret_ty, tail_expr, insert_pos, wrap_expr) =
39 if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) { 44 if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) {
40 let tail_expr = match closure.body()? { 45 let (tail_expr, wrap_expr) = match closure.body()? {
41 ast::Expr::BlockExpr(block) => block.expr()?, 46 ast::Expr::BlockExpr(block) => (block.expr()?, false),
42 body => body, 47 body => (body, true),
43 }; 48 };
44 let ret_ty = closure.ret_type(); 49 let ret_ty = closure.ret_type();
45 let rpipe = closure.param_list()?.syntax().last_token()?; 50 let rpipe = closure.param_list()?.syntax().last_token()?;
46 (ret_ty, tail_expr, rpipe) 51 (ret_ty, tail_expr, rpipe, wrap_expr)
47 } else { 52 } else {
48 let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; 53 let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?;
49 let tail_expr = func.body()?.expr()?; 54 let tail_expr = func.body()?.expr()?;
50 let ret_ty = func.ret_type(); 55 let ret_ty = func.ret_type();
51 let rparen = func.param_list()?.r_paren_token()?; 56 let rparen = func.param_list()?.r_paren_token()?;
52 (ret_ty, tail_expr, rparen) 57 (ret_ty, tail_expr, rparen, false)
53 }; 58 };
54 if ret_ty.is_some() { 59 if ret_ty.is_some() {
55 mark::hit!(existing_ret_type); 60 mark::hit!(existing_ret_type);
@@ -61,7 +66,7 @@ fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken)> {
61 mark::hit!(not_tail_expr); 66 mark::hit!(not_tail_expr);
62 return None; 67 return None;
63 } 68 }
64 Some((tail_expr, insert_pos)) 69 Some((tail_expr, insert_pos, wrap_expr))
65} 70}
66 71
67#[cfg(test)] 72#[cfg(test)]
@@ -159,10 +164,24 @@ mod tests {
159 check_assist( 164 check_assist(
160 infer_function_return_type, 165 infer_function_return_type,
161 r#"fn foo() { 166 r#"fn foo() {
167 |x: i32| { x<|> };
168 }"#,
169 r#"fn foo() {
170 |x: i32| -> i32 { x };
171 }"#,
172 );
173 }
174
175 #[test]
176 fn infer_return_type_closure_wrap() {
177 mark::check!(wrap_closure_non_block_expr);
178 check_assist(
179 infer_function_return_type,
180 r#"fn foo() {
162 |x: i32| x<|>; 181 |x: i32| x<|>;
163 }"#, 182 }"#,
164 r#"fn foo() { 183 r#"fn foo() {
165 |x: i32| -> i32 x; 184 |x: i32| -> i32 {x};
166 }"#, 185 }"#,
167 ); 186 );
168 } 187 }