aboutsummaryrefslogtreecommitdiff
path: root/crates/hir_def/src/body/scope.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/hir_def/src/body/scope.rs')
-rw-r--r--crates/hir_def/src/body/scope.rs456
1 files changed, 456 insertions, 0 deletions
diff --git a/crates/hir_def/src/body/scope.rs b/crates/hir_def/src/body/scope.rs
new file mode 100644
index 000000000..9142bc05b
--- /dev/null
+++ b/crates/hir_def/src/body/scope.rs
@@ -0,0 +1,456 @@
1//! Name resolution for expressions.
2use std::sync::Arc;
3
4use arena::{Arena, Idx};
5use hir_expand::name::Name;
6use rustc_hash::FxHashMap;
7
8use crate::{
9 body::Body,
10 db::DefDatabase,
11 expr::{Expr, ExprId, Pat, PatId, Statement},
12 DefWithBodyId,
13};
14
15pub type ScopeId = Idx<ScopeData>;
16
17#[derive(Debug, PartialEq, Eq)]
18pub struct ExprScopes {
19 scopes: Arena<ScopeData>,
20 scope_by_expr: FxHashMap<ExprId, ScopeId>,
21}
22
23#[derive(Debug, PartialEq, Eq)]
24pub struct ScopeEntry {
25 name: Name,
26 pat: PatId,
27}
28
29impl ScopeEntry {
30 pub fn name(&self) -> &Name {
31 &self.name
32 }
33
34 pub fn pat(&self) -> PatId {
35 self.pat
36 }
37}
38
39#[derive(Debug, PartialEq, Eq)]
40pub struct ScopeData {
41 parent: Option<ScopeId>,
42 entries: Vec<ScopeEntry>,
43}
44
45impl ExprScopes {
46 pub(crate) fn expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes> {
47 let body = db.body(def);
48 Arc::new(ExprScopes::new(&*body))
49 }
50
51 fn new(body: &Body) -> ExprScopes {
52 let mut scopes =
53 ExprScopes { scopes: Arena::default(), scope_by_expr: FxHashMap::default() };
54 let root = scopes.root_scope();
55 scopes.add_params_bindings(body, root, &body.params);
56 compute_expr_scopes(body.body_expr, body, &mut scopes, root);
57 scopes
58 }
59
60 pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
61 &self.scopes[scope].entries
62 }
63
64 pub fn scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_ {
65 std::iter::successors(scope, move |&scope| self.scopes[scope].parent)
66 }
67
68 pub fn resolve_name_in_scope(&self, scope: ScopeId, name: &Name) -> Option<&ScopeEntry> {
69 self.scope_chain(Some(scope))
70 .find_map(|scope| self.entries(scope).iter().find(|it| it.name == *name))
71 }
72
73 pub fn scope_for(&self, expr: ExprId) -> Option<ScopeId> {
74 self.scope_by_expr.get(&expr).copied()
75 }
76
77 pub fn scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId> {
78 &self.scope_by_expr
79 }
80
81 fn root_scope(&mut self) -> ScopeId {
82 self.scopes.alloc(ScopeData { parent: None, entries: vec![] })
83 }
84
85 fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
86 self.scopes.alloc(ScopeData { parent: Some(parent), entries: vec![] })
87 }
88
89 fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
90 let pattern = &body[pat];
91 if let Pat::Bind { name, .. } = pattern {
92 let entry = ScopeEntry { name: name.clone(), pat };
93 self.scopes[scope].entries.push(entry);
94 }
95
96 pattern.walk_child_pats(|pat| self.add_bindings(body, scope, pat));
97 }
98
99 fn add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId]) {
100 params.iter().for_each(|pat| self.add_bindings(body, scope, *pat));
101 }
102
103 fn set_scope(&mut self, node: ExprId, scope: ScopeId) {
104 self.scope_by_expr.insert(node, scope);
105 }
106}
107
108fn compute_block_scopes(
109 statements: &[Statement],
110 tail: Option<ExprId>,
111 body: &Body,
112 scopes: &mut ExprScopes,
113 mut scope: ScopeId,
114) {
115 for stmt in statements {
116 match stmt {
117 Statement::Let { pat, initializer, .. } => {
118 if let Some(expr) = initializer {
119 scopes.set_scope(*expr, scope);
120 compute_expr_scopes(*expr, body, scopes, scope);
121 }
122 scope = scopes.new_scope(scope);
123 scopes.add_bindings(body, scope, *pat);
124 }
125 Statement::Expr(expr) => {
126 scopes.set_scope(*expr, scope);
127 compute_expr_scopes(*expr, body, scopes, scope);
128 }
129 }
130 }
131 if let Some(expr) = tail {
132 compute_expr_scopes(expr, body, scopes, scope);
133 }
134}
135
136fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
137 scopes.set_scope(expr, scope);
138 match &body[expr] {
139 Expr::Block { statements, tail, .. } => {
140 compute_block_scopes(&statements, *tail, body, scopes, scope);
141 }
142 Expr::For { iterable, pat, body: body_expr, .. } => {
143 compute_expr_scopes(*iterable, body, scopes, scope);
144 let scope = scopes.new_scope(scope);
145 scopes.add_bindings(body, scope, *pat);
146 compute_expr_scopes(*body_expr, body, scopes, scope);
147 }
148 Expr::Lambda { args, body: body_expr, .. } => {
149 let scope = scopes.new_scope(scope);
150 scopes.add_params_bindings(body, scope, &args);
151 compute_expr_scopes(*body_expr, body, scopes, scope);
152 }
153 Expr::Match { expr, arms } => {
154 compute_expr_scopes(*expr, body, scopes, scope);
155 for arm in arms {
156 let scope = scopes.new_scope(scope);
157 scopes.add_bindings(body, scope, arm.pat);
158 if let Some(guard) = arm.guard {
159 scopes.set_scope(guard, scope);
160 compute_expr_scopes(guard, body, scopes, scope);
161 }
162 scopes.set_scope(arm.expr, scope);
163 compute_expr_scopes(arm.expr, body, scopes, scope);
164 }
165 }
166 e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
167 };
168}
169
170#[cfg(test)]
171mod tests {
172 use base_db::{fixture::WithFixture, FileId, SourceDatabase};
173 use hir_expand::{name::AsName, InFile};
174 use syntax::{algo::find_node_at_offset, ast, AstNode};
175 use test_utils::{assert_eq_text, extract_offset, mark};
176
177 use crate::{db::DefDatabase, test_db::TestDB, FunctionId, ModuleDefId};
178
179 fn find_function(db: &TestDB, file_id: FileId) -> FunctionId {
180 let krate = db.test_crate();
181 let crate_def_map = db.crate_def_map(krate);
182
183 let module = crate_def_map.modules_for_file(file_id).next().unwrap();
184 let (_, def) = crate_def_map[module].scope.entries().next().unwrap();
185 match def.take_values().unwrap() {
186 ModuleDefId::FunctionId(it) => it,
187 _ => panic!(),
188 }
189 }
190
191 fn do_check(ra_fixture: &str, expected: &[&str]) {
192 let (offset, code) = extract_offset(ra_fixture);
193 let code = {
194 let mut buf = String::new();
195 let off: usize = offset.into();
196 buf.push_str(&code[..off]);
197 buf.push_str("<|>marker");
198 buf.push_str(&code[off..]);
199 buf
200 };
201
202 let (db, position) = TestDB::with_position(&code);
203 let file_id = position.file_id;
204 let offset = position.offset;
205
206 let file_syntax = db.parse(file_id).syntax_node();
207 let marker: ast::PathExpr = find_node_at_offset(&file_syntax, offset).unwrap();
208 let function = find_function(&db, file_id);
209
210 let scopes = db.expr_scopes(function.into());
211 let (_body, source_map) = db.body_with_source_map(function.into());
212
213 let expr_id = source_map
214 .node_expr(InFile { file_id: file_id.into(), value: &marker.into() })
215 .unwrap();
216 let scope = scopes.scope_for(expr_id);
217
218 let actual = scopes
219 .scope_chain(scope)
220 .flat_map(|scope| scopes.entries(scope))
221 .map(|it| it.name().to_string())
222 .collect::<Vec<_>>()
223 .join("\n");
224 let expected = expected.join("\n");
225 assert_eq_text!(&expected, &actual);
226 }
227
228 #[test]
229 fn test_lambda_scope() {
230 do_check(
231 r"
232 fn quux(foo: i32) {
233 let f = |bar, baz: i32| {
234 <|>
235 };
236 }",
237 &["bar", "baz", "foo"],
238 );
239 }
240
241 #[test]
242 fn test_call_scope() {
243 do_check(
244 r"
245 fn quux() {
246 f(|x| <|> );
247 }",
248 &["x"],
249 );
250 }
251
252 #[test]
253 fn test_method_call_scope() {
254 do_check(
255 r"
256 fn quux() {
257 z.f(|x| <|> );
258 }",
259 &["x"],
260 );
261 }
262
263 #[test]
264 fn test_loop_scope() {
265 do_check(
266 r"
267 fn quux() {
268 loop {
269 let x = ();
270 <|>
271 };
272 }",
273 &["x"],
274 );
275 }
276
277 #[test]
278 fn test_match() {
279 do_check(
280 r"
281 fn quux() {
282 match () {
283 Some(x) => {
284 <|>
285 }
286 };
287 }",
288 &["x"],
289 );
290 }
291
292 #[test]
293 fn test_shadow_variable() {
294 do_check(
295 r"
296 fn foo(x: String) {
297 let x : &str = &x<|>;
298 }",
299 &["x"],
300 );
301 }
302
303 #[test]
304 fn test_bindings_after_at() {
305 do_check(
306 r"
307fn foo() {
308 match Some(()) {
309 opt @ Some(unit) => {
310 <|>
311 }
312 _ => {}
313 }
314}
315",
316 &["opt", "unit"],
317 );
318 }
319
320 #[test]
321 fn macro_inner_item() {
322 do_check(
323 r"
324 macro_rules! mac {
325 () => {{
326 fn inner() {}
327 inner();
328 }};
329 }
330
331 fn foo() {
332 mac!();
333 <|>
334 }
335 ",
336 &[],
337 );
338 }
339
340 #[test]
341 fn broken_inner_item() {
342 do_check(
343 r"
344 fn foo() {
345 trait {}
346 <|>
347 }
348 ",
349 &[],
350 );
351 }
352
353 fn do_check_local_name(ra_fixture: &str, expected_offset: u32) {
354 let (db, position) = TestDB::with_position(ra_fixture);
355 let file_id = position.file_id;
356 let offset = position.offset;
357
358 let file = db.parse(file_id).ok().unwrap();
359 let expected_name = find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into())
360 .expect("failed to find a name at the target offset");
361 let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), offset).unwrap();
362
363 let function = find_function(&db, file_id);
364
365 let scopes = db.expr_scopes(function.into());
366 let (_body, source_map) = db.body_with_source_map(function.into());
367
368 let expr_scope = {
369 let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap();
370 let expr_id =
371 source_map.node_expr(InFile { file_id: file_id.into(), value: &expr_ast }).unwrap();
372 scopes.scope_for(expr_id).unwrap()
373 };
374
375 let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap();
376 let pat_src = source_map.pat_syntax(resolved.pat()).unwrap();
377
378 let local_name = pat_src.value.either(
379 |it| it.syntax_node_ptr().to_node(file.syntax()),
380 |it| it.syntax_node_ptr().to_node(file.syntax()),
381 );
382 assert_eq!(local_name.text_range(), expected_name.syntax().text_range());
383 }
384
385 #[test]
386 fn test_resolve_local_name() {
387 do_check_local_name(
388 r#"
389fn foo(x: i32, y: u32) {
390 {
391 let z = x * 2;
392 }
393 {
394 let t = x<|> * 3;
395 }
396}
397"#,
398 7,
399 );
400 }
401
402 #[test]
403 fn test_resolve_local_name_declaration() {
404 do_check_local_name(
405 r#"
406fn foo(x: String) {
407 let x : &str = &x<|>;
408}
409"#,
410 7,
411 );
412 }
413
414 #[test]
415 fn test_resolve_local_name_shadow() {
416 do_check_local_name(
417 r"
418fn foo(x: String) {
419 let x : &str = &x;
420 x<|>
421}
422",
423 28,
424 );
425 }
426
427 #[test]
428 fn ref_patterns_contribute_bindings() {
429 do_check_local_name(
430 r"
431fn foo() {
432 if let Some(&from) = bar() {
433 from<|>;
434 }
435}
436",
437 28,
438 );
439 }
440
441 #[test]
442 fn while_let_desugaring() {
443 mark::check!(infer_resolve_while_let);
444 do_check_local_name(
445 r#"
446fn test() {
447 let foo: Option<f32> = None;
448 while let Option::Some(spam) = foo {
449 spam<|>
450 }
451}
452"#,
453 75,
454 );
455 }
456}