aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs169
-rw-r--r--crates/syntax/src/ast/make.rs5
2 files changed, 127 insertions, 47 deletions
diff --git a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
index 007aba23d..d3ff7b65c 100644
--- a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
+++ b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
@@ -11,14 +11,19 @@ use ide_db::{
11 search::FileReference, 11 search::FileReference,
12 RootDatabase, 12 RootDatabase,
13}; 13};
14use itertools::Itertools;
14use rustc_hash::FxHashSet; 15use rustc_hash::FxHashSet;
15use syntax::{ 16use syntax::{
16 algo::find_node_at_offset, 17 ast::{
17 ast::{self, make, AstNode, NameOwner, VisibilityOwner}, 18 self, make, AstNode, AttrsOwner, GenericParamsOwner, NameOwner, TypeBoundsOwner,
18 ted, SyntaxNode, T, 19 VisibilityOwner,
20 },
21 match_ast,
22 ted::{self, Position},
23 SyntaxNode, T,
19}; 24};
20 25
21use crate::{AssistContext, AssistId, AssistKind, Assists}; 26use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists};
22 27
23// Assist: extract_struct_from_enum_variant 28// Assist: extract_struct_from_enum_variant
24// 29//
@@ -70,11 +75,10 @@ pub(crate) fn extract_struct_from_enum_variant(
70 continue; 75 continue;
71 } 76 }
72 builder.edit_file(file_id); 77 builder.edit_file(file_id);
73 let source_file = builder.make_mut(ctx.sema.parse(file_id));
74 let processed = process_references( 78 let processed = process_references(
75 ctx, 79 ctx,
80 builder,
76 &mut visited_modules_set, 81 &mut visited_modules_set,
77 source_file.syntax(),
78 &enum_module_def, 82 &enum_module_def,
79 &variant_hir_name, 83 &variant_hir_name,
80 references, 84 references,
@@ -84,13 +88,12 @@ pub(crate) fn extract_struct_from_enum_variant(
84 }); 88 });
85 } 89 }
86 builder.edit_file(ctx.frange.file_id); 90 builder.edit_file(ctx.frange.file_id);
87 let source_file = builder.make_mut(ctx.sema.parse(ctx.frange.file_id));
88 let variant = builder.make_mut(variant.clone()); 91 let variant = builder.make_mut(variant.clone());
89 if let Some(references) = def_file_references { 92 if let Some(references) = def_file_references {
90 let processed = process_references( 93 let processed = process_references(
91 ctx, 94 ctx,
95 builder,
92 &mut visited_modules_set, 96 &mut visited_modules_set,
93 source_file.syntax(),
94 &enum_module_def, 97 &enum_module_def,
95 &variant_hir_name, 98 &variant_hir_name,
96 references, 99 references,
@@ -100,12 +103,12 @@ pub(crate) fn extract_struct_from_enum_variant(
100 }); 103 });
101 } 104 }
102 105
103 let def = create_struct_def(variant_name.clone(), &field_list, enum_ast.visibility()); 106 let def = create_struct_def(variant_name.clone(), &field_list, &enum_ast);
104 let start_offset = &variant.parent_enum().syntax().clone(); 107 let start_offset = &variant.parent_enum().syntax().clone();
105 ted::insert_raw(ted::Position::before(start_offset), def.syntax()); 108 ted::insert_raw(ted::Position::before(start_offset), def.syntax());
106 ted::insert_raw(ted::Position::before(start_offset), &make::tokens::blank_line()); 109 ted::insert_raw(ted::Position::before(start_offset), &make::tokens::blank_line());
107 110
108 update_variant(&variant); 111 update_variant(&variant, enum_ast.generic_param_list());
109 }, 112 },
110 ) 113 )
111} 114}
@@ -149,7 +152,7 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Va
149fn create_struct_def( 152fn create_struct_def(
150 variant_name: ast::Name, 153 variant_name: ast::Name,
151 field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>, 154 field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
152 visibility: Option<ast::Visibility>, 155 enum_: &ast::Enum,
153) -> ast::Struct { 156) -> ast::Struct {
154 let pub_vis = make::visibility_pub(); 157 let pub_vis = make::visibility_pub();
155 158
@@ -184,12 +187,38 @@ fn create_struct_def(
184 } 187 }
185 }; 188 };
186 189
187 make::struct_(visibility, variant_name, None, field_list).clone_for_update() 190 // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
191 let strukt =
192 make::struct_(enum_.visibility(), variant_name, enum_.generic_param_list(), field_list)
193 .clone_for_update();
194
195 // copy attributes
196 ted::insert_all(
197 Position::first_child_of(strukt.syntax()),
198 enum_.attrs().map(|it| it.syntax().clone_for_update().into()).collect(),
199 );
200 strukt
188} 201}
189 202
190fn update_variant(variant: &ast::Variant) -> Option<()> { 203fn update_variant(variant: &ast::Variant, generic: Option<ast::GenericParamList>) -> Option<()> {
191 let name = variant.name()?; 204 let name = variant.name()?;
192 let tuple_field = make::tuple_field(None, make::ty(&name.text())); 205 let ty = match generic {
206 // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
207 Some(gpl) => {
208 let gpl = gpl.clone_for_update();
209 gpl.generic_params().for_each(|gp| {
210 match gp {
211 ast::GenericParam::LifetimeParam(it) => it.type_bound_list(),
212 ast::GenericParam::TypeParam(it) => it.type_bound_list(),
213 ast::GenericParam::ConstParam(_) => return,
214 }
215 .map(|it| it.remove());
216 });
217 make::ty(&format!("{}<{}>", name.text(), gpl.generic_params().join(", ")))
218 }
219 None => make::ty(&name.text()),
220 };
221 let tuple_field = make::tuple_field(None, ty);
193 let replacement = make::variant( 222 let replacement = make::variant(
194 name, 223 name,
195 Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))), 224 Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))),
@@ -208,18 +237,17 @@ fn apply_references(
208 if let Some((scope, path)) = import { 237 if let Some((scope, path)) = import {
209 insert_use(&scope, mod_path_to_ast(&path), insert_use_cfg); 238 insert_use(&scope, mod_path_to_ast(&path), insert_use_cfg);
210 } 239 }
211 ted::insert_raw( 240 // deep clone to prevent cycle
212 ted::Position::before(segment.syntax()), 241 let path = make::path_from_segments(iter::once(segment.clone_subtree()), false);
213 make::path_from_text(&format!("{}", segment)).clone_for_update().syntax(), 242 ted::insert_raw(ted::Position::before(segment.syntax()), path.clone_for_update().syntax());
214 );
215 ted::insert_raw(ted::Position::before(segment.syntax()), make::token(T!['('])); 243 ted::insert_raw(ted::Position::before(segment.syntax()), make::token(T!['(']));
216 ted::insert_raw(ted::Position::after(&node), make::token(T![')'])); 244 ted::insert_raw(ted::Position::after(&node), make::token(T![')']));
217} 245}
218 246
219fn process_references( 247fn process_references(
220 ctx: &AssistContext, 248 ctx: &AssistContext,
249 builder: &mut AssistBuilder,
221 visited_modules: &mut FxHashSet<Module>, 250 visited_modules: &mut FxHashSet<Module>,
222 source_file: &SyntaxNode,
223 enum_module_def: &ModuleDef, 251 enum_module_def: &ModuleDef,
224 variant_hir_name: &Name, 252 variant_hir_name: &Name,
225 refs: Vec<FileReference>, 253 refs: Vec<FileReference>,
@@ -228,8 +256,9 @@ fn process_references(
228 // and corresponding nodes up front 256 // and corresponding nodes up front
229 refs.into_iter() 257 refs.into_iter()
230 .flat_map(|reference| { 258 .flat_map(|reference| {
231 let (segment, scope_node, module) = 259 let (segment, scope_node, module) = reference_to_node(&ctx.sema, reference)?;
232 reference_to_node(&ctx.sema, source_file, reference)?; 260 let segment = builder.make_mut(segment);
261 let scope_node = builder.make_syntax_mut(scope_node);
233 if !visited_modules.contains(&module) { 262 if !visited_modules.contains(&module) {
234 let mod_path = module.find_use_path_prefixed( 263 let mod_path = module.find_use_path_prefixed(
235 ctx.sema.db, 264 ctx.sema.db,
@@ -251,23 +280,22 @@ fn process_references(
251 280
252fn reference_to_node( 281fn reference_to_node(
253 sema: &hir::Semantics<RootDatabase>, 282 sema: &hir::Semantics<RootDatabase>,
254 source_file: &SyntaxNode,
255 reference: FileReference, 283 reference: FileReference,
256) -> Option<(ast::PathSegment, SyntaxNode, hir::Module)> { 284) -> Option<(ast::PathSegment, SyntaxNode, hir::Module)> {
257 let offset = reference.range.start(); 285 let segment =
258 if let Some(path_expr) = find_node_at_offset::<ast::PathExpr>(source_file, offset) { 286 reference.name.as_name_ref()?.syntax().parent().and_then(ast::PathSegment::cast)?;
259 // tuple variant 287 let parent = segment.parent_path().syntax().parent()?;
260 Some((path_expr.path()?.segment()?, path_expr.syntax().parent()?)) 288 let expr_or_pat = match_ast! {
261 } else if let Some(record_expr) = find_node_at_offset::<ast::RecordExpr>(source_file, offset) { 289 match parent {
262 // record variant 290 ast::PathExpr(_it) => parent.parent()?,
263 Some((record_expr.path()?.segment()?, record_expr.syntax().clone())) 291 ast::RecordExpr(_it) => parent,
264 } else { 292 ast::TupleStructPat(_it) => parent,
265 None 293 ast::RecordPat(_it) => parent,
266 } 294 _ => return None,
267 .and_then(|(segment, expr)| { 295 }
268 let module = sema.scope(&expr).module()?; 296 };
269 Some((segment, expr, module)) 297 let module = sema.scope(&expr_or_pat).module()?;
270 }) 298 Some((segment, expr_or_pat, module))
271} 299}
272 300
273#[cfg(test)] 301#[cfg(test)]
@@ -278,6 +306,12 @@ mod tests {
278 306
279 use super::*; 307 use super::*;
280 308
309 fn check_not_applicable(ra_fixture: &str) {
310 let fixture =
311 format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE);
312 check_assist_not_applicable(extract_struct_from_enum_variant, &fixture)
313 }
314
281 #[test] 315 #[test]
282 fn test_extract_struct_several_fields_tuple() { 316 fn test_extract_struct_several_fields_tuple() {
283 check_assist( 317 check_assist(
@@ -312,6 +346,32 @@ enum A { One(One) }"#,
312 } 346 }
313 347
314 #[test] 348 #[test]
349 fn test_extract_struct_carries_over_generics() {
350 check_assist(
351 extract_struct_from_enum_variant,
352 r"enum En<T> { Var { a: T$0 } }",
353 r#"struct Var<T>{ pub a: T }
354
355enum En<T> { Var(Var<T>) }"#,
356 );
357 }
358
359 #[test]
360 fn test_extract_struct_carries_over_attributes() {
361 check_assist(
362 extract_struct_from_enum_variant,
363 r#"#[derive(Debug)]
364#[derive(Clone)]
365enum Enum { Variant{ field: u32$0 } }"#,
366 r#"#[derive(Debug)]#[derive(Clone)] struct Variant{ pub field: u32 }
367
368#[derive(Debug)]
369#[derive(Clone)]
370enum Enum { Variant(Variant) }"#,
371 );
372 }
373
374 #[test]
315 fn test_extract_struct_keep_comments_and_attrs_one_field_named() { 375 fn test_extract_struct_keep_comments_and_attrs_one_field_named() {
316 check_assist( 376 check_assist(
317 extract_struct_from_enum_variant, 377 extract_struct_from_enum_variant,
@@ -496,7 +556,7 @@ enum E {
496} 556}
497 557
498fn f() { 558fn f() {
499 let e = E::V { i: 9, j: 2 }; 559 let E::V { i, j } = E::V { i: 9, j: 2 };
500} 560}
501"#, 561"#,
502 r#" 562 r#"
@@ -507,7 +567,34 @@ enum E {
507} 567}
508 568
509fn f() { 569fn f() {
510 let e = E::V(V { i: 9, j: 2 }); 570 let E::V(V { i, j }) = E::V(V { i: 9, j: 2 });
571}
572"#,
573 )
574 }
575
576 #[test]
577 fn extract_record_fix_references2() {
578 check_assist(
579 extract_struct_from_enum_variant,
580 r#"
581enum E {
582 $0V(i32, i32)
583}
584
585fn f() {
586 let E::V(i, j) = E::V(9, 2);
587}
588"#,
589 r#"
590struct V(pub i32, pub i32);
591
592enum E {
593 V(V)
594}
595
596fn f() {
597 let E::V(V(i, j)) = E::V(V(9, 2));
511} 598}
512"#, 599"#,
513 ) 600 )
@@ -610,12 +697,6 @@ fn foo() {
610 ); 697 );
611 } 698 }
612 699
613 fn check_not_applicable(ra_fixture: &str) {
614 let fixture =
615 format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE);
616 check_assist_not_applicable(extract_struct_from_enum_variant, &fixture)
617 }
618
619 #[test] 700 #[test]
620 fn test_extract_enum_not_applicable_for_element_with_no_fields() { 701 fn test_extract_enum_not_applicable_for_element_with_no_fields() {
621 check_not_applicable("enum A { $0One }"); 702 check_not_applicable("enum A { $0One }");
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index 0cf170626..4c3c9661d 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -580,12 +580,11 @@ pub fn fn_(
580pub fn struct_( 580pub fn struct_(
581 visibility: Option<ast::Visibility>, 581 visibility: Option<ast::Visibility>,
582 strukt_name: ast::Name, 582 strukt_name: ast::Name,
583 type_params: Option<ast::GenericParamList>, 583 generic_param_list: Option<ast::GenericParamList>,
584 field_list: ast::FieldList, 584 field_list: ast::FieldList,
585) -> ast::Struct { 585) -> ast::Struct {
586 let semicolon = if matches!(field_list, ast::FieldList::TupleFieldList(_)) { ";" } else { "" }; 586 let semicolon = if matches!(field_list, ast::FieldList::TupleFieldList(_)) { ";" } else { "" };
587 let type_params = 587 let type_params = generic_param_list.map_or_else(String::new, |it| it.to_string());
588 if let Some(type_params) = type_params { format!("<{}>", type_params) } else { "".into() };
589 let visibility = match visibility { 588 let visibility = match visibility {
590 None => String::new(), 589 None => String::new(),
591 Some(it) => format!("{} ", it), 590 Some(it) => format!("{} ", it),