diff options
-rw-r--r-- | crates/assists/src/handlers/extract_struct_from_enum_variant.rs | 120 | ||||
-rw-r--r-- | crates/ide/src/diagnostics/fixes.rs | 3 | ||||
-rw-r--r-- | crates/syntax/src/ast/make.rs | 27 |
3 files changed, 101 insertions, 49 deletions
diff --git a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs index 1bf5a4214..14209b771 100644 --- a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs +++ b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs | |||
@@ -1,3 +1,6 @@ | |||
1 | use std::iter; | ||
2 | |||
3 | use either::Either; | ||
1 | use hir::{AsName, EnumVariant, Module, ModuleDef, Name}; | 4 | use hir::{AsName, EnumVariant, Module, ModuleDef, Name}; |
2 | use ide_db::{defs::Definition, search::Reference, RootDatabase}; | 5 | use ide_db::{defs::Definition, search::Reference, RootDatabase}; |
3 | use rustc_hash::{FxHashMap, FxHashSet}; | 6 | use rustc_hash::{FxHashMap, FxHashSet}; |
@@ -31,48 +34,32 @@ pub(crate) fn extract_struct_from_enum_variant( | |||
31 | ctx: &AssistContext, | 34 | ctx: &AssistContext, |
32 | ) -> Option<()> { | 35 | ) -> Option<()> { |
33 | let variant = ctx.find_node_at_offset::<ast::Variant>()?; | 36 | let variant = ctx.find_node_at_offset::<ast::Variant>()?; |
34 | 37 | let field_list = extract_field_list_if_applicable(&variant)?; | |
35 | fn is_applicable_variant(variant: &ast::Variant) -> bool { | ||
36 | 1 < match variant.kind() { | ||
37 | ast::StructKind::Record(field_list) => field_list.fields().count(), | ||
38 | ast::StructKind::Tuple(field_list) => field_list.fields().count(), | ||
39 | ast::StructKind::Unit => 0, | ||
40 | } | ||
41 | } | ||
42 | |||
43 | if !is_applicable_variant(&variant) { | ||
44 | return None; | ||
45 | } | ||
46 | |||
47 | let field_list = match variant.kind() { | ||
48 | ast::StructKind::Tuple(field_list) => field_list, | ||
49 | _ => return None, | ||
50 | }; | ||
51 | 38 | ||
52 | let variant_name = variant.name()?; | 39 | let variant_name = variant.name()?; |
53 | let variant_hir = ctx.sema.to_def(&variant)?; | 40 | let variant_hir = ctx.sema.to_def(&variant)?; |
54 | if existing_definition(ctx.db(), &variant_name, &variant_hir) { | 41 | if existing_definition(ctx.db(), &variant_name, &variant_hir) { |
55 | return None; | 42 | return None; |
56 | } | 43 | } |
44 | |||
57 | let enum_ast = variant.parent_enum(); | 45 | let enum_ast = variant.parent_enum(); |
58 | let visibility = enum_ast.visibility(); | ||
59 | let enum_hir = ctx.sema.to_def(&enum_ast)?; | 46 | let enum_hir = ctx.sema.to_def(&enum_ast)?; |
60 | let variant_hir_name = variant_hir.name(ctx.db()); | ||
61 | let enum_module_def = ModuleDef::from(enum_hir); | ||
62 | let current_module = enum_hir.module(ctx.db()); | ||
63 | let target = variant.syntax().text_range(); | 47 | let target = variant.syntax().text_range(); |
64 | acc.add( | 48 | acc.add( |
65 | AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite), | 49 | AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite), |
66 | "Extract struct from enum variant", | 50 | "Extract struct from enum variant", |
67 | target, | 51 | target, |
68 | |builder| { | 52 | |builder| { |
69 | let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)); | 53 | let variant_hir_name = variant_hir.name(ctx.db()); |
70 | let res = definition.usages(&ctx.sema).all(); | 54 | let enum_module_def = ModuleDef::from(enum_hir); |
55 | let usages = | ||
56 | Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)).usages(&ctx.sema).all(); | ||
71 | 57 | ||
72 | let mut visited_modules_set = FxHashSet::default(); | 58 | let mut visited_modules_set = FxHashSet::default(); |
59 | let current_module = enum_hir.module(ctx.db()); | ||
73 | visited_modules_set.insert(current_module); | 60 | visited_modules_set.insert(current_module); |
74 | let mut rewriters = FxHashMap::default(); | 61 | let mut rewriters = FxHashMap::default(); |
75 | for reference in res { | 62 | for reference in usages { |
76 | let rewriter = rewriters | 63 | let rewriter = rewriters |
77 | .entry(reference.file_range.file_id) | 64 | .entry(reference.file_range.file_id) |
78 | .or_insert_with(SyntaxRewriter::default); | 65 | .or_insert_with(SyntaxRewriter::default); |
@@ -94,20 +81,34 @@ pub(crate) fn extract_struct_from_enum_variant( | |||
94 | builder.rewrite(rewriter); | 81 | builder.rewrite(rewriter); |
95 | } | 82 | } |
96 | builder.edit_file(ctx.frange.file_id); | 83 | builder.edit_file(ctx.frange.file_id); |
97 | update_variant(&mut rewriter, &variant_name, &field_list); | 84 | update_variant(&mut rewriter, &variant); |
98 | extract_struct_def( | 85 | extract_struct_def( |
99 | &mut rewriter, | 86 | &mut rewriter, |
100 | &enum_ast, | 87 | &enum_ast, |
101 | variant_name.clone(), | 88 | variant_name.clone(), |
102 | &field_list, | 89 | &field_list, |
103 | &variant.parent_enum().syntax().clone().into(), | 90 | &variant.parent_enum().syntax().clone().into(), |
104 | visibility, | 91 | enum_ast.visibility(), |
105 | ); | 92 | ); |
106 | builder.rewrite(rewriter); | 93 | builder.rewrite(rewriter); |
107 | }, | 94 | }, |
108 | ) | 95 | ) |
109 | } | 96 | } |
110 | 97 | ||
98 | fn extract_field_list_if_applicable( | ||
99 | variant: &ast::Variant, | ||
100 | ) -> Option<Either<ast::RecordFieldList, ast::TupleFieldList>> { | ||
101 | match variant.kind() { | ||
102 | ast::StructKind::Record(field_list) if field_list.fields().next().is_some() => { | ||
103 | Some(Either::Left(field_list)) | ||
104 | } | ||
105 | ast::StructKind::Tuple(field_list) if field_list.fields().count() > 1 => { | ||
106 | Some(Either::Right(field_list)) | ||
107 | } | ||
108 | _ => None, | ||
109 | } | ||
110 | } | ||
111 | |||
111 | fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool { | 112 | fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool { |
112 | variant | 113 | variant |
113 | .parent_enum(db) | 114 | .parent_enum(db) |
@@ -150,19 +151,29 @@ fn extract_struct_def( | |||
150 | rewriter: &mut SyntaxRewriter, | 151 | rewriter: &mut SyntaxRewriter, |
151 | enum_: &ast::Enum, | 152 | enum_: &ast::Enum, |
152 | variant_name: ast::Name, | 153 | variant_name: ast::Name, |
153 | variant_list: &ast::TupleFieldList, | 154 | field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>, |
154 | start_offset: &SyntaxElement, | 155 | start_offset: &SyntaxElement, |
155 | visibility: Option<ast::Visibility>, | 156 | visibility: Option<ast::Visibility>, |
156 | ) -> Option<()> { | 157 | ) -> Option<()> { |
157 | let variant_list = make::tuple_field_list( | 158 | let pub_vis = Some(make::visibility_pub()); |
158 | variant_list | 159 | let field_list = match field_list { |
159 | .fields() | 160 | Either::Left(field_list) => { |
160 | .flat_map(|field| Some(make::tuple_field(Some(make::visibility_pub()), field.ty()?))), | 161 | make::record_field_list(field_list.fields().flat_map(|field| { |
161 | ); | 162 | Some(make::record_field(pub_vis.clone(), field.name()?, field.ty()?)) |
163 | })) | ||
164 | .into() | ||
165 | } | ||
166 | Either::Right(field_list) => make::tuple_field_list( | ||
167 | field_list | ||
168 | .fields() | ||
169 | .flat_map(|field| Some(make::tuple_field(pub_vis.clone(), field.ty()?))), | ||
170 | ) | ||
171 | .into(), | ||
172 | }; | ||
162 | 173 | ||
163 | rewriter.insert_before( | 174 | rewriter.insert_before( |
164 | start_offset, | 175 | start_offset, |
165 | make::struct_(visibility, variant_name, None, variant_list.into()).syntax(), | 176 | make::struct_(visibility, variant_name, None, field_list).syntax(), |
166 | ); | 177 | ); |
167 | rewriter.insert_before(start_offset, &make::tokens::blank_line()); | 178 | rewriter.insert_before(start_offset, &make::tokens::blank_line()); |
168 | 179 | ||
@@ -173,15 +184,14 @@ fn extract_struct_def( | |||
173 | Some(()) | 184 | Some(()) |
174 | } | 185 | } |
175 | 186 | ||
176 | fn update_variant( | 187 | fn update_variant(rewriter: &mut SyntaxRewriter, variant: &ast::Variant) -> Option<()> { |
177 | rewriter: &mut SyntaxRewriter, | 188 | let name = variant.name()?; |
178 | variant_name: &ast::Name, | 189 | let tuple_field = make::tuple_field(None, make::ty(name.text())); |
179 | field_list: &ast::TupleFieldList, | 190 | let replacement = make::variant( |
180 | ) -> Option<()> { | 191 | name, |
181 | let (l, r): (SyntaxElement, SyntaxElement) = | 192 | Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))), |
182 | (field_list.l_paren_token()?.into(), field_list.r_paren_token()?.into()); | 193 | ); |
183 | let replacement = vec![l, variant_name.syntax().clone().into(), r]; | 194 | rewriter.replace(variant.syntax(), replacement.syntax()); |
184 | rewriter.replace_with_many(field_list.syntax(), replacement); | ||
185 | Some(()) | 195 | Some(()) |
186 | } | 196 | } |
187 | 197 | ||
@@ -243,10 +253,18 @@ enum A { One(One) }"#, | |||
243 | check_assist( | 253 | check_assist( |
244 | extract_struct_from_enum_variant, | 254 | extract_struct_from_enum_variant, |
245 | "enum A { <|>One { foo: u32, bar: u32 } }", | 255 | "enum A { <|>One { foo: u32, bar: u32 } }", |
246 | r#"struct One { | 256 | r#"struct One{ pub foo: u32, pub bar: u32 } |
247 | pub foo: u32, | 257 | |
248 | pub bar: u32 | 258 | enum A { One(One) }"#, |
249 | } | 259 | ); |
260 | } | ||
261 | |||
262 | #[test] | ||
263 | fn test_extract_struct_one_field_named() { | ||
264 | check_assist( | ||
265 | extract_struct_from_enum_variant, | ||
266 | "enum A { <|>One { foo: u32 } }", | ||
267 | r#"struct One{ pub foo: u32 } | ||
250 | 268 | ||
251 | enum A { One(One) }"#, | 269 | enum A { One(One) }"#, |
252 | ); | 270 | ); |
@@ -350,4 +368,14 @@ fn another_fn() { | |||
350 | fn test_extract_not_applicable_one_field() { | 368 | fn test_extract_not_applicable_one_field() { |
351 | check_not_applicable(r"enum A { <|>One(u32) }"); | 369 | check_not_applicable(r"enum A { <|>One(u32) }"); |
352 | } | 370 | } |
371 | |||
372 | #[test] | ||
373 | fn test_extract_not_applicable_no_field_tuple() { | ||
374 | check_not_applicable(r"enum A { <|>None() }"); | ||
375 | } | ||
376 | |||
377 | #[test] | ||
378 | fn test_extract_not_applicable_no_field_named() { | ||
379 | check_not_applicable(r"enum A { <|>None {} }"); | ||
380 | } | ||
353 | } | 381 | } |
diff --git a/crates/ide/src/diagnostics/fixes.rs b/crates/ide/src/diagnostics/fixes.rs index 02e17ba43..d275dd75b 100644 --- a/crates/ide/src/diagnostics/fixes.rs +++ b/crates/ide/src/diagnostics/fixes.rs | |||
@@ -157,7 +157,8 @@ fn missing_record_expr_field_fix( | |||
157 | return None; | 157 | return None; |
158 | } | 158 | } |
159 | let new_field = make::record_field( | 159 | let new_field = make::record_field( |
160 | record_expr_field.field_name()?, | 160 | None, |
161 | make::name(record_expr_field.field_name()?.text()), | ||
161 | make::ty(&new_field_type.display_source_code(sema.db, module.into()).ok()?), | 162 | make::ty(&new_field_type.display_source_code(sema.db, module.into()).ok()?), |
162 | ); | 163 | ); |
163 | 164 | ||
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index 2cf436e7a..b1578820f 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs | |||
@@ -110,8 +110,16 @@ pub fn record_expr_field(name: ast::NameRef, expr: Option<ast::Expr>) -> ast::Re | |||
110 | } | 110 | } |
111 | } | 111 | } |
112 | 112 | ||
113 | pub fn record_field(name: ast::NameRef, ty: ast::Type) -> ast::RecordField { | 113 | pub fn record_field( |
114 | ast_from_text(&format!("struct S {{ {}: {}, }}", name, ty)) | 114 | visibility: Option<ast::Visibility>, |
115 | name: ast::Name, | ||
116 | ty: ast::Type, | ||
117 | ) -> ast::RecordField { | ||
118 | let visibility = match visibility { | ||
119 | None => String::new(), | ||
120 | Some(it) => format!("{} ", it), | ||
121 | }; | ||
122 | ast_from_text(&format!("struct S {{ {}{}: {}, }}", visibility, name, ty)) | ||
115 | } | 123 | } |
116 | 124 | ||
117 | pub fn block_expr( | 125 | pub fn block_expr( |
@@ -360,6 +368,13 @@ pub fn tuple_field_list(fields: impl IntoIterator<Item = ast::TupleField>) -> as | |||
360 | ast_from_text(&format!("struct f({});", fields)) | 368 | ast_from_text(&format!("struct f({});", fields)) |
361 | } | 369 | } |
362 | 370 | ||
371 | pub fn record_field_list( | ||
372 | fields: impl IntoIterator<Item = ast::RecordField>, | ||
373 | ) -> ast::RecordFieldList { | ||
374 | let fields = fields.into_iter().join(", "); | ||
375 | ast_from_text(&format!("struct f {{ {} }}", fields)) | ||
376 | } | ||
377 | |||
363 | pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::TupleField { | 378 | pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::TupleField { |
364 | let visibility = match visibility { | 379 | let visibility = match visibility { |
365 | None => String::new(), | 380 | None => String::new(), |
@@ -368,6 +383,14 @@ pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::T | |||
368 | ast_from_text(&format!("struct f({}{});", visibility, ty)) | 383 | ast_from_text(&format!("struct f({}{});", visibility, ty)) |
369 | } | 384 | } |
370 | 385 | ||
386 | pub fn variant(name: ast::Name, field_list: Option<ast::FieldList>) -> ast::Variant { | ||
387 | let field_list = match field_list { | ||
388 | None => String::new(), | ||
389 | Some(it) => format!("{}", it), | ||
390 | }; | ||
391 | ast_from_text(&format!("enum f {{ {}{} }}", name, field_list)) | ||
392 | } | ||
393 | |||
371 | pub fn fn_( | 394 | pub fn fn_( |
372 | visibility: Option<ast::Visibility>, | 395 | visibility: Option<ast::Visibility>, |
373 | fn_name: ast::Name, | 396 | fn_name: ast::Name, |