aboutsummaryrefslogtreecommitdiff
path: root/crates/ra_hir/src/function/scope.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/ra_hir/src/function/scope.rs')
-rw-r--r--crates/ra_hir/src/function/scope.rs450
1 files changed, 450 insertions, 0 deletions
diff --git a/crates/ra_hir/src/function/scope.rs b/crates/ra_hir/src/function/scope.rs
new file mode 100644
index 000000000..c8b6b1934
--- /dev/null
+++ b/crates/ra_hir/src/function/scope.rs
@@ -0,0 +1,450 @@
1use rustc_hash::{FxHashMap, FxHashSet};
2
3use ra_syntax::{
4 AstNode, SmolStr, SyntaxNodeRef, TextRange,
5 algo::generate,
6 ast::{self, ArgListOwner, LoopBodyOwner, NameOwner},
7};
8use ra_db::LocalSyntaxPtr;
9
10use crate::{
11 arena::{Arena, Id},
12};
13
14pub(crate) type ScopeId = Id<ScopeData>;
15
16#[derive(Debug, PartialEq, Eq)]
17pub struct FnScopes {
18 pub(crate) self_param: Option<LocalSyntaxPtr>,
19 scopes: Arena<ScopeData>,
20 scope_for: FxHashMap<LocalSyntaxPtr, ScopeId>,
21}
22
23#[derive(Debug, PartialEq, Eq)]
24pub struct ScopeEntry {
25 name: SmolStr,
26 ptr: LocalSyntaxPtr,
27}
28
29#[derive(Debug, PartialEq, Eq)]
30pub(crate) struct ScopeData {
31 parent: Option<ScopeId>,
32 entries: Vec<ScopeEntry>,
33}
34
35impl FnScopes {
36 pub(crate) fn new(fn_def: ast::FnDef) -> FnScopes {
37 let mut scopes = FnScopes {
38 self_param: fn_def
39 .param_list()
40 .and_then(|it| it.self_param())
41 .map(|it| LocalSyntaxPtr::new(it.syntax())),
42 scopes: Arena::default(),
43 scope_for: FxHashMap::default(),
44 };
45 let root = scopes.root_scope();
46 scopes.add_params_bindings(root, fn_def.param_list());
47 if let Some(body) = fn_def.body() {
48 compute_block_scopes(body, &mut scopes, root)
49 }
50 scopes
51 }
52 pub(crate) fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
53 &self.scopes[scope].entries
54 }
55 pub fn scope_chain<'a>(&'a self, node: SyntaxNodeRef) -> impl Iterator<Item = ScopeId> + 'a {
56 generate(self.scope_for(node), move |&scope| {
57 self.scopes[scope].parent
58 })
59 }
60 pub(crate) fn resolve_local_name<'a>(
61 &'a self,
62 name_ref: ast::NameRef,
63 ) -> Option<&'a ScopeEntry> {
64 let mut shadowed = FxHashSet::default();
65 let ret = self
66 .scope_chain(name_ref.syntax())
67 .flat_map(|scope| self.entries(scope).iter())
68 .filter(|entry| shadowed.insert(entry.name()))
69 .filter(|entry| entry.name() == &name_ref.text())
70 .nth(0);
71 ret
72 }
73
74 pub fn find_all_refs(&self, pat: ast::BindPat) -> Vec<ReferenceDescriptor> {
75 let fn_def = pat.syntax().ancestors().find_map(ast::FnDef::cast).unwrap();
76 let name_ptr = LocalSyntaxPtr::new(pat.syntax());
77 let refs: Vec<_> = fn_def
78 .syntax()
79 .descendants()
80 .filter_map(ast::NameRef::cast)
81 .filter(|name_ref| match self.resolve_local_name(*name_ref) {
82 None => false,
83 Some(entry) => entry.ptr() == name_ptr,
84 })
85 .map(|name_ref| ReferenceDescriptor {
86 name: name_ref.syntax().text().to_string(),
87 range: name_ref.syntax().range(),
88 })
89 .collect();
90
91 refs
92 }
93
94 fn root_scope(&mut self) -> ScopeId {
95 self.scopes.alloc(ScopeData {
96 parent: None,
97 entries: vec![],
98 })
99 }
100 fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
101 self.scopes.alloc(ScopeData {
102 parent: Some(parent),
103 entries: vec![],
104 })
105 }
106 fn add_bindings(&mut self, scope: ScopeId, pat: ast::Pat) {
107 let entries = pat
108 .syntax()
109 .descendants()
110 .filter_map(ast::BindPat::cast)
111 .filter_map(ScopeEntry::new);
112 self.scopes[scope].entries.extend(entries);
113 }
114 fn add_params_bindings(&mut self, scope: ScopeId, params: Option<ast::ParamList>) {
115 params
116 .into_iter()
117 .flat_map(|it| it.params())
118 .filter_map(|it| it.pat())
119 .for_each(|it| self.add_bindings(scope, it));
120 }
121 fn set_scope(&mut self, node: SyntaxNodeRef, scope: ScopeId) {
122 self.scope_for.insert(LocalSyntaxPtr::new(node), scope);
123 }
124 fn scope_for(&self, node: SyntaxNodeRef) -> Option<ScopeId> {
125 node.ancestors()
126 .map(LocalSyntaxPtr::new)
127 .filter_map(|it| self.scope_for.get(&it).map(|&scope| scope))
128 .next()
129 }
130}
131
132impl ScopeEntry {
133 fn new(pat: ast::BindPat) -> Option<ScopeEntry> {
134 let name = pat.name()?;
135 let res = ScopeEntry {
136 name: name.text(),
137 ptr: LocalSyntaxPtr::new(pat.syntax()),
138 };
139 Some(res)
140 }
141 pub(crate) fn name(&self) -> &SmolStr {
142 &self.name
143 }
144 pub(crate) fn ptr(&self) -> LocalSyntaxPtr {
145 self.ptr
146 }
147}
148
149fn compute_block_scopes(block: ast::Block, scopes: &mut FnScopes, mut scope: ScopeId) {
150 for stmt in block.statements() {
151 match stmt {
152 ast::Stmt::LetStmt(stmt) => {
153 if let Some(expr) = stmt.initializer() {
154 scopes.set_scope(expr.syntax(), scope);
155 compute_expr_scopes(expr, scopes, scope);
156 }
157 scope = scopes.new_scope(scope);
158 if let Some(pat) = stmt.pat() {
159 scopes.add_bindings(scope, pat);
160 }
161 }
162 ast::Stmt::ExprStmt(expr_stmt) => {
163 if let Some(expr) = expr_stmt.expr() {
164 scopes.set_scope(expr.syntax(), scope);
165 compute_expr_scopes(expr, scopes, scope);
166 }
167 }
168 }
169 }
170 if let Some(expr) = block.expr() {
171 scopes.set_scope(expr.syntax(), scope);
172 compute_expr_scopes(expr, scopes, scope);
173 }
174}
175
176fn compute_expr_scopes(expr: ast::Expr, scopes: &mut FnScopes, scope: ScopeId) {
177 match expr {
178 ast::Expr::IfExpr(e) => {
179 let cond_scope = e
180 .condition()
181 .and_then(|cond| compute_cond_scopes(cond, scopes, scope));
182 if let Some(block) = e.then_branch() {
183 compute_block_scopes(block, scopes, cond_scope.unwrap_or(scope));
184 }
185 if let Some(block) = e.else_branch() {
186 compute_block_scopes(block, scopes, scope);
187 }
188 }
189 ast::Expr::BlockExpr(e) => {
190 if let Some(block) = e.block() {
191 compute_block_scopes(block, scopes, scope);
192 }
193 }
194 ast::Expr::LoopExpr(e) => {
195 if let Some(block) = e.loop_body() {
196 compute_block_scopes(block, scopes, scope);
197 }
198 }
199 ast::Expr::WhileExpr(e) => {
200 let cond_scope = e
201 .condition()
202 .and_then(|cond| compute_cond_scopes(cond, scopes, scope));
203 if let Some(block) = e.loop_body() {
204 compute_block_scopes(block, scopes, cond_scope.unwrap_or(scope));
205 }
206 }
207 ast::Expr::ForExpr(e) => {
208 if let Some(expr) = e.iterable() {
209 compute_expr_scopes(expr, scopes, scope);
210 }
211 let mut scope = scope;
212 if let Some(pat) = e.pat() {
213 scope = scopes.new_scope(scope);
214 scopes.add_bindings(scope, pat);
215 }
216 if let Some(block) = e.loop_body() {
217 compute_block_scopes(block, scopes, scope);
218 }
219 }
220 ast::Expr::LambdaExpr(e) => {
221 let scope = scopes.new_scope(scope);
222 scopes.add_params_bindings(scope, e.param_list());
223 if let Some(body) = e.body() {
224 scopes.set_scope(body.syntax(), scope);
225 compute_expr_scopes(body, scopes, scope);
226 }
227 }
228 ast::Expr::CallExpr(e) => {
229 compute_call_scopes(e.expr(), e.arg_list(), scopes, scope);
230 }
231 ast::Expr::MethodCallExpr(e) => {
232 compute_call_scopes(e.expr(), e.arg_list(), scopes, scope);
233 }
234 ast::Expr::MatchExpr(e) => {
235 if let Some(expr) = e.expr() {
236 compute_expr_scopes(expr, scopes, scope);
237 }
238 for arm in e.match_arm_list().into_iter().flat_map(|it| it.arms()) {
239 let scope = scopes.new_scope(scope);
240 for pat in arm.pats() {
241 scopes.add_bindings(scope, pat);
242 }
243 if let Some(expr) = arm.expr() {
244 compute_expr_scopes(expr, scopes, scope);
245 }
246 }
247 }
248 _ => expr
249 .syntax()
250 .children()
251 .filter_map(ast::Expr::cast)
252 .for_each(|expr| compute_expr_scopes(expr, scopes, scope)),
253 };
254
255 fn compute_call_scopes(
256 receiver: Option<ast::Expr>,
257 arg_list: Option<ast::ArgList>,
258 scopes: &mut FnScopes,
259 scope: ScopeId,
260 ) {
261 arg_list
262 .into_iter()
263 .flat_map(|it| it.args())
264 .chain(receiver)
265 .for_each(|expr| compute_expr_scopes(expr, scopes, scope));
266 }
267
268 fn compute_cond_scopes(
269 cond: ast::Condition,
270 scopes: &mut FnScopes,
271 scope: ScopeId,
272 ) -> Option<ScopeId> {
273 if let Some(expr) = cond.expr() {
274 compute_expr_scopes(expr, scopes, scope);
275 }
276 if let Some(pat) = cond.pat() {
277 let s = scopes.new_scope(scope);
278 scopes.add_bindings(s, pat);
279 Some(s)
280 } else {
281 None
282 }
283 }
284}
285
286#[derive(Debug)]
287pub struct ReferenceDescriptor {
288 pub range: TextRange,
289 pub name: String,
290}
291
292#[cfg(test)]
293mod tests {
294 use ra_editor::find_node_at_offset;
295 use ra_syntax::SourceFileNode;
296 use test_utils::extract_offset;
297
298 use super::*;
299
300 fn do_check(code: &str, expected: &[&str]) {
301 let (off, code) = extract_offset(code);
302 let code = {
303 let mut buf = String::new();
304 let off = u32::from(off) as usize;
305 buf.push_str(&code[..off]);
306 buf.push_str("marker");
307 buf.push_str(&code[off..]);
308 buf
309 };
310 let file = SourceFileNode::parse(&code);
311 let marker: ast::PathExpr = find_node_at_offset(file.syntax(), off).unwrap();
312 let fn_def: ast::FnDef = find_node_at_offset(file.syntax(), off).unwrap();
313 let scopes = FnScopes::new(fn_def);
314 let actual = scopes
315 .scope_chain(marker.syntax())
316 .flat_map(|scope| scopes.entries(scope))
317 .map(|it| it.name())
318 .collect::<Vec<_>>();
319 assert_eq!(actual.as_slice(), expected);
320 }
321
322 #[test]
323 fn test_lambda_scope() {
324 do_check(
325 r"
326 fn quux(foo: i32) {
327 let f = |bar, baz: i32| {
328 <|>
329 };
330 }",
331 &["bar", "baz", "foo"],
332 );
333 }
334
335 #[test]
336 fn test_call_scope() {
337 do_check(
338 r"
339 fn quux() {
340 f(|x| <|> );
341 }",
342 &["x"],
343 );
344 }
345
346 #[test]
347 fn test_metod_call_scope() {
348 do_check(
349 r"
350 fn quux() {
351 z.f(|x| <|> );
352 }",
353 &["x"],
354 );
355 }
356
357 #[test]
358 fn test_loop_scope() {
359 do_check(
360 r"
361 fn quux() {
362 loop {
363 let x = ();
364 <|>
365 };
366 }",
367 &["x"],
368 );
369 }
370
371 #[test]
372 fn test_match() {
373 do_check(
374 r"
375 fn quux() {
376 match () {
377 Some(x) => {
378 <|>
379 }
380 };
381 }",
382 &["x"],
383 );
384 }
385
386 #[test]
387 fn test_shadow_variable() {
388 do_check(
389 r"
390 fn foo(x: String) {
391 let x : &str = &x<|>;
392 }",
393 &["x"],
394 );
395 }
396
397 fn do_check_local_name(code: &str, expected_offset: u32) {
398 let (off, code) = extract_offset(code);
399 let file = SourceFileNode::parse(&code);
400 let fn_def: ast::FnDef = find_node_at_offset(file.syntax(), off).unwrap();
401 let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), off).unwrap();
402
403 let scopes = FnScopes::new(fn_def);
404
405 let local_name_entry = scopes.resolve_local_name(name_ref).unwrap();
406 let local_name = local_name_entry.ptr().resolve(&file);
407 let expected_name =
408 find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into()).unwrap();
409 assert_eq!(local_name.range(), expected_name.syntax().range());
410 }
411
412 #[test]
413 fn test_resolve_local_name() {
414 do_check_local_name(
415 r#"
416 fn foo(x: i32, y: u32) {
417 {
418 let z = x * 2;
419 }
420 {
421 let t = x<|> * 3;
422 }
423 }"#,
424 21,
425 );
426 }
427
428 #[test]
429 fn test_resolve_local_name_declaration() {
430 do_check_local_name(
431 r#"
432 fn foo(x: String) {
433 let x : &str = &x<|>;
434 }"#,
435 21,
436 );
437 }
438
439 #[test]
440 fn test_resolve_local_name_shadow() {
441 do_check_local_name(
442 r"
443 fn foo(x: String) {
444 let x : &str = &x;
445 x<|>
446 }",
447 46,
448 );
449 }
450}