From 33aa2f8e4f2b9c7c3a6b28427cb6d6f2aef7b802 Mon Sep 17 00:00:00 2001
From: Florian Diebold <florian.diebold@freiheit.com>
Date: Fri, 31 Jan 2020 15:57:44 +0100
Subject: Fix assoc type selection

---
 crates/ra_hir_ty/src/lower.rs        | 69 +++++++++++++++++++++---------------
 crates/ra_hir_ty/src/tests/traits.rs |  6 ++--
 2 files changed, 43 insertions(+), 32 deletions(-)

diff --git a/crates/ra_hir_ty/src/lower.rs b/crates/ra_hir_ty/src/lower.rs
index f1a11e073..5138019c7 100644
--- a/crates/ra_hir_ty/src/lower.rs
+++ b/crates/ra_hir_ty/src/lower.rs
@@ -10,7 +10,7 @@ use std::sync::Arc;
 
 use hir_def::{
     builtin_type::BuiltinType,
-    generics::{WherePredicateTarget, WherePredicate},
+    generics::{WherePredicate, WherePredicateTarget},
     path::{GenericArg, Path, PathSegment, PathSegments},
     resolver::{HasResolver, Resolver, TypeNs},
     type_ref::{TypeBound, TypeRef},
@@ -27,8 +27,8 @@ use crate::{
         all_super_traits, associated_type_by_name_including_super_traits, generics, make_mut_slice,
         variant_data,
     },
-    FnSig, GenericPredicate, ProjectionPredicate, ProjectionTy, Substs, TraitEnvironment, TraitRef,
-    Ty, TypeCtor, PolyFnSig, Binders,
+    Binders, FnSig, GenericPredicate, PolyFnSig, ProjectionPredicate, ProjectionTy, Substs,
+    TraitEnvironment, TraitRef, Ty, TypeCtor,
 };
 
 #[derive(Debug)]
@@ -62,7 +62,7 @@ impl<'a, DB: HirDatabase> TyLoweringContext<'a, DB> {
     }
 }
 
-#[derive(Copy, Clone, Debug)]
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 pub enum ImplTraitLoweringMode {
     /// `impl Trait` gets lowered into an opaque type that doesn't unify with
     /// anything except itself. This is used in places where values flow 'out',
@@ -78,7 +78,7 @@ pub enum ImplTraitLoweringMode {
     Disallowed,
 }
 
-#[derive(Copy, Clone, Debug)]
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 pub enum TypeParamLoweringMode {
     Placeholder,
     Variable,
@@ -140,12 +140,13 @@ impl Ty {
                     ImplTraitLoweringMode::Variable => {
                         let idx = ctx.impl_trait_counter.get();
                         ctx.impl_trait_counter.set(idx + 1);
-                        let (self_params, list_params, _impl_trait_params) = if let Some(def) = ctx.resolver.generic_def() {
-                            let generics = generics(ctx.db, def);
-                            generics.provenance_split()
-                        } else {
-                            (0, 0, 0)
-                        };
+                        let (self_params, list_params, _impl_trait_params) =
+                            if let Some(def) = ctx.resolver.generic_def() {
+                                let generics = generics(ctx.db, def);
+                                generics.provenance_split()
+                            } else {
+                                (0, 0, 0)
+                            };
                         // assert!((idx as usize) < impl_trait_params); // TODO return position impl trait
                         Ty::Bound(idx as u32 + self_params as u32 + list_params as u32)
                     }
@@ -251,7 +252,7 @@ impl Ty {
                         // FIXME: maybe return name in resolution?
                         let name = generics.param_name(param_id);
                         Ty::Param { idx, name }
-                    },
+                    }
                     TypeParamLoweringMode::Variable => Ty::Bound(idx),
                 }
             }
@@ -262,7 +263,7 @@ impl Ty {
                     TypeParamLoweringMode::Variable => Substs::bound_vars(&generics),
                 };
                 ctx.db.impl_self_ty(impl_id).subst(&substs)
