From 0a7c8512ffec9b6cf695f546ac5f4f297c92fa53 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 6 Nov 2020 02:13:29 +0100 Subject: Support closures in infer_function_return_type assist --- .../src/handlers/infer_function_return_type.rs | 133 ++++++++++++++++++--- 1 file changed, 114 insertions(+), 19 deletions(-) (limited to 'crates/assists/src') diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs index da60ff9de..f363a56f3 100644 --- a/crates/assists/src/handlers/infer_function_return_type.rs +++ b/crates/assists/src/handlers/infer_function_return_type.rs @@ -1,12 +1,12 @@ use hir::HirDisplay; -use syntax::{ast, AstNode, TextSize}; +use syntax::{ast, AstNode, SyntaxToken, TextSize}; use test_utils::mark; use crate::{AssistContext, AssistId, AssistKind, Assists}; // Assist: infer_function_return_type // -// Adds the return type to a function inferred from its tail expression if it doesn't have a return +// Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return // type specified. // // ``` @@ -18,36 +18,52 @@ use crate::{AssistContext, AssistId, AssistKind, Assists}; // ``` pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { let expr = ctx.find_node_at_offset::()?; - let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; - - if func.ret_type().is_some() { - mark::hit!(existing_ret_type); - return None; - } - let body = func.body()?; - let tail_expr = body.expr()?; - // check whether the expr we were at is indeed the tail expression - if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) { - mark::hit!(not_tail_expr); - return None; - } - let module = ctx.sema.scope(func.syntax()).module()?; + let (tail_expr, insert_pos) = extract(expr)?; + let module = ctx.sema.scope(tail_expr.syntax()).module()?; let ty = ctx.sema.type_of_expr(&tail_expr)?; let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; - let rparen = func.param_list()?.r_paren_token()?; acc.add( AssistId("change_return_type_to_result", AssistKind::RefactorRewrite), "Wrap return type in Result", tail_expr.syntax().text_range(), |builder| { - let insert_pos = rparen.text_range().end() + TextSize::from(1); - + let insert_pos = insert_pos.text_range().end() + TextSize::from(1); builder.insert(insert_pos, &format!("-> {} ", ty)); }, ) } +fn extract(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken)> { + let (ret_ty, tail_expr, insert_pos) = + if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) { + let tail_expr = match closure.body()? { + ast::Expr::BlockExpr(block) => block.expr()?, + body => body, + }; + let ret_ty = closure.ret_type(); + let rpipe = closure.param_list()?.syntax().last_token()?; + (ret_ty, tail_expr, rpipe) + } else { + let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; + let tail_expr = func.body()?.expr()?; + let ret_ty = func.ret_type(); + let rparen = func.param_list()?.r_paren_token()?; + (ret_ty, tail_expr, rparen) + }; + if ret_ty.is_some() { + mark::hit!(existing_ret_type); + mark::hit!(existing_ret_type_closure); + return None; + } + // check whether the expr we were at is indeed the tail expression + if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) { + mark::hit!(not_tail_expr); + return None; + } + Some((tail_expr, insert_pos)) +} + #[cfg(test)] mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; @@ -110,4 +126,83 @@ mod tests { }"#, ); } + + #[test] + fn infer_return_type_closure_block() { + check_assist( + infer_function_return_type, + r#"fn foo() { + |x: i32| { + x<|> + }; + }"#, + r#"fn foo() { + |x: i32| -> i32 { + x + }; + }"#, + ); + } + + #[test] + fn infer_return_type_closure() { + check_assist( + infer_function_return_type, + r#"fn foo() { + |x: i32| x<|>; + }"#, + r#"fn foo() { + |x: i32| -> i32 x; + }"#, + ); + } + + #[test] + fn infer_return_type_nested_closure() { + check_assist( + infer_function_return_type, + r#"fn foo() { + || { + if true { + 3<|> + } else { + 5 + } + } + }"#, + r#"fn foo() { + || -> i32 { + if true { + 3 + } else { + 5 + } + } + }"#, + ); + } + + #[test] + fn not_applicable_ret_type_specified_closure() { + mark::check!(existing_ret_type_closure); + check_assist_not_applicable( + infer_function_return_type, + r#"fn foo() { + || -> i32 { 3<|> } + }"#, + ); + } + + #[test] + fn not_applicable_non_tail_expr_closure() { + check_assist_not_applicable( + infer_function_return_type, + r#"fn foo() { + || -> i32 { + let x = 3<|>; + 6 + } + }"#, + ); + } } -- cgit v1.2.3