aboutsummaryrefslogtreecommitdiff
path: root/crates/assists
diff options
context:
space:
mode:
Diffstat (limited to 'crates/assists')
-rw-r--r--crates/assists/src/handlers/extract_struct_from_enum_variant.rs120
1 files changed, 74 insertions, 46 deletions
diff --git a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs
index 1bf5a4214..14209b771 100644
--- a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs
+++ b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs
@@ -1,3 +1,6 @@
1use std::iter;
2
3use either::Either;
1use hir::{AsName, EnumVariant, Module, ModuleDef, Name}; 4use hir::{AsName, EnumVariant, Module, ModuleDef, Name};
2use ide_db::{defs::Definition, search::Reference, RootDatabase}; 5use ide_db::{defs::Definition, search::Reference, RootDatabase};
3use rustc_hash::{FxHashMap, FxHashSet}; 6use rustc_hash::{FxHashMap, FxHashSet};
@@ -31,48 +34,32 @@ pub(crate) fn extract_struct_from_enum_variant(
31 ctx: &AssistContext, 34 ctx: &AssistContext,
32) -> Option<()> { 35) -> Option<()> {
33 let variant = ctx.find_node_at_offset::<ast::Variant>()?; 36 let variant = ctx.find_node_at_offset::<ast::Variant>()?;
34 37 let field_list = extract_field_list_if_applicable(&variant)?;
35 fn is_applicable_variant(variant: &ast::Variant) -> bool {
36 1 < match variant.kind() {
37 ast::StructKind::Record(field_list) => field_list.fields().count(),
38 ast::StructKind::Tuple(field_list) => field_list.fields().count(),
39 ast::StructKind::Unit => 0,
40 }
41 }
42
43 if !is_applicable_variant(&variant) {
44 return None;
45 }
46
47 let field_list = match variant.kind() {
48 ast::StructKind::Tuple(field_list) => field_list,
49 _ => return None,
50 };
51 38
52 let variant_name = variant.name()?; 39 let variant_name = variant.name()?;
53 let variant_hir = ctx.sema.to_def(&variant)?; 40 let variant_hir = ctx.sema.to_def(&variant)?;
54 if existing_definition(ctx.db(), &variant_name, &variant_hir) { 41 if existing_definition(ctx.db(), &variant_name, &variant_hir) {
55 return None; 42 return None;
56 } 43 }
44
57 let enum_ast = variant.parent_enum(); 45 let enum_ast = variant.parent_enum();
58 let visibility = enum_ast.visibility();
59 let enum_hir = ctx.sema.to_def(&enum_ast)?; 46 let enum_hir = ctx.sema.to_def(&enum_ast)?;
60 let variant_hir_name = variant_hir.name(ctx.db());
61 let enum_module_def = ModuleDef::from(enum_hir);
62 let current_module = enum_hir.module(ctx.db());
63 let target = variant.syntax().text_range(); 47 let target = variant.syntax().text_range();
64 acc.add( 48 acc.add(
65 AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite), 49 AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite),
66 "Extract struct from enum variant", 50 "Extract struct from enum variant",
67 target, 51 target,
68 |builder| { 52 |builder| {
69 let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)); 53 let variant_hir_name = variant_hir.name(ctx.db());
70 let res = definition.usages(&ctx.sema).all(); 54 let enum_module_def = ModuleDef::from(enum_hir);
55 let usages =
56 Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)).usages(&ctx.sema).all();
71 57
72 let mut visited_modules_set = FxHashSet::default(); 58 let mut visited_modules_set = FxHashSet::default();
59 let current_module = enum_hir.module(ctx.db());
73 visited_modules_set.insert(current_module); 60 visited_modules_set.insert(current_module);
74 let mut rewriters = FxHashMap::default(); 61 let mut rewriters = FxHashMap::default();
75 for reference in res { 62 for reference in usages {
76 let rewriter = rewriters 63 let rewriter = rewriters
77 .entry(reference.file_range.file_id) 64 .entry(reference.file_range.file_id)
78 .or_insert_with(SyntaxRewriter::default); 65 .or_insert_with(SyntaxRewriter::default);
@@ -94,20 +81,34 @@ pub(crate) fn extract_struct_from_enum_variant(
94 builder.rewrite(rewriter); 81 builder.rewrite(rewriter);
95 } 82 }
96 builder.edit_file(ctx.frange.file_id); 83 builder.edit_file(ctx.frange.file_id);
97 update_variant(&mut rewriter, &variant_name, &field_list); 84 update_variant(&mut rewriter, &variant);
98 extract_struct_def( 85 extract_struct_def(
99 &mut rewriter, 86 &mut rewriter,
100 &enum_ast, 87 &enum_ast,
101 variant_name.clone(), 88 variant_name.clone(),
102 &field_list, 89 &field_list,
103 &variant.parent_enum().syntax().clone().into(), 90 &variant.parent_enum().syntax().clone().into(),
104 visibility, 91 enum_ast.visibility(),
105 ); 92 );
106 builder.rewrite(rewriter); 93 builder.rewrite(rewriter);
107 }, 94 },
108 ) 95 )
109} 96}
110 97
98fn extract_field_list_if_applicable(
99 variant: &ast::Variant,
100) -> Option<Either<ast::RecordFieldList, ast::TupleFieldList>> {
101 match variant.kind() {
102 ast::StructKind::Record(field_list) if field_list.fields().next().is_some() => {
103 Some(Either::Left(field_list))
104 }
105 ast::StructKind::Tuple(field_list) if field_list.fields().count() > 1 => {
106 Some(Either::Right(field_list))
107 }
108 _ => None,
109 }
110}
111
111fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool { 112fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool {
112 variant 113 variant
113 .parent_enum(db) 114 .parent_enum(db)
@@ -150,19 +151,29 @@ fn extract_struct_def(
150 rewriter: &mut SyntaxRewriter, 151 rewriter: &mut SyntaxRewriter,
151 enum_: &ast::Enum, 152 enum_: &ast::Enum,
152 variant_name: ast::Name, 153 variant_name: ast::Name,
153 variant_list: &ast::TupleFieldList, 154 field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
154 start_offset: &SyntaxElement, 155 start_offset: &SyntaxElement,
155 visibility: Option<ast::Visibility>, 156 visibility: Option<ast::Visibility>,
156) -> Option<()> { 157) -> Option<()> {
157 let variant_list = make::tuple_field_list( 158 let pub_vis = Some(make::visibility_pub());
158 variant_list 159 let field_list = match field_list {
159 .fields() 160 Either::Left(field_list) => {
160 .flat_map(|field| Some(make::tuple_field(Some(make::visibility_pub()), field.ty()?))), 161 make::record_field_list(field_list.fields().flat_map(|field| {
161 ); 162 Some(make::record_field(pub_vis.clone(), field.name()?, field.ty()?))
163 }))
164 .into()
165 }
166 Either::Right(field_list) => make::tuple_field_list(
167 field_list
168 .fields()
169 .flat_map(|field| Some(make::tuple_field(pub_vis.clone(), field.ty()?))),
170 )
171 .into(),
172 };
162 173
163 rewriter.insert_before( 174 rewriter.insert_before(
164 start_offset, 175 start_offset,
165 make::struct_(visibility, variant_name, None, variant_list.into()).syntax(), 176 make::struct_(visibility, variant_name, None, field_list).syntax(),
166 ); 177 );
167 rewriter.insert_before(start_offset, &make::tokens::blank_line()); 178 rewriter.insert_before(start_offset, &make::tokens::blank_line());
168 179
@@ -173,15 +184,14 @@ fn extract_struct_def(
173 Some(()) 184 Some(())
174} 185}
175 186
176fn update_variant( 187fn update_variant(rewriter: &mut SyntaxRewriter, variant: &ast::Variant) -> Option<()> {
177 rewriter: &mut SyntaxRewriter, 188 let name = variant.name()?;
178 variant_name: &ast::Name, 189 let tuple_field = make::tuple_field(None, make::ty(name.text()));
179 field_list: &ast::TupleFieldList, 190 let replacement = make::variant(
180) -> Option<()> { 191 name,
181 let (l, r): (SyntaxElement, SyntaxElement) = 192 Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))),
182 (field_list.l_paren_token()?.into(), field_list.r_paren_token()?.into()); 193 );
183 let replacement = vec![l, variant_name.syntax().clone().into(), r]; 194 rewriter.replace(variant.syntax(), replacement.syntax());
184 rewriter.replace_with_many(field_list.syntax(), replacement);
185 Some(()) 195 Some(())
186} 196}
187 197
@@ -243,10 +253,18 @@ enum A { One(One) }"#,
243 check_assist( 253 check_assist(
244 extract_struct_from_enum_variant, 254 extract_struct_from_enum_variant,
245 "enum A { <|>One { foo: u32, bar: u32 } }", 255 "enum A { <|>One { foo: u32, bar: u32 } }",
246 r#"struct One { 256 r#"struct One{ pub foo: u32, pub bar: u32 }
247 pub foo: u32, 257
248 pub bar: u32 258enum A { One(One) }"#,
249} 259 );
260 }
261
262 #[test]
263 fn test_extract_struct_one_field_named() {
264 check_assist(
265 extract_struct_from_enum_variant,
266 "enum A { <|>One { foo: u32 } }",
267 r#"struct One{ pub foo: u32 }
250 268
251enum A { One(One) }"#, 269enum A { One(One) }"#,
252 ); 270 );
@@ -350,4 +368,14 @@ fn another_fn() {
350 fn test_extract_not_applicable_one_field() { 368 fn test_extract_not_applicable_one_field() {
351 check_not_applicable(r"enum A { <|>One(u32) }"); 369 check_not_applicable(r"enum A { <|>One(u32) }");
352 } 370 }
371
372 #[test]
373 fn test_extract_not_applicable_no_field_tuple() {
374 check_not_applicable(r"enum A { <|>None() }");
375 }
376
377 #[test]
378 fn test_extract_not_applicable_no_field_named() {
379 check_not_applicable(r"enum A { <|>None {} }");
380 }
353} 381}