-            },
+            }
             TypeNs::AdtSelfType(adt) => {
                 let generics = generics(ctx.db, adt.into());
                 let substs = match ctx.type_param_mode {
@@ -270,7 +271,7 @@ impl Ty {
                     TypeParamLoweringMode::Variable => Substs::bound_vars(&generics),
                 };
                 ctx.db.ty(adt.into()).subst(&substs)
-            },
+            }
 
             TypeNs::AdtId(it) => Ty::from_hir_path_inner(ctx, resolved_segment, it.into()),
             TypeNs::BuiltinType(it) => Ty::from_hir_path_inner(ctx, resolved_segment, it.into()),
@@ -309,7 +310,8 @@ impl Ty {
         segment: PathSegment<'_>,
     ) -> Ty {
         let param_idx = match self_ty {
-            Ty::Param { idx, .. } => idx,
+            Ty::Param { idx, .. } if ctx.type_param_mode == TypeParamLoweringMode::Placeholder => idx,
+            Ty::Bound(idx) if ctx.type_param_mode == TypeParamLoweringMode::Variable => idx,
             _ => return Ty::Unknown, // Error: Ambiguous associated type
         };
         let def = match ctx.resolver.generic_def() {
@@ -318,7 +320,14 @@ impl Ty {
         };
         let predicates = ctx.db.generic_predicates_for_param(def.into(), param_idx);
         let traits_from_env = predicates.iter().filter_map(|pred| match pred {
-            GenericPredicate::Implemented(tr) if tr.self_ty() == &self_ty => Some(tr.trait_),
+            GenericPredicate::Implemented(tr) => {
+                if let Ty::Param { idx, .. } = tr.self_ty() {
+                    if *idx == param_idx {
+                        return Some(tr.trait_);
+                    }
+                }
+                None
+            }
             _ => None,
         });
         let traits = traits_from_env.flat_map(|t| all_super_traits(ctx.db, t));
@@ -516,10 +525,10 @@ impl GenericPredicate {
                     TypeParamLoweringMode::Placeholder => {
                         let name = generics.param_name(param_id);
                         Ty::Param { idx, name }
-                    },
+                    }
                     TypeParamLoweringMode::Variable => Ty::Bound(idx),
                 }
-            },
+            }
         };
         GenericPredicate::from_type_bound(ctx, &where_predicate.bound, self_ty)
     }
