aboutsummaryrefslogtreecommitdiff
path: root/crates
diff options
context:
space:
mode:
Diffstat (limited to 'crates')
-rw-r--r--crates/assists/src/handlers/infer_function_return_type.rs24
1 files changed, 16 insertions, 8 deletions
diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs
index 520d07ae0..aa584eb03 100644
--- a/crates/assists/src/handlers/infer_function_return_type.rs
+++ b/crates/assists/src/handlers/infer_function_return_type.rs
@@ -17,7 +17,7 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
17// fn foo() -> i32 { 42i32 } 17// fn foo() -> i32 { 42i32 }
18// ``` 18// ```
19pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 19pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
20 let (tail_expr, builder_edit_pos, wrap_expr) = extract_tail(ctx)?; 20 let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?;
21 let module = ctx.sema.scope(tail_expr.syntax()).module()?; 21 let module = ctx.sema.scope(tail_expr.syntax()).module()?;
22 let ty = ctx.sema.type_of_expr(&tail_expr)?; 22 let ty = ctx.sema.type_of_expr(&tail_expr)?;
23 if ty.is_unit() { 23 if ty.is_unit() {
@@ -27,7 +27,10 @@ pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext)
27 27
28 acc.add( 28 acc.add(
29 AssistId("infer_function_return_type", AssistKind::RefactorRewrite), 29 AssistId("infer_function_return_type", AssistKind::RefactorRewrite),
30 "Add this function's return type", 30 match fn_type {
31 FnType::Function => "Add this function's return type",
32 FnType::Closure { .. } => "Add this closure's return type",
33 },
31 tail_expr.syntax().text_range(), 34 tail_expr.syntax().text_range(),
32 |builder| { 35 |builder| {
33 match builder_edit_pos { 36 match builder_edit_pos {
@@ -38,7 +41,7 @@ pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext)
38 builder.replace(text_range, &format!("-> {}", ty)) 41 builder.replace(text_range, &format!("-> {}", ty))
39 } 42 }
40 } 43 }
41 if wrap_expr { 44 if let FnType::Closure { wrap_expr: true } = fn_type {
42 mark::hit!(wrap_closure_non_block_expr); 45 mark::hit!(wrap_closure_non_block_expr);
43 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block 46 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
44 builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr)); 47 builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr));
@@ -72,8 +75,13 @@ fn ret_ty_to_action(ret_ty: Option<ast::RetType>, insert_pos: TextSize) -> Optio
72 } 75 }
73} 76}
74 77
75fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool)> { 78enum FnType {
76 let (tail_expr, return_type_range, action, wrap_expr) = 79 Function,
80 Closure { wrap_expr: bool },
81}
82
83fn extract_tail(ctx: &AssistContext) -> Option<(FnType, ast::Expr, InsertOrReplace)> {
84 let (fn_type, tail_expr, return_type_range, action) =
77 if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() { 85 if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
78 let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end(); 86 let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end();
79 let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?; 87 let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?;
@@ -86,7 +94,7 @@ fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool
86 }; 94 };
87 95
88 let ret_range = TextRange::new(rpipe_pos, body_start); 96 let ret_range = TextRange::new(rpipe_pos, body_start);
89 (tail_expr, ret_range, action, wrap_expr) 97 (FnType::Closure { wrap_expr }, tail_expr, ret_range, action)
90 } else { 98 } else {
91 let func = ctx.find_node_at_offset::<ast::Fn>()?; 99 let func = ctx.find_node_at_offset::<ast::Fn>()?;
92 let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end(); 100 let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end();
@@ -97,7 +105,7 @@ fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool
97 105
98 let ret_range_end = body.l_curly_token()?.text_range().start(); 106 let ret_range_end = body.l_curly_token()?.text_range().start();
99 let ret_range = TextRange::new(rparen_pos, ret_range_end); 107 let ret_range = TextRange::new(rparen_pos, ret_range_end);
100 (tail_expr, ret_range, action, false) 108 (FnType::Function, tail_expr, ret_range, action)
101 }; 109 };
102 let frange = ctx.frange.range; 110 let frange = ctx.frange.range;
103 if return_type_range.contains_range(frange) { 111 if return_type_range.contains_range(frange) {
@@ -109,7 +117,7 @@ fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool
109 } else { 117 } else {
110 return None; 118 return None;
111 } 119 }
112 Some((tail_expr, action, wrap_expr)) 120 Some((fn_type, tail_expr, action))
113} 121}
114 122
115#[cfg(test)] 123#[cfg(test)]