aboutsummaryrefslogtreecommitdiff
path: root/crates/assists
diff options
context:
space:
mode:
Diffstat (limited to 'crates/assists')
-rw-r--r--crates/assists/src/assist_context.rs45
-rw-r--r--crates/assists/src/handlers/add_custom_impl.rs284
-rw-r--r--crates/assists/src/handlers/add_missing_impl_members.rs92
-rw-r--r--crates/assists/src/handlers/add_turbo_fish.rs2
-rw-r--r--crates/assists/src/handlers/change_return_type_to_result.rs998
-rw-r--r--crates/assists/src/handlers/convert_integer_literal.rs395
-rw-r--r--crates/assists/src/handlers/expand_glob_import.rs78
-rw-r--r--crates/assists/src/handlers/extract_struct_from_enum_variant.rs67
-rw-r--r--crates/assists/src/handlers/flip_comma.rs2
-rw-r--r--crates/assists/src/handlers/flip_trait_bound.rs2
-rw-r--r--crates/assists/src/handlers/infer_function_return_type.rs337
-rw-r--r--crates/assists/src/handlers/introduce_named_lifetime.rs2
-rw-r--r--crates/assists/src/handlers/invert_if.rs2
-rw-r--r--crates/assists/src/handlers/raw_string.rs34
-rw-r--r--crates/assists/src/handlers/remove_mut.rs2
-rw-r--r--crates/assists/src/handlers/reorder_fields.rs4
-rw-r--r--crates/assists/src/handlers/replace_derive_with_manual_impl.rs398
-rw-r--r--crates/assists/src/handlers/replace_let_with_if_let.rs2
-rw-r--r--crates/assists/src/handlers/replace_string_with_char.rs8
-rw-r--r--crates/assists/src/handlers/split_import.rs2
-rw-r--r--crates/assists/src/handlers/unwrap_block.rs2
-rw-r--r--crates/assists/src/handlers/wrap_return_type_in_result.rs1158
-rw-r--r--crates/assists/src/lib.rs10
-rw-r--r--crates/assists/src/tests.rs25
-rw-r--r--crates/assists/src/tests/generated.rs81
-rw-r--r--crates/assists/src/utils.rs92
26 files changed, 2260 insertions, 1864 deletions
diff --git a/crates/assists/src/assist_context.rs b/crates/assists/src/assist_context.rs
index d11fee196..69499ea32 100644
--- a/crates/assists/src/assist_context.rs
+++ b/crates/assists/src/assist_context.rs
@@ -12,7 +12,7 @@ use ide_db::{
12}; 12};
13use syntax::{ 13use syntax::{
14 algo::{self, find_node_at_offset, SyntaxRewriter}, 14 algo::{self, find_node_at_offset, SyntaxRewriter},
15 AstNode, SourceFile, SyntaxElement, SyntaxKind, SyntaxToken, TextRange, TextSize, 15 AstNode, AstToken, SourceFile, SyntaxElement, SyntaxKind, SyntaxToken, TextRange, TextSize,
16 TokenAtOffset, 16 TokenAtOffset,
17}; 17};
18use text_edit::{TextEdit, TextEditBuilder}; 18use text_edit::{TextEdit, TextEditBuilder};
@@ -81,9 +81,12 @@ impl<'a> AssistContext<'a> {
81 pub(crate) fn token_at_offset(&self) -> TokenAtOffset<SyntaxToken> { 81 pub(crate) fn token_at_offset(&self) -> TokenAtOffset<SyntaxToken> {
82 self.source_file.syntax().token_at_offset(self.offset()) 82 self.source_file.syntax().token_at_offset(self.offset())
83 } 83 }
84 pub(crate) fn find_token_at_offset(&self, kind: SyntaxKind) -> Option<SyntaxToken> { 84 pub(crate) fn find_token_syntax_at_offset(&self, kind: SyntaxKind) -> Option<SyntaxToken> {
85 self.token_at_offset().find(|it| it.kind() == kind) 85 self.token_at_offset().find(|it| it.kind() == kind)
86 } 86 }
87 pub(crate) fn find_token_at_offset<T: AstToken>(&self) -> Option<T> {
88 self.token_at_offset().find_map(T::cast)
89 }
87 pub(crate) fn find_node_at_offset<N: AstNode>(&self) -> Option<N> { 90 pub(crate) fn find_node_at_offset<N: AstNode>(&self) -> Option<N> {
88 find_node_at_offset(self.source_file.syntax(), self.offset()) 91 find_node_at_offset(self.source_file.syntax(), self.offset())
89 } 92 }
@@ -205,7 +208,7 @@ pub(crate) struct AssistBuilder {
205 edit: TextEditBuilder, 208 edit: TextEditBuilder,
206 file_id: FileId, 209 file_id: FileId,
207 is_snippet: bool, 210 is_snippet: bool,
208 change: SourceChange, 211 source_file_edits: Vec<SourceFileEdit>,
209} 212}
210 213
211impl AssistBuilder { 214impl AssistBuilder {
@@ -214,20 +217,27 @@ impl AssistBuilder {
214 edit: TextEdit::builder(), 217 edit: TextEdit::builder(),
215 file_id, 218 file_id,
216 is_snippet: false, 219 is_snippet: false,
217 change: SourceChange::default(), 220 source_file_edits: Vec::default(),
218 } 221 }
219 } 222 }
220 223
221 pub(crate) fn edit_file(&mut self, file_id: FileId) { 224 pub(crate) fn edit_file(&mut self, file_id: FileId) {
225 self.commit();
222 self.file_id = file_id; 226 self.file_id = file_id;
223 } 227 }
224 228
225 fn commit(&mut self) { 229 fn commit(&mut self) {
226 let edit = mem::take(&mut self.edit).finish(); 230 let edit = mem::take(&mut self.edit).finish();
227 if !edit.is_empty() { 231 if !edit.is_empty() {
228 let new_edit = SourceFileEdit { file_id: self.file_id, edit }; 232 match self.source_file_edits.binary_search_by_key(&self.file_id, |edit| edit.file_id) {
229 assert!(!self.change.source_file_edits.iter().any(|it| it.file_id == new_edit.file_id)); 233 Ok(idx) => self.source_file_edits[idx]
230 self.change.source_file_edits.push(new_edit); 234 .edit
235 .union(edit)
236 .expect("overlapping edits for same file"),
237 Err(idx) => self
238 .source_file_edits
239 .insert(idx, SourceFileEdit { file_id: self.file_id, edit }),
240 }
231 } 241 }
232 } 242 }
233 243
@@ -267,23 +277,18 @@ impl AssistBuilder {
267 algo::diff(old.syntax(), new.syntax()).into_text_edit(&mut self.edit) 277 algo::diff(old.syntax(), new.syntax()).into_text_edit(&mut self.edit)
268 } 278 }
269 pub(crate) fn rewrite(&mut self, rewriter: SyntaxRewriter) { 279 pub(crate) fn rewrite(&mut self, rewriter: SyntaxRewriter) {
270 let node = rewriter.rewrite_root().unwrap(); 280 if let Some(node) = rewriter.rewrite_root() {
271 let new = rewriter.rewrite(&node); 281 let new = rewriter.rewrite(&node);
272 algo::diff(&node, &new).into_text_edit(&mut self.edit); 282 algo::diff(&node, &new).into_text_edit(&mut self.edit);
273 } 283 }
274
275 // FIXME: kill this API
276 /// Get access to the raw `TextEditBuilder`.
277 pub(crate) fn text_edit_builder(&mut self) -> &mut TextEditBuilder {
278 &mut self.edit
279 } 284 }
280 285
281 fn finish(mut self) -> SourceChange { 286 fn finish(mut self) -> SourceChange {
282 self.commit(); 287 self.commit();
283 let mut change = mem::take(&mut self.change); 288 SourceChange {
284 if self.is_snippet { 289 source_file_edits: mem::take(&mut self.source_file_edits),
285 change.is_snippet = true; 290 file_system_edits: Default::default(),
291 is_snippet: self.is_snippet,
286 } 292 }
287 change
288 } 293 }
289} 294}
diff --git a/crates/assists/src/handlers/add_custom_impl.rs b/crates/assists/src/handlers/add_custom_impl.rs
deleted file mode 100644
index 669dd9b21..000000000
--- a/crates/assists/src/handlers/add_custom_impl.rs
+++ /dev/null
@@ -1,284 +0,0 @@
1use ide_db::imports_locator;
2use itertools::Itertools;
3use syntax::{
4 ast::{self, make, AstNode},
5 Direction, SmolStr,
6 SyntaxKind::{IDENT, WHITESPACE},
7 TextRange, TextSize,
8};
9
10use crate::{
11 assist_config::SnippetCap,
12 assist_context::{AssistBuilder, AssistContext, Assists},
13 utils::mod_path_to_ast,
14 AssistId, AssistKind,
15};
16
17// Assist: add_custom_impl
18//
19// Adds impl block for derived trait.
20//
21// ```
22// #[derive(Deb<|>ug, Display)]
23// struct S;
24// ```
25// ->
26// ```
27// #[derive(Display)]
28// struct S;
29//
30// impl Debug for S {
31// $0
32// }
33// ```
34pub(crate) fn add_custom_impl(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
35 let attr = ctx.find_node_at_offset::<ast::Attr>()?;
36
37 let attr_name = attr
38 .syntax()
39 .descendants_with_tokens()
40 .filter(|t| t.kind() == IDENT)
41 .find_map(syntax::NodeOrToken::into_token)
42 .filter(|t| t.text() == "derive")?
43 .text()
44 .clone();
45
46 let trait_token =
47 ctx.token_at_offset().find(|t| t.kind() == IDENT && *t.text() != attr_name)?;
48 let trait_path = make::path_unqualified(make::path_segment(make::name_ref(trait_token.text())));
49
50 let annotated = attr.syntax().siblings(Direction::Next).find_map(ast::Name::cast)?;
51 let annotated_name = annotated.syntax().text().to_string();
52 let insert_pos = annotated.syntax().parent()?.text_range().end();
53
54 let current_module = ctx.sema.scope(annotated.syntax()).module()?;
55 let current_crate = current_module.krate();
56
57 let found_traits = imports_locator::find_imports(&ctx.sema, current_crate, trait_token.text())
58 .into_iter()
59 .filter_map(|candidate: either::Either<hir::ModuleDef, hir::MacroDef>| match candidate {
60 either::Either::Left(hir::ModuleDef::Trait(trait_)) => Some(trait_),
61 _ => None,
62 })
63 .flat_map(|trait_| {
64 current_module
65 .find_use_path(ctx.sema.db, hir::ModuleDef::Trait(trait_))
66 .as_ref()
67 .map(mod_path_to_ast)
68 .zip(Some(trait_))
69 });
70
71 let mut no_traits_found = true;
72 for (trait_path, _trait) in found_traits.inspect(|_| no_traits_found = false) {
73 add_assist(acc, ctx.config.snippet_cap, &attr, &trait_path, &annotated_name, insert_pos)?;
74 }
75 if no_traits_found {
76 add_assist(acc, ctx.config.snippet_cap, &attr, &trait_path, &annotated_name, insert_pos)?;
77 }
78 Some(())
79}
80
81fn add_assist(
82 acc: &mut Assists,
83 snippet_cap: Option<SnippetCap>,
84 attr: &ast::Attr,
85 trait_path: &ast::Path,
86 annotated_name: &str,
87 insert_pos: TextSize,
88) -> Option<()> {
89 let target = attr.syntax().text_range();
90 let input = attr.token_tree()?;
91 let label = format!("Add custom impl `{}` for `{}`", trait_path, annotated_name);
92 let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
93
94 acc.add(AssistId("add_custom_impl", AssistKind::Refactor), label, target, |builder| {
95 update_attribute(builder, &input, &trait_name, &attr);
96 match snippet_cap {
97 Some(cap) => {
98 builder.insert_snippet(
99 cap,
100 insert_pos,
101 format!("\n\nimpl {} for {} {{\n $0\n}}", trait_path, annotated_name),
102 );
103 }
104 None => {
105 builder.insert(
106 insert_pos,
107 format!("\n\nimpl {} for {} {{\n\n}}", trait_path, annotated_name),
108 );
109 }
110 }
111 })
112}
113
114fn update_attribute(
115 builder: &mut AssistBuilder,
116 input: &ast::TokenTree,
117 trait_name: &ast::NameRef,
118 attr: &ast::Attr,
119) {
120 let new_attr_input = input
121 .syntax()
122 .descendants_with_tokens()
123 .filter(|t| t.kind() == IDENT)
124 .filter_map(|t| t.into_token().map(|t| t.text().clone()))
125 .filter(|t| t != trait_name.text())
126 .collect::<Vec<SmolStr>>();
127 let has_more_derives = !new_attr_input.is_empty();
128
129 if has_more_derives {
130 let new_attr_input = format!("({})", new_attr_input.iter().format(", "));
131 builder.replace(input.syntax().text_range(), new_attr_input);
132 } else {
133 let attr_range = attr.syntax().text_range();
134 builder.delete(attr_range);
135
136 let line_break_range = attr
137 .syntax()
138 .next_sibling_or_token()
139 .filter(|t| t.kind() == WHITESPACE)
140 .map(|t| t.text_range())
141 .unwrap_or_else(|| TextRange::new(TextSize::from(0), TextSize::from(0)));
142 builder.delete(line_break_range);
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use crate::tests::{check_assist, check_assist_not_applicable};
149
150 use super::*;
151
152 #[test]
153 fn add_custom_impl_qualified() {
154 check_assist(
155 add_custom_impl,
156 "
157mod fmt {
158 pub trait Debug {}
159}
160
161#[derive(Debu<|>g)]
162struct Foo {
163 bar: String,
164}
165",
166 "
167mod fmt {
168 pub trait Debug {}
169}
170
171struct Foo {
172 bar: String,
173}
174
175impl fmt::Debug for Foo {
176 $0
177}
178",
179 )
180 }
181 #[test]
182 fn add_custom_impl_for_unique_input() {
183 check_assist(
184 add_custom_impl,
185 "
186#[derive(Debu<|>g)]
187struct Foo {
188 bar: String,
189}
190 ",
191 "
192struct Foo {
193 bar: String,
194}
195
196impl Debug for Foo {
197 $0
198}
199 ",
200 )
201 }
202
203 #[test]
204 fn add_custom_impl_for_with_visibility_modifier() {
205 check_assist(
206 add_custom_impl,
207 "
208#[derive(Debug<|>)]
209pub struct Foo {
210 bar: String,
211}
212 ",
213 "
214pub struct Foo {
215 bar: String,
216}
217
218impl Debug for Foo {
219 $0
220}
221 ",
222 )
223 }
224
225 #[test]
226 fn add_custom_impl_when_multiple_inputs() {
227 check_assist(
228 add_custom_impl,
229 "
230#[derive(Display, Debug<|>, Serialize)]
231struct Foo {}
232 ",
233 "
234#[derive(Display, Serialize)]
235struct Foo {}
236
237impl Debug for Foo {
238 $0
239}
240 ",
241 )
242 }
243
244 #[test]
245 fn test_ignore_derive_macro_without_input() {
246 check_assist_not_applicable(
247 add_custom_impl,
248 "
249#[derive(<|>)]
250struct Foo {}
251 ",
252 )
253 }
254
255 #[test]
256 fn test_ignore_if_cursor_on_param() {
257 check_assist_not_applicable(
258 add_custom_impl,
259 "
260#[derive<|>(Debug)]
261struct Foo {}
262 ",
263 );
264
265 check_assist_not_applicable(
266 add_custom_impl,
267 "
268#[derive(Debug)<|>]
269struct Foo {}
270 ",
271 )
272 }
273
274 #[test]
275 fn test_ignore_if_not_derive() {
276 check_assist_not_applicable(
277 add_custom_impl,
278 "
279#[allow(non_camel_<|>case_types)]
280struct Foo {}
281 ",
282 )
283 }
284}
diff --git a/crates/assists/src/handlers/add_missing_impl_members.rs b/crates/assists/src/handlers/add_missing_impl_members.rs
index b82fb30ad..bbb71e261 100644
--- a/crates/assists/src/handlers/add_missing_impl_members.rs
+++ b/crates/assists/src/handlers/add_missing_impl_members.rs
@@ -1,27 +1,14 @@
1use hir::HasSource; 1use ide_db::traits::resolve_target_trait;
2use ide_db::traits::{get_missing_assoc_items, resolve_target_trait}; 2use syntax::ast::{self, AstNode};
3use syntax::{
4 ast::{
5 self,
6 edit::{self, AstNodeEdit, IndentLevel},
7 make, AstNode, NameOwner,
8 },
9 SmolStr,
10};
11 3
12use crate::{ 4use crate::{
13 assist_context::{AssistContext, Assists}, 5 assist_context::{AssistContext, Assists},
14 ast_transform::{self, AstTransform, QualifyPaths, SubstituteTypeParams}, 6 utils::add_trait_assoc_items_to_impl,
15 utils::{render_snippet, Cursor}, 7 utils::DefaultMethods,
8 utils::{filter_assoc_items, render_snippet, Cursor},
16 AssistId, AssistKind, 9 AssistId, AssistKind,
17}; 10};
18 11
19#[derive(PartialEq)]
20enum AddMissingImplMembersMode {
21 DefaultMethodsOnly,
22 NoDefaultMethods,
23}
24
25// Assist: add_impl_missing_members 12// Assist: add_impl_missing_members
26// 13//
27// Adds scaffold for required impl members. 14// Adds scaffold for required impl members.
@@ -55,7 +42,7 @@ pub(crate) fn add_missing_impl_members(acc: &mut Assists, ctx: &AssistContext) -
55 add_missing_impl_members_inner( 42 add_missing_impl_members_inner(
56 acc, 43 acc,
57 ctx, 44 ctx,
58 AddMissingImplMembersMode::NoDefaultMethods, 45 DefaultMethods::No,
59 "add_impl_missing_members", 46 "add_impl_missing_members",
60 "Implement missing members", 47 "Implement missing members",
61 ) 48 )
@@ -97,7 +84,7 @@ pub(crate) fn add_missing_default_members(acc: &mut Assists, ctx: &AssistContext
97 add_missing_impl_members_inner( 84 add_missing_impl_members_inner(
98 acc, 85 acc,
99 ctx, 86 ctx,
100 AddMissingImplMembersMode::DefaultMethodsOnly, 87 DefaultMethods::Only,
101 "add_impl_default_members", 88 "add_impl_default_members",
102 "Implement default members", 89 "Implement default members",
103 ) 90 )
@@ -106,7 +93,7 @@ pub(crate) fn add_missing_default_members(acc: &mut Assists, ctx: &AssistContext
106fn add_missing_impl_members_inner( 93fn add_missing_impl_members_inner(
107 acc: &mut Assists, 94 acc: &mut Assists,
108 ctx: &AssistContext, 95 ctx: &AssistContext,
109 mode: AddMissingImplMembersMode, 96 mode: DefaultMethods,
110 assist_id: &'static str, 97 assist_id: &'static str,
111 label: &'static str, 98 label: &'static str,
112) -> Option<()> { 99) -> Option<()> {
@@ -114,32 +101,11 @@ fn add_missing_impl_members_inner(
114 let impl_def = ctx.find_node_at_offset::<ast::Impl>()?; 101 let impl_def = ctx.find_node_at_offset::<ast::Impl>()?;
115 let trait_ = resolve_target_trait(&ctx.sema, &impl_def)?; 102 let trait_ = resolve_target_trait(&ctx.sema, &impl_def)?;
116 103
117 let def_name = |item: &ast::AssocItem| -> Option<SmolStr> { 104 let missing_items = filter_assoc_items(
118 match item { 105 ctx.db(),
119 ast::AssocItem::Fn(def) => def.name(), 106 &ide_db::traits::get_missing_assoc_items(&ctx.sema, &impl_def),
120 ast::AssocItem::TypeAlias(def) => def.name(), 107 mode,
121 ast::AssocItem::Const(def) => def.name(), 108 );
122 ast::AssocItem::MacroCall(_) => None,
123 }
124 .map(|it| it.text().clone())
125 };
126
127 let missing_items = get_missing_assoc_items(&ctx.sema, &impl_def)
128 .iter()
129 .map(|i| match i {
130 hir::AssocItem::Function(i) => ast::AssocItem::Fn(i.source(ctx.db()).value),
131 hir::AssocItem::TypeAlias(i) => ast::AssocItem::TypeAlias(i.source(ctx.db()).value),
132 hir::AssocItem::Const(i) => ast::AssocItem::Const(i.source(ctx.db()).value),
133 })
134 .filter(|t| def_name(&t).is_some())
135 .filter(|t| match t {
136 ast::AssocItem::Fn(def) => match mode {
137 AddMissingImplMembersMode::DefaultMethodsOnly => def.body().is_some(),
138 AddMissingImplMembersMode::NoDefaultMethods => def.body().is_none(),
139 },
140 _ => mode == AddMissingImplMembersMode::NoDefaultMethods,
141 })
142 .collect::<Vec<_>>();
143 109
144 if missing_items.is_empty() { 110 if missing_items.is_empty() {
145 return None; 111 return None;
@@ -147,29 +113,9 @@ fn add_missing_impl_members_inner(
147 113
148 let target = impl_def.syntax().text_range(); 114 let target = impl_def.syntax().text_range();
149 acc.add(AssistId(assist_id, AssistKind::QuickFix), label, target, |builder| { 115 acc.add(AssistId(assist_id, AssistKind::QuickFix), label, target, |builder| {
150 let impl_item_list = impl_def.assoc_item_list().unwrap_or_else(make::assoc_item_list);
151
152 let n_existing_items = impl_item_list.assoc_items().count();
153 let source_scope = ctx.sema.scope_for_def(trait_);
154 let target_scope = ctx.sema.scope(impl_def.syntax()); 116 let target_scope = ctx.sema.scope(impl_def.syntax());
155 let ast_transform = QualifyPaths::new(&target_scope, &source_scope) 117 let (new_impl_def, first_new_item) =
156 .or(SubstituteTypeParams::for_trait_impl(&source_scope, trait_, impl_def.clone())); 118 add_trait_assoc_items_to_impl(&ctx.sema, missing_items, trait_, impl_def, target_scope);
157
158 let items = missing_items
159 .into_iter()
160 .map(|it| ast_transform::apply(&*ast_transform, it))
161 .map(|it| match it {
162 ast::AssocItem::Fn(def) => ast::AssocItem::Fn(add_body(def)),
163 ast::AssocItem::TypeAlias(def) => ast::AssocItem::TypeAlias(def.remove_bounds()),
164 _ => it,
165 })
166 .map(|it| edit::remove_attrs_and_docs(&it));
167
168 let new_impl_item_list = impl_item_list.append_items(items);
169 let new_impl_def = impl_def.with_assoc_item_list(new_impl_item_list);
170 let first_new_item =
171 new_impl_def.assoc_item_list().unwrap().assoc_items().nth(n_existing_items).unwrap();
172
173 match ctx.config.snippet_cap { 119 match ctx.config.snippet_cap {
174 None => builder.replace(target, new_impl_def.to_string()), 120 None => builder.replace(target, new_impl_def.to_string()),
175 Some(cap) => { 121 Some(cap) => {
@@ -193,14 +139,6 @@ fn add_missing_impl_members_inner(
193 }) 139 })
194} 140}
195 141
196fn add_body(fn_def: ast::Fn) -> ast::Fn {
197 if fn_def.body().is_some() {
198 return fn_def;
199 }
200 let body = make::block_expr(None, Some(make::expr_todo())).indent(IndentLevel(1));
201 fn_def.with_body(body)
202}
203
204#[cfg(test)] 142#[cfg(test)]
205mod tests { 143mod tests {
206 use crate::tests::{check_assist, check_assist_not_applicable}; 144 use crate::tests::{check_assist, check_assist_not_applicable};
diff --git a/crates/assists/src/handlers/add_turbo_fish.rs b/crates/assists/src/handlers/add_turbo_fish.rs
index e3d84d698..1f486c013 100644
--- a/crates/assists/src/handlers/add_turbo_fish.rs
+++ b/crates/assists/src/handlers/add_turbo_fish.rs
@@ -25,7 +25,7 @@ use crate::{
25// } 25// }
26// ``` 26// ```
27pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 27pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
28 let ident = ctx.find_token_at_offset(SyntaxKind::IDENT).or_else(|| { 28 let ident = ctx.find_token_syntax_at_offset(SyntaxKind::IDENT).or_else(|| {
29 let arg_list = ctx.find_node_at_offset::<ast::ArgList>()?; 29 let arg_list = ctx.find_node_at_offset::<ast::ArgList>()?;
30 if arg_list.args().count() > 0 { 30 if arg_list.args().count() > 0 {
31 return None; 31 return None;
diff --git a/crates/assists/src/handlers/change_return_type_to_result.rs b/crates/assists/src/handlers/change_return_type_to_result.rs
deleted file mode 100644
index be480943c..000000000
--- a/crates/assists/src/handlers/change_return_type_to_result.rs
+++ /dev/null
@@ -1,998 +0,0 @@
1use std::iter;
2
3use syntax::{
4 ast::{self, make, BlockExpr, Expr, LoopBodyOwner},
5 AstNode, SyntaxNode,
6};
7use test_utils::mark;
8
9use crate::{AssistContext, AssistId, AssistKind, Assists};
10
11// Assist: change_return_type_to_result
12//
13// Change the function's return type to Result.
14//
15// ```
16// fn foo() -> i32<|> { 42i32 }
17// ```
18// ->
19// ```
20// fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
21// ```
22pub(crate) fn change_return_type_to_result(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
23 let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
24 // FIXME: extend to lambdas as well
25 let fn_def = ret_type.syntax().parent().and_then(ast::Fn::cast)?;
26
27 let type_ref = &ret_type.ty()?;
28 let ret_type_str = type_ref.syntax().text().to_string();
29 let first_part_ret_type = ret_type_str.splitn(2, '<').next();
30 if let Some(ret_type_first_part) = first_part_ret_type {
31 if ret_type_first_part.ends_with("Result") {
32 mark::hit!(change_return_type_to_result_simple_return_type_already_result);
33 return None;
34 }
35 }
36
37 let block_expr = &fn_def.body()?;
38
39 acc.add(
40 AssistId("change_return_type_to_result", AssistKind::RefactorRewrite),
41 "Wrap return type in Result",
42 type_ref.syntax().text_range(),
43 |builder| {
44 let mut tail_return_expr_collector = TailReturnCollector::new();
45 tail_return_expr_collector.collect_jump_exprs(block_expr, false);
46 tail_return_expr_collector.collect_tail_exprs(block_expr);
47
48 for ret_expr_arg in tail_return_expr_collector.exprs_to_wrap {
49 let ok_wrapped = make::expr_call(
50 make::expr_path(make::path_unqualified(make::path_segment(make::name_ref(
51 "Ok",
52 )))),
53 make::arg_list(iter::once(ret_expr_arg.clone())),
54 );
55 builder.replace_ast(ret_expr_arg, ok_wrapped);
56 }
57
58 match ctx.config.snippet_cap {
59 Some(cap) => {
60 let snippet = format!("Result<{}, ${{0:_}}>", type_ref);
61 builder.replace_snippet(cap, type_ref.syntax().text_range(), snippet)
62 }
63 None => builder
64 .replace(type_ref.syntax().text_range(), format!("Result<{}, _>", type_ref)),
65 }
66 },
67 )
68}
69
70struct TailReturnCollector {
71 exprs_to_wrap: Vec<ast::Expr>,
72}
73
74impl TailReturnCollector {
75 fn new() -> Self {
76 Self { exprs_to_wrap: vec![] }
77 }
78 /// Collect all`return` expression
79 fn collect_jump_exprs(&mut self, block_expr: &BlockExpr, collect_break: bool) {
80 let statements = block_expr.statements();
81 for stmt in statements {
82 let expr = match &stmt {
83 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
84 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
85 ast::Stmt::Item(_) => continue,
86 };
87 if let Some(expr) = &expr {
88 self.handle_exprs(expr, collect_break);
89 }
90 }
91
92 // Browse tail expressions for each block
93 if let Some(expr) = block_expr.expr() {
94 if let Some(last_exprs) = get_tail_expr_from_block(&expr) {
95 for last_expr in last_exprs {
96 let last_expr = match last_expr {
97 NodeType::Node(expr) => expr,
98 NodeType::Leaf(expr) => expr.syntax().clone(),
99 };
100
101 if let Some(last_expr) = Expr::cast(last_expr.clone()) {
102 self.handle_exprs(&last_expr, collect_break);
103 } else if let Some(expr_stmt) = ast::Stmt::cast(last_expr) {
104 let expr_stmt = match &expr_stmt {
105 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
106 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
107 ast::Stmt::Item(_) => None,
108 };
109 if let Some(expr) = &expr_stmt {
110 self.handle_exprs(expr, collect_break);
111 }
112 }
113 }
114 }
115 }
116 }
117
118 fn handle_exprs(&mut self, expr: &Expr, collect_break: bool) {
119 match expr {
120 Expr::BlockExpr(block_expr) => {
121 self.collect_jump_exprs(&block_expr, collect_break);
122 }
123 Expr::ReturnExpr(ret_expr) => {
124 if let Some(ret_expr_arg) = &ret_expr.expr() {
125 self.exprs_to_wrap.push(ret_expr_arg.clone());
126 }
127 }
128 Expr::BreakExpr(break_expr) if collect_break => {
129 if let Some(break_expr_arg) = &break_expr.expr() {
130 self.exprs_to_wrap.push(break_expr_arg.clone());
131 }
132 }
133 Expr::IfExpr(if_expr) => {
134 for block in if_expr.blocks() {
135 self.collect_jump_exprs(&block, collect_break);
136 }
137 }
138 Expr::LoopExpr(loop_expr) => {
139 if let Some(block_expr) = loop_expr.loop_body() {
140 self.collect_jump_exprs(&block_expr, collect_break);
141 }
142 }
143 Expr::ForExpr(for_expr) => {
144 if let Some(block_expr) = for_expr.loop_body() {
145 self.collect_jump_exprs(&block_expr, collect_break);
146 }
147 }
148 Expr::WhileExpr(while_expr) => {
149 if let Some(block_expr) = while_expr.loop_body() {
150 self.collect_jump_exprs(&block_expr, collect_break);
151 }
152 }
153 Expr::MatchExpr(match_expr) => {
154 if let Some(arm_list) = match_expr.match_arm_list() {
155 arm_list.arms().filter_map(|match_arm| match_arm.expr()).for_each(|expr| {
156 self.handle_exprs(&expr, collect_break);
157 });
158 }
159 }
160 _ => {}
161 }
162 }
163
164 fn collect_tail_exprs(&mut self, block: &BlockExpr) {
165 if let Some(expr) = block.expr() {
166 self.handle_exprs(&expr, true);
167 self.fetch_tail_exprs(&expr);
168 }
169 }
170
171 fn fetch_tail_exprs(&mut self, expr: &Expr) {
172 if let Some(exprs) = get_tail_expr_from_block(expr) {
173 for node_type in &exprs {
174 match node_type {
175 NodeType::Leaf(expr) => {
176 self.exprs_to_wrap.push(expr.clone());
177 }
178 NodeType::Node(expr) => {
179 if let Some(last_expr) = Expr::cast(expr.clone()) {
180 self.fetch_tail_exprs(&last_expr);
181 }
182 }
183 }
184 }
185 }
186 }
187}
188
189#[derive(Debug)]
190enum NodeType {
191 Leaf(ast::Expr),
192 Node(SyntaxNode),
193}
194
195/// Get a tail expression inside a block
196fn get_tail_expr_from_block(expr: &Expr) -> Option<Vec<NodeType>> {
197 match expr {
198 Expr::IfExpr(if_expr) => {
199 let mut nodes = vec![];
200 for block in if_expr.blocks() {
201 if let Some(block_expr) = block.expr() {
202 if let Some(tail_exprs) = get_tail_expr_from_block(&block_expr) {
203 nodes.extend(tail_exprs);
204 }
205 } else if let Some(last_expr) = block.syntax().last_child() {
206 nodes.push(NodeType::Node(last_expr));
207 } else {
208 nodes.push(NodeType::Node(block.syntax().clone()));
209 }
210 }
211 Some(nodes)
212 }
213 Expr::LoopExpr(loop_expr) => {
214 loop_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
215 }
216 Expr::ForExpr(for_expr) => {
217 for_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
218 }
219 Expr::WhileExpr(while_expr) => {
220 while_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
221 }
222 Expr::BlockExpr(block_expr) => {
223 block_expr.expr().map(|lc| vec![NodeType::Node(lc.syntax().clone())])
224 }
225 Expr::MatchExpr(match_expr) => {
226 let arm_list = match_expr.match_arm_list()?;
227 let arms: Vec<NodeType> = arm_list
228 .arms()
229 .filter_map(|match_arm| match_arm.expr())
230 .map(|expr| match expr {
231 Expr::ReturnExpr(ret_expr) => NodeType::Node(ret_expr.syntax().clone()),
232 Expr::BreakExpr(break_expr) => NodeType::Node(break_expr.syntax().clone()),
233 _ => match expr.syntax().last_child() {
234 Some(last_expr) => NodeType::Node(last_expr),
235 None => NodeType::Node(expr.syntax().clone()),
236 },
237 })
238 .collect();
239
240 Some(arms)
241 }
242 Expr::BreakExpr(expr) => expr.expr().map(|e| vec![NodeType::Leaf(e)]),
243 Expr::ReturnExpr(ret_expr) => Some(vec![NodeType::Node(ret_expr.syntax().clone())]),
244
245 Expr::CallExpr(_)
246 | Expr::Literal(_)
247 | Expr::TupleExpr(_)
248 | Expr::ArrayExpr(_)
249 | Expr::ParenExpr(_)
250 | Expr::PathExpr(_)
251 | Expr::RecordExpr(_)
252 | Expr::IndexExpr(_)
253 | Expr::MethodCallExpr(_)
254 | Expr::AwaitExpr(_)
255 | Expr::CastExpr(_)
256 | Expr::RefExpr(_)
257 | Expr::PrefixExpr(_)
258 | Expr::RangeExpr(_)
259 | Expr::BinExpr(_)
260 | Expr::MacroCall(_)
261 | Expr::BoxExpr(_) => Some(vec![NodeType::Leaf(expr.clone())]),
262 _ => None,
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use crate::tests::{check_assist, check_assist_not_applicable};
269
270 use super::*;
271
272 #[test]
273 fn change_return_type_to_result_simple() {
274 check_assist(
275 change_return_type_to_result,
276 r#"fn foo() -> i3<|>2 {
277 let test = "test";
278 return 42i32;
279 }"#,
280 r#"fn foo() -> Result<i32, ${0:_}> {
281 let test = "test";
282 return Ok(42i32);
283 }"#,
284 );
285 }
286
287 #[test]
288 fn change_return_type_to_result_simple_return_type() {
289 check_assist(
290 change_return_type_to_result,
291 r#"fn foo() -> i32<|> {
292 let test = "test";
293 return 42i32;
294 }"#,
295 r#"fn foo() -> Result<i32, ${0:_}> {
296 let test = "test";
297 return Ok(42i32);
298 }"#,
299 );
300 }
301
302 #[test]
303 fn change_return_type_to_result_simple_return_type_bad_cursor() {
304 check_assist_not_applicable(
305 change_return_type_to_result,
306 r#"fn foo() -> i32 {
307 let test = "test";<|>
308 return 42i32;
309 }"#,
310 );
311 }
312
313 #[test]
314 fn change_return_type_to_result_simple_return_type_already_result_std() {
315 check_assist_not_applicable(
316 change_return_type_to_result,
317 r#"fn foo() -> std::result::Result<i32<|>, String> {
318 let test = "test";
319 return 42i32;
320 }"#,
321 );
322 }
323
324 #[test]
325 fn change_return_type_to_result_simple_return_type_already_result() {
326 mark::check!(change_return_type_to_result_simple_return_type_already_result);
327 check_assist_not_applicable(
328 change_return_type_to_result,
329 r#"fn foo() -> Result<i32<|>, String> {
330 let test = "test";
331 return 42i32;
332 }"#,
333 );
334 }
335
336 #[test]
337 fn change_return_type_to_result_simple_with_cursor() {
338 check_assist(
339 change_return_type_to_result,
340 r#"fn foo() -> <|>i32 {
341 let test = "test";
342 return 42i32;
343 }"#,
344 r#"fn foo() -> Result<i32, ${0:_}> {
345 let test = "test";
346 return Ok(42i32);
347 }"#,
348 );
349 }
350
351 #[test]
352 fn change_return_type_to_result_simple_with_tail() {
353 check_assist(
354 change_return_type_to_result,
355 r#"fn foo() -><|> i32 {
356 let test = "test";
357 42i32
358 }"#,
359 r#"fn foo() -> Result<i32, ${0:_}> {
360 let test = "test";
361 Ok(42i32)
362 }"#,
363 );
364 }
365
366 #[test]
367 fn change_return_type_to_result_simple_with_tail_only() {
368 check_assist(
369 change_return_type_to_result,
370 r#"fn foo() -> i32<|> {
371 42i32
372 }"#,
373 r#"fn foo() -> Result<i32, ${0:_}> {
374 Ok(42i32)
375 }"#,
376 );
377 }
378 #[test]
379 fn change_return_type_to_result_simple_with_tail_block_like() {
380 check_assist(
381 change_return_type_to_result,
382 r#"fn foo() -> i32<|> {
383 if true {
384 42i32
385 } else {
386 24i32
387 }
388 }"#,
389 r#"fn foo() -> Result<i32, ${0:_}> {
390 if true {
391 Ok(42i32)
392 } else {
393 Ok(24i32)
394 }
395 }"#,
396 );
397 }
398
399 #[test]
400 fn change_return_type_to_result_simple_with_nested_if() {
401 check_assist(
402 change_return_type_to_result,
403 r#"fn foo() -> i32<|> {
404 if true {
405 if false {
406 1
407 } else {
408 2
409 }
410 } else {
411 24i32
412 }
413 }"#,
414 r#"fn foo() -> Result<i32, ${0:_}> {
415 if true {
416 if false {
417 Ok(1)
418 } else {
419 Ok(2)
420 }
421 } else {
422 Ok(24i32)
423 }
424 }"#,
425 );
426 }
427
428 #[test]
429 fn change_return_type_to_result_simple_with_await() {
430 check_assist(
431 change_return_type_to_result,
432 r#"async fn foo() -> i<|>32 {
433 if true {
434 if false {
435 1.await
436 } else {
437 2.await
438 }
439 } else {
440 24i32.await
441 }
442 }"#,
443 r#"async fn foo() -> Result<i32, ${0:_}> {
444 if true {
445 if false {
446 Ok(1.await)
447 } else {
448 Ok(2.await)
449 }
450 } else {
451 Ok(24i32.await)
452 }
453 }"#,
454 );
455 }
456
457 #[test]
458 fn change_return_type_to_result_simple_with_array() {
459 check_assist(
460 change_return_type_to_result,
461 r#"fn foo() -> [i32;<|> 3] {
462 [1, 2, 3]
463 }"#,
464 r#"fn foo() -> Result<[i32; 3], ${0:_}> {
465 Ok([1, 2, 3])
466 }"#,
467 );
468 }
469
470 #[test]
471 fn change_return_type_to_result_simple_with_cast() {
472 check_assist(
473 change_return_type_to_result,
474 r#"fn foo() -<|>> i32 {
475 if true {
476 if false {
477 1 as i32
478 } else {
479 2 as i32
480 }
481 } else {
482 24 as i32
483 }
484 }"#,
485 r#"fn foo() -> Result<i32, ${0:_}> {
486 if true {
487 if false {
488 Ok(1 as i32)
489 } else {
490 Ok(2 as i32)
491 }
492 } else {
493 Ok(24 as i32)
494 }
495 }"#,
496 );
497 }
498
499 #[test]
500 fn change_return_type_to_result_simple_with_tail_block_like_match() {
501 check_assist(
502 change_return_type_to_result,
503 r#"fn foo() -> i32<|> {
504 let my_var = 5;
505 match my_var {
506 5 => 42i32,
507 _ => 24i32,
508 }
509 }"#,
510 r#"fn foo() -> Result<i32, ${0:_}> {
511 let my_var = 5;
512 match my_var {
513 5 => Ok(42i32),
514 _ => Ok(24i32),
515 }
516 }"#,
517 );
518 }
519
520 #[test]
521 fn change_return_type_to_result_simple_with_loop_with_tail() {
522 check_assist(
523 change_return_type_to_result,
524 r#"fn foo() -> i32<|> {
525 let my_var = 5;
526 loop {
527 println!("test");
528 5
529 }
530
531 my_var
532 }"#,
533 r#"fn foo() -> Result<i32, ${0:_}> {
534 let my_var = 5;
535 loop {
536 println!("test");
537 5
538 }
539
540 Ok(my_var)
541 }"#,
542 );
543 }
544
545 #[test]
546 fn change_return_type_to_result_simple_with_loop_in_let_stmt() {
547 check_assist(
548 change_return_type_to_result,
549 r#"fn foo() -> i32<|> {
550 let my_var = let x = loop {
551 break 1;
552 };
553
554 my_var
555 }"#,
556 r#"fn foo() -> Result<i32, ${0:_}> {
557 let my_var = let x = loop {
558 break 1;
559 };
560
561 Ok(my_var)
562 }"#,
563 );
564 }
565
566 #[test]
567 fn change_return_type_to_result_simple_with_tail_block_like_match_return_expr() {
568 check_assist(
569 change_return_type_to_result,
570 r#"fn foo() -> i32<|> {
571 let my_var = 5;
572 let res = match my_var {
573 5 => 42i32,
574 _ => return 24i32,
575 };
576
577 res
578 }"#,
579 r#"fn foo() -> Result<i32, ${0:_}> {
580 let my_var = 5;
581 let res = match my_var {
582 5 => 42i32,
583 _ => return Ok(24i32),
584 };
585
586 Ok(res)
587 }"#,
588 );
589
590 check_assist(
591 change_return_type_to_result,
592 r#"fn foo() -> i32<|> {
593 let my_var = 5;
594 let res = if my_var == 5 {
595 42i32
596 } else {
597 return 24i32;
598 };
599
600 res
601 }"#,
602 r#"fn foo() -> Result<i32, ${0:_}> {
603 let my_var = 5;
604 let res = if my_var == 5 {
605 42i32
606 } else {
607 return Ok(24i32);
608 };
609
610 Ok(res)
611 }"#,
612 );
613 }
614
615 #[test]
616 fn change_return_type_to_result_simple_with_tail_block_like_match_deeper() {
617 check_assist(
618 change_return_type_to_result,
619 r#"fn foo() -> i32<|> {
620 let my_var = 5;
621 match my_var {
622 5 => {
623 if true {
624 42i32
625 } else {
626 25i32
627 }
628 },
629 _ => {
630 let test = "test";
631 if test == "test" {
632 return bar();
633 }
634 53i32
635 },
636 }
637 }"#,
638 r#"fn foo() -> Result<i32, ${0:_}> {
639 let my_var = 5;
640 match my_var {
641 5 => {
642 if true {
643 Ok(42i32)
644 } else {
645 Ok(25i32)
646 }
647 },
648 _ => {
649 let test = "test";
650 if test == "test" {
651 return Ok(bar());
652 }
653 Ok(53i32)
654 },
655 }
656 }"#,
657 );
658 }
659
660 #[test]
661 fn change_return_type_to_result_simple_with_tail_block_like_early_return() {
662 check_assist(
663 change_return_type_to_result,
664 r#"fn foo() -> i<|>32 {
665 let test = "test";
666 if test == "test" {
667 return 24i32;
668 }
669 53i32
670 }"#,
671 r#"fn foo() -> Result<i32, ${0:_}> {
672 let test = "test";
673 if test == "test" {
674 return Ok(24i32);
675 }
676 Ok(53i32)
677 }"#,
678 );
679 }
680
681 #[test]
682 fn change_return_type_to_result_simple_with_closure() {
683 check_assist(
684 change_return_type_to_result,
685 r#"fn foo(the_field: u32) -><|> u32 {
686 let true_closure = || {
687 return true;
688 };
689 if the_field < 5 {
690 let mut i = 0;
691
692
693 if true_closure() {
694 return 99;
695 } else {
696 return 0;
697 }
698 }
699
700 the_field
701 }"#,
702 r#"fn foo(the_field: u32) -> Result<u32, ${0:_}> {
703 let true_closure = || {
704 return true;
705 };
706 if the_field < 5 {
707 let mut i = 0;
708
709
710 if true_closure() {
711 return Ok(99);
712 } else {
713 return Ok(0);
714 }
715 }
716
717 Ok(the_field)
718 }"#,
719 );
720
721 check_assist(
722 change_return_type_to_result,
723 r#"fn foo(the_field: u32) -> u32<|> {
724 let true_closure = || {
725 return true;
726 };
727 if the_field < 5 {
728 let mut i = 0;
729
730
731 if true_closure() {
732 return 99;
733 } else {
734 return 0;
735 }
736 }
737 let t = None;
738
739 t.unwrap_or_else(|| the_field)
740 }"#,
741 r#"fn foo(the_field: u32) -> Result<u32, ${0:_}> {
742 let true_closure = || {
743 return true;
744 };
745 if the_field < 5 {
746 let mut i = 0;
747
748
749 if true_closure() {
750 return Ok(99);
751 } else {
752 return Ok(0);
753 }
754 }
755 let t = None;
756
757 Ok(t.unwrap_or_else(|| the_field))
758 }"#,
759 );
760 }
761
762 #[test]
763 fn change_return_type_to_result_simple_with_weird_forms() {
764 check_assist(
765 change_return_type_to_result,
766 r#"fn foo() -> i32<|> {
767 let test = "test";
768 if test == "test" {
769 return 24i32;
770 }
771 let mut i = 0;
772 loop {
773 if i == 1 {
774 break 55;
775 }
776 i += 1;
777 }
778 }"#,
779 r#"fn foo() -> Result<i32, ${0:_}> {
780 let test = "test";
781 if test == "test" {
782 return Ok(24i32);
783 }
784 let mut i = 0;
785 loop {
786 if i == 1 {
787 break Ok(55);
788 }
789 i += 1;
790 }
791 }"#,
792 );
793
794 check_assist(
795 change_return_type_to_result,
796 r#"fn foo() -> i32<|> {
797 let test = "test";
798 if test == "test" {
799 return 24i32;
800 }
801 let mut i = 0;
802 loop {
803 loop {
804 if i == 1 {
805 break 55;
806 }
807 i += 1;
808 }
809 }
810 }"#,
811 r#"fn foo() -> Result<i32, ${0:_}> {
812 let test = "test";
813 if test == "test" {
814 return Ok(24i32);
815 }
816 let mut i = 0;
817 loop {
818 loop {
819 if i == 1 {
820 break Ok(55);
821 }
822 i += 1;
823 }
824 }
825 }"#,
826 );
827
828 check_assist(
829 change_return_type_to_result,
830 r#"fn foo() -> i3<|>2 {
831 let test = "test";
832 let other = 5;
833 if test == "test" {
834 let res = match other {
835 5 => 43,
836 _ => return 56,
837 };
838 }
839 let mut i = 0;
840 loop {
841 loop {
842 if i == 1 {
843 break 55;
844 }
845 i += 1;
846 }
847 }
848 }"#,
849 r#"fn foo() -> Result<i32, ${0:_}> {
850 let test = "test";
851 let other = 5;
852 if test == "test" {
853 let res = match other {
854 5 => 43,
855 _ => return Ok(56),
856 };
857 }
858 let mut i = 0;
859 loop {
860 loop {
861 if i == 1 {
862 break Ok(55);
863 }
864 i += 1;
865 }
866 }
867 }"#,
868 );
869
870 check_assist(
871 change_return_type_to_result,
872 r#"fn foo(the_field: u32) -> u32<|> {
873 if the_field < 5 {
874 let mut i = 0;
875 loop {
876 if i > 5 {
877 return 55u32;
878 }
879 i += 3;
880 }
881
882 match i {
883 5 => return 99,
884 _ => return 0,
885 };
886 }
887
888 the_field
889 }"#,
890 r#"fn foo(the_field: u32) -> Result<u32, ${0:_}> {
891 if the_field < 5 {
892 let mut i = 0;
893 loop {
894 if i > 5 {
895 return Ok(55u32);
896 }
897 i += 3;
898 }
899
900 match i {
901 5 => return Ok(99),
902 _ => return Ok(0),
903 };
904 }
905
906 Ok(the_field)
907 }"#,
908 );
909
910 check_assist(
911 change_return_type_to_result,
912 r#"fn foo(the_field: u32) -> u3<|>2 {
913 if the_field < 5 {
914 let mut i = 0;
915
916 match i {
917 5 => return 99,
918 _ => return 0,
919 }
920 }
921
922 the_field
923 }"#,
924 r#"fn foo(the_field: u32) -> Result<u32, ${0:_}> {
925 if the_field < 5 {
926 let mut i = 0;
927
928 match i {
929 5 => return Ok(99),
930 _ => return Ok(0),
931 }
932 }
933
934 Ok(the_field)
935 }"#,
936 );
937
938 check_assist(
939 change_return_type_to_result,
940 r#"fn foo(the_field: u32) -> u32<|> {
941 if the_field < 5 {
942 let mut i = 0;
943
944 if i == 5 {
945 return 99
946 } else {
947 return 0
948 }
949 }
950
951 the_field
952 }"#,
953 r#"fn foo(the_field: u32) -> Result<u32, ${0:_}> {
954 if the_field < 5 {
955 let mut i = 0;
956
957 if i == 5 {
958 return Ok(99)
959 } else {
960 return Ok(0)
961 }
962 }
963
964 Ok(the_field)
965 }"#,
966 );
967
968 check_assist(
969 change_return_type_to_result,
970 r#"fn foo(the_field: u32) -> <|>u32 {
971 if the_field < 5 {
972 let mut i = 0;
973
974 if i == 5 {
975 return 99;
976 } else {
977 return 0;
978 }
979 }
980
981 the_field
982 }"#,
983 r#"fn foo(the_field: u32) -> Result<u32, ${0:_}> {
984 if the_field < 5 {
985 let mut i = 0;
986
987 if i == 5 {
988 return Ok(99);
989 } else {
990 return Ok(0);
991 }
992 }
993
994 Ok(the_field)
995 }"#,
996 );
997 }
998}
diff --git a/crates/assists/src/handlers/convert_integer_literal.rs b/crates/assists/src/handlers/convert_integer_literal.rs
index c8af80701..667115382 100644
--- a/crates/assists/src/handlers/convert_integer_literal.rs
+++ b/crates/assists/src/handlers/convert_integer_literal.rs
@@ -1,4 +1,4 @@
1use syntax::{ast, ast::Radix, AstNode}; 1use syntax::{ast, ast::Radix, AstToken};
2 2
3use crate::{AssistContext, AssistId, AssistKind, Assists, GroupLabel}; 3use crate::{AssistContext, AssistId, AssistKind, Assists, GroupLabel};
4 4
@@ -15,14 +15,16 @@ use crate::{AssistContext, AssistId, AssistKind, Assists, GroupLabel};
15// ``` 15// ```
16pub(crate) fn convert_integer_literal(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 16pub(crate) fn convert_integer_literal(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
17 let literal = ctx.find_node_at_offset::<ast::Literal>()?; 17 let literal = ctx.find_node_at_offset::<ast::Literal>()?;
18 let (radix, value) = literal.int_value()?; 18 let literal = match literal.kind() {
19 ast::LiteralKind::IntNumber(it) => it,
20 _ => return None,
21 };
22 let radix = literal.radix();
23 let value = literal.value()?;
24 let suffix = literal.suffix();
19 25
20 let range = literal.syntax().text_range(); 26 let range = literal.syntax().text_range();
21 let group_id = GroupLabel("Convert integer base".into()); 27 let group_id = GroupLabel("Convert integer base".into());
22 let suffix = match literal.kind() {
23 ast::LiteralKind::IntNumber { suffix } => suffix,
24 _ => return None,
25 };
26 28
27 for &target_radix in Radix::ALL { 29 for &target_radix in Radix::ALL {
28 if target_radix == radix { 30 if target_radix == radix {
@@ -36,16 +38,11 @@ pub(crate) fn convert_integer_literal(acc: &mut Assists, ctx: &AssistContext) ->
36 Radix::Hexadecimal => format!("0x{:X}", value), 38 Radix::Hexadecimal => format!("0x{:X}", value),
37 }; 39 };
38 40
39 let label = format!( 41 let label = format!("Convert {} to {}{}", literal, converted, suffix.unwrap_or_default());
40 "Convert {} to {}{}",
41 literal,
42 converted,
43 suffix.as_deref().unwrap_or_default()
44 );
45 42
46 // Appends the type suffix back into the new literal if it exists. 43 // Appends the type suffix back into the new literal if it exists.
47 if let Some(suffix) = &suffix { 44 if let Some(suffix) = suffix {
48 converted.push_str(&suffix); 45 converted.push_str(suffix);
49 } 46 }
50 47
51 acc.add_group( 48 acc.add_group(
@@ -132,34 +129,6 @@ mod tests {
132 ); 129 );
133 } 130 }
134 131
135 // Decimal numbers under 3 digits have a special case where they return early because we can't fit a
136 // other base's prefix, so we have a separate test for that.
137 #[test]
138 fn convert_small_decimal_integer() {
139 let before = "const _: i32 = 10<|>;";
140
141 check_assist_by_label(
142 convert_integer_literal,
143 before,
144 "const _: i32 = 0b1010;",
145 "Convert 10 to 0b1010",
146 );
147
148 check_assist_by_label(
149 convert_integer_literal,
150 before,
151 "const _: i32 = 0o12;",
152 "Convert 10 to 0o12",
153 );
154
155 check_assist_by_label(
156 convert_integer_literal,
157 before,
158 "const _: i32 = 0xA;",
159 "Convert 10 to 0xA",
160 );
161 }
162
163 #[test] 132 #[test]
164 fn convert_hexadecimal_integer() { 133 fn convert_hexadecimal_integer() {
165 let before = "const _: i32 = 0xFF<|>;"; 134 let before = "const _: i32 = 0xFF<|>;";
@@ -239,7 +208,7 @@ mod tests {
239 } 208 }
240 209
241 #[test] 210 #[test]
242 fn convert_decimal_integer_with_underscores() { 211 fn convert_integer_with_underscores() {
243 let before = "const _: i32 = 1_00_0<|>;"; 212 let before = "const _: i32 = 1_00_0<|>;";
244 213
245 check_assist_by_label( 214 check_assist_by_label(
@@ -265,111 +234,7 @@ mod tests {
265 } 234 }
266 235
267 #[test] 236 #[test]
268 fn convert_small_decimal_integer_with_underscores() { 237 fn convert_integer_with_suffix() {
269 let before = "const _: i32 = 1_0<|>;";
270
271 check_assist_by_label(
272 convert_integer_literal,
273 before,
274 "const _: i32 = 0b1010;",
275 "Convert 1_0 to 0b1010",
276 );
277
278 check_assist_by_label(
279 convert_integer_literal,
280 before,
281 "const _: i32 = 0o12;",
282 "Convert 1_0 to 0o12",
283 );
284
285 check_assist_by_label(
286 convert_integer_literal,
287 before,
288 "const _: i32 = 0xA;",
289 "Convert 1_0 to 0xA",
290 );
291 }
292
293 #[test]
294 fn convert_hexadecimal_integer_with_underscores() {
295 let before = "const _: i32 = 0x_F_F<|>;";
296
297 check_assist_by_label(
298 convert_integer_literal,
299 before,
300 "const _: i32 = 0b11111111;",
301 "Convert 0x_F_F to 0b11111111",
302 );
303
304 check_assist_by_label(
305 convert_integer_literal,
306 before,
307 "const _: i32 = 0o377;",
308 "Convert 0x_F_F to 0o377",
309 );
310
311 check_assist_by_label(
312 convert_integer_literal,
313 before,
314 "const _: i32 = 255;",
315 "Convert 0x_F_F to 255",
316 );
317 }
318
319 #[test]
320 fn convert_binary_integer_with_underscores() {
321 let before = "const _: i32 = 0b1111_1111<|>;";
322
323 check_assist_by_label(
324 convert_integer_literal,
325 before,
326 "const _: i32 = 0o377;",
327 "Convert 0b1111_1111 to 0o377",
328 );
329
330 check_assist_by_label(
331 convert_integer_literal,
332 before,
333 "const _: i32 = 255;",
334 "Convert 0b1111_1111 to 255",
335 );
336
337 check_assist_by_label(
338 convert_integer_literal,
339 before,
340 "const _: i32 = 0xFF;",
341 "Convert 0b1111_1111 to 0xFF",
342 );
343 }
344
345 #[test]
346 fn convert_octal_integer_with_underscores() {
347 let before = "const _: i32 = 0o3_77<|>;";
348
349 check_assist_by_label(
350 convert_integer_literal,
351 before,
352 "const _: i32 = 0b11111111;",
353 "Convert 0o3_77 to 0b11111111",
354 );
355
356 check_assist_by_label(
357 convert_integer_literal,
358 before,
359 "const _: i32 = 255;",
360 "Convert 0o3_77 to 255",
361 );
362
363 check_assist_by_label(
364 convert_integer_literal,
365 before,
366 "const _: i32 = 0xFF;",
367 "Convert 0o3_77 to 0xFF",
368 );
369 }
370
371 #[test]
372 fn convert_decimal_integer_with_suffix() {
373 let before = "const _: i32 = 1000i32<|>;"; 238 let before = "const _: i32 = 1000i32<|>;";
374 239
375 check_assist_by_label( 240 check_assist_by_label(
@@ -395,240 +260,6 @@ mod tests {
395 } 260 }
396 261
397 #[test] 262 #[test]
398 fn convert_small_decimal_integer_with_suffix() {
399 let before = "const _: i32 = 10i32<|>;";
400
401 check_assist_by_label(
402 convert_integer_literal,
403 before,
404 "const _: i32 = 0b1010i32;",
405 "Convert 10i32 to 0b1010i32",
406 );
407
408 check_assist_by_label(
409 convert_integer_literal,
410 before,
411 "const _: i32 = 0o12i32;",
412 "Convert 10i32 to 0o12i32",
413 );
414
415 check_assist_by_label(
416 convert_integer_literal,
417 before,
418 "const _: i32 = 0xAi32;",
419 "Convert 10i32 to 0xAi32",
420 );
421 }
422
423 #[test]
424 fn convert_hexadecimal_integer_with_suffix() {
425 let before = "const _: i32 = 0xFFi32<|>;";
426
427 check_assist_by_label(
428 convert_integer_literal,
429 before,
430 "const _: i32 = 0b11111111i32;",
431 "Convert 0xFFi32 to 0b11111111i32",
432 );
433
434 check_assist_by_label(
435 convert_integer_literal,
436 before,
437 "const _: i32 = 0o377i32;",
438 "Convert 0xFFi32 to 0o377i32",
439 );
440
441 check_assist_by_label(
442 convert_integer_literal,
443 before,
444 "const _: i32 = 255i32;",
445 "Convert 0xFFi32 to 255i32",
446 );
447 }
448
449 #[test]
450 fn convert_binary_integer_with_suffix() {
451 let before = "const _: i32 = 0b11111111i32<|>;";
452
453 check_assist_by_label(
454 convert_integer_literal,
455 before,
456 "const _: i32 = 0o377i32;",
457 "Convert 0b11111111i32 to 0o377i32",
458 );
459
460 check_assist_by_label(
461 convert_integer_literal,
462 before,
463 "const _: i32 = 255i32;",
464 "Convert 0b11111111i32 to 255i32",
465 );
466
467 check_assist_by_label(
468 convert_integer_literal,
469 before,
470 "const _: i32 = 0xFFi32;",
471 "Convert 0b11111111i32 to 0xFFi32",
472 );
473 }
474
475 #[test]
476 fn convert_octal_integer_with_suffix() {
477 let before = "const _: i32 = 0o377i32<|>;";
478
479 check_assist_by_label(
480 convert_integer_literal,
481 before,
482 "const _: i32 = 0b11111111i32;",
483 "Convert 0o377i32 to 0b11111111i32",
484 );
485
486 check_assist_by_label(
487 convert_integer_literal,
488 before,
489 "const _: i32 = 255i32;",
490 "Convert 0o377i32 to 255i32",
491 );
492
493 check_assist_by_label(
494 convert_integer_literal,
495 before,
496 "const _: i32 = 0xFFi32;",
497 "Convert 0o377i32 to 0xFFi32",
498 );
499 }
500
501 #[test]
502 fn convert_decimal_integer_with_underscores_and_suffix() {
503 let before = "const _: i32 = 1_00_0i32<|>;";
504
505 check_assist_by_label(
506 convert_integer_literal,
507 before,
508 "const _: i32 = 0b1111101000i32;",
509 "Convert 1_00_0i32 to 0b1111101000i32",
510 );
511
512 check_assist_by_label(
513 convert_integer_literal,
514 before,
515 "const _: i32 = 0o1750i32;",
516 "Convert 1_00_0i32 to 0o1750i32",
517 );
518
519 check_assist_by_label(
520 convert_integer_literal,
521 before,
522 "const _: i32 = 0x3E8i32;",
523 "Convert 1_00_0i32 to 0x3E8i32",
524 );
525 }
526
527 #[test]
528 fn convert_small_decimal_integer_with_underscores_and_suffix() {
529 let before = "const _: i32 = 1_0i32<|>;";
530
531 check_assist_by_label(
532 convert_integer_literal,
533 before,
534 "const _: i32 = 0b1010i32;",
535 "Convert 1_0i32 to 0b1010i32",
536 );
537
538 check_assist_by_label(
539 convert_integer_literal,
540 before,
541 "const _: i32 = 0o12i32;",
542 "Convert 1_0i32 to 0o12i32",
543 );
544
545 check_assist_by_label(
546 convert_integer_literal,
547 before,
548 "const _: i32 = 0xAi32;",
549 "Convert 1_0i32 to 0xAi32",
550 );
551 }
552
553 #[test]
554 fn convert_hexadecimal_integer_with_underscores_and_suffix() {
555 let before = "const _: i32 = 0x_F_Fi32<|>;";
556
557 check_assist_by_label(
558 convert_integer_literal,
559 before,
560 "const _: i32 = 0b11111111i32;",
561 "Convert 0x_F_Fi32 to 0b11111111i32",
562 );
563
564 check_assist_by_label(
565 convert_integer_literal,
566 before,
567 "const _: i32 = 0o377i32;",
568 "Convert 0x_F_Fi32 to 0o377i32",
569 );
570
571 check_assist_by_label(
572 convert_integer_literal,
573 before,
574 "const _: i32 = 255i32;",
575 "Convert 0x_F_Fi32 to 255i32",
576 );
577 }
578
579 #[test]
580 fn convert_binary_integer_with_underscores_and_suffix() {
581 let before = "const _: i32 = 0b1111_1111i32<|>;";
582
583 check_assist_by_label(
584 convert_integer_literal,
585 before,
586 "const _: i32 = 0o377i32;",
587 "Convert 0b1111_1111i32 to 0o377i32",
588 );
589
590 check_assist_by_label(
591 convert_integer_literal,
592 before,
593 "const _: i32 = 255i32;",
594 "Convert 0b1111_1111i32 to 255i32",
595 );
596
597 check_assist_by_label(
598 convert_integer_literal,
599 before,
600 "const _: i32 = 0xFFi32;",
601 "Convert 0b1111_1111i32 to 0xFFi32",
602 );
603 }
604
605 #[test]
606 fn convert_octal_integer_with_underscores_and_suffix() {
607 let before = "const _: i32 = 0o3_77i32<|>;";
608
609 check_assist_by_label(
610 convert_integer_literal,
611 before,
612 "const _: i32 = 0b11111111i32;",
613 "Convert 0o3_77i32 to 0b11111111i32",
614 );
615
616 check_assist_by_label(
617 convert_integer_literal,
618 before,
619 "const _: i32 = 255i32;",
620 "Convert 0o3_77i32 to 255i32",
621 );
622
623 check_assist_by_label(
624 convert_integer_literal,
625 before,
626 "const _: i32 = 0xFFi32;",
627 "Convert 0o3_77i32 to 0xFFi32",
628 );
629 }
630
631 #[test]
632 fn convert_overflowing_literal() { 263 fn convert_overflowing_literal() {
633 let before = "const _: i32 = 264 let before = "const _: i32 =
634 111111111111111111111111111111111111111111111111111111111111111111111111<|>;"; 265 111111111111111111111111111111111111111111111111111111111111111111111111<|>;";
diff --git a/crates/assists/src/handlers/expand_glob_import.rs b/crates/assists/src/handlers/expand_glob_import.rs
index 316a58d88..f51a9a4ad 100644
--- a/crates/assists/src/handlers/expand_glob_import.rs
+++ b/crates/assists/src/handlers/expand_glob_import.rs
@@ -5,13 +5,13 @@ use ide_db::{
5 search::SearchScope, 5 search::SearchScope,
6}; 6};
7use syntax::{ 7use syntax::{
8 algo, 8 algo::SyntaxRewriter,
9 ast::{self, make}, 9 ast::{self, make},
10 AstNode, Direction, SyntaxNode, SyntaxToken, T, 10 AstNode, Direction, SyntaxNode, SyntaxToken, T,
11}; 11};
12 12
13use crate::{ 13use crate::{
14 assist_context::{AssistBuilder, AssistContext, Assists}, 14 assist_context::{AssistContext, Assists},
15 AssistId, AssistKind, 15 AssistId, AssistKind,
16}; 16};
17 17
@@ -41,7 +41,7 @@ use crate::{
41// fn qux(bar: Bar, baz: Baz) {} 41// fn qux(bar: Bar, baz: Baz) {}
42// ``` 42// ```
43pub(crate) fn expand_glob_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 43pub(crate) fn expand_glob_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
44 let star = ctx.find_token_at_offset(T![*])?; 44 let star = ctx.find_token_syntax_at_offset(T![*])?;
45 let (parent, mod_path) = find_parent_and_path(&star)?; 45 let (parent, mod_path) = find_parent_and_path(&star)?;
46 let target_module = match ctx.sema.resolve_path(&mod_path)? { 46 let target_module = match ctx.sema.resolve_path(&mod_path)? {
47 PathResolution::Def(ModuleDef::Module(it)) => it, 47 PathResolution::Def(ModuleDef::Module(it)) => it,
@@ -61,7 +61,9 @@ pub(crate) fn expand_glob_import(acc: &mut Assists, ctx: &AssistContext) -> Opti
61 "Expand glob import", 61 "Expand glob import",
62 target.text_range(), 62 target.text_range(),
63 |builder| { 63 |builder| {
64 replace_ast(builder, parent, mod_path, names_to_import); 64 let mut rewriter = SyntaxRewriter::default();
65 replace_ast(&mut rewriter, parent, mod_path, names_to_import);
66 builder.rewrite(rewriter);
65 }, 67 },
66 ) 68 )
67} 69}
@@ -236,7 +238,7 @@ fn find_names_to_import(
236} 238}
237 239
238fn replace_ast( 240fn replace_ast(
239 builder: &mut AssistBuilder, 241 rewriter: &mut SyntaxRewriter,
240 parent: Either<ast::UseTree, ast::UseTreeList>, 242 parent: Either<ast::UseTree, ast::UseTreeList>,
241 path: ast::Path, 243 path: ast::Path,
242 names_to_import: Vec<Name>, 244 names_to_import: Vec<Name>,
@@ -264,32 +266,21 @@ fn replace_ast(
264 match use_trees.as_slice() { 266 match use_trees.as_slice() {
265 [name] => { 267 [name] => {
266 if let Some(end_path) = name.path() { 268 if let Some(end_path) = name.path() {
267 let replacement = 269 rewriter.replace_ast(
268 make::use_tree(make::path_concat(path, end_path), None, None, false); 270 &parent.left_or_else(|tl| tl.parent_use_tree()),
269 271 &make::use_tree(make::path_concat(path, end_path), None, None, false),
270 algo::diff( 272 );
271 &parent.either(|n| n.syntax().clone(), |n| n.syntax().clone()),
272 replacement.syntax(),
273 )
274 .into_text_edit(builder.text_edit_builder());
275 } 273 }
276 } 274 }
277 names => { 275 names => match &parent {
278 let replacement = match parent { 276 Either::Left(parent) => rewriter.replace_ast(
279 Either::Left(_) => { 277 parent,
280 make::use_tree(path, Some(make::use_tree_list(names.to_owned())), None, false) 278 &make::use_tree(path, Some(make::use_tree_list(names.to_owned())), None, false),
281 .syntax() 279 ),
282 .clone() 280 Either::Right(parent) => {
283 } 281 rewriter.replace_ast(parent, &make::use_tree_list(names.to_owned()))
284 Either::Right(_) => make::use_tree_list(names.to_owned()).syntax().clone(), 282 }
285 }; 283 },
286
287 algo::diff(
288 &parent.either(|n| n.syntax().clone(), |n| n.syntax().clone()),
289 &replacement,
290 )
291 .into_text_edit(builder.text_edit_builder());
292 }
293 }; 284 };
294} 285}
295 286
@@ -884,4 +875,33 @@ fn qux(baz: Baz) {}
884 ", 875 ",
885 ) 876 )
886 } 877 }
878
879 #[test]
880 fn expanding_glob_import_single_nested_glob_only() {
881 check_assist(
882 expand_glob_import,
883 r"
884mod foo {
885 pub struct Bar;
886}
887
888use foo::{*<|>};
889
890struct Baz {
891 bar: Bar
892}
893",
894 r"
895mod foo {
896 pub struct Bar;
897}
898
899use foo::Bar;
900
901struct Baz {
902 bar: Bar
903}
904",
905 );
906 }
887} 907}
diff --git a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs
index 14209b771..84662d832 100644
--- a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs
+++ b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs
@@ -345,6 +345,73 @@ fn another_fn() {
345 ); 345 );
346 } 346 }
347 347
348 #[test]
349 fn test_several_files() {
350 check_assist(
351 extract_struct_from_enum_variant,
352 r#"
353//- /main.rs
354enum E {
355 <|>V(i32, i32)
356}
357mod foo;
358
359//- /foo.rs
360use crate::E;
361fn f() {
362 let e = E::V(9, 2);
363}
364"#,
365 r#"
366//- /main.rs
367struct V(pub i32, pub i32);
368
369enum E {
370 V(V)
371}
372mod foo;
373
374//- /foo.rs
375use V;
376
377use crate::E;
378fn f() {
379 let e = E::V(V(9, 2));
380}
381"#,
382 )
383 }
384
385 #[test]
386 fn test_several_files_record() {
387 // FIXME: this should fix the usage as well!
388 check_assist(
389 extract_struct_from_enum_variant,
390 r#"
391//- /main.rs
392enum E {
393 <|>V { i: i32, j: i32 }
394}
395mod foo;
396
397//- /foo.rs
398use crate::E;
399fn f() {
400 let e = E::V { i: 9, j: 2 };
401}
402"#,
403 r#"
404struct V{ pub i: i32, pub j: i32 }
405
406enum E {
407 V(V)
408}
409mod foo;
410
411"#,
412 )
413 }
414
348 fn check_not_applicable(ra_fixture: &str) { 415 fn check_not_applicable(ra_fixture: &str) {
349 let fixture = 416 let fixture =
350 format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE); 417 format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE);
diff --git a/crates/assists/src/handlers/flip_comma.rs b/crates/assists/src/handlers/flip_comma.rs
index 5c69db53e..64b4b1a76 100644
--- a/crates/assists/src/handlers/flip_comma.rs
+++ b/crates/assists/src/handlers/flip_comma.rs
@@ -18,7 +18,7 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
18// } 18// }
19// ``` 19// ```
20pub(crate) fn flip_comma(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 20pub(crate) fn flip_comma(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
21 let comma = ctx.find_token_at_offset(T![,])?; 21 let comma = ctx.find_token_syntax_at_offset(T![,])?;
22 let prev = non_trivia_sibling(comma.clone().into(), Direction::Prev)?; 22 let prev = non_trivia_sibling(comma.clone().into(), Direction::Prev)?;
23 let next = non_trivia_sibling(comma.clone().into(), Direction::Next)?; 23 let next = non_trivia_sibling(comma.clone().into(), Direction::Next)?;
24 24
diff --git a/crates/assists/src/handlers/flip_trait_bound.rs b/crates/assists/src/handlers/flip_trait_bound.rs
index 347e79b1d..92ee42181 100644
--- a/crates/assists/src/handlers/flip_trait_bound.rs
+++ b/crates/assists/src/handlers/flip_trait_bound.rs
@@ -20,7 +20,7 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
20pub(crate) fn flip_trait_bound(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 20pub(crate) fn flip_trait_bound(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
21 // We want to replicate the behavior of `flip_binexpr` by only suggesting 21 // We want to replicate the behavior of `flip_binexpr` by only suggesting
22 // the assist when the cursor is on a `+` 22 // the assist when the cursor is on a `+`
23 let plus = ctx.find_token_at_offset(T![+])?; 23 let plus = ctx.find_token_syntax_at_offset(T![+])?;
24 24
25 // Make sure we're in a `TypeBoundList` 25 // Make sure we're in a `TypeBoundList`
26 if ast::TypeBoundList::cast(plus.parent()).is_none() { 26 if ast::TypeBoundList::cast(plus.parent()).is_none() {
diff --git a/crates/assists/src/handlers/infer_function_return_type.rs b/crates/assists/src/handlers/infer_function_return_type.rs
new file mode 100644
index 000000000..520d07ae0
--- /dev/null
+++ b/crates/assists/src/handlers/infer_function_return_type.rs
@@ -0,0 +1,337 @@
1use hir::HirDisplay;
2use syntax::{ast, AstNode, TextRange, TextSize};
3use test_utils::mark;
4
5use crate::{AssistContext, AssistId, AssistKind, Assists};
6
7// Assist: infer_function_return_type
8//
9// Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return
10// type specified. This assists is useable in a functions or closures tail expression or return type position.
11//
12// ```
13// fn foo() { 4<|>2i32 }
14// ```
15// ->
16// ```
17// fn foo() -> i32 { 42i32 }
18// ```
19pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
20 let (tail_expr, builder_edit_pos, wrap_expr) = extract_tail(ctx)?;
21 let module = ctx.sema.scope(tail_expr.syntax()).module()?;
22 let ty = ctx.sema.type_of_expr(&tail_expr)?;
23 if ty.is_unit() {
24 return None;
25 }
26 let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
27
28 acc.add(
29 AssistId("infer_function_return_type", AssistKind::RefactorRewrite),
30 "Add this function's return type",
31 tail_expr.syntax().text_range(),
32 |builder| {
33 match builder_edit_pos {
34 InsertOrReplace::Insert(insert_pos) => {
35 builder.insert(insert_pos, &format!("-> {} ", ty))
36 }
37 InsertOrReplace::Replace(text_range) => {
38 builder.replace(text_range, &format!("-> {}", ty))
39 }
40 }
41 if wrap_expr {
42 mark::hit!(wrap_closure_non_block_expr);
43 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
44 builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr));
45 }
46 },
47 )
48}
49
50enum InsertOrReplace {
51 Insert(TextSize),
52 Replace(TextRange),
53}
54
55/// Check the potentially already specified return type and reject it or turn it into a builder command
56/// if allowed.
57fn ret_ty_to_action(ret_ty: Option<ast::RetType>, insert_pos: TextSize) -> Option<InsertOrReplace> {
58 match ret_ty {
59 Some(ret_ty) => match ret_ty.ty() {
60 Some(ast::Type::InferType(_)) | None => {
61 mark::hit!(existing_infer_ret_type);
62 mark::hit!(existing_infer_ret_type_closure);
63 Some(InsertOrReplace::Replace(ret_ty.syntax().text_range()))
64 }
65 _ => {
66 mark::hit!(existing_ret_type);
67 mark::hit!(existing_ret_type_closure);
68 None
69 }
70 },
71 None => Some(InsertOrReplace::Insert(insert_pos + TextSize::from(1))),
72 }
73}
74
75fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool)> {
76 let (tail_expr, return_type_range, action, wrap_expr) =
77 if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
78 let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end();
79 let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?;
80
81 let body = closure.body()?;
82 let body_start = body.syntax().first_token()?.text_range().start();
83 let (tail_expr, wrap_expr) = match body {
84 ast::Expr::BlockExpr(block) => (block.expr()?, false),
85 body => (body, true),
86 };
87
88 let ret_range = TextRange::new(rpipe_pos, body_start);
89 (tail_expr, ret_range, action, wrap_expr)
90 } else {
91 let func = ctx.find_node_at_offset::<ast::Fn>()?;
92 let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end();
93 let action = ret_ty_to_action(func.ret_type(), rparen_pos)?;
94
95 let body = func.body()?;
96 let tail_expr = body.expr()?;
97
98 let ret_range_end = body.l_curly_token()?.text_range().start();
99 let ret_range = TextRange::new(rparen_pos, ret_range_end);
100 (tail_expr, ret_range, action, false)
101 };
102 let frange = ctx.frange.range;
103 if return_type_range.contains_range(frange) {
104 mark::hit!(cursor_in_ret_position);
105 mark::hit!(cursor_in_ret_position_closure);
106 } else if tail_expr.syntax().text_range().contains_range(frange) {
107 mark::hit!(cursor_on_tail);
108 mark::hit!(cursor_on_tail_closure);
109 } else {
110 return None;
111 }
112 Some((tail_expr, action, wrap_expr))
113}
114
115#[cfg(test)]
116mod tests {
117 use crate::tests::{check_assist, check_assist_not_applicable};
118
119 use super::*;
120
121 #[test]
122 fn infer_return_type_specified_inferred() {
123 mark::check!(existing_infer_ret_type);
124 check_assist(
125 infer_function_return_type,
126 r#"fn foo() -> <|>_ {
127 45
128}"#,
129 r#"fn foo() -> i32 {
130 45
131}"#,
132 );
133 }
134
135 #[test]
136 fn infer_return_type_specified_inferred_closure() {
137 mark::check!(existing_infer_ret_type_closure);
138 check_assist(
139 infer_function_return_type,
140 r#"fn foo() {
141 || -> _ {<|>45};
142}"#,
143 r#"fn foo() {
144 || -> i32 {45};
145}"#,
146 );
147 }
148
149 #[test]
150 fn infer_return_type_cursor_at_return_type_pos() {
151 mark::check!(cursor_in_ret_position);
152 check_assist(
153 infer_function_return_type,
154 r#"fn foo() <|>{
155 45
156}"#,
157 r#"fn foo() -> i32 {
158 45
159}"#,
160 );
161 }
162
163 #[test]
164 fn infer_return_type_cursor_at_return_type_pos_closure() {
165 mark::check!(cursor_in_ret_position_closure);
166 check_assist(
167 infer_function_return_type,
168 r#"fn foo() {
169 || <|>45
170}"#,
171 r#"fn foo() {
172 || -> i32 {45}
173}"#,
174 );
175 }
176
177 #[test]
178 fn infer_return_type() {
179 mark::check!(cursor_on_tail);
180 check_assist(
181 infer_function_return_type,
182 r#"fn foo() {
183 45<|>
184}"#,
185 r#"fn foo() -> i32 {
186 45
187}"#,
188 );
189 }
190
191 #[test]
192 fn infer_return_type_nested() {
193 check_assist(
194 infer_function_return_type,
195 r#"fn foo() {
196 if true {
197 3<|>
198 } else {
199 5
200 }
201}"#,
202 r#"fn foo() -> i32 {
203 if true {
204 3
205 } else {
206 5
207 }
208}"#,
209 );
210 }
211
212 #[test]
213 fn not_applicable_ret_type_specified() {
214 mark::check!(existing_ret_type);
215 check_assist_not_applicable(
216 infer_function_return_type,
217 r#"fn foo() -> i32 {
218 ( 45<|> + 32 ) * 123
219}"#,
220 );
221 }
222
223 #[test]
224 fn not_applicable_non_tail_expr() {
225 check_assist_not_applicable(
226 infer_function_return_type,
227 r#"fn foo() {
228 let x = <|>3;
229 ( 45 + 32 ) * 123
230}"#,
231 );
232 }
233
234 #[test]
235 fn not_applicable_unit_return_type() {
236 check_assist_not_applicable(
237 infer_function_return_type,
238 r#"fn foo() {
239 (<|>)
240}"#,
241 );
242 }
243
244 #[test]
245 fn infer_return_type_closure_block() {
246 mark::check!(cursor_on_tail_closure);
247 check_assist(
248 infer_function_return_type,
249 r#"fn foo() {
250 |x: i32| {
251 x<|>
252 };
253}"#,
254 r#"fn foo() {
255 |x: i32| -> i32 {
256 x
257 };
258}"#,
259 );
260 }
261
262 #[test]
263 fn infer_return_type_closure() {
264 check_assist(
265 infer_function_return_type,
266 r#"fn foo() {
267 |x: i32| { x<|> };
268}"#,
269 r#"fn foo() {
270 |x: i32| -> i32 { x };
271}"#,
272 );
273 }
274
275 #[test]
276 fn infer_return_type_closure_wrap() {
277 mark::check!(wrap_closure_non_block_expr);
278 check_assist(
279 infer_function_return_type,
280 r#"fn foo() {
281 |x: i32| x<|>;
282}"#,
283 r#"fn foo() {
284 |x: i32| -> i32 {x};
285}"#,
286 );
287 }
288
289 #[test]
290 fn infer_return_type_nested_closure() {
291 check_assist(
292 infer_function_return_type,
293 r#"fn foo() {
294 || {
295 if true {
296 3<|>
297 } else {
298 5
299 }
300 }
301}"#,
302 r#"fn foo() {
303 || -> i32 {
304 if true {
305 3
306 } else {
307 5
308 }
309 }
310}"#,
311 );
312 }
313
314 #[test]
315 fn not_applicable_ret_type_specified_closure() {
316 mark::check!(existing_ret_type_closure);
317 check_assist_not_applicable(
318 infer_function_return_type,
319 r#"fn foo() {
320 || -> i32 { 3<|> }
321}"#,
322 );
323 }
324
325 #[test]
326 fn not_applicable_non_tail_expr_closure() {
327 check_assist_not_applicable(
328 infer_function_return_type,
329 r#"fn foo() {
330 || -> i32 {
331 let x = 3<|>;
332 6
333 }
334}"#,
335 );
336 }
337}
diff --git a/crates/assists/src/handlers/introduce_named_lifetime.rs b/crates/assists/src/handlers/introduce_named_lifetime.rs
index 5f623e5f7..4cc8dae65 100644
--- a/crates/assists/src/handlers/introduce_named_lifetime.rs
+++ b/crates/assists/src/handlers/introduce_named_lifetime.rs
@@ -36,7 +36,7 @@ static ASSIST_LABEL: &str = "Introduce named lifetime";
36// FIXME: should also add support for the case fun(f: &Foo) -> &<|>Foo 36// FIXME: should also add support for the case fun(f: &Foo) -> &<|>Foo
37pub(crate) fn introduce_named_lifetime(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 37pub(crate) fn introduce_named_lifetime(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
38 let lifetime_token = ctx 38 let lifetime_token = ctx
39 .find_token_at_offset(SyntaxKind::LIFETIME) 39 .find_token_syntax_at_offset(SyntaxKind::LIFETIME)
40 .filter(|lifetime| lifetime.text() == "'_")?; 40 .filter(|lifetime| lifetime.text() == "'_")?;
41 if let Some(fn_def) = lifetime_token.ancestors().find_map(ast::Fn::cast) { 41 if let Some(fn_def) = lifetime_token.ancestors().find_map(ast::Fn::cast) {
42 generate_fn_def_assist(acc, &fn_def, lifetime_token.text_range()) 42 generate_fn_def_assist(acc, &fn_def, lifetime_token.text_range())
diff --git a/crates/assists/src/handlers/invert_if.rs b/crates/assists/src/handlers/invert_if.rs
index 461fcf862..ea722b91b 100644
--- a/crates/assists/src/handlers/invert_if.rs
+++ b/crates/assists/src/handlers/invert_if.rs
@@ -29,7 +29,7 @@ use crate::{
29// ``` 29// ```
30 30
31pub(crate) fn invert_if(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 31pub(crate) fn invert_if(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
32 let if_keyword = ctx.find_token_at_offset(T![if])?; 32 let if_keyword = ctx.find_token_syntax_at_offset(T![if])?;
33 let expr = ast::IfExpr::cast(if_keyword.parent())?; 33 let expr = ast::IfExpr::cast(if_keyword.parent())?;
34 let if_range = if_keyword.text_range(); 34 let if_range = if_keyword.text_range();
35 let cursor_in_range = if_range.contains_range(ctx.frange.range); 35 let cursor_in_range = if_range.contains_range(ctx.frange.range);
diff --git a/crates/assists/src/handlers/raw_string.rs b/crates/assists/src/handlers/raw_string.rs
index 9ddd116e0..4c759cc25 100644
--- a/crates/assists/src/handlers/raw_string.rs
+++ b/crates/assists/src/handlers/raw_string.rs
@@ -1,11 +1,6 @@
1use std::borrow::Cow; 1use std::borrow::Cow;
2 2
3use syntax::{ 3use syntax::{ast, AstToken, TextRange, TextSize};
4 ast::{self, HasQuotes, HasStringValue},
5 AstToken,
6 SyntaxKind::{RAW_STRING, STRING},
7 TextRange, TextSize,
8};
9use test_utils::mark; 4use test_utils::mark;
10 5
11use crate::{AssistContext, AssistId, AssistKind, Assists}; 6use crate::{AssistContext, AssistId, AssistKind, Assists};
@@ -26,7 +21,10 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
26// } 21// }
27// ``` 22// ```
28pub(crate) fn make_raw_string(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 23pub(crate) fn make_raw_string(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
29 let token = ctx.find_token_at_offset(STRING).and_then(ast::String::cast)?; 24 let token = ctx.find_token_at_offset::<ast::String>()?;
25 if token.is_raw() {
26 return None;
27 }
30 let value = token.value()?; 28 let value = token.value()?;
31 let target = token.syntax().text_range(); 29 let target = token.syntax().text_range();
32 acc.add( 30 acc.add(
@@ -65,7 +63,10 @@ pub(crate) fn make_raw_string(acc: &mut Assists, ctx: &AssistContext) -> Option<
65// } 63// }
66// ``` 64// ```
67pub(crate) fn make_usual_string(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 65pub(crate) fn make_usual_string(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
68 let token = ctx.find_token_at_offset(RAW_STRING).and_then(ast::RawString::cast)?; 66 let token = ctx.find_token_at_offset::<ast::String>()?;
67 if !token.is_raw() {
68 return None;
69 }
69 let value = token.value()?; 70 let value = token.value()?;
70 let target = token.syntax().text_range(); 71 let target = token.syntax().text_range();
71 acc.add( 72 acc.add(
@@ -104,11 +105,15 @@ pub(crate) fn make_usual_string(acc: &mut Assists, ctx: &AssistContext) -> Optio
104// } 105// }
105// ``` 106// ```
106pub(crate) fn add_hash(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 107pub(crate) fn add_hash(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
107 let token = ctx.find_token_at_offset(RAW_STRING)?; 108 let token = ctx.find_token_at_offset::<ast::String>()?;
108 let target = token.text_range(); 109 if !token.is_raw() {
110 return None;
111 }
112 let text_range = token.syntax().text_range();
113 let target = text_range;
109 acc.add(AssistId("add_hash", AssistKind::Refactor), "Add #", target, |edit| { 114 acc.add(AssistId("add_hash", AssistKind::Refactor), "Add #", target, |edit| {
110 edit.insert(token.text_range().start() + TextSize::of('r'), "#"); 115 edit.insert(text_range.start() + TextSize::of('r'), "#");
111 edit.insert(token.text_range().end(), "#"); 116 edit.insert(text_range.end(), "#");
112 }) 117 })
113} 118}
114 119
@@ -128,7 +133,10 @@ pub(crate) fn add_hash(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
128// } 133// }
129// ``` 134// ```
130pub(crate) fn remove_hash(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 135pub(crate) fn remove_hash(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
131 let token = ctx.find_token_at_offset(RAW_STRING).and_then(ast::RawString::cast)?; 136 let token = ctx.find_token_at_offset::<ast::String>()?;
137 if !token.is_raw() {
138 return None;
139 }
132 140
133 let text = token.text().as_str(); 141 let text = token.text().as_str();
134 if !text.starts_with("r#") && text.ends_with('#') { 142 if !text.starts_with("r#") && text.ends_with('#') {
diff --git a/crates/assists/src/handlers/remove_mut.rs b/crates/assists/src/handlers/remove_mut.rs
index 44f41daa9..575b271f7 100644
--- a/crates/assists/src/handlers/remove_mut.rs
+++ b/crates/assists/src/handlers/remove_mut.rs
@@ -18,7 +18,7 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
18// } 18// }
19// ``` 19// ```
20pub(crate) fn remove_mut(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 20pub(crate) fn remove_mut(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
21 let mut_token = ctx.find_token_at_offset(T![mut])?; 21 let mut_token = ctx.find_token_syntax_at_offset(T![mut])?;
22 let delete_from = mut_token.text_range().start(); 22 let delete_from = mut_token.text_range().start();
23 let delete_to = match mut_token.next_token() { 23 let delete_to = match mut_token.next_token() {
24 Some(it) if it.kind() == SyntaxKind::WHITESPACE => it.text_range().end(), 24 Some(it) if it.kind() == SyntaxKind::WHITESPACE => it.text_range().end(),
diff --git a/crates/assists/src/handlers/reorder_fields.rs b/crates/assists/src/handlers/reorder_fields.rs
index 527f457a7..7c0f0f44e 100644
--- a/crates/assists/src/handlers/reorder_fields.rs
+++ b/crates/assists/src/handlers/reorder_fields.rs
@@ -47,9 +47,11 @@ fn reorder<R: AstNode>(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
47 "Reorder record fields", 47 "Reorder record fields",
48 target, 48 target,
49 |edit| { 49 |edit| {
50 let mut rewriter = algo::SyntaxRewriter::default();
50 for (old, new) in fields.iter().zip(&sorted_fields) { 51 for (old, new) in fields.iter().zip(&sorted_fields) {
51 algo::diff(old, new).into_text_edit(edit.text_edit_builder()); 52 rewriter.replace(old, new);
52 } 53 }
54 edit.rewrite(rewriter);
53 }, 55 },
54 ) 56 )
55} 57}
diff --git a/crates/assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/assists/src/handlers/replace_derive_with_manual_impl.rs
new file mode 100644
index 000000000..82625516c
--- /dev/null
+++ b/crates/assists/src/handlers/replace_derive_with_manual_impl.rs
@@ -0,0 +1,398 @@
1use ide_db::imports_locator;
2use itertools::Itertools;
3use syntax::{
4 ast::{self, make, AstNode},
5 Direction, SmolStr,
6 SyntaxKind::{IDENT, WHITESPACE},
7 TextSize,
8};
9
10use crate::{
11 assist_context::{AssistBuilder, AssistContext, Assists},
12 utils::{
13 add_trait_assoc_items_to_impl, filter_assoc_items, mod_path_to_ast, render_snippet, Cursor,
14 DefaultMethods,
15 },
16 AssistId, AssistKind,
17};
18
19// Assist: replace_derive_with_manual_impl
20//
21// Converts a `derive` impl into a manual one.
22//
23// ```
24// # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
25// #[derive(Deb<|>ug, Display)]
26// struct S;
27// ```
28// ->
29// ```
30// # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
31// #[derive(Display)]
32// struct S;
33//
34// impl Debug for S {
35// fn fmt(&self, f: &mut Formatter) -> Result<()> {
36// ${0:todo!()}
37// }
38// }
39// ```
40pub(crate) fn replace_derive_with_manual_impl(
41 acc: &mut Assists,
42 ctx: &AssistContext,
43) -> Option<()> {
44 let attr = ctx.find_node_at_offset::<ast::Attr>()?;
45
46 let attr_name = attr
47 .syntax()
48 .descendants_with_tokens()
49 .filter(|t| t.kind() == IDENT)
50 .find_map(syntax::NodeOrToken::into_token)
51 .filter(|t| t.text() == "derive")?
52 .text()
53 .clone();
54
55 let trait_token =
56 ctx.token_at_offset().find(|t| t.kind() == IDENT && *t.text() != attr_name)?;
57 let trait_path = make::path_unqualified(make::path_segment(make::name_ref(trait_token.text())));
58
59 let annotated_name = attr.syntax().siblings(Direction::Next).find_map(ast::Name::cast)?;
60 let insert_pos = annotated_name.syntax().parent()?.text_range().end();
61
62 let current_module = ctx.sema.scope(annotated_name.syntax()).module()?;
63 let current_crate = current_module.krate();
64
65 let found_traits = imports_locator::find_imports(&ctx.sema, current_crate, trait_token.text())
66 .into_iter()
67 .filter_map(|candidate: either::Either<hir::ModuleDef, hir::MacroDef>| match candidate {
68 either::Either::Left(hir::ModuleDef::Trait(trait_)) => Some(trait_),
69 _ => None,
70 })
71 .flat_map(|trait_| {
72 current_module
73 .find_use_path(ctx.sema.db, hir::ModuleDef::Trait(trait_))
74 .as_ref()
75 .map(mod_path_to_ast)
76 .zip(Some(trait_))
77 });
78
79 let mut no_traits_found = true;
80 for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
81 add_assist(acc, ctx, &attr, &trait_path, Some(trait_), &annotated_name, insert_pos)?;
82 }
83 if no_traits_found {
84 add_assist(acc, ctx, &attr, &trait_path, None, &annotated_name, insert_pos)?;
85 }
86 Some(())
87}
88
89fn add_assist(
90 acc: &mut Assists,
91 ctx: &AssistContext,
92 attr: &ast::Attr,
93 trait_path: &ast::Path,
94 trait_: Option<hir::Trait>,
95 annotated_name: &ast::Name,
96 insert_pos: TextSize,
97) -> Option<()> {
98 let target = attr.syntax().text_range();
99 let input = attr.token_tree()?;
100 let label = format!("Convert to manual `impl {} for {}`", trait_path, annotated_name);
101 let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
102
103 acc.add(
104 AssistId("replace_derive_with_manual_impl", AssistKind::Refactor),
105 label,
106 target,
107 |builder| {
108 let impl_def_with_items =
109 impl_def_from_trait(&ctx.sema, annotated_name, trait_, trait_path);
110 update_attribute(builder, &input, &trait_name, &attr);
111 match (ctx.config.snippet_cap, impl_def_with_items) {
112 (None, _) => builder.insert(
113 insert_pos,
114 format!("\n\nimpl {} for {} {{\n\n}}", trait_path, annotated_name),
115 ),
116 (Some(cap), None) => builder.insert_snippet(
117 cap,
118 insert_pos,
119 format!("\n\nimpl {} for {} {{\n $0\n}}", trait_path, annotated_name),
120 ),
121 (Some(cap), Some((impl_def, first_assoc_item))) => {
122 let mut cursor = Cursor::Before(first_assoc_item.syntax());
123 let placeholder;
124 if let ast::AssocItem::Fn(ref func) = first_assoc_item {
125 if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast)
126 {
127 if m.syntax().text() == "todo!()" {
128 placeholder = m;
129 cursor = Cursor::Replace(placeholder.syntax());
130 }
131 }
132 }
133
134 builder.insert_snippet(
135 cap,
136 insert_pos,
137 format!("\n\n{}", render_snippet(cap, impl_def.syntax(), cursor)),
138 )
139 }
140 };
141 },
142 )
143}
144
145fn impl_def_from_trait(
146 sema: &hir::Semantics<ide_db::RootDatabase>,
147 annotated_name: &ast::Name,
148 trait_: Option<hir::Trait>,
149 trait_path: &ast::Path,
150) -> Option<(ast::Impl, ast::AssocItem)> {
151 let trait_ = trait_?;
152 let target_scope = sema.scope(annotated_name.syntax());
153 let trait_items = filter_assoc_items(sema.db, &trait_.items(sema.db), DefaultMethods::No);
154 if trait_items.is_empty() {
155 return None;
156 }
157 let impl_def = make::impl_trait(
158 trait_path.clone(),
159 make::path_unqualified(make::path_segment(make::name_ref(annotated_name.text()))),
160 );
161 let (impl_def, first_assoc_item) =
162 add_trait_assoc_items_to_impl(sema, trait_items, trait_, impl_def, target_scope);
163 Some((impl_def, first_assoc_item))
164}
165
166fn update_attribute(
167 builder: &mut AssistBuilder,
168 input: &ast::TokenTree,
169 trait_name: &ast::NameRef,
170 attr: &ast::Attr,
171) {
172 let new_attr_input = input
173 .syntax()
174 .descendants_with_tokens()
175 .filter(|t| t.kind() == IDENT)
176 .filter_map(|t| t.into_token().map(|t| t.text().clone()))
177 .filter(|t| t != trait_name.text())
178 .collect::<Vec<SmolStr>>();
179 let has_more_derives = !new_attr_input.is_empty();
180
181 if has_more_derives {
182 let new_attr_input = format!("({})", new_attr_input.iter().format(", "));
183 builder.replace(input.syntax().text_range(), new_attr_input);
184 } else {
185 let attr_range = attr.syntax().text_range();
186 builder.delete(attr_range);
187
188 if let Some(line_break_range) = attr
189 .syntax()
190 .next_sibling_or_token()
191 .filter(|t| t.kind() == WHITESPACE)
192 .map(|t| t.text_range())
193 {
194 builder.delete(line_break_range);
195 }
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use crate::tests::{check_assist, check_assist_not_applicable};
202
203 use super::*;
204
205 #[test]
206 fn add_custom_impl_debug() {
207 check_assist(
208 replace_derive_with_manual_impl,
209 "
210mod fmt {
211 pub struct Error;
212 pub type Result = Result<(), Error>;
213 pub struct Formatter<'a>;
214 pub trait Debug {
215 fn fmt(&self, f: &mut Formatter<'_>) -> Result;
216 }
217}
218
219#[derive(Debu<|>g)]
220struct Foo {
221 bar: String,
222}
223",
224 "
225mod fmt {
226 pub struct Error;
227 pub type Result = Result<(), Error>;
228 pub struct Formatter<'a>;
229 pub trait Debug {
230 fn fmt(&self, f: &mut Formatter<'_>) -> Result;
231 }
232}
233
234struct Foo {
235 bar: String,
236}
237
238impl fmt::Debug for Foo {
239 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240 ${0:todo!()}
241 }
242}
243",
244 )
245 }
246 #[test]
247 fn add_custom_impl_all() {
248 check_assist(
249 replace_derive_with_manual_impl,
250 "
251mod foo {
252 pub trait Bar {
253 type Qux;
254 const Baz: usize = 42;
255 const Fez: usize;
256 fn foo();
257 fn bar() {}
258 }
259}
260
261#[derive(<|>Bar)]
262struct Foo {
263 bar: String,
264}
265",
266 "
267mod foo {
268 pub trait Bar {
269 type Qux;
270 const Baz: usize = 42;
271 const Fez: usize;
272 fn foo();
273 fn bar() {}
274 }
275}
276
277struct Foo {
278 bar: String,
279}
280
281impl foo::Bar for Foo {
282 $0type Qux;
283
284 const Baz: usize = 42;
285
286 const Fez: usize;
287
288 fn foo() {
289 todo!()
290 }
291}
292",
293 )
294 }
295 #[test]
296 fn add_custom_impl_for_unique_input() {
297 check_assist(
298 replace_derive_with_manual_impl,
299 "
300#[derive(Debu<|>g)]
301struct Foo {
302 bar: String,
303}
304 ",
305 "
306struct Foo {
307 bar: String,
308}
309
310impl Debug for Foo {
311 $0
312}
313 ",
314 )
315 }
316
317 #[test]
318 fn add_custom_impl_for_with_visibility_modifier() {
319 check_assist(
320 replace_derive_with_manual_impl,
321 "
322#[derive(Debug<|>)]
323pub struct Foo {
324 bar: String,
325}
326 ",
327 "
328pub struct Foo {
329 bar: String,
330}
331
332impl Debug for Foo {
333 $0
334}
335 ",
336 )
337 }
338
339 #[test]
340 fn add_custom_impl_when_multiple_inputs() {
341 check_assist(
342 replace_derive_with_manual_impl,
343 "
344#[derive(Display, Debug<|>, Serialize)]
345struct Foo {}
346 ",
347 "
348#[derive(Display, Serialize)]
349struct Foo {}
350
351impl Debug for Foo {
352 $0
353}
354 ",
355 )
356 }
357
358 #[test]
359 fn test_ignore_derive_macro_without_input() {
360 check_assist_not_applicable(
361 replace_derive_with_manual_impl,
362 "
363#[derive(<|>)]
364struct Foo {}
365 ",
366 )
367 }
368
369 #[test]
370 fn test_ignore_if_cursor_on_param() {
371 check_assist_not_applicable(
372 replace_derive_with_manual_impl,
373 "
374#[derive<|>(Debug)]
375struct Foo {}
376 ",
377 );
378
379 check_assist_not_applicable(
380 replace_derive_with_manual_impl,
381 "
382#[derive(Debug)<|>]
383struct Foo {}
384 ",
385 )
386 }
387
388 #[test]
389 fn test_ignore_if_not_derive() {
390 check_assist_not_applicable(
391 replace_derive_with_manual_impl,
392 "
393#[allow(non_camel_<|>case_types)]
394struct Foo {}
395 ",
396 )
397 }
398}
diff --git a/crates/assists/src/handlers/replace_let_with_if_let.rs b/crates/assists/src/handlers/replace_let_with_if_let.rs
index a5bcbda24..69d3b08d3 100644
--- a/crates/assists/src/handlers/replace_let_with_if_let.rs
+++ b/crates/assists/src/handlers/replace_let_with_if_let.rs
@@ -37,7 +37,7 @@ use ide_db::ty_filter::TryEnum;
37// fn compute() -> Option<i32> { None } 37// fn compute() -> Option<i32> { None }
38// ``` 38// ```
39pub(crate) fn replace_let_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 39pub(crate) fn replace_let_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
40 let let_kw = ctx.find_token_at_offset(T![let])?; 40 let let_kw = ctx.find_token_syntax_at_offset(T![let])?;
41 let let_stmt = let_kw.ancestors().find_map(ast::LetStmt::cast)?; 41 let let_stmt = let_kw.ancestors().find_map(ast::LetStmt::cast)?;
42 let init = let_stmt.initializer()?; 42 let init = let_stmt.initializer()?;
43 let original_pat = let_stmt.pat()?; 43 let original_pat = let_stmt.pat()?;
diff --git a/crates/assists/src/handlers/replace_string_with_char.rs b/crates/assists/src/handlers/replace_string_with_char.rs
index 4ca87a8ec..b4b898846 100644
--- a/crates/assists/src/handlers/replace_string_with_char.rs
+++ b/crates/assists/src/handlers/replace_string_with_char.rs
@@ -1,8 +1,4 @@
1use syntax::{ 1use syntax::{ast, AstToken, SyntaxKind::STRING};
2 ast::{self, HasStringValue},
3 AstToken,
4 SyntaxKind::STRING,
5};
6 2
7use crate::{AssistContext, AssistId, AssistKind, Assists}; 3use crate::{AssistContext, AssistId, AssistKind, Assists};
8 4
@@ -22,7 +18,7 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
22// } 18// }
23// ``` 19// ```
24pub(crate) fn replace_string_with_char(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 20pub(crate) fn replace_string_with_char(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
25 let token = ctx.find_token_at_offset(STRING).and_then(ast::String::cast)?; 21 let token = ctx.find_token_syntax_at_offset(STRING).and_then(ast::String::cast)?;
26 let value = token.value()?; 22 let value = token.value()?;
27 let target = token.syntax().text_range(); 23 let target = token.syntax().text_range();
28 24
diff --git a/crates/assists/src/handlers/split_import.rs b/crates/assists/src/handlers/split_import.rs
index 15e67eaa1..ef1f6b8a1 100644
--- a/crates/assists/src/handlers/split_import.rs
+++ b/crates/assists/src/handlers/split_import.rs
@@ -16,7 +16,7 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
16// use std::{collections::HashMap}; 16// use std::{collections::HashMap};
17// ``` 17// ```
18pub(crate) fn split_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { 18pub(crate) fn split_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
19 let colon_colon = ctx.find_token_at_offset(T![::])?; 19 let colon_colon = ctx.find_token_syntax_at_offset(T![::])?;
20 let path = ast::Path::cast(colon_colon.parent())?.qualifier()?; 20 let path = ast::Path::cast(colon_colon.parent())?.qualifier()?;
21 let top_path = successors(Some(path.clone()), |it| it.parent_path()).last()?; 21 let top_path = successors(Some(path.clone()), |it| it.parent_path()).last()?;
22 22
diff --git a/crates/assists/src/handlers/unwrap_block.rs b/crates/assists/src/handlers/unwrap_block.rs
index 3851aeb3e..36ef871b9 100644
--- a/crates/assists/src/handlers/unwrap_block.rs
+++ b/crates/assists/src/handlers/unwrap_block.rs
@@ -29,7 +29,7 @@ pub(crate) fn unwrap_block(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
29 let assist_id = AssistId("unwrap_block", AssistKind::RefactorRewrite); 29 let assist_id = AssistId("unwrap_block", AssistKind::RefactorRewrite);
30 let assist_label = "Unwrap block"; 30 let assist_label = "Unwrap block";
31 31
32 let l_curly_token = ctx.find_token_at_offset(T!['{'])?; 32 let l_curly_token = ctx.find_token_syntax_at_offset(T!['{'])?;
33 let mut block = ast::BlockExpr::cast(l_curly_token.parent())?; 33 let mut block = ast::BlockExpr::cast(l_curly_token.parent())?;
34 let mut parent = block.syntax().parent()?; 34 let mut parent = block.syntax().parent()?;
35 if ast::MatchArm::can_cast(parent.kind()) { 35 if ast::MatchArm::can_cast(parent.kind()) {
diff --git a/crates/assists/src/handlers/wrap_return_type_in_result.rs b/crates/assists/src/handlers/wrap_return_type_in_result.rs
new file mode 100644
index 000000000..59e5debb1
--- /dev/null
+++ b/crates/assists/src/handlers/wrap_return_type_in_result.rs
@@ -0,0 +1,1158 @@
1use std::iter;
2
3use syntax::{
4 ast::{self, make, BlockExpr, Expr, LoopBodyOwner},
5 match_ast, AstNode, SyntaxNode,
6};
7use test_utils::mark;
8
9use crate::{AssistContext, AssistId, AssistKind, Assists};
10
11// Assist: wrap_return_type_in_result
12//
13// Wrap the function's return type into Result.
14//
15// ```
16// fn foo() -> i32<|> { 42i32 }
17// ```
18// ->
19// ```
20// fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
21// ```
22pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
23 let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
24 let parent = ret_type.syntax().parent()?;
25 let block_expr = match_ast! {
26 match parent {
27 ast::Fn(func) => func.body()?,
28 ast::ClosureExpr(closure) => match closure.body()? {
29 Expr::BlockExpr(block) => block,
30 // closures require a block when a return type is specified
31 _ => return None,
32 },
33 _ => return None,
34 }
35 };
36
37 let type_ref = &ret_type.ty()?;
38 let ret_type_str = type_ref.syntax().text().to_string();
39 let first_part_ret_type = ret_type_str.splitn(2, '<').next();
40 if let Some(ret_type_first_part) = first_part_ret_type {
41 if ret_type_first_part.ends_with("Result") {
42 mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
43 return None;
44 }
45 }
46
47 acc.add(
48 AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite),
49 "Wrap return type in Result",
50 type_ref.syntax().text_range(),
51 |builder| {
52 let mut tail_return_expr_collector = TailReturnCollector::new();
53 tail_return_expr_collector.collect_jump_exprs(&block_expr, false);
54 tail_return_expr_collector.collect_tail_exprs(&block_expr);
55
56 for ret_expr_arg in tail_return_expr_collector.exprs_to_wrap {
57 let ok_wrapped = make::expr_call(
58 make::expr_path(make::path_unqualified(make::path_segment(make::name_ref(
59 "Ok",
60 )))),
61 make::arg_list(iter::once(ret_expr_arg.clone())),
62 );
63 builder.replace_ast(ret_expr_arg, ok_wrapped);
64 }
65
66 match ctx.config.snippet_cap {
67 Some(cap) => {
68 let snippet = format!("Result<{}, ${{0:_}}>", type_ref);
69 builder.replace_snippet(cap, type_ref.syntax().text_range(), snippet)
70 }
71 None => builder
72 .replace(type_ref.syntax().text_range(), format!("Result<{}, _>", type_ref)),
73 }
74 },
75 )
76}
77
78struct TailReturnCollector {
79 exprs_to_wrap: Vec<ast::Expr>,
80}
81
82impl TailReturnCollector {
83 fn new() -> Self {
84 Self { exprs_to_wrap: vec![] }
85 }
86 /// Collect all`return` expression
87 fn collect_jump_exprs(&mut self, block_expr: &BlockExpr, collect_break: bool) {
88 let statements = block_expr.statements();
89 for stmt in statements {
90 let expr = match &stmt {
91 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
92 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
93 ast::Stmt::Item(_) => continue,
94 };
95 if let Some(expr) = &expr {
96 self.handle_exprs(expr, collect_break);
97 }
98 }
99
100 // Browse tail expressions for each block
101 if let Some(expr) = block_expr.expr() {
102 if let Some(last_exprs) = get_tail_expr_from_block(&expr) {
103 for last_expr in last_exprs {
104 let last_expr = match last_expr {
105 NodeType::Node(expr) => expr,
106 NodeType::Leaf(expr) => expr.syntax().clone(),
107 };
108
109 if let Some(last_expr) = Expr::cast(last_expr.clone()) {
110 self.handle_exprs(&last_expr, collect_break);
111 } else if let Some(expr_stmt) = ast::Stmt::cast(last_expr) {
112 let expr_stmt = match &expr_stmt {
113 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
114 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
115 ast::Stmt::Item(_) => None,
116 };
117 if let Some(expr) = &expr_stmt {
118 self.handle_exprs(expr, collect_break);
119 }
120 }
121 }
122 }
123 }
124 }
125
126 fn handle_exprs(&mut self, expr: &Expr, collect_break: bool) {
127 match expr {
128 Expr::BlockExpr(block_expr) => {
129 self.collect_jump_exprs(&block_expr, collect_break);
130 }
131 Expr::ReturnExpr(ret_expr) => {
132 if let Some(ret_expr_arg) = &ret_expr.expr() {
133 self.exprs_to_wrap.push(ret_expr_arg.clone());
134 }
135 }
136 Expr::BreakExpr(break_expr) if collect_break => {
137 if let Some(break_expr_arg) = &break_expr.expr() {
138 self.exprs_to_wrap.push(break_expr_arg.clone());
139 }
140 }
141 Expr::IfExpr(if_expr) => {
142 for block in if_expr.blocks() {
143 self.collect_jump_exprs(&block, collect_break);
144 }
145 }
146 Expr::LoopExpr(loop_expr) => {
147 if let Some(block_expr) = loop_expr.loop_body() {
148 self.collect_jump_exprs(&block_expr, collect_break);
149 }
150 }
151 Expr::ForExpr(for_expr) => {
152 if let Some(block_expr) = for_expr.loop_body() {
153 self.collect_jump_exprs(&block_expr, collect_break);
154 }
155 }
156 Expr::WhileExpr(while_expr) => {
157 if let Some(block_expr) = while_expr.loop_body() {
158 self.collect_jump_exprs(&block_expr, collect_break);
159 }
160 }
161 Expr::MatchExpr(match_expr) => {
162 if let Some(arm_list) = match_expr.match_arm_list() {
163 arm_list.arms().filter_map(|match_arm| match_arm.expr()).for_each(|expr| {
164 self.handle_exprs(&expr, collect_break);
165 });
166 }
167 }
168 _ => {}
169 }
170 }
171
172 fn collect_tail_exprs(&mut self, block: &BlockExpr) {
173 if let Some(expr) = block.expr() {
174 self.handle_exprs(&expr, true);
175 self.fetch_tail_exprs(&expr);
176 }
177 }
178
179 fn fetch_tail_exprs(&mut self, expr: &Expr) {
180 if let Some(exprs) = get_tail_expr_from_block(expr) {
181 for node_type in &exprs {
182 match node_type {
183 NodeType::Leaf(expr) => {
184 self.exprs_to_wrap.push(expr.clone());
185 }
186 NodeType::Node(expr) => {
187 if let Some(last_expr) = Expr::cast(expr.clone()) {
188 self.fetch_tail_exprs(&last_expr);
189 }
190 }
191 }
192 }
193 }
194 }
195}
196
197#[derive(Debug)]
198enum NodeType {
199 Leaf(ast::Expr),
200 Node(SyntaxNode),
201}
202
203/// Get a tail expression inside a block
204fn get_tail_expr_from_block(expr: &Expr) -> Option<Vec<NodeType>> {
205 match expr {
206 Expr::IfExpr(if_expr) => {
207 let mut nodes = vec![];
208 for block in if_expr.blocks() {
209 if let Some(block_expr) = block.expr() {
210 if let Some(tail_exprs) = get_tail_expr_from_block(&block_expr) {
211 nodes.extend(tail_exprs);
212 }
213 } else if let Some(last_expr) = block.syntax().last_child() {
214 nodes.push(NodeType::Node(last_expr));
215 } else {
216 nodes.push(NodeType::Node(block.syntax().clone()));
217 }
218 }
219 Some(nodes)
220 }
221 Expr::LoopExpr(loop_expr) => {
222 loop_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
223 }
224 Expr::ForExpr(for_expr) => {
225 for_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
226 }
227 Expr::WhileExpr(while_expr) => {
228 while_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
229 }
230 Expr::BlockExpr(block_expr) => {
231 block_expr.expr().map(|lc| vec![NodeType::Node(lc.syntax().clone())])
232 }
233 Expr::MatchExpr(match_expr) => {
234 let arm_list = match_expr.match_arm_list()?;
235 let arms: Vec<NodeType> = arm_list
236 .arms()
237 .filter_map(|match_arm| match_arm.expr())
238 .map(|expr| match expr {
239 Expr::ReturnExpr(ret_expr) => NodeType::Node(ret_expr.syntax().clone()),
240 Expr::BreakExpr(break_expr) => NodeType::Node(break_expr.syntax().clone()),
241 _ => match expr.syntax().last_child() {
242 Some(last_expr) => NodeType::Node(last_expr),
243 None => NodeType::Node(expr.syntax().clone()),
244 },
245 })
246 .collect();
247
248 Some(arms)
249 }
250 Expr::BreakExpr(expr) => expr.expr().map(|e| vec![NodeType::Leaf(e)]),
251 Expr::ReturnExpr(ret_expr) => Some(vec![NodeType::Node(ret_expr.syntax().clone())]),
252
253 Expr::CallExpr(_)
254 | Expr::Literal(_)
255 | Expr::TupleExpr(_)
256 | Expr::ArrayExpr(_)
257 | Expr::ParenExpr(_)
258 | Expr::PathExpr(_)
259 | Expr::RecordExpr(_)
260 | Expr::IndexExpr(_)
261 | Expr::MethodCallExpr(_)
262 | Expr::AwaitExpr(_)
263 | Expr::CastExpr(_)
264 | Expr::RefExpr(_)
265 | Expr::PrefixExpr(_)
266 | Expr::RangeExpr(_)
267 | Expr::BinExpr(_)
268 | Expr::MacroCall(_)
269 | Expr::BoxExpr(_) => Some(vec![NodeType::Leaf(expr.clone())]),
270 _ => None,
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use crate::tests::{check_assist, check_assist_not_applicable};
277
278 use super::*;
279
280 #[test]
281 fn wrap_return_type_in_result_simple() {
282 check_assist(
283 wrap_return_type_in_result,
284 r#"
285fn foo() -> i3<|>2 {
286 let test = "test";
287 return 42i32;
288}
289"#,
290 r#"
291fn foo() -> Result<i32, ${0:_}> {
292 let test = "test";
293 return Ok(42i32);
294}
295"#,
296 );
297 }
298
299 #[test]
300 fn wrap_return_type_in_result_simple_closure() {
301 check_assist(
302 wrap_return_type_in_result,
303 r#"
304fn foo() {
305 || -> i32<|> {
306 let test = "test";
307 return 42i32;
308 };
309}
310"#,
311 r#"
312fn foo() {
313 || -> Result<i32, ${0:_}> {
314 let test = "test";
315 return Ok(42i32);
316 };
317}
318"#,
319 );
320 }
321
322 #[test]
323 fn wrap_return_type_in_result_simple_return_type_bad_cursor() {
324 check_assist_not_applicable(
325 wrap_return_type_in_result,
326 r#"
327fn foo() -> i32 {
328 let test = "test";<|>
329 return 42i32;
330}
331"#,
332 );
333 }
334
335 #[test]
336 fn wrap_return_type_in_result_simple_return_type_bad_cursor_closure() {
337 check_assist_not_applicable(
338 wrap_return_type_in_result,
339 r#"
340fn foo() {
341 || -> i32 {
342 let test = "test";<|>
343 return 42i32;
344 };
345}
346"#,
347 );
348 }
349
350 #[test]
351 fn wrap_return_type_in_result_closure_non_block() {
352 check_assist_not_applicable(wrap_return_type_in_result, r#"fn foo() { || -> i<|>32 3; }"#);
353 }
354
355 #[test]
356 fn wrap_return_type_in_result_simple_return_type_already_result_std() {
357 check_assist_not_applicable(
358 wrap_return_type_in_result,
359 r#"
360fn foo() -> std::result::Result<i32<|>, String> {
361 let test = "test";
362 return 42i32;
363}
364"#,
365 );
366 }
367
368 #[test]
369 fn wrap_return_type_in_result_simple_return_type_already_result() {
370 mark::check!(wrap_return_type_in_result_simple_return_type_already_result);
371 check_assist_not_applicable(
372 wrap_return_type_in_result,
373 r#"
374fn foo() -> Result<i32<|>, String> {
375 let test = "test";
376 return 42i32;
377}
378"#,
379 );
380 }
381
382 #[test]
383 fn wrap_return_type_in_result_simple_return_type_already_result_closure() {
384 check_assist_not_applicable(
385 wrap_return_type_in_result,
386 r#"
387fn foo() {
388 || -> Result<i32<|>, String> {
389 let test = "test";
390 return 42i32;
391 };
392}
393"#,
394 );
395 }
396
397 #[test]
398 fn wrap_return_type_in_result_simple_with_cursor() {
399 check_assist(
400 wrap_return_type_in_result,
401 r#"
402fn foo() -> <|>i32 {
403 let test = "test";
404 return 42i32;
405}
406"#,
407 r#"
408fn foo() -> Result<i32, ${0:_}> {
409 let test = "test";
410 return Ok(42i32);
411}
412"#,
413 );
414 }
415
416 #[test]
417 fn wrap_return_type_in_result_simple_with_tail() {
418 check_assist(
419 wrap_return_type_in_result,
420 r#"
421fn foo() -><|> i32 {
422 let test = "test";
423 42i32
424}
425"#,
426 r#"
427fn foo() -> Result<i32, ${0:_}> {
428 let test = "test";
429 Ok(42i32)
430}
431"#,
432 );
433 }
434
435 #[test]
436 fn wrap_return_type_in_result_simple_with_tail_closure() {
437 check_assist(
438 wrap_return_type_in_result,
439 r#"
440fn foo() {
441 || -><|> i32 {
442 let test = "test";
443 42i32
444 };
445}
446"#,
447 r#"
448fn foo() {
449 || -> Result<i32, ${0:_}> {
450 let test = "test";
451 Ok(42i32)
452 };
453}
454"#,
455 );
456 }
457
458 #[test]
459 fn wrap_return_type_in_result_simple_with_tail_only() {
460 check_assist(
461 wrap_return_type_in_result,
462 r#"fn foo() -> i32<|> { 42i32 }"#,
463 r#"fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }"#,
464 );
465 }
466
467 #[test]
468 fn wrap_return_type_in_result_simple_with_tail_block_like() {
469 check_assist(
470 wrap_return_type_in_result,
471 r#"
472fn foo() -> i32<|> {
473 if true {
474 42i32
475 } else {
476 24i32
477 }
478}
479"#,
480 r#"
481fn foo() -> Result<i32, ${0:_}> {
482 if true {
483 Ok(42i32)
484 } else {
485 Ok(24i32)
486 }
487}
488"#,
489 );
490 }
491
492 #[test]
493 fn wrap_return_type_in_result_simple_without_block_closure() {
494 check_assist(
495 wrap_return_type_in_result,
496 r#"
497fn foo() {
498 || -> i32<|> {
499 if true {
500 42i32
501 } else {
502 24i32
503 }
504 };
505}
506"#,
507 r#"
508fn foo() {
509 || -> Result<i32, ${0:_}> {
510 if true {
511 Ok(42i32)
512 } else {
513 Ok(24i32)
514 }
515 };
516}
517"#,
518 );
519 }
520
521 #[test]
522 fn wrap_return_type_in_result_simple_with_nested_if() {
523 check_assist(
524 wrap_return_type_in_result,
525 r#"
526fn foo() -> i32<|> {
527 if true {
528 if false {
529 1
530 } else {
531 2
532 }
533 } else {
534 24i32
535 }
536}
537"#,
538 r#"
539fn foo() -> Result<i32, ${0:_}> {
540 if true {
541 if false {
542 Ok(1)
543 } else {
544 Ok(2)
545 }
546 } else {
547 Ok(24i32)
548 }
549}
550"#,
551 );
552 }
553
554 #[test]
555 fn wrap_return_type_in_result_simple_with_await() {
556 check_assist(
557 wrap_return_type_in_result,
558 r#"
559async fn foo() -> i<|>32 {
560 if true {
561 if false {
562 1.await
563 } else {
564 2.await
565 }
566 } else {
567 24i32.await
568 }
569}
570"#,
571 r#"
572async fn foo() -> Result<i32, ${0:_}> {
573 if true {
574 if false {
575 Ok(1.await)
576 } else {
577 Ok(2.await)
578 }
579 } else {
580 Ok(24i32.await)
581 }
582}
583"#,
584 );
585 }
586
587 #[test]
588 fn wrap_return_type_in_result_simple_with_array() {
589 check_assist(
590 wrap_return_type_in_result,
591 r#"fn foo() -> [i32;<|> 3] { [1, 2, 3] }"#,
592 r#"fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }"#,
593 );
594 }
595
596 #[test]
597 fn wrap_return_type_in_result_simple_with_cast() {
598 check_assist(
599 wrap_return_type_in_result,
600 r#"
601fn foo() -<|>> i32 {
602 if true {
603 if false {
604 1 as i32
605 } else {
606 2 as i32
607 }
608 } else {
609 24 as i32
610 }
611}
612"#,
613 r#"
614fn foo() -> Result<i32, ${0:_}> {
615 if true {
616 if false {
617 Ok(1 as i32)
618 } else {
619 Ok(2 as i32)
620 }
621 } else {
622 Ok(24 as i32)
623 }
624}
625"#,
626 );
627 }
628
629 #[test]
630 fn wrap_return_type_in_result_simple_with_tail_block_like_match() {
631 check_assist(
632 wrap_return_type_in_result,
633 r#"
634fn foo() -> i32<|> {
635 let my_var = 5;
636 match my_var {
637 5 => 42i32,
638 _ => 24i32,
639 }
640}
641"#,
642 r#"
643fn foo() -> Result<i32, ${0:_}> {
644 let my_var = 5;
645 match my_var {
646 5 => Ok(42i32),
647 _ => Ok(24i32),
648 }
649}
650"#,
651 );
652 }
653
654 #[test]
655 fn wrap_return_type_in_result_simple_with_loop_with_tail() {
656 check_assist(
657 wrap_return_type_in_result,
658 r#"
659fn foo() -> i32<|> {
660 let my_var = 5;
661 loop {
662 println!("test");
663 5
664 }
665 my_var
666}
667"#,
668 r#"
669fn foo() -> Result<i32, ${0:_}> {
670 let my_var = 5;
671 loop {
672 println!("test");
673 5
674 }
675 Ok(my_var)
676}
677"#,
678 );
679 }
680
681 #[test]
682 fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() {
683 check_assist(
684 wrap_return_type_in_result,
685 r#"
686fn foo() -> i32<|> {
687 let my_var = let x = loop {
688 break 1;
689 };
690 my_var
691}
692"#,
693 r#"
694fn foo() -> Result<i32, ${0:_}> {
695 let my_var = let x = loop {
696 break 1;
697 };
698 Ok(my_var)
699}
700"#,
701 );
702 }
703
704 #[test]
705 fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() {
706 check_assist(
707 wrap_return_type_in_result,
708 r#"
709fn foo() -> i32<|> {
710 let my_var = 5;
711 let res = match my_var {
712 5 => 42i32,
713 _ => return 24i32,
714 };
715 res
716}
717"#,
718 r#"
719fn foo() -> Result<i32, ${0:_}> {
720 let my_var = 5;
721 let res = match my_var {
722 5 => 42i32,
723 _ => return Ok(24i32),
724 };
725 Ok(res)
726}
727"#,
728 );
729
730 check_assist(
731 wrap_return_type_in_result,
732 r#"
733fn foo() -> i32<|> {
734 let my_var = 5;
735 let res = if my_var == 5 {
736 42i32
737 } else {
738 return 24i32;
739 };
740 res
741}
742"#,
743 r#"
744fn foo() -> Result<i32, ${0:_}> {
745 let my_var = 5;
746 let res = if my_var == 5 {
747 42i32
748 } else {
749 return Ok(24i32);
750 };
751 Ok(res)
752}
753"#,
754 );
755 }
756
757 #[test]
758 fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() {
759 check_assist(
760 wrap_return_type_in_result,
761 r#"
762fn foo() -> i32<|> {
763 let my_var = 5;
764 match my_var {
765 5 => {
766 if true {
767 42i32
768 } else {
769 25i32
770 }
771 },
772 _ => {
773 let test = "test";
774 if test == "test" {
775 return bar();
776 }
777 53i32
778 },
779 }
780}
781"#,
782 r#"
783fn foo() -> Result<i32, ${0:_}> {
784 let my_var = 5;
785 match my_var {
786 5 => {
787 if true {
788 Ok(42i32)
789 } else {
790 Ok(25i32)
791 }
792 },
793 _ => {
794 let test = "test";
795 if test == "test" {
796 return Ok(bar());
797 }
798 Ok(53i32)
799 },
800 }
801}
802"#,
803 );
804 }
805
806 #[test]
807 fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() {
808 check_assist(
809 wrap_return_type_in_result,
810 r#"
811fn foo() -> i<|>32 {
812 let test = "test";
813 if test == "test" {
814 return 24i32;
815 }
816 53i32
817}
818"#,
819 r#"
820fn foo() -> Result<i32, ${0:_}> {
821 let test = "test";
822 if test == "test" {
823 return Ok(24i32);
824 }
825 Ok(53i32)
826}
827"#,
828 );
829 }
830
831 #[test]
832 fn wrap_return_type_in_result_simple_with_closure() {
833 check_assist(
834 wrap_return_type_in_result,
835 r#"
836fn foo(the_field: u32) -><|> u32 {
837 let true_closure = || { return true; };
838 if the_field < 5 {
839 let mut i = 0;
840 if true_closure() {
841 return 99;
842 } else {
843 return 0;
844 }
845 }
846 the_field
847}
848"#,
849 r#"
850fn foo(the_field: u32) -> Result<u32, ${0:_}> {
851 let true_closure = || { return true; };
852 if the_field < 5 {
853 let mut i = 0;
854 if true_closure() {
855 return Ok(99);
856 } else {
857 return Ok(0);
858 }
859 }
860 Ok(the_field)
861}
862"#,
863 );
864
865 check_assist(
866 wrap_return_type_in_result,
867 r#"
868 fn foo(the_field: u32) -> u32<|> {
869 let true_closure = || {
870 return true;
871 };
872 if the_field < 5 {
873 let mut i = 0;
874
875
876 if true_closure() {
877 return 99;
878 } else {
879 return 0;
880 }
881 }
882 let t = None;
883
884 t.unwrap_or_else(|| the_field)
885 }
886 "#,
887 r#"
888 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
889 let true_closure = || {
890 return true;
891 };
892 if the_field < 5 {
893 let mut i = 0;
894
895
896 if true_closure() {
897 return Ok(99);
898 } else {
899 return Ok(0);
900 }
901 }
902 let t = None;
903
904 Ok(t.unwrap_or_else(|| the_field))
905 }
906 "#,
907 );
908 }
909
910 #[test]
911 fn wrap_return_type_in_result_simple_with_weird_forms() {
912 check_assist(
913 wrap_return_type_in_result,
914 r#"
915fn foo() -> i32<|> {
916 let test = "test";
917 if test == "test" {
918 return 24i32;
919 }
920 let mut i = 0;
921 loop {
922 if i == 1 {
923 break 55;
924 }
925 i += 1;
926 }
927}
928"#,
929 r#"
930fn foo() -> Result<i32, ${0:_}> {
931 let test = "test";
932 if test == "test" {
933 return Ok(24i32);
934 }
935 let mut i = 0;
936 loop {
937 if i == 1 {
938 break Ok(55);
939 }
940 i += 1;
941 }
942}
943"#,
944 );
945
946 check_assist(
947 wrap_return_type_in_result,
948 r#"
949fn foo() -> i32<|> {
950 let test = "test";
951 if test == "test" {
952 return 24i32;
953 }
954 let mut i = 0;
955 loop {
956 loop {
957 if i == 1 {
958 break 55;
959 }
960 i += 1;
961 }
962 }
963}
964"#,
965 r#"
966fn foo() -> Result<i32, ${0:_}> {
967 let test = "test";
968 if test == "test" {
969 return Ok(24i32);
970 }
971 let mut i = 0;
972 loop {
973 loop {
974 if i == 1 {
975 break Ok(55);
976 }
977 i += 1;
978 }
979 }
980}
981"#,
982 );
983
984 check_assist(
985 wrap_return_type_in_result,
986 r#"
987fn foo() -> i3<|>2 {
988 let test = "test";
989 let other = 5;
990 if test == "test" {
991 let res = match other {
992 5 => 43,
993 _ => return 56,
994 };
995 }
996 let mut i = 0;
997 loop {
998 loop {
999 if i == 1 {
1000 break 55;
1001 }
1002 i += 1;
1003 }
1004 }
1005}
1006"#,
1007 r#"
1008fn foo() -> Result<i32, ${0:_}> {
1009 let test = "test";
1010 let other = 5;
1011 if test == "test" {
1012 let res = match other {
1013 5 => 43,
1014 _ => return Ok(56),
1015 };
1016 }
1017 let mut i = 0;
1018 loop {
1019 loop {
1020 if i == 1 {
1021 break Ok(55);
1022 }
1023 i += 1;
1024 }
1025 }
1026}
1027"#,
1028 );
1029
1030 check_assist(
1031 wrap_return_type_in_result,
1032 r#"
1033fn foo(the_field: u32) -> u32<|> {
1034 if the_field < 5 {
1035 let mut i = 0;
1036 loop {
1037 if i > 5 {
1038 return 55u32;
1039 }
1040 i += 3;
1041 }
1042 match i {
1043 5 => return 99,
1044 _ => return 0,
1045 };
1046 }
1047 the_field
1048}
1049"#,
1050 r#"
1051fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1052 if the_field < 5 {
1053 let mut i = 0;
1054 loop {
1055 if i > 5 {
1056 return Ok(55u32);
1057 }
1058 i += 3;
1059 }
1060 match i {
1061 5 => return Ok(99),
1062 _ => return Ok(0),
1063 };
1064 }
1065 Ok(the_field)
1066}
1067"#,
1068 );
1069
1070 check_assist(
1071 wrap_return_type_in_result,
1072 r#"
1073fn foo(the_field: u32) -> u3<|>2 {
1074 if the_field < 5 {
1075 let mut i = 0;
1076 match i {
1077 5 => return 99,
1078 _ => return 0,
1079 }
1080 }
1081 the_field
1082}
1083"#,
1084 r#"
1085fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1086 if the_field < 5 {
1087 let mut i = 0;
1088 match i {
1089 5 => return Ok(99),
1090 _ => return Ok(0),
1091 }
1092 }
1093 Ok(the_field)
1094}
1095"#,
1096 );
1097
1098 check_assist(
1099 wrap_return_type_in_result,
1100 r#"
1101fn foo(the_field: u32) -> u32<|> {
1102 if the_field < 5 {
1103 let mut i = 0;
1104 if i == 5 {
1105 return 99
1106 } else {
1107 return 0
1108 }
1109 }
1110 the_field
1111}
1112"#,
1113 r#"
1114fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1115 if the_field < 5 {
1116 let mut i = 0;
1117 if i == 5 {
1118 return Ok(99)
1119 } else {
1120 return Ok(0)
1121 }
1122 }
1123 Ok(the_field)
1124}
1125"#,
1126 );
1127
1128 check_assist(
1129 wrap_return_type_in_result,
1130 r#"
1131fn foo(the_field: u32) -> <|>u32 {
1132 if the_field < 5 {
1133 let mut i = 0;
1134 if i == 5 {
1135 return 99;
1136 } else {
1137 return 0;
1138 }
1139 }
1140 the_field
1141}
1142"#,
1143 r#"
1144fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1145 if the_field < 5 {
1146 let mut i = 0;
1147 if i == 5 {
1148 return Ok(99);
1149 } else {
1150 return Ok(0);
1151 }
1152 }
1153 Ok(the_field)
1154}
1155"#,
1156 );
1157 }
1158}
diff --git a/crates/assists/src/lib.rs b/crates/assists/src/lib.rs
index b804e495d..e8d81b33d 100644
--- a/crates/assists/src/lib.rs
+++ b/crates/assists/src/lib.rs
@@ -120,13 +120,11 @@ mod handlers {
120 120
121 pub(crate) type Handler = fn(&mut Assists, &AssistContext) -> Option<()>; 121 pub(crate) type Handler = fn(&mut Assists, &AssistContext) -> Option<()>;
122 122
123 mod add_custom_impl;
124 mod add_explicit_type; 123 mod add_explicit_type;
125 mod add_missing_impl_members; 124 mod add_missing_impl_members;
126 mod add_turbo_fish; 125 mod add_turbo_fish;
127 mod apply_demorgan; 126 mod apply_demorgan;
128 mod auto_import; 127 mod auto_import;
129 mod change_return_type_to_result;
130 mod change_visibility; 128 mod change_visibility;
131 mod convert_integer_literal; 129 mod convert_integer_literal;
132 mod early_return; 130 mod early_return;
@@ -143,6 +141,7 @@ mod handlers {
143 mod generate_function; 141 mod generate_function;
144 mod generate_impl; 142 mod generate_impl;
145 mod generate_new; 143 mod generate_new;
144 mod infer_function_return_type;
146 mod inline_local_variable; 145 mod inline_local_variable;
147 mod introduce_named_lifetime; 146 mod introduce_named_lifetime;
148 mod invert_if; 147 mod invert_if;
@@ -156,6 +155,7 @@ mod handlers {
156 mod remove_mut; 155 mod remove_mut;
157 mod remove_unused_param; 156 mod remove_unused_param;
158 mod reorder_fields; 157 mod reorder_fields;
158 mod replace_derive_with_manual_impl;
159 mod replace_if_let_with_match; 159 mod replace_if_let_with_match;
160 mod replace_impl_trait_with_generic; 160 mod replace_impl_trait_with_generic;
161 mod replace_let_with_if_let; 161 mod replace_let_with_if_let;
@@ -164,16 +164,15 @@ mod handlers {
164 mod replace_unwrap_with_match; 164 mod replace_unwrap_with_match;
165 mod split_import; 165 mod split_import;
166 mod unwrap_block; 166 mod unwrap_block;
167 mod wrap_return_type_in_result;
167 168
168 pub(crate) fn all() -> &'static [Handler] { 169 pub(crate) fn all() -> &'static [Handler] {
169 &[ 170 &[
170 // These are alphabetic for the foolish consistency 171 // These are alphabetic for the foolish consistency
171 add_custom_impl::add_custom_impl,
172 add_explicit_type::add_explicit_type, 172 add_explicit_type::add_explicit_type,
173 add_turbo_fish::add_turbo_fish, 173 add_turbo_fish::add_turbo_fish,
174 apply_demorgan::apply_demorgan, 174 apply_demorgan::apply_demorgan,
175 auto_import::auto_import, 175 auto_import::auto_import,
176 change_return_type_to_result::change_return_type_to_result,
177 change_visibility::change_visibility, 176 change_visibility::change_visibility,
178 convert_integer_literal::convert_integer_literal, 177 convert_integer_literal::convert_integer_literal,
179 early_return::convert_to_guarded_return, 178 early_return::convert_to_guarded_return,
@@ -190,6 +189,7 @@ mod handlers {
190 generate_function::generate_function, 189 generate_function::generate_function,
191 generate_impl::generate_impl, 190 generate_impl::generate_impl,
192 generate_new::generate_new, 191 generate_new::generate_new,
192 infer_function_return_type::infer_function_return_type,
193 inline_local_variable::inline_local_variable, 193 inline_local_variable::inline_local_variable,
194 introduce_named_lifetime::introduce_named_lifetime, 194 introduce_named_lifetime::introduce_named_lifetime,
195 invert_if::invert_if, 195 invert_if::invert_if,
@@ -206,6 +206,7 @@ mod handlers {
206 remove_mut::remove_mut, 206 remove_mut::remove_mut,
207 remove_unused_param::remove_unused_param, 207 remove_unused_param::remove_unused_param,
208 reorder_fields::reorder_fields, 208 reorder_fields::reorder_fields,
209 replace_derive_with_manual_impl::replace_derive_with_manual_impl,
209 replace_if_let_with_match::replace_if_let_with_match, 210 replace_if_let_with_match::replace_if_let_with_match,
210 replace_impl_trait_with_generic::replace_impl_trait_with_generic, 211 replace_impl_trait_with_generic::replace_impl_trait_with_generic,
211 replace_let_with_if_let::replace_let_with_if_let, 212 replace_let_with_if_let::replace_let_with_if_let,
@@ -213,6 +214,7 @@ mod handlers {
213 replace_unwrap_with_match::replace_unwrap_with_match, 214 replace_unwrap_with_match::replace_unwrap_with_match,
214 split_import::split_import, 215 split_import::split_import,
215 unwrap_block::unwrap_block, 216 unwrap_block::unwrap_block,
217 wrap_return_type_in_result::wrap_return_type_in_result,
216 // These are manually sorted for better priorities 218 // These are manually sorted for better priorities
217 add_missing_impl_members::add_missing_impl_members, 219 add_missing_impl_members::add_missing_impl_members,
218 add_missing_impl_members::add_missing_default_members, 220 add_missing_impl_members::add_missing_default_members,
diff --git a/crates/assists/src/tests.rs b/crates/assists/src/tests.rs
index 849d85e76..709a34d03 100644
--- a/crates/assists/src/tests.rs
+++ b/crates/assists/src/tests.rs
@@ -7,7 +7,7 @@ use syntax::TextRange;
7use test_utils::{assert_eq_text, extract_offset, extract_range}; 7use test_utils::{assert_eq_text, extract_offset, extract_range};
8 8
9use crate::{handlers::Handler, Assist, AssistConfig, AssistContext, AssistKind, Assists}; 9use crate::{handlers::Handler, Assist, AssistConfig, AssistContext, AssistKind, Assists};
10use stdx::trim_indent; 10use stdx::{format_to, trim_indent};
11 11
12pub(crate) fn with_single_file(text: &str) -> (RootDatabase, FileId) { 12pub(crate) fn with_single_file(text: &str) -> (RootDatabase, FileId) {
13 RootDatabase::with_single_file(text) 13 RootDatabase::with_single_file(text)
@@ -98,11 +98,24 @@ fn check(handler: Handler, before: &str, expected: ExpectedResult, assist_label:
98 match (assist, expected) { 98 match (assist, expected) {
99 (Some(assist), ExpectedResult::After(after)) => { 99 (Some(assist), ExpectedResult::After(after)) => {
100 let mut source_change = assist.source_change; 100 let mut source_change = assist.source_change;
101 let change = source_change.source_file_edits.pop().unwrap(); 101 assert!(!source_change.source_file_edits.is_empty());
102 102 let skip_header = source_change.source_file_edits.len() == 1;
103 let mut actual = db.file_text(change.file_id).as_ref().to_owned(); 103 source_change.source_file_edits.sort_by_key(|it| it.file_id);
104 change.edit.apply(&mut actual); 104
105 assert_eq_text!(after, &actual); 105 let mut buf = String::new();
106 for source_file_edit in source_change.source_file_edits {
107 let mut text = db.file_text(source_file_edit.file_id).as_ref().to_owned();
108 source_file_edit.edit.apply(&mut text);
109 if !skip_header {
110 let sr = db.file_source_root(source_file_edit.file_id);
111 let sr = db.source_root(sr);
112 let path = sr.path_for_file(&source_file_edit.file_id).unwrap();
113 format_to!(buf, "//- {}\n", path)
114 }
115 buf.push_str(&text);
116 }
117
118 assert_eq_text!(after, &buf);
106 } 119 }
107 (Some(assist), ExpectedResult::Target(target)) => { 120 (Some(assist), ExpectedResult::Target(target)) => {
108 let range = assist.assist.target; 121 let range = assist.assist.target;
diff --git a/crates/assists/src/tests/generated.rs b/crates/assists/src/tests/generated.rs
index acbf5b652..dbf4f21aa 100644
--- a/crates/assists/src/tests/generated.rs
+++ b/crates/assists/src/tests/generated.rs
@@ -3,25 +3,6 @@
3use super::check_doc_test; 3use super::check_doc_test;
4 4
5#[test] 5#[test]
6fn doctest_add_custom_impl() {
7 check_doc_test(
8 "add_custom_impl",
9 r#####"
10#[derive(Deb<|>ug, Display)]
11struct S;
12"#####,
13 r#####"
14#[derive(Display)]
15struct S;
16
17impl Debug for S {
18 $0
19}
20"#####,
21 )
22}
23
24#[test]
25fn doctest_add_explicit_type() { 6fn doctest_add_explicit_type() {
26 check_doc_test( 7 check_doc_test(
27 "add_explicit_type", 8 "add_explicit_type",
@@ -178,19 +159,6 @@ pub mod std { pub mod collections { pub struct HashMap { } } }
178} 159}
179 160
180#[test] 161#[test]
181fn doctest_change_return_type_to_result() {
182 check_doc_test(
183 "change_return_type_to_result",
184 r#####"
185fn foo() -> i32<|> { 42i32 }
186"#####,
187 r#####"
188fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
189"#####,
190 )
191}
192
193#[test]
194fn doctest_change_visibility() { 162fn doctest_change_visibility() {
195 check_doc_test( 163 check_doc_test(
196 "change_visibility", 164 "change_visibility",
@@ -506,6 +474,19 @@ impl<T: Clone> Ctx<T> {
506} 474}
507 475
508#[test] 476#[test]
477fn doctest_infer_function_return_type() {
478 check_doc_test(
479 "infer_function_return_type",
480 r#####"
481fn foo() { 4<|>2i32 }
482"#####,
483 r#####"
484fn foo() -> i32 { 42i32 }
485"#####,
486 )
487}
488
489#[test]
509fn doctest_inline_local_variable() { 490fn doctest_inline_local_variable() {
510 check_doc_test( 491 check_doc_test(
511 "inline_local_variable", 492 "inline_local_variable",
@@ -819,6 +800,29 @@ const test: Foo = Foo {foo: 1, bar: 0}
819} 800}
820 801
821#[test] 802#[test]
803fn doctest_replace_derive_with_manual_impl() {
804 check_doc_test(
805 "replace_derive_with_manual_impl",
806 r#####"
807trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
808#[derive(Deb<|>ug, Display)]
809struct S;
810"#####,
811 r#####"
812trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
813#[derive(Display)]
814struct S;
815
816impl Debug for S {
817 fn fmt(&self, f: &mut Formatter) -> Result<()> {
818 ${0:todo!()}
819 }
820}
821"#####,
822 )
823}
824
825#[test]
822fn doctest_replace_if_let_with_match() { 826fn doctest_replace_if_let_with_match() {
823 check_doc_test( 827 check_doc_test(
824 "replace_if_let_with_match", 828 "replace_if_let_with_match",
@@ -972,3 +976,16 @@ fn foo() {
972"#####, 976"#####,
973 ) 977 )
974} 978}
979
980#[test]
981fn doctest_wrap_return_type_in_result() {
982 check_doc_test(
983 "wrap_return_type_in_result",
984 r#####"
985fn foo() -> i32<|> { 42i32 }
986"#####,
987 r#####"
988fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
989"#####,
990 )
991}
diff --git a/crates/assists/src/utils.rs b/crates/assists/src/utils.rs
index 56f925ee6..7071fe96b 100644
--- a/crates/assists/src/utils.rs
+++ b/crates/assists/src/utils.rs
@@ -4,17 +4,22 @@ pub(crate) mod import_assets;
4 4
5use std::ops; 5use std::ops;
6 6
7use hir::{Crate, Enum, Module, ScopeDef, Semantics, Trait}; 7use hir::{Crate, Enum, HasSource, Module, ScopeDef, Semantics, Trait};
8use ide_db::RootDatabase; 8use ide_db::RootDatabase;
9use itertools::Itertools; 9use itertools::Itertools;
10use syntax::{ 10use syntax::{
11 ast::{self, make, ArgListOwner}, 11 ast::edit::AstNodeEdit,
12 ast::NameOwner,
13 ast::{self, edit, make, ArgListOwner},
12 AstNode, Direction, 14 AstNode, Direction,
13 SyntaxKind::*, 15 SyntaxKind::*,
14 SyntaxNode, TextSize, T, 16 SyntaxNode, TextSize, T,
15}; 17};
16 18
17use crate::assist_config::SnippetCap; 19use crate::{
20 assist_config::SnippetCap,
21 ast_transform::{self, AstTransform, QualifyPaths, SubstituteTypeParams},
22};
18 23
19pub use insert_use::MergeBehaviour; 24pub use insert_use::MergeBehaviour;
20pub(crate) use insert_use::{insert_use, ImportScope}; 25pub(crate) use insert_use::{insert_use, ImportScope};
@@ -77,6 +82,87 @@ pub fn extract_trivial_expression(block: &ast::BlockExpr) -> Option<ast::Expr> {
77 None 82 None
78} 83}
79 84
85#[derive(Copy, Clone, PartialEq)]
86pub enum DefaultMethods {
87 Only,
88 No,
89}
90
91pub fn filter_assoc_items(
92 db: &RootDatabase,
93 items: &[hir::AssocItem],
94 default_methods: DefaultMethods,
95) -> Vec<ast::AssocItem> {
96 fn has_def_name(item: &ast::AssocItem) -> bool {
97 match item {
98 ast::AssocItem::Fn(def) => def.name(),
99 ast::AssocItem::TypeAlias(def) => def.name(),
100 ast::AssocItem::Const(def) => def.name(),
101 ast::AssocItem::MacroCall(_) => None,
102 }
103 .is_some()
104 };
105
106 items
107 .iter()
108 .map(|i| match i {
109 hir::AssocItem::Function(i) => ast::AssocItem::Fn(i.source(db).value),
110 hir::AssocItem::TypeAlias(i) => ast::AssocItem::TypeAlias(i.source(db).value),
111 hir::AssocItem::Const(i) => ast::AssocItem::Const(i.source(db).value),
112 })
113 .filter(has_def_name)
114 .filter(|it| match it {
115 ast::AssocItem::Fn(def) => matches!(
116 (default_methods, def.body()),
117 (DefaultMethods::Only, Some(_)) | (DefaultMethods::No, None)
118 ),
119 _ => default_methods == DefaultMethods::No,
120 })
121 .collect::<Vec<_>>()
122}
123
124pub fn add_trait_assoc_items_to_impl(
125 sema: &hir::Semantics<ide_db::RootDatabase>,
126 items: Vec<ast::AssocItem>,
127 trait_: hir::Trait,
128 impl_def: ast::Impl,
129 target_scope: hir::SemanticsScope,
130) -> (ast::Impl, ast::AssocItem) {
131 let impl_item_list = impl_def.assoc_item_list().unwrap_or_else(make::assoc_item_list);
132
133 let n_existing_items = impl_item_list.assoc_items().count();
134 let source_scope = sema.scope_for_def(trait_);
135 let ast_transform = QualifyPaths::new(&target_scope, &source_scope)
136 .or(SubstituteTypeParams::for_trait_impl(&source_scope, trait_, impl_def.clone()));
137
138 let items = items
139 .into_iter()
140 .map(|it| ast_transform::apply(&*ast_transform, it))
141 .map(|it| match it {
142 ast::AssocItem::Fn(def) => ast::AssocItem::Fn(add_body(def)),
143 ast::AssocItem::TypeAlias(def) => ast::AssocItem::TypeAlias(def.remove_bounds()),
144 _ => it,
145 })
146 .map(|it| edit::remove_attrs_and_docs(&it));
147
148 let new_impl_item_list = impl_item_list.append_items(items);
149 let new_impl_def = impl_def.with_assoc_item_list(new_impl_item_list);
150 let first_new_item =
151 new_impl_def.assoc_item_list().unwrap().assoc_items().nth(n_existing_items).unwrap();
152 return (new_impl_def, first_new_item);
153
154 fn add_body(fn_def: ast::Fn) -> ast::Fn {
155 match fn_def.body() {
156 Some(_) => fn_def,
157 None => {
158 let body =
159 make::block_expr(None, Some(make::expr_todo())).indent(edit::IndentLevel(1));
160 fn_def.with_body(body)
161 }
162 }
163 }
164}
165
80#[derive(Clone, Copy, Debug)] 166#[derive(Clone, Copy, Debug)]
81pub(crate) enum Cursor<'a> { 167pub(crate) enum Cursor<'a> {
82 Replace(&'a SyntaxNode), 168 Replace(&'a SyntaxNode),