@@ -615,7 +624,9 @@ pub(crate) fn generic_predicates_for_param_query(
         .where_predicates_in_scope()
         // we have to filter out all other predicates *first*, before attempting to lower them
         .filter(|pred| match &pred.target {
-            WherePredicateTarget::TypeRef(type_ref) => Ty::from_hir_only_param(&ctx, type_ref) == Some(param_idx),
+            WherePredicateTarget::TypeRef(type_ref) => {
+                Ty::from_hir_only_param(&ctx, type_ref) == Some(param_idx)
+            }
             WherePredicateTarget::TypeParam(local_id) => {
                 let param_id = hir_def::TypeParamId { parent: def, local_id: *local_id };
                 let idx = generics.param_idx(param_id);
@@ -701,8 +712,8 @@ fn type_for_const(db: &impl HirDatabase, def: ConstId) -> Binders<Ty> {
     let data = db.const_data(def);
     let generics = generics(db, def.into());
     let resolver = def.resolver(db);
-    let ctx = TyLoweringContext::new(db, &resolver)
-        .with_type_param_mode(TypeParamLoweringMode::Variable);
+    let ctx =
+        TyLoweringContext::new(db, &resolver).with_type_param_mode(TypeParamLoweringMode::Variable);
 
     Binders::new(generics.len(), Ty::from_hir(&ctx, &data.type_ref))
 }
@@ -731,8 +742,8 @@ fn fn_sig_for_struct_constructor(db: &impl HirDatabase, def: StructId) -> PolyFn
     let struct_data = db.struct_data(def.into());
     let fields = struct_data.variant_data.fields();
     let resolver = def.resolver(db);
-    let ctx = TyLoweringContext::new(db, &resolver)
-        .with_type_param_mode(TypeParamLoweringMode::Variable);
+    let ctx =
+        TyLoweringContext::new(db, &resolver).with_type_param_mode(TypeParamLoweringMode::Variable);
     let params =
         fields.iter().map(|(_, field)| Ty::from_hir(&ctx, &field.type_ref)).collect::<Vec<_>>();
     let ret = type_for_adt(db, def.into());
@@ -755,8 +766,8 @@ fn fn_sig_for_enum_variant_constructor(db: &impl HirDatabase, def: EnumVariantId
     let var_data = &enum_data.variants[def.local_id];
     let fields = var_data.variant_data.fields();
     let resolver = def.parent.resolver(db);
-    let ctx = TyLoweringContext::new(db, &resolver)
-        .with_type_param_mode(TypeParamLoweringMode::Variable);
+    let ctx =
+        TyLoweringContext::new(db, &resolver).with_type_param_mode(TypeParamLoweringMode::Variable);
     let params =
         fields.iter().map(|(_, field)| Ty::from_hir(&ctx, &field.type_ref)).collect::<Vec<_>>();
     let ret = type_for_adt(db, def.parent.into());
@@ -784,8 +795,8 @@ fn type_for_adt(db: &impl HirDatabase, adt: AdtId) -> Binders<Ty> {
 fn type_for_type_alias(db: &impl HirDatabase, t: TypeAliasId) -> Binders<Ty> {
     let generics = generics(db, t.into());
     let resolver = t.resolver(db);
-    let ctx = TyLoweringContext::new(db, &resolver)
-        .with_type_param_mode(TypeParamLoweringMode::Variable);
+    let ctx =
+        TyLoweringContext::new(db, &resolver).with_type_param_mode(TypeParamLoweringMode::Variable);
     let type_ref = &db.type_alias_data(t).type_ref;
     let substs = Substs::bound_vars(&generics);
     let inner = Ty::from_hir(&ctx, type_ref.as_ref().unwrap_or(&TypeRef::Error));
@@ -870,8 +881,8 @@ pub(crate) fn impl_self_ty_query(db: &impl HirDatabase, impl_id: ImplId) -> Bind
     let impl_data = db.impl_data(impl_id);
     let resolver = impl_id.resolver(db);
     let generics = generics(db, impl_id.into());
-    let ctx = TyLoweringContext::new(db, &resolver)
-        .with_type_param_mode(TypeParamLoweringMode::Variable);
+    let ctx =
+        TyLoweringContext::new(db, &resolver).with_type_param_mode(TypeParamLoweringMode::Variable);
     Binders::new(generics.len(), Ty::from_hir(&ctx, &impl_data.target_type))
 }
 
diff --git a/crates/ra_hir_ty/src/tests/traits.rs b/crates/ra_hir_ty/src/tests/traits.rs
index 9ff396ad5..e2351ca98 100644
--- a/crates/ra_hir_ty/src/tests/traits.rs
+++ b/crates/ra_hir_ty/src/tests/traits.rs
@@ -358,15 +358,15 @@ fn test() {
     [221; 223) '{}': ()
     [234; 300) '{     ...(S); }': ()
     [244; 245) 'x': u32
-    [248; 252) 'foo1': fn foo1<S>(T) -> <T as Iterable>::Item
+    [248; 252) 'foo1': fn foo1<S>(S) -> <S as Iterable>::Item
     [248; 255) 'foo1(S)': u32
     [253; 254) 'S': S
     [265; 266) 'y': u32
-    [269; 273) 'foo2': fn foo2<S>(T) -> <T as Iterable>::Item
+    [269; 273) 'foo2': fn foo2<S>(S) -> <S as Iterable>::Item
     [269; 276) 'foo2(S)': u32
     [274; 275) 'S': S
     [286; 287) 'z': u32
-    [290; 294) 'foo3': fn foo3<S>(T) -> <T as Iterable>::Item
+    [290; 294) 'foo3': fn foo3<S>(S) -> <S as Iterable>::Item
     [290; 297) 'foo3(S)': u32
     [295; 296) 'S': S
     "###
-- 
cgit v1.2.3