diff options
-rw-r--r-- | crates/assists/src/handlers/infer_function_return_type.rs | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs index 520d07ae0..aa584eb03 100644 --- a/crates/assists/src/handlers/infer_function_return_type.rs +++ b/crates/assists/src/handlers/infer_function_return_type.rs | |||
@@ -17,7 +17,7 @@ use crate::{AssistContext, AssistId, AssistKind, Assists}; | |||
17 | // fn foo() -> i32 { 42i32 } | 17 | // fn foo() -> i32 { 42i32 } |
18 | // ``` | 18 | // ``` |
19 | pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { | 19 | pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { |
20 | let (tail_expr, builder_edit_pos, wrap_expr) = extract_tail(ctx)?; | 20 | let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?; |
21 | let module = ctx.sema.scope(tail_expr.syntax()).module()?; | 21 | let module = ctx.sema.scope(tail_expr.syntax()).module()?; |
22 | let ty = ctx.sema.type_of_expr(&tail_expr)?; | 22 | let ty = ctx.sema.type_of_expr(&tail_expr)?; |
23 | if ty.is_unit() { | 23 | if ty.is_unit() { |
@@ -27,7 +27,10 @@ pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) | |||
27 | 27 | ||
28 | acc.add( | 28 | acc.add( |
29 | AssistId("infer_function_return_type", AssistKind::RefactorRewrite), | 29 | AssistId("infer_function_return_type", AssistKind::RefactorRewrite), |
30 | "Add this function's return type", | 30 | match fn_type { |
31 | FnType::Function => "Add this function's return type", | ||
32 | FnType::Closure { .. } => "Add this closure's return type", | ||
33 | }, | ||
31 | tail_expr.syntax().text_range(), | 34 | tail_expr.syntax().text_range(), |
32 | |builder| { | 35 | |builder| { |
33 | match builder_edit_pos { | 36 | match builder_edit_pos { |
@@ -38,7 +41,7 @@ pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) | |||
38 | builder.replace(text_range, &format!("-> {}", ty)) | 41 | builder.replace(text_range, &format!("-> {}", ty)) |
39 | } | 42 | } |
40 | } | 43 | } |
41 | if wrap_expr { | 44 | if let FnType::Closure { wrap_expr: true } = fn_type { |
42 | mark::hit!(wrap_closure_non_block_expr); | 45 | mark::hit!(wrap_closure_non_block_expr); |
43 | // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block | 46 | // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block |
44 | builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr)); | 47 | builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr)); |
@@ -72,8 +75,13 @@ fn ret_ty_to_action(ret_ty: Option<ast::RetType>, insert_pos: TextSize) -> Optio | |||
72 | } | 75 | } |
73 | } | 76 | } |
74 | 77 | ||
75 | fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool)> { | 78 | enum FnType { |
76 | let (tail_expr, return_type_range, action, wrap_expr) = | 79 | Function, |
80 | Closure { wrap_expr: bool }, | ||
81 | } | ||
82 | |||
83 | fn extract_tail(ctx: &AssistContext) -> Option<(FnType, ast::Expr, InsertOrReplace)> { | ||
84 | let (fn_type, tail_expr, return_type_range, action) = | ||
77 | if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() { | 85 | if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() { |
78 | let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end(); | 86 | let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end(); |
79 | let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?; | 87 | let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?; |
@@ -86,7 +94,7 @@ fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool | |||
86 | }; | 94 | }; |
87 | 95 | ||
88 | let ret_range = TextRange::new(rpipe_pos, body_start); | 96 | let ret_range = TextRange::new(rpipe_pos, body_start); |
89 | (tail_expr, ret_range, action, wrap_expr) | 97 | (FnType::Closure { wrap_expr }, tail_expr, ret_range, action) |
90 | } else { | 98 | } else { |
91 | let func = ctx.find_node_at_offset::<ast::Fn>()?; | 99 | let func = ctx.find_node_at_offset::<ast::Fn>()?; |
92 | let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end(); | 100 | let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end(); |
@@ -97,7 +105,7 @@ fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool | |||
97 | 105 | ||
98 | let ret_range_end = body.l_curly_token()?.text_range().start(); | 106 | let ret_range_end = body.l_curly_token()?.text_range().start(); |
99 | let ret_range = TextRange::new(rparen_pos, ret_range_end); | 107 | let ret_range = TextRange::new(rparen_pos, ret_range_end); |
100 | (tail_expr, ret_range, action, false) | 108 | (FnType::Function, tail_expr, ret_range, action) |
101 | }; | 109 | }; |
102 | let frange = ctx.frange.range; | 110 | let frange = ctx.frange.range; |
103 | if return_type_range.contains_range(frange) { | 111 | if return_type_range.contains_range(frange) { |
@@ -109,7 +117,7 @@ fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool | |||
109 | } else { | 117 | } else { |
110 | return None; | 118 | return None; |
111 | } | 119 | } |
112 | Some((tail_expr, action, wrap_expr)) | 120 | Some((fn_type, tail_expr, action)) |
113 | } | 121 | } |
114 | 122 | ||
115 | #[cfg(test)] | 123 | #[cfg(test)] |