From 9e4b5ecec4fa4f6a20bb4d47f09de602e9c29608 Mon Sep 17 00:00:00 2001 From: Florian Diebold Date: Sat, 19 Jan 2019 15:48:55 +0100 Subject: Make generics work in struct patterns --- crates/ra_hir/src/ty.rs | 61 ++++++++++++++++------ crates/ra_hir/src/ty/tests.rs | 26 +++++++++ .../src/ty/tests/data/generics_in_patterns.txt | 17 ++++++ 3 files changed, 87 insertions(+), 17 deletions(-) create mode 100644 crates/ra_hir/src/ty/tests/data/generics_in_patterns.txt (limited to 'crates/ra_hir/src') diff --git a/crates/ra_hir/src/ty.rs b/crates/ra_hir/src/ty.rs index 1d2d1b906..3608daae4 100644 --- a/crates/ra_hir/src/ty.rs +++ b/crates/ra_hir/src/ty.rs @@ -683,9 +683,9 @@ pub(super) fn type_for_def(db: &impl HirDatabase, def_id: DefId) -> Ty { pub(super) fn type_for_field(db: &impl HirDatabase, def_id: DefId, field: Name) -> Option { let def = def_id.resolve(db); - let variant_data = match def { - Def::Struct(s) => s.variant_data(db), - Def::EnumVariant(ev) => ev.variant_data(db), + let (variant_data, generics) = match def { + Def::Struct(s) => (s.variant_data(db), s.generics(db)), + Def::EnumVariant(ev) => (ev.variant_data(db), ev.parent_enum(db).generics(db)), // TODO: unions _ => panic!( "trying to get type for field in non-struct/variant {:?}", @@ -694,7 +694,6 @@ pub(super) fn type_for_field(db: &impl HirDatabase, def_id: DefId, field: Name) }; let module = def_id.module(db); let impl_block = def_id.impl_block(db); - let generics = db.generics(def_id); let type_ref = variant_data.get_field_type_ref(&field)?; Some(Ty::from_hir( db, @@ -893,6 +892,14 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { ty } + fn unify_substs(&mut self, substs1: &Substs, substs2: &Substs) -> bool { + substs1 + .0 + .iter() + .zip(substs2.0.iter()) + .all(|(t1, t2)| self.unify(t1, t2)) + } + fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool { // try to resolve type vars first let ty1 = self.resolve_ty_shallow(ty1); @@ -913,12 +920,16 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { (Ty::Bool, _) | (Ty::Str, _) | (Ty::Never, _) | (Ty::Char, _) => ty1 == ty2, ( Ty::Adt { - def_id: def_id1, .. + def_id: def_id1, + substs: substs1, + .. }, Ty::Adt { - def_id: def_id2, .. + def_id: def_id2, + substs: substs2, + .. }, - ) if def_id1 == def_id2 => true, + ) if def_id1 == def_id2 => self.unify_substs(substs1, substs2), (Ty::Slice(t1), Ty::Slice(t2)) => self.unify(t1, t2), (Ty::RawPtr(t1, m1), Ty::RawPtr(t2, m2)) if m1 == m2 => self.unify(t1, t2), (Ty::Ref(t1, m1), Ty::Ref(t2, m2)) if m1 == m2 => self.unify(t1, t2), @@ -1088,49 +1099,65 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { } } - fn resolve_fields(&self, path: Option<&Path>) -> Option<(Ty, Vec)> { - let def_id = self.module.resolve_path(self.db, path?).take_types()?; + fn resolve_fields(&mut self, path: Option<&Path>) -> Option<(Ty, Vec)> { + let (ty, def_id) = self.resolve_variant(path); + let def_id = def_id?; let def = def_id.resolve(self.db); match def { Def::Struct(s) => { let fields = s.fields(self.db); - Some((type_for_struct(self.db, s), fields)) + Some((ty, fields)) } Def::EnumVariant(ev) => { let fields = ev.fields(self.db); - Some((type_for_enum_variant(self.db, ev), fields)) + Some((ty, fields)) } _ => None, } } - fn infer_tuple_struct_pat(&mut self, path: Option<&Path>, subpats: &[PatId]) -> Ty { + fn infer_tuple_struct_pat( + &mut self, + path: Option<&Path>, + subpats: &[PatId], + expected: &Ty, + ) -> Ty { let (ty, fields) = self .resolve_fields(path) .unwrap_or((Ty::Unknown, Vec::new())); + self.unify(&ty, expected); + + let substs = ty.substs().expect("adt should have substs"); + for (i, &subpat) in subpats.iter().enumerate() { let expected_ty = fields .get(i) .and_then(|field| field.ty(self.db)) - .unwrap_or(Ty::Unknown); + .unwrap_or(Ty::Unknown) + .subst(&substs); self.infer_pat(subpat, &expected_ty); } ty } - fn infer_struct_pat(&mut self, path: Option<&Path>, subpats: &[FieldPat]) -> Ty { + fn infer_struct_pat(&mut self, path: Option<&Path>, subpats: &[FieldPat], expected: &Ty) -> Ty { let (ty, fields) = self .resolve_fields(path) .unwrap_or((Ty::Unknown, Vec::new())); + self.unify(&ty, expected); + + let substs = ty.substs().expect("adt should have substs"); + for subpat in subpats { let matching_field = fields.iter().find(|field| field.name() == &subpat.name); let expected_ty = matching_field .and_then(|field| field.ty(self.db)) - .unwrap_or(Ty::Unknown); + .unwrap_or(Ty::Unknown) + .subst(&substs); self.infer_pat(subpat.pat, &expected_ty); } @@ -1175,11 +1202,11 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { Pat::TupleStruct { path: ref p, args: ref subpats, - } => self.infer_tuple_struct_pat(p.as_ref(), subpats), + } => self.infer_tuple_struct_pat(p.as_ref(), subpats, expected), Pat::Struct { path: ref p, args: ref fields, - } => self.infer_struct_pat(p.as_ref(), fields), + } => self.infer_struct_pat(p.as_ref(), fields, expected), Pat::Path(path) => self .module .resolve_path(self.db, &path) diff --git a/crates/ra_hir/src/ty/tests.rs b/crates/ra_hir/src/ty/tests.rs index c590a09db..06e32df59 100644 --- a/crates/ra_hir/src/ty/tests.rs +++ b/crates/ra_hir/src/ty/tests.rs @@ -438,6 +438,32 @@ fn test(a1: A, i: i32) { ); } +#[test] +fn infer_generics_in_patterns() { + check_inference( + r#" +struct A { + x: T, +} + +enum Option { + Some(T), + None, +} + +fn test(a1: A, o: Option) { + let A { x: x2 } = a1; + let A:: { x: x3 } = A { x: 1 }; + match o { + Option::Some(t) => t, + _ => 1, + }; +} +"#, + "generics_in_patterns.txt", + ); +} + #[test] fn infer_function_generics() { check_inference( diff --git a/crates/ra_hir/src/ty/tests/data/generics_in_patterns.txt b/crates/ra_hir/src/ty/tests/data/generics_in_patterns.txt new file mode 100644 index 000000000..1b01ef19e --- /dev/null +++ b/crates/ra_hir/src/ty/tests/data/generics_in_patterns.txt @@ -0,0 +1,17 @@ +[79; 81) 'a1': A +[91; 92) 'o': Option +[107; 244) '{ ... }; }': () +[117; 128) 'A { x: x2 }': A +[124; 126) 'x2': u32 +[131; 133) 'a1': A +[143; 161) 'A:: +[157; 159) 'x3': i64 +[164; 174) 'A { x: 1 }': A +[171; 172) '1': i64 +[180; 241) 'match ... }': u64 +[186; 187) 'o': Option +[198; 213) 'Option::Some(t)': Option +[211; 212) 't': u64 +[217; 218) 't': u64 +[228; 229) '_': Option +[233; 234) '1': u64 -- cgit v1.2.3