diff options
Diffstat (limited to 'crates/syntax/src/ast')
-rw-r--r-- | crates/syntax/src/ast/edit_in_place.rs | 153 | ||||
-rw-r--r-- | crates/syntax/src/ast/make.rs | 53 | ||||
-rw-r--r-- | crates/syntax/src/ast/node_ext.rs | 34 |
3 files changed, 233 insertions, 7 deletions
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs index 529bd0eb1..04f97f368 100644 --- a/crates/syntax/src/ast/edit_in_place.rs +++ b/crates/syntax/src/ast/edit_in_place.rs | |||
@@ -14,10 +14,29 @@ use crate::{ | |||
14 | use super::NameOwner; | 14 | use super::NameOwner; |
15 | 15 | ||
16 | pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit { | 16 | pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit { |
17 | fn get_or_create_generic_param_list(&self) -> ast::GenericParamList; | ||
17 | fn get_or_create_where_clause(&self) -> ast::WhereClause; | 18 | fn get_or_create_where_clause(&self) -> ast::WhereClause; |
18 | } | 19 | } |
19 | 20 | ||
20 | impl GenericParamsOwnerEdit for ast::Fn { | 21 | impl GenericParamsOwnerEdit for ast::Fn { |
22 | fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { | ||
23 | match self.generic_param_list() { | ||
24 | Some(it) => it, | ||
25 | None => { | ||
26 | let position = if let Some(name) = self.name() { | ||
27 | Position::after(name.syntax) | ||
28 | } else if let Some(fn_token) = self.fn_token() { | ||
29 | Position::after(fn_token) | ||
30 | } else if let Some(param_list) = self.param_list() { | ||
31 | Position::before(param_list.syntax) | ||
32 | } else { | ||
33 | Position::last_child_of(self.syntax()) | ||
34 | }; | ||
35 | create_generic_param_list(position) | ||
36 | } | ||
37 | } | ||
38 | } | ||
39 | |||
21 | fn get_or_create_where_clause(&self) -> WhereClause { | 40 | fn get_or_create_where_clause(&self) -> WhereClause { |
22 | if self.where_clause().is_none() { | 41 | if self.where_clause().is_none() { |
23 | let position = if let Some(ty) = self.ret_type() { | 42 | let position = if let Some(ty) = self.ret_type() { |
@@ -34,6 +53,20 @@ impl GenericParamsOwnerEdit for ast::Fn { | |||
34 | } | 53 | } |
35 | 54 | ||
36 | impl GenericParamsOwnerEdit for ast::Impl { | 55 | impl GenericParamsOwnerEdit for ast::Impl { |
56 | fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { | ||
57 | match self.generic_param_list() { | ||
58 | Some(it) => it, | ||
59 | None => { | ||
60 | let position = if let Some(imp_token) = self.impl_token() { | ||
61 | Position::after(imp_token) | ||
62 | } else { | ||
63 | Position::last_child_of(self.syntax()) | ||
64 | }; | ||
65 | create_generic_param_list(position) | ||
66 | } | ||
67 | } | ||
68 | } | ||
69 | |||
37 | fn get_or_create_where_clause(&self) -> WhereClause { | 70 | fn get_or_create_where_clause(&self) -> WhereClause { |
38 | if self.where_clause().is_none() { | 71 | if self.where_clause().is_none() { |
39 | let position = if let Some(items) = self.assoc_item_list() { | 72 | let position = if let Some(items) = self.assoc_item_list() { |
@@ -48,6 +81,22 @@ impl GenericParamsOwnerEdit for ast::Impl { | |||
48 | } | 81 | } |
49 | 82 | ||
50 | impl GenericParamsOwnerEdit for ast::Trait { | 83 | impl GenericParamsOwnerEdit for ast::Trait { |
84 | fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { | ||
85 | match self.generic_param_list() { | ||
86 | Some(it) => it, | ||
87 | None => { | ||
88 | let position = if let Some(name) = self.name() { | ||
89 | Position::after(name.syntax) | ||
90 | } else if let Some(trait_token) = self.trait_token() { | ||
91 | Position::after(trait_token) | ||
92 | } else { | ||
93 | Position::last_child_of(self.syntax()) | ||
94 | }; | ||
95 | create_generic_param_list(position) | ||
96 | } | ||
97 | } | ||
98 | } | ||
99 | |||
51 | fn get_or_create_where_clause(&self) -> WhereClause { | 100 | fn get_or_create_where_clause(&self) -> WhereClause { |
52 | if self.where_clause().is_none() { | 101 | if self.where_clause().is_none() { |
53 | let position = if let Some(items) = self.assoc_item_list() { | 102 | let position = if let Some(items) = self.assoc_item_list() { |
@@ -62,6 +111,22 @@ impl GenericParamsOwnerEdit for ast::Trait { | |||
62 | } | 111 | } |
63 | 112 | ||
64 | impl GenericParamsOwnerEdit for ast::Struct { | 113 | impl GenericParamsOwnerEdit for ast::Struct { |
114 | fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { | ||
115 | match self.generic_param_list() { | ||
116 | Some(it) => it, | ||
117 | None => { | ||
118 | let position = if let Some(name) = self.name() { | ||
119 | Position::after(name.syntax) | ||
120 | } else if let Some(struct_token) = self.struct_token() { | ||
121 | Position::after(struct_token) | ||
122 | } else { | ||
123 | Position::last_child_of(self.syntax()) | ||
124 | }; | ||
125 | create_generic_param_list(position) | ||
126 | } | ||
127 | } | ||
128 | } | ||
129 | |||
65 | fn get_or_create_where_clause(&self) -> WhereClause { | 130 | fn get_or_create_where_clause(&self) -> WhereClause { |
66 | if self.where_clause().is_none() { | 131 | if self.where_clause().is_none() { |
67 | let tfl = self.field_list().and_then(|fl| match fl { | 132 | let tfl = self.field_list().and_then(|fl| match fl { |
@@ -84,6 +149,22 @@ impl GenericParamsOwnerEdit for ast::Struct { | |||
84 | } | 149 | } |
85 | 150 | ||
86 | impl GenericParamsOwnerEdit for ast::Enum { | 151 | impl GenericParamsOwnerEdit for ast::Enum { |
152 | fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { | ||
153 | match self.generic_param_list() { | ||
154 | Some(it) => it, | ||
155 | None => { | ||
156 | let position = if let Some(name) = self.name() { | ||
157 | Position::after(name.syntax) | ||
158 | } else if let Some(enum_token) = self.enum_token() { | ||
159 | Position::after(enum_token) | ||
160 | } else { | ||
161 | Position::last_child_of(self.syntax()) | ||
162 | }; | ||
163 | create_generic_param_list(position) | ||
164 | } | ||
165 | } | ||
166 | } | ||
167 | |||
87 | fn get_or_create_where_clause(&self) -> WhereClause { | 168 | fn get_or_create_where_clause(&self) -> WhereClause { |
88 | if self.where_clause().is_none() { | 169 | if self.where_clause().is_none() { |
89 | let position = if let Some(gpl) = self.generic_param_list() { | 170 | let position = if let Some(gpl) = self.generic_param_list() { |
@@ -104,6 +185,37 @@ fn create_where_clause(position: Position) { | |||
104 | ted::insert(position, where_clause.syntax()); | 185 | ted::insert(position, where_clause.syntax()); |
105 | } | 186 | } |
106 | 187 | ||
188 | fn create_generic_param_list(position: Position) -> ast::GenericParamList { | ||
189 | let gpl = make::generic_param_list(empty()).clone_for_update(); | ||
190 | ted::insert_raw(position, gpl.syntax()); | ||
191 | gpl | ||
192 | } | ||
193 | |||
194 | impl ast::GenericParamList { | ||
195 | pub fn add_generic_param(&self, generic_param: ast::GenericParam) { | ||
196 | match self.generic_params().last() { | ||
197 | Some(last_param) => { | ||
198 | let mut elems = Vec::new(); | ||
199 | if !last_param | ||
200 | .syntax() | ||
201 | .siblings_with_tokens(Direction::Next) | ||
202 | .any(|it| it.kind() == T![,]) | ||
203 | { | ||
204 | elems.push(make::token(T![,]).into()); | ||
205 | elems.push(make::tokens::single_space().into()); | ||
206 | }; | ||
207 | elems.push(generic_param.syntax().clone().into()); | ||
208 | let after_last_param = Position::after(last_param.syntax()); | ||
209 | ted::insert_all(after_last_param, elems); | ||
210 | } | ||
211 | None => { | ||
212 | let after_l_angle = Position::after(self.l_angle_token().unwrap()); | ||
213 | ted::insert(after_l_angle, generic_param.syntax()) | ||
214 | } | ||
215 | } | ||
216 | } | ||
217 | } | ||
218 | |||
107 | impl ast::WhereClause { | 219 | impl ast::WhereClause { |
108 | pub fn add_predicate(&self, predicate: ast::WherePred) { | 220 | pub fn add_predicate(&self, predicate: ast::WherePred) { |
109 | if let Some(pred) = self.predicates().last() { | 221 | if let Some(pred) = self.predicates().last() { |
@@ -164,3 +276,44 @@ impl ast::Use { | |||
164 | ted::remove(self.syntax()) | 276 | ted::remove(self.syntax()) |
165 | } | 277 | } |
166 | } | 278 | } |
279 | |||
280 | #[cfg(test)] | ||
281 | mod tests { | ||
282 | use std::fmt; | ||
283 | |||
284 | use crate::SourceFile; | ||
285 | |||
286 | use super::*; | ||
287 | |||
288 | fn ast_mut_from_text<N: AstNode>(text: &str) -> N { | ||
289 | let parse = SourceFile::parse(text); | ||
290 | parse.tree().syntax().descendants().find_map(N::cast).unwrap().clone_for_update() | ||
291 | } | ||
292 | |||
293 | #[test] | ||
294 | fn test_create_generic_param_list() { | ||
295 | fn check_create_gpl<N: GenericParamsOwnerEdit + fmt::Display>(before: &str, after: &str) { | ||
296 | let gpl_owner = ast_mut_from_text::<N>(before); | ||
297 | gpl_owner.get_or_create_generic_param_list(); | ||
298 | assert_eq!(gpl_owner.to_string(), after); | ||
299 | } | ||
300 | |||
301 | check_create_gpl::<ast::Fn>("fn foo", "fn foo<>"); | ||
302 | check_create_gpl::<ast::Fn>("fn foo() {}", "fn foo<>() {}"); | ||
303 | |||
304 | check_create_gpl::<ast::Impl>("impl", "impl<>"); | ||
305 | check_create_gpl::<ast::Impl>("impl Struct {}", "impl<> Struct {}"); | ||
306 | check_create_gpl::<ast::Impl>("impl Trait for Struct {}", "impl<> Trait for Struct {}"); | ||
307 | |||
308 | check_create_gpl::<ast::Trait>("trait Trait<>", "trait Trait<>"); | ||
309 | check_create_gpl::<ast::Trait>("trait Trait<> {}", "trait Trait<> {}"); | ||
310 | |||
311 | check_create_gpl::<ast::Struct>("struct A", "struct A<>"); | ||
312 | check_create_gpl::<ast::Struct>("struct A;", "struct A<>;"); | ||
313 | check_create_gpl::<ast::Struct>("struct A();", "struct A<>();"); | ||
314 | check_create_gpl::<ast::Struct>("struct A {}", "struct A<> {}"); | ||
315 | |||
316 | check_create_gpl::<ast::Enum>("enum E", "enum E<>"); | ||
317 | check_create_gpl::<ast::Enum>("enum E {", "enum E<> {"); | ||
318 | } | ||
319 | } | ||
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index c6a7b99b7..42da09606 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs | |||
@@ -29,9 +29,13 @@ pub fn ty(text: &str) -> ast::Type { | |||
29 | pub fn ty_unit() -> ast::Type { | 29 | pub fn ty_unit() -> ast::Type { |
30 | ty("()") | 30 | ty("()") |
31 | } | 31 | } |
32 | // FIXME: handle types of length == 1 | ||
33 | pub fn ty_tuple(types: impl IntoIterator<Item = ast::Type>) -> ast::Type { | 32 | pub fn ty_tuple(types: impl IntoIterator<Item = ast::Type>) -> ast::Type { |
34 | let contents = types.into_iter().join(", "); | 33 | let mut count: usize = 0; |
34 | let mut contents = types.into_iter().inspect(|_| count += 1).join(", "); | ||
35 | if count == 1 { | ||
36 | contents.push(','); | ||
37 | } | ||
38 | |||
35 | ty(&format!("({})", contents)) | 39 | ty(&format!("({})", contents)) |
36 | } | 40 | } |
37 | // FIXME: handle path to type | 41 | // FIXME: handle path to type |
@@ -133,6 +137,17 @@ pub fn use_(visibility: Option<ast::Visibility>, use_tree: ast::UseTree) -> ast: | |||
133 | ast_from_text(&format!("{}use {};", visibility, use_tree)) | 137 | ast_from_text(&format!("{}use {};", visibility, use_tree)) |
134 | } | 138 | } |
135 | 139 | ||
140 | pub fn record_expr(path: ast::Path, fields: ast::RecordExprFieldList) -> ast::RecordExpr { | ||
141 | ast_from_text(&format!("fn f() {{ {} {} }}", path, fields)) | ||
142 | } | ||
143 | |||
144 | pub fn record_expr_field_list( | ||
145 | fields: impl IntoIterator<Item = ast::RecordExprField>, | ||
146 | ) -> ast::RecordExprFieldList { | ||
147 | let fields = fields.into_iter().join(", "); | ||
148 | ast_from_text(&format!("fn f() {{ S {{ {} }} }}", fields)) | ||
149 | } | ||
150 | |||
136 | pub fn record_expr_field(name: ast::NameRef, expr: Option<ast::Expr>) -> ast::RecordExprField { | 151 | pub fn record_expr_field(name: ast::NameRef, expr: Option<ast::Expr>) -> ast::RecordExprField { |
137 | return match expr { | 152 | return match expr { |
138 | Some(expr) => from_text(&format!("{}: {}", name, expr)), | 153 | Some(expr) => from_text(&format!("{}: {}", name, expr)), |
@@ -290,13 +305,23 @@ pub fn wildcard_pat() -> ast::WildcardPat { | |||
290 | } | 305 | } |
291 | } | 306 | } |
292 | 307 | ||
308 | pub fn literal_pat(lit: &str) -> ast::LiteralPat { | ||
309 | return from_text(lit); | ||
310 | |||
311 | fn from_text(text: &str) -> ast::LiteralPat { | ||
312 | ast_from_text(&format!("fn f() {{ match x {{ {} => {{}} }} }}", text)) | ||
313 | } | ||
314 | } | ||
315 | |||
293 | /// Creates a tuple of patterns from an iterator of patterns. | 316 | /// Creates a tuple of patterns from an iterator of patterns. |
294 | /// | 317 | /// |
295 | /// Invariant: `pats` must be length > 1 | 318 | /// Invariant: `pats` must be length > 0 |
296 | /// | ||
297 | /// FIXME handle `pats` length == 1 | ||
298 | pub fn tuple_pat(pats: impl IntoIterator<Item = ast::Pat>) -> ast::TuplePat { | 319 | pub fn tuple_pat(pats: impl IntoIterator<Item = ast::Pat>) -> ast::TuplePat { |
299 | let pats_str = pats.into_iter().map(|p| p.to_string()).join(", "); | 320 | let mut count: usize = 0; |
321 | let mut pats_str = pats.into_iter().inspect(|_| count += 1).join(", "); | ||
322 | if count == 1 { | ||
323 | pats_str.push(','); | ||
324 | } | ||
300 | return from_text(&format!("({})", pats_str)); | 325 | return from_text(&format!("({})", pats_str)); |
301 | 326 | ||
302 | fn from_text(text: &str) -> ast::TuplePat { | 327 | fn from_text(text: &str) -> ast::TuplePat { |
@@ -325,6 +350,21 @@ pub fn record_pat(path: ast::Path, pats: impl IntoIterator<Item = ast::Pat>) -> | |||
325 | } | 350 | } |
326 | } | 351 | } |
327 | 352 | ||
353 | pub fn record_pat_with_fields(path: ast::Path, fields: ast::RecordPatFieldList) -> ast::RecordPat { | ||
354 | ast_from_text(&format!("fn f({} {}: ()))", path, fields)) | ||
355 | } | ||
356 | |||
357 | pub fn record_pat_field_list( | ||
358 | fields: impl IntoIterator<Item = ast::RecordPatField>, | ||
359 | ) -> ast::RecordPatFieldList { | ||
360 | let fields = fields.into_iter().join(", "); | ||
361 | ast_from_text(&format!("fn f(S {{ {} }}: ()))", fields)) | ||
362 | } | ||
363 | |||
364 | pub fn record_pat_field(name_ref: ast::NameRef, pat: ast::Pat) -> ast::RecordPatField { | ||
365 | ast_from_text(&format!("fn f(S {{ {}: {} }}: ()))", name_ref, pat)) | ||
366 | } | ||
367 | |||
328 | /// Returns a `BindPat` if the path has just one segment, a `PathPat` otherwise. | 368 | /// Returns a `BindPat` if the path has just one segment, a `PathPat` otherwise. |
329 | pub fn path_pat(path: ast::Path) -> ast::Pat { | 369 | pub fn path_pat(path: ast::Path) -> ast::Pat { |
330 | return from_text(&path.to_string()); | 370 | return from_text(&path.to_string()); |
@@ -592,6 +632,7 @@ pub mod tokens { | |||
592 | SOURCE_FILE | 632 | SOURCE_FILE |
593 | .tree() | 633 | .tree() |
594 | .syntax() | 634 | .syntax() |
635 | .clone_for_update() | ||
595 | .descendants_with_tokens() | 636 | .descendants_with_tokens() |
596 | .filter_map(|it| it.into_token()) | 637 | .filter_map(|it| it.into_token()) |
597 | .find(|it| it.kind() == WHITESPACE && it.text() == "\n\n") | 638 | .find(|it| it.kind() == WHITESPACE && it.text() == "\n\n") |
diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs index ae98dbd26..492fbc4a0 100644 --- a/crates/syntax/src/ast/node_ext.rs +++ b/crates/syntax/src/ast/node_ext.rs | |||
@@ -1,7 +1,7 @@ | |||
1 | //! Various extension methods to ast Nodes, which are hard to code-generate. | 1 | //! Various extension methods to ast Nodes, which are hard to code-generate. |
2 | //! Extensions for various expressions live in a sibling `expr_extensions` module. | 2 | //! Extensions for various expressions live in a sibling `expr_extensions` module. |
3 | 3 | ||
4 | use std::fmt; | 4 | use std::{fmt, iter::successors}; |
5 | 5 | ||
6 | use itertools::Itertools; | 6 | use itertools::Itertools; |
7 | use parser::SyntaxKind; | 7 | use parser::SyntaxKind; |
@@ -125,6 +125,18 @@ pub enum AttrKind { | |||
125 | Outer, | 125 | Outer, |
126 | } | 126 | } |
127 | 127 | ||
128 | impl AttrKind { | ||
129 | /// Returns `true` if the attr_kind is [`Inner`]. | ||
130 | pub fn is_inner(&self) -> bool { | ||
131 | matches!(self, Self::Inner) | ||
132 | } | ||
133 | |||
134 | /// Returns `true` if the attr_kind is [`Outer`]. | ||
135 | pub fn is_outer(&self) -> bool { | ||
136 | matches!(self, Self::Outer) | ||
137 | } | ||
138 | } | ||
139 | |||
128 | impl ast::Attr { | 140 | impl ast::Attr { |
129 | pub fn as_simple_atom(&self) -> Option<SmolStr> { | 141 | pub fn as_simple_atom(&self) -> Option<SmolStr> { |
130 | if self.eq_token().is_some() || self.token_tree().is_some() { | 142 | if self.eq_token().is_some() || self.token_tree().is_some() { |
@@ -225,6 +237,26 @@ impl ast::Path { | |||
225 | None => self.segment(), | 237 | None => self.segment(), |
226 | } | 238 | } |
227 | } | 239 | } |
240 | |||
241 | pub fn first_qualifier_or_self(&self) -> ast::Path { | ||
242 | successors(Some(self.clone()), ast::Path::qualifier).last().unwrap() | ||
243 | } | ||
244 | |||
245 | pub fn first_segment(&self) -> Option<ast::PathSegment> { | ||
246 | self.first_qualifier_or_self().segment() | ||
247 | } | ||
248 | |||
249 | pub fn segments(&self) -> impl Iterator<Item = ast::PathSegment> + Clone { | ||
250 | // cant make use of SyntaxNode::siblings, because the returned Iterator is not clone | ||
251 | successors(self.first_segment(), |p| { | ||
252 | p.parent_path().parent_path().and_then(|p| p.segment()) | ||
253 | }) | ||
254 | } | ||
255 | } | ||
256 | impl ast::UseTree { | ||
257 | pub fn is_simple_path(&self) -> bool { | ||
258 | self.use_tree_list().is_none() && self.star_token().is_none() | ||
259 | } | ||
228 | } | 260 | } |
229 | 261 | ||
230 | impl ast::UseTreeList { | 262 | impl ast::UseTreeList { |