From 3100de842b3cc33c9ad364f10c7f740ac760f564 Mon Sep 17 00:00:00 2001 From: David Lattimore Date: Wed, 5 Aug 2020 19:48:52 +1000 Subject: Structured search replace now handles UFCS calls to trait methods --- crates/ra_ssr/src/matching.rs | 67 +++++++++++++++++++++++++++++++++++------- crates/ra_ssr/src/resolving.rs | 33 ++++++++++++++++++--- crates/ra_ssr/src/tests.rs | 29 ++++++++++++++++++ 3 files changed, 114 insertions(+), 15 deletions(-) diff --git a/crates/ra_ssr/src/matching.rs b/crates/ra_ssr/src/matching.rs index 125bf3895..6e0b92352 100644 --- a/crates/ra_ssr/src/matching.rs +++ b/crates/ra_ssr/src/matching.rs @@ -3,7 +3,7 @@ use crate::{ parsing::{Constraint, NodeKind, Placeholder}, - resolving::{ResolvedPattern, ResolvedRule}, + resolving::{ResolvedPattern, ResolvedRule, UfcsCallInfo}, SsrMatches, }; use hir::Semantics; @@ -190,11 +190,12 @@ impl<'db, 'sema> Matcher<'db, 'sema> { return Ok(()); } // We allow a UFCS call to match a method call, provided they resolve to the same function. - if let Some(pattern_function) = self.rule.pattern.ufcs_function_calls.get(pattern) { - if let (Some(pattern), Some(code)) = - (ast::CallExpr::cast(pattern.clone()), ast::MethodCallExpr::cast(code.clone())) - { - return self.attempt_match_ufcs(phase, &pattern, &code, *pattern_function); + if let Some(pattern_ufcs) = self.rule.pattern.ufcs_function_calls.get(pattern) { + if let Some(code) = ast::MethodCallExpr::cast(code.clone()) { + return self.attempt_match_ufcs_to_method_call(phase, pattern_ufcs, &code); + } + if let Some(code) = ast::CallExpr::cast(code.clone()) { + return self.attempt_match_ufcs_to_ufcs(phase, pattern_ufcs, &code); } } if pattern.kind() != code.kind() { @@ -521,23 +522,28 @@ impl<'db, 'sema> Matcher<'db, 'sema> { Ok(()) } - fn attempt_match_ufcs( + fn attempt_match_ufcs_to_method_call( &self, phase: &mut Phase, - pattern: &ast::CallExpr, + pattern_ufcs: &UfcsCallInfo, code: &ast::MethodCallExpr, - pattern_function: hir::Function, ) -> Result<(), MatchFailed> { use ast::ArgListOwner; let code_resolved_function = self .sema .resolve_method_call(code) .ok_or_else(|| match_error!("Failed to resolve method call"))?; - if pattern_function != code_resolved_function { + if pattern_ufcs.function != code_resolved_function { fail_match!("Method call resolved to a different function"); } + if code_resolved_function.has_self_param(self.sema.db) { + if let (Some(pattern_type), Some(expr)) = (&pattern_ufcs.qualifier_type, &code.expr()) { + self.check_expr_type(pattern_type, expr)?; + } + } // Check arguments. - let mut pattern_args = pattern + let mut pattern_args = pattern_ufcs + .call_expr .arg_list() .ok_or_else(|| match_error!("Pattern function call has no args"))? .args(); @@ -552,6 +558,45 @@ impl<'db, 'sema> Matcher<'db, 'sema> { } } + fn attempt_match_ufcs_to_ufcs( + &self, + phase: &mut Phase, + pattern_ufcs: &UfcsCallInfo, + code: &ast::CallExpr, + ) -> Result<(), MatchFailed> { + use ast::ArgListOwner; + // Check that the first argument is the expected type. + if let (Some(pattern_type), Some(expr)) = ( + &pattern_ufcs.qualifier_type, + &code.arg_list().and_then(|code_args| code_args.args().next()), + ) { + self.check_expr_type(pattern_type, expr)?; + } + self.attempt_match_node_children(phase, pattern_ufcs.call_expr.syntax(), code.syntax()) + } + + fn check_expr_type( + &self, + pattern_type: &hir::Type, + expr: &ast::Expr, + ) -> Result<(), MatchFailed> { + use hir::HirDisplay; + let code_type = self.sema.type_of_expr(&expr).ok_or_else(|| { + match_error!("Failed to get receiver type for `{}`", expr.syntax().text()) + })?; + if !code_type + .autoderef(self.sema.db) + .any(|deref_code_type| *pattern_type == deref_code_type) + { + fail_match!( + "Pattern type `{}` didn't match code type `{}`", + pattern_type.display(self.sema.db), + code_type.display(self.sema.db) + ); + } + Ok(()) + } + fn get_placeholder(&self, element: &SyntaxElement) -> Option<&Placeholder> { only_ident(element.clone()).and_then(|ident| self.rule.get_placeholder(&ident)) } diff --git a/crates/ra_ssr/src/resolving.rs b/crates/ra_ssr/src/resolving.rs index 7e7585c8b..bfc20705b 100644 --- a/crates/ra_ssr/src/resolving.rs +++ b/crates/ra_ssr/src/resolving.rs @@ -25,7 +25,7 @@ pub(crate) struct ResolvedPattern { pub(crate) node: SyntaxNode, // Paths in `node` that we've resolved. pub(crate) resolved_paths: FxHashMap, - pub(crate) ufcs_function_calls: FxHashMap, + pub(crate) ufcs_function_calls: FxHashMap, pub(crate) contains_self: bool, } @@ -35,6 +35,12 @@ pub(crate) struct ResolvedPath { pub(crate) depth: u32, } +pub(crate) struct UfcsCallInfo { + pub(crate) call_expr: ast::CallExpr, + pub(crate) function: hir::Function, + pub(crate) qualifier_type: Option, +} + impl ResolvedRule { pub(crate) fn new( rule: parsing::ParsedRule, @@ -70,6 +76,7 @@ struct Resolver<'a, 'db> { impl Resolver<'_, '_> { fn resolve_pattern_tree(&self, pattern: SyntaxNode) -> Result { + use syntax::ast::AstNode; use syntax::{SyntaxElement, T}; let mut resolved_paths = FxHashMap::default(); self.resolve(pattern.clone(), 0, &mut resolved_paths)?; @@ -77,11 +84,15 @@ impl Resolver<'_, '_> { .iter() .filter_map(|(path_node, resolved)| { if let Some(grandparent) = path_node.parent().and_then(|parent| parent.parent()) { - if grandparent.kind() == SyntaxKind::CALL_EXPR { + if let Some(call_expr) = ast::CallExpr::cast(grandparent.clone()) { if let hir::PathResolution::AssocItem(hir::AssocItem::Function(function)) = - &resolved.resolution + resolved.resolution { - return Some((grandparent, *function)); + let qualifier_type = self.resolution_scope.qualifier_type(path_node); + return Some(( + grandparent, + UfcsCallInfo { call_expr, function, qualifier_type }, + )); } } } @@ -226,6 +237,20 @@ impl<'db> ResolutionScope<'db> { None } } + + fn qualifier_type(&self, path: &SyntaxNode) -> Option { + use syntax::ast::AstNode; + if let Some(path) = ast::Path::cast(path.clone()) { + if let Some(qualifier) = path.qualifier() { + if let Some(resolved_qualifier) = self.resolve_path(&qualifier) { + if let hir::PathResolution::Def(hir::ModuleDef::Adt(adt)) = resolved_qualifier { + return Some(adt.ty(self.scope.db)); + } + } + } + } + None + } } fn is_self(path: &ast::Path) -> bool { diff --git a/crates/ra_ssr/src/tests.rs b/crates/ra_ssr/src/tests.rs index 7d4d470c0..4bc09c1e4 100644 --- a/crates/ra_ssr/src/tests.rs +++ b/crates/ra_ssr/src/tests.rs @@ -1143,3 +1143,32 @@ fn replace_self() { "#]], ); } + +#[test] +fn match_trait_method_call() { + // `Bar::foo` and `Bar2::foo` resolve to the same function. Make sure we only match if the type + // matches what's in the pattern. Also checks that we handle autoderef. + let code = r#" + pub struct Bar {} + pub struct Bar2 {} + pub trait Foo { + fn foo(&self, _: i32) {} + } + impl Foo for Bar {} + impl Foo for Bar2 {} + fn main() { + let v1 = Bar {}; + let v2 = Bar2 {}; + let v1_ref = &v1; + let v2_ref = &v2; + v1.foo(1); + v2.foo(2); + Bar::foo(&v1, 3); + Bar2::foo(&v2, 4); + v1_ref.foo(5); + v2_ref.foo(6); + } + "#; + assert_matches("Bar::foo($a, $b)", code, &["v1.foo(1)", "Bar::foo(&v1, 3)", "v1_ref.foo(5)"]); + assert_matches("Bar2::foo($a, $b)", code, &["v2.foo(2)", "Bar2::foo(&v2, 4)", "v2_ref.foo(6)"]); +} -- cgit v1.2.3