aboutsummaryrefslogtreecommitdiff
path: root/crates
diff options
context:
space:
mode:
Diffstat (limited to 'crates')
-rw-r--r--crates/assists/src/handlers/extract_struct_from_enum_variant.rs148
-rw-r--r--crates/ide/src/diagnostics/fixes.rs3
-rw-r--r--crates/syntax/src/ast/make.rs27
3 files changed, 137 insertions, 41 deletions
diff --git a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs
index dddab255e..14209b771 100644
--- a/crates/assists/src/handlers/extract_struct_from_enum_variant.rs
+++ b/crates/assists/src/handlers/extract_struct_from_enum_variant.rs
@@ -1,3 +1,6 @@
1use std::iter;
2
3use either::Either;
1use hir::{AsName, EnumVariant, Module, ModuleDef, Name}; 4use hir::{AsName, EnumVariant, Module, ModuleDef, Name};
2use ide_db::{defs::Definition, search::Reference, RootDatabase}; 5use ide_db::{defs::Definition, search::Reference, RootDatabase};
3use rustc_hash::{FxHashMap, FxHashSet}; 6use rustc_hash::{FxHashMap, FxHashSet};
@@ -31,40 +34,32 @@ pub(crate) fn extract_struct_from_enum_variant(
31 ctx: &AssistContext, 34 ctx: &AssistContext,
32) -> Option<()> { 35) -> Option<()> {
33 let variant = ctx.find_node_at_offset::<ast::Variant>()?; 36 let variant = ctx.find_node_at_offset::<ast::Variant>()?;
34 let field_list = match variant.kind() { 37 let field_list = extract_field_list_if_applicable(&variant)?;
35 ast::StructKind::Tuple(field_list) => field_list,
36 _ => return None,
37 };
38
39 // skip 1-tuple variants
40 if field_list.fields().count() == 1 {
41 return None;
42 }
43 38
44 let variant_name = variant.name()?; 39 let variant_name = variant.name()?;
45 let variant_hir = ctx.sema.to_def(&variant)?; 40 let variant_hir = ctx.sema.to_def(&variant)?;
46 if existing_struct_def(ctx.db(), &variant_name, &variant_hir) { 41 if existing_definition(ctx.db(), &variant_name, &variant_hir) {
47 return None; 42 return None;
48 } 43 }
44
49 let enum_ast = variant.parent_enum(); 45 let enum_ast = variant.parent_enum();
50 let visibility = enum_ast.visibility();
51 let enum_hir = ctx.sema.to_def(&enum_ast)?; 46 let enum_hir = ctx.sema.to_def(&enum_ast)?;
52 let variant_hir_name = variant_hir.name(ctx.db());
53 let enum_module_def = ModuleDef::from(enum_hir);
54 let current_module = enum_hir.module(ctx.db());
55 let target = variant.syntax().text_range(); 47 let target = variant.syntax().text_range();
56 acc.add( 48 acc.add(
57 AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite), 49 AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite),
58 "Extract struct from enum variant", 50 "Extract struct from enum variant",
59 target, 51 target,
60 |builder| { 52 |builder| {
61 let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)); 53 let variant_hir_name = variant_hir.name(ctx.db());
62 let res = definition.usages(&ctx.sema).all(); 54 let enum_module_def = ModuleDef::from(enum_hir);
55 let usages =
56 Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)).usages(&ctx.sema).all();
63 57
64 let mut visited_modules_set = FxHashSet::default(); 58 let mut visited_modules_set = FxHashSet::default();
59 let current_module = enum_hir.module(ctx.db());
65 visited_modules_set.insert(current_module); 60 visited_modules_set.insert(current_module);
66 let mut rewriters = FxHashMap::default(); 61 let mut rewriters = FxHashMap::default();
67 for reference in res { 62 for reference in usages {
68 let rewriter = rewriters 63 let rewriter = rewriters
69 .entry(reference.file_range.file_id) 64 .entry(reference.file_range.file_id)
70 .or_insert_with(SyntaxRewriter::default); 65 .or_insert_with(SyntaxRewriter::default);
@@ -86,26 +81,49 @@ pub(crate) fn extract_struct_from_enum_variant(
86 builder.rewrite(rewriter); 81 builder.rewrite(rewriter);
87 } 82 }
88 builder.edit_file(ctx.frange.file_id); 83 builder.edit_file(ctx.frange.file_id);
89 update_variant(&mut rewriter, &variant_name, &field_list); 84 update_variant(&mut rewriter, &variant);
90 extract_struct_def( 85 extract_struct_def(
91 &mut rewriter, 86 &mut rewriter,
92 &enum_ast, 87 &enum_ast,
93 variant_name.clone(), 88 variant_name.clone(),
94 &field_list, 89 &field_list,
95 &variant.parent_enum().syntax().clone().into(), 90 &variant.parent_enum().syntax().clone().into(),
96 visibility, 91 enum_ast.visibility(),
97 ); 92 );
98 builder.rewrite(rewriter); 93 builder.rewrite(rewriter);
99 }, 94 },
100 ) 95 )
101} 96}
102 97
103fn existing_struct_def(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool { 98fn extract_field_list_if_applicable(
99 variant: &ast::Variant,
100) -> Option<Either<ast::RecordFieldList, ast::TupleFieldList>> {
101 match variant.kind() {
102 ast::StructKind::Record(field_list) if field_list.fields().next().is_some() => {
103 Some(Either::Left(field_list))
104 }
105 ast::StructKind::Tuple(field_list) if field_list.fields().count() > 1 => {
106 Some(Either::Right(field_list))
107 }
108 _ => None,
109 }
110}
111
112fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool {
104 variant 113 variant
105 .parent_enum(db) 114 .parent_enum(db)
106 .module(db) 115 .module(db)
107 .scope(db, None) 116 .scope(db, None)
108 .into_iter() 117 .into_iter()
118 .filter(|(_, def)| match def {
119 // only check type-namespace
120 hir::ScopeDef::ModuleDef(def) => matches!(def,
121 ModuleDef::Module(_) | ModuleDef::Adt(_) |
122 ModuleDef::EnumVariant(_) | ModuleDef::Trait(_) |
123 ModuleDef::TypeAlias(_) | ModuleDef::BuiltinType(_)
124 ),
125 _ => false,
126 })
109 .any(|(name, _)| name == variant_name.as_name()) 127 .any(|(name, _)| name == variant_name.as_name())
110} 128}
111 129
@@ -133,19 +151,29 @@ fn extract_struct_def(
133 rewriter: &mut SyntaxRewriter, 151 rewriter: &mut SyntaxRewriter,
134 enum_: &ast::Enum, 152 enum_: &ast::Enum,
135 variant_name: ast::Name, 153 variant_name: ast::Name,
136 variant_list: &ast::TupleFieldList, 154 field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
137 start_offset: &SyntaxElement, 155 start_offset: &SyntaxElement,
138 visibility: Option<ast::Visibility>, 156 visibility: Option<ast::Visibility>,
139) -> Option<()> { 157) -> Option<()> {
140 let variant_list = make::tuple_field_list( 158 let pub_vis = Some(make::visibility_pub());
141 variant_list 159 let field_list = match field_list {
142 .fields() 160 Either::Left(field_list) => {
143 .flat_map(|field| Some(make::tuple_field(Some(make::visibility_pub()), field.ty()?))), 161 make::record_field_list(field_list.fields().flat_map(|field| {
144 ); 162 Some(make::record_field(pub_vis.clone(), field.name()?, field.ty()?))
163 }))
164 .into()
165 }
166 Either::Right(field_list) => make::tuple_field_list(
167 field_list
168 .fields()
169 .flat_map(|field| Some(make::tuple_field(pub_vis.clone(), field.ty()?))),
170 )
171 .into(),
172 };
145 173
146 rewriter.insert_before( 174 rewriter.insert_before(
147 start_offset, 175 start_offset,
148 make::struct_(visibility, variant_name, None, variant_list.into()).syntax(), 176 make::struct_(visibility, variant_name, None, field_list).syntax(),
149 ); 177 );
150 rewriter.insert_before(start_offset, &make::tokens::blank_line()); 178 rewriter.insert_before(start_offset, &make::tokens::blank_line());
151 179
@@ -156,15 +184,14 @@ fn extract_struct_def(
156 Some(()) 184 Some(())
157} 185}
158 186
159fn update_variant( 187fn update_variant(rewriter: &mut SyntaxRewriter, variant: &ast::Variant) -> Option<()> {
160 rewriter: &mut SyntaxRewriter, 188 let name = variant.name()?;
161 variant_name: &ast::Name, 189 let tuple_field = make::tuple_field(None, make::ty(name.text()));
162 field_list: &ast::TupleFieldList, 190 let replacement = make::variant(
163) -> Option<()> { 191 name,
164 let (l, r): (SyntaxElement, SyntaxElement) = 192 Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))),
165 (field_list.l_paren_token()?.into(), field_list.r_paren_token()?.into()); 193 );
166 let replacement = vec![l, variant_name.syntax().clone().into(), r]; 194 rewriter.replace(variant.syntax(), replacement.syntax());
167 rewriter.replace_with_many(field_list.syntax(), replacement);
168 Some(()) 195 Some(())
169} 196}
170 197
@@ -211,7 +238,7 @@ mod tests {
211 use super::*; 238 use super::*;
212 239
213 #[test] 240 #[test]
214 fn test_extract_struct_several_fields() { 241 fn test_extract_struct_several_fields_tuple() {
215 check_assist( 242 check_assist(
216 extract_struct_from_enum_variant, 243 extract_struct_from_enum_variant,
217 "enum A { <|>One(u32, u32) }", 244 "enum A { <|>One(u32, u32) }",
@@ -222,6 +249,41 @@ enum A { One(One) }"#,
222 } 249 }
223 250
224 #[test] 251 #[test]
252 fn test_extract_struct_several_fields_named() {
253 check_assist(
254 extract_struct_from_enum_variant,
255 "enum A { <|>One { foo: u32, bar: u32 } }",
256 r#"struct One{ pub foo: u32, pub bar: u32 }
257
258enum A { One(One) }"#,
259 );
260 }
261
262 #[test]
263 fn test_extract_struct_one_field_named() {
264 check_assist(
265 extract_struct_from_enum_variant,
266 "enum A { <|>One { foo: u32 } }",
267 r#"struct One{ pub foo: u32 }
268
269enum A { One(One) }"#,
270 );
271 }
272
273 #[test]
274 fn test_extract_enum_variant_name_value_namespace() {
275 check_assist(
276 extract_struct_from_enum_variant,
277 r#"const One: () = ();
278enum A { <|>One(u32, u32) }"#,
279 r#"const One: () = ();
280struct One(pub u32, pub u32);
281
282enum A { One(One) }"#,
283 );
284 }
285
286 #[test]
225 fn test_extract_struct_pub_visibility() { 287 fn test_extract_struct_pub_visibility() {
226 check_assist( 288 check_assist(
227 extract_struct_from_enum_variant, 289 extract_struct_from_enum_variant,
@@ -298,7 +360,7 @@ fn another_fn() {
298 fn test_extract_enum_not_applicable_if_struct_exists() { 360 fn test_extract_enum_not_applicable_if_struct_exists() {
299 check_not_applicable( 361 check_not_applicable(
300 r#"struct One; 362 r#"struct One;
301 enum A { <|>One(u8) }"#, 363 enum A { <|>One(u8, u32) }"#,
302 ); 364 );
303 } 365 }
304 366
@@ -306,4 +368,14 @@ fn another_fn() {
306 fn test_extract_not_applicable_one_field() { 368 fn test_extract_not_applicable_one_field() {
307 check_not_applicable(r"enum A { <|>One(u32) }"); 369 check_not_applicable(r"enum A { <|>One(u32) }");
308 } 370 }
371
372 #[test]
373 fn test_extract_not_applicable_no_field_tuple() {
374 check_not_applicable(r"enum A { <|>None() }");
375 }
376
377 #[test]
378 fn test_extract_not_applicable_no_field_named() {
379 check_not_applicable(r"enum A { <|>None {} }");
380 }
309} 381}
diff --git a/crates/ide/src/diagnostics/fixes.rs b/crates/ide/src/diagnostics/fixes.rs
index 02e17ba43..d275dd75b 100644
--- a/crates/ide/src/diagnostics/fixes.rs
+++ b/crates/ide/src/diagnostics/fixes.rs
@@ -157,7 +157,8 @@ fn missing_record_expr_field_fix(
157 return None; 157 return None;
158 } 158 }
159 let new_field = make::record_field( 159 let new_field = make::record_field(
160 record_expr_field.field_name()?, 160 None,
161 make::name(record_expr_field.field_name()?.text()),
161 make::ty(&new_field_type.display_source_code(sema.db, module.into()).ok()?), 162 make::ty(&new_field_type.display_source_code(sema.db, module.into()).ok()?),
162 ); 163 );
163 164
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index 2cf436e7a..b1578820f 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -110,8 +110,16 @@ pub fn record_expr_field(name: ast::NameRef, expr: Option<ast::Expr>) -> ast::Re
110 } 110 }
111} 111}
112 112
113pub fn record_field(name: ast::NameRef, ty: ast::Type) -> ast::RecordField { 113pub fn record_field(
114 ast_from_text(&format!("struct S {{ {}: {}, }}", name, ty)) 114 visibility: Option<ast::Visibility>,
115 name: ast::Name,
116 ty: ast::Type,
117) -> ast::RecordField {
118 let visibility = match visibility {
119 None => String::new(),
120 Some(it) => format!("{} ", it),
121 };
122 ast_from_text(&format!("struct S {{ {}{}: {}, }}", visibility, name, ty))
115} 123}
116 124
117pub fn block_expr( 125pub fn block_expr(
@@ -360,6 +368,13 @@ pub fn tuple_field_list(fields: impl IntoIterator<Item = ast::TupleField>) -> as
360 ast_from_text(&format!("struct f({});", fields)) 368 ast_from_text(&format!("struct f({});", fields))
361} 369}
362 370
371pub fn record_field_list(
372 fields: impl IntoIterator<Item = ast::RecordField>,
373) -> ast::RecordFieldList {
374 let fields = fields.into_iter().join(", ");
375 ast_from_text(&format!("struct f {{ {} }}", fields))
376}
377
363pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::TupleField { 378pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::TupleField {
364 let visibility = match visibility { 379 let visibility = match visibility {
365 None => String::new(), 380 None => String::new(),
@@ -368,6 +383,14 @@ pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::T
368 ast_from_text(&format!("struct f({}{});", visibility, ty)) 383 ast_from_text(&format!("struct f({}{});", visibility, ty))
369} 384}
370 385
386pub fn variant(name: ast::Name, field_list: Option<ast::FieldList>) -> ast::Variant {
387 let field_list = match field_list {
388 None => String::new(),
389 Some(it) => format!("{}", it),
390 };
391 ast_from_text(&format!("enum f {{ {}{} }}", name, field_list))
392}
393
371pub fn fn_( 394pub fn fn_(
372 visibility: Option<ast::Visibility>, 395 visibility: Option<ast::Visibility>,
373 fn_name: ast::Name, 396 fn_name: ast::Name,