From 8e3e5ab2c81f238ea4e731f55eac79b74d9d84c3 Mon Sep 17 00:00:00 2001 From: Florian Diebold Date: Sat, 5 Jan 2019 22:37:59 +0100 Subject: Make FnScopes use hir::Expr This was a bit complicated. I've added a wrapper type for now that does the LocalSyntaxPtr <-> ExprId translation; we might want to get rid of that or give it a nicer interface. --- crates/ra_hir/src/function/scope.rs | 368 +++++++++++++++++------------------- 1 file changed, 178 insertions(+), 190 deletions(-) (limited to 'crates/ra_hir/src/function') diff --git a/crates/ra_hir/src/function/scope.rs b/crates/ra_hir/src/function/scope.rs index 42bfe4f32..0607a99cb 100644 --- a/crates/ra_hir/src/function/scope.rs +++ b/crates/ra_hir/src/function/scope.rs @@ -1,14 +1,16 @@ +use std::sync::Arc; + use rustc_hash::{FxHashMap, FxHashSet}; use ra_syntax::{ AstNode, SyntaxNodeRef, TextUnit, TextRange, algo::generate, - ast::{self, ArgListOwner, LoopBodyOwner, NameOwner}, + ast, }; use ra_arena::{Arena, RawId, impl_arena_id}; use ra_db::LocalSyntaxPtr; -use crate::{Name, AsName}; +use crate::{Name, AsName, expr::{PatId, ExprId, Pat, Expr, Body, Statement, BodySyntaxMapping}}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ScopeId(RawId); @@ -16,15 +18,15 @@ impl_arena_id!(ScopeId); #[derive(Debug, PartialEq, Eq)] pub struct FnScopes { - pub self_param: Option, + body: Arc, scopes: Arena, - scope_for: FxHashMap, + scope_for: FxHashMap, } #[derive(Debug, PartialEq, Eq)] pub struct ScopeEntry { name: Name, - ptr: LocalSyntaxPtr, + pat: PatId, } #[derive(Debug, PartialEq, Eq)] @@ -34,28 +36,101 @@ pub struct ScopeData { } impl FnScopes { - pub(crate) fn new(fn_def: ast::FnDef) -> FnScopes { + pub(crate) fn new(body: Arc) -> FnScopes { let mut scopes = FnScopes { - self_param: fn_def - .param_list() - .and_then(|it| it.self_param()) - .map(|it| LocalSyntaxPtr::new(it.syntax())), + body: body.clone(), scopes: Arena::default(), scope_for: FxHashMap::default(), }; let root = scopes.root_scope(); - scopes.add_params_bindings(root, fn_def.param_list()); - if let Some(body) = fn_def.body() { - compute_block_scopes(body, &mut scopes, root) - } + scopes.add_params_bindings(root, body.args()); + compute_expr_scopes(body.body_expr(), &body, &mut scopes, root); scopes } pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] { &self.scopes[scope].entries } + pub fn scope_chain_for<'a>(&'a self, expr: ExprId) -> impl Iterator + 'a { + generate(self.scope_for(expr), move |&scope| { + self.scopes[scope].parent + }) + } + + pub fn resolve_local_name<'a>( + &'a self, + context_expr: ExprId, + name: Name, + ) -> Option<&'a ScopeEntry> { + let mut shadowed = FxHashSet::default(); + let ret = self + .scope_chain_for(context_expr) + .flat_map(|scope| self.entries(scope).iter()) + .filter(|entry| shadowed.insert(entry.name())) + .filter(|entry| entry.name() == &name) + .nth(0); + ret + } + + fn root_scope(&mut self) -> ScopeId { + self.scopes.alloc(ScopeData { + parent: None, + entries: vec![], + }) + } + fn new_scope(&mut self, parent: ScopeId) -> ScopeId { + self.scopes.alloc(ScopeData { + parent: Some(parent), + entries: vec![], + }) + } + fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) { + match body.pat(pat) { + Pat::Bind { name } => self.scopes[scope].entries.push(ScopeEntry { + name: name.clone(), + pat, + }), + p => p.walk_child_pats(|pat| self.add_bindings(body, scope, pat)), + } + } + fn add_params_bindings(&mut self, scope: ScopeId, params: &[PatId]) { + let body = Arc::clone(&self.body); + params + .into_iter() + .for_each(|it| self.add_bindings(&body, scope, *it)); + } + fn set_scope(&mut self, node: ExprId, scope: ScopeId) { + self.scope_for.insert(node, scope); + } + fn scope_for(&self, expr: ExprId) -> Option { + self.scope_for.get(&expr).map(|&scope| scope) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ScopesWithSyntaxMapping { + pub syntax_mapping: Arc, + pub scopes: Arc, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ScopeEntryWithSyntax { + name: Name, + ptr: LocalSyntaxPtr, +} + +impl ScopeEntryWithSyntax { + pub fn name(&self) -> &Name { + &self.name + } + pub fn ptr(&self) -> LocalSyntaxPtr { + self.ptr + } +} + +impl ScopesWithSyntaxMapping { pub fn scope_chain<'a>(&'a self, node: SyntaxNodeRef) -> impl Iterator + 'a { generate(self.scope_for(node), move |&scope| { - self.scopes[scope].parent + self.scopes.scopes[scope].parent }) } pub fn scope_chain_for_offset<'a>( @@ -63,26 +138,30 @@ impl FnScopes { offset: TextUnit, ) -> impl Iterator + 'a { let scope = self + .scopes .scope_for .iter() - // find containin scope + .filter_map(|(id, scope)| Some((self.syntax_mapping.expr_syntax(*id)?, scope))) + // find containing scope .min_by_key(|(ptr, _scope)| { ( !(ptr.range().start() <= offset && offset <= ptr.range().end()), ptr.range().len(), ) }) - .map(|(ptr, scope)| self.adjust(*ptr, *scope, offset)); + .map(|(ptr, scope)| self.adjust(ptr, *scope, offset)); - generate(scope, move |&scope| self.scopes[scope].parent) + generate(scope, move |&scope| self.scopes.scopes[scope].parent) } // XXX: during completion, cursor might be outside of any particular // expression. Try to figure out the correct scope... fn adjust(&self, ptr: LocalSyntaxPtr, original_scope: ScopeId, offset: TextUnit) -> ScopeId { let r = ptr.range(); let child_scopes = self + .scopes .scope_for .iter() + .filter_map(|(id, scope)| Some((self.syntax_mapping.expr_syntax(*id)?, scope))) .map(|(ptr, scope)| (ptr.range(), scope)) .filter(|(range, _)| range.start() <= offset && range.is_subrange(&r) && *range != r); @@ -100,22 +179,27 @@ impl FnScopes { .unwrap_or(original_scope) } - pub fn resolve_local_name<'a>(&'a self, name_ref: ast::NameRef) -> Option<&'a ScopeEntry> { + pub fn resolve_local_name(&self, name_ref: ast::NameRef) -> Option { let mut shadowed = FxHashSet::default(); let name = name_ref.as_name(); let ret = self .scope_chain(name_ref.syntax()) - .flat_map(|scope| self.entries(scope).iter()) + .flat_map(|scope| self.scopes.entries(scope).iter()) .filter(|entry| shadowed.insert(entry.name())) .filter(|entry| entry.name() == &name) .nth(0); - ret + ret.and_then(|entry| { + Some(ScopeEntryWithSyntax { + name: entry.name().clone(), + ptr: self.syntax_mapping.pat_syntax(entry.pat())?, + }) + }) } pub fn find_all_refs(&self, pat: ast::BindPat) -> Vec { let fn_def = pat.syntax().ancestors().find_map(ast::FnDef::cast).unwrap(); let name_ptr = LocalSyntaxPtr::new(pat.syntax()); - let refs: Vec<_> = fn_def + fn_def .syntax() .descendants() .filter_map(ast::NameRef::cast) @@ -127,203 +211,95 @@ impl FnScopes { name: name_ref.syntax().text().to_string(), range: name_ref.syntax().range(), }) - .collect(); - - refs + .collect() } - fn root_scope(&mut self) -> ScopeId { - self.scopes.alloc(ScopeData { - parent: None, - entries: vec![], - }) - } - fn new_scope(&mut self, parent: ScopeId) -> ScopeId { - self.scopes.alloc(ScopeData { - parent: Some(parent), - entries: vec![], - }) - } - fn add_bindings(&mut self, scope: ScopeId, pat: ast::Pat) { - let entries = pat - .syntax() - .descendants() - .filter_map(ast::BindPat::cast) - .filter_map(ScopeEntry::new); - self.scopes[scope].entries.extend(entries); - } - fn add_params_bindings(&mut self, scope: ScopeId, params: Option) { - params - .into_iter() - .flat_map(|it| it.params()) - .filter_map(|it| it.pat()) - .for_each(|it| self.add_bindings(scope, it)); - } - fn set_scope(&mut self, node: SyntaxNodeRef, scope: ScopeId) { - self.scope_for.insert(LocalSyntaxPtr::new(node), scope); - } fn scope_for(&self, node: SyntaxNodeRef) -> Option { node.ancestors() .map(LocalSyntaxPtr::new) - .filter_map(|it| self.scope_for.get(&it).map(|&scope| scope)) + .filter_map(|ptr| self.syntax_mapping.syntax_expr(ptr)) + .filter_map(|it| self.scopes.scope_for(it)) .next() } } impl ScopeEntry { - fn new(pat: ast::BindPat) -> Option { - let name = pat.name()?.as_name(); - let res = ScopeEntry { - name, - ptr: LocalSyntaxPtr::new(pat.syntax()), - }; - Some(res) - } pub fn name(&self) -> &Name { &self.name } - pub fn ptr(&self) -> LocalSyntaxPtr { - self.ptr + pub fn pat(&self) -> PatId { + self.pat } } -fn compute_block_scopes(block: ast::Block, scopes: &mut FnScopes, mut scope: ScopeId) { - // A hack for completion :( - scopes.set_scope(block.syntax(), scope); - for stmt in block.statements() { +fn compute_block_scopes( + statements: &[Statement], + tail: Option, + body: &Body, + scopes: &mut FnScopes, + mut scope: ScopeId, +) { + for stmt in statements { match stmt { - ast::Stmt::LetStmt(stmt) => { - if let Some(expr) = stmt.initializer() { - scopes.set_scope(expr.syntax(), scope); - compute_expr_scopes(expr, scopes, scope); + Statement::Let { + pat, initializer, .. + } => { + if let Some(expr) = initializer { + scopes.set_scope(*expr, scope); + compute_expr_scopes(*expr, body, scopes, scope); } scope = scopes.new_scope(scope); - if let Some(pat) = stmt.pat() { - scopes.add_bindings(scope, pat); - } + scopes.add_bindings(body, scope, *pat); } - ast::Stmt::ExprStmt(expr_stmt) => { - if let Some(expr) = expr_stmt.expr() { - scopes.set_scope(expr.syntax(), scope); - compute_expr_scopes(expr, scopes, scope); - } + Statement::Expr(expr) => { + scopes.set_scope(*expr, scope); + compute_expr_scopes(*expr, body, scopes, scope); } } } - if let Some(expr) = block.expr() { - scopes.set_scope(expr.syntax(), scope); - compute_expr_scopes(expr, scopes, scope); + if let Some(expr) = tail { + compute_expr_scopes(expr, body, scopes, scope); } } -fn compute_expr_scopes(expr: ast::Expr, scopes: &mut FnScopes, scope: ScopeId) { - match expr { - ast::Expr::IfExpr(e) => { - let cond_scope = e - .condition() - .and_then(|cond| compute_cond_scopes(cond, scopes, scope)); - if let Some(block) = e.then_branch() { - compute_block_scopes(block, scopes, cond_scope.unwrap_or(scope)); - } - if let Some(block) = e.else_branch() { - compute_block_scopes(block, scopes, scope); - } - } - ast::Expr::BlockExpr(e) => { - if let Some(block) = e.block() { - compute_block_scopes(block, scopes, scope); - } - } - ast::Expr::LoopExpr(e) => { - if let Some(block) = e.loop_body() { - compute_block_scopes(block, scopes, scope); - } - } - ast::Expr::WhileExpr(e) => { - let cond_scope = e - .condition() - .and_then(|cond| compute_cond_scopes(cond, scopes, scope)); - if let Some(block) = e.loop_body() { - compute_block_scopes(block, scopes, cond_scope.unwrap_or(scope)); - } +fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut FnScopes, scope: ScopeId) { + scopes.set_scope(expr, scope); + match body.expr(expr) { + Expr::Block { statements, tail } => { + compute_block_scopes(&statements, *tail, body, scopes, scope); } - ast::Expr::ForExpr(e) => { - if let Some(expr) = e.iterable() { - compute_expr_scopes(expr, scopes, scope); - } - let mut scope = scope; - if let Some(pat) = e.pat() { - scope = scopes.new_scope(scope); - scopes.add_bindings(scope, pat); - } - if let Some(block) = e.loop_body() { - compute_block_scopes(block, scopes, scope); - } - } - ast::Expr::LambdaExpr(e) => { + Expr::For { + iterable, + pat, + body: body_expr, + } => { + compute_expr_scopes(*iterable, body, scopes, scope); let scope = scopes.new_scope(scope); - scopes.add_params_bindings(scope, e.param_list()); - if let Some(body) = e.body() { - scopes.set_scope(body.syntax(), scope); - compute_expr_scopes(body, scopes, scope); - } + scopes.add_bindings(body, scope, *pat); + compute_expr_scopes(*body_expr, body, scopes, scope); } - ast::Expr::CallExpr(e) => { - compute_call_scopes(e.expr(), e.arg_list(), scopes, scope); - } - ast::Expr::MethodCallExpr(e) => { - compute_call_scopes(e.expr(), e.arg_list(), scopes, scope); + Expr::Lambda { + args, + body: body_expr, + .. + } => { + let scope = scopes.new_scope(scope); + scopes.add_params_bindings(scope, &args); + compute_expr_scopes(*body_expr, body, scopes, scope); } - ast::Expr::MatchExpr(e) => { - if let Some(expr) = e.expr() { - compute_expr_scopes(expr, scopes, scope); - } - for arm in e.match_arm_list().into_iter().flat_map(|it| it.arms()) { + Expr::Match { expr, arms } => { + compute_expr_scopes(*expr, body, scopes, scope); + for arm in arms { let scope = scopes.new_scope(scope); - for pat in arm.pats() { - scopes.add_bindings(scope, pat); - } - if let Some(expr) = arm.expr() { - compute_expr_scopes(expr, scopes, scope); + for pat in &arm.pats { + scopes.add_bindings(body, scope, *pat); } + scopes.set_scope(arm.expr, scope); + compute_expr_scopes(arm.expr, body, scopes, scope); } } - _ => expr - .syntax() - .children() - .filter_map(ast::Expr::cast) - .for_each(|expr| compute_expr_scopes(expr, scopes, scope)), + e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)), }; - - fn compute_call_scopes( - receiver: Option, - arg_list: Option, - scopes: &mut FnScopes, - scope: ScopeId, - ) { - arg_list - .into_iter() - .flat_map(|it| it.args()) - .chain(receiver) - .for_each(|expr| compute_expr_scopes(expr, scopes, scope)); - } - - fn compute_cond_scopes( - cond: ast::Condition, - scopes: &mut FnScopes, - scope: ScopeId, - ) -> Option { - if let Some(expr) = cond.expr() { - compute_expr_scopes(expr, scopes, scope); - } - if let Some(pat) = cond.pat() { - let s = scopes.new_scope(scope); - scopes.add_bindings(s, pat); - Some(s) - } else { - None - } - } } #[derive(Debug)] @@ -338,6 +314,8 @@ mod tests { use ra_syntax::SourceFileNode; use test_utils::{extract_offset, assert_eq_text}; + use crate::expr; + use super::*; fn do_check(code: &str, expected: &[&str]) { @@ -353,15 +331,20 @@ mod tests { let file = SourceFileNode::parse(&code); let marker: ast::PathExpr = find_node_at_offset(file.syntax(), off).unwrap(); let fn_def: ast::FnDef = find_node_at_offset(file.syntax(), off).unwrap(); - let scopes = FnScopes::new(fn_def); + let body_hir = expr::collect_fn_body_syntax(fn_def); + let scopes = FnScopes::new(Arc::clone(body_hir.body())); + let scopes = ScopesWithSyntaxMapping { + scopes: Arc::new(scopes), + syntax_mapping: Arc::new(body_hir), + }; let actual = scopes .scope_chain(marker.syntax()) - .flat_map(|scope| scopes.entries(scope)) + .flat_map(|scope| scopes.scopes.entries(scope)) .map(|it| it.name().to_string()) .collect::>() .join("\n"); let expected = expected.join("\n"); - assert_eq_text!(&actual, &expected); + assert_eq_text!(&expected, &actual); } #[test] @@ -389,7 +372,7 @@ mod tests { } #[test] - fn test_metod_call_scope() { + fn test_method_call_scope() { do_check( r" fn quux() { @@ -445,10 +428,15 @@ mod tests { let fn_def: ast::FnDef = find_node_at_offset(file.syntax(), off).unwrap(); let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), off).unwrap(); - let scopes = FnScopes::new(fn_def); + let body_hir = expr::collect_fn_body_syntax(fn_def); + let scopes = FnScopes::new(Arc::clone(body_hir.body())); + let scopes = ScopesWithSyntaxMapping { + scopes: Arc::new(scopes), + syntax_mapping: Arc::new(body_hir), + }; let local_name_entry = scopes.resolve_local_name(name_ref).unwrap(); - let local_name = local_name_entry.ptr().resolve(&file); + let local_name = local_name_entry.ptr(); let expected_name = find_node_at_offset::(file.syntax(), expected_offset.into()).unwrap(); assert_eq!(local_name.range(), expected_name.syntax().range()); -- cgit v1.2.3 From e5a6cf815372150ad40dee995b7b89f29e701427 Mon Sep 17 00:00:00 2001 From: Florian Diebold Date: Sun, 6 Jan 2019 00:33:58 +0100 Subject: Various small code review improvements --- crates/ra_hir/src/function/scope.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) (limited to 'crates/ra_hir/src/function') diff --git a/crates/ra_hir/src/function/scope.rs b/crates/ra_hir/src/function/scope.rs index 0607a99cb..0a12f0b35 100644 --- a/crates/ra_hir/src/function/scope.rs +++ b/crates/ra_hir/src/function/scope.rs @@ -66,8 +66,7 @@ impl FnScopes { .scope_chain_for(context_expr) .flat_map(|scope| self.entries(scope).iter()) .filter(|entry| shadowed.insert(entry.name())) - .filter(|entry| entry.name() == &name) - .nth(0); + .find(|entry| entry.name() == &name); ret } @@ -84,7 +83,7 @@ impl FnScopes { }) } fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) { - match body.pat(pat) { + match &body[pat] { Pat::Bind { name } => self.scopes[scope].entries.push(ScopeEntry { name: name.clone(), pat, @@ -96,7 +95,7 @@ impl FnScopes { let body = Arc::clone(&self.body); params .into_iter() - .for_each(|it| self.add_bindings(&body, scope, *it)); + .for_each(|pat| self.add_bindings(&body, scope, *pat)); } fn set_scope(&mut self, node: ExprId, scope: ScopeId) { self.scope_for.insert(node, scope); @@ -218,8 +217,7 @@ impl ScopesWithSyntaxMapping { node.ancestors() .map(LocalSyntaxPtr::new) .filter_map(|ptr| self.syntax_mapping.syntax_expr(ptr)) - .filter_map(|it| self.scopes.scope_for(it)) - .next() + .find_map(|it| self.scopes.scope_for(it)) } } @@ -264,7 +262,7 @@ fn compute_block_scopes( fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut FnScopes, scope: ScopeId) { scopes.set_scope(expr, scope); - match body.expr(expr) { + match &body[expr] { Expr::Block { statements, tail } => { compute_block_scopes(&statements, *tail, body, scopes, scope); } -- cgit v1.2.3