From bda68e23328ca62a71da348a13c4d13cc8f991f3 Mon Sep 17 00:00:00 2001 From: Jonas Schievink Date: Wed, 12 May 2021 00:27:16 +0200 Subject: Strip delimiter from fn-like proc macro input --- crates/hir_expand/src/input.rs | 31 +++++++++++++++++++++++++++++-- crates/hir_expand/src/lib.rs | 4 ++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/crates/hir_expand/src/input.rs b/crates/hir_expand/src/input.rs index 112216859..860aa049b 100644 --- a/crates/hir_expand/src/input.rs +++ b/crates/hir_expand/src/input.rs @@ -1,8 +1,9 @@ //! Macro input conditioning. +use parser::SyntaxKind; use syntax::{ ast::{self, AttrsOwner}, - AstNode, SyntaxNode, + AstNode, SyntaxElement, SyntaxNode, }; use crate::{ @@ -19,7 +20,33 @@ pub(crate) fn process_macro_input( let loc: MacroCallLoc = db.lookup_intern_macro(id); match loc.kind { - MacroCallKind::FnLike { .. } => node, + MacroCallKind::FnLike { .. } => { + if !loc.def.is_proc_macro() { + // MBE macros expect the parentheses as part of their input. + return node; + } + + // The input includes the `(` + `)` delimiter tokens, so remove them before passing this + // to the macro. + let node = node.clone_for_update(); + if let Some(SyntaxElement::Token(tkn)) = node.first_child_or_token() { + if matches!( + tkn.kind(), + SyntaxKind::L_BRACK | SyntaxKind::L_PAREN | SyntaxKind::L_CURLY + ) { + tkn.detach(); + } + } + if let Some(SyntaxElement::Token(tkn)) = node.last_child_or_token() { + if matches!( + tkn.kind(), + SyntaxKind::R_BRACK | SyntaxKind::R_PAREN | SyntaxKind::R_CURLY + ) { + tkn.detach(); + } + } + node + } MacroCallKind::Derive { derive_attr_index, .. } => { let item = match ast::Item::cast(node.clone()) { Some(item) => item, diff --git a/crates/hir_expand/src/lib.rs b/crates/hir_expand/src/lib.rs index 5df11856e..88cb16ca4 100644 --- a/crates/hir_expand/src/lib.rs +++ b/crates/hir_expand/src/lib.rs @@ -272,6 +272,10 @@ impl MacroDefId { }; Either::Left(*id) } + + pub fn is_proc_macro(&self) -> bool { + matches!(self.kind, MacroDefKind::ProcMacro(..)) + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -- cgit v1.2.3