From d644728d82df10b034d0ea736590c781afa2ba15 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sun, 7 Feb 2021 18:38:12 +0100 Subject: Refactor reference searching to work with the ast --- .../assists/src/handlers/inline_local_variable.rs | 110 ++++--- crates/ide/src/call_hierarchy.rs | 8 +- crates/ide/src/lib.rs | 4 +- crates/ide/src/references.rs | 357 +++++++++++---------- crates/ide/src/references/rename.rs | 197 +++++++----- crates/ide_db/src/search.rs | 140 ++++---- crates/rust-analyzer/src/handlers.rs | 49 ++- crates/syntax/src/ast/node_ext.rs | 30 +- 8 files changed, 489 insertions(+), 406 deletions(-) (limited to 'crates') diff --git a/crates/assists/src/handlers/inline_local_variable.rs b/crates/assists/src/handlers/inline_local_variable.rs index 0e63a60e8..e4f984713 100644 --- a/crates/assists/src/handlers/inline_local_variable.rs +++ b/crates/assists/src/handlers/inline_local_variable.rs @@ -1,7 +1,6 @@ -use ide_db::{ - defs::Definition, - search::{FileReference, ReferenceKind}, -}; +use std::collections::HashMap; + +use ide_db::{defs::Definition, search::FileReference}; use syntax::{ ast::{self, AstNode, AstToken}, TextRange, @@ -68,44 +67,51 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext) -> O let wrap_in_parens = usages .references - .values() - .flatten() - .map(|&FileReference { range, .. }| { - let usage_node = - ctx.covering_node_for_range(range).ancestors().find_map(ast::PathExpr::cast)?; - let usage_parent_option = usage_node.syntax().parent().and_then(ast::Expr::cast); - let usage_parent = match usage_parent_option { - Some(u) => u, - None => return Ok(false), - }; - - Ok(!matches!( - (&initializer_expr, usage_parent), - (ast::Expr::CallExpr(_), _) - | (ast::Expr::IndexExpr(_), _) - | (ast::Expr::MethodCallExpr(_), _) - | (ast::Expr::FieldExpr(_), _) - | (ast::Expr::TryExpr(_), _) - | (ast::Expr::RefExpr(_), _) - | (ast::Expr::Literal(_), _) - | (ast::Expr::TupleExpr(_), _) - | (ast::Expr::ArrayExpr(_), _) - | (ast::Expr::ParenExpr(_), _) - | (ast::Expr::PathExpr(_), _) - | (ast::Expr::BlockExpr(_), _) - | (ast::Expr::EffectExpr(_), _) - | (_, ast::Expr::CallExpr(_)) - | (_, ast::Expr::TupleExpr(_)) - | (_, ast::Expr::ArrayExpr(_)) - | (_, ast::Expr::ParenExpr(_)) - | (_, ast::Expr::ForExpr(_)) - | (_, ast::Expr::WhileExpr(_)) - | (_, ast::Expr::BreakExpr(_)) - | (_, ast::Expr::ReturnExpr(_)) - | (_, ast::Expr::MatchExpr(_)) - )) + .iter() + .map(|(&file_id, refs)| { + refs.iter() + .map(|&FileReference { range, .. }| { + let usage_node = ctx + .covering_node_for_range(range) + .ancestors() + .find_map(ast::PathExpr::cast)?; + let usage_parent_option = + usage_node.syntax().parent().and_then(ast::Expr::cast); + let usage_parent = match usage_parent_option { + Some(u) => u, + None => return Ok(false), + }; + + Ok(!matches!( + (&initializer_expr, usage_parent), + (ast::Expr::CallExpr(_), _) + | (ast::Expr::IndexExpr(_), _) + | (ast::Expr::MethodCallExpr(_), _) + | (ast::Expr::FieldExpr(_), _) + | (ast::Expr::TryExpr(_), _) + | (ast::Expr::RefExpr(_), _) + | (ast::Expr::Literal(_), _) + | (ast::Expr::TupleExpr(_), _) + | (ast::Expr::ArrayExpr(_), _) + | (ast::Expr::ParenExpr(_), _) + | (ast::Expr::PathExpr(_), _) + | (ast::Expr::BlockExpr(_), _) + | (ast::Expr::EffectExpr(_), _) + | (_, ast::Expr::CallExpr(_)) + | (_, ast::Expr::TupleExpr(_)) + | (_, ast::Expr::ArrayExpr(_)) + | (_, ast::Expr::ParenExpr(_)) + | (_, ast::Expr::ForExpr(_)) + | (_, ast::Expr::WhileExpr(_)) + | (_, ast::Expr::BreakExpr(_)) + | (_, ast::Expr::ReturnExpr(_)) + | (_, ast::Expr::MatchExpr(_)) + )) + }) + .collect::>() + .map(|b| (file_id, b)) }) - .collect::, _>>()?; + .collect::>, _>>()?; let init_str = initializer_expr.syntax().text().to_string(); let init_in_paren = format!("({})", &init_str); @@ -117,16 +123,20 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext) -> O target, move |builder| { builder.delete(delete_range); - for (reference, should_wrap) in usages.references.values().flatten().zip(wrap_in_parens) - { - let replacement = - if should_wrap { init_in_paren.clone() } else { init_str.clone() }; - match reference.kind { - ReferenceKind::FieldShorthandForLocal => { - mark::hit!(inline_field_shorthand); - builder.insert(reference.range.end(), format!(": {}", replacement)) + for (file_id, references) in usages.references { + let root = ctx.sema.parse(file_id); + for (&should_wrap, reference) in wrap_in_parens[&file_id].iter().zip(references) { + let replacement = + if should_wrap { init_in_paren.clone() } else { init_str.clone() }; + match &reference.as_name_ref(root.syntax()) { + Some(name_ref) + if ast::RecordExprField::for_field_name(name_ref).is_some() => + { + mark::hit!(inline_field_shorthand); + builder.insert(reference.range.end(), format!(": {}", replacement)); + } + _ => builder.replace(reference.range, replacement), } - _ => builder.replace(reference.range, replacement), } } }, diff --git a/crates/ide/src/call_hierarchy.rs b/crates/ide/src/call_hierarchy.rs index b10a0a78b..b848945d7 100644 --- a/crates/ide/src/call_hierarchy.rs +++ b/crates/ide/src/call_hierarchy.rs @@ -47,11 +47,11 @@ pub(crate) fn incoming_calls(db: &RootDatabase, position: FilePosition) -> Optio let mut calls = CallLocations::default(); - for (&file_id, references) in refs.references().iter() { + for (file_id, references) in refs.references { let file = sema.parse(file_id); let file = file.syntax(); - for reference in references { - let token = file.token_at_offset(reference.range.start()).next()?; + for (r_range, _) in references { + let token = file.token_at_offset(r_range.start()).next()?; let token = sema.descend_into_macros(token); let syntax = token.parent(); @@ -61,7 +61,7 @@ pub(crate) fn incoming_calls(db: &RootDatabase, position: FilePosition) -> Optio let def = sema.to_def(&fn_)?; def.try_to_nav(sema.db) }) { - let relative_range = reference.range; + let relative_range = r_range; calls.add(&nav, relative_range); } } diff --git a/crates/ide/src/lib.rs b/crates/ide/src/lib.rs index 989e94a31..a245d9341 100644 --- a/crates/ide/src/lib.rs +++ b/crates/ide/src/lib.rs @@ -73,7 +73,7 @@ pub use crate::{ inlay_hints::{InlayHint, InlayHintsConfig, InlayKind}, markup::Markup, prime_caches::PrimeCachesProgress, - references::{rename::RenameError, Declaration, ReferenceSearchResult}, + references::{rename::RenameError, ReferenceSearchResult}, runnables::{Runnable, RunnableKind, TestId}, syntax_highlighting::{ tags::{Highlight, HlMod, HlMods, HlPunct, HlTag}, @@ -94,7 +94,7 @@ pub use ide_db::{ call_info::CallInfo, label::Label, line_index::{LineCol, LineIndex}, - search::{FileReference, ReferenceAccess, ReferenceKind, SearchScope}, + search::{FileReference, ReferenceAccess, SearchScope}, source_change::{FileSystemEdit, SourceChange}, symbol_index::Query, RootDatabase, diff --git a/crates/ide/src/references.rs b/crates/ide/src/references.rs index e8737dcfa..f96fac9c1 100644 --- a/crates/ide/src/references.rs +++ b/crates/ide/src/references.rs @@ -11,13 +11,14 @@ pub(crate) mod rename; -use either::Either; use hir::Semantics; use ide_db::{ + base_db::FileId, defs::{Definition, NameClass, NameRefClass}, - search::{FileReference, ReferenceAccess, ReferenceKind, SearchScope, UsageSearchResult}, + search::{ReferenceAccess, SearchScope}, RootDatabase, }; +use rustc_hash::FxHashMap; use syntax::{ algo::find_node_at_offset, ast::{self, NameOwner}, @@ -28,32 +29,14 @@ use crate::{display::TryToNav, FilePosition, NavigationTarget}; #[derive(Debug, Clone)] pub struct ReferenceSearchResult { - declaration: Declaration, - references: UsageSearchResult, + pub declaration: Declaration, + pub references: FxHashMap)>>, } #[derive(Debug, Clone)] pub struct Declaration { - nav: NavigationTarget, - kind: ReferenceKind, - access: Option, -} - -impl ReferenceSearchResult { - pub fn references(self) -> UsageSearchResult { - self.references - } - - pub fn references_with_declaration(mut self) -> UsageSearchResult { - let decl_ref = FileReference { - range: self.declaration.nav.focus_or_full_range(), - kind: self.declaration.kind, - access: self.declaration.access, - }; - let file_id = self.declaration.nav.file_id; - self.references.references.entry(file_id).or_default().push(decl_ref); - self.references - } + pub nav: NavigationTarget, + pub access: Option, } pub(crate) fn find_all_refs( @@ -64,83 +47,76 @@ pub(crate) fn find_all_refs( let _p = profile::span("find_all_refs"); let syntax = sema.parse(position.file_id).syntax().clone(); - let (opt_name, search_kind) = if let Some(name) = + let (opt_name, ctor_filter): (_, Option bool>) = if let Some(name) = get_struct_def_name_for_struct_literal_search(&sema, &syntax, position) { - (Some(name), ReferenceKind::StructLiteral) + ( + Some(name), + Some(|name_ref| is_record_lit_name_ref(name_ref) || is_call_expr_name_ref(name_ref)), + ) } else if let Some(name) = get_enum_def_name_for_struct_literal_search(&sema, &syntax, position) { - (Some(name), ReferenceKind::EnumLiteral) + (Some(name), Some(is_enum_lit_name_ref)) } else { - ( - sema.find_node_at_offset_with_descend::(&syntax, position.offset), - ReferenceKind::Other, - ) + (sema.find_node_at_offset_with_descend::(&syntax, position.offset), None) }; - let def = find_name(&sema, &syntax, position, opt_name)?; + let def = find_def(&sema, &syntax, position, opt_name)?; let mut usages = def.usages(sema).set_scope(search_scope).all(); - usages - .references - .values_mut() - .for_each(|it| it.retain(|r| search_kind == ReferenceKind::Other || search_kind == r.kind)); - usages.references.retain(|_, it| !it.is_empty()); - + if let Some(ctor_filter) = ctor_filter { + // filter for constructor-literals + usages.references.iter_mut().for_each(|(&file_id, it)| { + let root = sema.parse(file_id); + let root = root.syntax(); + it.retain(|reference| { + reference.as_name_ref(root).map_or(false, |name_ref| ctor_filter(&name_ref)) + }) + }); + usages.references.retain(|_, it| !it.is_empty()); + } let nav = def.try_to_nav(sema.db)?; let decl_range = nav.focus_or_full_range(); - let mut kind = ReferenceKind::Other; - if let Definition::Local(local) = def { - match local.source(sema.db).value { - Either::Left(pat) => { - if matches!( - pat.syntax().parent().and_then(ast::RecordPatField::cast), - Some(pat_field) if pat_field.name_ref().is_none() - ) { - kind = ReferenceKind::FieldShorthandForLocal; - } - } - Either::Right(_) => kind = ReferenceKind::SelfParam, - } - } else if matches!( - def, - Definition::GenericParam(hir::GenericParam::LifetimeParam(_)) | Definition::Label(_) - ) { - kind = ReferenceKind::Lifetime; - }; - - let declaration = Declaration { nav, kind, access: decl_access(&def, &syntax, decl_range) }; + let declaration = Declaration { nav, access: decl_access(&def, &syntax, decl_range) }; + let references = usages + .into_iter() + .map(|(file_id, refs)| { + (file_id, refs.into_iter().map(|file_ref| (file_ref.range, file_ref.access)).collect()) + }) + .collect(); - Some(ReferenceSearchResult { declaration, references: usages }) + Some(ReferenceSearchResult { declaration, references }) } -fn find_name( +fn find_def( sema: &Semantics, syntax: &SyntaxNode, position: FilePosition, opt_name: Option, ) -> Option { - let def = if let Some(name) = opt_name { - NameClass::classify(sema, &name)?.referenced_or_defined(sema.db) + if let Some(name) = opt_name { + let class = NameClass::classify(sema, &name)?; + Some(class.referenced_or_defined(sema.db)) } else if let Some(lifetime) = sema.find_node_at_offset_with_descend::(&syntax, position.offset) { - if let Some(def) = + let def = if let Some(def) = NameRefClass::classify_lifetime(sema, &lifetime).map(|class| class.referenced(sema.db)) { def } else { NameClass::classify_lifetime(sema, &lifetime)?.referenced_or_defined(sema.db) - } + }; + Some(def) } else if let Some(name_ref) = sema.find_node_at_offset_with_descend::(&syntax, position.offset) { - NameRefClass::classify(sema, &name_ref)?.referenced(sema.db) + let class = NameRefClass::classify(sema, &name_ref)?; + Some(class.referenced(sema.db)) } else { - return None; - }; - Some(def) + None + } } fn decl_access(def: &Definition, syntax: &SyntaxNode, range: TextRange) -> Option { @@ -216,6 +192,43 @@ fn get_enum_def_name_for_struct_literal_search( None } +fn is_call_expr_name_ref(name_ref: &ast::NameRef) -> bool { + name_ref + .syntax() + .ancestors() + .find_map(ast::CallExpr::cast) + .and_then(|c| match c.expr()? { + ast::Expr::PathExpr(p) => { + Some(p.path()?.segment()?.name_ref().as_ref() == Some(name_ref)) + } + _ => None, + }) + .unwrap_or(false) +} + +fn is_record_lit_name_ref(name_ref: &ast::NameRef) -> bool { + name_ref + .syntax() + .ancestors() + .find_map(ast::RecordExpr::cast) + .and_then(|l| l.path()) + .and_then(|p| p.segment()) + .map(|p| p.name_ref().as_ref() == Some(name_ref)) + .unwrap_or(false) +} + +fn is_enum_lit_name_ref(name_ref: &ast::NameRef) -> bool { + name_ref + .syntax() + .ancestors() + .find_map(ast::PathExpr::cast) + .and_then(|p| p.path()) + .and_then(|p| p.qualifier()) + .and_then(|p| p.segment()) + .map(|p| p.name_ref().as_ref() == Some(name_ref)) + .unwrap_or(false) +} + #[cfg(test)] mod tests { use expect_test::{expect, Expect}; @@ -240,9 +253,9 @@ fn main() { } "#, expect![[r#" - Foo Struct FileId(0) 0..26 7..10 Other + Foo Struct FileId(0) 0..26 7..10 - FileId(0) 101..104 StructLiteral + FileId(0) 101..104 "#]], ); } @@ -258,10 +271,10 @@ struct Foo$0 {} } "#, expect![[r#" - Foo Struct FileId(0) 0..13 7..10 Other + Foo Struct FileId(0) 0..13 7..10 - FileId(0) 41..44 Other - FileId(0) 54..57 StructLiteral + FileId(0) 41..44 + FileId(0) 54..57 "#]], ); } @@ -277,9 +290,9 @@ struct Foo $0{} } "#, expect![[r#" - Foo Struct FileId(0) 0..16 7..10 Other + Foo Struct FileId(0) 0..16 7..10 - FileId(0) 64..67 StructLiteral + FileId(0) 64..67 "#]], ); } @@ -296,9 +309,9 @@ fn main() { } "#, expect![[r#" - Foo Struct FileId(0) 0..16 7..10 Other + Foo Struct FileId(0) 0..16 7..10 - FileId(0) 54..57 StructLiteral + FileId(0) 54..57 "#]], ); } @@ -317,9 +330,9 @@ fn main() { } "#, expect![[r#" - Foo Enum FileId(0) 0..26 5..8 Other + Foo Enum FileId(0) 0..26 5..8 - FileId(0) 63..66 EnumLiteral + FileId(0) 63..66 "#]], ); } @@ -338,10 +351,10 @@ fn main() { } "#, expect![[r#" - Foo Enum FileId(0) 0..26 5..8 Other + Foo Enum FileId(0) 0..26 5..8 - FileId(0) 50..53 Other - FileId(0) 63..66 EnumLiteral + FileId(0) 50..53 + FileId(0) 63..66 "#]], ); } @@ -360,9 +373,9 @@ fn main() { } "#, expect![[r#" - Foo Enum FileId(0) 0..32 5..8 Other + Foo Enum FileId(0) 0..32 5..8 - FileId(0) 73..76 EnumLiteral + FileId(0) 73..76 "#]], ); } @@ -381,9 +394,9 @@ fn main() { } "#, expect![[r#" - Foo Enum FileId(0) 0..33 5..8 Other + Foo Enum FileId(0) 0..33 5..8 - FileId(0) 70..73 EnumLiteral + FileId(0) 70..73 "#]], ); } @@ -404,12 +417,12 @@ fn main() { i = 5; }"#, expect![[r#" - i Local FileId(0) 20..25 24..25 Other Write + i Local FileId(0) 20..25 24..25 Write - FileId(0) 50..51 Other Write - FileId(0) 54..55 Other Read - FileId(0) 76..77 Other Write - FileId(0) 94..95 Other Write + FileId(0) 50..51 Write + FileId(0) 54..55 Read + FileId(0) 76..77 Write + FileId(0) 94..95 Write "#]], ); } @@ -428,10 +441,10 @@ fn bar() { } "#, expect![[r#" - spam Local FileId(0) 19..23 19..23 Other + spam Local FileId(0) 19..23 19..23 - FileId(0) 34..38 Other Read - FileId(0) 41..45 Other Read + FileId(0) 34..38 Read + FileId(0) 41..45 Read "#]], ); } @@ -443,9 +456,9 @@ fn bar() { fn foo(i : u32) -> u32 { i$0 } "#, expect![[r#" - i ValueParam FileId(0) 7..8 7..8 Other + i ValueParam FileId(0) 7..8 7..8 - FileId(0) 25..26 Other Read + FileId(0) 25..26 Read "#]], ); } @@ -457,9 +470,9 @@ fn foo(i : u32) -> u32 { i$0 } fn foo(i$0 : u32) -> u32 { i } "#, expect![[r#" - i ValueParam FileId(0) 7..8 7..8 Other + i ValueParam FileId(0) 7..8 7..8 - FileId(0) 25..26 Other Read + FileId(0) 25..26 Read "#]], ); } @@ -478,9 +491,9 @@ fn main(s: Foo) { } "#, expect![[r#" - spam Field FileId(0) 17..30 21..25 Other + spam Field FileId(0) 17..30 21..25 - FileId(0) 67..71 Other Read + FileId(0) 67..71 Read "#]], ); } @@ -495,7 +508,7 @@ impl Foo { } "#, expect![[r#" - f Function FileId(0) 27..43 30..31 Other + f Function FileId(0) 27..43 30..31 "#]], ); @@ -512,7 +525,7 @@ enum Foo { } "#, expect![[r#" - B Variant FileId(0) 22..23 22..23 Other + B Variant FileId(0) 22..23 22..23 "#]], ); @@ -529,7 +542,7 @@ enum Foo { } "#, expect![[r#" - field Field FileId(0) 26..35 26..31 Other + field Field FileId(0) 26..35 26..31 "#]], ); @@ -570,10 +583,10 @@ fn f() { } "#, expect![[r#" - Foo Struct FileId(1) 17..51 28..31 Other + Foo Struct FileId(1) 17..51 28..31 - FileId(0) 53..56 StructLiteral - FileId(2) 79..82 StructLiteral + FileId(0) 53..56 + FileId(2) 79..82 "#]], ); } @@ -600,9 +613,9 @@ pub struct Foo { } "#, expect![[r#" - foo Module FileId(1) 0..35 Other + foo Module FileId(1) 0..35 - FileId(0) 14..17 Other + FileId(0) 14..17 "#]], ); } @@ -628,10 +641,10 @@ pub(super) struct Foo$0 { } "#, expect![[r#" - Foo Struct FileId(2) 0..41 18..21 Other + Foo Struct FileId(2) 0..41 18..21 - FileId(1) 20..23 Other - FileId(1) 47..50 StructLiteral + FileId(1) 20..23 + FileId(1) 47..50 "#]], ); } @@ -656,10 +669,10 @@ pub(super) struct Foo$0 { code, None, expect![[r#" - quux Function FileId(0) 19..35 26..30 Other + quux Function FileId(0) 19..35 26..30 - FileId(1) 16..20 StructLiteral - FileId(2) 16..20 StructLiteral + FileId(1) 16..20 + FileId(2) 16..20 "#]], ); @@ -667,9 +680,9 @@ pub(super) struct Foo$0 { code, Some(SearchScope::single_file(FileId(2))), expect![[r#" - quux Function FileId(0) 19..35 26..30 Other + quux Function FileId(0) 19..35 26..30 - FileId(2) 16..20 StructLiteral + FileId(2) 16..20 "#]], ); } @@ -687,10 +700,10 @@ fn foo() { } "#, expect![[r#" - m1 Macro FileId(0) 0..46 29..31 Other + m1 Macro FileId(0) 0..46 29..31 - FileId(0) 63..65 StructLiteral - FileId(0) 73..75 StructLiteral + FileId(0) 63..65 + FileId(0) 73..75 "#]], ); } @@ -705,10 +718,10 @@ fn foo() { } "#, expect![[r#" - i Local FileId(0) 19..24 23..24 Other Write + i Local FileId(0) 19..24 23..24 Write - FileId(0) 34..35 Other Write - FileId(0) 38..39 Other Read + FileId(0) 34..35 Write + FileId(0) 38..39 Read "#]], ); } @@ -727,10 +740,10 @@ fn foo() { } "#, expect![[r#" - f Field FileId(0) 15..21 15..16 Other + f Field FileId(0) 15..21 15..16 - FileId(0) 55..56 RecordFieldExprOrPat Read - FileId(0) 68..69 Other Write + FileId(0) 55..56 Read + FileId(0) 68..69 Write "#]], ); } @@ -745,9 +758,9 @@ fn foo() { } "#, expect![[r#" - i Local FileId(0) 19..20 19..20 Other + i Local FileId(0) 19..20 19..20 - FileId(0) 26..27 Other Write + FileId(0) 26..27 Write "#]], ); } @@ -769,9 +782,9 @@ fn main() { } "#, expect![[r#" - new Function FileId(0) 54..81 61..64 Other + new Function FileId(0) 54..81 61..64 - FileId(0) 126..129 StructLiteral + FileId(0) 126..129 "#]], ); } @@ -791,10 +804,10 @@ use crate::f; fn g() { f(); } "#, expect![[r#" - f Function FileId(0) 22..31 25..26 Other + f Function FileId(0) 22..31 25..26 - FileId(1) 11..12 Other - FileId(1) 24..25 StructLiteral + FileId(1) 11..12 + FileId(1) 24..25 "#]], ); } @@ -814,9 +827,9 @@ fn f(s: S) { } "#, expect![[r#" - field Field FileId(0) 15..24 15..20 Other + field Field FileId(0) 15..24 15..20 - FileId(0) 68..73 FieldShorthandForField Read + FileId(0) 68..73 Read "#]], ); } @@ -838,9 +851,9 @@ fn f(e: En) { } "#, expect![[r#" - field Field FileId(0) 32..41 32..37 Other + field Field FileId(0) 32..41 32..37 - FileId(0) 102..107 FieldShorthandForField Read + FileId(0) 102..107 Read "#]], ); } @@ -862,9 +875,9 @@ fn f() -> m::En { } "#, expect![[r#" - field Field FileId(0) 56..65 56..61 Other + field Field FileId(0) 56..65 56..61 - FileId(0) 125..130 RecordFieldExprOrPat Read + FileId(0) 125..130 Read "#]], ); } @@ -887,10 +900,10 @@ impl Foo { } "#, expect![[r#" - self SelfParam FileId(0) 47..51 47..51 SelfParam + self SelfParam FileId(0) 47..51 47..51 - FileId(0) 71..75 Other Read - FileId(0) 152..156 Other Read + FileId(0) 71..75 Read + FileId(0) 152..156 Read "#]], ); } @@ -908,9 +921,9 @@ impl Foo { } "#, expect![[r#" - self SelfParam FileId(0) 47..51 47..51 SelfParam + self SelfParam FileId(0) 47..51 47..51 - FileId(0) 63..67 Other Read + FileId(0) 63..67 Read "#]], ); } @@ -926,7 +939,7 @@ impl Foo { let mut actual = String::new(); { let decl = refs.declaration; - format_to!(actual, "{} {:?}", decl.nav.debug_render(), decl.kind); + format_to!(actual, "{}", decl.nav.debug_render()); if let Some(access) = decl.access { format_to!(actual, " {:?}", access) } @@ -934,9 +947,9 @@ impl Foo { } for (file_id, references) in refs.references { - for r in references { - format_to!(actual, "{:?} {:?} {:?}", file_id, r.range, r.kind); - if let Some(access) = r.access { + for (range, access) in references { + format_to!(actual, "{:?} {:?}", file_id, range); + if let Some(access) = access { format_to!(actual, " {:?}", access); } actual += "\n"; @@ -957,13 +970,13 @@ fn foo<'a, 'b: 'a>(x: &'a$0 ()) -> &'a () where &'a (): Foo<'a> { } "#, expect![[r#" - 'a LifetimeParam FileId(0) 55..57 55..57 Lifetime + 'a LifetimeParam FileId(0) 55..57 55..57 - FileId(0) 63..65 Lifetime - FileId(0) 71..73 Lifetime - FileId(0) 82..84 Lifetime - FileId(0) 95..97 Lifetime - FileId(0) 106..108 Lifetime + FileId(0) 63..65 + FileId(0) 71..73 + FileId(0) 82..84 + FileId(0) 95..97 + FileId(0) 106..108 "#]], ); } @@ -975,10 +988,10 @@ fn foo<'a, 'b: 'a>(x: &'a$0 ()) -> &'a () where &'a (): Foo<'a> { type Foo<'a, T> where T: 'a$0 = &'a T; "#, expect![[r#" - 'a LifetimeParam FileId(0) 9..11 9..11 Lifetime + 'a LifetimeParam FileId(0) 9..11 9..11 - FileId(0) 25..27 Lifetime - FileId(0) 31..33 Lifetime + FileId(0) 25..27 + FileId(0) 31..33 "#]], ); } @@ -997,11 +1010,11 @@ impl<'a> Foo<'a> for &'a () { } "#, expect![[r#" - 'a LifetimeParam FileId(0) 47..49 47..49 Lifetime + 'a LifetimeParam FileId(0) 47..49 47..49 - FileId(0) 55..57 Lifetime - FileId(0) 64..66 Lifetime - FileId(0) 89..91 Lifetime + FileId(0) 55..57 + FileId(0) 64..66 + FileId(0) 89..91 "#]], ); } @@ -1017,9 +1030,9 @@ fn main() { } "#, expect![[r#" - a Local FileId(0) 59..60 59..60 Other + a Local FileId(0) 59..60 59..60 - FileId(0) 80..81 Other Read + FileId(0) 80..81 Read "#]], ); } @@ -1035,9 +1048,9 @@ fn main() { } "#, expect![[r#" - a Local FileId(0) 59..60 59..60 Other + a Local FileId(0) 59..60 59..60 - FileId(0) 80..81 Other Read + FileId(0) 80..81 Read "#]], ); } @@ -1056,10 +1069,10 @@ fn foo<'a>() -> &'a () { } "#, expect![[r#" - 'a Label FileId(0) 29..32 29..31 Lifetime + 'a Label FileId(0) 29..32 29..31 - FileId(0) 80..82 Lifetime - FileId(0) 108..110 Lifetime + FileId(0) 80..82 + FileId(0) 108..110 "#]], ); } @@ -1073,9 +1086,9 @@ fn foo() -> usize { } "#, expect![[r#" - FOO ConstParam FileId(0) 7..23 13..16 Other + FOO ConstParam FileId(0) 7..23 13..16 - FileId(0) 42..45 Other + FileId(0) 42..45 "#]], ); } @@ -1089,9 +1102,9 @@ trait Foo { } "#, expect![[r#" - Self TypeParam FileId(0) 6..9 6..9 Other + Self TypeParam FileId(0) 6..9 6..9 - FileId(0) 26..30 Other + FileId(0) 26..30 "#]], ); } diff --git a/crates/ide/src/references/rename.rs b/crates/ide/src/references/rename.rs index ebb1ce7dd..64992c72d 100644 --- a/crates/ide/src/references/rename.rs +++ b/crates/ide/src/references/rename.rs @@ -4,9 +4,9 @@ use std::fmt::{self, Display}; use either::Either; use hir::{HasSource, InFile, Module, ModuleDef, ModuleSource, Semantics}; use ide_db::{ - base_db::{AnchoredPathBuf, FileId, FileRange}, + base_db::{AnchoredPathBuf, FileId}, defs::{Definition, NameClass, NameRefClass}, - search::FileReference, + search::{FileReference, NameLike}, RootDatabase, }; use stdx::never; @@ -17,10 +17,7 @@ use syntax::{ use test_utils::mark; use text_edit::TextEdit; -use crate::{ - display::TryToNav, FilePosition, FileSystemEdit, RangeInfo, ReferenceKind, SourceChange, - TextRange, -}; +use crate::{display::TryToNav, FilePosition, FileSystemEdit, RangeInfo, SourceChange, TextRange}; type RenameResult = Result; #[derive(Debug)] @@ -41,6 +38,8 @@ macro_rules! bail { ($($tokens:tt)*) => {return Err(format_err!($($tokens)*))} } +/// Prepares a rename. The sole job of this function is to return the TextRange of the thing that is +/// being targeted for a rename. pub(crate) fn prepare_rename( db: &RootDatabase, position: FilePosition, @@ -123,12 +122,6 @@ fn check_identifier(new_name: &str) -> RenameResult { } } -enum NameLike { - Name(ast::Name), - NameRef(ast::NameRef), - Lifetime(ast::Lifetime), -} - fn find_name_like( sema: &Semantics, syntax: &SyntaxNode, @@ -174,69 +167,96 @@ fn source_edit_from_references( sema: &Semantics, file_id: FileId, references: &[FileReference], + def: Definition, new_name: &str, ) -> (FileId, TextEdit) { + let root = sema.parse(file_id); let mut edit = TextEdit::builder(); for reference in references { - let mut replacement_text = String::new(); - let range = match reference.kind { - ReferenceKind::FieldShorthandForField => { - mark::hit!(test_rename_struct_field_for_shorthand); - replacement_text.push_str(new_name); - replacement_text.push_str(": "); - TextRange::new(reference.range.start(), reference.range.start()) - } - ReferenceKind::FieldShorthandForLocal => { - mark::hit!(test_rename_local_for_field_shorthand); - replacement_text.push_str(": "); - replacement_text.push_str(new_name); - TextRange::new(reference.range.end(), reference.range.end()) - } - ReferenceKind::RecordFieldExprOrPat => { - mark::hit!(test_rename_field_expr_pat); - replacement_text.push_str(new_name); - edit_text_range_for_record_field_expr_or_pat( - sema, - FileRange { file_id, range: reference.range }, - new_name, - ) - } - _ => { - replacement_text.push_str(new_name); - reference.range - } + let (range, replacement) = match &reference.name_from_syntax(root.syntax()) { + Some(NameLike::Name(_)) => (None, format!("{}", new_name)), + Some(NameLike::NameRef(name_ref)) => source_edit_from_name_ref(name_ref, new_name, def), + Some(NameLike::Lifetime(_)) => (None, format!("{}", new_name)), + None => (None, new_name.to_owned()), }; - edit.replace(range, replacement_text); + // FIXME: Some(range) will be incorrect when we are inside macros + edit.replace(range.unwrap_or(reference.range), replacement); } (file_id, edit.finish()) } -fn edit_text_range_for_record_field_expr_or_pat( - sema: &Semantics, - file_range: FileRange, +fn source_edit_from_name_ref( + name_ref: &ast::NameRef, new_name: &str, -) -> TextRange { - let source_file = sema.parse(file_range.file_id); - let file_syntax = source_file.syntax(); - let original_range = file_range.range; - - syntax::algo::find_node_at_range::(file_syntax, original_range) - .and_then(|field_expr| match field_expr.expr().and_then(|e| e.name_ref()) { - Some(name) if &name.to_string() == new_name => Some(field_expr.syntax().text_range()), - _ => None, - }) - .or_else(|| { - syntax::algo::find_node_at_range::(file_syntax, original_range) - .and_then(|field_pat| match field_pat.pat() { - Some(ast::Pat::IdentPat(pat)) - if pat.name().map(|n| n.to_string()).as_deref() == Some(new_name) => - { - Some(field_pat.syntax().text_range()) + def: Definition, +) -> (Option, String) { + if let Some(record_field) = ast::RecordExprField::for_name_ref(name_ref) { + let rcf_name_ref = record_field.name_ref(); + let rcf_expr = record_field.expr(); + match (rcf_name_ref, rcf_expr.and_then(|it| it.name_ref())) { + // field: init-expr, check if we can use a field init shorthand + (Some(field_name), Some(init)) => { + if field_name == *name_ref { + if init.text() == new_name { + mark::hit!(test_rename_field_put_init_shorthand); + // same names, we can use a shorthand here instead + // we do not want to erase attributes hence this range start + let s = field_name.syntax().text_range().start(); + let e = record_field.syntax().text_range().end(); + return (Some(TextRange::new(s, e)), format!("{}", new_name)); } - _ => None, - }) - }) - .unwrap_or(original_range) + } else if init == *name_ref { + if field_name.text() == new_name { + mark::hit!(test_rename_local_put_init_shorthand); + // same names, we can use a shorthand here instead + // we do not want to erase attributes hence this range start + let s = field_name.syntax().text_range().start(); + let e = record_field.syntax().text_range().end(); + return (Some(TextRange::new(s, e)), format!("{}", new_name)); + } + } + } + // init shorthand + (None, Some(_)) => { + // FIXME: instead of splitting the shorthand, recursively trigger a rename of the + // other name https://github.com/rust-analyzer/rust-analyzer/issues/6547 + match def { + Definition::Field(_) => { + mark::hit!(test_rename_field_in_field_shorthand); + let s = name_ref.syntax().text_range().start(); + return (Some(TextRange::empty(s)), format!("{}: ", new_name)); + } + Definition::Local(_) => { + mark::hit!(test_rename_local_in_field_shorthand); + let s = name_ref.syntax().text_range().end(); + return (Some(TextRange::empty(s)), format!(": {}", new_name)); + } + _ => {} + } + } + _ => {} + } + } + if let Some(record_field) = ast::RecordPatField::for_field_name_ref(name_ref) { + let rcf_name_ref = record_field.name_ref(); + let rcf_pat = record_field.pat(); + match (rcf_name_ref, rcf_pat) { + // field: rename + (Some(field_name), Some(ast::Pat::IdentPat(pat))) if field_name == *name_ref => { + // field name is being renamed + if pat.name().map_or(false, |it| it.text() == new_name) { + mark::hit!(test_rename_field_put_init_shorthand_pat); + // same names, we can use a shorthand here instead + // we do not want to erase attributes hence this range start + let s = field_name.syntax().text_range().start(); + let e = record_field.syntax().text_range().end(); + return (Some(TextRange::new(s, e)), format!("{}", new_name)); + } + } + _ => {} + } + } + (None, format!("{}", new_name)) } fn rename_mod( @@ -277,7 +297,7 @@ fn rename_mod( let def = Definition::ModuleDef(ModuleDef::Module(module)); let usages = def.usages(sema).all(); let ref_edits = usages.iter().map(|(&file_id, references)| { - source_edit_from_references(sema, file_id, references, new_name) + source_edit_from_references(sema, file_id, references, def, new_name) }); source_change.extend(ref_edits); @@ -346,7 +366,7 @@ fn rename_to_self(sema: &Semantics, local: hir::Local) -> RenameRe let usages = def.usages(sema).all(); let mut source_change = SourceChange::default(); source_change.extend(usages.iter().map(|(&file_id, references)| { - source_edit_from_references(sema, file_id, references, "self") + source_edit_from_references(sema, file_id, references, def, "self") })); source_change.insert_source_edit( file_id.original_file(sema.db), @@ -403,7 +423,7 @@ fn rename_self_to_param( let mut source_change = SourceChange::default(); source_change.insert_source_edit(file_id.original_file(sema.db), edit); source_change.extend(usages.iter().map(|(&file_id, references)| { - source_edit_from_references(sema, file_id, &references, new_name) + source_edit_from_references(sema, file_id, &references, def, new_name) })); Ok(source_change) } @@ -457,7 +477,7 @@ fn rename_reference( } let mut source_change = SourceChange::default(); source_change.extend(usages.iter().map(|(&file_id, references)| { - source_edit_from_references(sema, file_id, &references, new_name) + source_edit_from_references(sema, file_id, &references, def, new_name) })); let (file_id, edit) = source_edit_from_def(sema, def, new_name)?; @@ -545,10 +565,8 @@ mod tests { fn check_expect(new_name: &str, ra_fixture: &str, expect: Expect) { let (analysis, position) = fixture::position(ra_fixture); - let source_change = analysis - .rename(position, new_name) - .unwrap() - .expect("Expect returned RangeInfo to be Some, but was None"); + let source_change = + analysis.rename(position, new_name).unwrap().expect("Expect returned a RenameError"); expect.assert_debug_eq(&source_change) } @@ -792,8 +810,8 @@ impl Foo { } #[test] - fn test_rename_struct_field_for_shorthand() { - mark::check!(test_rename_struct_field_for_shorthand); + fn test_rename_field_in_field_shorthand() { + mark::check!(test_rename_field_in_field_shorthand); check( "j", r#" @@ -818,8 +836,8 @@ impl Foo { } #[test] - fn test_rename_local_for_field_shorthand() { - mark::check!(test_rename_local_for_field_shorthand); + fn test_rename_local_in_field_shorthand() { + mark::check!(test_rename_local_in_field_shorthand); check( "j", r#" @@ -1417,8 +1435,8 @@ impl Foo { } #[test] - fn test_initializer_use_field_init_shorthand() { - mark::check!(test_rename_field_expr_pat); + fn test_rename_field_put_init_shorthand() { + mark::check!(test_rename_field_put_init_shorthand); check( "bar", r#" @@ -1438,8 +1456,31 @@ fn foo(bar: i32) -> Foo { ); } + #[test] + fn test_rename_local_put_init_shorthand() { + mark::check!(test_rename_local_put_init_shorthand); + check( + "i", + r#" +struct Foo { i: i32 } + +fn foo(bar$0: i32) -> Foo { + Foo { i: bar } +} +"#, + r#" +struct Foo { i: i32 } + +fn foo(i: i32) -> Foo { + Foo { i } +} +"#, + ); + } + #[test] fn test_struct_field_destructure_into_shorthand() { + mark::check!(test_rename_field_put_init_shorthand_pat); check( "baz", r#" diff --git a/crates/ide_db/src/search.rs b/crates/ide_db/src/search.rs index b9ba0aed5..d0aed26f7 100644 --- a/crates/ide_db/src/search.rs +++ b/crates/ide_db/src/search.rs @@ -10,7 +10,9 @@ use base_db::{FileId, FileRange, SourceDatabaseExt}; use hir::{DefWithBody, HasSource, Module, ModuleSource, Semantics, Visibility}; use once_cell::unsync::Lazy; use rustc_hash::FxHashMap; -use syntax::{ast, match_ast, AstNode, TextRange, TextSize}; +use syntax::{ + ast, match_ast, AstNode, NodeOrToken, SyntaxElement, SyntaxNode, TextRange, TextSize, +}; use crate::defs::NameClass; use crate::{ @@ -18,6 +20,13 @@ use crate::{ RootDatabase, }; +#[derive(Debug, Clone)] +pub enum NameKind { + Name, + NameRef, + Lifetime, +} + #[derive(Debug, Default, Clone)] pub struct UsageSearchResult { pub references: FxHashMap>, @@ -52,23 +61,53 @@ impl IntoIterator for UsageSearchResult { } } +#[derive(Debug, Clone)] +pub enum NameLike { + NameRef(ast::NameRef), + Name(ast::Name), + Lifetime(ast::Lifetime), +} + +mod __ { + use super::{ + ast::{Lifetime, Name, NameRef}, + NameLike, + }; + stdx::impl_from!(NameRef, Name, Lifetime for NameLike); +} + #[derive(Debug, Clone)] pub struct FileReference { pub range: TextRange, - pub kind: ReferenceKind, + pub name: NameKind, pub access: Option, } -#[derive(Debug, Clone, PartialEq)] -pub enum ReferenceKind { - FieldShorthandForField, - FieldShorthandForLocal, - StructLiteral, - RecordFieldExprOrPat, - SelfParam, - EnumLiteral, - Lifetime, - Other, +impl FileReference { + pub fn name_from_syntax(&self, root: &SyntaxNode) -> Option { + let node = node_or_parent(root.covering_element(self.range)); + match self.name { + NameKind::Name => ast::Name::cast(node).map(Into::into), + NameKind::NameRef => ast::NameRef::cast(node).map(Into::into), + NameKind::Lifetime => ast::Lifetime::cast(node).map(Into::into), + } + } + + pub fn as_name_ref(&self, root: &SyntaxNode) -> Option { + match self.name { + NameKind::NameRef => { + ast::NameRef::cast(node_or_parent(root.covering_element(self.range))) + } + _ => None, + } + } +} + +fn node_or_parent(ele: SyntaxElement) -> SyntaxNode { + match ele { + NodeOrToken::Node(node) => node, + NodeOrToken::Token(token) => token.parent(), + } } #[derive(Debug, Copy, Clone, PartialEq)] @@ -369,8 +408,7 @@ impl<'a> FindUsages<'a> { match NameRefClass::classify_lifetime(self.sema, lifetime) { Some(NameRefClass::Definition(def)) if &def == self.def => { let FileRange { file_id, range } = self.sema.original_range(lifetime.syntax()); - let reference = - FileReference { range, kind: ReferenceKind::Lifetime, access: None }; + let reference = FileReference { range, name: NameKind::Lifetime, access: None }; sink(file_id, reference) } _ => false, // not a usage @@ -384,19 +422,12 @@ impl<'a> FindUsages<'a> { ) -> bool { match NameRefClass::classify(self.sema, &name_ref) { Some(NameRefClass::Definition(def)) if &def == self.def => { - let kind = if is_record_field_expr_or_pat(&name_ref) { - ReferenceKind::RecordFieldExprOrPat - } else if is_record_lit_name_ref(&name_ref) || is_call_expr_name_ref(&name_ref) { - ReferenceKind::StructLiteral - } else if is_enum_lit_name_ref(&name_ref) { - ReferenceKind::EnumLiteral - } else { - ReferenceKind::Other - }; - let FileRange { file_id, range } = self.sema.original_range(name_ref.syntax()); - let reference = - FileReference { range, kind, access: reference_access(&def, &name_ref) }; + let reference = FileReference { + range, + name: NameKind::NameRef, + access: reference_access(&def, &name_ref), + }; sink(file_id, reference) } Some(NameRefClass::FieldShorthand { local_ref: local, field_ref: field }) => { @@ -404,12 +435,12 @@ impl<'a> FindUsages<'a> { let reference = match self.def { Definition::Field(_) if &field == self.def => FileReference { range, - kind: ReferenceKind::FieldShorthandForField, + name: NameKind::NameRef, access: reference_access(&field, &name_ref), }, Definition::Local(l) if &local == l => FileReference { range, - kind: ReferenceKind::FieldShorthandForLocal, + name: NameKind::NameRef, access: reference_access(&Definition::Local(local), &name_ref), }, _ => return false, // not a usage @@ -433,7 +464,7 @@ impl<'a> FindUsages<'a> { let FileRange { file_id, range } = self.sema.original_range(name.syntax()); let reference = FileReference { range, - kind: ReferenceKind::FieldShorthandForField, + name: NameKind::Name, // FIXME: mutable patterns should have `Write` access access: Some(ReferenceAccess::Read), }; @@ -473,54 +504,3 @@ fn reference_access(def: &Definition, name_ref: &ast::NameRef) -> Option bool { - name_ref - .syntax() - .ancestors() - .find_map(ast::CallExpr::cast) - .and_then(|c| match c.expr()? { - ast::Expr::PathExpr(p) => { - Some(p.path()?.segment()?.name_ref().as_ref() == Some(name_ref)) - } - _ => None, - }) - .unwrap_or(false) -} - -fn is_record_lit_name_ref(name_ref: &ast::NameRef) -> bool { - name_ref - .syntax() - .ancestors() - .find_map(ast::RecordExpr::cast) - .and_then(|l| l.path()) - .and_then(|p| p.segment()) - .map(|p| p.name_ref().as_ref() == Some(name_ref)) - .unwrap_or(false) -} - -fn is_record_field_expr_or_pat(name_ref: &ast::NameRef) -> bool { - if let Some(parent) = name_ref.syntax().parent() { - match_ast! { - match parent { - ast::RecordExprField(it) => true, - ast::RecordPatField(_it) => true, - _ => false, - } - } - } else { - false - } -} - -fn is_enum_lit_name_ref(name_ref: &ast::NameRef) -> bool { - name_ref - .syntax() - .ancestors() - .find_map(ast::PathExpr::cast) - .and_then(|p| p.path()) - .and_then(|p| p.qualifier()) - .and_then(|p| p.segment()) - .map(|p| p.name_ref().as_ref() == Some(name_ref)) - .unwrap_or(false) -} diff --git a/crates/rust-analyzer/src/handlers.rs b/crates/rust-analyzer/src/handlers.rs index 5a6501216..8898c12e3 100644 --- a/crates/rust-analyzer/src/handlers.rs +++ b/crates/rust-analyzer/src/handlers.rs @@ -827,18 +827,23 @@ pub(crate) fn handle_references( Some(refs) => refs, }; - let locations = if params.context.include_declaration { - refs.references_with_declaration() - .file_ranges() - .filter_map(|frange| to_proto::location(&snap, frange).ok()) - .collect() + let decl = if params.context.include_declaration { + Some(FileRange { + file_id: refs.declaration.nav.file_id, + range: refs.declaration.nav.focus_or_full_range(), + }) } else { - // Only iterate over the references if include_declaration was false - refs.references() - .file_ranges() - .filter_map(|frange| to_proto::location(&snap, frange).ok()) - .collect() + None }; + let locations = refs + .references + .into_iter() + .flat_map(|(file_id, refs)| { + refs.into_iter().map(move |(range, _)| FileRange { file_id, range }) + }) + .chain(decl) + .filter_map(|frange| to_proto::location(&snap, frange).ok()) + .collect(); Ok(Some(locations)) } @@ -1214,8 +1219,11 @@ pub(crate) fn handle_code_lens_resolve( .find_all_refs(position, None) .unwrap_or(None) .map(|r| { - r.references() - .file_ranges() + r.references + .into_iter() + .flat_map(|(file_id, ranges)| { + ranges.into_iter().map(move |(range, _)| FileRange { file_id, range }) + }) .filter_map(|frange| to_proto::location(&snap, frange).ok()) .collect_vec() }) @@ -1259,17 +1267,26 @@ pub(crate) fn handle_document_highlight( Some(refs) => refs, }; + let decl = if refs.declaration.nav.file_id == position.file_id { + Some(DocumentHighlight { + range: to_proto::range(&line_index, refs.declaration.nav.focus_or_full_range()), + kind: refs.declaration.access.map(to_proto::document_highlight_kind), + }) + } else { + None + }; + let res = refs - .references_with_declaration() .references .get(&position.file_id) .map(|file_refs| { file_refs .into_iter() - .map(|r| DocumentHighlight { - range: to_proto::range(&line_index, r.range), - kind: r.access.map(to_proto::document_highlight_kind), + .map(|&(range, access)| DocumentHighlight { + range: to_proto::range(&line_index, range), + kind: access.map(to_proto::document_highlight_kind), }) + .chain(decl) .collect() }) .unwrap_or_default(); diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs index 5c8cf900f..b105cb0e0 100644 --- a/crates/syntax/src/ast/node_ext.rs +++ b/crates/syntax/src/ast/node_ext.rs @@ -274,10 +274,7 @@ impl ast::Struct { impl ast::RecordExprField { pub fn for_field_name(field_name: &ast::NameRef) -> Option { - let candidate = - field_name.syntax().parent().and_then(ast::RecordExprField::cast).or_else(|| { - field_name.syntax().ancestors().nth(4).and_then(ast::RecordExprField::cast) - })?; + let candidate = Self::for_name_ref(field_name)?; if candidate.field_name().as_ref() == Some(field_name) { Some(candidate) } else { @@ -285,6 +282,13 @@ impl ast::RecordExprField { } } + pub fn for_name_ref(name_ref: &ast::NameRef) -> Option { + let syn = name_ref.syntax(); + syn.parent() + .and_then(ast::RecordExprField::cast) + .or_else(|| syn.ancestors().nth(4).and_then(ast::RecordExprField::cast)) + } + /// Deals with field init shorthand pub fn field_name(&self) -> Option { if let Some(name_ref) = self.name_ref() { @@ -294,6 +298,7 @@ impl ast::RecordExprField { } } +#[derive(Debug, Clone, PartialEq)] pub enum NameOrNameRef { Name(ast::Name), NameRef(ast::NameRef), @@ -309,6 +314,23 @@ impl fmt::Display for NameOrNameRef { } impl ast::RecordPatField { + pub fn for_field_name_ref(field_name: &ast::NameRef) -> Option { + let candidate = field_name.syntax().parent().and_then(ast::RecordPatField::cast)?; + match candidate.field_name()? { + NameOrNameRef::NameRef(name_ref) if name_ref == *field_name => Some(candidate), + _ => None, + } + } + + pub fn for_field_name(field_name: &ast::Name) -> Option { + let candidate = + field_name.syntax().ancestors().nth(3).and_then(ast::RecordPatField::cast)?; + match candidate.field_name()? { + NameOrNameRef::Name(name) if name == *field_name => Some(candidate), + _ => None, + } + } + /// Deals with field init shorthand pub fn field_name(&self) -> Option { if let Some(name_ref) = self.name_ref() { -- cgit v1.2.3