From f345d1772ab3827fbc3e31428b0d9479cab0ea39 Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Wed, 10 Feb 2021 05:50:03 +0300 Subject: handle return, break and continue when extracting function --- crates/assists/src/handlers/early_return.rs | 2 +- crates/assists/src/handlers/extract_function.rs | 1147 +++++++++++++++++++--- crates/assists/src/handlers/generate_function.rs | 7 +- crates/assists/src/tests.rs | 5 +- crates/syntax/src/ast/make.rs | 56 +- 5 files changed, 1083 insertions(+), 134 deletions(-) diff --git a/crates/assists/src/handlers/early_return.rs b/crates/assists/src/handlers/early_return.rs index 8bbbb7ed5..6b87c3c05 100644 --- a/crates/assists/src/handlers/early_return.rs +++ b/crates/assists/src/handlers/early_return.rs @@ -88,7 +88,7 @@ pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext) let early_expression: ast::Expr = match parent_container.kind() { WHILE_EXPR | LOOP_EXPR => make::expr_continue(), - FN => make::expr_return(), + FN => make::expr_return(None), _ => return None, }; diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 4bddd4eec..4372479b9 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -1,4 +1,4 @@ -use std::{fmt, iter}; +use std::iter; use ast::make; use either::Either; @@ -16,9 +16,9 @@ use syntax::{ edit::{AstNodeEdit, IndentLevel}, AstNode, }, - AstToken, Direction, SyntaxElement, + SyntaxElement, SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR}, - SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, T, + SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, WalkEvent, T, }; use test_utils::mark; @@ -84,6 +84,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option // We should not have variables that outlive body if we have expression block return None; } + let control_flow = external_control_flow(&body)?; let target_range = body.text_range(); @@ -98,16 +99,17 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option name: "fun_name".to_string(), self_param: self_param.map(|(_, pat)| pat), params, + control_flow, ret_ty, body, vars_defined_in_body_and_outlive, }; - builder.replace(target_range, format_replacement(ctx, &fun)); - let new_indent = IndentLevel::from_node(&insert_after); let old_indent = fun.body.indent_level(); + builder.replace(target_range, format_replacement(ctx, &fun, old_indent)); + let fn_def = format_function(ctx, module, &fun, old_indent, new_indent); let insert_offset = insert_after.text_range().end(); builder.insert(insert_offset, fn_def); @@ -115,11 +117,104 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option ) } +fn external_control_flow(body: &FunctionBody) -> Option { + let mut ret_expr = None; + let mut try_expr = None; + let mut break_expr = None; + let mut continue_expr = None; + let (syntax, text_range) = match body { + FunctionBody::Expr(expr) => (expr.syntax(), expr.syntax().text_range()), + FunctionBody::Span { parent, text_range } => (parent.syntax(), *text_range), + }; + + let mut nested_loop = None; + let mut nested_scope = None; + + for e in syntax.preorder() { + let e = match e { + WalkEvent::Enter(e) => e, + WalkEvent::Leave(e) => { + if nested_loop.as_ref() == Some(&e) { + nested_loop = None; + } + if nested_scope.as_ref() == Some(&e) { + nested_scope = None; + } + continue; + } + }; + if nested_scope.is_some() { + continue; + } + if !text_range.contains_range(e.text_range()) { + continue; + } + match e.kind() { + SyntaxKind::LOOP_EXPR | SyntaxKind::WHILE_EXPR | SyntaxKind::FOR_EXPR => { + if nested_loop.is_none() { + nested_loop = Some(e); + } + } + SyntaxKind::FN + | SyntaxKind::CONST + | SyntaxKind::STATIC + | SyntaxKind::IMPL + | SyntaxKind::MODULE => { + if nested_scope.is_none() { + nested_scope = Some(e); + } + } + SyntaxKind::RETURN_EXPR => { + ret_expr = Some(ast::ReturnExpr::cast(e).unwrap()); + } + SyntaxKind::TRY_EXPR => { + try_expr = Some(ast::TryExpr::cast(e).unwrap()); + } + SyntaxKind::BREAK_EXPR if nested_loop.is_none() => { + break_expr = Some(ast::BreakExpr::cast(e).unwrap()); + } + SyntaxKind::CONTINUE_EXPR if nested_loop.is_none() => { + continue_expr = Some(ast::ContinueExpr::cast(e).unwrap()); + } + _ => {} + } + } + + if try_expr.is_some() { + // FIXME: support try + return None; + } + + let kind = match (ret_expr, break_expr, continue_expr) { + (Some(r), None, None) => match r.expr() { + Some(expr) => Some(FlowKind::ReturnValue(expr)), + None => Some(FlowKind::Return), + }, + (Some(_), _, _) => { + mark::hit!(external_control_flow_return_and_bc); + return None; + } + (None, Some(_), Some(_)) => { + mark::hit!(external_control_flow_break_and_continue); + return None; + } + (None, Some(b), None) => match b.expr() { + Some(expr) => Some(FlowKind::BreakValue(expr)), + None => Some(FlowKind::Break), + }, + (None, None, Some(_)) => Some(FlowKind::Continue), + (None, None, None) => None, + }; + + Some(ControlFlow { kind }) +} + #[derive(Debug)] struct Function { name: String, self_param: Option, params: Vec, + control_flow: ControlFlow, ret_ty: RetType, body: FunctionBody, vars_defined_in_body_and_outlive: Vec, @@ -134,6 +229,11 @@ struct Param { is_copy: bool, } +#[derive(Debug)] +struct ControlFlow { + kind: Option, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ParamKind { Value, @@ -142,6 +242,30 @@ enum ParamKind { MutRef, } +#[derive(Debug, Eq, PartialEq)] +enum FunType { + Unit, + Single(hir::Type), + Tuple(Vec), +} + +impl Function { + fn return_type(&self, ctx: &AssistContext) -> FunType { + match &self.ret_ty { + RetType::Expr(ty) if ty.is_unit() => FunType::Unit, + RetType::Expr(ty) => FunType::Single(ty.clone()), + RetType::Stmt => match self.vars_defined_in_body_and_outlive.as_slice() { + [] => FunType::Unit, + [var] => FunType::Single(var.ty(ctx.db())), + vars => { + let types = vars.iter().map(|v| v.ty(ctx.db())).collect(); + FunType::Tuple(types) + } + }, + } + } +} + impl ParamKind { fn is_ref(&self) -> bool { matches!(self, ParamKind::SharedRef | ParamKind::MutRef) @@ -158,26 +282,78 @@ impl Param { } } - fn value_prefix(&self) -> &'static str { + fn to_arg(&self, ctx: &AssistContext) -> ast::Expr { + let var = path_expr_from_local(ctx, self.var); match self.kind() { - ParamKind::Value | ParamKind::MutValue => "", - ParamKind::SharedRef => "&", - ParamKind::MutRef => "&mut ", + ParamKind::Value | ParamKind::MutValue => var, + ParamKind::SharedRef => make::expr_ref(var, false), + ParamKind::MutRef => make::expr_ref(var, true), } } - fn type_prefix(&self) -> &'static str { - match self.kind() { - ParamKind::Value | ParamKind::MutValue => "", - ParamKind::SharedRef => "&", - ParamKind::MutRef => "&mut ", + fn to_param(&self, ctx: &AssistContext, module: hir::Module) -> ast::Param { + let var = self.var.name(ctx.db()).unwrap().to_string(); + let var_name = make::name(&var); + let pat = match self.kind() { + ParamKind::MutValue => make::ident_mut_pat(var_name), + ParamKind::Value | ParamKind::SharedRef | ParamKind::MutRef => { + make::ident_pat(var_name) + } + }; + + let ty = make_ty(&self.ty, ctx, module); + let ty = match self.kind() { + ParamKind::Value | ParamKind::MutValue => ty, + ParamKind::SharedRef => make::ty_ref(ty, false), + ParamKind::MutRef => make::ty_ref(ty, true), + }; + + make::param(pat.into(), ty) + } +} + +/// Control flow that is exported from extracted function +/// +/// E.g.: +/// ```rust,no_run +/// loop { +/// $0 +/// if 42 == 42 { +/// break; +/// } +/// $0 +/// } +/// ``` +#[derive(Debug, Clone)] +enum FlowKind { + /// Return without value (`return;`) + Return, + /// Return with value (`return $expr;`) + ReturnValue(ast::Expr), + /// Break without value (`return;`) + Break, + /// Break with value (`break $expr;`) + BreakValue(ast::Expr), + /// Continue + Continue, +} + +impl FlowKind { + fn make_expr(&self, expr: Option) -> ast::Expr { + match self { + FlowKind::Return | FlowKind::ReturnValue(_) => make::expr_return(expr), + FlowKind::Break | FlowKind::BreakValue(_) => make::expr_break(expr), + FlowKind::Continue => { + stdx::always!(expr.is_none(), "continue with value is not possible"); + make::expr_continue() + } } } - fn mut_pattern(&self) -> &'static str { - match self.kind() { - ParamKind::MutValue => "mut ", - _ => "", + fn expr_ty(&self, ctx: &AssistContext) -> Option { + match self { + FlowKind::ReturnValue(expr) | FlowKind::BreakValue(expr) => ctx.sema.type_of_expr(expr), + FlowKind::Return | FlowKind::Break | FlowKind::Continue => None, } } } @@ -195,14 +371,6 @@ impl RetType { RetType::Stmt => true, } } - - fn as_fn_ret(&self) -> Option<&hir::Type> { - match self { - RetType::Stmt => None, - RetType::Expr(ty) if ty.is_unit() => None, - RetType::Expr(ty) => Some(ty), - } - } } /// Semantically same as `ast::Expr`, but preserves identity when using only part of the Block @@ -234,7 +402,7 @@ impl FunctionBody { fn indent_level(&self) -> IndentLevel { match &self { FunctionBody::Expr(expr) => IndentLevel::from_node(expr.syntax()), - FunctionBody::Span { parent, .. } => IndentLevel::from_node(parent.syntax()), + FunctionBody::Span { parent, .. } => IndentLevel::from_node(parent.syntax()) + 1, } } @@ -668,9 +836,24 @@ fn scope_for_fn_insertion_node(node: &SyntaxNode, anchor: Anchor) -> Option String { - let mut buf = String::new(); +fn format_replacement(ctx: &AssistContext, fun: &Function, indent: IndentLevel) -> String { + let ret_ty = fun.return_type(ctx); + let args = fun.params.iter().map(|param| param.to_arg(ctx)); + let args = make::arg_list(args); + let call_expr = if fun.self_param.is_some() { + let self_arg = make::expr_path(make_path_from_text("self")); + make::expr_method_call(self_arg, &fun.name, args) + } else { + let func = make::expr_path(make_path_from_text(&fun.name)); + make::expr_call(func, args) + }; + + let handler = FlowHandler::from_ret_ty(fun, &ret_ty); + + let expr = handler.make_expr(call_expr).indent(indent); + + let mut buf = String::new(); match fun.vars_defined_in_body_and_outlive.as_slice() { [] => {} [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()), @@ -683,34 +866,123 @@ fn format_replacement(ctx: &AssistContext, fun: &Function) -> String { buf.push_str(") = "); } } - - if fun.self_param.is_some() { - format_to!(buf, "self."); - } - format_to!(buf, "{}(", fun.name); - format_arg_list_to(&mut buf, fun, ctx); - format_to!(buf, ")"); - - if fun.ret_ty.is_unit() { - format_to!(buf, ";"); + format_to!(buf, "{}", expr); + if fun.ret_ty.is_unit() + && (!fun.vars_defined_in_body_and_outlive.is_empty() || !expr.is_block_like()) + { + buf.push(';'); } - buf } -fn format_arg_list_to(buf: &mut String, fun: &Function, ctx: &AssistContext) { - let mut it = fun.params.iter(); - if let Some(param) = it.next() { - format_arg_to(buf, ctx, param); +enum FlowHandler { + None, + If { action: FlowKind }, + IfOption { action: FlowKind }, + MatchOption { none: FlowKind }, + MatchResult { err: FlowKind }, +} + +impl FlowHandler { + fn from_ret_ty(fun: &Function, ret_ty: &FunType) -> FlowHandler { + match &fun.control_flow.kind { + None => FlowHandler::None, + Some(flow_kind) => { + let action = flow_kind.clone(); + if *ret_ty == FunType::Unit { + match flow_kind { + FlowKind::Return | FlowKind::Break | FlowKind::Continue => { + FlowHandler::If { action } + } + FlowKind::ReturnValue(_) | FlowKind::BreakValue(_) => { + FlowHandler::IfOption { action } + } + } + } else { + match flow_kind { + FlowKind::Return | FlowKind::Break | FlowKind::Continue => { + FlowHandler::MatchOption { none: action } + } + FlowKind::ReturnValue(_) | FlowKind::BreakValue(_) => { + FlowHandler::MatchResult { err: action } + } + } + } + } + } } - for param in it { - buf.push_str(", "); - format_arg_to(buf, ctx, param); + + fn make_expr(&self, call_expr: ast::Expr) -> ast::Expr { + match self { + FlowHandler::None => call_expr, + FlowHandler::If { action } => { + let action = action.make_expr(None); + let stmt = make::expr_stmt(action); + let block = make::block_expr(iter::once(stmt.into()), None); + let condition = make::condition(call_expr, None); + make::expr_if(condition, block, None) + } + FlowHandler::IfOption { action } => { + let path = make_path_from_text("Some"); + let value_pat = make::ident_pat(make::name("value")); + let pattern = make::tuple_struct_pat(path, iter::once(value_pat.into())); + let cond = make::condition(call_expr, Some(pattern.into())); + let value = make::expr_path(make_path_from_text("value")); + let action_expr = action.make_expr(Some(value)); + let action_stmt = make::expr_stmt(action_expr); + let then = make::block_expr(iter::once(action_stmt.into()), None); + make::expr_if(cond, then, None) + } + FlowHandler::MatchOption { none } => { + let some_name = "value"; + + let some_arm = { + let path = make_path_from_text("Some"); + let value_pat = make::ident_pat(make::name(some_name)); + let pat = make::tuple_struct_pat(path, iter::once(value_pat.into())); + let value = make::expr_path(make_path_from_text(some_name)); + make::match_arm(iter::once(pat.into()), value) + }; + let none_arm = { + let path = make_path_from_text("None"); + let pat = make::path_pat(path); + make::match_arm(iter::once(pat), none.make_expr(None)) + }; + let arms = make::match_arm_list(vec![some_arm, none_arm]); + make::expr_match(call_expr, arms) + } + FlowHandler::MatchResult { err } => { + let ok_name = "value"; + let err_name = "value"; + + let ok_arm = { + let path = make_path_from_text("Ok"); + let value_pat = make::ident_pat(make::name(ok_name)); + let pat = make::tuple_struct_pat(path, iter::once(value_pat.into())); + let value = make::expr_path(make_path_from_text(ok_name)); + make::match_arm(iter::once(pat.into()), value) + }; + let err_arm = { + let path = make_path_from_text("Err"); + let value_pat = make::ident_pat(make::name(err_name)); + let pat = make::tuple_struct_pat(path, iter::once(value_pat.into())); + let value = make::expr_path(make_path_from_text(err_name)); + make::match_arm(iter::once(pat.into()), err.make_expr(Some(value))) + }; + let arms = make::match_arm_list(vec![ok_arm, err_arm]); + make::expr_match(call_expr, arms) + } + } } } -fn format_arg_to(buf: &mut String, ctx: &AssistContext, param: &Param) { - format_to!(buf, "{}{}", param.value_prefix(), param.var.name(ctx.db()).unwrap()); +fn make_path_from_text(text: &str) -> ast::Path { + make::path_unqualified(make::path_segment(make::name_ref(text))) +} + +fn path_expr_from_local(ctx: &AssistContext, var: Local) -> ast::Expr { + let name = var.name(ctx.db()).unwrap().to_string(); + make::expr_path(make_path_from_text(&name)) } fn format_function( @@ -721,91 +993,99 @@ fn format_function( new_indent: IndentLevel, ) -> String { let mut fn_def = String::new(); - format_to!(fn_def, "\n\n{}fn $0{}(", new_indent, fun.name); - format_function_param_list_to(&mut fn_def, ctx, module, fun); - fn_def.push(')'); - format_function_ret_to(&mut fn_def, ctx, module, fun); - fn_def.push(' '); - format_function_body_to(&mut fn_def, ctx, old_indent, new_indent, fun); + let params = make_param_list(ctx, module, fun); + let ret_ty = make_ret_ty(ctx, module, fun); + let body = make_body(ctx, old_indent, new_indent, fun); + format_to!(fn_def, "\n\n{}fn $0{}{}", new_indent, fun.name, params); + if let Some(ret_ty) = ret_ty { + format_to!(fn_def, " {}", ret_ty); + } + format_to!(fn_def, " {}", body); fn_def } -fn format_function_param_list_to( - fn_def: &mut String, - ctx: &AssistContext, - module: hir::Module, - fun: &Function, -) { - let mut it = fun.params.iter(); - if let Some(self_param) = &fun.self_param { - format_to!(fn_def, "{}", self_param); - } else if let Some(param) = it.next() { - format_param_to(fn_def, ctx, module, param); - } - for param in it { - fn_def.push_str(", "); - format_param_to(fn_def, ctx, module, param); - } -} - -fn format_param_to(fn_def: &mut String, ctx: &AssistContext, module: hir::Module, param: &Param) { - format_to!( - fn_def, - "{}{}: {}{}", - param.mut_pattern(), - param.var.name(ctx.db()).unwrap(), - param.type_prefix(), - format_type(¶m.ty, ctx, module) - ); +fn make_param_list(ctx: &AssistContext, module: hir::Module, fun: &Function) -> ast::ParamList { + let self_param = fun.self_param.clone(); + let params = fun.params.iter().map(|param| param.to_param(ctx, module)); + make::param_list(self_param, params) } -fn format_function_ret_to( - fn_def: &mut String, - ctx: &AssistContext, - module: hir::Module, - fun: &Function, -) { - if let Some(ty) = fun.ret_ty.as_fn_ret() { - format_to!(fn_def, " -> {}", format_type(ty, ctx, module)); - } else { - match fun.vars_defined_in_body_and_outlive.as_slice() { - [] => {} - [var] => { - format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module)); - } - [v0, vs @ ..] => { - format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module)); - for var in vs { - format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module)); +impl FunType { + fn make_ty(&self, ctx: &AssistContext, module: hir::Module) -> ast::Type { + match self { + FunType::Unit => make::ty_unit(), + FunType::Single(ty) => make_ty(ty, ctx, module), + FunType::Tuple(types) => match types.as_slice() { + [] => { + stdx::never!("tuple type with 0 elements"); + make::ty_unit() } - fn_def.push(')'); - } + [ty] => { + stdx::never!("tuple type with 1 element"); + make_ty(ty, ctx, module) + } + types => { + let types = types.iter().map(|ty| make_ty(ty, ctx, module)); + make::ty_tuple(types) + } + }, } } } -fn format_function_body_to( - fn_def: &mut String, +fn make_ret_ty(ctx: &AssistContext, module: hir::Module, fun: &Function) -> Option { + let ty = fun.return_type(ctx); + let handler = FlowHandler::from_ret_ty(fun, &ty); + let ret_ty = match &handler { + FlowHandler::None => { + if matches!(ty, FunType::Unit) { + return None; + } + ty.make_ty(ctx, module) + } + FlowHandler::If { .. } => make::ty("bool"), + FlowHandler::IfOption { action } => { + let handler_ty = action + .expr_ty(ctx) + .map(|ty| make_ty(&ty, ctx, module)) + .unwrap_or_else(make::ty_unit); + make::ty_generic(make::name_ref("Option"), iter::once(handler_ty)) + } + FlowHandler::MatchOption { .. } => { + make::ty_generic(make::name_ref("Option"), iter::once(ty.make_ty(ctx, module))) + } + FlowHandler::MatchResult { err } => { + let handler_ty = + err.expr_ty(ctx).map(|ty| make_ty(&ty, ctx, module)).unwrap_or_else(make::ty_unit); + make::ty_generic(make::name_ref("Result"), vec![ty.make_ty(ctx, module), handler_ty]) + } + }; + Some(make::ret_type(ret_ty)) +} + +fn make_body( ctx: &AssistContext, old_indent: IndentLevel, new_indent: IndentLevel, fun: &Function, -) { +) -> ast::BlockExpr { + let ret_ty = fun.return_type(ctx); + let handler = FlowHandler::from_ret_ty(fun, &ret_ty); let block = match &fun.body { FunctionBody::Expr(expr) => { - let expr = rewrite_body_segment(ctx, &fun.params, expr.syntax()); + let expr = rewrite_body_segment(ctx, &fun.params, &handler, expr.syntax()); let expr = ast::Expr::cast(expr).unwrap(); let expr = expr.dedent(old_indent).indent(IndentLevel(1)); - let block = make::block_expr(Vec::new(), Some(expr)); - block.indent(new_indent) + + make::block_expr(Vec::new(), Some(expr)) } FunctionBody::Span { parent, text_range } => { let mut elements: Vec<_> = parent .syntax() .children() .filter(|it| text_range.contains_range(it.text_range())) - .map(|it| rewrite_body_segment(ctx, &fun.params, &it)) + .map(|it| rewrite_body_segment(ctx, &fun.params, &handler, &it)) .collect(); let mut tail_expr = match elements.pop() { @@ -821,10 +1101,9 @@ fn format_function_body_to( [] => {} [var] => { tail_expr = Some(path_expr_from_local(ctx, *var)); - }, + } vars => { - let exprs = vars.iter() - .map(|var| path_expr_from_local(ctx, *var)); + let exprs = vars.iter().map(|var| path_expr_from_local(ctx, *var)); let expr = make::expr_tuple(exprs); tail_expr = Some(expr); } @@ -839,33 +1118,70 @@ fn format_function_body_to( } }); + let body_indent = IndentLevel(1); + let elements = elements.map(|stmt| stmt.dedent(old_indent).indent(body_indent)); + let tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent)); - let block = make::block_expr(elements, tail_expr); - block.dedent(old_indent).indent(new_indent) + make::block_expr(elements, tail_expr) } }; + let block = match &handler { + FlowHandler::None => block, + FlowHandler::If { .. } => { + let lit_false = ast::Literal::cast(make::tokens::literal("false").parent()).unwrap(); + with_tail_expr(block, lit_false.into()) + } + FlowHandler::IfOption { .. } => { + let none = make::expr_path(make_path_from_text("None")); + with_tail_expr(block, none) } - - format_to!(fn_def, "{}", block); + FlowHandler::MatchOption { .. } => map_tail_expr(block, |tail_expr| { + let some = make::expr_path(make_path_from_text("Some")); + let args = make::arg_list(iter::once(tail_expr)); + make::expr_call(some, args) + }), + FlowHandler::MatchResult { .. } => map_tail_expr(block, |tail_expr| { + let some = make::expr_path(make_path_from_text("Ok")); + let args = make::arg_list(iter::once(tail_expr)); + make::expr_call(some, args) + }), + }; + + block.indent(new_indent) } -fn path_expr_from_local(ctx: &AssistContext, var: Local) -> ast::Expr { - let name = var.name(ctx.db()).unwrap().to_string(); - let path = make::path_unqualified(make::path_segment(make::name_ref(&name))); - make::expr_path(path) +fn map_tail_expr(block: ast::BlockExpr, f: impl FnOnce(ast::Expr) -> ast::Expr) -> ast::BlockExpr { + let tail_expr = match block.tail_expr() { + Some(tail_expr) => tail_expr, + None => return block, + }; + make::block_expr(block.statements(), Some(f(tail_expr))) +} + +fn with_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr { + let stmt_tail = block.tail_expr().map(|expr| make::expr_stmt(expr).into()); + let stmts = block.statements().chain(stmt_tail); + make::block_expr(stmts, Some(tail_expr)) } fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String { ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string()) } +fn make_ty(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> ast::Type { + let ty_str = format_type(ty, ctx, module); + make::ty(&ty_str) +} + fn rewrite_body_segment( ctx: &AssistContext, params: &[Param], + handler: &FlowHandler, syntax: &SyntaxNode, ) -> SyntaxNode { - fix_param_usages(ctx, params, syntax) + let syntax = fix_param_usages(ctx, params, syntax); + update_external_control_flow(handler, &syntax) } /// change all usages to account for added `&`/`&mut` for some params @@ -906,6 +1222,98 @@ fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) rewriter.rewrite(syntax) } +fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) -> SyntaxNode { + let mut rewriter = SyntaxRewriter::default(); + + let mut nested_loop = None; + let mut nested_scope = None; + for event in syntax.preorder() { + let node = match event { + WalkEvent::Enter(e) => { + match e.kind() { + SyntaxKind::LOOP_EXPR | SyntaxKind::WHILE_EXPR | SyntaxKind::FOR_EXPR => { + if nested_loop.is_none() { + nested_loop = Some(e.clone()); + } + } + SyntaxKind::FN + | SyntaxKind::CONST + | SyntaxKind::STATIC + | SyntaxKind::IMPL + | SyntaxKind::MODULE => { + if nested_scope.is_none() { + nested_scope = Some(e.clone()); + } + } + _ => {} + } + e + } + WalkEvent::Leave(e) => { + if nested_loop.as_ref() == Some(&e) { + nested_loop = None; + } + if nested_scope.as_ref() == Some(&e) { + nested_scope = None; + } + continue; + } + }; + if nested_scope.is_some() { + continue; + } + let expr = match ast::Expr::cast(node) { + Some(e) => e, + None => continue, + }; + match expr { + ast::Expr::ReturnExpr(return_expr) if nested_scope.is_none() => { + let expr = return_expr.expr(); + if let Some(replacement) = make_rewritten_flow(handler, expr) { + rewriter.replace_ast(&return_expr.into(), &replacement); + } + } + ast::Expr::BreakExpr(break_expr) if nested_loop.is_none() => { + let expr = break_expr.expr(); + if let Some(replacement) = make_rewritten_flow(handler, expr) { + rewriter.replace_ast(&break_expr.into(), &replacement); + } + } + ast::Expr::ContinueExpr(continue_expr) if nested_loop.is_none() => { + if let Some(replacement) = make_rewritten_flow(handler, None) { + rewriter.replace_ast(&continue_expr.into(), &replacement); + } + } + _ => { + // do nothing + } + } + } + + rewriter.rewrite(syntax) +} + +fn make_rewritten_flow(handler: &FlowHandler, arg_expr: Option) -> Option { + let value = match handler { + FlowHandler::None => return None, + FlowHandler::If { .. } => { + ast::Literal::cast(make::tokens::literal("true").parent()).unwrap().into() + } + FlowHandler::IfOption { .. } => { + let expr = arg_expr.unwrap_or_else(|| make::expr_tuple(Vec::new())); + let args = make::arg_list(iter::once(expr)); + make::expr_call(make::expr_path(make_path_from_text("Some")), args) + } + FlowHandler::MatchOption { .. } => make::expr_path(make_path_from_text("None")), + FlowHandler::MatchResult { .. } => { + let expr = arg_expr.unwrap_or_else(|| make::expr_tuple(Vec::new())); + let args = make::arg_list(iter::once(expr)); + make::expr_call(make::expr_path(make_path_from_text("Err")), args) + } + }; + Some(make::expr_return(Some(value))) +} + #[cfg(test)] mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; @@ -2073,6 +2481,66 @@ fn $0fun_name(c: &Counter) { ); } + #[test] + fn copy_used_after() { + check_assist( + extract_function, + r##" +#[lang = "copy"] +pub trait Copy {} +impl Copy for i32 {} +fn foo() { + let n = 0; + $0let m = n;$0 + let k = n; +}"##, + r##" +#[lang = "copy"] +pub trait Copy {} +impl Copy for i32 {} +fn foo() { + let n = 0; + fun_name(n); + let k = n; +} + +fn $0fun_name(n: i32) { + let m = n; +}"##, + ) + } + + #[test] + fn copy_custom_used_after() { + check_assist( + extract_function, + r##" +#[lang = "copy"] +pub trait Copy {} +struct Counter(i32); +impl Copy for Counter {} +fn foo() { + let c = Counter(0); + $0let n = c.0;$0 + let m = c.0; +}"##, + r##" +#[lang = "copy"] +pub trait Copy {} +struct Counter(i32); +impl Copy for Counter {} +fn foo() { + let c = Counter(0); + fun_name(c); + let m = c.0; +} + +fn $0fun_name(c: Counter) { + let n = c.0; +}"##, + ); + } + #[test] fn indented_stmts() { check_assist( @@ -2134,4 +2602,441 @@ mod bar { }", ); } + + #[test] + fn break_loop() { + check_assist( + extract_function, + r##" +enum Option { + #[lang = "None"] None, + #[lang = "Some"] Some(T), +} +use Option::*; +fn foo() { + loop { + let n = 1; + $0let m = n + 1; + break; + let k = 2;$0 + let h = 1 + k; + } +}"##, + r##" +enum Option { + #[lang = "None"] None, + #[lang = "Some"] Some(T), +} +use Option::*; +fn foo() { + loop { + let n = 1; + let k = match fun_name(n) { + Some(value) => value, + None => break, + }; + let h = 1 + k; + } +} + +fn $0fun_name(n: i32) -> Option { + let m = n + 1; + return None; + let k = 2; + Some(k) +}"##, + ); + } + + #[test] + fn return_to_parent() { + check_assist( + extract_function, + r##" +#[lang = "copy"] +pub trait Copy {} +impl Copy for i32 {} +enum Result { + #[lang = "Ok"] Ok(T), + #[lang = "Err"] Err(E), +} +use Result::*; +fn foo() -> i64 { + let n = 1; + $0let m = n + 1; + return 1; + let k = 2;$0 + (n + k) as i64 +}"##, + r##" +#[lang = "copy"] +pub trait Copy {} +impl Copy for i32 {} +enum Result { + #[lang = "Ok"] Ok(T), + #[lang = "Err"] Err(E), +} +use Result::*; +fn foo() -> i64 { + let n = 1; + let k = match fun_name(n) { + Ok(value) => value, + Err(value) => return value, + }; + (n + k) as i64 +} + +fn $0fun_name(n: i32) -> Result { + let m = n + 1; + return Err(1); + let k = 2; + Ok(k) +}"##, + ); + } + + #[test] + fn break_and_continue() { + mark::check!(external_control_flow_break_and_continue); + check_assist_not_applicable( + extract_function, + r##" +fn foo() { + loop { + let n = 1; + $0let m = n + 1; + break; + let k = 2; + continue; + let k = k + 1;$0 + let r = n + k; + } +}"##, + ); + } + + #[test] + fn return_and_break() { + mark::check!(external_control_flow_return_and_bc); + check_assist_not_applicable( + extract_function, + r##" +fn foo() { + loop { + let n = 1; + $0let m = n + 1; + break; + let k = 2; + return; + let k = k + 1;$0 + let r = n + k; + } +}"##, + ); + } + + #[test] + fn break_loop_with_if() { + check_assist( + extract_function, + r##" +fn foo() { + loop { + let mut n = 1; + $0let m = n + 1; + break; + n += m;$0 + let h = 1 + n; + } +}"##, + r##" +fn foo() { + loop { + let mut n = 1; + if fun_name(&mut n) { + break; + } + let h = 1 + n; + } +} + +fn $0fun_name(n: &mut i32) -> bool { + let m = *n + 1; + return true; + *n += m; + false +}"##, + ); + } + + #[test] + fn break_loop_nested() { + check_assist( + extract_function, + r##" +fn foo() { + loop { + let mut n = 1; + $0let m = n + 1; + if m == 42 { + break; + }$0 + let h = 1; + } +}"##, + r##" +fn foo() { + loop { + let mut n = 1; + if fun_name(n) { + break; + } + let h = 1; + } +} + +fn $0fun_name(n: i32) -> bool { + let m = n + 1; + if m == 42 { + return true; + } + false +}"##, + ); + } + + #[test] + fn return_from_nested_loop() { + check_assist( + extract_function, + r##" +fn foo() { + loop { + let n = 1; + $0 + let k = 1; + loop { + return; + } + let m = k + 1;$0 + let h = 1 + m; + } +}"##, + r##" +fn foo() { + loop { + let n = 1; + let m = match fun_name() { + Some(value) => value, + None => return, + }; + let h = 1 + m; + } +} + +fn $0fun_name() -> Option { + let k = 1; + loop { + return None; + } + let m = k + 1; + Some(m) +}"##, + ); + } + + #[test] + fn break_from_nested_loop() { + check_assist( + extract_function, + r##" +fn foo() { + loop { + let n = 1; + $0let k = 1; + loop { + break; + } + let m = k + 1;$0 + let h = 1 + m; + } +}"##, + r##" +fn foo() { + loop { + let n = 1; + let m = fun_name(); + let h = 1 + m; + } +} + +fn $0fun_name() -> i32 { + let k = 1; + loop { + break; + } + let m = k + 1; + m +}"##, + ); + } + + #[test] + fn break_from_nested_and_outer_loops() { + check_assist( + extract_function, + r##" +fn foo() { + loop { + let n = 1; + $0let k = 1; + loop { + break; + } + if k == 42 { + break; + } + let m = k + 1;$0 + let h = 1 + m; + } +}"##, + r##" +fn foo() { + loop { + let n = 1; + let m = match fun_name() { + Some(value) => value, + None => break, + }; + let h = 1 + m; + } +} + +fn $0fun_name() -> Option { + let k = 1; + loop { + break; + } + if k == 42 { + return None; + } + let m = k + 1; + Some(m) +}"##, + ); + } + + #[test] + fn return_from_nested_fn() { + check_assist( + extract_function, + r##" +fn foo() { + loop { + let n = 1; + $0let k = 1; + fn test() { + return; + } + let m = k + 1;$0 + let h = 1 + m; + } +}"##, + r##" +fn foo() { + loop { + let n = 1; + let m = fun_name(); + let h = 1 + m; + } +} + +fn $0fun_name() -> i32 { + let k = 1; + fn test() { + return; + } + let m = k + 1; + m +}"##, + ); + } + + #[test] + fn break_with_value() { + check_assist( + extract_function, + r##" +fn foo() -> i32 { + loop { + let n = 1; + $0let k = 1; + if k == 42 { + break 3; + } + let m = k + 1;$0 + let h = 1; + } +}"##, + r##" +fn foo() -> i32 { + loop { + let n = 1; + if let Some(value) = fun_name() { + break value; + } + let h = 1; + } +} + +fn $0fun_name() -> Option { + let k = 1; + if k == 42 { + return Some(3); + } + let m = k + 1; + None +}"##, + ); + } + + #[test] + fn break_with_value_and_return() { + check_assist( + extract_function, + r##" +fn foo() -> i64 { + loop { + let n = 1; + $0 + let k = 1; + if k == 42 { + break 3; + } + let m = k + 1;$0 + let h = 1 + m; + } +}"##, + r##" +fn foo() -> i64 { + loop { + let n = 1; + let m = match fun_name() { + Ok(value) => value, + Err(value) => break value, + }; + let h = 1 + m; + } +} + +fn $0fun_name() -> Result { + let k = 1; + if k == 42 { + return Err(3); + } + let m = k + 1; + Ok(m) +}"##, + ); + } } diff --git a/crates/assists/src/handlers/generate_function.rs b/crates/assists/src/handlers/generate_function.rs index 1805c1dfd..959824981 100644 --- a/crates/assists/src/handlers/generate_function.rs +++ b/crates/assists/src/handlers/generate_function.rs @@ -215,8 +215,11 @@ fn fn_args( }); } deduplicate_arg_names(&mut arg_names); - let params = arg_names.into_iter().zip(arg_types).map(|(name, ty)| make::param(name, ty)); - Some((None, make::param_list(params))) + let params = arg_names + .into_iter() + .zip(arg_types) + .map(|(name, ty)| make::param(make::ident_pat(make::name(&name)).into(), make::ty(&ty))); + Some((None, make::param_list(None, params))) } /// Makes duplicate argument names unique by appending incrementing numbers. diff --git a/crates/assists/src/tests.rs b/crates/assists/src/tests.rs index 5b9992f15..720f561a1 100644 --- a/crates/assists/src/tests.rs +++ b/crates/assists/src/tests.rs @@ -195,6 +195,7 @@ fn assist_order_if_expr() { let assists = Assist::get(&db, &TEST_CONFIG, false, frange); let mut assists = assists.iter(); + assert_eq!(assists.next().expect("expected assist").label, "Extract into function"); assert_eq!(assists.next().expect("expected assist").label, "Extract into variable"); assert_eq!(assists.next().expect("expected assist").label, "Replace with match"); } @@ -220,6 +221,7 @@ fn assist_filter_works() { let assists = Assist::get(&db, &cfg, false, frange); let mut assists = assists.iter(); + assert_eq!(assists.next().expect("expected assist").label, "Extract into function"); assert_eq!(assists.next().expect("expected assist").label, "Extract into variable"); assert_eq!(assists.next().expect("expected assist").label, "Replace with match"); } @@ -228,9 +230,10 @@ fn assist_filter_works() { let mut cfg = TEST_CONFIG; cfg.allowed = Some(vec![AssistKind::RefactorExtract]); let assists = Assist::get(&db, &cfg, false, frange); - assert_eq!(assists.len(), 1); + assert_eq!(assists.len(), 2); let mut assists = assists.iter(); + assert_eq!(assists.next().expect("expected assist").label, "Extract into function"); assert_eq!(assists.next().expect("expected assist").label, "Extract into variable"); } diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index 1da5a125e..5f6b96c23 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -24,11 +24,24 @@ pub fn name_ref(text: &str) -> ast::NameRef { // FIXME: replace stringly-typed constructor with a family of typed ctors, a-la // `expr_xxx`. pub fn ty(text: &str) -> ast::Type { - ast_from_text(&format!("impl {} for D {{}};", text)) + ast_from_text(&format!("fn f() -> {} {{}}", text)) } pub fn ty_unit() -> ast::Type { ty("()") } +// FIXME: handle types of length == 1 +pub fn ty_tuple(types: impl IntoIterator) -> ast::Type { + let contents = types.into_iter().join(", "); + ty(&format!("({})", contents)) +} +// FIXME: handle path to type +pub fn ty_generic(name: ast::NameRef, types: impl IntoIterator) -> ast::Type { + let contents = types.into_iter().join(", "); + ty(&format!("{}<{}>", name, contents)) +} +pub fn ty_ref(target: ast::Type, exclusive: bool) -> ast::Type { + ty(&if exclusive { format!("&mut {}", target) } else { format!("&{}", target) }) +} pub fn assoc_item_list() -> ast::AssocItemList { ast_from_text("impl C for D {};") @@ -175,11 +188,17 @@ pub fn expr_path(path: ast::Path) -> ast::Expr { pub fn expr_continue() -> ast::Expr { expr_from_text("continue") } -pub fn expr_break() -> ast::Expr { - expr_from_text("break") +pub fn expr_break(expr: Option) -> ast::Expr { + match expr { + Some(expr) => expr_from_text(&format!("break {}", expr)), + None => expr_from_text("break"), + } } -pub fn expr_return() -> ast::Expr { - expr_from_text("return") +pub fn expr_return(expr: Option) -> ast::Expr { + match expr { + Some(expr) => expr_from_text(&format!("return {}", expr)), + None => expr_from_text("return"), + } } pub fn expr_match(expr: ast::Expr, match_arm_list: ast::MatchArmList) -> ast::Expr { expr_from_text(&format!("match {} {}", expr, match_arm_list)) @@ -212,6 +231,10 @@ pub fn expr_ref(expr: ast::Expr, exclusive: bool) -> ast::Expr { pub fn expr_paren(expr: ast::Expr) -> ast::Expr { expr_from_text(&format!("({})", expr)) } +pub fn expr_tuple(elements: impl IntoIterator) -> ast::Expr { + let expr = elements.into_iter().format(", "); + expr_from_text(&format!("({})", expr)) +} fn expr_from_text(text: &str) -> ast::Expr { ast_from_text(&format!("const C: () = {};", text)) } @@ -236,6 +259,13 @@ pub fn ident_pat(name: ast::Name) -> ast::IdentPat { ast_from_text(&format!("fn f({}: ())", text)) } } +pub fn ident_mut_pat(name: ast::Name) -> ast::IdentPat { + return from_text(name.text()); + + fn from_text(text: &str) -> ast::IdentPat { + ast_from_text(&format!("fn f(mut {}: ())", text)) + } +} pub fn wildcard_pat() -> ast::WildcardPat { return from_text("_"); @@ -356,17 +386,25 @@ pub fn token(kind: SyntaxKind) -> SyntaxToken { .unwrap_or_else(|| panic!("unhandled token: {:?}", kind)) } -pub fn param(name: String, ty: String) -> ast::Param { - ast_from_text(&format!("fn f({}: {}) {{ }}", name, ty)) +pub fn param(pat: ast::Pat, ty: ast::Type) -> ast::Param { + ast_from_text(&format!("fn f({}: {}) {{ }}", pat, ty)) } pub fn ret_type(ty: ast::Type) -> ast::RetType { ast_from_text(&format!("fn f() -> {} {{ }}", ty)) } -pub fn param_list(pats: impl IntoIterator) -> ast::ParamList { +pub fn param_list( + self_param: Option, + pats: impl IntoIterator, +) -> ast::ParamList { let args = pats.into_iter().join(", "); - ast_from_text(&format!("fn f({}) {{ }}", args)) + let list = match self_param { + Some(self_param) if args.is_empty() => format!("fn f({}) {{ }}", self_param), + Some(self_param) => format!("fn f({}, {}) {{ }}", self_param, args), + None => format!("fn f({}) {{ }}", args), + }; + ast_from_text(&list) } pub fn generic_param(name: String, ty: Option) -> ast::GenericParam { -- cgit v1.2.3