From b177813f3bef708636ec4be271e376b111c36a59 Mon Sep 17 00:00:00 2001 From: Edwin Cheng Date: Mon, 22 Apr 2019 15:33:55 +0800 Subject: Add mbe expand limit and poision macro set --- crates/ra_hir/src/ids.rs | 7 ++ crates/ra_hir/src/nameres.rs | 16 ++- crates/ra_hir/src/nameres/collector.rs | 186 ++++++++++++++++++++++++++++++--- crates/ra_mbe/src/subtree_parser.rs | 13 ++- crates/ra_tt/src/lib.rs | 12 +++ 5 files changed, 216 insertions(+), 18 deletions(-) diff --git a/crates/ra_hir/src/ids.rs b/crates/ra_hir/src/ids.rs index e771a311c..c7849c995 100644 --- a/crates/ra_hir/src/ids.rs +++ b/crates/ra_hir/src/ids.rs @@ -94,6 +94,13 @@ fn parse_macro( let macro_rules = db.macro_def(loc.def).ok_or("Fail to find macro definition")?; let tt = macro_rules.expand(¯o_arg).map_err(|err| format!("{:?}", err))?; + + // Set a hard limit for the expanded tt + let count = tt.count(); + if count > 65536 { + return Err(format!("Total tokens count exceed limit : count = {}", count)); + } + Ok(mbe::token_tree_to_ast_item_list(&tt)) } diff --git a/crates/ra_hir/src/nameres.rs b/crates/ra_hir/src/nameres.rs index 39152360c..fbfff4fd7 100644 --- a/crates/ra_hir/src/nameres.rs +++ b/crates/ra_hir/src/nameres.rs @@ -55,7 +55,7 @@ mod tests; use std::sync::Arc; -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use ra_arena::{Arena, RawId, impl_arena_id}; use ra_db::{FileId, Edition}; use test_utils::tested_by; @@ -91,6 +91,19 @@ pub struct CrateDefMap { root: CrateModuleId, modules: Arena, public_macros: FxHashMap, + + /// Some macros are not well-behavior, which leads to infinite loop + /// e.g. macro_rules! foo { ($ty:ty) => { foo!($ty); } } + /// We mark it down and skip it in collector + /// + /// FIXME: + /// Right now it only handle a poison macro in a single crate, + /// such that if other crate try to call that macro, + /// the whole process will do again until it became poisoned in that crate. + /// We should handle this macro set globally + /// However, do we want to put it as a global variable? + poison_macros: FxHashSet, + diagnostics: Vec, } @@ -195,6 +208,7 @@ impl CrateDefMap { root, modules, public_macros: FxHashMap::default(), + poison_macros: FxHashSet::default(), diagnostics: Vec::new(), } }; diff --git a/crates/ra_hir/src/nameres/collector.rs b/crates/ra_hir/src/nameres/collector.rs index 6147b3219..82738cce3 100644 --- a/crates/ra_hir/src/nameres/collector.rs +++ b/crates/ra_hir/src/nameres/collector.rs @@ -42,14 +42,40 @@ pub(super) fn collect_defs(db: &impl DefDatabase, mut def_map: CrateDefMap) -> C unresolved_imports: Vec::new(), unexpanded_macros: Vec::new(), global_macro_scope: FxHashMap::default(), - marco_stack_count: 0, + macro_stack_monitor: SimpleMacroStackMonitor::default(), }; collector.collect(); collector.finish() } +trait MacroStackMonitor { + fn increase(&mut self, macro_def_id: MacroDefId); + fn decrease(&mut self, macro_def_id: MacroDefId); + + fn is_poison(&self, macro_def_id: MacroDefId) -> bool; +} + +#[derive(Default)] +struct SimpleMacroStackMonitor { + counts: FxHashMap, +} + +impl MacroStackMonitor for SimpleMacroStackMonitor { + fn increase(&mut self, macro_def_id: MacroDefId) { + *self.counts.entry(macro_def_id).or_default() += 1; + } + + fn decrease(&mut self, macro_def_id: MacroDefId) { + *self.counts.entry(macro_def_id).or_default() -= 1; + } + + fn is_poison(&self, macro_def_id: MacroDefId) -> bool { + *self.counts.get(¯o_def_id).unwrap_or(&0) > 100 + } +} + /// Walks the tree of module recursively -struct DefCollector { +struct DefCollector { db: DB, def_map: CrateDefMap, glob_imports: FxHashMap>, @@ -59,12 +85,13 @@ struct DefCollector { /// Some macro use `$tt:tt which mean we have to handle the macro perfectly /// To prevent stackoverflow, we add a deep counter here for prevent that. - marco_stack_count: u32, + macro_stack_monitor: M, } -impl<'a, DB> DefCollector<&'a DB> +impl<'a, DB, M> DefCollector<&'a DB, M> where DB: DefDatabase, + M: MacroStackMonitor, { fn collect(&mut self) { let crate_graph = self.db.crate_graph(); @@ -317,30 +344,40 @@ where let def_map = self.db.crate_def_map(krate); if let Some(macro_id) = def_map.public_macros.get(&path.segments[1].name).cloned() { let call_id = MacroCallLoc { def: macro_id, ast_id: *ast_id }.id(self.db); - resolved.push((*module_id, call_id)); + resolved.push((*module_id, call_id, macro_id)); } false }); - for (module_id, macro_call_id) in resolved { - self.collect_macro_expansion(module_id, macro_call_id); + for (module_id, macro_call_id, macro_def_id) in resolved { + self.collect_macro_expansion(module_id, macro_call_id, macro_def_id); } res } - fn collect_macro_expansion(&mut self, module_id: CrateModuleId, macro_call_id: MacroCallId) { - self.marco_stack_count += 1; + fn collect_macro_expansion( + &mut self, + module_id: CrateModuleId, + macro_call_id: MacroCallId, + macro_def_id: MacroDefId, + ) { + if self.def_map.poison_macros.contains(¯o_def_id) { + return; + } + + self.macro_stack_monitor.increase(macro_def_id); - if self.marco_stack_count < 300 { + if !self.macro_stack_monitor.is_poison(macro_def_id) { let file_id: HirFileId = macro_call_id.into(); let raw_items = self.db.raw_items(file_id); ModCollector { def_collector: &mut *self, file_id, module_id, raw_items: &raw_items } - .collect(raw_items.items()) + .collect(raw_items.items()); } else { log::error!("Too deep macro expansion: {}", macro_call_id.debug_dump(self.db)); + self.def_map.poison_macros.insert(macro_def_id); } - self.marco_stack_count -= 1; + self.macro_stack_monitor.decrease(macro_def_id); } fn finish(self) -> CrateDefMap { @@ -356,9 +393,10 @@ struct ModCollector<'a, D> { raw_items: &'a raw::RawItems, } -impl ModCollector<'_, &'_ mut DefCollector<&'_ DB>> +impl ModCollector<'_, &'_ mut DefCollector<&'_ DB, M>> where DB: DefDatabase, + M: MacroStackMonitor, { fn collect(&mut self, items: &[raw::RawItem]) { for item in items { @@ -484,7 +522,7 @@ where { let macro_call_id = MacroCallLoc { def: macro_id, ast_id }.id(self.def_collector.db); - self.def_collector.collect_macro_expansion(self.module_id, macro_call_id); + self.def_collector.collect_macro_expansion(self.module_id, macro_call_id, macro_id); return; } @@ -530,3 +568,123 @@ fn resolve_submodule( None => Err(if is_dir_owner { file_mod } else { file_dir_mod }), } } + +#[cfg(test)] +mod tests { + use ra_db::SourceDatabase; + + use crate::{Crate, mock::MockDatabase, DefDatabase}; + use ra_arena::{Arena}; + use super::*; + use rustc_hash::FxHashSet; + + struct LimitedMacroStackMonitor { + count: u32, + limit: u32, + poison_limit: u32, + } + + impl MacroStackMonitor for LimitedMacroStackMonitor { + fn increase(&mut self, _: MacroDefId) { + self.count += 1; + assert!(self.count < self.limit); + } + + fn decrease(&mut self, _: MacroDefId) { + self.count -= 1; + } + + fn is_poison(&self, _: MacroDefId) -> bool { + self.count >= self.poison_limit + } + } + + fn do_collect_defs( + db: &impl DefDatabase, + def_map: CrateDefMap, + monitor: impl MacroStackMonitor, + ) -> CrateDefMap { + let mut collector = DefCollector { + db, + def_map, + glob_imports: FxHashMap::default(), + unresolved_imports: Vec::new(), + unexpanded_macros: Vec::new(), + global_macro_scope: FxHashMap::default(), + macro_stack_monitor: monitor, + }; + collector.collect(); + collector.finish() + } + + fn do_limited_resolve(code: &str, limit: u32, poison_limit: u32) -> CrateDefMap { + let (db, _source_root, _) = MockDatabase::with_single_file(&code); + let crate_id = db.crate_graph().iter().next().unwrap(); + let krate = Crate { crate_id }; + + let def_map = { + let edition = krate.edition(&db); + let mut modules: Arena = Arena::default(); + let root = modules.alloc(ModuleData::default()); + CrateDefMap { + krate, + edition, + extern_prelude: FxHashMap::default(), + prelude: None, + root, + modules, + public_macros: FxHashMap::default(), + poison_macros: FxHashSet::default(), + diagnostics: Vec::new(), + } + }; + + do_collect_defs(&db, def_map, LimitedMacroStackMonitor { count: 0, limit, poison_limit }) + } + + #[test] + fn test_macro_expand_limit_width() { + do_limited_resolve( + r#" + macro_rules! foo { + ($($ty:ty)*) => { foo!($($ty)*, $($ty)*); } + } +foo!(KABOOM); + "#, + 16, + 1000, + ); + } + + #[test] + fn test_macro_expand_poisoned() { + let def = do_limited_resolve( + r#" + macro_rules! foo { + ($ty:ty) => { foo!($ty); } + } +foo!(KABOOM); + "#, + 100, + 16, + ); + + assert_eq!(def.poison_macros.len(), 1); + } + + #[test] + fn test_macro_expand_normal() { + let def = do_limited_resolve( + r#" + macro_rules! foo { + ($ident:ident) => { struct $ident {} } + } +foo!(Bar); + "#, + 16, + 16, + ); + + assert_eq!(def.poison_macros.len(), 0); + } +} diff --git a/crates/ra_mbe/src/subtree_parser.rs b/crates/ra_mbe/src/subtree_parser.rs index 528aa0f8a..f07107414 100644 --- a/crates/ra_mbe/src/subtree_parser.rs +++ b/crates/ra_mbe/src/subtree_parser.rs @@ -5,6 +5,7 @@ use ra_syntax::{SyntaxKind}; struct OffsetTokenSink { token_pos: usize, + error: bool, } impl TreeSink for OffsetTokenSink { @@ -13,7 +14,9 @@ impl TreeSink for OffsetTokenSink { } fn start_node(&mut self, _kind: SyntaxKind) {} fn finish_node(&mut self) {} - fn error(&mut self, _error: ra_parser::ParseError) {} + fn error(&mut self, _error: ra_parser::ParseError) { + self.error = true; + } } pub(crate) struct Parser<'a> { @@ -67,11 +70,15 @@ impl<'a> Parser<'a> { F: FnOnce(&dyn TokenSource, &mut dyn TreeSink), { let mut src = SubtreeTokenSource::new(&self.subtree.token_trees[*self.cur_pos..]); - let mut sink = OffsetTokenSink { token_pos: 0 }; + let mut sink = OffsetTokenSink { token_pos: 0, error: false }; f(&src, &mut sink); - self.finish(sink.token_pos, &mut src) + let r = self.finish(sink.token_pos, &mut src); + if sink.error { + return None; + } + r } fn finish(self, parsed_token: usize, src: &mut SubtreeTokenSource) -> Option { diff --git a/crates/ra_tt/src/lib.rs b/crates/ra_tt/src/lib.rs index 0b0b9b4d2..9cc646140 100644 --- a/crates/ra_tt/src/lib.rs +++ b/crates/ra_tt/src/lib.rs @@ -149,3 +149,15 @@ impl fmt::Display for Punct { fmt::Display::fmt(&self.char, f) } } + +impl Subtree { + /// Count the number of tokens recursively + pub fn count(&self) -> usize { + self.token_trees.iter().fold(self.token_trees.len(), |acc, c| { + acc + match c { + TokenTree::Subtree(c) => c.count(), + _ => 0, + } + }) + } +} -- cgit v1.2.3