aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Lattimore <[email protected]>2020-08-05 10:48:52 +0100
committerDavid Lattimore <[email protected]>2020-08-13 11:24:55 +0100
commit3100de842b3cc33c9ad364f10c7f740ac760f564 (patch)
tree7713e2aea0b47ec8141fcefba101871137137c09
parentde1d93455f85747410efb69c28e0c1379e8e328a (diff)
Structured search replace now handles UFCS calls to trait methods
-rw-r--r--crates/ra_ssr/src/matching.rs67
-rw-r--r--crates/ra_ssr/src/resolving.rs33
-rw-r--r--crates/ra_ssr/src/tests.rs29
3 files changed, 114 insertions, 15 deletions
diff --git a/crates/ra_ssr/src/matching.rs b/crates/ra_ssr/src/matching.rs
index 125bf3895..6e0b92352 100644
--- a/crates/ra_ssr/src/matching.rs
+++ b/crates/ra_ssr/src/matching.rs
@@ -3,7 +3,7 @@
3 3
4use crate::{ 4use crate::{
5 parsing::{Constraint, NodeKind, Placeholder}, 5 parsing::{Constraint, NodeKind, Placeholder},
6 resolving::{ResolvedPattern, ResolvedRule}, 6 resolving::{ResolvedPattern, ResolvedRule, UfcsCallInfo},
7 SsrMatches, 7 SsrMatches,
8}; 8};
9use hir::Semantics; 9use hir::Semantics;
@@ -190,11 +190,12 @@ impl<'db, 'sema> Matcher<'db, 'sema> {
190 return Ok(()); 190 return Ok(());
191 } 191 }
192 // We allow a UFCS call to match a method call, provided they resolve to the same function. 192 // We allow a UFCS call to match a method call, provided they resolve to the same function.
193 if let Some(pattern_function) = self.rule.pattern.ufcs_function_calls.get(pattern) { 193 if let Some(pattern_ufcs) = self.rule.pattern.ufcs_function_calls.get(pattern) {
194 if let (Some(pattern), Some(code)) = 194 if let Some(code) = ast::MethodCallExpr::cast(code.clone()) {
195 (ast::CallExpr::cast(pattern.clone()), ast::MethodCallExpr::cast(code.clone())) 195 return self.attempt_match_ufcs_to_method_call(phase, pattern_ufcs, &code);
196 { 196 }
197 return self.attempt_match_ufcs(phase, &pattern, &code, *pattern_function); 197 if let Some(code) = ast::CallExpr::cast(code.clone()) {
198 return self.attempt_match_ufcs_to_ufcs(phase, pattern_ufcs, &code);
198 } 199 }
199 } 200 }
200 if pattern.kind() != code.kind() { 201 if pattern.kind() != code.kind() {
@@ -521,23 +522,28 @@ impl<'db, 'sema> Matcher<'db, 'sema> {
521 Ok(()) 522 Ok(())
522 } 523 }
523 524
524 fn attempt_match_ufcs( 525 fn attempt_match_ufcs_to_method_call(
525 &self, 526 &self,
526 phase: &mut Phase, 527 phase: &mut Phase,
527 pattern: &ast::CallExpr, 528 pattern_ufcs: &UfcsCallInfo,
528 code: &ast::MethodCallExpr, 529 code: &ast::MethodCallExpr,
529 pattern_function: hir::Function,
530 ) -> Result<(), MatchFailed> { 530 ) -> Result<(), MatchFailed> {
531 use ast::ArgListOwner; 531 use ast::ArgListOwner;
532 let code_resolved_function = self 532 let code_resolved_function = self
533 .sema 533 .sema
534 .resolve_method_call(code) 534 .resolve_method_call(code)
535 .ok_or_else(|| match_error!("Failed to resolve method call"))?; 535 .ok_or_else(|| match_error!("Failed to resolve method call"))?;
536 if pattern_function != code_resolved_function { 536 if pattern_ufcs.function != code_resolved_function {
537 fail_match!("Method call resolved to a different function"); 537 fail_match!("Method call resolved to a different function");
538 } 538 }
539 if code_resolved_function.has_self_param(self.sema.db) {
540 if let (Some(pattern_type), Some(expr)) = (&pattern_ufcs.qualifier_type, &code.expr()) {
541 self.check_expr_type(pattern_type, expr)?;
542 }
543 }
539 // Check arguments. 544 // Check arguments.
540 let mut pattern_args = pattern 545 let mut pattern_args = pattern_ufcs
546 .call_expr
541 .arg_list() 547 .arg_list()
542 .ok_or_else(|| match_error!("Pattern function call has no args"))? 548 .ok_or_else(|| match_error!("Pattern function call has no args"))?
543 .args(); 549 .args();
@@ -552,6 +558,45 @@ impl<'db, 'sema> Matcher<'db, 'sema> {
552 } 558 }
553 } 559 }
554 560
561 fn attempt_match_ufcs_to_ufcs(
562 &self,
563 phase: &mut Phase,
564 pattern_ufcs: &UfcsCallInfo,
565 code: &ast::CallExpr,
566 ) -> Result<(), MatchFailed> {
567 use ast::ArgListOwner;
568 // Check that the first argument is the expected type.
569 if let (Some(pattern_type), Some(expr)) = (
570 &pattern_ufcs.qualifier_type,
571 &code.arg_list().and_then(|code_args| code_args.args().next()),
572 ) {
573 self.check_expr_type(pattern_type, expr)?;
574 }
575 self.attempt_match_node_children(phase, pattern_ufcs.call_expr.syntax(), code.syntax())
576 }
577
578 fn check_expr_type(
579 &self,
580 pattern_type: &hir::Type,
581 expr: &ast::Expr,
582 ) -> Result<(), MatchFailed> {
583 use hir::HirDisplay;
584 let code_type = self.sema.type_of_expr(&expr).ok_or_else(|| {
585 match_error!("Failed to get receiver type for `{}`", expr.syntax().text())
586 })?;
587 if !code_type
588 .autoderef(self.sema.db)
589 .any(|deref_code_type| *pattern_type == deref_code_type)
590 {
591 fail_match!(
592 "Pattern type `{}` didn't match code type `{}`",
593 pattern_type.display(self.sema.db),
594 code_type.display(self.sema.db)
595 );
596 }
597 Ok(())
598 }
599
555 fn get_placeholder(&self, element: &SyntaxElement) -> Option<&Placeholder> { 600 fn get_placeholder(&self, element: &SyntaxElement) -> Option<&Placeholder> {
556 only_ident(element.clone()).and_then(|ident| self.rule.get_placeholder(&ident)) 601 only_ident(element.clone()).and_then(|ident| self.rule.get_placeholder(&ident))
557 } 602 }
diff --git a/crates/ra_ssr/src/resolving.rs b/crates/ra_ssr/src/resolving.rs
index 7e7585c8b..bfc20705b 100644
--- a/crates/ra_ssr/src/resolving.rs
+++ b/crates/ra_ssr/src/resolving.rs
@@ -25,7 +25,7 @@ pub(crate) struct ResolvedPattern {
25 pub(crate) node: SyntaxNode, 25 pub(crate) node: SyntaxNode,
26 // Paths in `node` that we've resolved. 26 // Paths in `node` that we've resolved.
27 pub(crate) resolved_paths: FxHashMap<SyntaxNode, ResolvedPath>, 27 pub(crate) resolved_paths: FxHashMap<SyntaxNode, ResolvedPath>,
28 pub(crate) ufcs_function_calls: FxHashMap<SyntaxNode, hir::Function>, 28 pub(crate) ufcs_function_calls: FxHashMap<SyntaxNode, UfcsCallInfo>,
29 pub(crate) contains_self: bool, 29 pub(crate) contains_self: bool,
30} 30}
31 31
@@ -35,6 +35,12 @@ pub(crate) struct ResolvedPath {
35 pub(crate) depth: u32, 35 pub(crate) depth: u32,
36} 36}
37 37
38pub(crate) struct UfcsCallInfo {
39 pub(crate) call_expr: ast::CallExpr,
40 pub(crate) function: hir::Function,
41 pub(crate) qualifier_type: Option<hir::Type>,
42}
43
38impl ResolvedRule { 44impl ResolvedRule {
39 pub(crate) fn new( 45 pub(crate) fn new(
40 rule: parsing::ParsedRule, 46 rule: parsing::ParsedRule,
@@ -70,6 +76,7 @@ struct Resolver<'a, 'db> {
70 76
71impl Resolver<'_, '_> { 77impl Resolver<'_, '_> {
72 fn resolve_pattern_tree(&self, pattern: SyntaxNode) -> Result<ResolvedPattern, SsrError> { 78 fn resolve_pattern_tree(&self, pattern: SyntaxNode) -> Result<ResolvedPattern, SsrError> {
79 use syntax::ast::AstNode;
73 use syntax::{SyntaxElement, T}; 80 use syntax::{SyntaxElement, T};
74 let mut resolved_paths = FxHashMap::default(); 81 let mut resolved_paths = FxHashMap::default();
75 self.resolve(pattern.clone(), 0, &mut resolved_paths)?; 82 self.resolve(pattern.clone(), 0, &mut resolved_paths)?;
@@ -77,11 +84,15 @@ impl Resolver<'_, '_> {
77 .iter() 84 .iter()
78 .filter_map(|(path_node, resolved)| { 85 .filter_map(|(path_node, resolved)| {
79 if let Some(grandparent) = path_node.parent().and_then(|parent| parent.parent()) { 86 if let Some(grandparent) = path_node.parent().and_then(|parent| parent.parent()) {
80 if grandparent.kind() == SyntaxKind::CALL_EXPR { 87 if let Some(call_expr) = ast::CallExpr::cast(grandparent.clone()) {
81 if let hir::PathResolution::AssocItem(hir::AssocItem::Function(function)) = 88 if let hir::PathResolution::AssocItem(hir::AssocItem::Function(function)) =
82 &resolved.resolution 89 resolved.resolution
83 { 90 {
84 return Some((grandparent, *function)); 91 let qualifier_type = self.resolution_scope.qualifier_type(path_node);
92 return Some((
93 grandparent,
94 UfcsCallInfo { call_expr, function, qualifier_type },
95 ));
85 } 96 }
86 } 97 }
87 } 98 }
@@ -226,6 +237,20 @@ impl<'db> ResolutionScope<'db> {
226 None 237 None
227 } 238 }
228 } 239 }
240
241 fn qualifier_type(&self, path: &SyntaxNode) -> Option<hir::Type> {
242 use syntax::ast::AstNode;
243 if let Some(path) = ast::Path::cast(path.clone()) {
244 if let Some(qualifier) = path.qualifier() {
245 if let Some(resolved_qualifier) = self.resolve_path(&qualifier) {
246 if let hir::PathResolution::Def(hir::ModuleDef::Adt(adt)) = resolved_qualifier {
247 return Some(adt.ty(self.scope.db));
248 }
249 }
250 }
251 }
252 None
253 }
229} 254}
230 255
231fn is_self(path: &ast::Path) -> bool { 256fn is_self(path: &ast::Path) -> bool {
diff --git a/crates/ra_ssr/src/tests.rs b/crates/ra_ssr/src/tests.rs
index 7d4d470c0..4bc09c1e4 100644
--- a/crates/ra_ssr/src/tests.rs
+++ b/crates/ra_ssr/src/tests.rs
@@ -1143,3 +1143,32 @@ fn replace_self() {
1143 "#]], 1143 "#]],
1144 ); 1144 );
1145} 1145}
1146
1147#[test]
1148fn match_trait_method_call() {
1149 // `Bar::foo` and `Bar2::foo` resolve to the same function. Make sure we only match if the type
1150 // matches what's in the pattern. Also checks that we handle autoderef.
1151 let code = r#"
1152 pub struct Bar {}
1153 pub struct Bar2 {}
1154 pub trait Foo {
1155 fn foo(&self, _: i32) {}
1156 }
1157 impl Foo for Bar {}
1158 impl Foo for Bar2 {}
1159 fn main() {
1160 let v1 = Bar {};
1161 let v2 = Bar2 {};
1162 let v1_ref = &v1;
1163 let v2_ref = &v2;
1164 v1.foo(1);
1165 v2.foo(2);
1166 Bar::foo(&v1, 3);
1167 Bar2::foo(&v2, 4);
1168 v1_ref.foo(5);
1169 v2_ref.foo(6);
1170 }
1171 "#;
1172 assert_matches("Bar::foo($a, $b)", code, &["v1.foo(1)", "Bar::foo(&v1, 3)", "v1_ref.foo(5)"]);
1173 assert_matches("Bar2::foo($a, $b)", code, &["v2.foo(2)", "Bar2::foo(&v2, 4)", "v2_ref.foo(6)"]);
1174}