aboutsummaryrefslogtreecommitdiff
path: root/crates
diff options
context:
space:
mode:
Diffstat (limited to 'crates')
-rw-r--r--crates/assists/src/handlers/infer_function_return_type.rs133
1 files changed, 114 insertions, 19 deletions
diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs
index da60ff9de..f363a56f3 100644
--- a/crates/assists/src/handlers/infer_function_return_type.rs
+++ b/crates/assists/src/handlers/infer_function_return_type.rs
@@ -1,12 +1,12 @@
1use hir::HirDisplay; 1use hir::HirDisplay;
2use syntax::{ast, AstNode, TextSize}; 2use syntax::{ast, AstNode, SyntaxToken, TextSize};
3use test_utils::mark; 3use test_utils::mark;
4 4
5use crate::{AssistContext, AssistId, AssistKind, Assists}; 5use crate::{AssistContext, AssistId, AssistKind, Assists};
6 6
7// Assist: infer_function_return_type 7// Assist: infer_function_return_type
8// 8//
9// Adds the return type to a function inferred from its tail expression if it doesn't have a return 9// Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return
10// type specified. 10// type specified.
11// 11//
12// ``` 12// ```
@@ -18,36 +18,52 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
18// ``` 18// ```
19pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 19pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
20 let expr = ctx.find_node_at_offset::<ast::Expr>()?; 20 let expr = ctx.find_node_at_offset::<ast::Expr>()?;
21 let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; 21 let (tail_expr, insert_pos) = extract(expr)?;
22 22 let module = ctx.sema.scope(tail_expr.syntax()).module()?;
23 if func.ret_type().is_some() {
24 mark::hit!(existing_ret_type);
25 return None;
26 }
27 let body = func.body()?;
28 let tail_expr = body.expr()?;
29 // check whether the expr we were at is indeed the tail expression
30 if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) {
31 mark::hit!(not_tail_expr);
32 return None;
33 }
34 let module = ctx.sema.scope(func.syntax()).module()?;
35 let ty = ctx.sema.type_of_expr(&tail_expr)?; 23 let ty = ctx.sema.type_of_expr(&tail_expr)?;
36 let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; 24 let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
37 let rparen = func.param_list()?.r_paren_token()?;
38 25
39 acc.add( 26 acc.add(
40 AssistId("change_return_type_to_result", AssistKind::RefactorRewrite), 27 AssistId("change_return_type_to_result", AssistKind::RefactorRewrite),
41 "Wrap return type in Result", 28 "Wrap return type in Result",
42 tail_expr.syntax().text_range(), 29 tail_expr.syntax().text_range(),
43 |builder| { 30 |builder| {
44 let insert_pos = rparen.text_range().end() + TextSize::from(1); 31 let insert_pos = insert_pos.text_range().end() + TextSize::from(1);
45
46 builder.insert(insert_pos, &format!("-> {} ", ty)); 32 builder.insert(insert_pos, &format!("-> {} ", ty));
47 }, 33 },
48 ) 34 )
49} 35}
50 36
37fn extract(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken)> {
38 let (ret_ty, tail_expr, insert_pos) =
39 if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) {
40 let tail_expr = match closure.body()? {
41 ast::Expr::BlockExpr(block) => block.expr()?,
42 body => body,
43 };
44 let ret_ty = closure.ret_type();
45 let rpipe = closure.param_list()?.syntax().last_token()?;
46 (ret_ty, tail_expr, rpipe)
47 } else {
48 let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?;
49 let tail_expr = func.body()?.expr()?;
50 let ret_ty = func.ret_type();
51 let rparen = func.param_list()?.r_paren_token()?;
52 (ret_ty, tail_expr, rparen)
53 };
54 if ret_ty.is_some() {
55 mark::hit!(existing_ret_type);
56 mark::hit!(existing_ret_type_closure);
57 return None;
58 }
59 // check whether the expr we were at is indeed the tail expression
60 if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) {
61 mark::hit!(not_tail_expr);
62 return None;
63 }
64 Some((tail_expr, insert_pos))
65}
66
51#[cfg(test)] 67#[cfg(test)]
52mod tests { 68mod tests {
53 use crate::tests::{check_assist, check_assist_not_applicable}; 69 use crate::tests::{check_assist, check_assist_not_applicable};
@@ -110,4 +126,83 @@ mod tests {
110 }"#, 126 }"#,
111 ); 127 );
112 } 128 }
129
130 #[test]
131 fn infer_return_type_closure_block() {
132 check_assist(
133 infer_function_return_type,
134 r#"fn foo() {
135 |x: i32| {
136 x<|>
137 };
138 }"#,
139 r#"fn foo() {
140 |x: i32| -> i32 {
141 x
142 };
143 }"#,
144 );
145 }
146
147 #[test]
148 fn infer_return_type_closure() {
149 check_assist(
150 infer_function_return_type,
151 r#"fn foo() {
152 |x: i32| x<|>;
153 }"#,
154 r#"fn foo() {
155 |x: i32| -> i32 x;
156 }"#,
157 );
158 }
159
160 #[test]
161 fn infer_return_type_nested_closure() {
162 check_assist(
163 infer_function_return_type,
164 r#"fn foo() {
165 || {
166 if true {
167 3<|>
168 } else {
169 5
170 }
171 }
172 }"#,
173 r#"fn foo() {
174 || -> i32 {
175 if true {
176 3
177 } else {
178 5
179 }
180 }
181 }"#,
182 );
183 }
184
185 #[test]
186 fn not_applicable_ret_type_specified_closure() {
187 mark::check!(existing_ret_type_closure);
188 check_assist_not_applicable(
189 infer_function_return_type,
190 r#"fn foo() {
191 || -> i32 { 3<|> }
192 }"#,
193 );
194 }
195
196 #[test]
197 fn not_applicable_non_tail_expr_closure() {
198 check_assist_not_applicable(
199 infer_function_return_type,
200 r#"fn foo() {
201 || -> i32 {
202 let x = 3<|>;
203 6
204 }
205 }"#,
206 );
207 }
113} 208}