aboutsummaryrefslogtreecommitdiff
path: root/crates/syntax/src/algo.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/syntax/src/algo.rs')
-rw-r--r--crates/syntax/src/algo.rs406
1 files changed, 406 insertions, 0 deletions
diff --git a/crates/syntax/src/algo.rs b/crates/syntax/src/algo.rs
new file mode 100644
index 000000000..6254b38ba
--- /dev/null
+++ b/crates/syntax/src/algo.rs
@@ -0,0 +1,406 @@
1//! FIXME: write short doc here
2
3use std::{
4 fmt,
5 ops::{self, RangeInclusive},
6};
7
8use itertools::Itertools;
9use rustc_hash::FxHashMap;
10use text_edit::TextEditBuilder;
11
12use crate::{
13 AstNode, Direction, NodeOrToken, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxNodePtr,
14 SyntaxToken, TextRange, TextSize,
15};
16
17/// Returns ancestors of the node at the offset, sorted by length. This should
18/// do the right thing at an edge, e.g. when searching for expressions at `{
19/// <|>foo }` we will get the name reference instead of the whole block, which
20/// we would get if we just did `find_token_at_offset(...).flat_map(|t|
21/// t.parent().ancestors())`.
22pub fn ancestors_at_offset(
23 node: &SyntaxNode,
24 offset: TextSize,
25) -> impl Iterator<Item = SyntaxNode> {
26 node.token_at_offset(offset)
27 .map(|token| token.parent().ancestors())
28 .kmerge_by(|node1, node2| node1.text_range().len() < node2.text_range().len())
29}
30
31/// Finds a node of specific Ast type at offset. Note that this is slightly
32/// imprecise: if the cursor is strictly between two nodes of the desired type,
33/// as in
34///
35/// ```no-run
36/// struct Foo {}|struct Bar;
37/// ```
38///
39/// then the shorter node will be silently preferred.
40pub fn find_node_at_offset<N: AstNode>(syntax: &SyntaxNode, offset: TextSize) -> Option<N> {
41 ancestors_at_offset(syntax, offset).find_map(N::cast)
42}
43
44pub fn find_node_at_range<N: AstNode>(syntax: &SyntaxNode, range: TextRange) -> Option<N> {
45 find_covering_element(syntax, range).ancestors().find_map(N::cast)
46}
47
48/// Skip to next non `trivia` token
49pub fn skip_trivia_token(mut token: SyntaxToken, direction: Direction) -> Option<SyntaxToken> {
50 while token.kind().is_trivia() {
51 token = match direction {
52 Direction::Next => token.next_token()?,
53 Direction::Prev => token.prev_token()?,
54 }
55 }
56 Some(token)
57}
58
59/// Finds the first sibling in the given direction which is not `trivia`
60pub fn non_trivia_sibling(element: SyntaxElement, direction: Direction) -> Option<SyntaxElement> {
61 return match element {
62 NodeOrToken::Node(node) => node.siblings_with_tokens(direction).skip(1).find(not_trivia),
63 NodeOrToken::Token(token) => token.siblings_with_tokens(direction).skip(1).find(not_trivia),
64 };
65
66 fn not_trivia(element: &SyntaxElement) -> bool {
67 match element {
68 NodeOrToken::Node(_) => true,
69 NodeOrToken::Token(token) => !token.kind().is_trivia(),
70 }
71 }
72}
73
74pub fn find_covering_element(root: &SyntaxNode, range: TextRange) -> SyntaxElement {
75 root.covering_element(range)
76}
77
78pub fn least_common_ancestor(u: &SyntaxNode, v: &SyntaxNode) -> Option<SyntaxNode> {
79 if u == v {
80 return Some(u.clone());
81 }
82
83 let u_depth = u.ancestors().count();
84 let v_depth = v.ancestors().count();
85 let keep = u_depth.min(v_depth);
86
87 let u_candidates = u.ancestors().skip(u_depth - keep);
88 let v_canidates = v.ancestors().skip(v_depth - keep);
89 let (res, _) = u_candidates.zip(v_canidates).find(|(x, y)| x == y)?;
90 Some(res)
91}
92
93pub fn neighbor<T: AstNode>(me: &T, direction: Direction) -> Option<T> {
94 me.syntax().siblings(direction).skip(1).find_map(T::cast)
95}
96
97pub fn has_errors(node: &SyntaxNode) -> bool {
98 node.children().any(|it| it.kind() == SyntaxKind::ERROR)
99}
100
101#[derive(Debug, PartialEq, Eq, Clone, Copy)]
102pub enum InsertPosition<T> {
103 First,
104 Last,
105 Before(T),
106 After(T),
107}
108
109pub struct TreeDiff {
110 replacements: FxHashMap<SyntaxElement, SyntaxElement>,
111}
112
113impl TreeDiff {
114 pub fn into_text_edit(&self, builder: &mut TextEditBuilder) {
115 for (from, to) in self.replacements.iter() {
116 builder.replace(from.text_range(), to.to_string())
117 }
118 }
119
120 pub fn is_empty(&self) -> bool {
121 self.replacements.is_empty()
122 }
123}
124
125/// Finds minimal the diff, which, applied to `from`, will result in `to`.
126///
127/// Specifically, returns a map whose keys are descendants of `from` and values
128/// are descendants of `to`, such that `replace_descendants(from, map) == to`.
129///
130/// A trivial solution is a singleton map `{ from: to }`, but this function
131/// tries to find a more fine-grained diff.
132pub fn diff(from: &SyntaxNode, to: &SyntaxNode) -> TreeDiff {
133 let mut buf = FxHashMap::default();
134 // FIXME: this is both horrible inefficient and gives larger than
135 // necessary diff. I bet there's a cool algorithm to diff trees properly.
136 go(&mut buf, from.clone().into(), to.clone().into());
137 return TreeDiff { replacements: buf };
138
139 fn go(
140 buf: &mut FxHashMap<SyntaxElement, SyntaxElement>,
141 lhs: SyntaxElement,
142 rhs: SyntaxElement,
143 ) {
144 if lhs.kind() == rhs.kind()
145 && lhs.text_range().len() == rhs.text_range().len()
146 && match (&lhs, &rhs) {
147 (NodeOrToken::Node(lhs), NodeOrToken::Node(rhs)) => {
148 lhs.green() == rhs.green() || lhs.text() == rhs.text()
149 }
150 (NodeOrToken::Token(lhs), NodeOrToken::Token(rhs)) => lhs.text() == rhs.text(),
151 _ => false,
152 }
153 {
154 return;
155 }
156 if let (Some(lhs), Some(rhs)) = (lhs.as_node(), rhs.as_node()) {
157 if lhs.children_with_tokens().count() == rhs.children_with_tokens().count() {
158 for (lhs, rhs) in lhs.children_with_tokens().zip(rhs.children_with_tokens()) {
159 go(buf, lhs, rhs)
160 }
161 return;
162 }
163 }
164 buf.insert(lhs, rhs);
165 }
166}
167
168/// Adds specified children (tokens or nodes) to the current node at the
169/// specific position.
170///
171/// This is a type-unsafe low-level editing API, if you need to use it,
172/// prefer to create a type-safe abstraction on top of it instead.
173pub fn insert_children(
174 parent: &SyntaxNode,
175 position: InsertPosition<SyntaxElement>,
176 to_insert: impl IntoIterator<Item = SyntaxElement>,
177) -> SyntaxNode {
178 let mut to_insert = to_insert.into_iter();
179 _insert_children(parent, position, &mut to_insert)
180}
181
182fn _insert_children(
183 parent: &SyntaxNode,
184 position: InsertPosition<SyntaxElement>,
185 to_insert: &mut dyn Iterator<Item = SyntaxElement>,
186) -> SyntaxNode {
187 let mut delta = TextSize::default();
188 let to_insert = to_insert.map(|element| {
189 delta += element.text_range().len();
190 to_green_element(element)
191 });
192
193 let mut old_children = parent.green().children().map(|it| match it {
194 NodeOrToken::Token(it) => NodeOrToken::Token(it.clone()),
195 NodeOrToken::Node(it) => NodeOrToken::Node(it.clone()),
196 });
197
198 let new_children = match &position {
199 InsertPosition::First => to_insert.chain(old_children).collect::<Vec<_>>(),
200 InsertPosition::Last => old_children.chain(to_insert).collect::<Vec<_>>(),
201 InsertPosition::Before(anchor) | InsertPosition::After(anchor) => {
202 let take_anchor = if let InsertPosition::After(_) = position { 1 } else { 0 };
203 let split_at = position_of_child(parent, anchor.clone()) + take_anchor;
204 let before = old_children.by_ref().take(split_at).collect::<Vec<_>>();
205 before.into_iter().chain(to_insert).chain(old_children).collect::<Vec<_>>()
206 }
207 };
208
209 with_children(parent, new_children)
210}
211
212/// Replaces all nodes in `to_delete` with nodes from `to_insert`
213///
214/// This is a type-unsafe low-level editing API, if you need to use it,
215/// prefer to create a type-safe abstraction on top of it instead.
216pub fn replace_children(
217 parent: &SyntaxNode,
218 to_delete: RangeInclusive<SyntaxElement>,
219 to_insert: impl IntoIterator<Item = SyntaxElement>,
220) -> SyntaxNode {
221 let mut to_insert = to_insert.into_iter();
222 _replace_children(parent, to_delete, &mut to_insert)
223}
224
225fn _replace_children(
226 parent: &SyntaxNode,
227 to_delete: RangeInclusive<SyntaxElement>,
228 to_insert: &mut dyn Iterator<Item = SyntaxElement>,
229) -> SyntaxNode {
230 let start = position_of_child(parent, to_delete.start().clone());
231 let end = position_of_child(parent, to_delete.end().clone());
232 let mut old_children = parent.green().children().map(|it| match it {
233 NodeOrToken::Token(it) => NodeOrToken::Token(it.clone()),
234 NodeOrToken::Node(it) => NodeOrToken::Node(it.clone()),
235 });
236
237 let before = old_children.by_ref().take(start).collect::<Vec<_>>();
238 let new_children = before
239 .into_iter()
240 .chain(to_insert.map(to_green_element))
241 .chain(old_children.skip(end + 1 - start))
242 .collect::<Vec<_>>();
243 with_children(parent, new_children)
244}
245
246#[derive(Default)]
247pub struct SyntaxRewriter<'a> {
248 f: Option<Box<dyn Fn(&SyntaxElement) -> Option<SyntaxElement> + 'a>>,
249 //FIXME: add debug_assertions that all elements are in fact from the same file.
250 replacements: FxHashMap<SyntaxElement, Replacement>,
251}
252
253impl fmt::Debug for SyntaxRewriter<'_> {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 f.debug_struct("SyntaxRewriter").field("replacements", &self.replacements).finish()
256 }
257}
258
259impl<'a> SyntaxRewriter<'a> {
260 pub fn from_fn(f: impl Fn(&SyntaxElement) -> Option<SyntaxElement> + 'a) -> SyntaxRewriter<'a> {
261 SyntaxRewriter { f: Some(Box::new(f)), replacements: FxHashMap::default() }
262 }
263 pub fn delete<T: Clone + Into<SyntaxElement>>(&mut self, what: &T) {
264 let what = what.clone().into();
265 let replacement = Replacement::Delete;
266 self.replacements.insert(what, replacement);
267 }
268 pub fn replace<T: Clone + Into<SyntaxElement>>(&mut self, what: &T, with: &T) {
269 let what = what.clone().into();
270 let replacement = Replacement::Single(with.clone().into());
271 self.replacements.insert(what, replacement);
272 }
273 pub fn replace_with_many<T: Clone + Into<SyntaxElement>>(
274 &mut self,
275 what: &T,
276 with: Vec<SyntaxElement>,
277 ) {
278 let what = what.clone().into();
279 let replacement = Replacement::Many(with);
280 self.replacements.insert(what, replacement);
281 }
282 pub fn replace_ast<T: AstNode>(&mut self, what: &T, with: &T) {
283 self.replace(what.syntax(), with.syntax())
284 }
285
286 pub fn rewrite(&self, node: &SyntaxNode) -> SyntaxNode {
287 if self.f.is_none() && self.replacements.is_empty() {
288 return node.clone();
289 }
290 self.rewrite_children(node)
291 }
292
293 pub fn rewrite_ast<N: AstNode>(self, node: &N) -> N {
294 N::cast(self.rewrite(node.syntax())).unwrap()
295 }
296
297 /// Returns a node that encompasses all replacements to be done by this rewriter.
298 ///
299 /// Passing the returned node to `rewrite` will apply all replacements queued up in `self`.
300 ///
301 /// Returns `None` when there are no replacements.
302 pub fn rewrite_root(&self) -> Option<SyntaxNode> {
303 assert!(self.f.is_none());
304 self.replacements
305 .keys()
306 .map(|element| match element {
307 SyntaxElement::Node(it) => it.clone(),
308 SyntaxElement::Token(it) => it.parent(),
309 })
310 // If we only have one replacement, we must return its parent node, since `rewrite` does
311 // not replace the node passed to it.
312 .map(|it| it.parent().unwrap_or(it))
313 .fold1(|a, b| least_common_ancestor(&a, &b).unwrap())
314 }
315
316 fn replacement(&self, element: &SyntaxElement) -> Option<Replacement> {
317 if let Some(f) = &self.f {
318 assert!(self.replacements.is_empty());
319 return f(element).map(Replacement::Single);
320 }
321 self.replacements.get(element).cloned()
322 }
323
324 fn rewrite_children(&self, node: &SyntaxNode) -> SyntaxNode {
325 // FIXME: this could be made much faster.
326 let mut new_children = Vec::new();
327 for child in node.children_with_tokens() {
328 self.rewrite_self(&mut new_children, &child);
329 }
330 with_children(node, new_children)
331 }
332
333 fn rewrite_self(
334 &self,
335 acc: &mut Vec<NodeOrToken<rowan::GreenNode, rowan::GreenToken>>,
336 element: &SyntaxElement,
337 ) {
338 if let Some(replacement) = self.replacement(&element) {
339 match replacement {
340 Replacement::Single(NodeOrToken::Node(it)) => {
341 acc.push(NodeOrToken::Node(it.green().clone()))
342 }
343 Replacement::Single(NodeOrToken::Token(it)) => {
344 acc.push(NodeOrToken::Token(it.green().clone()))
345 }
346 Replacement::Many(replacements) => {
347 acc.extend(replacements.iter().map(|it| match it {
348 NodeOrToken::Node(it) => NodeOrToken::Node(it.green().clone()),
349 NodeOrToken::Token(it) => NodeOrToken::Token(it.green().clone()),
350 }))
351 }
352 Replacement::Delete => (),
353 };
354 return;
355 }
356 let res = match element {
357 NodeOrToken::Token(it) => NodeOrToken::Token(it.green().clone()),
358 NodeOrToken::Node(it) => NodeOrToken::Node(self.rewrite_children(it).green().clone()),
359 };
360 acc.push(res)
361 }
362}
363
364impl ops::AddAssign for SyntaxRewriter<'_> {
365 fn add_assign(&mut self, rhs: SyntaxRewriter) {
366 assert!(rhs.f.is_none());
367 self.replacements.extend(rhs.replacements)
368 }
369}
370
371#[derive(Clone, Debug)]
372enum Replacement {
373 Delete,
374 Single(SyntaxElement),
375 Many(Vec<SyntaxElement>),
376}
377
378fn with_children(
379 parent: &SyntaxNode,
380 new_children: Vec<NodeOrToken<rowan::GreenNode, rowan::GreenToken>>,
381) -> SyntaxNode {
382 let len = new_children.iter().map(|it| it.text_len()).sum::<TextSize>();
383 let new_node = rowan::GreenNode::new(rowan::SyntaxKind(parent.kind() as u16), new_children);
384 let new_root_node = parent.replace_with(new_node);
385 let new_root_node = SyntaxNode::new_root(new_root_node);
386
387 // FIXME: use a more elegant way to re-fetch the node (#1185), make
388 // `range` private afterwards
389 let mut ptr = SyntaxNodePtr::new(parent);
390 ptr.range = TextRange::at(ptr.range.start(), len);
391 ptr.to_node(&new_root_node)
392}
393
394fn position_of_child(parent: &SyntaxNode, child: SyntaxElement) -> usize {
395 parent
396 .children_with_tokens()
397 .position(|it| it == child)
398 .expect("element is not a child of current element")
399}
400
401fn to_green_element(element: SyntaxElement) -> NodeOrToken<rowan::GreenNode, rowan::GreenToken> {
402 match element {
403 NodeOrToken::Node(it) => it.green().clone().into(),
404 NodeOrToken::Token(it) => it.green().clone().into(),
405 }
406}