use rustc_hash::{FxHashMap, FxHashSet}; use ra_syntax::{ AstNode, SmolStr, SyntaxNodeRef, TextUnit, TextRange, algo::generate, ast::{self, ArgListOwner, LoopBodyOwner, NameOwner}, }; use ra_db::LocalSyntaxPtr; use crate::{ arena::{Arena, Id}, }; pub(crate) type ScopeId = Id; #[derive(Debug, PartialEq, Eq)] pub struct FnScopes { pub self_param: Option, scopes: Arena, scope_for: FxHashMap, } #[derive(Debug, PartialEq, Eq)] pub struct ScopeEntry { name: SmolStr, ptr: LocalSyntaxPtr, } #[derive(Debug, PartialEq, Eq)] pub struct ScopeData { parent: Option, entries: Vec, } impl FnScopes { pub fn new(fn_def: ast::FnDef) -> FnScopes { let mut scopes = FnScopes { self_param: fn_def .param_list() .and_then(|it| it.self_param()) .map(|it| LocalSyntaxPtr::new(it.syntax())), scopes: Arena::default(), scope_for: FxHashMap::default(), }; let root = scopes.root_scope(); scopes.add_params_bindings(root, fn_def.param_list()); if let Some(body) = fn_def.body() { compute_block_scopes(body, &mut scopes, root) } scopes } pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] { &self.scopes[scope].entries } pub fn scope_chain<'a>(&'a self, node: SyntaxNodeRef) -> impl Iterator + 'a { generate(self.scope_for(node), move |&scope| { self.scopes[scope].parent }) } pub fn scope_chain_for_offset<'a>( &'a self, offset: TextUnit, ) -> impl Iterator + 'a { let scope = self .scope_for .iter() // find containin scope .min_by_key(|(ptr, _scope)| { ( !(ptr.range().start() <= offset && offset <= ptr.range().end()), ptr.range().len(), ) }) .map(|(ptr, scope)| self.adjust(*ptr, *scope, offset)); generate(scope, move |&scope| self.scopes[scope].parent) } // XXX: during completion, cursor might be outside of any particular // expression. Try to figure out the correct scope... fn adjust(&self, ptr: LocalSyntaxPtr, original_scope: ScopeId, offset: TextUnit) -> ScopeId { let r = ptr.range(); let child_scopes = self .scope_for .iter() .map(|(ptr, scope)| (ptr.range(), scope)) .filter(|(range, _)| range.start() <= offset && range.is_subrange(&r) && *range != r); child_scopes .max_by(|(r1, _), (r2, _)| { if r2.is_subrange(&r1) { std::cmp::Ordering::Greater } else if r1.is_subrange(&r2) { std::cmp::Ordering::Less } else { r1.start().cmp(&r2.start()) } }) .map(|(ptr, scope)| *scope) .unwrap_or(original_scope) } pub fn resolve_local_name<'a>(&'a self, name_ref: ast::NameRef) -> Option<&'a ScopeEntry> { let mut shadowed = FxHashSet::default(); let ret = self .scope_chain(name_ref.syntax()) .flat_map(|scope| self.entries(scope).iter()) .filter(|entry| shadowed.insert(entry.name())) .filter(|entry| entry.name() == &name_ref.text()) .nth(0); ret } pub fn find_all_refs(&self, pat: ast::BindPat) -> Vec { let fn_def = pat.syntax().ancestors().find_map(ast::FnDef::cast).unwrap(); let name_ptr = LocalSyntaxPtr::new(pat.syntax()); let refs: Vec<_> = fn_def .syntax() .descendants() .filter_map(ast::NameRef::cast) .filter(|name_ref| match self.resolve_local_name(*name_ref) { None => false, Some(entry) => entry.ptr() == name_ptr, }) .map(|name_ref| ReferenceDescriptor { name: name_ref.syntax().text().to_string(), range: name_ref.syntax().range(), }) .collect(); refs } fn root_scope(&mut self) -> ScopeId { self.scopes.alloc(ScopeData { parent: None, entries: vec![], }) } fn new_scope(&mut self, parent: ScopeId) -> ScopeId { self.scopes.alloc(ScopeData { parent: Some(parent), entries: vec![], }) } fn add_bindings(&mut self, scope: ScopeId, pat: ast::Pat) { let entries = pat .syntax() .descendants() .filter_map(ast::BindPat::cast) .filter_map(ScopeEntry::new); self.scopes[scope].entries.extend(entries); } fn add_params_bindings(&mut self, scope: ScopeId, params: Option) { params .into_iter() .flat_map(|it| it.params()) .filter_map(|it| it.pat()) .for_each(|it| self.add_bindings(scope, it)); } fn set_scope(&mut self, node: SyntaxNodeRef, scope: ScopeId) { self.scope_for.insert(LocalSyntaxPtr::new(node), scope); } fn scope_for(&self, node: SyntaxNodeRef) -> Option { node.ancestors() .map(LocalSyntaxPtr::new) .filter_map(|it| self.scope_for.get(&it).map(|&scope| scope)) .next() } } impl ScopeEntry { fn new(pat: ast::BindPat) -> Option { let name = pat.name()?; let res = ScopeEntry { name: name.text(), ptr: LocalSyntaxPtr::new(pat.syntax()), }; Some(res) } pub fn name(&self) -> &SmolStr { &self.name } pub fn ptr(&self) -> LocalSyntaxPtr { self.ptr } } fn compute_block_scopes(block: ast::Block, scopes: &mut FnScopes, mut scope: ScopeId) { // A hack for completion :( scopes.set_scope(block.syntax(), scope); for stmt in block.statements() { match stmt { ast::Stmt::LetStmt(stmt) => { if let Some(expr) = stmt.initializer() { scopes.set_scope(expr.syntax(), scope); compute_expr_scopes(expr, scopes, scope); } scope = scopes.new_scope(scope); if let Some(pat) = stmt.pat() { scopes.add_bindings(scope, pat); } } ast::Stmt::ExprStmt(expr_stmt) => { if let Some(expr) = expr_stmt.expr() { scopes.set_scope(expr.syntax(), scope); compute_expr_scopes(expr, scopes, scope); } } } } if let Some(expr) = block.expr() { eprintln!("{:?}", expr); scopes.set_scope(expr.syntax(), scope); compute_expr_scopes(expr, scopes, scope); } } fn compute_expr_scopes(expr: ast::Expr, scopes: &mut FnScopes, scope: ScopeId) { match expr { ast::Expr::IfExpr(e) => { let cond_scope = e .condition() .and_then(|cond| compute_cond_scopes(cond, scopes, scope)); if let Some(block) = e.then_branch() { compute_block_scopes(block, scopes, cond_scope.unwrap_or(scope)); } if let Some(block) = e.else_branch() { compute_block_scopes(block, scopes, scope); } } ast::Expr::BlockExpr(e) => { if let Some(block) = e.block() { compute_block_scopes(block, scopes, scope); } } ast::Expr::LoopExpr(e) => { if let Some(block) = e.loop_body() { compute_block_scopes(block, scopes, scope); } } ast::Expr::WhileExpr(e) => { let cond_scope = e .condition() .and_then(|cond| compute_cond_scopes(cond, scopes, scope)); if let Some(block) = e.loop_body() { compute_block_scopes(block, scopes, cond_scope.unwrap_or(scope)); } } ast::Expr::ForExpr(e) => { if let Some(expr) = e.iterable() { compute_expr_scopes(expr, scopes, scope); } let mut scope = scope; if let Some(pat) = e.pat() { scope = scopes.new_scope(scope); scopes.add_bindings(scope, pat); } if let Some(block) = e.loop_body() { compute_block_scopes(block, scopes, scope); } } ast::Expr::LambdaExpr(e) => { let scope = scopes.new_scope(scope); scopes.add_params_bindings(scope, e.param_list()); if let Some(body) = e.body() { scopes.set_scope(body.syntax(), scope); compute_expr_scopes(body, scopes, scope); } } ast::Expr::CallExpr(e) => { compute_call_scopes(e.expr(), e.arg_list(), scopes, scope); } ast::Expr::MethodCallExpr(e) => { compute_call_scopes(e.expr(), e.arg_list(), scopes, scope); } ast::Expr::MatchExpr(e) => { if let Some(expr) = e.expr() { compute_expr_scopes(expr, scopes, scope); } for arm in e.match_arm_list().into_iter().flat_map(|it| it.arms()) { let scope = scopes.new_scope(scope); for pat in arm.pats() { scopes.add_bindings(scope, pat); } if let Some(expr) = arm.expr() { compute_expr_scopes(expr, scopes, scope); } } } _ => expr .syntax() .children() .filter_map(ast::Expr::cast) .for_each(|expr| compute_expr_scopes(expr, scopes, scope)), }; fn compute_call_scopes( receiver: Option, arg_list: Option, scopes: &mut FnScopes, scope: ScopeId, ) { arg_list .into_iter() .flat_map(|it| it.args()) .chain(receiver) .for_each(|expr| compute_expr_scopes(expr, scopes, scope)); } fn compute_cond_scopes( cond: ast::Condition, scopes: &mut FnScopes, scope: ScopeId, ) -> Option { if let Some(expr) = cond.expr() { compute_expr_scopes(expr, scopes, scope); } if let Some(pat) = cond.pat() { let s = scopes.new_scope(scope); scopes.add_bindings(s, pat); Some(s) } else { None } } } #[derive(Debug)] pub struct ReferenceDescriptor { pub range: TextRange, pub name: String, } #[cfg(test)] mod tests { use ra_editor::find_node_at_offset; use ra_syntax::SourceFileNode; use test_utils::extract_offset; use super::*; fn do_check(code: &str, expected: &[&str]) { let (off, code) = extract_offset(code); let code = { let mut buf = String::new(); let off = u32::from(off) as usize; buf.push_str(&code[..off]); buf.push_str("marker"); buf.push_str(&code[off..]); buf }; let file = SourceFileNode::parse(&code); let marker: ast::PathExpr = find_node_at_offset(file.syntax(), off).unwrap(); let fn_def: ast::FnDef = find_node_at_offset(file.syntax(), off).unwrap(); let scopes = FnScopes::new(fn_def); let actual = scopes .scope_chain(marker.syntax()) .flat_map(|scope| scopes.entries(scope)) .map(|it| it.name()) .collect::>(); assert_eq!(actual.as_slice(), expected); } #[test] fn test_lambda_scope() { do_check( r" fn quux(foo: i32) { let f = |bar, baz: i32| { <|> }; }", &["bar", "baz", "foo"], ); } #[test] fn test_call_scope() { do_check( r" fn quux() { f(|x| <|> ); }", &["x"], ); } #[test] fn test_metod_call_scope() { do_check( r" fn quux() { z.f(|x| <|> ); }", &["x"], ); } #[test] fn test_loop_scope() { do_check( r" fn quux() { loop { let x = (); <|> }; }", &["x"], ); } #[test] fn test_match() { do_check( r" fn quux() { match () { Some(x) => { <|> } }; }", &["x"], ); } #[test] fn test_shadow_variable() { do_check( r" fn foo(x: String) { let x : &str = &x<|>; }", &["x"], ); } fn do_check_local_name(code: &str, expected_offset: u32) { let (off, code) = extract_offset(code); let file = SourceFileNode::parse(&code); let fn_def: ast::FnDef = find_node_at_offset(file.syntax(), off).unwrap(); let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), off).unwrap(); let scopes = FnScopes::new(fn_def); let local_name_entry = scopes.resolve_local_name(name_ref).unwrap(); let local_name = local_name_entry.ptr().resolve(&file); let expected_name = find_node_at_offset::(file.syntax(), expected_offset.into()).unwrap(); assert_eq!(local_name.range(), expected_name.syntax().range()); } #[test] fn test_resolve_local_name() { do_check_local_name( r#" fn foo(x: i32, y: u32) { { let z = x * 2; } { let t = x<|> * 3; } }"#, 21, ); } #[test] fn test_resolve_local_name_declaration() { do_check_local_name( r#" fn foo(x: String) { let x : &str = &x<|>; }"#, 21, ); } #[test] fn test_resolve_local_name_shadow() { do_check_local_name( r" fn foo(x: String) { let x : &str = &x; x<|> }", 46, ); } }