From 3340807bd24f398dca158e85eebae74012d8ef4b Mon Sep 17 00:00:00 2001 From: Marcus Klaas de Vries Date: Wed, 16 Jan 2019 20:26:58 +0100 Subject: Get basic struct pattern type inference working! --- crates/ra_hir/src/expr.rs | 39 ++++--- crates/ra_hir/src/ty.rs | 143 ++++++++++++++++-------- crates/ra_hir/src/ty/tests/data/adt_pattern.txt | 23 ++-- 3 files changed, 124 insertions(+), 81 deletions(-) (limited to 'crates') diff --git a/crates/ra_hir/src/expr.rs b/crates/ra_hir/src/expr.rs index c6d442ec4..893bad9cd 100644 --- a/crates/ra_hir/src/expr.rs +++ b/crates/ra_hir/src/expr.rs @@ -331,8 +331,8 @@ impl_arena_id!(PatId); #[derive(Debug, Clone, Eq, PartialEq)] pub struct FieldPat { - name: Name, - pat: Option, + pub(crate) name: Name, + pub(crate) pat: Option, } /// Close relative to rustc's hir::PatKind @@ -392,7 +392,9 @@ impl Pat { let total_iter = prefix.iter().chain(rest.iter()).chain(suffix.iter()); total_iter.map(|pat| *pat).for_each(f); } - Pat::Struct { .. } => {} // TODO + Pat::Struct { args, .. } => { + args.iter().filter_map(|a| a.pat).for_each(f); + } } } } @@ -814,23 +816,20 @@ impl ExprCollector { ast::PatKind::PlaceholderPat(_) => Pat::Wild, ast::PatKind::StructPat(p) => { let path = p.path().and_then(Path::from_ast); - - if let Some(field_list) = p.field_pat_list() { - let fields = field_list - .field_pats() - .into_iter() - .map(|f| FieldPat { - name: Name::new(f.ident), - pat: f.pat.as_ref().map(|p| self.collect_pat(p)), - }) - .collect(); - - Pat::Struct { - path: path, - args: fields, - } - } else { - Pat::Missing + let fields = p + .field_pat_list() + .expect("every struct should have a field list") + .field_pats() + .into_iter() + .map(|f| FieldPat { + name: Name::new(f.ident), + pat: f.pat.as_ref().map(|p| self.collect_pat(p)), + }) + .collect(); + + Pat::Struct { + path: path, + args: fields, } } diff --git a/crates/ra_hir/src/ty.rs b/crates/ra_hir/src/ty.rs index 3e1a4f02e..6bad61a2a 100644 --- a/crates/ra_hir/src/ty.rs +++ b/crates/ra_hir/src/ty.rs @@ -36,7 +36,7 @@ use crate::{ db::HirDatabase, type_ref::{TypeRef, Mutability}, name::KnownName, - expr::{Body, Expr, Literal, ExprId, Pat, PatId, UnaryOp, BinaryOp, Statement}, + expr::{Body, Expr, Literal, ExprId, Pat, PatId, UnaryOp, BinaryOp, Statement, FieldPat}, }; /// The ID of a type variable. @@ -872,6 +872,90 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { } } + fn resolve_fields(&self, path: Option<&Path>) -> Option<(Ty, Vec)> { + let def = path + .and_then(|path| self.module.resolve_path(self.db, &path).take_types()) + .map(|def_id| def_id.resolve(self.db)); + + let def = if let Some(def) = def { + def + } else { + return None; + }; + + match def { + Def::Struct(s) => { + let fields: Vec<_> = self + .db + .struct_data(s.def_id()) + .variant_data + .fields() + .iter() + .cloned() + .collect(); + Some((type_for_struct(self.db, s), fields)) + } + Def::EnumVariant(ev) => { + let fields: Vec<_> = ev.variant_data(self.db).fields().iter().cloned().collect(); + Some((type_for_enum_variant(self.db, ev), fields)) + } + _ => None, + } + } + + fn infer_tuple_struct(&mut self, path: Option<&Path>, sub_pats: &[PatId]) -> Ty { + let (ty, fields) = if let Some(x) = self.resolve_fields(path) { + x + } else { + return Ty::Unknown; + }; + + // walk subpats + if fields.len() != sub_pats.len() { + return Ty::Unknown; + } + + for (&sub_pat, field) in sub_pats.iter().zip(fields.iter()) { + let sub_ty = Ty::from_hir( + self.db, + &self.module, + self.impl_block.as_ref(), + &field.type_ref, + ); + + self.infer_pat(sub_pat, &Expectation::has_type(sub_ty)); + } + + ty + } + + fn infer_struct(&mut self, path: Option<&Path>, sub_pats: &[FieldPat]) -> Ty { + let (ty, fields) = if let Some(x) = self.resolve_fields(path) { + x + } else { + return Ty::Unknown; + }; + + for sub_pat in sub_pats { + let tyref = fields + .iter() + .find(|field| field.name == sub_pat.name) + .map(|field| &field.type_ref); + + if let Some(typeref) = tyref { + let sub_ty = Ty::from_hir(self.db, &self.module, self.impl_block.as_ref(), typeref); + + if let Some(pat) = sub_pat.pat { + self.infer_pat(pat, &Expectation::has_type(sub_ty)); + } else { + // TODO: deal with this case: S { x, y } + } + } + } + + ty + } + // FIXME: Expectation should probably contain a reference to a Ty instead of // a Ty itself fn infer_pat(&mut self, pat: PatId, expected: &Expectation) -> Ty { @@ -900,54 +984,15 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { path: ref p, args: ref sub_pats, }, - _expected, - ) => { - let def = p - .as_ref() - .and_then(|path| self.module.resolve_path(self.db, &path).take_types()) - .map(|def_id| def_id.resolve(self.db)); - - if let Some(def) = def { - let (ty, fields) = match def { - Def::Struct(s) => { - let fields: Vec<_> = self - .db - .struct_data(s.def_id()) - .variant_data - .fields() - .iter() - .cloned() - .collect(); - (type_for_struct(self.db, s), fields) - } - Def::EnumVariant(ev) => { - let fields: Vec<_> = - ev.variant_data(self.db).fields().iter().cloned().collect(); - (type_for_enum_variant(self.db, ev), fields) - } - _ => unreachable!(), - }; - // walk subpats - if fields.len() == sub_pats.len() { - for (&sub_pat, field) in sub_pats.iter().zip(fields.iter()) { - let sub_ty = Ty::from_hir( - self.db, - &self.module, - self.impl_block.as_ref(), - &field.type_ref, - ); - - self.infer_pat(sub_pat, &Expectation::has_type(sub_ty)); - } - - ty - } else { - expected.ty.clone() - } - } else { - expected.ty.clone() - } - } + _, + ) => self.infer_tuple_struct(p.as_ref(), sub_pats), + ( + &Pat::Struct { + path: ref p, + args: ref fields, + }, + _, + ) => self.infer_struct(p.as_ref(), fields), (_, ref _expected_ty) => expected.ty.clone(), }; // use a new type variable if we got Ty::Unknown here diff --git a/crates/ra_hir/src/ty/tests/data/adt_pattern.txt b/crates/ra_hir/src/ty/tests/data/adt_pattern.txt index d23b865a0..41e9c9d34 100644 --- a/crates/ra_hir/src/ty/tests/data/adt_pattern.txt +++ b/crates/ra_hir/src/ty/tests/data/adt_pattern.txt @@ -1,12 +1,11 @@ -[49; 192) '{ ... }; }': () -[59; 60) 'e': E -[63; 76) 'E::A { x: 3 }': E -[73; 74) '3': usize -[82; 124) 'if let... }': [unknown] -[105; 106) 'e': E -[107; 124) '{ ... }': [unknown] -[117; 118) 'x': [unknown] -[130; 189) 'match ... }': [unknown] -[136; 137) 'e': E -[162; 163) 'x': [unknown] -[181; 182) '1': i32 +[68; 155) '{ ...= e; }': () +[78; 79) 'e': E +[82; 95) 'E::A { x: 3 }': E +[92; 93) '3': usize +[106; 113) 'S(y, z)': S +[108; 109) 'y': u32 +[111; 112) 'z': E +[116; 119) 'foo': S +[129; 148) 'E::A {..._var }': E +[139; 146) 'new_var': usize +[151; 152) 'e': E -- cgit v1.2.3