From 3290bb4112e7988f98b108e3c590d39c881f00e0 Mon Sep 17 00:00:00 2001 From: Aleksey Kladov Date: Fri, 2 Oct 2020 20:52:48 +0200 Subject: Simplify ast_transform --- crates/assists/src/ast_transform.rs | 74 +++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 40 deletions(-) diff --git a/crates/assists/src/ast_transform.rs b/crates/assists/src/ast_transform.rs index 835da3bb2..4307e0191 100644 --- a/crates/assists/src/ast_transform.rs +++ b/crates/assists/src/ast_transform.rs @@ -5,12 +5,13 @@ use hir::{HirDisplay, PathResolution, SemanticsScope}; use syntax::{ algo::SyntaxRewriter, ast::{self, AstNode}, + SyntaxNode, }; pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: N) -> N { SyntaxRewriter::from_fn(|element| match element { syntax::SyntaxElement::Node(n) => { - let replacement = transformer.get_substitution(&n)?; + let replacement = transformer.get_substitution(&n, transformer)?; Some(replacement.into()) } _ => None, @@ -47,32 +48,35 @@ pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: N) -> N { /// We'd want to somehow express this concept simpler, but so far nobody got to /// simplifying this! pub trait AstTransform<'a> { - fn get_substitution(&self, node: &syntax::SyntaxNode) -> Option; + fn get_substitution( + &self, + node: &SyntaxNode, + recur: &dyn AstTransform<'a>, + ) -> Option; - fn chain_before(self, other: Box + 'a>) -> Box + 'a>; fn or + 'a>(self, other: T) -> Box + 'a> where Self: Sized + 'a, { - self.chain_before(Box::new(other)) + Box::new(Or(Box::new(self), Box::new(other))) } } -struct NullTransformer; +struct Or<'a>(Box + 'a>, Box + 'a>); -impl<'a> AstTransform<'a> for NullTransformer { - fn get_substitution(&self, _node: &syntax::SyntaxNode) -> Option { - None - } - fn chain_before(self, other: Box + 'a>) -> Box + 'a> { - other +impl<'a> AstTransform<'a> for Or<'a> { + fn get_substitution( + &self, + node: &SyntaxNode, + recur: &dyn AstTransform<'a>, + ) -> Option { + self.0.get_substitution(node, recur).or_else(|| self.1.get_substitution(node, recur)) } } pub struct SubstituteTypeParams<'a> { source_scope: &'a SemanticsScope<'a>, substs: FxHashMap, - previous: Box + 'a>, } impl<'a> SubstituteTypeParams<'a> { @@ -111,11 +115,7 @@ impl<'a> SubstituteTypeParams<'a> { } }) .collect(); - return SubstituteTypeParams { - source_scope, - substs: substs_by_param, - previous: Box::new(NullTransformer), - }; + return SubstituteTypeParams { source_scope, substs: substs_by_param }; // FIXME: It would probably be nicer if we could get this via HIR (i.e. get the // trait ref, and then go from the types in the substs back to the syntax). @@ -140,7 +140,14 @@ impl<'a> SubstituteTypeParams<'a> { Some(result) } } - fn get_substitution_inner(&self, node: &syntax::SyntaxNode) -> Option { +} + +impl<'a> AstTransform<'a> for SubstituteTypeParams<'a> { + fn get_substitution( + &self, + node: &SyntaxNode, + _recur: &dyn AstTransform<'a>, + ) -> Option { let type_ref = ast::Type::cast(node.clone())?; let path = match &type_ref { ast::Type::PathType(path_type) => path_type.path()?, @@ -154,27 +161,23 @@ impl<'a> SubstituteTypeParams<'a> { } } -impl<'a> AstTransform<'a> for SubstituteTypeParams<'a> { - fn get_substitution(&self, node: &syntax::SyntaxNode) -> Option { - self.get_substitution_inner(node).or_else(|| self.previous.get_substitution(node)) - } - fn chain_before(self, other: Box + 'a>) -> Box + 'a> { - Box::new(SubstituteTypeParams { previous: other, ..self }) - } -} - pub struct QualifyPaths<'a> { target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>, - previous: Box + 'a>, } impl<'a> QualifyPaths<'a> { pub fn new(target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>) -> Self { - Self { target_scope, source_scope, previous: Box::new(NullTransformer) } + Self { target_scope, source_scope } } +} - fn get_substitution_inner(&self, node: &syntax::SyntaxNode) -> Option { +impl<'a> AstTransform<'a> for QualifyPaths<'a> { + fn get_substitution( + &self, + node: &SyntaxNode, + recur: &dyn AstTransform<'a>, + ) -> Option { // FIXME handle value ns? let from = self.target_scope.module()?; let p = ast::Path::cast(node.clone())?; @@ -191,7 +194,7 @@ impl<'a> QualifyPaths<'a> { let type_args = p .segment() .and_then(|s| s.generic_arg_list()) - .map(|arg_list| apply(self, arg_list)); + .map(|arg_list| apply(recur, arg_list)); if let Some(type_args) = type_args { let last_segment = path.segment().unwrap(); path = path.with_segment(last_segment.with_generic_args(type_args)) @@ -208,15 +211,6 @@ impl<'a> QualifyPaths<'a> { } } -impl<'a> AstTransform<'a> for QualifyPaths<'a> { - fn get_substitution(&self, node: &syntax::SyntaxNode) -> Option { - self.get_substitution_inner(node).or_else(|| self.previous.get_substitution(node)) - } - fn chain_before(self, other: Box + 'a>) -> Box + 'a> { - Box::new(QualifyPaths { previous: other, ..self }) - } -} - pub(crate) fn path_to_ast(path: hir::ModPath) -> ast::Path { let parse = ast::SourceFile::parse(&path.to_string()); parse -- cgit v1.2.3