diff options
Diffstat (limited to 'crates')
-rw-r--r-- | crates/assists/src/handlers/infer_function_return_type.rs | 133 |
1 files changed, 114 insertions, 19 deletions
diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs index da60ff9de..f363a56f3 100644 --- a/crates/assists/src/handlers/infer_function_return_type.rs +++ b/crates/assists/src/handlers/infer_function_return_type.rs | |||
@@ -1,12 +1,12 @@ | |||
1 | use hir::HirDisplay; | 1 | use hir::HirDisplay; |
2 | use syntax::{ast, AstNode, TextSize}; | 2 | use syntax::{ast, AstNode, SyntaxToken, TextSize}; |
3 | use test_utils::mark; | 3 | use test_utils::mark; |
4 | 4 | ||
5 | use crate::{AssistContext, AssistId, AssistKind, Assists}; | 5 | use crate::{AssistContext, AssistId, AssistKind, Assists}; |
6 | 6 | ||
7 | // Assist: infer_function_return_type | 7 | // Assist: infer_function_return_type |
8 | // | 8 | // |
9 | // Adds the return type to a function inferred from its tail expression if it doesn't have a return | 9 | // Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return |
10 | // type specified. | 10 | // type specified. |
11 | // | 11 | // |
12 | // ``` | 12 | // ``` |
@@ -18,36 +18,52 @@ use crate::{AssistContext, AssistId, AssistKind, Assists}; | |||
18 | // ``` | 18 | // ``` |
19 | pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { | 19 | pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { |
20 | let expr = ctx.find_node_at_offset::<ast::Expr>()?; | 20 | let expr = ctx.find_node_at_offset::<ast::Expr>()?; |
21 | let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; | 21 | let (tail_expr, insert_pos) = extract(expr)?; |
22 | 22 | let module = ctx.sema.scope(tail_expr.syntax()).module()?; | |
23 | if func.ret_type().is_some() { | ||
24 | mark::hit!(existing_ret_type); | ||
25 | return None; | ||
26 | } | ||
27 | let body = func.body()?; | ||
28 | let tail_expr = body.expr()?; | ||
29 | // check whether the expr we were at is indeed the tail expression | ||
30 | if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) { | ||
31 | mark::hit!(not_tail_expr); | ||
32 | return None; | ||
33 | } | ||
34 | let module = ctx.sema.scope(func.syntax()).module()?; | ||
35 | let ty = ctx.sema.type_of_expr(&tail_expr)?; | 23 | let ty = ctx.sema.type_of_expr(&tail_expr)?; |
36 | let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; | 24 | let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; |
37 | let rparen = func.param_list()?.r_paren_token()?; | ||
38 | 25 | ||
39 | acc.add( | 26 | acc.add( |
40 | AssistId("change_return_type_to_result", AssistKind::RefactorRewrite), | 27 | AssistId("change_return_type_to_result", AssistKind::RefactorRewrite), |
41 | "Wrap return type in Result", | 28 | "Wrap return type in Result", |
42 | tail_expr.syntax().text_range(), | 29 | tail_expr.syntax().text_range(), |
43 | |builder| { | 30 | |builder| { |
44 | let insert_pos = rparen.text_range().end() + TextSize::from(1); | 31 | let insert_pos = insert_pos.text_range().end() + TextSize::from(1); |
45 | |||
46 | builder.insert(insert_pos, &format!("-> {} ", ty)); | 32 | builder.insert(insert_pos, &format!("-> {} ", ty)); |
47 | }, | 33 | }, |
48 | ) | 34 | ) |
49 | } | 35 | } |
50 | 36 | ||
37 | fn extract(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken)> { | ||
38 | let (ret_ty, tail_expr, insert_pos) = | ||
39 | if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) { | ||
40 | let tail_expr = match closure.body()? { | ||
41 | ast::Expr::BlockExpr(block) => block.expr()?, | ||
42 | body => body, | ||
43 | }; | ||
44 | let ret_ty = closure.ret_type(); | ||
45 | let rpipe = closure.param_list()?.syntax().last_token()?; | ||
46 | (ret_ty, tail_expr, rpipe) | ||
47 | } else { | ||
48 | let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; | ||
49 | let tail_expr = func.body()?.expr()?; | ||
50 | let ret_ty = func.ret_type(); | ||
51 | let rparen = func.param_list()?.r_paren_token()?; | ||
52 | (ret_ty, tail_expr, rparen) | ||
53 | }; | ||
54 | if ret_ty.is_some() { | ||
55 | mark::hit!(existing_ret_type); | ||
56 | mark::hit!(existing_ret_type_closure); | ||
57 | return None; | ||
58 | } | ||
59 | // check whether the expr we were at is indeed the tail expression | ||
60 | if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) { | ||
61 | mark::hit!(not_tail_expr); | ||
62 | return None; | ||
63 | } | ||
64 | Some((tail_expr, insert_pos)) | ||
65 | } | ||
66 | |||
51 | #[cfg(test)] | 67 | #[cfg(test)] |
52 | mod tests { | 68 | mod tests { |
53 | use crate::tests::{check_assist, check_assist_not_applicable}; | 69 | use crate::tests::{check_assist, check_assist_not_applicable}; |
@@ -110,4 +126,83 @@ mod tests { | |||
110 | }"#, | 126 | }"#, |
111 | ); | 127 | ); |
112 | } | 128 | } |
129 | |||
130 | #[test] | ||
131 | fn infer_return_type_closure_block() { | ||
132 | check_assist( | ||
133 | infer_function_return_type, | ||
134 | r#"fn foo() { | ||
135 | |x: i32| { | ||
136 | x<|> | ||
137 | }; | ||
138 | }"#, | ||
139 | r#"fn foo() { | ||
140 | |x: i32| -> i32 { | ||
141 | x | ||
142 | }; | ||
143 | }"#, | ||
144 | ); | ||
145 | } | ||
146 | |||
147 | #[test] | ||
148 | fn infer_return_type_closure() { | ||
149 | check_assist( | ||
150 | infer_function_return_type, | ||
151 | r#"fn foo() { | ||
152 | |x: i32| x<|>; | ||
153 | }"#, | ||
154 | r#"fn foo() { | ||
155 | |x: i32| -> i32 x; | ||
156 | }"#, | ||
157 | ); | ||
158 | } | ||
159 | |||
160 | #[test] | ||
161 | fn infer_return_type_nested_closure() { | ||
162 | check_assist( | ||
163 | infer_function_return_type, | ||
164 | r#"fn foo() { | ||
165 | || { | ||
166 | if true { | ||
167 | 3<|> | ||
168 | } else { | ||
169 | 5 | ||
170 | } | ||
171 | } | ||
172 | }"#, | ||
173 | r#"fn foo() { | ||
174 | || -> i32 { | ||
175 | if true { | ||
176 | 3 | ||
177 | } else { | ||
178 | 5 | ||
179 | } | ||
180 | } | ||
181 | }"#, | ||
182 | ); | ||
183 | } | ||
184 | |||
185 | #[test] | ||
186 | fn not_applicable_ret_type_specified_closure() { | ||
187 | mark::check!(existing_ret_type_closure); | ||
188 | check_assist_not_applicable( | ||
189 | infer_function_return_type, | ||
190 | r#"fn foo() { | ||
191 | || -> i32 { 3<|> } | ||
192 | }"#, | ||
193 | ); | ||
194 | } | ||
195 | |||
196 | #[test] | ||
197 | fn not_applicable_non_tail_expr_closure() { | ||
198 | check_assist_not_applicable( | ||
199 | infer_function_return_type, | ||
200 | r#"fn foo() { | ||
201 | || -> i32 { | ||
202 | let x = 3<|>; | ||
203 | 6 | ||
204 | } | ||
205 | }"#, | ||
206 | ); | ||
207 | } | ||
113 | } | 208 | } |