From 8e07b23b84bff16c0decc6f2b80c27862eac6df1 Mon Sep 17 00:00:00 2001 From: Edwin Cheng Date: Tue, 16 Mar 2021 13:44:50 +0800 Subject: Fix macro expansion for statements w/o semicolon --- crates/hir_def/src/body/lower.rs | 105 +++++++++++++++++-------------- crates/hir_def/src/item_tree.rs | 5 ++ crates/hir_expand/src/db.rs | 3 +- crates/hir_ty/src/tests/macros.rs | 16 +++++ crates/mbe/src/tests.rs | 11 ++-- crates/parser/src/grammar.rs | 6 +- crates/parser/src/grammar/expressions.rs | 6 +- crates/syntax/src/ast/generated/nodes.rs | 10 ++- 8 files changed, 100 insertions(+), 62 deletions(-) (limited to 'crates') diff --git a/crates/hir_def/src/body/lower.rs b/crates/hir_def/src/body/lower.rs index 8934ae6c9..7052058f2 100644 --- a/crates/hir_def/src/body/lower.rs +++ b/crates/hir_def/src/body/lower.rs @@ -519,7 +519,7 @@ impl ExprCollector<'_> { } ast::Expr::MacroCall(e) => { let mut ids = vec![]; - self.collect_macro_call(e, syntax_ptr.clone(), |this, expansion| { + self.collect_macro_call(e, syntax_ptr.clone(), true, |this, expansion| { ids.push(match expansion { Some(it) => this.collect_expr(it), None => this.alloc_expr(Expr::Missing, syntax_ptr.clone()), @@ -527,6 +527,17 @@ impl ExprCollector<'_> { }); ids[0] } + ast::Expr::MacroStmts(e) => { + // FIXME: these statements should be held by some hir containter + for stmt in e.statements() { + self.collect_stmt(stmt); + } + if let Some(expr) = e.expr() { + self.collect_expr(expr) + } else { + self.alloc_expr(Expr::Missing, syntax_ptr) + } + } } } @@ -534,6 +545,7 @@ impl ExprCollector<'_> { &mut self, e: ast::MacroCall, syntax_ptr: AstPtr, + is_error_recoverable: bool, mut collector: F, ) { // File containing the macro call. Expansion errors will be attached here. @@ -567,7 +579,7 @@ impl ExprCollector<'_> { Some((mark, expansion)) => { // FIXME: Statements are too complicated to recover from error for now. // It is because we don't have any hygiene for local variable expansion right now. - if T::can_cast(syntax::SyntaxKind::MACRO_STMTS) && res.err.is_some() { + if !is_error_recoverable && res.err.is_some() { self.expander.exit(self.db, mark); collector(self, None); } else { @@ -591,56 +603,55 @@ impl ExprCollector<'_> { } fn collect_stmt(&mut self, s: ast::Stmt) -> Option> { - let stmt = - match s { - ast::Stmt::LetStmt(stmt) => { - self.check_cfg(&stmt)?; - - let pat = self.collect_pat_opt(stmt.pat()); - let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it)); - let initializer = stmt.initializer().map(|e| self.collect_expr(e)); - vec![Statement::Let { pat, type_ref, initializer }] - } - ast::Stmt::ExprStmt(stmt) => { - self.check_cfg(&stmt)?; - - // Note that macro could be expended to multiple statements - if let Some(ast::Expr::MacroCall(m)) = stmt.expr() { - let syntax_ptr = AstPtr::new(&stmt.expr().unwrap()); - let mut stmts = vec![]; - - self.collect_macro_call(m, syntax_ptr.clone(), |this, expansion| { - match expansion { - Some(expansion) => { - let statements: ast::MacroStmts = expansion; - - statements.statements().for_each(|stmt| { - if let Some(mut r) = this.collect_stmt(stmt) { - stmts.append(&mut r); - } - }); - if let Some(expr) = statements.expr() { - stmts.push(Statement::Expr(this.collect_expr(expr))); + let stmt = match s { + ast::Stmt::LetStmt(stmt) => { + self.check_cfg(&stmt)?; + + let pat = self.collect_pat_opt(stmt.pat()); + let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it)); + let initializer = stmt.initializer().map(|e| self.collect_expr(e)); + vec![Statement::Let { pat, type_ref, initializer }] + } + ast::Stmt::ExprStmt(stmt) => { + self.check_cfg(&stmt)?; + + // Note that macro could be expended to multiple statements + if let Some(ast::Expr::MacroCall(m)) = stmt.expr() { + let syntax_ptr = AstPtr::new(&stmt.expr().unwrap()); + let mut stmts = vec![]; + + self.collect_macro_call(m, syntax_ptr.clone(), false, |this, expansion| { + match expansion { + Some(expansion) => { + let statements: ast::MacroStmts = expansion; + + statements.statements().for_each(|stmt| { + if let Some(mut r) = this.collect_stmt(stmt) { + stmts.append(&mut r); } - } - None => { - stmts.push(Statement::Expr( - this.alloc_expr(Expr::Missing, syntax_ptr.clone()), - )); + }); + if let Some(expr) = statements.expr() { + stmts.push(Statement::Expr(this.collect_expr(expr))); } } - }); - stmts - } else { - vec![Statement::Expr(self.collect_expr_opt(stmt.expr()))] - } + None => { + stmts.push(Statement::Expr( + this.alloc_expr(Expr::Missing, syntax_ptr.clone()), + )); + } + } + }); + stmts + } else { + vec![Statement::Expr(self.collect_expr_opt(stmt.expr()))] } - ast::Stmt::Item(item) => { - self.check_cfg(&item)?; + } + ast::Stmt::Item(item) => { + self.check_cfg(&item)?; - return None; - } - }; + return None; + } + }; Some(stmt) } diff --git a/crates/hir_def/src/item_tree.rs b/crates/hir_def/src/item_tree.rs index 09bcb10dc..86239d903 100644 --- a/crates/hir_def/src/item_tree.rs +++ b/crates/hir_def/src/item_tree.rs @@ -110,6 +110,11 @@ impl ItemTree { // still need to collect inner items. ctx.lower_inner_items(e.syntax()) }, + ast::ExprStmt(stmt) => { + // Macros can expand to stmt. We return an empty item tree in this case, but + // still need to collect inner items. + ctx.lower_inner_items(stmt.syntax()) + }, _ => { panic!("cannot create item tree from {:?} {}", syntax, syntax); }, diff --git a/crates/hir_expand/src/db.rs b/crates/hir_expand/src/db.rs index 9086e6c17..a3070f1f9 100644 --- a/crates/hir_expand/src/db.rs +++ b/crates/hir_expand/src/db.rs @@ -401,13 +401,14 @@ fn to_fragment_kind(db: &dyn AstDatabase, id: MacroCallId) -> FragmentKind { match parent.kind() { MACRO_ITEMS | SOURCE_FILE => FragmentKind::Items, + MACRO_STMTS => FragmentKind::Statement, ITEM_LIST => FragmentKind::Items, LET_STMT => { // FIXME: Handle Pattern FragmentKind::Expr } EXPR_STMT => FragmentKind::Statements, - BLOCK_EXPR => FragmentKind::Expr, + BLOCK_EXPR => FragmentKind::Statements, ARG_LIST => FragmentKind::Expr, TRY_EXPR => FragmentKind::Expr, TUPLE_EXPR => FragmentKind::Expr, diff --git a/crates/hir_ty/src/tests/macros.rs b/crates/hir_ty/src/tests/macros.rs index fb3afaedc..af4f8bb11 100644 --- a/crates/hir_ty/src/tests/macros.rs +++ b/crates/hir_ty/src/tests/macros.rs @@ -215,6 +215,22 @@ fn expr_macro_expanded_in_various_places() { ); } +#[test] +fn expr_macro_expanded_in_stmts() { + check_infer( + r#" + macro_rules! id { ($($es:tt)*) => { $($es)* } } + fn foo() { + id! { let a = (); } + } + "#, + expect![[r#" + !0..8 'leta=();': () + 57..84 '{ ...); } }': () + "#]], + ); +} + #[test] fn infer_type_value_macro_having_same_name() { check_infer( diff --git a/crates/mbe/src/tests.rs b/crates/mbe/src/tests.rs index 3a168bb4b..eca0bcc18 100644 --- a/crates/mbe/src/tests.rs +++ b/crates/mbe/src/tests.rs @@ -662,12 +662,11 @@ fn test_tt_to_stmts() { LITERAL@12..13 INT_NUMBER@12..13 "1" SEMICOLON@13..14 ";" - EXPR_STMT@14..15 - PATH_EXPR@14..15 - PATH@14..15 - PATH_SEGMENT@14..15 - NAME_REF@14..15 - IDENT@14..15 "a""#, + PATH_EXPR@14..15 + PATH@14..15 + PATH_SEGMENT@14..15 + NAME_REF@14..15 + IDENT@14..15 "a""#, ); } diff --git a/crates/parser/src/grammar.rs b/crates/parser/src/grammar.rs index 6c0e22722..cebb8f400 100644 --- a/crates/parser/src/grammar.rs +++ b/crates/parser/src/grammar.rs @@ -63,11 +63,11 @@ pub(crate) mod fragments { } pub(crate) fn stmt(p: &mut Parser) { - expressions::stmt(p, expressions::StmtWithSemi::No) + expressions::stmt(p, expressions::StmtWithSemi::No, true) } pub(crate) fn stmt_optional_semi(p: &mut Parser) { - expressions::stmt(p, expressions::StmtWithSemi::Optional) + expressions::stmt(p, expressions::StmtWithSemi::Optional, false) } pub(crate) fn opt_visibility(p: &mut Parser) { @@ -133,7 +133,7 @@ pub(crate) mod fragments { continue; } - expressions::stmt(p, expressions::StmtWithSemi::Optional); + expressions::stmt(p, expressions::StmtWithSemi::Optional, true); } m.complete(p, MACRO_STMTS); diff --git a/crates/parser/src/grammar/expressions.rs b/crates/parser/src/grammar/expressions.rs index 5f885edfd..0d9dc9348 100644 --- a/crates/parser/src/grammar/expressions.rs +++ b/crates/parser/src/grammar/expressions.rs @@ -54,7 +54,7 @@ fn is_expr_stmt_attr_allowed(kind: SyntaxKind) -> bool { !forbid } -pub(super) fn stmt(p: &mut Parser, with_semi: StmtWithSemi) { +pub(super) fn stmt(p: &mut Parser, with_semi: StmtWithSemi, prefer_expr: bool) { let m = p.start(); // test attr_on_expr_stmt // fn foo() { @@ -90,7 +90,7 @@ pub(super) fn stmt(p: &mut Parser, with_semi: StmtWithSemi) { p.error(format!("attributes are not allowed on {:?}", kind)); } - if p.at(T!['}']) { + if p.at(T!['}']) || (prefer_expr && p.at(EOF)) { // test attr_on_last_expr_in_block // fn foo() { // { #[A] bar!()? } @@ -198,7 +198,7 @@ pub(super) fn expr_block_contents(p: &mut Parser) { continue; } - stmt(p, StmtWithSemi::Yes) + stmt(p, StmtWithSemi::Yes, false) } } diff --git a/crates/syntax/src/ast/generated/nodes.rs b/crates/syntax/src/ast/generated/nodes.rs index 064931aec..6097178b6 100644 --- a/crates/syntax/src/ast/generated/nodes.rs +++ b/crates/syntax/src/ast/generated/nodes.rs @@ -1336,6 +1336,7 @@ pub enum Expr { Literal(Literal), LoopExpr(LoopExpr), MacroCall(MacroCall), + MacroStmts(MacroStmts), MatchExpr(MatchExpr), MethodCallExpr(MethodCallExpr), ParenExpr(ParenExpr), @@ -3034,6 +3035,9 @@ impl From for Expr { impl From for Expr { fn from(node: MacroCall) -> Expr { Expr::MacroCall(node) } } +impl From for Expr { + fn from(node: MacroStmts) -> Expr { Expr::MacroStmts(node) } +} impl From for Expr { fn from(node: MatchExpr) -> Expr { Expr::MatchExpr(node) } } @@ -3078,8 +3082,8 @@ impl AstNode for Expr { match kind { ARRAY_EXPR | AWAIT_EXPR | BIN_EXPR | BLOCK_EXPR | BOX_EXPR | BREAK_EXPR | CALL_EXPR | CAST_EXPR | CLOSURE_EXPR | CONTINUE_EXPR | EFFECT_EXPR | FIELD_EXPR | FOR_EXPR - | IF_EXPR | INDEX_EXPR | LITERAL | LOOP_EXPR | MACRO_CALL | MATCH_EXPR - | METHOD_CALL_EXPR | PAREN_EXPR | PATH_EXPR | PREFIX_EXPR | RANGE_EXPR + | IF_EXPR | INDEX_EXPR | LITERAL | LOOP_EXPR | MACRO_CALL | MACRO_STMTS + | MATCH_EXPR | METHOD_CALL_EXPR | PAREN_EXPR | PATH_EXPR | PREFIX_EXPR | RANGE_EXPR | RECORD_EXPR | REF_EXPR | RETURN_EXPR | TRY_EXPR | TUPLE_EXPR | WHILE_EXPR | YIELD_EXPR => true, _ => false, @@ -3105,6 +3109,7 @@ impl AstNode for Expr { LITERAL => Expr::Literal(Literal { syntax }), LOOP_EXPR => Expr::LoopExpr(LoopExpr { syntax }), MACRO_CALL => Expr::MacroCall(MacroCall { syntax }), + MACRO_STMTS => Expr::MacroStmts(MacroStmts { syntax }), MATCH_EXPR => Expr::MatchExpr(MatchExpr { syntax }), METHOD_CALL_EXPR => Expr::MethodCallExpr(MethodCallExpr { syntax }), PAREN_EXPR => Expr::ParenExpr(ParenExpr { syntax }), @@ -3142,6 +3147,7 @@ impl AstNode for Expr { Expr::Literal(it) => &it.syntax, Expr::LoopExpr(it) => &it.syntax, Expr::MacroCall(it) => &it.syntax, + Expr::MacroStmts(it) => &it.syntax, Expr::MatchExpr(it) => &it.syntax, Expr::MethodCallExpr(it) => &it.syntax, Expr::ParenExpr(it) => &it.syntax, -- cgit v1.2.3