From 4f1d90e73bb47b4da3df62b05a84ebd4297ed78d Mon Sep 17 00:00:00 2001 From: adamrk Date: Sat, 22 Feb 2020 22:58:48 +0100 Subject: Handle trivia in strucural search and replace --- crates/ra_ide/src/ssr.rs | 145 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 118 insertions(+), 27 deletions(-) (limited to 'crates') diff --git a/crates/ra_ide/src/ssr.rs b/crates/ra_ide/src/ssr.rs index 902c29fc6..83c212494 100644 --- a/crates/ra_ide/src/ssr.rs +++ b/crates/ra_ide/src/ssr.rs @@ -3,9 +3,7 @@ use crate::source_change::SourceFileEdit; use ra_ide_db::RootDatabase; use ra_syntax::ast::make::expr_from_text; -use ra_syntax::AstNode; -use ra_syntax::SyntaxElement; -use ra_syntax::SyntaxNode; +use ra_syntax::{AstNode, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken}; use ra_text_edit::{TextEdit, TextEditBuilder}; use rustc_hash::FxHashMap; use std::collections::HashMap; @@ -72,6 +70,7 @@ type Binding = HashMap; struct Match { place: SyntaxNode, binding: Binding, + ignored_comments: Vec, } #[derive(Debug)] @@ -179,25 +178,51 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { pattern: &SyntaxElement, code: &SyntaxElement, placeholders: &[Var], - match_: &mut Match, - ) -> bool { + mut match_: Match, + ) -> Option { match (pattern, code) { (SyntaxElement::Token(ref pattern), SyntaxElement::Token(ref code)) => { - pattern.text() == code.text() + if pattern.text() == code.text() { + Some(match_) + } else { + None + } } (SyntaxElement::Node(ref pattern), SyntaxElement::Node(ref code)) => { if placeholders.iter().any(|n| n.0.as_str() == pattern.text()) { match_.binding.insert(Var(pattern.text().to_string()), code.clone()); - true + Some(match_) } else { - pattern.green().children().count() == code.green().children().count() - && pattern - .children_with_tokens() - .zip(code.children_with_tokens()) - .all(|(a, b)| check(&a, &b, placeholders, match_)) + let mut pattern_children = pattern + .children_with_tokens() + .filter(|element| !element.kind().is_trivia()); + let mut code_children = + code.children_with_tokens().filter(|element| !element.kind().is_trivia()); + let new_ignored_comments = code.children_with_tokens().filter_map(|element| { + if let SyntaxElement::Token(token) = element { + if token.kind() == SyntaxKind::COMMENT { + return Some(token.clone()); + } + } + None + }); + match_.ignored_comments.extend(new_ignored_comments); + let match_from_children = pattern_children + .by_ref() + .zip(code_children.by_ref()) + .fold(Some(match_), |accum, (a, b)| { + accum.and_then(|match_| check(&a, &b, placeholders, match_)) + }); + match_from_children.and_then(|match_| { + if pattern_children.count() == 0 && code_children.count() == 0 { + Some(match_) + } else { + None + } + }) } } - _ => false, + _ => None, } } let kind = pattern.pattern.kind(); @@ -205,18 +230,12 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { .descendants_with_tokens() .filter(|n| n.kind() == kind) .filter_map(|code| { - let mut match_ = - Match { place: code.as_node().unwrap().clone(), binding: HashMap::new() }; - if check( - &SyntaxElement::from(pattern.pattern.clone()), - &code, - &pattern.vars, - &mut match_, - ) { - Some(match_) - } else { - None - } + let match_ = Match { + place: code.as_node().unwrap().clone(), + binding: HashMap::new(), + ignored_comments: vec![], + }; + check(&SyntaxElement::from(pattern.pattern.clone()), &code, &pattern.vars, match_) }) .collect(); SsrMatches { matches } @@ -225,18 +244,28 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { fn replace(matches: &SsrMatches, template: &SsrTemplate) -> TextEdit { let mut builder = TextEditBuilder::default(); for match_ in &matches.matches { - builder.replace(match_.place.text_range(), render_replace(&match_.binding, template)); + builder.replace( + match_.place.text_range(), + render_replace(&match_.binding, &match_.ignored_comments, template), + ); } builder.finish() } -fn render_replace(binding: &Binding, template: &SsrTemplate) -> String { +fn render_replace( + binding: &Binding, + ignored_comments: &Vec, + template: &SsrTemplate, +) -> String { let mut builder = TextEditBuilder::default(); for element in template.template.descendants() { if let Some(var) = template.placeholders.get(&element) { builder.replace(element.text_range(), binding[var].to_string()) } } + for comment in ignored_comments { + builder.insert(template.template.text_range().end(), comment.to_string()) + } builder.finish().apply(&template.template.text().to_string()) } @@ -325,4 +354,66 @@ mod tests { let edit = replace(&matches, &query.template); assert_eq!(edit.apply(input), "fn main() { bar(1+2); }"); } + + fn assert_ssr_transform(query: &str, input: &str, result: &str) { + let query: SsrQuery = query.parse().unwrap(); + let code = SourceFile::parse(input).tree(); + let matches = find(&query.pattern, code.syntax()); + let edit = replace(&matches, &query.template); + assert_eq!(edit.apply(input), result); + } + + #[test] + fn ssr_function_to_method() { + assert_ssr_transform( + "my_function($a:expr, $b:expr) ==>> ($a).my_method($b)", + "loop { my_function( other_func(x, y), z + w) }", + "loop { (other_func(x, y)).my_method(z + w) }", + ) + } + + #[test] + fn ssr_nested_function() { + assert_ssr_transform( + "foo($a:expr, $b:expr, $c:expr) ==>> bar($c, baz($a, $b))", + "fn main { foo (x + value.method(b), x+y-z, true && false) }", + "fn main { bar(true && false, baz(x + value.method(b), x+y-z)) }", + ) + } + + #[test] + fn ssr_expected_spacing() { + assert_ssr_transform( + "foo($x:expr) + bar() ==>> bar($x)", + "fn main() { foo(5) + bar() }", + "fn main() { bar(5) }", + ); + } + + #[test] + fn ssr_with_extra_space() { + assert_ssr_transform( + "foo($x:expr ) + bar() ==>> bar($x)", + "fn main() { foo( 5 ) +bar( ) }", + "fn main() { bar(5) }", + ); + } + + #[test] + fn ssr_keeps_nested_comment() { + assert_ssr_transform( + "foo($x:expr) ==>> bar($x)", + "fn main() { foo(other(5 /* using 5 */)) }", + "fn main() { bar(other(5 /* using 5 */)) }", + ) + } + + #[test] + fn ssr_keeps_comment() { + assert_ssr_transform( + "foo($x:expr) ==>> bar($x)", + "fn main() { foo(5 /* using 5 */) }", + "fn main() { bar(5)/* using 5 */ }", + ) + } } -- cgit v1.2.3