From bc3ae81a873173346df6cb000e503233d7558d03 Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Wed, 3 Feb 2021 10:57:11 +0300 Subject: initial version of extract function assist there are a few currently limitations: * no modifications of function body * does not handle mutability and references * no method support * may produce incorrect results --- crates/assists/src/handlers/extract_function.rs | 819 ++++++++++++++++++++++++ crates/assists/src/lib.rs | 2 + crates/assists/src/tests/generated.rs | 27 + 3 files changed, 848 insertions(+) create mode 100644 crates/assists/src/handlers/extract_function.rs (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs new file mode 100644 index 000000000..1a6cfebed --- /dev/null +++ b/crates/assists/src/handlers/extract_function.rs @@ -0,0 +1,819 @@ +use either::Either; +use hir::{HirDisplay, Local}; +use ide_db::defs::{Definition, NameRefClass}; +use rustc_hash::FxHashSet; +use stdx::format_to; +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + AstNode, NameOwner, + }, + Direction, SyntaxElement, + SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR}, + SyntaxNode, TextRange, +}; +use test_utils::mark; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, +}; + +// Assist: extract_function +// +// Extracts selected statements into new function. +// +// ``` +// fn main() { +// let n = 1; +// $0let m = n + 2; +// let k = m + n;$0 +// let g = 3; +// } +// ``` +// -> +// ``` +// fn main() { +// let n = 1; +// fun_name(n); +// let g = 3; +// } +// +// fn $0fun_name(n: i32) { +// let m = n + 2; +// let k = m + n; +// } +// ``` +pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + if ctx.frange.range.is_empty() { + return None; + } + + let node = ctx.covering_element(); + if node.kind() == COMMENT { + mark::hit!(extract_function_in_comment_is_not_applicable); + return None; + } + + let node = match node { + syntax::NodeOrToken::Node(n) => n, + syntax::NodeOrToken::Token(t) => t.parent(), + }; + + let mut body = None; + if node.text_range() == ctx.frange.range { + body = FunctionBody::from_whole_node(node.clone()); + } + if body.is_none() && node.kind() == BLOCK_EXPR { + body = FunctionBody::from_range(&node, ctx.frange.range); + } + if body.is_none() { + body = FunctionBody::from_whole_node(node.clone()); + } + if body.is_none() { + body = node.ancestors().find_map(FunctionBody::from_whole_node); + } + let body = body?; + + let insert_after = body.scope_for_fn_insertion()?; + + let module = ctx.sema.scope(&insert_after).module()?; + + let expr = body.tail_expr(); + let ret_ty = match expr { + Some(expr) => { + // TODO: can we do assist when type is unknown? + // We can insert something like `-> ()` + let ty = ctx.sema.type_of_expr(&expr)?; + Some(ty.display_source_code(ctx.db(), module.into()).ok()?) + } + None => None, + }; + + let target_range = match &body { + FunctionBody::Expr(expr) => expr.syntax().text_range(), + FunctionBody::Span { .. } => ctx.frange.range, + }; + + let mut params = local_variables(&body, &ctx) + .into_iter() + .map(|node| node.source(ctx.db())) + .filter(|src| src.file_id.original_file(ctx.db()) == ctx.frange.file_id) + .map(|src| match src.value { + Either::Left(pat) => { + (pat.syntax().clone(), pat.name(), ctx.sema.type_of_pat(&pat.into())) + } + Either::Right(it) => (it.syntax().clone(), it.name(), ctx.sema.type_of_self(&it)), + }) + .filter(|(node, _, _)| !body.contains_node(node)) + .map(|(_, name, ty)| { + let ty = ty + .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok()) + .unwrap_or_else(|| "()".to_string()); + + let name = name.unwrap().to_string(); + + Param { name, ty } + }) + .collect::>(); + deduplicate_params(&mut params); + + acc.add( + AssistId("extract_function", crate::AssistKind::RefactorExtract), + "Extract into function", + target_range, + move |builder| { + + let fun = Function { name: "fun_name".to_string(), params, ret_ty, body }; + + builder.replace(target_range, format_replacement(&fun)); + + let indent = IndentLevel::from_node(&insert_after); + + let fn_def = format_function(&fun, indent); + let insert_offset = insert_after.text_range().end(); + builder.insert(insert_offset, fn_def); + }, + ) +} + +fn format_replacement(fun: &Function) -> String { + let mut buf = String::new(); + format_to!(buf, "{}(", fun.name); + { + let mut it = fun.params.iter(); + if let Some(param) = it.next() { + format_to!(buf, "{}", param.name); + } + for param in it { + format_to!(buf, ", {}", param.name); + } + } + format_to!(buf, ")"); + + if fun.has_unit_ret() { + format_to!(buf, ";"); + } + + buf +} + +struct Function { + name: String, + params: Vec, + ret_ty: Option, + body: FunctionBody, +} + +impl Function { + fn has_unit_ret(&self) -> bool { + match &self.ret_ty { + Some(ty) => ty == "()", + None => true, + } + } +} + +#[derive(Debug)] +struct Param { + name: String, + ty: String, +} + +fn format_function(fun: &Function, indent: IndentLevel) -> String { + let mut fn_def = String::new(); + format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); + { + let mut it = fun.params.iter(); + if let Some(param) = it.next() { + format_to!(fn_def, "{}: {}", param.name, param.ty); + } + for param in it { + format_to!(fn_def, ", {}: {}", param.name, param.ty); + } + } + + format_to!(fn_def, ")"); + if !fun.has_unit_ret() { + if let Some(ty) = &fun.ret_ty { + format_to!(fn_def, " -> {}", ty); + } + } + format_to!(fn_def, " {{"); + + match &fun.body { + FunctionBody::Expr(expr) => { + fn_def.push('\n'); + let expr = expr.indent(indent); + format_to!(fn_def, "{}{}", indent + 1, expr.syntax()); + fn_def.push('\n'); + } + FunctionBody::Span { elements, leading_indent } => { + format_to!(fn_def, "{}", leading_indent); + for e in elements { + format_to!(fn_def, "{}", e); + } + if !fn_def.ends_with('\n') { + fn_def.push('\n'); + } + } + } + format_to!(fn_def, "{}}}", indent); + + fn_def +} + +#[derive(Debug)] +enum FunctionBody { + Expr(ast::Expr), + Span { elements: Vec, leading_indent: String }, +} + +impl FunctionBody { + fn from_whole_node(node: SyntaxNode) -> Option { + match node.kind() { + PATH_EXPR => None, + BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()).map(Self::Expr), + RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()).map(Self::Expr), + BLOCK_EXPR => ast::BlockExpr::cast(node) + .filter(|it| it.is_standalone()) + .map(Into::into) + .map(Self::Expr), + _ => ast::Expr::cast(node).map(Self::Expr), + } + } + + fn from_range(node: &SyntaxNode, range: TextRange) -> Option { + let mut first = node.token_at_offset(range.start()).left_biased()?; + let last = node.token_at_offset(range.end()).right_biased()?; + + let mut leading_indent = String::new(); + + let leading_trivia = first + .siblings_with_tokens(Direction::Prev) + .skip(1) + .take_while(|e| e.kind() == SyntaxKind::WHITESPACE && e.as_token().is_some()); + + for e in leading_trivia { + let token = e.as_token().unwrap(); + let text = token.text(); + match text.rfind('\n') { + Some(pos) => { + leading_indent = text[pos..].to_owned(); + break; + } + None => first = token.clone(), + } + } + + let mut elements: Vec<_> = first + .siblings_with_tokens(Direction::Next) + .take_while(|e| e.as_token() != Some(&last)) + .collect(); + + if !(last.kind() == SyntaxKind::WHITESPACE && last.text().lines().count() <= 2) { + elements.push(last.into()); + } + + Some(FunctionBody::Span { elements, leading_indent }) + } + + fn tail_expr(&self) -> Option { + match &self { + FunctionBody::Expr(expr) => Some(expr.clone()), + FunctionBody::Span { elements, .. } => { + elements.iter().rev().find_map(|e| e.as_node()).cloned().and_then(ast::Expr::cast) + } + } + } + + fn scope_for_fn_insertion(&self) -> Option { + match self { + FunctionBody::Expr(e) => scope_for_fn_insertion(e.syntax()), + FunctionBody::Span { elements, .. } => { + let node = elements.iter().find_map(|e| e.as_node())?; + scope_for_fn_insertion(&node) + } + } + } + + fn descendants(&self) -> impl Iterator + '_ { + match self { + FunctionBody::Expr(expr) => Either::Right(expr.syntax().descendants()), + FunctionBody::Span { elements, .. } => Either::Left( + elements + .iter() + .filter_map(SyntaxElement::as_node) + .flat_map(SyntaxNode::descendants), + ), + } + } + + fn contains_node(&self, node: &SyntaxNode) -> bool { + fn is_node(body: &FunctionBody, n: &SyntaxNode) -> bool { + match body { + FunctionBody::Expr(expr) => n == expr.syntax(), + FunctionBody::Span { elements, .. } => { + // FIXME: can it be quadratic? + elements.iter().filter_map(SyntaxElement::as_node).any(|e| e == n) + } + } + } + + node.ancestors().any(|a| is_node(self, &a)) + } +} + +fn scope_for_fn_insertion(node: &SyntaxNode) -> Option { + let mut ancestors = node.ancestors().peekable(); + let mut last_ancestor = None; + while let Some(next_ancestor) = ancestors.next() { + match next_ancestor.kind() { + SyntaxKind::SOURCE_FILE => break, + SyntaxKind::ITEM_LIST => { + if ancestors.peek().map(|a| a.kind()) == Some(SyntaxKind::MODULE) { + break; + } + } + _ => {} + } + last_ancestor = Some(next_ancestor); + } + last_ancestor +} + +fn deduplicate_params(params: &mut Vec) { + let mut seen_params = FxHashSet::default(); + params.retain(|p| seen_params.insert(p.name.clone())); +} + +/// Returns a vector of local variables that are refferenced in `body` +fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec { + body + .descendants() + .filter_map(ast::NameRef::cast) + .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) + .map(|name_kind| name_kind.referenced(ctx.db())) + .filter_map(|definition| match definition { + Definition::Local(local) => Some(local), + _ => None, + }) + .collect() +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn no_args_from_binary_expr() { + check_assist( + extract_function, + r#" +fn foo() { + foo($01 + 1$0); +}"#, + r#" +fn foo() { + foo(fun_name()); +} + +fn $0fun_name() -> i32 { + 1 + 1 +}"#, + ); + } + + #[test] + fn no_args_from_binary_expr_in_module() { + check_assist( + extract_function, + r#" +mod bar { + fn foo() { + foo($01 + 1$0); + } +}"#, + r#" +mod bar { + fn foo() { + foo(fun_name()); + } + + fn $0fun_name() -> i32 { + 1 + 1 + } +}"#, + ); + } + + #[test] + fn no_args_from_binary_expr_indented() { + check_assist( + extract_function, + r#" +fn foo() { + $0{ 1 + 1 }$0; +}"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() -> i32 { + { 1 + 1 } +}"#, + ); + } + + #[test] + fn no_args_from_stmt_with_last_expr() { + check_assist( + extract_function, + r#" +fn foo() -> i32 { + let k = 1; + $0let m = 1; + m + 1$0 +}"#, + r#" +fn foo() -> i32 { + let k = 1; + fun_name() +} + +fn $0fun_name() -> i32 { + let m = 1; + m + 1 +}"#, + ); + } + + #[test] + fn no_args_from_stmt_unit() { + check_assist( + extract_function, + r#" +fn foo() { + let k = 3; + $0let m = 1; + let n = m + 1;$0 + let g = 5; +}"#, + r#" +fn foo() { + let k = 3; + fun_name(); + let g = 5; +} + +fn $0fun_name() { + let m = 1; + let n = m + 1; +}"#, + ); + } + + #[test] + fn no_args_from_loop_unit() { + check_assist( + extract_function, + r#" +fn foo() { + $0loop { + let m = 1; + }$0 +}"#, + r#" +fn foo() { + fun_name() +} + +fn $0fun_name() -> ! { + loop { + let m = 1; + } +}"#, + ); + } + + #[test] + fn no_args_from_loop_with_return() { + check_assist( + extract_function, + r#" +fn foo() { + let v = $0loop { + let m = 1; + break m; + }$0; +}"#, + r#" +fn foo() { + let v = fun_name(); +} + +fn $0fun_name() -> i32 { + loop { + let m = 1; + break m; + } +}"#, + ); + } + + #[test] + fn no_args_from_match() { + check_assist( + extract_function, + r#" +fn foo() { + let v: i32 = $0match Some(1) { + Some(x) => x, + None => 0, + }$0; +}"#, + r#" +fn foo() { + let v: i32 = fun_name(); +} + +fn $0fun_name() -> i32 { + match Some(1) { + Some(x) => x, + None => 0, + } +}"#, + ); + } + + #[test] + fn argument_form_expr() { + check_assist( + extract_function, + r" +fn foo() -> u32 { + let n = 2; + $0n+2$0 +}", + r" +fn foo() -> u32 { + let n = 2; + fun_name(n) +} + +fn $0fun_name(n: u32) -> u32 { + n+2 +}", + ) + } + + #[test] + fn argument_used_twice_form_expr() { + check_assist( + extract_function, + r" +fn foo() -> u32 { + let n = 2; + $0n+n$0 +}", + r" +fn foo() -> u32 { + let n = 2; + fun_name(n) +} + +fn $0fun_name(n: u32) -> u32 { + n+n +}", + ) + } + + #[test] + fn two_arguments_form_expr() { + check_assist( + extract_function, + r" +fn foo() -> u32 { + let n = 2; + let m = 3; + $0n+n*m$0 +}", + r" +fn foo() -> u32 { + let n = 2; + let m = 3; + fun_name(n, m) +} + +fn $0fun_name(n: u32, m: u32) -> u32 { + n+n*m +}", + ) + } + + #[test] + fn argument_and_locals() { + check_assist( + extract_function, + r" +fn foo() -> u32 { + let n = 2; + $0let m = 1; + n + m$0 +}", + r" +fn foo() -> u32 { + let n = 2; + fun_name(n) +} + +fn $0fun_name(n: u32) -> u32 { + let m = 1; + n + m +}", + ) + } + + #[test] + fn in_comment_is_not_applicable() { + mark::check!(extract_function_in_comment_is_not_applicable); + check_assist_not_applicable(extract_function, r"fn main() { 1 + /* $0comment$0 */ 1; }"); + } + + #[test] + fn part_of_expr_stmt() { + check_assist( + extract_function, + " +fn foo() { + $01$0 + 1; +}", + " +fn foo() { + fun_name() + 1; +} + +fn $0fun_name() -> i32 { + 1 +}", + ); + } + + #[test] + fn function_expr() { + check_assist( + extract_function, + r#" +fn foo() { + $0bar(1 + 1)$0 +}"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() { + bar(1 + 1) +}"#, + ) + } + + #[test] + fn extract_from_nested() { + check_assist( + extract_function, + r" +fn main() { + let x = true; + let tuple = match x { + true => ($02 + 2$0, true) + _ => (0, false) + }; +}", + r" +fn main() { + let x = true; + let tuple = match x { + true => (fun_name(), true) + _ => (0, false) + }; +} + +fn $0fun_name() -> i32 { + 2 + 2 +}", + ); + } + + #[test] + fn param_from_closure() { + check_assist( + extract_function, + r" +fn main() { + let lambda = |x: u32| $0x * 2$0; +}", + r" +fn main() { + let lambda = |x: u32| fun_name(x); +} + +fn $0fun_name(x: u32) -> u32 { + x * 2 +}", + ); + } + + #[test] + fn extract_return_stmt() { + check_assist( + extract_function, + r" +fn foo() -> u32 { + $0return 2 + 2$0; +}", + r" +fn foo() -> u32 { + return fun_name(); +} + +fn $0fun_name() -> u32 { + 2 + 2 +}", + ); + } + + #[test] + fn does_not_add_extra_whitespace() { + check_assist( + extract_function, + r" +fn foo() -> u32 { + + + $0return 2 + 2$0; +}", + r" +fn foo() -> u32 { + + + return fun_name(); +} + +fn $0fun_name() -> u32 { + 2 + 2 +}", + ); + } + + #[test] + fn break_stmt() { + check_assist( + extract_function, + r" +fn main() { + let result = loop { + $0break 2 + 2$0; + }; +}", + r" +fn main() { + let result = loop { + break fun_name(); + }; +} + +fn $0fun_name() -> i32 { + 2 + 2 +}", + ); + } + + #[test] + fn extract_cast() { + check_assist( + extract_function, + r" +fn main() { + let v = $00f32 as u32$0; +}", + r" +fn main() { + let v = fun_name(); +} + +fn $0fun_name() -> u32 { + 0f32 as u32 +}", + ); + } + + #[test] + fn return_not_applicable() { + check_assist_not_applicable(extract_function, r"fn foo() { $0return$0; } "); + } +} diff --git a/crates/assists/src/lib.rs b/crates/assists/src/lib.rs index 559b9651e..062a902ab 100644 --- a/crates/assists/src/lib.rs +++ b/crates/assists/src/lib.rs @@ -117,6 +117,7 @@ mod handlers { mod convert_integer_literal; mod early_return; mod expand_glob_import; + mod extract_function; mod extract_struct_from_enum_variant; mod extract_variable; mod fill_match_arms; @@ -174,6 +175,7 @@ mod handlers { early_return::convert_to_guarded_return, expand_glob_import::expand_glob_import, move_module_to_file::move_module_to_file, + extract_function::extract_function, extract_struct_from_enum_variant::extract_struct_from_enum_variant, extract_variable::extract_variable, fill_match_arms::fill_match_arms, diff --git a/crates/assists/src/tests/generated.rs b/crates/assists/src/tests/generated.rs index 9aa807f10..e84f208a3 100644 --- a/crates/assists/src/tests/generated.rs +++ b/crates/assists/src/tests/generated.rs @@ -256,6 +256,33 @@ fn qux(bar: Bar, baz: Baz) {} ) } +#[test] +fn doctest_extract_function() { + check_doc_test( + "extract_function", + r#####" +fn main() { + let n = 1; + $0let m = n + 2; + let k = m + n;$0 + let g = 3; +} +"#####, + r#####" +fn main() { + let n = 1; + fun_name(n); + let g = 3; +} + +fn $0fun_name(n: i32) { + let m = n + 2; + let k = m + n; +} +"#####, + ) +} + #[test] fn doctest_extract_struct_from_enum_variant() { check_doc_test( -- cgit v1.2.3 From 1e6f13a0bee0d9600e7b582fbd9a2e1f4a9a24fc Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Wed, 3 Feb 2021 12:27:53 +0300 Subject: support extracting methods; no mut lowering currently mut refernce will *not* be downgraded to shared if it is sufficient(see relevant test for example) --- crates/assists/src/handlers/extract_function.rs | 228 ++++++++++++++++++++---- 1 file changed, 191 insertions(+), 37 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 1a6cfebed..09c2a9bc7 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -68,6 +68,11 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option if body.is_none() && node.kind() == BLOCK_EXPR { body = FunctionBody::from_range(&node, ctx.frange.range); } + if let Some(parent) = node.parent() { + if body.is_none() && parent.kind() == BLOCK_EXPR { + body = FunctionBody::from_range(&parent, ctx.frange.range); + } + } if body.is_none() { body = FunctionBody::from_whole_node(node.clone()); } @@ -76,10 +81,47 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option } let body = body?; - let insert_after = body.scope_for_fn_insertion()?; + let mut self_param = None; + let mut param_pats: Vec<_> = local_variables(&body, &ctx) + .into_iter() + .map(|node| node.source(ctx.db())) + .filter(|src| { + src.file_id.original_file(ctx.db()) == ctx.frange.file_id + && !body.contains_node(&either_syntax(&src.value)) + }) + .filter_map(|src| match src.value { + Either::Left(pat) => Some(pat), + Either::Right(it) => { + // we filter self param, as there can only be one + self_param = Some(it); + None + } + }) + .collect(); + deduplicate_params(&mut param_pats); + let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; + let insert_after = body.scope_for_fn_insertion(anchor)?; let module = ctx.sema.scope(&insert_after).module()?; + let params = param_pats + .into_iter() + .map(|pat| { + let ty = pat + .pat() + .and_then(|pat| ctx.sema.type_of_pat(&pat)) + .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok()) + .unwrap_or_else(|| "()".to_string()); + + let name = pat.name().unwrap().to_string(); + + Param { name, ty } + }) + .collect::>(); + + let self_param = + if let Some(self_param) = self_param { Some(self_param.to_string()) } else { None }; + let expr = body.tail_expr(); let ret_ty = match expr { Some(expr) => { @@ -96,36 +138,12 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option FunctionBody::Span { .. } => ctx.frange.range, }; - let mut params = local_variables(&body, &ctx) - .into_iter() - .map(|node| node.source(ctx.db())) - .filter(|src| src.file_id.original_file(ctx.db()) == ctx.frange.file_id) - .map(|src| match src.value { - Either::Left(pat) => { - (pat.syntax().clone(), pat.name(), ctx.sema.type_of_pat(&pat.into())) - } - Either::Right(it) => (it.syntax().clone(), it.name(), ctx.sema.type_of_self(&it)), - }) - .filter(|(node, _, _)| !body.contains_node(node)) - .map(|(_, name, ty)| { - let ty = ty - .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok()) - .unwrap_or_else(|| "()".to_string()); - - let name = name.unwrap().to_string(); - - Param { name, ty } - }) - .collect::>(); - deduplicate_params(&mut params); - acc.add( AssistId("extract_function", crate::AssistKind::RefactorExtract), "Extract into function", target_range, move |builder| { - - let fun = Function { name: "fun_name".to_string(), params, ret_ty, body }; + let fun = Function { name: "fun_name".to_string(), self_param, params, ret_ty, body }; builder.replace(target_range, format_replacement(&fun)); @@ -140,6 +158,9 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option fn format_replacement(fun: &Function) -> String { let mut buf = String::new(); + if fun.self_param.is_some() { + format_to!(buf, "self."); + } format_to!(buf, "{}(", fun.name); { let mut it = fun.params.iter(); @@ -161,6 +182,7 @@ fn format_replacement(fun: &Function) -> String { struct Function { name: String, + self_param: Option, params: Vec, ret_ty: Option, body: FunctionBody, @@ -186,7 +208,9 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String { format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); { let mut it = fun.params.iter(); - if let Some(param) = it.next() { + if let Some(self_param) = &fun.self_param { + format_to!(fn_def, "{}", self_param); + } else if let Some(param) = it.next() { format_to!(fn_def, "{}: {}", param.name, param.ty); } for param in it { @@ -230,6 +254,11 @@ enum FunctionBody { Span { elements: Vec, leading_indent: String }, } +enum Anchor { + Freestanding, + Method, +} + impl FunctionBody { fn from_whole_node(node: SyntaxNode) -> Option { match node.kind() { @@ -288,12 +317,12 @@ impl FunctionBody { } } - fn scope_for_fn_insertion(&self) -> Option { + fn scope_for_fn_insertion(&self, anchor: Anchor) -> Option { match self { - FunctionBody::Expr(e) => scope_for_fn_insertion(e.syntax()), + FunctionBody::Expr(e) => scope_for_fn_insertion(e.syntax(), anchor), FunctionBody::Span { elements, .. } => { let node = elements.iter().find_map(|e| e.as_node())?; - scope_for_fn_insertion(&node) + scope_for_fn_insertion(&node, anchor) } } } @@ -325,14 +354,25 @@ impl FunctionBody { } } -fn scope_for_fn_insertion(node: &SyntaxNode) -> Option { +fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option { let mut ancestors = node.ancestors().peekable(); let mut last_ancestor = None; while let Some(next_ancestor) = ancestors.next() { match next_ancestor.kind() { SyntaxKind::SOURCE_FILE => break, SyntaxKind::ITEM_LIST => { - if ancestors.peek().map(|a| a.kind()) == Some(SyntaxKind::MODULE) { + if !matches!(anchor, Anchor::Freestanding) { + continue; + } + if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::MODULE) { + break; + } + } + SyntaxKind::ASSOC_ITEM_LIST => { + if !matches!(anchor, Anchor::Method) { + continue; + } + if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::IMPL) { break; } } @@ -343,15 +383,21 @@ fn scope_for_fn_insertion(node: &SyntaxNode) -> Option { last_ancestor } -fn deduplicate_params(params: &mut Vec) { +fn deduplicate_params(params: &mut Vec) { let mut seen_params = FxHashSet::default(); - params.retain(|p| seen_params.insert(p.name.clone())); + params.retain(|p| seen_params.insert(p.clone())); +} + +fn either_syntax(value: &Either) -> &SyntaxNode { + match value { + Either::Left(pat) => pat.syntax(), + Either::Right(it) => it.syntax(), + } } /// Returns a vector of local variables that are refferenced in `body` fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec { - body - .descendants() + body.descendants() .filter_map(ast::NameRef::cast) .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) .map(|name_kind| name_kind.referenced(ctx.db())) @@ -386,7 +432,7 @@ fn $0fun_name() -> i32 { }"#, ); } - + #[test] fn no_args_from_binary_expr_in_module() { check_assist( @@ -816,4 +862,112 @@ fn $0fun_name() -> u32 { fn return_not_applicable() { check_assist_not_applicable(extract_function, r"fn foo() { $0return$0; } "); } + + #[test] + fn method_to_freestanding() { + check_assist( + extract_function, + r" +struct S; + +impl S { + fn foo(&self) -> i32 { + $01+1$0 + } +}", + r" +struct S; + +impl S { + fn foo(&self) -> i32 { + fun_name() + } +} + +fn $0fun_name() -> i32 { + 1+1 +}", + ); + } + + #[test] + fn method_with_reference() { + check_assist( + extract_function, + r" +struct S { f: i32 }; + +impl S { + fn foo(&self) -> i32 { + $01+self.f$0 + } +}", + r" +struct S { f: i32 }; + +impl S { + fn foo(&self) -> i32 { + self.fun_name() + } + + fn $0fun_name(&self) -> i32 { + 1+self.f + } +}", + ); + } + + #[test] + fn method_with_mut() { + check_assist( + extract_function, + r" +struct S { f: i32 }; + +impl S { + fn foo(&mut self) { + $0self.f += 1;$0 + } +}", + r" +struct S { f: i32 }; + +impl S { + fn foo(&mut self) { + self.fun_name(); + } + + fn $0fun_name(&mut self) { + self.f += 1; + } +}", + ); + } + + #[test] + fn method_with_mut_downgrade_to_shared() { + check_assist( + extract_function, + r" +struct S { f: i32 }; + +impl S { + fn foo(&mut self) -> i32 { + $01+self.f$0 + } +}", + r" +struct S { f: i32 }; + +impl S { + fn foo(&mut self) -> i32 { + self.fun_name() + } + + fn $0fun_name(&self) -> i32 { + 1+self.f + } +}", + ); + } } -- cgit v1.2.3 From 88b3034636a4f4c652e49de09a791a934573aaee Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Wed, 3 Feb 2021 17:45:36 +0300 Subject: convert IdentPat to Pat via Into before child getter was used --- crates/assists/src/handlers/extract_function.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 09c2a9bc7..218529fcf 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -107,14 +107,14 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let params = param_pats .into_iter() .map(|pat| { - let ty = pat - .pat() - .and_then(|pat| ctx.sema.type_of_pat(&pat)) + let name = pat.name().unwrap().to_string(); + + let ty = ctx + .sema + .type_of_pat(&pat.into()) .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok()) .unwrap_or_else(|| "()".to_string()); - let name = pat.name().unwrap().to_string(); - Param { name, ty } }) .collect::>(); -- cgit v1.2.3 From f0d2bb9131fab4898221d30caf0b5a12800ba4e8 Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Wed, 3 Feb 2021 17:46:57 +0300 Subject: disable test for downgrading mutability on extract --- crates/assists/src/handlers/extract_function.rs | 3 +++ 1 file changed, 3 insertions(+) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 218529fcf..66c5cdb8f 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -944,6 +944,9 @@ impl S { ); } + // it is unclear if this is wanted behaviour + // and how this behavour can be implemented + #[ignore] #[test] fn method_with_mut_downgrade_to_shared() { check_assist( -- cgit v1.2.3 From 313aa5f3a2a9237c96c97c5852da39cf83bcb1ae Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Wed, 3 Feb 2021 17:47:21 +0300 Subject: change TODO to FIXME --- crates/assists/src/handlers/extract_function.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 66c5cdb8f..958199e5e 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -125,8 +125,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let expr = body.tail_expr(); let ret_ty = match expr { Some(expr) => { - // TODO: can we do assist when type is unknown? - // We can insert something like `-> ()` + // FIXME: can we do assist when type is unknown? + // We can insert something like `-> ()` let ty = ctx.sema.type_of_expr(&expr)?; Some(ty.display_source_code(ctx.db(), module.into()).ok()?) } -- cgit v1.2.3 From 82787febdee3e7dfe5a96c94aee03cd726f642f9 Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Wed, 3 Feb 2021 20:31:12 +0300 Subject: allow local variables to be used after extracted body when variable is defined inside extracted body export this variable to original scope via return value(s) --- crates/assists/src/handlers/extract_function.rs | 224 +++++++++++++++++++----- 1 file changed, 183 insertions(+), 41 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 958199e5e..c5e6ec733 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -1,7 +1,10 @@ use either::Either; use hir::{HirDisplay, Local}; -use ide_db::defs::{Definition, NameRefClass}; -use rustc_hash::FxHashSet; +use ide_db::{ + defs::{Definition, NameRefClass}, + search::SearchScope, +}; +use itertools::Itertools; use stdx::format_to; use syntax::{ ast::{ @@ -81,9 +84,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option } let body = body?; + let vars_used_in_body = vars_used_in_body(&body, &ctx); let mut self_param = None; - let mut param_pats: Vec<_> = local_variables(&body, &ctx) - .into_iter() + let param_pats: Vec<_> = vars_used_in_body + .iter() .map(|node| node.source(ctx.db())) .filter(|src| { src.file_id.original_file(ctx.db()) == ctx.frange.file_id @@ -98,12 +102,27 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option } }) .collect(); - deduplicate_params(&mut param_pats); let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; let insert_after = body.scope_for_fn_insertion(anchor)?; let module = ctx.sema.scope(&insert_after).module()?; + let vars_defined_in_body = vars_defined_in_body(&body, ctx); + + let vars_in_body_used_afterwards: Vec<_> = vars_defined_in_body + .iter() + .copied() + .filter(|node| { + let usages = Definition::Local(*node) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.frange.file_id)) + .all(); + let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter()); + + usages.any(|reference| body.preceedes_range(reference.range)) + }) + .collect(); + let params = param_pats .into_iter() .map(|pat| { @@ -119,20 +138,18 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option }) .collect::>(); - let self_param = - if let Some(self_param) = self_param { Some(self_param.to_string()) } else { None }; - let expr = body.tail_expr(); let ret_ty = match expr { - Some(expr) => { - // FIXME: can we do assist when type is unknown? - // We can insert something like `-> ()` - let ty = ctx.sema.type_of_expr(&expr)?; - Some(ty.display_source_code(ctx.db(), module.into()).ok()?) - } + Some(expr) => Some(ctx.sema.type_of_expr(&expr)?), None => None, }; + let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit()); + if stdx::never!(!vars_in_body_used_afterwards.is_empty() && !has_unit_ret) { + // We should not have variables that outlive body if we have expression block + return None; + } + let target_range = match &body { FunctionBody::Expr(expr) => expr.syntax().text_range(), FunctionBody::Span { .. } => ctx.frange.range, @@ -143,21 +160,46 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option "Extract into function", target_range, move |builder| { - let fun = Function { name: "fun_name".to_string(), self_param, params, ret_ty, body }; + let fun = Function { + name: "fun_name".to_string(), + self_param, + params, + ret_ty, + body, + vars_in_body_used_afterwards, + }; - builder.replace(target_range, format_replacement(&fun)); + builder.replace(target_range, format_replacement(ctx, &fun)); let indent = IndentLevel::from_node(&insert_after); - let fn_def = format_function(&fun, indent); + let fn_def = format_function(ctx, module, &fun, indent); let insert_offset = insert_after.text_range().end(); builder.insert(insert_offset, fn_def); }, ) } -fn format_replacement(fun: &Function) -> String { +fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { let mut buf = String::new(); + + match fun.vars_in_body_used_afterwards.len() { + 0 => {} + 1 => format_to!( + buf, + "let {} = ", + fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap() + ), + _ => { + buf.push_str("let ("); + format_to!(buf, "{}", fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()); + for local in fun.vars_in_body_used_afterwards.iter().skip(1) { + format_to!(buf, ", {}", local.name(ctx.db()).unwrap()); + } + buf.push_str(") = "); + } + } + if fun.self_param.is_some() { format_to!(buf, "self."); } @@ -182,16 +224,17 @@ fn format_replacement(fun: &Function) -> String { struct Function { name: String, - self_param: Option, + self_param: Option, params: Vec, - ret_ty: Option, + ret_ty: Option, body: FunctionBody, + vars_in_body_used_afterwards: Vec, } impl Function { fn has_unit_ret(&self) -> bool { match &self.ret_ty { - Some(ty) => ty == "()", + Some(ty) => ty.is_unit(), None => true, } } @@ -203,7 +246,12 @@ struct Param { ty: String, } -fn format_function(fun: &Function, indent: IndentLevel) -> String { +fn format_function( + ctx: &AssistContext, + module: hir::Module, + fun: &Function, + indent: IndentLevel, +) -> String { let mut fn_def = String::new(); format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); { @@ -221,10 +269,24 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String { format_to!(fn_def, ")"); if !fun.has_unit_ret() { if let Some(ty) = &fun.ret_ty { - format_to!(fn_def, " -> {}", ty); + format_to!(fn_def, " -> {}", format_type(ty, ctx, module)); + } + } else { + match fun.vars_in_body_used_afterwards.as_slice() { + [] => {} + [var] => { + format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module)); + } + [v0, vs @ ..] => { + format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module)); + for var in vs { + format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module)); + } + fn_def.push(')'); + } } } - format_to!(fn_def, " {{"); + fn_def.push_str(" {"); match &fun.body { FunctionBody::Expr(expr) => { @@ -243,11 +305,28 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String { } } } + + match fun.vars_in_body_used_afterwards.as_slice() { + [] => {} + [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()), + [v0, vs @ ..] => { + format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap()); + for var in vs { + format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap()); + } + fn_def.push_str(")\n"); + } + } + format_to!(fn_def, "{}}}", indent); fn_def } +fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String { + ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) +} + #[derive(Debug)] enum FunctionBody { Expr(ast::Expr), @@ -339,18 +418,26 @@ impl FunctionBody { } } - fn contains_node(&self, node: &SyntaxNode) -> bool { - fn is_node(body: &FunctionBody, n: &SyntaxNode) -> bool { - match body { - FunctionBody::Expr(expr) => n == expr.syntax(), - FunctionBody::Span { elements, .. } => { - // FIXME: can it be quadratic? - elements.iter().filter_map(SyntaxElement::as_node).any(|e| e == n) - } - } + fn text_range(&self) -> TextRange { + match self { + FunctionBody::Expr(expr) => expr.syntax().text_range(), + FunctionBody::Span { elements, .. } => TextRange::new( + elements.first().unwrap().text_range().start(), + elements.last().unwrap().text_range().end(), + ), } + } + + fn contains_range(&self, range: TextRange) -> bool { + self.text_range().contains_range(range) + } - node.ancestors().any(|a| is_node(self, &a)) + fn preceedes_range(&self, range: TextRange) -> bool { + self.text_range().end() <= range.start() + } + + fn contains_node(&self, node: &SyntaxNode) -> bool { + self.contains_range(node.text_range()) } } @@ -383,11 +470,6 @@ fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option) { - let mut seen_params = FxHashSet::default(); - params.retain(|p| seen_params.insert(p.clone())); -} - fn either_syntax(value: &Either) -> &SyntaxNode { match value { Either::Left(pat) => pat.syntax(), @@ -395,8 +477,8 @@ fn either_syntax(value: &Either) -> &SyntaxNode { } } -/// Returns a vector of local variables that are refferenced in `body` -fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec { +/// Returns a vector of local variables that are referenced in `body` +fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { body.descendants() .filter_map(ast::NameRef::cast) .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) @@ -405,6 +487,16 @@ fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec { Definition::Local(local) => Some(local), _ => None, }) + .unique() + .collect() +} + +/// Returns a vector of local variables that are defined in `body` +fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { + body.descendants() + .filter_map(ast::IdentPat::cast) + .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt)) + .unique() .collect() } @@ -970,6 +1062,56 @@ impl S { fn $0fun_name(&self) -> i32 { 1+self.f } +}", + ); + } + + #[test] + fn variable_defined_inside_and_used_after_no_ret() { + check_assist( + extract_function, + r" +fn foo() { + let n = 1; + $0let k = n * n;$0 + let m = k + 1; +}", + r" +fn foo() { + let n = 1; + let k = fun_name(n); + let m = k + 1; +} + +fn $0fun_name(n: i32) -> i32 { + let k = n * n; + k +}", + ); + } + + #[test] + fn two_variables_defined_inside_and_used_after_no_ret() { + check_assist( + extract_function, + r" +fn foo() { + let n = 1; + $0let k = n * n; + let m = k + 2;$0 + let h = k + m; +}", + r" +fn foo() { + let n = 1; + let (k, m) = fun_name(n); + let h = k + m; +} + +fn $0fun_name(n: i32) -> (i32, i32) { + let k = n * n; + let m = k + 2; + (k, m) }", ); } -- cgit v1.2.3 From f102616aaea2894508f8f078cfb20ceef5411d12 Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Wed, 3 Feb 2021 23:45:03 +0300 Subject: allow modifications of vars from outer scope inside extracted function It currently allows only directly setting variable. No `&mut` references or methods. --- crates/assists/src/handlers/extract_function.rs | 381 +++++++++++++++++++++--- 1 file changed, 336 insertions(+), 45 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index c5e6ec733..ffa8bd77d 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -2,19 +2,20 @@ use either::Either; use hir::{HirDisplay, Local}; use ide_db::{ defs::{Definition, NameRefClass}, - search::SearchScope, + search::{ReferenceAccess, SearchScope}, }; use itertools::Itertools; use stdx::format_to; use syntax::{ + algo::SyntaxRewriter, ast::{ self, edit::{AstNodeEdit, IndentLevel}, - AstNode, NameOwner, + AstNode, }, Direction, SyntaxElement, SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR}, - SyntaxNode, TextRange, + SyntaxNode, TextRange, T, }; use test_utils::mark; @@ -88,16 +89,16 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let mut self_param = None; let param_pats: Vec<_> = vars_used_in_body .iter() - .map(|node| node.source(ctx.db())) - .filter(|src| { + .map(|node| (node, node.source(ctx.db()))) + .filter(|(_, src)| { src.file_id.original_file(ctx.db()) == ctx.frange.file_id && !body.contains_node(&either_syntax(&src.value)) }) - .filter_map(|src| match src.value { - Either::Left(pat) => Some(pat), + .filter_map(|(&node, src)| match src.value { + Either::Left(_) => Some(node), Either::Right(it) => { // we filter self param, as there can only be one - self_param = Some(it); + self_param = Some((node, it)); None } }) @@ -109,7 +110,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let vars_defined_in_body = vars_defined_in_body(&body, ctx); - let vars_in_body_used_afterwards: Vec<_> = vars_defined_in_body + let vars_defined_in_body_and_outlive: Vec<_> = vars_defined_in_body .iter() .copied() .filter(|node| { @@ -123,20 +124,27 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option }) .collect(); - let params = param_pats + let params: Vec<_> = param_pats .into_iter() - .map(|pat| { - let name = pat.name().unwrap().to_string(); - - let ty = ctx - .sema - .type_of_pat(&pat.into()) - .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok()) - .unwrap_or_else(|| "()".to_string()); + .map(|node| { + let usages = Definition::Local(node) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.frange.file_id)) + .all(); - Param { name, ty } + let has_usages_afterwards = usages + .iter() + .flat_map(|(_, rs)| rs.iter()) + .any(|reference| body.preceedes_range(reference.range)); + let has_mut_inside_body = usages + .iter() + .flat_map(|(_, rs)| rs.iter()) + .filter(|reference| body.contains_range(reference.range)) + .any(|reference| reference.access == Some(ReferenceAccess::Write)); + + Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true } }) - .collect::>(); + .collect(); let expr = body.tail_expr(); let ret_ty = match expr { @@ -145,7 +153,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option }; let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit()); - if stdx::never!(!vars_in_body_used_afterwards.is_empty() && !has_unit_ret) { + if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !has_unit_ret) { // We should not have variables that outlive body if we have expression block return None; } @@ -162,11 +170,11 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option move |builder| { let fun = Function { name: "fun_name".to_string(), - self_param, + self_param: self_param.map(|(_, pat)| pat), params, ret_ty, body, - vars_in_body_used_afterwards, + vars_defined_in_body_and_outlive, }; builder.replace(target_range, format_replacement(ctx, &fun)); @@ -183,17 +191,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { let mut buf = String::new(); - match fun.vars_in_body_used_afterwards.len() { - 0 => {} - 1 => format_to!( - buf, - "let {} = ", - fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap() - ), - _ => { + match fun.vars_defined_in_body_and_outlive.as_slice() { + [] => {} + [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()), + [v0, vs @ ..] => { buf.push_str("let ("); - format_to!(buf, "{}", fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()); - for local in fun.vars_in_body_used_afterwards.iter().skip(1) { + format_to!(buf, "{}", v0.name(ctx.db()).unwrap()); + for local in vs { format_to!(buf, ", {}", local.name(ctx.db()).unwrap()); } buf.push_str(") = "); @@ -207,10 +211,10 @@ fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { { let mut it = fun.params.iter(); if let Some(param) = it.next() { - format_to!(buf, "{}", param.name); + format_to!(buf, "{}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap()); } for param in it { - format_to!(buf, ", {}", param.name); + format_to!(buf, ", {}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap()); } } format_to!(buf, ")"); @@ -228,7 +232,7 @@ struct Function { params: Vec, ret_ty: Option, body: FunctionBody, - vars_in_body_used_afterwards: Vec, + vars_defined_in_body_and_outlive: Vec, } impl Function { @@ -242,8 +246,60 @@ impl Function { #[derive(Debug)] struct Param { - name: String, - ty: String, + node: Local, + has_usages_afterwards: bool, + has_mut_inside_body: bool, + is_copy: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ParamKind { + Value, + MutValue, + SharedRef, + MutRef, +} + +impl ParamKind { + fn is_ref(&self) -> bool { + matches!(self, ParamKind::SharedRef | ParamKind::MutRef) + } +} + +impl Param { + fn kind(&self) -> ParamKind { + match (self.has_usages_afterwards, self.has_mut_inside_body, self.is_copy) { + (true, true, _) => ParamKind::MutRef, + (true, false, false) => ParamKind::SharedRef, + (false, true, _) => ParamKind::MutValue, + (true, false, true) | (false, false, _) => ParamKind::Value, + } + } + + fn value_prefix(&self) -> &'static str { + match self.kind() { + ParamKind::Value => "", + ParamKind::MutValue => "", + ParamKind::SharedRef => "&", + ParamKind::MutRef => "&mut ", + } + } + + fn type_prefix(&self) -> &'static str { + match self.kind() { + ParamKind::Value => "", + ParamKind::MutValue => "", + ParamKind::SharedRef => "&", + ParamKind::MutRef => "&mut ", + } + } + + fn mut_pattern(&self) -> &'static str { + match self.kind() { + ParamKind::MutValue => "mut ", + _ => "", + } + } } fn format_function( @@ -259,10 +315,24 @@ fn format_function( if let Some(self_param) = &fun.self_param { format_to!(fn_def, "{}", self_param); } else if let Some(param) = it.next() { - format_to!(fn_def, "{}: {}", param.name, param.ty); + format_to!( + fn_def, + "{}{}: {}{}", + param.mut_pattern(), + param.node.name(ctx.db()).unwrap(), + param.type_prefix(), + format_type(¶m.node.ty(ctx.db()), ctx, module) + ); } for param in it { - format_to!(fn_def, ", {}: {}", param.name, param.ty); + format_to!( + fn_def, + ", {}{}: {}{}", + param.mut_pattern(), + param.node.name(ctx.db()).unwrap(), + param.type_prefix(), + format_type(¶m.node.ty(ctx.db()), ctx, module) + ); } } @@ -272,7 +342,7 @@ fn format_function( format_to!(fn_def, " -> {}", format_type(ty, ctx, module)); } } else { - match fun.vars_in_body_used_afterwards.as_slice() { + match fun.vars_defined_in_body_and_outlive.as_slice() { [] => {} [var] => { format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module)); @@ -292,13 +362,21 @@ fn format_function( FunctionBody::Expr(expr) => { fn_def.push('\n'); let expr = expr.indent(indent); - format_to!(fn_def, "{}{}", indent + 1, expr.syntax()); + let expr = fix_param_usages(ctx, &fun.params, expr.syntax()); + format_to!(fn_def, "{}{}", indent + 1, expr); fn_def.push('\n'); } FunctionBody::Span { elements, leading_indent } => { format_to!(fn_def, "{}", leading_indent); - for e in elements { - format_to!(fn_def, "{}", e); + for element in elements { + match element { + syntax::NodeOrToken::Node(node) => { + format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node)); + } + syntax::NodeOrToken::Token(token) => { + format_to!(fn_def, "{}", token); + } + } } if !fn_def.ends_with('\n') { fn_def.push('\n'); @@ -306,7 +384,7 @@ fn format_function( } } - match fun.vars_in_body_used_afterwards.as_slice() { + match fun.vars_defined_in_body_and_outlive.as_slice() { [] => {} [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()), [v0, vs @ ..] => { @@ -327,6 +405,61 @@ fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> Stri ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) } +fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { + let mut rewriter = SyntaxRewriter::default(); + for param in params { + if !param.kind().is_ref() { + continue; + } + + let usages = Definition::Local(param.node) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.frange.file_id)) + .all(); + let usages = usages + .iter() + .flat_map(|(_, rs)| rs.iter()) + .filter(|reference| syntax.text_range().contains_range(reference.range)); + for reference in usages { + let token = match syntax.token_at_offset(reference.range.start()).right_biased() { + Some(a) => a, + None => { + stdx::never!(false, "cannot find token at variable usage: {:?}", reference); + continue; + } + }; + let path = match token.ancestors().find_map(ast::Expr::cast) { + Some(n) => n, + None => { + stdx::never!(false, "cannot find path parent of variable usage: {:?}", token); + continue; + } + }; + stdx::always!(matches!(path, ast::Expr::PathExpr(_))); + match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { + Some(ast::Expr::MethodCallExpr(_)) => { + // do nothing + } + Some(ast::Expr::RefExpr(node)) + if param.kind() == ParamKind::MutRef && node.mut_token().is_some() => + { + rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); + } + Some(ast::Expr::RefExpr(node)) + if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() => + { + rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); + } + Some(_) | None => { + rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone())); + } + }; + } + } + + rewriter.rewrite(syntax) +} + #[derive(Debug)] enum FunctionBody { Expr(ast::Expr), @@ -1112,6 +1245,164 @@ fn $0fun_name(n: i32) -> (i32, i32) { let k = n * n; let m = k + 2; (k, m) +}", + ); + } + + #[test] + fn mut_var_from_outer_scope() { + check_assist( + extract_function, + r" +fn foo() { + let mut n = 1; + $0n += 1;$0 + let m = n + 1; +}", + r" +fn foo() { + let mut n = 1; + fun_name(&mut n); + let m = n + 1; +} + +fn $0fun_name(n: &mut i32) { + *n += 1; +}", + ); + } + + #[test] + fn mut_param_many_usages_stmt() { + check_assist( + extract_function, + r" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + $0n += n; + bar(n); + bar(n+1); + bar(n*n); + bar(&n); + n.inc(); + let v = &mut n; + *v = v.succ(); + n.succ();$0 + let m = n + 1; +}", + r" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + fun_name(&mut n); + let m = n + 1; +} + +fn $0fun_name(n: &mut i32) { + *n += *n; + bar(*n); + bar(*n+1); + bar(*n**n); + bar(&*n); + n.inc(); + let v = n; + *v = v.succ(); + n.succ(); +}", + ); + } + + #[test] + fn mut_param_many_usages_expr() { + check_assist( + extract_function, + r" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + $0{ + n += n; + bar(n); + bar(n+1); + bar(n*n); + bar(&n); + n.inc(); + let v = &mut n; + *v = v.succ(); + n.succ(); + }$0 + let m = n + 1; +}", + r" +fn bar(k: i32) {} +trait I: Copy { + fn succ(&self) -> Self; + fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v } +} +impl I for i32 { + fn succ(&self) -> Self { *self + 1 } +} +fn foo() { + let mut n = 1; + fun_name(&mut n); + let m = n + 1; +} + +fn $0fun_name(n: &mut i32) { + { + *n += *n; + bar(*n); + bar(*n+1); + bar(*n**n); + bar(&*n); + n.inc(); + let v = n; + *v = v.succ(); + n.succ(); + } +}", + ); + } + + #[test] + fn mut_param_by_value() { + check_assist( + extract_function, + r" +fn foo() { + let mut n = 1; + $0n += 1;$0 +}", + r" +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(mut n: i32) { + n += 1; }", ); } -- cgit v1.2.3 From 86ff1d4809b978f821f4339a200c9ca0f13e422e Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Thu, 4 Feb 2021 00:27:31 +0300 Subject: allow `&mut param` when extracting function Recognise &mut as variable modification. This allows extracting functions with `&mut var` with `var` being in outer scope --- crates/assists/src/handlers/extract_function.rs | 110 +++++++++++++++++++++++- 1 file changed, 107 insertions(+), 3 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index ffa8bd77d..a4b23d756 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -2,7 +2,7 @@ use either::Either; use hir::{HirDisplay, Local}; use ide_db::{ defs::{Definition, NameRefClass}, - search::{ReferenceAccess, SearchScope}, + search::{FileReference, ReferenceAccess, SearchScope}, }; use itertools::Itertools; use stdx::format_to; @@ -15,7 +15,7 @@ use syntax::{ }, Direction, SyntaxElement, SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR}, - SyntaxNode, TextRange, T, + SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, T, }; use test_utils::mark; @@ -140,7 +140,18 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option .iter() .flat_map(|(_, rs)| rs.iter()) .filter(|reference| body.contains_range(reference.range)) - .any(|reference| reference.access == Some(ReferenceAccess::Write)); + .any(|reference| { + if reference.access == Some(ReferenceAccess::Write) { + return true; + } + + let path = path_at_offset(&body, reference); + if is_mut_ref_expr(path.as_ref()).unwrap_or(false) { + return true; + } + + false + }); Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true } }) @@ -405,6 +416,19 @@ fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> Stri ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) } +fn path_at_offset(body: &FunctionBody, reference: &FileReference) -> Option { + let var = body.token_at_offset(reference.range.start()).right_biased()?; + let path = var.ancestors().find_map(ast::Expr::cast)?; + stdx::always!(matches!(path, ast::Expr::PathExpr(_))); + Some(path) +} + +fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option { + let path = path?; + let ref_expr = path.syntax().parent().and_then(ast::RefExpr::cast)?; + Some(ref_expr.mut_token().is_some()) +} + fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { let mut rewriter = SyntaxRewriter::default(); for param in params { @@ -551,6 +575,38 @@ impl FunctionBody { } } + fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset { + match self { + FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset), + FunctionBody::Span { elements, .. } => { + stdx::always!(self.text_range().contains(offset)); + let mut iter = elements + .iter() + .filter(|element| element.text_range().contains_inclusive(offset)); + let element1 = iter.next().expect("offset does not fall into body"); + let element2 = iter.next(); + stdx::always!(iter.next().is_none(), "> 2 tokens at offset"); + let t1 = match element1 { + syntax::NodeOrToken::Node(node) => node.token_at_offset(offset), + syntax::NodeOrToken::Token(token) => TokenAtOffset::Single(token.clone()), + }; + let t2 = element2.map(|e| match e { + syntax::NodeOrToken::Node(node) => node.token_at_offset(offset), + syntax::NodeOrToken::Token(token) => TokenAtOffset::Single(token.clone()), + }); + + match t2 { + Some(t2) => match (t1.clone().right_biased(), t2.clone().left_biased()) { + (Some(e1), Some(e2)) => TokenAtOffset::Between(e1, e2), + (Some(_), None) => t1, + (None, _) => t2, + }, + None => t1, + } + } + } + } + fn text_range(&self) -> TextRange { match self { FunctionBody::Expr(expr) => expr.syntax().text_range(), @@ -1403,6 +1459,54 @@ fn foo() { fn $0fun_name(mut n: i32) { n += 1; +}", + ); + } + + #[test] + fn mut_param_because_of_mut_ref() { + check_assist( + extract_function, + r" +fn foo() { + let mut n = 1; + $0let v = &mut n; + *v += 1;$0 + let k = n; +}", + r" +fn foo() { + let mut n = 1; + fun_name(&mut n); + let k = n; +} + +fn $0fun_name(n: &mut i32) { + let v = n; + *v += 1; +}", + ); + } + + #[test] + fn mut_param_by_value_because_of_mut_ref() { + check_assist( + extract_function, + r" +fn foo() { + let mut n = 1; + $0let v = &mut n; + *v += 1;$0 +}", + r" +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(mut n: i32) { + let v = &mut n; + *v += 1; }", ); } -- cgit v1.2.3 From c4f3669e70c6b7e4bafa03f41ad29a3de46f80ad Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Thu, 4 Feb 2021 00:44:36 +0300 Subject: allow calling `&mut` methods on outer vars when extracing function --- crates/assists/src/handlers/extract_function.rs | 116 ++++++++++++++++++++++++ 1 file changed, 116 insertions(+) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index a4b23d756..8a4073886 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -150,6 +150,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option return true; } + if is_mut_method_call(ctx, path.as_ref()).unwrap_or(false) { + return true; + } + false }); @@ -429,6 +433,17 @@ fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option { Some(ref_expr.mut_token().is_some()) } +fn is_mut_method_call(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option { + let path = path?; + let method_call = path.syntax().parent().and_then(ast::MethodCallExpr::cast)?; + + let func = ctx.sema.resolve_method_call(&method_call)?; + let self_param = func.self_param(ctx.db())?; + let access = self_param.access(ctx.db()); + + Some(matches!(access, hir::Access::Exclusive)) +} + fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { let mut rewriter = SyntaxRewriter::default(); for param in params { @@ -1507,6 +1522,107 @@ fn foo() { fn $0fun_name(mut n: i32) { let v = &mut n; *v += 1; +}", + ); + } + + #[test] + fn mut_method_call() { + check_assist( + extract_function, + r" +trait I { + fn inc(&mut self); +} +impl I for i32 { + fn inc(&mut self) { *self += 1 } +} +fn foo() { + let mut n = 1; + $0n.inc();$0 +}", + r" +trait I { + fn inc(&mut self); +} +impl I for i32 { + fn inc(&mut self) { *self += 1 } +} +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(mut n: i32) { + n.inc(); +}", + ); + } + + #[test] + fn shared_method_call() { + check_assist( + extract_function, + r" +trait I { + fn succ(&self); +} +impl I for i32 { + fn succ(&self) { *self + 1 } +} +fn foo() { + let mut n = 1; + $0n.succ();$0 +}", + r" +trait I { + fn succ(&self); +} +impl I for i32 { + fn succ(&self) { *self + 1 } +} +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(n: i32) { + n.succ(); +}", + ); + } + + #[test] + fn mut_method_call_with_other_receiver() { + check_assist( + extract_function, + r" +trait I { + fn inc(&mut self, n: i32); +} +impl I for i32 { + fn inc(&mut self, n: i32) { *self += n } +} +fn foo() { + let mut n = 1; + $0let mut m = 2; + m.inc(n);$0 +}", + r" +trait I { + fn inc(&mut self, n: i32); +} +impl I for i32 { + fn inc(&mut self, n: i32) { *self += n } +} +fn foo() { + let mut n = 1; + fun_name(n); +} + +fn $0fun_name(n: i32) { + let mut m = 2; + m.inc(n); }", ); } -- cgit v1.2.3 From ff77c5e68fefcf525c2aa449cff5e0c52e7d3a0d Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Thu, 4 Feb 2021 00:52:53 +0300 Subject: remove ignored test for downgrading mut to shared --- crates/assists/src/handlers/extract_function.rs | 30 ------------------------- 1 file changed, 30 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 8a4073886..dce7ffd7b 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -1240,36 +1240,6 @@ impl S { ); } - // it is unclear if this is wanted behaviour - // and how this behavour can be implemented - #[ignore] - #[test] - fn method_with_mut_downgrade_to_shared() { - check_assist( - extract_function, - r" -struct S { f: i32 }; - -impl S { - fn foo(&mut self) -> i32 { - $01+self.f$0 - } -}", - r" -struct S { f: i32 }; - -impl S { - fn foo(&mut self) -> i32 { - self.fun_name() - } - - fn $0fun_name(&self) -> i32 { - 1+self.f - } -}", - ); - } - #[test] fn variable_defined_inside_and_used_after_no_ret() { check_assist( -- cgit v1.2.3 From d9b122858b8a454c23dbeab0971571ce0b38aeec Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Fri, 5 Feb 2021 00:35:28 +0300 Subject: split extract_function into pieces and order them --- crates/assists/src/handlers/extract_function.rs | 892 ++++++++++++++---------- 1 file changed, 511 insertions(+), 381 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index dce7ffd7b..93ff66b24 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -60,115 +60,21 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option return None; } - let node = match node { - syntax::NodeOrToken::Node(n) => n, - syntax::NodeOrToken::Token(t) => t.parent(), - }; + let node = element_to_node(node); - let mut body = None; - if node.text_range() == ctx.frange.range { - body = FunctionBody::from_whole_node(node.clone()); - } - if body.is_none() && node.kind() == BLOCK_EXPR { - body = FunctionBody::from_range(&node, ctx.frange.range); - } - if let Some(parent) = node.parent() { - if body.is_none() && parent.kind() == BLOCK_EXPR { - body = FunctionBody::from_range(&parent, ctx.frange.range); - } - } - if body.is_none() { - body = FunctionBody::from_whole_node(node.clone()); - } - if body.is_none() { - body = node.ancestors().find_map(FunctionBody::from_whole_node); - } - let body = body?; + let body = extraction_target(&node, ctx.frange.range)?; let vars_used_in_body = vars_used_in_body(&body, &ctx); - let mut self_param = None; - let param_pats: Vec<_> = vars_used_in_body - .iter() - .map(|node| (node, node.source(ctx.db()))) - .filter(|(_, src)| { - src.file_id.original_file(ctx.db()) == ctx.frange.file_id - && !body.contains_node(&either_syntax(&src.value)) - }) - .filter_map(|(&node, src)| match src.value { - Either::Left(_) => Some(node), - Either::Right(it) => { - // we filter self param, as there can only be one - self_param = Some((node, it)); - None - } - }) - .collect(); + let self_param = self_param_from_usages(ctx, &body, &vars_used_in_body); let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; - let insert_after = body.scope_for_fn_insertion(anchor)?; + let insert_after = scope_for_fn_insertion(&body, anchor)?; let module = ctx.sema.scope(&insert_after).module()?; - let vars_defined_in_body = vars_defined_in_body(&body, ctx); - - let vars_defined_in_body_and_outlive: Vec<_> = vars_defined_in_body - .iter() - .copied() - .filter(|node| { - let usages = Definition::Local(*node) - .usages(&ctx.sema) - .in_scope(SearchScope::single_file(ctx.frange.file_id)) - .all(); - let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter()); - - usages.any(|reference| body.preceedes_range(reference.range)) - }) - .collect(); - - let params: Vec<_> = param_pats - .into_iter() - .map(|node| { - let usages = Definition::Local(node) - .usages(&ctx.sema) - .in_scope(SearchScope::single_file(ctx.frange.file_id)) - .all(); - - let has_usages_afterwards = usages - .iter() - .flat_map(|(_, rs)| rs.iter()) - .any(|reference| body.preceedes_range(reference.range)); - let has_mut_inside_body = usages - .iter() - .flat_map(|(_, rs)| rs.iter()) - .filter(|reference| body.contains_range(reference.range)) - .any(|reference| { - if reference.access == Some(ReferenceAccess::Write) { - return true; - } - - let path = path_at_offset(&body, reference); - if is_mut_ref_expr(path.as_ref()).unwrap_or(false) { - return true; - } - - if is_mut_method_call(ctx, path.as_ref()).unwrap_or(false) { - return true; - } - - false - }); - - Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true } - }) - .collect(); - - let expr = body.tail_expr(); - let ret_ty = match expr { - Some(expr) => Some(ctx.sema.type_of_expr(&expr)?), - None => None, - }; + let vars_defined_in_body_and_outlive = vars_defined_in_body_and_outlive(ctx, &body); + let ret_ty = body_return_ty(ctx, &body)?; - let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit()); - if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !has_unit_ret) { + if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !ret_ty.is_unit()) { // We should not have variables that outlive body if we have expression block return None; } @@ -183,6 +89,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option "Extract into function", target_range, move |builder| { + let params = extracted_function_params(ctx, &body, &vars_used_in_body); + let fun = Function { name: "fun_name".to_string(), self_param: self_param.map(|(_, pat)| pat), @@ -203,65 +111,19 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option ) } -fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { - let mut buf = String::new(); - - match fun.vars_defined_in_body_and_outlive.as_slice() { - [] => {} - [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()), - [v0, vs @ ..] => { - buf.push_str("let ("); - format_to!(buf, "{}", v0.name(ctx.db()).unwrap()); - for local in vs { - format_to!(buf, ", {}", local.name(ctx.db()).unwrap()); - } - buf.push_str(") = "); - } - } - - if fun.self_param.is_some() { - format_to!(buf, "self."); - } - format_to!(buf, "{}(", fun.name); - { - let mut it = fun.params.iter(); - if let Some(param) = it.next() { - format_to!(buf, "{}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap()); - } - for param in it { - format_to!(buf, ", {}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap()); - } - } - format_to!(buf, ")"); - - if fun.has_unit_ret() { - format_to!(buf, ";"); - } - - buf -} - +#[derive(Debug)] struct Function { name: String, self_param: Option, params: Vec, - ret_ty: Option, + ret_ty: RetType, body: FunctionBody, vars_defined_in_body_and_outlive: Vec, } -impl Function { - fn has_unit_ret(&self) -> bool { - match &self.ret_ty { - Some(ty) => ty.is_unit(), - None => true, - } - } -} - #[derive(Debug)] struct Param { - node: Local, + var: Local, has_usages_afterwards: bool, has_mut_inside_body: bool, is_copy: bool, @@ -293,8 +155,7 @@ impl Param { fn value_prefix(&self) -> &'static str { match self.kind() { - ParamKind::Value => "", - ParamKind::MutValue => "", + ParamKind::Value | ParamKind::MutValue => "", ParamKind::SharedRef => "&", ParamKind::MutRef => "&mut ", } @@ -302,8 +163,7 @@ impl Param { fn type_prefix(&self) -> &'static str { match self.kind() { - ParamKind::Value => "", - ParamKind::MutValue => "", + ParamKind::Value | ParamKind::MutValue => "", ParamKind::SharedRef => "&", ParamKind::MutRef => "&mut ", } @@ -317,186 +177,27 @@ impl Param { } } -fn format_function( - ctx: &AssistContext, - module: hir::Module, - fun: &Function, - indent: IndentLevel, -) -> String { - let mut fn_def = String::new(); - format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); - { - let mut it = fun.params.iter(); - if let Some(self_param) = &fun.self_param { - format_to!(fn_def, "{}", self_param); - } else if let Some(param) = it.next() { - format_to!( - fn_def, - "{}{}: {}{}", - param.mut_pattern(), - param.node.name(ctx.db()).unwrap(), - param.type_prefix(), - format_type(¶m.node.ty(ctx.db()), ctx, module) - ); - } - for param in it { - format_to!( - fn_def, - ", {}{}: {}{}", - param.mut_pattern(), - param.node.name(ctx.db()).unwrap(), - param.type_prefix(), - format_type(¶m.node.ty(ctx.db()), ctx, module) - ); - } - } - - format_to!(fn_def, ")"); - if !fun.has_unit_ret() { - if let Some(ty) = &fun.ret_ty { - format_to!(fn_def, " -> {}", format_type(ty, ctx, module)); - } - } else { - match fun.vars_defined_in_body_and_outlive.as_slice() { - [] => {} - [var] => { - format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module)); - } - [v0, vs @ ..] => { - format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module)); - for var in vs { - format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module)); - } - fn_def.push(')'); - } - } - } - fn_def.push_str(" {"); - - match &fun.body { - FunctionBody::Expr(expr) => { - fn_def.push('\n'); - let expr = expr.indent(indent); - let expr = fix_param_usages(ctx, &fun.params, expr.syntax()); - format_to!(fn_def, "{}{}", indent + 1, expr); - fn_def.push('\n'); - } - FunctionBody::Span { elements, leading_indent } => { - format_to!(fn_def, "{}", leading_indent); - for element in elements { - match element { - syntax::NodeOrToken::Node(node) => { - format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node)); - } - syntax::NodeOrToken::Token(token) => { - format_to!(fn_def, "{}", token); - } - } - } - if !fn_def.ends_with('\n') { - fn_def.push('\n'); - } - } - } - - match fun.vars_defined_in_body_and_outlive.as_slice() { - [] => {} - [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()), - [v0, vs @ ..] => { - format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap()); - for var in vs { - format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap()); - } - fn_def.push_str(")\n"); - } - } - - format_to!(fn_def, "{}}}", indent); - - fn_def -} - -fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String { - ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) -} - -fn path_at_offset(body: &FunctionBody, reference: &FileReference) -> Option { - let var = body.token_at_offset(reference.range.start()).right_biased()?; - let path = var.ancestors().find_map(ast::Expr::cast)?; - stdx::always!(matches!(path, ast::Expr::PathExpr(_))); - Some(path) -} - -fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option { - let path = path?; - let ref_expr = path.syntax().parent().and_then(ast::RefExpr::cast)?; - Some(ref_expr.mut_token().is_some()) -} - -fn is_mut_method_call(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option { - let path = path?; - let method_call = path.syntax().parent().and_then(ast::MethodCallExpr::cast)?; - - let func = ctx.sema.resolve_method_call(&method_call)?; - let self_param = func.self_param(ctx.db())?; - let access = self_param.access(ctx.db()); - - Some(matches!(access, hir::Access::Exclusive)) +#[derive(Debug)] +enum RetType { + Expr(hir::Type), + Stmt, } -fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { - let mut rewriter = SyntaxRewriter::default(); - for param in params { - if !param.kind().is_ref() { - continue; +impl RetType { + fn is_unit(&self) -> bool { + match self { + RetType::Expr(ty) => ty.is_unit(), + RetType::Stmt => true, } + } - let usages = Definition::Local(param.node) - .usages(&ctx.sema) - .in_scope(SearchScope::single_file(ctx.frange.file_id)) - .all(); - let usages = usages - .iter() - .flat_map(|(_, rs)| rs.iter()) - .filter(|reference| syntax.text_range().contains_range(reference.range)); - for reference in usages { - let token = match syntax.token_at_offset(reference.range.start()).right_biased() { - Some(a) => a, - None => { - stdx::never!(false, "cannot find token at variable usage: {:?}", reference); - continue; - } - }; - let path = match token.ancestors().find_map(ast::Expr::cast) { - Some(n) => n, - None => { - stdx::never!(false, "cannot find path parent of variable usage: {:?}", token); - continue; - } - }; - stdx::always!(matches!(path, ast::Expr::PathExpr(_))); - match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { - Some(ast::Expr::MethodCallExpr(_)) => { - // do nothing - } - Some(ast::Expr::RefExpr(node)) - if param.kind() == ParamKind::MutRef && node.mut_token().is_some() => - { - rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); - } - Some(ast::Expr::RefExpr(node)) - if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() => - { - rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); - } - Some(_) | None => { - rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone())); - } - }; + fn as_fn_ret(&self) -> Option<&hir::Type> { + match self { + RetType::Stmt => None, + RetType::Expr(ty) if ty.is_unit() => None, + RetType::Expr(ty) => Some(ty), } } - - rewriter.rewrite(syntax) } #[derive(Debug)] @@ -505,11 +206,6 @@ enum FunctionBody { Span { elements: Vec, leading_indent: String }, } -enum Anchor { - Freestanding, - Method, -} - impl FunctionBody { fn from_whole_node(node: SyntaxNode) -> Option { match node.kind() { @@ -568,16 +264,6 @@ impl FunctionBody { } } - fn scope_for_fn_insertion(&self, anchor: Anchor) -> Option { - match self { - FunctionBody::Expr(e) => scope_for_fn_insertion(e.syntax(), anchor), - FunctionBody::Span { elements, .. } => { - let node = elements.iter().find_map(|e| e.as_node())?; - scope_for_fn_insertion(&node, anchor) - } - } - } - fn descendants(&self) -> impl Iterator + '_ { match self { FunctionBody::Expr(expr) => Either::Right(expr.syntax().descendants()), @@ -590,6 +276,30 @@ impl FunctionBody { } } + fn text_range(&self) -> TextRange { + match self { + FunctionBody::Expr(expr) => expr.syntax().text_range(), + FunctionBody::Span { elements, .. } => TextRange::new( + elements.first().unwrap().text_range().start(), + elements.last().unwrap().text_range().end(), + ), + } + } + + fn contains_range(&self, range: TextRange) -> bool { + self.text_range().contains_range(range) + } + + fn preceedes_range(&self, range: TextRange) -> bool { + self.text_range().end() <= range.start() + } + + fn contains_node(&self, node: &SyntaxNode) -> bool { + self.contains_range(node.text_range()) + } +} + +impl HasTokenAtOffset for FunctionBody { fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset { match self { FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset), @@ -621,31 +331,278 @@ impl FunctionBody { } } } +} - fn text_range(&self) -> TextRange { - match self { - FunctionBody::Expr(expr) => expr.syntax().text_range(), - FunctionBody::Span { elements, .. } => TextRange::new( - elements.first().unwrap().text_range().start(), - elements.last().unwrap().text_range().end(), - ), - } +fn element_to_node(node: SyntaxElement) -> SyntaxNode { + match node { + syntax::NodeOrToken::Node(n) => n, + syntax::NodeOrToken::Token(t) => t.parent(), } +} - fn contains_range(&self, range: TextRange) -> bool { - self.text_range().contains_range(range) +fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option { + if node.text_range() == selection_range { + let body = FunctionBody::from_whole_node(node.clone()); + if body.is_some() { + return body; + } } - fn preceedes_range(&self, range: TextRange) -> bool { - self.text_range().end() <= range.start() + if node.kind() == BLOCK_EXPR { + let body = FunctionBody::from_range(&node, selection_range); + if body.is_some() { + return body; + } + } + if let Some(parent) = node.parent() { + if parent.kind() == BLOCK_EXPR { + let body = FunctionBody::from_range(&parent, selection_range); + if body.is_some() { + return body; + } + } } - fn contains_node(&self, node: &SyntaxNode) -> bool { - self.contains_range(node.text_range()) + let body = FunctionBody::from_whole_node(node.clone()); + if body.is_some() { + return body; + } + + let body = node.ancestors().find_map(FunctionBody::from_whole_node); + if body.is_some() { + return body; + } + + None +} + +/// Returns a vector of local variables that are referenced in `body` +fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { + body.descendants() + .filter_map(ast::NameRef::cast) + .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) + .map(|name_kind| name_kind.referenced(ctx.db())) + .filter_map(|definition| match definition { + Definition::Local(local) => Some(local), + _ => None, + }) + .unique() + .collect() +} + +fn self_param_from_usages( + ctx: &AssistContext, + body: &FunctionBody, + vars_used_in_body: &[Local], +) -> Option<(Local, ast::SelfParam)> { + let mut iter = vars_used_in_body + .iter() + .filter(|var| var.is_self(ctx.db())) + .map(|var| (var, var.source(ctx.db()))) + .filter(|(_, src)| is_defined_before(ctx, body, src)) + .filter_map(|(&node, src)| match src.value { + Either::Right(it) => Some((node, it)), + Either::Left(_) => { + stdx::never!(false, "Local::is_self returned true, but source is IdentPat"); + None + } + }); + + let self_param = iter.next(); + stdx::always!( + iter.next().is_none(), + "body references two different self params both defined outside" + ); + + self_param +} + +fn extracted_function_params( + ctx: &AssistContext, + body: &FunctionBody, + vars_used_in_body: &[Local], +) -> Vec { + vars_used_in_body + .iter() + .filter(|var| !var.is_self(ctx.db())) + .map(|node| (node, node.source(ctx.db()))) + .filter(|(_, src)| is_defined_before(ctx, body, src)) + .filter_map(|(&node, src)| { + if src.value.is_left() { + Some(node) + } else { + stdx::never!(false, "Local::is_self returned false, but source is SelfParam"); + None + } + }) + .map(|var| { + let usages = LocalUsages::find(ctx, var); + Param { + var, + has_usages_afterwards: has_usages_after_body(&usages, body), + has_mut_inside_body: has_exclusive_usages(ctx, &usages, body), + is_copy: true, + } + }) + .collect() +} + +fn has_usages_after_body(usages: &LocalUsages, body: &FunctionBody) -> bool { + usages.iter().any(|reference| body.preceedes_range(reference.range)) +} + +fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &FunctionBody) -> bool { + usages + .iter() + .filter(|reference| body.contains_range(reference.range)) + .any(|reference| reference_is_exclusive(reference, body, ctx)) +} + +fn reference_is_exclusive( + reference: &FileReference, + body: &FunctionBody, + ctx: &AssistContext, +) -> bool { + if reference.access == Some(ReferenceAccess::Write) { + return true; + } + + let path = path_at_offset(body, reference); + if is_mut_ref_expr(path.as_ref()).unwrap_or(false) { + return true; + } + + if is_mut_method_call(ctx, path.as_ref()).unwrap_or(false) { + return true; + } + + false +} + +struct LocalUsages(ide_db::search::UsageSearchResult); + +impl LocalUsages { + fn find(ctx: &AssistContext, var: Local) -> Self { + Self( + Definition::Local(var) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.frange.file_id)) + .all(), + ) + } + + fn iter(&self) -> impl Iterator + '_ { + self.0.iter().flat_map(|(_, rs)| rs.iter()) + } +} + +trait HasTokenAtOffset { + fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset; +} + +impl HasTokenAtOffset for SyntaxNode { + fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset { + SyntaxNode::token_at_offset(&self, offset) + } +} + +fn path_at_offset(node: &dyn HasTokenAtOffset, reference: &FileReference) -> Option { + let token = node.token_at_offset(reference.range.start()).right_biased().or_else(|| { + stdx::never!(false, "cannot find token at variable usage: {:?}", reference); + None + })?; + let path = token.ancestors().find_map(ast::Expr::cast).or_else(|| { + stdx::never!(false, "cannot find path parent of variable usage: {:?}", token); + None + })?; + stdx::always!(matches!(path, ast::Expr::PathExpr(_))); + Some(path) +} + +fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option { + let path = path?; + let ref_expr = path.syntax().parent().and_then(ast::RefExpr::cast)?; + Some(ref_expr.mut_token().is_some()) +} + +fn is_mut_method_call(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option { + let path = path?; + let method_call = path.syntax().parent().and_then(ast::MethodCallExpr::cast)?; + + let func = ctx.sema.resolve_method_call(&method_call)?; + let self_param = func.self_param(ctx.db())?; + let access = self_param.access(ctx.db()); + + Some(matches!(access, hir::Access::Exclusive)) +} + +/// Returns a vector of local variables that are defined in `body` +fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { + body.descendants() + .filter_map(ast::IdentPat::cast) + .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt)) + .unique() + .collect() +} + +fn vars_defined_in_body_and_outlive(ctx: &AssistContext, body: &FunctionBody) -> Vec { + let mut vars_defined_in_body = vars_defined_in_body(&body, ctx); + vars_defined_in_body.retain(|var| var_outlives_body(ctx, body, var)); + vars_defined_in_body +} + +fn is_defined_before( + ctx: &AssistContext, + body: &FunctionBody, + src: &hir::InFile>, +) -> bool { + src.file_id.original_file(ctx.db()) == ctx.frange.file_id + && !body.contains_node(&either_syntax(&src.value)) +} + +fn either_syntax(value: &Either) -> &SyntaxNode { + match value { + Either::Left(pat) => pat.syntax(), + Either::Right(it) => it.syntax(), + } +} + +fn var_outlives_body(ctx: &AssistContext, body: &FunctionBody, var: &Local) -> bool { + let usages = Definition::Local(*var) + .usages(&ctx.sema) + .in_scope(SearchScope::single_file(ctx.frange.file_id)) + .all(); + let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter()); + + usages.any(|reference| body.preceedes_range(reference.range)) +} + +fn body_return_ty(ctx: &AssistContext, body: &FunctionBody) -> Option { + match body.tail_expr() { + Some(expr) => { + let ty = ctx.sema.type_of_expr(&expr)?; + Some(RetType::Expr(ty)) + } + None => Some(RetType::Stmt), + } +} +#[derive(Debug)] +enum Anchor { + Freestanding, + Method, +} + +fn scope_for_fn_insertion(body: &FunctionBody, anchor: Anchor) -> Option { + match body { + FunctionBody::Expr(e) => scope_for_fn_insertion_node(e.syntax(), anchor), + FunctionBody::Span { elements, .. } => { + let node = elements.iter().find_map(|e| e.as_node())?; + scope_for_fn_insertion_node(&node, anchor) + } } } -fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option { +fn scope_for_fn_insertion_node(node: &SyntaxNode, anchor: Anchor) -> Option { let mut ancestors = node.ancestors().peekable(); let mut last_ancestor = None; while let Some(next_ancestor) = ancestors.next() { @@ -674,34 +631,207 @@ fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option) -> &SyntaxNode { - match value { - Either::Left(pat) => pat.syntax(), - Either::Right(it) => it.syntax(), +fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { + let mut buf = String::new(); + + match fun.vars_defined_in_body_and_outlive.as_slice() { + [] => {} + [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()), + [v0, vs @ ..] => { + buf.push_str("let ("); + format_to!(buf, "{}", v0.name(ctx.db()).unwrap()); + for var in vs { + format_to!(buf, ", {}", var.name(ctx.db()).unwrap()); + } + buf.push_str(") = "); + } + } + + if fun.self_param.is_some() { + format_to!(buf, "self."); + } + format_to!(buf, "{}(", fun.name); + format_arg_list_to(&mut buf, fun, ctx); + format_to!(buf, ")"); + + if fun.ret_ty.is_unit() { + format_to!(buf, ";"); } + + buf } -/// Returns a vector of local variables that are referenced in `body` -fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { - body.descendants() - .filter_map(ast::NameRef::cast) - .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) - .map(|name_kind| name_kind.referenced(ctx.db())) - .filter_map(|definition| match definition { - Definition::Local(local) => Some(local), - _ => None, - }) - .unique() - .collect() +fn format_arg_list_to(buf: &mut String, fun: &Function, ctx: &AssistContext) { + let mut it = fun.params.iter(); + if let Some(param) = it.next() { + format_arg_to(buf, ctx, param); + } + for param in it { + buf.push_str(", "); + format_arg_to(buf, ctx, param); + } } -/// Returns a vector of local variables that are defined in `body` -fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { - body.descendants() - .filter_map(ast::IdentPat::cast) - .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt)) - .unique() - .collect() +fn format_arg_to(buf: &mut String, ctx: &AssistContext, param: &Param) { + format_to!(buf, "{}{}", param.value_prefix(), param.var.name(ctx.db()).unwrap()); +} + +fn format_function( + ctx: &AssistContext, + module: hir::Module, + fun: &Function, + indent: IndentLevel, +) -> String { + let mut fn_def = String::new(); + format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); + format_function_param_list_to(&mut fn_def, ctx, module, fun); + fn_def.push(')'); + format_function_ret_to(&mut fn_def, ctx, module, fun); + fn_def.push_str(" {"); + format_function_body_to(&mut fn_def, ctx, indent, fun); + format_to!(fn_def, "{}}}", indent); + + fn_def +} + +fn format_function_param_list_to( + fn_def: &mut String, + ctx: &AssistContext, + module: hir::Module, + fun: &Function, +) { + let mut it = fun.params.iter(); + if let Some(self_param) = &fun.self_param { + format_to!(fn_def, "{}", self_param); + } else if let Some(param) = it.next() { + format_param_to(fn_def, ctx, module, param); + } + for param in it { + fn_def.push_str(", "); + format_param_to(fn_def, ctx, module, param); + } +} + +fn format_param_to(fn_def: &mut String, ctx: &AssistContext, module: hir::Module, param: &Param) { + format_to!( + fn_def, + "{}{}: {}{}", + param.mut_pattern(), + param.var.name(ctx.db()).unwrap(), + param.type_prefix(), + format_type(¶m.var.ty(ctx.db()), ctx, module) + ); +} + +fn format_function_ret_to( + fn_def: &mut String, + ctx: &AssistContext, + module: hir::Module, + fun: &Function, +) { + if let Some(ty) = fun.ret_ty.as_fn_ret() { + format_to!(fn_def, " -> {}", format_type(ty, ctx, module)); + } else { + match fun.vars_defined_in_body_and_outlive.as_slice() { + [] => {} + [var] => { + format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module)); + } + [v0, vs @ ..] => { + format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module)); + for var in vs { + format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module)); + } + fn_def.push(')'); + } + } + } +} + +fn format_function_body_to( + fn_def: &mut String, + ctx: &AssistContext, + indent: IndentLevel, + fun: &Function, +) { + match &fun.body { + FunctionBody::Expr(expr) => { + fn_def.push('\n'); + let expr = expr.indent(indent); + let expr = fix_param_usages(ctx, &fun.params, expr.syntax()); + format_to!(fn_def, "{}{}", indent + 1, expr); + fn_def.push('\n'); + } + FunctionBody::Span { elements, leading_indent } => { + format_to!(fn_def, "{}", leading_indent); + for element in elements { + match element { + syntax::NodeOrToken::Node(node) => { + format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node)); + } + syntax::NodeOrToken::Token(token) => { + format_to!(fn_def, "{}", token); + } + } + } + if !fn_def.ends_with('\n') { + fn_def.push('\n'); + } + } + } + + match fun.vars_defined_in_body_and_outlive.as_slice() { + [] => {} + [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()), + [v0, vs @ ..] => { + format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap()); + for var in vs { + format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap()); + } + fn_def.push_str(")\n"); + } + } +} + +fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String { + ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) +} + +fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { + let mut rewriter = SyntaxRewriter::default(); + for param in params { + if !param.kind().is_ref() { + continue; + } + + let usages = LocalUsages::find(ctx, param.var); + let usages = usages + .iter() + .filter(|reference| syntax.text_range().contains_range(reference.range)) + .filter_map(|reference| path_at_offset(syntax, reference)); + for path in usages { + match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { + Some(ast::Expr::MethodCallExpr(_)) => { + // do nothing + } + Some(ast::Expr::RefExpr(node)) + if param.kind() == ParamKind::MutRef && node.mut_token().is_some() => + { + rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); + } + Some(ast::Expr::RefExpr(node)) + if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() => + { + rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap()); + } + Some(_) | None => { + rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone())); + } + }; + } + } + + rewriter.rewrite(syntax) } #[cfg(test)] -- cgit v1.2.3 From 0ff74467c0d107a0b9472e928f9f0845f68be088 Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Fri, 5 Feb 2021 01:18:45 +0300 Subject: use `&T` for non copy params of extracted function Use shared ref if param is not `T: Copy` and is used after body --- crates/assists/src/handlers/extract_function.rs | 57 ++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 2 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 93ff66b24..ac2a5a674 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -124,6 +124,7 @@ struct Function { #[derive(Debug)] struct Param { var: Local, + ty: hir::Type, has_usages_afterwards: bool, has_mut_inside_body: bool, is_copy: bool, @@ -437,11 +438,14 @@ fn extracted_function_params( }) .map(|var| { let usages = LocalUsages::find(ctx, var); + let ty = var.ty(ctx.db()); + let is_copy = ty.is_copy(ctx.db()); Param { var, + ty, has_usages_afterwards: has_usages_after_body(&usages, body), has_mut_inside_body: has_exclusive_usages(ctx, &usages, body), - is_copy: true, + is_copy, } }) .collect() @@ -719,7 +723,7 @@ fn format_param_to(fn_def: &mut String, ctx: &AssistContext, module: hir::Module param.mut_pattern(), param.var.name(ctx.db()).unwrap(), param.type_prefix(), - format_type(¶m.var.ty(ctx.db()), ctx, module) + format_type(¶m.ty, ctx, module) ); } @@ -1723,6 +1727,55 @@ fn foo() { fn $0fun_name(n: i32) { let mut m = 2; m.inc(n); +}", + ); + } + + #[test] + fn non_copy_without_usages_after() { + check_assist( + extract_function, + r" +struct Counter(i32); +fn foo() { + let c = Counter(0); + $0let n = c.0;$0 +}", + r" +struct Counter(i32); +fn foo() { + let c = Counter(0); + fun_name(c); +} + +fn $0fun_name(c: Counter) { + let n = c.0; +}", + ); + } + + + #[test] + fn non_copy_used_after() { + check_assist( + extract_function, + r" +struct Counter(i32); +fn foo() { + let c = Counter(0); + $0let n = c.0;$0 + let m = c.0; +}", + r" +struct Counter(i32); +fn foo() { + let c = Counter(0); + fun_name(&c); + let m = c.0; +} + +fn $0fun_name(c: &Counter) { + let n = *c.0; }", ); } -- cgit v1.2.3 From 4dc2a42500b52001299ef861d5105d8a8249ecd8 Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Fri, 5 Feb 2021 02:14:32 +0300 Subject: document extract_function assist implementation --- crates/assists/src/handlers/extract_function.rs | 148 ++++++++++++++++++++---- 1 file changed, 126 insertions(+), 22 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index ac2a5a674..dfb3da7a5 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -64,7 +64,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let body = extraction_target(&node, ctx.frange.range)?; - let vars_used_in_body = vars_used_in_body(&body, &ctx); + let vars_used_in_body = vars_used_in_body(ctx, &body); let self_param = self_param_from_usages(ctx, &body, &vars_used_in_body); let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; @@ -74,6 +74,9 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option let vars_defined_in_body_and_outlive = vars_defined_in_body_and_outlive(ctx, &body); let ret_ty = body_return_ty(ctx, &body)?; + // FIXME: we compute variables that outlive here just to check `never!` condition + // this requires traversing whole `body` (cheap) and finding all references (expensive) + // maybe we can move this check to `edit` closure somehow? if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !ret_ty.is_unit()) { // We should not have variables that outlive body if we have expression block return None; @@ -201,6 +204,7 @@ impl RetType { } } +/// Semantically same as `ast::Expr`, but preserves identity when using only part of the Block #[derive(Debug)] enum FunctionBody { Expr(ast::Expr), @@ -334,6 +338,7 @@ impl HasTokenAtOffset for FunctionBody { } } +/// node or token's parent fn element_to_node(node: SyntaxElement) -> SyntaxNode { match node { syntax::NodeOrToken::Node(n) => n, @@ -341,7 +346,26 @@ fn element_to_node(node: SyntaxElement) -> SyntaxNode { } } +/// Try to guess what user wants to extract +/// +/// We have basically have two cases: +/// * We want whole node, like `loop {}`, `2 + 2`, `{ let n = 1; }` exprs. +/// Then we can use `ast::Expr` +/// * We want a few statements for a block. E.g. +/// ```rust,no_run +/// fn foo() -> i32 { +/// let m = 1; +/// $0 +/// let n = 2; +/// let k = 3; +/// k + n +/// $0 +/// } +/// ``` +/// fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option { + // we have selected exactly the expr node + // wrap it before anything else if node.text_range() == selection_range { let body = FunctionBody::from_whole_node(node.clone()); if body.is_some() { @@ -349,12 +373,18 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option Option Vec { +/// list local variables that are referenced in `body` +fn vars_used_in_body(ctx: &AssistContext, body: &FunctionBody) -> Vec { + // FIXME: currently usages inside macros are not found body.descendants() .filter_map(ast::NameRef::cast) .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref)) @@ -391,6 +413,9 @@ fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { .collect() } +/// find `self` param, that was not defined inside `body` +/// +/// It should skip `self` params from impls inside `body` fn self_param_from_usages( ctx: &AssistContext, body: &FunctionBody, @@ -412,12 +437,15 @@ fn self_param_from_usages( let self_param = iter.next(); stdx::always!( iter.next().is_none(), - "body references two different self params both defined outside" + "body references two different self params, both defined outside" ); self_param } +/// find variables that should be extracted as params +/// +/// Computes additional info that affects param type and mutability fn extracted_function_params( ctx: &AssistContext, body: &FunctionBody, @@ -455,6 +483,7 @@ fn has_usages_after_body(usages: &LocalUsages, body: &FunctionBody) -> bool { usages.iter().any(|reference| body.preceedes_range(reference.range)) } +/// checks if relevant var is used with `&mut` access inside body fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &FunctionBody) -> bool { usages .iter() @@ -462,27 +491,34 @@ fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &Functi .any(|reference| reference_is_exclusive(reference, body, ctx)) } +/// checks if this reference requires `&mut` access inside body fn reference_is_exclusive( reference: &FileReference, body: &FunctionBody, ctx: &AssistContext, ) -> bool { + // we directly modify variable with set: `n = 0`, `n += 1` if reference.access == Some(ReferenceAccess::Write) { return true; } - let path = path_at_offset(body, reference); + // we take `&mut` reference to variable: `&mut v` + let path = path_element_of_reference(body, reference); if is_mut_ref_expr(path.as_ref()).unwrap_or(false) { return true; } - if is_mut_method_call(ctx, path.as_ref()).unwrap_or(false) { + // we call method with `&mut self` receiver + if is_mut_method_call_receiver(ctx, path.as_ref()).unwrap_or(false) { return true; } false } +/// Container of local varaible usages +/// +/// Semanticall same as `UsageSearchResult`, but provides more convenient interface struct LocalUsages(ide_db::search::UsageSearchResult); impl LocalUsages { @@ -510,7 +546,15 @@ impl HasTokenAtOffset for SyntaxNode { } } -fn path_at_offset(node: &dyn HasTokenAtOffset, reference: &FileReference) -> Option { +/// find relevant `ast::PathExpr` for reference +/// +/// # Preconditions +/// +/// `node` must cover `reference`, that is `node.text_range().contains_range(reference.range)` +fn path_element_of_reference( + node: &dyn HasTokenAtOffset, + reference: &FileReference, +) -> Option { let token = node.token_at_offset(reference.range.start()).right_biased().or_else(|| { stdx::never!(false, "cannot find token at variable usage: {:?}", reference); None @@ -529,7 +573,8 @@ fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option { Some(ref_expr.mut_token().is_some()) } -fn is_mut_method_call(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option { +/// checks if `path` is the receiver in method call that requires `&mut self` access +fn is_mut_method_call_receiver(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option { let path = path?; let method_call = path.syntax().parent().and_then(ast::MethodCallExpr::cast)?; @@ -540,8 +585,10 @@ fn is_mut_method_call(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option Vec { + // FIXME: this doesn't work well with macros + // see https://github.com/rust-analyzer/rust-analyzer/pull/7535#discussion_r570048550 body.descendants() .filter_map(ast::IdentPat::cast) .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt)) @@ -549,12 +596,14 @@ fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec .collect() } +/// list local variables defined inside `body` that should be returned from extracted function fn vars_defined_in_body_and_outlive(ctx: &AssistContext, body: &FunctionBody) -> Vec { let mut vars_defined_in_body = vars_defined_in_body(&body, ctx); vars_defined_in_body.retain(|var| var_outlives_body(ctx, body, var)); vars_defined_in_body } +/// checks if the relevant local was defined before(outside of) body fn is_defined_before( ctx: &AssistContext, body: &FunctionBody, @@ -571,6 +620,7 @@ fn either_syntax(value: &Either) -> &SyntaxNode { } } +/// checks if local variable is used after(outside of) body fn var_outlives_body(ctx: &AssistContext, body: &FunctionBody, var: &Local) -> bool { let usages = Definition::Local(*var) .usages(&ctx.sema) @@ -590,12 +640,18 @@ fn body_return_ty(ctx: &AssistContext, body: &FunctionBody) -> Option { None => Some(RetType::Stmt), } } +/// Where to put extracted function definition #[derive(Debug)] enum Anchor { + /// Extract free function and put right after current top-level function Freestanding, + /// Extract method and put right after current function in the impl-block Method, } +/// find where to put extracted function definition +/// +/// Function should be put right after returned node fn scope_for_fn_insertion(body: &FunctionBody, anchor: Anchor) -> Option { match body { FunctionBody::Expr(e) => scope_for_fn_insertion_node(e.syntax(), anchor), @@ -801,6 +857,7 @@ fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> Stri ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) } +/// change all usages to account for added `&`/`&mut` for some params fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode { let mut rewriter = SyntaxRewriter::default(); for param in params { @@ -812,7 +869,7 @@ fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) let usages = usages .iter() .filter(|reference| syntax.text_range().contains_range(reference.range)) - .filter_map(|reference| path_at_offset(syntax, reference)); + .filter_map(|reference| path_element_of_reference(syntax, reference)); for path in usages { match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { Some(ast::Expr::MethodCallExpr(_)) => { @@ -1424,6 +1481,54 @@ fn $0fun_name(n: i32) -> (i32, i32) { ); } + #[test] + fn nontrivial_patterns_define_variables() { + check_assist( + extract_function, + r" +struct Counter(i32); +fn foo() { + $0let Counter(n) = Counter(0);$0 + let m = n; +}", + r" +struct Counter(i32); +fn foo() { + let n = fun_name(); + let m = n; +} + +fn $0fun_name() -> i32 { + let Counter(n) = Counter(0); + n +}", + ); + } + + #[test] + fn struct_with_two_fields_pattern_define_variables() { + check_assist( + extract_function, + r" +struct Counter { n: i32, m: i32 }; +fn foo() { + $0let Counter { n, m: k } = Counter { n: 1, m: 2 };$0 + let h = n + k; +}", + r" +struct Counter { n: i32, m: i32 }; +fn foo() { + let (n, k) = fun_name(); + let h = n + k; +} + +fn $0fun_name() -> (i32, i32) { + let Counter { n, m: k } = Counter { n: 1, m: 2 }; + (n, k) +}", + ); + } + #[test] fn mut_var_from_outer_scope() { check_assist( @@ -1754,7 +1859,6 @@ fn $0fun_name(c: Counter) { ); } - #[test] fn non_copy_used_after() { check_assist( -- cgit v1.2.3 From 271c1cb01325ac252b5153c3729462a4d96a0e0a Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Fri, 5 Feb 2021 02:30:34 +0300 Subject: add tests for extracting if/match/while/for exprs --- crates/assists/src/handlers/extract_function.rs | 120 ++++++++++++++++++++++++ 1 file changed, 120 insertions(+) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index dfb3da7a5..5d6f5bb26 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -1010,6 +1010,126 @@ fn $0fun_name() { ); } + #[test] + fn no_args_if() { + check_assist( + extract_function, + r#" +fn foo() { + $0if true { }$0 +}"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() { + if true { } +}"#, + ); + } + + #[test] + fn no_args_if_else() { + check_assist( + extract_function, + r#" +fn foo() -> i32 { + $0if true { 1 } else { 2 }$0 +}"#, + r#" +fn foo() -> i32 { + fun_name() +} + +fn $0fun_name() -> i32 { + if true { 1 } else { 2 } +}"#, + ); + } + + #[test] + fn no_args_if_let_else() { + check_assist( + extract_function, + r#" +fn foo() -> i32 { + $0if let true = false { 1 } else { 2 }$0 +}"#, + r#" +fn foo() -> i32 { + fun_name() +} + +fn $0fun_name() -> i32 { + if let true = false { 1 } else { 2 } +}"#, + ); + } + + #[test] + fn no_args_match() { + check_assist( + extract_function, + r#" +fn foo() -> i32 { + $0match true { + true => 1, + false => 2, + }$0 +}"#, + r#" +fn foo() -> i32 { + fun_name() +} + +fn $0fun_name() -> i32 { + match true { + true => 1, + false => 2, + } +}"#, + ); + } + + #[test] + fn no_args_while() { + check_assist( + extract_function, + r#" +fn foo() { + $0while true { }$0 +}"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() { + while true { } +}"#, + ); + } + + #[test] + fn no_args_for() { + check_assist( + extract_function, + r#" +fn foo() { + $0for v in &[0, 1] { }$0 +}"#, + r#" +fn foo() { + fun_name(); +} + +fn $0fun_name() { + for v in &[0, 1] { } +}"#, + ); + } + #[test] fn no_args_from_loop_unit() { check_assist( -- cgit v1.2.3 From 876ca603166dcd2680652b42fb6bdd5358e59aa6 Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Fri, 5 Feb 2021 04:35:41 +0300 Subject: allow transitive `&mut` access for fields in extract_function --- crates/assists/src/handlers/extract_function.rs | 119 ++++++++++++++++++------ 1 file changed, 92 insertions(+), 27 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 5d6f5bb26..49ea1c4b3 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -503,17 +503,42 @@ fn reference_is_exclusive( } // we take `&mut` reference to variable: `&mut v` - let path = path_element_of_reference(body, reference); - if is_mut_ref_expr(path.as_ref()).unwrap_or(false) { - return true; + let path = match path_element_of_reference(body, reference) { + Some(path) => path, + None => return false, + }; + + expr_require_exclusive_access(ctx, &path).unwrap_or(false) +} + +/// checks if this expr requires `&mut` access, recurses on field access +fn expr_require_exclusive_access(ctx: &AssistContext, expr: &ast::Expr) -> Option { + let parent = expr.syntax().parent()?; + + if let Some(bin_expr) = ast::BinExpr::cast(parent.clone()) { + if bin_expr.op_kind()?.is_assignment() { + return Some(bin_expr.lhs()?.syntax() == expr.syntax()); + } + return Some(false); } - // we call method with `&mut self` receiver - if is_mut_method_call_receiver(ctx, path.as_ref()).unwrap_or(false) { - return true; + if let Some(ref_expr) = ast::RefExpr::cast(parent.clone()) { + return Some(ref_expr.mut_token().is_some()); + } + + if let Some(method_call) = ast::MethodCallExpr::cast(parent.clone()) { + let func = ctx.sema.resolve_method_call(&method_call)?; + let self_param = func.self_param(ctx.db())?; + let access = self_param.access(ctx.db()); + + return Some(matches!(access, hir::Access::Exclusive)); + } + + if let Some(field) = ast::FieldExpr::cast(parent) { + return expr_require_exclusive_access(ctx, &field.into()); } - false + Some(false) } /// Container of local varaible usages @@ -567,24 +592,6 @@ fn path_element_of_reference( Some(path) } -fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option { - let path = path?; - let ref_expr = path.syntax().parent().and_then(ast::RefExpr::cast)?; - Some(ref_expr.mut_token().is_some()) -} - -/// checks if `path` is the receiver in method call that requires `&mut self` access -fn is_mut_method_call_receiver(ctx: &AssistContext, path: Option<&ast::Expr>) -> Option { - let path = path?; - let method_call = path.syntax().parent().and_then(ast::MethodCallExpr::cast)?; - - let func = ctx.sema.resolve_method_call(&method_call)?; - let self_param = func.self_param(ctx.db())?; - let access = self_param.access(ctx.db()); - - Some(matches!(access, hir::Access::Exclusive)) -} - /// list local variables defined inside `body` fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec { // FIXME: this doesn't work well with macros @@ -872,7 +879,7 @@ fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) .filter_map(|reference| path_element_of_reference(syntax, reference)); for path in usages { match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { - Some(ast::Expr::MethodCallExpr(_)) => { + Some(ast::Expr::MethodCallExpr(_)) | Some(ast::Expr::FieldExpr(_)) => { // do nothing } Some(ast::Expr::RefExpr(node)) @@ -1672,6 +1679,64 @@ fn $0fun_name(n: &mut i32) { ); } + #[test] + fn mut_field_from_outer_scope() { + check_assist( + extract_function, + r" +struct C { n: i32 } +fn foo() { + let mut c = C { n: 0 }; + $0c.n += 1;$0 + let m = c.n + 1; +}", + r" +struct C { n: i32 } +fn foo() { + let mut c = C { n: 0 }; + fun_name(&mut c); + let m = c.n + 1; +} + +fn $0fun_name(c: &mut C) { + c.n += 1; +}", + ); + } + + #[test] + fn mut_nested_field_from_outer_scope() { + check_assist( + extract_function, + r" +struct P { n: i32} +struct C { p: P } +fn foo() { + let mut c = C { p: P { n: 0 } }; + let mut v = C { p: P { n: 0 } }; + let u = C { p: P { n: 0 } }; + $0c.p.n += u.p.n; + let r = &mut v.p.n;$0 + let m = c.p.n + v.p.n + u.p.n; +}", + r" +struct P { n: i32} +struct C { p: P } +fn foo() { + let mut c = C { p: P { n: 0 } }; + let mut v = C { p: P { n: 0 } }; + let u = C { p: P { n: 0 } }; + fun_name(&mut c, &u, &mut v); + let m = c.p.n + v.p.n + u.p.n; +} + +fn $0fun_name(c: &mut C, u: &C, v: &mut C) { + c.p.n += u.p.n; + let r = &mut v.p.n; +}", + ); + } + #[test] fn mut_param_many_usages_stmt() { check_assist( @@ -1999,7 +2064,7 @@ fn foo() { } fn $0fun_name(c: &Counter) { - let n = *c.0; + let n = c.0; }", ); } -- cgit v1.2.3 From 7eaa3e56a01e9a275129c76817232559b0e20f2b Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Fri, 5 Feb 2021 05:00:53 +0300 Subject: allow extracted body to be indented(dedent it) --- crates/assists/src/handlers/extract_function.rs | 114 +++++++++++++++++++++--- 1 file changed, 101 insertions(+), 13 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 49ea1c4b3..d876eabca 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -13,7 +13,7 @@ use syntax::{ edit::{AstNodeEdit, IndentLevel}, AstNode, }, - Direction, SyntaxElement, + AstToken, Direction, SyntaxElement, SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR}, SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, T, }; @@ -105,9 +105,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option builder.replace(target_range, format_replacement(ctx, &fun)); - let indent = IndentLevel::from_node(&insert_after); + let new_indent = IndentLevel::from_node(&insert_after); + let old_indent = fun.body.indent_level(); - let fn_def = format_function(ctx, module, &fun, indent); + let fn_def = format_function(ctx, module, &fun, old_indent, new_indent); let insert_offset = insert_after.text_range().end(); builder.insert(insert_offset, fn_def); }, @@ -260,6 +261,18 @@ impl FunctionBody { Some(FunctionBody::Span { elements, leading_indent }) } + fn indent_level(&self) -> IndentLevel { + match &self { + FunctionBody::Expr(expr) => IndentLevel::from_node(expr.syntax()), + FunctionBody::Span { elements, .. } => elements + .iter() + .filter_map(SyntaxElement::as_node) + .map(IndentLevel::from_node) + .min_by_key(|level| level.0) + .expect("body must contain at least one node"), + } + } + fn tail_expr(&self) -> Option { match &self { FunctionBody::Expr(expr) => Some(expr.clone()), @@ -747,16 +760,17 @@ fn format_function( ctx: &AssistContext, module: hir::Module, fun: &Function, - indent: IndentLevel, + old_indent: IndentLevel, + new_indent: IndentLevel, ) -> String { let mut fn_def = String::new(); - format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name); + format_to!(fn_def, "\n\n{}fn $0{}(", new_indent, fun.name); format_function_param_list_to(&mut fn_def, ctx, module, fun); fn_def.push(')'); format_function_ret_to(&mut fn_def, ctx, module, fun); fn_def.push_str(" {"); - format_function_body_to(&mut fn_def, ctx, indent, fun); - format_to!(fn_def, "{}}}", indent); + format_function_body_to(&mut fn_def, ctx, old_indent, new_indent, fun); + format_to!(fn_def, "{}}}", new_indent); fn_def } @@ -818,20 +832,32 @@ fn format_function_ret_to( fn format_function_body_to( fn_def: &mut String, ctx: &AssistContext, - indent: IndentLevel, + old_indent: IndentLevel, + new_indent: IndentLevel, fun: &Function, ) { match &fun.body { FunctionBody::Expr(expr) => { fn_def.push('\n'); - let expr = expr.indent(indent); + let expr = expr.dedent(old_indent).indent(new_indent + 1); let expr = fix_param_usages(ctx, &fun.params, expr.syntax()); - format_to!(fn_def, "{}{}", indent + 1, expr); + format_to!(fn_def, "{}{}", new_indent + 1, expr); fn_def.push('\n'); } FunctionBody::Span { elements, leading_indent } => { format_to!(fn_def, "{}", leading_indent); - for element in elements { + let new_indent_str = format!("\n{}", new_indent + 1); + for mut element in elements { + let new_ws; + if let Some(ws) = element.as_token().cloned().and_then(ast::Whitespace::cast) { + let text = ws.syntax().text(); + if text.contains('\n') { + let new_text = text.replace(&format!("\n{}", old_indent), &new_indent_str); + new_ws = ast::make::tokens::whitespace(&new_text).into(); + element = &new_ws; + } + } + match element { syntax::NodeOrToken::Node(node) => { format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node)); @@ -849,9 +875,9 @@ fn format_function_body_to( match fun.vars_defined_in_body_and_outlive.as_slice() { [] => {} - [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()), + [var] => format_to!(fn_def, "{}{}\n", new_indent + 1, var.name(ctx.db()).unwrap()), [v0, vs @ ..] => { - format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap()); + format_to!(fn_def, "{}({}", new_indent + 1, v0.name(ctx.db()).unwrap()); for var in vs { format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap()); } @@ -2065,6 +2091,68 @@ fn foo() { fn $0fun_name(c: &Counter) { let n = c.0; +}", + ); + } + + #[test] + fn indented_stmts() { + check_assist( + extract_function, + r" +fn foo() { + if true { + loop { + $0let n = 1; + let m = 2;$0 + } + } +}", + r" +fn foo() { + if true { + loop { + fun_name(); + } + } +} + +fn $0fun_name() { + let n = 1; + let m = 2; +}", + ); + } + + #[test] + fn indented_stmts_inside_mod() { + check_assist( + extract_function, + r" +mod bar { + fn foo() { + if true { + loop { + $0let n = 1; + let m = 2;$0 + } + } + } +}", + r" +mod bar { + fn foo() { + if true { + loop { + fun_name(); + } + } + } + + fn $0fun_name() { + let n = 1; + let m = 2; + } }", ); } -- cgit v1.2.3