aboutsummaryrefslogtreecommitdiff
path: root/crates/ra_ide
diff options
context:
space:
mode:
Diffstat (limited to 'crates/ra_ide')
-rw-r--r--crates/ra_ide/src/lib.rs12
-rw-r--r--crates/ra_ide/src/ssr.rs324
2 files changed, 336 insertions, 0 deletions
diff --git a/crates/ra_ide/src/lib.rs b/crates/ra_ide/src/lib.rs
index 9d66c365b..f86f98be7 100644
--- a/crates/ra_ide/src/lib.rs
+++ b/crates/ra_ide/src/lib.rs
@@ -37,6 +37,7 @@ mod display;
37mod inlay_hints; 37mod inlay_hints;
38mod expand; 38mod expand;
39mod expand_macro; 39mod expand_macro;
40mod ssr;
40 41
41#[cfg(test)] 42#[cfg(test)]
42mod marks; 43mod marks;
@@ -73,6 +74,7 @@ pub use crate::{
73 }, 74 },
74 runnables::{Runnable, RunnableKind, TestId}, 75 runnables::{Runnable, RunnableKind, TestId},
75 source_change::{FileSystemEdit, SourceChange, SourceFileEdit}, 76 source_change::{FileSystemEdit, SourceChange, SourceFileEdit},
77 ssr::SsrError,
76 syntax_highlighting::HighlightedRange, 78 syntax_highlighting::HighlightedRange,
77}; 79};
78 80
@@ -464,6 +466,16 @@ impl Analysis {
464 self.with_db(|db| references::rename(db, position, new_name)) 466 self.with_db(|db| references::rename(db, position, new_name))
465 } 467 }
466 468
469 pub fn structural_search_replace(
470 &self,
471 query: &str,
472 ) -> Cancelable<Result<SourceChange, SsrError>> {
473 self.with_db(|db| {
474 let edits = ssr::parse_search_replace(query, db)?;
475 Ok(SourceChange::source_file_edits("ssr", edits))
476 })
477 }
478
467 /// Performs an operation on that may be Canceled. 479 /// Performs an operation on that may be Canceled.
468 fn with_db<F: FnOnce(&RootDatabase) -> T + std::panic::UnwindSafe, T>( 480 fn with_db<F: FnOnce(&RootDatabase) -> T + std::panic::UnwindSafe, T>(
469 &self, 481 &self,
diff --git a/crates/ra_ide/src/ssr.rs b/crates/ra_ide/src/ssr.rs
new file mode 100644
index 000000000..14eb0b8b2
--- /dev/null
+++ b/crates/ra_ide/src/ssr.rs
@@ -0,0 +1,324 @@
1//! structural search replace
2
3use crate::source_change::SourceFileEdit;
4use ra_ide_db::RootDatabase;
5use ra_syntax::ast::make::expr_from_text;
6use ra_syntax::AstNode;
7use ra_syntax::SyntaxElement;
8use ra_syntax::SyntaxNode;
9use ra_text_edit::{TextEdit, TextEditBuilder};
10use rustc_hash::FxHashMap;
11use std::collections::HashMap;
12use std::str::FromStr;
13
14pub use ra_db::{SourceDatabase, SourceDatabaseExt};
15use ra_ide_db::symbol_index::SymbolsDatabase;
16
17#[derive(Debug, PartialEq)]
18pub struct SsrError(String);
19
20impl std::fmt::Display for SsrError {
21 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
22 write!(f, "Parse error: {}", self.0)
23 }
24}
25
26impl std::error::Error for SsrError {}
27
28pub fn parse_search_replace(
29 query: &str,
30 db: &RootDatabase,
31) -> Result<Vec<SourceFileEdit>, SsrError> {
32 let mut edits = vec![];
33 let query: SsrQuery = query.parse()?;
34 for &root in db.local_roots().iter() {
35 let sr = db.source_root(root);
36 for file_id in sr.walk() {
37 dbg!(db.file_relative_path(file_id));
38 let matches = find(&query.pattern, db.parse(file_id).tree().syntax());
39 if !matches.matches.is_empty() {
40 edits.push(SourceFileEdit { file_id, edit: replace(&matches, &query.template) });
41 }
42 }
43 }
44 Ok(edits)
45}
46
47#[derive(Debug)]
48struct SsrQuery {
49 pattern: SsrPattern,
50 template: SsrTemplate,
51}
52
53#[derive(Debug)]
54struct SsrPattern {
55 pattern: SyntaxNode,
56 vars: Vec<Var>,
57}
58
59/// represents an `$var` in an SSR query
60#[derive(Debug, Clone, PartialEq, Eq, Hash)]
61struct Var(String);
62
63#[derive(Debug)]
64struct SsrTemplate {
65 template: SyntaxNode,
66 placeholders: FxHashMap<SyntaxNode, Var>,
67}
68
69type Binding = HashMap<Var, SyntaxNode>;
70
71#[derive(Debug)]
72struct Match {
73 place: SyntaxNode,
74 binding: Binding,
75}
76
77#[derive(Debug)]
78struct SsrMatches {
79 matches: Vec<Match>,
80}
81
82impl FromStr for SsrQuery {
83 type Err = SsrError;
84
85 fn from_str(query: &str) -> Result<SsrQuery, SsrError> {
86 let mut it = query.split("==>>");
87 let pattern = it.next().expect("at least empty string").trim();
88 let mut template =
89 it.next().ok_or(SsrError("Cannot find delemiter `==>>`".into()))?.trim().to_string();
90 if it.next().is_some() {
91 return Err(SsrError("More than one delimiter found".into()));
92 }
93 let mut vars = vec![];
94 let mut it = pattern.split('$');
95 let mut pattern = it.next().expect("something").to_string();
96
97 for part in it.map(split_by_var) {
98 let (var, var_type, remainder) = part?;
99 is_expr(var_type)?;
100 let new_var = create_name(var, &mut vars)?;
101 pattern.push_str(new_var);
102 pattern.push_str(remainder);
103 template = replace_in_template(template, var, new_var);
104 }
105
106 let template = expr_from_text(&template).syntax().clone();
107 let mut placeholders = FxHashMap::default();
108
109 traverse(&template, &mut |n| {
110 if let Some(v) = vars.iter().find(|v| v.0.as_str() == n.text()) {
111 placeholders.insert(n.clone(), v.clone());
112 false
113 } else {
114 true
115 }
116 });
117
118 let pattern = SsrPattern { pattern: expr_from_text(&pattern).syntax().clone(), vars };
119 let template = SsrTemplate { template, placeholders };
120 Ok(SsrQuery { pattern, template })
121 }
122}
123
124fn traverse(node: &SyntaxNode, go: &mut impl FnMut(&SyntaxNode) -> bool) {
125 if !go(node) {
126 return;
127 }
128 for ref child in node.children() {
129 traverse(child, go);
130 }
131}
132
133fn split_by_var(s: &str) -> Result<(&str, &str, &str), SsrError> {
134 let end_of_name = s.find(":").ok_or(SsrError("Use $<name>:expr".into()))?;
135 let name = &s[0..end_of_name];
136 is_name(name)?;
137 let type_begin = end_of_name + 1;
138 let type_length = s[type_begin..].find(|c| !char::is_ascii_alphanumeric(&c)).unwrap_or(s.len());
139 let type_name = &s[type_begin..type_begin + type_length];
140 Ok((name, type_name, &s[type_begin + type_length..]))
141}
142
143fn is_name(s: &str) -> Result<(), SsrError> {
144 if s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
145 Ok(())
146 } else {
147 Err(SsrError("Name can contain only alphanumerics and _".into()))
148 }
149}
150
151fn is_expr(s: &str) -> Result<(), SsrError> {
152 if s == "expr" {
153 Ok(())
154 } else {
155 Err(SsrError("Only $<name>:expr is supported".into()))
156 }
157}
158
159fn replace_in_template(template: String, var: &str, new_var: &str) -> String {
160 let name = format!("${}", var);
161 template.replace(&name, new_var)
162}
163
164fn create_name<'a>(name: &str, vars: &'a mut Vec<Var>) -> Result<&'a str, SsrError> {
165 let sanitized_name = format!("__search_pattern_{}", name);
166 if vars.iter().any(|a| a.0 == sanitized_name) {
167 return Err(SsrError(format!("Name `{}` repeats more than once", name)));
168 }
169 vars.push(Var(sanitized_name));
170 Ok(&vars.last().unwrap().0)
171}
172
173fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
174 fn check(
175 pattern: &SyntaxElement,
176 code: &SyntaxElement,
177 placeholders: &[Var],
178 match_: &mut Match,
179 ) -> bool {
180 match (pattern, code) {
181 (SyntaxElement::Token(ref pattern), SyntaxElement::Token(ref code)) => {
182 pattern.text() == code.text()
183 }
184 (SyntaxElement::Node(ref pattern), SyntaxElement::Node(ref code)) => {
185 if placeholders.iter().find(|&n| n.0.as_str() == pattern.text()).is_some() {
186 match_.binding.insert(Var(pattern.text().to_string()), code.clone());
187 true
188 } else {
189 pattern.green().children().count() == code.green().children().count()
190 && pattern
191 .children_with_tokens()
192 .zip(code.children_with_tokens())
193 .all(|(a, b)| check(&a, &b, placeholders, match_))
194 }
195 }
196 _ => false,
197 }
198 }
199 let kind = pattern.pattern.kind();
200 let matches = code
201 .descendants_with_tokens()
202 .filter(|n| n.kind() == kind)
203 .filter_map(|code| {
204 let mut match_ =
205 Match { place: code.as_node().unwrap().clone(), binding: HashMap::new() };
206 if check(
207 &SyntaxElement::from(pattern.pattern.clone()),
208 &code,
209 &pattern.vars,
210 &mut match_,
211 ) {
212 Some(match_)
213 } else {
214 None
215 }
216 })
217 .collect();
218 SsrMatches { matches }
219}
220
221fn replace(matches: &SsrMatches, template: &SsrTemplate) -> TextEdit {
222 let mut builder = TextEditBuilder::default();
223 for match_ in &matches.matches {
224 builder.replace(match_.place.text_range(), render_replace(&match_.binding, template));
225 }
226 builder.finish()
227}
228
229fn render_replace(binding: &Binding, template: &SsrTemplate) -> String {
230 let mut builder = TextEditBuilder::default();
231 for element in template.template.descendants() {
232 if let Some(var) = template.placeholders.get(&element) {
233 builder.replace(element.text_range(), binding[var].to_string())
234 }
235 }
236 builder.finish().apply(&template.template.text().to_string())
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use ra_syntax::SourceFile;
243
244 fn parse_error_text(query: &str) -> String {
245 format!("{}", query.parse::<SsrQuery>().unwrap_err())
246 }
247
248 #[test]
249 fn parser_happy_case() {
250 let result: SsrQuery = "foo($a:expr, $b:expr) ==>> bar($b, $a)".parse().unwrap();
251 assert_eq!(&result.pattern.pattern.text(), "foo(__search_pattern_a, __search_pattern_b)");
252 assert_eq!(result.pattern.vars.len(), 2);
253 assert_eq!(result.pattern.vars[0].0, "__search_pattern_a");
254 assert_eq!(result.pattern.vars[1].0, "__search_pattern_b");
255 assert_eq!(&result.template.template.text(), "bar(__search_pattern_b, __search_pattern_a)");
256 dbg!(result.template.placeholders);
257 }
258
259 #[test]
260 fn parser_empty_query() {
261 assert_eq!(parse_error_text(""), "Parse error: Cannot find delemiter `==>>`");
262 }
263
264 #[test]
265 fn parser_no_delimiter() {
266 assert_eq!(parse_error_text("foo()"), "Parse error: Cannot find delemiter `==>>`");
267 }
268
269 #[test]
270 fn parser_two_delimiters() {
271 assert_eq!(
272 parse_error_text("foo() ==>> a ==>> b "),
273 "Parse error: More than one delimiter found"
274 );
275 }
276
277 #[test]
278 fn parser_no_pattern_type() {
279 assert_eq!(parse_error_text("foo($a) ==>>"), "Parse error: Use $<name>:expr");
280 }
281
282 #[test]
283 fn parser_invalid_name() {
284 assert_eq!(
285 parse_error_text("foo($a+:expr) ==>>"),
286 "Parse error: Name can contain only alphanumerics and _"
287 );
288 }
289
290 #[test]
291 fn parser_invalid_type() {
292 assert_eq!(
293 parse_error_text("foo($a:ident) ==>>"),
294 "Parse error: Only $<name>:expr is supported"
295 );
296 }
297
298 #[test]
299 fn parser_repeated_name() {
300 assert_eq!(
301 parse_error_text("foo($a:expr, $a:expr) ==>>"),
302 "Parse error: Name `a` repeats more than once"
303 );
304 }
305
306 #[test]
307 fn parse_match_replace() {
308 let query: SsrQuery = "foo($x:expr) ==>> bar($x)".parse().unwrap();
309 let input = "fn main() { foo(1+2); }";
310
311 let code = SourceFile::parse(input).tree();
312 let matches = find(&query.pattern, code.syntax());
313 assert_eq!(matches.matches.len(), 1);
314 assert_eq!(matches.matches[0].place.text(), "foo(1+2)");
315 assert_eq!(matches.matches[0].binding.len(), 1);
316 assert_eq!(
317 matches.matches[0].binding[&Var("__search_pattern_x".to_string())].text(),
318 "1+2"
319 );
320
321 let edit = replace(&matches, &query.template);
322 assert_eq!(edit.apply(input), "fn main() { bar(1+2); }");
323 }
324}