From 8ce15b02dea7152953775904fd937cced2422bc6 Mon Sep 17 00:00:00 2001 From: Edwin Cheng Date: Fri, 26 Mar 2021 03:52:35 +0800 Subject: Fix recursive macro statement expansion --- crates/hir_def/src/body/lower.rs | 68 +++++++++++++++++------------------ crates/hir_def/src/expr.rs | 4 +++ crates/hir_def/src/item_tree.rs | 9 ----- crates/hir_def/src/item_tree/lower.rs | 8 ----- crates/hir_expand/src/db.rs | 59 +++++++++++++++++++++--------- crates/hir_ty/src/infer/expr.rs | 1 + crates/hir_ty/src/tests/macros.rs | 40 ++++++++++++++++++++- 7 files changed, 119 insertions(+), 70 deletions(-) diff --git a/crates/hir_def/src/body/lower.rs b/crates/hir_def/src/body/lower.rs index 19f5065d1..229e81dd4 100644 --- a/crates/hir_def/src/body/lower.rs +++ b/crates/hir_def/src/body/lower.rs @@ -74,6 +74,7 @@ pub(super) fn lower( _c: Count::new(), }, expander, + statements_in_scope: Vec::new(), } .collect(params, body) } @@ -83,6 +84,7 @@ struct ExprCollector<'a> { expander: Expander, body: Body, source_map: BodySourceMap, + statements_in_scope: Vec, } impl ExprCollector<'_> { @@ -533,15 +535,13 @@ impl ExprCollector<'_> { ids[0] } ast::Expr::MacroStmts(e) => { - // FIXME: these statements should be held by some hir containter - for stmt in e.statements() { - self.collect_stmt(stmt); - } - if let Some(expr) = e.expr() { - self.collect_expr(expr) - } else { - self.alloc_expr(Expr::Missing, syntax_ptr) - } + e.statements().for_each(|s| self.collect_stmt(s)); + let tail = e + .expr() + .map(|e| self.collect_expr(e)) + .unwrap_or_else(|| self.alloc_expr(Expr::Missing, syntax_ptr.clone())); + + self.alloc_expr(Expr::MacroStmts { tail }, syntax_ptr) } }) } @@ -618,58 +618,54 @@ impl ExprCollector<'_> { } } - fn collect_stmt(&mut self, s: ast::Stmt) -> Option> { - let stmt = match s { + fn collect_stmt(&mut self, s: ast::Stmt) { + match s { ast::Stmt::LetStmt(stmt) => { - self.check_cfg(&stmt)?; - + if self.check_cfg(&stmt).is_none() { + return; + } let pat = self.collect_pat_opt(stmt.pat()); let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it)); let initializer = stmt.initializer().map(|e| self.collect_expr(e)); - vec![Statement::Let { pat, type_ref, initializer }] + self.statements_in_scope.push(Statement::Let { pat, type_ref, initializer }); } ast::Stmt::ExprStmt(stmt) => { - self.check_cfg(&stmt)?; + if self.check_cfg(&stmt).is_none() { + return; + } // Note that macro could be expended to multiple statements if let Some(ast::Expr::MacroCall(m)) = stmt.expr() { let syntax_ptr = AstPtr::new(&stmt.expr().unwrap()); - let mut stmts = vec![]; self.collect_macro_call(m, syntax_ptr.clone(), false, |this, expansion| { match expansion { Some(expansion) => { let statements: ast::MacroStmts = expansion; - statements.statements().for_each(|stmt| { - if let Some(mut r) = this.collect_stmt(stmt) { - stmts.append(&mut r); - } - }); + statements.statements().for_each(|stmt| this.collect_stmt(stmt)); if let Some(expr) = statements.expr() { - stmts.push(Statement::Expr(this.collect_expr(expr))); + let expr = this.collect_expr(expr); + this.statements_in_scope.push(Statement::Expr(expr)); } } None => { - stmts.push(Statement::Expr( - this.alloc_expr(Expr::Missing, syntax_ptr.clone()), - )); + let expr = this.alloc_expr(Expr::Missing, syntax_ptr.clone()); + this.statements_in_scope.push(Statement::Expr(expr)); } } }); - stmts } else { - vec![Statement::Expr(self.collect_expr_opt(stmt.expr()))] + let expr = self.collect_expr_opt(stmt.expr()); + self.statements_in_scope.push(Statement::Expr(expr)); } } ast::Stmt::Item(item) => { - self.check_cfg(&item)?; - - return None; + if self.check_cfg(&item).is_none() { + return; + } } - }; - - Some(stmt) + } } fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId { @@ -685,10 +681,12 @@ impl ExprCollector<'_> { let module = if has_def_map { def_map.root() } else { self.expander.module }; let prev_def_map = mem::replace(&mut self.expander.def_map, def_map); let prev_local_module = mem::replace(&mut self.expander.module, module); + let prev_statements = std::mem::take(&mut self.statements_in_scope); + + block.statements().for_each(|s| self.collect_stmt(s)); - let statements = - block.statements().filter_map(|s| self.collect_stmt(s)).flatten().collect(); let tail = block.tail_expr().map(|e| self.collect_expr(e)); + let statements = std::mem::replace(&mut self.statements_in_scope, prev_statements); let syntax_node_ptr = AstPtr::new(&block.into()); let expr_id = self.alloc_expr( Expr::Block { id: block_id, statements, tail, label: None }, diff --git a/crates/hir_def/src/expr.rs b/crates/hir_def/src/expr.rs index 24be93773..6c7376fad 100644 --- a/crates/hir_def/src/expr.rs +++ b/crates/hir_def/src/expr.rs @@ -171,6 +171,9 @@ pub enum Expr { Unsafe { body: ExprId, }, + MacroStmts { + tail: ExprId, + }, Array(Array), Literal(Literal), } @@ -357,6 +360,7 @@ impl Expr { f(*repeat) } }, + Expr::MacroStmts { tail } => f(*tail), Expr::Literal(_) => {} } } diff --git a/crates/hir_def/src/item_tree.rs b/crates/hir_def/src/item_tree.rs index ae2475b4e..ca0048b16 100644 --- a/crates/hir_def/src/item_tree.rs +++ b/crates/hir_def/src/item_tree.rs @@ -110,15 +110,6 @@ impl ItemTree { // still need to collect inner items. ctx.lower_inner_items(e.syntax()) }, - ast::ExprStmt(stmt) => { - // Macros can expand to stmt. We return an empty item tree in this case, but - // still need to collect inner items. - ctx.lower_inner_items(stmt.syntax()) - }, - ast::Item(item) => { - // Macros can expand to stmt and other item, and we add it as top level item - ctx.lower_single_item(item) - }, _ => { panic!("cannot create item tree from {:?} {}", syntax, syntax); }, diff --git a/crates/hir_def/src/item_tree/lower.rs b/crates/hir_def/src/item_tree/lower.rs index d3fe1ce1e..3f558edd8 100644 --- a/crates/hir_def/src/item_tree/lower.rs +++ b/crates/hir_def/src/item_tree/lower.rs @@ -87,14 +87,6 @@ impl Ctx { self.tree } - pub(super) fn lower_single_item(mut self, item: ast::Item) -> ItemTree { - self.tree.top_level = self - .lower_mod_item(&item, false) - .map(|item| item.0) - .unwrap_or_else(|| Default::default()); - self.tree - } - pub(super) fn lower_inner_items(mut self, within: &SyntaxNode) -> ItemTree { self.collect_inner_items(within); self.tree diff --git a/crates/hir_expand/src/db.rs b/crates/hir_expand/src/db.rs index fc73e435b..d672f6723 100644 --- a/crates/hir_expand/src/db.rs +++ b/crates/hir_expand/src/db.rs @@ -5,7 +5,13 @@ use std::sync::Arc; use base_db::{salsa, SourceDatabase}; use mbe::{ExpandError, ExpandResult, MacroRules}; use parser::FragmentKind; -use syntax::{algo::diff, ast::NameOwner, AstNode, GreenNode, Parse, SyntaxKind::*, SyntaxNode}; +use syntax::{ + algo::diff, + ast::{MacroStmts, NameOwner}, + AstNode, GreenNode, Parse, + SyntaxKind::*, + SyntaxNode, +}; use crate::{ ast_id_map::AstIdMap, hygiene::HygieneFrame, BuiltinDeriveExpander, BuiltinFnLikeExpander, @@ -340,13 +346,19 @@ fn parse_macro_with_arg( None => return ExpandResult { value: None, err: result.err }, }; - log::debug!("expanded = {}", tt.as_debug_string()); - let fragment_kind = to_fragment_kind(db, macro_call_id); + log::debug!("expanded = {}", tt.as_debug_string()); + log::debug!("kind = {:?}", fragment_kind); + let (parse, rev_token_map) = match mbe::token_tree_to_syntax_node(&tt, fragment_kind) { Ok(it) => it, Err(err) => { + log::debug!( + "failed to parse expanstion to {:?} = {}", + fragment_kind, + tt.as_debug_string() + ); return ExpandResult::only_err(err); } }; @@ -362,15 +374,34 @@ fn parse_macro_with_arg( return ExpandResult::only_err(err); } }; - - if !diff(&node, &call_node.value).is_empty() { - ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: Some(err) } - } else { + if is_self_replicating(&node, &call_node.value) { return ExpandResult::only_err(err); + } else { + ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: Some(err) } + } + } + None => { + log::debug!("parse = {:?}", parse.syntax_node().kind()); + ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: None } + } + } +} + +fn is_self_replicating(from: &SyntaxNode, to: &SyntaxNode) -> bool { + if diff(from, to).is_empty() { + return true; + } + if let Some(stmts) = MacroStmts::cast(from.clone()) { + if stmts.statements().any(|stmt| diff(stmt.syntax(), to).is_empty()) { + return true; + } + if let Some(expr) = stmts.expr() { + if diff(expr.syntax(), to).is_empty() { + return true; } } - None => ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: None }, } + false } fn hygiene_frame(db: &dyn AstDatabase, file_id: HirFileId) -> Arc { @@ -390,21 +421,15 @@ fn to_fragment_kind(db: &dyn AstDatabase, id: MacroCallId) -> FragmentKind { let parent = match syn.parent() { Some(it) => it, - None => { - // FIXME: - // If it is root, which means the parent HirFile - // MacroKindFile must be non-items - // return expr now. - return FragmentKind::Expr; - } + None => return FragmentKind::Statements, }; match parent.kind() { MACRO_ITEMS | SOURCE_FILE => FragmentKind::Items, - MACRO_STMTS => FragmentKind::Statement, + MACRO_STMTS => FragmentKind::Statements, ITEM_LIST => FragmentKind::Items, LET_STMT => { - // FIXME: Handle Pattern + // FIXME: Handle LHS Pattern FragmentKind::Expr } EXPR_STMT => FragmentKind::Statements, diff --git a/crates/hir_ty/src/infer/expr.rs b/crates/hir_ty/src/infer/expr.rs index 3f3187ea2..e6ede05ca 100644 --- a/crates/hir_ty/src/infer/expr.rs +++ b/crates/hir_ty/src/infer/expr.rs @@ -767,6 +767,7 @@ impl<'a> InferenceContext<'a> { None => self.table.new_float_var(), }, }, + Expr::MacroStmts { tail } => self.infer_expr(*tail, expected), }; // use a new type variable if we got unknown here let ty = self.insert_type_vars_shallow(ty); diff --git a/crates/hir_ty/src/tests/macros.rs b/crates/hir_ty/src/tests/macros.rs index 7eda51866..01935ec99 100644 --- a/crates/hir_ty/src/tests/macros.rs +++ b/crates/hir_ty/src/tests/macros.rs @@ -226,11 +226,48 @@ fn expr_macro_expanded_in_stmts() { "#, expect![[r#" !0..8 'leta=();': () + !0..8 'leta=();': () + !3..4 'a': () + !5..7 '()': () 57..84 '{ ...); } }': () "#]], ); } +#[test] +fn recurisve_macro_expanded_in_stmts() { + check_infer( + r#" + macro_rules! ng { + ([$($tts:tt)*]) => { + $($tts)*; + }; + ([$($tts:tt)*] $head:tt $($rest:tt)*) => { + ng! { + [$($tts)* $head] $($rest)* + } + }; + } + fn foo() { + ng!([] let a = 3); + let b = a; + } + "#, + expect![[r#" + !0..7 'leta=3;': {unknown} + !0..7 'leta=3;': {unknown} + !0..13 'ng!{[leta=3]}': {unknown} + !0..13 'ng!{[leta=]3}': {unknown} + !0..13 'ng!{[leta]=3}': {unknown} + !3..4 'a': i32 + !5..6 '3': i32 + 196..237 '{ ...= a; }': () + 229..230 'b': i32 + 233..234 'a': i32 + "#]], + ); +} + #[test] fn recursive_inner_item_macro_rules() { check_infer( @@ -246,7 +283,8 @@ fn recursive_inner_item_macro_rules() { "#, expect![[r#" !0..1 '1': i32 - !0..7 'mac!($)': {unknown} + !0..26 'macro_...>{1};}': {unknown} + !0..26 'macro_...>{1};}': {unknown} 107..143 '{ ...!(); }': () 129..130 'a': i32 "#]], -- cgit v1.2.3