diff options
Diffstat (limited to 'crates/ra_ide')
-rw-r--r-- | crates/ra_ide/src/ssr.rs | 122 |
1 files changed, 110 insertions, 12 deletions
diff --git a/crates/ra_ide/src/ssr.rs b/crates/ra_ide/src/ssr.rs index 1abb891c1..7b93ff2d2 100644 --- a/crates/ra_ide/src/ssr.rs +++ b/crates/ra_ide/src/ssr.rs | |||
@@ -5,12 +5,14 @@ use ra_db::{SourceDatabase, SourceDatabaseExt}; | |||
5 | use ra_ide_db::symbol_index::SymbolsDatabase; | 5 | use ra_ide_db::symbol_index::SymbolsDatabase; |
6 | use ra_ide_db::RootDatabase; | 6 | use ra_ide_db::RootDatabase; |
7 | use ra_syntax::ast::make::try_expr_from_text; | 7 | use ra_syntax::ast::make::try_expr_from_text; |
8 | use ra_syntax::ast::{AstToken, Comment, RecordField, RecordLit}; | 8 | use ra_syntax::ast::{ |
9 | use ra_syntax::{AstNode, SyntaxElement, SyntaxNode}; | 9 | ArgList, AstToken, CallExpr, Comment, Expr, MethodCallExpr, RecordField, RecordLit, |
10 | }; | ||
11 | use ra_syntax::{AstNode, SyntaxElement, SyntaxKind, SyntaxNode}; | ||
10 | use ra_text_edit::{TextEdit, TextEditBuilder}; | 12 | use ra_text_edit::{TextEdit, TextEditBuilder}; |
11 | use rustc_hash::FxHashMap; | 13 | use rustc_hash::FxHashMap; |
12 | use std::collections::HashMap; | 14 | use std::collections::HashMap; |
13 | use std::str::FromStr; | 15 | use std::{iter::once, str::FromStr}; |
14 | 16 | ||
15 | #[derive(Debug, PartialEq)] | 17 | #[derive(Debug, PartialEq)] |
16 | pub struct SsrError(String); | 18 | pub struct SsrError(String); |
@@ -219,6 +221,50 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { | |||
219 | ) | 221 | ) |
220 | } | 222 | } |
221 | 223 | ||
224 | fn check_call_and_method_call( | ||
225 | pattern: CallExpr, | ||
226 | code: MethodCallExpr, | ||
227 | placeholders: &[Var], | ||
228 | match_: Match, | ||
229 | ) -> Option<Match> { | ||
230 | let (pattern_name, pattern_type_args) = if let Some(Expr::PathExpr(path_exr)) = | ||
231 | pattern.expr() | ||
232 | { | ||
233 | let segment = path_exr.path().and_then(|p| p.segment()); | ||
234 | (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list())) | ||
235 | } else { | ||
236 | (None, None) | ||
237 | }; | ||
238 | let match_ = check_opt_nodes(pattern_name, code.name_ref(), placeholders, match_)?; | ||
239 | let match_ = | ||
240 | check_opt_nodes(pattern_type_args, code.type_arg_list(), placeholders, match_)?; | ||
241 | let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args(); | ||
242 | let code_args = code.syntax().children().find_map(ArgList::cast)?.args(); | ||
243 | let code_args = once(code.expr()?).chain(code_args); | ||
244 | check_iter(pattern_args, code_args, placeholders, match_) | ||
245 | } | ||
246 | |||
247 | fn check_method_call_and_call( | ||
248 | pattern: MethodCallExpr, | ||
249 | code: CallExpr, | ||
250 | placeholders: &[Var], | ||
251 | match_: Match, | ||
252 | ) -> Option<Match> { | ||
253 | let (code_name, code_type_args) = if let Some(Expr::PathExpr(path_exr)) = code.expr() { | ||
254 | let segment = path_exr.path().and_then(|p| p.segment()); | ||
255 | (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list())) | ||
256 | } else { | ||
257 | (None, None) | ||
258 | }; | ||
259 | let match_ = check_opt_nodes(pattern.name_ref(), code_name, placeholders, match_)?; | ||
260 | let match_ = | ||
261 | check_opt_nodes(pattern.type_arg_list(), code_type_args, placeholders, match_)?; | ||
262 | let code_args = code.syntax().children().find_map(ArgList::cast)?.args(); | ||
263 | let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args(); | ||
264 | let pattern_args = once(pattern.expr()?).chain(pattern_args); | ||
265 | check_iter(pattern_args, code_args, placeholders, match_) | ||
266 | } | ||
267 | |||
222 | fn check_opt_nodes( | 268 | fn check_opt_nodes( |
223 | pattern: Option<impl AstNode>, | 269 | pattern: Option<impl AstNode>, |
224 | code: Option<impl AstNode>, | 270 | code: Option<impl AstNode>, |
@@ -227,8 +273,8 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { | |||
227 | ) -> Option<Match> { | 273 | ) -> Option<Match> { |
228 | match (pattern, code) { | 274 | match (pattern, code) { |
229 | (Some(pattern), Some(code)) => check( | 275 | (Some(pattern), Some(code)) => check( |
230 | &SyntaxElement::from(pattern.syntax().clone()), | 276 | &pattern.syntax().clone().into(), |
231 | &SyntaxElement::from(code.syntax().clone()), | 277 | &code.syntax().clone().into(), |
232 | placeholders, | 278 | placeholders, |
233 | match_, | 279 | match_, |
234 | ), | 280 | ), |
@@ -237,6 +283,33 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { | |||
237 | } | 283 | } |
238 | } | 284 | } |
239 | 285 | ||
286 | fn check_iter<T, I1, I2>( | ||
287 | mut pattern: I1, | ||
288 | mut code: I2, | ||
289 | placeholders: &[Var], | ||
290 | match_: Match, | ||
291 | ) -> Option<Match> | ||
292 | where | ||
293 | T: AstNode, | ||
294 | I1: Iterator<Item = T>, | ||
295 | I2: Iterator<Item = T>, | ||
296 | { | ||
297 | pattern | ||
298 | .by_ref() | ||
299 | .zip(code.by_ref()) | ||
300 | .fold(Some(match_), |accum, (a, b)| { | ||
301 | accum.and_then(|match_| { | ||
302 | check( | ||
303 | &a.syntax().clone().into(), | ||
304 | &b.syntax().clone().into(), | ||
305 | placeholders, | ||
306 | match_, | ||
307 | ) | ||
308 | }) | ||
309 | }) | ||
310 | .filter(|_| pattern.next().is_none() && code.next().is_none()) | ||
311 | } | ||
312 | |||
240 | fn check( | 313 | fn check( |
241 | pattern: &SyntaxElement, | 314 | pattern: &SyntaxElement, |
242 | code: &SyntaxElement, | 315 | code: &SyntaxElement, |
@@ -260,6 +333,14 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { | |||
260 | (RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone())) | 333 | (RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone())) |
261 | { | 334 | { |
262 | check_record_lit(pattern, code, placeholders, match_) | 335 | check_record_lit(pattern, code, placeholders, match_) |
336 | } else if let (Some(pattern), Some(code)) = | ||
337 | (CallExpr::cast(pattern.clone()), MethodCallExpr::cast(code.clone())) | ||
338 | { | ||
339 | check_call_and_method_call(pattern, code, placeholders, match_) | ||
340 | } else if let (Some(pattern), Some(code)) = | ||
341 | (MethodCallExpr::cast(pattern.clone()), CallExpr::cast(code.clone())) | ||
342 | { | ||
343 | check_method_call_and_call(pattern, code, placeholders, match_) | ||
263 | } else { | 344 | } else { |
264 | let mut pattern_children = pattern | 345 | let mut pattern_children = pattern |
265 | .children_with_tokens() | 346 | .children_with_tokens() |
@@ -290,16 +371,15 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { | |||
290 | let kind = pattern.pattern.kind(); | 371 | let kind = pattern.pattern.kind(); |
291 | let matches = code | 372 | let matches = code |
292 | .descendants() | 373 | .descendants() |
293 | .filter(|n| n.kind() == kind) | 374 | .filter(|n| { |
375 | n.kind() == kind | ||
376 | || (kind == SyntaxKind::CALL_EXPR && n.kind() == SyntaxKind::METHOD_CALL_EXPR) | ||
377 | || (kind == SyntaxKind::METHOD_CALL_EXPR && n.kind() == SyntaxKind::CALL_EXPR) | ||
378 | }) | ||
294 | .filter_map(|code| { | 379 | .filter_map(|code| { |
295 | let match_ = | 380 | let match_ = |
296 | Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] }; | 381 | Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] }; |
297 | check( | 382 | check(&pattern.pattern.clone().into(), &code.into(), &pattern.vars, match_) |
298 | &SyntaxElement::from(pattern.pattern.clone()), | ||
299 | &SyntaxElement::from(code), | ||
300 | &pattern.vars, | ||
301 | match_, | ||
302 | ) | ||
303 | }) | 383 | }) |
304 | .collect(); | 384 | .collect(); |
305 | SsrMatches { matches } | 385 | SsrMatches { matches } |
@@ -498,4 +578,22 @@ mod tests { | |||
498 | "fn main() { foo::new(1, 2) }", | 578 | "fn main() { foo::new(1, 2) }", |
499 | ) | 579 | ) |
500 | } | 580 | } |
581 | |||
582 | #[test] | ||
583 | fn ssr_call_and_method_call() { | ||
584 | assert_ssr_transform( | ||
585 | "foo::<'a>($a:expr, $b:expr)) ==>> foo2($a, $b)", | ||
586 | "fn main() { get().bar.foo::<'a>(1); }", | ||
587 | "fn main() { foo2(get().bar, 1); }", | ||
588 | ) | ||
589 | } | ||
590 | |||
591 | #[test] | ||
592 | fn ssr_method_call_and_call() { | ||
593 | assert_ssr_transform( | ||
594 | "$o:expr.foo::<i32>($a:expr)) ==>> $o.foo2($a)", | ||
595 | "fn main() { X::foo::<i32>(x, 1); }", | ||
596 | "fn main() { x.foo2(1); }", | ||
597 | ) | ||
598 | } | ||
501 | } | 599 | } |