aboutsummaryrefslogtreecommitdiff
path: root/crates/assists/src/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'crates/assists/src/handlers')
-rw-r--r--crates/assists/src/handlers/infer_function_return_type.rs158
1 files changed, 129 insertions, 29 deletions
diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs
index b4944a6b0..80864c530 100644
--- a/crates/assists/src/handlers/infer_function_return_type.rs
+++ b/crates/assists/src/handlers/infer_function_return_type.rs
@@ -1,5 +1,5 @@
1use hir::HirDisplay; 1use hir::HirDisplay;
2use syntax::{ast, AstNode, SyntaxToken, TextSize}; 2use syntax::{ast, AstNode, TextRange, TextSize};
3use test_utils::mark; 3use test_utils::mark;
4 4
5use crate::{AssistContext, AssistId, AssistKind, Assists}; 5use crate::{AssistContext, AssistId, AssistKind, Assists};
@@ -7,7 +7,7 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
7// Assist: infer_function_return_type 7// Assist: infer_function_return_type
8// 8//
9// Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return 9// Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return
10// type specified. 10// type specified. This assists is useable in a functions or closures tail expression or return type position.
11// 11//
12// ``` 12// ```
13// fn foo() { 4<|>2i32 } 13// fn foo() { 4<|>2i32 }
@@ -17,10 +17,12 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
17// fn foo() -> i32 { 42i32 } 17// fn foo() -> i32 { 42i32 }
18// ``` 18// ```
19pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 19pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
20 let expr = ctx.find_node_at_offset::<ast::Expr>()?; 20 let (tail_expr, builder_edit_pos, wrap_expr) = extract_tail(ctx)?;
21 let (tail_expr, insert_pos, wrap_expr) = extract_tail(expr)?;
22 let module = ctx.sema.scope(tail_expr.syntax()).module()?; 21 let module = ctx.sema.scope(tail_expr.syntax()).module()?;
23 let ty = ctx.sema.type_of_expr(&tail_expr).filter(|ty| !ty.is_unit())?; 22 let ty = ctx.sema.type_of_expr(&tail_expr)?;
23 if ty.is_unit() {
24 return None;
25 }
24 let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; 26 let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
25 27
26 acc.add( 28 acc.add(
@@ -28,8 +30,14 @@ pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext)
28 "Add this function's return type", 30 "Add this function's return type",
29 tail_expr.syntax().text_range(), 31 tail_expr.syntax().text_range(),
30 |builder| { 32 |builder| {
31 let insert_pos = insert_pos.text_range().end() + TextSize::from(1); 33 match builder_edit_pos {
32 builder.insert(insert_pos, &format!("-> {} ", ty)); 34 InsertOrReplace::Insert(insert_pos) => {
35 builder.insert(insert_pos, &format!("-> {} ", ty))
36 }
37 InsertOrReplace::Replace(text_range) => {
38 builder.replace(text_range, &format!("-> {}", ty))
39 }
40 }
33 if wrap_expr { 41 if wrap_expr {
34 mark::hit!(wrap_closure_non_block_expr); 42 mark::hit!(wrap_closure_non_block_expr);
35 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block 43 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
@@ -39,34 +47,69 @@ pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext)
39 ) 47 )
40} 48}
41 49
42fn extract_tail(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken, bool)> { 50enum InsertOrReplace {
43 let (ret_ty, tail_expr, insert_pos, wrap_expr) = 51 Insert(TextSize),
44 if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) { 52 Replace(TextRange),
45 let (tail_expr, wrap_expr) = match closure.body()? { 53}
54
55/// Check the potentially already specified return type and reject it or turn it into a builder command
56/// if allowed.
57fn ret_ty_to_action(ret_ty: Option<ast::RetType>, insert_pos: TextSize) -> Option<InsertOrReplace> {
58 match ret_ty {
59 Some(ret_ty) => match ret_ty.ty() {
60 Some(ast::Type::InferType(_)) | None => {
61 mark::hit!(existing_infer_ret_type);
62 mark::hit!(existing_infer_ret_type_closure);
63 Some(InsertOrReplace::Replace(ret_ty.syntax().text_range()))
64 }
65 _ => {
66 mark::hit!(existing_ret_type);
67 mark::hit!(existing_ret_type_closure);
68 None
69 }
70 },
71 None => Some(InsertOrReplace::Insert(insert_pos + TextSize::from(1))),
72 }
73}
74
75fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool)> {
76 let (tail_expr, return_type_range, action, wrap_expr) =
77 if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
78 let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end();
79 let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?;
80
81 let body = closure.body()?;
82 let body_start = body.syntax().first_token()?.text_range().start();
83 let (tail_expr, wrap_expr) = match body {
46 ast::Expr::BlockExpr(block) => (block.expr()?, false), 84 ast::Expr::BlockExpr(block) => (block.expr()?, false),
47 body => (body, true), 85 body => (body, true),
48 }; 86 };
49 let ret_ty = closure.ret_type(); 87
50 let rpipe = closure.param_list()?.syntax().last_token()?; 88 let ret_range = TextRange::new(rpipe_pos, body_start);
51 (ret_ty, tail_expr, rpipe, wrap_expr) 89 (tail_expr, ret_range, action, wrap_expr)
52 } else { 90 } else {
53 let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; 91 let func = ctx.find_node_at_offset::<ast::Fn>()?;
54 let tail_expr = func.body()?.expr()?; 92 let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end();
55 let ret_ty = func.ret_type(); 93 let action = ret_ty_to_action(func.ret_type(), rparen_pos)?;
56 let rparen = func.param_list()?.r_paren_token()?; 94
57 (ret_ty, tail_expr, rparen, false) 95 let body = func.body()?;
96 let tail_expr = body.expr()?;
97
98 let ret_range_end = body.l_curly_token()?.text_range().start();
99 let ret_range = TextRange::new(rparen_pos, ret_range_end);
100 (tail_expr, ret_range, action, false)
58 }; 101 };
59 if ret_ty.is_some() { 102 let frange = ctx.frange.range;
60 mark::hit!(existing_ret_type); 103 if return_type_range.contains_range(frange) {
61 mark::hit!(existing_ret_type_closure); 104 mark::hit!(cursor_in_ret_position);
62 return None; 105 mark::hit!(cursor_in_ret_position_closure);
63 } 106 } else if tail_expr.syntax().text_range().contains_range(frange) {
64 // check whether the expr we were at is indeed the tail expression 107 mark::hit!(cursor_on_tail);
65 if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) { 108 mark::hit!(cursor_on_tail_closure);
66 mark::hit!(not_tail_expr); 109 } else {
67 return None; 110 return None;
68 } 111 }
69 Some((tail_expr, insert_pos, wrap_expr)) 112 Some((tail_expr, action, wrap_expr))
70} 113}
71 114
72#[cfg(test)] 115#[cfg(test)]
@@ -76,7 +119,64 @@ mod tests {
76 use super::*; 119 use super::*;
77 120
78 #[test] 121 #[test]
122 fn infer_return_type_specified_inferred() {
123 mark::check!(existing_infer_ret_type);
124 check_assist(
125 infer_function_return_type,
126 r#"fn foo() -> <|>_ {
127 45
128 }"#,
129 r#"fn foo() -> i32 {
130 45
131 }"#,
132 );
133 }
134
135 #[test]
136 fn infer_return_type_specified_inferred_closure() {
137 mark::check!(existing_infer_ret_type_closure);
138 check_assist(
139 infer_function_return_type,
140 r#"fn foo() {
141 || -> _ {<|>45};
142 }"#,
143 r#"fn foo() {
144 || -> i32 {45};
145 }"#,
146 );
147 }
148
149 #[test]
150 fn infer_return_type_cursor_at_return_type_pos() {
151 mark::check!(cursor_in_ret_position);
152 check_assist(
153 infer_function_return_type,
154 r#"fn foo() <|>{
155 45
156 }"#,
157 r#"fn foo() -> i32 {
158 45
159 }"#,
160 );
161 }
162
163 #[test]
164 fn infer_return_type_cursor_at_return_type_pos_closure() {
165 mark::check!(cursor_in_ret_position_closure);
166 check_assist(
167 infer_function_return_type,
168 r#"fn foo() {
169 || <|>45
170 }"#,
171 r#"fn foo() {
172 || -> i32 {45}
173 }"#,
174 );
175 }
176
177 #[test]
79 fn infer_return_type() { 178 fn infer_return_type() {
179 mark::check!(cursor_on_tail);
80 check_assist( 180 check_assist(
81 infer_function_return_type, 181 infer_function_return_type,
82 r#"fn foo() { 182 r#"fn foo() {
@@ -122,7 +222,6 @@ mod tests {
122 222
123 #[test] 223 #[test]
124 fn not_applicable_non_tail_expr() { 224 fn not_applicable_non_tail_expr() {
125 mark::check!(not_tail_expr);
126 check_assist_not_applicable( 225 check_assist_not_applicable(
127 infer_function_return_type, 226 infer_function_return_type,
128 r#"fn foo() { 227 r#"fn foo() {
@@ -144,6 +243,7 @@ mod tests {
144 243
145 #[test] 244 #[test]
146 fn infer_return_type_closure_block() { 245 fn infer_return_type_closure_block() {
246 mark::check!(cursor_on_tail_closure);
147 check_assist( 247 check_assist(
148 infer_function_return_type, 248 infer_function_return_type,
149 r#"fn foo() { 249 r#"fn foo() {