aboutsummaryrefslogtreecommitdiff
path: root/crates/ide_assists/src/handlers/move_bounds.rs
blob: 9ad0c98168cae3d12b498f78916924d37c72afe7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use syntax::{
    ast::{self, edit_in_place::GenericParamsOwnerEdit, make, AstNode, NameOwner, TypeBoundsOwner},
    match_ast,
};

use crate::{AssistContext, AssistId, AssistKind, Assists};

// Assist: move_bounds_to_where_clause
//
// Moves inline type bounds to a where clause.
//
// ```
// fn apply<T, U, $0F: FnOnce(T) -> U>(f: F, x: T) -> U {
//     f(x)
// }
// ```
// ->
// ```
// fn apply<T, U, F>(f: F, x: T) -> U where F: FnOnce(T) -> U {
//     f(x)
// }
// ```
pub(crate) fn move_bounds_to_where_clause(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
    let type_param_list = ctx.find_node_at_offset::<ast::GenericParamList>()?.clone_for_update();

    let mut type_params = type_param_list.type_params();
    if type_params.all(|p| p.type_bound_list().is_none()) {
        return None;
    }

    let parent = type_param_list.syntax().parent()?;
    let original_parent_range = parent.text_range();

    let target = type_param_list.syntax().text_range();
    acc.add(
        AssistId("move_bounds_to_where_clause", AssistKind::RefactorRewrite),
        "Move to where clause",
        target,
        |edit| {
            let where_clause: ast::WhereClause = match_ast! {
                match parent {
                    ast::Fn(it) => it.get_or_create_where_clause(),
                    ast::Trait(it) => it.get_or_create_where_clause(),
                    ast::Impl(it) => it.get_or_create_where_clause(),
                    ast::Enum(it) => it.get_or_create_where_clause(),
                    ast::Struct(it) => it.get_or_create_where_clause(),
                    _ => return,
                }
            };

            for type_param in type_param_list.type_params() {
                if let Some(tbl) = type_param.type_bound_list() {
                    if let Some(predicate) = build_predicate(type_param.clone()) {
                        where_clause.add_predicate(predicate.clone_for_update())
                    }
                    tbl.remove()
                }
            }

            edit.replace(original_parent_range, parent.to_string())
        },
    )
}

fn build_predicate(param: ast::TypeParam) -> Option<ast::WherePred> {
    let path = {
        let name_ref = make::name_ref(&param.name()?.syntax().to_string());
        let segment = make::path_segment(name_ref);
        make::path_unqualified(segment)
    };
    let predicate = make::where_pred(path, param.type_bound_list()?.bounds());
    Some(predicate)
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::tests::check_assist;

    #[test]
    fn move_bounds_to_where_clause_fn() {
        check_assist(
            move_bounds_to_where_clause,
            r#"fn foo<T: u32, $0F: FnOnce(T) -> T>() {}"#,
            r#"fn foo<T, F>() where T: u32, F: FnOnce(T) -> T {}"#,
        );
    }

    #[test]
    fn move_bounds_to_where_clause_impl() {
        check_assist(
            move_bounds_to_where_clause,
            r#"impl<U: u32, $0T> A<U, T> {}"#,
            r#"impl<U, T> A<U, T> where U: u32 {}"#,
        );
    }

    #[test]
    fn move_bounds_to_where_clause_struct() {
        check_assist(
            move_bounds_to_where_clause,
            r#"struct A<$0T: Iterator<Item = u32>> {}"#,
            r#"struct A<T> where T: Iterator<Item = u32> {}"#,
        );
    }

    #[test]
    fn move_bounds_to_where_clause_tuple_struct() {
        check_assist(
            move_bounds_to_where_clause,
            r#"struct Pair<$0T: u32>(T, T);"#,
            r#"struct Pair<T>(T, T) where T: u32;"#,
        );
    }
}