From 8ce15b02dea7152953775904fd937cced2422bc6 Mon Sep 17 00:00:00 2001 From: Edwin Cheng Date: Fri, 26 Mar 2021 03:52:35 +0800 Subject: Fix recursive macro statement expansion --- crates/hir_def/src/body/lower.rs | 68 +++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 35 deletions(-) (limited to 'crates/hir_def/src/body') diff --git a/crates/hir_def/src/body/lower.rs b/crates/hir_def/src/body/lower.rs index 19f5065d1..229e81dd4 100644 --- a/crates/hir_def/src/body/lower.rs +++ b/crates/hir_def/src/body/lower.rs @@ -74,6 +74,7 @@ pub(super) fn lower( _c: Count::new(), }, expander, + statements_in_scope: Vec::new(), } .collect(params, body) } @@ -83,6 +84,7 @@ struct ExprCollector<'a> { expander: Expander, body: Body, source_map: BodySourceMap, + statements_in_scope: Vec, } impl ExprCollector<'_> { @@ -533,15 +535,13 @@ impl ExprCollector<'_> { ids[0] } ast::Expr::MacroStmts(e) => { - // FIXME: these statements should be held by some hir containter - for stmt in e.statements() { - self.collect_stmt(stmt); - } - if let Some(expr) = e.expr() { - self.collect_expr(expr) - } else { - self.alloc_expr(Expr::Missing, syntax_ptr) - } + e.statements().for_each(|s| self.collect_stmt(s)); + let tail = e + .expr() + .map(|e| self.collect_expr(e)) + .unwrap_or_else(|| self.alloc_expr(Expr::Missing, syntax_ptr.clone())); + + self.alloc_expr(Expr::MacroStmts { tail }, syntax_ptr) } }) } @@ -618,58 +618,54 @@ impl ExprCollector<'_> { } } - fn collect_stmt(&mut self, s: ast::Stmt) -> Option> { - let stmt = match s { + fn collect_stmt(&mut self, s: ast::Stmt) { + match s { ast::Stmt::LetStmt(stmt) => { - self.check_cfg(&stmt)?; - + if self.check_cfg(&stmt).is_none() { + return; + } let pat = self.collect_pat_opt(stmt.pat()); let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it)); let initializer = stmt.initializer().map(|e| self.collect_expr(e)); - vec![Statement::Let { pat, type_ref, initializer }] + self.statements_in_scope.push(Statement::Let { pat, type_ref, initializer }); } ast::Stmt::ExprStmt(stmt) => { - self.check_cfg(&stmt)?; + if self.check_cfg(&stmt).is_none() { + return; + } // Note that macro could be expended to multiple statements if let Some(ast::Expr::MacroCall(m)) = stmt.expr() { let syntax_ptr = AstPtr::new(&stmt.expr().unwrap()); - let mut stmts = vec![]; self.collect_macro_call(m, syntax_ptr.clone(), false, |this, expansion| { match expansion { Some(expansion) => { let statements: ast::MacroStmts = expansion; - statements.statements().for_each(|stmt| { - if let Some(mut r) = this.collect_stmt(stmt) { - stmts.append(&mut r); - } - }); + statements.statements().for_each(|stmt| this.collect_stmt(stmt)); if let Some(expr) = statements.expr() { - stmts.push(Statement::Expr(this.collect_expr(expr))); + let expr = this.collect_expr(expr); + this.statements_in_scope.push(Statement::Expr(expr)); } } None => { - stmts.push(Statement::Expr( - this.alloc_expr(Expr::Missing, syntax_ptr.clone()), - )); + let expr = this.alloc_expr(Expr::Missing, syntax_ptr.clone()); + this.statements_in_scope.push(Statement::Expr(expr)); } } }); - stmts } else { - vec![Statement::Expr(self.collect_expr_opt(stmt.expr()))] + let expr = self.collect_expr_opt(stmt.expr()); + self.statements_in_scope.push(Statement::Expr(expr)); } } ast::Stmt::Item(item) => { - self.check_cfg(&item)?; - - return None; + if self.check_cfg(&item).is_none() { + return; + } } - }; - - Some(stmt) + } } fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId { @@ -685,10 +681,12 @@ impl ExprCollector<'_> { let module = if has_def_map { def_map.root() } else { self.expander.module }; let prev_def_map = mem::replace(&mut self.expander.def_map, def_map); let prev_local_module = mem::replace(&mut self.expander.module, module); + let prev_statements = std::mem::take(&mut self.statements_in_scope); + + block.statements().for_each(|s| self.collect_stmt(s)); - let statements = - block.statements().filter_map(|s| self.collect_stmt(s)).flatten().collect(); let tail = block.tail_expr().map(|e| self.collect_expr(e)); + let statements = std::mem::replace(&mut self.statements_in_scope, prev_statements); let syntax_node_ptr = AstPtr::new(&block.into()); let expr_id = self.alloc_expr( Expr::Block { id: block_id, statements, tail, label: None }, -- cgit v1.2.3