aboutsummaryrefslogtreecommitdiff
path: root/crates/ra_hir/src/function/scope.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/ra_hir/src/function/scope.rs')
-rw-r--r--crates/ra_hir/src/function/scope.rs447
1 files changed, 447 insertions, 0 deletions
diff --git a/crates/ra_hir/src/function/scope.rs b/crates/ra_hir/src/function/scope.rs
new file mode 100644
index 000000000..863453291
--- /dev/null
+++ b/crates/ra_hir/src/function/scope.rs
@@ -0,0 +1,447 @@
1use rustc_hash::{FxHashMap, FxHashSet};
2
3use ra_syntax::{
4 AstNode, SmolStr, SyntaxNodeRef, TextRange,
5 algo::generate,
6 ast::{self, ArgListOwner, LoopBodyOwner, NameOwner},
7};
8use ra_db::LocalSyntaxPtr;
9
10use crate::{
11 arena::{Arena, Id},
12};
13
14pub(crate) type ScopeId = Id<ScopeData>;
15
16#[derive(Debug, PartialEq, Eq)]
17pub struct FnScopes {
18 pub self_param: Option<LocalSyntaxPtr>,
19 scopes: Arena<ScopeData>,
20 scope_for: FxHashMap<LocalSyntaxPtr, ScopeId>,
21}
22
23#[derive(Debug, PartialEq, Eq)]
24pub struct ScopeEntry {
25 name: SmolStr,
26 ptr: LocalSyntaxPtr,
27}
28
29#[derive(Debug, PartialEq, Eq)]
30pub struct ScopeData {
31 parent: Option<ScopeId>,
32 entries: Vec<ScopeEntry>,
33}
34
35impl FnScopes {
36 pub fn new(fn_def: ast::FnDef) -> FnScopes {
37 let mut scopes = FnScopes {
38 self_param: fn_def
39 .param_list()
40 .and_then(|it| it.self_param())
41 .map(|it| LocalSyntaxPtr::new(it.syntax())),
42 scopes: Arena::default(),
43 scope_for: FxHashMap::default(),
44 };
45 let root = scopes.root_scope();
46 scopes.add_params_bindings(root, fn_def.param_list());
47 if let Some(body) = fn_def.body() {
48 compute_block_scopes(body, &mut scopes, root)
49 }
50 scopes
51 }
52 pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
53 &self.scopes[scope].entries
54 }
55 pub fn scope_chain<'a>(&'a self, node: SyntaxNodeRef) -> impl Iterator<Item = ScopeId> + 'a {
56 generate(self.scope_for(node), move |&scope| {
57 self.scopes[scope].parent
58 })
59 }
60 pub fn resolve_local_name<'a>(&'a self, name_ref: ast::NameRef) -> Option<&'a ScopeEntry> {
61 let mut shadowed = FxHashSet::default();
62 let ret = self
63 .scope_chain(name_ref.syntax())
64 .flat_map(|scope| self.entries(scope).iter())
65 .filter(|entry| shadowed.insert(entry.name()))
66 .filter(|entry| entry.name() == &name_ref.text())
67 .nth(0);
68 ret
69 }
70
71 pub fn find_all_refs(&self, pat: ast::BindPat) -> Vec<ReferenceDescriptor> {
72 let fn_def = pat.syntax().ancestors().find_map(ast::FnDef::cast).unwrap();
73 let name_ptr = LocalSyntaxPtr::new(pat.syntax());
74 let refs: Vec<_> = fn_def
75 .syntax()
76 .descendants()
77 .filter_map(ast::NameRef::cast)
78 .filter(|name_ref| match self.resolve_local_name(*name_ref) {
79 None => false,
80 Some(entry) => entry.ptr() == name_ptr,
81 })
82 .map(|name_ref| ReferenceDescriptor {
83 name: name_ref.syntax().text().to_string(),
84 range: name_ref.syntax().range(),
85 })
86 .collect();
87
88 refs
89 }
90
91 fn root_scope(&mut self) -> ScopeId {
92 self.scopes.alloc(ScopeData {
93 parent: None,
94 entries: vec![],
95 })
96 }
97 fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
98 self.scopes.alloc(ScopeData {
99 parent: Some(parent),
100 entries: vec![],
101 })
102 }
103 fn add_bindings(&mut self, scope: ScopeId, pat: ast::Pat) {
104 let entries = pat
105 .syntax()
106 .descendants()
107 .filter_map(ast::BindPat::cast)
108 .filter_map(ScopeEntry::new);
109 self.scopes[scope].entries.extend(entries);
110 }
111 fn add_params_bindings(&mut self, scope: ScopeId, params: Option<ast::ParamList>) {
112 params
113 .into_iter()
114 .flat_map(|it| it.params())
115 .filter_map(|it| it.pat())
116 .for_each(|it| self.add_bindings(scope, it));
117 }
118 fn set_scope(&mut self, node: SyntaxNodeRef, scope: ScopeId) {
119 self.scope_for.insert(LocalSyntaxPtr::new(node), scope);
120 }
121 fn scope_for(&self, node: SyntaxNodeRef) -> Option<ScopeId> {
122 node.ancestors()
123 .map(LocalSyntaxPtr::new)
124 .filter_map(|it| self.scope_for.get(&it).map(|&scope| scope))
125 .next()
126 }
127}
128
129impl ScopeEntry {
130 fn new(pat: ast::BindPat) -> Option<ScopeEntry> {
131 let name = pat.name()?;
132 let res = ScopeEntry {
133 name: name.text(),
134 ptr: LocalSyntaxPtr::new(pat.syntax()),
135 };
136 Some(res)
137 }
138 pub fn name(&self) -> &SmolStr {
139 &self.name
140 }
141 pub fn ptr(&self) -> LocalSyntaxPtr {
142 self.ptr
143 }
144}
145
146fn compute_block_scopes(block: ast::Block, scopes: &mut FnScopes, mut scope: ScopeId) {
147 for stmt in block.statements() {
148 match stmt {
149 ast::Stmt::LetStmt(stmt) => {
150 if let Some(expr) = stmt.initializer() {
151 scopes.set_scope(expr.syntax(), scope);
152 compute_expr_scopes(expr, scopes, scope);
153 }
154 scope = scopes.new_scope(scope);
155 if let Some(pat) = stmt.pat() {
156 scopes.add_bindings(scope, pat);
157 }
158 }
159 ast::Stmt::ExprStmt(expr_stmt) => {
160 if let Some(expr) = expr_stmt.expr() {
161 scopes.set_scope(expr.syntax(), scope);
162 compute_expr_scopes(expr, scopes, scope);
163 }
164 }
165 }
166 }
167 if let Some(expr) = block.expr() {
168 scopes.set_scope(expr.syntax(), scope);
169 compute_expr_scopes(expr, scopes, scope);
170 }
171}
172
173fn compute_expr_scopes(expr: ast::Expr, scopes: &mut FnScopes, scope: ScopeId) {
174 match expr {
175 ast::Expr::IfExpr(e) => {
176 let cond_scope = e
177 .condition()
178 .and_then(|cond| compute_cond_scopes(cond, scopes, scope));
179 if let Some(block) = e.then_branch() {
180 compute_block_scopes(block, scopes, cond_scope.unwrap_or(scope));
181 }
182 if let Some(block) = e.else_branch() {
183 compute_block_scopes(block, scopes, scope);
184 }
185 }
186 ast::Expr::BlockExpr(e) => {
187 if let Some(block) = e.block() {
188 compute_block_scopes(block, scopes, scope);
189 }
190 }
191 ast::Expr::LoopExpr(e) => {
192 if let Some(block) = e.loop_body() {
193 compute_block_scopes(block, scopes, scope);
194 }
195 }
196 ast::Expr::WhileExpr(e) => {
197 let cond_scope = e
198 .condition()
199 .and_then(|cond| compute_cond_scopes(cond, scopes, scope));
200 if let Some(block) = e.loop_body() {
201 compute_block_scopes(block, scopes, cond_scope.unwrap_or(scope));
202 }
203 }
204 ast::Expr::ForExpr(e) => {
205 if let Some(expr) = e.iterable() {
206 compute_expr_scopes(expr, scopes, scope);
207 }
208 let mut scope = scope;
209 if let Some(pat) = e.pat() {
210 scope = scopes.new_scope(scope);
211 scopes.add_bindings(scope, pat);
212 }
213 if let Some(block) = e.loop_body() {
214 compute_block_scopes(block, scopes, scope);
215 }
216 }
217 ast::Expr::LambdaExpr(e) => {
218 let scope = scopes.new_scope(scope);
219 scopes.add_params_bindings(scope, e.param_list());
220 if let Some(body) = e.body() {
221 scopes.set_scope(body.syntax(), scope);
222 compute_expr_scopes(body, scopes, scope);
223 }
224 }
225 ast::Expr::CallExpr(e) => {
226 compute_call_scopes(e.expr(), e.arg_list(), scopes, scope);
227 }
228 ast::Expr::MethodCallExpr(e) => {
229 compute_call_scopes(e.expr(), e.arg_list(), scopes, scope);
230 }
231 ast::Expr::MatchExpr(e) => {
232 if let Some(expr) = e.expr() {
233 compute_expr_scopes(expr, scopes, scope);
234 }
235 for arm in e.match_arm_list().into_iter().flat_map(|it| it.arms()) {
236 let scope = scopes.new_scope(scope);
237 for pat in arm.pats() {
238 scopes.add_bindings(scope, pat);
239 }
240 if let Some(expr) = arm.expr() {
241 compute_expr_scopes(expr, scopes, scope);
242 }
243 }
244 }
245 _ => expr
246 .syntax()
247 .children()
248 .filter_map(ast::Expr::cast)
249 .for_each(|expr| compute_expr_scopes(expr, scopes, scope)),
250 };
251
252 fn compute_call_scopes(
253 receiver: Option<ast::Expr>,
254 arg_list: Option<ast::ArgList>,
255 scopes: &mut FnScopes,
256 scope: ScopeId,
257 ) {
258 arg_list
259 .into_iter()
260 .flat_map(|it| it.args())
261 .chain(receiver)
262 .for_each(|expr| compute_expr_scopes(expr, scopes, scope));
263 }
264
265 fn compute_cond_scopes(
266 cond: ast::Condition,
267 scopes: &mut FnScopes,
268 scope: ScopeId,
269 ) -> Option<ScopeId> {
270 if let Some(expr) = cond.expr() {
271 compute_expr_scopes(expr, scopes, scope);
272 }
273 if let Some(pat) = cond.pat() {
274 let s = scopes.new_scope(scope);
275 scopes.add_bindings(s, pat);
276 Some(s)
277 } else {
278 None
279 }
280 }
281}
282
283#[derive(Debug)]
284pub struct ReferenceDescriptor {
285 pub range: TextRange,
286 pub name: String,
287}
288
289#[cfg(test)]
290mod tests {
291 use ra_editor::find_node_at_offset;
292 use ra_syntax::SourceFileNode;
293 use test_utils::extract_offset;
294
295 use super::*;
296
297 fn do_check(code: &str, expected: &[&str]) {
298 let (off, code) = extract_offset(code);
299 let code = {
300 let mut buf = String::new();
301 let off = u32::from(off) as usize;
302 buf.push_str(&code[..off]);
303 buf.push_str("marker");
304 buf.push_str(&code[off..]);
305 buf
306 };
307 let file = SourceFileNode::parse(&code);
308 let marker: ast::PathExpr = find_node_at_offset(file.syntax(), off).unwrap();
309 let fn_def: ast::FnDef = find_node_at_offset(file.syntax(), off).unwrap();
310 let scopes = FnScopes::new(fn_def);
311 let actual = scopes
312 .scope_chain(marker.syntax())
313 .flat_map(|scope| scopes.entries(scope))
314 .map(|it| it.name())
315 .collect::<Vec<_>>();
316 assert_eq!(actual.as_slice(), expected);
317 }
318
319 #[test]
320 fn test_lambda_scope() {
321 do_check(
322 r"
323 fn quux(foo: i32) {
324 let f = |bar, baz: i32| {
325 <|>
326 };
327 }",
328 &["bar", "baz", "foo"],
329 );
330 }
331
332 #[test]
333 fn test_call_scope() {
334 do_check(
335 r"
336 fn quux() {
337 f(|x| <|> );
338 }",
339 &["x"],
340 );
341 }
342
343 #[test]
344 fn test_metod_call_scope() {
345 do_check(
346 r"
347 fn quux() {
348 z.f(|x| <|> );
349 }",
350 &["x"],
351 );
352 }
353
354 #[test]
355 fn test_loop_scope() {
356 do_check(
357 r"
358 fn quux() {
359 loop {
360 let x = ();
361 <|>
362 };
363 }",
364 &["x"],
365 );
366 }
367
368 #[test]
369 fn test_match() {
370 do_check(
371 r"
372 fn quux() {
373 match () {
374 Some(x) => {
375 <|>
376 }
377 };
378 }",
379 &["x"],
380 );
381 }
382
383 #[test]
384 fn test_shadow_variable() {
385 do_check(
386 r"
387 fn foo(x: String) {
388 let x : &str = &x<|>;
389 }",
390 &["x"],
391 );
392 }
393
394 fn do_check_local_name(code: &str, expected_offset: u32) {
395 let (off, code) = extract_offset(code);
396 let file = SourceFileNode::parse(&code);
397 let fn_def: ast::FnDef = find_node_at_offset(file.syntax(), off).unwrap();
398 let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), off).unwrap();
399
400 let scopes = FnScopes::new(fn_def);
401
402 let local_name_entry = scopes.resolve_local_name(name_ref).unwrap();
403 let local_name = local_name_entry.ptr().resolve(&file);
404 let expected_name =
405 find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into()).unwrap();
406 assert_eq!(local_name.range(), expected_name.syntax().range());
407 }
408
409 #[test]
410 fn test_resolve_local_name() {
411 do_check_local_name(
412 r#"
413 fn foo(x: i32, y: u32) {
414 {
415 let z = x * 2;
416 }
417 {
418 let t = x<|> * 3;
419 }
420 }"#,
421 21,
422 );
423 }
424
425 #[test]
426 fn test_resolve_local_name_declaration() {
427 do_check_local_name(
428 r#"
429 fn foo(x: String) {
430 let x : &str = &x<|>;
431 }"#,
432 21,
433 );
434 }
435
436 #[test]
437 fn test_resolve_local_name_shadow() {
438 do_check_local_name(
439 r"
440 fn foo(x: String) {
441 let x : &str = &x;
442 x<|>
443 }",
444 46,
445 );
446 }
447}