aboutsummaryrefslogtreecommitdiff
path: root/crates
diff options
context:
space:
mode:
Diffstat (limited to 'crates')
-rw-r--r--crates/ra_ide/src/ssr.rs122
1 files changed, 110 insertions, 12 deletions
diff --git a/crates/ra_ide/src/ssr.rs b/crates/ra_ide/src/ssr.rs
index 1abb891c1..7b93ff2d2 100644
--- a/crates/ra_ide/src/ssr.rs
+++ b/crates/ra_ide/src/ssr.rs
@@ -5,12 +5,14 @@ use ra_db::{SourceDatabase, SourceDatabaseExt};
5use ra_ide_db::symbol_index::SymbolsDatabase; 5use ra_ide_db::symbol_index::SymbolsDatabase;
6use ra_ide_db::RootDatabase; 6use ra_ide_db::RootDatabase;
7use ra_syntax::ast::make::try_expr_from_text; 7use ra_syntax::ast::make::try_expr_from_text;
8use ra_syntax::ast::{AstToken, Comment, RecordField, RecordLit}; 8use ra_syntax::ast::{
9use ra_syntax::{AstNode, SyntaxElement, SyntaxNode}; 9 ArgList, AstToken, CallExpr, Comment, Expr, MethodCallExpr, RecordField, RecordLit,
10};
11use ra_syntax::{AstNode, SyntaxElement, SyntaxKind, SyntaxNode};
10use ra_text_edit::{TextEdit, TextEditBuilder}; 12use ra_text_edit::{TextEdit, TextEditBuilder};
11use rustc_hash::FxHashMap; 13use rustc_hash::FxHashMap;
12use std::collections::HashMap; 14use std::collections::HashMap;
13use std::str::FromStr; 15use std::{iter::once, str::FromStr};
14 16
15#[derive(Debug, PartialEq)] 17#[derive(Debug, PartialEq)]
16pub struct SsrError(String); 18pub struct SsrError(String);
@@ -219,6 +221,50 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
219 ) 221 )
220 } 222 }
221 223
224 fn check_call_and_method_call(
225 pattern: CallExpr,
226 code: MethodCallExpr,
227 placeholders: &[Var],
228 match_: Match,
229 ) -> Option<Match> {
230 let (pattern_name, pattern_type_args) = if let Some(Expr::PathExpr(path_exr)) =
231 pattern.expr()
232 {
233 let segment = path_exr.path().and_then(|p| p.segment());
234 (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list()))
235 } else {
236 (None, None)
237 };
238 let match_ = check_opt_nodes(pattern_name, code.name_ref(), placeholders, match_)?;
239 let match_ =
240 check_opt_nodes(pattern_type_args, code.type_arg_list(), placeholders, match_)?;
241 let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args();
242 let code_args = code.syntax().children().find_map(ArgList::cast)?.args();
243 let code_args = once(code.expr()?).chain(code_args);
244 check_iter(pattern_args, code_args, placeholders, match_)
245 }
246
247 fn check_method_call_and_call(
248 pattern: MethodCallExpr,
249 code: CallExpr,
250 placeholders: &[Var],
251 match_: Match,
252 ) -> Option<Match> {
253 let (code_name, code_type_args) = if let Some(Expr::PathExpr(path_exr)) = code.expr() {
254 let segment = path_exr.path().and_then(|p| p.segment());
255 (segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list()))
256 } else {
257 (None, None)
258 };
259 let match_ = check_opt_nodes(pattern.name_ref(), code_name, placeholders, match_)?;
260 let match_ =
261 check_opt_nodes(pattern.type_arg_list(), code_type_args, placeholders, match_)?;
262 let code_args = code.syntax().children().find_map(ArgList::cast)?.args();
263 let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args();
264 let pattern_args = once(pattern.expr()?).chain(pattern_args);
265 check_iter(pattern_args, code_args, placeholders, match_)
266 }
267
222 fn check_opt_nodes( 268 fn check_opt_nodes(
223 pattern: Option<impl AstNode>, 269 pattern: Option<impl AstNode>,
224 code: Option<impl AstNode>, 270 code: Option<impl AstNode>,
@@ -227,8 +273,8 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
227 ) -> Option<Match> { 273 ) -> Option<Match> {
228 match (pattern, code) { 274 match (pattern, code) {
229 (Some(pattern), Some(code)) => check( 275 (Some(pattern), Some(code)) => check(
230 &SyntaxElement::from(pattern.syntax().clone()), 276 &pattern.syntax().clone().into(),
231 &SyntaxElement::from(code.syntax().clone()), 277 &code.syntax().clone().into(),
232 placeholders, 278 placeholders,
233 match_, 279 match_,
234 ), 280 ),
@@ -237,6 +283,33 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
237 } 283 }
238 } 284 }
239 285
286 fn check_iter<T, I1, I2>(
287 mut pattern: I1,
288 mut code: I2,
289 placeholders: &[Var],
290 match_: Match,
291 ) -> Option<Match>
292 where
293 T: AstNode,
294 I1: Iterator<Item = T>,
295 I2: Iterator<Item = T>,
296 {
297 pattern
298 .by_ref()
299 .zip(code.by_ref())
300 .fold(Some(match_), |accum, (a, b)| {
301 accum.and_then(|match_| {
302 check(
303 &a.syntax().clone().into(),
304 &b.syntax().clone().into(),
305 placeholders,
306 match_,
307 )
308 })
309 })
310 .filter(|_| pattern.next().is_none() && code.next().is_none())
311 }
312
240 fn check( 313 fn check(
241 pattern: &SyntaxElement, 314 pattern: &SyntaxElement,
242 code: &SyntaxElement, 315 code: &SyntaxElement,
@@ -260,6 +333,14 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
260 (RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone())) 333 (RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone()))
261 { 334 {
262 check_record_lit(pattern, code, placeholders, match_) 335 check_record_lit(pattern, code, placeholders, match_)
336 } else if let (Some(pattern), Some(code)) =
337 (CallExpr::cast(pattern.clone()), MethodCallExpr::cast(code.clone()))
338 {
339 check_call_and_method_call(pattern, code, placeholders, match_)
340 } else if let (Some(pattern), Some(code)) =
341 (MethodCallExpr::cast(pattern.clone()), CallExpr::cast(code.clone()))
342 {
343 check_method_call_and_call(pattern, code, placeholders, match_)
263 } else { 344 } else {
264 let mut pattern_children = pattern 345 let mut pattern_children = pattern
265 .children_with_tokens() 346 .children_with_tokens()
@@ -290,16 +371,15 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
290 let kind = pattern.pattern.kind(); 371 let kind = pattern.pattern.kind();
291 let matches = code 372 let matches = code
292 .descendants() 373 .descendants()
293 .filter(|n| n.kind() == kind) 374 .filter(|n| {
375 n.kind() == kind
376 || (kind == SyntaxKind::CALL_EXPR && n.kind() == SyntaxKind::METHOD_CALL_EXPR)
377 || (kind == SyntaxKind::METHOD_CALL_EXPR && n.kind() == SyntaxKind::CALL_EXPR)
378 })
294 .filter_map(|code| { 379 .filter_map(|code| {
295 let match_ = 380 let match_ =
296 Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] }; 381 Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] };
297 check( 382 check(&pattern.pattern.clone().into(), &code.into(), &pattern.vars, match_)
298 &SyntaxElement::from(pattern.pattern.clone()),
299 &SyntaxElement::from(code),
300 &pattern.vars,
301 match_,
302 )
303 }) 383 })
304 .collect(); 384 .collect();
305 SsrMatches { matches } 385 SsrMatches { matches }
@@ -498,4 +578,22 @@ mod tests {
498 "fn main() { foo::new(1, 2) }", 578 "fn main() { foo::new(1, 2) }",
499 ) 579 )
500 } 580 }
581
582 #[test]
583 fn ssr_call_and_method_call() {
584 assert_ssr_transform(
585 "foo::<'a>($a:expr, $b:expr)) ==>> foo2($a, $b)",
586 "fn main() { get().bar.foo::<'a>(1); }",
587 "fn main() { foo2(get().bar, 1); }",
588 )
589 }
590
591 #[test]
592 fn ssr_method_call_and_call() {
593 assert_ssr_transform(
594 "$o:expr.foo::<i32>($a:expr)) ==>> $o.foo2($a)",
595 "fn main() { X::foo::<i32>(x, 1); }",
596 "fn main() { x.foo2(1); }",
597 )
598 }
501} 599}