From 1316422a7c2ef26e9da78fa23f170407b1cb39bb Mon Sep 17 00:00:00 2001 From: Phil Ellison Date: Mon, 28 Dec 2020 13:41:15 +0000 Subject: Add diagnostic for filter_map followed by next --- crates/hir_ty/src/diagnostics/expr.rs | 70 ++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 9 deletions(-) (limited to 'crates/hir_ty/src/diagnostics/expr.rs') diff --git a/crates/hir_ty/src/diagnostics/expr.rs b/crates/hir_ty/src/diagnostics/expr.rs index 107417c27..170d23178 100644 --- a/crates/hir_ty/src/diagnostics/expr.rs +++ b/crates/hir_ty/src/diagnostics/expr.rs @@ -24,6 +24,8 @@ pub(crate) use hir_def::{ LocalFieldId, VariantId, }; +use super::ReplaceFilterMapNextWithFindMap; + pub(super) struct ExprValidator<'a, 'b: 'a> { owner: DefWithBodyId, infer: Arc, @@ -39,7 +41,18 @@ impl<'a, 'b> ExprValidator<'a, 'b> { ExprValidator { owner, infer, sink } } + fn bar() { + // LOOK FOR THIS + let m = [1, 2, 3] + .iter() + .filter_map(|x| if *x == 2 { Some(4) } else { None }) + .next(); + } + pub(super) fn validate_body(&mut self, db: &dyn HirDatabase) { + // DO NOT MERGE: just getting something working for now + self.check_for_filter_map_next(db); + let body = db.body(self.owner.into()); for (id, expr) in body.exprs.iter() { @@ -150,20 +163,58 @@ impl<'a, 'b> ExprValidator<'a, 'b> { } } - fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr) -> Option<()> { + fn check_for_filter_map_next(&mut self, db: &dyn HirDatabase) { + let body = db.body(self.owner.into()); + let mut prev = None; + + for (id, expr) in body.exprs.iter() { + if let Expr::MethodCall { receiver, method_name, args, .. } = expr { + let method_name_hack_do_not_merge = format!("{}", method_name); + + if method_name_hack_do_not_merge == "filter_map" && args.len() == 1 { + prev = Some((id, args[0])); + continue; + } + + if method_name_hack_do_not_merge == "next" { + if let Some((filter_map_id, filter_map_args)) = prev { + if *receiver == filter_map_id { + let (_, source_map) = db.body_with_source_map(self.owner.into()); + if let (Ok(filter_map_source_ptr), Ok(next_source_ptr)) = ( + source_map.expr_syntax(filter_map_id), + source_map.expr_syntax(id), + ) { + self.sink.push(ReplaceFilterMapNextWithFindMap { + file: filter_map_source_ptr.file_id, + filter_map_expr: filter_map_source_ptr.value, + next_expr: next_source_ptr.value, + }); + } + } + } + } + } + prev = None; + } + } + + fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr) { // Check that the number of arguments matches the number of parameters. // FIXME: Due to shortcomings in the current type system implementation, only emit this // diagnostic if there are no type mismatches in the containing function. if self.infer.type_mismatches.iter().next().is_some() { - return None; + return; } let is_method_call = matches!(expr, Expr::MethodCall { .. }); let (sig, args) = match expr { Expr::Call { callee, args } => { let callee = &self.infer.type_of_expr[*callee]; - let sig = callee.callable_sig(db)?; + let sig = match callee.callable_sig(db) { + Some(sig) => sig, + None => return, + }; (sig, args.clone()) } Expr::MethodCall { receiver, args, .. } => { @@ -175,22 +226,25 @@ impl<'a, 'b> ExprValidator<'a, 'b> { // if the receiver is of unknown type, it's very likely we // don't know enough to correctly resolve the method call. // This is kind of a band-aid for #6975. - return None; + return; } // FIXME: note that we erase information about substs here. This // is not right, but, luckily, doesn't matter as we care only // about the number of params - let callee = self.infer.method_resolution(call_id)?; + let callee = match self.infer.method_resolution(call_id) { + Some(callee) => callee, + None => return, + }; let sig = db.callable_item_signature(callee.into()).value; (sig, args) } - _ => return None, + _ => return, }; if sig.is_varargs { - return None; + return; } let params = sig.params(); @@ -213,8 +267,6 @@ impl<'a, 'b> ExprValidator<'a, 'b> { }); } } - - None } fn validate_match( -- cgit v1.2.3 From 1ff860b93c972e0f8d3a8ee582c255fa59e9b284 Mon Sep 17 00:00:00 2001 From: Phil Ellison Date: Wed, 30 Dec 2020 15:46:05 +0000 Subject: Implement fix, add tests --- crates/hir_ty/src/diagnostics/expr.rs | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) (limited to 'crates/hir_ty/src/diagnostics/expr.rs') diff --git a/crates/hir_ty/src/diagnostics/expr.rs b/crates/hir_ty/src/diagnostics/expr.rs index 170d23178..b87557ff5 100644 --- a/crates/hir_ty/src/diagnostics/expr.rs +++ b/crates/hir_ty/src/diagnostics/expr.rs @@ -41,16 +41,7 @@ impl<'a, 'b> ExprValidator<'a, 'b> { ExprValidator { owner, infer, sink } } - fn bar() { - // LOOK FOR THIS - let m = [1, 2, 3] - .iter() - .filter_map(|x| if *x == 2 { Some(4) } else { None }) - .next(); - } - pub(super) fn validate_body(&mut self, db: &dyn HirDatabase) { - // DO NOT MERGE: just getting something working for now self.check_for_filter_map_next(db); let body = db.body(self.owner.into()); @@ -169,24 +160,20 @@ impl<'a, 'b> ExprValidator<'a, 'b> { for (id, expr) in body.exprs.iter() { if let Expr::MethodCall { receiver, method_name, args, .. } = expr { - let method_name_hack_do_not_merge = format!("{}", method_name); + let method_name = format!("{}", method_name); - if method_name_hack_do_not_merge == "filter_map" && args.len() == 1 { - prev = Some((id, args[0])); + if method_name == "filter_map" && args.len() == 1 { + prev = Some(id); continue; } - if method_name_hack_do_not_merge == "next" { - if let Some((filter_map_id, filter_map_args)) = prev { + if method_name == "next" { + if let Some(filter_map_id) = prev { if *receiver == filter_map_id { let (_, source_map) = db.body_with_source_map(self.owner.into()); - if let (Ok(filter_map_source_ptr), Ok(next_source_ptr)) = ( - source_map.expr_syntax(filter_map_id), - source_map.expr_syntax(id), - ) { + if let Ok(next_source_ptr) = source_map.expr_syntax(id) { self.sink.push(ReplaceFilterMapNextWithFindMap { - file: filter_map_source_ptr.file_id, - filter_map_expr: filter_map_source_ptr.value, + file: next_source_ptr.file_id, next_expr: next_source_ptr.value, }); } -- cgit v1.2.3 From 8c7ccdc29d071649e816030ac744338e91eb5558 Mon Sep 17 00:00:00 2001 From: Phil Ellison Date: Fri, 1 Jan 2021 21:11:08 +0000 Subject: Identify methods using functions ids rather than string names --- crates/hir_ty/src/diagnostics/expr.rs | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) (limited to 'crates/hir_ty/src/diagnostics/expr.rs') diff --git a/crates/hir_ty/src/diagnostics/expr.rs b/crates/hir_ty/src/diagnostics/expr.rs index b87557ff5..16bbd48fb 100644 --- a/crates/hir_ty/src/diagnostics/expr.rs +++ b/crates/hir_ty/src/diagnostics/expr.rs @@ -2,8 +2,8 @@ use std::sync::Arc; -use hir_def::{expr::Statement, path::path, resolver::HasResolver, AdtId, DefWithBodyId}; -use hir_expand::diagnostics::DiagnosticSink; +use hir_def::{AdtId, AssocItemId, DefWithBodyId, expr::Statement, path::path, resolver::HasResolver}; +use hir_expand::{diagnostics::DiagnosticSink, name}; use rustc_hash::FxHashSet; use syntax::{ast, AstPtr}; @@ -155,19 +155,39 @@ impl<'a, 'b> ExprValidator<'a, 'b> { } fn check_for_filter_map_next(&mut self, db: &dyn HirDatabase) { + // Find the FunctionIds for Iterator::filter_map and Iterator::next + let iterator_path = path![core::iter::Iterator]; + let resolver = self.owner.resolver(db.upcast()); + let iterator_trait_id = match resolver.resolve_known_trait(db.upcast(), &iterator_path) { + Some(id) => id, + None => return, + }; + let iterator_trait_items = &db.trait_data(iterator_trait_id).items; + let filter_map_function_id = match iterator_trait_items.iter().find(|item| item.0 == name![filter_map]) { + Some((_, AssocItemId::FunctionId(id))) => id, + _ => return, + }; + let next_function_id = match iterator_trait_items.iter().find(|item| item.0 == name![next]) { + Some((_, AssocItemId::FunctionId(id))) => id, + _ => return, + }; + + // Search function body for instances of .filter_map(..).next() let body = db.body(self.owner.into()); let mut prev = None; - for (id, expr) in body.exprs.iter() { - if let Expr::MethodCall { receiver, method_name, args, .. } = expr { - let method_name = format!("{}", method_name); + if let Expr::MethodCall { receiver, .. } = expr { + let function_id = match self.infer.method_resolution(id) { + Some(id) => id, + None => continue, + }; - if method_name == "filter_map" && args.len() == 1 { + if function_id == *filter_map_function_id { prev = Some(id); continue; } - if method_name == "next" { + if function_id == *next_function_id { if let Some(filter_map_id) = prev { if *receiver == filter_map_id { let (_, source_map) = db.body_with_source_map(self.owner.into()); -- cgit v1.2.3 From 65a5ea581d547c36e98b4a3c5a99671ad5d4c117 Mon Sep 17 00:00:00 2001 From: Phil Ellison Date: Fri, 1 Jan 2021 21:40:11 +0000 Subject: Update tests to register the required standard library types --- crates/hir_ty/src/diagnostics/expr.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'crates/hir_ty/src/diagnostics/expr.rs') diff --git a/crates/hir_ty/src/diagnostics/expr.rs b/crates/hir_ty/src/diagnostics/expr.rs index 16bbd48fb..d740b7265 100644 --- a/crates/hir_ty/src/diagnostics/expr.rs +++ b/crates/hir_ty/src/diagnostics/expr.rs @@ -2,7 +2,9 @@ use std::sync::Arc; -use hir_def::{AdtId, AssocItemId, DefWithBodyId, expr::Statement, path::path, resolver::HasResolver}; +use hir_def::{ + expr::Statement, path::path, resolver::HasResolver, AdtId, AssocItemId, DefWithBodyId, +}; use hir_expand::{diagnostics::DiagnosticSink, name}; use rustc_hash::FxHashSet; use syntax::{ast, AstPtr}; @@ -163,11 +165,13 @@ impl<'a, 'b> ExprValidator<'a, 'b> { None => return, }; let iterator_trait_items = &db.trait_data(iterator_trait_id).items; - let filter_map_function_id = match iterator_trait_items.iter().find(|item| item.0 == name![filter_map]) { - Some((_, AssocItemId::FunctionId(id))) => id, - _ => return, - }; - let next_function_id = match iterator_trait_items.iter().find(|item| item.0 == name![next]) { + let filter_map_function_id = + match iterator_trait_items.iter().find(|item| item.0 == name![filter_map]) { + Some((_, AssocItemId::FunctionId(id))) => id, + _ => return, + }; + let next_function_id = match iterator_trait_items.iter().find(|item| item.0 == name![next]) + { Some((_, AssocItemId::FunctionId(id))) => id, _ => return, }; -- cgit v1.2.3