diff options
Diffstat (limited to 'crates/syntax/src/algo.rs')
-rw-r--r-- | crates/syntax/src/algo.rs | 406 |
1 files changed, 406 insertions, 0 deletions
diff --git a/crates/syntax/src/algo.rs b/crates/syntax/src/algo.rs new file mode 100644 index 000000000..6254b38ba --- /dev/null +++ b/crates/syntax/src/algo.rs | |||
@@ -0,0 +1,406 @@ | |||
1 | //! FIXME: write short doc here | ||
2 | |||
3 | use std::{ | ||
4 | fmt, | ||
5 | ops::{self, RangeInclusive}, | ||
6 | }; | ||
7 | |||
8 | use itertools::Itertools; | ||
9 | use rustc_hash::FxHashMap; | ||
10 | use text_edit::TextEditBuilder; | ||
11 | |||
12 | use crate::{ | ||
13 | AstNode, Direction, NodeOrToken, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxNodePtr, | ||
14 | SyntaxToken, TextRange, TextSize, | ||
15 | }; | ||
16 | |||
17 | /// Returns ancestors of the node at the offset, sorted by length. This should | ||
18 | /// do the right thing at an edge, e.g. when searching for expressions at `{ | ||
19 | /// <|>foo }` we will get the name reference instead of the whole block, which | ||
20 | /// we would get if we just did `find_token_at_offset(...).flat_map(|t| | ||
21 | /// t.parent().ancestors())`. | ||
22 | pub fn ancestors_at_offset( | ||
23 | node: &SyntaxNode, | ||
24 | offset: TextSize, | ||
25 | ) -> impl Iterator<Item = SyntaxNode> { | ||
26 | node.token_at_offset(offset) | ||
27 | .map(|token| token.parent().ancestors()) | ||
28 | .kmerge_by(|node1, node2| node1.text_range().len() < node2.text_range().len()) | ||
29 | } | ||
30 | |||
31 | /// Finds a node of specific Ast type at offset. Note that this is slightly | ||
32 | /// imprecise: if the cursor is strictly between two nodes of the desired type, | ||
33 | /// as in | ||
34 | /// | ||
35 | /// ```no-run | ||
36 | /// struct Foo {}|struct Bar; | ||
37 | /// ``` | ||
38 | /// | ||
39 | /// then the shorter node will be silently preferred. | ||
40 | pub fn find_node_at_offset<N: AstNode>(syntax: &SyntaxNode, offset: TextSize) -> Option<N> { | ||
41 | ancestors_at_offset(syntax, offset).find_map(N::cast) | ||
42 | } | ||
43 | |||
44 | pub fn find_node_at_range<N: AstNode>(syntax: &SyntaxNode, range: TextRange) -> Option<N> { | ||
45 | find_covering_element(syntax, range).ancestors().find_map(N::cast) | ||
46 | } | ||
47 | |||
48 | /// Skip to next non `trivia` token | ||
49 | pub fn skip_trivia_token(mut token: SyntaxToken, direction: Direction) -> Option<SyntaxToken> { | ||
50 | while token.kind().is_trivia() { | ||
51 | token = match direction { | ||
52 | Direction::Next => token.next_token()?, | ||
53 | Direction::Prev => token.prev_token()?, | ||
54 | } | ||
55 | } | ||
56 | Some(token) | ||
57 | } | ||
58 | |||
59 | /// Finds the first sibling in the given direction which is not `trivia` | ||
60 | pub fn non_trivia_sibling(element: SyntaxElement, direction: Direction) -> Option<SyntaxElement> { | ||
61 | return match element { | ||
62 | NodeOrToken::Node(node) => node.siblings_with_tokens(direction).skip(1).find(not_trivia), | ||
63 | NodeOrToken::Token(token) => token.siblings_with_tokens(direction).skip(1).find(not_trivia), | ||
64 | }; | ||
65 | |||
66 | fn not_trivia(element: &SyntaxElement) -> bool { | ||
67 | match element { | ||
68 | NodeOrToken::Node(_) => true, | ||
69 | NodeOrToken::Token(token) => !token.kind().is_trivia(), | ||
70 | } | ||
71 | } | ||
72 | } | ||
73 | |||
74 | pub fn find_covering_element(root: &SyntaxNode, range: TextRange) -> SyntaxElement { | ||
75 | root.covering_element(range) | ||
76 | } | ||
77 | |||
78 | pub fn least_common_ancestor(u: &SyntaxNode, v: &SyntaxNode) -> Option<SyntaxNode> { | ||
79 | if u == v { | ||
80 | return Some(u.clone()); | ||
81 | } | ||
82 | |||
83 | let u_depth = u.ancestors().count(); | ||
84 | let v_depth = v.ancestors().count(); | ||
85 | let keep = u_depth.min(v_depth); | ||
86 | |||
87 | let u_candidates = u.ancestors().skip(u_depth - keep); | ||
88 | let v_canidates = v.ancestors().skip(v_depth - keep); | ||
89 | let (res, _) = u_candidates.zip(v_canidates).find(|(x, y)| x == y)?; | ||
90 | Some(res) | ||
91 | } | ||
92 | |||
93 | pub fn neighbor<T: AstNode>(me: &T, direction: Direction) -> Option<T> { | ||
94 | me.syntax().siblings(direction).skip(1).find_map(T::cast) | ||
95 | } | ||
96 | |||
97 | pub fn has_errors(node: &SyntaxNode) -> bool { | ||
98 | node.children().any(|it| it.kind() == SyntaxKind::ERROR) | ||
99 | } | ||
100 | |||
101 | #[derive(Debug, PartialEq, Eq, Clone, Copy)] | ||
102 | pub enum InsertPosition<T> { | ||
103 | First, | ||
104 | Last, | ||
105 | Before(T), | ||
106 | After(T), | ||
107 | } | ||
108 | |||
109 | pub struct TreeDiff { | ||
110 | replacements: FxHashMap<SyntaxElement, SyntaxElement>, | ||
111 | } | ||
112 | |||
113 | impl TreeDiff { | ||
114 | pub fn into_text_edit(&self, builder: &mut TextEditBuilder) { | ||
115 | for (from, to) in self.replacements.iter() { | ||
116 | builder.replace(from.text_range(), to.to_string()) | ||
117 | } | ||
118 | } | ||
119 | |||
120 | pub fn is_empty(&self) -> bool { | ||
121 | self.replacements.is_empty() | ||
122 | } | ||
123 | } | ||
124 | |||
125 | /// Finds minimal the diff, which, applied to `from`, will result in `to`. | ||
126 | /// | ||
127 | /// Specifically, returns a map whose keys are descendants of `from` and values | ||
128 | /// are descendants of `to`, such that `replace_descendants(from, map) == to`. | ||
129 | /// | ||
130 | /// A trivial solution is a singleton map `{ from: to }`, but this function | ||
131 | /// tries to find a more fine-grained diff. | ||
132 | pub fn diff(from: &SyntaxNode, to: &SyntaxNode) -> TreeDiff { | ||
133 | let mut buf = FxHashMap::default(); | ||
134 | // FIXME: this is both horrible inefficient and gives larger than | ||
135 | // necessary diff. I bet there's a cool algorithm to diff trees properly. | ||
136 | go(&mut buf, from.clone().into(), to.clone().into()); | ||
137 | return TreeDiff { replacements: buf }; | ||
138 | |||
139 | fn go( | ||
140 | buf: &mut FxHashMap<SyntaxElement, SyntaxElement>, | ||
141 | lhs: SyntaxElement, | ||
142 | rhs: SyntaxElement, | ||
143 | ) { | ||
144 | if lhs.kind() == rhs.kind() | ||
145 | && lhs.text_range().len() == rhs.text_range().len() | ||
146 | && match (&lhs, &rhs) { | ||
147 | (NodeOrToken::Node(lhs), NodeOrToken::Node(rhs)) => { | ||
148 | lhs.green() == rhs.green() || lhs.text() == rhs.text() | ||
149 | } | ||
150 | (NodeOrToken::Token(lhs), NodeOrToken::Token(rhs)) => lhs.text() == rhs.text(), | ||
151 | _ => false, | ||
152 | } | ||
153 | { | ||
154 | return; | ||
155 | } | ||
156 | if let (Some(lhs), Some(rhs)) = (lhs.as_node(), rhs.as_node()) { | ||
157 | if lhs.children_with_tokens().count() == rhs.children_with_tokens().count() { | ||
158 | for (lhs, rhs) in lhs.children_with_tokens().zip(rhs.children_with_tokens()) { | ||
159 | go(buf, lhs, rhs) | ||
160 | } | ||
161 | return; | ||
162 | } | ||
163 | } | ||
164 | buf.insert(lhs, rhs); | ||
165 | } | ||
166 | } | ||
167 | |||
168 | /// Adds specified children (tokens or nodes) to the current node at the | ||
169 | /// specific position. | ||
170 | /// | ||
171 | /// This is a type-unsafe low-level editing API, if you need to use it, | ||
172 | /// prefer to create a type-safe abstraction on top of it instead. | ||
173 | pub fn insert_children( | ||
174 | parent: &SyntaxNode, | ||
175 | position: InsertPosition<SyntaxElement>, | ||
176 | to_insert: impl IntoIterator<Item = SyntaxElement>, | ||
177 | ) -> SyntaxNode { | ||
178 | let mut to_insert = to_insert.into_iter(); | ||
179 | _insert_children(parent, position, &mut to_insert) | ||
180 | } | ||
181 | |||
182 | fn _insert_children( | ||
183 | parent: &SyntaxNode, | ||
184 | position: InsertPosition<SyntaxElement>, | ||
185 | to_insert: &mut dyn Iterator<Item = SyntaxElement>, | ||
186 | ) -> SyntaxNode { | ||
187 | let mut delta = TextSize::default(); | ||
188 | let to_insert = to_insert.map(|element| { | ||
189 | delta += element.text_range().len(); | ||
190 | to_green_element(element) | ||
191 | }); | ||
192 | |||
193 | let mut old_children = parent.green().children().map(|it| match it { | ||
194 | NodeOrToken::Token(it) => NodeOrToken::Token(it.clone()), | ||
195 | NodeOrToken::Node(it) => NodeOrToken::Node(it.clone()), | ||
196 | }); | ||
197 | |||
198 | let new_children = match &position { | ||
199 | InsertPosition::First => to_insert.chain(old_children).collect::<Vec<_>>(), | ||
200 | InsertPosition::Last => old_children.chain(to_insert).collect::<Vec<_>>(), | ||
201 | InsertPosition::Before(anchor) | InsertPosition::After(anchor) => { | ||
202 | let take_anchor = if let InsertPosition::After(_) = position { 1 } else { 0 }; | ||
203 | let split_at = position_of_child(parent, anchor.clone()) + take_anchor; | ||
204 | let before = old_children.by_ref().take(split_at).collect::<Vec<_>>(); | ||
205 | before.into_iter().chain(to_insert).chain(old_children).collect::<Vec<_>>() | ||
206 | } | ||
207 | }; | ||
208 | |||
209 | with_children(parent, new_children) | ||
210 | } | ||
211 | |||
212 | /// Replaces all nodes in `to_delete` with nodes from `to_insert` | ||
213 | /// | ||
214 | /// This is a type-unsafe low-level editing API, if you need to use it, | ||
215 | /// prefer to create a type-safe abstraction on top of it instead. | ||
216 | pub fn replace_children( | ||
217 | parent: &SyntaxNode, | ||
218 | to_delete: RangeInclusive<SyntaxElement>, | ||
219 | to_insert: impl IntoIterator<Item = SyntaxElement>, | ||
220 | ) -> SyntaxNode { | ||
221 | let mut to_insert = to_insert.into_iter(); | ||
222 | _replace_children(parent, to_delete, &mut to_insert) | ||
223 | } | ||
224 | |||
225 | fn _replace_children( | ||
226 | parent: &SyntaxNode, | ||
227 | to_delete: RangeInclusive<SyntaxElement>, | ||
228 | to_insert: &mut dyn Iterator<Item = SyntaxElement>, | ||
229 | ) -> SyntaxNode { | ||
230 | let start = position_of_child(parent, to_delete.start().clone()); | ||
231 | let end = position_of_child(parent, to_delete.end().clone()); | ||
232 | let mut old_children = parent.green().children().map(|it| match it { | ||
233 | NodeOrToken::Token(it) => NodeOrToken::Token(it.clone()), | ||
234 | NodeOrToken::Node(it) => NodeOrToken::Node(it.clone()), | ||
235 | }); | ||
236 | |||
237 | let before = old_children.by_ref().take(start).collect::<Vec<_>>(); | ||
238 | let new_children = before | ||
239 | .into_iter() | ||
240 | .chain(to_insert.map(to_green_element)) | ||
241 | .chain(old_children.skip(end + 1 - start)) | ||
242 | .collect::<Vec<_>>(); | ||
243 | with_children(parent, new_children) | ||
244 | } | ||
245 | |||
246 | #[derive(Default)] | ||
247 | pub struct SyntaxRewriter<'a> { | ||
248 | f: Option<Box<dyn Fn(&SyntaxElement) -> Option<SyntaxElement> + 'a>>, | ||
249 | //FIXME: add debug_assertions that all elements are in fact from the same file. | ||
250 | replacements: FxHashMap<SyntaxElement, Replacement>, | ||
251 | } | ||
252 | |||
253 | impl fmt::Debug for SyntaxRewriter<'_> { | ||
254 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
255 | f.debug_struct("SyntaxRewriter").field("replacements", &self.replacements).finish() | ||
256 | } | ||
257 | } | ||
258 | |||
259 | impl<'a> SyntaxRewriter<'a> { | ||
260 | pub fn from_fn(f: impl Fn(&SyntaxElement) -> Option<SyntaxElement> + 'a) -> SyntaxRewriter<'a> { | ||
261 | SyntaxRewriter { f: Some(Box::new(f)), replacements: FxHashMap::default() } | ||
262 | } | ||
263 | pub fn delete<T: Clone + Into<SyntaxElement>>(&mut self, what: &T) { | ||
264 | let what = what.clone().into(); | ||
265 | let replacement = Replacement::Delete; | ||
266 | self.replacements.insert(what, replacement); | ||
267 | } | ||
268 | pub fn replace<T: Clone + Into<SyntaxElement>>(&mut self, what: &T, with: &T) { | ||
269 | let what = what.clone().into(); | ||
270 | let replacement = Replacement::Single(with.clone().into()); | ||
271 | self.replacements.insert(what, replacement); | ||
272 | } | ||
273 | pub fn replace_with_many<T: Clone + Into<SyntaxElement>>( | ||
274 | &mut self, | ||
275 | what: &T, | ||
276 | with: Vec<SyntaxElement>, | ||
277 | ) { | ||
278 | let what = what.clone().into(); | ||
279 | let replacement = Replacement::Many(with); | ||
280 | self.replacements.insert(what, replacement); | ||
281 | } | ||
282 | pub fn replace_ast<T: AstNode>(&mut self, what: &T, with: &T) { | ||
283 | self.replace(what.syntax(), with.syntax()) | ||
284 | } | ||
285 | |||
286 | pub fn rewrite(&self, node: &SyntaxNode) -> SyntaxNode { | ||
287 | if self.f.is_none() && self.replacements.is_empty() { | ||
288 | return node.clone(); | ||
289 | } | ||
290 | self.rewrite_children(node) | ||
291 | } | ||
292 | |||
293 | pub fn rewrite_ast<N: AstNode>(self, node: &N) -> N { | ||
294 | N::cast(self.rewrite(node.syntax())).unwrap() | ||
295 | } | ||
296 | |||
297 | /// Returns a node that encompasses all replacements to be done by this rewriter. | ||
298 | /// | ||
299 | /// Passing the returned node to `rewrite` will apply all replacements queued up in `self`. | ||
300 | /// | ||
301 | /// Returns `None` when there are no replacements. | ||
302 | pub fn rewrite_root(&self) -> Option<SyntaxNode> { | ||
303 | assert!(self.f.is_none()); | ||
304 | self.replacements | ||
305 | .keys() | ||
306 | .map(|element| match element { | ||
307 | SyntaxElement::Node(it) => it.clone(), | ||
308 | SyntaxElement::Token(it) => it.parent(), | ||
309 | }) | ||
310 | // If we only have one replacement, we must return its parent node, since `rewrite` does | ||
311 | // not replace the node passed to it. | ||
312 | .map(|it| it.parent().unwrap_or(it)) | ||
313 | .fold1(|a, b| least_common_ancestor(&a, &b).unwrap()) | ||
314 | } | ||
315 | |||
316 | fn replacement(&self, element: &SyntaxElement) -> Option<Replacement> { | ||
317 | if let Some(f) = &self.f { | ||
318 | assert!(self.replacements.is_empty()); | ||
319 | return f(element).map(Replacement::Single); | ||
320 | } | ||
321 | self.replacements.get(element).cloned() | ||
322 | } | ||
323 | |||
324 | fn rewrite_children(&self, node: &SyntaxNode) -> SyntaxNode { | ||
325 | // FIXME: this could be made much faster. | ||
326 | let mut new_children = Vec::new(); | ||
327 | for child in node.children_with_tokens() { | ||
328 | self.rewrite_self(&mut new_children, &child); | ||
329 | } | ||
330 | with_children(node, new_children) | ||
331 | } | ||
332 | |||
333 | fn rewrite_self( | ||
334 | &self, | ||
335 | acc: &mut Vec<NodeOrToken<rowan::GreenNode, rowan::GreenToken>>, | ||
336 | element: &SyntaxElement, | ||
337 | ) { | ||
338 | if let Some(replacement) = self.replacement(&element) { | ||
339 | match replacement { | ||
340 | Replacement::Single(NodeOrToken::Node(it)) => { | ||
341 | acc.push(NodeOrToken::Node(it.green().clone())) | ||
342 | } | ||
343 | Replacement::Single(NodeOrToken::Token(it)) => { | ||
344 | acc.push(NodeOrToken::Token(it.green().clone())) | ||
345 | } | ||
346 | Replacement::Many(replacements) => { | ||
347 | acc.extend(replacements.iter().map(|it| match it { | ||
348 | NodeOrToken::Node(it) => NodeOrToken::Node(it.green().clone()), | ||
349 | NodeOrToken::Token(it) => NodeOrToken::Token(it.green().clone()), | ||
350 | })) | ||
351 | } | ||
352 | Replacement::Delete => (), | ||
353 | }; | ||
354 | return; | ||
355 | } | ||
356 | let res = match element { | ||
357 | NodeOrToken::Token(it) => NodeOrToken::Token(it.green().clone()), | ||
358 | NodeOrToken::Node(it) => NodeOrToken::Node(self.rewrite_children(it).green().clone()), | ||
359 | }; | ||
360 | acc.push(res) | ||
361 | } | ||
362 | } | ||
363 | |||
364 | impl ops::AddAssign for SyntaxRewriter<'_> { | ||
365 | fn add_assign(&mut self, rhs: SyntaxRewriter) { | ||
366 | assert!(rhs.f.is_none()); | ||
367 | self.replacements.extend(rhs.replacements) | ||
368 | } | ||
369 | } | ||
370 | |||
371 | #[derive(Clone, Debug)] | ||
372 | enum Replacement { | ||
373 | Delete, | ||
374 | Single(SyntaxElement), | ||
375 | Many(Vec<SyntaxElement>), | ||
376 | } | ||
377 | |||
378 | fn with_children( | ||
379 | parent: &SyntaxNode, | ||
380 | new_children: Vec<NodeOrToken<rowan::GreenNode, rowan::GreenToken>>, | ||
381 | ) -> SyntaxNode { | ||
382 | let len = new_children.iter().map(|it| it.text_len()).sum::<TextSize>(); | ||
383 | let new_node = rowan::GreenNode::new(rowan::SyntaxKind(parent.kind() as u16), new_children); | ||
384 | let new_root_node = parent.replace_with(new_node); | ||
385 | let new_root_node = SyntaxNode::new_root(new_root_node); | ||
386 | |||
387 | // FIXME: use a more elegant way to re-fetch the node (#1185), make | ||
388 | // `range` private afterwards | ||
389 | let mut ptr = SyntaxNodePtr::new(parent); | ||
390 | ptr.range = TextRange::at(ptr.range.start(), len); | ||
391 | ptr.to_node(&new_root_node) | ||
392 | } | ||
393 | |||
394 | fn position_of_child(parent: &SyntaxNode, child: SyntaxElement) -> usize { | ||
395 | parent | ||
396 | .children_with_tokens() | ||
397 | .position(|it| it == child) | ||
398 | .expect("element is not a child of current element") | ||
399 | } | ||
400 | |||
401 | fn to_green_element(element: SyntaxElement) -> NodeOrToken<rowan::GreenNode, rowan::GreenToken> { | ||
402 | match element { | ||
403 | NodeOrToken::Node(it) => it.green().clone().into(), | ||
404 | NodeOrToken::Token(it) => it.green().clone().into(), | ||
405 | } | ||
406 | } | ||