From 9eb19d92dd8d3200f3530faefa7a4048f58d280d Mon Sep 17 00:00:00 2001 From: Vladyslav Katasonov Date: Wed, 10 Feb 2021 19:26:42 +0300 Subject: allow try expr? when extacting function --- crates/assists/src/handlers/extract_function.rs | 377 ++++++++++++++++++++++-- 1 file changed, 347 insertions(+), 30 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/extract_function.rs b/crates/assists/src/handlers/extract_function.rs index 4372479b9..225a50d2d 100644 --- a/crates/assists/src/handlers/extract_function.rs +++ b/crates/assists/src/handlers/extract_function.rs @@ -84,7 +84,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option // We should not have variables that outlive body if we have expression block return None; } - let control_flow = external_control_flow(&body)?; + let control_flow = external_control_flow(ctx, &body)?; let target_range = body.text_range(); @@ -117,7 +117,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option ) } -fn external_control_flow(body: &FunctionBody) -> Option { +fn external_control_flow(ctx: &AssistContext, body: &FunctionBody) -> Option { let mut ret_expr = None; let mut try_expr = None; let mut break_expr = None; @@ -180,35 +180,71 @@ fn external_control_flow(body: &FunctionBody) -> Option { } } - if try_expr.is_some() { - // FIXME: support try - return None; - } + let kind = match (try_expr, ret_expr, break_expr, continue_expr) { + (Some(e), None, None, None) => { + let func = e.syntax().ancestors().find_map(ast::Fn::cast)?; + let def = ctx.sema.to_def(&func)?; + let ret_ty = def.ret_type(ctx.db()); + let kind = try_kind_of_ty(ret_ty, ctx)?; - let kind = match (ret_expr, break_expr, continue_expr) { - (Some(r), None, None) => match r.expr() { + Some(FlowKind::Try { kind }) + } + (Some(_), Some(r), None, None) => match r.expr() { + Some(expr) => { + if let Some(kind) = expr_err_kind(&expr, ctx) { + Some(FlowKind::TryReturn { expr, kind }) + } else { + mark::hit!(external_control_flow_try_and_return_non_err); + return None; + } + } + None => return None, + }, + (Some(_), _, _, _) => { + mark::hit!(external_control_flow_try_and_bc); + return None; + } + (None, Some(r), None, None) => match r.expr() { Some(expr) => Some(FlowKind::ReturnValue(expr)), None => Some(FlowKind::Return), }, - (Some(_), _, _) => { + (None, Some(_), _, _) => { mark::hit!(external_control_flow_return_and_bc); return None; } - (None, Some(_), Some(_)) => { + (None, None, Some(_), Some(_)) => { mark::hit!(external_control_flow_break_and_continue); return None; } - (None, Some(b), None) => match b.expr() { + (None, None, Some(b), None) => match b.expr() { Some(expr) => Some(FlowKind::BreakValue(expr)), None => Some(FlowKind::Break), }, - (None, None, Some(_)) => Some(FlowKind::Continue), - (None, None, None) => None, + (None, None, None, Some(_)) => Some(FlowKind::Continue), + (None, None, None, None) => None, }; Some(ControlFlow { kind }) } +/// Checks is expr is `Err(_)` or `None` +fn expr_err_kind(expr: &ast::Expr, ctx: &AssistContext) -> Option { + let call_expr = match expr { + ast::Expr::CallExpr(call_expr) => call_expr, + _ => return None, + }; + let func = call_expr.expr()?; + let text = func.syntax().text(); + + if text == "Err" { + Some(TryKind::Result { ty: ctx.sema.type_of_expr(expr)? }) + } else if text == "None" { + Some(TryKind::Option) + } else { + None + } +} + #[derive(Debug)] struct Function { name: String, @@ -330,6 +366,13 @@ enum FlowKind { Return, /// Return with value (`return $expr;`) ReturnValue(ast::Expr), + Try { + kind: TryKind, + }, + TryReturn { + expr: ast::Expr, + kind: TryKind, + }, /// Break without value (`return;`) Break, /// Break with value (`break $expr;`) @@ -338,11 +381,21 @@ enum FlowKind { Continue, } +#[derive(Debug, Clone)] +enum TryKind { + Option, + Result { ty: hir::Type }, +} + impl FlowKind { - fn make_expr(&self, expr: Option) -> ast::Expr { + fn make_result_handler(&self, expr: Option) -> ast::Expr { match self { FlowKind::Return | FlowKind::ReturnValue(_) => make::expr_return(expr), FlowKind::Break | FlowKind::BreakValue(_) => make::expr_break(expr), + FlowKind::Try { .. } | FlowKind::TryReturn { .. } => { + stdx::never!("cannot have result handler with try"); + expr.unwrap_or_else(|| make::expr_return(None)) + } FlowKind::Continue => { stdx::always!(expr.is_none(), "continue with value is not possible"); make::expr_continue() @@ -352,12 +405,34 @@ impl FlowKind { fn expr_ty(&self, ctx: &AssistContext) -> Option { match self { - FlowKind::ReturnValue(expr) | FlowKind::BreakValue(expr) => ctx.sema.type_of_expr(expr), + FlowKind::ReturnValue(expr) + | FlowKind::BreakValue(expr) + | FlowKind::TryReturn { expr, .. } => ctx.sema.type_of_expr(expr), + FlowKind::Try { .. } => { + stdx::never!("try does not have defined expr_ty"); + None + } FlowKind::Return | FlowKind::Break | FlowKind::Continue => None, } } } +fn try_kind_of_ty(ty: hir::Type, ctx: &AssistContext) -> Option { + if ty.is_unknown() { + // We favour Result for `expr?` + return Some(TryKind::Result { ty }); + } + let adt = ty.as_adt()?; + let name = adt.name(ctx.db()); + // FIXME: use lang items to determine if it is std type or user defined + // E.g. if user happens to define type named `Option`, we would have false positive + match name.to_string().as_str() { + "Option" => Some(TryKind::Option), + "Result" => Some(TryKind::Result { ty }), + _ => None, + } +} + #[derive(Debug)] enum RetType { Expr(hir::Type), @@ -851,7 +926,7 @@ fn format_replacement(ctx: &AssistContext, fun: &Function, indent: IndentLevel) let handler = FlowHandler::from_ret_ty(fun, &ret_ty); - let expr = handler.make_expr(call_expr).indent(indent); + let expr = handler.make_call_expr(call_expr).indent(indent); let mut buf = String::new(); match fun.vars_defined_in_body_and_outlive.as_slice() { @@ -877,6 +952,7 @@ fn format_replacement(ctx: &AssistContext, fun: &Function, indent: IndentLevel) enum FlowHandler { None, + Try { kind: TryKind }, If { action: FlowKind }, IfOption { action: FlowKind }, MatchOption { none: FlowKind }, @@ -897,6 +973,9 @@ impl FlowHandler { FlowKind::ReturnValue(_) | FlowKind::BreakValue(_) => { FlowHandler::IfOption { action } } + FlowKind::Try { kind } | FlowKind::TryReturn { kind, .. } => { + FlowHandler::Try { kind: kind.clone() } + } } } else { match flow_kind { @@ -906,17 +985,21 @@ impl FlowHandler { FlowKind::ReturnValue(_) | FlowKind::BreakValue(_) => { FlowHandler::MatchResult { err: action } } + FlowKind::Try { kind } | FlowKind::TryReturn { kind, .. } => { + FlowHandler::Try { kind: kind.clone() } + } } } } } } - fn make_expr(&self, call_expr: ast::Expr) -> ast::Expr { + fn make_call_expr(&self, call_expr: ast::Expr) -> ast::Expr { match self { FlowHandler::None => call_expr, + FlowHandler::Try { kind: _ } => make::expr_try(call_expr), FlowHandler::If { action } => { - let action = action.make_expr(None); + let action = action.make_result_handler(None); let stmt = make::expr_stmt(action); let block = make::block_expr(iter::once(stmt.into()), None); let condition = make::condition(call_expr, None); @@ -928,7 +1011,7 @@ impl FlowHandler { let pattern = make::tuple_struct_pat(path, iter::once(value_pat.into())); let cond = make::condition(call_expr, Some(pattern.into())); let value = make::expr_path(make_path_from_text("value")); - let action_expr = action.make_expr(Some(value)); + let action_expr = action.make_result_handler(Some(value)); let action_stmt = make::expr_stmt(action_expr); let then = make::block_expr(iter::once(action_stmt.into()), None); make::expr_if(cond, then, None) @@ -946,7 +1029,7 @@ impl FlowHandler { let none_arm = { let path = make_path_from_text("None"); let pat = make::path_pat(path); - make::match_arm(iter::once(pat), none.make_expr(None)) + make::match_arm(iter::once(pat), none.make_result_handler(None)) }; let arms = make::match_arm_list(vec![some_arm, none_arm]); make::expr_match(call_expr, arms) @@ -967,7 +1050,7 @@ impl FlowHandler { let value_pat = make::ident_pat(make::name(err_name)); let pat = make::tuple_struct_pat(path, iter::once(value_pat.into())); let value = make::expr_path(make_path_from_text(err_name)); - make::match_arm(iter::once(pat.into()), err.make_expr(Some(value))) + make::match_arm(iter::once(pat.into()), err.make_result_handler(Some(value))) }; let arms = make::match_arm_list(vec![ok_arm, err_arm]); make::expr_match(call_expr, arms) @@ -1035,14 +1118,25 @@ impl FunType { } fn make_ret_ty(ctx: &AssistContext, module: hir::Module, fun: &Function) -> Option { - let ty = fun.return_type(ctx); - let handler = FlowHandler::from_ret_ty(fun, &ty); + let fun_ty = fun.return_type(ctx); + let handler = FlowHandler::from_ret_ty(fun, &fun_ty); let ret_ty = match &handler { FlowHandler::None => { - if matches!(ty, FunType::Unit) { + if matches!(fun_ty, FunType::Unit) { return None; } - ty.make_ty(ctx, module) + fun_ty.make_ty(ctx, module) + } + FlowHandler::Try { kind: TryKind::Option } => { + make::ty_generic(make::name_ref("Option"), iter::once(fun_ty.make_ty(ctx, module))) + } + FlowHandler::Try { kind: TryKind::Result { ty: parent_ret_ty } } => { + let handler_ty = + result_err_ty(parent_ret_ty, ctx, module).unwrap_or_else(make::ty_unit); + make::ty_generic( + make::name_ref("Result"), + vec![fun_ty.make_ty(ctx, module), handler_ty], + ) } FlowHandler::If { .. } => make::ty("bool"), FlowHandler::IfOption { action } => { @@ -1053,17 +1147,42 @@ fn make_ret_ty(ctx: &AssistContext, module: hir::Module, fun: &Function) -> Opti make::ty_generic(make::name_ref("Option"), iter::once(handler_ty)) } FlowHandler::MatchOption { .. } => { - make::ty_generic(make::name_ref("Option"), iter::once(ty.make_ty(ctx, module))) + make::ty_generic(make::name_ref("Option"), iter::once(fun_ty.make_ty(ctx, module))) } FlowHandler::MatchResult { err } => { let handler_ty = err.expr_ty(ctx).map(|ty| make_ty(&ty, ctx, module)).unwrap_or_else(make::ty_unit); - make::ty_generic(make::name_ref("Result"), vec![ty.make_ty(ctx, module), handler_ty]) + make::ty_generic( + make::name_ref("Result"), + vec![fun_ty.make_ty(ctx, module), handler_ty], + ) } }; Some(make::ret_type(ret_ty)) } +/// Extract `E` type from `Result` +fn result_err_ty( + parent_ret_ty: &hir::Type, + ctx: &AssistContext, + module: hir::Module, +) -> Option { + // FIXME: use hir to extract argument information + // currently we use `format -> take part -> parse` + let path_ty = match make_ty(&parent_ret_ty, ctx, module) { + ast::Type::PathType(path_ty) => path_ty, + _ => return None, + }; + let arg_list = path_ty.path()?.segment()?.generic_arg_list()?; + let err_arg = arg_list.generic_args().nth(1)?; + let type_arg = match err_arg { + ast::GenericArg::TypeArg(type_arg) => type_arg, + _ => return None, + }; + + type_arg.ty() +} + fn make_body( ctx: &AssistContext, old_indent: IndentLevel, @@ -1128,6 +1247,18 @@ fn make_body( let block = match &handler { FlowHandler::None => block, + FlowHandler::Try { kind } => { + let block = with_default_tail_expr(block, make::expr_unit()); + map_tail_expr(block, |tail_expr| { + let constructor = match kind { + TryKind::Option => "Some", + TryKind::Result { .. } => "Ok", + }; + let func = make::expr_path(make_path_from_text(constructor)); + let args = make::arg_list(iter::once(tail_expr)); + make::expr_call(func, args) + }) + } FlowHandler::If { .. } => { let lit_false = ast::Literal::cast(make::tokens::literal("false").parent()).unwrap(); with_tail_expr(block, lit_false.into()) @@ -1142,9 +1273,9 @@ fn make_body( make::expr_call(some, args) }), FlowHandler::MatchResult { .. } => map_tail_expr(block, |tail_expr| { - let some = make::expr_path(make_path_from_text("Ok")); + let ok = make::expr_path(make_path_from_text("Ok")); let args = make::arg_list(iter::once(tail_expr)); - make::expr_call(some, args) + make::expr_call(ok, args) }), }; @@ -1159,6 +1290,13 @@ fn map_tail_expr(block: ast::BlockExpr, f: impl FnOnce(ast::Expr) -> ast::Expr) make::block_expr(block.statements(), Some(f(tail_expr))) } +fn with_default_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr { + match block.tail_expr() { + Some(_) => block, + None => make::block_expr(block.statements(), Some(tail_expr)), + } +} + fn with_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr { let stmt_tail = block.tail_expr().map(|expr| make::expr_stmt(expr).into()); let stmts = block.statements().chain(stmt_tail); @@ -1295,7 +1433,7 @@ fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) -> S fn make_rewritten_flow(handler: &FlowHandler, arg_expr: Option) -> Option { let value = match handler { - FlowHandler::None => return None, + FlowHandler::None | FlowHandler::Try { .. } => return None, FlowHandler::If { .. } => { ast::Literal::cast(make::tokens::literal("true").parent()).unwrap().into() } @@ -3036,6 +3174,185 @@ fn $0fun_name() -> Result { } let m = k + 1; Ok(m) +}"##, + ); + } + + #[test] + fn try_option() { + check_assist( + extract_function, + r##" +enum Option { None, Some(T), } +use Option::*; +fn bar() -> Option { None } +fn foo() -> Option<()> { + let n = bar()?; + $0let k = foo()?; + let m = k + 1;$0 + let h = 1 + m; + Some(()) +}"##, + r##" +enum Option { None, Some(T), } +use Option::*; +fn bar() -> Option { None } +fn foo() -> Option<()> { + let n = bar()?; + let m = fun_name()?; + let h = 1 + m; + Some(()) +} + +fn $0fun_name() -> Option { + let k = foo()?; + let m = k + 1; + Some(m) +}"##, + ); + } + + #[test] + fn try_option_unit() { + check_assist( + extract_function, + r##" +enum Option { None, Some(T), } +use Option::*; +fn foo() -> Option<()> { + let n = 1; + $0let k = foo()?; + let m = k + 1;$0 + let h = 1 + n; + Some(()) +}"##, + r##" +enum Option { None, Some(T), } +use Option::*; +fn foo() -> Option<()> { + let n = 1; + fun_name()?; + let h = 1 + n; + Some(()) +} + +fn $0fun_name() -> Option<()> { + let k = foo()?; + let m = k + 1; + Some(()) +}"##, + ); + } + + #[test] + fn try_result() { + check_assist( + extract_function, + r##" +enum Result { Ok(T), Err(E), } +use Result::*; +fn foo() -> Result<(), i64> { + let n = 1; + $0let k = foo()?; + let m = k + 1;$0 + let h = 1 + m; + Ok(()) +}"##, + r##" +enum Result { Ok(T), Err(E), } +use Result::*; +fn foo() -> Result<(), i64> { + let n = 1; + let m = fun_name()?; + let h = 1 + m; + Ok(()) +} + +fn $0fun_name() -> Result { + let k = foo()?; + let m = k + 1; + Ok(m) +}"##, + ); + } + + #[test] + fn try_result_with_return() { + check_assist( + extract_function, + r##" +enum Result { Ok(T), Err(E), } +use Result::*; +fn foo() -> Result<(), i64> { + let n = 1; + $0let k = foo()?; + if k == 42 { + return Err(1); + } + let m = k + 1;$0 + let h = 1 + m; + Ok(()) +}"##, + r##" +enum Result { Ok(T), Err(E), } +use Result::*; +fn foo() -> Result<(), i64> { + let n = 1; + let m = fun_name()?; + let h = 1 + m; + Ok(()) +} + +fn $0fun_name() -> Result { + let k = foo()?; + if k == 42 { + return Err(1); + } + let m = k + 1; + Ok(m) +}"##, + ); + } + + #[test] + fn try_and_break() { + mark::check!(external_control_flow_try_and_bc); + check_assist_not_applicable( + extract_function, + r##" +enum Option { None, Some(T) } +use Option::*; +fn foo() -> Option<()> { + loop { + let n = Some(1); + $0let m = n? + 1; + break; + let k = 2; + let k = k + 1;$0 + let r = n + k; + } + Some(()) +}"##, + ); + } + + #[test] + fn try_and_return_ok() { + mark::check!(external_control_flow_try_and_return_non_err); + check_assist_not_applicable( + extract_function, + r##" +enum Result { Ok(T), Err(E), } +use Result::*; +fn foo() -> Result<(), i64> { + let n = 1; + $0let k = foo()?; + if k == 42 { + return Ok(1); + } + let m = k + 1;$0 + let h = 1 + m; + Ok(()) }"##, ); } -- cgit v1.2.3