aboutsummaryrefslogtreecommitdiff
path: root/crates/ide_assists/src/handlers/generate_function.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/ide_assists/src/handlers/generate_function.rs')
-rw-r--r--crates/ide_assists/src/handlers/generate_function.rs129
1 files changed, 119 insertions, 10 deletions
diff --git a/crates/ide_assists/src/handlers/generate_function.rs b/crates/ide_assists/src/handlers/generate_function.rs
index 959824981..6f95b1a07 100644
--- a/crates/ide_assists/src/handlers/generate_function.rs
+++ b/crates/ide_assists/src/handlers/generate_function.rs
@@ -1,6 +1,7 @@
1use hir::HirDisplay; 1use hir::HirDisplay;
2use ide_db::{base_db::FileId, helpers::SnippetCap}; 2use ide_db::{base_db::FileId, helpers::SnippetCap};
3use rustc_hash::{FxHashMap, FxHashSet}; 3use rustc_hash::{FxHashMap, FxHashSet};
4use stdx::to_lower_snake_case;
4use syntax::{ 5use syntax::{
5 ast::{ 6 ast::{
6 self, 7 self,
@@ -82,17 +83,18 @@ struct FunctionTemplate {
82 leading_ws: String, 83 leading_ws: String,
83 fn_def: ast::Fn, 84 fn_def: ast::Fn,
84 ret_type: ast::RetType, 85 ret_type: ast::RetType,
86 should_render_snippet: bool,
85 trailing_ws: String, 87 trailing_ws: String,
86 file: FileId, 88 file: FileId,
87} 89}
88 90
89impl FunctionTemplate { 91impl FunctionTemplate {
90 fn to_string(&self, cap: Option<SnippetCap>) -> String { 92 fn to_string(&self, cap: Option<SnippetCap>) -> String {
91 let f = match cap { 93 let f = match (cap, self.should_render_snippet) {
92 Some(cap) => { 94 (Some(cap), true) => {
93 render_snippet(cap, self.fn_def.syntax(), Cursor::Replace(self.ret_type.syntax())) 95 render_snippet(cap, self.fn_def.syntax(), Cursor::Replace(self.ret_type.syntax()))
94 } 96 }
95 None => self.fn_def.to_string(), 97 _ => self.fn_def.to_string(),
96 }; 98 };
97 format!("{}{}{}", self.leading_ws, f, self.trailing_ws) 99 format!("{}{}{}", self.leading_ws, f, self.trailing_ws)
98 } 100 }
@@ -103,6 +105,8 @@ struct FunctionBuilder {
103 fn_name: ast::Name, 105 fn_name: ast::Name,
104 type_params: Option<ast::GenericParamList>, 106 type_params: Option<ast::GenericParamList>,
105 params: ast::ParamList, 107 params: ast::ParamList,
108 ret_type: ast::RetType,
109 should_render_snippet: bool,
106 file: FileId, 110 file: FileId,
107 needs_pub: bool, 111 needs_pub: bool,
108} 112}
@@ -131,7 +135,43 @@ impl FunctionBuilder {
131 let fn_name = fn_name(&path)?; 135 let fn_name = fn_name(&path)?;
132 let (type_params, params) = fn_args(ctx, target_module, &call)?; 136 let (type_params, params) = fn_args(ctx, target_module, &call)?;
133 137
134 Some(Self { target, fn_name, type_params, params, file, needs_pub }) 138 // should_render_snippet intends to express a rough level of confidence about
139 // the correctness of the return type.
140 //
141 // If we are able to infer some return type, and that return type is not unit, we
142 // don't want to render the snippet. The assumption here is in this situation the
143 // return type is just as likely to be correct as any other part of the generated
144 // function.
145 //
146 // In the case where the return type is inferred as unit it is likely that the
147 // user does in fact intend for this generated function to return some non unit
148 // type, but that the current state of their code doesn't allow that return type
149 // to be accurately inferred.
150 let (ret_ty, should_render_snippet) = {
151 match ctx.sema.type_of_expr(&ast::Expr::CallExpr(call.clone())) {
152 Some(ty) if ty.is_unknown() || ty.is_unit() => (make::ty_unit(), true),
153 Some(ty) => {
154 let rendered = ty.display_source_code(ctx.db(), target_module.into());
155 match rendered {
156 Ok(rendered) => (make::ty(&rendered), false),
157 Err(_) => (make::ty_unit(), true),
158 }
159 }
160 None => (make::ty_unit(), true),
161 }
162 };
163 let ret_type = make::ret_type(ret_ty);
164
165 Some(Self {
166 target,
167 fn_name,
168 type_params,
169 params,
170 ret_type,
171 should_render_snippet,
172 file,
173 needs_pub,
174 })
135 } 175 }
136 176
137 fn render(self) -> FunctionTemplate { 177 fn render(self) -> FunctionTemplate {
@@ -144,7 +184,7 @@ impl FunctionBuilder {
144 self.type_params, 184 self.type_params,
145 self.params, 185 self.params,
146 fn_body, 186 fn_body,
147 Some(make::ret_type(make::ty_unit())), 187 Some(self.ret_type),
148 ); 188 );
149 let leading_ws; 189 let leading_ws;
150 let trailing_ws; 190 let trailing_ws;
@@ -170,6 +210,7 @@ impl FunctionBuilder {
170 insert_offset, 210 insert_offset,
171 leading_ws, 211 leading_ws,
172 ret_type: fn_def.ret_type().unwrap(), 212 ret_type: fn_def.ret_type().unwrap(),
213 should_render_snippet: self.should_render_snippet,
173 fn_def, 214 fn_def,
174 trailing_ws, 215 trailing_ws,
175 file: self.file, 216 file: self.file,
@@ -257,14 +298,15 @@ fn deduplicate_arg_names(arg_names: &mut Vec<String>) {
257fn fn_arg_name(fn_arg: &ast::Expr) -> Option<String> { 298fn fn_arg_name(fn_arg: &ast::Expr) -> Option<String> {
258 match fn_arg { 299 match fn_arg {
259 ast::Expr::CastExpr(cast_expr) => fn_arg_name(&cast_expr.expr()?), 300 ast::Expr::CastExpr(cast_expr) => fn_arg_name(&cast_expr.expr()?),
260 _ => Some( 301 _ => {
261 fn_arg 302 let s = fn_arg
262 .syntax() 303 .syntax()
263 .descendants() 304 .descendants()
264 .filter(|d| ast::NameRef::can_cast(d.kind())) 305 .filter(|d| ast::NameRef::can_cast(d.kind()))
265 .last()? 306 .last()?
266 .to_string(), 307 .to_string();
267 ), 308 Some(to_lower_snake_case(&s))
309 }
268 } 310 }
269} 311}
270 312
@@ -448,6 +490,52 @@ mod baz {
448 } 490 }
449 491
450 #[test] 492 #[test]
493 fn add_function_with_upper_camel_case_arg() {
494 check_assist(
495 generate_function,
496 r"
497struct BazBaz;
498fn foo() {
499 bar$0(BazBaz);
500}
501",
502 r"
503struct BazBaz;
504fn foo() {
505 bar(BazBaz);
506}
507
508fn bar(baz_baz: BazBaz) ${0:-> ()} {
509 todo!()
510}
511",
512 );
513 }
514
515 #[test]
516 fn add_function_with_upper_camel_case_arg_as_cast() {
517 check_assist(
518 generate_function,
519 r"
520struct BazBaz;
521fn foo() {
522 bar$0(&BazBaz as *const BazBaz);
523}
524",
525 r"
526struct BazBaz;
527fn foo() {
528 bar(&BazBaz as *const BazBaz);
529}
530
531fn bar(baz_baz: *const BazBaz) ${0:-> ()} {
532 todo!()
533}
534",
535 );
536 }
537
538 #[test]
451 fn add_function_with_function_call_arg() { 539 fn add_function_with_function_call_arg() {
452 check_assist( 540 check_assist(
453 generate_function, 541 generate_function,
@@ -498,7 +586,7 @@ impl Baz {
498 } 586 }
499} 587}
500 588
501fn bar(baz: Baz) ${0:-> ()} { 589fn bar(baz: Baz) -> Baz {
502 todo!() 590 todo!()
503} 591}
504", 592",
@@ -1012,6 +1100,27 @@ pub(crate) fn bar() ${0:-> ()} {
1012 } 1100 }
1013 1101
1014 #[test] 1102 #[test]
1103 fn add_function_with_return_type() {
1104 check_assist(
1105 generate_function,
1106 r"
1107fn main() {
1108 let x: u32 = foo$0();
1109}
1110",
1111 r"
1112fn main() {
1113 let x: u32 = foo();
1114}
1115
1116fn foo() -> u32 {
1117 todo!()
1118}
1119",
1120 )
1121 }
1122
1123 #[test]
1015 fn add_function_not_applicable_if_function_already_exists() { 1124 fn add_function_not_applicable_if_function_already_exists() {
1016 check_assist_not_applicable( 1125 check_assist_not_applicable(
1017 generate_function, 1126 generate_function,