aboutsummaryrefslogtreecommitdiff
path: root/crates/assists/src/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'crates/assists/src/handlers')
-rw-r--r--crates/assists/src/handlers/extract_function.rs224
1 files changed, 183 insertions, 41 deletions
diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs
index 958199e5e..c5e6ec733 100644
--- a/crates/assists/src/handlers/extract_function.rs
+++ b/crates/assists/src/handlers/extract_function.rs
@@ -1,7 +1,10 @@
1use either::Either; 1use either::Either;
2use hir::{HirDisplay, Local}; 2use hir::{HirDisplay, Local};
3use ide_db::defs::{Definition, NameRefClass}; 3use ide_db::{
4use rustc_hash::FxHashSet; 4 defs::{Definition, NameRefClass},
5 search::SearchScope,
6};
7use itertools::Itertools;
5use stdx::format_to; 8use stdx::format_to;
6use syntax::{ 9use syntax::{
7 ast::{ 10 ast::{
@@ -81,9 +84,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
81 } 84 }
82 let body = body?; 85 let body = body?;
83 86
87 let vars_used_in_body = vars_used_in_body(&body, &ctx);
84 let mut self_param = None; 88 let mut self_param = None;
85 let mut param_pats: Vec<_> = local_variables(&body, &ctx) 89 let param_pats: Vec<_> = vars_used_in_body
86 .into_iter() 90 .iter()
87 .map(|node| node.source(ctx.db())) 91 .map(|node| node.source(ctx.db()))
88 .filter(|src| { 92 .filter(|src| {
89 src.file_id.original_file(ctx.db()) == ctx.frange.file_id 93 src.file_id.original_file(ctx.db()) == ctx.frange.file_id
@@ -98,12 +102,27 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
98 } 102 }
99 }) 103 })
100 .collect(); 104 .collect();
101 deduplicate_params(&mut param_pats);
102 105
103 let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; 106 let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding };
104 let insert_after = body.scope_for_fn_insertion(anchor)?; 107 let insert_after = body.scope_for_fn_insertion(anchor)?;
105 let module = ctx.sema.scope(&insert_after).module()?; 108 let module = ctx.sema.scope(&insert_after).module()?;
106 109
110 let vars_defined_in_body = vars_defined_in_body(&body, ctx);
111
112 let vars_in_body_used_afterwards: Vec<_> = vars_defined_in_body
113 .iter()
114 .copied()
115 .filter(|node| {
116 let usages = Definition::Local(*node)
117 .usages(&ctx.sema)
118 .in_scope(SearchScope::single_file(ctx.frange.file_id))
119 .all();
120 let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter());
121
122 usages.any(|reference| body.preceedes_range(reference.range))
123 })
124 .collect();
125
107 let params = param_pats 126 let params = param_pats
108 .into_iter() 127 .into_iter()
109 .map(|pat| { 128 .map(|pat| {
@@ -119,20 +138,18 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
119 }) 138 })
120 .collect::<Vec<_>>(); 139 .collect::<Vec<_>>();
121 140
122 let self_param =
123 if let Some(self_param) = self_param { Some(self_param.to_string()) } else { None };
124
125 let expr = body.tail_expr(); 141 let expr = body.tail_expr();
126 let ret_ty = match expr { 142 let ret_ty = match expr {
127 Some(expr) => { 143 Some(expr) => Some(ctx.sema.type_of_expr(&expr)?),
128 // FIXME: can we do assist when type is unknown?
129 // We can insert something like `-> ()`
130 let ty = ctx.sema.type_of_expr(&expr)?;
131 Some(ty.display_source_code(ctx.db(), module.into()).ok()?)
132 }
133 None => None, 144 None => None,
134 }; 145 };
135 146
147 let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit());
148 if stdx::never!(!vars_in_body_used_afterwards.is_empty() && !has_unit_ret) {
149 // We should not have variables that outlive body if we have expression block
150 return None;
151 }
152
136 let target_range = match &body { 153 let target_range = match &body {
137 FunctionBody::Expr(expr) => expr.syntax().text_range(), 154 FunctionBody::Expr(expr) => expr.syntax().text_range(),
138 FunctionBody::Span { .. } => ctx.frange.range, 155 FunctionBody::Span { .. } => ctx.frange.range,
@@ -143,21 +160,46 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
143 "Extract into function", 160 "Extract into function",
144 target_range, 161 target_range,
145 move |builder| { 162 move |builder| {
146 let fun = Function { name: "fun_name".to_string(), self_param, params, ret_ty, body }; 163 let fun = Function {
164 name: "fun_name".to_string(),
165 self_param,
166 params,
167 ret_ty,
168 body,
169 vars_in_body_used_afterwards,
170 };
147 171
148 builder.replace(target_range, format_replacement(&fun)); 172 builder.replace(target_range, format_replacement(ctx, &fun));
149 173
150 let indent = IndentLevel::from_node(&insert_after); 174 let indent = IndentLevel::from_node(&insert_after);
151 175
152 let fn_def = format_function(&fun, indent); 176 let fn_def = format_function(ctx, module, &fun, indent);
153 let insert_offset = insert_after.text_range().end(); 177 let insert_offset = insert_after.text_range().end();
154 builder.insert(insert_offset, fn_def); 178 builder.insert(insert_offset, fn_def);
155 }, 179 },
156 ) 180 )
157} 181}
158 182
159fn format_replacement(fun: &Function) -> String { 183fn format_replacement(ctx: &AssistContext, fun: &Function) -> String {
160 let mut buf = String::new(); 184 let mut buf = String::new();
185
186 match fun.vars_in_body_used_afterwards.len() {
187 0 => {}
188 1 => format_to!(
189 buf,
190 "let {} = ",
191 fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()
192 ),
193 _ => {
194 buf.push_str("let (");
195 format_to!(buf, "{}", fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap());
196 for local in fun.vars_in_body_used_afterwards.iter().skip(1) {
197 format_to!(buf, ", {}", local.name(ctx.db()).unwrap());
198 }
199 buf.push_str(") = ");
200 }
201 }
202
161 if fun.self_param.is_some() { 203 if fun.self_param.is_some() {
162 format_to!(buf, "self."); 204 format_to!(buf, "self.");
163 } 205 }
@@ -182,16 +224,17 @@ fn format_replacement(fun: &Function) -> String {
182 224
183struct Function { 225struct Function {
184 name: String, 226 name: String,
185 self_param: Option<String>, 227 self_param: Option<ast::SelfParam>,
186 params: Vec<Param>, 228 params: Vec<Param>,
187 ret_ty: Option<String>, 229 ret_ty: Option<hir::Type>,
188 body: FunctionBody, 230 body: FunctionBody,
231 vars_in_body_used_afterwards: Vec<Local>,
189} 232}
190 233
191impl Function { 234impl Function {
192 fn has_unit_ret(&self) -> bool { 235 fn has_unit_ret(&self) -> bool {
193 match &self.ret_ty { 236 match &self.ret_ty {
194 Some(ty) => ty == "()", 237 Some(ty) => ty.is_unit(),
195 None => true, 238 None => true,
196 } 239 }
197 } 240 }
@@ -203,7 +246,12 @@ struct Param {
203 ty: String, 246 ty: String,
204} 247}
205 248
206fn format_function(fun: &Function, indent: IndentLevel) -> String { 249fn format_function(
250 ctx: &AssistContext,
251 module: hir::Module,
252 fun: &Function,
253 indent: IndentLevel,
254) -> String {
207 let mut fn_def = String::new(); 255 let mut fn_def = String::new();
208 format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); 256 format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name);
209 { 257 {
@@ -221,10 +269,24 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String {
221 format_to!(fn_def, ")"); 269 format_to!(fn_def, ")");
222 if !fun.has_unit_ret() { 270 if !fun.has_unit_ret() {
223 if let Some(ty) = &fun.ret_ty { 271 if let Some(ty) = &fun.ret_ty {
224 format_to!(fn_def, " -> {}", ty); 272 format_to!(fn_def, " -> {}", format_type(ty, ctx, module));
273 }
274 } else {
275 match fun.vars_in_body_used_afterwards.as_slice() {
276 [] => {}
277 [var] => {
278 format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module));
279 }
280 [v0, vs @ ..] => {
281 format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module));
282 for var in vs {
283 format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module));
284 }
285 fn_def.push(')');
286 }
225 } 287 }
226 } 288 }
227 format_to!(fn_def, " {{"); 289 fn_def.push_str(" {");
228 290
229 match &fun.body { 291 match &fun.body {
230 FunctionBody::Expr(expr) => { 292 FunctionBody::Expr(expr) => {
@@ -243,11 +305,28 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String {
243 } 305 }
244 } 306 }
245 } 307 }
308
309 match fun.vars_in_body_used_afterwards.as_slice() {
310 [] => {}
311 [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()),
312 [v0, vs @ ..] => {
313 format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap());
314 for var in vs {
315 format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap());
316 }
317 fn_def.push_str(")\n");
318 }
319 }
320
246 format_to!(fn_def, "{}}}", indent); 321 format_to!(fn_def, "{}}}", indent);
247 322
248 fn_def 323 fn_def
249} 324}
250 325
326fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String {
327 ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string())
328}
329
251#[derive(Debug)] 330#[derive(Debug)]
252enum FunctionBody { 331enum FunctionBody {
253 Expr(ast::Expr), 332 Expr(ast::Expr),
@@ -339,18 +418,26 @@ impl FunctionBody {
339 } 418 }
340 } 419 }
341 420
342 fn contains_node(&self, node: &SyntaxNode) -> bool { 421 fn text_range(&self) -> TextRange {
343 fn is_node(body: &FunctionBody, n: &SyntaxNode) -> bool { 422 match self {
344 match body { 423 FunctionBody::Expr(expr) => expr.syntax().text_range(),
345 FunctionBody::Expr(expr) => n == expr.syntax(), 424 FunctionBody::Span { elements, .. } => TextRange::new(
346 FunctionBody::Span { elements, .. } => { 425 elements.first().unwrap().text_range().start(),
347 // FIXME: can it be quadratic? 426 elements.last().unwrap().text_range().end(),
348 elements.iter().filter_map(SyntaxElement::as_node).any(|e| e == n) 427 ),
349 }
350 }
351 } 428 }
429 }
430
431 fn contains_range(&self, range: TextRange) -> bool {
432 self.text_range().contains_range(range)
433 }
352 434
353 node.ancestors().any(|a| is_node(self, &a)) 435 fn preceedes_range(&self, range: TextRange) -> bool {
436 self.text_range().end() <= range.start()
437 }
438
439 fn contains_node(&self, node: &SyntaxNode) -> bool {
440 self.contains_range(node.text_range())
354 } 441 }
355} 442}
356 443
@@ -383,11 +470,6 @@ fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option<SyntaxNod
383 last_ancestor 470 last_ancestor
384} 471}
385 472
386fn deduplicate_params(params: &mut Vec<ast::IdentPat>) {
387 let mut seen_params = FxHashSet::default();
388 params.retain(|p| seen_params.insert(p.clone()));
389}
390
391fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode { 473fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
392 match value { 474 match value {
393 Either::Left(pat) => pat.syntax(), 475 Either::Left(pat) => pat.syntax(),
@@ -395,8 +477,8 @@ fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
395 } 477 }
396} 478}
397 479
398/// Returns a vector of local variables that are refferenced in `body` 480/// Returns a vector of local variables that are referenced in `body`
399fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> { 481fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
400 body.descendants() 482 body.descendants()
401 .filter_map(ast::NameRef::cast) 483 .filter_map(ast::NameRef::cast)
402 .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) 484 .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref))
@@ -405,6 +487,16 @@ fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
405 Definition::Local(local) => Some(local), 487 Definition::Local(local) => Some(local),
406 _ => None, 488 _ => None,
407 }) 489 })
490 .unique()
491 .collect()
492}
493
494/// Returns a vector of local variables that are defined in `body`
495fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
496 body.descendants()
497 .filter_map(ast::IdentPat::cast)
498 .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt))
499 .unique()
408 .collect() 500 .collect()
409} 501}
410 502
@@ -973,4 +1065,54 @@ impl S {
973}", 1065}",
974 ); 1066 );
975 } 1067 }
1068
1069 #[test]
1070 fn variable_defined_inside_and_used_after_no_ret() {
1071 check_assist(
1072 extract_function,
1073 r"
1074fn foo() {
1075 let n = 1;
1076 $0let k = n * n;$0
1077 let m = k + 1;
1078}",
1079 r"
1080fn foo() {
1081 let n = 1;
1082 let k = fun_name(n);
1083 let m = k + 1;
1084}
1085
1086fn $0fun_name(n: i32) -> i32 {
1087 let k = n * n;
1088 k
1089}",
1090 );
1091 }
1092
1093 #[test]
1094 fn two_variables_defined_inside_and_used_after_no_ret() {
1095 check_assist(
1096 extract_function,
1097 r"
1098fn foo() {
1099 let n = 1;
1100 $0let k = n * n;
1101 let m = k + 2;$0
1102 let h = k + m;
1103}",
1104 r"
1105fn foo() {
1106 let n = 1;
1107 let (k, m) = fun_name(n);
1108 let h = k + m;
1109}
1110
1111fn $0fun_name(n: i32) -> (i32, i32) {
1112 let k = n * n;
1113 let m = k + 2;
1114 (k, m)
1115}",
1116 );
1117 }
976} 1118}