aboutsummaryrefslogtreecommitdiff
path: root/crates/assists/src
diff options
context:
space:
mode:
authorVladyslav Katasonov <[email protected]>2021-02-03 20:45:03 +0000
committerVladyslav Katasonov <[email protected]>2021-02-03 20:45:03 +0000
commitf102616aaea2894508f8f078cfb20ceef5411d12 (patch)
treefe2f951fd8fe9d3ed9aa7e92db9f467a8bb7fc66 /crates/assists/src
parent82787febdee3e7dfe5a96c94aee03cd726f642f9 (diff)
allow modifications of vars from outer scope inside extracted function
It currently allows only directly setting variable. No `&mut` references or methods.
Diffstat (limited to 'crates/assists/src')
-rw-r--r--crates/assists/src/handlers/extract_function.rs381
1 files changed, 336 insertions, 45 deletions
diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs
index c5e6ec733..ffa8bd77d 100644
--- a/crates/assists/src/handlers/extract_function.rs
+++ b/crates/assists/src/handlers/extract_function.rs
@@ -2,19 +2,20 @@ use either::Either;
2use hir::{HirDisplay, Local}; 2use hir::{HirDisplay, Local};
3use ide_db::{ 3use ide_db::{
4 defs::{Definition, NameRefClass}, 4 defs::{Definition, NameRefClass},
5 search::SearchScope, 5 search::{ReferenceAccess, SearchScope},
6}; 6};
7use itertools::Itertools; 7use itertools::Itertools;
8use stdx::format_to; 8use stdx::format_to;
9use syntax::{ 9use syntax::{
10 algo::SyntaxRewriter,
10 ast::{ 11 ast::{
11 self, 12 self,
12 edit::{AstNodeEdit, IndentLevel}, 13 edit::{AstNodeEdit, IndentLevel},
13 AstNode, NameOwner, 14 AstNode,
14 }, 15 },
15 Direction, SyntaxElement, 16 Direction, SyntaxElement,
16 SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR}, 17 SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR},
17 SyntaxNode, TextRange, 18 SyntaxNode, TextRange, T,
18}; 19};
19use test_utils::mark; 20use test_utils::mark;
20 21
@@ -88,16 +89,16 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
88 let mut self_param = None; 89 let mut self_param = None;
89 let param_pats: Vec<_> = vars_used_in_body 90 let param_pats: Vec<_> = vars_used_in_body
90 .iter() 91 .iter()
91 .map(|node| node.source(ctx.db())) 92 .map(|node| (node, node.source(ctx.db())))
92 .filter(|src| { 93 .filter(|(_, src)| {
93 src.file_id.original_file(ctx.db()) == ctx.frange.file_id 94 src.file_id.original_file(ctx.db()) == ctx.frange.file_id
94 && !body.contains_node(&either_syntax(&src.value)) 95 && !body.contains_node(&either_syntax(&src.value))
95 }) 96 })
96 .filter_map(|src| match src.value { 97 .filter_map(|(&node, src)| match src.value {
97 Either::Left(pat) => Some(pat), 98 Either::Left(_) => Some(node),
98 Either::Right(it) => { 99 Either::Right(it) => {
99 // we filter self param, as there can only be one 100 // we filter self param, as there can only be one
100 self_param = Some(it); 101 self_param = Some((node, it));
101 None 102 None
102 } 103 }
103 }) 104 })
@@ -109,7 +110,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
109 110
110 let vars_defined_in_body = vars_defined_in_body(&body, ctx); 111 let vars_defined_in_body = vars_defined_in_body(&body, ctx);
111 112
112 let vars_in_body_used_afterwards: Vec<_> = vars_defined_in_body 113 let vars_defined_in_body_and_outlive: Vec<_> = vars_defined_in_body
113 .iter() 114 .iter()
114 .copied() 115 .copied()
115 .filter(|node| { 116 .filter(|node| {
@@ -123,20 +124,27 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
123 }) 124 })
124 .collect(); 125 .collect();
125 126
126 let params = param_pats 127 let params: Vec<_> = param_pats
127 .into_iter() 128 .into_iter()
128 .map(|pat| { 129 .map(|node| {
129 let name = pat.name().unwrap().to_string(); 130 let usages = Definition::Local(node)
130 131 .usages(&ctx.sema)
131 let ty = ctx 132 .in_scope(SearchScope::single_file(ctx.frange.file_id))
132 .sema 133 .all();
133 .type_of_pat(&pat.into())
134 .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok())
135 .unwrap_or_else(|| "()".to_string());
136 134
137 Param { name, ty } 135 let has_usages_afterwards = usages
136 .iter()
137 .flat_map(|(_, rs)| rs.iter())
138 .any(|reference| body.preceedes_range(reference.range));
139 let has_mut_inside_body = usages
140 .iter()
141 .flat_map(|(_, rs)| rs.iter())
142 .filter(|reference| body.contains_range(reference.range))
143 .any(|reference| reference.access == Some(ReferenceAccess::Write));
144
145 Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true }
138 }) 146 })
139 .collect::<Vec<_>>(); 147 .collect();
140 148
141 let expr = body.tail_expr(); 149 let expr = body.tail_expr();
142 let ret_ty = match expr { 150 let ret_ty = match expr {
@@ -145,7 +153,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
145 }; 153 };
146 154
147 let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit()); 155 let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit());
148 if stdx::never!(!vars_in_body_used_afterwards.is_empty() && !has_unit_ret) { 156 if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !has_unit_ret) {
149 // We should not have variables that outlive body if we have expression block 157 // We should not have variables that outlive body if we have expression block
150 return None; 158 return None;
151 } 159 }
@@ -162,11 +170,11 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
162 move |builder| { 170 move |builder| {
163 let fun = Function { 171 let fun = Function {
164 name: "fun_name".to_string(), 172 name: "fun_name".to_string(),
165 self_param, 173 self_param: self_param.map(|(_, pat)| pat),
166 params, 174 params,
167 ret_ty, 175 ret_ty,
168 body, 176 body,
169 vars_in_body_used_afterwards, 177 vars_defined_in_body_and_outlive,
170 }; 178 };
171 179
172 builder.replace(target_range, format_replacement(ctx, &fun)); 180 builder.replace(target_range, format_replacement(ctx, &fun));
@@ -183,17 +191,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
183fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { 191fn format_replacement(ctx: &AssistContext, fun: &Function) -> String {
184 let mut buf = String::new(); 192 let mut buf = String::new();
185 193
186 match fun.vars_in_body_used_afterwards.len() { 194 match fun.vars_defined_in_body_and_outlive.as_slice() {
187 0 => {} 195 [] => {}
188 1 => format_to!( 196 [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()),
189 buf, 197 [v0, vs @ ..] => {
190 "let {} = ",
191 fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()
192 ),
193 _ => {
194 buf.push_str("let ("); 198 buf.push_str("let (");
195 format_to!(buf, "{}", fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()); 199 format_to!(buf, "{}", v0.name(ctx.db()).unwrap());
196 for local in fun.vars_in_body_used_afterwards.iter().skip(1) { 200 for local in vs {
197 format_to!(buf, ", {}", local.name(ctx.db()).unwrap()); 201 format_to!(buf, ", {}", local.name(ctx.db()).unwrap());
198 } 202 }
199 buf.push_str(") = "); 203 buf.push_str(") = ");
@@ -207,10 +211,10 @@ fn format_replacement(ctx: &AssistContext, fun: &Function) -> String {
207 { 211 {
208 let mut it = fun.params.iter(); 212 let mut it = fun.params.iter();
209 if let Some(param) = it.next() { 213 if let Some(param) = it.next() {
210 format_to!(buf, "{}", param.name); 214 format_to!(buf, "{}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap());
211 } 215 }
212 for param in it { 216 for param in it {
213 format_to!(buf, ", {}", param.name); 217 format_to!(buf, ", {}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap());
214 } 218 }
215 } 219 }
216 format_to!(buf, ")"); 220 format_to!(buf, ")");
@@ -228,7 +232,7 @@ struct Function {
228 params: Vec<Param>, 232 params: Vec<Param>,
229 ret_ty: Option<hir::Type>, 233 ret_ty: Option<hir::Type>,
230 body: FunctionBody, 234 body: FunctionBody,
231 vars_in_body_used_afterwards: Vec<Local>, 235 vars_defined_in_body_and_outlive: Vec<Local>,
232} 236}
233 237
234impl Function { 238impl Function {
@@ -242,8 +246,60 @@ impl Function {
242 246
243#[derive(Debug)] 247#[derive(Debug)]
244struct Param { 248struct Param {
245 name: String, 249 node: Local,
246 ty: String, 250 has_usages_afterwards: bool,
251 has_mut_inside_body: bool,
252 is_copy: bool,
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq)]
256enum ParamKind {
257 Value,
258 MutValue,
259 SharedRef,
260 MutRef,
261}
262
263impl ParamKind {
264 fn is_ref(&self) -> bool {
265 matches!(self, ParamKind::SharedRef | ParamKind::MutRef)
266 }
267}
268
269impl Param {
270 fn kind(&self) -> ParamKind {
271 match (self.has_usages_afterwards, self.has_mut_inside_body, self.is_copy) {
272 (true, true, _) => ParamKind::MutRef,
273 (true, false, false) => ParamKind::SharedRef,
274 (false, true, _) => ParamKind::MutValue,
275 (true, false, true) | (false, false, _) => ParamKind::Value,
276 }
277 }
278
279 fn value_prefix(&self) -> &'static str {
280 match self.kind() {
281 ParamKind::Value => "",
282 ParamKind::MutValue => "",
283 ParamKind::SharedRef => "&",
284 ParamKind::MutRef => "&mut ",
285 }
286 }
287
288 fn type_prefix(&self) -> &'static str {
289 match self.kind() {
290 ParamKind::Value => "",
291 ParamKind::MutValue => "",
292 ParamKind::SharedRef => "&",
293 ParamKind::MutRef => "&mut ",
294 }
295 }
296
297 fn mut_pattern(&self) -> &'static str {
298 match self.kind() {
299 ParamKind::MutValue => "mut ",
300 _ => "",
301 }
302 }
247} 303}
248 304
249fn format_function( 305fn format_function(
@@ -259,10 +315,24 @@ fn format_function(
259 if let Some(self_param) = &fun.self_param { 315 if let Some(self_param) = &fun.self_param {
260 format_to!(fn_def, "{}", self_param); 316 format_to!(fn_def, "{}", self_param);
261 } else if let Some(param) = it.next() { 317 } else if let Some(param) = it.next() {
262 format_to!(fn_def, "{}: {}", param.name, param.ty); 318 format_to!(
319 fn_def,
320 "{}{}: {}{}",
321 param.mut_pattern(),
322 param.node.name(ctx.db()).unwrap(),
323 param.type_prefix(),
324 format_type(&param.node.ty(ctx.db()), ctx, module)
325 );
263 } 326 }
264 for param in it { 327 for param in it {
265 format_to!(fn_def, ", {}: {}", param.name, param.ty); 328 format_to!(
329 fn_def,
330 ", {}{}: {}{}",
331 param.mut_pattern(),
332 param.node.name(ctx.db()).unwrap(),
333 param.type_prefix(),
334 format_type(&param.node.ty(ctx.db()), ctx, module)
335 );
266 } 336 }
267 } 337 }
268 338
@@ -272,7 +342,7 @@ fn format_function(
272 format_to!(fn_def, " -> {}", format_type(ty, ctx, module)); 342 format_to!(fn_def, " -> {}", format_type(ty, ctx, module));
273 } 343 }
274 } else { 344 } else {
275 match fun.vars_in_body_used_afterwards.as_slice() { 345 match fun.vars_defined_in_body_and_outlive.as_slice() {
276 [] => {} 346 [] => {}
277 [var] => { 347 [var] => {
278 format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module)); 348 format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module));
@@ -292,13 +362,21 @@ fn format_function(
292 FunctionBody::Expr(expr) => { 362 FunctionBody::Expr(expr) => {
293 fn_def.push('\n'); 363 fn_def.push('\n');
294 let expr = expr.indent(indent); 364 let expr = expr.indent(indent);
295 format_to!(fn_def, "{}{}", indent + 1, expr.syntax()); 365 let expr = fix_param_usages(ctx, &fun.params, expr.syntax());
366 format_to!(fn_def, "{}{}", indent + 1, expr);
296 fn_def.push('\n'); 367 fn_def.push('\n');
297 } 368 }
298 FunctionBody::Span { elements, leading_indent } => { 369 FunctionBody::Span { elements, leading_indent } => {
299 format_to!(fn_def, "{}", leading_indent); 370 format_to!(fn_def, "{}", leading_indent);
300 for e in elements { 371 for element in elements {
301 format_to!(fn_def, "{}", e); 372 match element {
373 syntax::NodeOrToken::Node(node) => {
374 format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node));
375 }
376 syntax::NodeOrToken::Token(token) => {
377 format_to!(fn_def, "{}", token);
378 }
379 }
302 } 380 }
303 if !fn_def.ends_with('\n') { 381 if !fn_def.ends_with('\n') {
304 fn_def.push('\n'); 382 fn_def.push('\n');
@@ -306,7 +384,7 @@ fn format_function(
306 } 384 }
307 } 385 }
308 386
309 match fun.vars_in_body_used_afterwards.as_slice() { 387 match fun.vars_defined_in_body_and_outlive.as_slice() {
310 [] => {} 388 [] => {}
311 [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()), 389 [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()),
312 [v0, vs @ ..] => { 390 [v0, vs @ ..] => {
@@ -327,6 +405,61 @@ fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> Stri
327 ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) 405 ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string())
328} 406}
329 407
408fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode {
409 let mut rewriter = SyntaxRewriter::default();
410 for param in params {
411 if !param.kind().is_ref() {
412 continue;
413 }
414
415 let usages = Definition::Local(param.node)
416 .usages(&ctx.sema)
417 .in_scope(SearchScope::single_file(ctx.frange.file_id))
418 .all();
419 let usages = usages
420 .iter()
421 .flat_map(|(_, rs)| rs.iter())
422 .filter(|reference| syntax.text_range().contains_range(reference.range));
423 for reference in usages {
424 let token = match syntax.token_at_offset(reference.range.start()).right_biased() {
425 Some(a) => a,
426 None => {
427 stdx::never!(false, "cannot find token at variable usage: {:?}", reference);
428 continue;
429 }
430 };
431 let path = match token.ancestors().find_map(ast::Expr::cast) {
432 Some(n) => n,
433 None => {
434 stdx::never!(false, "cannot find path parent of variable usage: {:?}", token);
435 continue;
436 }
437 };
438 stdx::always!(matches!(path, ast::Expr::PathExpr(_)));
439 match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) {
440 Some(ast::Expr::MethodCallExpr(_)) => {
441 // do nothing
442 }
443 Some(ast::Expr::RefExpr(node))
444 if param.kind() == ParamKind::MutRef && node.mut_token().is_some() =>
445 {
446 rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap());
447 }
448 Some(ast::Expr::RefExpr(node))
449 if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() =>
450 {
451 rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap());
452 }
453 Some(_) | None => {
454 rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone()));
455 }
456 };
457 }
458 }
459
460 rewriter.rewrite(syntax)
461}
462
330#[derive(Debug)] 463#[derive(Debug)]
331enum FunctionBody { 464enum FunctionBody {
332 Expr(ast::Expr), 465 Expr(ast::Expr),
@@ -1115,4 +1248,162 @@ fn $0fun_name(n: i32) -> (i32, i32) {
1115}", 1248}",
1116 ); 1249 );
1117 } 1250 }
1251
1252 #[test]
1253 fn mut_var_from_outer_scope() {
1254 check_assist(
1255 extract_function,
1256 r"
1257fn foo() {
1258 let mut n = 1;
1259 $0n += 1;$0
1260 let m = n + 1;
1261}",
1262 r"
1263fn foo() {
1264 let mut n = 1;
1265 fun_name(&mut n);
1266 let m = n + 1;
1267}
1268
1269fn $0fun_name(n: &mut i32) {
1270 *n += 1;
1271}",
1272 );
1273 }
1274
1275 #[test]
1276 fn mut_param_many_usages_stmt() {
1277 check_assist(
1278 extract_function,
1279 r"
1280fn bar(k: i32) {}
1281trait I: Copy {
1282 fn succ(&self) -> Self;
1283 fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
1284}
1285impl I for i32 {
1286 fn succ(&self) -> Self { *self + 1 }
1287}
1288fn foo() {
1289 let mut n = 1;
1290 $0n += n;
1291 bar(n);
1292 bar(n+1);
1293 bar(n*n);
1294 bar(&n);
1295 n.inc();
1296 let v = &mut n;
1297 *v = v.succ();
1298 n.succ();$0
1299 let m = n + 1;
1300}",
1301 r"
1302fn bar(k: i32) {}
1303trait I: Copy {
1304 fn succ(&self) -> Self;
1305 fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
1306}
1307impl I for i32 {
1308 fn succ(&self) -> Self { *self + 1 }
1309}
1310fn foo() {
1311 let mut n = 1;
1312 fun_name(&mut n);
1313 let m = n + 1;
1314}
1315
1316fn $0fun_name(n: &mut i32) {
1317 *n += *n;
1318 bar(*n);
1319 bar(*n+1);
1320 bar(*n**n);
1321 bar(&*n);
1322 n.inc();
1323 let v = n;
1324 *v = v.succ();
1325 n.succ();
1326}",
1327 );
1328 }
1329
1330 #[test]
1331 fn mut_param_many_usages_expr() {
1332 check_assist(
1333 extract_function,
1334 r"
1335fn bar(k: i32) {}
1336trait I: Copy {
1337 fn succ(&self) -> Self;
1338 fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
1339}
1340impl I for i32 {
1341 fn succ(&self) -> Self { *self + 1 }
1342}
1343fn foo() {
1344 let mut n = 1;
1345 $0{
1346 n += n;
1347 bar(n);
1348 bar(n+1);
1349 bar(n*n);
1350 bar(&n);
1351 n.inc();
1352 let v = &mut n;
1353 *v = v.succ();
1354 n.succ();
1355 }$0
1356 let m = n + 1;
1357}",
1358 r"
1359fn bar(k: i32) {}
1360trait I: Copy {
1361 fn succ(&self) -> Self;
1362 fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
1363}
1364impl I for i32 {
1365 fn succ(&self) -> Self { *self + 1 }
1366}
1367fn foo() {
1368 let mut n = 1;
1369 fun_name(&mut n);
1370 let m = n + 1;
1371}
1372
1373fn $0fun_name(n: &mut i32) {
1374 {
1375 *n += *n;
1376 bar(*n);
1377 bar(*n+1);
1378 bar(*n**n);
1379 bar(&*n);
1380 n.inc();
1381 let v = n;
1382 *v = v.succ();
1383 n.succ();
1384 }
1385}",
1386 );
1387 }
1388
1389 #[test]
1390 fn mut_param_by_value() {
1391 check_assist(
1392 extract_function,
1393 r"
1394fn foo() {
1395 let mut n = 1;
1396 $0n += 1;$0
1397}",
1398 r"
1399fn foo() {
1400 let mut n = 1;
1401 fun_name(n);
1402}
1403
1404fn $0fun_name(mut n: i32) {
1405 n += 1;
1406}",
1407 );
1408 }
1118} 1409}