From a14df19d825152aff823fae3344f9e4c2d31937b Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 6 Nov 2020 01:47:41 +0100 Subject: Add infer_function_return_type assist --- .../src/handlers/infer_function_return_type.rs | 113 +++++++++++++++++++++ crates/assists/src/lib.rs | 2 + crates/assists/src/tests/generated.rs | 13 +++ 3 files changed, 128 insertions(+) create mode 100644 crates/assists/src/handlers/infer_function_return_type.rs (limited to 'crates/assists') diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs new file mode 100644 index 000000000..da60ff9de --- /dev/null +++ b/crates/assists/src/handlers/infer_function_return_type.rs @@ -0,0 +1,113 @@ +use hir::HirDisplay; +use syntax::{ast, AstNode, TextSize}; +use test_utils::mark; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: infer_function_return_type +// +// Adds the return type to a function inferred from its tail expression if it doesn't have a return +// type specified. +// +// ``` +// fn foo() { 4<|>2i32 } +// ``` +// -> +// ``` +// fn foo() -> i32 { 42i32 } +// ``` +pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let expr = ctx.find_node_at_offset::()?; + let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?; + + if func.ret_type().is_some() { + mark::hit!(existing_ret_type); + return None; + } + let body = func.body()?; + let tail_expr = body.expr()?; + // check whether the expr we were at is indeed the tail expression + if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) { + mark::hit!(not_tail_expr); + return None; + } + let module = ctx.sema.scope(func.syntax()).module()?; + let ty = ctx.sema.type_of_expr(&tail_expr)?; + let ty = ty.display_source_code(ctx.db(), module.into()).ok()?; + let rparen = func.param_list()?.r_paren_token()?; + + acc.add( + AssistId("change_return_type_to_result", AssistKind::RefactorRewrite), + "Wrap return type in Result", + tail_expr.syntax().text_range(), + |builder| { + let insert_pos = rparen.text_range().end() + TextSize::from(1); + + builder.insert(insert_pos, &format!("-> {} ", ty)); + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn infer_return_type() { + check_assist( + infer_function_return_type, + r#"fn foo() { + 45<|> + }"#, + r#"fn foo() -> i32 { + 45 + }"#, + ); + } + + #[test] + fn infer_return_type_nested() { + check_assist( + infer_function_return_type, + r#"fn foo() { + if true { + 3<|> + } else { + 5 + } + }"#, + r#"fn foo() -> i32 { + if true { + 3 + } else { + 5 + } + }"#, + ); + } + + #[test] + fn not_applicable_ret_type_specified() { + mark::check!(existing_ret_type); + check_assist_not_applicable( + infer_function_return_type, + r#"fn foo() -> i32 { + ( 45<|> + 32 ) * 123 + }"#, + ); + } + + #[test] + fn not_applicable_non_tail_expr() { + mark::check!(not_tail_expr); + check_assist_not_applicable( + infer_function_return_type, + r#"fn foo() { + let x = <|>3; + ( 45 + 32 ) * 123 + }"#, + ); + } +} diff --git a/crates/assists/src/lib.rs b/crates/assists/src/lib.rs index b804e495d..af88b3437 100644 --- a/crates/assists/src/lib.rs +++ b/crates/assists/src/lib.rs @@ -143,6 +143,7 @@ mod handlers { mod generate_function; mod generate_impl; mod generate_new; + mod infer_function_return_type; mod inline_local_variable; mod introduce_named_lifetime; mod invert_if; @@ -190,6 +191,7 @@ mod handlers { generate_function::generate_function, generate_impl::generate_impl, generate_new::generate_new, + infer_function_return_type::infer_function_return_type, inline_local_variable::inline_local_variable, introduce_named_lifetime::introduce_named_lifetime, invert_if::invert_if, diff --git a/crates/assists/src/tests/generated.rs b/crates/assists/src/tests/generated.rs index acbf5b652..168e1626a 100644 --- a/crates/assists/src/tests/generated.rs +++ b/crates/assists/src/tests/generated.rs @@ -505,6 +505,19 @@ impl Ctx { ) } +#[test] +fn doctest_infer_function_return_type() { + check_doc_test( + "infer_function_return_type", + r#####" +fn foo() { 4<|>2i32 } +"#####, + r#####" +fn foo() -> i32 { 42i32 } +"#####, + ) +} + #[test] fn doctest_inline_local_variable() { check_doc_test( -- cgit v1.2.3