aboutsummaryrefslogtreecommitdiff
path: root/crates/assists
diff options
context:
space:
mode:
Diffstat (limited to 'crates/assists')
-rw-r--r--crates/assists/src/handlers/extract_function.rs819
-rw-r--r--crates/assists/src/lib.rs2
-rw-r--r--crates/assists/src/tests/generated.rs27
3 files changed, 848 insertions, 0 deletions
diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs
new file mode 100644
index 000000000..1a6cfebed
--- /dev/null
+++ b/crates/assists/src/handlers/extract_function.rs
@@ -0,0 +1,819 @@
1use either::Either;
2use hir::{HirDisplay, Local};
3use ide_db::defs::{Definition, NameRefClass};
4use rustc_hash::FxHashSet;
5use stdx::format_to;
6use syntax::{
7 ast::{
8 self,
9 edit::{AstNodeEdit, IndentLevel},
10 AstNode, NameOwner,
11 },
12 Direction, SyntaxElement,
13 SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR},
14 SyntaxNode, TextRange,
15};
16use test_utils::mark;
17
18use crate::{
19 assist_context::{AssistContext, Assists},
20 AssistId,
21};
22
23// Assist: extract_function
24//
25// Extracts selected statements into new function.
26//
27// ```
28// fn main() {
29// let n = 1;
30// $0let m = n + 2;
31// let k = m + n;$0
32// let g = 3;
33// }
34// ```
35// ->
36// ```
37// fn main() {
38// let n = 1;
39// fun_name(n);
40// let g = 3;
41// }
42//
43// fn $0fun_name(n: i32) {
44// let m = n + 2;
45// let k = m + n;
46// }
47// ```
48pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
49 if ctx.frange.range.is_empty() {
50 return None;
51 }
52
53 let node = ctx.covering_element();
54 if node.kind() == COMMENT {
55 mark::hit!(extract_function_in_comment_is_not_applicable);
56 return None;
57 }
58
59 let node = match node {
60 syntax::NodeOrToken::Node(n) => n,
61 syntax::NodeOrToken::Token(t) => t.parent(),
62 };
63
64 let mut body = None;
65 if node.text_range() == ctx.frange.range {
66 body = FunctionBody::from_whole_node(node.clone());
67 }
68 if body.is_none() && node.kind() == BLOCK_EXPR {
69 body = FunctionBody::from_range(&node, ctx.frange.range);
70 }
71 if body.is_none() {
72 body = FunctionBody::from_whole_node(node.clone());
73 }
74 if body.is_none() {
75 body = node.ancestors().find_map(FunctionBody::from_whole_node);
76 }
77 let body = body?;
78
79 let insert_after = body.scope_for_fn_insertion()?;
80
81 let module = ctx.sema.scope(&insert_after).module()?;
82
83 let expr = body.tail_expr();
84 let ret_ty = match expr {
85 Some(expr) => {
86 // TODO: can we do assist when type is unknown?
87 // We can insert something like `-> ()`
88 let ty = ctx.sema.type_of_expr(&expr)?;
89 Some(ty.display_source_code(ctx.db(), module.into()).ok()?)
90 }
91 None => None,
92 };
93
94 let target_range = match &body {
95 FunctionBody::Expr(expr) => expr.syntax().text_range(),
96 FunctionBody::Span { .. } => ctx.frange.range,
97 };
98
99 let mut params = local_variables(&body, &ctx)
100 .into_iter()
101 .map(|node| node.source(ctx.db()))
102 .filter(|src| src.file_id.original_file(ctx.db()) == ctx.frange.file_id)
103 .map(|src| match src.value {
104 Either::Left(pat) => {
105 (pat.syntax().clone(), pat.name(), ctx.sema.type_of_pat(&pat.into()))
106 }
107 Either::Right(it) => (it.syntax().clone(), it.name(), ctx.sema.type_of_self(&it)),
108 })
109 .filter(|(node, _, _)| !body.contains_node(node))
110 .map(|(_, name, ty)| {
111 let ty = ty
112 .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok())
113 .unwrap_or_else(|| "()".to_string());
114
115 let name = name.unwrap().to_string();
116
117 Param { name, ty }
118 })
119 .collect::<Vec<_>>();
120 deduplicate_params(&mut params);
121
122 acc.add(
123 AssistId("extract_function", crate::AssistKind::RefactorExtract),
124 "Extract into function",
125 target_range,
126 move |builder| {
127
128 let fun = Function { name: "fun_name".to_string(), params, ret_ty, body };
129
130 builder.replace(target_range, format_replacement(&fun));
131
132 let indent = IndentLevel::from_node(&insert_after);
133
134 let fn_def = format_function(&fun, indent);
135 let insert_offset = insert_after.text_range().end();
136 builder.insert(insert_offset, fn_def);
137 },
138 )
139}
140
141fn format_replacement(fun: &Function) -> String {
142 let mut buf = String::new();
143 format_to!(buf, "{}(", fun.name);
144 {
145 let mut it = fun.params.iter();
146 if let Some(param) = it.next() {
147 format_to!(buf, "{}", param.name);
148 }
149 for param in it {
150 format_to!(buf, ", {}", param.name);
151 }
152 }
153 format_to!(buf, ")");
154
155 if fun.has_unit_ret() {
156 format_to!(buf, ";");
157 }
158
159 buf
160}
161
162struct Function {
163 name: String,
164 params: Vec<Param>,
165 ret_ty: Option<String>,
166 body: FunctionBody,
167}
168
169impl Function {
170 fn has_unit_ret(&self) -> bool {
171 match &self.ret_ty {
172 Some(ty) => ty == "()",
173 None => true,
174 }
175 }
176}
177
178#[derive(Debug)]
179struct Param {
180 name: String,
181 ty: String,
182}
183
184fn format_function(fun: &Function, indent: IndentLevel) -> String {
185 let mut fn_def = String::new();
186 format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name);
187 {
188 let mut it = fun.params.iter();
189 if let Some(param) = it.next() {
190 format_to!(fn_def, "{}: {}", param.name, param.ty);
191 }
192 for param in it {
193 format_to!(fn_def, ", {}: {}", param.name, param.ty);
194 }
195 }
196
197 format_to!(fn_def, ")");
198 if !fun.has_unit_ret() {
199 if let Some(ty) = &fun.ret_ty {
200 format_to!(fn_def, " -> {}", ty);
201 }
202 }
203 format_to!(fn_def, " {{");
204
205 match &fun.body {
206 FunctionBody::Expr(expr) => {
207 fn_def.push('\n');
208 let expr = expr.indent(indent);
209 format_to!(fn_def, "{}{}", indent + 1, expr.syntax());
210 fn_def.push('\n');
211 }
212 FunctionBody::Span { elements, leading_indent } => {
213 format_to!(fn_def, "{}", leading_indent);
214 for e in elements {
215 format_to!(fn_def, "{}", e);
216 }
217 if !fn_def.ends_with('\n') {
218 fn_def.push('\n');
219 }
220 }
221 }
222 format_to!(fn_def, "{}}}", indent);
223
224 fn_def
225}
226
227#[derive(Debug)]
228enum FunctionBody {
229 Expr(ast::Expr),
230 Span { elements: Vec<SyntaxElement>, leading_indent: String },
231}
232
233impl FunctionBody {
234 fn from_whole_node(node: SyntaxNode) -> Option<Self> {
235 match node.kind() {
236 PATH_EXPR => None,
237 BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()).map(Self::Expr),
238 RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()).map(Self::Expr),
239 BLOCK_EXPR => ast::BlockExpr::cast(node)
240 .filter(|it| it.is_standalone())
241 .map(Into::into)
242 .map(Self::Expr),
243 _ => ast::Expr::cast(node).map(Self::Expr),
244 }
245 }
246
247 fn from_range(node: &SyntaxNode, range: TextRange) -> Option<FunctionBody> {
248 let mut first = node.token_at_offset(range.start()).left_biased()?;
249 let last = node.token_at_offset(range.end()).right_biased()?;
250
251 let mut leading_indent = String::new();
252
253 let leading_trivia = first
254 .siblings_with_tokens(Direction::Prev)
255 .skip(1)
256 .take_while(|e| e.kind() == SyntaxKind::WHITESPACE && e.as_token().is_some());
257
258 for e in leading_trivia {
259 let token = e.as_token().unwrap();
260 let text = token.text();
261 match text.rfind('\n') {
262 Some(pos) => {
263 leading_indent = text[pos..].to_owned();
264 break;
265 }
266 None => first = token.clone(),
267 }
268 }
269
270 let mut elements: Vec<_> = first
271 .siblings_with_tokens(Direction::Next)
272 .take_while(|e| e.as_token() != Some(&last))
273 .collect();
274
275 if !(last.kind() == SyntaxKind::WHITESPACE && last.text().lines().count() <= 2) {
276 elements.push(last.into());
277 }
278
279 Some(FunctionBody::Span { elements, leading_indent })
280 }
281
282 fn tail_expr(&self) -> Option<ast::Expr> {
283 match &self {
284 FunctionBody::Expr(expr) => Some(expr.clone()),
285 FunctionBody::Span { elements, .. } => {
286 elements.iter().rev().find_map(|e| e.as_node()).cloned().and_then(ast::Expr::cast)
287 }
288 }
289 }
290
291 fn scope_for_fn_insertion(&self) -> Option<SyntaxNode> {
292 match self {
293 FunctionBody::Expr(e) => scope_for_fn_insertion(e.syntax()),
294 FunctionBody::Span { elements, .. } => {
295 let node = elements.iter().find_map(|e| e.as_node())?;
296 scope_for_fn_insertion(&node)
297 }
298 }
299 }
300
301 fn descendants(&self) -> impl Iterator<Item = SyntaxNode> + '_ {
302 match self {
303 FunctionBody::Expr(expr) => Either::Right(expr.syntax().descendants()),
304 FunctionBody::Span { elements, .. } => Either::Left(
305 elements
306 .iter()
307 .filter_map(SyntaxElement::as_node)
308 .flat_map(SyntaxNode::descendants),
309 ),
310 }
311 }
312
313 fn contains_node(&self, node: &SyntaxNode) -> bool {
314 fn is_node(body: &FunctionBody, n: &SyntaxNode) -> bool {
315 match body {
316 FunctionBody::Expr(expr) => n == expr.syntax(),
317 FunctionBody::Span { elements, .. } => {
318 // FIXME: can it be quadratic?
319 elements.iter().filter_map(SyntaxElement::as_node).any(|e| e == n)
320 }
321 }
322 }
323
324 node.ancestors().any(|a| is_node(self, &a))
325 }
326}
327
328fn scope_for_fn_insertion(node: &SyntaxNode) -> Option<SyntaxNode> {
329 let mut ancestors = node.ancestors().peekable();
330 let mut last_ancestor = None;
331 while let Some(next_ancestor) = ancestors.next() {
332 match next_ancestor.kind() {
333 SyntaxKind::SOURCE_FILE => break,
334 SyntaxKind::ITEM_LIST => {
335 if ancestors.peek().map(|a| a.kind()) == Some(SyntaxKind::MODULE) {
336 break;
337 }
338 }
339 _ => {}
340 }
341 last_ancestor = Some(next_ancestor);
342 }
343 last_ancestor
344}
345
346fn deduplicate_params(params: &mut Vec<Param>) {
347 let mut seen_params = FxHashSet::default();
348 params.retain(|p| seen_params.insert(p.name.clone()));
349}
350
351/// Returns a vector of local variables that are refferenced in `body`
352fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
353 body
354 .descendants()
355 .filter_map(ast::NameRef::cast)
356 .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref))
357 .map(|name_kind| name_kind.referenced(ctx.db()))
358 .filter_map(|definition| match definition {
359 Definition::Local(local) => Some(local),
360 _ => None,
361 })
362 .collect()
363}
364
365#[cfg(test)]
366mod tests {
367 use crate::tests::{check_assist, check_assist_not_applicable};
368
369 use super::*;
370
371 #[test]
372 fn no_args_from_binary_expr() {
373 check_assist(
374 extract_function,
375 r#"
376fn foo() {
377 foo($01 + 1$0);
378}"#,
379 r#"
380fn foo() {
381 foo(fun_name());
382}
383
384fn $0fun_name() -> i32 {
385 1 + 1
386}"#,
387 );
388 }
389
390 #[test]
391 fn no_args_from_binary_expr_in_module() {
392 check_assist(
393 extract_function,
394 r#"
395mod bar {
396 fn foo() {
397 foo($01 + 1$0);
398 }
399}"#,
400 r#"
401mod bar {
402 fn foo() {
403 foo(fun_name());
404 }
405
406 fn $0fun_name() -> i32 {
407 1 + 1
408 }
409}"#,
410 );
411 }
412
413 #[test]
414 fn no_args_from_binary_expr_indented() {
415 check_assist(
416 extract_function,
417 r#"
418fn foo() {
419 $0{ 1 + 1 }$0;
420}"#,
421 r#"
422fn foo() {
423 fun_name();
424}
425
426fn $0fun_name() -> i32 {
427 { 1 + 1 }
428}"#,
429 );
430 }
431
432 #[test]
433 fn no_args_from_stmt_with_last_expr() {
434 check_assist(
435 extract_function,
436 r#"
437fn foo() -> i32 {
438 let k = 1;
439 $0let m = 1;
440 m + 1$0
441}"#,
442 r#"
443fn foo() -> i32 {
444 let k = 1;
445 fun_name()
446}
447
448fn $0fun_name() -> i32 {
449 let m = 1;
450 m + 1
451}"#,
452 );
453 }
454
455 #[test]
456 fn no_args_from_stmt_unit() {
457 check_assist(
458 extract_function,
459 r#"
460fn foo() {
461 let k = 3;
462 $0let m = 1;
463 let n = m + 1;$0
464 let g = 5;
465}"#,
466 r#"
467fn foo() {
468 let k = 3;
469 fun_name();
470 let g = 5;
471}
472
473fn $0fun_name() {
474 let m = 1;
475 let n = m + 1;
476}"#,
477 );
478 }
479
480 #[test]
481 fn no_args_from_loop_unit() {
482 check_assist(
483 extract_function,
484 r#"
485fn foo() {
486 $0loop {
487 let m = 1;
488 }$0
489}"#,
490 r#"
491fn foo() {
492 fun_name()
493}
494
495fn $0fun_name() -> ! {
496 loop {
497 let m = 1;
498 }
499}"#,
500 );
501 }
502
503 #[test]
504 fn no_args_from_loop_with_return() {
505 check_assist(
506 extract_function,
507 r#"
508fn foo() {
509 let v = $0loop {
510 let m = 1;
511 break m;
512 }$0;
513}"#,
514 r#"
515fn foo() {
516 let v = fun_name();
517}
518
519fn $0fun_name() -> i32 {
520 loop {
521 let m = 1;
522 break m;
523 }
524}"#,
525 );
526 }
527
528 #[test]
529 fn no_args_from_match() {
530 check_assist(
531 extract_function,
532 r#"
533fn foo() {
534 let v: i32 = $0match Some(1) {
535 Some(x) => x,
536 None => 0,
537 }$0;
538}"#,
539 r#"
540fn foo() {
541 let v: i32 = fun_name();
542}
543
544fn $0fun_name() -> i32 {
545 match Some(1) {
546 Some(x) => x,
547 None => 0,
548 }
549}"#,
550 );
551 }
552
553 #[test]
554 fn argument_form_expr() {
555 check_assist(
556 extract_function,
557 r"
558fn foo() -> u32 {
559 let n = 2;
560 $0n+2$0
561}",
562 r"
563fn foo() -> u32 {
564 let n = 2;
565 fun_name(n)
566}
567
568fn $0fun_name(n: u32) -> u32 {
569 n+2
570}",
571 )
572 }
573
574 #[test]
575 fn argument_used_twice_form_expr() {
576 check_assist(
577 extract_function,
578 r"
579fn foo() -> u32 {
580 let n = 2;
581 $0n+n$0
582}",
583 r"
584fn foo() -> u32 {
585 let n = 2;
586 fun_name(n)
587}
588
589fn $0fun_name(n: u32) -> u32 {
590 n+n
591}",
592 )
593 }
594
595 #[test]
596 fn two_arguments_form_expr() {
597 check_assist(
598 extract_function,
599 r"
600fn foo() -> u32 {
601 let n = 2;
602 let m = 3;
603 $0n+n*m$0
604}",
605 r"
606fn foo() -> u32 {
607 let n = 2;
608 let m = 3;
609 fun_name(n, m)
610}
611
612fn $0fun_name(n: u32, m: u32) -> u32 {
613 n+n*m
614}",
615 )
616 }
617
618 #[test]
619 fn argument_and_locals() {
620 check_assist(
621 extract_function,
622 r"
623fn foo() -> u32 {
624 let n = 2;
625 $0let m = 1;
626 n + m$0
627}",
628 r"
629fn foo() -> u32 {
630 let n = 2;
631 fun_name(n)
632}
633
634fn $0fun_name(n: u32) -> u32 {
635 let m = 1;
636 n + m
637}",
638 )
639 }
640
641 #[test]
642 fn in_comment_is_not_applicable() {
643 mark::check!(extract_function_in_comment_is_not_applicable);
644 check_assist_not_applicable(extract_function, r"fn main() { 1 + /* $0comment$0 */ 1; }");
645 }
646
647 #[test]
648 fn part_of_expr_stmt() {
649 check_assist(
650 extract_function,
651 "
652fn foo() {
653 $01$0 + 1;
654}",
655 "
656fn foo() {
657 fun_name() + 1;
658}
659
660fn $0fun_name() -> i32 {
661 1
662}",
663 );
664 }
665
666 #[test]
667 fn function_expr() {
668 check_assist(
669 extract_function,
670 r#"
671fn foo() {
672 $0bar(1 + 1)$0
673}"#,
674 r#"
675fn foo() {
676 fun_name();
677}
678
679fn $0fun_name() {
680 bar(1 + 1)
681}"#,
682 )
683 }
684
685 #[test]
686 fn extract_from_nested() {
687 check_assist(
688 extract_function,
689 r"
690fn main() {
691 let x = true;
692 let tuple = match x {
693 true => ($02 + 2$0, true)
694 _ => (0, false)
695 };
696}",
697 r"
698fn main() {
699 let x = true;
700 let tuple = match x {
701 true => (fun_name(), true)
702 _ => (0, false)
703 };
704}
705
706fn $0fun_name() -> i32 {
707 2 + 2
708}",
709 );
710 }
711
712 #[test]
713 fn param_from_closure() {
714 check_assist(
715 extract_function,
716 r"
717fn main() {
718 let lambda = |x: u32| $0x * 2$0;
719}",
720 r"
721fn main() {
722 let lambda = |x: u32| fun_name(x);
723}
724
725fn $0fun_name(x: u32) -> u32 {
726 x * 2
727}",
728 );
729 }
730
731 #[test]
732 fn extract_return_stmt() {
733 check_assist(
734 extract_function,
735 r"
736fn foo() -> u32 {
737 $0return 2 + 2$0;
738}",
739 r"
740fn foo() -> u32 {
741 return fun_name();
742}
743
744fn $0fun_name() -> u32 {
745 2 + 2
746}",
747 );
748 }
749
750 #[test]
751 fn does_not_add_extra_whitespace() {
752 check_assist(
753 extract_function,
754 r"
755fn foo() -> u32 {
756
757
758 $0return 2 + 2$0;
759}",
760 r"
761fn foo() -> u32 {
762
763
764 return fun_name();
765}
766
767fn $0fun_name() -> u32 {
768 2 + 2
769}",
770 );
771 }
772
773 #[test]
774 fn break_stmt() {
775 check_assist(
776 extract_function,
777 r"
778fn main() {
779 let result = loop {
780 $0break 2 + 2$0;
781 };
782}",
783 r"
784fn main() {
785 let result = loop {
786 break fun_name();
787 };
788}
789
790fn $0fun_name() -> i32 {
791 2 + 2
792}",
793 );
794 }
795
796 #[test]
797 fn extract_cast() {
798 check_assist(
799 extract_function,
800 r"
801fn main() {
802 let v = $00f32 as u32$0;
803}",
804 r"
805fn main() {
806 let v = fun_name();
807}
808
809fn $0fun_name() -> u32 {
810 0f32 as u32
811}",
812 );
813 }
814
815 #[test]
816 fn return_not_applicable() {
817 check_assist_not_applicable(extract_function, r"fn foo() { $0return$0; } ");
818 }
819}
diff --git a/crates/assists/src/lib.rs b/crates/assists/src/lib.rs
index 559b9651e..062a902ab 100644
--- a/crates/assists/src/lib.rs
+++ b/crates/assists/src/lib.rs
@@ -117,6 +117,7 @@ mod handlers {
117 mod convert_integer_literal; 117 mod convert_integer_literal;
118 mod early_return; 118 mod early_return;
119 mod expand_glob_import; 119 mod expand_glob_import;
120 mod extract_function;
120 mod extract_struct_from_enum_variant; 121 mod extract_struct_from_enum_variant;
121 mod extract_variable; 122 mod extract_variable;
122 mod fill_match_arms; 123 mod fill_match_arms;
@@ -174,6 +175,7 @@ mod handlers {
174 early_return::convert_to_guarded_return, 175 early_return::convert_to_guarded_return,
175 expand_glob_import::expand_glob_import, 176 expand_glob_import::expand_glob_import,
176 move_module_to_file::move_module_to_file, 177 move_module_to_file::move_module_to_file,
178 extract_function::extract_function,
177 extract_struct_from_enum_variant::extract_struct_from_enum_variant, 179 extract_struct_from_enum_variant::extract_struct_from_enum_variant,
178 extract_variable::extract_variable, 180 extract_variable::extract_variable,
179 fill_match_arms::fill_match_arms, 181 fill_match_arms::fill_match_arms,
diff --git a/crates/assists/src/tests/generated.rs b/crates/assists/src/tests/generated.rs
index 9aa807f10..e84f208a3 100644
--- a/crates/assists/src/tests/generated.rs
+++ b/crates/assists/src/tests/generated.rs
@@ -257,6 +257,33 @@ fn qux(bar: Bar, baz: Baz) {}
257} 257}
258 258
259#[test] 259#[test]
260fn doctest_extract_function() {
261 check_doc_test(
262 "extract_function",
263 r#####"
264fn main() {
265 let n = 1;
266 $0let m = n + 2;
267 let k = m + n;$0
268 let g = 3;
269}
270"#####,
271 r#####"
272fn main() {
273 let n = 1;
274 fun_name(n);
275 let g = 3;
276}
277
278fn $0fun_name(n: i32) {
279 let m = n + 2;
280 let k = m + n;
281}
282"#####,
283 )
284}
285
286#[test]
260fn doctest_extract_struct_from_enum_variant() { 287fn doctest_extract_struct_from_enum_variant() {
261 check_doc_test( 288 check_doc_test(
262 "extract_struct_from_enum_variant", 289 "extract_struct_from_enum_variant",