From 9e4b5ecec4fa4f6a20bb4d47f09de602e9c29608 Mon Sep 17 00:00:00 2001 From: Florian Diebold Date: Sat, 19 Jan 2019 15:48:55 +0100 Subject: Make generics work in struct patterns --- crates/ra_hir/src/ty.rs | 61 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 17 deletions(-) (limited to 'crates/ra_hir/src/ty.rs') diff --git a/crates/ra_hir/src/ty.rs b/crates/ra_hir/src/ty.rs index 1d2d1b906..3608daae4 100644 --- a/crates/ra_hir/src/ty.rs +++ b/crates/ra_hir/src/ty.rs @@ -683,9 +683,9 @@ pub(super) fn type_for_def(db: &impl HirDatabase, def_id: DefId) -> Ty { pub(super) fn type_for_field(db: &impl HirDatabase, def_id: DefId, field: Name) -> Option { let def = def_id.resolve(db); - let variant_data = match def { - Def::Struct(s) => s.variant_data(db), - Def::EnumVariant(ev) => ev.variant_data(db), + let (variant_data, generics) = match def { + Def::Struct(s) => (s.variant_data(db), s.generics(db)), + Def::EnumVariant(ev) => (ev.variant_data(db), ev.parent_enum(db).generics(db)), // TODO: unions _ => panic!( "trying to get type for field in non-struct/variant {:?}", @@ -694,7 +694,6 @@ pub(super) fn type_for_field(db: &impl HirDatabase, def_id: DefId, field: Name) }; let module = def_id.module(db); let impl_block = def_id.impl_block(db); - let generics = db.generics(def_id); let type_ref = variant_data.get_field_type_ref(&field)?; Some(Ty::from_hir( db, @@ -893,6 +892,14 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { ty } + fn unify_substs(&mut self, substs1: &Substs, substs2: &Substs) -> bool { + substs1 + .0 + .iter() + .zip(substs2.0.iter()) + .all(|(t1, t2)| self.unify(t1, t2)) + } + fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool { // try to resolve type vars first let ty1 = self.resolve_ty_shallow(ty1); @@ -913,12 +920,16 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { (Ty::Bool, _) | (Ty::Str, _) | (Ty::Never, _) | (Ty::Char, _) => ty1 == ty2, ( Ty::Adt { - def_id: def_id1, .. + def_id: def_id1, + substs: substs1, + .. }, Ty::Adt { - def_id: def_id2, .. + def_id: def_id2, + substs: substs2, + .. }, - ) if def_id1 == def_id2 => true, + ) if def_id1 == def_id2 => self.unify_substs(substs1, substs2), (Ty::Slice(t1), Ty::Slice(t2)) => self.unify(t1, t2), (Ty::RawPtr(t1, m1), Ty::RawPtr(t2, m2)) if m1 == m2 => self.unify(t1, t2), (Ty::Ref(t1, m1), Ty::Ref(t2, m2)) if m1 == m2 => self.unify(t1, t2), @@ -1088,49 +1099,65 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { } } - fn resolve_fields(&self, path: Option<&Path>) -> Option<(Ty, Vec)> { - let def_id = self.module.resolve_path(self.db, path?).take_types()?; + fn resolve_fields(&mut self, path: Option<&Path>) -> Option<(Ty, Vec)> { + let (ty, def_id) = self.resolve_variant(path); + let def_id = def_id?; let def = def_id.resolve(self.db); match def { Def::Struct(s) => { let fields = s.fields(self.db); - Some((type_for_struct(self.db, s), fields)) + Some((ty, fields)) } Def::EnumVariant(ev) => { let fields = ev.fields(self.db); - Some((type_for_enum_variant(self.db, ev), fields)) + Some((ty, fields)) } _ => None, } } - fn infer_tuple_struct_pat(&mut self, path: Option<&Path>, subpats: &[PatId]) -> Ty { + fn infer_tuple_struct_pat( + &mut self, + path: Option<&Path>, + subpats: &[PatId], + expected: &Ty, + ) -> Ty { let (ty, fields) = self .resolve_fields(path) .unwrap_or((Ty::Unknown, Vec::new())); + self.unify(&ty, expected); + + let substs = ty.substs().expect("adt should have substs"); + for (i, &subpat) in subpats.iter().enumerate() { let expected_ty = fields .get(i) .and_then(|field| field.ty(self.db)) - .unwrap_or(Ty::Unknown); + .unwrap_or(Ty::Unknown) + .subst(&substs); self.infer_pat(subpat, &expected_ty); } ty } - fn infer_struct_pat(&mut self, path: Option<&Path>, subpats: &[FieldPat]) -> Ty { + fn infer_struct_pat(&mut self, path: Option<&Path>, subpats: &[FieldPat], expected: &Ty) -> Ty { let (ty, fields) = self .resolve_fields(path) .unwrap_or((Ty::Unknown, Vec::new())); + self.unify(&ty, expected); + + let substs = ty.substs().expect("adt should have substs"); + for subpat in subpats { let matching_field = fields.iter().find(|field| field.name() == &subpat.name); let expected_ty = matching_field .and_then(|field| field.ty(self.db)) - .unwrap_or(Ty::Unknown); + .unwrap_or(Ty::Unknown) + .subst(&substs); self.infer_pat(subpat.pat, &expected_ty); } @@ -1175,11 +1202,11 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { Pat::TupleStruct { path: ref p, args: ref subpats, - } => self.infer_tuple_struct_pat(p.as_ref(), subpats), + } => self.infer_tuple_struct_pat(p.as_ref(), subpats, expected), Pat::Struct { path: ref p, args: ref fields, - } => self.infer_struct_pat(p.as_ref(), fields), + } => self.infer_struct_pat(p.as_ref(), fields, expected), Pat::Path(path) => self .module .resolve_path(self.db, &path) -- cgit v1.2.3