aboutsummaryrefslogtreecommitdiff
path: root/crates
diff options
context:
space:
mode:
authorbors[bot] <26634292+bors[bot]@users.noreply.github.com>2020-06-22 12:50:34 +0100
committerGitHub <[email protected]>2020-06-22 12:50:34 +0100
commitd144d69d2eded43a59c8edb59419b1b9e85c10a5 (patch)
tree0d52bdbb15723d25b7d3fff9ad25274c72e43434 /crates
parent19701b39ac232b023ff9ab077a33c743df96d178 (diff)
parent662ab2ecc8e29eb5995b3c162fac869838bea9a2 (diff)
Merge #4921
4921: Allow SSR to match type references, items, paths and patterns r=davidlattimore a=davidlattimore Part of #3186 Co-authored-by: David Lattimore <[email protected]>
Diffstat (limited to 'crates')
-rw-r--r--crates/ra_ide/Cargo.toml1
-rw-r--r--crates/ra_ide/src/lib.rs2
-rw-r--r--crates/ra_ide/src/ssr.rs563
-rw-r--r--crates/ra_ssr/Cargo.toml19
-rw-r--r--crates/ra_ssr/src/lib.rs120
-rw-r--r--crates/ra_ssr/src/matching.rs494
-rw-r--r--crates/ra_ssr/src/parsing.rs272
-rw-r--r--crates/ra_ssr/src/replacing.rs55
-rw-r--r--crates/ra_ssr/src/tests.rs496
9 files changed, 1467 insertions, 555 deletions
diff --git a/crates/ra_ide/Cargo.toml b/crates/ra_ide/Cargo.toml
index 05c940605..bbc6a5c9b 100644
--- a/crates/ra_ide/Cargo.toml
+++ b/crates/ra_ide/Cargo.toml
@@ -29,6 +29,7 @@ ra_fmt = { path = "../ra_fmt" }
29ra_prof = { path = "../ra_prof" } 29ra_prof = { path = "../ra_prof" }
30test_utils = { path = "../test_utils" } 30test_utils = { path = "../test_utils" }
31ra_assists = { path = "../ra_assists" } 31ra_assists = { path = "../ra_assists" }
32ra_ssr = { path = "../ra_ssr" }
32 33
33# ra_ide should depend only on the top-level `hir` package. if you need 34# ra_ide should depend only on the top-level `hir` package. if you need
34# something from some `hir_xxx` subpackage, reexport the API via `hir`. 35# something from some `hir_xxx` subpackage, reexport the API via `hir`.
diff --git a/crates/ra_ide/src/lib.rs b/crates/ra_ide/src/lib.rs
index be9ab62c0..47823718f 100644
--- a/crates/ra_ide/src/lib.rs
+++ b/crates/ra_ide/src/lib.rs
@@ -70,7 +70,6 @@ pub use crate::{
70 inlay_hints::{InlayHint, InlayHintsConfig, InlayKind}, 70 inlay_hints::{InlayHint, InlayHintsConfig, InlayKind},
71 references::{Declaration, Reference, ReferenceAccess, ReferenceKind, ReferenceSearchResult}, 71 references::{Declaration, Reference, ReferenceAccess, ReferenceKind, ReferenceSearchResult},
72 runnables::{Runnable, RunnableKind, TestId}, 72 runnables::{Runnable, RunnableKind, TestId},
73 ssr::SsrError,
74 syntax_highlighting::{ 73 syntax_highlighting::{
75 Highlight, HighlightModifier, HighlightModifiers, HighlightTag, HighlightedRange, 74 Highlight, HighlightModifier, HighlightModifiers, HighlightTag, HighlightedRange,
76 }, 75 },
@@ -89,6 +88,7 @@ pub use ra_ide_db::{
89 symbol_index::Query, 88 symbol_index::Query,
90 RootDatabase, 89 RootDatabase,
91}; 90};
91pub use ra_ssr::SsrError;
92pub use ra_text_edit::{Indel, TextEdit}; 92pub use ra_text_edit::{Indel, TextEdit};
93 93
94pub type Cancelable<T> = Result<T, Canceled>; 94pub type Cancelable<T> = Result<T, Canceled>;
diff --git a/crates/ra_ide/src/ssr.rs b/crates/ra_ide/src/ssr.rs
index 762aab962..59c230f6c 100644
--- a/crates/ra_ide/src/ssr.rs
+++ b/crates/ra_ide/src/ssr.rs
@@ -1,31 +1,12 @@
1use std::{collections::HashMap, iter::once, str::FromStr}; 1use ra_db::SourceDatabaseExt;
2
3use ra_db::{SourceDatabase, SourceDatabaseExt};
4use ra_ide_db::{symbol_index::SymbolsDatabase, RootDatabase}; 2use ra_ide_db::{symbol_index::SymbolsDatabase, RootDatabase};
5use ra_syntax::ast::{
6 make::try_expr_from_text, ArgList, AstToken, CallExpr, Comment, Expr, MethodCallExpr,
7 RecordField, RecordLit,
8};
9use ra_syntax::{AstNode, SyntaxElement, SyntaxKind, SyntaxNode};
10use ra_text_edit::{TextEdit, TextEditBuilder};
11use rustc_hash::FxHashMap;
12 3
13use crate::SourceFileEdit; 4use crate::SourceFileEdit;
14 5use ra_ssr::{MatchFinder, SsrError, SsrRule};
15#[derive(Debug, PartialEq)]
16pub struct SsrError(String);
17
18impl std::fmt::Display for SsrError {
19 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
20 write!(f, "Parse error: {}", self.0)
21 }
22}
23
24impl std::error::Error for SsrError {}
25 6
26// Feature: Structural Seach and Replace 7// Feature: Structural Seach and Replace
27// 8//
28// Search and replace with named wildcards that will match any expression. 9// Search and replace with named wildcards that will match any expression, type, path, pattern or item.
29// The syntax for a structural search replace command is `<search_pattern> ==>> <replace_pattern>`. 10// The syntax for a structural search replace command is `<search_pattern> ==>> <replace_pattern>`.
30// A `$<name>` placeholder in the search pattern will match any AST node and `$<name>` will reference it in the replacement. 11// A `$<name>` placeholder in the search pattern will match any AST node and `$<name>` will reference it in the replacement.
31// Available via the command `rust-analyzer.ssr`. 12// Available via the command `rust-analyzer.ssr`.
@@ -46,550 +27,24 @@ impl std::error::Error for SsrError {}
46// | VS Code | **Rust Analyzer: Structural Search Replace** 27// | VS Code | **Rust Analyzer: Structural Search Replace**
47// |=== 28// |===
48pub fn parse_search_replace( 29pub fn parse_search_replace(
49 query: &str, 30 rule: &str,
50 parse_only: bool, 31 parse_only: bool,
51 db: &RootDatabase, 32 db: &RootDatabase,
52) -> Result<Vec<SourceFileEdit>, SsrError> { 33) -> Result<Vec<SourceFileEdit>, SsrError> {
53 let mut edits = vec![]; 34 let mut edits = vec![];
54 let query: SsrQuery = query.parse()?; 35 let rule: SsrRule = rule.parse()?;
55 if parse_only { 36 if parse_only {
56 return Ok(edits); 37 return Ok(edits);
57 } 38 }
39 let mut match_finder = MatchFinder::new(db);
40 match_finder.add_rule(rule);
58 for &root in db.local_roots().iter() { 41 for &root in db.local_roots().iter() {
59 let sr = db.source_root(root); 42 let sr = db.source_root(root);
60 for file_id in sr.walk() { 43 for file_id in sr.walk() {
61 let matches = find(&query.pattern, db.parse(file_id).tree().syntax()); 44 if let Some(edit) = match_finder.edits_for_file(file_id) {
62 if !matches.matches.is_empty() { 45 edits.push(SourceFileEdit { file_id, edit });
63 edits.push(SourceFileEdit { file_id, edit: replace(&matches, &query.template) });
64 } 46 }
65 } 47 }
66 } 48 }
67 Ok(edits) 49 Ok(edits)
68} 50}
69
70#[derive(Debug)]
71struct SsrQuery {
72 pattern: SsrPattern,
73 template: SsrTemplate,
74}
75
76#[derive(Debug)]
77struct SsrPattern {
78 pattern: SyntaxNode,
79 vars: Vec<Var>,
80}
81
82/// Represents a `$var` in an SSR query.
83#[derive(Debug, Clone, PartialEq, Eq, Hash)]
84struct Var(String);
85
86#[derive(Debug)]
87struct SsrTemplate {
88 template: SyntaxNode,
89 placeholders: FxHashMap<SyntaxNode, Var>,
90}
91
92type Binding = HashMap<Var, SyntaxNode>;
93
94#[derive(Debug)]
95struct Match {
96 place: SyntaxNode,
97 binding: Binding,
98 ignored_comments: Vec<Comment>,
99}
100
101#[derive(Debug)]
102struct SsrMatches {
103 matches: Vec<Match>,
104}
105
106impl FromStr for SsrQuery {
107 type Err = SsrError;
108
109 fn from_str(query: &str) -> Result<SsrQuery, SsrError> {
110 let mut it = query.split("==>>");
111 let pattern = it.next().expect("at least empty string").trim();
112 let mut template = it
113 .next()
114 .ok_or_else(|| SsrError("Cannot find delemiter `==>>`".into()))?
115 .trim()
116 .to_string();
117 if it.next().is_some() {
118 return Err(SsrError("More than one delimiter found".into()));
119 }
120 let mut vars = vec![];
121 let mut it = pattern.split('$');
122 let mut pattern = it.next().expect("something").to_string();
123
124 for part in it.map(split_by_var) {
125 let (var, remainder) = part?;
126 let new_var = create_name(var, &mut vars)?;
127 pattern.push_str(new_var);
128 pattern.push_str(remainder);
129 template = replace_in_template(template, var, new_var);
130 }
131
132 let template = try_expr_from_text(&template)
133 .ok_or(SsrError("Template is not an expression".into()))?
134 .syntax()
135 .clone();
136 let mut placeholders = FxHashMap::default();
137
138 traverse(&template, &mut |n| {
139 if let Some(v) = vars.iter().find(|v| v.0.as_str() == n.text()) {
140 placeholders.insert(n.clone(), v.clone());
141 false
142 } else {
143 true
144 }
145 });
146
147 let pattern = SsrPattern {
148 pattern: try_expr_from_text(&pattern)
149 .ok_or(SsrError("Pattern is not an expression".into()))?
150 .syntax()
151 .clone(),
152 vars,
153 };
154 let template = SsrTemplate { template, placeholders };
155 Ok(SsrQuery { pattern, template })
156 }
157}
158
159fn traverse(node: &SyntaxNode, go: &mut impl FnMut(&SyntaxNode) -> bool) {
160 if !go(node) {
161 return;
162 }
163 for ref child in node.children() {
164 traverse(child, go);
165 }
166}
167
168fn split_by_var(s: &str) -> Result<(&str, &str), SsrError> {
169 let end_of_name = s.find(|c| !char::is_ascii_alphanumeric(&c)).unwrap_or_else(|| s.len());
170 let name = &s[..end_of_name];
171 is_name(name)?;
172 Ok((name, &s[end_of_name..]))
173}
174
175fn is_name(s: &str) -> Result<(), SsrError> {
176 if s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
177 Ok(())
178 } else {
179 Err(SsrError("Name can contain only alphanumerics and _".into()))
180 }
181}
182
183fn replace_in_template(template: String, var: &str, new_var: &str) -> String {
184 let name = format!("${}", var);
185 template.replace(&name, new_var)
186}
187
188fn create_name<'a>(name: &str, vars: &'a mut Vec<Var>) -> Result<&'a str, SsrError> {
189 let sanitized_name = format!("__search_pattern_{}", name);
190 if vars.iter().any(|a| a.0 == sanitized_name) {
191 return Err(SsrError(format!("Name `{}` repeats more than once", name)));
192 }
193 vars.push(Var(sanitized_name));
194 Ok(&vars.last().unwrap().0)
195}
196
197fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
198 fn check_record_lit(
199 pattern: RecordLit,
200 code: RecordLit,
201 placeholders: &[Var],
202 match_: Match,
203 ) -> Option<Match> {
204 let match_ = check_opt_nodes(pattern.path(), code.path(), placeholders, match_)?;
205
206 let mut pattern_fields: Vec<RecordField> =
207 pattern.record_field_list().map(|x| x.fields().collect()).unwrap_or_default();
208 let mut code_fields: Vec<RecordField> =
209 code.record_field_list().map(|x| x.fields().collect()).unwrap_or_default();
210
211 if pattern_fields.len() != code_fields.len() {
212 return None;
213 }
214
215 let by_name = |a: &RecordField, b: &RecordField| {
216 a.name_ref()
217 .map(|x| x.syntax().text().to_string())
218 .cmp(&b.name_ref().map(|x| x.syntax().text().to_string()))
219 };
220 pattern_fields.sort_by(by_name);
221 code_fields.sort_by(by_name);
222
223 pattern_fields.into_iter().zip(code_fields.into_iter()).fold(
224 Some(match_),
225 |accum, (a, b)| {
226 accum.and_then(|match_| check_opt_nodes(Some(a), Some(b), placeholders, match_))
227 },
228 )
229 }
230
231 fn check_call_and_method_call(
232 pattern: CallExpr,
233 code: MethodCallExpr,
234 placeholders: &[Var],
235 match_: Match,
236 ) -> Option<Match> {
237 let (pattern_name, pattern_type_args) = if let Some(Expr::PathExpr(path_exr)) =
238 pattern.expr()
239 {
240 let segment = path_exr.path().and_then(|p| p.segment());
241 (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list()))
242 } else {
243 (None, None)
244 };
245 let match_ = check_opt_nodes(pattern_name, code.name_ref(), placeholders, match_)?;
246 let match_ =
247 check_opt_nodes(pattern_type_args, code.type_arg_list(), placeholders, match_)?;
248 let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args();
249 let code_args = code.syntax().children().find_map(ArgList::cast)?.args();
250 let code_args = once(code.expr()?).chain(code_args);
251 check_iter(pattern_args, code_args, placeholders, match_)
252 }
253
254 fn check_method_call_and_call(
255 pattern: MethodCallExpr,
256 code: CallExpr,
257 placeholders: &[Var],
258 match_: Match,
259 ) -> Option<Match> {
260 let (code_name, code_type_args) = if let Some(Expr::PathExpr(path_exr)) = code.expr() {
261 let segment = path_exr.path().and_then(|p| p.segment());
262 (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list()))
263 } else {
264 (None, None)
265 };
266 let match_ = check_opt_nodes(pattern.name_ref(), code_name, placeholders, match_)?;
267 let match_ =
268 check_opt_nodes(pattern.type_arg_list(), code_type_args, placeholders, match_)?;
269 let code_args = code.syntax().children().find_map(ArgList::cast)?.args();
270 let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args();
271 let pattern_args = once(pattern.expr()?).chain(pattern_args);
272 check_iter(pattern_args, code_args, placeholders, match_)
273 }
274
275 fn check_opt_nodes(
276 pattern: Option<impl AstNode>,
277 code: Option<impl AstNode>,
278 placeholders: &[Var],
279 match_: Match,
280 ) -> Option<Match> {
281 match (pattern, code) {
282 (Some(pattern), Some(code)) => check(
283 &pattern.syntax().clone().into(),
284 &code.syntax().clone().into(),
285 placeholders,
286 match_,
287 ),
288 (None, None) => Some(match_),
289 _ => None,
290 }
291 }
292
293 fn check_iter<T, I1, I2>(
294 mut pattern: I1,
295 mut code: I2,
296 placeholders: &[Var],
297 match_: Match,
298 ) -> Option<Match>
299 where
300 T: AstNode,
301 I1: Iterator<Item = T>,
302 I2: Iterator<Item = T>,
303 {
304 pattern
305 .by_ref()
306 .zip(code.by_ref())
307 .fold(Some(match_), |accum, (a, b)| {
308 accum.and_then(|match_| {
309 check(
310 &a.syntax().clone().into(),
311 &b.syntax().clone().into(),
312 placeholders,
313 match_,
314 )
315 })
316 })
317 .filter(|_| pattern.next().is_none() && code.next().is_none())
318 }
319
320 fn check(
321 pattern: &SyntaxElement,
322 code: &SyntaxElement,
323 placeholders: &[Var],
324 mut match_: Match,
325 ) -> Option<Match> {
326 match (&pattern, &code) {
327 (SyntaxElement::Token(pattern), SyntaxElement::Token(code)) => {
328 if pattern.text() == code.text() {
329 Some(match_)
330 } else {
331 None
332 }
333 }
334 (SyntaxElement::Node(pattern), SyntaxElement::Node(code)) => {
335 if placeholders.iter().any(|n| n.0.as_str() == pattern.text()) {
336 match_.binding.insert(Var(pattern.text().to_string()), code.clone());
337 Some(match_)
338 } else {
339 if let (Some(pattern), Some(code)) =
340 (RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone()))
341 {
342 check_record_lit(pattern, code, placeholders, match_)
343 } else if let (Some(pattern), Some(code)) =
344 (CallExpr::cast(pattern.clone()), MethodCallExpr::cast(code.clone()))
345 {
346 check_call_and_method_call(pattern, code, placeholders, match_)
347 } else if let (Some(pattern), Some(code)) =
348 (MethodCallExpr::cast(pattern.clone()), CallExpr::cast(code.clone()))
349 {
350 check_method_call_and_call(pattern, code, placeholders, match_)
351 } else {
352 let mut pattern_children = pattern
353 .children_with_tokens()
354 .filter(|element| !element.kind().is_trivia());
355 let mut code_children = code
356 .children_with_tokens()
357 .filter(|element| !element.kind().is_trivia());
358 let new_ignored_comments =
359 code.children_with_tokens().filter_map(|element| {
360 element.as_token().and_then(|token| Comment::cast(token.clone()))
361 });
362 match_.ignored_comments.extend(new_ignored_comments);
363 pattern_children
364 .by_ref()
365 .zip(code_children.by_ref())
366 .fold(Some(match_), |accum, (a, b)| {
367 accum.and_then(|match_| check(&a, &b, placeholders, match_))
368 })
369 .filter(|_| {
370 pattern_children.next().is_none() && code_children.next().is_none()
371 })
372 }
373 }
374 }
375 _ => None,
376 }
377 }
378 let kind = pattern.pattern.kind();
379 let matches = code
380 .descendants()
381 .filter(|n| {
382 n.kind() == kind
383 || (kind == SyntaxKind::CALL_EXPR && n.kind() == SyntaxKind::METHOD_CALL_EXPR)
384 || (kind == SyntaxKind::METHOD_CALL_EXPR && n.kind() == SyntaxKind::CALL_EXPR)
385 })
386 .filter_map(|code| {
387 let match_ =
388 Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] };
389 check(&pattern.pattern.clone().into(), &code.into(), &pattern.vars, match_)
390 })
391 .collect();
392 SsrMatches { matches }
393}
394
395fn replace(matches: &SsrMatches, template: &SsrTemplate) -> TextEdit {
396 let mut builder = TextEditBuilder::default();
397 for match_ in &matches.matches {
398 builder.replace(
399 match_.place.text_range(),
400 render_replace(&match_.binding, &match_.ignored_comments, template),
401 );
402 }
403 builder.finish()
404}
405
406fn render_replace(
407 binding: &Binding,
408 ignored_comments: &Vec<Comment>,
409 template: &SsrTemplate,
410) -> String {
411 let edit = {
412 let mut builder = TextEditBuilder::default();
413 for element in template.template.descendants() {
414 if let Some(var) = template.placeholders.get(&element) {
415 builder.replace(element.text_range(), binding[var].to_string())
416 }
417 }
418 for comment in ignored_comments {
419 builder.insert(template.template.text_range().end(), comment.syntax().to_string())
420 }
421 builder.finish()
422 };
423
424 let mut text = template.template.text().to_string();
425 edit.apply(&mut text);
426 text
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use ra_syntax::SourceFile;
433
434 fn parse_error_text(query: &str) -> String {
435 format!("{}", query.parse::<SsrQuery>().unwrap_err())
436 }
437
438 #[test]
439 fn parser_happy_case() {
440 let result: SsrQuery = "foo($a, $b) ==>> bar($b, $a)".parse().unwrap();
441 assert_eq!(&result.pattern.pattern.text(), "foo(__search_pattern_a, __search_pattern_b)");
442 assert_eq!(result.pattern.vars.len(), 2);
443 assert_eq!(result.pattern.vars[0].0, "__search_pattern_a");
444 assert_eq!(result.pattern.vars[1].0, "__search_pattern_b");
445 assert_eq!(&result.template.template.text(), "bar(__search_pattern_b, __search_pattern_a)");
446 }
447
448 #[test]
449 fn parser_empty_query() {
450 assert_eq!(parse_error_text(""), "Parse error: Cannot find delemiter `==>>`");
451 }
452
453 #[test]
454 fn parser_no_delimiter() {
455 assert_eq!(parse_error_text("foo()"), "Parse error: Cannot find delemiter `==>>`");
456 }
457
458 #[test]
459 fn parser_two_delimiters() {
460 assert_eq!(
461 parse_error_text("foo() ==>> a ==>> b "),
462 "Parse error: More than one delimiter found"
463 );
464 }
465
466 #[test]
467 fn parser_repeated_name() {
468 assert_eq!(
469 parse_error_text("foo($a, $a) ==>>"),
470 "Parse error: Name `a` repeats more than once"
471 );
472 }
473
474 #[test]
475 fn parser_invlid_pattern() {
476 assert_eq!(parse_error_text(" ==>> ()"), "Parse error: Pattern is not an expression");
477 }
478
479 #[test]
480 fn parser_invlid_template() {
481 assert_eq!(parse_error_text("() ==>> )"), "Parse error: Template is not an expression");
482 }
483
484 #[test]
485 fn parse_match_replace() {
486 let query: SsrQuery = "foo($x) ==>> bar($x)".parse().unwrap();
487 let input = "fn main() { foo(1+2); }";
488
489 let code = SourceFile::parse(input).tree();
490 let matches = find(&query.pattern, code.syntax());
491 assert_eq!(matches.matches.len(), 1);
492 assert_eq!(matches.matches[0].place.text(), "foo(1+2)");
493 assert_eq!(matches.matches[0].binding.len(), 1);
494 assert_eq!(
495 matches.matches[0].binding[&Var("__search_pattern_x".to_string())].text(),
496 "1+2"
497 );
498
499 let edit = replace(&matches, &query.template);
500 let mut after = input.to_string();
501 edit.apply(&mut after);
502 assert_eq!(after, "fn main() { bar(1+2); }");
503 }
504
505 fn assert_ssr_transform(query: &str, input: &str, result: &str) {
506 let query: SsrQuery = query.parse().unwrap();
507 let code = SourceFile::parse(input).tree();
508 let matches = find(&query.pattern, code.syntax());
509 let edit = replace(&matches, &query.template);
510 let mut after = input.to_string();
511 edit.apply(&mut after);
512 assert_eq!(after, result);
513 }
514
515 #[test]
516 fn ssr_function_to_method() {
517 assert_ssr_transform(
518 "my_function($a, $b) ==>> ($a).my_method($b)",
519 "loop { my_function( other_func(x, y), z + w) }",
520 "loop { (other_func(x, y)).my_method(z + w) }",
521 )
522 }
523
524 #[test]
525 fn ssr_nested_function() {
526 assert_ssr_transform(
527 "foo($a, $b, $c) ==>> bar($c, baz($a, $b))",
528 "fn main { foo (x + value.method(b), x+y-z, true && false) }",
529 "fn main { bar(true && false, baz(x + value.method(b), x+y-z)) }",
530 )
531 }
532
533 #[test]
534 fn ssr_expected_spacing() {
535 assert_ssr_transform(
536 "foo($x) + bar() ==>> bar($x)",
537 "fn main() { foo(5) + bar() }",
538 "fn main() { bar(5) }",
539 );
540 }
541
542 #[test]
543 fn ssr_with_extra_space() {
544 assert_ssr_transform(
545 "foo($x ) + bar() ==>> bar($x)",
546 "fn main() { foo( 5 ) +bar( ) }",
547 "fn main() { bar(5) }",
548 );
549 }
550
551 #[test]
552 fn ssr_keeps_nested_comment() {
553 assert_ssr_transform(
554 "foo($x) ==>> bar($x)",
555 "fn main() { foo(other(5 /* using 5 */)) }",
556 "fn main() { bar(other(5 /* using 5 */)) }",
557 )
558 }
559
560 #[test]
561 fn ssr_keeps_comment() {
562 assert_ssr_transform(
563 "foo($x) ==>> bar($x)",
564 "fn main() { foo(5 /* using 5 */) }",
565 "fn main() { bar(5)/* using 5 */ }",
566 )
567 }
568
569 #[test]
570 fn ssr_struct_lit() {
571 assert_ssr_transform(
572 "foo{a: $a, b: $b} ==>> foo::new($a, $b)",
573 "fn main() { foo{b:2, a:1} }",
574 "fn main() { foo::new(1, 2) }",
575 )
576 }
577
578 #[test]
579 fn ssr_call_and_method_call() {
580 assert_ssr_transform(
581 "foo::<'a>($a, $b)) ==>> foo2($a, $b)",
582 "fn main() { get().bar.foo::<'a>(1); }",
583 "fn main() { foo2(get().bar, 1); }",
584 )
585 }
586
587 #[test]
588 fn ssr_method_call_and_call() {
589 assert_ssr_transform(
590 "$o.foo::<i32>($a)) ==>> $o.foo2($a)",
591 "fn main() { X::foo::<i32>(x, 1); }",
592 "fn main() { x.foo2(1); }",
593 )
594 }
595}
diff --git a/crates/ra_ssr/Cargo.toml b/crates/ra_ssr/Cargo.toml
new file mode 100644
index 000000000..3c2f15a83
--- /dev/null
+++ b/crates/ra_ssr/Cargo.toml
@@ -0,0 +1,19 @@
1[package]
2edition = "2018"
3name = "ra_ssr"
4version = "0.1.0"
5authors = ["rust-analyzer developers"]
6license = "MIT OR Apache-2.0"
7description = "Structural search and replace of Rust code"
8repository = "https://github.com/rust-analyzer/rust-analyzer"
9
10[lib]
11doctest = false
12
13[dependencies]
14ra_text_edit = { path = "../ra_text_edit" }
15ra_syntax = { path = "../ra_syntax" }
16ra_db = { path = "../ra_db" }
17ra_ide_db = { path = "../ra_ide_db" }
18hir = { path = "../ra_hir", package = "ra_hir" }
19rustc-hash = "1.1.0"
diff --git a/crates/ra_ssr/src/lib.rs b/crates/ra_ssr/src/lib.rs
new file mode 100644
index 000000000..fc716ae82
--- /dev/null
+++ b/crates/ra_ssr/src/lib.rs
@@ -0,0 +1,120 @@
1//! Structural Search Replace
2//!
3//! Allows searching the AST for code that matches one or more patterns and then replacing that code
4//! based on a template.
5
6mod matching;
7mod parsing;
8mod replacing;
9#[cfg(test)]
10mod tests;
11
12use crate::matching::Match;
13use hir::Semantics;
14use ra_db::{FileId, FileRange};
15use ra_syntax::{AstNode, SmolStr, SyntaxNode};
16use ra_text_edit::TextEdit;
17use rustc_hash::FxHashMap;
18
19// A structured search replace rule. Create by calling `parse` on a str.
20#[derive(Debug)]
21pub struct SsrRule {
22 /// A structured pattern that we're searching for.
23 pattern: SsrPattern,
24 /// What we'll replace it with.
25 template: parsing::SsrTemplate,
26}
27
28#[derive(Debug)]
29struct SsrPattern {
30 raw: parsing::RawSearchPattern,
31 /// Placeholders keyed by the stand-in ident that we use in Rust source code.
32 placeholders_by_stand_in: FxHashMap<SmolStr, parsing::Placeholder>,
33 // We store our search pattern, parsed as each different kind of thing we can look for. As we
34 // traverse the AST, we get the appropriate one of these for the type of node we're on. For many
35 // search patterns, only some of these will be present.
36 expr: Option<SyntaxNode>,
37 type_ref: Option<SyntaxNode>,
38 item: Option<SyntaxNode>,
39 path: Option<SyntaxNode>,
40 pattern: Option<SyntaxNode>,
41}
42
43#[derive(Debug, PartialEq)]
44pub struct SsrError(String);
45
46#[derive(Debug, Default)]
47pub struct SsrMatches {
48 matches: Vec<Match>,
49}
50
51/// Searches a crate for pattern matches and possibly replaces them with something else.
52pub struct MatchFinder<'db> {
53 /// Our source of information about the user's code.
54 sema: Semantics<'db, ra_ide_db::RootDatabase>,
55 rules: Vec<SsrRule>,
56}
57
58impl<'db> MatchFinder<'db> {
59 pub fn new(db: &'db ra_ide_db::RootDatabase) -> MatchFinder<'db> {
60 MatchFinder { sema: Semantics::new(db), rules: Vec::new() }
61 }
62
63 pub fn add_rule(&mut self, rule: SsrRule) {
64 self.rules.push(rule);
65 }
66
67 pub fn edits_for_file(&self, file_id: FileId) -> Option<TextEdit> {
68 let matches = self.find_matches_in_file(file_id);
69 if matches.matches.is_empty() {
70 None
71 } else {
72 Some(replacing::matches_to_edit(&matches))
73 }
74 }
75
76 fn find_matches_in_file(&self, file_id: FileId) -> SsrMatches {
77 let file = self.sema.parse(file_id);
78 let code = file.syntax();
79 let mut matches = SsrMatches::default();
80 self.find_matches(code, &None, &mut matches);
81 matches
82 }
83
84 fn find_matches(
85 &self,
86 code: &SyntaxNode,
87 restrict_range: &Option<FileRange>,
88 matches_out: &mut SsrMatches,
89 ) {
90 for rule in &self.rules {
91 if let Ok(mut m) = matching::get_match(false, rule, &code, restrict_range, &self.sema) {
92 // Continue searching in each of our placeholders.
93 for placeholder_value in m.placeholder_values.values_mut() {
94 // Don't search our placeholder if it's the entire matched node, otherwise we'd
95 // find the same match over and over until we got a stack overflow.
96 if placeholder_value.node != *code {
97 self.find_matches(
98 &placeholder_value.node,
99 restrict_range,
100 &mut placeholder_value.inner_matches,
101 );
102 }
103 }
104 matches_out.matches.push(m);
105 return;
106 }
107 }
108 for child in code.children() {
109 self.find_matches(&child, restrict_range, matches_out);
110 }
111 }
112}
113
114impl std::fmt::Display for SsrError {
115 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
116 write!(f, "Parse error: {}", self.0)
117 }
118}
119
120impl std::error::Error for SsrError {}
diff --git a/crates/ra_ssr/src/matching.rs b/crates/ra_ssr/src/matching.rs
new file mode 100644
index 000000000..265b6d793
--- /dev/null
+++ b/crates/ra_ssr/src/matching.rs
@@ -0,0 +1,494 @@
1//! This module is responsible for matching a search pattern against a node in the AST. In the
2//! process of matching, placeholder values are recorded.
3
4use crate::{
5 parsing::{Placeholder, SsrTemplate},
6 SsrMatches, SsrPattern, SsrRule,
7};
8use hir::Semantics;
9use ra_db::FileRange;
10use ra_syntax::ast::{AstNode, AstToken};
11use ra_syntax::{
12 ast, SyntaxElement, SyntaxElementChildren, SyntaxKind, SyntaxNode, SyntaxToken, TextRange,
13};
14use rustc_hash::FxHashMap;
15use std::{cell::Cell, iter::Peekable};
16
17// Creates a match error. If we're currently attempting to match some code that we thought we were
18// going to match, as indicated by the --debug-snippet flag, then populate the reason field.
19macro_rules! match_error {
20 ($e:expr) => {{
21 MatchFailed {
22 reason: if recording_match_fail_reasons() {
23 Some(format!("{}", $e))
24 } else {
25 None
26 }
27 }
28 }};
29 ($fmt:expr, $($arg:tt)+) => {{
30 MatchFailed {
31 reason: if recording_match_fail_reasons() {
32 Some(format!($fmt, $($arg)+))
33 } else {
34 None
35 }
36 }
37 }};
38}
39
40// Fails the current match attempt, recording the supplied reason if we're recording match fail reasons.
41macro_rules! fail_match {
42 ($($args:tt)*) => {return Err(match_error!($($args)*))};
43}
44
45/// Information about a match that was found.
46#[derive(Debug)]
47pub(crate) struct Match {
48 pub(crate) range: TextRange,
49 pub(crate) matched_node: SyntaxNode,
50 pub(crate) placeholder_values: FxHashMap<Var, PlaceholderMatch>,
51 pub(crate) ignored_comments: Vec<ast::Comment>,
52 // A copy of the template for the rule that produced this match. We store this on the match for
53 // if/when we do replacement.
54 pub(crate) template: SsrTemplate,
55}
56
57/// Represents a `$var` in an SSR query.
58#[derive(Debug, Clone, PartialEq, Eq, Hash)]
59pub(crate) struct Var(pub String);
60
61/// Information about a placeholder bound in a match.
62#[derive(Debug)]
63pub(crate) struct PlaceholderMatch {
64 /// The node that the placeholder matched to.
65 pub(crate) node: SyntaxNode,
66 pub(crate) range: FileRange,
67 /// More matches, found within `node`.
68 pub(crate) inner_matches: SsrMatches,
69}
70
71#[derive(Debug)]
72pub(crate) struct MatchFailureReason {
73 pub(crate) reason: String,
74}
75
76/// An "error" indicating that matching failed. Use the fail_match! macro to create and return this.
77#[derive(Clone)]
78pub(crate) struct MatchFailed {
79 /// The reason why we failed to match. Only present when debug_active true in call to
80 /// `get_match`.
81 pub(crate) reason: Option<String>,
82}
83
84/// Checks if `code` matches the search pattern found in `search_scope`, returning information about
85/// the match, if it does. Since we only do matching in this module and searching is done by the
86/// parent module, we don't populate nested matches.
87pub(crate) fn get_match(
88 debug_active: bool,
89 rule: &SsrRule,
90 code: &SyntaxNode,
91 restrict_range: &Option<FileRange>,
92 sema: &Semantics<ra_ide_db::RootDatabase>,
93) -> Result<Match, MatchFailed> {
94 record_match_fails_reasons_scope(debug_active, || {
95 MatchState::try_match(rule, code, restrict_range, sema)
96 })
97}
98
99/// Inputs to matching. This cannot be part of `MatchState`, since we mutate `MatchState` and in at
100/// least one case need to hold a borrow of a placeholder from the input pattern while calling a
101/// mutable `MatchState` method.
102struct MatchInputs<'pattern> {
103 ssr_pattern: &'pattern SsrPattern,
104}
105
106/// State used while attempting to match our search pattern against a particular node of the AST.
107struct MatchState<'db, 'sema> {
108 sema: &'sema Semantics<'db, ra_ide_db::RootDatabase>,
109 /// If any placeholders come from anywhere outside of this range, then the match will be
110 /// rejected.
111 restrict_range: Option<FileRange>,
112 /// The match that we're building. We do two passes for a successful match. On the first pass,
113 /// this is None so that we can avoid doing things like storing copies of what placeholders
114 /// matched to. If that pass succeeds, then we do a second pass where we collect those details.
115 /// This means that if we have a pattern like `$a.foo()` we won't do an insert into the
116 /// placeholders map for every single method call in the codebase. Instead we'll discard all the
117 /// method calls that aren't calls to `foo` on the first pass and only insert into the
118 /// placeholders map on the second pass. Likewise for ignored comments.
119 match_out: Option<Match>,
120}
121
122impl<'db, 'sema> MatchState<'db, 'sema> {
123 fn try_match(
124 rule: &SsrRule,
125 code: &SyntaxNode,
126 restrict_range: &Option<FileRange>,
127 sema: &'sema Semantics<'db, ra_ide_db::RootDatabase>,
128 ) -> Result<Match, MatchFailed> {
129 let mut match_state =
130 MatchState { sema, restrict_range: restrict_range.clone(), match_out: None };
131 let match_inputs = MatchInputs { ssr_pattern: &rule.pattern };
132 let pattern_tree = rule.pattern.tree_for_kind(code.kind())?;
133 // First pass at matching, where we check that node types and idents match.
134 match_state.attempt_match_node(&match_inputs, &pattern_tree, code)?;
135 match_state.validate_range(&sema.original_range(code))?;
136 match_state.match_out = Some(Match {
137 range: sema.original_range(code).range,
138 matched_node: code.clone(),
139 placeholder_values: FxHashMap::default(),
140 ignored_comments: Vec::new(),
141 template: rule.template.clone(),
142 });
143 // Second matching pass, where we record placeholder matches, ignored comments and maybe do
144 // any other more expensive checks that we didn't want to do on the first pass.
145 match_state.attempt_match_node(&match_inputs, &pattern_tree, code)?;
146 Ok(match_state.match_out.unwrap())
147 }
148
149 /// Checks that `range` is within the permitted range if any. This is applicable when we're
150 /// processing a macro expansion and we want to fail the match if we're working with a node that
151 /// didn't originate from the token tree of the macro call.
152 fn validate_range(&self, range: &FileRange) -> Result<(), MatchFailed> {
153 if let Some(restrict_range) = &self.restrict_range {
154 if restrict_range.file_id != range.file_id
155 || !restrict_range.range.contains_range(range.range)
156 {
157 fail_match!("Node originated from a macro");
158 }
159 }
160 Ok(())
161 }
162
163 fn attempt_match_node(
164 &mut self,
165 match_inputs: &MatchInputs,
166 pattern: &SyntaxNode,
167 code: &SyntaxNode,
168 ) -> Result<(), MatchFailed> {
169 // Handle placeholders.
170 if let Some(placeholder) =
171 match_inputs.get_placeholder(&SyntaxElement::Node(pattern.clone()))
172 {
173 if self.match_out.is_none() {
174 return Ok(());
175 }
176 let original_range = self.sema.original_range(code);
177 // We validated the range for the node when we started the match, so the placeholder
178 // probably can't fail range validation, but just to be safe...
179 self.validate_range(&original_range)?;
180 if let Some(match_out) = &mut self.match_out {
181 match_out.placeholder_values.insert(
182 Var(placeholder.ident.to_string()),
183 PlaceholderMatch::new(code, original_range),
184 );
185 }
186 return Ok(());
187 }
188 // Non-placeholders.
189 if pattern.kind() != code.kind() {
190 fail_match!("Pattern had a {:?}, code had {:?}", pattern.kind(), code.kind());
191 }
192 // Some kinds of nodes have special handling. For everything else, we fall back to default
193 // matching.
194 match code.kind() {
195 SyntaxKind::RECORD_FIELD_LIST => {
196 self.attempt_match_record_field_list(match_inputs, pattern, code)
197 }
198 _ => self.attempt_match_node_children(match_inputs, pattern, code),
199 }
200 }
201
202 fn attempt_match_node_children(
203 &mut self,
204 match_inputs: &MatchInputs,
205 pattern: &SyntaxNode,
206 code: &SyntaxNode,
207 ) -> Result<(), MatchFailed> {
208 self.attempt_match_sequences(
209 match_inputs,
210 PatternIterator::new(pattern),
211 code.children_with_tokens(),
212 )
213 }
214
215 fn attempt_match_sequences(
216 &mut self,
217 match_inputs: &MatchInputs,
218 pattern_it: PatternIterator,
219 mut code_it: SyntaxElementChildren,
220 ) -> Result<(), MatchFailed> {
221 let mut pattern_it = pattern_it.peekable();
222 loop {
223 match self.next_non_trivial(&mut code_it) {
224 None => {
225 if let Some(p) = pattern_it.next() {
226 fail_match!("Part of the pattern was unmached: {:?}", p);
227 }
228 return Ok(());
229 }
230 Some(SyntaxElement::Token(c)) => {
231 self.attempt_match_token(&mut pattern_it, &c)?;
232 }
233 Some(SyntaxElement::Node(c)) => match pattern_it.next() {
234 Some(SyntaxElement::Node(p)) => {
235 self.attempt_match_node(match_inputs, &p, &c)?;
236 }
237 Some(p) => fail_match!("Pattern wanted '{}', code has {}", p, c.text()),
238 None => fail_match!("Pattern reached end, code has {}", c.text()),
239 },
240 }
241 }
242 }
243
244 fn attempt_match_token(
245 &mut self,
246 pattern: &mut Peekable<PatternIterator>,
247 code: &ra_syntax::SyntaxToken,
248 ) -> Result<(), MatchFailed> {
249 self.record_ignored_comments(code);
250 // Ignore whitespace and comments.
251 if code.kind().is_trivia() {
252 return Ok(());
253 }
254 if let Some(SyntaxElement::Token(p)) = pattern.peek() {
255 // If the code has a comma and the pattern is about to close something, then accept the
256 // comma without advancing the pattern. i.e. ignore trailing commas.
257 if code.kind() == SyntaxKind::COMMA && is_closing_token(p.kind()) {
258 return Ok(());
259 }
260 // Conversely, if the pattern has a comma and the code doesn't, skip that part of the
261 // pattern and continue to match the code.
262 if p.kind() == SyntaxKind::COMMA && is_closing_token(code.kind()) {
263 pattern.next();
264 }
265 }
266 // Consume an element from the pattern and make sure it matches.
267 match pattern.next() {
268 Some(SyntaxElement::Token(p)) => {
269 if p.kind() != code.kind() || p.text() != code.text() {
270 fail_match!(
271 "Pattern wanted token '{}' ({:?}), but code had token '{}' ({:?})",
272 p.text(),
273 p.kind(),
274 code.text(),
275 code.kind()
276 )
277 }
278 }
279 Some(SyntaxElement::Node(p)) => {
280 // Not sure if this is actually reachable.
281 fail_match!(
282 "Pattern wanted {:?}, but code had token '{}' ({:?})",
283 p,
284 code.text(),
285 code.kind()
286 );
287 }
288 None => {
289 fail_match!("Pattern exhausted, while code remains: `{}`", code.text());
290 }
291 }
292 Ok(())
293 }
294
295 /// We want to allow the records to match in any order, so we have special matching logic for
296 /// them.
297 fn attempt_match_record_field_list(
298 &mut self,
299 match_inputs: &MatchInputs,
300 pattern: &SyntaxNode,
301 code: &SyntaxNode,
302 ) -> Result<(), MatchFailed> {
303 // Build a map keyed by field name.
304 let mut fields_by_name = FxHashMap::default();
305 for child in code.children() {
306 if let Some(record) = ast::RecordField::cast(child.clone()) {
307 if let Some(name) = record.field_name() {
308 fields_by_name.insert(name.text().clone(), child.clone());
309 }
310 }
311 }
312 for p in pattern.children_with_tokens() {
313 if let SyntaxElement::Node(p) = p {
314 if let Some(name_element) = p.first_child_or_token() {
315 if match_inputs.get_placeholder(&name_element).is_some() {
316 // If the pattern is using placeholders for field names then order
317 // independence doesn't make sense. Fall back to regular ordered
318 // matching.
319 return self.attempt_match_node_children(match_inputs, pattern, code);
320 }
321 if let Some(ident) = only_ident(name_element) {
322 let code_record = fields_by_name.remove(ident.text()).ok_or_else(|| {
323 match_error!(
324 "Placeholder has record field '{}', but code doesn't",
325 ident
326 )
327 })?;
328 self.attempt_match_node(match_inputs, &p, &code_record)?;
329 }
330 }
331 }
332 }
333 if let Some(unmatched_fields) = fields_by_name.keys().next() {
334 fail_match!(
335 "{} field(s) of a record literal failed to match, starting with {}",
336 fields_by_name.len(),
337 unmatched_fields
338 );
339 }
340 Ok(())
341 }
342
343 fn next_non_trivial(&mut self, code_it: &mut SyntaxElementChildren) -> Option<SyntaxElement> {
344 loop {
345 let c = code_it.next();
346 if let Some(SyntaxElement::Token(t)) = &c {
347 self.record_ignored_comments(t);
348 if t.kind().is_trivia() {
349 continue;
350 }
351 }
352 return c;
353 }
354 }
355
356 fn record_ignored_comments(&mut self, token: &SyntaxToken) {
357 if token.kind() == SyntaxKind::COMMENT {
358 if let Some(match_out) = &mut self.match_out {
359 if let Some(comment) = ast::Comment::cast(token.clone()) {
360 match_out.ignored_comments.push(comment);
361 }
362 }
363 }
364 }
365}
366
367impl MatchInputs<'_> {
368 fn get_placeholder(&self, element: &SyntaxElement) -> Option<&Placeholder> {
369 only_ident(element.clone())
370 .and_then(|ident| self.ssr_pattern.placeholders_by_stand_in.get(ident.text()))
371 }
372}
373
374fn is_closing_token(kind: SyntaxKind) -> bool {
375 kind == SyntaxKind::R_PAREN || kind == SyntaxKind::R_CURLY || kind == SyntaxKind::R_BRACK
376}
377
378pub(crate) fn record_match_fails_reasons_scope<F, T>(debug_active: bool, f: F) -> T
379where
380 F: Fn() -> T,
381{
382 RECORDING_MATCH_FAIL_REASONS.with(|c| c.set(debug_active));
383 let res = f();
384 RECORDING_MATCH_FAIL_REASONS.with(|c| c.set(false));
385 res
386}
387
388// For performance reasons, we don't want to record the reason why every match fails, only the bit
389// of code that the user indicated they thought would match. We use a thread local to indicate when
390// we are trying to match that bit of code. This saves us having to pass a boolean into all the bits
391// of code that can make the decision to not match.
392thread_local! {
393 pub static RECORDING_MATCH_FAIL_REASONS: Cell<bool> = Cell::new(false);
394}
395
396fn recording_match_fail_reasons() -> bool {
397 RECORDING_MATCH_FAIL_REASONS.with(|c| c.get())
398}
399
400impl PlaceholderMatch {
401 fn new(node: &SyntaxNode, range: FileRange) -> Self {
402 Self { node: node.clone(), range, inner_matches: SsrMatches::default() }
403 }
404}
405
406impl SsrPattern {
407 pub(crate) fn tree_for_kind(&self, kind: SyntaxKind) -> Result<&SyntaxNode, MatchFailed> {
408 let (tree, kind_name) = if ast::Expr::can_cast(kind) {
409 (&self.expr, "expression")
410 } else if ast::TypeRef::can_cast(kind) {
411 (&self.type_ref, "type reference")
412 } else if ast::ModuleItem::can_cast(kind) {
413 (&self.item, "item")
414 } else if ast::Path::can_cast(kind) {
415 (&self.path, "path")
416 } else if ast::Pat::can_cast(kind) {
417 (&self.pattern, "pattern")
418 } else {
419 fail_match!("Matching nodes of kind {:?} is not supported", kind);
420 };
421 match tree {
422 Some(tree) => Ok(tree),
423 None => fail_match!("Pattern cannot be parsed as a {}", kind_name),
424 }
425 }
426}
427
428// If `node` contains nothing but an ident then return it, otherwise return None.
429fn only_ident(element: SyntaxElement) -> Option<SyntaxToken> {
430 match element {
431 SyntaxElement::Token(t) => {
432 if t.kind() == SyntaxKind::IDENT {
433 return Some(t);
434 }
435 }
436 SyntaxElement::Node(n) => {
437 let mut children = n.children_with_tokens();
438 if let (Some(only_child), None) = (children.next(), children.next()) {
439 return only_ident(only_child);
440 }
441 }
442 }
443 None
444}
445
446struct PatternIterator {
447 iter: SyntaxElementChildren,
448}
449
450impl Iterator for PatternIterator {
451 type Item = SyntaxElement;
452
453 fn next(&mut self) -> Option<SyntaxElement> {
454 while let Some(element) = self.iter.next() {
455 if !element.kind().is_trivia() {
456 return Some(element);
457 }
458 }
459 None
460 }
461}
462
463impl PatternIterator {
464 fn new(parent: &SyntaxNode) -> Self {
465 Self { iter: parent.children_with_tokens() }
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::MatchFinder;
473
474 #[test]
475 fn parse_match_replace() {
476 let rule: SsrRule = "foo($x) ==>> bar($x)".parse().unwrap();
477 let input = "fn main() { foo(1+2); }";
478
479 use ra_db::fixture::WithFixture;
480 let (db, file_id) = ra_ide_db::RootDatabase::with_single_file(input);
481 let mut match_finder = MatchFinder::new(&db);
482 match_finder.add_rule(rule);
483 let matches = match_finder.find_matches_in_file(file_id);
484 assert_eq!(matches.matches.len(), 1);
485 assert_eq!(matches.matches[0].matched_node.text(), "foo(1+2)");
486 assert_eq!(matches.matches[0].placeholder_values.len(), 1);
487 assert_eq!(matches.matches[0].placeholder_values[&Var("x".to_string())].node.text(), "1+2");
488
489 let edit = crate::replacing::matches_to_edit(&matches);
490 let mut after = input.to_string();
491 edit.apply(&mut after);
492 assert_eq!(after, "fn main() { bar(1+2); }");
493 }
494}
diff --git a/crates/ra_ssr/src/parsing.rs b/crates/ra_ssr/src/parsing.rs
new file mode 100644
index 000000000..90c13dbc2
--- /dev/null
+++ b/crates/ra_ssr/src/parsing.rs
@@ -0,0 +1,272 @@
1//! This file contains code for parsing SSR rules, which look something like `foo($a) ==>> bar($b)`.
2//! We first split everything before and after the separator `==>>`. Next, both the search pattern
3//! and the replacement template get tokenized by the Rust tokenizer. Tokens are then searched for
4//! placeholders, which start with `$`. For replacement templates, this is the final form. For
5//! search patterns, we go further and parse the pattern as each kind of thing that we can match.
6//! e.g. expressions, type references etc.
7
8use crate::{SsrError, SsrPattern, SsrRule};
9use ra_syntax::{ast, AstNode, SmolStr, SyntaxKind};
10use rustc_hash::{FxHashMap, FxHashSet};
11use std::str::FromStr;
12
13/// Returns from the current function with an error, supplied by arguments as for format!
14macro_rules! bail {
15 ($e:expr) => {return Err($crate::SsrError::new($e))};
16 ($fmt:expr, $($arg:tt)+) => {return Err($crate::SsrError::new(format!($fmt, $($arg)+)))}
17}
18
19#[derive(Clone, Debug)]
20pub(crate) struct SsrTemplate {
21 pub(crate) tokens: Vec<PatternElement>,
22}
23
24#[derive(Debug)]
25pub(crate) struct RawSearchPattern {
26 tokens: Vec<PatternElement>,
27}
28
29// Part of a search or replace pattern.
30#[derive(Clone, Debug, PartialEq, Eq)]
31pub(crate) enum PatternElement {
32 Token(Token),
33 Placeholder(Placeholder),
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub(crate) struct Placeholder {
38 /// The name of this placeholder. e.g. for "$a", this would be "a"
39 pub(crate) ident: SmolStr,
40 /// A unique name used in place of this placeholder when we parse the pattern as Rust code.
41 stand_in_name: String,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub(crate) struct Token {
46 kind: SyntaxKind,
47 pub(crate) text: SmolStr,
48}
49
50impl FromStr for SsrRule {
51 type Err = SsrError;
52
53 fn from_str(query: &str) -> Result<SsrRule, SsrError> {
54 let mut it = query.split("==>>");
55 let pattern = it.next().expect("at least empty string").trim();
56 let template = it
57 .next()
58 .ok_or_else(|| SsrError("Cannot find delemiter `==>>`".into()))?
59 .trim()
60 .to_string();
61 if it.next().is_some() {
62 return Err(SsrError("More than one delimiter found".into()));
63 }
64 let rule = SsrRule { pattern: pattern.parse()?, template: template.parse()? };
65 validate_rule(&rule)?;
66 Ok(rule)
67 }
68}
69
70impl FromStr for RawSearchPattern {
71 type Err = SsrError;
72
73 fn from_str(pattern_str: &str) -> Result<RawSearchPattern, SsrError> {
74 Ok(RawSearchPattern { tokens: parse_pattern(pattern_str)? })
75 }
76}
77
78impl RawSearchPattern {
79 /// Returns this search pattern as Rust source code that we can feed to the Rust parser.
80 fn as_rust_code(&self) -> String {
81 let mut res = String::new();
82 for t in &self.tokens {
83 res.push_str(match t {
84 PatternElement::Token(token) => token.text.as_str(),
85 PatternElement::Placeholder(placeholder) => placeholder.stand_in_name.as_str(),
86 });
87 }
88 res
89 }
90
91 fn placeholders_by_stand_in(&self) -> FxHashMap<SmolStr, Placeholder> {
92 let mut res = FxHashMap::default();
93 for t in &self.tokens {
94 if let PatternElement::Placeholder(placeholder) = t {
95 res.insert(SmolStr::new(placeholder.stand_in_name.clone()), placeholder.clone());
96 }
97 }
98 res
99 }
100}
101
102impl FromStr for SsrPattern {
103 type Err = SsrError;
104
105 fn from_str(pattern_str: &str) -> Result<SsrPattern, SsrError> {
106 let raw: RawSearchPattern = pattern_str.parse()?;
107 let raw_str = raw.as_rust_code();
108 let res = SsrPattern {
109 expr: ast::Expr::parse(&raw_str).ok().map(|n| n.syntax().clone()),
110 type_ref: ast::TypeRef::parse(&raw_str).ok().map(|n| n.syntax().clone()),
111 item: ast::ModuleItem::parse(&raw_str).ok().map(|n| n.syntax().clone()),
112 path: ast::Path::parse(&raw_str).ok().map(|n| n.syntax().clone()),
113 pattern: ast::Pat::parse(&raw_str).ok().map(|n| n.syntax().clone()),
114 placeholders_by_stand_in: raw.placeholders_by_stand_in(),
115 raw,
116 };
117 if res.expr.is_none()
118 && res.type_ref.is_none()
119 && res.item.is_none()
120 && res.path.is_none()
121 && res.pattern.is_none()
122 {
123 bail!("Pattern is not a valid Rust expression, type, item, path or pattern");
124 }
125 Ok(res)
126 }
127}
128
129impl FromStr for SsrTemplate {
130 type Err = SsrError;
131
132 fn from_str(pattern_str: &str) -> Result<SsrTemplate, SsrError> {
133 let tokens = parse_pattern(pattern_str)?;
134 // Validate that the template is a valid fragment of Rust code. We reuse the validation
135 // logic for search patterns since the only thing that differs is the error message.
136 if SsrPattern::from_str(pattern_str).is_err() {
137 bail!("Replacement is not a valid Rust expression, type, item, path or pattern");
138 }
139 // Our actual template needs to preserve whitespace, so we can't reuse `tokens`.
140 Ok(SsrTemplate { tokens })
141 }
142}
143
144/// Returns `pattern_str`, parsed as a search or replace pattern. If `remove_whitespace` is true,
145/// then any whitespace tokens will be removed, which we do for the search pattern, but not for the
146/// replace pattern.
147fn parse_pattern(pattern_str: &str) -> Result<Vec<PatternElement>, SsrError> {
148 let mut res = Vec::new();
149 let mut placeholder_names = FxHashSet::default();
150 let mut tokens = tokenize(pattern_str)?.into_iter();
151 while let Some(token) = tokens.next() {
152 if token.kind == SyntaxKind::DOLLAR {
153 let placeholder = parse_placeholder(&mut tokens)?;
154 if !placeholder_names.insert(placeholder.ident.clone()) {
155 bail!("Name `{}` repeats more than once", placeholder.ident);
156 }
157 res.push(PatternElement::Placeholder(placeholder));
158 } else {
159 res.push(PatternElement::Token(token));
160 }
161 }
162 Ok(res)
163}
164
165/// Checks for errors in a rule. e.g. the replace pattern referencing placeholders that the search
166/// pattern didn't define.
167fn validate_rule(rule: &SsrRule) -> Result<(), SsrError> {
168 let mut defined_placeholders = std::collections::HashSet::new();
169 for p in &rule.pattern.raw.tokens {
170 if let PatternElement::Placeholder(placeholder) = p {
171 defined_placeholders.insert(&placeholder.ident);
172 }
173 }
174 let mut undefined = Vec::new();
175 for p in &rule.template.tokens {
176 if let PatternElement::Placeholder(placeholder) = p {
177 if !defined_placeholders.contains(&placeholder.ident) {
178 undefined.push(format!("${}", placeholder.ident));
179 }
180 }
181 }
182 if !undefined.is_empty() {
183 bail!("Replacement contains undefined placeholders: {}", undefined.join(", "));
184 }
185 Ok(())
186}
187
188fn tokenize(source: &str) -> Result<Vec<Token>, SsrError> {
189 let mut start = 0;
190 let (raw_tokens, errors) = ra_syntax::tokenize(source);
191 if let Some(first_error) = errors.first() {
192 bail!("Failed to parse pattern: {}", first_error);
193 }
194 let mut tokens: Vec<Token> = Vec::new();
195 for raw_token in raw_tokens {
196 let token_len = usize::from(raw_token.len);
197 tokens.push(Token {
198 kind: raw_token.kind,
199 text: SmolStr::new(&source[start..start + token_len]),
200 });
201 start += token_len;
202 }
203 Ok(tokens)
204}
205
206fn parse_placeholder(tokens: &mut std::vec::IntoIter<Token>) -> Result<Placeholder, SsrError> {
207 let mut name = None;
208 if let Some(token) = tokens.next() {
209 match token.kind {
210 SyntaxKind::IDENT => {
211 name = Some(token.text);
212 }
213 _ => {
214 bail!("Placeholders should be $name");
215 }
216 }
217 }
218 let name = name.ok_or_else(|| SsrError::new("Placeholder ($) with no name"))?;
219 Ok(Placeholder::new(name))
220}
221
222impl Placeholder {
223 fn new(name: SmolStr) -> Self {
224 Self { stand_in_name: format!("__placeholder_{}", name), ident: name }
225 }
226}
227
228impl SsrError {
229 fn new(message: impl Into<String>) -> SsrError {
230 SsrError(message.into())
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn parser_happy_case() {
240 fn token(kind: SyntaxKind, text: &str) -> PatternElement {
241 PatternElement::Token(Token { kind, text: SmolStr::new(text) })
242 }
243 fn placeholder(name: &str) -> PatternElement {
244 PatternElement::Placeholder(Placeholder::new(SmolStr::new(name)))
245 }
246 let result: SsrRule = "foo($a, $b) ==>> bar($b, $a)".parse().unwrap();
247 assert_eq!(
248 result.pattern.raw.tokens,
249 vec![
250 token(SyntaxKind::IDENT, "foo"),
251 token(SyntaxKind::L_PAREN, "("),
252 placeholder("a"),
253 token(SyntaxKind::COMMA, ","),
254 token(SyntaxKind::WHITESPACE, " "),
255 placeholder("b"),
256 token(SyntaxKind::R_PAREN, ")"),
257 ]
258 );
259 assert_eq!(
260 result.template.tokens,
261 vec![
262 token(SyntaxKind::IDENT, "bar"),
263 token(SyntaxKind::L_PAREN, "("),
264 placeholder("b"),
265 token(SyntaxKind::COMMA, ","),
266 token(SyntaxKind::WHITESPACE, " "),
267 placeholder("a"),
268 token(SyntaxKind::R_PAREN, ")"),
269 ]
270 );
271 }
272}
diff --git a/crates/ra_ssr/src/replacing.rs b/crates/ra_ssr/src/replacing.rs
new file mode 100644
index 000000000..81a5e06a9
--- /dev/null
+++ b/crates/ra_ssr/src/replacing.rs
@@ -0,0 +1,55 @@
1//! Code for applying replacement templates for matches that have previously been found.
2
3use crate::matching::Var;
4use crate::parsing::PatternElement;
5use crate::{Match, SsrMatches};
6use ra_syntax::ast::AstToken;
7use ra_syntax::TextSize;
8use ra_text_edit::TextEdit;
9
10/// Returns a text edit that will replace each match in `matches` with its corresponding replacement
11/// template. Placeholders in the template will have been substituted with whatever they matched to
12/// in the original code.
13pub(crate) fn matches_to_edit(matches: &SsrMatches) -> TextEdit {
14 matches_to_edit_at_offset(matches, 0.into())
15}
16
17fn matches_to_edit_at_offset(matches: &SsrMatches, relative_start: TextSize) -> TextEdit {
18 let mut edit_builder = ra_text_edit::TextEditBuilder::default();
19 for m in &matches.matches {
20 edit_builder.replace(m.range.checked_sub(relative_start).unwrap(), render_replace(m));
21 }
22 edit_builder.finish()
23}
24
25fn render_replace(match_info: &Match) -> String {
26 let mut out = String::new();
27 for r in &match_info.template.tokens {
28 match r {
29 PatternElement::Token(t) => out.push_str(t.text.as_str()),
30 PatternElement::Placeholder(p) => {
31 if let Some(placeholder_value) =
32 match_info.placeholder_values.get(&Var(p.ident.to_string()))
33 {
34 let range = &placeholder_value.range.range;
35 let mut matched_text = placeholder_value.node.text().to_string();
36 let edit =
37 matches_to_edit_at_offset(&placeholder_value.inner_matches, range.start());
38 edit.apply(&mut matched_text);
39 out.push_str(&matched_text);
40 } else {
41 // We validated that all placeholder references were valid before we
42 // started, so this shouldn't happen.
43 panic!(
44 "Internal error: replacement referenced unknown placeholder {}",
45 p.ident
46 );
47 }
48 }
49 }
50 }
51 for comment in &match_info.ignored_comments {
52 out.push_str(&comment.syntax().to_string());
53 }
54 out
55}
diff --git a/crates/ra_ssr/src/tests.rs b/crates/ra_ssr/src/tests.rs
new file mode 100644
index 000000000..4b747fe18
--- /dev/null
+++ b/crates/ra_ssr/src/tests.rs
@@ -0,0 +1,496 @@
1use crate::matching::MatchFailureReason;
2use crate::{matching, Match, MatchFinder, SsrMatches, SsrPattern, SsrRule};
3use matching::record_match_fails_reasons_scope;
4use ra_db::{FileId, FileRange, SourceDatabaseExt};
5use ra_syntax::ast::AstNode;
6use ra_syntax::{ast, SyntaxKind, SyntaxNode, TextRange};
7
8struct MatchDebugInfo {
9 node: SyntaxNode,
10 /// Our search pattern parsed as the same kind of syntax node as `node`. e.g. expression, item,
11 /// etc. Will be absent if the pattern can't be parsed as that kind.
12 pattern: Result<SyntaxNode, MatchFailureReason>,
13 matched: Result<Match, MatchFailureReason>,
14}
15
16impl SsrPattern {
17 pub(crate) fn tree_for_kind_with_reason(
18 &self,
19 kind: SyntaxKind,
20 ) -> Result<&SyntaxNode, MatchFailureReason> {
21 record_match_fails_reasons_scope(true, || self.tree_for_kind(kind))
22 .map_err(|e| MatchFailureReason { reason: e.reason.unwrap() })
23 }
24}
25
26impl std::fmt::Debug for MatchDebugInfo {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(f, "========= PATTERN ==========\n")?;
29 match &self.pattern {
30 Ok(pattern) => {
31 write!(f, "{:#?}", pattern)?;
32 }
33 Err(err) => {
34 write!(f, "{}", err.reason)?;
35 }
36 }
37 write!(
38 f,
39 "\n============ AST ===========\n\
40 {:#?}\n============================",
41 self.node
42 )?;
43 match &self.matched {
44 Ok(_) => write!(f, "Node matched")?,
45 Err(reason) => write!(f, "Node failed to match because: {}", reason.reason)?,
46 }
47 Ok(())
48 }
49}
50
51impl SsrMatches {
52 /// Returns `self` with any nested matches removed and made into top-level matches.
53 pub(crate) fn flattened(self) -> SsrMatches {
54 let mut out = SsrMatches::default();
55 self.flatten_into(&mut out);
56 out
57 }
58
59 fn flatten_into(self, out: &mut SsrMatches) {
60 for mut m in self.matches {
61 for p in m.placeholder_values.values_mut() {
62 std::mem::replace(&mut p.inner_matches, SsrMatches::default()).flatten_into(out);
63 }
64 out.matches.push(m);
65 }
66 }
67}
68
69impl Match {
70 pub(crate) fn matched_text(&self) -> String {
71 self.matched_node.text().to_string()
72 }
73}
74
75impl<'db> MatchFinder<'db> {
76 /// Adds a search pattern. For use if you intend to only call `find_matches_in_file`. If you
77 /// intend to do replacement, use `add_rule` instead.
78 fn add_search_pattern(&mut self, pattern: SsrPattern) {
79 self.add_rule(SsrRule { pattern, template: "()".parse().unwrap() })
80 }
81
82 /// Finds all nodes in `file_id` whose text is exactly equal to `snippet` and attempts to match
83 /// them, while recording reasons why they don't match. This API is useful for command
84 /// line-based debugging where providing a range is difficult.
85 fn debug_where_text_equal(&self, file_id: FileId, snippet: &str) -> Vec<MatchDebugInfo> {
86 let file = self.sema.parse(file_id);
87 let mut res = Vec::new();
88 let file_text = self.sema.db.file_text(file_id);
89 let mut remaining_text = file_text.as_str();
90 let mut base = 0;
91 let len = snippet.len() as u32;
92 while let Some(offset) = remaining_text.find(snippet) {
93 let start = base + offset as u32;
94 let end = start + len;
95 self.output_debug_for_nodes_at_range(
96 file.syntax(),
97 TextRange::new(start.into(), end.into()),
98 &None,
99 &mut res,
100 );
101 remaining_text = &remaining_text[offset + snippet.len()..];
102 base = end;
103 }
104 res
105 }
106
107 fn output_debug_for_nodes_at_range(
108 &self,
109 node: &SyntaxNode,
110 range: TextRange,
111 restrict_range: &Option<FileRange>,
112 out: &mut Vec<MatchDebugInfo>,
113 ) {
114 for node in node.children() {
115 if !node.text_range().contains_range(range) {
116 continue;
117 }
118 if node.text_range() == range {
119 for rule in &self.rules {
120 let pattern =
121 rule.pattern.tree_for_kind_with_reason(node.kind()).map(|p| p.clone());
122 out.push(MatchDebugInfo {
123 matched: matching::get_match(true, rule, &node, restrict_range, &self.sema)
124 .map_err(|e| MatchFailureReason {
125 reason: e.reason.unwrap_or_else(|| {
126 "Match failed, but no reason was given".to_owned()
127 }),
128 }),
129 pattern,
130 node: node.clone(),
131 });
132 }
133 } else if let Some(macro_call) = ast::MacroCall::cast(node.clone()) {
134 if let Some(expanded) = self.sema.expand(&macro_call) {
135 if let Some(tt) = macro_call.token_tree() {
136 self.output_debug_for_nodes_at_range(
137 &expanded,
138 range,
139 &Some(self.sema.original_range(tt.syntax())),
140 out,
141 );
142 }
143 }
144 }
145 }
146 }
147}
148
149fn parse_error_text(query: &str) -> String {
150 format!("{}", query.parse::<SsrRule>().unwrap_err())
151}
152
153#[test]
154fn parser_empty_query() {
155 assert_eq!(parse_error_text(""), "Parse error: Cannot find delemiter `==>>`");
156}
157
158#[test]
159fn parser_no_delimiter() {
160 assert_eq!(parse_error_text("foo()"), "Parse error: Cannot find delemiter `==>>`");
161}
162
163#[test]
164fn parser_two_delimiters() {
165 assert_eq!(
166 parse_error_text("foo() ==>> a ==>> b "),
167 "Parse error: More than one delimiter found"
168 );
169}
170
171#[test]
172fn parser_repeated_name() {
173 assert_eq!(
174 parse_error_text("foo($a, $a) ==>>"),
175 "Parse error: Name `a` repeats more than once"
176 );
177}
178
179#[test]
180fn parser_invalid_pattern() {
181 assert_eq!(
182 parse_error_text(" ==>> ()"),
183 "Parse error: Pattern is not a valid Rust expression, type, item, path or pattern"
184 );
185}
186
187#[test]
188fn parser_invalid_template() {
189 assert_eq!(
190 parse_error_text("() ==>> )"),
191 "Parse error: Replacement is not a valid Rust expression, type, item, path or pattern"
192 );
193}
194
195#[test]
196fn parser_undefined_placeholder_in_replacement() {
197 assert_eq!(
198 parse_error_text("42 ==>> $a"),
199 "Parse error: Replacement contains undefined placeholders: $a"
200 );
201}
202
203fn single_file(code: &str) -> (ra_ide_db::RootDatabase, FileId) {
204 use ra_db::fixture::WithFixture;
205 ra_ide_db::RootDatabase::with_single_file(code)
206}
207
208fn assert_ssr_transform(rule: &str, input: &str, result: &str) {
209 assert_ssr_transforms(&[rule], input, result);
210}
211
212fn assert_ssr_transforms(rules: &[&str], input: &str, result: &str) {
213 let (db, file_id) = single_file(input);
214 let mut match_finder = MatchFinder::new(&db);
215 for rule in rules {
216 let rule: SsrRule = rule.parse().unwrap();
217 match_finder.add_rule(rule);
218 }
219 if let Some(edits) = match_finder.edits_for_file(file_id) {
220 let mut after = input.to_string();
221 edits.apply(&mut after);
222 assert_eq!(after, result);
223 } else {
224 panic!("No edits were made");
225 }
226}
227
228fn assert_matches(pattern: &str, code: &str, expected: &[&str]) {
229 let (db, file_id) = single_file(code);
230 let mut match_finder = MatchFinder::new(&db);
231 match_finder.add_search_pattern(pattern.parse().unwrap());
232 let matched_strings: Vec<String> = match_finder
233 .find_matches_in_file(file_id)
234 .flattened()
235 .matches
236 .iter()
237 .map(|m| m.matched_text())
238 .collect();
239 if matched_strings != expected && !expected.is_empty() {
240 let debug_info = match_finder.debug_where_text_equal(file_id, &expected[0]);
241 eprintln!("Test is about to fail. Some possibly useful info: {} nodes had text exactly equal to '{}'", debug_info.len(), &expected[0]);
242 for d in debug_info {
243 eprintln!("{:#?}", d);
244 }
245 }
246 assert_eq!(matched_strings, expected);
247}
248
249fn assert_no_match(pattern: &str, code: &str) {
250 assert_matches(pattern, code, &[]);
251}
252
253#[test]
254fn ssr_function_to_method() {
255 assert_ssr_transform(
256 "my_function($a, $b) ==>> ($a).my_method($b)",
257 "loop { my_function( other_func(x, y), z + w) }",
258 "loop { (other_func(x, y)).my_method(z + w) }",
259 )
260}
261
262#[test]
263fn ssr_nested_function() {
264 assert_ssr_transform(
265 "foo($a, $b, $c) ==>> bar($c, baz($a, $b))",
266 "fn main { foo (x + value.method(b), x+y-z, true && false) }",
267 "fn main { bar(true && false, baz(x + value.method(b), x+y-z)) }",
268 )
269}
270
271#[test]
272fn ssr_expected_spacing() {
273 assert_ssr_transform(
274 "foo($x) + bar() ==>> bar($x)",
275 "fn main() { foo(5) + bar() }",
276 "fn main() { bar(5) }",
277 );
278}
279
280#[test]
281fn ssr_with_extra_space() {
282 assert_ssr_transform(
283 "foo($x ) + bar() ==>> bar($x)",
284 "fn main() { foo( 5 ) +bar( ) }",
285 "fn main() { bar(5) }",
286 );
287}
288
289#[test]
290fn ssr_keeps_nested_comment() {
291 assert_ssr_transform(
292 "foo($x) ==>> bar($x)",
293 "fn main() { foo(other(5 /* using 5 */)) }",
294 "fn main() { bar(other(5 /* using 5 */)) }",
295 )
296}
297
298#[test]
299fn ssr_keeps_comment() {
300 assert_ssr_transform(
301 "foo($x) ==>> bar($x)",
302 "fn main() { foo(5 /* using 5 */) }",
303 "fn main() { bar(5)/* using 5 */ }",
304 )
305}
306
307#[test]
308fn ssr_struct_lit() {
309 assert_ssr_transform(
310 "foo{a: $a, b: $b} ==>> foo::new($a, $b)",
311 "fn main() { foo{b:2, a:1} }",
312 "fn main() { foo::new(1, 2) }",
313 )
314}
315
316#[test]
317fn ignores_whitespace() {
318 assert_matches("1+2", "fn f() -> i32 {1 + 2}", &["1 + 2"]);
319 assert_matches("1 + 2", "fn f() -> i32 {1+2}", &["1+2"]);
320}
321
322#[test]
323fn no_match() {
324 assert_no_match("1 + 3", "fn f() -> i32 {1 + 2}");
325}
326
327#[test]
328fn match_fn_definition() {
329 assert_matches("fn $a($b: $t) {$c}", "fn f(a: i32) {bar()}", &["fn f(a: i32) {bar()}"]);
330}
331
332#[test]
333fn match_struct_definition() {
334 assert_matches(
335 "struct $n {$f: Option<String>}",
336 "struct Bar {} struct Foo {name: Option<String>}",
337 &["struct Foo {name: Option<String>}"],
338 );
339}
340
341#[test]
342fn match_expr() {
343 let code = "fn f() -> i32 {foo(40 + 2, 42)}";
344 assert_matches("foo($a, $b)", code, &["foo(40 + 2, 42)"]);
345 assert_no_match("foo($a, $b, $c)", code);
346 assert_no_match("foo($a)", code);
347}
348
349#[test]
350fn match_nested_method_calls() {
351 assert_matches(
352 "$a.z().z().z()",
353 "fn f() {h().i().j().z().z().z().d().e()}",
354 &["h().i().j().z().z().z()"],
355 );
356}
357
358#[test]
359fn match_complex_expr() {
360 let code = "fn f() -> i32 {foo(bar(40, 2), 42)}";
361 assert_matches("foo($a, $b)", code, &["foo(bar(40, 2), 42)"]);
362 assert_no_match("foo($a, $b, $c)", code);
363 assert_no_match("foo($a)", code);
364 assert_matches("bar($a, $b)", code, &["bar(40, 2)"]);
365}
366
367// Trailing commas in the code should be ignored.
368#[test]
369fn match_with_trailing_commas() {
370 // Code has comma, pattern doesn't.
371 assert_matches("foo($a, $b)", "fn f() {foo(1, 2,);}", &["foo(1, 2,)"]);
372 assert_matches("Foo{$a, $b}", "fn f() {Foo{1, 2,};}", &["Foo{1, 2,}"]);
373
374 // Pattern has comma, code doesn't.
375 assert_matches("foo($a, $b,)", "fn f() {foo(1, 2);}", &["foo(1, 2)"]);
376 assert_matches("Foo{$a, $b,}", "fn f() {Foo{1, 2};}", &["Foo{1, 2}"]);
377}
378
379#[test]
380fn match_type() {
381 assert_matches("i32", "fn f() -> i32 {1 + 2}", &["i32"]);
382 assert_matches("Option<$a>", "fn f() -> Option<i32> {42}", &["Option<i32>"]);
383 assert_no_match("Option<$a>", "fn f() -> Result<i32, ()> {42}");
384}
385
386#[test]
387fn match_struct_instantiation() {
388 assert_matches(
389 "Foo {bar: 1, baz: 2}",
390 "fn f() {Foo {bar: 1, baz: 2}}",
391 &["Foo {bar: 1, baz: 2}"],
392 );
393 // Now with placeholders for all parts of the struct.
394 assert_matches(
395 "Foo {$a: $b, $c: $d}",
396 "fn f() {Foo {bar: 1, baz: 2}}",
397 &["Foo {bar: 1, baz: 2}"],
398 );
399 assert_matches("Foo {}", "fn f() {Foo {}}", &["Foo {}"]);
400}
401
402#[test]
403fn match_path() {
404 assert_matches("foo::bar", "fn f() {foo::bar(42)}", &["foo::bar"]);
405 assert_matches("$a::bar", "fn f() {foo::bar(42)}", &["foo::bar"]);
406 assert_matches("foo::$b", "fn f() {foo::bar(42)}", &["foo::bar"]);
407}
408
409#[test]
410fn match_pattern() {
411 assert_matches("Some($a)", "fn f() {if let Some(x) = foo() {}}", &["Some(x)"]);
412}
413
414#[test]
415fn match_reordered_struct_instantiation() {
416 assert_matches(
417 "Foo {aa: 1, b: 2, ccc: 3}",
418 "fn f() {Foo {b: 2, ccc: 3, aa: 1}}",
419 &["Foo {b: 2, ccc: 3, aa: 1}"],
420 );
421 assert_no_match("Foo {a: 1}", "fn f() {Foo {b: 1}}");
422 assert_no_match("Foo {a: 1}", "fn f() {Foo {a: 2}}");
423 assert_no_match("Foo {a: 1, b: 2}", "fn f() {Foo {a: 1}}");
424 assert_no_match("Foo {a: 1, b: 2}", "fn f() {Foo {b: 2}}");
425 assert_no_match("Foo {a: 1, }", "fn f() {Foo {a: 1, b: 2}}");
426 assert_no_match("Foo {a: 1, z: 9}", "fn f() {Foo {a: 1}}");
427}
428
429#[test]
430fn replace_function_call() {
431 assert_ssr_transform("foo() ==>> bar()", "fn f1() {foo(); foo();}", "fn f1() {bar(); bar();}");
432}
433
434#[test]
435fn replace_function_call_with_placeholders() {
436 assert_ssr_transform(
437 "foo($a, $b) ==>> bar($b, $a)",
438 "fn f1() {foo(5, 42)}",
439 "fn f1() {bar(42, 5)}",
440 );
441}
442
443#[test]
444fn replace_nested_function_calls() {
445 assert_ssr_transform(
446 "foo($a) ==>> bar($a)",
447 "fn f1() {foo(foo(42))}",
448 "fn f1() {bar(bar(42))}",
449 );
450}
451
452#[test]
453fn replace_type() {
454 assert_ssr_transform(
455 "Result<(), $a> ==>> Option<$a>",
456 "fn f1() -> Result<(), Vec<Error>> {foo()}",
457 "fn f1() -> Option<Vec<Error>> {foo()}",
458 );
459}
460
461#[test]
462fn replace_struct_init() {
463 assert_ssr_transform(
464 "Foo {a: $a, b: $b} ==>> Foo::new($a, $b)",
465 "fn f1() {Foo{b: 1, a: 2}}",
466 "fn f1() {Foo::new(2, 1)}",
467 );
468}
469
470#[test]
471fn replace_binary_op() {
472 assert_ssr_transform(
473 "$a + $b ==>> $b + $a",
474 "fn f() {2 * 3 + 4 * 5}",
475 "fn f() {4 * 5 + 2 * 3}",
476 );
477 assert_ssr_transform(
478 "$a + $b ==>> $b + $a",
479 "fn f() {1 + 2 + 3 + 4}",
480 "fn f() {4 + 3 + 2 + 1}",
481 );
482}
483
484#[test]
485fn match_binary_op() {
486 assert_matches("$a + $b", "fn f() {1 + 2 + 3 + 4}", &["1 + 2", "1 + 2 + 3", "1 + 2 + 3 + 4"]);
487}
488
489#[test]
490fn multiple_rules() {
491 assert_ssr_transforms(
492 &["$a + 1 ==>> add_one($a)", "$a + $b ==>> add($a, $b)"],
493 "fn f() -> i32 {3 + 2 + 1}",
494 "fn f() -> i32 {add_one(add(3, 2))}",
495 )
496}