From 35a2cd08c154c6a0429bd02611991b7d887015ee Mon Sep 17 00:00:00 2001 From: Mikhail Modin Date: Thu, 2 Apr 2020 20:17:33 +0100 Subject: Adds to SSR match for semantically equivalent call and method call --- crates/ra_ide/src/ssr.rs | 122 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 110 insertions(+), 12 deletions(-) (limited to 'crates/ra_ide/src') diff --git a/crates/ra_ide/src/ssr.rs b/crates/ra_ide/src/ssr.rs index 1abb891c1..7b93ff2d2 100644 --- a/crates/ra_ide/src/ssr.rs +++ b/crates/ra_ide/src/ssr.rs @@ -5,12 +5,14 @@ use ra_db::{SourceDatabase, SourceDatabaseExt}; use ra_ide_db::symbol_index::SymbolsDatabase; use ra_ide_db::RootDatabase; use ra_syntax::ast::make::try_expr_from_text; -use ra_syntax::ast::{AstToken, Comment, RecordField, RecordLit}; -use ra_syntax::{AstNode, SyntaxElement, SyntaxNode}; +use ra_syntax::ast::{ + ArgList, AstToken, CallExpr, Comment, Expr, MethodCallExpr, RecordField, RecordLit, +}; +use ra_syntax::{AstNode, SyntaxElement, SyntaxKind, SyntaxNode}; use ra_text_edit::{TextEdit, TextEditBuilder}; use rustc_hash::FxHashMap; use std::collections::HashMap; -use std::str::FromStr; +use std::{iter::once, str::FromStr}; #[derive(Debug, PartialEq)] pub struct SsrError(String); @@ -219,6 +221,50 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { ) } + fn check_call_and_method_call( + pattern: CallExpr, + code: MethodCallExpr, + placeholders: &[Var], + match_: Match, + ) -> Option { + let (pattern_name, pattern_type_args) = if let Some(Expr::PathExpr(path_exr)) = + pattern.expr() + { + let segment = path_exr.path().and_then(|p| p.segment()); + (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list())) + } else { + (None, None) + }; + let match_ = check_opt_nodes(pattern_name, code.name_ref(), placeholders, match_)?; + let match_ = + check_opt_nodes(pattern_type_args, code.type_arg_list(), placeholders, match_)?; + let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args(); + let code_args = code.syntax().children().find_map(ArgList::cast)?.args(); + let code_args = once(code.expr()?).chain(code_args); + check_iter(pattern_args, code_args, placeholders, match_) + } + + fn check_method_call_and_call( + pattern: MethodCallExpr, + code: CallExpr, + placeholders: &[Var], + match_: Match, + ) -> Option { + let (code_name, code_type_args) = if let Some(Expr::PathExpr(path_exr)) = code.expr() { + let segment = path_exr.path().and_then(|p| p.segment()); + (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list())) + } else { + (None, None) + }; + let match_ = check_opt_nodes(pattern.name_ref(), code_name, placeholders, match_)?; + let match_ = + check_opt_nodes(pattern.type_arg_list(), code_type_args, placeholders, match_)?; + let code_args = code.syntax().children().find_map(ArgList::cast)?.args(); + let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args(); + let pattern_args = once(pattern.expr()?).chain(pattern_args); + check_iter(pattern_args, code_args, placeholders, match_) + } + fn check_opt_nodes( pattern: Option, code: Option, @@ -227,8 +273,8 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { ) -> Option { match (pattern, code) { (Some(pattern), Some(code)) => check( - &SyntaxElement::from(pattern.syntax().clone()), - &SyntaxElement::from(code.syntax().clone()), + &pattern.syntax().clone().into(), + &code.syntax().clone().into(), placeholders, match_, ), @@ -237,6 +283,33 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { } } + fn check_iter( + mut pattern: I1, + mut code: I2, + placeholders: &[Var], + match_: Match, + ) -> Option + where + T: AstNode, + I1: Iterator, + I2: Iterator, + { + pattern + .by_ref() + .zip(code.by_ref()) + .fold(Some(match_), |accum, (a, b)| { + accum.and_then(|match_| { + check( + &a.syntax().clone().into(), + &b.syntax().clone().into(), + placeholders, + match_, + ) + }) + }) + .filter(|_| pattern.next().is_none() && code.next().is_none()) + } + fn check( pattern: &SyntaxElement, code: &SyntaxElement, @@ -260,6 +333,14 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { (RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone())) { check_record_lit(pattern, code, placeholders, match_) + } else if let (Some(pattern), Some(code)) = + (CallExpr::cast(pattern.clone()), MethodCallExpr::cast(code.clone())) + { + check_call_and_method_call(pattern, code, placeholders, match_) + } else if let (Some(pattern), Some(code)) = + (MethodCallExpr::cast(pattern.clone()), CallExpr::cast(code.clone())) + { + check_method_call_and_call(pattern, code, placeholders, match_) } else { let mut pattern_children = pattern .children_with_tokens() @@ -290,16 +371,15 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { let kind = pattern.pattern.kind(); let matches = code .descendants() - .filter(|n| n.kind() == kind) + .filter(|n| { + n.kind() == kind + || (kind == SyntaxKind::CALL_EXPR && n.kind() == SyntaxKind::METHOD_CALL_EXPR) + || (kind == SyntaxKind::METHOD_CALL_EXPR && n.kind() == SyntaxKind::CALL_EXPR) + }) .filter_map(|code| { let match_ = Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] }; - check( - &SyntaxElement::from(pattern.pattern.clone()), - &SyntaxElement::from(code), - &pattern.vars, - match_, - ) + check(&pattern.pattern.clone().into(), &code.into(), &pattern.vars, match_) }) .collect(); SsrMatches { matches } @@ -498,4 +578,22 @@ mod tests { "fn main() { foo::new(1, 2) }", ) } + + #[test] + fn ssr_call_and_method_call() { + assert_ssr_transform( + "foo::<'a>($a:expr, $b:expr)) ==>> foo2($a, $b)", + "fn main() { get().bar.foo::<'a>(1); }", + "fn main() { foo2(get().bar, 1); }", + ) + } + + #[test] + fn ssr_method_call_and_call() { + assert_ssr_transform( + "$o:expr.foo::($a:expr)) ==>> $o.foo2($a)", + "fn main() { X::foo::(x, 1); }", + "fn main() { x.foo2(1); }", + ) + } } -- cgit v1.2.3