aboutsummaryrefslogtreecommitdiff
path: root/crates/assists/src/handlers/infer_function_return_type.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/assists/src/handlers/infer_function_return_type.rs')
-rw-r--r--crates/assists/src/handlers/infer_function_return_type.rs113
1 files changed, 113 insertions, 0 deletions
diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs
new file mode 100644
index 000000000..da60ff9de
--- /dev/null
+++ b/crates/assists/src/handlers/infer_function_return_type.rs
@@ -0,0 +1,113 @@
1use hir::HirDisplay;
2use syntax::{ast, AstNode, TextSize};
3use test_utils::mark;
4
5use crate::{AssistContext, AssistId, AssistKind, Assists};
6
7// Assist: infer_function_return_type
8//
9// Adds the return type to a function inferred from its tail expression if it doesn't have a return
10// type specified.
11//
12// ```
13// fn foo() { 4<|>2i32 }
14// ```
15// ->
16// ```
17// fn foo() -> i32 { 42i32 }
18// ```
19pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
20 let expr = ctx.find_node_at_offset::<ast::Expr>()?;
21 let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?;
22
23 if func.ret_type().is_some() {
24 mark::hit!(existing_ret_type);
25 return None;
26 }
27 let body = func.body()?;
28 let tail_expr = body.expr()?;
29 // check whether the expr we were at is indeed the tail expression
30 if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) {
31 mark::hit!(not_tail_expr);
32 return None;
33 }
34 let module = ctx.sema.scope(func.syntax()).module()?;
35 let ty = ctx.sema.type_of_expr(&tail_expr)?;
36 let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
37 let rparen = func.param_list()?.r_paren_token()?;
38
39 acc.add(
40 AssistId("change_return_type_to_result", AssistKind::RefactorRewrite),
41 "Wrap return type in Result",
42 tail_expr.syntax().text_range(),
43 |builder| {
44 let insert_pos = rparen.text_range().end() + TextSize::from(1);
45
46 builder.insert(insert_pos, &format!("-> {} ", ty));
47 },
48 )
49}
50
51#[cfg(test)]
52mod tests {
53 use crate::tests::{check_assist, check_assist_not_applicable};
54
55 use super::*;
56
57 #[test]
58 fn infer_return_type() {
59 check_assist(
60 infer_function_return_type,
61 r#"fn foo() {
62 45<|>
63 }"#,
64 r#"fn foo() -> i32 {
65 45
66 }"#,
67 );
68 }
69
70 #[test]
71 fn infer_return_type_nested() {
72 check_assist(
73 infer_function_return_type,
74 r#"fn foo() {
75 if true {
76 3<|>
77 } else {
78 5
79 }
80 }"#,
81 r#"fn foo() -> i32 {
82 if true {
83 3
84 } else {
85 5
86 }
87 }"#,
88 );
89 }
90
91 #[test]
92 fn not_applicable_ret_type_specified() {
93 mark::check!(existing_ret_type);
94 check_assist_not_applicable(
95 infer_function_return_type,
96 r#"fn foo() -> i32 {
97 ( 45<|> + 32 ) * 123
98 }"#,
99 );
100 }
101
102 #[test]
103 fn not_applicable_non_tail_expr() {
104 mark::check!(not_tail_expr);
105 check_assist_not_applicable(
106 infer_function_return_type,
107 r#"fn foo() {
108 let x = <|>3;
109 ( 45 + 32 ) * 123
110 }"#,
111 );
112 }
113}