From 662ab2ecc8e29eb5995b3c162fac869838bea9a2 Mon Sep 17 00:00:00 2001 From: David Lattimore Date: Wed, 17 Jun 2020 16:53:51 +1000 Subject: Allow SSR to match type references, items, paths and patterns Part of #3186 --- crates/ra_ide/Cargo.toml | 1 + crates/ra_ide/src/lib.rs | 2 +- crates/ra_ide/src/ssr.rs | 563 +---------------------------------------- crates/ra_ssr/Cargo.toml | 19 ++ crates/ra_ssr/src/lib.rs | 120 +++++++++ crates/ra_ssr/src/matching.rs | 494 ++++++++++++++++++++++++++++++++++++ crates/ra_ssr/src/parsing.rs | 272 ++++++++++++++++++++ crates/ra_ssr/src/replacing.rs | 55 ++++ crates/ra_ssr/src/tests.rs | 496 ++++++++++++++++++++++++++++++++++++ 9 files changed, 1467 insertions(+), 555 deletions(-) create mode 100644 crates/ra_ssr/Cargo.toml create mode 100644 crates/ra_ssr/src/lib.rs create mode 100644 crates/ra_ssr/src/matching.rs create mode 100644 crates/ra_ssr/src/parsing.rs create mode 100644 crates/ra_ssr/src/replacing.rs create mode 100644 crates/ra_ssr/src/tests.rs (limited to 'crates') diff --git a/crates/ra_ide/Cargo.toml b/crates/ra_ide/Cargo.toml index 05c940605..bbc6a5c9b 100644 --- a/crates/ra_ide/Cargo.toml +++ b/crates/ra_ide/Cargo.toml @@ -29,6 +29,7 @@ ra_fmt = { path = "../ra_fmt" } ra_prof = { path = "../ra_prof" } test_utils = { path = "../test_utils" } ra_assists = { path = "../ra_assists" } +ra_ssr = { path = "../ra_ssr" } # ra_ide should depend only on the top-level `hir` package. if you need # something from some `hir_xxx` subpackage, reexport the API via `hir`. diff --git a/crates/ra_ide/src/lib.rs b/crates/ra_ide/src/lib.rs index be9ab62c0..47823718f 100644 --- a/crates/ra_ide/src/lib.rs +++ b/crates/ra_ide/src/lib.rs @@ -70,7 +70,6 @@ pub use crate::{ inlay_hints::{InlayHint, InlayHintsConfig, InlayKind}, references::{Declaration, Reference, ReferenceAccess, ReferenceKind, ReferenceSearchResult}, runnables::{Runnable, RunnableKind, TestId}, - ssr::SsrError, syntax_highlighting::{ Highlight, HighlightModifier, HighlightModifiers, HighlightTag, HighlightedRange, }, @@ -89,6 +88,7 @@ pub use ra_ide_db::{ symbol_index::Query, RootDatabase, }; +pub use ra_ssr::SsrError; pub use ra_text_edit::{Indel, TextEdit}; pub type Cancelable = Result; diff --git a/crates/ra_ide/src/ssr.rs b/crates/ra_ide/src/ssr.rs index 762aab962..59c230f6c 100644 --- a/crates/ra_ide/src/ssr.rs +++ b/crates/ra_ide/src/ssr.rs @@ -1,31 +1,12 @@ -use std::{collections::HashMap, iter::once, str::FromStr}; - -use ra_db::{SourceDatabase, SourceDatabaseExt}; +use ra_db::SourceDatabaseExt; use ra_ide_db::{symbol_index::SymbolsDatabase, RootDatabase}; -use ra_syntax::ast::{ - make::try_expr_from_text, ArgList, AstToken, CallExpr, Comment, Expr, MethodCallExpr, - RecordField, RecordLit, -}; -use ra_syntax::{AstNode, SyntaxElement, SyntaxKind, SyntaxNode}; -use ra_text_edit::{TextEdit, TextEditBuilder}; -use rustc_hash::FxHashMap; use crate::SourceFileEdit; - -#[derive(Debug, PartialEq)] -pub struct SsrError(String); - -impl std::fmt::Display for SsrError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "Parse error: {}", self.0) - } -} - -impl std::error::Error for SsrError {} +use ra_ssr::{MatchFinder, SsrError, SsrRule}; // Feature: Structural Seach and Replace // -// Search and replace with named wildcards that will match any expression. +// Search and replace with named wildcards that will match any expression, type, path, pattern or item. // The syntax for a structural search replace command is ` ==>> `. // A `$` placeholder in the search pattern will match any AST node and `$` will reference it in the replacement. // Available via the command `rust-analyzer.ssr`. @@ -46,550 +27,24 @@ impl std::error::Error for SsrError {} // | VS Code | **Rust Analyzer: Structural Search Replace** // |=== pub fn parse_search_replace( - query: &str, + rule: &str, parse_only: bool, db: &RootDatabase, ) -> Result, SsrError> { let mut edits = vec![]; - let query: SsrQuery = query.parse()?; + let rule: SsrRule = rule.parse()?; if parse_only { return Ok(edits); } + let mut match_finder = MatchFinder::new(db); + match_finder.add_rule(rule); for &root in db.local_roots().iter() { let sr = db.source_root(root); for file_id in sr.walk() { - let matches = find(&query.pattern, db.parse(file_id).tree().syntax()); - if !matches.matches.is_empty() { - edits.push(SourceFileEdit { file_id, edit: replace(&matches, &query.template) }); + if let Some(edit) = match_finder.edits_for_file(file_id) { + edits.push(SourceFileEdit { file_id, edit }); } } } Ok(edits) } - -#[derive(Debug)] -struct SsrQuery { - pattern: SsrPattern, - template: SsrTemplate, -} - -#[derive(Debug)] -struct SsrPattern { - pattern: SyntaxNode, - vars: Vec, -} - -/// Represents a `$var` in an SSR query. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -struct Var(String); - -#[derive(Debug)] -struct SsrTemplate { - template: SyntaxNode, - placeholders: FxHashMap, -} - -type Binding = HashMap; - -#[derive(Debug)] -struct Match { - place: SyntaxNode, - binding: Binding, - ignored_comments: Vec, -} - -#[derive(Debug)] -struct SsrMatches { - matches: Vec, -} - -impl FromStr for SsrQuery { - type Err = SsrError; - - fn from_str(query: &str) -> Result { - let mut it = query.split("==>>"); - let pattern = it.next().expect("at least empty string").trim(); - let mut template = it - .next() - .ok_or_else(|| SsrError("Cannot find delemiter `==>>`".into()))? - .trim() - .to_string(); - if it.next().is_some() { - return Err(SsrError("More than one delimiter found".into())); - } - let mut vars = vec![]; - let mut it = pattern.split('$'); - let mut pattern = it.next().expect("something").to_string(); - - for part in it.map(split_by_var) { - let (var, remainder) = part?; - let new_var = create_name(var, &mut vars)?; - pattern.push_str(new_var); - pattern.push_str(remainder); - template = replace_in_template(template, var, new_var); - } - - let template = try_expr_from_text(&template) - .ok_or(SsrError("Template is not an expression".into()))? - .syntax() - .clone(); - let mut placeholders = FxHashMap::default(); - - traverse(&template, &mut |n| { - if let Some(v) = vars.iter().find(|v| v.0.as_str() == n.text()) { - placeholders.insert(n.clone(), v.clone()); - false - } else { - true - } - }); - - let pattern = SsrPattern { - pattern: try_expr_from_text(&pattern) - .ok_or(SsrError("Pattern is not an expression".into()))? - .syntax() - .clone(), - vars, - }; - let template = SsrTemplate { template, placeholders }; - Ok(SsrQuery { pattern, template }) - } -} - -fn traverse(node: &SyntaxNode, go: &mut impl FnMut(&SyntaxNode) -> bool) { - if !go(node) { - return; - } - for ref child in node.children() { - traverse(child, go); - } -} - -fn split_by_var(s: &str) -> Result<(&str, &str), SsrError> { - let end_of_name = s.find(|c| !char::is_ascii_alphanumeric(&c)).unwrap_or_else(|| s.len()); - let name = &s[..end_of_name]; - is_name(name)?; - Ok((name, &s[end_of_name..])) -} - -fn is_name(s: &str) -> Result<(), SsrError> { - if s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { - Ok(()) - } else { - Err(SsrError("Name can contain only alphanumerics and _".into())) - } -} - -fn replace_in_template(template: String, var: &str, new_var: &str) -> String { - let name = format!("${}", var); - template.replace(&name, new_var) -} - -fn create_name<'a>(name: &str, vars: &'a mut Vec) -> Result<&'a str, SsrError> { - let sanitized_name = format!("__search_pattern_{}", name); - if vars.iter().any(|a| a.0 == sanitized_name) { - return Err(SsrError(format!("Name `{}` repeats more than once", name))); - } - vars.push(Var(sanitized_name)); - Ok(&vars.last().unwrap().0) -} - -fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { - fn check_record_lit( - pattern: RecordLit, - code: RecordLit, - placeholders: &[Var], - match_: Match, - ) -> Option { - let match_ = check_opt_nodes(pattern.path(), code.path(), placeholders, match_)?; - - let mut pattern_fields: Vec = - pattern.record_field_list().map(|x| x.fields().collect()).unwrap_or_default(); - let mut code_fields: Vec = - code.record_field_list().map(|x| x.fields().collect()).unwrap_or_default(); - - if pattern_fields.len() != code_fields.len() { - return None; - } - - let by_name = |a: &RecordField, b: &RecordField| { - a.name_ref() - .map(|x| x.syntax().text().to_string()) - .cmp(&b.name_ref().map(|x| x.syntax().text().to_string())) - }; - pattern_fields.sort_by(by_name); - code_fields.sort_by(by_name); - - pattern_fields.into_iter().zip(code_fields.into_iter()).fold( - Some(match_), - |accum, (a, b)| { - accum.and_then(|match_| check_opt_nodes(Some(a), Some(b), placeholders, match_)) - }, - ) - } - - fn check_call_and_method_call( - pattern: CallExpr, - code: MethodCallExpr, - placeholders: &[Var], - match_: Match, - ) -> Option { - let (pattern_name, pattern_type_args) = if let Some(Expr::PathExpr(path_exr)) = - pattern.expr() - { - let segment = path_exr.path().and_then(|p| p.segment()); - (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list())) - } else { - (None, None) - }; - let match_ = check_opt_nodes(pattern_name, code.name_ref(), placeholders, match_)?; - let match_ = - check_opt_nodes(pattern_type_args, code.type_arg_list(), placeholders, match_)?; - let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args(); - let code_args = code.syntax().children().find_map(ArgList::cast)?.args(); - let code_args = once(code.expr()?).chain(code_args); - check_iter(pattern_args, code_args, placeholders, match_) - } - - fn check_method_call_and_call( - pattern: MethodCallExpr, - code: CallExpr, - placeholders: &[Var], - match_: Match, - ) -> Option { - let (code_name, code_type_args) = if let Some(Expr::PathExpr(path_exr)) = code.expr() { - let segment = path_exr.path().and_then(|p| p.segment()); - (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list())) - } else { - (None, None) - }; - let match_ = check_opt_nodes(pattern.name_ref(), code_name, placeholders, match_)?; - let match_ = - check_opt_nodes(pattern.type_arg_list(), code_type_args, placeholders, match_)?; - let code_args = code.syntax().children().find_map(ArgList::cast)?.args(); - let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args(); - let pattern_args = once(pattern.expr()?).chain(pattern_args); - check_iter(pattern_args, code_args, placeholders, match_) - } - - fn check_opt_nodes( - pattern: Option, - code: Option, - placeholders: &[Var], - match_: Match, - ) -> Option { - match (pattern, code) { - (Some(pattern), Some(code)) => check( - &pattern.syntax().clone().into(), - &code.syntax().clone().into(), - placeholders, - match_, - ), - (None, None) => Some(match_), - _ => None, - } - } - - fn check_iter( - mut pattern: I1, - mut code: I2, - placeholders: &[Var], - match_: Match, - ) -> Option - where - T: AstNode, - I1: Iterator, - I2: Iterator, - { - pattern - .by_ref() - .zip(code.by_ref()) - .fold(Some(match_), |accum, (a, b)| { - accum.and_then(|match_| { - check( - &a.syntax().clone().into(), - &b.syntax().clone().into(), - placeholders, - match_, - ) - }) - }) - .filter(|_| pattern.next().is_none() && code.next().is_none()) - } - - fn check( - pattern: &SyntaxElement, - code: &SyntaxElement, - placeholders: &[Var], - mut match_: Match, - ) -> Option { - match (&pattern, &code) { - (SyntaxElement::Token(pattern), SyntaxElement::Token(code)) => { - if pattern.text() == code.text() { - Some(match_) - } else { - None - } - } - (SyntaxElement::Node(pattern), SyntaxElement::Node(code)) => { - if placeholders.iter().any(|n| n.0.as_str() == pattern.text()) { - match_.binding.insert(Var(pattern.text().to_string()), code.clone()); - Some(match_) - } else { - if let (Some(pattern), Some(code)) = - (RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone())) - { - check_record_lit(pattern, code, placeholders, match_) - } else if let (Some(pattern), Some(code)) = - (CallExpr::cast(pattern.clone()), MethodCallExpr::cast(code.clone())) - { - check_call_and_method_call(pattern, code, placeholders, match_) - } else if let (Some(pattern), Some(code)) = - (MethodCallExpr::cast(pattern.clone()), CallExpr::cast(code.clone())) - { - check_method_call_and_call(pattern, code, placeholders, match_) - } else { - let mut pattern_children = pattern - .children_with_tokens() - .filter(|element| !element.kind().is_trivia()); - let mut code_children = code - .children_with_tokens() - .filter(|element| !element.kind().is_trivia()); - let new_ignored_comments = - code.children_with_tokens().filter_map(|element| { - element.as_token().and_then(|token| Comment::cast(token.clone())) - }); - match_.ignored_comments.extend(new_ignored_comments); - pattern_children - .by_ref() - .zip(code_children.by_ref()) - .fold(Some(match_), |accum, (a, b)| { - accum.and_then(|match_| check(&a, &b, placeholders, match_)) - }) - .filter(|_| { - pattern_children.next().is_none() && code_children.next().is_none() - }) - } - } - } - _ => None, - } - } - let kind = pattern.pattern.kind(); - let matches = code - .descendants() - .filter(|n| { - n.kind() == kind - || (kind == SyntaxKind::CALL_EXPR && n.kind() == SyntaxKind::METHOD_CALL_EXPR) - || (kind == SyntaxKind::METHOD_CALL_EXPR && n.kind() == SyntaxKind::CALL_EXPR) - }) - .filter_map(|code| { - let match_ = - Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] }; - check(&pattern.pattern.clone().into(), &code.into(), &pattern.vars, match_) - }) - .collect(); - SsrMatches { matches } -} - -fn replace(matches: &SsrMatches, template: &SsrTemplate) -> TextEdit { - let mut builder = TextEditBuilder::default(); - for match_ in &matches.matches { - builder.replace( - match_.place.text_range(), - render_replace(&match_.binding, &match_.ignored_comments, template), - ); - } - builder.finish() -} - -fn render_replace( - binding: &Binding, - ignored_comments: &Vec, - template: &SsrTemplate, -) -> String { - let edit = { - let mut builder = TextEditBuilder::default(); - for element in template.template.descendants() { - if let Some(var) = template.placeholders.get(&element) { - builder.replace(element.text_range(), binding[var].to_string()) - } - } - for comment in ignored_comments { - builder.insert(template.template.text_range().end(), comment.syntax().to_string()) - } - builder.finish() - }; - - let mut text = template.template.text().to_string(); - edit.apply(&mut text); - text -} - -#[cfg(test)] -mod tests { - use super::*; - use ra_syntax::SourceFile; - - fn parse_error_text(query: &str) -> String { - format!("{}", query.parse::().unwrap_err()) - } - - #[test] - fn parser_happy_case() { - let result: SsrQuery = "foo($a, $b) ==>> bar($b, $a)".parse().unwrap(); - assert_eq!(&result.pattern.pattern.text(), "foo(__search_pattern_a, __search_pattern_b)"); - assert_eq!(result.pattern.vars.len(), 2); - assert_eq!(result.pattern.vars[0].0, "__search_pattern_a"); - assert_eq!(result.pattern.vars[1].0, "__search_pattern_b"); - assert_eq!(&result.template.template.text(), "bar(__search_pattern_b, __search_pattern_a)"); - } - - #[test] - fn parser_empty_query() { - assert_eq!(parse_error_text(""), "Parse error: Cannot find delemiter `==>>`"); - } - - #[test] - fn parser_no_delimiter() { - assert_eq!(parse_error_text("foo()"), "Parse error: Cannot find delemiter `==>>`"); - } - - #[test] - fn parser_two_delimiters() { - assert_eq!( - parse_error_text("foo() ==>> a ==>> b "), - "Parse error: More than one delimiter found" - ); - } - - #[test] - fn parser_repeated_name() { - assert_eq!( - parse_error_text("foo($a, $a) ==>>"), - "Parse error: Name `a` repeats more than once" - ); - } - - #[test] - fn parser_invlid_pattern() { - assert_eq!(parse_error_text(" ==>> ()"), "Parse error: Pattern is not an expression"); - } - - #[test] - fn parser_invlid_template() { - assert_eq!(parse_error_text("() ==>> )"), "Parse error: Template is not an expression"); - } - - #[test] - fn parse_match_replace() { - let query: SsrQuery = "foo($x) ==>> bar($x)".parse().unwrap(); - let input = "fn main() { foo(1+2); }"; - - let code = SourceFile::parse(input).tree(); - let matches = find(&query.pattern, code.syntax()); - assert_eq!(matches.matches.len(), 1); - assert_eq!(matches.matches[0].place.text(), "foo(1+2)"); - assert_eq!(matches.matches[0].binding.len(), 1); - assert_eq!( - matches.matches[0].binding[&Var("__search_pattern_x".to_string())].text(), - "1+2" - ); - - let edit = replace(&matches, &query.template); - let mut after = input.to_string(); - edit.apply(&mut after); - assert_eq!(after, "fn main() { bar(1+2); }"); - } - - fn assert_ssr_transform(query: &str, input: &str, result: &str) { - let query: SsrQuery = query.parse().unwrap(); - let code = SourceFile::parse(input).tree(); - let matches = find(&query.pattern, code.syntax()); - let edit = replace(&matches, &query.template); - let mut after = input.to_string(); - edit.apply(&mut after); - assert_eq!(after, result); - } - - #[test] - fn ssr_function_to_method() { - assert_ssr_transform( - "my_function($a, $b) ==>> ($a).my_method($b)", - "loop { my_function( other_func(x, y), z + w) }", - "loop { (other_func(x, y)).my_method(z + w) }", - ) - } - - #[test] - fn ssr_nested_function() { - assert_ssr_transform( - "foo($a, $b, $c) ==>> bar($c, baz($a, $b))", - "fn main { foo (x + value.method(b), x+y-z, true && false) }", - "fn main { bar(true && false, baz(x + value.method(b), x+y-z)) }", - ) - } - - #[test] - fn ssr_expected_spacing() { - assert_ssr_transform( - "foo($x) + bar() ==>> bar($x)", - "fn main() { foo(5) + bar() }", - "fn main() { bar(5) }", - ); - } - - #[test] - fn ssr_with_extra_space() { - assert_ssr_transform( - "foo($x ) + bar() ==>> bar($x)", - "fn main() { foo( 5 ) +bar( ) }", - "fn main() { bar(5) }", - ); - } - - #[test] - fn ssr_keeps_nested_comment() { - assert_ssr_transform( - "foo($x) ==>> bar($x)", - "fn main() { foo(other(5 /* using 5 */)) }", - "fn main() { bar(other(5 /* using 5 */)) }", - ) - } - - #[test] - fn ssr_keeps_comment() { - assert_ssr_transform( - "foo($x) ==>> bar($x)", - "fn main() { foo(5 /* using 5 */) }", - "fn main() { bar(5)/* using 5 */ }", - ) - } - - #[test] - fn ssr_struct_lit() { - assert_ssr_transform( - "foo{a: $a, b: $b} ==>> foo::new($a, $b)", - "fn main() { foo{b:2, a:1} }", - "fn main() { foo::new(1, 2) }", - ) - } - - #[test] - fn ssr_call_and_method_call() { - assert_ssr_transform( - "foo::<'a>($a, $b)) ==>> foo2($a, $b)", - "fn main() { get().bar.foo::<'a>(1); }", - "fn main() { foo2(get().bar, 1); }", - ) - } - - #[test] - fn ssr_method_call_and_call() { - assert_ssr_transform( - "$o.foo::($a)) ==>> $o.foo2($a)", - "fn main() { X::foo::(x, 1); }", - "fn main() { x.foo2(1); }", - ) - } -} diff --git a/crates/ra_ssr/Cargo.toml b/crates/ra_ssr/Cargo.toml new file mode 100644 index 000000000..3c2f15a83 --- /dev/null +++ b/crates/ra_ssr/Cargo.toml @@ -0,0 +1,19 @@ +[package] +edition = "2018" +name = "ra_ssr" +version = "0.1.0" +authors = ["rust-analyzer developers"] +license = "MIT OR Apache-2.0" +description = "Structural search and replace of Rust code" +repository = "https://github.com/rust-analyzer/rust-analyzer" + +[lib] +doctest = false + +[dependencies] +ra_text_edit = { path = "../ra_text_edit" } +ra_syntax = { path = "../ra_syntax" } +ra_db = { path = "../ra_db" } +ra_ide_db = { path = "../ra_ide_db" } +hir = { path = "../ra_hir", package = "ra_hir" } +rustc-hash = "1.1.0" diff --git a/crates/ra_ssr/src/lib.rs b/crates/ra_ssr/src/lib.rs new file mode 100644 index 000000000..fc716ae82 --- /dev/null +++ b/crates/ra_ssr/src/lib.rs @@ -0,0 +1,120 @@ +//! Structural Search Replace +//! +//! Allows searching the AST for code that matches one or more patterns and then replacing that code +//! based on a template. + +mod matching; +mod parsing; +mod replacing; +#[cfg(test)] +mod tests; + +use crate::matching::Match; +use hir::Semantics; +use ra_db::{FileId, FileRange}; +use ra_syntax::{AstNode, SmolStr, SyntaxNode}; +use ra_text_edit::TextEdit; +use rustc_hash::FxHashMap; + +// A structured search replace rule. Create by calling `parse` on a str. +#[derive(Debug)] +pub struct SsrRule { + /// A structured pattern that we're searching for. + pattern: SsrPattern, + /// What we'll replace it with. + template: parsing::SsrTemplate, +} + +#[derive(Debug)] +struct SsrPattern { + raw: parsing::RawSearchPattern, + /// Placeholders keyed by the stand-in ident that we use in Rust source code. + placeholders_by_stand_in: FxHashMap, + // We store our search pattern, parsed as each different kind of thing we can look for. As we + // traverse the AST, we get the appropriate one of these for the type of node we're on. For many + // search patterns, only some of these will be present. + expr: Option, + type_ref: Option, + item: Option, + path: Option, + pattern: Option, +} + +#[derive(Debug, PartialEq)] +pub struct SsrError(String); + +#[derive(Debug, Default)] +pub struct SsrMatches { + matches: Vec, +} + +/// Searches a crate for pattern matches and possibly replaces them with something else. +pub struct MatchFinder<'db> { + /// Our source of information about the user's code. + sema: Semantics<'db, ra_ide_db::RootDatabase>, + rules: Vec, +} + +impl<'db> MatchFinder<'db> { + pub fn new(db: &'db ra_ide_db::RootDatabase) -> MatchFinder<'db> { + MatchFinder { sema: Semantics::new(db), rules: Vec::new() } + } + + pub fn add_rule(&mut self, rule: SsrRule) { + self.rules.push(rule); + } + + pub fn edits_for_file(&self, file_id: FileId) -> Option { + let matches = self.find_matches_in_file(file_id); + if matches.matches.is_empty() { + None + } else { + Some(replacing::matches_to_edit(&matches)) + } + } + + fn find_matches_in_file(&self, file_id: FileId) -> SsrMatches { + let file = self.sema.parse(file_id); + let code = file.syntax(); + let mut matches = SsrMatches::default(); + self.find_matches(code, &None, &mut matches); + matches + } + + fn find_matches( + &self, + code: &SyntaxNode, + restrict_range: &Option, + matches_out: &mut SsrMatches, + ) { + for rule in &self.rules { + if let Ok(mut m) = matching::get_match(false, rule, &code, restrict_range, &self.sema) { + // Continue searching in each of our placeholders. + for placeholder_value in m.placeholder_values.values_mut() { + // Don't search our placeholder if it's the entire matched node, otherwise we'd + // find the same match over and over until we got a stack overflow. + if placeholder_value.node != *code { + self.find_matches( + &placeholder_value.node, + restrict_range, + &mut placeholder_value.inner_matches, + ); + } + } + matches_out.matches.push(m); + return; + } + } + for child in code.children() { + self.find_matches(&child, restrict_range, matches_out); + } + } +} + +impl std::fmt::Display for SsrError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "Parse error: {}", self.0) + } +} + +impl std::error::Error for SsrError {} diff --git a/crates/ra_ssr/src/matching.rs b/crates/ra_ssr/src/matching.rs new file mode 100644 index 000000000..265b6d793 --- /dev/null +++ b/crates/ra_ssr/src/matching.rs @@ -0,0 +1,494 @@ +//! This module is responsible for matching a search pattern against a node in the AST. In the +//! process of matching, placeholder values are recorded. + +use crate::{ + parsing::{Placeholder, SsrTemplate}, + SsrMatches, SsrPattern, SsrRule, +}; +use hir::Semantics; +use ra_db::FileRange; +use ra_syntax::ast::{AstNode, AstToken}; +use ra_syntax::{ + ast, SyntaxElement, SyntaxElementChildren, SyntaxKind, SyntaxNode, SyntaxToken, TextRange, +}; +use rustc_hash::FxHashMap; +use std::{cell::Cell, iter::Peekable}; + +// Creates a match error. If we're currently attempting to match some code that we thought we were +// going to match, as indicated by the --debug-snippet flag, then populate the reason field. +macro_rules! match_error { + ($e:expr) => {{ + MatchFailed { + reason: if recording_match_fail_reasons() { + Some(format!("{}", $e)) + } else { + None + } + } + }}; + ($fmt:expr, $($arg:tt)+) => {{ + MatchFailed { + reason: if recording_match_fail_reasons() { + Some(format!($fmt, $($arg)+)) + } else { + None + } + } + }}; +} + +// Fails the current match attempt, recording the supplied reason if we're recording match fail reasons. +macro_rules! fail_match { + ($($args:tt)*) => {return Err(match_error!($($args)*))}; +} + +/// Information about a match that was found. +#[derive(Debug)] +pub(crate) struct Match { + pub(crate) range: TextRange, + pub(crate) matched_node: SyntaxNode, + pub(crate) placeholder_values: FxHashMap, + pub(crate) ignored_comments: Vec, + // A copy of the template for the rule that produced this match. We store this on the match for + // if/when we do replacement. + pub(crate) template: SsrTemplate, +} + +/// Represents a `$var` in an SSR query. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct Var(pub String); + +/// Information about a placeholder bound in a match. +#[derive(Debug)] +pub(crate) struct PlaceholderMatch { + /// The node that the placeholder matched to. + pub(crate) node: SyntaxNode, + pub(crate) range: FileRange, + /// More matches, found within `node`. + pub(crate) inner_matches: SsrMatches, +} + +#[derive(Debug)] +pub(crate) struct MatchFailureReason { + pub(crate) reason: String, +} + +/// An "error" indicating that matching failed. Use the fail_match! macro to create and return this. +#[derive(Clone)] +pub(crate) struct MatchFailed { + /// The reason why we failed to match. Only present when debug_active true in call to + /// `get_match`. + pub(crate) reason: Option, +} + +/// Checks if `code` matches the search pattern found in `search_scope`, returning information about +/// the match, if it does. Since we only do matching in this module and searching is done by the +/// parent module, we don't populate nested matches. +pub(crate) fn get_match( + debug_active: bool, + rule: &SsrRule, + code: &SyntaxNode, + restrict_range: &Option, + sema: &Semantics, +) -> Result { + record_match_fails_reasons_scope(debug_active, || { + MatchState::try_match(rule, code, restrict_range, sema) + }) +} + +/// Inputs to matching. This cannot be part of `MatchState`, since we mutate `MatchState` and in at +/// least one case need to hold a borrow of a placeholder from the input pattern while calling a +/// mutable `MatchState` method. +struct MatchInputs<'pattern> { + ssr_pattern: &'pattern SsrPattern, +} + +/// State used while attempting to match our search pattern against a particular node of the AST. +struct MatchState<'db, 'sema> { + sema: &'sema Semantics<'db, ra_ide_db::RootDatabase>, + /// If any placeholders come from anywhere outside of this range, then the match will be + /// rejected. + restrict_range: Option, + /// The match that we're building. We do two passes for a successful match. On the first pass, + /// this is None so that we can avoid doing things like storing copies of what placeholders + /// matched to. If that pass succeeds, then we do a second pass where we collect those details. + /// This means that if we have a pattern like `$a.foo()` we won't do an insert into the + /// placeholders map for every single method call in the codebase. Instead we'll discard all the + /// method calls that aren't calls to `foo` on the first pass and only insert into the + /// placeholders map on the second pass. Likewise for ignored comments. + match_out: Option, +} + +impl<'db, 'sema> MatchState<'db, 'sema> { + fn try_match( + rule: &SsrRule, + code: &SyntaxNode, + restrict_range: &Option, + sema: &'sema Semantics<'db, ra_ide_db::RootDatabase>, + ) -> Result { + let mut match_state = + MatchState { sema, restrict_range: restrict_range.clone(), match_out: None }; + let match_inputs = MatchInputs { ssr_pattern: &rule.pattern }; + let pattern_tree = rule.pattern.tree_for_kind(code.kind())?; + // First pass at matching, where we check that node types and idents match. + match_state.attempt_match_node(&match_inputs, &pattern_tree, code)?; + match_state.validate_range(&sema.original_range(code))?; + match_state.match_out = Some(Match { + range: sema.original_range(code).range, + matched_node: code.clone(), + placeholder_values: FxHashMap::default(), + ignored_comments: Vec::new(), + template: rule.template.clone(), + }); + // Second matching pass, where we record placeholder matches, ignored comments and maybe do + // any other more expensive checks that we didn't want to do on the first pass. + match_state.attempt_match_node(&match_inputs, &pattern_tree, code)?; + Ok(match_state.match_out.unwrap()) + } + + /// Checks that `range` is within the permitted range if any. This is applicable when we're + /// processing a macro expansion and we want to fail the match if we're working with a node that + /// didn't originate from the token tree of the macro call. + fn validate_range(&self, range: &FileRange) -> Result<(), MatchFailed> { + if let Some(restrict_range) = &self.restrict_range { + if restrict_range.file_id != range.file_id + || !restrict_range.range.contains_range(range.range) + { + fail_match!("Node originated from a macro"); + } + } + Ok(()) + } + + fn attempt_match_node( + &mut self, + match_inputs: &MatchInputs, + pattern: &SyntaxNode, + code: &SyntaxNode, + ) -> Result<(), MatchFailed> { + // Handle placeholders. + if let Some(placeholder) = + match_inputs.get_placeholder(&SyntaxElement::Node(pattern.clone())) + { + if self.match_out.is_none() { + return Ok(()); + } + let original_range = self.sema.original_range(code); + // We validated the range for the node when we started the match, so the placeholder + // probably can't fail range validation, but just to be safe... + self.validate_range(&original_range)?; + if let Some(match_out) = &mut self.match_out { + match_out.placeholder_values.insert( + Var(placeholder.ident.to_string()), + PlaceholderMatch::new(code, original_range), + ); + } + return Ok(()); + } + // Non-placeholders. + if pattern.kind() != code.kind() { + fail_match!("Pattern had a {:?}, code had {:?}", pattern.kind(), code.kind()); + } + // Some kinds of nodes have special handling. For everything else, we fall back to default + // matching. + match code.kind() { + SyntaxKind::RECORD_FIELD_LIST => { + self.attempt_match_record_field_list(match_inputs, pattern, code) + } + _ => self.attempt_match_node_children(match_inputs, pattern, code), + } + } + + fn attempt_match_node_children( + &mut self, + match_inputs: &MatchInputs, + pattern: &SyntaxNode, + code: &SyntaxNode, + ) -> Result<(), MatchFailed> { + self.attempt_match_sequences( + match_inputs, + PatternIterator::new(pattern), + code.children_with_tokens(), + ) + } + + fn attempt_match_sequences( + &mut self, + match_inputs: &MatchInputs, + pattern_it: PatternIterator, + mut code_it: SyntaxElementChildren, + ) -> Result<(), MatchFailed> { + let mut pattern_it = pattern_it.peekable(); + loop { + match self.next_non_trivial(&mut code_it) { + None => { + if let Some(p) = pattern_it.next() { + fail_match!("Part of the pattern was unmached: {:?}", p); + } + return Ok(()); + } + Some(SyntaxElement::Token(c)) => { + self.attempt_match_token(&mut pattern_it, &c)?; + } + Some(SyntaxElement::Node(c)) => match pattern_it.next() { + Some(SyntaxElement::Node(p)) => { + self.attempt_match_node(match_inputs, &p, &c)?; + } + Some(p) => fail_match!("Pattern wanted '{}', code has {}", p, c.text()), + None => fail_match!("Pattern reached end, code has {}", c.text()), + }, + } + } + } + + fn attempt_match_token( + &mut self, + pattern: &mut Peekable, + code: &ra_syntax::SyntaxToken, + ) -> Result<(), MatchFailed> { + self.record_ignored_comments(code); + // Ignore whitespace and comments. + if code.kind().is_trivia() { + return Ok(()); + } + if let Some(SyntaxElement::Token(p)) = pattern.peek() { + // If the code has a comma and the pattern is about to close something, then accept the + // comma without advancing the pattern. i.e. ignore trailing commas. + if code.kind() == SyntaxKind::COMMA && is_closing_token(p.kind()) { + return Ok(()); + } + // Conversely, if the pattern has a comma and the code doesn't, skip that part of the + // pattern and continue to match the code. + if p.kind() == SyntaxKind::COMMA && is_closing_token(code.kind()) { + pattern.next(); + } + } + // Consume an element from the pattern and make sure it matches. + match pattern.next() { + Some(SyntaxElement::Token(p)) => { + if p.kind() != code.kind() || p.text() != code.text() { + fail_match!( + "Pattern wanted token '{}' ({:?}), but code had token '{}' ({:?})", + p.text(), + p.kind(), + code.text(), + code.kind() + ) + } + } + Some(SyntaxElement::Node(p)) => { + // Not sure if this is actually reachable. + fail_match!( + "Pattern wanted {:?}, but code had token '{}' ({:?})", + p, + code.text(), + code.kind() + ); + } + None => { + fail_match!("Pattern exhausted, while code remains: `{}`", code.text()); + } + } + Ok(()) + } + + /// We want to allow the records to match in any order, so we have special matching logic for + /// them. + fn attempt_match_record_field_list( + &mut self, + match_inputs: &MatchInputs, + pattern: &SyntaxNode, + code: &SyntaxNode, + ) -> Result<(), MatchFailed> { + // Build a map keyed by field name. + let mut fields_by_name = FxHashMap::default(); + for child in code.children() { + if let Some(record) = ast::RecordField::cast(child.clone()) { + if let Some(name) = record.field_name() { + fields_by_name.insert(name.text().clone(), child.clone()); + } + } + } + for p in pattern.children_with_tokens() { + if let SyntaxElement::Node(p) = p { + if let Some(name_element) = p.first_child_or_token() { + if match_inputs.get_placeholder(&name_element).is_some() { + // If the pattern is using placeholders for field names then order + // independence doesn't make sense. Fall back to regular ordered + // matching. + return self.attempt_match_node_children(match_inputs, pattern, code); + } + if let Some(ident) = only_ident(name_element) { + let code_record = fields_by_name.remove(ident.text()).ok_or_else(|| { + match_error!( + "Placeholder has record field '{}', but code doesn't", + ident + ) + })?; + self.attempt_match_node(match_inputs, &p, &code_record)?; + } + } + } + } + if let Some(unmatched_fields) = fields_by_name.keys().next() { + fail_match!( + "{} field(s) of a record literal failed to match, starting with {}", + fields_by_name.len(), + unmatched_fields + ); + } + Ok(()) + } + + fn next_non_trivial(&mut self, code_it: &mut SyntaxElementChildren) -> Option { + loop { + let c = code_it.next(); + if let Some(SyntaxElement::Token(t)) = &c { + self.record_ignored_comments(t); + if t.kind().is_trivia() { + continue; + } + } + return c; + } + } + + fn record_ignored_comments(&mut self, token: &SyntaxToken) { + if token.kind() == SyntaxKind::COMMENT { + if let Some(match_out) = &mut self.match_out { + if let Some(comment) = ast::Comment::cast(token.clone()) { + match_out.ignored_comments.push(comment); + } + } + } + } +} + +impl MatchInputs<'_> { + fn get_placeholder(&self, element: &SyntaxElement) -> Option<&Placeholder> { + only_ident(element.clone()) + .and_then(|ident| self.ssr_pattern.placeholders_by_stand_in.get(ident.text())) + } +} + +fn is_closing_token(kind: SyntaxKind) -> bool { + kind == SyntaxKind::R_PAREN || kind == SyntaxKind::R_CURLY || kind == SyntaxKind::R_BRACK +} + +pub(crate) fn record_match_fails_reasons_scope(debug_active: bool, f: F) -> T +where + F: Fn() -> T, +{ + RECORDING_MATCH_FAIL_REASONS.with(|c| c.set(debug_active)); + let res = f(); + RECORDING_MATCH_FAIL_REASONS.with(|c| c.set(false)); + res +} + +// For performance reasons, we don't want to record the reason why every match fails, only the bit +// of code that the user indicated they thought would match. We use a thread local to indicate when +// we are trying to match that bit of code. This saves us having to pass a boolean into all the bits +// of code that can make the decision to not match. +thread_local! { + pub static RECORDING_MATCH_FAIL_REASONS: Cell = Cell::new(false); +} + +fn recording_match_fail_reasons() -> bool { + RECORDING_MATCH_FAIL_REASONS.with(|c| c.get()) +} + +impl PlaceholderMatch { + fn new(node: &SyntaxNode, range: FileRange) -> Self { + Self { node: node.clone(), range, inner_matches: SsrMatches::default() } + } +} + +impl SsrPattern { + pub(crate) fn tree_for_kind(&self, kind: SyntaxKind) -> Result<&SyntaxNode, MatchFailed> { + let (tree, kind_name) = if ast::Expr::can_cast(kind) { + (&self.expr, "expression") + } else if ast::TypeRef::can_cast(kind) { + (&self.type_ref, "type reference") + } else if ast::ModuleItem::can_cast(kind) { + (&self.item, "item") + } else if ast::Path::can_cast(kind) { + (&self.path, "path") + } else if ast::Pat::can_cast(kind) { + (&self.pattern, "pattern") + } else { + fail_match!("Matching nodes of kind {:?} is not supported", kind); + }; + match tree { + Some(tree) => Ok(tree), + None => fail_match!("Pattern cannot be parsed as a {}", kind_name), + } + } +} + +// If `node` contains nothing but an ident then return it, otherwise return None. +fn only_ident(element: SyntaxElement) -> Option { + match element { + SyntaxElement::Token(t) => { + if t.kind() == SyntaxKind::IDENT { + return Some(t); + } + } + SyntaxElement::Node(n) => { + let mut children = n.children_with_tokens(); + if let (Some(only_child), None) = (children.next(), children.next()) { + return only_ident(only_child); + } + } + } + None +} + +struct PatternIterator { + iter: SyntaxElementChildren, +} + +impl Iterator for PatternIterator { + type Item = SyntaxElement; + + fn next(&mut self) -> Option { + while let Some(element) = self.iter.next() { + if !element.kind().is_trivia() { + return Some(element); + } + } + None + } +} + +impl PatternIterator { + fn new(parent: &SyntaxNode) -> Self { + Self { iter: parent.children_with_tokens() } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::MatchFinder; + + #[test] + fn parse_match_replace() { + let rule: SsrRule = "foo($x) ==>> bar($x)".parse().unwrap(); + let input = "fn main() { foo(1+2); }"; + + use ra_db::fixture::WithFixture; + let (db, file_id) = ra_ide_db::RootDatabase::with_single_file(input); + let mut match_finder = MatchFinder::new(&db); + match_finder.add_rule(rule); + let matches = match_finder.find_matches_in_file(file_id); + assert_eq!(matches.matches.len(), 1); + assert_eq!(matches.matches[0].matched_node.text(), "foo(1+2)"); + assert_eq!(matches.matches[0].placeholder_values.len(), 1); + assert_eq!(matches.matches[0].placeholder_values[&Var("x".to_string())].node.text(), "1+2"); + + let edit = crate::replacing::matches_to_edit(&matches); + let mut after = input.to_string(); + edit.apply(&mut after); + assert_eq!(after, "fn main() { bar(1+2); }"); + } +} diff --git a/crates/ra_ssr/src/parsing.rs b/crates/ra_ssr/src/parsing.rs new file mode 100644 index 000000000..90c13dbc2 --- /dev/null +++ b/crates/ra_ssr/src/parsing.rs @@ -0,0 +1,272 @@ +//! This file contains code for parsing SSR rules, which look something like `foo($a) ==>> bar($b)`. +//! We first split everything before and after the separator `==>>`. Next, both the search pattern +//! and the replacement template get tokenized by the Rust tokenizer. Tokens are then searched for +//! placeholders, which start with `$`. For replacement templates, this is the final form. For +//! search patterns, we go further and parse the pattern as each kind of thing that we can match. +//! e.g. expressions, type references etc. + +use crate::{SsrError, SsrPattern, SsrRule}; +use ra_syntax::{ast, AstNode, SmolStr, SyntaxKind}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::str::FromStr; + +/// Returns from the current function with an error, supplied by arguments as for format! +macro_rules! bail { + ($e:expr) => {return Err($crate::SsrError::new($e))}; + ($fmt:expr, $($arg:tt)+) => {return Err($crate::SsrError::new(format!($fmt, $($arg)+)))} +} + +#[derive(Clone, Debug)] +pub(crate) struct SsrTemplate { + pub(crate) tokens: Vec, +} + +#[derive(Debug)] +pub(crate) struct RawSearchPattern { + tokens: Vec, +} + +// Part of a search or replace pattern. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) enum PatternElement { + Token(Token), + Placeholder(Placeholder), +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct Placeholder { + /// The name of this placeholder. e.g. for "$a", this would be "a" + pub(crate) ident: SmolStr, + /// A unique name used in place of this placeholder when we parse the pattern as Rust code. + stand_in_name: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct Token { + kind: SyntaxKind, + pub(crate) text: SmolStr, +} + +impl FromStr for SsrRule { + type Err = SsrError; + + fn from_str(query: &str) -> Result { + let mut it = query.split("==>>"); + let pattern = it.next().expect("at least empty string").trim(); + let template = it + .next() + .ok_or_else(|| SsrError("Cannot find delemiter `==>>`".into()))? + .trim() + .to_string(); + if it.next().is_some() { + return Err(SsrError("More than one delimiter found".into())); + } + let rule = SsrRule { pattern: pattern.parse()?, template: template.parse()? }; + validate_rule(&rule)?; + Ok(rule) + } +} + +impl FromStr for RawSearchPattern { + type Err = SsrError; + + fn from_str(pattern_str: &str) -> Result { + Ok(RawSearchPattern { tokens: parse_pattern(pattern_str)? }) + } +} + +impl RawSearchPattern { + /// Returns this search pattern as Rust source code that we can feed to the Rust parser. + fn as_rust_code(&self) -> String { + let mut res = String::new(); + for t in &self.tokens { + res.push_str(match t { + PatternElement::Token(token) => token.text.as_str(), + PatternElement::Placeholder(placeholder) => placeholder.stand_in_name.as_str(), + }); + } + res + } + + fn placeholders_by_stand_in(&self) -> FxHashMap { + let mut res = FxHashMap::default(); + for t in &self.tokens { + if let PatternElement::Placeholder(placeholder) = t { + res.insert(SmolStr::new(placeholder.stand_in_name.clone()), placeholder.clone()); + } + } + res + } +} + +impl FromStr for SsrPattern { + type Err = SsrError; + + fn from_str(pattern_str: &str) -> Result { + let raw: RawSearchPattern = pattern_str.parse()?; + let raw_str = raw.as_rust_code(); + let res = SsrPattern { + expr: ast::Expr::parse(&raw_str).ok().map(|n| n.syntax().clone()), + type_ref: ast::TypeRef::parse(&raw_str).ok().map(|n| n.syntax().clone()), + item: ast::ModuleItem::parse(&raw_str).ok().map(|n| n.syntax().clone()), + path: ast::Path::parse(&raw_str).ok().map(|n| n.syntax().clone()), + pattern: ast::Pat::parse(&raw_str).ok().map(|n| n.syntax().clone()), + placeholders_by_stand_in: raw.placeholders_by_stand_in(), + raw, + }; + if res.expr.is_none() + && res.type_ref.is_none() + && res.item.is_none() + && res.path.is_none() + && res.pattern.is_none() + { + bail!("Pattern is not a valid Rust expression, type, item, path or pattern"); + } + Ok(res) + } +} + +impl FromStr for SsrTemplate { + type Err = SsrError; + + fn from_str(pattern_str: &str) -> Result { + let tokens = parse_pattern(pattern_str)?; + // Validate that the template is a valid fragment of Rust code. We reuse the validation + // logic for search patterns since the only thing that differs is the error message. + if SsrPattern::from_str(pattern_str).is_err() { + bail!("Replacement is not a valid Rust expression, type, item, path or pattern"); + } + // Our actual template needs to preserve whitespace, so we can't reuse `tokens`. + Ok(SsrTemplate { tokens }) + } +} + +/// Returns `pattern_str`, parsed as a search or replace pattern. If `remove_whitespace` is true, +/// then any whitespace tokens will be removed, which we do for the search pattern, but not for the +/// replace pattern. +fn parse_pattern(pattern_str: &str) -> Result, SsrError> { + let mut res = Vec::new(); + let mut placeholder_names = FxHashSet::default(); + let mut tokens = tokenize(pattern_str)?.into_iter(); + while let Some(token) = tokens.next() { + if token.kind == SyntaxKind::DOLLAR { + let placeholder = parse_placeholder(&mut tokens)?; + if !placeholder_names.insert(placeholder.ident.clone()) { + bail!("Name `{}` repeats more than once", placeholder.ident); + } + res.push(PatternElement::Placeholder(placeholder)); + } else { + res.push(PatternElement::Token(token)); + } + } + Ok(res) +} + +/// Checks for errors in a rule. e.g. the replace pattern referencing placeholders that the search +/// pattern didn't define. +fn validate_rule(rule: &SsrRule) -> Result<(), SsrError> { + let mut defined_placeholders = std::collections::HashSet::new(); + for p in &rule.pattern.raw.tokens { + if let PatternElement::Placeholder(placeholder) = p { + defined_placeholders.insert(&placeholder.ident); + } + } + let mut undefined = Vec::new(); + for p in &rule.template.tokens { + if let PatternElement::Placeholder(placeholder) = p { + if !defined_placeholders.contains(&placeholder.ident) { + undefined.push(format!("${}", placeholder.ident)); + } + } + } + if !undefined.is_empty() { + bail!("Replacement contains undefined placeholders: {}", undefined.join(", ")); + } + Ok(()) +} + +fn tokenize(source: &str) -> Result, SsrError> { + let mut start = 0; + let (raw_tokens, errors) = ra_syntax::tokenize(source); + if let Some(first_error) = errors.first() { + bail!("Failed to parse pattern: {}", first_error); + } + let mut tokens: Vec = Vec::new(); + for raw_token in raw_tokens { + let token_len = usize::from(raw_token.len); + tokens.push(Token { + kind: raw_token.kind, + text: SmolStr::new(&source[start..start + token_len]), + }); + start += token_len; + } + Ok(tokens) +} + +fn parse_placeholder(tokens: &mut std::vec::IntoIter) -> Result { + let mut name = None; + if let Some(token) = tokens.next() { + match token.kind { + SyntaxKind::IDENT => { + name = Some(token.text); + } + _ => { + bail!("Placeholders should be $name"); + } + } + } + let name = name.ok_or_else(|| SsrError::new("Placeholder ($) with no name"))?; + Ok(Placeholder::new(name)) +} + +impl Placeholder { + fn new(name: SmolStr) -> Self { + Self { stand_in_name: format!("__placeholder_{}", name), ident: name } + } +} + +impl SsrError { + fn new(message: impl Into) -> SsrError { + SsrError(message.into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parser_happy_case() { + fn token(kind: SyntaxKind, text: &str) -> PatternElement { + PatternElement::Token(Token { kind, text: SmolStr::new(text) }) + } + fn placeholder(name: &str) -> PatternElement { + PatternElement::Placeholder(Placeholder::new(SmolStr::new(name))) + } + let result: SsrRule = "foo($a, $b) ==>> bar($b, $a)".parse().unwrap(); + assert_eq!( + result.pattern.raw.tokens, + vec![ + token(SyntaxKind::IDENT, "foo"), + token(SyntaxKind::L_PAREN, "("), + placeholder("a"), + token(SyntaxKind::COMMA, ","), + token(SyntaxKind::WHITESPACE, " "), + placeholder("b"), + token(SyntaxKind::R_PAREN, ")"), + ] + ); + assert_eq!( + result.template.tokens, + vec![ + token(SyntaxKind::IDENT, "bar"), + token(SyntaxKind::L_PAREN, "("), + placeholder("b"), + token(SyntaxKind::COMMA, ","), + token(SyntaxKind::WHITESPACE, " "), + placeholder("a"), + token(SyntaxKind::R_PAREN, ")"), + ] + ); + } +} diff --git a/crates/ra_ssr/src/replacing.rs b/crates/ra_ssr/src/replacing.rs new file mode 100644 index 000000000..81a5e06a9 --- /dev/null +++ b/crates/ra_ssr/src/replacing.rs @@ -0,0 +1,55 @@ +//! Code for applying replacement templates for matches that have previously been found. + +use crate::matching::Var; +use crate::parsing::PatternElement; +use crate::{Match, SsrMatches}; +use ra_syntax::ast::AstToken; +use ra_syntax::TextSize; +use ra_text_edit::TextEdit; + +/// Returns a text edit that will replace each match in `matches` with its corresponding replacement +/// template. Placeholders in the template will have been substituted with whatever they matched to +/// in the original code. +pub(crate) fn matches_to_edit(matches: &SsrMatches) -> TextEdit { + matches_to_edit_at_offset(matches, 0.into()) +} + +fn matches_to_edit_at_offset(matches: &SsrMatches, relative_start: TextSize) -> TextEdit { + let mut edit_builder = ra_text_edit::TextEditBuilder::default(); + for m in &matches.matches { + edit_builder.replace(m.range.checked_sub(relative_start).unwrap(), render_replace(m)); + } + edit_builder.finish() +} + +fn render_replace(match_info: &Match) -> String { + let mut out = String::new(); + for r in &match_info.template.tokens { + match r { + PatternElement::Token(t) => out.push_str(t.text.as_str()), + PatternElement::Placeholder(p) => { + if let Some(placeholder_value) = + match_info.placeholder_values.get(&Var(p.ident.to_string())) + { + let range = &placeholder_value.range.range; + let mut matched_text = placeholder_value.node.text().to_string(); + let edit = + matches_to_edit_at_offset(&placeholder_value.inner_matches, range.start()); + edit.apply(&mut matched_text); + out.push_str(&matched_text); + } else { + // We validated that all placeholder references were valid before we + // started, so this shouldn't happen. + panic!( + "Internal error: replacement referenced unknown placeholder {}", + p.ident + ); + } + } + } + } + for comment in &match_info.ignored_comments { + out.push_str(&comment.syntax().to_string()); + } + out +} diff --git a/crates/ra_ssr/src/tests.rs b/crates/ra_ssr/src/tests.rs new file mode 100644 index 000000000..4b747fe18 --- /dev/null +++ b/crates/ra_ssr/src/tests.rs @@ -0,0 +1,496 @@ +use crate::matching::MatchFailureReason; +use crate::{matching, Match, MatchFinder, SsrMatches, SsrPattern, SsrRule}; +use matching::record_match_fails_reasons_scope; +use ra_db::{FileId, FileRange, SourceDatabaseExt}; +use ra_syntax::ast::AstNode; +use ra_syntax::{ast, SyntaxKind, SyntaxNode, TextRange}; + +struct MatchDebugInfo { + node: SyntaxNode, + /// Our search pattern parsed as the same kind of syntax node as `node`. e.g. expression, item, + /// etc. Will be absent if the pattern can't be parsed as that kind. + pattern: Result, + matched: Result, +} + +impl SsrPattern { + pub(crate) fn tree_for_kind_with_reason( + &self, + kind: SyntaxKind, + ) -> Result<&SyntaxNode, MatchFailureReason> { + record_match_fails_reasons_scope(true, || self.tree_for_kind(kind)) + .map_err(|e| MatchFailureReason { reason: e.reason.unwrap() }) + } +} + +impl std::fmt::Debug for MatchDebugInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "========= PATTERN ==========\n")?; + match &self.pattern { + Ok(pattern) => { + write!(f, "{:#?}", pattern)?; + } + Err(err) => { + write!(f, "{}", err.reason)?; + } + } + write!( + f, + "\n============ AST ===========\n\ + {:#?}\n============================", + self.node + )?; + match &self.matched { + Ok(_) => write!(f, "Node matched")?, + Err(reason) => write!(f, "Node failed to match because: {}", reason.reason)?, + } + Ok(()) + } +} + +impl SsrMatches { + /// Returns `self` with any nested matches removed and made into top-level matches. + pub(crate) fn flattened(self) -> SsrMatches { + let mut out = SsrMatches::default(); + self.flatten_into(&mut out); + out + } + + fn flatten_into(self, out: &mut SsrMatches) { + for mut m in self.matches { + for p in m.placeholder_values.values_mut() { + std::mem::replace(&mut p.inner_matches, SsrMatches::default()).flatten_into(out); + } + out.matches.push(m); + } + } +} + +impl Match { + pub(crate) fn matched_text(&self) -> String { + self.matched_node.text().to_string() + } +} + +impl<'db> MatchFinder<'db> { + /// Adds a search pattern. For use if you intend to only call `find_matches_in_file`. If you + /// intend to do replacement, use `add_rule` instead. + fn add_search_pattern(&mut self, pattern: SsrPattern) { + self.add_rule(SsrRule { pattern, template: "()".parse().unwrap() }) + } + + /// Finds all nodes in `file_id` whose text is exactly equal to `snippet` and attempts to match + /// them, while recording reasons why they don't match. This API is useful for command + /// line-based debugging where providing a range is difficult. + fn debug_where_text_equal(&self, file_id: FileId, snippet: &str) -> Vec { + let file = self.sema.parse(file_id); + let mut res = Vec::new(); + let file_text = self.sema.db.file_text(file_id); + let mut remaining_text = file_text.as_str(); + let mut base = 0; + let len = snippet.len() as u32; + while let Some(offset) = remaining_text.find(snippet) { + let start = base + offset as u32; + let end = start + len; + self.output_debug_for_nodes_at_range( + file.syntax(), + TextRange::new(start.into(), end.into()), + &None, + &mut res, + ); + remaining_text = &remaining_text[offset + snippet.len()..]; + base = end; + } + res + } + + fn output_debug_for_nodes_at_range( + &self, + node: &SyntaxNode, + range: TextRange, + restrict_range: &Option, + out: &mut Vec, + ) { + for node in node.children() { + if !node.text_range().contains_range(range) { + continue; + } + if node.text_range() == range { + for rule in &self.rules { + let pattern = + rule.pattern.tree_for_kind_with_reason(node.kind()).map(|p| p.clone()); + out.push(MatchDebugInfo { + matched: matching::get_match(true, rule, &node, restrict_range, &self.sema) + .map_err(|e| MatchFailureReason { + reason: e.reason.unwrap_or_else(|| { + "Match failed, but no reason was given".to_owned() + }), + }), + pattern, + node: node.clone(), + }); + } + } else if let Some(macro_call) = ast::MacroCall::cast(node.clone()) { + if let Some(expanded) = self.sema.expand(¯o_call) { + if let Some(tt) = macro_call.token_tree() { + self.output_debug_for_nodes_at_range( + &expanded, + range, + &Some(self.sema.original_range(tt.syntax())), + out, + ); + } + } + } + } + } +} + +fn parse_error_text(query: &str) -> String { + format!("{}", query.parse::().unwrap_err()) +} + +#[test] +fn parser_empty_query() { + assert_eq!(parse_error_text(""), "Parse error: Cannot find delemiter `==>>`"); +} + +#[test] +fn parser_no_delimiter() { + assert_eq!(parse_error_text("foo()"), "Parse error: Cannot find delemiter `==>>`"); +} + +#[test] +fn parser_two_delimiters() { + assert_eq!( + parse_error_text("foo() ==>> a ==>> b "), + "Parse error: More than one delimiter found" + ); +} + +#[test] +fn parser_repeated_name() { + assert_eq!( + parse_error_text("foo($a, $a) ==>>"), + "Parse error: Name `a` repeats more than once" + ); +} + +#[test] +fn parser_invalid_pattern() { + assert_eq!( + parse_error_text(" ==>> ()"), + "Parse error: Pattern is not a valid Rust expression, type, item, path or pattern" + ); +} + +#[test] +fn parser_invalid_template() { + assert_eq!( + parse_error_text("() ==>> )"), + "Parse error: Replacement is not a valid Rust expression, type, item, path or pattern" + ); +} + +#[test] +fn parser_undefined_placeholder_in_replacement() { + assert_eq!( + parse_error_text("42 ==>> $a"), + "Parse error: Replacement contains undefined placeholders: $a" + ); +} + +fn single_file(code: &str) -> (ra_ide_db::RootDatabase, FileId) { + use ra_db::fixture::WithFixture; + ra_ide_db::RootDatabase::with_single_file(code) +} + +fn assert_ssr_transform(rule: &str, input: &str, result: &str) { + assert_ssr_transforms(&[rule], input, result); +} + +fn assert_ssr_transforms(rules: &[&str], input: &str, result: &str) { + let (db, file_id) = single_file(input); + let mut match_finder = MatchFinder::new(&db); + for rule in rules { + let rule: SsrRule = rule.parse().unwrap(); + match_finder.add_rule(rule); + } + if let Some(edits) = match_finder.edits_for_file(file_id) { + let mut after = input.to_string(); + edits.apply(&mut after); + assert_eq!(after, result); + } else { + panic!("No edits were made"); + } +} + +fn assert_matches(pattern: &str, code: &str, expected: &[&str]) { + let (db, file_id) = single_file(code); + let mut match_finder = MatchFinder::new(&db); + match_finder.add_search_pattern(pattern.parse().unwrap()); + let matched_strings: Vec = match_finder + .find_matches_in_file(file_id) + .flattened() + .matches + .iter() + .map(|m| m.matched_text()) + .collect(); + if matched_strings != expected && !expected.is_empty() { + let debug_info = match_finder.debug_where_text_equal(file_id, &expected[0]); + eprintln!("Test is about to fail. Some possibly useful info: {} nodes had text exactly equal to '{}'", debug_info.len(), &expected[0]); + for d in debug_info { + eprintln!("{:#?}", d); + } + } + assert_eq!(matched_strings, expected); +} + +fn assert_no_match(pattern: &str, code: &str) { + assert_matches(pattern, code, &[]); +} + +#[test] +fn ssr_function_to_method() { + assert_ssr_transform( + "my_function($a, $b) ==>> ($a).my_method($b)", + "loop { my_function( other_func(x, y), z + w) }", + "loop { (other_func(x, y)).my_method(z + w) }", + ) +} + +#[test] +fn ssr_nested_function() { + assert_ssr_transform( + "foo($a, $b, $c) ==>> bar($c, baz($a, $b))", + "fn main { foo (x + value.method(b), x+y-z, true && false) }", + "fn main { bar(true && false, baz(x + value.method(b), x+y-z)) }", + ) +} + +#[test] +fn ssr_expected_spacing() { + assert_ssr_transform( + "foo($x) + bar() ==>> bar($x)", + "fn main() { foo(5) + bar() }", + "fn main() { bar(5) }", + ); +} + +#[test] +fn ssr_with_extra_space() { + assert_ssr_transform( + "foo($x ) + bar() ==>> bar($x)", + "fn main() { foo( 5 ) +bar( ) }", + "fn main() { bar(5) }", + ); +} + +#[test] +fn ssr_keeps_nested_comment() { + assert_ssr_transform( + "foo($x) ==>> bar($x)", + "fn main() { foo(other(5 /* using 5 */)) }", + "fn main() { bar(other(5 /* using 5 */)) }", + ) +} + +#[test] +fn ssr_keeps_comment() { + assert_ssr_transform( + "foo($x) ==>> bar($x)", + "fn main() { foo(5 /* using 5 */) }", + "fn main() { bar(5)/* using 5 */ }", + ) +} + +#[test] +fn ssr_struct_lit() { + assert_ssr_transform( + "foo{a: $a, b: $b} ==>> foo::new($a, $b)", + "fn main() { foo{b:2, a:1} }", + "fn main() { foo::new(1, 2) }", + ) +} + +#[test] +fn ignores_whitespace() { + assert_matches("1+2", "fn f() -> i32 {1 + 2}", &["1 + 2"]); + assert_matches("1 + 2", "fn f() -> i32 {1+2}", &["1+2"]); +} + +#[test] +fn no_match() { + assert_no_match("1 + 3", "fn f() -> i32 {1 + 2}"); +} + +#[test] +fn match_fn_definition() { + assert_matches("fn $a($b: $t) {$c}", "fn f(a: i32) {bar()}", &["fn f(a: i32) {bar()}"]); +} + +#[test] +fn match_struct_definition() { + assert_matches( + "struct $n {$f: Option}", + "struct Bar {} struct Foo {name: Option}", + &["struct Foo {name: Option}"], + ); +} + +#[test] +fn match_expr() { + let code = "fn f() -> i32 {foo(40 + 2, 42)}"; + assert_matches("foo($a, $b)", code, &["foo(40 + 2, 42)"]); + assert_no_match("foo($a, $b, $c)", code); + assert_no_match("foo($a)", code); +} + +#[test] +fn match_nested_method_calls() { + assert_matches( + "$a.z().z().z()", + "fn f() {h().i().j().z().z().z().d().e()}", + &["h().i().j().z().z().z()"], + ); +} + +#[test] +fn match_complex_expr() { + let code = "fn f() -> i32 {foo(bar(40, 2), 42)}"; + assert_matches("foo($a, $b)", code, &["foo(bar(40, 2), 42)"]); + assert_no_match("foo($a, $b, $c)", code); + assert_no_match("foo($a)", code); + assert_matches("bar($a, $b)", code, &["bar(40, 2)"]); +} + +// Trailing commas in the code should be ignored. +#[test] +fn match_with_trailing_commas() { + // Code has comma, pattern doesn't. + assert_matches("foo($a, $b)", "fn f() {foo(1, 2,);}", &["foo(1, 2,)"]); + assert_matches("Foo{$a, $b}", "fn f() {Foo{1, 2,};}", &["Foo{1, 2,}"]); + + // Pattern has comma, code doesn't. + assert_matches("foo($a, $b,)", "fn f() {foo(1, 2);}", &["foo(1, 2)"]); + assert_matches("Foo{$a, $b,}", "fn f() {Foo{1, 2};}", &["Foo{1, 2}"]); +} + +#[test] +fn match_type() { + assert_matches("i32", "fn f() -> i32 {1 + 2}", &["i32"]); + assert_matches("Option<$a>", "fn f() -> Option {42}", &["Option"]); + assert_no_match("Option<$a>", "fn f() -> Result {42}"); +} + +#[test] +fn match_struct_instantiation() { + assert_matches( + "Foo {bar: 1, baz: 2}", + "fn f() {Foo {bar: 1, baz: 2}}", + &["Foo {bar: 1, baz: 2}"], + ); + // Now with placeholders for all parts of the struct. + assert_matches( + "Foo {$a: $b, $c: $d}", + "fn f() {Foo {bar: 1, baz: 2}}", + &["Foo {bar: 1, baz: 2}"], + ); + assert_matches("Foo {}", "fn f() {Foo {}}", &["Foo {}"]); +} + +#[test] +fn match_path() { + assert_matches("foo::bar", "fn f() {foo::bar(42)}", &["foo::bar"]); + assert_matches("$a::bar", "fn f() {foo::bar(42)}", &["foo::bar"]); + assert_matches("foo::$b", "fn f() {foo::bar(42)}", &["foo::bar"]); +} + +#[test] +fn match_pattern() { + assert_matches("Some($a)", "fn f() {if let Some(x) = foo() {}}", &["Some(x)"]); +} + +#[test] +fn match_reordered_struct_instantiation() { + assert_matches( + "Foo {aa: 1, b: 2, ccc: 3}", + "fn f() {Foo {b: 2, ccc: 3, aa: 1}}", + &["Foo {b: 2, ccc: 3, aa: 1}"], + ); + assert_no_match("Foo {a: 1}", "fn f() {Foo {b: 1}}"); + assert_no_match("Foo {a: 1}", "fn f() {Foo {a: 2}}"); + assert_no_match("Foo {a: 1, b: 2}", "fn f() {Foo {a: 1}}"); + assert_no_match("Foo {a: 1, b: 2}", "fn f() {Foo {b: 2}}"); + assert_no_match("Foo {a: 1, }", "fn f() {Foo {a: 1, b: 2}}"); + assert_no_match("Foo {a: 1, z: 9}", "fn f() {Foo {a: 1}}"); +} + +#[test] +fn replace_function_call() { + assert_ssr_transform("foo() ==>> bar()", "fn f1() {foo(); foo();}", "fn f1() {bar(); bar();}"); +} + +#[test] +fn replace_function_call_with_placeholders() { + assert_ssr_transform( + "foo($a, $b) ==>> bar($b, $a)", + "fn f1() {foo(5, 42)}", + "fn f1() {bar(42, 5)}", + ); +} + +#[test] +fn replace_nested_function_calls() { + assert_ssr_transform( + "foo($a) ==>> bar($a)", + "fn f1() {foo(foo(42))}", + "fn f1() {bar(bar(42))}", + ); +} + +#[test] +fn replace_type() { + assert_ssr_transform( + "Result<(), $a> ==>> Option<$a>", + "fn f1() -> Result<(), Vec> {foo()}", + "fn f1() -> Option> {foo()}", + ); +} + +#[test] +fn replace_struct_init() { + assert_ssr_transform( + "Foo {a: $a, b: $b} ==>> Foo::new($a, $b)", + "fn f1() {Foo{b: 1, a: 2}}", + "fn f1() {Foo::new(2, 1)}", + ); +} + +#[test] +fn replace_binary_op() { + assert_ssr_transform( + "$a + $b ==>> $b + $a", + "fn f() {2 * 3 + 4 * 5}", + "fn f() {4 * 5 + 2 * 3}", + ); + assert_ssr_transform( + "$a + $b ==>> $b + $a", + "fn f() {1 + 2 + 3 + 4}", + "fn f() {4 + 3 + 2 + 1}", + ); +} + +#[test] +fn match_binary_op() { + assert_matches("$a + $b", "fn f() {1 + 2 + 3 + 4}", &["1 + 2", "1 + 2 + 3", "1 + 2 + 3 + 4"]); +} + +#[test] +fn multiple_rules() { + assert_ssr_transforms( + &["$a + 1 ==>> add_one($a)", "$a + $b ==>> add($a, $b)"], + "fn f() -> i32 {3 + 2 + 1}", + "fn f() -> i32 {add_one(add(3, 2))}", + ) +} -- cgit v1.2.3