From 7a0c93c58ac17b089edd8c9763fef303b7a81414 Mon Sep 17 00:00:00 2001 From: Florian Diebold Date: Sun, 23 May 2021 18:10:40 +0200 Subject: Infer correct expected type for generic struct fields --- crates/hir/src/lib.rs | 18 +++++++++++++++--- crates/hir_def/src/lib.rs | 8 ++++++++ crates/ide_completion/src/context.rs | 32 ++++++++++++++++---------------- crates/ide_completion/src/render.rs | 7 +++++++ 4 files changed, 46 insertions(+), 19 deletions(-) diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index a7c42ca1e..edee99356 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -513,9 +513,9 @@ impl Field { } /// Returns the type as in the signature of the struct (i.e., with - /// placeholder types for type parameters). This is good for showing - /// signature help, but not so good to actually get the type of the field - /// when you actually have a variable of the struct. + /// placeholder types for type parameters). Only use this in the context of + /// the field *definition*; if you've already got a variable of the struct + /// type, use `Type::field_type` to get to the field type. pub fn ty(&self, db: &dyn HirDatabase) -> Type { let var_id = self.parent.into(); let generic_def_id: GenericDefId = match self.parent { @@ -1944,6 +1944,18 @@ impl Type { } } + pub fn field_type(&self, db: &dyn HirDatabase, field: Field) -> Option { + let (adt_id, substs) = self.ty.as_adt()?; + let variant_id: hir_def::VariantId = field.parent.into(); + if variant_id.adt_id() != adt_id { + return None; + } + + let ty = db.field_types(variant_id).get(field.id)?.clone(); + let ty = ty.substitute(&Interner, substs); + Some(self.derived(ty)) + } + pub fn fields(&self, db: &dyn HirDatabase) -> Vec<(Field, Type)> { let (variant_id, substs) = match self.ty.kind(&Interner) { &TyKind::Adt(hir_ty::AdtId(AdtId::StructId(s)), ref substs) => (s.into(), substs), diff --git a/crates/hir_def/src/lib.rs b/crates/hir_def/src/lib.rs index a82ea5957..70001cac8 100644 --- a/crates/hir_def/src/lib.rs +++ b/crates/hir_def/src/lib.rs @@ -485,6 +485,14 @@ impl VariantId { VariantId::UnionId(it) => it.lookup(db).id.file_id(), } } + + pub fn adt_id(self) -> AdtId { + match self { + VariantId::EnumVariantId(it) => it.parent.into(), + VariantId::StructId(it) => it.into(), + VariantId::UnionId(it) => it.into(), + } + } } trait Intern { diff --git a/crates/ide_completion/src/context.rs b/crates/ide_completion/src/context.rs index c929d7394..4a88a6e88 100644 --- a/crates/ide_completion/src/context.rs +++ b/crates/ide_completion/src/context.rs @@ -337,25 +337,25 @@ impl<'a> CompletionContext<'a> { }, ast::RecordExprFieldList(_it) => { cov_mark::hit!(expected_type_struct_field_without_leading_char); - self.token.prev_sibling_or_token() - .and_then(|se| se.into_node()) - .and_then(|node| ast::RecordExprField::cast(node)) - .and_then(|rf| self.sema.resolve_record_field(&rf).zip(Some(rf))) - .map(|(f, rf)|( - Some(f.0.ty(self.db)), - rf.field_name().map(NameOrNameRef::NameRef), + // wouldn't try {} be nice... + (|| { + let record_ty = self.sema.type_of_expr(&ast::Expr::cast(node.parent()?)?)?; + let expr_field = self.token.prev_sibling_or_token()? + .into_node() + .and_then(|node| ast::RecordExprField::cast(node))?; + let field = self.sema.resolve_record_field(&expr_field)?.0; + Some(( + record_ty.field_type(self.db, field), + expr_field.field_name().map(NameOrNameRef::NameRef), )) - .unwrap_or((None, None)) + })().unwrap_or((None, None)) }, ast::RecordExprField(it) => { cov_mark::hit!(expected_type_struct_field_with_leading_char); - self.sema - .resolve_record_field(&it) - .map(|f|( - Some(f.0.ty(self.db)), - it.field_name().map(NameOrNameRef::NameRef), - )) - .unwrap_or((None, None)) + ( + it.expr().as_ref().and_then(|e| self.sema.type_of_expr(e)), + it.field_name().map(NameOrNameRef::NameRef), + ) }, ast::MatchExpr(it) => { cov_mark::hit!(expected_type_match_arm_without_leading_char); @@ -910,7 +910,7 @@ fn foo() -> u32 { } #[test] - fn expected_type_closure_param() { + fn expected_type_closure_param_return() { check_expected_type_and_name( r#" fn foo() { diff --git a/crates/ide_completion/src/render.rs b/crates/ide_completion/src/render.rs index 6b04ee164..d7f96b864 100644 --- a/crates/ide_completion/src/render.rs +++ b/crates/ide_completion/src/render.rs @@ -667,6 +667,13 @@ fn foo() { A { the$0 } } ), detail: "u32", deprecated: true, + relevance: CompletionRelevance { + exact_name_match: false, + type_match: Some( + CouldUnify, + ), + is_local: false, + }, }, ] "#]], -- cgit v1.2.3