From fc34403018079ea053f26d0a31b7517053c7dd8c Mon Sep 17 00:00:00 2001 From: Aleksey Kladov Date: Thu, 13 Aug 2020 17:33:38 +0200 Subject: Rename ra_assists -> assists --- crates/assists/src/handlers/add_custom_impl.rs | 208 ++++ crates/assists/src/handlers/add_explicit_type.rs | 211 ++++ .../src/handlers/add_missing_impl_members.rs | 711 +++++++++++++ crates/assists/src/handlers/add_turbo_fish.rs | 164 +++ crates/assists/src/handlers/apply_demorgan.rs | 93 ++ crates/assists/src/handlers/auto_import.rs | 1088 ++++++++++++++++++++ .../src/handlers/change_return_type_to_result.rs | 998 ++++++++++++++++++ crates/assists/src/handlers/change_visibility.rs | 200 ++++ crates/assists/src/handlers/early_return.rs | 515 +++++++++ crates/assists/src/handlers/expand_glob_import.rs | 391 +++++++ .../handlers/extract_struct_from_enum_variant.rs | 317 ++++++ crates/assists/src/handlers/extract_variable.rs | 588 +++++++++++ crates/assists/src/handlers/fill_match_arms.rs | 747 ++++++++++++++ crates/assists/src/handlers/fix_visibility.rs | 607 +++++++++++ crates/assists/src/handlers/flip_binexpr.rs | 142 +++ crates/assists/src/handlers/flip_comma.rs | 84 ++ crates/assists/src/handlers/flip_trait_bound.rs | 121 +++ crates/assists/src/handlers/generate_derive.rs | 132 +++ .../src/handlers/generate_from_impl_for_enum.rs | 200 ++++ crates/assists/src/handlers/generate_function.rs | 1058 +++++++++++++++++++ crates/assists/src/handlers/generate_impl.rs | 110 ++ crates/assists/src/handlers/generate_new.rs | 421 ++++++++ .../assists/src/handlers/inline_local_variable.rs | 695 +++++++++++++ .../src/handlers/introduce_named_lifetime.rs | 318 ++++++ crates/assists/src/handlers/invert_if.rs | 109 ++ crates/assists/src/handlers/merge_imports.rs | 321 ++++++ crates/assists/src/handlers/merge_match_arms.rs | 248 +++++ crates/assists/src/handlers/move_bounds.rs | 152 +++ crates/assists/src/handlers/move_guard.rs | 293 ++++++ crates/assists/src/handlers/raw_string.rs | 504 +++++++++ crates/assists/src/handlers/remove_dbg.rs | 205 ++++ crates/assists/src/handlers/remove_mut.rs | 37 + crates/assists/src/handlers/reorder_fields.rs | 220 ++++ .../src/handlers/replace_if_let_with_match.rs | 257 +++++ .../src/handlers/replace_let_with_if_let.rs | 100 ++ .../handlers/replace_qualified_name_with_use.rs | 688 +++++++++++++ .../src/handlers/replace_unwrap_with_match.rs | 187 ++++ crates/assists/src/handlers/split_import.rs | 79 ++ crates/assists/src/handlers/unwrap_block.rs | 517 ++++++++++ 39 files changed, 14036 insertions(+) create mode 100644 crates/assists/src/handlers/add_custom_impl.rs create mode 100644 crates/assists/src/handlers/add_explicit_type.rs create mode 100644 crates/assists/src/handlers/add_missing_impl_members.rs create mode 100644 crates/assists/src/handlers/add_turbo_fish.rs create mode 100644 crates/assists/src/handlers/apply_demorgan.rs create mode 100644 crates/assists/src/handlers/auto_import.rs create mode 100644 crates/assists/src/handlers/change_return_type_to_result.rs create mode 100644 crates/assists/src/handlers/change_visibility.rs create mode 100644 crates/assists/src/handlers/early_return.rs create mode 100644 crates/assists/src/handlers/expand_glob_import.rs create mode 100644 crates/assists/src/handlers/extract_struct_from_enum_variant.rs create mode 100644 crates/assists/src/handlers/extract_variable.rs create mode 100644 crates/assists/src/handlers/fill_match_arms.rs create mode 100644 crates/assists/src/handlers/fix_visibility.rs create mode 100644 crates/assists/src/handlers/flip_binexpr.rs create mode 100644 crates/assists/src/handlers/flip_comma.rs create mode 100644 crates/assists/src/handlers/flip_trait_bound.rs create mode 100644 crates/assists/src/handlers/generate_derive.rs create mode 100644 crates/assists/src/handlers/generate_from_impl_for_enum.rs create mode 100644 crates/assists/src/handlers/generate_function.rs create mode 100644 crates/assists/src/handlers/generate_impl.rs create mode 100644 crates/assists/src/handlers/generate_new.rs create mode 100644 crates/assists/src/handlers/inline_local_variable.rs create mode 100644 crates/assists/src/handlers/introduce_named_lifetime.rs create mode 100644 crates/assists/src/handlers/invert_if.rs create mode 100644 crates/assists/src/handlers/merge_imports.rs create mode 100644 crates/assists/src/handlers/merge_match_arms.rs create mode 100644 crates/assists/src/handlers/move_bounds.rs create mode 100644 crates/assists/src/handlers/move_guard.rs create mode 100644 crates/assists/src/handlers/raw_string.rs create mode 100644 crates/assists/src/handlers/remove_dbg.rs create mode 100644 crates/assists/src/handlers/remove_mut.rs create mode 100644 crates/assists/src/handlers/reorder_fields.rs create mode 100644 crates/assists/src/handlers/replace_if_let_with_match.rs create mode 100644 crates/assists/src/handlers/replace_let_with_if_let.rs create mode 100644 crates/assists/src/handlers/replace_qualified_name_with_use.rs create mode 100644 crates/assists/src/handlers/replace_unwrap_with_match.rs create mode 100644 crates/assists/src/handlers/split_import.rs create mode 100644 crates/assists/src/handlers/unwrap_block.rs (limited to 'crates/assists/src/handlers') diff --git a/crates/assists/src/handlers/add_custom_impl.rs b/crates/assists/src/handlers/add_custom_impl.rs new file mode 100644 index 000000000..8757fa33f --- /dev/null +++ b/crates/assists/src/handlers/add_custom_impl.rs @@ -0,0 +1,208 @@ +use itertools::Itertools; +use syntax::{ + ast::{self, AstNode}, + Direction, SmolStr, + SyntaxKind::{IDENT, WHITESPACE}, + TextRange, TextSize, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: add_custom_impl +// +// Adds impl block for derived trait. +// +// ``` +// #[derive(Deb<|>ug, Display)] +// struct S; +// ``` +// -> +// ``` +// #[derive(Display)] +// struct S; +// +// impl Debug for S { +// $0 +// } +// ``` +pub(crate) fn add_custom_impl(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let attr = ctx.find_node_at_offset::()?; + let input = attr.token_tree()?; + + let attr_name = attr + .syntax() + .descendants_with_tokens() + .filter(|t| t.kind() == IDENT) + .find_map(|i| i.into_token()) + .filter(|t| *t.text() == "derive")? + .text() + .clone(); + + let trait_token = + ctx.token_at_offset().find(|t| t.kind() == IDENT && *t.text() != attr_name)?; + + let annotated = attr.syntax().siblings(Direction::Next).find_map(ast::Name::cast)?; + let annotated_name = annotated.syntax().text().to_string(); + let start_offset = annotated.syntax().parent()?.text_range().end(); + + let label = + format!("Add custom impl `{}` for `{}`", trait_token.text().as_str(), annotated_name); + + let target = attr.syntax().text_range(); + acc.add(AssistId("add_custom_impl", AssistKind::Refactor), label, target, |builder| { + let new_attr_input = input + .syntax() + .descendants_with_tokens() + .filter(|t| t.kind() == IDENT) + .filter_map(|t| t.into_token().map(|t| t.text().clone())) + .filter(|t| t != trait_token.text()) + .collect::>(); + let has_more_derives = !new_attr_input.is_empty(); + + if has_more_derives { + let new_attr_input = format!("({})", new_attr_input.iter().format(", ")); + builder.replace(input.syntax().text_range(), new_attr_input); + } else { + let attr_range = attr.syntax().text_range(); + builder.delete(attr_range); + + let line_break_range = attr + .syntax() + .next_sibling_or_token() + .filter(|t| t.kind() == WHITESPACE) + .map(|t| t.text_range()) + .unwrap_or_else(|| TextRange::new(TextSize::from(0), TextSize::from(0))); + builder.delete(line_break_range); + } + + match ctx.config.snippet_cap { + Some(cap) => { + builder.insert_snippet( + cap, + start_offset, + format!("\n\nimpl {} for {} {{\n $0\n}}", trait_token, annotated_name), + ); + } + None => { + builder.insert( + start_offset, + format!("\n\nimpl {} for {} {{\n\n}}", trait_token, annotated_name), + ); + } + } + }) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn add_custom_impl_for_unique_input() { + check_assist( + add_custom_impl, + " +#[derive(Debu<|>g)] +struct Foo { + bar: String, +} + ", + " +struct Foo { + bar: String, +} + +impl Debug for Foo { + $0 +} + ", + ) + } + + #[test] + fn add_custom_impl_for_with_visibility_modifier() { + check_assist( + add_custom_impl, + " +#[derive(Debug<|>)] +pub struct Foo { + bar: String, +} + ", + " +pub struct Foo { + bar: String, +} + +impl Debug for Foo { + $0 +} + ", + ) + } + + #[test] + fn add_custom_impl_when_multiple_inputs() { + check_assist( + add_custom_impl, + " +#[derive(Display, Debug<|>, Serialize)] +struct Foo {} + ", + " +#[derive(Display, Serialize)] +struct Foo {} + +impl Debug for Foo { + $0 +} + ", + ) + } + + #[test] + fn test_ignore_derive_macro_without_input() { + check_assist_not_applicable( + add_custom_impl, + " +#[derive(<|>)] +struct Foo {} + ", + ) + } + + #[test] + fn test_ignore_if_cursor_on_param() { + check_assist_not_applicable( + add_custom_impl, + " +#[derive<|>(Debug)] +struct Foo {} + ", + ); + + check_assist_not_applicable( + add_custom_impl, + " +#[derive(Debug)<|>] +struct Foo {} + ", + ) + } + + #[test] + fn test_ignore_if_not_derive() { + check_assist_not_applicable( + add_custom_impl, + " +#[allow(non_camel_<|>case_types)] +struct Foo {} + ", + ) + } +} diff --git a/crates/assists/src/handlers/add_explicit_type.rs b/crates/assists/src/handlers/add_explicit_type.rs new file mode 100644 index 000000000..563cbf505 --- /dev/null +++ b/crates/assists/src/handlers/add_explicit_type.rs @@ -0,0 +1,211 @@ +use hir::HirDisplay; +use syntax::{ + ast::{self, AstNode, LetStmt, NameOwner}, + TextRange, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: add_explicit_type +// +// Specify type for a let binding. +// +// ``` +// fn main() { +// let x<|> = 92; +// } +// ``` +// -> +// ``` +// fn main() { +// let x: i32 = 92; +// } +// ``` +pub(crate) fn add_explicit_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let let_stmt = ctx.find_node_at_offset::()?; + let module = ctx.sema.scope(let_stmt.syntax()).module()?; + let expr = let_stmt.initializer()?; + // Must be a binding + let pat = match let_stmt.pat()? { + ast::Pat::IdentPat(bind_pat) => bind_pat, + _ => return None, + }; + let pat_range = pat.syntax().text_range(); + // The binding must have a name + let name = pat.name()?; + let name_range = name.syntax().text_range(); + let stmt_range = let_stmt.syntax().text_range(); + let eq_range = let_stmt.eq_token()?.text_range(); + // Assist should only be applicable if cursor is between 'let' and '=' + let let_range = TextRange::new(stmt_range.start(), eq_range.start()); + let cursor_in_range = let_range.contains_range(ctx.frange.range); + if !cursor_in_range { + return None; + } + // Assist not applicable if the type has already been specified + // and it has no placeholders + let ascribed_ty = let_stmt.ty(); + if let Some(ty) = &ascribed_ty { + if ty.syntax().descendants().find_map(ast::InferType::cast).is_none() { + return None; + } + } + // Infer type + let ty = ctx.sema.type_of_expr(&expr)?; + + if ty.contains_unknown() || ty.is_closure() { + return None; + } + + let inferred_type = ty.display_source_code(ctx.db(), module.into()).ok()?; + acc.add( + AssistId("add_explicit_type", AssistKind::RefactorRewrite), + format!("Insert explicit type `{}`", inferred_type), + pat_range, + |builder| match ascribed_ty { + Some(ascribed_ty) => { + builder.replace(ascribed_ty.syntax().text_range(), inferred_type); + } + None => { + builder.insert(name_range.end(), format!(": {}", inferred_type)); + } + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn add_explicit_type_target() { + check_assist_target(add_explicit_type, "fn f() { let a<|> = 1; }", "a"); + } + + #[test] + fn add_explicit_type_works_for_simple_expr() { + check_assist(add_explicit_type, "fn f() { let a<|> = 1; }", "fn f() { let a: i32 = 1; }"); + } + + #[test] + fn add_explicit_type_works_for_underscore() { + check_assist( + add_explicit_type, + "fn f() { let a<|>: _ = 1; }", + "fn f() { let a: i32 = 1; }", + ); + } + + #[test] + fn add_explicit_type_works_for_nested_underscore() { + check_assist( + add_explicit_type, + r#" + enum Option { + Some(T), + None + } + + fn f() { + let a<|>: Option<_> = Option::Some(1); + }"#, + r#" + enum Option { + Some(T), + None + } + + fn f() { + let a: Option = Option::Some(1); + }"#, + ); + } + + #[test] + fn add_explicit_type_works_for_macro_call() { + check_assist( + add_explicit_type, + r"macro_rules! v { () => {0u64} } fn f() { let a<|> = v!(); }", + r"macro_rules! v { () => {0u64} } fn f() { let a: u64 = v!(); }", + ); + } + + #[test] + fn add_explicit_type_works_for_macro_call_recursive() { + check_assist( + add_explicit_type, + r#"macro_rules! u { () => {0u64} } macro_rules! v { () => {u!()} } fn f() { let a<|> = v!(); }"#, + r#"macro_rules! u { () => {0u64} } macro_rules! v { () => {u!()} } fn f() { let a: u64 = v!(); }"#, + ); + } + + #[test] + fn add_explicit_type_not_applicable_if_ty_not_inferred() { + check_assist_not_applicable(add_explicit_type, "fn f() { let a<|> = None; }"); + } + + #[test] + fn add_explicit_type_not_applicable_if_ty_already_specified() { + check_assist_not_applicable(add_explicit_type, "fn f() { let a<|>: i32 = 1; }"); + } + + #[test] + fn add_explicit_type_not_applicable_if_specified_ty_is_tuple() { + check_assist_not_applicable(add_explicit_type, "fn f() { let a<|>: (i32, i32) = (3, 4); }"); + } + + #[test] + fn add_explicit_type_not_applicable_if_cursor_after_equals() { + check_assist_not_applicable( + add_explicit_type, + "fn f() {let a =<|> match 1 {2 => 3, 3 => 5};}", + ) + } + + #[test] + fn add_explicit_type_not_applicable_if_cursor_before_let() { + check_assist_not_applicable( + add_explicit_type, + "fn f() <|>{let a = match 1 {2 => 3, 3 => 5};}", + ) + } + + #[test] + fn closure_parameters_are_not_added() { + check_assist_not_applicable( + add_explicit_type, + r#" +fn main() { + let multiply_by_two<|> = |i| i * 3; + let six = multiply_by_two(2); +}"#, + ) + } + + #[test] + fn default_generics_should_not_be_added() { + check_assist( + add_explicit_type, + r#" +struct Test { + k: K, + t: T, +} + +fn main() { + let test<|> = Test { t: 23u8, k: 33 }; +}"#, + r#" +struct Test { + k: K, + t: T, +} + +fn main() { + let test: Test = Test { t: 23u8, k: 33 }; +}"#, + ); + } +} diff --git a/crates/assists/src/handlers/add_missing_impl_members.rs b/crates/assists/src/handlers/add_missing_impl_members.rs new file mode 100644 index 000000000..81b61ebf8 --- /dev/null +++ b/crates/assists/src/handlers/add_missing_impl_members.rs @@ -0,0 +1,711 @@ +use hir::HasSource; +use syntax::{ + ast::{ + self, + edit::{self, AstNodeEdit, IndentLevel}, + make, AstNode, NameOwner, + }, + SmolStr, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + ast_transform::{self, AstTransform, QualifyPaths, SubstituteTypeParams}, + utils::{get_missing_assoc_items, render_snippet, resolve_target_trait, Cursor}, + AssistId, AssistKind, +}; + +#[derive(PartialEq)] +enum AddMissingImplMembersMode { + DefaultMethodsOnly, + NoDefaultMethods, +} + +// Assist: add_impl_missing_members +// +// Adds scaffold for required impl members. +// +// ``` +// trait Trait { +// Type X; +// fn foo(&self) -> T; +// fn bar(&self) {} +// } +// +// impl Trait for () {<|> +// +// } +// ``` +// -> +// ``` +// trait Trait { +// Type X; +// fn foo(&self) -> T; +// fn bar(&self) {} +// } +// +// impl Trait for () { +// fn foo(&self) -> u32 { +// ${0:todo!()} +// } +// +// } +// ``` +pub(crate) fn add_missing_impl_members(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + add_missing_impl_members_inner( + acc, + ctx, + AddMissingImplMembersMode::NoDefaultMethods, + "add_impl_missing_members", + "Implement missing members", + ) +} + +// Assist: add_impl_default_members +// +// Adds scaffold for overriding default impl members. +// +// ``` +// trait Trait { +// Type X; +// fn foo(&self); +// fn bar(&self) {} +// } +// +// impl Trait for () { +// Type X = (); +// fn foo(&self) {}<|> +// +// } +// ``` +// -> +// ``` +// trait Trait { +// Type X; +// fn foo(&self); +// fn bar(&self) {} +// } +// +// impl Trait for () { +// Type X = (); +// fn foo(&self) {} +// $0fn bar(&self) {} +// +// } +// ``` +pub(crate) fn add_missing_default_members(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + add_missing_impl_members_inner( + acc, + ctx, + AddMissingImplMembersMode::DefaultMethodsOnly, + "add_impl_default_members", + "Implement default members", + ) +} + +fn add_missing_impl_members_inner( + acc: &mut Assists, + ctx: &AssistContext, + mode: AddMissingImplMembersMode, + assist_id: &'static str, + label: &'static str, +) -> Option<()> { + let _p = profile::span("add_missing_impl_members_inner"); + let impl_def = ctx.find_node_at_offset::()?; + let impl_item_list = impl_def.assoc_item_list()?; + + let trait_ = resolve_target_trait(&ctx.sema, &impl_def)?; + + let def_name = |item: &ast::AssocItem| -> Option { + match item { + ast::AssocItem::Fn(def) => def.name(), + ast::AssocItem::TypeAlias(def) => def.name(), + ast::AssocItem::Const(def) => def.name(), + ast::AssocItem::MacroCall(_) => None, + } + .map(|it| it.text().clone()) + }; + + let missing_items = get_missing_assoc_items(&ctx.sema, &impl_def) + .iter() + .map(|i| match i { + hir::AssocItem::Function(i) => ast::AssocItem::Fn(i.source(ctx.db()).value), + hir::AssocItem::TypeAlias(i) => ast::AssocItem::TypeAlias(i.source(ctx.db()).value), + hir::AssocItem::Const(i) => ast::AssocItem::Const(i.source(ctx.db()).value), + }) + .filter(|t| def_name(&t).is_some()) + .filter(|t| match t { + ast::AssocItem::Fn(def) => match mode { + AddMissingImplMembersMode::DefaultMethodsOnly => def.body().is_some(), + AddMissingImplMembersMode::NoDefaultMethods => def.body().is_none(), + }, + _ => mode == AddMissingImplMembersMode::NoDefaultMethods, + }) + .collect::>(); + + if missing_items.is_empty() { + return None; + } + + let target = impl_def.syntax().text_range(); + acc.add(AssistId(assist_id, AssistKind::QuickFix), label, target, |builder| { + let n_existing_items = impl_item_list.assoc_items().count(); + let source_scope = ctx.sema.scope_for_def(trait_); + let target_scope = ctx.sema.scope(impl_item_list.syntax()); + let ast_transform = QualifyPaths::new(&target_scope, &source_scope) + .or(SubstituteTypeParams::for_trait_impl(&source_scope, trait_, impl_def)); + let items = missing_items + .into_iter() + .map(|it| ast_transform::apply(&*ast_transform, it)) + .map(|it| match it { + ast::AssocItem::Fn(def) => ast::AssocItem::Fn(add_body(def)), + ast::AssocItem::TypeAlias(def) => ast::AssocItem::TypeAlias(def.remove_bounds()), + _ => it, + }) + .map(|it| edit::remove_attrs_and_docs(&it)); + let new_impl_item_list = impl_item_list.append_items(items); + let first_new_item = new_impl_item_list.assoc_items().nth(n_existing_items).unwrap(); + + let original_range = impl_item_list.syntax().text_range(); + match ctx.config.snippet_cap { + None => builder.replace(original_range, new_impl_item_list.to_string()), + Some(cap) => { + let mut cursor = Cursor::Before(first_new_item.syntax()); + let placeholder; + if let ast::AssocItem::Fn(func) = &first_new_item { + if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast) { + if m.syntax().text() == "todo!()" { + placeholder = m; + cursor = Cursor::Replace(placeholder.syntax()); + } + } + } + builder.replace_snippet( + cap, + original_range, + render_snippet(cap, new_impl_item_list.syntax(), cursor), + ) + } + }; + }) +} + +fn add_body(fn_def: ast::Fn) -> ast::Fn { + if fn_def.body().is_some() { + return fn_def; + } + let body = make::block_expr(None, Some(make::expr_todo())).indent(IndentLevel(1)); + fn_def.with_body(body) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_add_missing_impl_members() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { + type Output; + + const CONST: usize = 42; + + fn foo(&self); + fn bar(&self); + fn baz(&self); +} + +struct S; + +impl Foo for S { + fn bar(&self) {} +<|> +}"#, + r#" +trait Foo { + type Output; + + const CONST: usize = 42; + + fn foo(&self); + fn bar(&self); + fn baz(&self); +} + +struct S; + +impl Foo for S { + fn bar(&self) {} + $0type Output; + const CONST: usize = 42; + fn foo(&self) { + todo!() + } + fn baz(&self) { + todo!() + } + +}"#, + ); + } + + #[test] + fn test_copied_overriden_members() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { + fn foo(&self); + fn bar(&self) -> bool { true } + fn baz(&self) -> u32 { 42 } +} + +struct S; + +impl Foo for S { + fn bar(&self) {} +<|> +}"#, + r#" +trait Foo { + fn foo(&self); + fn bar(&self) -> bool { true } + fn baz(&self) -> u32 { 42 } +} + +struct S; + +impl Foo for S { + fn bar(&self) {} + fn foo(&self) { + ${0:todo!()} + } + +}"#, + ); + } + + #[test] + fn test_empty_impl_def() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { fn foo(&self); } +struct S; +impl Foo for S { <|> }"#, + r#" +trait Foo { fn foo(&self); } +struct S; +impl Foo for S { + fn foo(&self) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn fill_in_type_params_1() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { fn foo(&self, t: T) -> &T; } +struct S; +impl Foo for S { <|> }"#, + r#" +trait Foo { fn foo(&self, t: T) -> &T; } +struct S; +impl Foo for S { + fn foo(&self, t: u32) -> &u32 { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn fill_in_type_params_2() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { fn foo(&self, t: T) -> &T; } +struct S; +impl Foo for S { <|> }"#, + r#" +trait Foo { fn foo(&self, t: T) -> &T; } +struct S; +impl Foo for S { + fn foo(&self, t: U) -> &U { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_cursor_after_empty_impl_def() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { fn foo(&self); } +struct S; +impl Foo for S {}<|>"#, + r#" +trait Foo { fn foo(&self); } +struct S; +impl Foo for S { + fn foo(&self) { + ${0:todo!()} + } +}"#, + ) + } + + #[test] + fn test_qualify_path_1() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { <|> }"#, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: foo::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_generic() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { <|> }"#, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: foo::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_and_substitute_param() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { <|> }"#, + r#" +mod foo { + pub struct Bar; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: foo::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_substitute_param_no_qualify() { + // when substituting params, the substituted param should not be qualified! + check_assist( + add_missing_impl_members, + r#" +mod foo { + trait Foo { fn foo(&self, bar: T); } + pub struct Param; +} +struct Param; +struct S; +impl foo::Foo for S { <|> }"#, + r#" +mod foo { + trait Foo { fn foo(&self, bar: T); } + pub struct Param; +} +struct Param; +struct S; +impl foo::Foo for S { + fn foo(&self, bar: Param) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_associated_item() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub struct Bar; + impl Bar { type Assoc = u32; } + trait Foo { fn foo(&self, bar: Bar::Assoc); } +} +struct S; +impl foo::Foo for S { <|> }"#, + r#" +mod foo { + pub struct Bar; + impl Bar { type Assoc = u32; } + trait Foo { fn foo(&self, bar: Bar::Assoc); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: foo::Bar::Assoc) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_nested() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub struct Bar; + pub struct Baz; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { <|> }"#, + r#" +mod foo { + pub struct Bar; + pub struct Baz; + trait Foo { fn foo(&self, bar: Bar); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: foo::Bar) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_qualify_path_fn_trait_notation() { + check_assist( + add_missing_impl_members, + r#" +mod foo { + pub trait Fn { type Output; } + trait Foo { fn foo(&self, bar: dyn Fn(u32) -> i32); } +} +struct S; +impl foo::Foo for S { <|> }"#, + r#" +mod foo { + pub trait Fn { type Output; } + trait Foo { fn foo(&self, bar: dyn Fn(u32) -> i32); } +} +struct S; +impl foo::Foo for S { + fn foo(&self, bar: dyn Fn(u32) -> i32) { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_empty_trait() { + check_assist_not_applicable( + add_missing_impl_members, + r#" +trait Foo; +struct S; +impl Foo for S { <|> }"#, + ) + } + + #[test] + fn test_ignore_unnamed_trait_members_and_default_methods() { + check_assist_not_applicable( + add_missing_impl_members, + r#" +trait Foo { + fn (arg: u32); + fn valid(some: u32) -> bool { false } +} +struct S; +impl Foo for S { <|> }"#, + ) + } + + #[test] + fn test_with_docstring_and_attrs() { + check_assist( + add_missing_impl_members, + r#" +#[doc(alias = "test alias")] +trait Foo { + /// doc string + type Output; + + #[must_use] + fn foo(&self); +} +struct S; +impl Foo for S {}<|>"#, + r#" +#[doc(alias = "test alias")] +trait Foo { + /// doc string + type Output; + + #[must_use] + fn foo(&self); +} +struct S; +impl Foo for S { + $0type Output; + fn foo(&self) { + todo!() + } +}"#, + ) + } + + #[test] + fn test_default_methods() { + check_assist( + add_missing_default_members, + r#" +trait Foo { + type Output; + + const CONST: usize = 42; + + fn valid(some: u32) -> bool { false } + fn foo(some: u32) -> bool; +} +struct S; +impl Foo for S { <|> }"#, + r#" +trait Foo { + type Output; + + const CONST: usize = 42; + + fn valid(some: u32) -> bool { false } + fn foo(some: u32) -> bool; +} +struct S; +impl Foo for S { + $0fn valid(some: u32) -> bool { false } +}"#, + ) + } + + #[test] + fn test_generic_single_default_parameter() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { + fn bar(&self, other: &T); +} + +struct S; +impl Foo for S { <|> }"#, + r#" +trait Foo { + fn bar(&self, other: &T); +} + +struct S; +impl Foo for S { + fn bar(&self, other: &Self) { + ${0:todo!()} + } +}"#, + ) + } + + #[test] + fn test_generic_default_parameter_is_second() { + check_assist( + add_missing_impl_members, + r#" +trait Foo { + fn bar(&self, this: &T1, that: &T2); +} + +struct S; +impl Foo for S { <|> }"#, + r#" +trait Foo { + fn bar(&self, this: &T1, that: &T2); +} + +struct S; +impl Foo for S { + fn bar(&self, this: &T, that: &Self) { + ${0:todo!()} + } +}"#, + ) + } + + #[test] + fn test_assoc_type_bounds_are_removed() { + check_assist( + add_missing_impl_members, + r#" +trait Tr { + type Ty: Copy + 'static; +} + +impl Tr for ()<|> { +}"#, + r#" +trait Tr { + type Ty: Copy + 'static; +} + +impl Tr for () { + $0type Ty; +}"#, + ) + } +} diff --git a/crates/assists/src/handlers/add_turbo_fish.rs b/crates/assists/src/handlers/add_turbo_fish.rs new file mode 100644 index 000000000..f4f997d8e --- /dev/null +++ b/crates/assists/src/handlers/add_turbo_fish.rs @@ -0,0 +1,164 @@ +use ide_db::defs::{classify_name_ref, Definition, NameRefClass}; +use syntax::{ast, AstNode, SyntaxKind, T}; +use test_utils::mark; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: add_turbo_fish +// +// Adds `::<_>` to a call of a generic method or function. +// +// ``` +// fn make() -> T { todo!() } +// fn main() { +// let x = make<|>(); +// } +// ``` +// -> +// ``` +// fn make() -> T { todo!() } +// fn main() { +// let x = make::<${0:_}>(); +// } +// ``` +pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let ident = ctx.find_token_at_offset(SyntaxKind::IDENT).or_else(|| { + let arg_list = ctx.find_node_at_offset::()?; + if arg_list.args().count() > 0 { + return None; + } + mark::hit!(add_turbo_fish_after_call); + arg_list.l_paren_token()?.prev_token().filter(|it| it.kind() == SyntaxKind::IDENT) + })?; + let next_token = ident.next_token()?; + if next_token.kind() == T![::] { + mark::hit!(add_turbo_fish_one_fish_is_enough); + return None; + } + let name_ref = ast::NameRef::cast(ident.parent())?; + let def = match classify_name_ref(&ctx.sema, &name_ref)? { + NameRefClass::Definition(def) => def, + NameRefClass::ExternCrate(_) | NameRefClass::FieldShorthand { .. } => return None, + }; + let fun = match def { + Definition::ModuleDef(hir::ModuleDef::Function(it)) => it, + _ => return None, + }; + let generics = hir::GenericDef::Function(fun).params(ctx.sema.db); + if generics.is_empty() { + mark::hit!(add_turbo_fish_non_generic); + return None; + } + acc.add( + AssistId("add_turbo_fish", AssistKind::RefactorRewrite), + "Add `::<>`", + ident.text_range(), + |builder| match ctx.config.snippet_cap { + Some(cap) => builder.insert_snippet(cap, ident.text_range().end(), "::<${0:_}>"), + None => builder.insert(ident.text_range().end(), "::<_>"), + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + use test_utils::mark; + + #[test] + fn add_turbo_fish_function() { + check_assist( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + make<|>(); +} +"#, + r#" +fn make() -> T {} +fn main() { + make::<${0:_}>(); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_after_call() { + mark::check!(add_turbo_fish_after_call); + check_assist( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + make()<|>; +} +"#, + r#" +fn make() -> T {} +fn main() { + make::<${0:_}>(); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_method() { + check_assist( + add_turbo_fish, + r#" +struct S; +impl S { + fn make(&self) -> T {} +} +fn main() { + S.make<|>(); +} +"#, + r#" +struct S; +impl S { + fn make(&self) -> T {} +} +fn main() { + S.make::<${0:_}>(); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_one_fish_is_enough() { + mark::check!(add_turbo_fish_one_fish_is_enough); + check_assist_not_applicable( + add_turbo_fish, + r#" +fn make() -> T {} +fn main() { + make<|>::<()>(); +} +"#, + ); + } + + #[test] + fn add_turbo_fish_non_generic() { + mark::check!(add_turbo_fish_non_generic); + check_assist_not_applicable( + add_turbo_fish, + r#" +fn make() -> () {} +fn main() { + make<|>(); +} +"#, + ); + } +} diff --git a/crates/assists/src/handlers/apply_demorgan.rs b/crates/assists/src/handlers/apply_demorgan.rs new file mode 100644 index 000000000..1a6fdafda --- /dev/null +++ b/crates/assists/src/handlers/apply_demorgan.rs @@ -0,0 +1,93 @@ +use syntax::ast::{self, AstNode}; + +use crate::{utils::invert_boolean_expression, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: apply_demorgan +// +// Apply https://en.wikipedia.org/wiki/De_Morgan%27s_laws[De Morgan's law]. +// This transforms expressions of the form `!l || !r` into `!(l && r)`. +// This also works with `&&`. This assist can only be applied with the cursor +// on either `||` or `&&`, with both operands being a negation of some kind. +// This means something of the form `!x` or `x != y`. +// +// ``` +// fn main() { +// if x != 4 ||<|> !y {} +// } +// ``` +// -> +// ``` +// fn main() { +// if !(x == 4 && y) {} +// } +// ``` +pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let expr = ctx.find_node_at_offset::()?; + let op = expr.op_kind()?; + let op_range = expr.op_token()?.text_range(); + let opposite_op = opposite_logic_op(op)?; + let cursor_in_range = op_range.contains_range(ctx.frange.range); + if !cursor_in_range { + return None; + } + + let lhs = expr.lhs()?; + let lhs_range = lhs.syntax().text_range(); + let not_lhs = invert_boolean_expression(lhs); + + let rhs = expr.rhs()?; + let rhs_range = rhs.syntax().text_range(); + let not_rhs = invert_boolean_expression(rhs); + + acc.add( + AssistId("apply_demorgan", AssistKind::RefactorRewrite), + "Apply De Morgan's law", + op_range, + |edit| { + edit.replace(op_range, opposite_op); + edit.replace(lhs_range, format!("!({}", not_lhs.syntax().text())); + edit.replace(rhs_range, format!("{})", not_rhs.syntax().text())); + }, + ) +} + +// Return the opposite text for a given logical operator, if it makes sense +fn opposite_logic_op(kind: ast::BinOp) -> Option<&'static str> { + match kind { + ast::BinOp::BooleanOr => Some("&&"), + ast::BinOp::BooleanAnd => Some("||"), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn demorgan_turns_and_into_or() { + check_assist(apply_demorgan, "fn f() { !x &&<|> !x }", "fn f() { !(x || x) }") + } + + #[test] + fn demorgan_turns_or_into_and() { + check_assist(apply_demorgan, "fn f() { !x ||<|> !x }", "fn f() { !(x && x) }") + } + + #[test] + fn demorgan_removes_inequality() { + check_assist(apply_demorgan, "fn f() { x != x ||<|> !x }", "fn f() { !(x == x && x) }") + } + + #[test] + fn demorgan_general_case() { + check_assist(apply_demorgan, "fn f() { x ||<|> x }", "fn f() { !(!x && !x) }") + } + + #[test] + fn demorgan_doesnt_apply_with_cursor_not_on_op() { + check_assist_not_applicable(apply_demorgan, "fn f() { <|> !x || !x }") + } +} diff --git a/crates/assists/src/handlers/auto_import.rs b/crates/assists/src/handlers/auto_import.rs new file mode 100644 index 000000000..cce789972 --- /dev/null +++ b/crates/assists/src/handlers/auto_import.rs @@ -0,0 +1,1088 @@ +use std::collections::BTreeSet; + +use either::Either; +use hir::{ + AsAssocItem, AssocItemContainer, ModPath, Module, ModuleDef, PathResolution, Semantics, Trait, + Type, +}; +use ide_db::{imports_locator, RootDatabase}; +use rustc_hash::FxHashSet; +use syntax::{ + ast::{self, AstNode}, + SyntaxNode, +}; + +use crate::{ + utils::insert_use_statement, AssistContext, AssistId, AssistKind, Assists, GroupLabel, +}; + +// Assist: auto_import +// +// If the name is unresolved, provides all possible imports for it. +// +// ``` +// fn main() { +// let map = HashMap<|>::new(); +// } +// # pub mod std { pub mod collections { pub struct HashMap { } } } +// ``` +// -> +// ``` +// use std::collections::HashMap; +// +// fn main() { +// let map = HashMap::new(); +// } +// # pub mod std { pub mod collections { pub struct HashMap { } } } +// ``` +pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let auto_import_assets = AutoImportAssets::new(ctx)?; + let proposed_imports = auto_import_assets.search_for_imports(ctx); + if proposed_imports.is_empty() { + return None; + } + + let range = ctx.sema.original_range(&auto_import_assets.syntax_under_caret).range; + let group = auto_import_assets.get_import_group_message(); + for import in proposed_imports { + acc.add_group( + &group, + AssistId("auto_import", AssistKind::QuickFix), + format!("Import `{}`", &import), + range, + |builder| { + insert_use_statement( + &auto_import_assets.syntax_under_caret, + &import, + ctx, + builder.text_edit_builder(), + ); + }, + ); + } + Some(()) +} + +#[derive(Debug)] +struct AutoImportAssets { + import_candidate: ImportCandidate, + module_with_name_to_import: Module, + syntax_under_caret: SyntaxNode, +} + +impl AutoImportAssets { + fn new(ctx: &AssistContext) -> Option { + if let Some(path_under_caret) = ctx.find_node_at_offset_with_descend::() { + Self::for_regular_path(path_under_caret, &ctx) + } else { + Self::for_method_call(ctx.find_node_at_offset_with_descend()?, &ctx) + } + } + + fn for_method_call(method_call: ast::MethodCallExpr, ctx: &AssistContext) -> Option { + let syntax_under_caret = method_call.syntax().to_owned(); + let module_with_name_to_import = ctx.sema.scope(&syntax_under_caret).module()?; + Some(Self { + import_candidate: ImportCandidate::for_method_call(&ctx.sema, &method_call)?, + module_with_name_to_import, + syntax_under_caret, + }) + } + + fn for_regular_path(path_under_caret: ast::Path, ctx: &AssistContext) -> Option { + let syntax_under_caret = path_under_caret.syntax().to_owned(); + if syntax_under_caret.ancestors().find_map(ast::Use::cast).is_some() { + return None; + } + + let module_with_name_to_import = ctx.sema.scope(&syntax_under_caret).module()?; + Some(Self { + import_candidate: ImportCandidate::for_regular_path(&ctx.sema, &path_under_caret)?, + module_with_name_to_import, + syntax_under_caret, + }) + } + + fn get_search_query(&self) -> &str { + match &self.import_candidate { + ImportCandidate::UnqualifiedName(name) => name, + ImportCandidate::QualifierStart(qualifier_start) => qualifier_start, + ImportCandidate::TraitAssocItem(_, trait_assoc_item_name) => trait_assoc_item_name, + ImportCandidate::TraitMethod(_, trait_method_name) => trait_method_name, + } + } + + fn get_import_group_message(&self) -> GroupLabel { + let name = match &self.import_candidate { + ImportCandidate::UnqualifiedName(name) => format!("Import {}", name), + ImportCandidate::QualifierStart(qualifier_start) => { + format!("Import {}", qualifier_start) + } + ImportCandidate::TraitAssocItem(_, trait_assoc_item_name) => { + format!("Import a trait for item {}", trait_assoc_item_name) + } + ImportCandidate::TraitMethod(_, trait_method_name) => { + format!("Import a trait for method {}", trait_method_name) + } + }; + GroupLabel(name) + } + + fn search_for_imports(&self, ctx: &AssistContext) -> BTreeSet { + let _p = profile::span("auto_import::search_for_imports"); + let db = ctx.db(); + let current_crate = self.module_with_name_to_import.krate(); + imports_locator::find_imports(&ctx.sema, current_crate, &self.get_search_query()) + .into_iter() + .filter_map(|candidate| match &self.import_candidate { + ImportCandidate::TraitAssocItem(assoc_item_type, _) => { + let located_assoc_item = match candidate { + Either::Left(ModuleDef::Function(located_function)) => located_function + .as_assoc_item(db) + .map(|assoc| assoc.container(db)) + .and_then(Self::assoc_to_trait), + Either::Left(ModuleDef::Const(located_const)) => located_const + .as_assoc_item(db) + .map(|assoc| assoc.container(db)) + .and_then(Self::assoc_to_trait), + _ => None, + }?; + + let mut trait_candidates = FxHashSet::default(); + trait_candidates.insert(located_assoc_item.into()); + + assoc_item_type + .iterate_path_candidates( + db, + current_crate, + &trait_candidates, + None, + |_, assoc| Self::assoc_to_trait(assoc.container(db)), + ) + .map(ModuleDef::from) + .map(Either::Left) + } + ImportCandidate::TraitMethod(function_callee, _) => { + let located_assoc_item = + if let Either::Left(ModuleDef::Function(located_function)) = candidate { + located_function + .as_assoc_item(db) + .map(|assoc| assoc.container(db)) + .and_then(Self::assoc_to_trait) + } else { + None + }?; + + let mut trait_candidates = FxHashSet::default(); + trait_candidates.insert(located_assoc_item.into()); + + function_callee + .iterate_method_candidates( + db, + current_crate, + &trait_candidates, + None, + |_, function| { + Self::assoc_to_trait(function.as_assoc_item(db)?.container(db)) + }, + ) + .map(ModuleDef::from) + .map(Either::Left) + } + _ => Some(candidate), + }) + .filter_map(|candidate| match candidate { + Either::Left(module_def) => { + self.module_with_name_to_import.find_use_path(db, module_def) + } + Either::Right(macro_def) => { + self.module_with_name_to_import.find_use_path(db, macro_def) + } + }) + .filter(|use_path| !use_path.segments.is_empty()) + .take(20) + .collect::>() + } + + fn assoc_to_trait(assoc: AssocItemContainer) -> Option { + if let AssocItemContainer::Trait(extracted_trait) = assoc { + Some(extracted_trait) + } else { + None + } + } +} + +#[derive(Debug)] +enum ImportCandidate { + /// Simple name like 'HashMap' + UnqualifiedName(String), + /// First part of the qualified name. + /// For 'std::collections::HashMap', that will be 'std'. + QualifierStart(String), + /// A trait associated function (with no self parameter) or associated constant. + /// For 'test_mod::TestEnum::test_function', `Type` is the `test_mod::TestEnum` expression type + /// and `String` is the `test_function` + TraitAssocItem(Type, String), + /// A trait method with self parameter. + /// For 'test_enum.test_method()', `Type` is the `test_enum` expression type + /// and `String` is the `test_method` + TraitMethod(Type, String), +} + +impl ImportCandidate { + fn for_method_call( + sema: &Semantics, + method_call: &ast::MethodCallExpr, + ) -> Option { + if sema.resolve_method_call(method_call).is_some() { + return None; + } + Some(Self::TraitMethod( + sema.type_of_expr(&method_call.expr()?)?, + method_call.name_ref()?.syntax().to_string(), + )) + } + + fn for_regular_path( + sema: &Semantics, + path_under_caret: &ast::Path, + ) -> Option { + if sema.resolve_path(path_under_caret).is_some() { + return None; + } + + let segment = path_under_caret.segment()?; + if let Some(qualifier) = path_under_caret.qualifier() { + let qualifier_start = qualifier.syntax().descendants().find_map(ast::NameRef::cast)?; + let qualifier_start_path = + qualifier_start.syntax().ancestors().find_map(ast::Path::cast)?; + if let Some(qualifier_start_resolution) = sema.resolve_path(&qualifier_start_path) { + let qualifier_resolution = if qualifier_start_path == qualifier { + qualifier_start_resolution + } else { + sema.resolve_path(&qualifier)? + }; + if let PathResolution::Def(ModuleDef::Adt(assoc_item_path)) = qualifier_resolution { + Some(ImportCandidate::TraitAssocItem( + assoc_item_path.ty(sema.db), + segment.syntax().to_string(), + )) + } else { + None + } + } else { + Some(ImportCandidate::QualifierStart(qualifier_start.syntax().to_string())) + } + } else { + Some(ImportCandidate::UnqualifiedName( + segment.syntax().descendants().find_map(ast::NameRef::cast)?.syntax().to_string(), + )) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn applicable_when_found_an_import() { + check_assist( + auto_import, + r" + <|>PubStruct + + pub mod PubMod { + pub struct PubStruct; + } + ", + r" + use PubMod::PubStruct; + + PubStruct + + pub mod PubMod { + pub struct PubStruct; + } + ", + ); + } + + #[test] + fn applicable_when_found_an_import_in_macros() { + check_assist( + auto_import, + r" + macro_rules! foo { + ($i:ident) => { fn foo(a: $i) {} } + } + foo!(Pub<|>Struct); + + pub mod PubMod { + pub struct PubStruct; + } + ", + r" + use PubMod::PubStruct; + + macro_rules! foo { + ($i:ident) => { fn foo(a: $i) {} } + } + foo!(PubStruct); + + pub mod PubMod { + pub struct PubStruct; + } + ", + ); + } + + #[test] + fn auto_imports_are_merged() { + check_assist( + auto_import, + r" + use PubMod::PubStruct1; + + struct Test { + test: Pub<|>Struct2, + } + + pub mod PubMod { + pub struct PubStruct1; + pub struct PubStruct2 { + _t: T, + } + } + ", + r" + use PubMod::{PubStruct2, PubStruct1}; + + struct Test { + test: PubStruct2, + } + + pub mod PubMod { + pub struct PubStruct1; + pub struct PubStruct2 { + _t: T, + } + } + ", + ); + } + + #[test] + fn applicable_when_found_multiple_imports() { + check_assist( + auto_import, + r" + PubSt<|>ruct + + pub mod PubMod1 { + pub struct PubStruct; + } + pub mod PubMod2 { + pub struct PubStruct; + } + pub mod PubMod3 { + pub struct PubStruct; + } + ", + r" + use PubMod3::PubStruct; + + PubStruct + + pub mod PubMod1 { + pub struct PubStruct; + } + pub mod PubMod2 { + pub struct PubStruct; + } + pub mod PubMod3 { + pub struct PubStruct; + } + ", + ); + } + + #[test] + fn not_applicable_for_already_imported_types() { + check_assist_not_applicable( + auto_import, + r" + use PubMod::PubStruct; + + PubStruct<|> + + pub mod PubMod { + pub struct PubStruct; + } + ", + ); + } + + #[test] + fn not_applicable_for_types_with_private_paths() { + check_assist_not_applicable( + auto_import, + r" + PrivateStruct<|> + + pub mod PubMod { + struct PrivateStruct; + } + ", + ); + } + + #[test] + fn not_applicable_when_no_imports_found() { + check_assist_not_applicable( + auto_import, + " + PubStruct<|>", + ); + } + + #[test] + fn not_applicable_in_import_statements() { + check_assist_not_applicable( + auto_import, + r" + use PubStruct<|>; + + pub mod PubMod { + pub struct PubStruct; + }", + ); + } + + #[test] + fn function_import() { + check_assist( + auto_import, + r" + test_function<|> + + pub mod PubMod { + pub fn test_function() {}; + } + ", + r" + use PubMod::test_function; + + test_function + + pub mod PubMod { + pub fn test_function() {}; + } + ", + ); + } + + #[test] + fn macro_import() { + check_assist( + auto_import, + r" +//- /lib.rs crate:crate_with_macro +#[macro_export] +macro_rules! foo { + () => () +} + +//- /main.rs crate:main deps:crate_with_macro +fn main() { + foo<|> +} +", + r"use crate_with_macro::foo; + +fn main() { + foo +} +", + ); + } + + #[test] + fn auto_import_target() { + check_assist_target( + auto_import, + r" + struct AssistInfo { + group_label: Option<<|>GroupLabel>, + } + + mod m { pub struct GroupLabel; } + ", + "GroupLabel", + ) + } + + #[test] + fn not_applicable_when_path_start_is_imported() { + check_assist_not_applicable( + auto_import, + r" + pub mod mod1 { + pub mod mod2 { + pub mod mod3 { + pub struct TestStruct; + } + } + } + + use mod1::mod2; + fn main() { + mod2::mod3::TestStruct<|> + } + ", + ); + } + + #[test] + fn not_applicable_for_imported_function() { + check_assist_not_applicable( + auto_import, + r" + pub mod test_mod { + pub fn test_function() {} + } + + use test_mod::test_function; + fn main() { + test_function<|> + } + ", + ); + } + + #[test] + fn associated_struct_function() { + check_assist( + auto_import, + r" + mod test_mod { + pub struct TestStruct {} + impl TestStruct { + pub fn test_function() {} + } + } + + fn main() { + TestStruct::test_function<|> + } + ", + r" + use test_mod::TestStruct; + + mod test_mod { + pub struct TestStruct {} + impl TestStruct { + pub fn test_function() {} + } + } + + fn main() { + TestStruct::test_function + } + ", + ); + } + + #[test] + fn associated_struct_const() { + check_assist( + auto_import, + r" + mod test_mod { + pub struct TestStruct {} + impl TestStruct { + const TEST_CONST: u8 = 42; + } + } + + fn main() { + TestStruct::TEST_CONST<|> + } + ", + r" + use test_mod::TestStruct; + + mod test_mod { + pub struct TestStruct {} + impl TestStruct { + const TEST_CONST: u8 = 42; + } + } + + fn main() { + TestStruct::TEST_CONST + } + ", + ); + } + + #[test] + fn associated_trait_function() { + check_assist( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } + } + + fn main() { + test_mod::TestStruct::test_function<|> + } + ", + r" + use test_mod::TestTrait; + + mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } + } + + fn main() { + test_mod::TestStruct::test_function + } + ", + ); + } + + #[test] + fn not_applicable_for_imported_trait_for_function() { + check_assist_not_applicable( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub trait TestTrait2 { + fn test_function(); + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + fn test_function() {} + } + impl TestTrait for TestEnum { + fn test_function() {} + } + } + + use test_mod::TestTrait2; + fn main() { + test_mod::TestEnum::test_function<|>; + } + ", + ) + } + + #[test] + fn associated_trait_const() { + check_assist( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const TEST_CONST: u8 = 42; + } + } + + fn main() { + test_mod::TestStruct::TEST_CONST<|> + } + ", + r" + use test_mod::TestTrait; + + mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const TEST_CONST: u8 = 42; + } + } + + fn main() { + test_mod::TestStruct::TEST_CONST + } + ", + ); + } + + #[test] + fn not_applicable_for_imported_trait_for_const() { + check_assist_not_applicable( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + const TEST_CONST: u8; + } + pub trait TestTrait2 { + const TEST_CONST: f64; + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + const TEST_CONST: f64 = 42.0; + } + impl TestTrait for TestEnum { + const TEST_CONST: u8 = 42; + } + } + + use test_mod::TestTrait2; + fn main() { + test_mod::TestEnum::TEST_CONST<|>; + } + ", + ) + } + + #[test] + fn trait_method() { + check_assist( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } + } + + fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_meth<|>od() + } + ", + r" + use test_mod::TestTrait; + + mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } + } + + fn main() { + let test_struct = test_mod::TestStruct {}; + test_struct.test_method() + } + ", + ); + } + + #[test] + fn trait_method_cross_crate() { + check_assist( + auto_import, + r" + //- /main.rs crate:main deps:dep + fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_meth<|>od() + } + //- /dep.rs crate:dep + pub mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } + } + ", + r" + use dep::test_mod::TestTrait; + + fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_method() + } + ", + ); + } + + #[test] + fn assoc_fn_cross_crate() { + check_assist( + auto_import, + r" + //- /main.rs crate:main deps:dep + fn main() { + dep::test_mod::TestStruct::test_func<|>tion + } + //- /dep.rs crate:dep + pub mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } + } + ", + r" + use dep::test_mod::TestTrait; + + fn main() { + dep::test_mod::TestStruct::test_function + } + ", + ); + } + + #[test] + fn assoc_const_cross_crate() { + check_assist( + auto_import, + r" + //- /main.rs crate:main deps:dep + fn main() { + dep::test_mod::TestStruct::CONST<|> + } + //- /dep.rs crate:dep + pub mod test_mod { + pub trait TestTrait { + const CONST: bool; + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + const CONST: bool = true; + } + } + ", + r" + use dep::test_mod::TestTrait; + + fn main() { + dep::test_mod::TestStruct::CONST + } + ", + ); + } + + #[test] + fn assoc_fn_as_method_cross_crate() { + check_assist_not_applicable( + auto_import, + r" + //- /main.rs crate:main deps:dep + fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_func<|>tion() + } + //- /dep.rs crate:dep + pub mod test_mod { + pub trait TestTrait { + fn test_function(); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_function() {} + } + } + ", + ); + } + + #[test] + fn private_trait_cross_crate() { + check_assist_not_applicable( + auto_import, + r" + //- /main.rs crate:main deps:dep + fn main() { + let test_struct = dep::test_mod::TestStruct {}; + test_struct.test_meth<|>od() + } + //- /dep.rs crate:dep + pub mod test_mod { + trait TestTrait { + fn test_method(&self); + } + pub struct TestStruct {} + impl TestTrait for TestStruct { + fn test_method(&self) {} + } + } + ", + ); + } + + #[test] + fn not_applicable_for_imported_trait_for_method() { + check_assist_not_applicable( + auto_import, + r" + mod test_mod { + pub trait TestTrait { + fn test_method(&self); + } + pub trait TestTrait2 { + fn test_method(&self); + } + pub enum TestEnum { + One, + Two, + } + impl TestTrait2 for TestEnum { + fn test_method(&self) {} + } + impl TestTrait for TestEnum { + fn test_method(&self) {} + } + } + + use test_mod::TestTrait2; + fn main() { + let one = test_mod::TestEnum::One; + one.test<|>_method(); + } + ", + ) + } + + #[test] + fn dep_import() { + check_assist( + auto_import, + r" +//- /lib.rs crate:dep +pub struct Struct; + +//- /main.rs crate:main deps:dep +fn main() { + Struct<|> +} +", + r"use dep::Struct; + +fn main() { + Struct +} +", + ); + } + + #[test] + fn whole_segment() { + // Tests that only imports whose last segment matches the identifier get suggested. + check_assist( + auto_import, + r" +//- /lib.rs crate:dep +pub mod fmt { + pub trait Display {} +} + +pub fn panic_fmt() {} + +//- /main.rs crate:main deps:dep +struct S; + +impl f<|>mt::Display for S {} +", + r"use dep::fmt; + +struct S; + +impl fmt::Display for S {} +", + ); + } + + #[test] + fn macro_generated() { + // Tests that macro-generated items are suggested from external crates. + check_assist( + auto_import, + r" +//- /lib.rs crate:dep +macro_rules! mac { + () => { + pub struct Cheese; + }; +} + +mac!(); + +//- /main.rs crate:main deps:dep +fn main() { + Cheese<|>; +} +", + r"use dep::Cheese; + +fn main() { + Cheese; +} +", + ); + } + + #[test] + fn casing() { + // Tests that differently cased names don't interfere and we only suggest the matching one. + check_assist( + auto_import, + r" +//- /lib.rs crate:dep +pub struct FMT; +pub struct fmt; + +//- /main.rs crate:main deps:dep +fn main() { + FMT<|>; +} +", + r"use dep::FMT; + +fn main() { + FMT; +} +", + ); + } +} diff --git a/crates/assists/src/handlers/change_return_type_to_result.rs b/crates/assists/src/handlers/change_return_type_to_result.rs new file mode 100644 index 000000000..be480943c --- /dev/null +++ b/crates/assists/src/handlers/change_return_type_to_result.rs @@ -0,0 +1,998 @@ +use std::iter; + +use syntax::{ + ast::{self, make, BlockExpr, Expr, LoopBodyOwner}, + AstNode, SyntaxNode, +}; +use test_utils::mark; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: change_return_type_to_result +// +// Change the function's return type to Result. +// +// ``` +// fn foo() -> i32<|> { 42i32 } +// ``` +// -> +// ``` +// fn foo() -> Result { Ok(42i32) } +// ``` +pub(crate) fn change_return_type_to_result(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let ret_type = ctx.find_node_at_offset::()?; + // FIXME: extend to lambdas as well + let fn_def = ret_type.syntax().parent().and_then(ast::Fn::cast)?; + + let type_ref = &ret_type.ty()?; + let ret_type_str = type_ref.syntax().text().to_string(); + let first_part_ret_type = ret_type_str.splitn(2, '<').next(); + if let Some(ret_type_first_part) = first_part_ret_type { + if ret_type_first_part.ends_with("Result") { + mark::hit!(change_return_type_to_result_simple_return_type_already_result); + return None; + } + } + + let block_expr = &fn_def.body()?; + + acc.add( + AssistId("change_return_type_to_result", AssistKind::RefactorRewrite), + "Wrap return type in Result", + type_ref.syntax().text_range(), + |builder| { + let mut tail_return_expr_collector = TailReturnCollector::new(); + tail_return_expr_collector.collect_jump_exprs(block_expr, false); + tail_return_expr_collector.collect_tail_exprs(block_expr); + + for ret_expr_arg in tail_return_expr_collector.exprs_to_wrap { + let ok_wrapped = make::expr_call( + make::expr_path(make::path_unqualified(make::path_segment(make::name_ref( + "Ok", + )))), + make::arg_list(iter::once(ret_expr_arg.clone())), + ); + builder.replace_ast(ret_expr_arg, ok_wrapped); + } + + match ctx.config.snippet_cap { + Some(cap) => { + let snippet = format!("Result<{}, ${{0:_}}>", type_ref); + builder.replace_snippet(cap, type_ref.syntax().text_range(), snippet) + } + None => builder + .replace(type_ref.syntax().text_range(), format!("Result<{}, _>", type_ref)), + } + }, + ) +} + +struct TailReturnCollector { + exprs_to_wrap: Vec, +} + +impl TailReturnCollector { + fn new() -> Self { + Self { exprs_to_wrap: vec![] } + } + /// Collect all`return` expression + fn collect_jump_exprs(&mut self, block_expr: &BlockExpr, collect_break: bool) { + let statements = block_expr.statements(); + for stmt in statements { + let expr = match &stmt { + ast::Stmt::ExprStmt(stmt) => stmt.expr(), + ast::Stmt::LetStmt(stmt) => stmt.initializer(), + ast::Stmt::Item(_) => continue, + }; + if let Some(expr) = &expr { + self.handle_exprs(expr, collect_break); + } + } + + // Browse tail expressions for each block + if let Some(expr) = block_expr.expr() { + if let Some(last_exprs) = get_tail_expr_from_block(&expr) { + for last_expr in last_exprs { + let last_expr = match last_expr { + NodeType::Node(expr) => expr, + NodeType::Leaf(expr) => expr.syntax().clone(), + }; + + if let Some(last_expr) = Expr::cast(last_expr.clone()) { + self.handle_exprs(&last_expr, collect_break); + } else if let Some(expr_stmt) = ast::Stmt::cast(last_expr) { + let expr_stmt = match &expr_stmt { + ast::Stmt::ExprStmt(stmt) => stmt.expr(), + ast::Stmt::LetStmt(stmt) => stmt.initializer(), + ast::Stmt::Item(_) => None, + }; + if let Some(expr) = &expr_stmt { + self.handle_exprs(expr, collect_break); + } + } + } + } + } + } + + fn handle_exprs(&mut self, expr: &Expr, collect_break: bool) { + match expr { + Expr::BlockExpr(block_expr) => { + self.collect_jump_exprs(&block_expr, collect_break); + } + Expr::ReturnExpr(ret_expr) => { + if let Some(ret_expr_arg) = &ret_expr.expr() { + self.exprs_to_wrap.push(ret_expr_arg.clone()); + } + } + Expr::BreakExpr(break_expr) if collect_break => { + if let Some(break_expr_arg) = &break_expr.expr() { + self.exprs_to_wrap.push(break_expr_arg.clone()); + } + } + Expr::IfExpr(if_expr) => { + for block in if_expr.blocks() { + self.collect_jump_exprs(&block, collect_break); + } + } + Expr::LoopExpr(loop_expr) => { + if let Some(block_expr) = loop_expr.loop_body() { + self.collect_jump_exprs(&block_expr, collect_break); + } + } + Expr::ForExpr(for_expr) => { + if let Some(block_expr) = for_expr.loop_body() { + self.collect_jump_exprs(&block_expr, collect_break); + } + } + Expr::WhileExpr(while_expr) => { + if let Some(block_expr) = while_expr.loop_body() { + self.collect_jump_exprs(&block_expr, collect_break); + } + } + Expr::MatchExpr(match_expr) => { + if let Some(arm_list) = match_expr.match_arm_list() { + arm_list.arms().filter_map(|match_arm| match_arm.expr()).for_each(|expr| { + self.handle_exprs(&expr, collect_break); + }); + } + } + _ => {} + } + } + + fn collect_tail_exprs(&mut self, block: &BlockExpr) { + if let Some(expr) = block.expr() { + self.handle_exprs(&expr, true); + self.fetch_tail_exprs(&expr); + } + } + + fn fetch_tail_exprs(&mut self, expr: &Expr) { + if let Some(exprs) = get_tail_expr_from_block(expr) { + for node_type in &exprs { + match node_type { + NodeType::Leaf(expr) => { + self.exprs_to_wrap.push(expr.clone()); + } + NodeType::Node(expr) => { + if let Some(last_expr) = Expr::cast(expr.clone()) { + self.fetch_tail_exprs(&last_expr); + } + } + } + } + } + } +} + +#[derive(Debug)] +enum NodeType { + Leaf(ast::Expr), + Node(SyntaxNode), +} + +/// Get a tail expression inside a block +fn get_tail_expr_from_block(expr: &Expr) -> Option> { + match expr { + Expr::IfExpr(if_expr) => { + let mut nodes = vec![]; + for block in if_expr.blocks() { + if let Some(block_expr) = block.expr() { + if let Some(tail_exprs) = get_tail_expr_from_block(&block_expr) { + nodes.extend(tail_exprs); + } + } else if let Some(last_expr) = block.syntax().last_child() { + nodes.push(NodeType::Node(last_expr)); + } else { + nodes.push(NodeType::Node(block.syntax().clone())); + } + } + Some(nodes) + } + Expr::LoopExpr(loop_expr) => { + loop_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)]) + } + Expr::ForExpr(for_expr) => { + for_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)]) + } + Expr::WhileExpr(while_expr) => { + while_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)]) + } + Expr::BlockExpr(block_expr) => { + block_expr.expr().map(|lc| vec![NodeType::Node(lc.syntax().clone())]) + } + Expr::MatchExpr(match_expr) => { + let arm_list = match_expr.match_arm_list()?; + let arms: Vec = arm_list + .arms() + .filter_map(|match_arm| match_arm.expr()) + .map(|expr| match expr { + Expr::ReturnExpr(ret_expr) => NodeType::Node(ret_expr.syntax().clone()), + Expr::BreakExpr(break_expr) => NodeType::Node(break_expr.syntax().clone()), + _ => match expr.syntax().last_child() { + Some(last_expr) => NodeType::Node(last_expr), + None => NodeType::Node(expr.syntax().clone()), + }, + }) + .collect(); + + Some(arms) + } + Expr::BreakExpr(expr) => expr.expr().map(|e| vec![NodeType::Leaf(e)]), + Expr::ReturnExpr(ret_expr) => Some(vec![NodeType::Node(ret_expr.syntax().clone())]), + + Expr::CallExpr(_) + | Expr::Literal(_) + | Expr::TupleExpr(_) + | Expr::ArrayExpr(_) + | Expr::ParenExpr(_) + | Expr::PathExpr(_) + | Expr::RecordExpr(_) + | Expr::IndexExpr(_) + | Expr::MethodCallExpr(_) + | Expr::AwaitExpr(_) + | Expr::CastExpr(_) + | Expr::RefExpr(_) + | Expr::PrefixExpr(_) + | Expr::RangeExpr(_) + | Expr::BinExpr(_) + | Expr::MacroCall(_) + | Expr::BoxExpr(_) => Some(vec![NodeType::Leaf(expr.clone())]), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn change_return_type_to_result_simple() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i3<|>2 { + let test = "test"; + return 42i32; + }"#, + r#"fn foo() -> Result { + let test = "test"; + return Ok(42i32); + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_return_type() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let test = "test"; + return 42i32; + }"#, + r#"fn foo() -> Result { + let test = "test"; + return Ok(42i32); + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_return_type_bad_cursor() { + check_assist_not_applicable( + change_return_type_to_result, + r#"fn foo() -> i32 { + let test = "test";<|> + return 42i32; + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_return_type_already_result_std() { + check_assist_not_applicable( + change_return_type_to_result, + r#"fn foo() -> std::result::Result, String> { + let test = "test"; + return 42i32; + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_return_type_already_result() { + mark::check!(change_return_type_to_result_simple_return_type_already_result); + check_assist_not_applicable( + change_return_type_to_result, + r#"fn foo() -> Result, String> { + let test = "test"; + return 42i32; + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_cursor() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> <|>i32 { + let test = "test"; + return 42i32; + }"#, + r#"fn foo() -> Result { + let test = "test"; + return Ok(42i32); + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail() { + check_assist( + change_return_type_to_result, + r#"fn foo() -><|> i32 { + let test = "test"; + 42i32 + }"#, + r#"fn foo() -> Result { + let test = "test"; + Ok(42i32) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail_only() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + 42i32 + }"#, + r#"fn foo() -> Result { + Ok(42i32) + }"#, + ); + } + #[test] + fn change_return_type_to_result_simple_with_tail_block_like() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + if true { + 42i32 + } else { + 24i32 + } + }"#, + r#"fn foo() -> Result { + if true { + Ok(42i32) + } else { + Ok(24i32) + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_nested_if() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + if true { + if false { + 1 + } else { + 2 + } + } else { + 24i32 + } + }"#, + r#"fn foo() -> Result { + if true { + if false { + Ok(1) + } else { + Ok(2) + } + } else { + Ok(24i32) + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_await() { + check_assist( + change_return_type_to_result, + r#"async fn foo() -> i<|>32 { + if true { + if false { + 1.await + } else { + 2.await + } + } else { + 24i32.await + } + }"#, + r#"async fn foo() -> Result { + if true { + if false { + Ok(1.await) + } else { + Ok(2.await) + } + } else { + Ok(24i32.await) + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_array() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> [i32;<|> 3] { + [1, 2, 3] + }"#, + r#"fn foo() -> Result<[i32; 3], ${0:_}> { + Ok([1, 2, 3]) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_cast() { + check_assist( + change_return_type_to_result, + r#"fn foo() -<|>> i32 { + if true { + if false { + 1 as i32 + } else { + 2 as i32 + } + } else { + 24 as i32 + } + }"#, + r#"fn foo() -> Result { + if true { + if false { + Ok(1 as i32) + } else { + Ok(2 as i32) + } + } else { + Ok(24 as i32) + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail_block_like_match() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = 5; + match my_var { + 5 => 42i32, + _ => 24i32, + } + }"#, + r#"fn foo() -> Result { + let my_var = 5; + match my_var { + 5 => Ok(42i32), + _ => Ok(24i32), + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_loop_with_tail() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = 5; + loop { + println!("test"); + 5 + } + + my_var + }"#, + r#"fn foo() -> Result { + let my_var = 5; + loop { + println!("test"); + 5 + } + + Ok(my_var) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_loop_in_let_stmt() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = let x = loop { + break 1; + }; + + my_var + }"#, + r#"fn foo() -> Result { + let my_var = let x = loop { + break 1; + }; + + Ok(my_var) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail_block_like_match_return_expr() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = 5; + let res = match my_var { + 5 => 42i32, + _ => return 24i32, + }; + + res + }"#, + r#"fn foo() -> Result { + let my_var = 5; + let res = match my_var { + 5 => 42i32, + _ => return Ok(24i32), + }; + + Ok(res) + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = 5; + let res = if my_var == 5 { + 42i32 + } else { + return 24i32; + }; + + res + }"#, + r#"fn foo() -> Result { + let my_var = 5; + let res = if my_var == 5 { + 42i32 + } else { + return Ok(24i32); + }; + + Ok(res) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail_block_like_match_deeper() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = 5; + match my_var { + 5 => { + if true { + 42i32 + } else { + 25i32 + } + }, + _ => { + let test = "test"; + if test == "test" { + return bar(); + } + 53i32 + }, + } + }"#, + r#"fn foo() -> Result { + let my_var = 5; + match my_var { + 5 => { + if true { + Ok(42i32) + } else { + Ok(25i32) + } + }, + _ => { + let test = "test"; + if test == "test" { + return Ok(bar()); + } + Ok(53i32) + }, + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail_block_like_early_return() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i<|>32 { + let test = "test"; + if test == "test" { + return 24i32; + } + 53i32 + }"#, + r#"fn foo() -> Result { + let test = "test"; + if test == "test" { + return Ok(24i32); + } + Ok(53i32) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_closure() { + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -><|> u32 { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return 99; + } else { + return 0; + } + } + + the_field + }"#, + r#"fn foo(the_field: u32) -> Result { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return Ok(99); + } else { + return Ok(0); + } + } + + Ok(the_field) + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -> u32<|> { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return 99; + } else { + return 0; + } + } + let t = None; + + t.unwrap_or_else(|| the_field) + }"#, + r#"fn foo(the_field: u32) -> Result { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return Ok(99); + } else { + return Ok(0); + } + } + let t = None; + + Ok(t.unwrap_or_else(|| the_field)) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_weird_forms() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let test = "test"; + if test == "test" { + return 24i32; + } + let mut i = 0; + loop { + if i == 1 { + break 55; + } + i += 1; + } + }"#, + r#"fn foo() -> Result { + let test = "test"; + if test == "test" { + return Ok(24i32); + } + let mut i = 0; + loop { + if i == 1 { + break Ok(55); + } + i += 1; + } + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let test = "test"; + if test == "test" { + return 24i32; + } + let mut i = 0; + loop { + loop { + if i == 1 { + break 55; + } + i += 1; + } + } + }"#, + r#"fn foo() -> Result { + let test = "test"; + if test == "test" { + return Ok(24i32); + } + let mut i = 0; + loop { + loop { + if i == 1 { + break Ok(55); + } + i += 1; + } + } + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo() -> i3<|>2 { + let test = "test"; + let other = 5; + if test == "test" { + let res = match other { + 5 => 43, + _ => return 56, + }; + } + let mut i = 0; + loop { + loop { + if i == 1 { + break 55; + } + i += 1; + } + } + }"#, + r#"fn foo() -> Result { + let test = "test"; + let other = 5; + if test == "test" { + let res = match other { + 5 => 43, + _ => return Ok(56), + }; + } + let mut i = 0; + loop { + loop { + if i == 1 { + break Ok(55); + } + i += 1; + } + } + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -> u32<|> { + if the_field < 5 { + let mut i = 0; + loop { + if i > 5 { + return 55u32; + } + i += 3; + } + + match i { + 5 => return 99, + _ => return 0, + }; + } + + the_field + }"#, + r#"fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + loop { + if i > 5 { + return Ok(55u32); + } + i += 3; + } + + match i { + 5 => return Ok(99), + _ => return Ok(0), + }; + } + + Ok(the_field) + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -> u3<|>2 { + if the_field < 5 { + let mut i = 0; + + match i { + 5 => return 99, + _ => return 0, + } + } + + the_field + }"#, + r#"fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + + match i { + 5 => return Ok(99), + _ => return Ok(0), + } + } + + Ok(the_field) + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -> u32<|> { + if the_field < 5 { + let mut i = 0; + + if i == 5 { + return 99 + } else { + return 0 + } + } + + the_field + }"#, + r#"fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + + if i == 5 { + return Ok(99) + } else { + return Ok(0) + } + } + + Ok(the_field) + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -> <|>u32 { + if the_field < 5 { + let mut i = 0; + + if i == 5 { + return 99; + } else { + return 0; + } + } + + the_field + }"#, + r#"fn foo(the_field: u32) -> Result { + if the_field < 5 { + let mut i = 0; + + if i == 5 { + return Ok(99); + } else { + return Ok(0); + } + } + + Ok(the_field) + }"#, + ); + } +} diff --git a/crates/assists/src/handlers/change_visibility.rs b/crates/assists/src/handlers/change_visibility.rs new file mode 100644 index 000000000..32dc05378 --- /dev/null +++ b/crates/assists/src/handlers/change_visibility.rs @@ -0,0 +1,200 @@ +use syntax::{ + ast::{self, NameOwner, VisibilityOwner}, + AstNode, + SyntaxKind::{CONST, ENUM, FN, MODULE, STATIC, STRUCT, TRAIT, VISIBILITY}, + T, +}; +use test_utils::mark; + +use crate::{utils::vis_offset, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: change_visibility +// +// Adds or changes existing visibility specifier. +// +// ``` +// <|>fn frobnicate() {} +// ``` +// -> +// ``` +// pub(crate) fn frobnicate() {} +// ``` +pub(crate) fn change_visibility(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + if let Some(vis) = ctx.find_node_at_offset::() { + return change_vis(acc, vis); + } + add_vis(acc, ctx) +} + +fn add_vis(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let item_keyword = ctx.token_at_offset().find(|leaf| { + matches!( + leaf.kind(), + T![const] | T![static] | T![fn] | T![mod] | T![struct] | T![enum] | T![trait] + ) + }); + + let (offset, target) = if let Some(keyword) = item_keyword { + let parent = keyword.parent(); + let def_kws = vec![CONST, STATIC, FN, MODULE, STRUCT, ENUM, TRAIT]; + // Parent is not a definition, can't add visibility + if !def_kws.iter().any(|&def_kw| def_kw == parent.kind()) { + return None; + } + // Already have visibility, do nothing + if parent.children().any(|child| child.kind() == VISIBILITY) { + return None; + } + (vis_offset(&parent), keyword.text_range()) + } else if let Some(field_name) = ctx.find_node_at_offset::() { + let field = field_name.syntax().ancestors().find_map(ast::RecordField::cast)?; + if field.name()? != field_name { + mark::hit!(change_visibility_field_false_positive); + return None; + } + if field.visibility().is_some() { + return None; + } + (vis_offset(field.syntax()), field_name.syntax().text_range()) + } else if let Some(field) = ctx.find_node_at_offset::() { + if field.visibility().is_some() { + return None; + } + (vis_offset(field.syntax()), field.syntax().text_range()) + } else { + return None; + }; + + acc.add( + AssistId("change_visibility", AssistKind::RefactorRewrite), + "Change visibility to pub(crate)", + target, + |edit| { + edit.insert(offset, "pub(crate) "); + }, + ) +} + +fn change_vis(acc: &mut Assists, vis: ast::Visibility) -> Option<()> { + if vis.syntax().text() == "pub" { + let target = vis.syntax().text_range(); + return acc.add( + AssistId("change_visibility", AssistKind::RefactorRewrite), + "Change Visibility to pub(crate)", + target, + |edit| { + edit.replace(vis.syntax().text_range(), "pub(crate)"); + }, + ); + } + if vis.syntax().text() == "pub(crate)" { + let target = vis.syntax().text_range(); + return acc.add( + AssistId("change_visibility", AssistKind::RefactorRewrite), + "Change visibility to pub", + target, + |edit| { + edit.replace(vis.syntax().text_range(), "pub"); + }, + ); + } + None +} + +#[cfg(test)] +mod tests { + use test_utils::mark; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn change_visibility_adds_pub_crate_to_items() { + check_assist(change_visibility, "<|>fn foo() {}", "pub(crate) fn foo() {}"); + check_assist(change_visibility, "f<|>n foo() {}", "pub(crate) fn foo() {}"); + check_assist(change_visibility, "<|>struct Foo {}", "pub(crate) struct Foo {}"); + check_assist(change_visibility, "<|>mod foo {}", "pub(crate) mod foo {}"); + check_assist(change_visibility, "<|>trait Foo {}", "pub(crate) trait Foo {}"); + check_assist(change_visibility, "m<|>od {}", "pub(crate) mod {}"); + check_assist(change_visibility, "unsafe f<|>n foo() {}", "pub(crate) unsafe fn foo() {}"); + } + + #[test] + fn change_visibility_works_with_struct_fields() { + check_assist( + change_visibility, + r"struct S { <|>field: u32 }", + r"struct S { pub(crate) field: u32 }", + ); + check_assist(change_visibility, r"struct S ( <|>u32 )", r"struct S ( pub(crate) u32 )"); + } + + #[test] + fn change_visibility_field_false_positive() { + mark::check!(change_visibility_field_false_positive); + check_assist_not_applicable( + change_visibility, + r"struct S { field: [(); { let <|>x = ();}] }", + ) + } + + #[test] + fn change_visibility_pub_to_pub_crate() { + check_assist(change_visibility, "<|>pub fn foo() {}", "pub(crate) fn foo() {}") + } + + #[test] + fn change_visibility_pub_crate_to_pub() { + check_assist(change_visibility, "<|>pub(crate) fn foo() {}", "pub fn foo() {}") + } + + #[test] + fn change_visibility_const() { + check_assist(change_visibility, "<|>const FOO = 3u8;", "pub(crate) const FOO = 3u8;"); + } + + #[test] + fn change_visibility_static() { + check_assist(change_visibility, "<|>static FOO = 3u8;", "pub(crate) static FOO = 3u8;"); + } + + #[test] + fn change_visibility_handles_comment_attrs() { + check_assist( + change_visibility, + r" + /// docs + + // comments + + #[derive(Debug)] + <|>struct Foo; + ", + r" + /// docs + + // comments + + #[derive(Debug)] + pub(crate) struct Foo; + ", + ) + } + + #[test] + fn not_applicable_for_enum_variants() { + check_assist_not_applicable( + change_visibility, + r"mod foo { pub enum Foo {Foo1} } + fn main() { foo::Foo::Foo1<|> } ", + ); + } + + #[test] + fn change_visibility_target() { + check_assist_target(change_visibility, "<|>fn foo() {}", "fn"); + check_assist_target(change_visibility, "pub(crate)<|> fn foo() {}", "pub(crate)"); + check_assist_target(change_visibility, "struct S { <|>field: u32 }", "field"); + } +} diff --git a/crates/assists/src/handlers/early_return.rs b/crates/assists/src/handlers/early_return.rs new file mode 100644 index 000000000..7fd78e9d4 --- /dev/null +++ b/crates/assists/src/handlers/early_return.rs @@ -0,0 +1,515 @@ +use std::{iter::once, ops::RangeInclusive}; + +use syntax::{ + algo::replace_children, + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, + }, + AstNode, + SyntaxKind::{FN, LOOP_EXPR, L_CURLY, R_CURLY, WHILE_EXPR, WHITESPACE}, + SyntaxNode, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + utils::invert_boolean_expression, + AssistId, AssistKind, +}; + +// Assist: convert_to_guarded_return +// +// Replace a large conditional with a guarded return. +// +// ``` +// fn main() { +// <|>if cond { +// foo(); +// bar(); +// } +// } +// ``` +// -> +// ``` +// fn main() { +// if !cond { +// return; +// } +// foo(); +// bar(); +// } +// ``` +pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let if_expr: ast::IfExpr = ctx.find_node_at_offset()?; + if if_expr.else_branch().is_some() { + return None; + } + + let cond = if_expr.condition()?; + + // Check if there is an IfLet that we can handle. + let if_let_pat = match cond.pat() { + None => None, // No IfLet, supported. + Some(ast::Pat::TupleStructPat(pat)) if pat.fields().count() == 1 => { + let path = pat.path()?; + match path.qualifier() { + None => { + let bound_ident = pat.fields().next().unwrap(); + Some((path, bound_ident)) + } + Some(_) => return None, + } + } + Some(_) => return None, // Unsupported IfLet. + }; + + let cond_expr = cond.expr()?; + let then_block = if_expr.then_branch()?; + + let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?; + + if parent_block.expr()? != if_expr.clone().into() { + return None; + } + + // check for early return and continue + let first_in_then_block = then_block.syntax().first_child()?; + if ast::ReturnExpr::can_cast(first_in_then_block.kind()) + || ast::ContinueExpr::can_cast(first_in_then_block.kind()) + || first_in_then_block + .children() + .any(|x| ast::ReturnExpr::can_cast(x.kind()) || ast::ContinueExpr::can_cast(x.kind())) + { + return None; + } + + let parent_container = parent_block.syntax().parent()?; + + let early_expression: ast::Expr = match parent_container.kind() { + WHILE_EXPR | LOOP_EXPR => make::expr_continue(), + FN => make::expr_return(), + _ => return None, + }; + + if then_block.syntax().first_child_or_token().map(|t| t.kind() == L_CURLY).is_none() { + return None; + } + + then_block.syntax().last_child_or_token().filter(|t| t.kind() == R_CURLY)?; + + let target = if_expr.syntax().text_range(); + acc.add( + AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite), + "Convert to guarded return", + target, + |edit| { + let if_indent_level = IndentLevel::from_node(&if_expr.syntax()); + let new_block = match if_let_pat { + None => { + // If. + let new_expr = { + let then_branch = + make::block_expr(once(make::expr_stmt(early_expression).into()), None); + let cond = invert_boolean_expression(cond_expr); + make::expr_if(make::condition(cond, None), then_branch) + .indent(if_indent_level) + }; + replace(new_expr.syntax(), &then_block, &parent_block, &if_expr) + } + Some((path, bound_ident)) => { + // If-let. + let match_expr = { + let happy_arm = { + let pat = make::tuple_struct_pat( + path, + once(make::ident_pat(make::name("it")).into()), + ); + let expr = { + let name_ref = make::name_ref("it"); + let segment = make::path_segment(name_ref); + let path = make::path_unqualified(segment); + make::expr_path(path) + }; + make::match_arm(once(pat.into()), expr) + }; + + let sad_arm = make::match_arm( + // FIXME: would be cool to use `None` or `Err(_)` if appropriate + once(make::wildcard_pat().into()), + early_expression, + ); + + make::expr_match(cond_expr, make::match_arm_list(vec![happy_arm, sad_arm])) + }; + + let let_stmt = make::let_stmt( + make::ident_pat(make::name(&bound_ident.syntax().to_string())).into(), + Some(match_expr), + ); + let let_stmt = let_stmt.indent(if_indent_level); + replace(let_stmt.syntax(), &then_block, &parent_block, &if_expr) + } + }; + edit.replace_ast(parent_block, ast::BlockExpr::cast(new_block).unwrap()); + + fn replace( + new_expr: &SyntaxNode, + then_block: &ast::BlockExpr, + parent_block: &ast::BlockExpr, + if_expr: &ast::IfExpr, + ) -> SyntaxNode { + let then_block_items = then_block.dedent(IndentLevel(1)); + let end_of_then = then_block_items.syntax().last_child_or_token().unwrap(); + let end_of_then = + if end_of_then.prev_sibling_or_token().map(|n| n.kind()) == Some(WHITESPACE) { + end_of_then.prev_sibling_or_token().unwrap() + } else { + end_of_then + }; + let mut then_statements = new_expr.children_with_tokens().chain( + then_block_items + .syntax() + .children_with_tokens() + .skip(1) + .take_while(|i| *i != end_of_then), + ); + replace_children( + &parent_block.syntax(), + RangeInclusive::new( + if_expr.clone().syntax().clone().into(), + if_expr.syntax().clone().into(), + ), + &mut then_statements, + ) + } + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn convert_inside_fn() { + check_assist( + convert_to_guarded_return, + r#" + fn main() { + bar(); + if<|> true { + foo(); + + //comment + bar(); + } + } + "#, + r#" + fn main() { + bar(); + if !true { + return; + } + foo(); + + //comment + bar(); + } + "#, + ); + } + + #[test] + fn convert_let_inside_fn() { + check_assist( + convert_to_guarded_return, + r#" + fn main(n: Option) { + bar(); + if<|> let Some(n) = n { + foo(n); + + //comment + bar(); + } + } + "#, + r#" + fn main(n: Option) { + bar(); + let n = match n { + Some(it) => it, + _ => return, + }; + foo(n); + + //comment + bar(); + } + "#, + ); + } + + #[test] + fn convert_if_let_result() { + check_assist( + convert_to_guarded_return, + r#" + fn main() { + if<|> let Ok(x) = Err(92) { + foo(x); + } + } + "#, + r#" + fn main() { + let x = match Err(92) { + Ok(it) => it, + _ => return, + }; + foo(x); + } + "#, + ); + } + + #[test] + fn convert_let_ok_inside_fn() { + check_assist( + convert_to_guarded_return, + r#" + fn main(n: Option) { + bar(); + if<|> let Ok(n) = n { + foo(n); + + //comment + bar(); + } + } + "#, + r#" + fn main(n: Option) { + bar(); + let n = match n { + Ok(it) => it, + _ => return, + }; + foo(n); + + //comment + bar(); + } + "#, + ); + } + + #[test] + fn convert_inside_while() { + check_assist( + convert_to_guarded_return, + r#" + fn main() { + while true { + if<|> true { + foo(); + bar(); + } + } + } + "#, + r#" + fn main() { + while true { + if !true { + continue; + } + foo(); + bar(); + } + } + "#, + ); + } + + #[test] + fn convert_let_inside_while() { + check_assist( + convert_to_guarded_return, + r#" + fn main() { + while true { + if<|> let Some(n) = n { + foo(n); + bar(); + } + } + } + "#, + r#" + fn main() { + while true { + let n = match n { + Some(it) => it, + _ => continue, + }; + foo(n); + bar(); + } + } + "#, + ); + } + + #[test] + fn convert_inside_loop() { + check_assist( + convert_to_guarded_return, + r#" + fn main() { + loop { + if<|> true { + foo(); + bar(); + } + } + } + "#, + r#" + fn main() { + loop { + if !true { + continue; + } + foo(); + bar(); + } + } + "#, + ); + } + + #[test] + fn convert_let_inside_loop() { + check_assist( + convert_to_guarded_return, + r#" + fn main() { + loop { + if<|> let Some(n) = n { + foo(n); + bar(); + } + } + } + "#, + r#" + fn main() { + loop { + let n = match n { + Some(it) => it, + _ => continue, + }; + foo(n); + bar(); + } + } + "#, + ); + } + + #[test] + fn ignore_already_converted_if() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" + fn main() { + if<|> true { + return; + } + } + "#, + ); + } + + #[test] + fn ignore_already_converted_loop() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" + fn main() { + loop { + if<|> true { + continue; + } + } + } + "#, + ); + } + + #[test] + fn ignore_return() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" + fn main() { + if<|> true { + return + } + } + "#, + ); + } + + #[test] + fn ignore_else_branch() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" + fn main() { + if<|> true { + foo(); + } else { + bar() + } + } + "#, + ); + } + + #[test] + fn ignore_statements_aftert_if() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" + fn main() { + if<|> true { + foo(); + } + bar(); + } + "#, + ); + } + + #[test] + fn ignore_statements_inside_if() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" + fn main() { + if false { + if<|> true { + foo(); + } + } + } + "#, + ); + } +} diff --git a/crates/assists/src/handlers/expand_glob_import.rs b/crates/assists/src/handlers/expand_glob_import.rs new file mode 100644 index 000000000..f690ec343 --- /dev/null +++ b/crates/assists/src/handlers/expand_glob_import.rs @@ -0,0 +1,391 @@ +use hir::{AssocItem, MacroDef, ModuleDef, Name, PathResolution, ScopeDef, SemanticsScope}; +use ide_db::{ + defs::{classify_name_ref, Definition, NameRefClass}, + RootDatabase, +}; +use syntax::{algo, ast, match_ast, AstNode, SyntaxNode, SyntaxToken, T}; + +use crate::{ + assist_context::{AssistBuilder, AssistContext, Assists}, + AssistId, AssistKind, +}; + +use either::Either; + +// Assist: expand_glob_import +// +// Expands glob imports. +// +// ``` +// mod foo { +// pub struct Bar; +// pub struct Baz; +// } +// +// use foo::*<|>; +// +// fn qux(bar: Bar, baz: Baz) {} +// ``` +// -> +// ``` +// mod foo { +// pub struct Bar; +// pub struct Baz; +// } +// +// use foo::{Baz, Bar}; +// +// fn qux(bar: Bar, baz: Baz) {} +// ``` +pub(crate) fn expand_glob_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let star = ctx.find_token_at_offset(T![*])?; + let mod_path = find_mod_path(&star)?; + + let source_file = ctx.source_file(); + let scope = ctx.sema.scope_at_offset(source_file.syntax(), ctx.offset()); + + let defs_in_mod = find_defs_in_mod(ctx, scope, &mod_path)?; + let name_refs_in_source_file = + source_file.syntax().descendants().filter_map(ast::NameRef::cast).collect(); + let used_names = find_used_names(ctx, defs_in_mod, name_refs_in_source_file); + + let parent = star.parent().parent()?; + acc.add( + AssistId("expand_glob_import", AssistKind::RefactorRewrite), + "Expand glob import", + parent.text_range(), + |builder| { + replace_ast(builder, &parent, mod_path, used_names); + }, + ) +} + +fn find_mod_path(star: &SyntaxToken) -> Option { + star.ancestors().find_map(|n| ast::UseTree::cast(n).and_then(|u| u.path())) +} + +#[derive(PartialEq)] +enum Def { + ModuleDef(ModuleDef), + MacroDef(MacroDef), +} + +impl Def { + fn name(&self, db: &RootDatabase) -> Option { + match self { + Def::ModuleDef(def) => def.name(db), + Def::MacroDef(def) => def.name(db), + } + } +} + +fn find_defs_in_mod( + ctx: &AssistContext, + from: SemanticsScope<'_>, + path: &ast::Path, +) -> Option> { + let hir_path = ctx.sema.lower_path(&path)?; + let module = if let Some(PathResolution::Def(ModuleDef::Module(module))) = + from.resolve_hir_path_qualifier(&hir_path) + { + module + } else { + return None; + }; + + let module_scope = module.scope(ctx.db(), from.module()); + + let mut defs = vec![]; + for (_, def) in module_scope { + match def { + ScopeDef::ModuleDef(def) => defs.push(Def::ModuleDef(def)), + ScopeDef::MacroDef(def) => defs.push(Def::MacroDef(def)), + _ => continue, + } + } + + Some(defs) +} + +fn find_used_names( + ctx: &AssistContext, + defs_in_mod: Vec, + name_refs_in_source_file: Vec, +) -> Vec { + let defs_in_source_file = name_refs_in_source_file + .iter() + .filter_map(|r| classify_name_ref(&ctx.sema, r)) + .filter_map(|rc| match rc { + NameRefClass::Definition(Definition::ModuleDef(def)) => Some(Def::ModuleDef(def)), + NameRefClass::Definition(Definition::Macro(def)) => Some(Def::MacroDef(def)), + _ => None, + }) + .collect::>(); + + defs_in_mod + .iter() + .filter(|def| { + if let Def::ModuleDef(ModuleDef::Trait(tr)) = def { + for item in tr.items(ctx.db()) { + if let AssocItem::Function(f) = item { + if defs_in_source_file.contains(&Def::ModuleDef(ModuleDef::Function(f))) { + return true; + } + } + } + } + + defs_in_source_file.contains(def) + }) + .filter_map(|d| d.name(ctx.db())) + .collect() +} + +fn replace_ast( + builder: &mut AssistBuilder, + node: &SyntaxNode, + path: ast::Path, + used_names: Vec, +) { + let replacement: Either = match used_names.as_slice() { + [name] => Either::Left(ast::make::use_tree( + ast::make::path_from_text(&format!("{}::{}", path, name)), + None, + None, + false, + )), + names => Either::Right(ast::make::use_tree_list(names.iter().map(|n| { + ast::make::use_tree(ast::make::path_from_text(&n.to_string()), None, None, false) + }))), + }; + + let mut replace_node = |replacement: Either| { + algo::diff(node, &replacement.either(|u| u.syntax().clone(), |ut| ut.syntax().clone())) + .into_text_edit(builder.text_edit_builder()); + }; + + match_ast! { + match node { + ast::UseTree(use_tree) => { + replace_node(replacement); + }, + ast::UseTreeList(use_tree_list) => { + replace_node(replacement); + }, + ast::Use(use_item) => { + builder.replace_ast(use_item, ast::make::use_(replacement.left_or_else(|ut| ast::make::use_tree(path, Some(ut), None, false)))); + }, + _ => {}, + } + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn expanding_glob_import() { + check_assist( + expand_glob_import, + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::*<|>; + +fn qux(bar: Bar, baz: Baz) { + f(); +} +", + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::{Baz, Bar, f}; + +fn qux(bar: Bar, baz: Baz) { + f(); +} +", + ) + } + + #[test] + fn expanding_glob_import_with_existing_explicit_names() { + check_assist( + expand_glob_import, + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::{*<|>, f}; + +fn qux(bar: Bar, baz: Baz) { + f(); +} +", + r" +mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} +} + +use foo::{Baz, Bar, f}; + +fn qux(bar: Bar, baz: Baz) { + f(); +} +", + ) + } + + #[test] + fn expanding_nested_glob_import() { + check_assist( + expand_glob_import, + r" +mod foo { + mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + mod baz { + pub fn g() {} + } +} + +use foo::{bar::{*<|>, f}, baz::*}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); +} +", + r" +mod foo { + mod bar { + pub struct Bar; + pub struct Baz; + pub struct Qux; + + pub fn f() {} + } + + mod baz { + pub fn g() {} + } +} + +use foo::{bar::{Baz, Bar, f}, baz::*}; + +fn qux(bar: Bar, baz: Baz) { + f(); + g(); +} +", + ) + } + + #[test] + fn expanding_glob_import_with_macro_defs() { + check_assist( + expand_glob_import, + r" +//- /lib.rs crate:foo +#[macro_export] +macro_rules! bar { + () => () +} + +pub fn baz() {} + +//- /main.rs crate:main deps:foo +use foo::*<|>; + +fn main() { + bar!(); + baz(); +} +", + r" +use foo::{bar, baz}; + +fn main() { + bar!(); + baz(); +} +", + ) + } + + #[test] + fn expanding_glob_import_with_trait_method_uses() { + check_assist( + expand_glob_import, + r" +//- /lib.rs crate:foo +pub trait Tr { + fn method(&self) {} +} +impl Tr for () {} + +//- /main.rs crate:main deps:foo +use foo::*<|>; + +fn main() { + ().method(); +} +", + r" +use foo::Tr; + +fn main() { + ().method(); +} +", + ) + } + + #[test] + fn expanding_is_not_applicable_if_cursor_is_not_in_star_token() { + check_assist_not_applicable( + expand_glob_import, + r" + mod foo { + pub struct Bar; + pub struct Baz; + pub struct Qux; + } + + use foo::Bar<|>; + + fn qux(bar: Bar, baz: Baz) {} + ", + ) + } +} diff --git a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs new file mode 100644 index 000000000..4bcdae7ba --- /dev/null +++ b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs @@ -0,0 +1,317 @@ +use base_db::FileId; +use hir::{EnumVariant, Module, ModuleDef, Name}; +use ide_db::{defs::Definition, search::Reference, RootDatabase}; +use rustc_hash::FxHashSet; +use syntax::{ + algo::find_node_at_offset, + ast::{self, edit::IndentLevel, ArgListOwner, AstNode, NameOwner, VisibilityOwner}, + SourceFile, TextRange, TextSize, +}; + +use crate::{ + assist_context::AssistBuilder, utils::insert_use_statement, AssistContext, AssistId, + AssistKind, Assists, +}; + +// Assist: extract_struct_from_enum_variant +// +// Extracts a struct from enum variant. +// +// ``` +// enum A { <|>One(u32, u32) } +// ``` +// -> +// ``` +// struct One(pub u32, pub u32); +// +// enum A { One(One) } +// ``` +pub(crate) fn extract_struct_from_enum_variant( + acc: &mut Assists, + ctx: &AssistContext, +) -> Option<()> { + let variant = ctx.find_node_at_offset::()?; + let field_list = match variant.kind() { + ast::StructKind::Tuple(field_list) => field_list, + _ => return None, + }; + let variant_name = variant.name()?.to_string(); + let variant_hir = ctx.sema.to_def(&variant)?; + if existing_struct_def(ctx.db(), &variant_name, &variant_hir) { + return None; + } + let enum_ast = variant.parent_enum(); + let visibility = enum_ast.visibility(); + let enum_hir = ctx.sema.to_def(&enum_ast)?; + let variant_hir_name = variant_hir.name(ctx.db()); + let enum_module_def = ModuleDef::from(enum_hir); + let current_module = enum_hir.module(ctx.db()); + let target = variant.syntax().text_range(); + acc.add( + AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite), + "Extract struct from enum variant", + target, + |builder| { + let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)); + let res = definition.find_usages(&ctx.sema, None); + let start_offset = variant.parent_enum().syntax().text_range().start(); + let mut visited_modules_set = FxHashSet::default(); + visited_modules_set.insert(current_module); + for reference in res { + let source_file = ctx.sema.parse(reference.file_range.file_id); + update_reference( + ctx, + builder, + reference, + &source_file, + &enum_module_def, + &variant_hir_name, + &mut visited_modules_set, + ); + } + extract_struct_def( + builder, + &enum_ast, + &variant_name, + &field_list.to_string(), + start_offset, + ctx.frange.file_id, + &visibility, + ); + let list_range = field_list.syntax().text_range(); + update_variant(builder, &variant_name, ctx.frange.file_id, list_range); + }, + ) +} + +fn existing_struct_def(db: &RootDatabase, variant_name: &str, variant: &EnumVariant) -> bool { + variant + .parent_enum(db) + .module(db) + .scope(db, None) + .into_iter() + .any(|(name, _)| name.to_string() == variant_name.to_string()) +} + +fn insert_import( + ctx: &AssistContext, + builder: &mut AssistBuilder, + path: &ast::PathExpr, + module: &Module, + enum_module_def: &ModuleDef, + variant_hir_name: &Name, +) -> Option<()> { + let db = ctx.db(); + let mod_path = module.find_use_path(db, enum_module_def.clone()); + if let Some(mut mod_path) = mod_path { + mod_path.segments.pop(); + mod_path.segments.push(variant_hir_name.clone()); + insert_use_statement(path.syntax(), &mod_path, ctx, builder.text_edit_builder()); + } + Some(()) +} + +// FIXME: this should use strongly-typed `make`, rather than string manipulation. +fn extract_struct_def( + builder: &mut AssistBuilder, + enum_: &ast::Enum, + variant_name: &str, + variant_list: &str, + start_offset: TextSize, + file_id: FileId, + visibility: &Option, +) -> Option<()> { + let visibility_string = if let Some(visibility) = visibility { + format!("{} ", visibility.to_string()) + } else { + "".to_string() + }; + let indent = IndentLevel::from_node(enum_.syntax()); + let struct_def = format!( + r#"{}struct {}{}; + +{}"#, + visibility_string, + variant_name, + list_with_visibility(variant_list), + indent + ); + builder.edit_file(file_id); + builder.insert(start_offset, struct_def); + Some(()) +} + +fn update_variant( + builder: &mut AssistBuilder, + variant_name: &str, + file_id: FileId, + list_range: TextRange, +) -> Option<()> { + let inside_variant_range = TextRange::new( + list_range.start().checked_add(TextSize::from(1))?, + list_range.end().checked_sub(TextSize::from(1))?, + ); + builder.edit_file(file_id); + builder.replace(inside_variant_range, variant_name); + Some(()) +} + +fn update_reference( + ctx: &AssistContext, + builder: &mut AssistBuilder, + reference: Reference, + source_file: &SourceFile, + enum_module_def: &ModuleDef, + variant_hir_name: &Name, + visited_modules_set: &mut FxHashSet, +) -> Option<()> { + let path_expr: ast::PathExpr = find_node_at_offset::( + source_file.syntax(), + reference.file_range.range.start(), + )?; + let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?; + let list = call.arg_list()?; + let segment = path_expr.path()?.segment()?; + let module = ctx.sema.scope(&path_expr.syntax()).module()?; + let list_range = list.syntax().text_range(); + let inside_list_range = TextRange::new( + list_range.start().checked_add(TextSize::from(1))?, + list_range.end().checked_sub(TextSize::from(1))?, + ); + builder.edit_file(reference.file_range.file_id); + if !visited_modules_set.contains(&module) { + if insert_import(ctx, builder, &path_expr, &module, enum_module_def, variant_hir_name) + .is_some() + { + visited_modules_set.insert(module); + } + } + builder.replace(inside_list_range, format!("{}{}", segment, list)); + Some(()) +} + +fn list_with_visibility(list: &str) -> String { + list.split(',') + .map(|part| { + let index = if part.chars().next().unwrap() == '(' { 1usize } else { 0 }; + let mut mod_part = part.trim().to_string(); + mod_part.insert_str(index, "pub "); + mod_part + }) + .collect::>() + .join(", ") +} + +#[cfg(test)] +mod tests { + + use crate::{ + tests::{check_assist, check_assist_not_applicable}, + utils::FamousDefs, + }; + + use super::*; + + #[test] + fn test_extract_struct_several_fields() { + check_assist( + extract_struct_from_enum_variant, + "enum A { <|>One(u32, u32) }", + r#"struct One(pub u32, pub u32); + +enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_one_field() { + check_assist( + extract_struct_from_enum_variant, + "enum A { <|>One(u32) }", + r#"struct One(pub u32); + +enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_pub_visibility() { + check_assist( + extract_struct_from_enum_variant, + "pub enum A { <|>One(u32, u32) }", + r#"pub struct One(pub u32, pub u32); + +pub enum A { One(One) }"#, + ); + } + + #[test] + fn test_extract_struct_with_complex_imports() { + check_assist( + extract_struct_from_enum_variant, + r#"mod my_mod { + fn another_fn() { + let m = my_other_mod::MyEnum::MyField(1, 1); + } + + pub mod my_other_mod { + fn another_fn() { + let m = MyEnum::MyField(1, 1); + } + + pub enum MyEnum { + <|>MyField(u8, u8), + } + } +} + +fn another_fn() { + let m = my_mod::my_other_mod::MyEnum::MyField(1, 1); +}"#, + r#"use my_mod::my_other_mod::MyField; + +mod my_mod { + use my_other_mod::MyField; + + fn another_fn() { + let m = my_other_mod::MyEnum::MyField(MyField(1, 1)); + } + + pub mod my_other_mod { + fn another_fn() { + let m = MyEnum::MyField(MyField(1, 1)); + } + + pub struct MyField(pub u8, pub u8); + + pub enum MyEnum { + MyField(MyField), + } + } +} + +fn another_fn() { + let m = my_mod::my_other_mod::MyEnum::MyField(MyField(1, 1)); +}"#, + ); + } + + fn check_not_applicable(ra_fixture: &str) { + let fixture = + format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE); + check_assist_not_applicable(extract_struct_from_enum_variant, &fixture) + } + + #[test] + fn test_extract_enum_not_applicable_for_element_with_no_fields() { + check_not_applicable("enum A { <|>One }"); + } + + #[test] + fn test_extract_enum_not_applicable_if_struct_exists() { + check_not_applicable( + r#"struct One; + enum A { <|>One(u8) }"#, + ); + } +} diff --git a/crates/assists/src/handlers/extract_variable.rs b/crates/assists/src/handlers/extract_variable.rs new file mode 100644 index 000000000..d2ae137cd --- /dev/null +++ b/crates/assists/src/handlers/extract_variable.rs @@ -0,0 +1,588 @@ +use stdx::format_to; +use syntax::{ + ast::{self, AstNode}, + SyntaxKind::{ + BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, PATH_EXPR, RETURN_EXPR, + }, + SyntaxNode, +}; +use test_utils::mark; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: extract_variable +// +// Extracts subexpression into a variable. +// +// ``` +// fn main() { +// <|>(1 + 2)<|> * 4; +// } +// ``` +// -> +// ``` +// fn main() { +// let $0var_name = (1 + 2); +// var_name * 4; +// } +// ``` +pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + if ctx.frange.range.is_empty() { + return None; + } + let node = ctx.covering_element(); + if node.kind() == COMMENT { + mark::hit!(extract_var_in_comment_is_not_applicable); + return None; + } + let to_extract = node.ancestors().find_map(valid_target_expr)?; + let anchor = Anchor::from(&to_extract)?; + let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone(); + let target = to_extract.syntax().text_range(); + acc.add( + AssistId("extract_variable", AssistKind::RefactorExtract), + "Extract into variable", + target, + move |edit| { + let field_shorthand = + match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) { + Some(field) => field.name_ref(), + None => None, + }; + + let mut buf = String::new(); + + let var_name = match &field_shorthand { + Some(it) => it.to_string(), + None => "var_name".to_string(), + }; + let expr_range = match &field_shorthand { + Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()), + None => to_extract.syntax().text_range(), + }; + + if let Anchor::WrapInBlock(_) = anchor { + format_to!(buf, "{{ let {} = ", var_name); + } else { + format_to!(buf, "let {} = ", var_name); + }; + format_to!(buf, "{}", to_extract.syntax()); + + if let Anchor::Replace(stmt) = anchor { + mark::hit!(test_extract_var_expr_stmt); + if stmt.semicolon_token().is_none() { + buf.push_str(";"); + } + match ctx.config.snippet_cap { + Some(cap) => { + let snip = buf + .replace(&format!("let {}", var_name), &format!("let $0{}", var_name)); + edit.replace_snippet(cap, expr_range, snip) + } + None => edit.replace(expr_range, buf), + } + return; + } + + buf.push_str(";"); + + // We want to maintain the indent level, + // but we do not want to duplicate possible + // extra newlines in the indent block + let text = indent.text(); + if text.starts_with('\n') { + buf.push_str("\n"); + buf.push_str(text.trim_start_matches('\n')); + } else { + buf.push_str(text); + } + + edit.replace(expr_range, var_name.clone()); + let offset = anchor.syntax().text_range().start(); + match ctx.config.snippet_cap { + Some(cap) => { + let snip = + buf.replace(&format!("let {}", var_name), &format!("let $0{}", var_name)); + edit.insert_snippet(cap, offset, snip) + } + None => edit.insert(offset, buf), + } + + if let Anchor::WrapInBlock(_) = anchor { + edit.insert(anchor.syntax().text_range().end(), " }"); + } + }, + ) +} + +/// Check whether the node is a valid expression which can be extracted to a variable. +/// In general that's true for any expression, but in some cases that would produce invalid code. +fn valid_target_expr(node: SyntaxNode) -> Option { + match node.kind() { + PATH_EXPR | LOOP_EXPR => None, + BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()), + RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()), + BLOCK_EXPR => { + ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from) + } + _ => ast::Expr::cast(node), + } +} + +enum Anchor { + Before(SyntaxNode), + Replace(ast::ExprStmt), + WrapInBlock(SyntaxNode), +} + +impl Anchor { + fn from(to_extract: &ast::Expr) -> Option { + to_extract.syntax().ancestors().find_map(|node| { + if let Some(expr) = + node.parent().and_then(ast::BlockExpr::cast).and_then(|it| it.expr()) + { + if expr.syntax() == &node { + mark::hit!(test_extract_var_last_expr); + return Some(Anchor::Before(node)); + } + } + + if let Some(parent) = node.parent() { + if parent.kind() == MATCH_ARM || parent.kind() == CLOSURE_EXPR { + return Some(Anchor::WrapInBlock(node)); + } + } + + if let Some(stmt) = ast::Stmt::cast(node.clone()) { + if let ast::Stmt::ExprStmt(stmt) = stmt { + if stmt.expr().as_ref() == Some(to_extract) { + return Some(Anchor::Replace(stmt)); + } + } + return Some(Anchor::Before(node)); + } + None + }) + } + + fn syntax(&self) -> &SyntaxNode { + match self { + Anchor::Before(it) | Anchor::WrapInBlock(it) => it, + Anchor::Replace(stmt) => stmt.syntax(), + } + } +} + +#[cfg(test)] +mod tests { + use test_utils::mark; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn test_extract_var_simple() { + check_assist( + extract_variable, + r#" +fn foo() { + foo(<|>1 + 1<|>); +}"#, + r#" +fn foo() { + let $0var_name = 1 + 1; + foo(var_name); +}"#, + ); + } + + #[test] + fn extract_var_in_comment_is_not_applicable() { + mark::check!(extract_var_in_comment_is_not_applicable); + check_assist_not_applicable(extract_variable, "fn main() { 1 + /* <|>comment<|> */ 1; }"); + } + + #[test] + fn test_extract_var_expr_stmt() { + mark::check!(test_extract_var_expr_stmt); + check_assist( + extract_variable, + r#" +fn foo() { + <|>1 + 1<|>; +}"#, + r#" +fn foo() { + let $0var_name = 1 + 1; +}"#, + ); + check_assist( + extract_variable, + " +fn foo() { + <|>{ let x = 0; x }<|> + something_else(); +}", + " +fn foo() { + let $0var_name = { let x = 0; x }; + something_else(); +}", + ); + } + + #[test] + fn test_extract_var_part_of_expr_stmt() { + check_assist( + extract_variable, + " +fn foo() { + <|>1<|> + 1; +}", + " +fn foo() { + let $0var_name = 1; + var_name + 1; +}", + ); + } + + #[test] + fn test_extract_var_last_expr() { + mark::check!(test_extract_var_last_expr); + check_assist( + extract_variable, + r#" +fn foo() { + bar(<|>1 + 1<|>) +} +"#, + r#" +fn foo() { + let $0var_name = 1 + 1; + bar(var_name) +} +"#, + ); + check_assist( + extract_variable, + r#" +fn foo() { + <|>bar(1 + 1)<|> +} +"#, + r#" +fn foo() { + let $0var_name = bar(1 + 1); + var_name +} +"#, + ) + } + + #[test] + fn test_extract_var_in_match_arm_no_block() { + check_assist( + extract_variable, + " +fn main() { + let x = true; + let tuple = match x { + true => (<|>2 + 2<|>, true) + _ => (0, false) + }; +} +", + " +fn main() { + let x = true; + let tuple = match x { + true => { let $0var_name = 2 + 2; (var_name, true) } + _ => (0, false) + }; +} +", + ); + } + + #[test] + fn test_extract_var_in_match_arm_with_block() { + check_assist( + extract_variable, + " +fn main() { + let x = true; + let tuple = match x { + true => { + let y = 1; + (<|>2 + y<|>, true) + } + _ => (0, false) + }; +} +", + " +fn main() { + let x = true; + let tuple = match x { + true => { + let y = 1; + let $0var_name = 2 + y; + (var_name, true) + } + _ => (0, false) + }; +} +", + ); + } + + #[test] + fn test_extract_var_in_closure_no_block() { + check_assist( + extract_variable, + " +fn main() { + let lambda = |x: u32| <|>x * 2<|>; +} +", + " +fn main() { + let lambda = |x: u32| { let $0var_name = x * 2; var_name }; +} +", + ); + } + + #[test] + fn test_extract_var_in_closure_with_block() { + check_assist( + extract_variable, + " +fn main() { + let lambda = |x: u32| { <|>x * 2<|> }; +} +", + " +fn main() { + let lambda = |x: u32| { let $0var_name = x * 2; var_name }; +} +", + ); + } + + #[test] + fn test_extract_var_path_simple() { + check_assist( + extract_variable, + " +fn main() { + let o = <|>Some(true)<|>; +} +", + " +fn main() { + let $0var_name = Some(true); + let o = var_name; +} +", + ); + } + + #[test] + fn test_extract_var_path_method() { + check_assist( + extract_variable, + " +fn main() { + let v = <|>bar.foo()<|>; +} +", + " +fn main() { + let $0var_name = bar.foo(); + let v = var_name; +} +", + ); + } + + #[test] + fn test_extract_var_return() { + check_assist( + extract_variable, + " +fn foo() -> u32 { + <|>return 2 + 2<|>; +} +", + " +fn foo() -> u32 { + let $0var_name = 2 + 2; + return var_name; +} +", + ); + } + + #[test] + fn test_extract_var_does_not_add_extra_whitespace() { + check_assist( + extract_variable, + " +fn foo() -> u32 { + + + <|>return 2 + 2<|>; +} +", + " +fn foo() -> u32 { + + + let $0var_name = 2 + 2; + return var_name; +} +", + ); + + check_assist( + extract_variable, + " +fn foo() -> u32 { + + <|>return 2 + 2<|>; +} +", + " +fn foo() -> u32 { + + let $0var_name = 2 + 2; + return var_name; +} +", + ); + + check_assist( + extract_variable, + " +fn foo() -> u32 { + let foo = 1; + + // bar + + + <|>return 2 + 2<|>; +} +", + " +fn foo() -> u32 { + let foo = 1; + + // bar + + + let $0var_name = 2 + 2; + return var_name; +} +", + ); + } + + #[test] + fn test_extract_var_break() { + check_assist( + extract_variable, + " +fn main() { + let result = loop { + <|>break 2 + 2<|>; + }; +} +", + " +fn main() { + let result = loop { + let $0var_name = 2 + 2; + break var_name; + }; +} +", + ); + } + + #[test] + fn test_extract_var_for_cast() { + check_assist( + extract_variable, + " +fn main() { + let v = <|>0f32 as u32<|>; +} +", + " +fn main() { + let $0var_name = 0f32 as u32; + let v = var_name; +} +", + ); + } + + #[test] + fn extract_var_field_shorthand() { + check_assist( + extract_variable, + r#" +struct S { + foo: i32 +} + +fn main() { + S { foo: <|>1 + 1<|> } +} +"#, + r#" +struct S { + foo: i32 +} + +fn main() { + let $0foo = 1 + 1; + S { foo } +} +"#, + ) + } + + #[test] + fn test_extract_var_for_return_not_applicable() { + check_assist_not_applicable(extract_variable, "fn foo() { <|>return<|>; } "); + } + + #[test] + fn test_extract_var_for_break_not_applicable() { + check_assist_not_applicable(extract_variable, "fn main() { loop { <|>break<|>; }; }"); + } + + // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic + #[test] + fn extract_var_target() { + check_assist_target(extract_variable, "fn foo() -> u32 { <|>return 2 + 2<|>; }", "2 + 2"); + + check_assist_target( + extract_variable, + " +fn main() { + let x = true; + let tuple = match x { + true => (<|>2 + 2<|>, true) + _ => (0, false) + }; +} +", + "2 + 2", + ); + } +} diff --git a/crates/assists/src/handlers/fill_match_arms.rs b/crates/assists/src/handlers/fill_match_arms.rs new file mode 100644 index 000000000..3d9bdb2bf --- /dev/null +++ b/crates/assists/src/handlers/fill_match_arms.rs @@ -0,0 +1,747 @@ +use std::iter; + +use hir::{Adt, HasSource, ModuleDef, Semantics}; +use ide_db::RootDatabase; +use itertools::Itertools; +use syntax::ast::{self, make, AstNode, MatchArm, NameOwner, Pat}; +use test_utils::mark; + +use crate::{ + utils::{render_snippet, Cursor, FamousDefs}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: fill_match_arms +// +// Adds missing clauses to a `match` expression. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// <|> +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// $0Action::Move { distance } => {} +// Action::Stop => {} +// } +// } +// ``` +pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let match_expr = ctx.find_node_at_offset::()?; + let match_arm_list = match_expr.match_arm_list()?; + + let expr = match_expr.expr()?; + + let mut arms: Vec = match_arm_list.arms().collect(); + if arms.len() == 1 { + if let Some(Pat::WildcardPat(..)) = arms[0].pat() { + arms.clear(); + } + } + + let module = ctx.sema.scope(expr.syntax()).module()?; + + let missing_arms: Vec = if let Some(enum_def) = resolve_enum_def(&ctx.sema, &expr) { + let variants = enum_def.variants(ctx.db()); + + let mut variants = variants + .into_iter() + .filter_map(|variant| build_pat(ctx.db(), module, variant)) + .filter(|variant_pat| is_variant_missing(&mut arms, variant_pat)) + .map(|pat| make::match_arm(iter::once(pat), make::expr_empty_block())) + .collect::>(); + if Some(enum_def) == FamousDefs(&ctx.sema, module.krate()).core_option_Option() { + // Match `Some` variant first. + mark::hit!(option_order); + variants.reverse() + } + variants + } else if let Some(enum_defs) = resolve_tuple_of_enum_def(&ctx.sema, &expr) { + // Partial fill not currently supported for tuple of enums. + if !arms.is_empty() { + return None; + } + + // We do not currently support filling match arms for a tuple + // containing a single enum. + if enum_defs.len() < 2 { + return None; + } + + // When calculating the match arms for a tuple of enums, we want + // to create a match arm for each possible combination of enum + // values. The `multi_cartesian_product` method transforms + // Vec> into Vec<(EnumVariant, .., EnumVariant)> + // where each tuple represents a proposed match arm. + enum_defs + .into_iter() + .map(|enum_def| enum_def.variants(ctx.db())) + .multi_cartesian_product() + .map(|variants| { + let patterns = + variants.into_iter().filter_map(|variant| build_pat(ctx.db(), module, variant)); + ast::Pat::from(make::tuple_pat(patterns)) + }) + .filter(|variant_pat| is_variant_missing(&mut arms, variant_pat)) + .map(|pat| make::match_arm(iter::once(pat), make::expr_empty_block())) + .collect() + } else { + return None; + }; + + if missing_arms.is_empty() { + return None; + } + + let target = match_expr.syntax().text_range(); + acc.add( + AssistId("fill_match_arms", AssistKind::QuickFix), + "Fill match arms", + target, + |builder| { + let new_arm_list = match_arm_list.remove_placeholder(); + let n_old_arms = new_arm_list.arms().count(); + let new_arm_list = new_arm_list.append_arms(missing_arms); + let first_new_arm = new_arm_list.arms().nth(n_old_arms); + let old_range = match_arm_list.syntax().text_range(); + match (first_new_arm, ctx.config.snippet_cap) { + (Some(first_new_arm), Some(cap)) => { + let extend_lifetime; + let cursor = + match first_new_arm.syntax().descendants().find_map(ast::WildcardPat::cast) + { + Some(it) => { + extend_lifetime = it.syntax().clone(); + Cursor::Replace(&extend_lifetime) + } + None => Cursor::Before(first_new_arm.syntax()), + }; + let snippet = render_snippet(cap, new_arm_list.syntax(), cursor); + builder.replace_snippet(cap, old_range, snippet); + } + _ => builder.replace(old_range, new_arm_list.to_string()), + } + }, + ) +} + +fn is_variant_missing(existing_arms: &mut Vec, var: &Pat) -> bool { + existing_arms.iter().filter_map(|arm| arm.pat()).all(|pat| { + // Special casee OrPat as separate top-level pats + let top_level_pats: Vec = match pat { + Pat::OrPat(pats) => pats.pats().collect::>(), + _ => vec![pat], + }; + + !top_level_pats.iter().any(|pat| does_pat_match_variant(pat, var)) + }) +} + +fn does_pat_match_variant(pat: &Pat, var: &Pat) -> bool { + let first_node_text = |pat: &Pat| pat.syntax().first_child().map(|node| node.text()); + + let pat_head = match pat { + Pat::IdentPat(bind_pat) => { + if let Some(p) = bind_pat.pat() { + first_node_text(&p) + } else { + return false; + } + } + pat => first_node_text(pat), + }; + + let var_head = first_node_text(var); + + pat_head == var_head +} + +fn resolve_enum_def(sema: &Semantics, expr: &ast::Expr) -> Option { + sema.type_of_expr(&expr)?.autoderef(sema.db).find_map(|ty| match ty.as_adt() { + Some(Adt::Enum(e)) => Some(e), + _ => None, + }) +} + +fn resolve_tuple_of_enum_def( + sema: &Semantics, + expr: &ast::Expr, +) -> Option> { + sema.type_of_expr(&expr)? + .tuple_fields(sema.db) + .iter() + .map(|ty| { + ty.autoderef(sema.db).find_map(|ty| match ty.as_adt() { + Some(Adt::Enum(e)) => Some(e), + // For now we only handle expansion for a tuple of enums. Here + // we map non-enum items to None and rely on `collect` to + // convert Vec> into Option>. + _ => None, + }) + }) + .collect() +} + +fn build_pat(db: &RootDatabase, module: hir::Module, var: hir::EnumVariant) -> Option { + let path = crate::ast_transform::path_to_ast(module.find_use_path(db, ModuleDef::from(var))?); + + // FIXME: use HIR for this; it doesn't currently expose struct vs. tuple vs. unit variants though + let pat: ast::Pat = match var.source(db).value.kind() { + ast::StructKind::Tuple(field_list) => { + let pats = iter::repeat(make::wildcard_pat().into()).take(field_list.fields().count()); + make::tuple_struct_pat(path, pats).into() + } + ast::StructKind::Record(field_list) => { + let pats = field_list.fields().map(|f| make::ident_pat(f.name().unwrap()).into()); + make::record_pat(path, pats).into() + } + ast::StructKind::Unit => make::path_pat(path), + }; + + Some(pat) +} + +#[cfg(test)] +mod tests { + use test_utils::mark; + + use crate::{ + tests::{check_assist, check_assist_not_applicable, check_assist_target}, + utils::FamousDefs, + }; + + use super::fill_match_arms; + + #[test] + fn all_match_arms_provided() { + check_assist_not_applicable( + fill_match_arms, + r#" + enum A { + As, + Bs{x:i32, y:Option}, + Cs(i32, Option), + } + fn main() { + match A::As<|> { + A::As, + A::Bs{x,y:Some(_)} => {} + A::Cs(_, Some(_)) => {} + } + } + "#, + ); + } + + #[test] + fn tuple_of_non_enum() { + // for now this case is not handled, although it potentially could be + // in the future + check_assist_not_applicable( + fill_match_arms, + r#" + fn main() { + match (0, false)<|> { + } + } + "#, + ); + } + + #[test] + fn partial_fill_record_tuple() { + check_assist( + fill_match_arms, + r#" + enum A { + As, + Bs { x: i32, y: Option }, + Cs(i32, Option), + } + fn main() { + match A::As<|> { + A::Bs { x, y: Some(_) } => {} + A::Cs(_, Some(_)) => {} + } + } + "#, + r#" + enum A { + As, + Bs { x: i32, y: Option }, + Cs(i32, Option), + } + fn main() { + match A::As { + A::Bs { x, y: Some(_) } => {} + A::Cs(_, Some(_)) => {} + $0A::As => {} + } + } + "#, + ); + } + + #[test] + fn partial_fill_or_pat() { + check_assist( + fill_match_arms, + r#" +enum A { As, Bs, Cs(Option) } +fn main() { + match A::As<|> { + A::Cs(_) | A::Bs => {} + } +} +"#, + r#" +enum A { As, Bs, Cs(Option) } +fn main() { + match A::As { + A::Cs(_) | A::Bs => {} + $0A::As => {} + } +} +"#, + ); + } + + #[test] + fn partial_fill() { + check_assist( + fill_match_arms, + r#" +enum A { As, Bs, Cs, Ds(String), Es(B) } +enum B { Xs, Ys } +fn main() { + match A::As<|> { + A::Bs if 0 < 1 => {} + A::Ds(_value) => { let x = 1; } + A::Es(B::Xs) => (), + } +} +"#, + r#" +enum A { As, Bs, Cs, Ds(String), Es(B) } +enum B { Xs, Ys } +fn main() { + match A::As { + A::Bs if 0 < 1 => {} + A::Ds(_value) => { let x = 1; } + A::Es(B::Xs) => (), + $0A::As => {} + A::Cs => {} + } +} +"#, + ); + } + + #[test] + fn partial_fill_bind_pat() { + check_assist( + fill_match_arms, + r#" +enum A { As, Bs, Cs(Option) } +fn main() { + match A::As<|> { + A::As(_) => {} + a @ A::Bs(_) => {} + } +} +"#, + r#" +enum A { As, Bs, Cs(Option) } +fn main() { + match A::As { + A::As(_) => {} + a @ A::Bs(_) => {} + A::Cs(${0:_}) => {} + } +} +"#, + ); + } + + #[test] + fn fill_match_arms_empty_body() { + check_assist( + fill_match_arms, + r#" +enum A { As, Bs, Cs(String), Ds(String, String), Es { x: usize, y: usize } } + +fn main() { + let a = A::As; + match a<|> {} +} +"#, + r#" +enum A { As, Bs, Cs(String), Ds(String, String), Es { x: usize, y: usize } } + +fn main() { + let a = A::As; + match a { + $0A::As => {} + A::Bs => {} + A::Cs(_) => {} + A::Ds(_, _) => {} + A::Es { x, y } => {} + } +} +"#, + ); + } + + #[test] + fn fill_match_arms_tuple_of_enum() { + check_assist( + fill_match_arms, + r#" + enum A { One, Two } + enum B { One, Two } + + fn main() { + let a = A::One; + let b = B::One; + match (a<|>, b) {} + } + "#, + r#" + enum A { One, Two } + enum B { One, Two } + + fn main() { + let a = A::One; + let b = B::One; + match (a, b) { + $0(A::One, B::One) => {} + (A::One, B::Two) => {} + (A::Two, B::One) => {} + (A::Two, B::Two) => {} + } + } + "#, + ); + } + + #[test] + fn fill_match_arms_tuple_of_enum_ref() { + check_assist( + fill_match_arms, + r#" + enum A { One, Two } + enum B { One, Two } + + fn main() { + let a = A::One; + let b = B::One; + match (&a<|>, &b) {} + } + "#, + r#" + enum A { One, Two } + enum B { One, Two } + + fn main() { + let a = A::One; + let b = B::One; + match (&a, &b) { + $0(A::One, B::One) => {} + (A::One, B::Two) => {} + (A::Two, B::One) => {} + (A::Two, B::Two) => {} + } + } + "#, + ); + } + + #[test] + fn fill_match_arms_tuple_of_enum_partial() { + check_assist_not_applicable( + fill_match_arms, + r#" + enum A { One, Two } + enum B { One, Two } + + fn main() { + let a = A::One; + let b = B::One; + match (a<|>, b) { + (A::Two, B::One) => {} + } + } + "#, + ); + } + + #[test] + fn fill_match_arms_tuple_of_enum_not_applicable() { + check_assist_not_applicable( + fill_match_arms, + r#" + enum A { One, Two } + enum B { One, Two } + + fn main() { + let a = A::One; + let b = B::One; + match (a<|>, b) { + (A::Two, B::One) => {} + (A::One, B::One) => {} + (A::One, B::Two) => {} + (A::Two, B::Two) => {} + } + } + "#, + ); + } + + #[test] + fn fill_match_arms_single_element_tuple_of_enum() { + // For now we don't hande the case of a single element tuple, but + // we could handle this in the future if `make::tuple_pat` allowed + // creating a tuple with a single pattern. + check_assist_not_applicable( + fill_match_arms, + r#" + enum A { One, Two } + + fn main() { + let a = A::One; + match (a<|>, ) { + } + } + "#, + ); + } + + #[test] + fn test_fill_match_arm_refs() { + check_assist( + fill_match_arms, + r#" + enum A { As } + + fn foo(a: &A) { + match a<|> { + } + } + "#, + r#" + enum A { As } + + fn foo(a: &A) { + match a { + $0A::As => {} + } + } + "#, + ); + + check_assist( + fill_match_arms, + r#" + enum A { + Es { x: usize, y: usize } + } + + fn foo(a: &mut A) { + match a<|> { + } + } + "#, + r#" + enum A { + Es { x: usize, y: usize } + } + + fn foo(a: &mut A) { + match a { + $0A::Es { x, y } => {} + } + } + "#, + ); + } + + #[test] + fn fill_match_arms_target() { + check_assist_target( + fill_match_arms, + r#" + enum E { X, Y } + + fn main() { + match E::X<|> {} + } + "#, + "match E::X {}", + ); + } + + #[test] + fn fill_match_arms_trivial_arm() { + check_assist( + fill_match_arms, + r#" + enum E { X, Y } + + fn main() { + match E::X { + <|>_ => {} + } + } + "#, + r#" + enum E { X, Y } + + fn main() { + match E::X { + $0E::X => {} + E::Y => {} + } + } + "#, + ); + } + + #[test] + fn fill_match_arms_qualifies_path() { + check_assist( + fill_match_arms, + r#" + mod foo { pub enum E { X, Y } } + use foo::E::X; + + fn main() { + match X { + <|> + } + } + "#, + r#" + mod foo { pub enum E { X, Y } } + use foo::E::X; + + fn main() { + match X { + $0X => {} + foo::E::Y => {} + } + } + "#, + ); + } + + #[test] + fn fill_match_arms_preserves_comments() { + check_assist( + fill_match_arms, + r#" + enum A { One, Two } + fn foo(a: A) { + match a { + // foo bar baz<|> + A::One => {} + // This is where the rest should be + } + } + "#, + r#" + enum A { One, Two } + fn foo(a: A) { + match a { + // foo bar baz + A::One => {} + // This is where the rest should be + $0A::Two => {} + } + } + "#, + ); + } + + #[test] + fn fill_match_arms_preserves_comments_empty() { + check_assist( + fill_match_arms, + r#" + enum A { One, Two } + fn foo(a: A) { + match a { + // foo bar baz<|> + } + } + "#, + r#" + enum A { One, Two } + fn foo(a: A) { + match a { + // foo bar baz + $0A::One => {} + A::Two => {} + } + } + "#, + ); + } + + #[test] + fn fill_match_arms_placeholder() { + check_assist( + fill_match_arms, + r#" + enum A { One, Two, } + fn foo(a: A) { + match a<|> { + _ => (), + } + } + "#, + r#" + enum A { One, Two, } + fn foo(a: A) { + match a { + $0A::One => {} + A::Two => {} + } + } + "#, + ); + } + + #[test] + fn option_order() { + mark::check!(option_order); + let before = r#" +fn foo(opt: Option) { + match opt<|> { + } +} +"#; + let before = &format!("//- /main.rs crate:main deps:core{}{}", before, FamousDefs::FIXTURE); + + check_assist( + fill_match_arms, + before, + r#" +fn foo(opt: Option) { + match opt { + Some(${0:_}) => {} + None => {} + } +} +"#, + ); + } +} diff --git a/crates/assists/src/handlers/fix_visibility.rs b/crates/assists/src/handlers/fix_visibility.rs new file mode 100644 index 000000000..7cd76ea06 --- /dev/null +++ b/crates/assists/src/handlers/fix_visibility.rs @@ -0,0 +1,607 @@ +use base_db::FileId; +use hir::{db::HirDatabase, HasSource, HasVisibility, PathResolution}; +use syntax::{ast, AstNode, TextRange, TextSize}; + +use crate::{utils::vis_offset, AssistContext, AssistId, AssistKind, Assists}; +use ast::VisibilityOwner; + +// FIXME: this really should be a fix for diagnostic, rather than an assist. + +// Assist: fix_visibility +// +// Makes inaccessible item public. +// +// ``` +// mod m { +// fn frobnicate() {} +// } +// fn main() { +// m::frobnicate<|>() {} +// } +// ``` +// -> +// ``` +// mod m { +// $0pub(crate) fn frobnicate() {} +// } +// fn main() { +// m::frobnicate() {} +// } +// ``` +pub(crate) fn fix_visibility(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + add_vis_to_referenced_module_def(acc, ctx) + .or_else(|| add_vis_to_referenced_record_field(acc, ctx)) +} + +fn add_vis_to_referenced_module_def(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let path: ast::Path = ctx.find_node_at_offset()?; + let path_res = ctx.sema.resolve_path(&path)?; + let def = match path_res { + PathResolution::Def(def) => def, + _ => return None, + }; + + let current_module = ctx.sema.scope(&path.syntax()).module()?; + let target_module = def.module(ctx.db())?; + + let vis = target_module.visibility_of(ctx.db(), &def)?; + if vis.is_visible_from(ctx.db(), current_module.into()) { + return None; + }; + + let (offset, current_visibility, target, target_file, target_name) = + target_data_for_def(ctx.db(), def)?; + + let missing_visibility = + if current_module.krate() == target_module.krate() { "pub(crate)" } else { "pub" }; + + let assist_label = match target_name { + None => format!("Change visibility to {}", missing_visibility), + Some(name) => format!("Change visibility of {} to {}", name, missing_visibility), + }; + + acc.add(AssistId("fix_visibility", AssistKind::QuickFix), assist_label, target, |builder| { + builder.edit_file(target_file); + match ctx.config.snippet_cap { + Some(cap) => match current_visibility { + Some(current_visibility) => builder.replace_snippet( + cap, + current_visibility.syntax().text_range(), + format!("$0{}", missing_visibility), + ), + None => builder.insert_snippet(cap, offset, format!("$0{} ", missing_visibility)), + }, + None => match current_visibility { + Some(current_visibility) => { + builder.replace(current_visibility.syntax().text_range(), missing_visibility) + } + None => builder.insert(offset, format!("{} ", missing_visibility)), + }, + } + }) +} + +fn add_vis_to_referenced_record_field(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let record_field: ast::RecordExprField = ctx.find_node_at_offset()?; + let (record_field_def, _) = ctx.sema.resolve_record_field(&record_field)?; + + let current_module = ctx.sema.scope(record_field.syntax()).module()?; + let visibility = record_field_def.visibility(ctx.db()); + if visibility.is_visible_from(ctx.db(), current_module.into()) { + return None; + } + + let parent = record_field_def.parent_def(ctx.db()); + let parent_name = parent.name(ctx.db()); + let target_module = parent.module(ctx.db()); + + let in_file_source = record_field_def.source(ctx.db()); + let (offset, current_visibility, target) = match in_file_source.value { + hir::FieldSource::Named(it) => { + let s = it.syntax(); + (vis_offset(s), it.visibility(), s.text_range()) + } + hir::FieldSource::Pos(it) => { + let s = it.syntax(); + (vis_offset(s), it.visibility(), s.text_range()) + } + }; + + let missing_visibility = + if current_module.krate() == target_module.krate() { "pub(crate)" } else { "pub" }; + let target_file = in_file_source.file_id.original_file(ctx.db()); + + let target_name = record_field_def.name(ctx.db()); + let assist_label = + format!("Change visibility of {}.{} to {}", parent_name, target_name, missing_visibility); + + acc.add(AssistId("fix_visibility", AssistKind::QuickFix), assist_label, target, |builder| { + builder.edit_file(target_file); + match ctx.config.snippet_cap { + Some(cap) => match current_visibility { + Some(current_visibility) => builder.replace_snippet( + cap, + current_visibility.syntax().text_range(), + format!("$0{}", missing_visibility), + ), + None => builder.insert_snippet(cap, offset, format!("$0{} ", missing_visibility)), + }, + None => match current_visibility { + Some(current_visibility) => { + builder.replace(current_visibility.syntax().text_range(), missing_visibility) + } + None => builder.insert(offset, format!("{} ", missing_visibility)), + }, + } + }) +} + +fn target_data_for_def( + db: &dyn HirDatabase, + def: hir::ModuleDef, +) -> Option<(TextSize, Option, TextRange, FileId, Option)> { + fn offset_target_and_file_id( + db: &dyn HirDatabase, + x: S, + ) -> (TextSize, Option, TextRange, FileId) + where + S: HasSource, + Ast: AstNode + ast::VisibilityOwner, + { + let source = x.source(db); + let in_file_syntax = source.syntax(); + let file_id = in_file_syntax.file_id; + let syntax = in_file_syntax.value; + let current_visibility = source.value.visibility(); + ( + vis_offset(syntax), + current_visibility, + syntax.text_range(), + file_id.original_file(db.upcast()), + ) + } + + let target_name; + let (offset, current_visibility, target, target_file) = match def { + hir::ModuleDef::Function(f) => { + target_name = Some(f.name(db)); + offset_target_and_file_id(db, f) + } + hir::ModuleDef::Adt(adt) => { + target_name = Some(adt.name(db)); + match adt { + hir::Adt::Struct(s) => offset_target_and_file_id(db, s), + hir::Adt::Union(u) => offset_target_and_file_id(db, u), + hir::Adt::Enum(e) => offset_target_and_file_id(db, e), + } + } + hir::ModuleDef::Const(c) => { + target_name = c.name(db); + offset_target_and_file_id(db, c) + } + hir::ModuleDef::Static(s) => { + target_name = s.name(db); + offset_target_and_file_id(db, s) + } + hir::ModuleDef::Trait(t) => { + target_name = Some(t.name(db)); + offset_target_and_file_id(db, t) + } + hir::ModuleDef::TypeAlias(t) => { + target_name = Some(t.name(db)); + offset_target_and_file_id(db, t) + } + hir::ModuleDef::Module(m) => { + target_name = m.name(db); + let in_file_source = m.declaration_source(db)?; + let file_id = in_file_source.file_id.original_file(db.upcast()); + let syntax = in_file_source.value.syntax(); + (vis_offset(syntax), in_file_source.value.visibility(), syntax.text_range(), file_id) + } + // Enum variants can't be private, we can't modify builtin types + hir::ModuleDef::EnumVariant(_) | hir::ModuleDef::BuiltinType(_) => return None, + }; + + Some((offset, current_visibility, target, target_file, target_name)) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn fix_visibility_of_fn() { + check_assist( + fix_visibility, + r"mod foo { fn foo() {} } + fn main() { foo::foo<|>() } ", + r"mod foo { $0pub(crate) fn foo() {} } + fn main() { foo::foo() } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub fn foo() {} } + fn main() { foo::foo<|>() } ", + ) + } + + #[test] + fn fix_visibility_of_adt_in_submodule() { + check_assist( + fix_visibility, + r"mod foo { struct Foo; } + fn main() { foo::Foo<|> } ", + r"mod foo { $0pub(crate) struct Foo; } + fn main() { foo::Foo } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub struct Foo; } + fn main() { foo::Foo<|> } ", + ); + check_assist( + fix_visibility, + r"mod foo { enum Foo; } + fn main() { foo::Foo<|> } ", + r"mod foo { $0pub(crate) enum Foo; } + fn main() { foo::Foo } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub enum Foo; } + fn main() { foo::Foo<|> } ", + ); + check_assist( + fix_visibility, + r"mod foo { union Foo; } + fn main() { foo::Foo<|> } ", + r"mod foo { $0pub(crate) union Foo; } + fn main() { foo::Foo } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub union Foo; } + fn main() { foo::Foo<|> } ", + ); + } + + #[test] + fn fix_visibility_of_adt_in_other_file() { + check_assist( + fix_visibility, + r" +//- /main.rs +mod foo; +fn main() { foo::Foo<|> } + +//- /foo.rs +struct Foo; +", + r"$0pub(crate) struct Foo; +", + ); + } + + #[test] + fn fix_visibility_of_struct_field() { + check_assist( + fix_visibility, + r"mod foo { pub struct Foo { bar: (), } } + fn main() { foo::Foo { <|>bar: () }; } ", + r"mod foo { pub struct Foo { $0pub(crate) bar: (), } } + fn main() { foo::Foo { bar: () }; } ", + ); + check_assist( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo { <|>bar: () }; } +//- /foo.rs +pub struct Foo { bar: () } +", + r"pub struct Foo { $0pub(crate) bar: () } +", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub struct Foo { pub bar: (), } } + fn main() { foo::Foo { <|>bar: () }; } ", + ); + check_assist_not_applicable( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo { <|>bar: () }; } +//- /foo.rs +pub struct Foo { pub bar: () } +", + ); + } + + #[test] + fn fix_visibility_of_enum_variant_field() { + check_assist( + fix_visibility, + r"mod foo { pub enum Foo { Bar { bar: () } } } + fn main() { foo::Foo::Bar { <|>bar: () }; } ", + r"mod foo { pub enum Foo { Bar { $0pub(crate) bar: () } } } + fn main() { foo::Foo::Bar { bar: () }; } ", + ); + check_assist( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo::Bar { <|>bar: () }; } +//- /foo.rs +pub enum Foo { Bar { bar: () } } +", + r"pub enum Foo { Bar { $0pub(crate) bar: () } } +", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub struct Foo { pub bar: (), } } + fn main() { foo::Foo { <|>bar: () }; } ", + ); + check_assist_not_applicable( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo { <|>bar: () }; } +//- /foo.rs +pub struct Foo { pub bar: () } +", + ); + } + + #[test] + #[ignore] + // FIXME reenable this test when `Semantics::resolve_record_field` works with union fields + fn fix_visibility_of_union_field() { + check_assist( + fix_visibility, + r"mod foo { pub union Foo { bar: (), } } + fn main() { foo::Foo { <|>bar: () }; } ", + r"mod foo { pub union Foo { $0pub(crate) bar: (), } } + fn main() { foo::Foo { bar: () }; } ", + ); + check_assist( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo { <|>bar: () }; } +//- /foo.rs +pub union Foo { bar: () } +", + r"pub union Foo { $0pub(crate) bar: () } +", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub union Foo { pub bar: (), } } + fn main() { foo::Foo { <|>bar: () }; } ", + ); + check_assist_not_applicable( + fix_visibility, + r" +//- /lib.rs +mod foo; +fn main() { foo::Foo { <|>bar: () }; } +//- /foo.rs +pub union Foo { pub bar: () } +", + ); + } + + #[test] + fn fix_visibility_of_const() { + check_assist( + fix_visibility, + r"mod foo { const FOO: () = (); } + fn main() { foo::FOO<|> } ", + r"mod foo { $0pub(crate) const FOO: () = (); } + fn main() { foo::FOO } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub const FOO: () = (); } + fn main() { foo::FOO<|> } ", + ); + } + + #[test] + fn fix_visibility_of_static() { + check_assist( + fix_visibility, + r"mod foo { static FOO: () = (); } + fn main() { foo::FOO<|> } ", + r"mod foo { $0pub(crate) static FOO: () = (); } + fn main() { foo::FOO } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub static FOO: () = (); } + fn main() { foo::FOO<|> } ", + ); + } + + #[test] + fn fix_visibility_of_trait() { + check_assist( + fix_visibility, + r"mod foo { trait Foo { fn foo(&self) {} } } + fn main() { let x: &dyn foo::<|>Foo; } ", + r"mod foo { $0pub(crate) trait Foo { fn foo(&self) {} } } + fn main() { let x: &dyn foo::Foo; } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub trait Foo { fn foo(&self) {} } } + fn main() { let x: &dyn foo::Foo<|>; } ", + ); + } + + #[test] + fn fix_visibility_of_type_alias() { + check_assist( + fix_visibility, + r"mod foo { type Foo = (); } + fn main() { let x: foo::Foo<|>; } ", + r"mod foo { $0pub(crate) type Foo = (); } + fn main() { let x: foo::Foo; } ", + ); + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub type Foo = (); } + fn main() { let x: foo::Foo<|>; } ", + ); + } + + #[test] + fn fix_visibility_of_module() { + check_assist( + fix_visibility, + r"mod foo { mod bar { fn bar() {} } } + fn main() { foo::bar<|>::bar(); } ", + r"mod foo { $0pub(crate) mod bar { fn bar() {} } } + fn main() { foo::bar::bar(); } ", + ); + + check_assist( + fix_visibility, + r" +//- /main.rs +mod foo; +fn main() { foo::bar<|>::baz(); } + +//- /foo.rs +mod bar { + pub fn baz() {} +} +", + r"$0pub(crate) mod bar { + pub fn baz() {} +} +", + ); + + check_assist_not_applicable( + fix_visibility, + r"mod foo { pub mod bar { pub fn bar() {} } } + fn main() { foo::bar<|>::bar(); } ", + ); + } + + #[test] + fn fix_visibility_of_inline_module_in_other_file() { + check_assist( + fix_visibility, + r" +//- /main.rs +mod foo; +fn main() { foo::bar<|>::baz(); } + +//- /foo.rs +mod bar; +//- /foo/bar.rs +pub fn baz() {} +", + r"$0pub(crate) mod bar; +", + ); + } + + #[test] + fn fix_visibility_of_module_declaration_in_other_file() { + check_assist( + fix_visibility, + r" +//- /main.rs +mod foo; +fn main() { foo::bar<|>>::baz(); } + +//- /foo.rs +mod bar { + pub fn baz() {} +} +", + r"$0pub(crate) mod bar { + pub fn baz() {} +} +", + ); + } + + #[test] + fn adds_pub_when_target_is_in_another_crate() { + check_assist( + fix_visibility, + r" +//- /main.rs crate:a deps:foo +foo::Bar<|> +//- /lib.rs crate:foo +struct Bar; +", + r"$0pub struct Bar; +", + ) + } + + #[test] + fn replaces_pub_crate_with_pub() { + check_assist( + fix_visibility, + r" +//- /main.rs crate:a deps:foo +foo::Bar<|> +//- /lib.rs crate:foo +pub(crate) struct Bar; +", + r"$0pub struct Bar; +", + ); + check_assist( + fix_visibility, + r" +//- /main.rs crate:a deps:foo +fn main() { + foo::Foo { <|>bar: () }; +} +//- /lib.rs crate:foo +pub struct Foo { pub(crate) bar: () } +", + r"pub struct Foo { $0pub bar: () } +", + ); + } + + #[test] + #[ignore] + // FIXME handle reexports properly + fn fix_visibility_of_reexport() { + check_assist( + fix_visibility, + r" + mod foo { + use bar::Baz; + mod bar { pub(super) struct Baz; } + } + foo::Baz<|> + ", + r" + mod foo { + $0pub(crate) use bar::Baz; + mod bar { pub(super) struct Baz; } + } + foo::Baz + ", + ) + } +} diff --git a/crates/assists/src/handlers/flip_binexpr.rs b/crates/assists/src/handlers/flip_binexpr.rs new file mode 100644 index 000000000..404f06133 --- /dev/null +++ b/crates/assists/src/handlers/flip_binexpr.rs @@ -0,0 +1,142 @@ +use syntax::ast::{AstNode, BinExpr, BinOp}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: flip_binexpr +// +// Flips operands of a binary expression. +// +// ``` +// fn main() { +// let _ = 90 +<|> 2; +// } +// ``` +// -> +// ``` +// fn main() { +// let _ = 2 + 90; +// } +// ``` +pub(crate) fn flip_binexpr(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let expr = ctx.find_node_at_offset::()?; + let lhs = expr.lhs()?.syntax().clone(); + let rhs = expr.rhs()?.syntax().clone(); + let op_range = expr.op_token()?.text_range(); + // The assist should be applied only if the cursor is on the operator + let cursor_in_range = op_range.contains_range(ctx.frange.range); + if !cursor_in_range { + return None; + } + let action: FlipAction = expr.op_kind()?.into(); + // The assist should not be applied for certain operators + if let FlipAction::DontFlip = action { + return None; + } + + acc.add( + AssistId("flip_binexpr", AssistKind::RefactorRewrite), + "Flip binary expression", + op_range, + |edit| { + if let FlipAction::FlipAndReplaceOp(new_op) = action { + edit.replace(op_range, new_op); + } + edit.replace(lhs.text_range(), rhs.text()); + edit.replace(rhs.text_range(), lhs.text()); + }, + ) +} + +enum FlipAction { + // Flip the expression + Flip, + // Flip the expression and replace the operator with this string + FlipAndReplaceOp(&'static str), + // Do not flip the expression + DontFlip, +} + +impl From for FlipAction { + fn from(op_kind: BinOp) -> Self { + match op_kind { + kind if kind.is_assignment() => FlipAction::DontFlip, + BinOp::GreaterTest => FlipAction::FlipAndReplaceOp("<"), + BinOp::GreaterEqualTest => FlipAction::FlipAndReplaceOp("<="), + BinOp::LesserTest => FlipAction::FlipAndReplaceOp(">"), + BinOp::LesserEqualTest => FlipAction::FlipAndReplaceOp(">="), + _ => FlipAction::Flip, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn flip_binexpr_target_is_the_op() { + check_assist_target(flip_binexpr, "fn f() { let res = 1 ==<|> 2; }", "==") + } + + #[test] + fn flip_binexpr_not_applicable_for_assignment() { + check_assist_not_applicable(flip_binexpr, "fn f() { let mut _x = 1; _x +=<|> 2 }") + } + + #[test] + fn flip_binexpr_works_for_eq() { + check_assist( + flip_binexpr, + "fn f() { let res = 1 ==<|> 2; }", + "fn f() { let res = 2 == 1; }", + ) + } + + #[test] + fn flip_binexpr_works_for_gt() { + check_assist(flip_binexpr, "fn f() { let res = 1 ><|> 2; }", "fn f() { let res = 2 < 1; }") + } + + #[test] + fn flip_binexpr_works_for_lteq() { + check_assist( + flip_binexpr, + "fn f() { let res = 1 <=<|> 2; }", + "fn f() { let res = 2 >= 1; }", + ) + } + + #[test] + fn flip_binexpr_works_for_complex_expr() { + check_assist( + flip_binexpr, + "fn f() { let res = (1 + 1) ==<|> (2 + 2); }", + "fn f() { let res = (2 + 2) == (1 + 1); }", + ) + } + + #[test] + fn flip_binexpr_works_inside_match() { + check_assist( + flip_binexpr, + r#" + fn dyn_eq(&self, other: &dyn Diagnostic) -> bool { + match other.downcast_ref::() { + None => false, + Some(it) => it ==<|> self, + } + } + "#, + r#" + fn dyn_eq(&self, other: &dyn Diagnostic) -> bool { + match other.downcast_ref::() { + None => false, + Some(it) => self == it, + } + } + "#, + ) + } +} diff --git a/crates/assists/src/handlers/flip_comma.rs b/crates/assists/src/handlers/flip_comma.rs new file mode 100644 index 000000000..5c69db53e --- /dev/null +++ b/crates/assists/src/handlers/flip_comma.rs @@ -0,0 +1,84 @@ +use syntax::{algo::non_trivia_sibling, Direction, T}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: flip_comma +// +// Flips two comma-separated items. +// +// ``` +// fn main() { +// ((1, 2),<|> (3, 4)); +// } +// ``` +// -> +// ``` +// fn main() { +// ((3, 4), (1, 2)); +// } +// ``` +pub(crate) fn flip_comma(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let comma = ctx.find_token_at_offset(T![,])?; + let prev = non_trivia_sibling(comma.clone().into(), Direction::Prev)?; + let next = non_trivia_sibling(comma.clone().into(), Direction::Next)?; + + // Don't apply a "flip" in case of a last comma + // that typically comes before punctuation + if next.kind().is_punct() { + return None; + } + + acc.add( + AssistId("flip_comma", AssistKind::RefactorRewrite), + "Flip comma", + comma.text_range(), + |edit| { + edit.replace(prev.text_range(), next.to_string()); + edit.replace(next.text_range(), prev.to_string()); + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_target}; + + #[test] + fn flip_comma_works_for_function_parameters() { + check_assist( + flip_comma, + "fn foo(x: i32,<|> y: Result<(), ()>) {}", + "fn foo(y: Result<(), ()>, x: i32) {}", + ) + } + + #[test] + fn flip_comma_target() { + check_assist_target(flip_comma, "fn foo(x: i32,<|> y: Result<(), ()>) {}", ",") + } + + #[test] + #[should_panic] + fn flip_comma_before_punct() { + // See https://github.com/rust-analyzer/rust-analyzer/issues/1619 + // "Flip comma" assist shouldn't be applicable to the last comma in enum or struct + // declaration body. + check_assist_target( + flip_comma, + "pub enum Test { \ + A,<|> \ + }", + ",", + ); + + check_assist_target( + flip_comma, + "pub struct Test { \ + foo: usize,<|> \ + }", + ",", + ); + } +} diff --git a/crates/assists/src/handlers/flip_trait_bound.rs b/crates/assists/src/handlers/flip_trait_bound.rs new file mode 100644 index 000000000..347e79b1d --- /dev/null +++ b/crates/assists/src/handlers/flip_trait_bound.rs @@ -0,0 +1,121 @@ +use syntax::{ + algo::non_trivia_sibling, + ast::{self, AstNode}, + Direction, T, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: flip_trait_bound +// +// Flips two trait bounds. +// +// ``` +// fn foo Copy>() { } +// ``` +// -> +// ``` +// fn foo() { } +// ``` +pub(crate) fn flip_trait_bound(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + // We want to replicate the behavior of `flip_binexpr` by only suggesting + // the assist when the cursor is on a `+` + let plus = ctx.find_token_at_offset(T![+])?; + + // Make sure we're in a `TypeBoundList` + if ast::TypeBoundList::cast(plus.parent()).is_none() { + return None; + } + + let (before, after) = ( + non_trivia_sibling(plus.clone().into(), Direction::Prev)?, + non_trivia_sibling(plus.clone().into(), Direction::Next)?, + ); + + let target = plus.text_range(); + acc.add( + AssistId("flip_trait_bound", AssistKind::RefactorRewrite), + "Flip trait bounds", + target, + |edit| { + edit.replace(before.text_range(), after.to_string()); + edit.replace(after.text_range(), before.to_string()); + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn flip_trait_bound_assist_available() { + check_assist_target(flip_trait_bound, "struct S where T: A <|>+ B + C { }", "+") + } + + #[test] + fn flip_trait_bound_not_applicable_for_single_trait_bound() { + check_assist_not_applicable(flip_trait_bound, "struct S where T: <|>A { }") + } + + #[test] + fn flip_trait_bound_works_for_struct() { + check_assist( + flip_trait_bound, + "struct S where T: A <|>+ B { }", + "struct S where T: B + A { }", + ) + } + + #[test] + fn flip_trait_bound_works_for_trait_impl() { + check_assist( + flip_trait_bound, + "impl X for S where T: A +<|> B { }", + "impl X for S where T: B + A { }", + ) + } + + #[test] + fn flip_trait_bound_works_for_fn() { + check_assist(flip_trait_bound, "fn f+ B>(t: T) { }", "fn f(t: T) { }") + } + + #[test] + fn flip_trait_bound_works_for_fn_where_clause() { + check_assist( + flip_trait_bound, + "fn f(t: T) where T: A +<|> B { }", + "fn f(t: T) where T: B + A { }", + ) + } + + #[test] + fn flip_trait_bound_works_for_lifetime() { + check_assist( + flip_trait_bound, + "fn f(t: T) where T: A <|>+ 'static { }", + "fn f(t: T) where T: 'static + A { }", + ) + } + + #[test] + fn flip_trait_bound_works_for_complex_bounds() { + check_assist( + flip_trait_bound, + "struct S where T: A <|>+ b_mod::B + C { }", + "struct S where T: b_mod::B + A + C { }", + ) + } + + #[test] + fn flip_trait_bound_works_for_long_bounds() { + check_assist( + flip_trait_bound, + "struct S where T: A + B + C + D + E + F +<|> G + H + I + J { }", + "struct S where T: A + B + C + D + E + G + F + H + I + J { }", + ) + } +} diff --git a/crates/assists/src/handlers/generate_derive.rs b/crates/assists/src/handlers/generate_derive.rs new file mode 100644 index 000000000..314504e15 --- /dev/null +++ b/crates/assists/src/handlers/generate_derive.rs @@ -0,0 +1,132 @@ +use syntax::{ + ast::{self, AstNode, AttrsOwner}, + SyntaxKind::{COMMENT, WHITESPACE}, + TextSize, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: generate_derive +// +// Adds a new `#[derive()]` clause to a struct or enum. +// +// ``` +// struct Point { +// x: u32, +// y: u32,<|> +// } +// ``` +// -> +// ``` +// #[derive($0)] +// struct Point { +// x: u32, +// y: u32, +// } +// ``` +pub(crate) fn generate_derive(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let cap = ctx.config.snippet_cap?; + let nominal = ctx.find_node_at_offset::()?; + let node_start = derive_insertion_offset(&nominal)?; + let target = nominal.syntax().text_range(); + acc.add( + AssistId("generate_derive", AssistKind::Generate), + "Add `#[derive]`", + target, + |builder| { + let derive_attr = nominal + .attrs() + .filter_map(|x| x.as_simple_call()) + .filter(|(name, _arg)| name == "derive") + .map(|(_name, arg)| arg) + .next(); + match derive_attr { + None => { + builder.insert_snippet(cap, node_start, "#[derive($0)]\n"); + } + Some(tt) => { + // Just move the cursor. + builder.insert_snippet( + cap, + tt.syntax().text_range().end() - TextSize::of(')'), + "$0", + ) + } + }; + }, + ) +} + +// Insert `derive` after doc comments. +fn derive_insertion_offset(nominal: &ast::AdtDef) -> Option { + let non_ws_child = nominal + .syntax() + .children_with_tokens() + .find(|it| it.kind() != COMMENT && it.kind() != WHITESPACE)?; + Some(non_ws_child.text_range().start()) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_target}; + + use super::*; + + #[test] + fn add_derive_new() { + check_assist( + generate_derive, + "struct Foo { a: i32, <|>}", + "#[derive($0)]\nstruct Foo { a: i32, }", + ); + check_assist( + generate_derive, + "struct Foo { <|> a: i32, }", + "#[derive($0)]\nstruct Foo { a: i32, }", + ); + } + + #[test] + fn add_derive_existing() { + check_assist( + generate_derive, + "#[derive(Clone)]\nstruct Foo { a: i32<|>, }", + "#[derive(Clone$0)]\nstruct Foo { a: i32, }", + ); + } + + #[test] + fn add_derive_new_with_doc_comment() { + check_assist( + generate_derive, + " +/// `Foo` is a pretty important struct. +/// It does stuff. +struct Foo { a: i32<|>, } + ", + " +/// `Foo` is a pretty important struct. +/// It does stuff. +#[derive($0)] +struct Foo { a: i32, } + ", + ); + } + + #[test] + fn add_derive_target() { + check_assist_target( + generate_derive, + " +struct SomeThingIrrelevant; +/// `Foo` is a pretty important struct. +/// It does stuff. +struct Foo { a: i32<|>, } +struct EvenMoreIrrelevant; + ", + "/// `Foo` is a pretty important struct. +/// It does stuff. +struct Foo { a: i32, }", + ); + } +} diff --git a/crates/assists/src/handlers/generate_from_impl_for_enum.rs b/crates/assists/src/handlers/generate_from_impl_for_enum.rs new file mode 100644 index 000000000..7f04b9572 --- /dev/null +++ b/crates/assists/src/handlers/generate_from_impl_for_enum.rs @@ -0,0 +1,200 @@ +use ide_db::RootDatabase; +use syntax::ast::{self, AstNode, NameOwner}; +use test_utils::mark; + +use crate::{utils::FamousDefs, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: generate_from_impl_for_enum +// +// Adds a From impl for an enum variant with one tuple field. +// +// ``` +// enum A { <|>One(u32) } +// ``` +// -> +// ``` +// enum A { One(u32) } +// +// impl From for A { +// fn from(v: u32) -> Self { +// A::One(v) +// } +// } +// ``` +pub(crate) fn generate_from_impl_for_enum(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let variant = ctx.find_node_at_offset::()?; + let variant_name = variant.name()?; + let enum_name = variant.parent_enum().name()?; + let field_list = match variant.kind() { + ast::StructKind::Tuple(field_list) => field_list, + _ => return None, + }; + if field_list.fields().count() != 1 { + return None; + } + let field_type = field_list.fields().next()?.ty()?; + let path = match field_type { + ast::Type::PathType(it) => it, + _ => return None, + }; + + if existing_from_impl(&ctx.sema, &variant).is_some() { + mark::hit!(test_add_from_impl_already_exists); + return None; + } + + let target = variant.syntax().text_range(); + acc.add( + AssistId("generate_from_impl_for_enum", AssistKind::Generate), + "Generate `From` impl for this enum variant", + target, + |edit| { + let start_offset = variant.parent_enum().syntax().text_range().end(); + let buf = format!( + r#" + +impl From<{0}> for {1} {{ + fn from(v: {0}) -> Self {{ + {1}::{2}(v) + }} +}}"#, + path.syntax(), + enum_name, + variant_name + ); + edit.insert(start_offset, buf); + }, + ) +} + +fn existing_from_impl( + sema: &'_ hir::Semantics<'_, RootDatabase>, + variant: &ast::Variant, +) -> Option<()> { + let variant = sema.to_def(variant)?; + let enum_ = variant.parent_enum(sema.db); + let krate = enum_.module(sema.db).krate(); + + let from_trait = FamousDefs(sema, krate).core_convert_From()?; + + let enum_type = enum_.ty(sema.db); + + let wrapped_type = variant.fields(sema.db).get(0)?.signature_ty(sema.db); + + if enum_type.impls_trait(sema.db, from_trait, &[wrapped_type]) { + Some(()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use test_utils::mark; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_generate_from_impl_for_enum() { + check_assist( + generate_from_impl_for_enum, + "enum A { <|>One(u32) }", + r#"enum A { One(u32) } + +impl From for A { + fn from(v: u32) -> Self { + A::One(v) + } +}"#, + ); + } + + #[test] + fn test_generate_from_impl_for_enum_complicated_path() { + check_assist( + generate_from_impl_for_enum, + r#"enum A { <|>One(foo::bar::baz::Boo) }"#, + r#"enum A { One(foo::bar::baz::Boo) } + +impl From for A { + fn from(v: foo::bar::baz::Boo) -> Self { + A::One(v) + } +}"#, + ); + } + + fn check_not_applicable(ra_fixture: &str) { + let fixture = + format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE); + check_assist_not_applicable(generate_from_impl_for_enum, &fixture) + } + + #[test] + fn test_add_from_impl_no_element() { + check_not_applicable("enum A { <|>One }"); + } + + #[test] + fn test_add_from_impl_more_than_one_element_in_tuple() { + check_not_applicable("enum A { <|>One(u32, String) }"); + } + + #[test] + fn test_add_from_impl_struct_variant() { + check_not_applicable("enum A { <|>One { x: u32 } }"); + } + + #[test] + fn test_add_from_impl_already_exists() { + mark::check!(test_add_from_impl_already_exists); + check_not_applicable( + r#" +enum A { <|>One(u32), } + +impl From for A { + fn from(v: u32) -> Self { + A::One(v) + } +} +"#, + ); + } + + #[test] + fn test_add_from_impl_different_variant_impl_exists() { + check_assist( + generate_from_impl_for_enum, + r#"enum A { <|>One(u32), Two(String), } + +impl From for A { + fn from(v: String) -> Self { + A::Two(v) + } +} + +pub trait From { + fn from(T) -> Self; +}"#, + r#"enum A { One(u32), Two(String), } + +impl From for A { + fn from(v: u32) -> Self { + A::One(v) + } +} + +impl From for A { + fn from(v: String) -> Self { + A::Two(v) + } +} + +pub trait From { + fn from(T) -> Self; +}"#, + ); + } +} diff --git a/crates/assists/src/handlers/generate_function.rs b/crates/assists/src/handlers/generate_function.rs new file mode 100644 index 000000000..b38d64058 --- /dev/null +++ b/crates/assists/src/handlers/generate_function.rs @@ -0,0 +1,1058 @@ +use base_db::FileId; +use hir::HirDisplay; +use rustc_hash::{FxHashMap, FxHashSet}; +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, ArgListOwner, AstNode, ModuleItemOwner, + }, + SyntaxKind, SyntaxNode, TextSize, +}; + +use crate::{ + assist_config::SnippetCap, + utils::{render_snippet, Cursor}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: generate_function +// +// Adds a stub function with a signature matching the function under the cursor. +// +// ``` +// struct Baz; +// fn baz() -> Baz { Baz } +// fn foo() { +// bar<|>("", baz()); +// } +// +// ``` +// -> +// ``` +// struct Baz; +// fn baz() -> Baz { Baz } +// fn foo() { +// bar("", baz()); +// } +// +// fn bar(arg: &str, baz: Baz) { +// ${0:todo!()} +// } +// +// ``` +pub(crate) fn generate_function(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let path_expr: ast::PathExpr = ctx.find_node_at_offset()?; + let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?; + let path = path_expr.path()?; + + if ctx.sema.resolve_path(&path).is_some() { + // The function call already resolves, no need to add a function + return None; + } + + let target_module = match path.qualifier() { + Some(qualifier) => match ctx.sema.resolve_path(&qualifier) { + Some(hir::PathResolution::Def(hir::ModuleDef::Module(module))) => Some(module), + _ => return None, + }, + None => None, + }; + + let function_builder = FunctionBuilder::from_call(&ctx, &call, &path, target_module)?; + + let target = call.syntax().text_range(); + acc.add( + AssistId("generate_function", AssistKind::Generate), + format!("Generate `{}` function", function_builder.fn_name), + target, + |builder| { + let function_template = function_builder.render(); + builder.edit_file(function_template.file); + let new_fn = function_template.to_string(ctx.config.snippet_cap); + match ctx.config.snippet_cap { + Some(cap) => builder.insert_snippet(cap, function_template.insert_offset, new_fn), + None => builder.insert(function_template.insert_offset, new_fn), + } + }, + ) +} + +struct FunctionTemplate { + insert_offset: TextSize, + placeholder_expr: ast::MacroCall, + leading_ws: String, + fn_def: ast::Fn, + trailing_ws: String, + file: FileId, +} + +impl FunctionTemplate { + fn to_string(&self, cap: Option) -> String { + let f = match cap { + Some(cap) => render_snippet( + cap, + self.fn_def.syntax(), + Cursor::Replace(self.placeholder_expr.syntax()), + ), + None => self.fn_def.to_string(), + }; + format!("{}{}{}", self.leading_ws, f, self.trailing_ws) + } +} + +struct FunctionBuilder { + target: GeneratedFunctionTarget, + fn_name: ast::Name, + type_params: Option, + params: ast::ParamList, + file: FileId, + needs_pub: bool, +} + +impl FunctionBuilder { + /// Prepares a generated function that matches `call`. + /// The function is generated in `target_module` or next to `call` + fn from_call( + ctx: &AssistContext, + call: &ast::CallExpr, + path: &ast::Path, + target_module: Option, + ) -> Option { + let mut file = ctx.frange.file_id; + let target = match &target_module { + Some(target_module) => { + let module_source = target_module.definition_source(ctx.db()); + let (in_file, target) = next_space_for_fn_in_module(ctx.sema.db, &module_source)?; + file = in_file; + target + } + None => next_space_for_fn_after_call_site(&call)?, + }; + let needs_pub = target_module.is_some(); + let target_module = target_module.or_else(|| ctx.sema.scope(target.syntax()).module())?; + let fn_name = fn_name(&path)?; + let (type_params, params) = fn_args(ctx, target_module, &call)?; + + Some(Self { target, fn_name, type_params, params, file, needs_pub }) + } + + fn render(self) -> FunctionTemplate { + let placeholder_expr = make::expr_todo(); + let fn_body = make::block_expr(vec![], Some(placeholder_expr)); + let visibility = if self.needs_pub { Some(make::visibility_pub_crate()) } else { None }; + let mut fn_def = + make::fn_(visibility, self.fn_name, self.type_params, self.params, fn_body); + let leading_ws; + let trailing_ws; + + let insert_offset = match self.target { + GeneratedFunctionTarget::BehindItem(it) => { + let indent = IndentLevel::from_node(&it); + leading_ws = format!("\n\n{}", indent); + fn_def = fn_def.indent(indent); + trailing_ws = String::new(); + it.text_range().end() + } + GeneratedFunctionTarget::InEmptyItemList(it) => { + let indent = IndentLevel::from_node(it.syntax()); + leading_ws = format!("\n{}", indent + 1); + fn_def = fn_def.indent(indent + 1); + trailing_ws = format!("\n{}", indent); + it.syntax().text_range().start() + TextSize::of('{') + } + }; + + let placeholder_expr = + fn_def.syntax().descendants().find_map(ast::MacroCall::cast).unwrap(); + FunctionTemplate { + insert_offset, + placeholder_expr, + leading_ws, + fn_def, + trailing_ws, + file: self.file, + } + } +} + +enum GeneratedFunctionTarget { + BehindItem(SyntaxNode), + InEmptyItemList(ast::ItemList), +} + +impl GeneratedFunctionTarget { + fn syntax(&self) -> &SyntaxNode { + match self { + GeneratedFunctionTarget::BehindItem(it) => it, + GeneratedFunctionTarget::InEmptyItemList(it) => it.syntax(), + } + } +} + +fn fn_name(call: &ast::Path) -> Option { + let name = call.segment()?.syntax().to_string(); + Some(make::name(&name)) +} + +/// Computes the type variables and arguments required for the generated function +fn fn_args( + ctx: &AssistContext, + target_module: hir::Module, + call: &ast::CallExpr, +) -> Option<(Option, ast::ParamList)> { + let mut arg_names = Vec::new(); + let mut arg_types = Vec::new(); + for arg in call.arg_list()?.args() { + arg_names.push(match fn_arg_name(&arg) { + Some(name) => name, + None => String::from("arg"), + }); + arg_types.push(match fn_arg_type(ctx, target_module, &arg) { + Some(ty) => ty, + None => String::from("()"), + }); + } + deduplicate_arg_names(&mut arg_names); + let params = arg_names.into_iter().zip(arg_types).map(|(name, ty)| make::param(name, ty)); + Some((None, make::param_list(params))) +} + +/// Makes duplicate argument names unique by appending incrementing numbers. +/// +/// ``` +/// let mut names: Vec = +/// vec!["foo".into(), "foo".into(), "bar".into(), "baz".into(), "bar".into()]; +/// deduplicate_arg_names(&mut names); +/// let expected: Vec = +/// vec!["foo_1".into(), "foo_2".into(), "bar_1".into(), "baz".into(), "bar_2".into()]; +/// assert_eq!(names, expected); +/// ``` +fn deduplicate_arg_names(arg_names: &mut Vec) { + let arg_name_counts = arg_names.iter().fold(FxHashMap::default(), |mut m, name| { + *m.entry(name).or_insert(0) += 1; + m + }); + let duplicate_arg_names: FxHashSet = arg_name_counts + .into_iter() + .filter(|(_, count)| *count >= 2) + .map(|(name, _)| name.clone()) + .collect(); + + let mut counter_per_name = FxHashMap::default(); + for arg_name in arg_names.iter_mut() { + if duplicate_arg_names.contains(arg_name) { + let counter = counter_per_name.entry(arg_name.clone()).or_insert(1); + arg_name.push('_'); + arg_name.push_str(&counter.to_string()); + *counter += 1; + } + } +} + +fn fn_arg_name(fn_arg: &ast::Expr) -> Option { + match fn_arg { + ast::Expr::CastExpr(cast_expr) => fn_arg_name(&cast_expr.expr()?), + _ => Some( + fn_arg + .syntax() + .descendants() + .filter(|d| ast::NameRef::can_cast(d.kind())) + .last()? + .to_string(), + ), + } +} + +fn fn_arg_type( + ctx: &AssistContext, + target_module: hir::Module, + fn_arg: &ast::Expr, +) -> Option { + let ty = ctx.sema.type_of_expr(fn_arg)?; + if ty.is_unknown() { + return None; + } + + if let Ok(rendered) = ty.display_source_code(ctx.db(), target_module.into()) { + Some(rendered) + } else { + None + } +} + +/// Returns the position inside the current mod or file +/// directly after the current block +/// We want to write the generated function directly after +/// fns, impls or macro calls, but inside mods +fn next_space_for_fn_after_call_site(expr: &ast::CallExpr) -> Option { + let mut ancestors = expr.syntax().ancestors().peekable(); + let mut last_ancestor: Option = None; + while let Some(next_ancestor) = ancestors.next() { + match next_ancestor.kind() { + SyntaxKind::SOURCE_FILE => { + break; + } + SyntaxKind::ITEM_LIST => { + if ancestors.peek().map(|a| a.kind()) == Some(SyntaxKind::MODULE) { + break; + } + } + _ => {} + } + last_ancestor = Some(next_ancestor); + } + last_ancestor.map(GeneratedFunctionTarget::BehindItem) +} + +fn next_space_for_fn_in_module( + db: &dyn hir::db::AstDatabase, + module_source: &hir::InFile, +) -> Option<(FileId, GeneratedFunctionTarget)> { + let file = module_source.file_id.original_file(db); + let assist_item = match &module_source.value { + hir::ModuleSource::SourceFile(it) => { + if let Some(last_item) = it.items().last() { + GeneratedFunctionTarget::BehindItem(last_item.syntax().clone()) + } else { + GeneratedFunctionTarget::BehindItem(it.syntax().clone()) + } + } + hir::ModuleSource::Module(it) => { + if let Some(last_item) = it.item_list().and_then(|it| it.items().last()) { + GeneratedFunctionTarget::BehindItem(last_item.syntax().clone()) + } else { + GeneratedFunctionTarget::InEmptyItemList(it.item_list()?) + } + } + }; + Some((file, assist_item)) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn add_function_with_no_args() { + check_assist( + generate_function, + r" +fn foo() { + bar<|>(); +} +", + r" +fn foo() { + bar(); +} + +fn bar() { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_from_method() { + // This ensures that the function is correctly generated + // in the next outer mod or file + check_assist( + generate_function, + r" +impl Foo { + fn foo() { + bar<|>(); + } +} +", + r" +impl Foo { + fn foo() { + bar(); + } +} + +fn bar() { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_directly_after_current_block() { + // The new fn should not be created at the end of the file or module + check_assist( + generate_function, + r" +fn foo1() { + bar<|>(); +} + +fn foo2() {} +", + r" +fn foo1() { + bar(); +} + +fn bar() { + ${0:todo!()} +} + +fn foo2() {} +", + ) + } + + #[test] + fn add_function_with_no_args_in_same_module() { + check_assist( + generate_function, + r" +mod baz { + fn foo() { + bar<|>(); + } +} +", + r" +mod baz { + fn foo() { + bar(); + } + + fn bar() { + ${0:todo!()} + } +} +", + ) + } + + #[test] + fn add_function_with_function_call_arg() { + check_assist( + generate_function, + r" +struct Baz; +fn baz() -> Baz { todo!() } +fn foo() { + bar<|>(baz()); +} +", + r" +struct Baz; +fn baz() -> Baz { todo!() } +fn foo() { + bar(baz()); +} + +fn bar(baz: Baz) { + ${0:todo!()} +} +", + ); + } + + #[test] + fn add_function_with_method_call_arg() { + check_assist( + generate_function, + r" +struct Baz; +impl Baz { + fn foo(&self) -> Baz { + ba<|>r(self.baz()) + } + fn baz(&self) -> Baz { + Baz + } +} +", + r" +struct Baz; +impl Baz { + fn foo(&self) -> Baz { + bar(self.baz()) + } + fn baz(&self) -> Baz { + Baz + } +} + +fn bar(baz: Baz) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_string_literal_arg() { + check_assist( + generate_function, + r#" +fn foo() { + <|>bar("bar") +} +"#, + r#" +fn foo() { + bar("bar") +} + +fn bar(arg: &str) { + ${0:todo!()} +} +"#, + ) + } + + #[test] + fn add_function_with_char_literal_arg() { + check_assist( + generate_function, + r#" +fn foo() { + <|>bar('x') +} +"#, + r#" +fn foo() { + bar('x') +} + +fn bar(arg: char) { + ${0:todo!()} +} +"#, + ) + } + + #[test] + fn add_function_with_int_literal_arg() { + check_assist( + generate_function, + r" +fn foo() { + <|>bar(42) +} +", + r" +fn foo() { + bar(42) +} + +fn bar(arg: i32) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_cast_int_literal_arg() { + check_assist( + generate_function, + r" +fn foo() { + <|>bar(42 as u8) +} +", + r" +fn foo() { + bar(42 as u8) +} + +fn bar(arg: u8) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn name_of_cast_variable_is_used() { + // Ensures that the name of the cast type isn't used + // in the generated function signature. + check_assist( + generate_function, + r" +fn foo() { + let x = 42; + bar<|>(x as u8) +} +", + r" +fn foo() { + let x = 42; + bar(x as u8) +} + +fn bar(x: u8) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_variable_arg() { + check_assist( + generate_function, + r" +fn foo() { + let worble = (); + <|>bar(worble) +} +", + r" +fn foo() { + let worble = (); + bar(worble) +} + +fn bar(worble: ()) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_impl_trait_arg() { + check_assist( + generate_function, + r" +trait Foo {} +fn foo() -> impl Foo { + todo!() +} +fn baz() { + <|>bar(foo()) +} +", + r" +trait Foo {} +fn foo() -> impl Foo { + todo!() +} +fn baz() { + bar(foo()) +} + +fn bar(foo: impl Foo) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn borrowed_arg() { + check_assist( + generate_function, + r" +struct Baz; +fn baz() -> Baz { todo!() } + +fn foo() { + bar<|>(&baz()) +} +", + r" +struct Baz; +fn baz() -> Baz { todo!() } + +fn foo() { + bar(&baz()) +} + +fn bar(baz: &Baz) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_with_qualified_path_arg() { + check_assist( + generate_function, + r" +mod Baz { + pub struct Bof; + pub fn baz() -> Bof { Bof } +} +fn foo() { + <|>bar(Baz::baz()) +} +", + r" +mod Baz { + pub struct Bof; + pub fn baz() -> Bof { Bof } +} +fn foo() { + bar(Baz::baz()) +} + +fn bar(baz: Baz::Bof) { + ${0:todo!()} +} +", + ) + } + + #[test] + #[ignore] + // FIXME fix printing the generics of a `Ty` to make this test pass + fn add_function_with_generic_arg() { + check_assist( + generate_function, + r" +fn foo(t: T) { + <|>bar(t) +} +", + r" +fn foo(t: T) { + bar(t) +} + +fn bar(t: T) { + ${0:todo!()} +} +", + ) + } + + #[test] + #[ignore] + // FIXME Fix function type printing to make this test pass + fn add_function_with_fn_arg() { + check_assist( + generate_function, + r" +struct Baz; +impl Baz { + fn new() -> Self { Baz } +} +fn foo() { + <|>bar(Baz::new); +} +", + r" +struct Baz; +impl Baz { + fn new() -> Self { Baz } +} +fn foo() { + bar(Baz::new); +} + +fn bar(arg: fn() -> Baz) { + ${0:todo!()} +} +", + ) + } + + #[test] + #[ignore] + // FIXME Fix closure type printing to make this test pass + fn add_function_with_closure_arg() { + check_assist( + generate_function, + r" +fn foo() { + let closure = |x: i64| x - 1; + <|>bar(closure) +} +", + r" +fn foo() { + let closure = |x: i64| x - 1; + bar(closure) +} + +fn bar(closure: impl Fn(i64) -> i64) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn unresolveable_types_default_to_unit() { + check_assist( + generate_function, + r" +fn foo() { + <|>bar(baz) +} +", + r" +fn foo() { + bar(baz) +} + +fn bar(baz: ()) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn arg_names_dont_overlap() { + check_assist( + generate_function, + r" +struct Baz; +fn baz() -> Baz { Baz } +fn foo() { + <|>bar(baz(), baz()) +} +", + r" +struct Baz; +fn baz() -> Baz { Baz } +fn foo() { + bar(baz(), baz()) +} + +fn bar(baz_1: Baz, baz_2: Baz) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn arg_name_counters_start_at_1_per_name() { + check_assist( + generate_function, + r#" +struct Baz; +fn baz() -> Baz { Baz } +fn foo() { + <|>bar(baz(), baz(), "foo", "bar") +} +"#, + r#" +struct Baz; +fn baz() -> Baz { Baz } +fn foo() { + bar(baz(), baz(), "foo", "bar") +} + +fn bar(baz_1: Baz, baz_2: Baz, arg_1: &str, arg_2: &str) { + ${0:todo!()} +} +"#, + ) + } + + #[test] + fn add_function_in_module() { + check_assist( + generate_function, + r" +mod bar {} + +fn foo() { + bar::my_fn<|>() +} +", + r" +mod bar { + pub(crate) fn my_fn() { + ${0:todo!()} + } +} + +fn foo() { + bar::my_fn() +} +", + ) + } + + #[test] + #[ignore] + // Ignored until local imports are supported. + // See https://github.com/rust-analyzer/rust-analyzer/issues/1165 + fn qualified_path_uses_correct_scope() { + check_assist( + generate_function, + " +mod foo { + pub struct Foo; +} +fn bar() { + use foo::Foo; + let foo = Foo; + baz<|>(foo) +} +", + " +mod foo { + pub struct Foo; +} +fn bar() { + use foo::Foo; + let foo = Foo; + baz(foo) +} + +fn baz(foo: foo::Foo) { + ${0:todo!()} +} +", + ) + } + + #[test] + fn add_function_in_module_containing_other_items() { + check_assist( + generate_function, + r" +mod bar { + fn something_else() {} +} + +fn foo() { + bar::my_fn<|>() +} +", + r" +mod bar { + fn something_else() {} + + pub(crate) fn my_fn() { + ${0:todo!()} + } +} + +fn foo() { + bar::my_fn() +} +", + ) + } + + #[test] + fn add_function_in_nested_module() { + check_assist( + generate_function, + r" +mod bar { + mod baz {} +} + +fn foo() { + bar::baz::my_fn<|>() +} +", + r" +mod bar { + mod baz { + pub(crate) fn my_fn() { + ${0:todo!()} + } + } +} + +fn foo() { + bar::baz::my_fn() +} +", + ) + } + + #[test] + fn add_function_in_another_file() { + check_assist( + generate_function, + r" +//- /main.rs +mod foo; + +fn main() { + foo::bar<|>() +} +//- /foo.rs +", + r" + + +pub(crate) fn bar() { + ${0:todo!()} +}", + ) + } + + #[test] + fn add_function_not_applicable_if_function_already_exists() { + check_assist_not_applicable( + generate_function, + r" +fn foo() { + bar<|>(); +} + +fn bar() {} +", + ) + } + + #[test] + fn add_function_not_applicable_if_unresolved_variable_in_call_is_selected() { + check_assist_not_applicable( + // bar is resolved, but baz isn't. + // The assist is only active if the cursor is on an unresolved path, + // but the assist should only be offered if the path is a function call. + generate_function, + r" +fn foo() { + bar(b<|>az); +} + +fn bar(baz: ()) {} +", + ) + } + + #[test] + #[ignore] + fn create_method_with_no_args() { + check_assist( + generate_function, + r" +struct Foo; +impl Foo { + fn foo(&self) { + self.bar()<|>; + } +} + ", + r" +struct Foo; +impl Foo { + fn foo(&self) { + self.bar(); + } + fn bar(&self) { + todo!(); + } +} + ", + ) + } +} diff --git a/crates/assists/src/handlers/generate_impl.rs b/crates/assists/src/handlers/generate_impl.rs new file mode 100644 index 000000000..9989109b5 --- /dev/null +++ b/crates/assists/src/handlers/generate_impl.rs @@ -0,0 +1,110 @@ +use itertools::Itertools; +use stdx::format_to; +use syntax::ast::{self, AstNode, GenericParamsOwner, NameOwner}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: generate_impl +// +// Adds a new inherent impl for a type. +// +// ``` +// struct Ctx { +// data: T,<|> +// } +// ``` +// -> +// ``` +// struct Ctx { +// data: T, +// } +// +// impl Ctx { +// $0 +// } +// ``` +pub(crate) fn generate_impl(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let nominal = ctx.find_node_at_offset::()?; + let name = nominal.name()?; + let target = nominal.syntax().text_range(); + acc.add( + AssistId("generate_impl", AssistKind::Generate), + format!("Generate impl for `{}`", name), + target, + |edit| { + let type_params = nominal.generic_param_list(); + let start_offset = nominal.syntax().text_range().end(); + let mut buf = String::new(); + buf.push_str("\n\nimpl"); + if let Some(type_params) = &type_params { + format_to!(buf, "{}", type_params.syntax()); + } + buf.push_str(" "); + buf.push_str(name.text().as_str()); + if let Some(type_params) = type_params { + let lifetime_params = type_params + .lifetime_params() + .filter_map(|it| it.lifetime_token()) + .map(|it| it.text().clone()); + let type_params = type_params + .type_params() + .filter_map(|it| it.name()) + .map(|it| it.text().clone()); + + let generic_params = lifetime_params.chain(type_params).format(", "); + format_to!(buf, "<{}>", generic_params) + } + match ctx.config.snippet_cap { + Some(cap) => { + buf.push_str(" {\n $0\n}"); + edit.insert_snippet(cap, start_offset, buf); + } + None => { + buf.push_str(" {\n}"); + edit.insert(start_offset, buf); + } + } + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_target}; + + use super::*; + + #[test] + fn test_add_impl() { + check_assist( + generate_impl, + "struct Foo {<|>}\n", + "struct Foo {}\n\nimpl Foo {\n $0\n}\n", + ); + check_assist( + generate_impl, + "struct Foo {<|>}", + "struct Foo {}\n\nimpl Foo {\n $0\n}", + ); + check_assist( + generate_impl, + "struct Foo<'a, T: Foo<'a>> {<|>}", + "struct Foo<'a, T: Foo<'a>> {}\n\nimpl<'a, T: Foo<'a>> Foo<'a, T> {\n $0\n}", + ); + } + + #[test] + fn add_impl_target() { + check_assist_target( + generate_impl, + " +struct SomeThingIrrelevant; +/// Has a lifetime parameter +struct Foo<'a, T: Foo<'a>> {<|>} +struct EvenMoreIrrelevant; +", + "/// Has a lifetime parameter +struct Foo<'a, T: Foo<'a>> {}", + ); + } +} diff --git a/crates/assists/src/handlers/generate_new.rs b/crates/assists/src/handlers/generate_new.rs new file mode 100644 index 000000000..7db10f276 --- /dev/null +++ b/crates/assists/src/handlers/generate_new.rs @@ -0,0 +1,421 @@ +use hir::Adt; +use itertools::Itertools; +use stdx::format_to; +use syntax::{ + ast::{self, AstNode, GenericParamsOwner, NameOwner, StructKind, VisibilityOwner}, + T, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: generate_new +// +// Adds a new inherent impl for a type. +// +// ``` +// struct Ctx { +// data: T,<|> +// } +// ``` +// -> +// ``` +// struct Ctx { +// data: T, +// } +// +// impl Ctx { +// fn $0new(data: T) -> Self { Self { data } } +// } +// +// ``` +pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let strukt = ctx.find_node_at_offset::()?; + + // We want to only apply this to non-union structs with named fields + let field_list = match strukt.kind() { + StructKind::Record(named) => named, + _ => return None, + }; + + // Return early if we've found an existing new fn + let impl_def = find_struct_impl(&ctx, &strukt)?; + + let target = strukt.syntax().text_range(); + acc.add(AssistId("generate_new", AssistKind::Generate), "Generate `new`", target, |builder| { + let mut buf = String::with_capacity(512); + + if impl_def.is_some() { + buf.push('\n'); + } + + let vis = strukt.visibility().map_or(String::new(), |v| format!("{} ", v)); + + let params = field_list + .fields() + .filter_map(|f| Some(format!("{}: {}", f.name()?.syntax(), f.ty()?.syntax()))) + .format(", "); + let fields = field_list.fields().filter_map(|f| f.name()).format(", "); + + format_to!(buf, " {}fn new({}) -> Self {{ Self {{ {} }} }}", vis, params, fields); + + let start_offset = impl_def + .and_then(|impl_def| { + buf.push('\n'); + let start = impl_def + .syntax() + .descendants_with_tokens() + .find(|t| t.kind() == T!['{'])? + .text_range() + .end(); + + Some(start) + }) + .unwrap_or_else(|| { + buf = generate_impl_text(&strukt, &buf); + strukt.syntax().text_range().end() + }); + + match ctx.config.snippet_cap { + None => builder.insert(start_offset, buf), + Some(cap) => { + buf = buf.replace("fn new", "fn $0new"); + builder.insert_snippet(cap, start_offset, buf); + } + } + }) +} + +// Generates the surrounding `impl Type { }` including type and lifetime +// parameters +fn generate_impl_text(strukt: &ast::Struct, code: &str) -> String { + let type_params = strukt.generic_param_list(); + let mut buf = String::with_capacity(code.len()); + buf.push_str("\n\nimpl"); + if let Some(type_params) = &type_params { + format_to!(buf, "{}", type_params.syntax()); + } + buf.push_str(" "); + buf.push_str(strukt.name().unwrap().text().as_str()); + if let Some(type_params) = type_params { + let lifetime_params = type_params + .lifetime_params() + .filter_map(|it| it.lifetime_token()) + .map(|it| it.text().clone()); + let type_params = + type_params.type_params().filter_map(|it| it.name()).map(|it| it.text().clone()); + format_to!(buf, "<{}>", lifetime_params.chain(type_params).format(", ")) + } + + format_to!(buf, " {{\n{}\n}}\n", code); + + buf +} + +// Uses a syntax-driven approach to find any impl blocks for the struct that +// exist within the module/file +// +// Returns `None` if we've found an existing `new` fn +// +// FIXME: change the new fn checking to a more semantic approach when that's more +// viable (e.g. we process proc macros, etc) +fn find_struct_impl(ctx: &AssistContext, strukt: &ast::Struct) -> Option> { + let db = ctx.db(); + let module = strukt.syntax().ancestors().find(|node| { + ast::Module::can_cast(node.kind()) || ast::SourceFile::can_cast(node.kind()) + })?; + + let struct_def = ctx.sema.to_def(strukt)?; + + let block = module.descendants().filter_map(ast::Impl::cast).find_map(|impl_blk| { + let blk = ctx.sema.to_def(&impl_blk)?; + + // FIXME: handle e.g. `struct S; impl S {}` + // (we currently use the wrong type parameter) + // also we wouldn't want to use e.g. `impl S` + let same_ty = match blk.target_ty(db).as_adt() { + Some(def) => def == Adt::Struct(struct_def), + None => false, + }; + let not_trait_impl = blk.target_trait(db).is_none(); + + if !(same_ty && not_trait_impl) { + None + } else { + Some(impl_blk) + } + }); + + if let Some(ref impl_blk) = block { + if has_new_fn(impl_blk) { + return None; + } + } + + Some(block) +} + +fn has_new_fn(imp: &ast::Impl) -> bool { + if let Some(il) = imp.assoc_item_list() { + for item in il.assoc_items() { + if let ast::AssocItem::Fn(f) = item { + if let Some(name) = f.name() { + if name.text().eq_ignore_ascii_case("new") { + return true; + } + } + } + } + } + + false +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + #[rustfmt::skip] + fn test_generate_new() { + // Check output of generation + check_assist( + generate_new, +"struct Foo {<|>}", +"struct Foo {} + +impl Foo { + fn $0new() -> Self { Self { } } +} +", + ); + check_assist( + generate_new, +"struct Foo {<|>}", +"struct Foo {} + +impl Foo { + fn $0new() -> Self { Self { } } +} +", + ); + check_assist( + generate_new, +"struct Foo<'a, T: Foo<'a>> {<|>}", +"struct Foo<'a, T: Foo<'a>> {} + +impl<'a, T: Foo<'a>> Foo<'a, T> { + fn $0new() -> Self { Self { } } +} +", + ); + check_assist( + generate_new, +"struct Foo { baz: String <|>}", +"struct Foo { baz: String } + +impl Foo { + fn $0new(baz: String) -> Self { Self { baz } } +} +", + ); + check_assist( + generate_new, +"struct Foo { baz: String, qux: Vec <|>}", +"struct Foo { baz: String, qux: Vec } + +impl Foo { + fn $0new(baz: String, qux: Vec) -> Self { Self { baz, qux } } +} +", + ); + + // Check that visibility modifiers don't get brought in for fields + check_assist( + generate_new, +"struct Foo { pub baz: String, pub qux: Vec <|>}", +"struct Foo { pub baz: String, pub qux: Vec } + +impl Foo { + fn $0new(baz: String, qux: Vec) -> Self { Self { baz, qux } } +} +", + ); + + // Check that it reuses existing impls + check_assist( + generate_new, +"struct Foo {<|>} + +impl Foo {} +", +"struct Foo {} + +impl Foo { + fn $0new() -> Self { Self { } } +} +", + ); + check_assist( + generate_new, +"struct Foo {<|>} + +impl Foo { + fn qux(&self) {} +} +", +"struct Foo {} + +impl Foo { + fn $0new() -> Self { Self { } } + + fn qux(&self) {} +} +", + ); + + check_assist( + generate_new, +"struct Foo {<|>} + +impl Foo { + fn qux(&self) {} + fn baz() -> i32 { + 5 + } +} +", +"struct Foo {} + +impl Foo { + fn $0new() -> Self { Self { } } + + fn qux(&self) {} + fn baz() -> i32 { + 5 + } +} +", + ); + + // Check visibility of new fn based on struct + check_assist( + generate_new, +"pub struct Foo {<|>}", +"pub struct Foo {} + +impl Foo { + pub fn $0new() -> Self { Self { } } +} +", + ); + check_assist( + generate_new, +"pub(crate) struct Foo {<|>}", +"pub(crate) struct Foo {} + +impl Foo { + pub(crate) fn $0new() -> Self { Self { } } +} +", + ); + } + + #[test] + fn generate_new_not_applicable_if_fn_exists() { + check_assist_not_applicable( + generate_new, + " +struct Foo {<|>} + +impl Foo { + fn new() -> Self { + Self + } +}", + ); + + check_assist_not_applicable( + generate_new, + " +struct Foo {<|>} + +impl Foo { + fn New() -> Self { + Self + } +}", + ); + } + + #[test] + fn generate_new_target() { + check_assist_target( + generate_new, + " +struct SomeThingIrrelevant; +/// Has a lifetime parameter +struct Foo<'a, T: Foo<'a>> {<|>} +struct EvenMoreIrrelevant; +", + "/// Has a lifetime parameter +struct Foo<'a, T: Foo<'a>> {}", + ); + } + + #[test] + fn test_unrelated_new() { + check_assist( + generate_new, + r##" +pub struct AstId { + file_id: HirFileId, + file_ast_id: FileAstId, +} + +impl AstId { + pub fn new(file_id: HirFileId, file_ast_id: FileAstId) -> AstId { + AstId { file_id, file_ast_id } + } +} + +pub struct Source { + pub file_id: HirFileId,<|> + pub ast: T, +} + +impl Source { + pub fn map U, U>(self, f: F) -> Source { + Source { file_id: self.file_id, ast: f(self.ast) } + } +} +"##, + r##" +pub struct AstId { + file_id: HirFileId, + file_ast_id: FileAstId, +} + +impl AstId { + pub fn new(file_id: HirFileId, file_ast_id: FileAstId) -> AstId { + AstId { file_id, file_ast_id } + } +} + +pub struct Source { + pub file_id: HirFileId, + pub ast: T, +} + +impl Source { + pub fn $0new(file_id: HirFileId, ast: T) -> Self { Self { file_id, ast } } + + pub fn map U, U>(self, f: F) -> Source { + Source { file_id: self.file_id, ast: f(self.ast) } + } +} +"##, + ); + } +} diff --git a/crates/assists/src/handlers/inline_local_variable.rs b/crates/assists/src/handlers/inline_local_variable.rs new file mode 100644 index 000000000..2b52b333b --- /dev/null +++ b/crates/assists/src/handlers/inline_local_variable.rs @@ -0,0 +1,695 @@ +use ide_db::defs::Definition; +use syntax::{ + ast::{self, AstNode, AstToken}, + TextRange, +}; +use test_utils::mark; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: inline_local_variable +// +// Inlines local variable. +// +// ``` +// fn main() { +// let x<|> = 1 + 2; +// x * 4; +// } +// ``` +// -> +// ``` +// fn main() { +// (1 + 2) * 4; +// } +// ``` +pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let let_stmt = ctx.find_node_at_offset::()?; + let bind_pat = match let_stmt.pat()? { + ast::Pat::IdentPat(pat) => pat, + _ => return None, + }; + if bind_pat.mut_token().is_some() { + mark::hit!(test_not_inline_mut_variable); + return None; + } + if !bind_pat.syntax().text_range().contains_inclusive(ctx.offset()) { + mark::hit!(not_applicable_outside_of_bind_pat); + return None; + } + let initializer_expr = let_stmt.initializer()?; + + let def = ctx.sema.to_def(&bind_pat)?; + let def = Definition::Local(def); + let refs = def.find_usages(&ctx.sema, None); + if refs.is_empty() { + mark::hit!(test_not_applicable_if_variable_unused); + return None; + }; + + let delete_range = if let Some(whitespace) = let_stmt + .syntax() + .next_sibling_or_token() + .and_then(|it| ast::Whitespace::cast(it.as_token()?.clone())) + { + TextRange::new( + let_stmt.syntax().text_range().start(), + whitespace.syntax().text_range().end(), + ) + } else { + let_stmt.syntax().text_range() + }; + + let mut wrap_in_parens = vec![true; refs.len()]; + + for (i, desc) in refs.iter().enumerate() { + let usage_node = ctx + .covering_node_for_range(desc.file_range.range) + .ancestors() + .find_map(ast::PathExpr::cast)?; + let usage_parent_option = usage_node.syntax().parent().and_then(ast::Expr::cast); + let usage_parent = match usage_parent_option { + Some(u) => u, + None => { + wrap_in_parens[i] = false; + continue; + } + }; + + wrap_in_parens[i] = match (&initializer_expr, usage_parent) { + (ast::Expr::CallExpr(_), _) + | (ast::Expr::IndexExpr(_), _) + | (ast::Expr::MethodCallExpr(_), _) + | (ast::Expr::FieldExpr(_), _) + | (ast::Expr::TryExpr(_), _) + | (ast::Expr::RefExpr(_), _) + | (ast::Expr::Literal(_), _) + | (ast::Expr::TupleExpr(_), _) + | (ast::Expr::ArrayExpr(_), _) + | (ast::Expr::ParenExpr(_), _) + | (ast::Expr::PathExpr(_), _) + | (ast::Expr::BlockExpr(_), _) + | (ast::Expr::EffectExpr(_), _) + | (_, ast::Expr::CallExpr(_)) + | (_, ast::Expr::TupleExpr(_)) + | (_, ast::Expr::ArrayExpr(_)) + | (_, ast::Expr::ParenExpr(_)) + | (_, ast::Expr::ForExpr(_)) + | (_, ast::Expr::WhileExpr(_)) + | (_, ast::Expr::BreakExpr(_)) + | (_, ast::Expr::ReturnExpr(_)) + | (_, ast::Expr::MatchExpr(_)) => false, + _ => true, + }; + } + + let init_str = initializer_expr.syntax().text().to_string(); + let init_in_paren = format!("({})", &init_str); + + let target = bind_pat.syntax().text_range(); + acc.add( + AssistId("inline_local_variable", AssistKind::RefactorInline), + "Inline variable", + target, + move |builder| { + builder.delete(delete_range); + for (desc, should_wrap) in refs.iter().zip(wrap_in_parens) { + let replacement = + if should_wrap { init_in_paren.clone() } else { init_str.clone() }; + builder.replace(desc.file_range.range, replacement) + } + }, + ) +} + +#[cfg(test)] +mod tests { + use test_utils::mark; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_inline_let_bind_literal_expr() { + check_assist( + inline_local_variable, + r" +fn bar(a: usize) {} +fn foo() { + let a<|> = 1; + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn bar(a: usize) {} +fn foo() { + 1 + 1; + if 1 > 10 { + } + + while 1 > 10 { + + } + let b = 1 * 10; + bar(1); +}", + ); + } + + #[test] + fn test_inline_let_bind_bin_expr() { + check_assist( + inline_local_variable, + r" +fn bar(a: usize) {} +fn foo() { + let a<|> = 1 + 1; + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn bar(a: usize) {} +fn foo() { + (1 + 1) + 1; + if (1 + 1) > 10 { + } + + while (1 + 1) > 10 { + + } + let b = (1 + 1) * 10; + bar(1 + 1); +}", + ); + } + + #[test] + fn test_inline_let_bind_function_call_expr() { + check_assist( + inline_local_variable, + r" +fn bar(a: usize) {} +fn foo() { + let a<|> = bar(1); + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn bar(a: usize) {} +fn foo() { + bar(1) + 1; + if bar(1) > 10 { + } + + while bar(1) > 10 { + + } + let b = bar(1) * 10; + bar(bar(1)); +}", + ); + } + + #[test] + fn test_inline_let_bind_cast_expr() { + check_assist( + inline_local_variable, + r" +fn bar(a: usize): usize { a } +fn foo() { + let a<|> = bar(1) as u64; + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn bar(a: usize): usize { a } +fn foo() { + (bar(1) as u64) + 1; + if (bar(1) as u64) > 10 { + } + + while (bar(1) as u64) > 10 { + + } + let b = (bar(1) as u64) * 10; + bar(bar(1) as u64); +}", + ); + } + + #[test] + fn test_inline_let_bind_block_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = { 10 + 1 }; + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn foo() { + { 10 + 1 } + 1; + if { 10 + 1 } > 10 { + } + + while { 10 + 1 } > 10 { + + } + let b = { 10 + 1 } * 10; + bar({ 10 + 1 }); +}", + ); + } + + #[test] + fn test_inline_let_bind_paren_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = ( 10 + 1 ); + a + 1; + if a > 10 { + } + + while a > 10 { + + } + let b = a * 10; + bar(a); +}", + r" +fn foo() { + ( 10 + 1 ) + 1; + if ( 10 + 1 ) > 10 { + } + + while ( 10 + 1 ) > 10 { + + } + let b = ( 10 + 1 ) * 10; + bar(( 10 + 1 )); +}", + ); + } + + #[test] + fn test_not_inline_mut_variable() { + mark::check!(test_not_inline_mut_variable); + check_assist_not_applicable( + inline_local_variable, + r" +fn foo() { + let mut a<|> = 1 + 1; + a + 1; +}", + ); + } + + #[test] + fn test_call_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = bar(10 + 1); + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let b = bar(10 + 1) * 10; + let c = bar(10 + 1) as usize; +}", + ); + } + + #[test] + fn test_index_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let x = vec![1, 2, 3]; + let a<|> = x[0]; + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let x = vec![1, 2, 3]; + let b = x[0] * 10; + let c = x[0] as usize; +}", + ); + } + + #[test] + fn test_method_call_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let bar = vec![1]; + let a<|> = bar.len(); + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let bar = vec![1]; + let b = bar.len() * 10; + let c = bar.len() as usize; +}", + ); + } + + #[test] + fn test_field_expr() { + check_assist( + inline_local_variable, + r" +struct Bar { + foo: usize +} + +fn foo() { + let bar = Bar { foo: 1 }; + let a<|> = bar.foo; + let b = a * 10; + let c = a as usize; +}", + r" +struct Bar { + foo: usize +} + +fn foo() { + let bar = Bar { foo: 1 }; + let b = bar.foo * 10; + let c = bar.foo as usize; +}", + ); + } + + #[test] + fn test_try_expr() { + check_assist( + inline_local_variable, + r" +fn foo() -> Option { + let bar = Some(1); + let a<|> = bar?; + let b = a * 10; + let c = a as usize; + None +}", + r" +fn foo() -> Option { + let bar = Some(1); + let b = bar? * 10; + let c = bar? as usize; + None +}", + ); + } + + #[test] + fn test_ref_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let bar = 10; + let a<|> = &bar; + let b = a * 10; +}", + r" +fn foo() { + let bar = 10; + let b = &bar * 10; +}", + ); + } + + #[test] + fn test_tuple_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = (10, 20); + let b = a[0]; +}", + r" +fn foo() { + let b = (10, 20)[0]; +}", + ); + } + + #[test] + fn test_array_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = [1, 2, 3]; + let b = a.len(); +}", + r" +fn foo() { + let b = [1, 2, 3].len(); +}", + ); + } + + #[test] + fn test_paren() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = (10 + 20); + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let b = (10 + 20) * 10; + let c = (10 + 20) as usize; +}", + ); + } + + #[test] + fn test_path_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let d = 10; + let a<|> = d; + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let d = 10; + let b = d * 10; + let c = d as usize; +}", + ); + } + + #[test] + fn test_block_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = { 10 }; + let b = a * 10; + let c = a as usize; +}", + r" +fn foo() { + let b = { 10 } * 10; + let c = { 10 } as usize; +}", + ); + } + + #[test] + fn test_used_in_different_expr1() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = 10 + 20; + let b = a * 10; + let c = (a, 20); + let d = [a, 10]; + let e = (a); +}", + r" +fn foo() { + let b = (10 + 20) * 10; + let c = (10 + 20, 20); + let d = [10 + 20, 10]; + let e = (10 + 20); +}", + ); + } + + #[test] + fn test_used_in_for_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = vec![10, 20]; + for i in a {} +}", + r" +fn foo() { + for i in vec![10, 20] {} +}", + ); + } + + #[test] + fn test_used_in_while_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = 1 > 0; + while a {} +}", + r" +fn foo() { + while 1 > 0 {} +}", + ); + } + + #[test] + fn test_used_in_break_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = 1 + 1; + loop { + break a; + } +}", + r" +fn foo() { + loop { + break 1 + 1; + } +}", + ); + } + + #[test] + fn test_used_in_return_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = 1 > 0; + return a; +}", + r" +fn foo() { + return 1 > 0; +}", + ); + } + + #[test] + fn test_used_in_match_expr() { + check_assist( + inline_local_variable, + r" +fn foo() { + let a<|> = 1 > 0; + match a {} +}", + r" +fn foo() { + match 1 > 0 {} +}", + ); + } + + #[test] + fn test_not_applicable_if_variable_unused() { + mark::check!(test_not_applicable_if_variable_unused); + check_assist_not_applicable( + inline_local_variable, + r" +fn foo() { + let <|>a = 0; +} + ", + ) + } + + #[test] + fn not_applicable_outside_of_bind_pat() { + mark::check!(not_applicable_outside_of_bind_pat); + check_assist_not_applicable( + inline_local_variable, + r" +fn main() { + let x = <|>1 + 2; + x * 4; +} +", + ) + } +} diff --git a/crates/assists/src/handlers/introduce_named_lifetime.rs b/crates/assists/src/handlers/introduce_named_lifetime.rs new file mode 100644 index 000000000..5f623e5f7 --- /dev/null +++ b/crates/assists/src/handlers/introduce_named_lifetime.rs @@ -0,0 +1,318 @@ +use rustc_hash::FxHashSet; +use syntax::{ + ast::{self, GenericParamsOwner, NameOwner}, + AstNode, SyntaxKind, TextRange, TextSize, +}; + +use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists}; + +static ASSIST_NAME: &str = "introduce_named_lifetime"; +static ASSIST_LABEL: &str = "Introduce named lifetime"; + +// Assist: introduce_named_lifetime +// +// Change an anonymous lifetime to a named lifetime. +// +// ``` +// impl Cursor<'_<|>> { +// fn node(self) -> &SyntaxNode { +// match self { +// Cursor::Replace(node) | Cursor::Before(node) => node, +// } +// } +// } +// ``` +// -> +// ``` +// impl<'a> Cursor<'a> { +// fn node(self) -> &SyntaxNode { +// match self { +// Cursor::Replace(node) | Cursor::Before(node) => node, +// } +// } +// } +// ``` +// FIXME: How can we handle renaming any one of multiple anonymous lifetimes? +// FIXME: should also add support for the case fun(f: &Foo) -> &<|>Foo +pub(crate) fn introduce_named_lifetime(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let lifetime_token = ctx + .find_token_at_offset(SyntaxKind::LIFETIME) + .filter(|lifetime| lifetime.text() == "'_")?; + if let Some(fn_def) = lifetime_token.ancestors().find_map(ast::Fn::cast) { + generate_fn_def_assist(acc, &fn_def, lifetime_token.text_range()) + } else if let Some(impl_def) = lifetime_token.ancestors().find_map(ast::Impl::cast) { + generate_impl_def_assist(acc, &impl_def, lifetime_token.text_range()) + } else { + None + } +} + +/// Generate the assist for the fn def case +fn generate_fn_def_assist( + acc: &mut Assists, + fn_def: &ast::Fn, + lifetime_loc: TextRange, +) -> Option<()> { + let param_list: ast::ParamList = fn_def.param_list()?; + let new_lifetime_param = generate_unique_lifetime_param_name(&fn_def.generic_param_list())?; + let end_of_fn_ident = fn_def.name()?.ident_token()?.text_range().end(); + let self_param = + // use the self if it's a reference and has no explicit lifetime + param_list.self_param().filter(|p| p.lifetime_token().is_none() && p.amp_token().is_some()); + // compute the location which implicitly has the same lifetime as the anonymous lifetime + let loc_needing_lifetime = if let Some(self_param) = self_param { + // if we have a self reference, use that + Some(self_param.self_token()?.text_range().start()) + } else { + // otherwise, if there's a single reference parameter without a named liftime, use that + let fn_params_without_lifetime: Vec<_> = param_list + .params() + .filter_map(|param| match param.ty() { + Some(ast::Type::RefType(ascribed_type)) + if ascribed_type.lifetime_token() == None => + { + Some(ascribed_type.amp_token()?.text_range().end()) + } + _ => None, + }) + .collect(); + match fn_params_without_lifetime.len() { + 1 => Some(fn_params_without_lifetime.into_iter().nth(0)?), + 0 => None, + // multiple unnnamed is invalid. assist is not applicable + _ => return None, + } + }; + acc.add(AssistId(ASSIST_NAME, AssistKind::Refactor), ASSIST_LABEL, lifetime_loc, |builder| { + add_lifetime_param(fn_def, builder, end_of_fn_ident, new_lifetime_param); + builder.replace(lifetime_loc, format!("'{}", new_lifetime_param)); + loc_needing_lifetime.map(|loc| builder.insert(loc, format!("'{} ", new_lifetime_param))); + }) +} + +/// Generate the assist for the impl def case +fn generate_impl_def_assist( + acc: &mut Assists, + impl_def: &ast::Impl, + lifetime_loc: TextRange, +) -> Option<()> { + let new_lifetime_param = generate_unique_lifetime_param_name(&impl_def.generic_param_list())?; + let end_of_impl_kw = impl_def.impl_token()?.text_range().end(); + acc.add(AssistId(ASSIST_NAME, AssistKind::Refactor), ASSIST_LABEL, lifetime_loc, |builder| { + add_lifetime_param(impl_def, builder, end_of_impl_kw, new_lifetime_param); + builder.replace(lifetime_loc, format!("'{}", new_lifetime_param)); + }) +} + +/// Given a type parameter list, generate a unique lifetime parameter name +/// which is not in the list +fn generate_unique_lifetime_param_name( + existing_type_param_list: &Option, +) -> Option { + match existing_type_param_list { + Some(type_params) => { + let used_lifetime_params: FxHashSet<_> = type_params + .lifetime_params() + .map(|p| p.syntax().text().to_string()[1..].to_owned()) + .collect(); + (b'a'..=b'z').map(char::from).find(|c| !used_lifetime_params.contains(&c.to_string())) + } + None => Some('a'), + } +} + +/// Add the lifetime param to `builder`. If there are type parameters in `type_params_owner`, add it to the end. Otherwise +/// add new type params brackets with the lifetime parameter at `new_type_params_loc`. +fn add_lifetime_param( + type_params_owner: &TypeParamsOwner, + builder: &mut AssistBuilder, + new_type_params_loc: TextSize, + new_lifetime_param: char, +) { + match type_params_owner.generic_param_list() { + // add the new lifetime parameter to an existing type param list + Some(type_params) => { + builder.insert( + (u32::from(type_params.syntax().text_range().end()) - 1).into(), + format!(", '{}", new_lifetime_param), + ); + } + // create a new type param list containing only the new lifetime parameter + None => { + builder.insert(new_type_params_loc, format!("<'{}>", new_lifetime_param)); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn test_example_case() { + check_assist( + introduce_named_lifetime, + r#"impl Cursor<'_<|>> { + fn node(self) -> &SyntaxNode { + match self { + Cursor::Replace(node) | Cursor::Before(node) => node, + } + } + }"#, + r#"impl<'a> Cursor<'a> { + fn node(self) -> &SyntaxNode { + match self { + Cursor::Replace(node) | Cursor::Before(node) => node, + } + } + }"#, + ); + } + + #[test] + fn test_example_case_simplified() { + check_assist( + introduce_named_lifetime, + r#"impl Cursor<'_<|>> {"#, + r#"impl<'a> Cursor<'a> {"#, + ); + } + + #[test] + fn test_example_case_cursor_after_tick() { + check_assist( + introduce_named_lifetime, + r#"impl Cursor<'<|>_> {"#, + r#"impl<'a> Cursor<'a> {"#, + ); + } + + #[test] + fn test_impl_with_other_type_param() { + check_assist( + introduce_named_lifetime, + "impl fmt::Display for SepByBuilder<'_<|>, I> + where + I: Iterator, + I::Item: fmt::Display, + {", + "impl fmt::Display for SepByBuilder<'a, I> + where + I: Iterator, + I::Item: fmt::Display, + {", + ) + } + + #[test] + fn test_example_case_cursor_before_tick() { + check_assist( + introduce_named_lifetime, + r#"impl Cursor<<|>'_> {"#, + r#"impl<'a> Cursor<'a> {"#, + ); + } + + #[test] + fn test_not_applicable_cursor_position() { + check_assist_not_applicable(introduce_named_lifetime, r#"impl Cursor<'_><|> {"#); + check_assist_not_applicable(introduce_named_lifetime, r#"impl Cursor<|><'_> {"#); + } + + #[test] + fn test_not_applicable_lifetime_already_name() { + check_assist_not_applicable(introduce_named_lifetime, r#"impl Cursor<'a<|>> {"#); + check_assist_not_applicable(introduce_named_lifetime, r#"fn my_fun<'a>() -> X<'a<|>>"#); + } + + #[test] + fn test_with_type_parameter() { + check_assist( + introduce_named_lifetime, + r#"impl Cursor>"#, + r#"impl Cursor"#, + ); + } + + #[test] + fn test_with_existing_lifetime_name_conflict() { + check_assist( + introduce_named_lifetime, + r#"impl<'a, 'b> Cursor<'a, 'b, '_<|>>"#, + r#"impl<'a, 'b, 'c> Cursor<'a, 'b, 'c>"#, + ); + } + + #[test] + fn test_function_return_value_anon_lifetime_param() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun() -> X<'_<|>>"#, + r#"fn my_fun<'a>() -> X<'a>"#, + ); + } + + #[test] + fn test_function_return_value_anon_reference_lifetime() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun() -> &'_<|> X"#, + r#"fn my_fun<'a>() -> &'a X"#, + ); + } + + #[test] + fn test_function_param_anon_lifetime() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun(x: X<'_<|>>)"#, + r#"fn my_fun<'a>(x: X<'a>)"#, + ); + } + + #[test] + fn test_function_add_lifetime_to_params() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun(f: &Foo) -> X<'_<|>>"#, + r#"fn my_fun<'a>(f: &'a Foo) -> X<'a>"#, + ); + } + + #[test] + fn test_function_add_lifetime_to_params_in_presence_of_other_lifetime() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun<'other>(f: &Foo, b: &'other Bar) -> X<'_<|>>"#, + r#"fn my_fun<'other, 'a>(f: &'a Foo, b: &'other Bar) -> X<'a>"#, + ); + } + + #[test] + fn test_function_not_applicable_without_self_and_multiple_unnamed_param_lifetimes() { + // this is not permitted under lifetime elision rules + check_assist_not_applicable( + introduce_named_lifetime, + r#"fn my_fun(f: &Foo, b: &Bar) -> X<'_<|>>"#, + ); + } + + #[test] + fn test_function_add_lifetime_to_self_ref_param() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun<'other>(&self, f: &Foo, b: &'other Bar) -> X<'_<|>>"#, + r#"fn my_fun<'other, 'a>(&'a self, f: &Foo, b: &'other Bar) -> X<'a>"#, + ); + } + + #[test] + fn test_function_add_lifetime_to_param_with_non_ref_self() { + check_assist( + introduce_named_lifetime, + r#"fn my_fun<'other>(self, f: &Foo, b: &'other Bar) -> X<'_<|>>"#, + r#"fn my_fun<'other, 'a>(self, f: &'a Foo, b: &'other Bar) -> X<'a>"#, + ); + } +} diff --git a/crates/assists/src/handlers/invert_if.rs b/crates/assists/src/handlers/invert_if.rs new file mode 100644 index 000000000..f0e047538 --- /dev/null +++ b/crates/assists/src/handlers/invert_if.rs @@ -0,0 +1,109 @@ +use syntax::{ + ast::{self, AstNode}, + T, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + utils::invert_boolean_expression, + AssistId, AssistKind, +}; + +// Assist: invert_if +// +// Apply invert_if +// This transforms if expressions of the form `if !x {A} else {B}` into `if x {B} else {A}` +// This also works with `!=`. This assist can only be applied with the cursor +// on `if`. +// +// ``` +// fn main() { +// if<|> !y { A } else { B } +// } +// ``` +// -> +// ``` +// fn main() { +// if y { B } else { A } +// } +// ``` + +pub(crate) fn invert_if(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let if_keyword = ctx.find_token_at_offset(T![if])?; + let expr = ast::IfExpr::cast(if_keyword.parent())?; + let if_range = if_keyword.text_range(); + let cursor_in_range = if_range.contains_range(ctx.frange.range); + if !cursor_in_range { + return None; + } + + // This assist should not apply for if-let. + if expr.condition()?.pat().is_some() { + return None; + } + + let cond = expr.condition()?.expr()?; + let then_node = expr.then_branch()?.syntax().clone(); + let else_block = match expr.else_branch()? { + ast::ElseBranch::Block(it) => it, + ast::ElseBranch::IfExpr(_) => return None, + }; + + let cond_range = cond.syntax().text_range(); + let flip_cond = invert_boolean_expression(cond); + let else_node = else_block.syntax(); + let else_range = else_node.text_range(); + let then_range = then_node.text_range(); + acc.add(AssistId("invert_if", AssistKind::RefactorRewrite), "Invert if", if_range, |edit| { + edit.replace(cond_range, flip_cond.syntax().text()); + edit.replace(else_range, then_node.text()); + edit.replace(then_range, else_node.text()); + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn invert_if_remove_inequality() { + check_assist( + invert_if, + "fn f() { i<|>f x != 3 { 1 } else { 3 + 2 } }", + "fn f() { if x == 3 { 3 + 2 } else { 1 } }", + ) + } + + #[test] + fn invert_if_remove_not() { + check_assist( + invert_if, + "fn f() { <|>if !cond { 3 * 2 } else { 1 } }", + "fn f() { if cond { 1 } else { 3 * 2 } }", + ) + } + + #[test] + fn invert_if_general_case() { + check_assist( + invert_if, + "fn f() { i<|>f cond { 3 * 2 } else { 1 } }", + "fn f() { if !cond { 1 } else { 3 * 2 } }", + ) + } + + #[test] + fn invert_if_doesnt_apply_with_cursor_not_on_if() { + check_assist_not_applicable(invert_if, "fn f() { if !<|>cond { 3 * 2 } else { 1 } }") + } + + #[test] + fn invert_if_doesnt_apply_with_if_let() { + check_assist_not_applicable( + invert_if, + "fn f() { i<|>f let Some(_) = Some(1) { 1 } else { 0 } }", + ) + } +} diff --git a/crates/assists/src/handlers/merge_imports.rs b/crates/assists/src/handlers/merge_imports.rs new file mode 100644 index 000000000..47d465404 --- /dev/null +++ b/crates/assists/src/handlers/merge_imports.rs @@ -0,0 +1,321 @@ +use std::iter::successors; + +use syntax::{ + algo::{neighbor, skip_trivia_token, SyntaxRewriter}, + ast::{self, edit::AstNodeEdit, make}, + AstNode, Direction, InsertPosition, SyntaxElement, T, +}; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: merge_imports +// +// Merges two imports with a common prefix. +// +// ``` +// use std::<|>fmt::Formatter; +// use std::io; +// ``` +// -> +// ``` +// use std::{fmt::Formatter, io}; +// ``` +pub(crate) fn merge_imports(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let tree: ast::UseTree = ctx.find_node_at_offset()?; + let mut rewriter = SyntaxRewriter::default(); + let mut offset = ctx.offset(); + + if let Some(use_item) = tree.syntax().parent().and_then(ast::Use::cast) { + let (merged, to_delete) = next_prev() + .filter_map(|dir| neighbor(&use_item, dir)) + .filter_map(|it| Some((it.clone(), it.use_tree()?))) + .find_map(|(use_item, use_tree)| { + Some((try_merge_trees(&tree, &use_tree)?, use_item)) + })?; + + rewriter.replace_ast(&tree, &merged); + rewriter += to_delete.remove(); + + if to_delete.syntax().text_range().end() < offset { + offset -= to_delete.syntax().text_range().len(); + } + } else { + let (merged, to_delete) = next_prev() + .filter_map(|dir| neighbor(&tree, dir)) + .find_map(|use_tree| Some((try_merge_trees(&tree, &use_tree)?, use_tree.clone())))?; + + rewriter.replace_ast(&tree, &merged); + rewriter += to_delete.remove(); + + if to_delete.syntax().text_range().end() < offset { + offset -= to_delete.syntax().text_range().len(); + } + }; + + let target = tree.syntax().text_range(); + acc.add( + AssistId("merge_imports", AssistKind::RefactorRewrite), + "Merge imports", + target, + |builder| { + builder.rewrite(rewriter); + }, + ) +} + +fn next_prev() -> impl Iterator { + [Direction::Next, Direction::Prev].iter().copied() +} + +fn try_merge_trees(old: &ast::UseTree, new: &ast::UseTree) -> Option { + let lhs_path = old.path()?; + let rhs_path = new.path()?; + + let (lhs_prefix, rhs_prefix) = common_prefix(&lhs_path, &rhs_path)?; + + let lhs = old.split_prefix(&lhs_prefix); + let rhs = new.split_prefix(&rhs_prefix); + + let should_insert_comma = lhs + .use_tree_list()? + .r_curly_token() + .and_then(|it| skip_trivia_token(it.prev_token()?, Direction::Prev)) + .map(|it| it.kind() != T![,]) + .unwrap_or(true); + + let mut to_insert: Vec = Vec::new(); + if should_insert_comma { + to_insert.push(make::token(T![,]).into()); + to_insert.push(make::tokens::single_space().into()); + } + to_insert.extend( + rhs.use_tree_list()? + .syntax() + .children_with_tokens() + .filter(|it| it.kind() != T!['{'] && it.kind() != T!['}']), + ); + let use_tree_list = lhs.use_tree_list()?; + let pos = InsertPosition::Before(use_tree_list.r_curly_token()?.into()); + let use_tree_list = use_tree_list.insert_children(pos, to_insert); + Some(lhs.with_use_tree_list(use_tree_list)) +} + +fn common_prefix(lhs: &ast::Path, rhs: &ast::Path) -> Option<(ast::Path, ast::Path)> { + let mut res = None; + let mut lhs_curr = first_path(&lhs); + let mut rhs_curr = first_path(&rhs); + loop { + match (lhs_curr.segment(), rhs_curr.segment()) { + (Some(lhs), Some(rhs)) if lhs.syntax().text() == rhs.syntax().text() => (), + _ => break, + } + res = Some((lhs_curr.clone(), rhs_curr.clone())); + + match (lhs_curr.parent_path(), rhs_curr.parent_path()) { + (Some(lhs), Some(rhs)) => { + lhs_curr = lhs; + rhs_curr = rhs; + } + _ => break, + } + } + + res +} + +fn first_path(path: &ast::Path) -> ast::Path { + successors(Some(path.clone()), |it| it.qualifier()).last().unwrap() +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_merge_first() { + check_assist( + merge_imports, + r" +use std::fmt<|>::Debug; +use std::fmt::Display; +", + r" +use std::fmt::{Debug, Display}; +", + ) + } + + #[test] + fn test_merge_second() { + check_assist( + merge_imports, + r" +use std::fmt::Debug; +use std::fmt<|>::Display; +", + r" +use std::fmt::{Display, Debug}; +", + ); + } + + #[test] + fn merge_self1() { + check_assist( + merge_imports, + r" +use std::fmt<|>; +use std::fmt::Display; +", + r" +use std::fmt::{self, Display}; +", + ); + } + + #[test] + fn merge_self2() { + check_assist( + merge_imports, + r" +use std::{fmt, <|>fmt::Display}; +", + r" +use std::{fmt::{Display, self}}; +", + ); + } + + #[test] + fn test_merge_nested() { + check_assist( + merge_imports, + r" +use std::{fmt<|>::Debug, fmt::Display}; +", + r" +use std::{fmt::{Debug, Display}}; +", + ); + check_assist( + merge_imports, + r" +use std::{fmt::Debug, fmt<|>::Display}; +", + r" +use std::{fmt::{Display, Debug}}; +", + ); + } + + #[test] + fn test_merge_single_wildcard_diff_prefixes() { + check_assist( + merge_imports, + r" +use std<|>::cell::*; +use std::str; +", + r" +use std::{cell::*, str}; +", + ) + } + + #[test] + fn test_merge_both_wildcard_diff_prefixes() { + check_assist( + merge_imports, + r" +use std<|>::cell::*; +use std::str::*; +", + r" +use std::{cell::*, str::*}; +", + ) + } + + #[test] + fn removes_just_enough_whitespace() { + check_assist( + merge_imports, + r" +use foo<|>::bar; +use foo::baz; + +/// Doc comment +", + r" +use foo::{bar, baz}; + +/// Doc comment +", + ); + } + + #[test] + fn works_with_trailing_comma() { + check_assist( + merge_imports, + r" +use { + foo<|>::bar, + foo::baz, +}; +", + r" +use { + foo::{bar, baz}, +}; +", + ); + check_assist( + merge_imports, + r" +use { + foo::baz, + foo<|>::bar, +}; +", + r" +use { + foo::{bar, baz}, +}; +", + ); + } + + #[test] + fn test_double_comma() { + check_assist( + merge_imports, + r" +use foo::bar::baz; +use foo::<|>{ + FooBar, +}; +", + r" +use foo::{ + FooBar, +bar::baz}; +", + ) + } + + #[test] + fn test_empty_use() { + check_assist_not_applicable( + merge_imports, + r" +use std::<|> +fn main() {}", + ); + } +} diff --git a/crates/assists/src/handlers/merge_match_arms.rs b/crates/assists/src/handlers/merge_match_arms.rs new file mode 100644 index 000000000..c347eb40e --- /dev/null +++ b/crates/assists/src/handlers/merge_match_arms.rs @@ -0,0 +1,248 @@ +use std::iter::successors; + +use syntax::{ + algo::neighbor, + ast::{self, AstNode}, + Direction, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists, TextRange}; + +// Assist: merge_match_arms +// +// Merges identical match arms. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// <|>Action::Move(..) => foo(), +// Action::Stop => foo(), +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move(..) | Action::Stop => foo(), +// } +// } +// ``` +pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let current_arm = ctx.find_node_at_offset::()?; + // Don't try to handle arms with guards for now - can add support for this later + if current_arm.guard().is_some() { + return None; + } + let current_expr = current_arm.expr()?; + let current_text_range = current_arm.syntax().text_range(); + + // We check if the following match arms match this one. We could, but don't, + // compare to the previous match arm as well. + let arms_to_merge = successors(Some(current_arm), |it| neighbor(it, Direction::Next)) + .take_while(|arm| { + if arm.guard().is_some() { + return false; + } + match arm.expr() { + Some(expr) => expr.syntax().text() == current_expr.syntax().text(), + None => false, + } + }) + .collect::>(); + + if arms_to_merge.len() <= 1 { + return None; + } + + acc.add( + AssistId("merge_match_arms", AssistKind::RefactorRewrite), + "Merge match arms", + current_text_range, + |edit| { + let pats = if arms_to_merge.iter().any(contains_placeholder) { + "_".into() + } else { + arms_to_merge + .iter() + .filter_map(ast::MatchArm::pat) + .map(|x| x.syntax().to_string()) + .collect::>() + .join(" | ") + }; + + let arm = format!("{} => {}", pats, current_expr.syntax().text()); + + let start = arms_to_merge.first().unwrap().syntax().text_range().start(); + let end = arms_to_merge.last().unwrap().syntax().text_range().end(); + + edit.replace(TextRange::new(start, end), arm); + }, + ) +} + +fn contains_placeholder(a: &ast::MatchArm) -> bool { + matches!(a.pat(), Some(ast::Pat::WildcardPat(..))) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn merge_match_arms_single_patterns() { + check_assist( + merge_match_arms, + r#" + #[derive(Debug)] + enum X { A, B, C } + + fn main() { + let x = X::A; + let y = match x { + X::A => { 1i32<|> } + X::B => { 1i32 } + X::C => { 2i32 } + } + } + "#, + r#" + #[derive(Debug)] + enum X { A, B, C } + + fn main() { + let x = X::A; + let y = match x { + X::A | X::B => { 1i32 } + X::C => { 2i32 } + } + } + "#, + ); + } + + #[test] + fn merge_match_arms_multiple_patterns() { + check_assist( + merge_match_arms, + r#" + #[derive(Debug)] + enum X { A, B, C, D, E } + + fn main() { + let x = X::A; + let y = match x { + X::A | X::B => {<|> 1i32 }, + X::C | X::D => { 1i32 }, + X::E => { 2i32 }, + } + } + "#, + r#" + #[derive(Debug)] + enum X { A, B, C, D, E } + + fn main() { + let x = X::A; + let y = match x { + X::A | X::B | X::C | X::D => { 1i32 }, + X::E => { 2i32 }, + } + } + "#, + ); + } + + #[test] + fn merge_match_arms_placeholder_pattern() { + check_assist( + merge_match_arms, + r#" + #[derive(Debug)] + enum X { A, B, C, D, E } + + fn main() { + let x = X::A; + let y = match x { + X::A => { 1i32 }, + X::B => { 2i<|>32 }, + _ => { 2i32 } + } + } + "#, + r#" + #[derive(Debug)] + enum X { A, B, C, D, E } + + fn main() { + let x = X::A; + let y = match x { + X::A => { 1i32 }, + _ => { 2i32 } + } + } + "#, + ); + } + + #[test] + fn merges_all_subsequent_arms() { + check_assist( + merge_match_arms, + r#" + enum X { A, B, C, D, E } + + fn main() { + match X::A { + X::A<|> => 92, + X::B => 92, + X::C => 92, + X::D => 62, + _ => panic!(), + } + } + "#, + r#" + enum X { A, B, C, D, E } + + fn main() { + match X::A { + X::A | X::B | X::C => 92, + X::D => 62, + _ => panic!(), + } + } + "#, + ) + } + + #[test] + fn merge_match_arms_rejects_guards() { + check_assist_not_applicable( + merge_match_arms, + r#" + #[derive(Debug)] + enum X { + A(i32), + B, + C + } + + fn main() { + let x = X::A; + let y = match x { + X::A(a) if a > 5 => { <|>1i32 }, + X::B => { 1i32 }, + X::C => { 2i32 } + } + } + "#, + ); + } +} diff --git a/crates/assists/src/handlers/move_bounds.rs b/crates/assists/src/handlers/move_bounds.rs new file mode 100644 index 000000000..e2e461520 --- /dev/null +++ b/crates/assists/src/handlers/move_bounds.rs @@ -0,0 +1,152 @@ +use syntax::{ + ast::{self, edit::AstNodeEdit, make, AstNode, NameOwner, TypeBoundsOwner}, + match_ast, + SyntaxKind::*, + T, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: move_bounds_to_where_clause +// +// Moves inline type bounds to a where clause. +// +// ``` +// fn applyF: FnOnce(T) -> U>(f: F, x: T) -> U { +// f(x) +// } +// ``` +// -> +// ``` +// fn apply(f: F, x: T) -> U where F: FnOnce(T) -> U { +// f(x) +// } +// ``` +pub(crate) fn move_bounds_to_where_clause(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let type_param_list = ctx.find_node_at_offset::()?; + + let mut type_params = type_param_list.type_params(); + if type_params.all(|p| p.type_bound_list().is_none()) { + return None; + } + + let parent = type_param_list.syntax().parent()?; + if parent.children_with_tokens().any(|it| it.kind() == WHERE_CLAUSE) { + return None; + } + + let anchor = match_ast! { + match parent { + ast::Fn(it) => it.body()?.syntax().clone().into(), + ast::Trait(it) => it.assoc_item_list()?.syntax().clone().into(), + ast::Impl(it) => it.assoc_item_list()?.syntax().clone().into(), + ast::Enum(it) => it.variant_list()?.syntax().clone().into(), + ast::Struct(it) => { + it.syntax().children_with_tokens() + .find(|it| it.kind() == RECORD_FIELD_LIST || it.kind() == T![;])? + }, + _ => return None + } + }; + + let target = type_param_list.syntax().text_range(); + acc.add( + AssistId("move_bounds_to_where_clause", AssistKind::RefactorRewrite), + "Move to where clause", + target, + |edit| { + let new_params = type_param_list + .type_params() + .filter(|it| it.type_bound_list().is_some()) + .map(|type_param| { + let without_bounds = type_param.remove_bounds(); + (type_param, without_bounds) + }); + + let new_type_param_list = type_param_list.replace_descendants(new_params); + edit.replace_ast(type_param_list.clone(), new_type_param_list); + + let where_clause = { + let predicates = type_param_list.type_params().filter_map(build_predicate); + make::where_clause(predicates) + }; + + let to_insert = match anchor.prev_sibling_or_token() { + Some(ref elem) if elem.kind() == WHITESPACE => { + format!("{} ", where_clause.syntax()) + } + _ => format!(" {}", where_clause.syntax()), + }; + edit.insert(anchor.text_range().start(), to_insert); + }, + ) +} + +fn build_predicate(param: ast::TypeParam) -> Option { + let path = { + let name_ref = make::name_ref(¶m.name()?.syntax().to_string()); + let segment = make::path_segment(name_ref); + make::path_unqualified(segment) + }; + let predicate = make::where_pred(path, param.type_bound_list()?.bounds()); + Some(predicate) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::check_assist; + + #[test] + fn move_bounds_to_where_clause_fn() { + check_assist( + move_bounds_to_where_clause, + r#" + fn fooF: FnOnce(T) -> T>() {} + "#, + r#" + fn foo() where T: u32, F: FnOnce(T) -> T {} + "#, + ); + } + + #[test] + fn move_bounds_to_where_clause_impl() { + check_assist( + move_bounds_to_where_clause, + r#" + implT> A {} + "#, + r#" + impl A where U: u32 {} + "#, + ); + } + + #[test] + fn move_bounds_to_where_clause_struct() { + check_assist( + move_bounds_to_where_clause, + r#" + struct A<<|>T: Iterator> {} + "#, + r#" + struct A where T: Iterator {} + "#, + ); + } + + #[test] + fn move_bounds_to_where_clause_tuple_struct() { + check_assist( + move_bounds_to_where_clause, + r#" + struct Pair<<|>T: u32>(T, T); + "#, + r#" + struct Pair(T, T) where T: u32; + "#, + ); + } +} diff --git a/crates/assists/src/handlers/move_guard.rs b/crates/assists/src/handlers/move_guard.rs new file mode 100644 index 000000000..452115fe6 --- /dev/null +++ b/crates/assists/src/handlers/move_guard.rs @@ -0,0 +1,293 @@ +use syntax::{ + ast::{edit::AstNodeEdit, make, AstNode, IfExpr, MatchArm}, + SyntaxKind::WHITESPACE, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: move_guard_to_arm_body +// +// Moves match guard into match arm body. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move { distance } <|>if distance > 10 => foo(), +// _ => (), +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move { distance } => if distance > 10 { +// foo() +// }, +// _ => (), +// } +// } +// ``` +pub(crate) fn move_guard_to_arm_body(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let match_arm = ctx.find_node_at_offset::()?; + let guard = match_arm.guard()?; + let space_before_guard = guard.syntax().prev_sibling_or_token(); + + let guard_condition = guard.expr()?; + let arm_expr = match_arm.expr()?; + let if_expr = make::expr_if( + make::condition(guard_condition, None), + make::block_expr(None, Some(arm_expr.clone())), + ) + .indent(arm_expr.indent_level()); + + let target = guard.syntax().text_range(); + acc.add( + AssistId("move_guard_to_arm_body", AssistKind::RefactorRewrite), + "Move guard to arm body", + target, + |edit| { + match space_before_guard { + Some(element) if element.kind() == WHITESPACE => { + edit.delete(element.text_range()); + } + _ => (), + }; + + edit.delete(guard.syntax().text_range()); + edit.replace_ast(arm_expr, if_expr); + }, + ) +} + +// Assist: move_arm_cond_to_match_guard +// +// Moves if expression from match arm body into a guard. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move { distance } => <|>if distance > 10 { foo() }, +// _ => (), +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move { distance } if distance > 10 => foo(), +// _ => (), +// } +// } +// ``` +pub(crate) fn move_arm_cond_to_match_guard(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let match_arm: MatchArm = ctx.find_node_at_offset::()?; + let match_pat = match_arm.pat()?; + + let arm_body = match_arm.expr()?; + let if_expr: IfExpr = IfExpr::cast(arm_body.syntax().clone())?; + let cond = if_expr.condition()?; + let then_block = if_expr.then_branch()?; + + // Not support if with else branch + if if_expr.else_branch().is_some() { + return None; + } + // Not support moving if let to arm guard + if cond.pat().is_some() { + return None; + } + + let buf = format!(" if {}", cond.syntax().text()); + + let target = if_expr.syntax().text_range(); + acc.add( + AssistId("move_arm_cond_to_match_guard", AssistKind::RefactorRewrite), + "Move condition to match guard", + target, + |edit| { + let then_only_expr = then_block.statements().next().is_none(); + + match &then_block.expr() { + Some(then_expr) if then_only_expr => { + edit.replace(if_expr.syntax().text_range(), then_expr.syntax().text()) + } + _ => edit.replace(if_expr.syntax().text_range(), then_block.syntax().text()), + } + + edit.insert(match_pat.syntax().text_range().end(), buf); + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn move_guard_to_arm_body_target() { + check_assist_target( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + x <|>if x > 10 => false, + _ => true + } +} +"#, + r#"if x > 10"#, + ); + } + + #[test] + fn move_guard_to_arm_body_works() { + check_assist( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + x <|>if x > 10 => false, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x => if x > 10 { + false + }, + _ => true + } +} +"#, + ); + } + + #[test] + fn move_guard_to_arm_body_works_complex_match() { + check_assist( + move_guard_to_arm_body, + r#" +fn main() { + match 92 { + <|>x @ 4 | x @ 5 if x > 5 => true, + _ => false + } +} +"#, + r#" +fn main() { + match 92 { + x @ 4 | x @ 5 => if x > 5 { + true + }, + _ => false + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_to_match_guard_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if x > 10 { <|>false }, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => false, + _ => true + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_to_match_guard_if_let_not_works() { + check_assist_not_applicable( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if let 62 = x { <|>false }, + _ => true + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_to_match_guard_if_empty_body_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if x > 10 { <|> }, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => { }, + _ => true + } +} +"#, + ); + } + + #[test] + fn move_arm_cond_to_match_guard_if_multiline_body_works() { + check_assist( + move_arm_cond_to_match_guard, + r#" +fn main() { + match 92 { + x => if x > 10 { + 92;<|> + false + }, + _ => true + } +} +"#, + r#" +fn main() { + match 92 { + x if x > 10 => { + 92; + false + }, + _ => true + } +} +"#, + ); + } +} diff --git a/crates/assists/src/handlers/raw_string.rs b/crates/assists/src/handlers/raw_string.rs new file mode 100644 index 000000000..9ddd116e0 --- /dev/null +++ b/crates/assists/src/handlers/raw_string.rs @@ -0,0 +1,504 @@ +use std::borrow::Cow; + +use syntax::{ + ast::{self, HasQuotes, HasStringValue}, + AstToken, + SyntaxKind::{RAW_STRING, STRING}, + TextRange, TextSize, +}; +use test_utils::mark; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: make_raw_string +// +// Adds `r#` to a plain string literal. +// +// ``` +// fn main() { +// "Hello,<|> World!"; +// } +// ``` +// -> +// ``` +// fn main() { +// r#"Hello, World!"#; +// } +// ``` +pub(crate) fn make_raw_string(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let token = ctx.find_token_at_offset(STRING).and_then(ast::String::cast)?; + let value = token.value()?; + let target = token.syntax().text_range(); + acc.add( + AssistId("make_raw_string", AssistKind::RefactorRewrite), + "Rewrite as raw string", + target, + |edit| { + let hashes = "#".repeat(required_hashes(&value).max(1)); + if matches!(value, Cow::Borrowed(_)) { + // Avoid replacing the whole string to better position the cursor. + edit.insert(token.syntax().text_range().start(), format!("r{}", hashes)); + edit.insert(token.syntax().text_range().end(), format!("{}", hashes)); + } else { + edit.replace( + token.syntax().text_range(), + format!("r{}\"{}\"{}", hashes, value, hashes), + ); + } + }, + ) +} + +// Assist: make_usual_string +// +// Turns a raw string into a plain string. +// +// ``` +// fn main() { +// r#"Hello,<|> "World!""#; +// } +// ``` +// -> +// ``` +// fn main() { +// "Hello, \"World!\""; +// } +// ``` +pub(crate) fn make_usual_string(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let token = ctx.find_token_at_offset(RAW_STRING).and_then(ast::RawString::cast)?; + let value = token.value()?; + let target = token.syntax().text_range(); + acc.add( + AssistId("make_usual_string", AssistKind::RefactorRewrite), + "Rewrite as regular string", + target, + |edit| { + // parse inside string to escape `"` + let escaped = value.escape_default().to_string(); + if let Some(offsets) = token.quote_offsets() { + if token.text()[offsets.contents - token.syntax().text_range().start()] == escaped { + edit.replace(offsets.quotes.0, "\""); + edit.replace(offsets.quotes.1, "\""); + return; + } + } + + edit.replace(token.syntax().text_range(), format!("\"{}\"", escaped)); + }, + ) +} + +// Assist: add_hash +// +// Adds a hash to a raw string literal. +// +// ``` +// fn main() { +// r#"Hello,<|> World!"#; +// } +// ``` +// -> +// ``` +// fn main() { +// r##"Hello, World!"##; +// } +// ``` +pub(crate) fn add_hash(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let token = ctx.find_token_at_offset(RAW_STRING)?; + let target = token.text_range(); + acc.add(AssistId("add_hash", AssistKind::Refactor), "Add #", target, |edit| { + edit.insert(token.text_range().start() + TextSize::of('r'), "#"); + edit.insert(token.text_range().end(), "#"); + }) +} + +// Assist: remove_hash +// +// Removes a hash from a raw string literal. +// +// ``` +// fn main() { +// r#"Hello,<|> World!"#; +// } +// ``` +// -> +// ``` +// fn main() { +// r"Hello, World!"; +// } +// ``` +pub(crate) fn remove_hash(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let token = ctx.find_token_at_offset(RAW_STRING).and_then(ast::RawString::cast)?; + + let text = token.text().as_str(); + if !text.starts_with("r#") && text.ends_with('#') { + return None; + } + + let existing_hashes = text.chars().skip(1).take_while(|&it| it == '#').count(); + + let text_range = token.syntax().text_range(); + let internal_text = &text[token.text_range_between_quotes()? - text_range.start()]; + + if existing_hashes == required_hashes(internal_text) { + mark::hit!(cant_remove_required_hash); + return None; + } + + acc.add(AssistId("remove_hash", AssistKind::RefactorRewrite), "Remove #", text_range, |edit| { + edit.delete(TextRange::at(text_range.start() + TextSize::of('r'), TextSize::of('#'))); + edit.delete(TextRange::new(text_range.end() - TextSize::of('#'), text_range.end())); + }) +} + +fn required_hashes(s: &str) -> usize { + let mut res = 0usize; + for idx in s.match_indices('"').map(|(i, _)| i) { + let (_, sub) = s.split_at(idx + 1); + let n_hashes = sub.chars().take_while(|c| *c == '#').count(); + res = res.max(n_hashes + 1) + } + res +} + +#[test] +fn test_required_hashes() { + assert_eq!(0, required_hashes("abc")); + assert_eq!(0, required_hashes("###")); + assert_eq!(1, required_hashes("\"")); + assert_eq!(2, required_hashes("\"#abc")); + assert_eq!(0, required_hashes("#abc")); + assert_eq!(3, required_hashes("#ab\"##c")); + assert_eq!(5, required_hashes("#ab\"##\"####c")); +} + +#[cfg(test)] +mod tests { + use test_utils::mark; + + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn make_raw_string_target() { + check_assist_target( + make_raw_string, + r#" + fn f() { + let s = <|>"random\nstring"; + } + "#, + r#""random\nstring""#, + ); + } + + #[test] + fn make_raw_string_works() { + check_assist( + make_raw_string, + r#" +fn f() { + let s = <|>"random\nstring"; +} +"#, + r##" +fn f() { + let s = r#"random +string"#; +} +"##, + ) + } + + #[test] + fn make_raw_string_works_inside_macros() { + check_assist( + make_raw_string, + r#" + fn f() { + format!(<|>"x = {}", 92) + } + "#, + r##" + fn f() { + format!(r#"x = {}"#, 92) + } + "##, + ) + } + + #[test] + fn make_raw_string_hashes_inside_works() { + check_assist( + make_raw_string, + r###" +fn f() { + let s = <|>"#random##\nstring"; +} +"###, + r####" +fn f() { + let s = r#"#random## +string"#; +} +"####, + ) + } + + #[test] + fn make_raw_string_closing_hashes_inside_works() { + check_assist( + make_raw_string, + r###" +fn f() { + let s = <|>"#random\"##\nstring"; +} +"###, + r####" +fn f() { + let s = r###"#random"## +string"###; +} +"####, + ) + } + + #[test] + fn make_raw_string_nothing_to_unescape_works() { + check_assist( + make_raw_string, + r#" + fn f() { + let s = <|>"random string"; + } + "#, + r##" + fn f() { + let s = r#"random string"#; + } + "##, + ) + } + + #[test] + fn make_raw_string_not_works_on_partial_string() { + check_assist_not_applicable( + make_raw_string, + r#" + fn f() { + let s = "foo<|> + } + "#, + ) + } + + #[test] + fn make_usual_string_not_works_on_partial_string() { + check_assist_not_applicable( + make_usual_string, + r#" + fn main() { + let s = r#"bar<|> + } + "#, + ) + } + + #[test] + fn add_hash_target() { + check_assist_target( + add_hash, + r#" + fn f() { + let s = <|>r"random string"; + } + "#, + r#"r"random string""#, + ); + } + + #[test] + fn add_hash_works() { + check_assist( + add_hash, + r#" + fn f() { + let s = <|>r"random string"; + } + "#, + r##" + fn f() { + let s = r#"random string"#; + } + "##, + ) + } + + #[test] + fn add_more_hash_works() { + check_assist( + add_hash, + r##" + fn f() { + let s = <|>r#"random"string"#; + } + "##, + r###" + fn f() { + let s = r##"random"string"##; + } + "###, + ) + } + + #[test] + fn add_hash_not_works() { + check_assist_not_applicable( + add_hash, + r#" + fn f() { + let s = <|>"random string"; + } + "#, + ); + } + + #[test] + fn remove_hash_target() { + check_assist_target( + remove_hash, + r##" + fn f() { + let s = <|>r#"random string"#; + } + "##, + r##"r#"random string"#"##, + ); + } + + #[test] + fn remove_hash_works() { + check_assist( + remove_hash, + r##"fn f() { let s = <|>r#"random string"#; }"##, + r#"fn f() { let s = r"random string"; }"#, + ) + } + + #[test] + fn cant_remove_required_hash() { + mark::check!(cant_remove_required_hash); + check_assist_not_applicable( + remove_hash, + r##" + fn f() { + let s = <|>r#"random"str"ing"#; + } + "##, + ) + } + + #[test] + fn remove_more_hash_works() { + check_assist( + remove_hash, + r###" + fn f() { + let s = <|>r##"random string"##; + } + "###, + r##" + fn f() { + let s = r#"random string"#; + } + "##, + ) + } + + #[test] + fn remove_hash_doesnt_work() { + check_assist_not_applicable(remove_hash, r#"fn f() { let s = <|>"random string"; }"#); + } + + #[test] + fn remove_hash_no_hash_doesnt_work() { + check_assist_not_applicable(remove_hash, r#"fn f() { let s = <|>r"random string"; }"#); + } + + #[test] + fn make_usual_string_target() { + check_assist_target( + make_usual_string, + r##" + fn f() { + let s = <|>r#"random string"#; + } + "##, + r##"r#"random string"#"##, + ); + } + + #[test] + fn make_usual_string_works() { + check_assist( + make_usual_string, + r##" + fn f() { + let s = <|>r#"random string"#; + } + "##, + r#" + fn f() { + let s = "random string"; + } + "#, + ) + } + + #[test] + fn make_usual_string_with_quote_works() { + check_assist( + make_usual_string, + r##" + fn f() { + let s = <|>r#"random"str"ing"#; + } + "##, + r#" + fn f() { + let s = "random\"str\"ing"; + } + "#, + ) + } + + #[test] + fn make_usual_string_more_hash_works() { + check_assist( + make_usual_string, + r###" + fn f() { + let s = <|>r##"random string"##; + } + "###, + r##" + fn f() { + let s = "random string"; + } + "##, + ) + } + + #[test] + fn make_usual_string_not_works() { + check_assist_not_applicable( + make_usual_string, + r#" + fn f() { + let s = <|>"random string"; + } + "#, + ); + } +} diff --git a/crates/assists/src/handlers/remove_dbg.rs b/crates/assists/src/handlers/remove_dbg.rs new file mode 100644 index 000000000..f3dcca534 --- /dev/null +++ b/crates/assists/src/handlers/remove_dbg.rs @@ -0,0 +1,205 @@ +use syntax::{ + ast::{self, AstNode}, + TextRange, TextSize, T, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: remove_dbg +// +// Removes `dbg!()` macro call. +// +// ``` +// fn main() { +// <|>dbg!(92); +// } +// ``` +// -> +// ``` +// fn main() { +// 92; +// } +// ``` +pub(crate) fn remove_dbg(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let macro_call = ctx.find_node_at_offset::()?; + + if !is_valid_macrocall(¯o_call, "dbg")? { + return None; + } + + let is_leaf = macro_call.syntax().next_sibling().is_none(); + + let macro_end = if macro_call.semicolon_token().is_some() { + macro_call.syntax().text_range().end() - TextSize::of(';') + } else { + macro_call.syntax().text_range().end() + }; + + // macro_range determines what will be deleted and replaced with macro_content + let macro_range = TextRange::new(macro_call.syntax().text_range().start(), macro_end); + let paste_instead_of_dbg = { + let text = macro_call.token_tree()?.syntax().text(); + + // leafiness determines if we should include the parenthesis or not + let slice_index: TextRange = if is_leaf { + // leaf means - we can extract the contents of the dbg! in text + TextRange::new(TextSize::of('('), text.len() - TextSize::of(')')) + } else { + // not leaf - means we should keep the parens + TextRange::up_to(text.len()) + }; + text.slice(slice_index).to_string() + }; + + let target = macro_call.syntax().text_range(); + acc.add(AssistId("remove_dbg", AssistKind::Refactor), "Remove dbg!()", target, |builder| { + builder.replace(macro_range, paste_instead_of_dbg); + }) +} + +/// Verifies that the given macro_call actually matches the given name +/// and contains proper ending tokens +fn is_valid_macrocall(macro_call: &ast::MacroCall, macro_name: &str) -> Option { + let path = macro_call.path()?; + let name_ref = path.segment()?.name_ref()?; + + // Make sure it is actually a dbg-macro call, dbg followed by ! + let excl = path.syntax().next_sibling_or_token()?; + + if name_ref.text() != macro_name || excl.kind() != T![!] { + return None; + } + + let node = macro_call.token_tree()?.syntax().clone(); + let first_child = node.first_child_or_token()?; + let last_child = node.last_child_or_token()?; + + match (first_child.kind(), last_child.kind()) { + (T!['('], T![')']) | (T!['['], T![']']) | (T!['{'], T!['}']) => Some(true), + _ => Some(false), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + #[test] + fn test_remove_dbg() { + check_assist(remove_dbg, "<|>dbg!(1 + 1)", "1 + 1"); + + check_assist(remove_dbg, "dbg!<|>((1 + 1))", "(1 + 1)"); + + check_assist(remove_dbg, "dbg!(1 <|>+ 1)", "1 + 1"); + + check_assist(remove_dbg, "let _ = <|>dbg!(1 + 1)", "let _ = 1 + 1"); + + check_assist( + remove_dbg, + " +fn foo(n: usize) { + if let Some(_) = dbg!(n.<|>checked_sub(4)) { + // ... + } +} +", + " +fn foo(n: usize) { + if let Some(_) = n.checked_sub(4) { + // ... + } +} +", + ); + } + + #[test] + fn test_remove_dbg_with_brackets_and_braces() { + check_assist(remove_dbg, "dbg![<|>1 + 1]", "1 + 1"); + check_assist(remove_dbg, "dbg!{<|>1 + 1}", "1 + 1"); + } + + #[test] + fn test_remove_dbg_not_applicable() { + check_assist_not_applicable(remove_dbg, "<|>vec![1, 2, 3]"); + check_assist_not_applicable(remove_dbg, "<|>dbg(5, 6, 7)"); + check_assist_not_applicable(remove_dbg, "<|>dbg!(5, 6, 7"); + } + + #[test] + fn test_remove_dbg_target() { + check_assist_target( + remove_dbg, + " +fn foo(n: usize) { + if let Some(_) = dbg!(n.<|>checked_sub(4)) { + // ... + } +} +", + "dbg!(n.checked_sub(4))", + ); + } + + #[test] + fn test_remove_dbg_keep_semicolon() { + // https://github.com/rust-analyzer/rust-analyzer/issues/5129#issuecomment-651399779 + // not quite though + // adding a comment at the end of the line makes + // the ast::MacroCall to include the semicolon at the end + check_assist( + remove_dbg, + r#"let res = <|>dbg!(1 * 20); // needless comment"#, + r#"let res = 1 * 20; // needless comment"#, + ); + } + + #[test] + fn test_remove_dbg_keep_expression() { + check_assist( + remove_dbg, + r#"let res = <|>dbg!(a + b).foo();"#, + r#"let res = (a + b).foo();"#, + ); + } + + #[test] + fn test_remove_dbg_from_inside_fn() { + check_assist_target( + remove_dbg, + r#" +fn square(x: u32) -> u32 { + x * x +} + +fn main() { + let x = square(dbg<|>!(5 + 10)); + println!("{}", x); +}"#, + "dbg!(5 + 10)", + ); + + check_assist( + remove_dbg, + r#" +fn square(x: u32) -> u32 { + x * x +} + +fn main() { + let x = square(dbg<|>!(5 + 10)); + println!("{}", x); +}"#, + r#" +fn square(x: u32) -> u32 { + x * x +} + +fn main() { + let x = square(5 + 10); + println!("{}", x); +}"#, + ); + } +} diff --git a/crates/assists/src/handlers/remove_mut.rs b/crates/assists/src/handlers/remove_mut.rs new file mode 100644 index 000000000..44f41daa9 --- /dev/null +++ b/crates/assists/src/handlers/remove_mut.rs @@ -0,0 +1,37 @@ +use syntax::{SyntaxKind, TextRange, T}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: remove_mut +// +// Removes the `mut` keyword. +// +// ``` +// impl Walrus { +// fn feed(&mut<|> self, amount: u32) {} +// } +// ``` +// -> +// ``` +// impl Walrus { +// fn feed(&self, amount: u32) {} +// } +// ``` +pub(crate) fn remove_mut(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let mut_token = ctx.find_token_at_offset(T![mut])?; + let delete_from = mut_token.text_range().start(); + let delete_to = match mut_token.next_token() { + Some(it) if it.kind() == SyntaxKind::WHITESPACE => it.text_range().end(), + _ => mut_token.text_range().end(), + }; + + let target = mut_token.text_range(); + acc.add( + AssistId("remove_mut", AssistKind::Refactor), + "Remove `mut` keyword", + target, + |builder| { + builder.delete(TextRange::new(delete_from, delete_to)); + }, + ) +} diff --git a/crates/assists/src/handlers/reorder_fields.rs b/crates/assists/src/handlers/reorder_fields.rs new file mode 100644 index 000000000..527f457a7 --- /dev/null +++ b/crates/assists/src/handlers/reorder_fields.rs @@ -0,0 +1,220 @@ +use itertools::Itertools; +use rustc_hash::FxHashMap; + +use hir::{Adt, ModuleDef, PathResolution, Semantics, Struct}; +use ide_db::RootDatabase; +use syntax::{algo, ast, match_ast, AstNode, SyntaxKind, SyntaxKind::*, SyntaxNode}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: reorder_fields +// +// Reorder the fields of record literals and record patterns in the same order as in +// the definition. +// +// ``` +// struct Foo {foo: i32, bar: i32}; +// const test: Foo = <|>Foo {bar: 0, foo: 1} +// ``` +// -> +// ``` +// struct Foo {foo: i32, bar: i32}; +// const test: Foo = Foo {foo: 1, bar: 0} +// ``` +// +pub(crate) fn reorder_fields(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + reorder::(acc, ctx).or_else(|| reorder::(acc, ctx)) +} + +fn reorder(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let record = ctx.find_node_at_offset::()?; + let path = record.syntax().children().find_map(ast::Path::cast)?; + + let ranks = compute_fields_ranks(&path, &ctx)?; + + let fields = get_fields(&record.syntax()); + let sorted_fields = sorted_by_rank(&fields, |node| { + *ranks.get(&get_field_name(node)).unwrap_or(&usize::max_value()) + }); + + if sorted_fields == fields { + return None; + } + + let target = record.syntax().text_range(); + acc.add( + AssistId("reorder_fields", AssistKind::RefactorRewrite), + "Reorder record fields", + target, + |edit| { + for (old, new) in fields.iter().zip(&sorted_fields) { + algo::diff(old, new).into_text_edit(edit.text_edit_builder()); + } + }, + ) +} + +fn get_fields_kind(node: &SyntaxNode) -> Vec { + match node.kind() { + RECORD_EXPR => vec![RECORD_EXPR_FIELD], + RECORD_PAT => vec![RECORD_PAT_FIELD, IDENT_PAT], + _ => vec![], + } +} + +fn get_field_name(node: &SyntaxNode) -> String { + let res = match_ast! { + match node { + ast::RecordExprField(field) => field.field_name().map(|it| it.to_string()), + ast::RecordPatField(field) => field.field_name().map(|it| it.to_string()), + _ => None, + } + }; + res.unwrap_or_default() +} + +fn get_fields(record: &SyntaxNode) -> Vec { + let kinds = get_fields_kind(record); + record.children().flat_map(|n| n.children()).filter(|n| kinds.contains(&n.kind())).collect() +} + +fn sorted_by_rank( + fields: &[SyntaxNode], + get_rank: impl Fn(&SyntaxNode) -> usize, +) -> Vec { + fields.iter().cloned().sorted_by_key(get_rank).collect() +} + +fn struct_definition(path: &ast::Path, sema: &Semantics) -> Option { + match sema.resolve_path(path) { + Some(PathResolution::Def(ModuleDef::Adt(Adt::Struct(s)))) => Some(s), + _ => None, + } +} + +fn compute_fields_ranks(path: &ast::Path, ctx: &AssistContext) -> Option> { + Some( + struct_definition(path, &ctx.sema)? + .fields(ctx.db()) + .iter() + .enumerate() + .map(|(idx, field)| (field.name(ctx.db()).to_string(), idx)) + .collect(), + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn not_applicable_if_sorted() { + check_assist_not_applicable( + reorder_fields, + r#" + struct Foo { + foo: i32, + bar: i32, + } + + const test: Foo = <|>Foo { foo: 0, bar: 0 }; + "#, + ) + } + + #[test] + fn trivial_empty_fields() { + check_assist_not_applicable( + reorder_fields, + r#" + struct Foo {}; + const test: Foo = <|>Foo {} + "#, + ) + } + + #[test] + fn reorder_struct_fields() { + check_assist( + reorder_fields, + r#" + struct Foo {foo: i32, bar: i32}; + const test: Foo = <|>Foo {bar: 0, foo: 1} + "#, + r#" + struct Foo {foo: i32, bar: i32}; + const test: Foo = Foo {foo: 1, bar: 0} + "#, + ) + } + + #[test] + fn reorder_struct_pattern() { + check_assist( + reorder_fields, + r#" + struct Foo { foo: i64, bar: i64, baz: i64 } + + fn f(f: Foo) -> { + match f { + <|>Foo { baz: 0, ref mut bar, .. } => (), + _ => () + } + } + "#, + r#" + struct Foo { foo: i64, bar: i64, baz: i64 } + + fn f(f: Foo) -> { + match f { + Foo { ref mut bar, baz: 0, .. } => (), + _ => () + } + } + "#, + ) + } + + #[test] + fn reorder_with_extra_field() { + check_assist( + reorder_fields, + r#" + struct Foo { + foo: String, + bar: String, + } + + impl Foo { + fn new() -> Foo { + let foo = String::new(); + <|>Foo { + bar: foo.clone(), + extra: "Extra field", + foo, + } + } + } + "#, + r#" + struct Foo { + foo: String, + bar: String, + } + + impl Foo { + fn new() -> Foo { + let foo = String::new(); + Foo { + foo, + bar: foo.clone(), + extra: "Extra field", + } + } + } + "#, + ) + } +} diff --git a/crates/assists/src/handlers/replace_if_let_with_match.rs b/crates/assists/src/handlers/replace_if_let_with_match.rs new file mode 100644 index 000000000..79097621e --- /dev/null +++ b/crates/assists/src/handlers/replace_if_let_with_match.rs @@ -0,0 +1,257 @@ +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, + }, + AstNode, +}; + +use crate::{ + utils::{unwrap_trivial_block, TryEnum}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: replace_if_let_with_match +// +// Replaces `if let` with an else branch with a `match` expression. +// +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// <|>if let Action::Move { distance } = action { +// foo(distance) +// } else { +// bar() +// } +// } +// ``` +// -> +// ``` +// enum Action { Move { distance: u32 }, Stop } +// +// fn handle(action: Action) { +// match action { +// Action::Move { distance } => foo(distance), +// _ => bar(), +// } +// } +// ``` +pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let if_expr: ast::IfExpr = ctx.find_node_at_offset()?; + let cond = if_expr.condition()?; + let pat = cond.pat()?; + let expr = cond.expr()?; + let then_block = if_expr.then_branch()?; + let else_block = match if_expr.else_branch()? { + ast::ElseBranch::Block(it) => it, + ast::ElseBranch::IfExpr(_) => return None, + }; + + let target = if_expr.syntax().text_range(); + acc.add( + AssistId("replace_if_let_with_match", AssistKind::RefactorRewrite), + "Replace with match", + target, + move |edit| { + let match_expr = { + let then_arm = { + let then_block = then_block.reset_indent().indent(IndentLevel(1)); + let then_expr = unwrap_trivial_block(then_block); + make::match_arm(vec![pat.clone()], then_expr) + }; + let else_arm = { + let pattern = ctx + .sema + .type_of_pat(&pat) + .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty)) + .map(|it| it.sad_pattern()) + .unwrap_or_else(|| make::wildcard_pat().into()); + let else_expr = unwrap_trivial_block(else_block); + make::match_arm(vec![pattern], else_expr) + }; + let match_expr = + make::expr_match(expr, make::match_arm_list(vec![then_arm, else_arm])); + match_expr.indent(IndentLevel::from_node(if_expr.syntax())) + }; + + edit.replace_ast::(if_expr.into(), match_expr); + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_target}; + + #[test] + fn test_replace_if_let_with_match_unwraps_simple_expressions() { + check_assist( + replace_if_let_with_match, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + if <|>let VariantData::Struct(..) = *self { + true + } else { + false + } + } +} "#, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + match *self { + VariantData::Struct(..) => true, + _ => false, + } + } +} "#, + ) + } + + #[test] + fn test_replace_if_let_with_match_doesnt_unwrap_multiline_expressions() { + check_assist( + replace_if_let_with_match, + r#" +fn foo() { + if <|>let VariantData::Struct(..) = a { + bar( + 123 + ) + } else { + false + } +} "#, + r#" +fn foo() { + match a { + VariantData::Struct(..) => { + bar( + 123 + ) + } + _ => false, + } +} "#, + ) + } + + #[test] + fn replace_if_let_with_match_target() { + check_assist_target( + replace_if_let_with_match, + r#" +impl VariantData { + pub fn is_struct(&self) -> bool { + if <|>let VariantData::Struct(..) = *self { + true + } else { + false + } + } +} "#, + "if let VariantData::Struct(..) = *self { + true + } else { + false + }", + ); + } + + #[test] + fn special_case_option() { + check_assist( + replace_if_let_with_match, + r#" +enum Option { Some(T), None } +use Option::*; + +fn foo(x: Option) { + <|>if let Some(x) = x { + println!("{}", x) + } else { + println!("none") + } +} + "#, + r#" +enum Option { Some(T), None } +use Option::*; + +fn foo(x: Option) { + match x { + Some(x) => println!("{}", x), + None => println!("none"), + } +} + "#, + ); + } + + #[test] + fn special_case_result() { + check_assist( + replace_if_let_with_match, + r#" +enum Result { Ok(T), Err(E) } +use Result::*; + +fn foo(x: Result) { + <|>if let Ok(x) = x { + println!("{}", x) + } else { + println!("none") + } +} + "#, + r#" +enum Result { Ok(T), Err(E) } +use Result::*; + +fn foo(x: Result) { + match x { + Ok(x) => println!("{}", x), + Err(_) => println!("none"), + } +} + "#, + ); + } + + #[test] + fn nested_indent() { + check_assist( + replace_if_let_with_match, + r#" +fn main() { + if true { + <|>if let Ok(rel_path) = path.strip_prefix(root_path) { + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) + } else { + None + } + } +} +"#, + r#" +fn main() { + if true { + match path.strip_prefix(root_path) { + Ok(rel_path) => { + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) + } + _ => None, + } + } +} +"#, + ) + } +} diff --git a/crates/assists/src/handlers/replace_let_with_if_let.rs b/crates/assists/src/handlers/replace_let_with_if_let.rs new file mode 100644 index 000000000..ed6d0c29b --- /dev/null +++ b/crates/assists/src/handlers/replace_let_with_if_let.rs @@ -0,0 +1,100 @@ +use std::iter::once; + +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, + }, + AstNode, T, +}; + +use crate::{utils::TryEnum, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: replace_let_with_if_let +// +// Replaces `let` with an `if-let`. +// +// ``` +// # enum Option { Some(T), None } +// +// fn main(action: Action) { +// <|>let x = compute(); +// } +// +// fn compute() -> Option { None } +// ``` +// -> +// ``` +// # enum Option { Some(T), None } +// +// fn main(action: Action) { +// if let Some(x) = compute() { +// } +// } +// +// fn compute() -> Option { None } +// ``` +pub(crate) fn replace_let_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let let_kw = ctx.find_token_at_offset(T![let])?; + let let_stmt = let_kw.ancestors().find_map(ast::LetStmt::cast)?; + let init = let_stmt.initializer()?; + let original_pat = let_stmt.pat()?; + let ty = ctx.sema.type_of_expr(&init)?; + let happy_variant = TryEnum::from_ty(&ctx.sema, &ty).map(|it| it.happy_case()); + + let target = let_kw.text_range(); + acc.add( + AssistId("replace_let_with_if_let", AssistKind::RefactorRewrite), + "Replace with if-let", + target, + |edit| { + let with_placeholder: ast::Pat = match happy_variant { + None => make::wildcard_pat().into(), + Some(var_name) => make::tuple_struct_pat( + make::path_unqualified(make::path_segment(make::name_ref(var_name))), + once(make::wildcard_pat().into()), + ) + .into(), + }; + let block = + make::block_expr(None, None).indent(IndentLevel::from_node(let_stmt.syntax())); + let if_ = make::expr_if(make::condition(init, Some(with_placeholder)), block); + let stmt = make::expr_stmt(if_); + + let placeholder = stmt.syntax().descendants().find_map(ast::WildcardPat::cast).unwrap(); + let stmt = stmt.replace_descendant(placeholder.into(), original_pat); + + edit.replace_ast(ast::Stmt::from(let_stmt), ast::Stmt::from(stmt)); + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::check_assist; + + use super::*; + + #[test] + fn replace_let_unknown_enum() { + check_assist( + replace_let_with_if_let, + r" +enum E { X(T), Y(T) } + +fn main() { + <|>let x = E::X(92); +} + ", + r" +enum E { X(T), Y(T) } + +fn main() { + if let x = E::X(92) { + } +} + ", + ) + } +} diff --git a/crates/assists/src/handlers/replace_qualified_name_with_use.rs b/crates/assists/src/handlers/replace_qualified_name_with_use.rs new file mode 100644 index 000000000..011bf1106 --- /dev/null +++ b/crates/assists/src/handlers/replace_qualified_name_with_use.rs @@ -0,0 +1,688 @@ +use hir; +use syntax::{algo::SyntaxRewriter, ast, match_ast, AstNode, SmolStr, SyntaxNode}; + +use crate::{ + utils::{find_insert_use_container, insert_use_statement}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: replace_qualified_name_with_use +// +// Adds a use statement for a given fully-qualified name. +// +// ``` +// fn process(map: std::collections::<|>HashMap) {} +// ``` +// -> +// ``` +// use std::collections::HashMap; +// +// fn process(map: HashMap) {} +// ``` +pub(crate) fn replace_qualified_name_with_use( + acc: &mut Assists, + ctx: &AssistContext, +) -> Option<()> { + let path: ast::Path = ctx.find_node_at_offset()?; + // We don't want to mess with use statements + if path.syntax().ancestors().find_map(ast::Use::cast).is_some() { + return None; + } + + let hir_path = ctx.sema.lower_path(&path)?; + let segments = collect_hir_path_segments(&hir_path)?; + if segments.len() < 2 { + return None; + } + + let target = path.syntax().text_range(); + acc.add( + AssistId("replace_qualified_name_with_use", AssistKind::RefactorRewrite), + "Replace qualified path with use", + target, + |builder| { + let path_to_import = hir_path.mod_path().clone(); + let container = match find_insert_use_container(path.syntax(), ctx) { + Some(c) => c, + None => return, + }; + insert_use_statement(path.syntax(), &path_to_import, ctx, builder.text_edit_builder()); + + // Now that we've brought the name into scope, re-qualify all paths that could be + // affected (that is, all paths inside the node we added the `use` to). + let mut rewriter = SyntaxRewriter::default(); + let syntax = container.either(|l| l.syntax().clone(), |r| r.syntax().clone()); + shorten_paths(&mut rewriter, syntax, path); + builder.rewrite(rewriter); + }, + ) +} + +fn collect_hir_path_segments(path: &hir::Path) -> Option> { + let mut ps = Vec::::with_capacity(10); + match path.kind() { + hir::PathKind::Abs => ps.push("".into()), + hir::PathKind::Crate => ps.push("crate".into()), + hir::PathKind::Plain => {} + hir::PathKind::Super(0) => ps.push("self".into()), + hir::PathKind::Super(lvl) => { + let mut chain = "super".to_string(); + for _ in 0..*lvl { + chain += "::super"; + } + ps.push(chain.into()); + } + hir::PathKind::DollarCrate(_) => return None, + } + ps.extend(path.segments().iter().map(|it| it.name.to_string().into())); + Some(ps) +} + +/// Adds replacements to `re` that shorten `path` in all descendants of `node`. +fn shorten_paths(rewriter: &mut SyntaxRewriter<'static>, node: SyntaxNode, path: ast::Path) { + for child in node.children() { + match_ast! { + match child { + // Don't modify `use` items, as this can break the `use` item when injecting a new + // import into the use tree. + ast::Use(_it) => continue, + // Don't descend into submodules, they don't have the same `use` items in scope. + ast::Module(_it) => continue, + + ast::Path(p) => { + match maybe_replace_path(rewriter, p.clone(), path.clone()) { + Some(()) => {}, + None => shorten_paths(rewriter, p.syntax().clone(), path.clone()), + } + }, + _ => shorten_paths(rewriter, child, path.clone()), + } + } + } +} + +fn maybe_replace_path( + rewriter: &mut SyntaxRewriter<'static>, + path: ast::Path, + target: ast::Path, +) -> Option<()> { + if !path_eq(path.clone(), target) { + return None; + } + + // Shorten `path`, leaving only its last segment. + if let Some(parent) = path.qualifier() { + rewriter.delete(parent.syntax()); + } + if let Some(double_colon) = path.coloncolon_token() { + rewriter.delete(&double_colon); + } + + Some(()) +} + +fn path_eq(lhs: ast::Path, rhs: ast::Path) -> bool { + let mut lhs_curr = lhs; + let mut rhs_curr = rhs; + loop { + match (lhs_curr.segment(), rhs_curr.segment()) { + (Some(lhs), Some(rhs)) if lhs.syntax().text() == rhs.syntax().text() => (), + _ => return false, + } + + match (lhs_curr.qualifier(), rhs_curr.qualifier()) { + (Some(lhs), Some(rhs)) => { + lhs_curr = lhs; + rhs_curr = rhs; + } + (None, None) => return true, + _ => return false, + } + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_replace_add_use_no_anchor() { + check_assist( + replace_qualified_name_with_use, + r" +std::fmt::Debug<|> + ", + r" +use std::fmt::Debug; + +Debug + ", + ); + } + #[test] + fn test_replace_add_use_no_anchor_with_item_below() { + check_assist( + replace_qualified_name_with_use, + r" +std::fmt::Debug<|> + +fn main() { +} + ", + r" +use std::fmt::Debug; + +Debug + +fn main() { +} + ", + ); + } + + #[test] + fn test_replace_add_use_no_anchor_with_item_above() { + check_assist( + replace_qualified_name_with_use, + r" +fn main() { +} + +std::fmt::Debug<|> + ", + r" +use std::fmt::Debug; + +fn main() { +} + +Debug + ", + ); + } + + #[test] + fn test_replace_add_use_no_anchor_2seg() { + check_assist( + replace_qualified_name_with_use, + r" +std::fmt<|>::Debug + ", + r" +use std::fmt; + +fmt::Debug + ", + ); + } + + #[test] + fn test_replace_add_use() { + check_assist( + replace_qualified_name_with_use, + r" +use stdx; + +impl std::fmt::Debug<|> for Foo { +} + ", + r" +use stdx; +use std::fmt::Debug; + +impl Debug for Foo { +} + ", + ); + } + + #[test] + fn test_replace_file_use_other_anchor() { + check_assist( + replace_qualified_name_with_use, + r" +impl std::fmt::Debug<|> for Foo { +} + ", + r" +use std::fmt::Debug; + +impl Debug for Foo { +} + ", + ); + } + + #[test] + fn test_replace_add_use_other_anchor_indent() { + check_assist( + replace_qualified_name_with_use, + r" + impl std::fmt::Debug<|> for Foo { + } + ", + r" + use std::fmt::Debug; + + impl Debug for Foo { + } + ", + ); + } + + #[test] + fn test_replace_split_different() { + check_assist( + replace_qualified_name_with_use, + r" +use std::fmt; + +impl std::io<|> for Foo { +} + ", + r" +use std::{io, fmt}; + +impl io for Foo { +} + ", + ); + } + + #[test] + fn test_replace_split_self_for_use() { + check_assist( + replace_qualified_name_with_use, + r" +use std::fmt; + +impl std::fmt::Debug<|> for Foo { +} + ", + r" +use std::fmt::{self, Debug, }; + +impl Debug for Foo { +} + ", + ); + } + + #[test] + fn test_replace_split_self_for_target() { + check_assist( + replace_qualified_name_with_use, + r" +use std::fmt::Debug; + +impl std::fmt<|> for Foo { +} + ", + r" +use std::fmt::{self, Debug}; + +impl fmt for Foo { +} + ", + ); + } + + #[test] + fn test_replace_add_to_nested_self_nested() { + check_assist( + replace_qualified_name_with_use, + r" +use std::fmt::{Debug, nested::{Display}}; + +impl std::fmt::nested<|> for Foo { +} +", + r" +use std::fmt::{Debug, nested::{Display, self}}; + +impl nested for Foo { +} +", + ); + } + + #[test] + fn test_replace_add_to_nested_self_already_included() { + check_assist( + replace_qualified_name_with_use, + r" +use std::fmt::{Debug, nested::{self, Display}}; + +impl std::fmt::nested<|> for Foo { +} +", + r" +use std::fmt::{Debug, nested::{self, Display}}; + +impl nested for Foo { +} +", + ); + } + + #[test] + fn test_replace_add_to_nested_nested() { + check_assist( + replace_qualified_name_with_use, + r" +use std::fmt::{Debug, nested::{Display}}; + +impl std::fmt::nested::Debug<|> for Foo { +} +", + r" +use std::fmt::{Debug, nested::{Display, Debug}}; + +impl Debug for Foo { +} +", + ); + } + + #[test] + fn test_replace_split_common_target_longer() { + check_assist( + replace_qualified_name_with_use, + r" +use std::fmt::Debug; + +impl std::fmt::nested::Display<|> for Foo { +} +", + r" +use std::fmt::{nested::Display, Debug}; + +impl Display for Foo { +} +", + ); + } + + #[test] + fn test_replace_split_common_use_longer() { + check_assist( + replace_qualified_name_with_use, + r" +use std::fmt::nested::Debug; + +impl std::fmt::Display<|> for Foo { +} +", + r" +use std::fmt::{Display, nested::Debug}; + +impl Display for Foo { +} +", + ); + } + + #[test] + fn test_replace_use_nested_import() { + check_assist( + replace_qualified_name_with_use, + r" +use crate::{ + ty::{Substs, Ty}, + AssocItem, +}; + +fn foo() { crate::ty::lower<|>::trait_env() } +", + r" +use crate::{ + ty::{Substs, Ty, lower}, + AssocItem, +}; + +fn foo() { lower::trait_env() } +", + ); + } + + #[test] + fn test_replace_alias() { + check_assist( + replace_qualified_name_with_use, + r" +use std::fmt as foo; + +impl foo::Debug<|> for Foo { +} +", + r" +use std::fmt as foo; + +impl Debug for Foo { +} +", + ); + } + + #[test] + fn test_replace_not_applicable_one_segment() { + check_assist_not_applicable( + replace_qualified_name_with_use, + r" +impl foo<|> for Foo { +} +", + ); + } + + #[test] + fn test_replace_not_applicable_in_use() { + check_assist_not_applicable( + replace_qualified_name_with_use, + r" +use std::fmt<|>; +", + ); + } + + #[test] + fn test_replace_add_use_no_anchor_in_mod_mod() { + check_assist( + replace_qualified_name_with_use, + r" +mod foo { + mod bar { + std::fmt::Debug<|> + } +} + ", + r" +mod foo { + mod bar { + use std::fmt::Debug; + + Debug + } +} + ", + ); + } + + #[test] + fn inserts_imports_after_inner_attributes() { + check_assist( + replace_qualified_name_with_use, + r" +#![allow(dead_code)] + +fn main() { + std::fmt::Debug<|> +} + ", + r" +#![allow(dead_code)] +use std::fmt::Debug; + +fn main() { + Debug +} + ", + ); + } + + #[test] + fn replaces_all_affected_paths() { + check_assist( + replace_qualified_name_with_use, + r" +fn main() { + std::fmt::Debug<|>; + let x: std::fmt::Debug = std::fmt::Debug; +} + ", + r" +use std::fmt::Debug; + +fn main() { + Debug; + let x: Debug = Debug; +} + ", + ); + } + + #[test] + fn replaces_all_affected_paths_mod() { + check_assist( + replace_qualified_name_with_use, + r" +mod m { + fn f() { + std::fmt::Debug<|>; + let x: std::fmt::Debug = std::fmt::Debug; + } + fn g() { + std::fmt::Debug; + } +} + +fn f() { + std::fmt::Debug; +} + ", + r" +mod m { + use std::fmt::Debug; + + fn f() { + Debug; + let x: Debug = Debug; + } + fn g() { + Debug; + } +} + +fn f() { + std::fmt::Debug; +} + ", + ); + } + + #[test] + fn does_not_replace_in_submodules() { + check_assist( + replace_qualified_name_with_use, + r" +fn main() { + std::fmt::Debug<|>; +} + +mod sub { + fn f() { + std::fmt::Debug; + } +} + ", + r" +use std::fmt::Debug; + +fn main() { + Debug; +} + +mod sub { + fn f() { + std::fmt::Debug; + } +} + ", + ); + } + + #[test] + fn does_not_replace_in_use() { + check_assist( + replace_qualified_name_with_use, + r" +use std::fmt::Display; + +fn main() { + std::fmt<|>; +} + ", + r" +use std::fmt::{self, Display}; + +fn main() { + fmt; +} + ", + ); + } + + #[test] + fn does_not_replace_pub_use() { + check_assist( + replace_qualified_name_with_use, + r" +pub use std::fmt; + +impl std::io<|> for Foo { +} + ", + r" +use std::io; + +pub use std::fmt; + +impl io for Foo { +} + ", + ); + } + + #[test] + fn does_not_replace_pub_crate_use() { + check_assist( + replace_qualified_name_with_use, + r" +pub(crate) use std::fmt; + +impl std::io<|> for Foo { +} + ", + r" +use std::io; + +pub(crate) use std::fmt; + +impl io for Foo { +} + ", + ); + } +} diff --git a/crates/assists/src/handlers/replace_unwrap_with_match.rs b/crates/assists/src/handlers/replace_unwrap_with_match.rs new file mode 100644 index 000000000..9705f11b7 --- /dev/null +++ b/crates/assists/src/handlers/replace_unwrap_with_match.rs @@ -0,0 +1,187 @@ +use std::iter; + +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + make, + }, + AstNode, +}; + +use crate::{ + utils::{render_snippet, Cursor, TryEnum}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: replace_unwrap_with_match +// +// Replaces `unwrap` a `match` expression. Works for Result and Option. +// +// ``` +// enum Result { Ok(T), Err(E) } +// fn main() { +// let x: Result = Result::Ok(92); +// let y = x.<|>unwrap(); +// } +// ``` +// -> +// ``` +// enum Result { Ok(T), Err(E) } +// fn main() { +// let x: Result = Result::Ok(92); +// let y = match x { +// Ok(a) => a, +// $0_ => unreachable!(), +// }; +// } +// ``` +pub(crate) fn replace_unwrap_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let method_call: ast::MethodCallExpr = ctx.find_node_at_offset()?; + let name = method_call.name_ref()?; + if name.text() != "unwrap" { + return None; + } + let caller = method_call.expr()?; + let ty = ctx.sema.type_of_expr(&caller)?; + let happy_variant = TryEnum::from_ty(&ctx.sema, &ty)?.happy_case(); + let target = method_call.syntax().text_range(); + acc.add( + AssistId("replace_unwrap_with_match", AssistKind::RefactorRewrite), + "Replace unwrap with match", + target, + |builder| { + let ok_path = make::path_unqualified(make::path_segment(make::name_ref(happy_variant))); + let it = make::ident_pat(make::name("a")).into(); + let ok_tuple = make::tuple_struct_pat(ok_path, iter::once(it)).into(); + + let bind_path = make::path_unqualified(make::path_segment(make::name_ref("a"))); + let ok_arm = make::match_arm(iter::once(ok_tuple), make::expr_path(bind_path)); + + let unreachable_call = make::expr_unreachable(); + let err_arm = + make::match_arm(iter::once(make::wildcard_pat().into()), unreachable_call); + + let match_arm_list = make::match_arm_list(vec![ok_arm, err_arm]); + let match_expr = make::expr_match(caller.clone(), match_arm_list) + .indent(IndentLevel::from_node(method_call.syntax())); + + let range = method_call.syntax().text_range(); + match ctx.config.snippet_cap { + Some(cap) => { + let err_arm = match_expr + .syntax() + .descendants() + .filter_map(ast::MatchArm::cast) + .last() + .unwrap(); + let snippet = + render_snippet(cap, match_expr.syntax(), Cursor::Before(err_arm.syntax())); + builder.replace_snippet(cap, range, snippet) + } + None => builder.replace(range, match_expr.to_string()), + } + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_target}; + + use super::*; + + #[test] + fn test_replace_result_unwrap_with_match() { + check_assist( + replace_unwrap_with_match, + r" +enum Result { Ok(T), Err(E) } +fn i(a: T) -> T { a } +fn main() { + let x: Result = Result::Ok(92); + let y = i(x).<|>unwrap(); +} + ", + r" +enum Result { Ok(T), Err(E) } +fn i(a: T) -> T { a } +fn main() { + let x: Result = Result::Ok(92); + let y = match i(x) { + Ok(a) => a, + $0_ => unreachable!(), + }; +} + ", + ) + } + + #[test] + fn test_replace_option_unwrap_with_match() { + check_assist( + replace_unwrap_with_match, + r" +enum Option { Some(T), None } +fn i(a: T) -> T { a } +fn main() { + let x = Option::Some(92); + let y = i(x).<|>unwrap(); +} + ", + r" +enum Option { Some(T), None } +fn i(a: T) -> T { a } +fn main() { + let x = Option::Some(92); + let y = match i(x) { + Some(a) => a, + $0_ => unreachable!(), + }; +} + ", + ); + } + + #[test] + fn test_replace_result_unwrap_with_match_chaining() { + check_assist( + replace_unwrap_with_match, + r" +enum Result { Ok(T), Err(E) } +fn i(a: T) -> T { a } +fn main() { + let x: Result = Result::Ok(92); + let y = i(x).<|>unwrap().count_zeroes(); +} + ", + r" +enum Result { Ok(T), Err(E) } +fn i(a: T) -> T { a } +fn main() { + let x: Result = Result::Ok(92); + let y = match i(x) { + Ok(a) => a, + $0_ => unreachable!(), + }.count_zeroes(); +} + ", + ) + } + + #[test] + fn replace_unwrap_with_match_target() { + check_assist_target( + replace_unwrap_with_match, + r" +enum Option { Some(T), None } +fn i(a: T) -> T { a } +fn main() { + let x = Option::Some(92); + let y = i(x).<|>unwrap(); +} + ", + r"i(x).unwrap()", + ); + } +} diff --git a/crates/assists/src/handlers/split_import.rs b/crates/assists/src/handlers/split_import.rs new file mode 100644 index 000000000..15e67eaa1 --- /dev/null +++ b/crates/assists/src/handlers/split_import.rs @@ -0,0 +1,79 @@ +use std::iter::successors; + +use syntax::{ast, AstNode, T}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: split_import +// +// Wraps the tail of import into braces. +// +// ``` +// use std::<|>collections::HashMap; +// ``` +// -> +// ``` +// use std::{collections::HashMap}; +// ``` +pub(crate) fn split_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let colon_colon = ctx.find_token_at_offset(T![::])?; + let path = ast::Path::cast(colon_colon.parent())?.qualifier()?; + let top_path = successors(Some(path.clone()), |it| it.parent_path()).last()?; + + let use_tree = top_path.syntax().ancestors().find_map(ast::UseTree::cast)?; + + let new_tree = use_tree.split_prefix(&path); + if new_tree == use_tree { + return None; + } + + let target = colon_colon.text_range(); + acc.add(AssistId("split_import", AssistKind::RefactorRewrite), "Split import", target, |edit| { + edit.replace_ast(use_tree, new_tree); + }) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + + use super::*; + + #[test] + fn test_split_import() { + check_assist( + split_import, + "use crate::<|>db::RootDatabase;", + "use crate::{db::RootDatabase};", + ) + } + + #[test] + fn split_import_works_with_trees() { + check_assist( + split_import, + "use crate:<|>:db::{RootDatabase, FileSymbol}", + "use crate::{db::{RootDatabase, FileSymbol}}", + ) + } + + #[test] + fn split_import_target() { + check_assist_target(split_import, "use crate::<|>db::{RootDatabase, FileSymbol}", "::"); + } + + #[test] + fn issue4044() { + check_assist_not_applicable(split_import, "use crate::<|>:::self;") + } + + #[test] + fn test_empty_use() { + check_assist_not_applicable( + split_import, + r" +use std::<|> +fn main() {}", + ); + } +} diff --git a/crates/assists/src/handlers/unwrap_block.rs b/crates/assists/src/handlers/unwrap_block.rs new file mode 100644 index 000000000..3851aeb3e --- /dev/null +++ b/crates/assists/src/handlers/unwrap_block.rs @@ -0,0 +1,517 @@ +use syntax::{ + ast::{ + self, + edit::{AstNodeEdit, IndentLevel}, + }, + AstNode, TextRange, T, +}; + +use crate::{utils::unwrap_trivial_block, AssistContext, AssistId, AssistKind, Assists}; + +// Assist: unwrap_block +// +// This assist removes if...else, for, while and loop control statements to just keep the body. +// +// ``` +// fn foo() { +// if true {<|> +// println!("foo"); +// } +// } +// ``` +// -> +// ``` +// fn foo() { +// println!("foo"); +// } +// ``` +pub(crate) fn unwrap_block(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let assist_id = AssistId("unwrap_block", AssistKind::RefactorRewrite); + let assist_label = "Unwrap block"; + + let l_curly_token = ctx.find_token_at_offset(T!['{'])?; + let mut block = ast::BlockExpr::cast(l_curly_token.parent())?; + let mut parent = block.syntax().parent()?; + if ast::MatchArm::can_cast(parent.kind()) { + parent = parent.ancestors().find(|it| ast::MatchExpr::can_cast(it.kind()))? + } + + let parent = ast::Expr::cast(parent)?; + + match parent.clone() { + ast::Expr::ForExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::LoopExpr(_) => (), + ast::Expr::MatchExpr(_) => block = block.dedent(IndentLevel(1)), + ast::Expr::IfExpr(if_expr) => { + let then_branch = if_expr.then_branch()?; + if then_branch == block { + if let Some(ancestor) = if_expr.syntax().parent().and_then(ast::IfExpr::cast) { + // For `else if` blocks + let ancestor_then_branch = ancestor.then_branch()?; + + let target = then_branch.syntax().text_range(); + return acc.add(assist_id, assist_label, target, |edit| { + let range_to_del_else_if = TextRange::new( + ancestor_then_branch.syntax().text_range().end(), + l_curly_token.text_range().start(), + ); + let range_to_del_rest = TextRange::new( + then_branch.syntax().text_range().end(), + if_expr.syntax().text_range().end(), + ); + + edit.delete(range_to_del_rest); + edit.delete(range_to_del_else_if); + edit.replace( + target, + update_expr_string(then_branch.to_string(), &[' ', '{']), + ); + }); + } + } else { + let target = block.syntax().text_range(); + return acc.add(assist_id, assist_label, target, |edit| { + let range_to_del = TextRange::new( + then_branch.syntax().text_range().end(), + l_curly_token.text_range().start(), + ); + + edit.delete(range_to_del); + edit.replace(target, update_expr_string(block.to_string(), &[' ', '{'])); + }); + } + } + _ => return None, + }; + + let unwrapped = unwrap_trivial_block(block); + let target = unwrapped.syntax().text_range(); + acc.add(assist_id, assist_label, target, |builder| { + builder.replace( + parent.syntax().text_range(), + update_expr_string(unwrapped.to_string(), &[' ', '{', '\n']), + ); + }) +} + +fn update_expr_string(expr_str: String, trim_start_pat: &[char]) -> String { + let expr_string = expr_str.trim_start_matches(trim_start_pat); + let mut expr_string_lines: Vec<&str> = expr_string.lines().collect(); + expr_string_lines.pop(); // Delete last line + + expr_string_lines + .into_iter() + .map(|line| line.replacen(" ", "", 1)) // Delete indentation + .collect::>() + .join("\n") +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn simple_if() { + check_assist( + unwrap_block, + r#" + fn main() { + bar(); + if true {<|> + foo(); + + //comment + bar(); + } else { + println!("bar"); + } + } + "#, + r#" + fn main() { + bar(); + foo(); + + //comment + bar(); + } + "#, + ); + } + + #[test] + fn simple_if_else() { + check_assist( + unwrap_block, + r#" + fn main() { + bar(); + if true { + foo(); + + //comment + bar(); + } else {<|> + println!("bar"); + } + } + "#, + r#" + fn main() { + bar(); + if true { + foo(); + + //comment + bar(); + } + println!("bar"); + } + "#, + ); + } + + #[test] + fn simple_if_else_if() { + check_assist( + unwrap_block, + r#" + fn main() { + //bar(); + if true { + println!("true"); + + //comment + //bar(); + } else if false {<|> + println!("bar"); + } else { + println!("foo"); + } + } + "#, + r#" + fn main() { + //bar(); + if true { + println!("true"); + + //comment + //bar(); + } + println!("bar"); + } + "#, + ); + } + + #[test] + fn simple_if_else_if_nested() { + check_assist( + unwrap_block, + r#" + fn main() { + //bar(); + if true { + println!("true"); + + //comment + //bar(); + } else if false { + println!("bar"); + } else if true {<|> + println!("foo"); + } + } + "#, + r#" + fn main() { + //bar(); + if true { + println!("true"); + + //comment + //bar(); + } else if false { + println!("bar"); + } + println!("foo"); + } + "#, + ); + } + + #[test] + fn simple_if_else_if_nested_else() { + check_assist( + unwrap_block, + r#" + fn main() { + //bar(); + if true { + println!("true"); + + //comment + //bar(); + } else if false { + println!("bar"); + } else if true { + println!("foo"); + } else {<|> + println!("else"); + } + } + "#, + r#" + fn main() { + //bar(); + if true { + println!("true"); + + //comment + //bar(); + } else if false { + println!("bar"); + } else if true { + println!("foo"); + } + println!("else"); + } + "#, + ); + } + + #[test] + fn simple_if_else_if_nested_middle() { + check_assist( + unwrap_block, + r#" + fn main() { + //bar(); + if true { + println!("true"); + + //comment + //bar(); + } else if false { + println!("bar"); + } else if true {<|> + println!("foo"); + } else { + println!("else"); + } + } + "#, + r#" + fn main() { + //bar(); + if true { + println!("true"); + + //comment + //bar(); + } else if false { + println!("bar"); + } + println!("foo"); + } + "#, + ); + } + + #[test] + fn simple_if_bad_cursor_position() { + check_assist_not_applicable( + unwrap_block, + r#" + fn main() { + bar();<|> + if true { + foo(); + + //comment + bar(); + } else { + println!("bar"); + } + } + "#, + ); + } + + #[test] + fn simple_for() { + check_assist( + unwrap_block, + r#" + fn main() { + for i in 0..5 {<|> + if true { + foo(); + + //comment + bar(); + } else { + println!("bar"); + } + } + } + "#, + r#" + fn main() { + if true { + foo(); + + //comment + bar(); + } else { + println!("bar"); + } + } + "#, + ); + } + + #[test] + fn simple_if_in_for() { + check_assist( + unwrap_block, + r#" + fn main() { + for i in 0..5 { + if true {<|> + foo(); + + //comment + bar(); + } else { + println!("bar"); + } + } + } + "#, + r#" + fn main() { + for i in 0..5 { + foo(); + + //comment + bar(); + } + } + "#, + ); + } + + #[test] + fn simple_loop() { + check_assist( + unwrap_block, + r#" + fn main() { + loop {<|> + if true { + foo(); + + //comment + bar(); + } else { + println!("bar"); + } + } + } + "#, + r#" + fn main() { + if true { + foo(); + + //comment + bar(); + } else { + println!("bar"); + } + } + "#, + ); + } + + #[test] + fn simple_while() { + check_assist( + unwrap_block, + r#" + fn main() { + while true {<|> + if true { + foo(); + + //comment + bar(); + } else { + println!("bar"); + } + } + } + "#, + r#" + fn main() { + if true { + foo(); + + //comment + bar(); + } else { + println!("bar"); + } + } + "#, + ); + } + + #[test] + fn unwrap_match_arm() { + check_assist( + unwrap_block, + r#" +fn main() { + match rel_path { + Ok(rel_path) => {<|> + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) + } + Err(_) => None, + } +} +"#, + r#" +fn main() { + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) +} +"#, + ); + } + + #[test] + fn simple_if_in_while_bad_cursor_position() { + check_assist_not_applicable( + unwrap_block, + r#" + fn main() { + while true { + if true { + foo();<|> + + //comment + bar(); + } else { + println!("bar"); + } + } + } + "#, + ); + } +} -- cgit v1.2.3 From 9f548a02957314819aa0530d01f7c4a27dffdfc8 Mon Sep 17 00:00:00 2001 From: jDomantas Date: Fri, 14 Aug 2020 16:10:52 +0300 Subject: fixup whitespace when adding missing impl items --- .../src/handlers/add_missing_impl_members.rs | 63 ++++++++++++++++++++-- 1 file changed, 59 insertions(+), 4 deletions(-) (limited to 'crates/assists/src/handlers') diff --git a/crates/assists/src/handlers/add_missing_impl_members.rs b/crates/assists/src/handlers/add_missing_impl_members.rs index 81b61ebf8..83a2ada9a 100644 --- a/crates/assists/src/handlers/add_missing_impl_members.rs +++ b/crates/assists/src/handlers/add_missing_impl_members.rs @@ -48,7 +48,6 @@ enum AddMissingImplMembersMode { // fn foo(&self) -> u32 { // ${0:todo!()} // } -// // } // ``` pub(crate) fn add_missing_impl_members(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { @@ -89,8 +88,8 @@ pub(crate) fn add_missing_impl_members(acc: &mut Assists, ctx: &AssistContext) - // impl Trait for () { // Type X = (); // fn foo(&self) {} -// $0fn bar(&self) {} // +// $0fn bar(&self) {} // } // ``` pub(crate) fn add_missing_default_members(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { @@ -240,15 +239,18 @@ struct S; impl Foo for S { fn bar(&self) {} + $0type Output; + const CONST: usize = 42; + fn foo(&self) { todo!() } + fn baz(&self) { todo!() } - }"#, ); } @@ -281,10 +283,10 @@ struct S; impl Foo for S { fn bar(&self) {} + fn foo(&self) { ${0:todo!()} } - }"#, ); } @@ -599,6 +601,7 @@ trait Foo { struct S; impl Foo for S { $0type Output; + fn foo(&self) { todo!() } @@ -705,6 +708,58 @@ trait Tr { impl Tr for () { $0type Ty; +}"#, + ) + } + + #[test] + fn test_whitespace_fixup_preserves_bad_tokens() { + check_assist( + add_missing_impl_members, + r#" +trait Tr { + fn foo(); +} + +impl Tr for ()<|> { + +++ +}"#, + r#" +trait Tr { + fn foo(); +} + +impl Tr for () { + fn foo() { + ${0:todo!()} + } + +++ +}"#, + ) + } + + #[test] + fn test_whitespace_fixup_preserves_comments() { + check_assist( + add_missing_impl_members, + r#" +trait Tr { + fn foo(); +} + +impl Tr for ()<|> { + // very important +}"#, + r#" +trait Tr { + fn foo(); +} + +impl Tr for () { + fn foo() { + ${0:todo!()} + } + // very important }"#, ) } -- cgit v1.2.3 From 0ca1ba29e8e88c060dcf36946e4e02a6f015754b Mon Sep 17 00:00:00 2001 From: Aleksey Kladov Date: Sat, 15 Aug 2020 18:50:41 +0200 Subject: Don't expose hir::Path out of hir Conjecture: it's impossible to use hir::Path *correctly* from an IDE. I am not entirely sure about this, and we might need to add it back at some point, but I have to arguments that convince me that we probably won't: * `hir::Path` has to know about hygiene, which an IDE can't set up properly. * `hir::Path` lacks identity, but you actually have to know identity to resolve it correctly --- crates/assists/src/handlers/auto_import.rs | 2 +- crates/assists/src/handlers/expand_glob_import.rs | 20 +++------ .../handlers/extract_struct_from_enum_variant.rs | 7 ++- .../handlers/replace_qualified_name_with_use.rs | 50 +++++++++------------- 4 files changed, 35 insertions(+), 44 deletions(-) (limited to 'crates/assists/src/handlers') diff --git a/crates/assists/src/handlers/auto_import.rs b/crates/assists/src/handlers/auto_import.rs index cce789972..b9ec3f10b 100644 --- a/crates/assists/src/handlers/auto_import.rs +++ b/crates/assists/src/handlers/auto_import.rs @@ -53,7 +53,7 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()> |builder| { insert_use_statement( &auto_import_assets.syntax_under_caret, - &import, + &import.to_string(), ctx, builder.text_edit_builder(), ); diff --git a/crates/assists/src/handlers/expand_glob_import.rs b/crates/assists/src/handlers/expand_glob_import.rs index f690ec343..81d0af2f3 100644 --- a/crates/assists/src/handlers/expand_glob_import.rs +++ b/crates/assists/src/handlers/expand_glob_import.rs @@ -1,3 +1,4 @@ +use either::Either; use hir::{AssocItem, MacroDef, ModuleDef, Name, PathResolution, ScopeDef, SemanticsScope}; use ide_db::{ defs::{classify_name_ref, Definition, NameRefClass}, @@ -10,8 +11,6 @@ use crate::{ AssistId, AssistKind, }; -use either::Either; - // Assist: expand_glob_import // // Expands glob imports. @@ -40,11 +39,15 @@ use either::Either; pub(crate) fn expand_glob_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { let star = ctx.find_token_at_offset(T![*])?; let mod_path = find_mod_path(&star)?; + let module = match ctx.sema.resolve_path(&mod_path)? { + PathResolution::Def(ModuleDef::Module(it)) => it, + _ => return None, + }; let source_file = ctx.source_file(); let scope = ctx.sema.scope_at_offset(source_file.syntax(), ctx.offset()); - let defs_in_mod = find_defs_in_mod(ctx, scope, &mod_path)?; + let defs_in_mod = find_defs_in_mod(ctx, scope, module)?; let name_refs_in_source_file = source_file.syntax().descendants().filter_map(ast::NameRef::cast).collect(); let used_names = find_used_names(ctx, defs_in_mod, name_refs_in_source_file); @@ -82,17 +85,8 @@ impl Def { fn find_defs_in_mod( ctx: &AssistContext, from: SemanticsScope<'_>, - path: &ast::Path, + module: hir::Module, ) -> Option> { - let hir_path = ctx.sema.lower_path(&path)?; - let module = if let Some(PathResolution::Def(ModuleDef::Module(module))) = - from.resolve_hir_path_qualifier(&hir_path) - { - module - } else { - return None; - }; - let module_scope = module.scope(ctx.db(), from.module()); let mut defs = vec![]; diff --git a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs index 4bcdae7ba..d62e06b4a 100644 --- a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs +++ b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs @@ -106,7 +106,12 @@ fn insert_import( if let Some(mut mod_path) = mod_path { mod_path.segments.pop(); mod_path.segments.push(variant_hir_name.clone()); - insert_use_statement(path.syntax(), &mod_path, ctx, builder.text_edit_builder()); + insert_use_statement( + path.syntax(), + &mod_path.to_string(), + ctx, + builder.text_edit_builder(), + ); } Some(()) } diff --git a/crates/assists/src/handlers/replace_qualified_name_with_use.rs b/crates/assists/src/handlers/replace_qualified_name_with_use.rs index 011bf1106..470e5f8ff 100644 --- a/crates/assists/src/handlers/replace_qualified_name_with_use.rs +++ b/crates/assists/src/handlers/replace_qualified_name_with_use.rs @@ -1,5 +1,5 @@ -use hir; -use syntax::{algo::SyntaxRewriter, ast, match_ast, AstNode, SmolStr, SyntaxNode}; +use syntax::{algo::SyntaxRewriter, ast, match_ast, AstNode, SyntaxNode, TextRange}; +use test_utils::mark; use crate::{ utils::{find_insert_use_container, insert_use_statement}, @@ -28,12 +28,19 @@ pub(crate) fn replace_qualified_name_with_use( if path.syntax().ancestors().find_map(ast::Use::cast).is_some() { return None; } - - let hir_path = ctx.sema.lower_path(&path)?; - let segments = collect_hir_path_segments(&hir_path)?; - if segments.len() < 2 { + if path.qualifier().is_none() { + mark::hit!(dont_import_trivial_paths); return None; } + let path_to_import = path.to_string().clone(); + let path_to_import = match path.segment()?.generic_arg_list() { + Some(generic_args) => { + let generic_args_start = + generic_args.syntax().text_range().start() - path.syntax().text_range().start(); + &path_to_import[TextRange::up_to(generic_args_start)] + } + None => path_to_import.as_str(), + }; let target = path.syntax().text_range(); acc.add( @@ -41,12 +48,16 @@ pub(crate) fn replace_qualified_name_with_use( "Replace qualified path with use", target, |builder| { - let path_to_import = hir_path.mod_path().clone(); let container = match find_insert_use_container(path.syntax(), ctx) { Some(c) => c, None => return, }; - insert_use_statement(path.syntax(), &path_to_import, ctx, builder.text_edit_builder()); + insert_use_statement( + path.syntax(), + &path_to_import.to_string(), + ctx, + builder.text_edit_builder(), + ); // Now that we've brought the name into scope, re-qualify all paths that could be // affected (that is, all paths inside the node we added the `use` to). @@ -58,26 +69,6 @@ pub(crate) fn replace_qualified_name_with_use( ) } -fn collect_hir_path_segments(path: &hir::Path) -> Option> { - let mut ps = Vec::::with_capacity(10); - match path.kind() { - hir::PathKind::Abs => ps.push("".into()), - hir::PathKind::Crate => ps.push("crate".into()), - hir::PathKind::Plain => {} - hir::PathKind::Super(0) => ps.push("self".into()), - hir::PathKind::Super(lvl) => { - let mut chain = "super".to_string(); - for _ in 0..*lvl { - chain += "::super"; - } - ps.push(chain.into()); - } - hir::PathKind::DollarCrate(_) => return None, - } - ps.extend(path.segments().iter().map(|it| it.name.to_string().into())); - Some(ps) -} - /// Adds replacements to `re` that shorten `path` in all descendants of `node`. fn shorten_paths(rewriter: &mut SyntaxRewriter<'static>, node: SyntaxNode, path: ast::Path) { for child in node.children() { @@ -467,7 +458,8 @@ impl Debug for Foo { } #[test] - fn test_replace_not_applicable_one_segment() { + fn dont_import_trivial_paths() { + mark::check!(dont_import_trivial_paths); check_assist_not_applicable( replace_qualified_name_with_use, r" -- cgit v1.2.3 From 81b0976187d73eba4f9b14d8a0b8539ab8f06dcd Mon Sep 17 00:00:00 2001 From: Aleksey Kladov Date: Wed, 19 Aug 2020 18:58:48 +0200 Subject: Future proof find-usages API We might want to provide more efficient impls for check if usages exist, limiting the search, filtering and cancellation, so let's violate YAGNI a bit here. --- crates/assists/src/handlers/extract_struct_from_enum_variant.rs | 2 +- crates/assists/src/handlers/inline_local_variable.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'crates/assists/src/handlers') diff --git a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs index d62e06b4a..c1124b9e2 100644 --- a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs +++ b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs @@ -53,7 +53,7 @@ pub(crate) fn extract_struct_from_enum_variant( target, |builder| { let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)); - let res = definition.find_usages(&ctx.sema, None); + let res = definition.usages(&ctx.sema).all(); let start_offset = variant.parent_enum().syntax().text_range().start(); let mut visited_modules_set = FxHashSet::default(); visited_modules_set.insert(current_module); diff --git a/crates/assists/src/handlers/inline_local_variable.rs b/crates/assists/src/handlers/inline_local_variable.rs index 2b52b333b..164bbce86 100644 --- a/crates/assists/src/handlers/inline_local_variable.rs +++ b/crates/assists/src/handlers/inline_local_variable.rs @@ -44,7 +44,7 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext) -> O let def = ctx.sema.to_def(&bind_pat)?; let def = Definition::Local(def); - let refs = def.find_usages(&ctx.sema, None); + let refs = def.usages(&ctx.sema).all(); if refs.is_empty() { mark::hit!(test_not_applicable_if_variable_unused); return None; -- cgit v1.2.3 From 4b5b55f6f3d1879cd974f290e2f0d92f487acc4b Mon Sep 17 00:00:00 2001 From: Aleksey Kladov Date: Wed, 19 Aug 2020 18:44:33 +0200 Subject: **Remove Unused Parameter** refactoring --- crates/assists/src/handlers/merge_imports.rs | 5 +- crates/assists/src/handlers/remove_dbg.rs | 3 +- crates/assists/src/handlers/remove_unused_param.rs | 131 +++++++++++++++++++++ 3 files changed, 134 insertions(+), 5 deletions(-) create mode 100644 crates/assists/src/handlers/remove_unused_param.rs (limited to 'crates/assists/src/handlers') diff --git a/crates/assists/src/handlers/merge_imports.rs b/crates/assists/src/handlers/merge_imports.rs index 47d465404..35b884206 100644 --- a/crates/assists/src/handlers/merge_imports.rs +++ b/crates/assists/src/handlers/merge_imports.rs @@ -8,6 +8,7 @@ use syntax::{ use crate::{ assist_context::{AssistContext, Assists}, + utils::next_prev, AssistId, AssistKind, }; @@ -66,10 +67,6 @@ pub(crate) fn merge_imports(acc: &mut Assists, ctx: &AssistContext) -> Option<() ) } -fn next_prev() -> impl Iterator { - [Direction::Next, Direction::Prev].iter().copied() -} - fn try_merge_trees(old: &ast::UseTree, new: &ast::UseTree) -> Option { let lhs_path = old.path()?; let rhs_path = new.path()?; diff --git a/crates/assists/src/handlers/remove_dbg.rs b/crates/assists/src/handlers/remove_dbg.rs index f3dcca534..4e252edf0 100644 --- a/crates/assists/src/handlers/remove_dbg.rs +++ b/crates/assists/src/handlers/remove_dbg.rs @@ -82,9 +82,10 @@ fn is_valid_macrocall(macro_call: &ast::MacroCall, macro_name: &str) -> Optiondbg!(1 + 1)", "1 + 1"); diff --git a/crates/assists/src/handlers/remove_unused_param.rs b/crates/assists/src/handlers/remove_unused_param.rs new file mode 100644 index 000000000..5fccca54b --- /dev/null +++ b/crates/assists/src/handlers/remove_unused_param.rs @@ -0,0 +1,131 @@ +use ide_db::{defs::Definition, search::Reference}; +use syntax::{ + algo::find_node_at_range, + ast::{self, ArgListOwner}, + AstNode, SyntaxNode, TextRange, T, +}; +use test_utils::mark; + +use crate::{ + assist_context::AssistBuilder, utils::next_prev, AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: remove_unused_param +// +// Removes unused function parameter. +// +// ``` +// fn frobnicate(x: i32<|>) {} +// +// fn main() { +// frobnicate(92); +// } +// ``` +// -> +// ``` +// fn frobnicate() {} +// +// fn main() { +// frobnicate(); +// } +// ``` +pub(crate) fn remove_unused_param(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let param: ast::Param = ctx.find_node_at_offset()?; + let ident_pat = match param.pat()? { + ast::Pat::IdentPat(it) => it, + _ => return None, + }; + let func = param.syntax().ancestors().find_map(ast::Fn::cast)?; + let param_position = func.param_list()?.params().position(|it| it == param)?; + + let fn_def = { + let func = ctx.sema.to_def(&func)?; + Definition::ModuleDef(func.into()) + }; + + let param_def = { + let local = ctx.sema.to_def(&ident_pat)?; + Definition::Local(local) + }; + if param_def.usages(&ctx.sema).at_least_one() { + mark::hit!(keep_used); + return None; + } + acc.add( + AssistId("remove_unused_param", AssistKind::Refactor), + "Remove unused parameter", + param.syntax().text_range(), + |builder| { + builder.delete(range_with_coma(param.syntax())); + for usage in fn_def.usages(&ctx.sema).all() { + process_usage(ctx, builder, usage, param_position); + } + }, + ) +} + +fn process_usage( + ctx: &AssistContext, + builder: &mut AssistBuilder, + usage: Reference, + arg_to_remove: usize, +) -> Option<()> { + let source_file = ctx.sema.parse(usage.file_range.file_id); + let call_expr: ast::CallExpr = + find_node_at_range(source_file.syntax(), usage.file_range.range)?; + if call_expr.expr()?.syntax().text_range() != usage.file_range.range { + return None; + } + let arg = call_expr.arg_list()?.args().nth(arg_to_remove)?; + + builder.edit_file(usage.file_range.file_id); + builder.delete(range_with_coma(arg.syntax())); + + Some(()) +} + +fn range_with_coma(node: &SyntaxNode) -> TextRange { + let up_to = next_prev().find_map(|dir| { + node.siblings_with_tokens(dir) + .filter_map(|it| it.into_token()) + .find(|it| it.kind() == T![,]) + }); + let up_to = up_to.map_or(node.text_range(), |it| it.text_range()); + node.text_range().cover(up_to) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn remove_unused() { + check_assist( + remove_unused_param, + r#" +fn a() { foo(9, 2) } +fn foo(x: i32, <|>y: i32) { x; } +fn b() { foo(9, 2,) } +"#, + r#" +fn a() { foo(9) } +fn foo(x: i32) { x; } +fn b() { foo(9, ) } +"#, + ); + } + + #[test] + fn keep_used() { + mark::check!(keep_used); + check_assist_not_applicable( + remove_unused_param, + r#" +fn foo(x: i32, <|>y: i32) { y; } +fn main() { foo(9, 2) } +"#, + ); + } +} -- cgit v1.2.3 From 863b1fb731e797f02daeac87a94d40a34222a062 Mon Sep 17 00:00:00 2001 From: Aleksey Kladov Date: Fri, 21 Aug 2020 19:12:38 +0200 Subject: :arrow_up: ungrammar --- crates/assists/src/handlers/auto_import.rs | 2 +- crates/assists/src/handlers/replace_unwrap_with_match.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'crates/assists/src/handlers') diff --git a/crates/assists/src/handlers/auto_import.rs b/crates/assists/src/handlers/auto_import.rs index b9ec3f10b..c4770f336 100644 --- a/crates/assists/src/handlers/auto_import.rs +++ b/crates/assists/src/handlers/auto_import.rs @@ -239,7 +239,7 @@ impl ImportCandidate { return None; } Some(Self::TraitMethod( - sema.type_of_expr(&method_call.expr()?)?, + sema.type_of_expr(&method_call.receiver()?)?, method_call.name_ref()?.syntax().to_string(), )) } diff --git a/crates/assists/src/handlers/replace_unwrap_with_match.rs b/crates/assists/src/handlers/replace_unwrap_with_match.rs index 9705f11b7..4043c219c 100644 --- a/crates/assists/src/handlers/replace_unwrap_with_match.rs +++ b/crates/assists/src/handlers/replace_unwrap_with_match.rs @@ -42,7 +42,7 @@ pub(crate) fn replace_unwrap_with_match(acc: &mut Assists, ctx: &AssistContext) if name.text() != "unwrap" { return None; } - let caller = method_call.expr()?; + let caller = method_call.receiver()?; let ty = ctx.sema.type_of_expr(&caller)?; let happy_variant = TryEnum::from_ty(&ctx.sema, &ty)?.happy_case(); let target = method_call.syntax().text_range(); -- cgit v1.2.3