From d5c3808545e26d246d75e0754e81de803f9e53e6 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Tue, 11 Feb 2020 18:24:43 +0200 Subject: Support trait method call autoimports --- crates/ra_assists/src/handlers/auto_import.rs | 306 +++++++++++++++++--------- crates/ra_hir/src/code_model.rs | 6 +- 2 files changed, 211 insertions(+), 101 deletions(-) (limited to 'crates') diff --git a/crates/ra_assists/src/handlers/auto_import.rs b/crates/ra_assists/src/handlers/auto_import.rs index a9778fab7..9a366414c 100644 --- a/crates/ra_assists/src/handlers/auto_import.rs +++ b/crates/ra_assists/src/handlers/auto_import.rs @@ -1,15 +1,17 @@ use ra_ide_db::{imports_locator::ImportsLocator, RootDatabase}; -use ra_syntax::ast::{self, AstNode}; +use ra_syntax::{ + ast::{self, AstNode}, + SyntaxNode, +}; use crate::{ assist_ctx::{Assist, AssistCtx}, insert_use_statement, AssistId, }; -use ast::{FnDefOwner, ModuleItem, ModuleItemOwner}; use hir::{ db::{DefDatabase, HirDatabase}, - Adt, AssocContainerId, Crate, Function, HasSource, InFile, ModPath, Module, ModuleDef, - PathResolution, SourceAnalyzer, SourceBinder, Trait, + AssocContainerId, AssocItem, Crate, Function, ModPath, Module, ModuleDef, PathResolution, + SourceAnalyzer, Trait, Type, }; use rustc_hash::FxHashSet; use std::collections::BTreeSet; @@ -34,36 +36,28 @@ use std::collections::BTreeSet; // # pub mod std { pub mod collections { pub struct HashMap { } } } // ``` pub(crate) fn auto_import(ctx: AssistCtx) -> Option { - let path_under_caret: ast::Path = ctx.find_node_at_offset()?; - if path_under_caret.syntax().ancestors().find_map(ast::UseItem::cast).is_some() { - return None; - } - - let module = path_under_caret.syntax().ancestors().find_map(ast::Module::cast); - let position = match module.and_then(|it| it.item_list()) { - Some(item_list) => item_list.syntax().clone(), - None => { - let current_file = - path_under_caret.syntax().ancestors().find_map(ast::SourceFile::cast)?; - current_file.syntax().clone() - } + let auto_import_assets = if let Some(path_under_caret) = ctx.find_node_at_offset::() + { + AutoImportAssets::for_regular_path(path_under_caret, &ctx)? + } else { + AutoImportAssets::for_method_call(ctx.find_node_at_offset()?, &ctx)? }; - let source_analyzer = ctx.source_analyzer(&position, None); - let module_with_name_to_import = source_analyzer.module()?; - let import_candidate = ImportCandidate::new(&path_under_caret, &source_analyzer, ctx.db)?; - let proposed_imports = import_candidate.search_for_imports(ctx.db, module_with_name_to_import); + let proposed_imports = auto_import_assets + .search_for_imports(ctx.db, auto_import_assets.module_with_name_to_import); if proposed_imports.is_empty() { return None; } - let mut group = ctx.add_assist_group(format!("Import {}", import_candidate.get_search_query())); + let mut group = + // TODO kb create another method and add something about traits there + ctx.add_assist_group(format!("Import {}", auto_import_assets.get_search_query())); for import in proposed_imports { group.add_assist(AssistId("auto_import"), format!("Import `{}`", &import), |edit| { - edit.target(path_under_caret.syntax().text_range()); + edit.target(auto_import_assets.syntax_under_caret.text_range()); insert_use_statement( - &position, - path_under_caret.syntax(), + &auto_import_assets.syntax_under_caret, + &auto_import_assets.syntax_under_caret, &import, edit.text_edit_builder(), ); @@ -72,64 +66,55 @@ pub(crate) fn auto_import(ctx: AssistCtx) -> Option { group.finish() } -#[derive(Debug)] -// TODO kb rustdocs -enum ImportCandidate { - UnqualifiedName(ast::NameRef), - QualifierStart(ast::NameRef), - TraitFunction(Adt, ast::PathSegment), +struct AutoImportAssets { + import_candidate: ImportCandidate, + module_with_name_to_import: Module, + syntax_under_caret: SyntaxNode, } -impl ImportCandidate { - // TODO kb refactor this mess - fn new( - path_under_caret: &ast::Path, - source_analyzer: &SourceAnalyzer, - db: &impl HirDatabase, - ) -> Option { - if source_analyzer.resolve_path(db, path_under_caret).is_some() { +impl AutoImportAssets { + fn for_method_call(method_call: ast::MethodCallExpr, ctx: &AssistCtx) -> Option { + let syntax_under_caret = method_call.syntax().to_owned(); + let source_analyzer = ctx.source_analyzer(&syntax_under_caret, None); + let module_with_name_to_import = source_analyzer.module()?; + Some(Self { + import_candidate: ImportCandidate::for_method_call( + &method_call, + &source_analyzer, + ctx.db, + )?, + module_with_name_to_import, + syntax_under_caret, + }) + } + + fn for_regular_path(path_under_caret: ast::Path, ctx: &AssistCtx) -> Option { + let syntax_under_caret = path_under_caret.syntax().to_owned(); + if syntax_under_caret.ancestors().find_map(ast::UseItem::cast).is_some() { return None; } - let segment = path_under_caret.segment()?; - if let Some(qualifier) = path_under_caret.qualifier() { - let qualifier_start = qualifier.syntax().descendants().find_map(ast::NameRef::cast)?; - let qualifier_start_path = - qualifier_start.syntax().ancestors().find_map(ast::Path::cast)?; - if let Some(qualifier_start_resolution) = - source_analyzer.resolve_path(db, &qualifier_start_path) - { - let qualifier_resolution = if &qualifier_start_path == path_under_caret { - qualifier_start_resolution - } else { - source_analyzer.resolve_path(db, &qualifier)? - }; - if let PathResolution::Def(ModuleDef::Adt(function_callee)) = qualifier_resolution { - Some(ImportCandidate::TraitFunction(function_callee, segment)) - } else { - None - } - } else { - Some(ImportCandidate::QualifierStart(qualifier_start)) - } - } else { - if source_analyzer.resolve_path(db, path_under_caret).is_none() { - Some(ImportCandidate::UnqualifiedName( - segment.syntax().descendants().find_map(ast::NameRef::cast)?, - )) - } else { - None - } - } + let source_analyzer = ctx.source_analyzer(&syntax_under_caret, None); + let module_with_name_to_import = source_analyzer.module()?; + Some(Self { + import_candidate: ImportCandidate::for_regular_path( + &path_under_caret, + &source_analyzer, + ctx.db, + )?, + module_with_name_to_import, + syntax_under_caret, + }) } fn get_search_query(&self) -> String { - match self { + match &self.import_candidate { ImportCandidate::UnqualifiedName(name_ref) | ImportCandidate::QualifierStart(name_ref) => name_ref.syntax().to_string(), ImportCandidate::TraitFunction(_, trait_function) => { trait_function.syntax().to_string() } + ImportCandidate::TraitMethod(_, trait_method) => trait_method.syntax().to_string(), } } @@ -141,7 +126,7 @@ impl ImportCandidate { ImportsLocator::new(db) .find_imports(&self.get_search_query()) .into_iter() - .map(|module_def| match self { + .map(|module_def| match &self.import_candidate { ImportCandidate::TraitFunction(function_callee, _) => { let mut applicable_traits = Vec::new(); if let ModuleDef::Function(located_function) = module_def { @@ -154,7 +139,7 @@ impl ImportCandidate { .map(|trait_candidate| trait_candidate.into()) .collect(); - function_callee.ty(db).iterate_path_candidates( + function_callee.iterate_path_candidates( db, module_with_name_to_import.krate(), &trait_candidates, @@ -172,6 +157,42 @@ impl ImportCandidate { } applicable_traits } + ImportCandidate::TraitMethod(function_callee, _) => { + let mut applicable_traits = Vec::new(); + if let ModuleDef::Function(located_function) = module_def { + let trait_candidates: FxHashSet<_> = Self::get_trait_candidates( + db, + located_function, + module_with_name_to_import.krate(), + ) + .into_iter() + .map(|trait_candidate| trait_candidate.into()) + .collect(); + + if !trait_candidates.is_empty() { + function_callee.iterate_method_candidates( + db, + module_with_name_to_import.krate(), + &trait_candidates, + None, + |_, funciton| { + if let AssocContainerId::TraitId(trait_id) = + funciton.container(db) + { + applicable_traits.push( + module_with_name_to_import.find_use_path( + db, + ModuleDef::Trait(trait_id.into()), + ), + ); + }; + None::<()> + }, + ); + } + } + applicable_traits + } _ => vec![module_with_name_to_import.find_use_path(db, module_def)], }) .flatten() @@ -186,7 +207,6 @@ impl ImportCandidate { called_function: Function, root_crate: Crate, ) -> FxHashSet { - let mut source_binder = SourceBinder::new(db); root_crate .dependencies(db) .into_iter() @@ -196,28 +216,22 @@ impl ImportCandidate { crate_def_map .modules .iter() - .filter_map(|(_, module_data)| module_data.declaration_source(db)) - .filter_map(|in_file_module| { - Some((in_file_module.file_id, in_file_module.value.item_list()?.items())) - }) - .map(|(file_id, item_list)| { - let mut if_file_trait_defs = Vec::new(); - for module_item in item_list { - if let ModuleItem::TraitDef(trait_def) = module_item { - if let Some(item_list) = trait_def.item_list() { - if item_list - .functions() - .any(|fn_def| fn_def == called_function.source(db).value) - { - if_file_trait_defs.push(InFile::new(file_id, trait_def)) - } + .map(|(_, module_data)| { + let mut traits = Vec::new(); + for module_def_id in module_data.scope.declarations() { + if let ModuleDef::Trait(trait_candidate) = module_def_id.into() { + if trait_candidate + .items(db) + .into_iter() + .any(|item| item == AssocItem::Function(called_function)) + { + traits.push(trait_candidate) } } } - if_file_trait_defs + traits }) .flatten() - .filter_map(|in_file_trait_def| source_binder.to_def(in_file_trait_def)) .collect::>() }) .flatten() @@ -225,6 +239,72 @@ impl ImportCandidate { } } +#[derive(Debug)] +// TODO kb rustdocs +enum ImportCandidate { + UnqualifiedName(ast::NameRef), + QualifierStart(ast::NameRef), + TraitFunction(Type, ast::PathSegment), + TraitMethod(Type, ast::NameRef), +} + +impl ImportCandidate { + fn for_method_call( + method_call: &ast::MethodCallExpr, + source_analyzer: &SourceAnalyzer, + db: &impl HirDatabase, + ) -> Option { + if source_analyzer.resolve_method_call(method_call).is_some() { + return None; + } + Some(Self::TraitMethod( + source_analyzer.type_of(db, &method_call.expr()?)?, + method_call.name_ref()?, + )) + } + + fn for_regular_path( + path_under_caret: &ast::Path, + source_analyzer: &SourceAnalyzer, + db: &impl HirDatabase, + ) -> Option { + if source_analyzer.resolve_path(db, path_under_caret).is_some() { + return None; + } + + let segment = path_under_caret.segment()?; + if let Some(qualifier) = path_under_caret.qualifier() { + let qualifier_start = qualifier.syntax().descendants().find_map(ast::NameRef::cast)?; + let qualifier_start_path = + qualifier_start.syntax().ancestors().find_map(ast::Path::cast)?; + if let Some(qualifier_start_resolution) = + source_analyzer.resolve_path(db, &qualifier_start_path) + { + let qualifier_resolution = if &qualifier_start_path == path_under_caret { + qualifier_start_resolution + } else { + source_analyzer.resolve_path(db, &qualifier)? + }; + if let PathResolution::Def(ModuleDef::Adt(function_callee)) = qualifier_resolution { + Some(ImportCandidate::TraitFunction(function_callee.ty(db), segment)) + } else { + None + } + } else { + Some(ImportCandidate::QualifierStart(qualifier_start)) + } + } else { + if source_analyzer.resolve_path(db, path_under_caret).is_none() { + Some(ImportCandidate::UnqualifiedName( + segment.syntax().descendants().find_map(ast::NameRef::cast)?, + )) + } else { + None + } + } + } +} + #[cfg(test)] mod tests { use crate::helpers::{check_assist, check_assist_not_applicable, check_assist_target}; @@ -525,32 +605,25 @@ mod tests { } #[test] - fn not_applicable_for_imported_trait() { + fn not_applicable_for_imported_trait_for_function() { check_assist_not_applicable( auto_import, r" mod test_mod { pub trait TestTrait { - fn test_method(&self); fn test_function(); } - pub trait TestTrait2 { - fn test_method(&self); fn test_function(); } pub enum TestEnum { One, Two, } - impl TestTrait2 for TestEnum { - fn test_method(&self) {} fn test_function() {} } - impl TestTrait for TestEnum { - fn test_method(&self) {} fn test_function() {} } } @@ -580,7 +653,7 @@ mod tests { fn main() { let test_struct = test_mod::TestStruct {}; - test_struct.test_method<|> + test_struct.test_meth<|>od() } ", r" @@ -598,9 +671,42 @@ mod tests { fn main() { let test_struct = test_mod::TestStruct {}; - test_struct.test_method<|> + test_struct.test_meth<|>od() } ", ); } + + #[test] + fn not_applicable_for_imported_trait_for_method() { + check_assist_not_applicable( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub trait TestTrait2 { + fn test_method(&self); + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + fn test_method(&self) {} + } + impl TestTrait for TestEnum { + fn test_method(&self) {} + } + } + + use test_mod::TestTrait2; + fn main() { + let one = test_mod::TestEnum::One; + one.test<|>_method(); + } + ", + ) + } } diff --git a/crates/ra_hir/src/code_model.rs b/crates/ra_hir/src/code_model.rs index 73158b8bd..140b3a87f 100644 --- a/crates/ra_hir/src/code_model.rs +++ b/crates/ra_hir/src/code_model.rs @@ -548,6 +548,10 @@ impl Function { let mut validator = ExprValidator::new(self.id, infer, sink); validator.validate_body(db); } + + pub fn container(self, db: &impl DefDatabase) -> AssocContainerId { + self.id.lookup(db).container + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -699,7 +703,7 @@ impl AssocItem { pub fn container(self, db: &impl DefDatabase) -> AssocContainerId { match self { - AssocItem::Function(f) => f.id.lookup(db).container, + AssocItem::Function(f) => f.container(db), AssocItem::Const(c) => c.id.lookup(db).container, AssocItem::TypeAlias(t) => t.id.lookup(db).container, } -- cgit v1.2.3