aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--crates/assists/src/handlers/extract_function.rs890
1 files changed, 510 insertions, 380 deletions
diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs
index dce7ffd7b..93ff66b24 100644
--- a/crates/assists/src/handlers/extract_function.rs
+++ b/crates/assists/src/handlers/extract_function.rs
@@ -60,115 +60,21 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
60 return None; 60 return None;
61 } 61 }
62 62
63 let node = match node { 63 let node = element_to_node(node);
64 syntax::NodeOrToken::Node(n) => n,
65 syntax::NodeOrToken::Token(t) => t.parent(),
66 };
67 64
68 let mut body = None; 65 let body = extraction_target(&node, ctx.frange.range)?;
69 if node.text_range() == ctx.frange.range {
70 body = FunctionBody::from_whole_node(node.clone());
71 }
72 if body.is_none() && node.kind() == BLOCK_EXPR {
73 body = FunctionBody::from_range(&node, ctx.frange.range);
74 }
75 if let Some(parent) = node.parent() {
76 if body.is_none() && parent.kind() == BLOCK_EXPR {
77 body = FunctionBody::from_range(&parent, ctx.frange.range);
78 }
79 }
80 if body.is_none() {
81 body = FunctionBody::from_whole_node(node.clone());
82 }
83 if body.is_none() {
84 body = node.ancestors().find_map(FunctionBody::from_whole_node);
85 }
86 let body = body?;
87 66
88 let vars_used_in_body = vars_used_in_body(&body, &ctx); 67 let vars_used_in_body = vars_used_in_body(&body, &ctx);
89 let mut self_param = None; 68 let self_param = self_param_from_usages(ctx, &body, &vars_used_in_body);
90 let param_pats: Vec<_> = vars_used_in_body
91 .iter()
92 .map(|node| (node, node.source(ctx.db())))
93 .filter(|(_, src)| {
94 src.file_id.original_file(ctx.db()) == ctx.frange.file_id
95 && !body.contains_node(&either_syntax(&src.value))
96 })
97 .filter_map(|(&node, src)| match src.value {
98 Either::Left(_) => Some(node),
99 Either::Right(it) => {
100 // we filter self param, as there can only be one
101 self_param = Some((node, it));
102 None
103 }
104 })
105 .collect();
106 69
107 let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; 70 let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding };
108 let insert_after = body.scope_for_fn_insertion(anchor)?; 71 let insert_after = scope_for_fn_insertion(&body, anchor)?;
109 let module = ctx.sema.scope(&insert_after).module()?; 72 let module = ctx.sema.scope(&insert_after).module()?;
110 73
111 let vars_defined_in_body = vars_defined_in_body(&body, ctx); 74 let vars_defined_in_body_and_outlive = vars_defined_in_body_and_outlive(ctx, &body);
112 75 let ret_ty = body_return_ty(ctx, &body)?;
113 let vars_defined_in_body_and_outlive: Vec<_> = vars_defined_in_body
114 .iter()
115 .copied()
116 .filter(|node| {
117 let usages = Definition::Local(*node)
118 .usages(&ctx.sema)
119 .in_scope(SearchScope::single_file(ctx.frange.file_id))
120 .all();
121 let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter());
122
123 usages.any(|reference| body.preceedes_range(reference.range))
124 })
125 .collect();
126
127 let params: Vec<_> = param_pats
128 .into_iter()
129 .map(|node| {
130 let usages = Definition::Local(node)
131 .usages(&ctx.sema)
132 .in_scope(SearchScope::single_file(ctx.frange.file_id))
133 .all();
134
135 let has_usages_afterwards = usages
136 .iter()
137 .flat_map(|(_, rs)| rs.iter())
138 .any(|reference| body.preceedes_range(reference.range));
139 let has_mut_inside_body = usages
140 .iter()
141 .flat_map(|(_, rs)| rs.iter())
142 .filter(|reference| body.contains_range(reference.range))
143 .any(|reference| {
144 if reference.access == Some(ReferenceAccess::Write) {
145 return true;
146 }
147
148 let path = path_at_offset(&body, reference);
149 if is_mut_ref_expr(path.as_ref()).unwrap_or(false) {
150 return true;
151 }
152
153 if is_mut_method_call(ctx, path.as_ref()).unwrap_or(false) {
154 return true;
155 }
156
157 false
158 });
159
160 Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true }
161 })
162 .collect();
163
164 let expr = body.tail_expr();
165 let ret_ty = match expr {
166 Some(expr) => Some(ctx.sema.type_of_expr(&expr)?),
167 None => None,
168 };
169 76
170 let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit()); 77 if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !ret_ty.is_unit()) {
171 if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !has_unit_ret) {
172 // We should not have variables that outlive body if we have expression block 78 // We should not have variables that outlive body if we have expression block
173 return None; 79 return None;
174 } 80 }
@@ -183,6 +89,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
183 "Extract into function", 89 "Extract into function",
184 target_range, 90 target_range,
185 move |builder| { 91 move |builder| {
92 let params = extracted_function_params(ctx, &body, &vars_used_in_body);
93
186 let fun = Function { 94 let fun = Function {
187 name: "fun_name".to_string(), 95 name: "fun_name".to_string(),
188 self_param: self_param.map(|(_, pat)| pat), 96 self_param: self_param.map(|(_, pat)| pat),
@@ -203,65 +111,19 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
203 ) 111 )
204} 112}
205 113
206fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { 114#[derive(Debug)]
207 let mut buf = String::new();
208
209 match fun.vars_defined_in_body_and_outlive.as_slice() {
210 [] => {}
211 [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()),
212 [v0, vs @ ..] => {
213 buf.push_str("let (");
214 format_to!(buf, "{}", v0.name(ctx.db()).unwrap());
215 for local in vs {
216 format_to!(buf, ", {}", local.name(ctx.db()).unwrap());
217 }
218 buf.push_str(") = ");
219 }
220 }
221
222 if fun.self_param.is_some() {
223 format_to!(buf, "self.");
224 }
225 format_to!(buf, "{}(", fun.name);
226 {
227 let mut it = fun.params.iter();
228 if let Some(param) = it.next() {
229 format_to!(buf, "{}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap());
230 }
231 for param in it {
232 format_to!(buf, ", {}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap());
233 }
234 }
235 format_to!(buf, ")");
236
237 if fun.has_unit_ret() {
238 format_to!(buf, ";");
239 }
240
241 buf
242}
243
244struct Function { 115struct Function {
245 name: String, 116 name: String,
246 self_param: Option<ast::SelfParam>, 117 self_param: Option<ast::SelfParam>,
247 params: Vec<Param>, 118 params: Vec<Param>,
248 ret_ty: Option<hir::Type>, 119 ret_ty: RetType,
249 body: FunctionBody, 120 body: FunctionBody,
250 vars_defined_in_body_and_outlive: Vec<Local>, 121 vars_defined_in_body_and_outlive: Vec<Local>,
251} 122}
252 123
253impl Function {
254 fn has_unit_ret(&self) -> bool {
255 match &self.ret_ty {
256 Some(ty) => ty.is_unit(),
257 None => true,
258 }
259 }
260}
261
262#[derive(Debug)] 124#[derive(Debug)]
263struct Param { 125struct Param {
264 node: Local, 126 var: Local,
265 has_usages_afterwards: bool, 127 has_usages_afterwards: bool,
266 has_mut_inside_body: bool, 128 has_mut_inside_body: bool,
267 is_copy: bool, 129 is_copy: bool,
@@ -293,8 +155,7 @@ impl Param {
293 155
294 fn value_prefix(&self) -> &'static str { 156 fn value_prefix(&self) -> &'static str {
295 match self.kind() { 157 match self.kind() {
296 ParamKind::Value => "", 158 ParamKind::Value | ParamKind::MutValue => "",
297 ParamKind::MutValue => "",
298 ParamKind::SharedRef => "&", 159 ParamKind::SharedRef => "&",
299 ParamKind::MutRef => "&mut ", 160 ParamKind::MutRef => "&mut ",
300 } 161 }
@@ -302,8 +163,7 @@ impl Param {
302 163
303 fn type_prefix(&self) -> &'static str { 164 fn type_prefix(&self) -> &'static str {
304 match self.kind() { 165 match self.kind() {
305 ParamKind::Value => "", 166 ParamKind::Value | ParamKind::MutValue => "",
306 ParamKind::MutValue => "",
307 ParamKind::SharedRef => "&", 167 ParamKind::SharedRef => "&",
308 ParamKind::MutRef => "&mut ", 168 ParamKind::MutRef => "&mut ",
309 } 169 }
@@ -317,186 +177,27 @@ impl Param {
317 } 177 }
318} 178}
319 179
320fn format_function( 180#[derive(Debug)]
321 ctx: &AssistContext, 181enum RetType {
322 module: hir::Module, 182 Expr(hir::Type),
323 fun: &Function, 183 Stmt,
324 indent: IndentLevel,
325) -> String {
326 let mut fn_def = String::new();
327 format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name);
328 {
329 let mut it = fun.params.iter();
330 if let Some(self_param) = &fun.self_param {
331 format_to!(fn_def, "{}", self_param);
332 } else if let Some(param) = it.next() {
333 format_to!(
334 fn_def,
335 "{}{}: {}{}",
336 param.mut_pattern(),
337 param.node.name(ctx.db()).unwrap(),
338 param.type_prefix(),
339 format_type(&param.node.ty(ctx.db()), ctx, module)
340 );
341 }
342 for param in it {
343 format_to!(
344 fn_def,
345 ", {}{}: {}{}",
346 param.mut_pattern(),
347 param.node.name(ctx.db()).unwrap(),
348 param.type_prefix(),
349 format_type(&param.node.ty(ctx.db()), ctx, module)
350 );
351 }
352 }
353
354 format_to!(fn_def, ")");
355 if !fun.has_unit_ret() {
356 if let Some(ty) = &fun.ret_ty {
357 format_to!(fn_def, " -> {}", format_type(ty, ctx, module));
358 }
359 } else {
360 match fun.vars_defined_in_body_and_outlive.as_slice() {
361 [] => {}
362 [var] => {
363 format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module));
364 }
365 [v0, vs @ ..] => {
366 format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module));
367 for var in vs {
368 format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module));
369 }
370 fn_def.push(')');
371 }
372 }
373 }
374 fn_def.push_str(" {");
375
376 match &fun.body {
377 FunctionBody::Expr(expr) => {
378 fn_def.push('\n');
379 let expr = expr.indent(indent);
380 let expr = fix_param_usages(ctx, &fun.params, expr.syntax());
381 format_to!(fn_def, "{}{}", indent + 1, expr);
382 fn_def.push('\n');
383 }
384 FunctionBody::Span { elements, leading_indent } => {
385 format_to!(fn_def, "{}", leading_indent);
386 for element in elements {
387 match element {
388 syntax::NodeOrToken::Node(node) => {
389 format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node));
390 }
391 syntax::NodeOrToken::Token(token) => {
392 format_to!(fn_def, "{}", token);
393 }
394 }
395 }
396 if !fn_def.ends_with('\n') {
397 fn_def.push('\n');
398 }
399 }
400 }
401
402 match fun.vars_defined_in_body_and_outlive.as_slice() {
403 [] => {}
404 [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()),
405 [v0, vs @ ..] => {
406 format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap());
407 for var in vs {
408 format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap());
409 }
410 fn_def.push_str(")\n");
411 }
412 }
413
414 format_to!(fn_def, "{}}}", indent);
415
416 fn_def
417}
418
419fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String {
420 ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string())
421}
422
423fn path_at_offset(body: &FunctionBody, reference: &FileReference) -> Option<ast::Expr> {
424 let var = body.token_at_offset(reference.range.start()).right_biased()?;
425 let path = var.ancestors().find_map(ast::Expr::cast)?;
426 stdx::always!(matches!(path, ast::Expr::PathExpr(_)));
427 Some(path)
428}
429
430fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option<bool> {
431 let path = path?;
432 let ref_expr = path.syntax().parent().and_then(ast::RefExpr::cast)?;
433 Some(ref_expr.mut_token().is_some())
434}
435
436fn is_mut_method_call(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option<bool> {
437 let path = path?;
438 let method_call = path.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
439
440 let func = ctx.sema.resolve_method_call(&method_call)?;
441 let self_param = func.self_param(ctx.db())?;
442 let access = self_param.access(ctx.db());
443
444 Some(matches!(access, hir::Access::Exclusive))
445} 184}
446 185
447fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { 186impl RetType {
448 let mut rewriter = SyntaxRewriter::default(); 187 fn is_unit(&self) -> bool {
449 for param in params { 188 match self {
450 if !param.kind().is_ref() { 189 RetType::Expr(ty) => ty.is_unit(),
451 continue; 190 RetType::Stmt => true,
452 } 191 }
192 }
453 193
454 let usages = Definition::Local(param.node) 194 fn as_fn_ret(&self) -> Option<&hir::Type> {
455 .usages(&ctx.sema) 195 match self {
456 .in_scope(SearchScope::single_file(ctx.frange.file_id)) 196 RetType::Stmt => None,
457 .all(); 197 RetType::Expr(ty) if ty.is_unit() => None,
458 let usages = usages 198 RetType::Expr(ty) => Some(ty),
459 .iter()
460 .flat_map(|(_, rs)| rs.iter())
461 .filter(|reference| syntax.text_range().contains_range(reference.range));
462 for reference in usages {
463 let token = match syntax.token_at_offset(reference.range.start()).right_biased() {
464 Some(a) => a,
465 None => {
466 stdx::never!(false, "cannot find token at variable usage: {:?}", reference);
467 continue;
468 }
469 };
470 let path = match token.ancestors().find_map(ast::Expr::cast) {
471 Some(n) => n,
472 None => {
473 stdx::never!(false, "cannot find path parent of variable usage: {:?}", token);
474 continue;
475 }
476 };
477 stdx::always!(matches!(path, ast::Expr::PathExpr(_)));
478 match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) {
479 Some(ast::Expr::MethodCallExpr(_)) => {
480 // do nothing
481 }
482 Some(ast::Expr::RefExpr(node))
483 if param.kind() == ParamKind::MutRef && node.mut_token().is_some() =>
484 {
485 rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap());
486 }
487 Some(ast::Expr::RefExpr(node))
488 if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() =>
489 {
490 rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap());
491 }
492 Some(_) | None => {
493 rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone()));
494 }
495 };
496 } 199 }
497 } 200 }
498
499 rewriter.rewrite(syntax)
500} 201}
501 202
502#[derive(Debug)] 203#[derive(Debug)]
@@ -505,11 +206,6 @@ enum FunctionBody {
505 Span { elements: Vec<SyntaxElement>, leading_indent: String }, 206 Span { elements: Vec<SyntaxElement>, leading_indent: String },
506} 207}
507 208
508enum Anchor {
509 Freestanding,
510 Method,
511}
512
513impl FunctionBody { 209impl FunctionBody {
514 fn from_whole_node(node: SyntaxNode) -> Option<Self> { 210 fn from_whole_node(node: SyntaxNode) -> Option<Self> {
515 match node.kind() { 211 match node.kind() {
@@ -568,16 +264,6 @@ impl FunctionBody {
568 } 264 }
569 } 265 }
570 266
571 fn scope_for_fn_insertion(&self, anchor: Anchor) -> Option<SyntaxNode> {
572 match self {
573 FunctionBody::Expr(e) => scope_for_fn_insertion(e.syntax(), anchor),
574 FunctionBody::Span { elements, .. } => {
575 let node = elements.iter().find_map(|e| e.as_node())?;
576 scope_for_fn_insertion(&node, anchor)
577 }
578 }
579 }
580
581 fn descendants(&self) -> impl Iterator<Item = SyntaxNode> + '_ { 267 fn descendants(&self) -> impl Iterator<Item = SyntaxNode> + '_ {
582 match self { 268 match self {
583 FunctionBody::Expr(expr) => Either::Right(expr.syntax().descendants()), 269 FunctionBody::Expr(expr) => Either::Right(expr.syntax().descendants()),
@@ -590,6 +276,30 @@ impl FunctionBody {
590 } 276 }
591 } 277 }
592 278
279 fn text_range(&self) -> TextRange {
280 match self {
281 FunctionBody::Expr(expr) => expr.syntax().text_range(),
282 FunctionBody::Span { elements, .. } => TextRange::new(
283 elements.first().unwrap().text_range().start(),
284 elements.last().unwrap().text_range().end(),
285 ),
286 }
287 }
288
289 fn contains_range(&self, range: TextRange) -> bool {
290 self.text_range().contains_range(range)
291 }
292
293 fn preceedes_range(&self, range: TextRange) -> bool {
294 self.text_range().end() <= range.start()
295 }
296
297 fn contains_node(&self, node: &SyntaxNode) -> bool {
298 self.contains_range(node.text_range())
299 }
300}
301
302impl HasTokenAtOffset for FunctionBody {
593 fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset<SyntaxToken> { 303 fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset<SyntaxToken> {
594 match self { 304 match self {
595 FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset), 305 FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset),
@@ -621,31 +331,278 @@ impl FunctionBody {
621 } 331 }
622 } 332 }
623 } 333 }
334}
624 335
625 fn text_range(&self) -> TextRange { 336fn element_to_node(node: SyntaxElement) -> SyntaxNode {
626 match self { 337 match node {
627 FunctionBody::Expr(expr) => expr.syntax().text_range(), 338 syntax::NodeOrToken::Node(n) => n,
628 FunctionBody::Span { elements, .. } => TextRange::new( 339 syntax::NodeOrToken::Token(t) => t.parent(),
629 elements.first().unwrap().text_range().start(), 340 }
630 elements.last().unwrap().text_range().end(), 341}
631 ), 342
343fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option<FunctionBody> {
344 if node.text_range() == selection_range {
345 let body = FunctionBody::from_whole_node(node.clone());
346 if body.is_some() {
347 return body;
632 } 348 }
633 } 349 }
634 350
635 fn contains_range(&self, range: TextRange) -> bool { 351 if node.kind() == BLOCK_EXPR {
636 self.text_range().contains_range(range) 352 let body = FunctionBody::from_range(&node, selection_range);
353 if body.is_some() {
354 return body;
355 }
356 }
357 if let Some(parent) = node.parent() {
358 if parent.kind() == BLOCK_EXPR {
359 let body = FunctionBody::from_range(&parent, selection_range);
360 if body.is_some() {
361 return body;
362 }
363 }
637 } 364 }
638 365
639 fn preceedes_range(&self, range: TextRange) -> bool { 366 let body = FunctionBody::from_whole_node(node.clone());
640 self.text_range().end() <= range.start() 367 if body.is_some() {
368 return body;
641 } 369 }
642 370
643 fn contains_node(&self, node: &SyntaxNode) -> bool { 371 let body = node.ancestors().find_map(FunctionBody::from_whole_node);
644 self.contains_range(node.text_range()) 372 if body.is_some() {
373 return body;
374 }
375
376 None
377}
378
379/// Returns a vector of local variables that are referenced in `body`
380fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
381 body.descendants()
382 .filter_map(ast::NameRef::cast)
383 .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref))
384 .map(|name_kind| name_kind.referenced(ctx.db()))
385 .filter_map(|definition| match definition {
386 Definition::Local(local) => Some(local),
387 _ => None,
388 })
389 .unique()
390 .collect()
391}
392
393fn self_param_from_usages(
394 ctx: &AssistContext,
395 body: &FunctionBody,
396 vars_used_in_body: &[Local],
397) -> Option<(Local, ast::SelfParam)> {
398 let mut iter = vars_used_in_body
399 .iter()
400 .filter(|var| var.is_self(ctx.db()))
401 .map(|var| (var, var.source(ctx.db())))
402 .filter(|(_, src)| is_defined_before(ctx, body, src))
403 .filter_map(|(&node, src)| match src.value {
404 Either::Right(it) => Some((node, it)),
405 Either::Left(_) => {
406 stdx::never!(false, "Local::is_self returned true, but source is IdentPat");
407 None
408 }
409 });
410
411 let self_param = iter.next();
412 stdx::always!(
413 iter.next().is_none(),
414 "body references two different self params both defined outside"
415 );
416
417 self_param
418}
419
420fn extracted_function_params(
421 ctx: &AssistContext,
422 body: &FunctionBody,
423 vars_used_in_body: &[Local],
424) -> Vec<Param> {
425 vars_used_in_body
426 .iter()
427 .filter(|var| !var.is_self(ctx.db()))
428 .map(|node| (node, node.source(ctx.db())))
429 .filter(|(_, src)| is_defined_before(ctx, body, src))
430 .filter_map(|(&node, src)| {
431 if src.value.is_left() {
432 Some(node)
433 } else {
434 stdx::never!(false, "Local::is_self returned false, but source is SelfParam");
435 None
436 }
437 })
438 .map(|var| {
439 let usages = LocalUsages::find(ctx, var);
440 Param {
441 var,
442 has_usages_afterwards: has_usages_after_body(&usages, body),
443 has_mut_inside_body: has_exclusive_usages(ctx, &usages, body),
444 is_copy: true,
445 }
446 })
447 .collect()
448}
449
450fn has_usages_after_body(usages: &LocalUsages, body: &FunctionBody) -> bool {
451 usages.iter().any(|reference| body.preceedes_range(reference.range))
452}
453
454fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &FunctionBody) -> bool {
455 usages
456 .iter()
457 .filter(|reference| body.contains_range(reference.range))
458 .any(|reference| reference_is_exclusive(reference, body, ctx))
459}
460
461fn reference_is_exclusive(
462 reference: &FileReference,
463 body: &FunctionBody,
464 ctx: &AssistContext,
465) -> bool {
466 if reference.access == Some(ReferenceAccess::Write) {
467 return true;
468 }
469
470 let path = path_at_offset(body, reference);
471 if is_mut_ref_expr(path.as_ref()).unwrap_or(false) {
472 return true;
473 }
474
475 if is_mut_method_call(ctx, path.as_ref()).unwrap_or(false) {
476 return true;
477 }
478
479 false
480}
481
482struct LocalUsages(ide_db::search::UsageSearchResult);
483
484impl LocalUsages {
485 fn find(ctx: &AssistContext, var: Local) -> Self {
486 Self(
487 Definition::Local(var)
488 .usages(&ctx.sema)
489 .in_scope(SearchScope::single_file(ctx.frange.file_id))
490 .all(),
491 )
492 }
493
494 fn iter(&self) -> impl Iterator<Item = &FileReference> + '_ {
495 self.0.iter().flat_map(|(_, rs)| rs.iter())
496 }
497}
498
499trait HasTokenAtOffset {
500 fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset<SyntaxToken>;
501}
502
503impl HasTokenAtOffset for SyntaxNode {
504 fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset<SyntaxToken> {
505 SyntaxNode::token_at_offset(&self, offset)
506 }
507}
508
509fn path_at_offset(node: &dyn HasTokenAtOffset, reference: &FileReference) -> Option<ast::Expr> {
510 let token = node.token_at_offset(reference.range.start()).right_biased().or_else(|| {
511 stdx::never!(false, "cannot find token at variable usage: {:?}", reference);
512 None
513 })?;
514 let path = token.ancestors().find_map(ast::Expr::cast).or_else(|| {
515 stdx::never!(false, "cannot find path parent of variable usage: {:?}", token);
516 None
517 })?;
518 stdx::always!(matches!(path, ast::Expr::PathExpr(_)));
519 Some(path)
520}
521
522fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option<bool> {
523 let path = path?;
524 let ref_expr = path.syntax().parent().and_then(ast::RefExpr::cast)?;
525 Some(ref_expr.mut_token().is_some())
526}
527
528fn is_mut_method_call(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option<bool> {
529 let path = path?;
530 let method_call = path.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
531
532 let func = ctx.sema.resolve_method_call(&method_call)?;
533 let self_param = func.self_param(ctx.db())?;
534 let access = self_param.access(ctx.db());
535
536 Some(matches!(access, hir::Access::Exclusive))
537}
538
539/// Returns a vector of local variables that are defined in `body`
540fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
541 body.descendants()
542 .filter_map(ast::IdentPat::cast)
543 .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt))
544 .unique()
545 .collect()
546}
547
548fn vars_defined_in_body_and_outlive(ctx: &AssistContext, body: &FunctionBody) -> Vec<Local> {
549 let mut vars_defined_in_body = vars_defined_in_body(&body, ctx);
550 vars_defined_in_body.retain(|var| var_outlives_body(ctx, body, var));
551 vars_defined_in_body
552}
553
554fn is_defined_before(
555 ctx: &AssistContext,
556 body: &FunctionBody,
557 src: &hir::InFile<Either<ast::IdentPat, ast::SelfParam>>,
558) -> bool {
559 src.file_id.original_file(ctx.db()) == ctx.frange.file_id
560 && !body.contains_node(&either_syntax(&src.value))
561}
562
563fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
564 match value {
565 Either::Left(pat) => pat.syntax(),
566 Either::Right(it) => it.syntax(),
567 }
568}
569
570fn var_outlives_body(ctx: &AssistContext, body: &FunctionBody, var: &Local) -> bool {
571 let usages = Definition::Local(*var)
572 .usages(&ctx.sema)
573 .in_scope(SearchScope::single_file(ctx.frange.file_id))
574 .all();
575 let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter());
576
577 usages.any(|reference| body.preceedes_range(reference.range))
578}
579
580fn body_return_ty(ctx: &AssistContext, body: &FunctionBody) -> Option<RetType> {
581 match body.tail_expr() {
582 Some(expr) => {
583 let ty = ctx.sema.type_of_expr(&expr)?;
584 Some(RetType::Expr(ty))
585 }
586 None => Some(RetType::Stmt),
587 }
588}
589#[derive(Debug)]
590enum Anchor {
591 Freestanding,
592 Method,
593}
594
595fn scope_for_fn_insertion(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNode> {
596 match body {
597 FunctionBody::Expr(e) => scope_for_fn_insertion_node(e.syntax(), anchor),
598 FunctionBody::Span { elements, .. } => {
599 let node = elements.iter().find_map(|e| e.as_node())?;
600 scope_for_fn_insertion_node(&node, anchor)
601 }
645 } 602 }
646} 603}
647 604
648fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option<SyntaxNode> { 605fn scope_for_fn_insertion_node(node: &SyntaxNode, anchor: Anchor) -> Option<SyntaxNode> {
649 let mut ancestors = node.ancestors().peekable(); 606 let mut ancestors = node.ancestors().peekable();
650 let mut last_ancestor = None; 607 let mut last_ancestor = None;
651 while let Some(next_ancestor) = ancestors.next() { 608 while let Some(next_ancestor) = ancestors.next() {
@@ -674,34 +631,207 @@ fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option<SyntaxNod
674 last_ancestor 631 last_ancestor
675} 632}
676 633
677fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode { 634fn format_replacement(ctx: &AssistContext, fun: &Function) -> String {
678 match value { 635 let mut buf = String::new();
679 Either::Left(pat) => pat.syntax(), 636
680 Either::Right(it) => it.syntax(), 637 match fun.vars_defined_in_body_and_outlive.as_slice() {
638 [] => {}
639 [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()),
640 [v0, vs @ ..] => {
641 buf.push_str("let (");
642 format_to!(buf, "{}", v0.name(ctx.db()).unwrap());
643 for var in vs {
644 format_to!(buf, ", {}", var.name(ctx.db()).unwrap());
645 }
646 buf.push_str(") = ");
647 }
681 } 648 }
649
650 if fun.self_param.is_some() {
651 format_to!(buf, "self.");
652 }
653 format_to!(buf, "{}(", fun.name);
654 format_arg_list_to(&mut buf, fun, ctx);
655 format_to!(buf, ")");
656
657 if fun.ret_ty.is_unit() {
658 format_to!(buf, ";");
659 }
660
661 buf
682} 662}
683 663
684/// Returns a vector of local variables that are referenced in `body` 664fn format_arg_list_to(buf: &mut String, fun: &Function, ctx: &AssistContext) {
685fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> { 665 let mut it = fun.params.iter();
686 body.descendants() 666 if let Some(param) = it.next() {
687 .filter_map(ast::NameRef::cast) 667 format_arg_to(buf, ctx, param);
688 .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) 668 }
689 .map(|name_kind| name_kind.referenced(ctx.db())) 669 for param in it {
690 .filter_map(|definition| match definition { 670 buf.push_str(", ");
691 Definition::Local(local) => Some(local), 671 format_arg_to(buf, ctx, param);
692 _ => None, 672 }
693 })
694 .unique()
695 .collect()
696} 673}
697 674
698/// Returns a vector of local variables that are defined in `body` 675fn format_arg_to(buf: &mut String, ctx: &AssistContext, param: &Param) {
699fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> { 676 format_to!(buf, "{}{}", param.value_prefix(), param.var.name(ctx.db()).unwrap());
700 body.descendants() 677}
701 .filter_map(ast::IdentPat::cast) 678
702 .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt)) 679fn format_function(
703 .unique() 680 ctx: &AssistContext,
704 .collect() 681 module: hir::Module,
682 fun: &Function,
683 indent: IndentLevel,
684) -> String {
685 let mut fn_def = String::new();
686 format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name);
687 format_function_param_list_to(&mut fn_def, ctx, module, fun);
688 fn_def.push(')');
689 format_function_ret_to(&mut fn_def, ctx, module, fun);
690 fn_def.push_str(" {");
691 format_function_body_to(&mut fn_def, ctx, indent, fun);
692 format_to!(fn_def, "{}}}", indent);
693
694 fn_def
695}
696
697fn format_function_param_list_to(
698 fn_def: &mut String,
699 ctx: &AssistContext,
700 module: hir::Module,
701 fun: &Function,
702) {
703 let mut it = fun.params.iter();
704 if let Some(self_param) = &fun.self_param {
705 format_to!(fn_def, "{}", self_param);
706 } else if let Some(param) = it.next() {
707 format_param_to(fn_def, ctx, module, param);
708 }
709 for param in it {
710 fn_def.push_str(", ");
711 format_param_to(fn_def, ctx, module, param);
712 }
713}
714
715fn format_param_to(fn_def: &mut String, ctx: &AssistContext, module: hir::Module, param: &Param) {
716 format_to!(
717 fn_def,
718 "{}{}: {}{}",
719 param.mut_pattern(),
720 param.var.name(ctx.db()).unwrap(),
721 param.type_prefix(),
722 format_type(&param.var.ty(ctx.db()), ctx, module)
723 );
724}
725
726fn format_function_ret_to(
727 fn_def: &mut String,
728 ctx: &AssistContext,
729 module: hir::Module,
730 fun: &Function,
731) {
732 if let Some(ty) = fun.ret_ty.as_fn_ret() {
733 format_to!(fn_def, " -> {}", format_type(ty, ctx, module));
734 } else {
735 match fun.vars_defined_in_body_and_outlive.as_slice() {
736 [] => {}
737 [var] => {
738 format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module));
739 }
740 [v0, vs @ ..] => {
741 format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module));
742 for var in vs {
743 format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module));
744 }
745 fn_def.push(')');
746 }
747 }
748 }
749}
750
751fn format_function_body_to(
752 fn_def: &mut String,
753 ctx: &AssistContext,
754 indent: IndentLevel,
755 fun: &Function,
756) {
757 match &fun.body {
758 FunctionBody::Expr(expr) => {
759 fn_def.push('\n');
760 let expr = expr.indent(indent);
761 let expr = fix_param_usages(ctx, &fun.params, expr.syntax());
762 format_to!(fn_def, "{}{}", indent + 1, expr);
763 fn_def.push('\n');
764 }
765 FunctionBody::Span { elements, leading_indent } => {
766 format_to!(fn_def, "{}", leading_indent);
767 for element in elements {
768 match element {
769 syntax::NodeOrToken::Node(node) => {
770 format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node));
771 }
772 syntax::NodeOrToken::Token(token) => {
773 format_to!(fn_def, "{}", token);
774 }
775 }
776 }
777 if !fn_def.ends_with('\n') {
778 fn_def.push('\n');
779 }
780 }
781 }
782
783 match fun.vars_defined_in_body_and_outlive.as_slice() {
784 [] => {}
785 [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()),
786 [v0, vs @ ..] => {
787 format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap());
788 for var in vs {
789 format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap());
790 }
791 fn_def.push_str(")\n");
792 }
793 }
794}
795
796fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String {
797 ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string())
798}
799
800fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode {
801 let mut rewriter = SyntaxRewriter::default();
802 for param in params {
803 if !param.kind().is_ref() {
804 continue;
805 }
806
807 let usages = LocalUsages::find(ctx, param.var);
808 let usages = usages
809 .iter()
810 .filter(|reference| syntax.text_range().contains_range(reference.range))
811 .filter_map(|reference| path_at_offset(syntax, reference));
812 for path in usages {
813 match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) {
814 Some(ast::Expr::MethodCallExpr(_)) => {
815 // do nothing
816 }
817 Some(ast::Expr::RefExpr(node))
818 if param.kind() == ParamKind::MutRef && node.mut_token().is_some() =>
819 {
820 rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap());
821 }
822 Some(ast::Expr::RefExpr(node))
823 if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() =>
824 {
825 rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap());
826 }
827 Some(_) | None => {
828 rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone()));
829 }
830 };
831 }
832 }
833
834 rewriter.rewrite(syntax)
705} 835}
706 836
707#[cfg(test)] 837#[cfg(test)]