aboutsummaryrefslogtreecommitdiff
path: root/crates/syntax/src/ast
diff options
context:
space:
mode:
Diffstat (limited to 'crates/syntax/src/ast')
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs153
-rw-r--r--crates/syntax/src/ast/make.rs53
-rw-r--r--crates/syntax/src/ast/node_ext.rs34
3 files changed, 233 insertions, 7 deletions
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index 529bd0eb1..04f97f368 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -14,10 +14,29 @@ use crate::{
14use super::NameOwner; 14use super::NameOwner;
15 15
16pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit { 16pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit {
17 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList;
17 fn get_or_create_where_clause(&self) -> ast::WhereClause; 18 fn get_or_create_where_clause(&self) -> ast::WhereClause;
18} 19}
19 20
20impl GenericParamsOwnerEdit for ast::Fn { 21impl GenericParamsOwnerEdit for ast::Fn {
22 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
23 match self.generic_param_list() {
24 Some(it) => it,
25 None => {
26 let position = if let Some(name) = self.name() {
27 Position::after(name.syntax)
28 } else if let Some(fn_token) = self.fn_token() {
29 Position::after(fn_token)
30 } else if let Some(param_list) = self.param_list() {
31 Position::before(param_list.syntax)
32 } else {
33 Position::last_child_of(self.syntax())
34 };
35 create_generic_param_list(position)
36 }
37 }
38 }
39
21 fn get_or_create_where_clause(&self) -> WhereClause { 40 fn get_or_create_where_clause(&self) -> WhereClause {
22 if self.where_clause().is_none() { 41 if self.where_clause().is_none() {
23 let position = if let Some(ty) = self.ret_type() { 42 let position = if let Some(ty) = self.ret_type() {
@@ -34,6 +53,20 @@ impl GenericParamsOwnerEdit for ast::Fn {
34} 53}
35 54
36impl GenericParamsOwnerEdit for ast::Impl { 55impl GenericParamsOwnerEdit for ast::Impl {
56 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
57 match self.generic_param_list() {
58 Some(it) => it,
59 None => {
60 let position = if let Some(imp_token) = self.impl_token() {
61 Position::after(imp_token)
62 } else {
63 Position::last_child_of(self.syntax())
64 };
65 create_generic_param_list(position)
66 }
67 }
68 }
69
37 fn get_or_create_where_clause(&self) -> WhereClause { 70 fn get_or_create_where_clause(&self) -> WhereClause {
38 if self.where_clause().is_none() { 71 if self.where_clause().is_none() {
39 let position = if let Some(items) = self.assoc_item_list() { 72 let position = if let Some(items) = self.assoc_item_list() {
@@ -48,6 +81,22 @@ impl GenericParamsOwnerEdit for ast::Impl {
48} 81}
49 82
50impl GenericParamsOwnerEdit for ast::Trait { 83impl GenericParamsOwnerEdit for ast::Trait {
84 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
85 match self.generic_param_list() {
86 Some(it) => it,
87 None => {
88 let position = if let Some(name) = self.name() {
89 Position::after(name.syntax)
90 } else if let Some(trait_token) = self.trait_token() {
91 Position::after(trait_token)
92 } else {
93 Position::last_child_of(self.syntax())
94 };
95 create_generic_param_list(position)
96 }
97 }
98 }
99
51 fn get_or_create_where_clause(&self) -> WhereClause { 100 fn get_or_create_where_clause(&self) -> WhereClause {
52 if self.where_clause().is_none() { 101 if self.where_clause().is_none() {
53 let position = if let Some(items) = self.assoc_item_list() { 102 let position = if let Some(items) = self.assoc_item_list() {
@@ -62,6 +111,22 @@ impl GenericParamsOwnerEdit for ast::Trait {
62} 111}
63 112
64impl GenericParamsOwnerEdit for ast::Struct { 113impl GenericParamsOwnerEdit for ast::Struct {
114 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
115 match self.generic_param_list() {
116 Some(it) => it,
117 None => {
118 let position = if let Some(name) = self.name() {
119 Position::after(name.syntax)
120 } else if let Some(struct_token) = self.struct_token() {
121 Position::after(struct_token)
122 } else {
123 Position::last_child_of(self.syntax())
124 };
125 create_generic_param_list(position)
126 }
127 }
128 }
129
65 fn get_or_create_where_clause(&self) -> WhereClause { 130 fn get_or_create_where_clause(&self) -> WhereClause {
66 if self.where_clause().is_none() { 131 if self.where_clause().is_none() {
67 let tfl = self.field_list().and_then(|fl| match fl { 132 let tfl = self.field_list().and_then(|fl| match fl {
@@ -84,6 +149,22 @@ impl GenericParamsOwnerEdit for ast::Struct {
84} 149}
85 150
86impl GenericParamsOwnerEdit for ast::Enum { 151impl GenericParamsOwnerEdit for ast::Enum {
152 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
153 match self.generic_param_list() {
154 Some(it) => it,
155 None => {
156 let position = if let Some(name) = self.name() {
157 Position::after(name.syntax)
158 } else if let Some(enum_token) = self.enum_token() {
159 Position::after(enum_token)
160 } else {
161 Position::last_child_of(self.syntax())
162 };
163 create_generic_param_list(position)
164 }
165 }
166 }
167
87 fn get_or_create_where_clause(&self) -> WhereClause { 168 fn get_or_create_where_clause(&self) -> WhereClause {
88 if self.where_clause().is_none() { 169 if self.where_clause().is_none() {
89 let position = if let Some(gpl) = self.generic_param_list() { 170 let position = if let Some(gpl) = self.generic_param_list() {
@@ -104,6 +185,37 @@ fn create_where_clause(position: Position) {
104 ted::insert(position, where_clause.syntax()); 185 ted::insert(position, where_clause.syntax());
105} 186}
106 187
188fn create_generic_param_list(position: Position) -> ast::GenericParamList {
189 let gpl = make::generic_param_list(empty()).clone_for_update();
190 ted::insert_raw(position, gpl.syntax());
191 gpl
192}
193
194impl ast::GenericParamList {
195 pub fn add_generic_param(&self, generic_param: ast::GenericParam) {
196 match self.generic_params().last() {
197 Some(last_param) => {
198 let mut elems = Vec::new();
199 if !last_param
200 .syntax()
201 .siblings_with_tokens(Direction::Next)
202 .any(|it| it.kind() == T![,])
203 {
204 elems.push(make::token(T![,]).into());
205 elems.push(make::tokens::single_space().into());
206 };
207 elems.push(generic_param.syntax().clone().into());
208 let after_last_param = Position::after(last_param.syntax());
209 ted::insert_all(after_last_param, elems);
210 }
211 None => {
212 let after_l_angle = Position::after(self.l_angle_token().unwrap());
213 ted::insert(after_l_angle, generic_param.syntax())
214 }
215 }
216 }
217}
218
107impl ast::WhereClause { 219impl ast::WhereClause {
108 pub fn add_predicate(&self, predicate: ast::WherePred) { 220 pub fn add_predicate(&self, predicate: ast::WherePred) {
109 if let Some(pred) = self.predicates().last() { 221 if let Some(pred) = self.predicates().last() {
@@ -164,3 +276,44 @@ impl ast::Use {
164 ted::remove(self.syntax()) 276 ted::remove(self.syntax())
165 } 277 }
166} 278}
279
280#[cfg(test)]
281mod tests {
282 use std::fmt;
283
284 use crate::SourceFile;
285
286 use super::*;
287
288 fn ast_mut_from_text<N: AstNode>(text: &str) -> N {
289 let parse = SourceFile::parse(text);
290 parse.tree().syntax().descendants().find_map(N::cast).unwrap().clone_for_update()
291 }
292
293 #[test]
294 fn test_create_generic_param_list() {
295 fn check_create_gpl<N: GenericParamsOwnerEdit + fmt::Display>(before: &str, after: &str) {
296 let gpl_owner = ast_mut_from_text::<N>(before);
297 gpl_owner.get_or_create_generic_param_list();
298 assert_eq!(gpl_owner.to_string(), after);
299 }
300
301 check_create_gpl::<ast::Fn>("fn foo", "fn foo<>");
302 check_create_gpl::<ast::Fn>("fn foo() {}", "fn foo<>() {}");
303
304 check_create_gpl::<ast::Impl>("impl", "impl<>");
305 check_create_gpl::<ast::Impl>("impl Struct {}", "impl<> Struct {}");
306 check_create_gpl::<ast::Impl>("impl Trait for Struct {}", "impl<> Trait for Struct {}");
307
308 check_create_gpl::<ast::Trait>("trait Trait<>", "trait Trait<>");
309 check_create_gpl::<ast::Trait>("trait Trait<> {}", "trait Trait<> {}");
310
311 check_create_gpl::<ast::Struct>("struct A", "struct A<>");
312 check_create_gpl::<ast::Struct>("struct A;", "struct A<>;");
313 check_create_gpl::<ast::Struct>("struct A();", "struct A<>();");
314 check_create_gpl::<ast::Struct>("struct A {}", "struct A<> {}");
315
316 check_create_gpl::<ast::Enum>("enum E", "enum E<>");
317 check_create_gpl::<ast::Enum>("enum E {", "enum E<> {");
318 }
319}
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index c6a7b99b7..42da09606 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -29,9 +29,13 @@ pub fn ty(text: &str) -> ast::Type {
29pub fn ty_unit() -> ast::Type { 29pub fn ty_unit() -> ast::Type {
30 ty("()") 30 ty("()")
31} 31}
32// FIXME: handle types of length == 1
33pub fn ty_tuple(types: impl IntoIterator<Item = ast::Type>) -> ast::Type { 32pub fn ty_tuple(types: impl IntoIterator<Item = ast::Type>) -> ast::Type {
34 let contents = types.into_iter().join(", "); 33 let mut count: usize = 0;
34 let mut contents = types.into_iter().inspect(|_| count += 1).join(", ");
35 if count == 1 {
36 contents.push(',');
37 }
38
35 ty(&format!("({})", contents)) 39 ty(&format!("({})", contents))
36} 40}
37// FIXME: handle path to type 41// FIXME: handle path to type
@@ -133,6 +137,17 @@ pub fn use_(visibility: Option<ast::Visibility>, use_tree: ast::UseTree) -> ast:
133 ast_from_text(&format!("{}use {};", visibility, use_tree)) 137 ast_from_text(&format!("{}use {};", visibility, use_tree))
134} 138}
135 139
140pub fn record_expr(path: ast::Path, fields: ast::RecordExprFieldList) -> ast::RecordExpr {
141 ast_from_text(&format!("fn f() {{ {} {} }}", path, fields))
142}
143
144pub fn record_expr_field_list(
145 fields: impl IntoIterator<Item = ast::RecordExprField>,
146) -> ast::RecordExprFieldList {
147 let fields = fields.into_iter().join(", ");
148 ast_from_text(&format!("fn f() {{ S {{ {} }} }}", fields))
149}
150
136pub fn record_expr_field(name: ast::NameRef, expr: Option<ast::Expr>) -> ast::RecordExprField { 151pub fn record_expr_field(name: ast::NameRef, expr: Option<ast::Expr>) -> ast::RecordExprField {
137 return match expr { 152 return match expr {
138 Some(expr) => from_text(&format!("{}: {}", name, expr)), 153 Some(expr) => from_text(&format!("{}: {}", name, expr)),
@@ -290,13 +305,23 @@ pub fn wildcard_pat() -> ast::WildcardPat {
290 } 305 }
291} 306}
292 307
308pub fn literal_pat(lit: &str) -> ast::LiteralPat {
309 return from_text(lit);
310
311 fn from_text(text: &str) -> ast::LiteralPat {
312 ast_from_text(&format!("fn f() {{ match x {{ {} => {{}} }} }}", text))
313 }
314}
315
293/// Creates a tuple of patterns from an iterator of patterns. 316/// Creates a tuple of patterns from an iterator of patterns.
294/// 317///
295/// Invariant: `pats` must be length > 1 318/// Invariant: `pats` must be length > 0
296///
297/// FIXME handle `pats` length == 1
298pub fn tuple_pat(pats: impl IntoIterator<Item = ast::Pat>) -> ast::TuplePat { 319pub fn tuple_pat(pats: impl IntoIterator<Item = ast::Pat>) -> ast::TuplePat {
299 let pats_str = pats.into_iter().map(|p| p.to_string()).join(", "); 320 let mut count: usize = 0;
321 let mut pats_str = pats.into_iter().inspect(|_| count += 1).join(", ");
322 if count == 1 {
323 pats_str.push(',');
324 }
300 return from_text(&format!("({})", pats_str)); 325 return from_text(&format!("({})", pats_str));
301 326
302 fn from_text(text: &str) -> ast::TuplePat { 327 fn from_text(text: &str) -> ast::TuplePat {
@@ -325,6 +350,21 @@ pub fn record_pat(path: ast::Path, pats: impl IntoIterator<Item = ast::Pat>) ->
325 } 350 }
326} 351}
327 352
353pub fn record_pat_with_fields(path: ast::Path, fields: ast::RecordPatFieldList) -> ast::RecordPat {
354 ast_from_text(&format!("fn f({} {}: ()))", path, fields))
355}
356
357pub fn record_pat_field_list(
358 fields: impl IntoIterator<Item = ast::RecordPatField>,
359) -> ast::RecordPatFieldList {
360 let fields = fields.into_iter().join(", ");
361 ast_from_text(&format!("fn f(S {{ {} }}: ()))", fields))
362}
363
364pub fn record_pat_field(name_ref: ast::NameRef, pat: ast::Pat) -> ast::RecordPatField {
365 ast_from_text(&format!("fn f(S {{ {}: {} }}: ()))", name_ref, pat))
366}
367
328/// Returns a `BindPat` if the path has just one segment, a `PathPat` otherwise. 368/// Returns a `BindPat` if the path has just one segment, a `PathPat` otherwise.
329pub fn path_pat(path: ast::Path) -> ast::Pat { 369pub fn path_pat(path: ast::Path) -> ast::Pat {
330 return from_text(&path.to_string()); 370 return from_text(&path.to_string());
@@ -592,6 +632,7 @@ pub mod tokens {
592 SOURCE_FILE 632 SOURCE_FILE
593 .tree() 633 .tree()
594 .syntax() 634 .syntax()
635 .clone_for_update()
595 .descendants_with_tokens() 636 .descendants_with_tokens()
596 .filter_map(|it| it.into_token()) 637 .filter_map(|it| it.into_token())
597 .find(|it| it.kind() == WHITESPACE && it.text() == "\n\n") 638 .find(|it| it.kind() == WHITESPACE && it.text() == "\n\n")
diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs
index ae98dbd26..492fbc4a0 100644
--- a/crates/syntax/src/ast/node_ext.rs
+++ b/crates/syntax/src/ast/node_ext.rs
@@ -1,7 +1,7 @@
1//! Various extension methods to ast Nodes, which are hard to code-generate. 1//! Various extension methods to ast Nodes, which are hard to code-generate.
2//! Extensions for various expressions live in a sibling `expr_extensions` module. 2//! Extensions for various expressions live in a sibling `expr_extensions` module.
3 3
4use std::fmt; 4use std::{fmt, iter::successors};
5 5
6use itertools::Itertools; 6use itertools::Itertools;
7use parser::SyntaxKind; 7use parser::SyntaxKind;
@@ -125,6 +125,18 @@ pub enum AttrKind {
125 Outer, 125 Outer,
126} 126}
127 127
128impl AttrKind {
129 /// Returns `true` if the attr_kind is [`Inner`].
130 pub fn is_inner(&self) -> bool {
131 matches!(self, Self::Inner)
132 }
133
134 /// Returns `true` if the attr_kind is [`Outer`].
135 pub fn is_outer(&self) -> bool {
136 matches!(self, Self::Outer)
137 }
138}
139
128impl ast::Attr { 140impl ast::Attr {
129 pub fn as_simple_atom(&self) -> Option<SmolStr> { 141 pub fn as_simple_atom(&self) -> Option<SmolStr> {
130 if self.eq_token().is_some() || self.token_tree().is_some() { 142 if self.eq_token().is_some() || self.token_tree().is_some() {
@@ -225,6 +237,26 @@ impl ast::Path {
225 None => self.segment(), 237 None => self.segment(),
226 } 238 }
227 } 239 }
240
241 pub fn first_qualifier_or_self(&self) -> ast::Path {
242 successors(Some(self.clone()), ast::Path::qualifier).last().unwrap()
243 }
244
245 pub fn first_segment(&self) -> Option<ast::PathSegment> {
246 self.first_qualifier_or_self().segment()
247 }
248
249 pub fn segments(&self) -> impl Iterator<Item = ast::PathSegment> + Clone {
250 // cant make use of SyntaxNode::siblings, because the returned Iterator is not clone
251 successors(self.first_segment(), |p| {
252 p.parent_path().parent_path().and_then(|p| p.segment())
253 })
254 }
255}
256impl ast::UseTree {
257 pub fn is_simple_path(&self) -> bool {
258 self.use_tree_list().is_none() && self.star_token().is_none()
259 }
228} 260}
229 261
230impl ast::UseTreeList { 262impl ast::UseTreeList {