aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--crates/ra_assists/src/assists/early_return.rs95
-rw-r--r--crates/ra_syntax/src/ast/make.rs52
2 files changed, 102 insertions, 45 deletions
diff --git a/crates/ra_assists/src/assists/early_return.rs b/crates/ra_assists/src/assists/early_return.rs
index 8507a60fb..264412526 100644
--- a/crates/ra_assists/src/assists/early_return.rs
+++ b/crates/ra_assists/src/assists/early_return.rs
@@ -1,4 +1,4 @@
1use std::ops::RangeInclusive; 1use std::{iter::once, ops::RangeInclusive};
2 2
3use hir::db::HirDatabase; 3use hir::db::HirDatabase;
4use ra_syntax::{ 4use ra_syntax::{
@@ -45,19 +45,22 @@ pub(crate) fn convert_to_guarded_return(ctx: AssistCtx<impl HirDatabase>) -> Opt
45 let cond = if_expr.condition()?; 45 let cond = if_expr.condition()?;
46 46
47 // Check if there is an IfLet that we can handle. 47 // Check if there is an IfLet that we can handle.
48 let bound_ident = match cond.pat() { 48 let if_let_pat = match cond.pat() {
49 None => None, // No IfLet, supported. 49 None => None, // No IfLet, supported.
50 Some(TupleStructPat(pat)) if pat.args().count() == 1 => { 50 Some(TupleStructPat(pat)) if pat.args().count() == 1 => {
51 let path = pat.path()?; 51 let path = pat.path()?;
52 match path.qualifier() { 52 match path.qualifier() {
53 None => Some(path.segment()?.name_ref()?), 53 None => {
54 let bound_ident = pat.args().next().unwrap();
55 Some((path, bound_ident))
56 }
54 Some(_) => return None, 57 Some(_) => return None,
55 } 58 }
56 } 59 }
57 Some(_) => return None, // Unsupported IfLet. 60 Some(_) => return None, // Unsupported IfLet.
58 }; 61 };
59 62
60 let expr = cond.expr()?; 63 let cond_expr = cond.expr()?;
61 let then_block = if_expr.then_branch()?.block()?; 64 let then_block = if_expr.then_branch()?.block()?;
62 65
63 let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::Block::cast)?; 66 let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::Block::cast)?;
@@ -79,11 +82,11 @@ pub(crate) fn convert_to_guarded_return(ctx: AssistCtx<impl HirDatabase>) -> Opt
79 82
80 let parent_container = parent_block.syntax().parent()?.parent()?; 83 let parent_container = parent_block.syntax().parent()?.parent()?;
81 84
82 let early_expression = match parent_container.kind() { 85 let early_expression: ast::Expr = match parent_container.kind() {
83 WHILE_EXPR | LOOP_EXPR => Some("continue"), 86 WHILE_EXPR | LOOP_EXPR => make::expr_continue().into(),
84 FN_DEF => Some("return"), 87 FN_DEF => make::expr_return().into(),
85 _ => None, 88 _ => return None,
86 }?; 89 };
87 90
88 if then_block.syntax().first_child_or_token().map(|t| t.kind() == L_CURLY).is_none() { 91 if then_block.syntax().first_child_or_token().map(|t| t.kind() == L_CURLY).is_none() {
89 return None; 92 return None;
@@ -94,22 +97,43 @@ pub(crate) fn convert_to_guarded_return(ctx: AssistCtx<impl HirDatabase>) -> Opt
94 97
95 ctx.add_assist(AssistId("convert_to_guarded_return"), "convert to guarded return", |edit| { 98 ctx.add_assist(AssistId("convert_to_guarded_return"), "convert to guarded return", |edit| {
96 let if_indent_level = IndentLevel::from_node(&if_expr.syntax()); 99 let if_indent_level = IndentLevel::from_node(&if_expr.syntax());
97 let new_block = match bound_ident { 100 let new_block = match if_let_pat {
98 None => { 101 None => {
99 // If. 102 // If.
100 let early_expression = &(early_expression.to_owned() + ";"); 103 let early_expression = &(early_expression.syntax().to_string() + ";");
101 let new_expr = 104 let new_expr = if_indent_level
102 if_indent_level.increase_indent(make::if_expression(&expr, early_expression)); 105 .increase_indent(make::if_expression(&cond_expr, early_expression));
103 replace(new_expr.syntax(), &then_block, &parent_block, &if_expr) 106 replace(new_expr.syntax(), &then_block, &parent_block, &if_expr)
104 } 107 }
105 Some(bound_ident) => { 108 Some((path, bound_ident)) => {
106 // If-let. 109 // If-let.
107 let new_expr = if_indent_level.increase_indent(make::let_match_early( 110 let match_expr = {
108 expr, 111 let happy_arm = make::match_arm(
109 &bound_ident.syntax().to_string(), 112 once(
110 early_expression, 113 make::tuple_struct_pat(
111 )); 114 path,
112 replace(new_expr.syntax(), &then_block, &parent_block, &if_expr) 115 once(make::bind_pat(make::name("it")).into()),
116 )
117 .into(),
118 ),
119 make::expr_path(make::path_from_name_ref(make::name_ref("it"))).into(),
120 );
121
122 let sad_arm = make::match_arm(
123 // FIXME: would be cool to use `None` or `Err(_)` if appropriate
124 once(make::placeholder_pat().into()),
125 early_expression.into(),
126 );
127
128 make::expr_match(cond_expr, make::match_arm_list(vec![happy_arm, sad_arm]))
129 };
130
131 let let_stmt = make::let_stmt(
132 make::bind_pat(make::name(&bound_ident.syntax().to_string())).into(),
133 Some(match_expr.into()),
134 );
135 let let_stmt = if_indent_level.increase_indent(let_stmt);
136 replace(let_stmt.syntax(), &then_block, &parent_block, &if_expr)
113 } 137 }
114 }; 138 };
115 edit.target(if_expr.syntax().text_range()); 139 edit.target(if_expr.syntax().text_range());
@@ -205,7 +229,7 @@ mod tests {
205 bar(); 229 bar();
206 le<|>t n = match n { 230 le<|>t n = match n {
207 Some(it) => it, 231 Some(it) => it,
208 None => return, 232 _ => return,
209 }; 233 };
210 foo(n); 234 foo(n);
211 235
@@ -217,6 +241,29 @@ mod tests {
217 } 241 }
218 242
219 #[test] 243 #[test]
244 fn convert_if_let_result() {
245 check_assist(
246 convert_to_guarded_return,
247 r#"
248 fn main() {
249 if<|> let Ok(x) = Err(92) {
250 foo(x);
251 }
252 }
253 "#,
254 r#"
255 fn main() {
256 le<|>t x = match Err(92) {
257 Ok(it) => it,
258 _ => return,
259 };
260 foo(x);
261 }
262 "#,
263 );
264 }
265
266 #[test]
220 fn convert_let_ok_inside_fn() { 267 fn convert_let_ok_inside_fn() {
221 check_assist( 268 check_assist(
222 convert_to_guarded_return, 269 convert_to_guarded_return,
@@ -236,7 +283,7 @@ mod tests {
236 bar(); 283 bar();
237 le<|>t n = match n { 284 le<|>t n = match n {
238 Ok(it) => it, 285 Ok(it) => it,
239 None => return, 286 _ => return,
240 }; 287 };
241 foo(n); 288 foo(n);
242 289
@@ -294,7 +341,7 @@ mod tests {
294 while true { 341 while true {
295 le<|>t n = match n { 342 le<|>t n = match n {
296 Some(it) => it, 343 Some(it) => it,
297 None => continue, 344 _ => continue,
298 }; 345 };
299 foo(n); 346 foo(n);
300 bar(); 347 bar();
@@ -351,7 +398,7 @@ mod tests {
351 loop { 398 loop {
352 le<|>t n = match n { 399 le<|>t n = match n {
353 Some(it) => it, 400 Some(it) => it,
354 None => continue, 401 _ => continue,
355 }; 402 };
356 foo(n); 403 foo(n);
357 bar(); 404 bar();
diff --git a/crates/ra_syntax/src/ast/make.rs b/crates/ra_syntax/src/ast/make.rs
index 95062ef6c..6c903ca64 100644
--- a/crates/ra_syntax/src/ast/make.rs
+++ b/crates/ra_syntax/src/ast/make.rs
@@ -4,6 +4,10 @@ use itertools::Itertools;
4 4
5use crate::{ast, AstNode, SourceFile}; 5use crate::{ast, AstNode, SourceFile};
6 6
7pub fn name(text: &str) -> ast::Name {
8 ast_from_text(&format!("mod {};", text))
9}
10
7pub fn name_ref(text: &str) -> ast::NameRef { 11pub fn name_ref(text: &str) -> ast::NameRef {
8 ast_from_text(&format!("fn f() {{ {}; }}", text)) 12 ast_from_text(&format!("fn f() {{ {}; }}", text))
9} 13}
@@ -43,6 +47,21 @@ pub fn expr_unit() -> ast::Expr {
43pub fn expr_unimplemented() -> ast::Expr { 47pub fn expr_unimplemented() -> ast::Expr {
44 expr_from_text("unimplemented!()") 48 expr_from_text("unimplemented!()")
45} 49}
50pub fn expr_path(path: ast::Path) -> ast::Expr {
51 expr_from_text(&path.syntax().to_string())
52}
53pub fn expr_continue() -> ast::Expr {
54 expr_from_text("continue")
55}
56pub fn expr_break() -> ast::Expr {
57 expr_from_text("break")
58}
59pub fn expr_return() -> ast::Expr {
60 expr_from_text("return")
61}
62pub fn expr_match(expr: ast::Expr, match_arm_list: ast::MatchArmList) -> ast::Expr {
63 expr_from_text(&format!("match {} {}", expr.syntax(), match_arm_list.syntax()))
64}
46fn expr_from_text(text: &str) -> ast::Expr { 65fn expr_from_text(text: &str) -> ast::Expr {
47 ast_from_text(&format!("const C: () = {};", text)) 66 ast_from_text(&format!("const C: () = {};", text))
48} 67}
@@ -92,8 +111,8 @@ pub fn path_pat(path: ast::Path) -> ast::PathPat {
92 } 111 }
93} 112}
94 113
95pub fn match_arm(pats: impl Iterator<Item = ast::Pat>, expr: ast::Expr) -> ast::MatchArm { 114pub fn match_arm(pats: impl IntoIterator<Item = ast::Pat>, expr: ast::Expr) -> ast::MatchArm {
96 let pats_str = pats.map(|p| p.syntax().to_string()).join(" | "); 115 let pats_str = pats.into_iter().map(|p| p.syntax().to_string()).join(" | ");
97 return from_text(&format!("{} => {}", pats_str, expr.syntax())); 116 return from_text(&format!("{} => {}", pats_str, expr.syntax()));
98 117
99 fn from_text(text: &str) -> ast::MatchArm { 118 fn from_text(text: &str) -> ast::MatchArm {
@@ -101,8 +120,8 @@ pub fn match_arm(pats: impl Iterator<Item = ast::Pat>, expr: ast::Expr) -> ast::
101 } 120 }
102} 121}
103 122
104pub fn match_arm_list(arms: impl Iterator<Item = ast::MatchArm>) -> ast::MatchArmList { 123pub fn match_arm_list(arms: impl IntoIterator<Item = ast::MatchArm>) -> ast::MatchArmList {
105 let arms_str = arms.map(|arm| format!("\n {}", arm.syntax())).join(","); 124 let arms_str = arms.into_iter().map(|arm| format!("\n {}", arm.syntax())).join(",");
106 return from_text(&format!("{},\n", arms_str)); 125 return from_text(&format!("{},\n", arms_str));
107 126
108 fn from_text(text: &str) -> ast::MatchArmList { 127 fn from_text(text: &str) -> ast::MatchArmList {
@@ -110,23 +129,6 @@ pub fn match_arm_list(arms: impl Iterator<Item = ast::MatchArm>) -> ast::MatchAr
110 } 129 }
111} 130}
112 131
113pub fn let_match_early(expr: ast::Expr, path: &str, early_expression: &str) -> ast::LetStmt {
114 return from_text(&format!(
115 r#"let {} = match {} {{
116 {}(it) => it,
117 None => {},
118}};"#,
119 expr.syntax().text(),
120 expr.syntax().text(),
121 path,
122 early_expression
123 ));
124
125 fn from_text(text: &str) -> ast::LetStmt {
126 ast_from_text(&format!("fn f() {{ {} }}", text))
127 }
128}
129
130pub fn where_pred(path: ast::Path, bounds: impl Iterator<Item = ast::TypeBound>) -> ast::WherePred { 132pub fn where_pred(path: ast::Path, bounds: impl Iterator<Item = ast::TypeBound>) -> ast::WherePred {
131 let bounds = bounds.map(|b| b.syntax().to_string()).join(" + "); 133 let bounds = bounds.map(|b| b.syntax().to_string()).join(" + ");
132 return from_text(&format!("{}: {}", path.syntax(), bounds)); 134 return from_text(&format!("{}: {}", path.syntax(), bounds));
@@ -153,6 +155,14 @@ pub fn if_expression(condition: &ast::Expr, statement: &str) -> ast::IfExpr {
153 )) 155 ))
154} 156}
155 157
158pub fn let_stmt(pattern: ast::Pat, initializer: Option<ast::Expr>) -> ast::LetStmt {
159 let text = match initializer {
160 Some(it) => format!("let {} = {};", pattern.syntax(), it.syntax()),
161 None => format!("let {};", pattern.syntax()),
162 };
163 ast_from_text(&format!("fn f() {{ {} }}", text))
164}
165
156fn ast_from_text<N: AstNode>(text: &str) -> N { 166fn ast_from_text<N: AstNode>(text: &str) -> N {
157 let parse = SourceFile::parse(text); 167 let parse = SourceFile::parse(text);
158 let res = parse.tree().syntax().descendants().find_map(N::cast).unwrap(); 168 let res = parse.tree().syntax().descendants().find_map(N::cast).unwrap();