From 86ff1d4809b978f821f4339a200c9ca0f13e422e Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Thu, 4 Feb 2021 00:27:31 +0300 Subject: allow `&mut param` when extracting function Recognise &mut as variable modification. This allows extracting functions with `&mut var` with `var` being in outer scope --- crates/assists/src/handlers/extract_function.rs | 110 +++++++++++++++++++++++- 1 file changed, 107 insertions(+), 3 deletions(-) (limited to 'crates/assists/src/handlers/extract_function.rs') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index ffa8bd77d..a4b23d756 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -2,7 +2,7 @@ use either::Either; use hir::{HirDisplay, Local}; use ide_db::{ defs::{Definition, NameRefClass}, - search::{ReferenceAccess, SearchScope}, + search::{FileReference, ReferenceAccess, SearchScope}, }; use itertools::Itertools; use stdx::format_to; @@ -15,7 +15,7 @@ use syntax::{ }, Direction, SyntaxElement, SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR}, - SyntaxNode, TextRange, T, + SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, T, }; use test_utils::mark; @@ -140,7 +140,18 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option .iter() .flat_map(|(_, rs)| rs.iter()) .filter(|reference| body.contains_range(reference.range)) - .any(|reference| reference.access == Some(ReferenceAccess::Write)); + .any(|reference| { + if reference.access == Some(ReferenceAccess::Write) { + return true; + } + + let path = path_at_offset(&body, reference); + if is_mut_ref_expr(path.as_ref()).unwrap_or(false) { + return true; + } + + false + }); Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true } }) @@ -405,6 +416,19 @@ fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> Stri ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) } +fn path_at_offset(body: &FunctionBody, reference: &FileReference) -> Option { + let var = body.token_at_offset(reference.range.start()).right_biased()?; + let path = var.ancestors().find_map(ast::Expr::cast)?; + stdx::always!(matches!(path, ast::Expr::PathExpr(_))); + Some(path) +} + +fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option { + let path = path?; + let ref_expr = path.syntax().parent().and_then(ast::RefExpr::cast)?; + Some(ref_expr.mut_token().is_some()) +} + fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { let mut rewriter = SyntaxRewriter::default(); for param in params { @@ -551,6 +575,38 @@ impl FunctionBody { } } + fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset { + match self { + FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset), + FunctionBody::Span { elements, .. } => { + stdx::always!(self.text_range().contains(offset)); + let mut iter = elements + .iter() + .filter(|element| element.text_range().contains_inclusive(offset)); + let element1 = iter.next().expect("offset does not fall into body"); + let element2 = iter.next(); + stdx::always!(iter.next().is_none(), "> 2 tokens at offset"); + let t1 = match element1 { + syntax::NodeOrToken::Node(node) => node.token_at_offset(offset), + syntax::NodeOrToken::Token(token) => TokenAtOffset::Single(token.clone()), + }; + let t2 = element2.map(|e| match e { + syntax::NodeOrToken::Node(node) => node.token_at_offset(offset), + syntax::NodeOrToken::Token(token) => TokenAtOffset::Single(token.clone()), + }); + + match t2 { + Some(t2) => match (t1.clone().right_biased(), t2.clone().left_biased()) { + (Some(e1), Some(e2)) => TokenAtOffset::Between(e1, e2), + (Some(_), None) => t1, + (None, _) => t2, + }, + None => t1, + } + } + } + } + fn text_range(&self) -> TextRange { match self { FunctionBody::Expr(expr) => expr.syntax().text_range(), @@ -1403,6 +1459,54 @@ fn foo() { fn $0fun_name(mut n: i32) { n += 1; +}", + ); + } + + #[test] + fn mut_param_because_of_mut_ref() { + check_assist( + extract_function, + r" +fn foo() { + let mut n = 1; + $0let v = &mut n; + *v += 1;$0 + let k = n; +}", + r" +fn foo() { + let mut n = 1; + fun_name(&mut n); + let k = n; +} + +fn $0fun_name(n: &mut i32) { + let v = n; + *v += 1; +}", + ); + } + + #[test] + fn mut_param_by_value_because_of_mut_ref() { + check_assist( + extract_function, + r" +fn foo() { + let mut n = 1; + $0let v = &mut n; + *v += 1;$0 +}", + r" +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(mut n: i32) { + let v = &mut n; + *v += 1; }", ); } -- cgit v1.2.3