aboutsummaryrefslogtreecommitdiff
path: root/crates/syntax/src/ast
diff options
context:
space:
mode:
Diffstat (limited to 'crates/syntax/src/ast')
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs153
-rw-r--r--crates/syntax/src/ast/node_ext.rs12
2 files changed, 165 insertions, 0 deletions
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index 529bd0eb1..04f97f368 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -14,10 +14,29 @@ use crate::{
14use super::NameOwner; 14use super::NameOwner;
15 15
16pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit { 16pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit {
17 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList;
17 fn get_or_create_where_clause(&self) -> ast::WhereClause; 18 fn get_or_create_where_clause(&self) -> ast::WhereClause;
18} 19}
19 20
20impl GenericParamsOwnerEdit for ast::Fn { 21impl GenericParamsOwnerEdit for ast::Fn {
22 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
23 match self.generic_param_list() {
24 Some(it) => it,
25 None => {
26 let position = if let Some(name) = self.name() {
27 Position::after(name.syntax)
28 } else if let Some(fn_token) = self.fn_token() {
29 Position::after(fn_token)
30 } else if let Some(param_list) = self.param_list() {
31 Position::before(param_list.syntax)
32 } else {
33 Position::last_child_of(self.syntax())
34 };
35 create_generic_param_list(position)
36 }
37 }
38 }
39
21 fn get_or_create_where_clause(&self) -> WhereClause { 40 fn get_or_create_where_clause(&self) -> WhereClause {
22 if self.where_clause().is_none() { 41 if self.where_clause().is_none() {
23 let position = if let Some(ty) = self.ret_type() { 42 let position = if let Some(ty) = self.ret_type() {
@@ -34,6 +53,20 @@ impl GenericParamsOwnerEdit for ast::Fn {
34} 53}
35 54
36impl GenericParamsOwnerEdit for ast::Impl { 55impl GenericParamsOwnerEdit for ast::Impl {
56 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
57 match self.generic_param_list() {
58 Some(it) => it,
59 None => {
60 let position = if let Some(imp_token) = self.impl_token() {
61 Position::after(imp_token)
62 } else {
63 Position::last_child_of(self.syntax())
64 };
65 create_generic_param_list(position)
66 }
67 }
68 }
69
37 fn get_or_create_where_clause(&self) -> WhereClause { 70 fn get_or_create_where_clause(&self) -> WhereClause {
38 if self.where_clause().is_none() { 71 if self.where_clause().is_none() {
39 let position = if let Some(items) = self.assoc_item_list() { 72 let position = if let Some(items) = self.assoc_item_list() {
@@ -48,6 +81,22 @@ impl GenericParamsOwnerEdit for ast::Impl {
48} 81}
49 82
50impl GenericParamsOwnerEdit for ast::Trait { 83impl GenericParamsOwnerEdit for ast::Trait {
84 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
85 match self.generic_param_list() {
86 Some(it) => it,
87 None => {
88 let position = if let Some(name) = self.name() {
89 Position::after(name.syntax)
90 } else if let Some(trait_token) = self.trait_token() {
91 Position::after(trait_token)
92 } else {
93 Position::last_child_of(self.syntax())
94 };
95 create_generic_param_list(position)
96 }
97 }
98 }
99
51 fn get_or_create_where_clause(&self) -> WhereClause { 100 fn get_or_create_where_clause(&self) -> WhereClause {
52 if self.where_clause().is_none() { 101 if self.where_clause().is_none() {
53 let position = if let Some(items) = self.assoc_item_list() { 102 let position = if let Some(items) = self.assoc_item_list() {
@@ -62,6 +111,22 @@ impl GenericParamsOwnerEdit for ast::Trait {
62} 111}
63 112
64impl GenericParamsOwnerEdit for ast::Struct { 113impl GenericParamsOwnerEdit for ast::Struct {
114 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
115 match self.generic_param_list() {
116 Some(it) => it,
117 None => {
118 let position = if let Some(name) = self.name() {
119 Position::after(name.syntax)
120 } else if let Some(struct_token) = self.struct_token() {
121 Position::after(struct_token)
122 } else {
123 Position::last_child_of(self.syntax())
124 };
125 create_generic_param_list(position)
126 }
127 }
128 }
129
65 fn get_or_create_where_clause(&self) -> WhereClause { 130 fn get_or_create_where_clause(&self) -> WhereClause {
66 if self.where_clause().is_none() { 131 if self.where_clause().is_none() {
67 let tfl = self.field_list().and_then(|fl| match fl { 132 let tfl = self.field_list().and_then(|fl| match fl {
@@ -84,6 +149,22 @@ impl GenericParamsOwnerEdit for ast::Struct {
84} 149}
85 150
86impl GenericParamsOwnerEdit for ast::Enum { 151impl GenericParamsOwnerEdit for ast::Enum {
152 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
153 match self.generic_param_list() {
154 Some(it) => it,
155 None => {
156 let position = if let Some(name) = self.name() {
157 Position::after(name.syntax)
158 } else if let Some(enum_token) = self.enum_token() {
159 Position::after(enum_token)
160 } else {
161 Position::last_child_of(self.syntax())
162 };
163 create_generic_param_list(position)
164 }
165 }
166 }
167
87 fn get_or_create_where_clause(&self) -> WhereClause { 168 fn get_or_create_where_clause(&self) -> WhereClause {
88 if self.where_clause().is_none() { 169 if self.where_clause().is_none() {
89 let position = if let Some(gpl) = self.generic_param_list() { 170 let position = if let Some(gpl) = self.generic_param_list() {
@@ -104,6 +185,37 @@ fn create_where_clause(position: Position) {
104 ted::insert(position, where_clause.syntax()); 185 ted::insert(position, where_clause.syntax());
105} 186}
106 187
188fn create_generic_param_list(position: Position) -> ast::GenericParamList {
189 let gpl = make::generic_param_list(empty()).clone_for_update();
190 ted::insert_raw(position, gpl.syntax());
191 gpl
192}
193
194impl ast::GenericParamList {
195 pub fn add_generic_param(&self, generic_param: ast::GenericParam) {
196 match self.generic_params().last() {
197 Some(last_param) => {
198 let mut elems = Vec::new();
199 if !last_param
200 .syntax()
201 .siblings_with_tokens(Direction::Next)
202 .any(|it| it.kind() == T![,])
203 {
204 elems.push(make::token(T![,]).into());
205 elems.push(make::tokens::single_space().into());
206 };
207 elems.push(generic_param.syntax().clone().into());
208 let after_last_param = Position::after(last_param.syntax());
209 ted::insert_all(after_last_param, elems);
210 }
211 None => {
212 let after_l_angle = Position::after(self.l_angle_token().unwrap());
213 ted::insert(after_l_angle, generic_param.syntax())
214 }
215 }
216 }
217}
218
107impl ast::WhereClause { 219impl ast::WhereClause {
108 pub fn add_predicate(&self, predicate: ast::WherePred) { 220 pub fn add_predicate(&self, predicate: ast::WherePred) {
109 if let Some(pred) = self.predicates().last() { 221 if let Some(pred) = self.predicates().last() {
@@ -164,3 +276,44 @@ impl ast::Use {
164 ted::remove(self.syntax()) 276 ted::remove(self.syntax())
165 } 277 }
166} 278}
279
280#[cfg(test)]
281mod tests {
282 use std::fmt;
283
284 use crate::SourceFile;
285
286 use super::*;
287
288 fn ast_mut_from_text<N: AstNode>(text: &str) -> N {
289 let parse = SourceFile::parse(text);
290 parse.tree().syntax().descendants().find_map(N::cast).unwrap().clone_for_update()
291 }
292
293 #[test]
294 fn test_create_generic_param_list() {
295 fn check_create_gpl<N: GenericParamsOwnerEdit + fmt::Display>(before: &str, after: &str) {
296 let gpl_owner = ast_mut_from_text::<N>(before);
297 gpl_owner.get_or_create_generic_param_list();
298 assert_eq!(gpl_owner.to_string(), after);
299 }
300
301 check_create_gpl::<ast::Fn>("fn foo", "fn foo<>");
302 check_create_gpl::<ast::Fn>("fn foo() {}", "fn foo<>() {}");
303
304 check_create_gpl::<ast::Impl>("impl", "impl<>");
305 check_create_gpl::<ast::Impl>("impl Struct {}", "impl<> Struct {}");
306 check_create_gpl::<ast::Impl>("impl Trait for Struct {}", "impl<> Trait for Struct {}");
307
308 check_create_gpl::<ast::Trait>("trait Trait<>", "trait Trait<>");
309 check_create_gpl::<ast::Trait>("trait Trait<> {}", "trait Trait<> {}");
310
311 check_create_gpl::<ast::Struct>("struct A", "struct A<>");
312 check_create_gpl::<ast::Struct>("struct A;", "struct A<>;");
313 check_create_gpl::<ast::Struct>("struct A();", "struct A<>();");
314 check_create_gpl::<ast::Struct>("struct A {}", "struct A<> {}");
315
316 check_create_gpl::<ast::Enum>("enum E", "enum E<>");
317 check_create_gpl::<ast::Enum>("enum E {", "enum E<> {");
318 }
319}
diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs
index ae98dbd26..171099661 100644
--- a/crates/syntax/src/ast/node_ext.rs
+++ b/crates/syntax/src/ast/node_ext.rs
@@ -125,6 +125,18 @@ pub enum AttrKind {
125 Outer, 125 Outer,
126} 126}
127 127
128impl AttrKind {
129 /// Returns `true` if the attr_kind is [`Inner`].
130 pub fn is_inner(&self) -> bool {
131 matches!(self, Self::Inner)
132 }
133
134 /// Returns `true` if the attr_kind is [`Outer`].
135 pub fn is_outer(&self) -> bool {
136 matches!(self, Self::Outer)
137 }
138}
139
128impl ast::Attr { 140impl ast::Attr {
129 pub fn as_simple_atom(&self) -> Option<SmolStr> { 141 pub fn as_simple_atom(&self) -> Option<SmolStr> {
130 if self.eq_token().is_some() || self.token_tree().is_some() { 142 if self.eq_token().is_some() || self.token_tree().is_some() {