aboutsummaryrefslogtreecommitdiff
path: root/crates/hir_expand/src/input.rs
blob: fe4790e7b6ba76680ced79dc7db9540cfb1edc9b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
//! Macro input conditioning.

use syntax::{
    ast::{self, AttrsOwner},
    AstNode, SyntaxNode,
};

use crate::{
    db::AstDatabase,
    name::{name, AsName},
    MacroCallId, MacroCallKind, MacroCallLoc,
};

pub(crate) fn process_macro_input(
    db: &dyn AstDatabase,
    node: SyntaxNode,
    id: MacroCallId,
) -> SyntaxNode {
    let loc: MacroCallLoc = db.lookup_intern_macro(id);

    match loc.kind {
        MacroCallKind::FnLike { .. } => node,
        MacroCallKind::Derive { derive_attr_index, .. } => {
            let item = match ast::Item::cast(node.clone()) {
                Some(item) => item,
                None => return node,
            };

            remove_derives_up_to(item, derive_attr_index as usize).syntax().clone()
        }
    }
}

/// Removes `#[derive]` attributes from `item`, up to `attr_index`.
fn remove_derives_up_to(item: ast::Item, attr_index: usize) -> ast::Item {
    let item = item.clone_for_update();
    for attr in item.attrs().take(attr_index + 1) {
        if let Some(name) =
            attr.path().and_then(|path| path.as_single_segment()).and_then(|seg| seg.name_ref())
        {
            if name.as_name() == name![derive] {
                attr.syntax().detach();
            }
        }
    }
    item
}

#[cfg(test)]
mod tests {
    use base_db::fixture::WithFixture;
    use base_db::SourceDatabase;
    use expect_test::{expect, Expect};

    use crate::test_db::TestDB;

    use super::*;

    fn test_remove_derives_up_to(attr: usize, ra_fixture: &str, expect: Expect) {
        let (db, file_id) = TestDB::with_single_file(&ra_fixture);
        let parsed = db.parse(file_id);

        let mut items: Vec<_> =
            parsed.syntax_node().descendants().filter_map(ast::Item::cast).collect();
        assert_eq!(items.len(), 1);

        let item = remove_derives_up_to(items.pop().unwrap(), attr);
        expect.assert_eq(&item.to_string());
    }

    #[test]
    fn remove_derive() {
        test_remove_derives_up_to(
            2,
            r#"
#[allow(unused)]
#[derive(Copy)]
#[derive(Hello)]
#[derive(Clone)]
struct A {
    bar: u32
}
        "#,
            expect![[r#"
#[allow(unused)]


#[derive(Clone)]
struct A {
    bar: u32
}"#]],
        );
    }
}