From c6ddb907144688ae77a6de3666159feef53638e1 Mon Sep 17 00:00:00 2001 From: adamrk Date: Sat, 22 Aug 2020 20:11:37 +0200 Subject: Add references to fn args during completion --- crates/hir/src/code_model.rs | 34 ++++++++-- crates/hir_ty/src/db.rs | 3 + crates/hir_ty/src/infer.rs | 15 ++++- crates/hir_ty/src/lib.rs | 2 +- crates/ide/src/completion/presentation.rs | 107 +++++++++++++++++++++++++++++- 5 files changed, 151 insertions(+), 10 deletions(-) diff --git a/crates/hir/src/code_model.rs b/crates/hir/src/code_model.rs index c2fc819e7..f182ab228 100644 --- a/crates/hir/src/code_model.rs +++ b/crates/hir/src/code_model.rs @@ -708,12 +708,24 @@ impl Function { Some(SelfParam { func: self.id }) } - pub fn params(self, db: &dyn HirDatabase) -> Vec { + pub fn params(self, db: &dyn HirDatabase) -> Vec { + let resolver = self.id.resolver(db.upcast()); + let ctx = hir_ty::TyLoweringContext::new(db, &resolver); + let environment = TraitEnvironment::lower(db, &resolver); db.function_data(self.id) .params .iter() .skip(if self.self_param(db).is_some() { 1 } else { 0 }) - .map(|_| Param { _ty: () }) + .map(|type_ref| { + let ty = Type { + krate: self.id.lookup(db.upcast()).container.module(db.upcast()).krate, + ty: InEnvironment { + value: Ty::from_hir_ext(&ctx, type_ref).0, + environment: environment.clone(), + }, + }; + ty + }) .collect() } @@ -747,10 +759,6 @@ pub struct SelfParam { func: FunctionId, } -pub struct Param { - _ty: (), -} - impl SelfParam { pub fn access(self, db: &dyn HirDatabase) -> Access { let func_data = db.function_data(self.func); @@ -1100,6 +1108,12 @@ impl Local { ast.map_left(|it| it.cast().unwrap().to_node(&root)).map_right(|it| it.to_node(&root)) }) } + + pub fn can_unify(self, other: Type, db: &dyn HirDatabase) -> bool { + let def = DefWithBodyId::from(self.parent); + let infer = db.infer(def); + db.can_unify(def, infer[self.pat_id].clone(), other.ty.value) + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -1276,6 +1290,14 @@ impl Type { ) } + pub fn remove_ref(&self) -> Option { + if let Ty::Apply(ApplicationTy { ctor: TypeCtor::Ref(_), .. }) = self.ty.value { + self.ty.value.substs().map(|substs| self.derived(substs[0].clone())) + } else { + None + } + } + pub fn is_unknown(&self) -> bool { matches!(self.ty.value, Ty::Unknown) } diff --git a/crates/hir_ty/src/db.rs b/crates/hir_ty/src/db.rs index 25cf9eb7f..57e60c53b 100644 --- a/crates/hir_ty/src/db.rs +++ b/crates/hir_ty/src/db.rs @@ -26,6 +26,9 @@ pub trait HirDatabase: DefDatabase + Upcast { #[salsa::invoke(crate::infer::infer_query)] fn infer_query(&self, def: DefWithBodyId) -> Arc; + #[salsa::invoke(crate::infer::can_unify)] + fn can_unify(&self, def: DefWithBodyId, ty1: Ty, ty2: Ty) -> bool; + #[salsa::invoke(crate::lower::ty_query)] #[salsa::cycle(crate::lower::ty_recover)] fn ty(&self, def: TyDefId) -> Binders; diff --git a/crates/hir_ty/src/infer.rs b/crates/hir_ty/src/infer.rs index 03b00b101..d461e077b 100644 --- a/crates/hir_ty/src/infer.rs +++ b/crates/hir_ty/src/infer.rs @@ -55,7 +55,7 @@ macro_rules! ty_app { }; } -mod unify; +pub mod unify; mod path; mod expr; mod pat; @@ -78,6 +78,19 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc bool { + let resolver = def.resolver(db.upcast()); + let mut ctx = InferenceContext::new(db, def, resolver); + + let ty1 = ctx.canonicalizer().canonicalize_ty(ty1).value; + let ty2 = ctx.canonicalizer().canonicalize_ty(ty2).value; + let mut kinds = Vec::from(ty1.kinds.to_vec()); + kinds.extend_from_slice(ty2.kinds.as_ref()); + let tys = crate::Canonical::new((ty1.value, ty2.value), kinds); + + unify(&tys).is_some() +} + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] enum ExprOrPatId { ExprId(ExprId), diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 1e748476a..681f98bde 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -43,7 +43,7 @@ use crate::{ }; pub use autoderef::autoderef; -pub use infer::{InferTy, InferenceResult}; +pub use infer::{unify, InferTy, InferenceResult}; pub use lower::CallableDefId; pub use lower::{ associated_type_shorthand_candidates, callable_item_sig, ImplTraitLoweringMode, TyDefId, diff --git a/crates/ide/src/completion/presentation.rs b/crates/ide/src/completion/presentation.rs index 24c507f9b..cfcb6dfa1 100644 --- a/crates/ide/src/completion/presentation.rs +++ b/crates/ide/src/completion/presentation.rs @@ -191,6 +191,22 @@ impl Completions { func: hir::Function, local_name: Option, ) { + fn add_arg(arg: &str, ty: &Type, ctx: &CompletionContext) -> String { + let mut prefix = ""; + if let Some(derefed_ty) = ty.remove_ref() { + ctx.scope.process_all_names(&mut |name, scope| { + if prefix != "" { + return; + } + if let ScopeDef::Local(local) = scope { + if name.to_string() == arg && local.can_unify(derefed_ty.clone(), ctx.db) { + prefix = if ty.is_mutable_reference() { "&mut " } else { "&" }; + } + } + }); + } + prefix.to_string() + arg + }; let name = local_name.unwrap_or_else(|| func.name(ctx.db).to_string()); let ast_node = func.source(ctx.db).value; @@ -205,12 +221,20 @@ impl Completions { .set_deprecated(is_deprecated(func, ctx.db)) .detail(function_declaration(&ast_node)); + let params_ty = func.params(ctx.db); let params = ast_node .param_list() .into_iter() .flat_map(|it| it.params()) - .flat_map(|it| it.pat()) - .map(|pat| pat.to_string().trim_start_matches('_').into()) + .zip(params_ty) + .flat_map(|(it, param_ty)| { + if let Some(pat) = it.pat() { + let name = pat.to_string(); + let arg = name.trim_start_matches('_'); + return Some(add_arg(arg, ¶m_ty, ctx)); + } + None + }) .collect(); builder = builder.add_call_parens(ctx, name, Params::Named(params)); @@ -863,6 +887,85 @@ fn main() { foo(${1:foo}, ${2:bar}, ${3:ho_ge_})$0 } ); } + #[test] + fn insert_ref_when_matching_local_in_scope() { + check_edit( + "ref_arg", + r#" +struct Foo {} +fn ref_arg(x: &Foo) {} +fn main() { + let x = Foo {}; + ref_ar<|> +} +"#, + r#" +struct Foo {} +fn ref_arg(x: &Foo) {} +fn main() { + let x = Foo {}; + ref_arg(${1:&x})$0 +} +"#, + ); + } + + #[test] + fn insert_mut_ref_when_matching_local_in_scope() { + check_edit( + "ref_arg", + r#" +struct Foo {} +fn ref_arg(x: &mut Foo) {} +fn main() { + let x = Foo {}; + ref_ar<|> +} +"#, + r#" +struct Foo {} +fn ref_arg(x: &mut Foo) {} +fn main() { + let x = Foo {}; + ref_arg(${1:&mut x})$0 +} +"#, + ); + } + + #[test] + fn insert_ref_when_matching_local_in_scope_for_method() { + check_edit( + "apply_foo", + r#" +struct Foo {} +struct Bar {} +impl Bar { + fn apply_foo(&self, x: &Foo) {} +} + +fn main() { + let x = Foo {}; + let y = Bar {}; + y.<|> +} +"#, + r#" +struct Foo {} +struct Bar {} +impl Bar { + fn apply_foo(&self, x: &Foo) {} +} + +fn main() { + let x = Foo {}; + let y = Bar {}; + y.apply_foo(${1:&x})$0 +} +"#, + ); + } + #[test] fn inserts_parens_for_tuple_enums() { mark::check!(inserts_parens_for_tuple_enums); -- cgit v1.2.3