aboutsummaryrefslogtreecommitdiff
path: root/crates/assists/src/handlers/infer_function_return_type.rs
blob: da60ff9dedf9311088aab34618f689d5bd844aba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use hir::HirDisplay;
use syntax::{ast, AstNode, TextSize};
use test_utils::mark;

use crate::{AssistContext, AssistId, AssistKind, Assists};

// Assist: infer_function_return_type
//
// Adds the return type to a function inferred from its tail expression if it doesn't have a return
// type specified.
//
// ```
// fn foo() { 4<|>2i32 }
// ```
// ->
// ```
// fn foo() -> i32 { 42i32 }
// ```
pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
    let expr = ctx.find_node_at_offset::<ast::Expr>()?;
    let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?;

    if func.ret_type().is_some() {
        mark::hit!(existing_ret_type);
        return None;
    }
    let body = func.body()?;
    let tail_expr = body.expr()?;
    // check whether the expr we were at is indeed the tail expression
    if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) {
        mark::hit!(not_tail_expr);
        return None;
    }
    let module = ctx.sema.scope(func.syntax()).module()?;
    let ty = ctx.sema.type_of_expr(&tail_expr)?;
    let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
    let rparen = func.param_list()?.r_paren_token()?;

    acc.add(
        AssistId("change_return_type_to_result", AssistKind::RefactorRewrite),
        "Wrap return type in Result",
        tail_expr.syntax().text_range(),
        |builder| {
            let insert_pos = rparen.text_range().end() + TextSize::from(1);

            builder.insert(insert_pos, &format!("-> {} ", ty));
        },
    )
}

#[cfg(test)]
mod tests {
    use crate::tests::{check_assist, check_assist_not_applicable};

    use super::*;

    #[test]
    fn infer_return_type() {
        check_assist(
            infer_function_return_type,
            r#"fn foo() {
                45<|>
            }"#,
            r#"fn foo() -> i32 {
                45
            }"#,
        );
    }

    #[test]
    fn infer_return_type_nested() {
        check_assist(
            infer_function_return_type,
            r#"fn foo() {
                if true {
                    3<|>
                } else {
                    5
                }
            }"#,
            r#"fn foo() -> i32 {
                if true {
                    3
                } else {
                    5
                }
            }"#,
        );
    }

    #[test]
    fn not_applicable_ret_type_specified() {
        mark::check!(existing_ret_type);
        check_assist_not_applicable(
            infer_function_return_type,
            r#"fn foo() -> i32 {
                ( 45<|> + 32 ) * 123
            }"#,
        );
    }

    #[test]
    fn not_applicable_non_tail_expr() {
        mark::check!(not_tail_expr);
        check_assist_not_applicable(
            infer_function_return_type,
            r#"fn foo() {
                let x = <|>3;
                ( 45 + 32 ) * 123
            }"#,
        );
    }
}