aboutsummaryrefslogtreecommitdiff
path: root/crates/assists
diff options
context:
space:
mode:
authorVladyslav Katasonov <[email protected]>2021-02-03 17:31:12 +0000
committerVladyslav Katasonov <[email protected]>2021-02-03 18:11:12 +0000
commit82787febdee3e7dfe5a96c94aee03cd726f642f9 (patch)
tree1fa901e779885b9ad79a6f691335352388d55691 /crates/assists
parent313aa5f3a2a9237c96c97c5852da39cf83bcb1ae (diff)
allow local variables to be used after extracted body
when variable is defined inside extracted body export this variable to original scope via return value(s)
Diffstat (limited to 'crates/assists')
-rw-r--r--crates/assists/src/handlers/extract_function.rs224
1 files changed, 183 insertions, 41 deletions
diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs
index 958199e5e..c5e6ec733 100644
--- a/crates/assists/src/handlers/extract_function.rs
+++ b/crates/assists/src/handlers/extract_function.rs
@@ -1,7 +1,10 @@
1use either::Either; 1use either::Either;
2use hir::{HirDisplay, Local}; 2use hir::{HirDisplay, Local};
3use ide_db::defs::{Definition, NameRefClass}; 3use ide_db::{
4use rustc_hash::FxHashSet; 4 defs::{Definition, NameRefClass},
5 search::SearchScope,
6};
7use itertools::Itertools;
5use stdx::format_to; 8use stdx::format_to;
6use syntax::{ 9use syntax::{
7 ast::{ 10 ast::{
@@ -81,9 +84,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
81 } 84 }
82 let body = body?; 85 let body = body?;
83 86
87 let vars_used_in_body = vars_used_in_body(&body, &ctx);
84 let mut self_param = None; 88 let mut self_param = None;
85 let mut param_pats: Vec<_> = local_variables(&body, &ctx) 89 let param_pats: Vec<_> = vars_used_in_body
86 .into_iter() 90 .iter()
87 .map(|node| node.source(ctx.db())) 91 .map(|node| node.source(ctx.db()))
88 .filter(|src| { 92 .filter(|src| {
89 src.file_id.original_file(ctx.db()) == ctx.frange.file_id 93 src.file_id.original_file(ctx.db()) == ctx.frange.file_id
@@ -98,12 +102,27 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
98 } 102 }
99 }) 103 })
100 .collect(); 104 .collect();
101 deduplicate_params(&mut param_pats);
102 105
103 let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; 106 let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding };
104 let insert_after = body.scope_for_fn_insertion(anchor)?; 107 let insert_after = body.scope_for_fn_insertion(anchor)?;
105 let module = ctx.sema.scope(&insert_after).module()?; 108 let module = ctx.sema.scope(&insert_after).module()?;
106 109
110 let vars_defined_in_body = vars_defined_in_body(&body, ctx);
111
112 let vars_in_body_used_afterwards: Vec<_> = vars_defined_in_body
113 .iter()
114 .copied()
115 .filter(|node| {
116 let usages = Definition::Local(*node)
117 .usages(&ctx.sema)
118 .in_scope(SearchScope::single_file(ctx.frange.file_id))
119 .all();
120 let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter());
121
122 usages.any(|reference| body.preceedes_range(reference.range))
123 })
124 .collect();
125
107 let params = param_pats 126 let params = param_pats
108 .into_iter() 127 .into_iter()
109 .map(|pat| { 128 .map(|pat| {
@@ -119,20 +138,18 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
119 }) 138 })
120 .collect::<Vec<_>>(); 139 .collect::<Vec<_>>();
121 140
122 let self_param =
123 if let Some(self_param) = self_param { Some(self_param.to_string()) } else { None };
124
125 let expr = body.tail_expr(); 141 let expr = body.tail_expr();
126 let ret_ty = match expr { 142 let ret_ty = match expr {
127 Some(expr) => { 143 Some(expr) => Some(ctx.sema.type_of_expr(&expr)?),
128 // FIXME: can we do assist when type is unknown?
129 // We can insert something like `-> ()`
130 let ty = ctx.sema.type_of_expr(&expr)?;
131 Some(ty.display_source_code(ctx.db(), module.into()).ok()?)
132 }
133 None => None, 144 None => None,
134 }; 145 };
135 146
147 let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit());
148 if stdx::never!(!vars_in_body_used_afterwards.is_empty() && !has_unit_ret) {
149 // We should not have variables that outlive body if we have expression block
150 return None;
151 }
152
136 let target_range = match &body { 153 let target_range = match &body {
137 FunctionBody::Expr(expr) => expr.syntax().text_range(), 154 FunctionBody::Expr(expr) => expr.syntax().text_range(),
138 FunctionBody::Span { .. } => ctx.frange.range, 155 FunctionBody::Span { .. } => ctx.frange.range,
@@ -143,21 +160,46 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
143 "Extract into function", 160 "Extract into function",
144 target_range, 161 target_range,
145 move |builder| { 162 move |builder| {
146 let fun = Function { name: "fun_name".to_string(), self_param, params, ret_ty, body }; 163 let fun = Function {
164 name: "fun_name".to_string(),
165 self_param,
166 params,
167 ret_ty,
168 body,
169 vars_in_body_used_afterwards,
170 };
147 171
148 builder.replace(target_range, format_replacement(&fun)); 172 builder.replace(target_range, format_replacement(ctx, &fun));
149 173
150 let indent = IndentLevel::from_node(&insert_after); 174 let indent = IndentLevel::from_node(&insert_after);
151 175
152 let fn_def = format_function(&fun, indent); 176 let fn_def = format_function(ctx, module, &fun, indent);
153 let insert_offset = insert_after.text_range().end(); 177 let insert_offset = insert_after.text_range().end();
154 builder.insert(insert_offset, fn_def); 178 builder.insert(insert_offset, fn_def);
155 }, 179 },
156 ) 180 )
157} 181}
158 182
159fn format_replacement(fun: &Function) -> String { 183fn format_replacement(ctx: &AssistContext, fun: &Function) -> String {
160 let mut buf = String::new(); 184 let mut buf = String::new();
185
186 match fun.vars_in_body_used_afterwards.len() {
187 0 => {}
188 1 => format_to!(
189 buf,
190 "let {} = ",
191 fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()
192 ),
193 _ => {
194 buf.push_str("let (");
195 format_to!(buf, "{}", fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap());
196 for local in fun.vars_in_body_used_afterwards.iter().skip(1) {
197 format_to!(buf, ", {}", local.name(ctx.db()).unwrap());
198 }
199 buf.push_str(") = ");
200 }
201 }
202
161 if fun.self_param.is_some() { 203 if fun.self_param.is_some() {
162 format_to!(buf, "self."); 204 format_to!(buf, "self.");
163 } 205 }
@@ -182,16 +224,17 @@ fn format_replacement(fun: &Function) -> String {
182 224
183struct Function { 225struct Function {
184 name: String, 226 name: String,
185 self_param: Option<String>, 227 self_param: Option<ast::SelfParam>,
186 params: Vec<Param>, 228 params: Vec<Param>,
187 ret_ty: Option<String>, 229 ret_ty: Option<hir::Type>,
188 body: FunctionBody, 230 body: FunctionBody,
231 vars_in_body_used_afterwards: Vec<Local>,
189} 232}
190 233
191impl Function { 234impl Function {
192 fn has_unit_ret(&self) -> bool { 235 fn has_unit_ret(&self) -> bool {
193 match &self.ret_ty { 236 match &self.ret_ty {
194 Some(ty) => ty == "()", 237 Some(ty) => ty.is_unit(),
195 None => true, 238 None => true,
196 } 239 }
197 } 240 }
@@ -203,7 +246,12 @@ struct Param {
203 ty: String, 246 ty: String,
204} 247}
205 248
206fn format_function(fun: &Function, indent: IndentLevel) -> String { 249fn format_function(
250 ctx: &AssistContext,
251 module: hir::Module,
252 fun: &Function,
253 indent: IndentLevel,
254) -> String {
207 let mut fn_def = String::new(); 255 let mut fn_def = String::new();
208 format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); 256 format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name);
209 { 257 {
@@ -221,10 +269,24 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String {
221 format_to!(fn_def, ")"); 269 format_to!(fn_def, ")");
222 if !fun.has_unit_ret() { 270 if !fun.has_unit_ret() {
223 if let Some(ty) = &fun.ret_ty { 271 if let Some(ty) = &fun.ret_ty {
224 format_to!(fn_def, " -> {}", ty); 272 format_to!(fn_def, " -> {}", format_type(ty, ctx, module));
273 }
274 } else {
275 match fun.vars_in_body_used_afterwards.as_slice() {
276 [] => {}
277 [var] => {
278 format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module));
279 }
280 [v0, vs @ ..] => {
281 format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module));
282 for var in vs {
283 format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module));
284 }
285 fn_def.push(')');
286 }
225 } 287 }
226 } 288 }
227 format_to!(fn_def, " {{"); 289 fn_def.push_str(" {");
228 290
229 match &fun.body { 291 match &fun.body {
230 FunctionBody::Expr(expr) => { 292 FunctionBody::Expr(expr) => {
@@ -243,11 +305,28 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String {
243 } 305 }
244 } 306 }
245 } 307 }
308
309 match fun.vars_in_body_used_afterwards.as_slice() {
310 [] => {}
311 [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()),
312 [v0, vs @ ..] => {
313 format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap());
314 for var in vs {
315 format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap());
316 }
317 fn_def.push_str(")\n");
318 }
319 }
320
246 format_to!(fn_def, "{}}}", indent); 321 format_to!(fn_def, "{}}}", indent);
247 322
248 fn_def 323 fn_def
249} 324}
250 325
326fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String {
327 ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string())
328}
329
251#[derive(Debug)] 330#[derive(Debug)]
252enum FunctionBody { 331enum FunctionBody {
253 Expr(ast::Expr), 332 Expr(ast::Expr),
@@ -339,18 +418,26 @@ impl FunctionBody {
339 } 418 }
340 } 419 }
341 420
342 fn contains_node(&self, node: &SyntaxNode) -> bool { 421 fn text_range(&self) -> TextRange {
343 fn is_node(body: &FunctionBody, n: &SyntaxNode) -> bool { 422 match self {
344 match body { 423 FunctionBody::Expr(expr) => expr.syntax().text_range(),
345 FunctionBody::Expr(expr) => n == expr.syntax(), 424 FunctionBody::Span { elements, .. } => TextRange::new(
346 FunctionBody::Span { elements, .. } => { 425 elements.first().unwrap().text_range().start(),
347 // FIXME: can it be quadratic? 426 elements.last().unwrap().text_range().end(),
348 elements.iter().filter_map(SyntaxElement::as_node).any(|e| e == n) 427 ),
349 }
350 }
351 } 428 }
429 }
430
431 fn contains_range(&self, range: TextRange) -> bool {
432 self.text_range().contains_range(range)
433 }
352 434
353 node.ancestors().any(|a| is_node(self, &a)) 435 fn preceedes_range(&self, range: TextRange) -> bool {
436 self.text_range().end() <= range.start()
437 }
438
439 fn contains_node(&self, node: &SyntaxNode) -> bool {
440 self.contains_range(node.text_range())
354 } 441 }
355} 442}
356 443
@@ -383,11 +470,6 @@ fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option<SyntaxNod
383 last_ancestor 470 last_ancestor
384} 471}
385 472
386fn deduplicate_params(params: &mut Vec<ast::IdentPat>) {
387 let mut seen_params = FxHashSet::default();
388 params.retain(|p| seen_params.insert(p.clone()));
389}
390
391fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode { 473fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
392 match value { 474 match value {
393 Either::Left(pat) => pat.syntax(), 475 Either::Left(pat) => pat.syntax(),
@@ -395,8 +477,8 @@ fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
395 } 477 }
396} 478}
397 479
398/// Returns a vector of local variables that are refferenced in `body` 480/// Returns a vector of local variables that are referenced in `body`
399fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> { 481fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
400 body.descendants() 482 body.descendants()
401 .filter_map(ast::NameRef::cast) 483 .filter_map(ast::NameRef::cast)
402 .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) 484 .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref))
@@ -405,6 +487,16 @@ fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
405 Definition::Local(local) => Some(local), 487 Definition::Local(local) => Some(local),
406 _ => None, 488 _ => None,
407 }) 489 })
490 .unique()
491 .collect()
492}
493
494/// Returns a vector of local variables that are defined in `body`
495fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
496 body.descendants()
497 .filter_map(ast::IdentPat::cast)
498 .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt))
499 .unique()
408 .collect() 500 .collect()
409} 501}
410 502
@@ -973,4 +1065,54 @@ impl S {
973}", 1065}",
974 ); 1066 );
975 } 1067 }
1068
1069 #[test]
1070 fn variable_defined_inside_and_used_after_no_ret() {
1071 check_assist(
1072 extract_function,
1073 r"
1074fn foo() {
1075 let n = 1;
1076 $0let k = n * n;$0
1077 let m = k + 1;
1078}",
1079 r"
1080fn foo() {
1081 let n = 1;
1082 let k = fun_name(n);
1083 let m = k + 1;
1084}
1085
1086fn $0fun_name(n: i32) -> i32 {
1087 let k = n * n;
1088 k
1089}",
1090 );
1091 }
1092
1093 #[test]
1094 fn two_variables_defined_inside_and_used_after_no_ret() {
1095 check_assist(
1096 extract_function,
1097 r"
1098fn foo() {
1099 let n = 1;
1100 $0let k = n * n;
1101 let m = k + 2;$0
1102 let h = k + m;
1103}",
1104 r"
1105fn foo() {
1106 let n = 1;
1107 let (k, m) = fun_name(n);
1108 let h = k + m;
1109}
1110
1111fn $0fun_name(n: i32) -> (i32, i32) {
1112 let k = n * n;
1113 let m = k + 2;
1114 (k, m)
1115}",
1116 );
1117 }
976} 1118}