From 978de5cf8bfd2ff82696fc8d5369b41e147431c3 Mon Sep 17 00:00:00 2001 From: Marcus Klaas de Vries Date: Tue, 8 Jan 2019 16:01:19 +0100 Subject: Implement type inference for enum variants --- crates/ra_hir/src/adt.rs | 105 +++++++++++++++++----- crates/ra_hir/src/code_model_api.rs | 34 ++++++- crates/ra_hir/src/code_model_impl/module.rs | 38 +++++++- crates/ra_hir/src/db.rs | 7 +- crates/ra_hir/src/ids.rs | 13 +-- crates/ra_hir/src/lib.rs | 2 +- crates/ra_hir/src/mock.rs | 1 + crates/ra_hir/src/ty.rs | 20 +++-- crates/ra_hir/src/ty/tests.rs | 16 ++++ crates/ra_hir/src/ty/tests/data/enum.txt | 4 + crates/ra_ide_api/src/completion/complete_path.rs | 22 +++-- crates/ra_ide_api/src/db.rs | 1 + 12 files changed, 218 insertions(+), 45 deletions(-) create mode 100644 crates/ra_hir/src/ty/tests/data/enum.txt (limited to 'crates') diff --git a/crates/ra_hir/src/adt.rs b/crates/ra_hir/src/adt.rs index d30390f25..f1b98cdd7 100644 --- a/crates/ra_hir/src/adt.rs +++ b/crates/ra_hir/src/adt.rs @@ -1,10 +1,19 @@ use std::sync::Arc; use ra_db::Cancelable; -use ra_syntax::ast::{self, NameOwner, StructFlavor, AstNode}; +use ra_syntax::{ + SyntaxNode, + ast::{self, NameOwner, StructFlavor, AstNode} +}; use crate::{ +<<<<<<< HEAD DefId, Name, AsName, Struct, Enum, HirDatabase, DefKind, +======= + DefId, DefLoc, Name, AsName, Struct, Enum, EnumVariant, + VariantData, StructField, HirDatabase, DefKind, + SourceItemId, +>>>>>>> 95ac72a3... Implement type inference for enum variants type_ref::TypeRef, }; @@ -45,33 +54,39 @@ impl StructData { } } -impl Enum { - pub(crate) fn new(def_id: DefId) -> Self { - Enum { def_id } - } +fn get_def_id( + db: &impl HirDatabase, + same_file_loc: &DefLoc, + node: &SyntaxNode, + expected_kind: DefKind, +) -> DefId { + let file_id = same_file_loc.source_item_id.file_id; + let file_items = db.file_items(file_id); + + let item_id = file_items.id_of(file_id, node); + let source_item_id = SourceItemId { + item_id: Some(item_id), + ..same_file_loc.source_item_id + }; + let loc = DefLoc { + kind: expected_kind, + source_item_id: source_item_id, + ..*same_file_loc + }; + loc.id(db) } #[derive(Debug, Clone, PartialEq, Eq)] pub struct EnumData { pub(crate) name: Option, - pub(crate) variants: Vec<(Name, Arc)>, + // TODO: keep track of names also since we already have them? + // then we won't need additional db lookups + pub(crate) variants: Option>, } impl EnumData { - fn new(enum_def: &ast::EnumDef) -> Self { + fn new(enum_def: &ast::EnumDef, variants: Option>) -> Self { let name = enum_def.name().map(|n| n.as_name()); - let variants = if let Some(evl) = enum_def.variant_list() { - evl.variants() - .map(|v| { - ( - v.name().map(|n| n.as_name()).unwrap_or_else(Name::missing), - Arc::new(VariantData::new(v.flavor())), - ) - }) - .collect() - } else { - Vec::new() - }; EnumData { name, variants } } @@ -83,7 +98,57 @@ impl EnumData { assert!(def_loc.kind == DefKind::Enum); let syntax = db.file_item(def_loc.source_item_id); let enum_def = ast::EnumDef::cast(&syntax).expect("enum def should point to EnumDef node"); - Ok(Arc::new(EnumData::new(enum_def))) + let variants = enum_def.variant_list().map(|vl| { + vl.variants() + .map(|ev| { + let def_id = get_def_id(db, &def_loc, ev.syntax(), DefKind::EnumVariant); + EnumVariant::new(def_id) + }) + .collect() + }); + Ok(Arc::new(EnumData::new(enum_def, variants))) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EnumVariantData { + pub(crate) name: Option, + pub(crate) variant_data: Arc, + pub(crate) parent_enum: Enum, +} + +impl EnumVariantData { + fn new(variant_def: &ast::EnumVariant, parent_enum: Enum) -> EnumVariantData { + let name = variant_def.name().map(|n| n.as_name()); + let variant_data = VariantData::new(variant_def.flavor()); + let variant_data = Arc::new(variant_data); + EnumVariantData { + name, + variant_data, + parent_enum, + } + } + + pub(crate) fn enum_variant_data_query( + db: &impl HirDatabase, + def_id: DefId, + ) -> Cancelable> { + let def_loc = def_id.loc(db); + assert!(def_loc.kind == DefKind::EnumVariant); + let syntax = db.file_item(def_loc.source_item_id); + let variant_def = ast::EnumVariant::cast(&syntax) + .expect("enum variant def should point to EnumVariant node"); + let enum_node = syntax + .parent() + .expect("enum variant should have enum variant list ancestor") + .parent() + .expect("enum variant list should have enum ancestor"); + let enum_def_id = get_def_id(db, &def_loc, enum_node, DefKind::Enum); + + Ok(Arc::new(EnumVariantData::new( + variant_def, + Enum::new(enum_def_id), + ))) } } diff --git a/crates/ra_hir/src/code_model_api.rs b/crates/ra_hir/src/code_model_api.rs index fa3e4baa7..c7d1bf0a6 100644 --- a/crates/ra_hir/src/code_model_api.rs +++ b/crates/ra_hir/src/code_model_api.rs @@ -44,6 +44,7 @@ pub enum Def { Module(Module), Struct(Struct), Enum(Enum), + EnumVariant(EnumVariant), Function(Function), Item, } @@ -188,6 +189,10 @@ pub struct Enum { } impl Enum { + pub(crate) fn new(def_id: DefId) -> Self { + Enum { def_id } + } + pub fn def_id(&self) -> DefId { self.def_id } @@ -196,11 +201,38 @@ impl Enum { Ok(db.enum_data(self.def_id)?.name.clone()) } - pub fn variants(&self, db: &impl HirDatabase) -> Cancelable)>> { + pub fn variants(&self, db: &impl HirDatabase) -> Cancelable>> { Ok(db.enum_data(self.def_id)?.variants.clone()) } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct EnumVariant { + pub(crate) def_id: DefId, +} + +impl EnumVariant { + pub(crate) fn new(def_id: DefId) -> Self { + EnumVariant { def_id } + } + + pub fn def_id(&self) -> DefId { + self.def_id + } + + pub fn parent_enum(&self, db: &impl HirDatabase) -> Cancelable { + Ok(db.enum_variant_data(self.def_id)?.parent_enum.clone()) + } + + pub fn name(&self, db: &impl HirDatabase) -> Cancelable> { + Ok(db.enum_variant_data(self.def_id)?.name.clone()) + } + + pub fn variant_data(&self, db: &impl HirDatabase) -> Cancelable> { + Ok(db.enum_variant_data(self.def_id)?.variant_data.clone()) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Function { pub(crate) def_id: DefId, diff --git a/crates/ra_hir/src/code_model_impl/module.rs b/crates/ra_hir/src/code_model_impl/module.rs index 1cb408cff..d7d62e863 100644 --- a/crates/ra_hir/src/code_model_impl/module.rs +++ b/crates/ra_hir/src/code_model_impl/module.rs @@ -13,6 +13,7 @@ impl Module { pub(crate) fn new(def_id: DefId) -> Self { crate::code_model_api::Module { def_id } } + pub(crate) fn from_module_id( db: &impl HirDatabase, source_root_id: SourceRootId, @@ -85,6 +86,7 @@ impl Module { let module_id = loc.module_id.crate_root(&module_tree); Module::from_module_id(db, loc.source_root_id, module_id) } + /// Finds a child module with the specified name. pub fn child_impl(&self, db: &impl HirDatabase, name: &Name) -> Cancelable> { let loc = self.def_id.loc(db); @@ -92,12 +94,14 @@ impl Module { let child_id = ctry!(loc.module_id.child(&module_tree, name)); Module::from_module_id(db, loc.source_root_id, child_id).map(Some) } + pub fn parent_impl(&self, db: &impl HirDatabase) -> Cancelable> { let loc = self.def_id.loc(db); let module_tree = db.module_tree(loc.source_root_id)?; let parent_id = ctry!(loc.module_id.parent(&module_tree)); Module::from_module_id(db, loc.source_root_id, parent_id).map(Some) } + /// Returns a `ModuleScope`: a set of items, visible in this module. pub fn scope_impl(&self, db: &impl HirDatabase) -> Cancelable { let loc = self.def_id.loc(db); @@ -105,6 +109,7 @@ impl Module { let res = item_map.per_module[&loc.module_id].clone(); Ok(res) } + pub fn resolve_path_impl( &self, db: &impl HirDatabase, @@ -126,7 +131,7 @@ impl Module { ); let segments = &path.segments; - for name in segments.iter() { + for (idx, name) in segments.iter().enumerate() { let curr = if let Some(r) = curr_per_ns.as_ref().take_types() { r } else { @@ -134,7 +139,35 @@ impl Module { }; let module = match curr.resolve(db)? { Def::Module(it) => it, - // TODO here would be the place to handle enum variants... + Def::Enum(e) => { + if segments.len() == idx + 1 { + // enum variant + let matching_variant = e.variants(db)?.map(|variants| { + variants + .into_iter() + // FIXME: replace by match lol + .find(|variant| { + variant + .name(db) + .map(|o| o.map(|ref n| n == name)) + .unwrap_or(Some(false)) + .unwrap_or(false) + }) + }); + + if let Some(Some(variant)) = matching_variant { + return Ok(PerNs::both(variant.def_id(), e.def_id())); + } else { + return Ok(PerNs::none()); + } + } else if segments.len() == idx { + // enum + return Ok(PerNs::types(e.def_id())); + } else { + // malformed enum? + return Ok(PerNs::none()); + } + } _ => return Ok(PerNs::none()), }; let scope = module.scope(db)?; @@ -146,6 +179,7 @@ impl Module { } Ok(curr_per_ns) } + pub fn problems_impl( &self, db: &impl HirDatabase, diff --git a/crates/ra_hir/src/db.rs b/crates/ra_hir/src/db.rs index 7dbe93f2b..9a6ef8083 100644 --- a/crates/ra_hir/src/db.rs +++ b/crates/ra_hir/src/db.rs @@ -12,7 +12,7 @@ use crate::{ module_tree::{ModuleId, ModuleTree}, nameres::{ItemMap, InputModuleItems}, ty::{InferenceResult, Ty}, - adt::{StructData, EnumData}, + adt::{StructData, EnumData, EnumVariantData}, impl_block::ModuleImplBlocks, }; @@ -47,6 +47,11 @@ pub trait HirDatabase: SyntaxDatabase use fn crate::adt::EnumData::enum_data_query; } + fn enum_variant_data(def_id: DefId) -> Cancelable> { + type EnumVariantDataQuery; + use fn crate::adt::EnumVariantData::enum_variant_data_query; + } + fn infer(def_id: DefId) -> Cancelable> { type InferQuery; use fn crate::ty::infer; diff --git a/crates/ra_hir/src/ids.rs b/crates/ra_hir/src/ids.rs index 0aa687a08..db0107e53 100644 --- a/crates/ra_hir/src/ids.rs +++ b/crates/ra_hir/src/ids.rs @@ -3,7 +3,7 @@ use ra_syntax::{TreePtr, SyntaxKind, SyntaxNode, SourceFile, AstNode, ast}; use ra_arena::{Arena, RawId, impl_arena_id}; use crate::{ - HirDatabase, PerNs, Def, Function, Struct, Enum, ImplBlock, Crate, + HirDatabase, PerNs, Def, Function, Struct, Enum, EnumVariant, ImplBlock, Crate, module_tree::ModuleId, }; @@ -145,6 +145,7 @@ pub(crate) enum DefKind { Function, Struct, Enum, + EnumVariant, Item, StructCtor, @@ -170,10 +171,8 @@ impl DefId { let struct_def = Struct::new(self); Def::Struct(struct_def) } - DefKind::Enum => { - let enum_def = Enum::new(self); - Def::Enum(enum_def) - } + DefKind::Enum => Def::Enum(Enum::new(self)), + DefKind::EnumVariant => Def::EnumVariant(EnumVariant::new(self)), DefKind::StructCtor => Def::Item, DefKind::Item => Def::Item, }; @@ -258,7 +257,9 @@ impl SourceFileItems { // change parent's id. This means that, say, adding a new function to a // trait does not chage ids of top-level items, which helps caching. bfs(source_file.syntax(), |it| { - if let Some(module_item) = ast::ModuleItem::cast(it) { + if let Some(enum_variant) = ast::EnumVariant::cast(it) { + self.alloc(enum_variant.syntax().to_owned()); + } else if let Some(module_item) = ast::ModuleItem::cast(it) { self.alloc(module_item.syntax().to_owned()); } else if let Some(macro_call) = ast::MacroCall::cast(it) { self.alloc(macro_call.syntax().to_owned()); diff --git a/crates/ra_hir/src/lib.rs b/crates/ra_hir/src/lib.rs index 1b6b72c98..74957ffc9 100644 --- a/crates/ra_hir/src/lib.rs +++ b/crates/ra_hir/src/lib.rs @@ -56,6 +56,6 @@ pub use self::code_model_api::{ Crate, CrateDependency, Def, Module, ModuleSource, Problem, - Struct, Enum, + Struct, Enum, EnumVariant, Function, FnSignature, }; diff --git a/crates/ra_hir/src/mock.rs b/crates/ra_hir/src/mock.rs index 7a0301648..6f93bb59d 100644 --- a/crates/ra_hir/src/mock.rs +++ b/crates/ra_hir/src/mock.rs @@ -233,6 +233,7 @@ salsa::database_storage! { fn type_for_field() for db::TypeForFieldQuery; fn struct_data() for db::StructDataQuery; fn enum_data() for db::EnumDataQuery; + fn enum_variant_data() for db::EnumVariantDataQuery; fn impls_in_module() for db::ImplsInModuleQuery; fn body_hir() for db::BodyHirQuery; fn body_syntax_mapping() for db::BodySyntaxMappingQuery; diff --git a/crates/ra_hir/src/ty.rs b/crates/ra_hir/src/ty.rs index eb7764f65..18c41a015 100644 --- a/crates/ra_hir/src/ty.rs +++ b/crates/ra_hir/src/ty.rs @@ -30,7 +30,7 @@ use join_to_string::join; use ra_db::Cancelable; use crate::{ - Def, DefId, Module, Function, Struct, Enum, Path, Name, ImplBlock, + Def, DefId, Module, Function, Struct, Enum, EnumVariant, Path, Name, ImplBlock, FnSignature, FnScopes, db::HirDatabase, type_ref::{TypeRef, Mutability}, @@ -453,6 +453,12 @@ pub fn type_for_enum(db: &impl HirDatabase, s: Enum) -> Cancelable { }) } +pub fn type_for_enum_variant(db: &impl HirDatabase, ev: EnumVariant) -> Cancelable { + let enum_parent = ev.parent_enum(db)?; + + type_for_enum(db, enum_parent) +} + pub(super) fn type_for_def(db: &impl HirDatabase, def_id: DefId) -> Cancelable { let def = def_id.resolve(db)?; match def { @@ -463,6 +469,7 @@ pub(super) fn type_for_def(db: &impl HirDatabase, def_id: DefId) -> Cancelable type_for_fn(db, f), Def::Struct(s) => type_for_struct(db, s), Def::Enum(e) => type_for_enum(db, e), + Def::EnumVariant(ev) => type_for_enum_variant(db, ev), Def::Item => { log::debug!("trying to get type for item of unknown type {:?}", def_id); Ok(Ty::Unknown) @@ -477,12 +484,9 @@ pub(super) fn type_for_field( ) -> Cancelable> { let def = def_id.resolve(db)?; let variant_data = match def { - Def::Struct(s) => { - let variant_data = s.variant_data(db)?; - variant_data - } + Def::Struct(s) => s.variant_data(db)?, + Def::EnumVariant(ev) => ev.variant_data(db)?, // TODO: unions - // TODO: enum variants _ => panic!( "trying to get type for field in non-struct/variant {:?}", def_id @@ -788,6 +792,10 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { let ty = type_for_struct(self.db, s)?; (ty, Some(def_id)) } + Def::EnumVariant(ev) => { + let ty = type_for_enum_variant(self.db, ev)?; + (ty, Some(def_id)) + } _ => (Ty::Unknown, None), }) } diff --git a/crates/ra_hir/src/ty/tests.rs b/crates/ra_hir/src/ty/tests.rs index ba2a44474..d8c0af326 100644 --- a/crates/ra_hir/src/ty/tests.rs +++ b/crates/ra_hir/src/ty/tests.rs @@ -94,6 +94,22 @@ fn test() { ); } +#[test] +fn infer_enum() { + check_inference( + r#" +enum E { + V1 { field: u32 }, + V2 +} +fn test() { + E::V1 { field: 1 }; + E::V2; +}"#, + "enum.txt", + ); +} + #[test] fn infer_refs() { check_inference( diff --git a/crates/ra_hir/src/ty/tests/data/enum.txt b/crates/ra_hir/src/ty/tests/data/enum.txt new file mode 100644 index 000000000..481eb0bc7 --- /dev/null +++ b/crates/ra_hir/src/ty/tests/data/enum.txt @@ -0,0 +1,4 @@ +[48; 82) '{ E:...:V2; }': () +[52; 70) 'E::V1 ...d: 1 }': E +[67; 68) '1': u32 +[74; 79) 'E::V2': E diff --git a/crates/ra_ide_api/src/completion/complete_path.rs b/crates/ra_ide_api/src/completion/complete_path.rs index 4723a65a6..6a55670d1 100644 --- a/crates/ra_ide_api/src/completion/complete_path.rs +++ b/crates/ra_ide_api/src/completion/complete_path.rs @@ -21,14 +21,20 @@ pub(super) fn complete_path(acc: &mut Completions, ctx: &CompletionContext) -> C .add_to(acc) }); } - hir::Def::Enum(e) => e - .variants(ctx.db)? - .into_iter() - .for_each(|(name, _variant)| { - CompletionItem::new(CompletionKind::Reference, name.to_string()) - .kind(CompletionItemKind::EnumVariant) - .add_to(acc) - }), + hir::Def::Enum(e) => { + e.variants(ctx.db)? + .unwrap_or(vec![]) + .into_iter() + .for_each(|variant| { + let variant_name = variant.name(ctx.db); + + if let Ok(Some(name)) = variant_name { + CompletionItem::new(CompletionKind::Reference, name.to_string()) + .kind(CompletionItemKind::EnumVariant) + .add_to(acc) + } + }) + } _ => return Ok(()), }; Ok(()) diff --git a/crates/ra_ide_api/src/db.rs b/crates/ra_ide_api/src/db.rs index a2e06f5db..efdf261be 100644 --- a/crates/ra_ide_api/src/db.rs +++ b/crates/ra_ide_api/src/db.rs @@ -122,6 +122,7 @@ salsa::database_storage! { fn type_for_field() for hir::db::TypeForFieldQuery; fn struct_data() for hir::db::StructDataQuery; fn enum_data() for hir::db::EnumDataQuery; + fn enum_variant_data() for hir::db::EnumVariantDataQuery; fn impls_in_module() for hir::db::ImplsInModuleQuery; fn body_hir() for hir::db::BodyHirQuery; fn body_syntax_mapping() for hir::db::BodySyntaxMappingQuery; -- cgit v1.2.3