aboutsummaryrefslogtreecommitdiff
path: root/crates/ra_hir/src/expr/scope.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/ra_hir/src/expr/scope.rs')
-rw-r--r--crates/ra_hir/src/expr/scope.rs521
1 files changed, 521 insertions, 0 deletions
diff --git a/crates/ra_hir/src/expr/scope.rs b/crates/ra_hir/src/expr/scope.rs
new file mode 100644
index 000000000..f8b5ba581
--- /dev/null
+++ b/crates/ra_hir/src/expr/scope.rs
@@ -0,0 +1,521 @@
1use std::sync::Arc;
2
3use rustc_hash::{FxHashMap, FxHashSet};
4
5use ra_syntax::{
6 AstNode, SyntaxNode, TextUnit, TextRange, SyntaxNodePtr,
7 algo::generate,
8 ast,
9};
10use ra_arena::{Arena, RawId, impl_arena_id};
11
12use crate::{
13 Name, AsName, Function,
14 expr::{PatId, ExprId, Pat, Expr, Body, Statement, BodySyntaxMapping},
15 db::HirDatabase,
16};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
19pub struct ScopeId(RawId);
20impl_arena_id!(ScopeId);
21
22#[derive(Debug, PartialEq, Eq)]
23pub struct ExprScopes {
24 body: Arc<Body>,
25 scopes: Arena<ScopeId, ScopeData>,
26 scope_for: FxHashMap<ExprId, ScopeId>,
27}
28
29#[derive(Debug, PartialEq, Eq)]
30pub struct ScopeEntry {
31 name: Name,
32 pat: PatId,
33}
34
35#[derive(Debug, PartialEq, Eq)]
36pub struct ScopeData {
37 parent: Option<ScopeId>,
38 entries: Vec<ScopeEntry>,
39}
40
41impl ExprScopes {
42 // TODO: This should take something more general than Function
43 pub(crate) fn expr_scopes_query(db: &impl HirDatabase, function: Function) -> Arc<ExprScopes> {
44 let body = db.body_hir(function);
45 let res = ExprScopes::new(body);
46 Arc::new(res)
47 }
48
49 fn new(body: Arc<Body>) -> ExprScopes {
50 let mut scopes = ExprScopes {
51 body: body.clone(),
52 scopes: Arena::default(),
53 scope_for: FxHashMap::default(),
54 };
55 let root = scopes.root_scope();
56 scopes.add_params_bindings(root, body.params());
57 compute_expr_scopes(body.body_expr(), &body, &mut scopes, root);
58 scopes
59 }
60
61 pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
62 &self.scopes[scope].entries
63 }
64
65 pub fn scope_chain_for<'a>(&'a self, expr: ExprId) -> impl Iterator<Item = ScopeId> + 'a {
66 generate(self.scope_for(expr), move |&scope| {
67 self.scopes[scope].parent
68 })
69 }
70
71 pub fn resolve_local_name<'a>(
72 &'a self,
73 context_expr: ExprId,
74 name: Name,
75 ) -> Option<&'a ScopeEntry> {
76 let mut shadowed = FxHashSet::default();
77 let ret = self
78 .scope_chain_for(context_expr)
79 .flat_map(|scope| self.entries(scope).iter())
80 .filter(|entry| shadowed.insert(entry.name()))
81 .find(|entry| entry.name() == &name);
82 ret
83 }
84
85 fn root_scope(&mut self) -> ScopeId {
86 self.scopes.alloc(ScopeData {
87 parent: None,
88 entries: vec![],
89 })
90 }
91
92 fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
93 self.scopes.alloc(ScopeData {
94 parent: Some(parent),
95 entries: vec![],
96 })
97 }
98
99 fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
100 match &body[pat] {
101 Pat::Bind { name, .. } => {
102 // bind can have a subpattern, but it's actually not allowed
103 // to bind to things in there
104 let entry = ScopeEntry {
105 name: name.clone(),
106 pat,
107 };
108 self.scopes[scope].entries.push(entry)
109 }
110 p => p.walk_child_pats(|pat| self.add_bindings(body, scope, pat)),
111 }
112 }
113
114 fn add_params_bindings(&mut self, scope: ScopeId, params: &[PatId]) {
115 let body = Arc::clone(&self.body);
116 params
117 .into_iter()
118 .for_each(|pat| self.add_bindings(&body, scope, *pat));
119 }
120
121 fn set_scope(&mut self, node: ExprId, scope: ScopeId) {
122 self.scope_for.insert(node, scope);
123 }
124
125 fn scope_for(&self, expr: ExprId) -> Option<ScopeId> {
126 self.scope_for.get(&expr).map(|&scope| scope)
127 }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq)]
131pub struct ScopesWithSyntaxMapping {
132 pub syntax_mapping: Arc<BodySyntaxMapping>,
133 pub scopes: Arc<ExprScopes>,
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub struct ScopeEntryWithSyntax {
138 name: Name,
139 ptr: SyntaxNodePtr,
140}
141
142impl ScopeEntryWithSyntax {
143 pub fn name(&self) -> &Name {
144 &self.name
145 }
146
147 pub fn ptr(&self) -> SyntaxNodePtr {
148 self.ptr
149 }
150}
151
152impl ScopesWithSyntaxMapping {
153 pub fn scope_chain<'a>(&'a self, node: &SyntaxNode) -> impl Iterator<Item = ScopeId> + 'a {
154 generate(self.scope_for(node), move |&scope| {
155 self.scopes.scopes[scope].parent
156 })
157 }
158
159 pub fn scope_chain_for_offset<'a>(
160 &'a self,
161 offset: TextUnit,
162 ) -> impl Iterator<Item = ScopeId> + 'a {
163 let scope = self
164 .scopes
165 .scope_for
166 .iter()
167 .filter_map(|(id, scope)| Some((self.syntax_mapping.expr_syntax(*id)?, scope)))
168 // find containing scope
169 .min_by_key(|(ptr, _scope)| {
170 (
171 !(ptr.range().start() <= offset && offset <= ptr.range().end()),
172 ptr.range().len(),
173 )
174 })
175 .map(|(ptr, scope)| self.adjust(ptr, *scope, offset));
176
177 generate(scope, move |&scope| self.scopes.scopes[scope].parent)
178 }
179
180 // XXX: during completion, cursor might be outside of any particular
181 // expression. Try to figure out the correct scope...
182 fn adjust(&self, ptr: SyntaxNodePtr, original_scope: ScopeId, offset: TextUnit) -> ScopeId {
183 let r = ptr.range();
184 let child_scopes = self
185 .scopes
186 .scope_for
187 .iter()
188 .filter_map(|(id, scope)| Some((self.syntax_mapping.expr_syntax(*id)?, scope)))
189 .map(|(ptr, scope)| (ptr.range(), scope))
190 .filter(|(range, _)| range.start() <= offset && range.is_subrange(&r) && *range != r);
191
192 child_scopes
193 .max_by(|(r1, _), (r2, _)| {
194 if r2.is_subrange(&r1) {
195 std::cmp::Ordering::Greater
196 } else if r1.is_subrange(&r2) {
197 std::cmp::Ordering::Less
198 } else {
199 r1.start().cmp(&r2.start())
200 }
201 })
202 .map(|(_ptr, scope)| *scope)
203 .unwrap_or(original_scope)
204 }
205
206 pub fn resolve_local_name(&self, name_ref: &ast::NameRef) -> Option<ScopeEntryWithSyntax> {
207 let mut shadowed = FxHashSet::default();
208 let name = name_ref.as_name();
209 let ret = self
210 .scope_chain(name_ref.syntax())
211 .flat_map(|scope| self.scopes.entries(scope).iter())
212 .filter(|entry| shadowed.insert(entry.name()))
213 .filter(|entry| entry.name() == &name)
214 .nth(0);
215 ret.and_then(|entry| {
216 Some(ScopeEntryWithSyntax {
217 name: entry.name().clone(),
218 ptr: self.syntax_mapping.pat_syntax(entry.pat())?,
219 })
220 })
221 }
222
223 pub fn find_all_refs(&self, pat: &ast::BindPat) -> Vec<ReferenceDescriptor> {
224 let fn_def = pat.syntax().ancestors().find_map(ast::FnDef::cast).unwrap();
225 let name_ptr = SyntaxNodePtr::new(pat.syntax());
226 fn_def
227 .syntax()
228 .descendants()
229 .filter_map(ast::NameRef::cast)
230 .filter(|name_ref| match self.resolve_local_name(*name_ref) {
231 None => false,
232 Some(entry) => entry.ptr() == name_ptr,
233 })
234 .map(|name_ref| ReferenceDescriptor {
235 name: name_ref.syntax().text().to_string(),
236 range: name_ref.syntax().range(),
237 })
238 .collect()
239 }
240
241 fn scope_for(&self, node: &SyntaxNode) -> Option<ScopeId> {
242 node.ancestors()
243 .map(SyntaxNodePtr::new)
244 .filter_map(|ptr| self.syntax_mapping.syntax_expr(ptr))
245 .find_map(|it| self.scopes.scope_for(it))
246 }
247}
248
249impl ScopeEntry {
250 pub fn name(&self) -> &Name {
251 &self.name
252 }
253
254 pub fn pat(&self) -> PatId {
255 self.pat
256 }
257}
258
259fn compute_block_scopes(
260 statements: &[Statement],
261 tail: Option<ExprId>,
262 body: &Body,
263 scopes: &mut ExprScopes,
264 mut scope: ScopeId,
265) {
266 for stmt in statements {
267 match stmt {
268 Statement::Let {
269 pat, initializer, ..
270 } => {
271 if let Some(expr) = initializer {
272 scopes.set_scope(*expr, scope);
273 compute_expr_scopes(*expr, body, scopes, scope);
274 }
275 scope = scopes.new_scope(scope);
276 scopes.add_bindings(body, scope, *pat);
277 }
278 Statement::Expr(expr) => {
279 scopes.set_scope(*expr, scope);
280 compute_expr_scopes(*expr, body, scopes, scope);
281 }
282 }
283 }
284 if let Some(expr) = tail {
285 compute_expr_scopes(expr, body, scopes, scope);
286 }
287}
288
289fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
290 scopes.set_scope(expr, scope);
291 match &body[expr] {
292 Expr::Block { statements, tail } => {
293 compute_block_scopes(&statements, *tail, body, scopes, scope);
294 }
295 Expr::For {
296 iterable,
297 pat,
298 body: body_expr,
299 } => {
300 compute_expr_scopes(*iterable, body, scopes, scope);
301 let scope = scopes.new_scope(scope);
302 scopes.add_bindings(body, scope, *pat);
303 compute_expr_scopes(*body_expr, body, scopes, scope);
304 }
305 Expr::Lambda {
306 args,
307 body: body_expr,
308 ..
309 } => {
310 let scope = scopes.new_scope(scope);
311 scopes.add_params_bindings(scope, &args);
312 compute_expr_scopes(*body_expr, body, scopes, scope);
313 }
314 Expr::Match { expr, arms } => {
315 compute_expr_scopes(*expr, body, scopes, scope);
316 for arm in arms {
317 let scope = scopes.new_scope(scope);
318 for pat in &arm.pats {
319 scopes.add_bindings(body, scope, *pat);
320 }
321 scopes.set_scope(arm.expr, scope);
322 compute_expr_scopes(arm.expr, body, scopes, scope);
323 }
324 }
325 e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
326 };
327}
328
329#[derive(Debug)]
330pub struct ReferenceDescriptor {
331 pub range: TextRange,
332 pub name: String,
333}
334
335#[cfg(test)]
336mod tests {
337 use ra_syntax::{SourceFile, algo::find_node_at_offset};
338 use test_utils::{extract_offset, assert_eq_text};
339
340 use crate::expr;
341
342 use super::*;
343
344 fn do_check(code: &str, expected: &[&str]) {
345 let (off, code) = extract_offset(code);
346 let code = {
347 let mut buf = String::new();
348 let off = u32::from(off) as usize;
349 buf.push_str(&code[..off]);
350 buf.push_str("marker");
351 buf.push_str(&code[off..]);
352 buf
353 };
354 let file = SourceFile::parse(&code);
355 let marker: &ast::PathExpr = find_node_at_offset(file.syntax(), off).unwrap();
356 let fn_def: &ast::FnDef = find_node_at_offset(file.syntax(), off).unwrap();
357 let body_hir = expr::collect_fn_body_syntax(fn_def);
358 let scopes = ExprScopes::new(Arc::clone(body_hir.body()));
359 let scopes = ScopesWithSyntaxMapping {
360 scopes: Arc::new(scopes),
361 syntax_mapping: Arc::new(body_hir),
362 };
363 let actual = scopes
364 .scope_chain(marker.syntax())
365 .flat_map(|scope| scopes.scopes.entries(scope))
366 .map(|it| it.name().to_string())
367 .collect::<Vec<_>>()
368 .join("\n");
369 let expected = expected.join("\n");
370 assert_eq_text!(&expected, &actual);
371 }
372
373 #[test]
374 fn test_lambda_scope() {
375 do_check(
376 r"
377 fn quux(foo: i32) {
378 let f = |bar, baz: i32| {
379 <|>
380 };
381 }",
382 &["bar", "baz", "foo"],
383 );
384 }
385
386 #[test]
387 fn test_call_scope() {
388 do_check(
389 r"
390 fn quux() {
391 f(|x| <|> );
392 }",
393 &["x"],
394 );
395 }
396
397 #[test]
398 fn test_method_call_scope() {
399 do_check(
400 r"
401 fn quux() {
402 z.f(|x| <|> );
403 }",
404 &["x"],
405 );
406 }
407
408 #[test]
409 fn test_loop_scope() {
410 do_check(
411 r"
412 fn quux() {
413 loop {
414 let x = ();
415 <|>
416 };
417 }",
418 &["x"],
419 );
420 }
421
422 #[test]
423 fn test_match() {
424 do_check(
425 r"
426 fn quux() {
427 match () {
428 Some(x) => {
429 <|>
430 }
431 };
432 }",
433 &["x"],
434 );
435 }
436
437 #[test]
438 fn test_shadow_variable() {
439 do_check(
440 r"
441 fn foo(x: String) {
442 let x : &str = &x<|>;
443 }",
444 &["x"],
445 );
446 }
447
448 fn do_check_local_name(code: &str, expected_offset: u32) {
449 let (off, code) = extract_offset(code);
450 let file = SourceFile::parse(&code);
451 let expected_name = find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into())
452 .expect("failed to find a name at the target offset");
453
454 let fn_def: &ast::FnDef = find_node_at_offset(file.syntax(), off).unwrap();
455 let name_ref: &ast::NameRef = find_node_at_offset(file.syntax(), off).unwrap();
456
457 let body_hir = expr::collect_fn_body_syntax(fn_def);
458 let scopes = ExprScopes::new(Arc::clone(body_hir.body()));
459 let scopes = ScopesWithSyntaxMapping {
460 scopes: Arc::new(scopes),
461 syntax_mapping: Arc::new(body_hir),
462 };
463 let local_name_entry = scopes.resolve_local_name(name_ref).unwrap();
464 let local_name = local_name_entry.ptr();
465 assert_eq!(local_name.range(), expected_name.syntax().range());
466 }
467
468 #[test]
469 fn test_resolve_local_name() {
470 do_check_local_name(
471 r#"
472 fn foo(x: i32, y: u32) {
473 {
474 let z = x * 2;
475 }
476 {
477 let t = x<|> * 3;
478 }
479 }"#,
480 21,
481 );
482 }
483
484 #[test]
485 fn test_resolve_local_name_declaration() {
486 do_check_local_name(
487 r#"
488 fn foo(x: String) {
489 let x : &str = &x<|>;
490 }"#,
491 21,
492 );
493 }
494
495 #[test]
496 fn test_resolve_local_name_shadow() {
497 do_check_local_name(
498 r"
499 fn foo(x: String) {
500 let x : &str = &x;
501 x<|>
502 }
503 ",
504 53,
505 );
506 }
507
508 #[test]
509 fn ref_patterns_contribute_bindings() {
510 do_check_local_name(
511 r"
512 fn foo() {
513 if let Some(&from) = bar() {
514 from<|>;
515 }
516 }
517 ",
518 53,
519 );
520 }
521}