From 6a7fc76b89dca4d1b4e3e50047183535aee98627 Mon Sep 17 00:00:00 2001 From: Florian Diebold Date: Fri, 17 Apr 2020 19:41:37 +0200 Subject: Fix type equality for dyn Trait Fixes a lot of false type mismatches. (And as always when touching the unification code, I have to say I'm looking forward to replacing it by Chalk's...) --- crates/ra_hir_ty/src/infer/coerce.rs | 4 ++-- crates/ra_hir_ty/src/infer/unify.rs | 42 +++++++++++++++++++++++++++++++++--- crates/ra_hir_ty/src/tests/traits.rs | 24 +++++++++++++++++++++ 3 files changed, 65 insertions(+), 5 deletions(-) (limited to 'crates') diff --git a/crates/ra_hir_ty/src/infer/coerce.rs b/crates/ra_hir_ty/src/infer/coerce.rs index 959b1e212..89200255a 100644 --- a/crates/ra_hir_ty/src/infer/coerce.rs +++ b/crates/ra_hir_ty/src/infer/coerce.rs @@ -51,7 +51,7 @@ impl<'a> InferenceContext<'a> { // Trivial cases, this should go after `never` check to // avoid infer result type to be never _ => { - if self.table.unify_inner_trivial(&from_ty, &to_ty) { + if self.table.unify_inner_trivial(&from_ty, &to_ty, 0) { return true; } } @@ -175,7 +175,7 @@ impl<'a> InferenceContext<'a> { return self.table.unify_substs(st1, st2, 0); } _ => { - if self.table.unify_inner_trivial(&derefed_ty, &to_ty) { + if self.table.unify_inner_trivial(&derefed_ty, &to_ty, 0) { return true; } } diff --git a/crates/ra_hir_ty/src/infer/unify.rs b/crates/ra_hir_ty/src/infer/unify.rs index 5f6cea8d3..ab0bc8b70 100644 --- a/crates/ra_hir_ty/src/infer/unify.rs +++ b/crates/ra_hir_ty/src/infer/unify.rs @@ -8,7 +8,8 @@ use test_utils::tested_by; use super::{InferenceContext, Obligation}; use crate::{ - BoundVar, Canonical, DebruijnIndex, InEnvironment, InferTy, Substs, Ty, TypeCtor, TypeWalk, + BoundVar, Canonical, DebruijnIndex, GenericPredicate, InEnvironment, InferTy, Substs, Ty, + TypeCtor, TypeWalk, }; impl<'a> InferenceContext<'a> { @@ -226,16 +227,26 @@ impl InferenceTable { (Ty::Apply(a_ty1), Ty::Apply(a_ty2)) if a_ty1.ctor == a_ty2.ctor => { self.unify_substs(&a_ty1.parameters, &a_ty2.parameters, depth + 1) } - _ => self.unify_inner_trivial(&ty1, &ty2), + + _ => self.unify_inner_trivial(&ty1, &ty2, depth), } } - pub(super) fn unify_inner_trivial(&mut self, ty1: &Ty, ty2: &Ty) -> bool { + pub(super) fn unify_inner_trivial(&mut self, ty1: &Ty, ty2: &Ty, depth: usize) -> bool { match (ty1, ty2) { (Ty::Unknown, _) | (_, Ty::Unknown) => true, (Ty::Placeholder(p1), Ty::Placeholder(p2)) if *p1 == *p2 => true, + (Ty::Dyn(dyn1), Ty::Dyn(dyn2)) if dyn1.len() == dyn2.len() => { + for (pred1, pred2) in dyn1.iter().zip(dyn2.iter()) { + if !self.unify_preds(pred1, pred2, depth + 1) { + return false; + } + } + true + } + (Ty::Infer(InferTy::TypeVar(tv1)), Ty::Infer(InferTy::TypeVar(tv2))) | (Ty::Infer(InferTy::IntVar(tv1)), Ty::Infer(InferTy::IntVar(tv2))) | (Ty::Infer(InferTy::FloatVar(tv1)), Ty::Infer(InferTy::FloatVar(tv2))) @@ -268,6 +279,31 @@ impl InferenceTable { } } + fn unify_preds( + &mut self, + pred1: &GenericPredicate, + pred2: &GenericPredicate, + depth: usize, + ) -> bool { + match (pred1, pred2) { + (GenericPredicate::Implemented(tr1), GenericPredicate::Implemented(tr2)) + if tr1.trait_ == tr2.trait_ => + { + self.unify_substs(&tr1.substs, &tr2.substs, depth + 1) + } + (GenericPredicate::Projection(proj1), GenericPredicate::Projection(proj2)) + if proj1.projection_ty.associated_ty == proj2.projection_ty.associated_ty => + { + self.unify_substs( + &proj1.projection_ty.parameters, + &proj2.projection_ty.parameters, + depth + 1, + ) && self.unify_inner(&proj1.ty, &proj2.ty, depth + 1) + } + _ => false, + } + } + /// If `ty` is a type variable with known type, returns that type; /// otherwise, return ty. pub fn resolve_ty_shallow<'b>(&mut self, ty: &'b Ty) -> Cow<'b, Ty> { diff --git a/crates/ra_hir_ty/src/tests/traits.rs b/crates/ra_hir_ty/src/tests/traits.rs index dc517fc4a..f6e3e07cd 100644 --- a/crates/ra_hir_ty/src/tests/traits.rs +++ b/crates/ra_hir_ty/src/tests/traits.rs @@ -2378,3 +2378,27 @@ fn main() { ); assert_eq!(t, "Foo"); } + +#[test] +fn trait_object_no_coercion() { + assert_snapshot!( + infer_with_mismatches(r#" +trait Foo {} + +fn foo(x: &dyn Foo) {} + +fn test(x: &dyn Foo) { + foo(x); +} +"#, true), + @r###" + [22; 23) 'x': &dyn Foo + [35; 37) '{}': () + [47; 48) 'x': &dyn Foo + [60; 75) '{ foo(x); }': () + [66; 69) 'foo': fn foo(&dyn Foo) + [66; 72) 'foo(x)': () + [70; 71) 'x': &dyn Foo + "### + ); +} -- cgit v1.2.3