From 662ab2ecc8e29eb5995b3c162fac869838bea9a2 Mon Sep 17 00:00:00 2001 From: David Lattimore Date: Wed, 17 Jun 2020 16:53:51 +1000 Subject: Allow SSR to match type references, items, paths and patterns Part of #3186 --- crates/ra_ide/src/ssr.rs | 563 +---------------------------------------------- 1 file changed, 9 insertions(+), 554 deletions(-) (limited to 'crates/ra_ide/src/ssr.rs') diff --git a/crates/ra_ide/src/ssr.rs b/crates/ra_ide/src/ssr.rs index 762aab962..59c230f6c 100644 --- a/crates/ra_ide/src/ssr.rs +++ b/crates/ra_ide/src/ssr.rs @@ -1,31 +1,12 @@ -use std::{collections::HashMap, iter::once, str::FromStr}; - -use ra_db::{SourceDatabase, SourceDatabaseExt}; +use ra_db::SourceDatabaseExt; use ra_ide_db::{symbol_index::SymbolsDatabase, RootDatabase}; -use ra_syntax::ast::{ - make::try_expr_from_text, ArgList, AstToken, CallExpr, Comment, Expr, MethodCallExpr, - RecordField, RecordLit, -}; -use ra_syntax::{AstNode, SyntaxElement, SyntaxKind, SyntaxNode}; -use ra_text_edit::{TextEdit, TextEditBuilder}; -use rustc_hash::FxHashMap; use crate::SourceFileEdit; - -#[derive(Debug, PartialEq)] -pub struct SsrError(String); - -impl std::fmt::Display for SsrError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "Parse error: {}", self.0) - } -} - -impl std::error::Error for SsrError {} +use ra_ssr::{MatchFinder, SsrError, SsrRule}; // Feature: Structural Seach and Replace // -// Search and replace with named wildcards that will match any expression. +// Search and replace with named wildcards that will match any expression, type, path, pattern or item. // The syntax for a structural search replace command is ` ==>> `. // A `$` placeholder in the search pattern will match any AST node and `$` will reference it in the replacement. // Available via the command `rust-analyzer.ssr`. @@ -46,550 +27,24 @@ impl std::error::Error for SsrError {} // | VS Code | **Rust Analyzer: Structural Search Replace** // |=== pub fn parse_search_replace( - query: &str, + rule: &str, parse_only: bool, db: &RootDatabase, ) -> Result, SsrError> { let mut edits = vec![]; - let query: SsrQuery = query.parse()?; + let rule: SsrRule = rule.parse()?; if parse_only { return Ok(edits); } + let mut match_finder = MatchFinder::new(db); + match_finder.add_rule(rule); for &root in db.local_roots().iter() { let sr = db.source_root(root); for file_id in sr.walk() { - let matches = find(&query.pattern, db.parse(file_id).tree().syntax()); - if !matches.matches.is_empty() { - edits.push(SourceFileEdit { file_id, edit: replace(&matches, &query.template) }); + if let Some(edit) = match_finder.edits_for_file(file_id) { + edits.push(SourceFileEdit { file_id, edit }); } } } Ok(edits) } - -#[derive(Debug)] -struct SsrQuery { - pattern: SsrPattern, - template: SsrTemplate, -} - -#[derive(Debug)] -struct SsrPattern { - pattern: SyntaxNode, - vars: Vec, -} - -/// Represents a `$var` in an SSR query. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -struct Var(String); - -#[derive(Debug)] -struct SsrTemplate { - template: SyntaxNode, - placeholders: FxHashMap, -} - -type Binding = HashMap; - -#[derive(Debug)] -struct Match { - place: SyntaxNode, - binding: Binding, - ignored_comments: Vec, -} - -#[derive(Debug)] -struct SsrMatches { - matches: Vec, -} - -impl FromStr for SsrQuery { - type Err = SsrError; - - fn from_str(query: &str) -> Result { - let mut it = query.split("==>>"); - let pattern = it.next().expect("at least empty string").trim(); - let mut template = it - .next() - .ok_or_else(|| SsrError("Cannot find delemiter `==>>`".into()))? - .trim() - .to_string(); - if it.next().is_some() { - return Err(SsrError("More than one delimiter found".into())); - } - let mut vars = vec![]; - let mut it = pattern.split('$'); - let mut pattern = it.next().expect("something").to_string(); - - for part in it.map(split_by_var) { - let (var, remainder) = part?; - let new_var = create_name(var, &mut vars)?; - pattern.push_str(new_var); - pattern.push_str(remainder); - template = replace_in_template(template, var, new_var); - } - - let template = try_expr_from_text(&template) - .ok_or(SsrError("Template is not an expression".into()))? - .syntax() - .clone(); - let mut placeholders = FxHashMap::default(); - - traverse(&template, &mut |n| { - if let Some(v) = vars.iter().find(|v| v.0.as_str() == n.text()) { - placeholders.insert(n.clone(), v.clone()); - false - } else { - true - } - }); - - let pattern = SsrPattern { - pattern: try_expr_from_text(&pattern) - .ok_or(SsrError("Pattern is not an expression".into()))? - .syntax() - .clone(), - vars, - }; - let template = SsrTemplate { template, placeholders }; - Ok(SsrQuery { pattern, template }) - } -} - -fn traverse(node: &SyntaxNode, go: &mut impl FnMut(&SyntaxNode) -> bool) { - if !go(node) { - return; - } - for ref child in node.children() { - traverse(child, go); - } -} - -fn split_by_var(s: &str) -> Result<(&str, &str), SsrError> { - let end_of_name = s.find(|c| !char::is_ascii_alphanumeric(&c)).unwrap_or_else(|| s.len()); - let name = &s[..end_of_name]; - is_name(name)?; - Ok((name, &s[end_of_name..])) -} - -fn is_name(s: &str) -> Result<(), SsrError> { - if s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { - Ok(()) - } else { - Err(SsrError("Name can contain only alphanumerics and _".into())) - } -} - -fn replace_in_template(template: String, var: &str, new_var: &str) -> String { - let name = format!("${}", var); - template.replace(&name, new_var) -} - -fn create_name<'a>(name: &str, vars: &'a mut Vec) -> Result<&'a str, SsrError> { - let sanitized_name = format!("__search_pattern_{}", name); - if vars.iter().any(|a| a.0 == sanitized_name) { - return Err(SsrError(format!("Name `{}` repeats more than once", name))); - } - vars.push(Var(sanitized_name)); - Ok(&vars.last().unwrap().0) -} - -fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { - fn check_record_lit( - pattern: RecordLit, - code: RecordLit, - placeholders: &[Var], - match_: Match, - ) -> Option { - let match_ = check_opt_nodes(pattern.path(), code.path(), placeholders, match_)?; - - let mut pattern_fields: Vec = - pattern.record_field_list().map(|x| x.fields().collect()).unwrap_or_default(); - let mut code_fields: Vec = - code.record_field_list().map(|x| x.fields().collect()).unwrap_or_default(); - - if pattern_fields.len() != code_fields.len() { - return None; - } - - let by_name = |a: &RecordField, b: &RecordField| { - a.name_ref() - .map(|x| x.syntax().text().to_string()) - .cmp(&b.name_ref().map(|x| x.syntax().text().to_string())) - }; - pattern_fields.sort_by(by_name); - code_fields.sort_by(by_name); - - pattern_fields.into_iter().zip(code_fields.into_iter()).fold( - Some(match_), - |accum, (a, b)| { - accum.and_then(|match_| check_opt_nodes(Some(a), Some(b), placeholders, match_)) - }, - ) - } - - fn check_call_and_method_call( - pattern: CallExpr, - code: MethodCallExpr, - placeholders: &[Var], - match_: Match, - ) -> Option { - let (pattern_name, pattern_type_args) = if let Some(Expr::PathExpr(path_exr)) = - pattern.expr() - { - let segment = path_exr.path().and_then(|p| p.segment()); - (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list())) - } else { - (None, None) - }; - let match_ = check_opt_nodes(pattern_name, code.name_ref(), placeholders, match_)?; - let match_ = - check_opt_nodes(pattern_type_args, code.type_arg_list(), placeholders, match_)?; - let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args(); - let code_args = code.syntax().children().find_map(ArgList::cast)?.args(); - let code_args = once(code.expr()?).chain(code_args); - check_iter(pattern_args, code_args, placeholders, match_) - } - - fn check_method_call_and_call( - pattern: MethodCallExpr, - code: CallExpr, - placeholders: &[Var], - match_: Match, - ) -> Option { - let (code_name, code_type_args) = if let Some(Expr::PathExpr(path_exr)) = code.expr() { - let segment = path_exr.path().and_then(|p| p.segment()); - (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list())) - } else { - (None, None) - }; - let match_ = check_opt_nodes(pattern.name_ref(), code_name, placeholders, match_)?; - let match_ = - check_opt_nodes(pattern.type_arg_list(), code_type_args, placeholders, match_)?; - let code_args = code.syntax().children().find_map(ArgList::cast)?.args(); - let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args(); - let pattern_args = once(pattern.expr()?).chain(pattern_args); - check_iter(pattern_args, code_args, placeholders, match_) - } - - fn check_opt_nodes( - pattern: Option, - code: Option, - placeholders: &[Var], - match_: Match, - ) -> Option { - match (pattern, code) { - (Some(pattern), Some(code)) => check( - &pattern.syntax().clone().into(), - &code.syntax().clone().into(), - placeholders, - match_, - ), - (None, None) => Some(match_), - _ => None, - } - } - - fn check_iter( - mut pattern: I1, - mut code: I2, - placeholders: &[Var], - match_: Match, - ) -> Option - where - T: AstNode, - I1: Iterator, - I2: Iterator, - { - pattern - .by_ref() - .zip(code.by_ref()) - .fold(Some(match_), |accum, (a, b)| { - accum.and_then(|match_| { - check( - &a.syntax().clone().into(), - &b.syntax().clone().into(), - placeholders, - match_, - ) - }) - }) - .filter(|_| pattern.next().is_none() && code.next().is_none()) - } - - fn check( - pattern: &SyntaxElement, - code: &SyntaxElement, - placeholders: &[Var], - mut match_: Match, - ) -> Option { - match (&pattern, &code) { - (SyntaxElement::Token(pattern), SyntaxElement::Token(code)) => { - if pattern.text() == code.text() { - Some(match_) - } else { - None - } - } - (SyntaxElement::Node(pattern), SyntaxElement::Node(code)) => { - if placeholders.iter().any(|n| n.0.as_str() == pattern.text()) { - match_.binding.insert(Var(pattern.text().to_string()), code.clone()); - Some(match_) - } else { - if let (Some(pattern), Some(code)) = - (RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone())) - { - check_record_lit(pattern, code, placeholders, match_) - } else if let (Some(pattern), Some(code)) = - (CallExpr::cast(pattern.clone()), MethodCallExpr::cast(code.clone())) - { - check_call_and_method_call(pattern, code, placeholders, match_) - } else if let (Some(pattern), Some(code)) = - (MethodCallExpr::cast(pattern.clone()), CallExpr::cast(code.clone())) - { - check_method_call_and_call(pattern, code, placeholders, match_) - } else { - let mut pattern_children = pattern - .children_with_tokens() - .filter(|element| !element.kind().is_trivia()); - let mut code_children = code - .children_with_tokens() - .filter(|element| !element.kind().is_trivia()); - let new_ignored_comments = - code.children_with_tokens().filter_map(|element| { - element.as_token().and_then(|token| Comment::cast(token.clone())) - }); - match_.ignored_comments.extend(new_ignored_comments); - pattern_children - .by_ref() - .zip(code_children.by_ref()) - .fold(Some(match_), |accum, (a, b)| { - accum.and_then(|match_| check(&a, &b, placeholders, match_)) - }) - .filter(|_| { - pattern_children.next().is_none() && code_children.next().is_none() - }) - } - } - } - _ => None, - } - } - let kind = pattern.pattern.kind(); - let matches = code - .descendants() - .filter(|n| { - n.kind() == kind - || (kind == SyntaxKind::CALL_EXPR && n.kind() == SyntaxKind::METHOD_CALL_EXPR) - || (kind == SyntaxKind::METHOD_CALL_EXPR && n.kind() == SyntaxKind::CALL_EXPR) - }) - .filter_map(|code| { - let match_ = - Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] }; - check(&pattern.pattern.clone().into(), &code.into(), &pattern.vars, match_) - }) - .collect(); - SsrMatches { matches } -} - -fn replace(matches: &SsrMatches, template: &SsrTemplate) -> TextEdit { - let mut builder = TextEditBuilder::default(); - for match_ in &matches.matches { - builder.replace( - match_.place.text_range(), - render_replace(&match_.binding, &match_.ignored_comments, template), - ); - } - builder.finish() -} - -fn render_replace( - binding: &Binding, - ignored_comments: &Vec, - template: &SsrTemplate, -) -> String { - let edit = { - let mut builder = TextEditBuilder::default(); - for element in template.template.descendants() { - if let Some(var) = template.placeholders.get(&element) { - builder.replace(element.text_range(), binding[var].to_string()) - } - } - for comment in ignored_comments { - builder.insert(template.template.text_range().end(), comment.syntax().to_string()) - } - builder.finish() - }; - - let mut text = template.template.text().to_string(); - edit.apply(&mut text); - text -} - -#[cfg(test)] -mod tests { - use super::*; - use ra_syntax::SourceFile; - - fn parse_error_text(query: &str) -> String { - format!("{}", query.parse::().unwrap_err()) - } - - #[test] - fn parser_happy_case() { - let result: SsrQuery = "foo($a, $b) ==>> bar($b, $a)".parse().unwrap(); - assert_eq!(&result.pattern.pattern.text(), "foo(__search_pattern_a, __search_pattern_b)"); - assert_eq!(result.pattern.vars.len(), 2); - assert_eq!(result.pattern.vars[0].0, "__search_pattern_a"); - assert_eq!(result.pattern.vars[1].0, "__search_pattern_b"); - assert_eq!(&result.template.template.text(), "bar(__search_pattern_b, __search_pattern_a)"); - } - - #[test] - fn parser_empty_query() { - assert_eq!(parse_error_text(""), "Parse error: Cannot find delemiter `==>>`"); - } - - #[test] - fn parser_no_delimiter() { - assert_eq!(parse_error_text("foo()"), "Parse error: Cannot find delemiter `==>>`"); - } - - #[test] - fn parser_two_delimiters() { - assert_eq!( - parse_error_text("foo() ==>> a ==>> b "), - "Parse error: More than one delimiter found" - ); - } - - #[test] - fn parser_repeated_name() { - assert_eq!( - parse_error_text("foo($a, $a) ==>>"), - "Parse error: Name `a` repeats more than once" - ); - } - - #[test] - fn parser_invlid_pattern() { - assert_eq!(parse_error_text(" ==>> ()"), "Parse error: Pattern is not an expression"); - } - - #[test] - fn parser_invlid_template() { - assert_eq!(parse_error_text("() ==>> )"), "Parse error: Template is not an expression"); - } - - #[test] - fn parse_match_replace() { - let query: SsrQuery = "foo($x) ==>> bar($x)".parse().unwrap(); - let input = "fn main() { foo(1+2); }"; - - let code = SourceFile::parse(input).tree(); - let matches = find(&query.pattern, code.syntax()); - assert_eq!(matches.matches.len(), 1); - assert_eq!(matches.matches[0].place.text(), "foo(1+2)"); - assert_eq!(matches.matches[0].binding.len(), 1); - assert_eq!( - matches.matches[0].binding[&Var("__search_pattern_x".to_string())].text(), - "1+2" - ); - - let edit = replace(&matches, &query.template); - let mut after = input.to_string(); - edit.apply(&mut after); - assert_eq!(after, "fn main() { bar(1+2); }"); - } - - fn assert_ssr_transform(query: &str, input: &str, result: &str) { - let query: SsrQuery = query.parse().unwrap(); - let code = SourceFile::parse(input).tree(); - let matches = find(&query.pattern, code.syntax()); - let edit = replace(&matches, &query.template); - let mut after = input.to_string(); - edit.apply(&mut after); - assert_eq!(after, result); - } - - #[test] - fn ssr_function_to_method() { - assert_ssr_transform( - "my_function($a, $b) ==>> ($a).my_method($b)", - "loop { my_function( other_func(x, y), z + w) }", - "loop { (other_func(x, y)).my_method(z + w) }", - ) - } - - #[test] - fn ssr_nested_function() { - assert_ssr_transform( - "foo($a, $b, $c) ==>> bar($c, baz($a, $b))", - "fn main { foo (x + value.method(b), x+y-z, true && false) }", - "fn main { bar(true && false, baz(x + value.method(b), x+y-z)) }", - ) - } - - #[test] - fn ssr_expected_spacing() { - assert_ssr_transform( - "foo($x) + bar() ==>> bar($x)", - "fn main() { foo(5) + bar() }", - "fn main() { bar(5) }", - ); - } - - #[test] - fn ssr_with_extra_space() { - assert_ssr_transform( - "foo($x ) + bar() ==>> bar($x)", - "fn main() { foo( 5 ) +bar( ) }", - "fn main() { bar(5) }", - ); - } - - #[test] - fn ssr_keeps_nested_comment() { - assert_ssr_transform( - "foo($x) ==>> bar($x)", - "fn main() { foo(other(5 /* using 5 */)) }", - "fn main() { bar(other(5 /* using 5 */)) }", - ) - } - - #[test] - fn ssr_keeps_comment() { - assert_ssr_transform( - "foo($x) ==>> bar($x)", - "fn main() { foo(5 /* using 5 */) }", - "fn main() { bar(5)/* using 5 */ }", - ) - } - - #[test] - fn ssr_struct_lit() { - assert_ssr_transform( - "foo{a: $a, b: $b} ==>> foo::new($a, $b)", - "fn main() { foo{b:2, a:1} }", - "fn main() { foo::new(1, 2) }", - ) - } - - #[test] - fn ssr_call_and_method_call() { - assert_ssr_transform( - "foo::<'a>($a, $b)) ==>> foo2($a, $b)", - "fn main() { get().bar.foo::<'a>(1); }", - "fn main() { foo2(get().bar, 1); }", - ) - } - - #[test] - fn ssr_method_call_and_call() { - assert_ssr_transform( - "$o.foo::($a)) ==>> $o.foo2($a)", - "fn main() { X::foo::(x, 1); }", - "fn main() { x.foo2(1); }", - ) - } -} -- cgit v1.2.3