diff options
author | Vladyslav Katasonov <[email protected]> | 2021-02-03 17:31:12 +0000 |
---|---|---|
committer | Vladyslav Katasonov <[email protected]> | 2021-02-03 18:11:12 +0000 |
commit | 82787febdee3e7dfe5a96c94aee03cd726f642f9 (patch) | |
tree | 1fa901e779885b9ad79a6f691335352388d55691 /crates/assists/src/handlers | |
parent | 313aa5f3a2a9237c96c97c5852da39cf83bcb1ae (diff) |
allow local variables to be used after extracted body
when variable is defined inside extracted body
export this variable to original scope via return value(s)
Diffstat (limited to 'crates/assists/src/handlers')
-rw-r--r-- | crates/assists/src/handlers/extract_function.rs | 224 |
1 files changed, 183 insertions, 41 deletions
diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 958199e5e..c5e6ec733 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs | |||
@@ -1,7 +1,10 @@ | |||
1 | use either::Either; | 1 | use either::Either; |
2 | use hir::{HirDisplay, Local}; | 2 | use hir::{HirDisplay, Local}; |
3 | use ide_db::defs::{Definition, NameRefClass}; | 3 | use ide_db::{ |
4 | use rustc_hash::FxHashSet; | 4 | defs::{Definition, NameRefClass}, |
5 | search::SearchScope, | ||
6 | }; | ||
7 | use itertools::Itertools; | ||
5 | use stdx::format_to; | 8 | use stdx::format_to; |
6 | use syntax::{ | 9 | use syntax::{ |
7 | ast::{ | 10 | ast::{ |
@@ -81,9 +84,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option | |||
81 | } | 84 | } |
82 | let body = body?; | 85 | let body = body?; |
83 | 86 | ||
87 | let vars_used_in_body = vars_used_in_body(&body, &ctx); | ||
84 | let mut self_param = None; | 88 | let mut self_param = None; |
85 | let mut param_pats: Vec<_> = local_variables(&body, &ctx) | 89 | let param_pats: Vec<_> = vars_used_in_body |
86 | .into_iter() | 90 | .iter() |
87 | .map(|node| node.source(ctx.db())) | 91 | .map(|node| node.source(ctx.db())) |
88 | .filter(|src| { | 92 | .filter(|src| { |
89 | src.file_id.original_file(ctx.db()) == ctx.frange.file_id | 93 | src.file_id.original_file(ctx.db()) == ctx.frange.file_id |
@@ -98,12 +102,27 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option | |||
98 | } | 102 | } |
99 | }) | 103 | }) |
100 | .collect(); | 104 | .collect(); |
101 | deduplicate_params(&mut param_pats); | ||
102 | 105 | ||
103 | let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; | 106 | let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; |
104 | let insert_after = body.scope_for_fn_insertion(anchor)?; | 107 | let insert_after = body.scope_for_fn_insertion(anchor)?; |
105 | let module = ctx.sema.scope(&insert_after).module()?; | 108 | let module = ctx.sema.scope(&insert_after).module()?; |
106 | 109 | ||
110 | let vars_defined_in_body = vars_defined_in_body(&body, ctx); | ||
111 | |||
112 | let vars_in_body_used_afterwards: Vec<_> = vars_defined_in_body | ||
113 | .iter() | ||
114 | .copied() | ||
115 | .filter(|node| { | ||
116 | let usages = Definition::Local(*node) | ||
117 | .usages(&ctx.sema) | ||
118 | .in_scope(SearchScope::single_file(ctx.frange.file_id)) | ||
119 | .all(); | ||
120 | let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter()); | ||
121 | |||
122 | usages.any(|reference| body.preceedes_range(reference.range)) | ||
123 | }) | ||
124 | .collect(); | ||
125 | |||
107 | let params = param_pats | 126 | let params = param_pats |
108 | .into_iter() | 127 | .into_iter() |
109 | .map(|pat| { | 128 | .map(|pat| { |
@@ -119,20 +138,18 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option | |||
119 | }) | 138 | }) |
120 | .collect::<Vec<_>>(); | 139 | .collect::<Vec<_>>(); |
121 | 140 | ||
122 | let self_param = | ||
123 | if let Some(self_param) = self_param { Some(self_param.to_string()) } else { None }; | ||
124 | |||
125 | let expr = body.tail_expr(); | 141 | let expr = body.tail_expr(); |
126 | let ret_ty = match expr { | 142 | let ret_ty = match expr { |
127 | Some(expr) => { | 143 | Some(expr) => Some(ctx.sema.type_of_expr(&expr)?), |
128 | // FIXME: can we do assist when type is unknown? | ||
129 | // We can insert something like `-> ()` | ||
130 | let ty = ctx.sema.type_of_expr(&expr)?; | ||
131 | Some(ty.display_source_code(ctx.db(), module.into()).ok()?) | ||
132 | } | ||
133 | None => None, | 144 | None => None, |
134 | }; | 145 | }; |
135 | 146 | ||
147 | let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit()); | ||
148 | if stdx::never!(!vars_in_body_used_afterwards.is_empty() && !has_unit_ret) { | ||
149 | // We should not have variables that outlive body if we have expression block | ||
150 | return None; | ||
151 | } | ||
152 | |||
136 | let target_range = match &body { | 153 | let target_range = match &body { |
137 | FunctionBody::Expr(expr) => expr.syntax().text_range(), | 154 | FunctionBody::Expr(expr) => expr.syntax().text_range(), |
138 | FunctionBody::Span { .. } => ctx.frange.range, | 155 | FunctionBody::Span { .. } => ctx.frange.range, |
@@ -143,21 +160,46 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option | |||
143 | "Extract into function", | 160 | "Extract into function", |
144 | target_range, | 161 | target_range, |
145 | move |builder| { | 162 | move |builder| { |
146 | let fun = Function { name: "fun_name".to_string(), self_param, params, ret_ty, body }; | 163 | let fun = Function { |
164 | name: "fun_name".to_string(), | ||
165 | self_param, | ||
166 | params, | ||
167 | ret_ty, | ||
168 | body, | ||
169 | vars_in_body_used_afterwards, | ||
170 | }; | ||
147 | 171 | ||
148 | builder.replace(target_range, format_replacement(&fun)); | 172 | builder.replace(target_range, format_replacement(ctx, &fun)); |
149 | 173 | ||
150 | let indent = IndentLevel::from_node(&insert_after); | 174 | let indent = IndentLevel::from_node(&insert_after); |
151 | 175 | ||
152 | let fn_def = format_function(&fun, indent); | 176 | let fn_def = format_function(ctx, module, &fun, indent); |
153 | let insert_offset = insert_after.text_range().end(); | 177 | let insert_offset = insert_after.text_range().end(); |
154 | builder.insert(insert_offset, fn_def); | 178 | builder.insert(insert_offset, fn_def); |
155 | }, | 179 | }, |
156 | ) | 180 | ) |
157 | } | 181 | } |
158 | 182 | ||
159 | fn format_replacement(fun: &Function) -> String { | 183 | fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { |
160 | let mut buf = String::new(); | 184 | let mut buf = String::new(); |
185 | |||
186 | match fun.vars_in_body_used_afterwards.len() { | ||
187 | 0 => {} | ||
188 | 1 => format_to!( | ||
189 | buf, | ||
190 | "let {} = ", | ||
191 | fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap() | ||
192 | ), | ||
193 | _ => { | ||
194 | buf.push_str("let ("); | ||
195 | format_to!(buf, "{}", fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()); | ||
196 | for local in fun.vars_in_body_used_afterwards.iter().skip(1) { | ||
197 | format_to!(buf, ", {}", local.name(ctx.db()).unwrap()); | ||
198 | } | ||
199 | buf.push_str(") = "); | ||
200 | } | ||
201 | } | ||
202 | |||
161 | if fun.self_param.is_some() { | 203 | if fun.self_param.is_some() { |
162 | format_to!(buf, "self."); | 204 | format_to!(buf, "self."); |
163 | } | 205 | } |
@@ -182,16 +224,17 @@ fn format_replacement(fun: &Function) -> String { | |||
182 | 224 | ||
183 | struct Function { | 225 | struct Function { |
184 | name: String, | 226 | name: String, |
185 | self_param: Option<String>, | 227 | self_param: Option<ast::SelfParam>, |
186 | params: Vec<Param>, | 228 | params: Vec<Param>, |
187 | ret_ty: Option<String>, | 229 | ret_ty: Option<hir::Type>, |
188 | body: FunctionBody, | 230 | body: FunctionBody, |
231 | vars_in_body_used_afterwards: Vec<Local>, | ||
189 | } | 232 | } |
190 | 233 | ||
191 | impl Function { | 234 | impl Function { |
192 | fn has_unit_ret(&self) -> bool { | 235 | fn has_unit_ret(&self) -> bool { |
193 | match &self.ret_ty { | 236 | match &self.ret_ty { |
194 | Some(ty) => ty == "()", | 237 | Some(ty) => ty.is_unit(), |
195 | None => true, | 238 | None => true, |
196 | } | 239 | } |
197 | } | 240 | } |
@@ -203,7 +246,12 @@ struct Param { | |||
203 | ty: String, | 246 | ty: String, |
204 | } | 247 | } |
205 | 248 | ||
206 | fn format_function(fun: &Function, indent: IndentLevel) -> String { | 249 | fn format_function( |
250 | ctx: &AssistContext, | ||
251 | module: hir::Module, | ||
252 | fun: &Function, | ||
253 | indent: IndentLevel, | ||
254 | ) -> String { | ||
207 | let mut fn_def = String::new(); | 255 | let mut fn_def = String::new(); |
208 | format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); | 256 | format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); |
209 | { | 257 | { |
@@ -221,10 +269,24 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String { | |||
221 | format_to!(fn_def, ")"); | 269 | format_to!(fn_def, ")"); |
222 | if !fun.has_unit_ret() { | 270 | if !fun.has_unit_ret() { |
223 | if let Some(ty) = &fun.ret_ty { | 271 | if let Some(ty) = &fun.ret_ty { |
224 | format_to!(fn_def, " -> {}", ty); | 272 | format_to!(fn_def, " -> {}", format_type(ty, ctx, module)); |
273 | } | ||
274 | } else { | ||
275 | match fun.vars_in_body_used_afterwards.as_slice() { | ||
276 | [] => {} | ||
277 | [var] => { | ||
278 | format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module)); | ||
279 | } | ||
280 | [v0, vs @ ..] => { | ||
281 | format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module)); | ||
282 | for var in vs { | ||
283 | format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module)); | ||
284 | } | ||
285 | fn_def.push(')'); | ||
286 | } | ||
225 | } | 287 | } |
226 | } | 288 | } |
227 | format_to!(fn_def, " {{"); | 289 | fn_def.push_str(" {"); |
228 | 290 | ||
229 | match &fun.body { | 291 | match &fun.body { |
230 | FunctionBody::Expr(expr) => { | 292 | FunctionBody::Expr(expr) => { |
@@ -243,11 +305,28 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String { | |||
243 | } | 305 | } |
244 | } | 306 | } |
245 | } | 307 | } |
308 | |||
309 | match fun.vars_in_body_used_afterwards.as_slice() { | ||
310 | [] => {} | ||
311 | [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()), | ||
312 | [v0, vs @ ..] => { | ||
313 | format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap()); | ||
314 | for var in vs { | ||
315 | format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap()); | ||
316 | } | ||
317 | fn_def.push_str(")\n"); | ||
318 | } | ||
319 | } | ||
320 | |||
246 | format_to!(fn_def, "{}}}", indent); | 321 | format_to!(fn_def, "{}}}", indent); |
247 | 322 | ||
248 | fn_def | 323 | fn_def |
249 | } | 324 | } |
250 | 325 | ||
326 | fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String { | ||
327 | ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) | ||
328 | } | ||
329 | |||
251 | #[derive(Debug)] | 330 | #[derive(Debug)] |
252 | enum FunctionBody { | 331 | enum FunctionBody { |
253 | Expr(ast::Expr), | 332 | Expr(ast::Expr), |
@@ -339,18 +418,26 @@ impl FunctionBody { | |||
339 | } | 418 | } |
340 | } | 419 | } |
341 | 420 | ||
342 | fn contains_node(&self, node: &SyntaxNode) -> bool { | 421 | fn text_range(&self) -> TextRange { |
343 | fn is_node(body: &FunctionBody, n: &SyntaxNode) -> bool { | 422 | match self { |
344 | match body { | 423 | FunctionBody::Expr(expr) => expr.syntax().text_range(), |
345 | FunctionBody::Expr(expr) => n == expr.syntax(), | 424 | FunctionBody::Span { elements, .. } => TextRange::new( |
346 | FunctionBody::Span { elements, .. } => { | 425 | elements.first().unwrap().text_range().start(), |
347 | // FIXME: can it be quadratic? | 426 | elements.last().unwrap().text_range().end(), |
348 | elements.iter().filter_map(SyntaxElement::as_node).any(|e| e == n) | 427 | ), |
349 | } | ||
350 | } | ||
351 | } | 428 | } |
429 | } | ||
430 | |||
431 | fn contains_range(&self, range: TextRange) -> bool { | ||
432 | self.text_range().contains_range(range) | ||
433 | } | ||
352 | 434 | ||
353 | node.ancestors().any(|a| is_node(self, &a)) | 435 | fn preceedes_range(&self, range: TextRange) -> bool { |
436 | self.text_range().end() <= range.start() | ||
437 | } | ||
438 | |||
439 | fn contains_node(&self, node: &SyntaxNode) -> bool { | ||
440 | self.contains_range(node.text_range()) | ||
354 | } | 441 | } |
355 | } | 442 | } |
356 | 443 | ||
@@ -383,11 +470,6 @@ fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option<SyntaxNod | |||
383 | last_ancestor | 470 | last_ancestor |
384 | } | 471 | } |
385 | 472 | ||
386 | fn deduplicate_params(params: &mut Vec<ast::IdentPat>) { | ||
387 | let mut seen_params = FxHashSet::default(); | ||
388 | params.retain(|p| seen_params.insert(p.clone())); | ||
389 | } | ||
390 | |||
391 | fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode { | 473 | fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode { |
392 | match value { | 474 | match value { |
393 | Either::Left(pat) => pat.syntax(), | 475 | Either::Left(pat) => pat.syntax(), |
@@ -395,8 +477,8 @@ fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode { | |||
395 | } | 477 | } |
396 | } | 478 | } |
397 | 479 | ||
398 | /// Returns a vector of local variables that are refferenced in `body` | 480 | /// Returns a vector of local variables that are referenced in `body` |
399 | fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> { | 481 | fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> { |
400 | body.descendants() | 482 | body.descendants() |
401 | .filter_map(ast::NameRef::cast) | 483 | .filter_map(ast::NameRef::cast) |
402 | .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) | 484 | .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) |
@@ -405,6 +487,16 @@ fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> { | |||
405 | Definition::Local(local) => Some(local), | 487 | Definition::Local(local) => Some(local), |
406 | _ => None, | 488 | _ => None, |
407 | }) | 489 | }) |
490 | .unique() | ||
491 | .collect() | ||
492 | } | ||
493 | |||
494 | /// Returns a vector of local variables that are defined in `body` | ||
495 | fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> { | ||
496 | body.descendants() | ||
497 | .filter_map(ast::IdentPat::cast) | ||
498 | .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt)) | ||
499 | .unique() | ||
408 | .collect() | 500 | .collect() |
409 | } | 501 | } |
410 | 502 | ||
@@ -973,4 +1065,54 @@ impl S { | |||
973 | }", | 1065 | }", |
974 | ); | 1066 | ); |
975 | } | 1067 | } |
1068 | |||
1069 | #[test] | ||
1070 | fn variable_defined_inside_and_used_after_no_ret() { | ||
1071 | check_assist( | ||
1072 | extract_function, | ||
1073 | r" | ||
1074 | fn foo() { | ||
1075 | let n = 1; | ||
1076 | $0let k = n * n;$0 | ||
1077 | let m = k + 1; | ||
1078 | }", | ||
1079 | r" | ||
1080 | fn foo() { | ||
1081 | let n = 1; | ||
1082 | let k = fun_name(n); | ||
1083 | let m = k + 1; | ||
1084 | } | ||
1085 | |||
1086 | fn $0fun_name(n: i32) -> i32 { | ||
1087 | let k = n * n; | ||
1088 | k | ||
1089 | }", | ||
1090 | ); | ||
1091 | } | ||
1092 | |||
1093 | #[test] | ||
1094 | fn two_variables_defined_inside_and_used_after_no_ret() { | ||
1095 | check_assist( | ||
1096 | extract_function, | ||
1097 | r" | ||
1098 | fn foo() { | ||
1099 | let n = 1; | ||
1100 | $0let k = n * n; | ||
1101 | let m = k + 2;$0 | ||
1102 | let h = k + m; | ||
1103 | }", | ||
1104 | r" | ||
1105 | fn foo() { | ||
1106 | let n = 1; | ||
1107 | let (k, m) = fun_name(n); | ||
1108 | let h = k + m; | ||
1109 | } | ||
1110 | |||
1111 | fn $0fun_name(n: i32) -> (i32, i32) { | ||
1112 | let k = n * n; | ||
1113 | let m = k + 2; | ||
1114 | (k, m) | ||
1115 | }", | ||
1116 | ); | ||
1117 | } | ||
976 | } | 1118 | } |