From 32de8bd5dac80a2c09e7106144cab5a8e16accc4 Mon Sep 17 00:00:00 2001 From: Akshay Date: Thu, 8 Aug 2024 22:19:14 +0100 Subject: store nodes as usize --- src/ast.rs | 13 ++++- src/eval.rs | 174 ++++++++++++++++++++++++++++++++++++++-------------------- src/lib.rs | 25 ++++++++- src/parser.rs | 40 ++++++++------ 4 files changed, 170 insertions(+), 82 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 96fe8ab..35bf6c3 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -59,7 +59,7 @@ pub enum Statement { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Expr { Node, - FieldAccess(Vec), + FieldAccess(Box, Identifier), Unit, Lit(Literal), Ident(Identifier), @@ -67,7 +67,7 @@ pub enum Expr { Bin(Box, BinOp, Box), Unary(Box, UnaryOp), Call(Call), - IfExpr(If), + If(IfExpr), Block(Block), } @@ -75,6 +75,13 @@ impl Expr { pub fn boxed(self) -> Box { Box::new(self) } + + pub fn as_ident(self) -> Option { + match self { + Self::Ident(i) => Some(i), + _ => None, + } + } } #[derive(Debug, Eq, PartialEq, Clone, Copy)] @@ -164,7 +171,7 @@ pub struct Declaration { } #[derive(Debug, Eq, PartialEq, Clone)] -pub struct If { +pub struct IfExpr { pub condition: Box, pub then: Block, pub else_: Block, diff --git a/src/eval.rs b/src/eval.rs index 1a3f8a8..7822419 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,6 +1,6 @@ //! tree walking interpreter for tbsp -use crate::ast; +use crate::{ast, Wrap}; use std::{collections::HashMap, fmt}; #[derive(Debug, PartialEq, Eq, Clone)] @@ -38,10 +38,11 @@ pub enum Value { Integer(i128), String(String), Boolean(bool), - Node, - FieldAccess(Vec), + Node(NodeId), } +type NodeId = usize; + impl Value { fn ty(&self) -> ast::Type { match self { @@ -49,8 +50,7 @@ impl Value { Self::Integer(_) => ast::Type::Integer, Self::String(_) => ast::Type::String, Self::Boolean(_) => ast::Type::Boolean, - Self::Node => ast::Type::Node, - Self::FieldAccess(_) => ast::Type::Node, + Self::Node(_) => ast::Type::Node, } } @@ -106,6 +106,16 @@ impl Value { } } + fn as_node(&self) -> std::result::Result { + match self { + Self::Node(id) => Ok(*id), + v => Err(Error::TypeMismatch { + expected: ast::Type::Node, + got: v.ty(), + }), + } + } + fn add(&self, other: &Self) -> Result { match (self, other) { (Self::Integer(s), Self::Integer(o)) => Ok(Self::Integer(*s + *o)), @@ -267,8 +277,7 @@ impl fmt::Display for Value { Self::Integer(i) => write!(f, "{i}"), Self::String(s) => write!(f, "{s}"), Self::Boolean(b) => write!(f, "{b}"), - Self::Node => write!(f, ""), - Self::FieldAccess(items) => write!(f, ".{}", items.join(".")), + Self::Node(id) => write!(f, ""), } } } @@ -373,15 +382,17 @@ pub enum Error { pub type Result = std::result::Result; -pub struct Context<'a> { +pub struct Context { variables: HashMap, language: tree_sitter::Language, visitors: Visitors, input_src: Option, - cursor: Option>, + cursor: Option>, + tree: Option<&'static tree_sitter::Tree>, + cache: HashMap>, } -impl<'a> fmt::Debug for Context<'a> { +impl fmt::Debug for Context { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Context") .field("variables", &self.variables) @@ -400,7 +411,7 @@ impl<'a> fmt::Debug for Context<'a> { } } -impl<'a> Context<'a> { +impl Context { pub fn new(language: tree_sitter::Language) -> Self { Self { visitors: Default::default(), @@ -408,9 +419,44 @@ impl<'a> Context<'a> { language, input_src: None, cursor: None, + tree: None, + cache: HashMap::default(), } } + pub fn cache_node(&mut self, node: tree_sitter::Node<'static>) { + self.cache.entry(node.id()).or_insert(node); + } + + pub fn get_node_by_id(&mut self, id: usize) -> Option> { + let root_node = self.tree.as_ref().map(|t| t.root_node())?; + self.get_node_by_id_helper(root_node, id) + } + + fn get_node_by_id_helper( + &mut self, + start: tree_sitter::Node<'static>, + id: usize, + ) -> Option> { + self.cache_node(start); + + if let Some(found) = self.cache.get(&id) { + return Some(*found); + } + + if start.id() == id { + return Some(start); + } else { + for child in start.children(&mut start.walk()) { + if let Some(n) = self.get_node_by_id_helper(child, id) { + return Some(n); + }; + } + } + + None + } + pub fn with_program(mut self, program: ast::Program) -> std::result::Result { for stanza in program.stanzas.into_iter() { self.visitors.insert(stanza, &self.language)?; @@ -423,8 +469,10 @@ impl<'a> Context<'a> { self } - pub fn with_cursor(mut self, cursor: tree_sitter::TreeCursor<'a>) -> Self { - self.cursor = Some(cursor); + pub fn with_tree(mut self, tree: tree_sitter::Tree) -> Self { + let tree = Box::leak(Box::new(tree)); + self.cursor = Some(tree.walk()); + self.tree = Some(tree); self } @@ -436,10 +484,10 @@ impl<'a> Context<'a> { ast::Expr::Bin(lhs, op, rhs) => self.eval_bin(&*lhs, *op, &*rhs), ast::Expr::Unary(expr, op) => self.eval_unary(&*expr, *op), ast::Expr::Call(call) => self.eval_call(&*call), - ast::Expr::IfExpr(if_expr) => self.eval_if(if_expr), + ast::Expr::If(if_expr) => self.eval_if(if_expr), ast::Expr::Block(block) => self.eval_block(block), - ast::Expr::Node => Ok(Value::Node), - ast::Expr::FieldAccess(items) => Ok(Value::FieldAccess(items.to_owned())), + ast::Expr::Node => self.eval_node(), + ast::Expr::FieldAccess(expr, items) => self.eval_field_access(expr, items), } } @@ -451,6 +499,16 @@ impl<'a> Context<'a> { } } + fn eval_node(&mut self) -> Result { + self.cursor + .as_ref() + .ok_or(Error::CurrentNodeNotPresent)? + .node() + .id() + .wrap(Value::Node) + .wrap_ok() + } + fn lookup(&mut self, ident: &ast::Identifier) -> std::result::Result<&Variable, Error> { self.variables .get(ident) @@ -469,14 +527,14 @@ impl<'a> Context<'a> { ty: ast::Type, ) -> std::result::Result<&mut Variable, Error> { if self.lookup(ident).is_err() { - Ok(self - .variables + self.variables .entry(ident.to_owned()) .or_insert_with(|| Variable { name: ident.to_owned(), value: Value::default(ty), ty, - })) + }) + .wrap_ok() } else { Err(Error::AlreadyBound(ident.to_owned())) } @@ -574,7 +632,7 @@ impl<'a> Context<'a> { } } - fn eval_if(&mut self, if_expr: &ast::If) -> Result { + fn eval_if(&mut self, if_expr: &ast::IfExpr) -> Result { let cond = self.eval_expr(&if_expr.condition)?; if cond.as_boolean()? { @@ -638,32 +696,9 @@ impl<'a> Context<'a> { } } ("text", [arg]) => { - let node = match self.eval_expr(arg)? { - Value::Node => self - .cursor - .as_ref() - .ok_or(Error::CurrentNodeNotPresent)? - .node(), - Value::FieldAccess(fields) => { - let mut node = self - .cursor - .as_ref() - .ok_or(Error::CurrentNodeNotPresent)? - .node(); - for field in &fields { - node = node - .child_by_field_name(field.as_bytes()) - .ok_or_else(|| Error::FailedLookup(field.to_owned()))?; - } - node - } - v => { - return Err(Error::TypeMismatch { - expected: ast::Type::Node, - got: v.ty(), - }) - } - }; + let v = self.eval_expr(arg)?; + let id = v.as_node()?; + let node = self.get_node_by_id(id).unwrap(); let text = node .utf8_text(self.input_src.as_ref().unwrap().as_bytes()) .unwrap(); @@ -689,7 +724,7 @@ impl<'a> Context<'a> { fn eval_statement(&mut self, stmt: &ast::Statement) -> Result { match stmt { - ast::Statement::Bare(expr) => self.eval_expr(expr).map(|_| Value::Unit), + ast::Statement::Bare(expr) => self.eval_expr(expr), ast::Statement::Declaration(decl) => self.eval_declaration(decl), } } @@ -701,6 +736,24 @@ impl<'a> Context<'a> { Ok(Value::Unit) } + fn eval_field_access(&mut self, expr: &ast::Expr, field: &ast::Identifier) -> Result { + let v = self.eval_expr(expr)?; + let base = v.as_node()?; + let base_node = self.get_node_by_id(base).unwrap(); + base_node + .child_by_field_name(field) + .ok_or(Error::InvalidNodeKind(field.clone())) + .map(|n| n.id()) + .map(Value::Node) + } + + fn goto_first_child(&mut self) -> bool { + self.cursor + .as_mut() + .map(|c| c.goto_first_child()) + .unwrap_or_default() + } + pub fn eval(&mut self) -> Result { let visitors = std::mem::take(&mut self.visitors); let mut has_next = true; @@ -710,29 +763,29 @@ impl<'a> Context<'a> { self.eval_block(&visitors.begin)?; while has_next { - let current_node = self.cursor.as_mut().unwrap().node(); + let current_node = self.cursor.as_ref().unwrap().node(); postorder.push(current_node); let visitor = visitors.get_by_node(current_node); - visitor.map(|v| self.eval_block(&v.enter)); + if let Some(v) = visitor { + self.eval_block(&v.enter)?; + } - has_next = self.cursor.as_mut().unwrap().goto_first_child(); + has_next = self.goto_first_child(); if !has_next { has_next = self.cursor.as_mut().unwrap().goto_next_sibling(); - postorder - .pop() - .and_then(|n| visitors.get_by_node(n)) - .map(|v| self.eval_block(&v.leave)); + if let Some(v) = postorder.pop().and_then(|n| visitors.get_by_node(n)) { + self.eval_block(&v.leave)?; + } } while !has_next && self.cursor.as_mut().unwrap().goto_parent() { has_next = self.cursor.as_mut().unwrap().goto_next_sibling(); - postorder - .pop() - .and_then(|n| visitors.get_by_node(n)) - .map(|v| self.eval_block(&v.leave)); + if let Some(v) = postorder.pop().and_then(|n| visitors.get_by_node(n)) { + self.eval_block(&v.leave)?; + } } } @@ -748,12 +801,11 @@ pub fn evaluate(file: &str, program: &str, language: tree_sitter::Language) -> R let _ = parser.set_language(&language); let tree = parser.parse(file, None).unwrap(); - let cursor = tree.walk(); let program = ast::Program::new().from_str(program).unwrap(); let mut ctx = Context::new(language) .with_input(file.to_owned()) - .with_cursor(cursor) + .with_tree(tree) .with_program(program)?; ctx.eval() @@ -857,7 +909,7 @@ mod test { name: "a".to_owned(), init: Some(ast::Expr::int(1).boxed()), }), - ast::Statement::Bare(ast::Expr::IfExpr(ast::If { + ast::Statement::Bare(ast::Expr::If(ast::IfExpr { condition: ast::Expr::true_().boxed(), then: ast::Block { body: vec![ast::Statement::Bare(ast::Expr::Bin( diff --git a/src/lib.rs b/src/lib.rs index fca5dd5..39f4605 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,29 @@ mod ast; mod eval; -mod parser; +pub mod parser; mod string; pub use eval::evaluate; + +trait Wrap { + fn wrap(self, f: F) -> U + where + F: Fn(T) -> U, + Self: Sized; + + fn wrap_ok(self) -> Result + where + Self: Sized, + { + Ok(self) + } +} + +impl Wrap for T { + fn wrap(self, f: F) -> U + where + F: Fn(T) -> U, + { + f(self) + } +} diff --git a/src/parser.rs b/src/parser.rs index d705a11..4ec8e57 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -4,7 +4,7 @@ use nom::{ character::complete::{alpha1, alphanumeric1, char, multispace0, multispace1, one_of}, combinator::{map, opt, recognize, value}, error::ParseError, - multi::{many0, many0_count, many1, separated_list0, separated_list1}, + multi::{fold_many0, many0, many0_count, many1, separated_list0}, sequence::{delimited, pair, preceded, terminated, tuple}, IResult, Parser, }; @@ -169,24 +169,27 @@ fn parse_mul<'a>(i: &'a str) -> IResult<&'a str, Expr> { let div = parse_op("/", BinOp::Arith(ArithOp::Div)); let mod_ = parse_op("%", BinOp::Arith(ArithOp::Mod)); let op = alt((mul, div, mod_)); - let recursive = parse_binary(parse_atom, op, parse_mul); - let base = parse_atom; + let recursive = parse_binary(parse_field_access, op, parse_mul); + let base = parse_field_access; alt((recursive, base)).parse(i) } -fn parse_field_access<'a>(i: &'a str) -> IResult<&'a str, Vec> { - let node = tag("node"); - let dot = ws(char('.')); - let fields = separated_list1(ws(char('.')), map(parse_name, str::to_owned)); - map(tuple((node, dot, fields)), |(_, _, fields)| fields)(i) +fn parse_field_access<'a>(i: &'a str) -> IResult<&'a str, Expr> { + let trailing = map(tuple((ws(char('.')), ws(parse_ident))), |(_, i)| i); + let (i, base) = parse_atom(i)?; + + fold_many0( + trailing, + move || base.clone(), + move |acc, new| Expr::FieldAccess(acc.boxed(), new), + )(i) } fn parse_atom<'a>(i: &'a str) -> IResult<&'a str, Expr> { let inner = alt(( - map(parse_field_access, Expr::FieldAccess), map(tag("node"), |_| Expr::Node), map(parse_block, Expr::Block), - map(parse_if, Expr::IfExpr), + map(parse_if, Expr::If), map(parse_call, Expr::Call), map(parse_lit, Expr::Lit), map(parse_ident, Expr::Ident), @@ -217,7 +220,7 @@ fn parse_block<'a>(i: &'a str) -> IResult<&'a str, Block> { delimited(open, statements, close).parse(i) } -fn parse_if<'a>(i: &'a str) -> IResult<&'a str, If> { +fn parse_if<'a>(i: &'a str) -> IResult<&'a str, IfExpr> { let if_ = delimited(multispace0, tag("if"), multispace1); let open = char('('); @@ -231,7 +234,7 @@ fn parse_if<'a>(i: &'a str) -> IResult<&'a str, If> { map( tuple((if_, open, condition, close, then, else_)), - |(_, _, condition, _, then, else_)| If { + |(_, _, condition, _, then, else_)| IfExpr { condition: condition.boxed(), then, else_: else_.unwrap_or_default(), @@ -571,7 +574,7 @@ mod test { assert_eq!(parse_expr(r#" node "#), Ok(("", Expr::Node))); assert_eq!( parse_expr(r#" node.foo "#), - Ok(("", Expr::FieldAccess(vec!["foo".to_owned()]))) + Ok(("", Expr::FieldAccess(Expr::Node.boxed(), "foo".to_owned()))) ); assert_eq!( parse_expr( @@ -581,7 +584,10 @@ mod test { ), Ok(( "", - Expr::FieldAccess(vec!["foo".to_owned(), "bar".to_owned()]) + Expr::FieldAccess( + Expr::FieldAccess(Expr::Node.boxed(), "foo".to_owned()).boxed(), + "bar".to_owned() + ) )) ); } @@ -600,7 +606,7 @@ mod test { ), Ok(( "", - Expr::IfExpr(If { + Expr::If(IfExpr { condition: Expr::Bin( Expr::int(1).boxed(), BinOp::Cmp(CmpOp::Eq), @@ -713,7 +719,7 @@ mod test { parse_statement("if (true) { true; };"), Ok(( "", - Statement::Bare(Expr::IfExpr(If { + Statement::Bare(Expr::If(IfExpr { condition: Expr::true_().boxed(), then: Block { body: vec![Statement::Bare(Expr::true_())] @@ -726,7 +732,7 @@ mod test { parse_expr("if (true) { true; } else { true; }"), Ok(( "", - Expr::IfExpr(If { + Expr::If(IfExpr { condition: Expr::true_().boxed(), then: Block { body: vec![Statement::Bare(Expr::true_())] -- cgit v1.2.3