diff options
-rw-r--r-- | crates/assists/src/handlers/infer_function_return_type.rs | 158 |
1 files changed, 129 insertions, 29 deletions
diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs index b4944a6b0..80864c530 100644 --- a/crates/assists/src/handlers/infer_function_return_type.rs +++ b/crates/assists/src/handlers/infer_function_return_type.rs | |||
@@ -1,5 +1,5 @@ | |||
1 | use hir::HirDisplay; | 1 | use hir::HirDisplay; |
2 | use syntax::{ast, AstNode, SyntaxToken, TextSize}; | 2 | use syntax::{ast, AstNode, TextRange, TextSize}; |
3 | use test_utils::mark; | 3 | use test_utils::mark; |
4 | 4 | ||
5 | use crate::{AssistContext, AssistId, AssistKind, Assists}; | 5 | use crate::{AssistContext, AssistId, AssistKind, Assists}; |
@@ -7,7 +7,7 @@ use crate::{AssistContext, AssistId, AssistKind, Assists}; | |||
7 | // Assist: infer_function_return_type | 7 | // Assist: infer_function_return_type |
8 | // | 8 | // |
9 | // Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return | 9 | // Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return |
10 | // type specified. | 10 | // type specified. This assists is useable in a functions or closures tail expression or return type position. |
11 | // | 11 | // |
12 | // ``` | 12 | // ``` |
13 | // fn foo() { 4<|>2i32 } | 13 | // fn foo() { 4<|>2i32 } |
@@ -17,10 +17,12 @@ use crate::{AssistContext, AssistId, AssistKind, Assists}; | |||
17 | // fn foo() -> i32 { 42i32 } | 17 | // fn foo() -> i32 { 42i32 } |
18 | // ``` | 18 | // ``` |
19 | pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { | 19 | pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { |
20 | let expr = ctx.find_node_at_offset::<ast::Expr>()?; | 20 | let (tail_expr, builder_edit_pos, wrap_expr) = extract_tail(ctx)?; |
21 | let (tail_expr, insert_pos, wrap_expr) = extract_tail(expr)?; | ||
22 | let module = ctx.sema.scope(tail_expr.syntax()).module()?; | 21 | let module = ctx.sema.scope(tail_expr.syntax()).module()?; |
23 | let ty = ctx.sema.type_of_expr(&tail_expr).filter(|ty| !ty.is_unit())?; | 22 | let ty = ctx.sema.type_of_expr(&tail_expr)?; |
23 | if ty.is_unit() { | ||
24 | return None; | ||
25 | } | ||
24 | let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; | 26 | let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; |
25 | 27 | ||
26 | acc.add( | 28 | acc.add( |
@@ -28,8 +30,14 @@ pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) | |||
28 | "Add this function's return type", | 30 | "Add this function's return type", |
29 | tail_expr.syntax().text_range(), | 31 | tail_expr.syntax().text_range(), |
30 | |builder| { | 32 | |builder| { |
31 | let insert_pos = insert_pos.text_range().end() + TextSize::from(1); | 33 | match builder_edit_pos { |
32 | builder.insert(insert_pos, &format!("-> {} ", ty)); | 34 | InsertOrReplace::Insert(insert_pos) => { |
35 | builder.insert(insert_pos, &format!("-> {} ", ty)) | ||
36 | } | ||
37 | InsertOrReplace::Replace(text_range) => { | ||
38 | builder.replace(text_range, &format!("-> {}", ty)) | ||
39 | } | ||
40 | } | ||
33 | if wrap_expr { | 41 | if wrap_expr { |
34 | mark::hit!(wrap_closure_non_block_expr); | 42 | mark::hit!(wrap_closure_non_block_expr); |
35 | // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block | 43 | // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block |
@@ -39,34 +47,69 @@ pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) | |||
39 | ) | 47 | ) |
40 | } | 48 | } |
41 | 49 | ||
42 | fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken, bool)> { | 50 | enum InsertOrReplace { |
43 | let (ret_ty, tail_expr, insert_pos, wrap_expr) = | 51 | Insert(TextSize), |
44 | if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) { | 52 | Replace(TextRange), |
45 | let (tail_expr, wrap_expr) = match closure.body()? { | 53 | } |
54 | |||
55 | /// Check the potentially already specified return type and reject it or turn it into a builder command | ||
56 | /// if allowed. | ||
57 | fn ret_ty_to_action(ret_ty: Option<ast::RetType>, insert_pos: TextSize) -> Option<InsertOrReplace> { | ||
58 | match ret_ty { | ||
59 | Some(ret_ty) => match ret_ty.ty() { | ||
60 | Some(ast::Type::InferType(_)) | None => { | ||
61 | mark::hit!(existing_infer_ret_type); | ||
62 | mark::hit!(existing_infer_ret_type_closure); | ||
63 | Some(InsertOrReplace::Replace(ret_ty.syntax().text_range())) | ||
64 | } | ||
65 | _ => { | ||
66 | mark::hit!(existing_ret_type); | ||
67 | mark::hit!(existing_ret_type_closure); | ||
68 | None | ||
69 | } | ||
70 | }, | ||
71 | None => Some(InsertOrReplace::Insert(insert_pos + TextSize::from(1))), | ||
72 | } | ||
73 | } | ||
74 | |||
75 | fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool)> { | ||
76 | let (tail_expr, return_type_range, action, wrap_expr) = | ||
77 | if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() { | ||
78 | let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end(); | ||
79 | let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?; | ||
80 | |||
81 | let body = closure.body()?; | ||
82 | let body_start = body.syntax().first_token()?.text_range().start(); | ||
83 | let (tail_expr, wrap_expr) = match body { | ||
46 | ast::Expr::BlockExpr(block) => (block.expr()?, false), | 84 | ast::Expr::BlockExpr(block) => (block.expr()?, false), |
47 | body => (body, true), | 85 | body => (body, true), |
48 | }; | 86 | }; |
49 | let ret_ty = closure.ret_type(); | 87 | |
50 | let rpipe = closure.param_list()?.syntax().last_token()?; | 88 | let ret_range = TextRange::new(rpipe_pos, body_start); |
51 | (ret_ty, tail_expr, rpipe, wrap_expr) | 89 | (tail_expr, ret_range, action, wrap_expr) |
52 | } else { | 90 | } else { |
53 | let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; | 91 | let func = ctx.find_node_at_offset::<ast::Fn>()?; |
54 | let tail_expr = func.body()?.expr()?; | 92 | let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end(); |
55 | let ret_ty = func.ret_type(); | 93 | let action = ret_ty_to_action(func.ret_type(), rparen_pos)?; |
56 | let rparen = func.param_list()?.r_paren_token()?; | 94 | |
57 | (ret_ty, tail_expr, rparen, false) | 95 | let body = func.body()?; |
96 | let tail_expr = body.expr()?; | ||
97 | |||
98 | let ret_range_end = body.l_curly_token()?.text_range().start(); | ||
99 | let ret_range = TextRange::new(rparen_pos, ret_range_end); | ||
100 | (tail_expr, ret_range, action, false) | ||
58 | }; | 101 | }; |
59 | if ret_ty.is_some() { | 102 | let frange = ctx.frange.range; |
60 | mark::hit!(existing_ret_type); | 103 | if return_type_range.contains_range(frange) { |
61 | mark::hit!(existing_ret_type_closure); | 104 | mark::hit!(cursor_in_ret_position); |
62 | return None; | 105 | mark::hit!(cursor_in_ret_position_closure); |
63 | } | 106 | } else if tail_expr.syntax().text_range().contains_range(frange) { |
64 | // check whether the expr we were at is indeed the tail expression | 107 | mark::hit!(cursor_on_tail); |
65 | if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) { | 108 | mark::hit!(cursor_on_tail_closure); |
66 | mark::hit!(not_tail_expr); | 109 | } else { |
67 | return None; | 110 | return None; |
68 | } | 111 | } |
69 | Some((tail_expr, insert_pos, wrap_expr)) | 112 | Some((tail_expr, action, wrap_expr)) |
70 | } | 113 | } |
71 | 114 | ||
72 | #[cfg(test)] | 115 | #[cfg(test)] |
@@ -76,7 +119,64 @@ mod tests { | |||
76 | use super::*; | 119 | use super::*; |
77 | 120 | ||
78 | #[test] | 121 | #[test] |
122 | fn infer_return_type_specified_inferred() { | ||
123 | mark::check!(existing_infer_ret_type); | ||
124 | check_assist( | ||
125 | infer_function_return_type, | ||
126 | r#"fn foo() -> <|>_ { | ||
127 | 45 | ||
128 | }"#, | ||
129 | r#"fn foo() -> i32 { | ||
130 | 45 | ||
131 | }"#, | ||
132 | ); | ||
133 | } | ||
134 | |||
135 | #[test] | ||
136 | fn infer_return_type_specified_inferred_closure() { | ||
137 | mark::check!(existing_infer_ret_type_closure); | ||
138 | check_assist( | ||
139 | infer_function_return_type, | ||
140 | r#"fn foo() { | ||
141 | || -> _ {<|>45}; | ||
142 | }"#, | ||
143 | r#"fn foo() { | ||
144 | || -> i32 {45}; | ||
145 | }"#, | ||
146 | ); | ||
147 | } | ||
148 | |||
149 | #[test] | ||
150 | fn infer_return_type_cursor_at_return_type_pos() { | ||
151 | mark::check!(cursor_in_ret_position); | ||
152 | check_assist( | ||
153 | infer_function_return_type, | ||
154 | r#"fn foo() <|>{ | ||
155 | 45 | ||
156 | }"#, | ||
157 | r#"fn foo() -> i32 { | ||
158 | 45 | ||
159 | }"#, | ||
160 | ); | ||
161 | } | ||
162 | |||
163 | #[test] | ||
164 | fn infer_return_type_cursor_at_return_type_pos_closure() { | ||
165 | mark::check!(cursor_in_ret_position_closure); | ||
166 | check_assist( | ||
167 | infer_function_return_type, | ||
168 | r#"fn foo() { | ||
169 | || <|>45 | ||
170 | }"#, | ||
171 | r#"fn foo() { | ||
172 | || -> i32 {45} | ||
173 | }"#, | ||
174 | ); | ||
175 | } | ||
176 | |||
177 | #[test] | ||
79 | fn infer_return_type() { | 178 | fn infer_return_type() { |
179 | mark::check!(cursor_on_tail); | ||
80 | check_assist( | 180 | check_assist( |
81 | infer_function_return_type, | 181 | infer_function_return_type, |
82 | r#"fn foo() { | 182 | r#"fn foo() { |
@@ -122,7 +222,6 @@ mod tests { | |||
122 | 222 | ||
123 | #[test] | 223 | #[test] |
124 | fn not_applicable_non_tail_expr() { | 224 | fn not_applicable_non_tail_expr() { |
125 | mark::check!(not_tail_expr); | ||
126 | check_assist_not_applicable( | 225 | check_assist_not_applicable( |
127 | infer_function_return_type, | 226 | infer_function_return_type, |
128 | r#"fn foo() { | 227 | r#"fn foo() { |
@@ -144,6 +243,7 @@ mod tests { | |||
144 | 243 | ||
145 | #[test] | 244 | #[test] |
146 | fn infer_return_type_closure_block() { | 245 | fn infer_return_type_closure_block() { |
246 | mark::check!(cursor_on_tail_closure); | ||
147 | check_assist( | 247 | check_assist( |
148 | infer_function_return_type, | 248 | infer_function_return_type, |
149 | r#"fn foo() { | 249 | r#"fn foo() { |