From 8eb38033e0c615983c4490354dad4abb00031042 Mon Sep 17 00:00:00 2001 From: Akshay Date: Sat, 13 Jul 2024 18:32:41 +0100 Subject: init trawk --- src/ast.rs | 186 ++++++++++++++ src/eval.rs | 764 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 7 + src/main.rs | 47 ++++ src/parser.rs | 689 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/string.rs | 152 ++++++++++++ 6 files changed, 1845 insertions(+) create mode 100644 src/ast.rs create mode 100644 src/eval.rs create mode 100644 src/lib.rs create mode 100644 src/main.rs create mode 100644 src/parser.rs create mode 100644 src/string.rs (limited to 'src') diff --git a/src/ast.rs b/src/ast.rs new file mode 100644 index 0000000..07b5c39 --- /dev/null +++ b/src/ast.rs @@ -0,0 +1,186 @@ +#[derive(Debug)] +pub struct Program { + pub stanzas: Vec, +} + +impl Program { + pub fn new() -> Self { + Self { + stanzas: Vec::new(), + } + } + + pub fn from_str(mut self, i: &str) -> Result> { + use nom::Finish; + let (remaining_input, stanzas) = crate::parser::parse_file(i).finish()?; + assert!(remaining_input.trim().is_empty(), "{remaining_input}"); + self.stanzas = stanzas; + Ok(self) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct Stanza { + pub pattern: Pattern, + pub statements: Block, +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Pattern { + Begin, + End, + Node(NodePattern), +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct NodePattern { + pub modifier: Modifier, + pub kind: String, +} + +#[derive(Default, Debug, Eq, PartialEq, Clone, Copy)] +pub enum Modifier { + #[default] + Enter, + Leave, +} + +#[derive(Debug, Default, Eq, PartialEq, Clone)] +pub struct Block { + pub body: Vec, +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Statement { + Bare(Expr), + Declaration(Declaration), +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Expr { + Node, + Unit, + Lit(Literal), + Ident(Identifier), + // List(Vec), + Bin(Box, BinOp, Box), + Unary(Box, UnaryOp), + Call(Call), + IfExpr(If), + Block(Block), +} + +impl Expr { + pub fn int(int: i128) -> Expr { + Self::Lit(Literal::Int(int)) + } + + pub fn str(s: &str) -> Expr { + Self::Lit(Literal::Str(s.to_owned())) + } + + pub const fn false_() -> Expr { + Self::Lit(Literal::Bool(false)) + } + + pub const fn true_() -> Expr { + Self::Lit(Literal::Bool(true)) + } + + pub fn boxed(self) -> Box { + Box::new(self) + } +} + +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum UnaryOp { + Not, +} + +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum BinOp { + Arith(ArithOp), + Cmp(CmpOp), + Logic(LogicOp), + // = + Assign(AssignOp), +} + +// + - * / +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum ArithOp { + Add, + Sub, + Mul, + Div, + Mod, +} + +// && || +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum LogicOp { + And, + Or, +} + +// == != > < >= <= +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum CmpOp { + Eq, + Neq, + Gt, + Lt, + Gte, + Lte, +} + +// =, +=, -=, *=, /= +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub struct AssignOp { + pub op: Option, +} + +pub type Identifier = String; + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Literal { + Str(String), + Int(i128), + Bool(bool), +} + +/// A function call +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct Call { + pub function: Identifier, + pub parameters: Vec, +} + +impl From for Expr { + fn from(expr: Call) -> Expr { + Expr::Call(expr) + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Type { + Unit, + Integer, + String, + Boolean, + Node, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Declaration { + pub ty: Type, + pub name: Identifier, + pub init: Option>, +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct If { + pub condition: Box, + pub then: Block, + pub else_: Block, +} diff --git a/src/eval.rs b/src/eval.rs new file mode 100644 index 0000000..859979d --- /dev/null +++ b/src/eval.rs @@ -0,0 +1,764 @@ +//! tree walking interpreter for trawk + +use crate::ast; +use std::{collections::HashMap, fmt}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Variable { + pub ty: ast::Type, + pub name: ast::Identifier, + pub value: Value, +} + +impl Variable { + fn value(&self) -> &Value { + &self.value + } + + fn ty(&self) -> ast::Type { + self.ty + } + + fn assign(&mut self, value: Value) -> Result { + if self.ty() == value.ty() { + self.value = value; + Ok(self.value.clone()) + } else { + Err(Error::TypeMismatch { + expected: self.ty(), + got: value.ty(), + }) + } + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub enum Value { + Unit, + Integer(i128), + String(String), + Boolean(bool), + Node, +} + +impl Value { + fn ty(&self) -> ast::Type { + match self { + Self::Unit => ast::Type::Unit, + Self::Integer(_) => ast::Type::Integer, + Self::String(_) => ast::Type::String, + Self::Boolean(_) => ast::Type::Boolean, + Self::Node => ast::Type::Node, + } + } + + fn default(ty: ast::Type) -> Self { + match ty { + ast::Type::Unit => Self::Unit, + ast::Type::Integer => Self::default_int(), + ast::Type::String => Self::default_string(), + ast::Type::Boolean => Self::default_bool(), + ast::Type::Node => unreachable!(), + } + } + + fn default_int() -> Self { + Self::Integer(0) + } + + fn default_bool() -> Self { + Self::Boolean(false) + } + + fn default_string() -> Self { + Self::String(String::default()) + } + + fn as_boolean(&self) -> Option { + match self { + Self::Boolean(b) => Some(*b), + _ => None, + } + } + + fn add(&self, other: &Self) -> Result { + match (self, other) { + (Self::Integer(s), Self::Integer(o)) => Ok(Self::Integer(*s + *o)), + (Self::String(s), Self::String(o)) => Ok(Self::String(format!("{s}{o}"))), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Arith(ast::ArithOp::Add), + self.ty(), + other.ty(), + )), + } + } + + fn sub(&self, other: &Self) -> Result { + match (self, other) { + (Self::Integer(s), Self::Integer(o)) => Ok(Self::Integer(*s - *o)), + (Self::String(s), Self::String(o)) => { + Ok(Self::String(s.strip_suffix(o).unwrap_or(s).to_owned())) + } + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Arith(ast::ArithOp::Sub), + self.ty(), + other.ty(), + )), + } + } + + fn mul(&self, other: &Self) -> Result { + match (self, other) { + (Self::Integer(s), Self::Integer(o)) => Ok(Self::Integer(*s * *o)), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Arith(ast::ArithOp::Mul), + self.ty(), + other.ty(), + )), + } + } + + fn div(&self, other: &Self) -> Result { + match (self, other) { + (Self::Integer(s), Self::Integer(o)) => Ok(Self::Integer(*s / *o)), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Arith(ast::ArithOp::Div), + self.ty(), + other.ty(), + )), + } + } + + fn mod_(&self, other: &Self) -> Result { + match (self, other) { + (Self::Integer(s), Self::Integer(o)) => Ok(Self::Integer(*s % *o)), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Arith(ast::ArithOp::Mod), + self.ty(), + other.ty(), + )), + } + } + + fn equals(&self, other: &Self) -> Result { + match (self, other) { + (Self::Integer(s), Self::Integer(o)) => Ok(Self::Boolean(s == o)), + (Self::String(s), Self::String(o)) => Ok(Self::Boolean(s == o)), + (Self::Boolean(s), Self::Boolean(o)) => Ok(Self::Boolean(s == o)), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Cmp(ast::CmpOp::Eq), + self.ty(), + other.ty(), + )), + } + } + + fn greater_than(&self, other: &Self) -> Result { + match (self, other) { + (Self::Integer(s), Self::Integer(o)) => Ok(Self::Boolean(s > o)), + (Self::String(s), Self::String(o)) => Ok(Self::Boolean(s.cmp(o).is_gt())), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Cmp(ast::CmpOp::Gt), + self.ty(), + other.ty(), + )), + } + } + + fn less_than(&self, other: &Self) -> Result { + match (self, other) { + (Self::Integer(s), Self::Integer(o)) => Ok(Self::Boolean(s < o)), + (Self::String(s), Self::String(o)) => Ok(Self::Boolean(s.cmp(o).is_lt())), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Cmp(ast::CmpOp::Lt), + self.ty(), + other.ty(), + )), + } + } + + fn greater_than_equals(&self, other: &Self) -> Result { + match (self, other) { + (Self::Integer(s), Self::Integer(o)) => Ok(Self::Boolean(s >= o)), + (Self::String(s), Self::String(o)) => Ok(Self::Boolean(s.cmp(o).is_ge())), + (Self::Boolean(s), Self::Boolean(o)) => Ok(Self::Boolean(s == o)), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Cmp(ast::CmpOp::Gte), + self.ty(), + other.ty(), + )), + } + } + + fn less_than_equals(&self, other: &Self) -> Result { + match (self, other) { + (Self::Integer(s), Self::Integer(o)) => Ok(Self::Boolean(s <= o)), + (Self::String(s), Self::String(o)) => Ok(Self::Boolean(s.cmp(o).is_le())), + (Self::Boolean(s), Self::Boolean(o)) => Ok(Self::Boolean(s == o)), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Cmp(ast::CmpOp::Lte), + self.ty(), + other.ty(), + )), + } + } + + fn not(&self) -> Result { + match self { + Self::Boolean(s) => Ok(Self::Boolean(!s)), + _ => Err(Error::UndefinedUnaryOp(ast::UnaryOp::Not, self.ty())), + } + } + + fn and(&self, other: &Self) -> Result { + match (self, other) { + (Self::Boolean(s), Self::Boolean(o)) => Ok(Self::Boolean(*s && *o)), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Logic(ast::LogicOp::And), + self.ty(), + other.ty(), + )), + } + } + + fn or(&self, other: &Self) -> Result { + match (self, other) { + (Self::Boolean(s), Self::Boolean(o)) => Ok(Self::Boolean(*s || *o)), + _ => Err(Error::UndefinedBinOp( + ast::BinOp::Logic(ast::LogicOp::Or), + self.ty(), + other.ty(), + )), + } + } +} + +impl fmt::Display for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Unit => write!(f, "()"), + Self::Integer(i) => write!(f, "{i}"), + Self::String(s) => write!(f, "{s}"), + Self::Boolean(b) => write!(f, "{b}"), + Self::Node => write!(f, ""), + } + } +} + +type NodeKind = u16; + +#[derive(Debug, Default)] +struct Visitor { + enter: ast::Block, + leave: ast::Block, +} + +#[derive(Debug)] +struct Visitors { + visitors: HashMap, + begin: ast::Block, + end: ast::Block, +} + +impl Default for Visitors { + fn default() -> Self { + Self::new() + } +} + +impl Visitors { + pub fn new() -> Self { + Self { + visitors: HashMap::new(), + begin: ast::Block { body: vec![] }, + end: ast::Block { body: vec![] }, + } + } + + pub fn insert( + &mut self, + stanza: ast::Stanza, + language: &tree_sitter::Language, + ) -> std::result::Result<(), Error> { + match &stanza.pattern { + ast::Pattern::Begin => self.begin = stanza.statements, + ast::Pattern::End => self.end = stanza.statements, + ast::Pattern::Node(ast::NodePattern { modifier, kind }) => { + let id = language.id_for_node_kind(&kind, true); + if id == 0 { + return Err(Error::InvalidNodeKind(kind.to_owned())); + } + let v = self.visitors.entry(id).or_default(); + match modifier { + ast::Modifier::Enter => v.enter = stanza.statements.clone(), + ast::Modifier::Leave => v.leave = stanza.statements.clone(), + }; + } + } + Ok(()) + } + + pub fn get_by_node(&self, node: tree_sitter::Node) -> Option<&Visitor> { + let node_id = node.kind_id(); + self.visitors.get(&node_id) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub enum Error { + FailedLookup(ast::Identifier), + TypeMismatch { expected: ast::Type, got: ast::Type }, + UndefinedBinOp(ast::BinOp, ast::Type, ast::Type), + UndefinedUnaryOp(ast::UnaryOp, ast::Type), + AlreadyBound(ast::Identifier), + MalformedExpr(String), + InvalidNodeKind(String), + // current node is only set in visitors, not in BEGIN or END blocks + CurrentNodeNotPresent, +} + +type Result = std::result::Result; + +pub struct Context<'a> { + variables: HashMap, + language: tree_sitter::Language, + visitors: Visitors, + input_src: Option, + cursor: Option>, +} + +impl<'a> fmt::Debug for Context<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Context") + .field("variables", &self.variables) + .field("language", &self.language) + .field("visitors", &self.visitors) + .field("input_src", &self.input_src) + .field( + "cursor", + if self.cursor.is_some() { + &"Some()" + } else { + &"None" + }, + ) + .finish() + } +} + +impl<'a> Context<'a> { + pub fn new(language: tree_sitter::Language) -> Self { + Self { + visitors: Default::default(), + variables: Default::default(), + language, + input_src: None, + cursor: None, + } + } + + pub fn with_program(mut self, program: ast::Program) -> std::result::Result { + for stanza in program.stanzas.into_iter() { + self.visitors.insert(stanza, &self.language)?; + } + Ok(self) + } + + pub fn with_input(mut self, src: String) -> Self { + self.input_src = Some(src); + self + } + + pub fn with_cursor(mut self, cursor: tree_sitter::TreeCursor<'a>) -> Self { + self.cursor = Some(cursor); + self + } + + fn eval_expr(&mut self, expr: &ast::Expr) -> Result { + match expr { + ast::Expr::Unit => Ok(Value::Unit), + ast::Expr::Lit(lit) => self.eval_lit(lit), + ast::Expr::Ident(ident) => self.lookup(ident).map(Variable::value).cloned(), + ast::Expr::Bin(lhs, op, rhs) => self.eval_bin(&*lhs, *op, &*rhs), + ast::Expr::Unary(expr, op) => self.eval_unary(&*expr, *op), + ast::Expr::Call(call) => self.eval_call(&*call), + ast::Expr::IfExpr(if_expr) => self.eval_if(if_expr), + ast::Expr::Block(block) => self.eval_block(block), + ast::Expr::Node => Ok(Value::Node), + } + } + + fn eval_lit(&mut self, lit: &ast::Literal) -> Result { + match lit { + ast::Literal::Str(s) => Ok(Value::String(s.to_owned())), + ast::Literal::Int(i) => Ok(Value::Integer(*i)), + ast::Literal::Bool(b) => Ok(Value::Boolean(*b)), + } + } + + fn lookup(&mut self, ident: &ast::Identifier) -> std::result::Result<&Variable, Error> { + self.variables + .get(ident) + .ok_or_else(|| Error::FailedLookup(ident.to_owned())) + } + + fn lookup_mut(&mut self, ident: &ast::Identifier) -> std::result::Result<&mut Variable, Error> { + self.variables + .get_mut(ident) + .ok_or_else(|| Error::FailedLookup(ident.to_owned())) + } + + fn bind( + &mut self, + ident: &ast::Identifier, + ty: ast::Type, + ) -> std::result::Result<&mut Variable, Error> { + if self.lookup(ident).is_err() { + Ok(self + .variables + .entry(ident.to_owned()) + .or_insert_with(|| Variable { + name: ident.to_owned(), + value: Value::default(ty), + ty, + })) + } else { + Err(Error::AlreadyBound(ident.to_owned())) + } + } + + fn eval_bin(&mut self, lhs: &ast::Expr, op: ast::BinOp, rhs: &ast::Expr) -> Result { + match op { + ast::BinOp::Assign(op) => self.eval_assign(lhs, op, rhs), + ast::BinOp::Arith(op) => self.eval_arith(lhs, op, rhs), + ast::BinOp::Cmp(op) => self.eval_cmp(lhs, op, rhs), + ast::BinOp::Logic(op) => self.eval_logic(lhs, op, rhs), + } + } + + fn eval_assign( + &mut self, + lhs: &ast::Expr, + ast::AssignOp { op }: ast::AssignOp, + rhs: &ast::Expr, + ) -> Result { + let ast::Expr::Ident(ident) = lhs else { + return Err(Error::MalformedExpr(format!( + "malformed assigment, lhs: {:?}", + lhs + ))); + }; + let value = self.eval_expr(rhs)?; + let variable = self.lookup_mut(ident)?; + match op { + None => variable.assign(value), + Some(ast::ArithOp::Add) => variable.assign(variable.value().add(&value)?), + Some(ast::ArithOp::Sub) => variable.assign(variable.value().sub(&value)?), + Some(ast::ArithOp::Mul) => variable.assign(variable.value().mul(&value)?), + Some(ast::ArithOp::Div) => variable.assign(variable.value().div(&value)?), + Some(ast::ArithOp::Mod) => variable.assign(variable.value().mod_(&value)?), + } + } + + fn eval_arith(&mut self, lhs: &ast::Expr, op: ast::ArithOp, rhs: &ast::Expr) -> Result { + let l = self.eval_expr(lhs)?; + let r = self.eval_expr(rhs)?; + match op { + ast::ArithOp::Add => l.add(&r), + ast::ArithOp::Sub => l.sub(&r), + ast::ArithOp::Mul => l.mul(&r), + ast::ArithOp::Div => l.div(&r), + ast::ArithOp::Mod => l.mod_(&r), + } + } + + fn eval_cmp(&mut self, lhs: &ast::Expr, op: ast::CmpOp, rhs: &ast::Expr) -> Result { + let l = self.eval_expr(lhs)?; + let r = self.eval_expr(rhs)?; + + match op { + ast::CmpOp::Eq => l.equals(&r), + ast::CmpOp::Gt => l.greater_than(&r), + ast::CmpOp::Lt => l.less_than(&r), + ast::CmpOp::Neq => l.equals(&r).and_then(|v| v.not()), + ast::CmpOp::Gte => l.greater_than_equals(&r), + ast::CmpOp::Lte => l.less_than_equals(&r), + } + } + + fn eval_logic(&mut self, lhs: &ast::Expr, op: ast::LogicOp, rhs: &ast::Expr) -> Result { + let l = self.eval_expr(lhs)?; + + // short-circuit + let l_value = l.as_boolean().ok_or_else(|| Error::TypeMismatch { + expected: ast::Type::Boolean, + got: l.ty(), + })?; + + match op { + ast::LogicOp::Or => { + if l_value { + return Ok(l); + } else { + let r = self.eval_expr(rhs)?; + l.or(&r) + } + } + ast::LogicOp::And => { + if !l_value { + return Ok(l); + } else { + let r = self.eval_expr(rhs)?; + l.and(&r) + } + } + } + } + + fn eval_unary(&mut self, expr: &ast::Expr, op: ast::UnaryOp) -> Result { + let val = self.eval_expr(expr)?; + match op { + ast::UnaryOp::Not => val.not(), + } + } + + fn eval_if(&mut self, if_expr: &ast::If) -> Result { + let cond = self.eval_expr(&if_expr.condition)?; + + if cond.as_boolean().ok_or_else(|| Error::TypeMismatch { + expected: ast::Type::Boolean, + got: cond.ty(), + })? { + self.eval_block(&if_expr.then) + } else { + self.eval_block(&if_expr.else_) + } + } + + fn eval_call(&mut self, call: &ast::Call) -> Result { + match (call.function.as_str(), call.parameters.as_slice()) { + ("print", args) => { + for arg in args { + let val = self.eval_expr(arg)?; + print!("{val}"); + } + Ok(Value::Unit) + } + ("text", [arg]) if self.eval_expr(arg)? == Value::Node => { + let node = self + .cursor + .as_ref() + .ok_or(Error::CurrentNodeNotPresent)? + .node(); + let text = node + .utf8_text(self.input_src.as_ref().unwrap().as_bytes()) + .unwrap(); + Ok(Value::String(text.to_owned())) + } + (s, _) => Err(Error::FailedLookup(s.to_owned())), + } + } + + fn eval_declaration(&mut self, decl: &ast::Declaration) -> Result { + let initial_value = match decl.init.as_ref() { + Some(init) => Some(self.eval_expr(&*init)?), + None => None, + }; + let variable = self.bind(&decl.name, decl.ty)?; + + if let Some(init) = initial_value { + variable.assign(init)?; + } + + Ok(Value::Unit) + } + + fn eval_statement(&mut self, stmt: &ast::Statement) -> Result { + match stmt { + ast::Statement::Bare(expr) => self.eval_expr(expr).map(|_| Value::Unit), + ast::Statement::Declaration(decl) => self.eval_declaration(decl), + } + } + + fn eval_block(&mut self, block: &ast::Block) -> Result { + for stmt in block.body.iter() { + self.eval_statement(stmt)?; + } + Ok(Value::Unit) + } + + pub fn eval(&mut self) -> Result { + let visitors = std::mem::take(&mut self.visitors); + let mut has_next = true; + let mut postorder = Vec::new(); + + // BEGIN block + self.eval_block(&visitors.begin)?; + + while has_next { + let current_node = self.cursor.as_mut().unwrap().node(); + postorder.push(current_node); + + let visitor = visitors.get_by_node(current_node); + + visitor.map(|v| self.eval_block(&v.enter)); + + has_next = self.cursor.as_mut().unwrap().goto_first_child(); + + if !has_next { + has_next = self.cursor.as_mut().unwrap().goto_next_sibling(); + postorder + .pop() + .and_then(|n| visitors.get_by_node(n)) + .map(|v| self.eval_block(&v.leave)); + } + + while !has_next && self.cursor.as_mut().unwrap().goto_parent() { + has_next = self.cursor.as_mut().unwrap().goto_next_sibling(); + postorder + .pop() + .and_then(|n| visitors.get_by_node(n)) + .map(|v| self.eval_block(&v.leave)); + } + } + + // END block + self.eval_block(&visitors.end)?; + + Ok(Value::Unit) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn bin() { + let language = tree_sitter_python::language(); + let mut ctx = Context::new(language) + .with_program(ast::Program::new()) + .unwrap(); + assert_eq!( + ctx.eval_expr(&ast::Expr::Bin( + ast::Expr::int(5).boxed(), + ast::BinOp::Arith(ast::ArithOp::Add), + ast::Expr::int(10).boxed(), + )), + Ok(Value::Integer(15)) + ); + assert_eq!( + ctx.eval_expr(&ast::Expr::Bin( + ast::Expr::int(5).boxed(), + ast::BinOp::Cmp(ast::CmpOp::Eq), + ast::Expr::int(10).boxed(), + )), + Ok(Value::Boolean(false)) + ); + assert_eq!( + ctx.eval_expr(&ast::Expr::Bin( + ast::Expr::int(5).boxed(), + ast::BinOp::Cmp(ast::CmpOp::Lt), + ast::Expr::int(10).boxed(), + )), + Ok(Value::Boolean(true)) + ); + assert_eq!( + ctx.eval_expr(&ast::Expr::Bin( + ast::Expr::Bin( + ast::Expr::int(5).boxed(), + ast::BinOp::Cmp(ast::CmpOp::Lt), + ast::Expr::int(10).boxed(), + ) + .boxed(), + ast::BinOp::Logic(ast::LogicOp::And), + ast::Expr::false_().boxed() + )), + Ok(Value::Boolean(false)) + ); + } + + #[test] + fn test_evaluate_blocks() { + let language = tree_sitter_python::language(); + let mut ctx = Context::new(language) + .with_program(ast::Program::new()) + .unwrap(); + assert_eq!( + ctx.eval_block(&ast::Block { + body: vec![ + ast::Statement::Declaration(ast::Declaration { + ty: ast::Type::Integer, + name: "a".to_owned(), + init: None, + }), + ast::Statement::Bare(ast::Expr::Bin( + ast::Expr::Ident("a".to_owned()).boxed(), + ast::BinOp::Assign(ast::AssignOp { + op: Some(ast::ArithOp::Add) + }), + ast::Expr::int(5).boxed() + )), + ] + }), + Ok(Value::Unit) + ); + assert_eq!( + ctx.lookup(&String::from("a")).unwrap().clone(), + Variable { + ty: ast::Type::Integer, + name: "a".to_owned(), + value: Value::Integer(5) + } + ); + } + + #[test] + fn test_evaluate_if() { + let language = tree_sitter_python::language(); + let mut ctx = Context::new(language) + .with_program(ast::Program::new()) + .unwrap(); + assert_eq!( + ctx.eval_block(&ast::Block { + body: vec![ + ast::Statement::Declaration(ast::Declaration { + ty: ast::Type::Integer, + name: "a".to_owned(), + init: Some(ast::Expr::int(1).boxed()), + }), + ast::Statement::Bare(ast::Expr::IfExpr(ast::If { + condition: ast::Expr::true_().boxed(), + then: ast::Block { + body: vec![ast::Statement::Bare(ast::Expr::Bin( + ast::Expr::Ident("a".to_owned()).boxed(), + ast::BinOp::Assign(ast::AssignOp { + op: Some(ast::ArithOp::Add) + }), + ast::Expr::int(5).boxed() + ))] + }, + else_: ast::Block { + body: vec![ast::Statement::Bare(ast::Expr::Bin( + ast::Expr::Ident("a".to_owned()).boxed(), + ast::BinOp::Assign(ast::AssignOp { + op: Some(ast::ArithOp::Add) + }), + ast::Expr::int(10).boxed() + ))] + } + })) + ] + }), + Ok(Value::Unit) + ); + assert_eq!( + ctx.lookup(&String::from("a")).unwrap().clone(), + Variable { + ty: ast::Type::Integer, + name: "a".to_owned(), + value: Value::Integer(6) + } + ); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..8780b74 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,7 @@ +mod ast; +mod eval; +mod parser; +mod string; + +pub use ast::Program; +pub use eval::Context; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..09a15ef --- /dev/null +++ b/src/main.rs @@ -0,0 +1,47 @@ +use trawk::{Context, Program}; + +fn main() { + let src = r#" +bar = 0 +def foo(): + baz = 5 + "# + .to_owned(); + + let program = Program::new() + .from_str( + r#" + BEGIN { + bool in_def = false; + } + pre function_definition { + in_def = true; + } + post function_definition { + in_def = false; + } + pre identifier { + if (in_def) { + print(text(node)); + print(" "); + print("in def\n"); + } else { + }; + }"#, + ) + .unwrap(); + + let mut parser = tree_sitter::Parser::new(); + let _ = parser.set_language(tree_sitter_python::language()); + + let tree = parser.parse(&src, None).unwrap(); + let cursor = tree.walk(); + + let mut ctx = Context::new(tree_sitter_python::language()) + .with_input(src) + .with_cursor(cursor) + .with_program(program) + .unwrap(); + + let _ = ctx.eval(); +} diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..3a020dc --- /dev/null +++ b/src/parser.rs @@ -0,0 +1,689 @@ +use nom::{ + branch::alt, + bytes::complete::tag, + character::complete::{alpha1, alphanumeric1, char, multispace0, multispace1, one_of}, + combinator::{map, opt, recognize, value}, + error::ParseError, + multi::{many0, many0_count, many1, separated_list0}, + sequence::{delimited, pair, preceded, terminated, tuple}, + IResult, Parser, +}; +// use tree_sitter::Query; + +use crate::ast::*; +use crate::string::parse_string; + +fn ws<'a, F: 'a, O, E>(inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O, E> +where + F: FnMut(&'a str) -> IResult<&'a str, O, E>, + E: ParseError<&'a str>, +{ + delimited(multispace0, inner, multispace0) +} + +fn parse_unit<'a>(i: &'a str) -> IResult<&'a str, ()> { + let open = char('('); + let close = char(')'); + let unit = tuple((open, close)); + value((), unit)(i) +} + +fn parse_bool(i: &str) -> IResult<&str, bool> { + let t = value(true, tag("true")); + let f = value(false, tag("false")); + alt((t, f)).parse(i) +} + +fn parse_int<'a>(i: &'a str) -> IResult<&'a str, i128> { + map(recognize(many1(one_of("0123456789"))), |s: &str| { + s.parse::().unwrap() + })(i) +} + +fn parse_name(i: &str) -> IResult<&str, &str> { + recognize(pair( + alt((alpha1, tag("_"))), + many0_count(alt((alphanumeric1, tag("_")))), + )) + .parse(i) +} + +fn parse_ident(i: &str) -> IResult<&str, Identifier> { + map(parse_name, str::to_owned)(i) +} + +fn parse_lit<'a>(i: &'a str) -> IResult<&'a str, Literal> { + alt(( + map(parse_string, Literal::Str), + map(parse_int, Literal::Int), + map(parse_bool, Literal::Bool), + )) + .parse(i) +} + +fn parse_cmp_op(i: &str) -> IResult<&str, CmpOp> { + alt(( + value(CmpOp::Eq, tag("==")), + value(CmpOp::Neq, tag("!=")), + value(CmpOp::Gte, tag(">=")), + value(CmpOp::Lte, tag("<=")), + value(CmpOp::Gt, tag(">")), + value(CmpOp::Lt, tag("<")), + )) + .parse(i) +} + +fn parse_assign_op(i: &str) -> IResult<&str, AssignOp> { + let parse_arith_op = alt(( + value(ArithOp::Add, char('+')), + value(ArithOp::Sub, char('-')), + value(ArithOp::Mul, char('*')), + value(ArithOp::Div, char('/')), + value(ArithOp::Mod, char('%')), + )); + map(tuple((opt(parse_arith_op), char('='))), |(op, _)| { + AssignOp { op } + })(i) +} + +fn parse_op<'a, E, T>( + op_str: &'static str, + op: T, +) -> impl FnMut(&'a str) -> Result<(&'a str, T), nom::Err> +where + E: ParseError<&'a str>, + T: Copy, +{ + value(op, tag(op_str)) +} + +fn parse_binary<'a, P1, P2, P3, E>( + lhs: P1, + op: P2, + rhs: P3, +) -> impl FnMut(&'a str) -> Result<(&'a str, Expr), nom::Err> +where + P1: Parser<&'a str, Expr, E>, + P2: Parser<&'a str, BinOp, E>, + P3: Parser<&'a str, Expr, E>, + E: ParseError<&'a str>, +{ + map(tuple((lhs, op, rhs)), |(l, o, r)| { + Expr::Bin(l.boxed(), o, r.boxed()) + }) +} + +fn parse_assign<'a>(i: &'a str) -> IResult<&'a str, Expr> { + let op = map(parse_assign_op, BinOp::Assign); + let recursive = parse_binary(parse_atom, op, parse_assign); + let base = parse_union; + alt((recursive, base)).parse(i) +} + +fn parse_union<'a>(i: &'a str) -> IResult<&'a str, Expr> { + let op = parse_op("||", BinOp::Logic(LogicOp::Or)); + let recursive = parse_binary(parse_intersection, op, parse_union); + let base = parse_intersection; + alt((recursive, base)).parse(i) +} + +fn parse_intersection<'a>(i: &'a str) -> IResult<&'a str, Expr> { + let op = parse_op("&&", BinOp::Logic(LogicOp::And)); + let recursive = parse_binary(parse_negated, op, parse_intersection); + let base = parse_negated; + alt((recursive, base)).parse(i) +} + +fn parse_negated<'a>(i: &'a str) -> IResult<&'a str, Expr> { + let op = parse_op("!", UnaryOp::Not); + let recursive = map(tuple((op, parse_rel)), |(op, expr)| { + Expr::Unary(expr.boxed(), op) + }); + let base = parse_rel; + alt((recursive, base)).parse(i) +} + +fn parse_rel<'a>(i: &'a str) -> IResult<&'a str, Expr> { + let op = map(parse_cmp_op, BinOp::Cmp); + let recursive = parse_binary(parse_sum, op, parse_rel); + let base = parse_sum; + alt((recursive, base)).parse(i) +} + +fn parse_sum<'a>(i: &'a str) -> IResult<&'a str, Expr> { + let add = parse_op("+", BinOp::Arith(ArithOp::Add)); + let sub = parse_op("-", BinOp::Arith(ArithOp::Sub)); + let op = alt((add, sub)); + let recursive = parse_binary(parse_mul, op, parse_sum); + let base = parse_mul; + alt((recursive, base)).parse(i) +} + +fn parse_mul<'a>(i: &'a str) -> IResult<&'a str, Expr> { + let mul = parse_op("*", BinOp::Arith(ArithOp::Mul)); + let div = parse_op("/", BinOp::Arith(ArithOp::Div)); + let mod_ = parse_op("%", BinOp::Arith(ArithOp::Mod)); + let op = alt((mul, div, mod_)); + let recursive = parse_binary(parse_atom, op, parse_mul); + let base = parse_atom; + alt((recursive, base)).parse(i) +} + +fn parse_atom<'a>(i: &'a str) -> IResult<&'a str, Expr> { + let inner = alt(( + map(tag("node"), |_| Expr::Node), + map(parse_block, Expr::Block), + map(parse_if, Expr::IfExpr), + map(parse_call, Expr::Call), + map(parse_lit, Expr::Lit), + map(parse_ident, Expr::Ident), + map(parse_unit, |_| Expr::Unit), + )); + ws(inner).parse(i) +} + +fn parse_call<'a>(i: &'a str) -> IResult<&'a str, Call> { + let ident = parse_ident; + let open = ws(char('(')); + let args = separated_list0(char(','), parse_expr); + let close = ws(char(')')); + map( + tuple((ident, open, args, close)), + |(function, _, parameters, _)| Call { + function, + parameters, + }, + ) + .parse(i) +} + +fn parse_block<'a>(i: &'a str) -> IResult<&'a str, Block> { + let open = ws(char('{')); + let statements = map(many0(parse_statement), |body| Block { body }); + let close = ws(char('}')); + delimited(open, statements, close).parse(i) +} + +fn parse_if<'a>(i: &'a str) -> IResult<&'a str, If> { + let if_ = delimited(multispace0, tag("if"), multispace1); + + let open = char('('); + let condition = ws(parse_expr); + let close = terminated(char(')'), multispace0); + + let then = parse_block; + + let else_kw = ws(tag("else")); + let else_ = opt(preceded(else_kw, parse_block)); + + map( + tuple((if_, open, condition, close, then, else_)), + |(_, _, condition, _, then, else_)| If { + condition: condition.boxed(), + then, + else_: else_.unwrap_or_default(), + }, + )(i) +} + +fn parse_expr<'a>(i: &'a str) -> IResult<&'a str, Expr> { + parse_assign.parse(i) +} + +fn parse_bare<'a>(i: &'a str) -> IResult<&'a str, Expr> { + parse_expr(i) +} + +fn parse_type<'a>(i: &'a str) -> IResult<&'a str, Type> { + let int = value(Type::Integer, tag("int")); + let string = value(Type::String, tag("string")); + let bool_ = value(Type::Boolean, tag("bool")); + alt((int, string, bool_)).parse(i) +} + +fn parse_declaration<'a>(i: &'a str) -> IResult<&'a str, Declaration> { + let ty = parse_type; + let name = parse_ident; + let op = ws(char('=')); + let init = opt(preceded(op, map(parse_expr, Expr::boxed))); + map( + tuple((ty, multispace0, name, init)), + |(ty, _, name, init)| Declaration { ty, name, init }, + )(i) +} + +fn parse_statement<'a>(i: &'a str) -> IResult<&'a str, Statement> { + let semicolon = ws(char(';')); + let inner = alt(( + map(parse_declaration, Statement::Declaration), + map(parse_bare, Statement::Bare), + )); + terminated(inner, semicolon).parse(i) +} + +// pub fn skip_query(mut i: &str) -> IResult<&str, ()> { +// let mut paren_depth = 0; +// let mut in_string = false; +// let mut in_escape = false; +// let mut in_comment = false; +// loop { +// let ch = i +// .chars() +// .next() +// .ok_or(nom::Err::Error(nom::error::Error::new( +// i, +// nom::error::ErrorKind::Eof, +// )))?; +// if in_escape { +// in_escape = false; +// } else if in_string { +// match ch { +// '\\' => { +// in_escape = true; +// } +// '"' | '\n' => { +// in_string = false; +// } +// _ => {} +// } +// } else if in_comment { +// if ch == '\n' { +// in_comment = false; +// } +// } else { +// match ch { +// '"' => in_string = true, +// '(' => paren_depth += 1, +// ')' => { +// if paren_depth > 0 { +// paren_depth -= 1; +// } +// } +// '{' => return Ok((i, ())), +// ';' => in_comment = true, +// _ => {} +// } +// } +// i = &i[1..]; +// } +// } + +// fn parse_query<'a>( +// language: tree_sitter::Language, +// ) -> impl FnMut(&'a str) -> IResult<&'a str, Query> { +// return move |initial: &'a str| { +// let query_start = 0; +// let (skipped, _) = skip_query(initial)?; +// let query_end = initial.len() - skipped.len(); +// let query_source = &initial[query_start..query_end].to_owned(); +// +// let query = Query::new(language, &query_source).map_err(|mut _e| { +// nom::Err::Error(nom::error::Error::new(initial, nom::error::ErrorKind::Fail)) +// })?; +// Ok((skipped, query)) +// }; +// } + +fn parse_modifier<'a>(i: &str) -> IResult<&str, Modifier> { + let pre = value(Modifier::Enter, tag("enter")); + let post = value(Modifier::Leave, tag("leave")); + map(opt(alt((pre, post))), Option::unwrap_or_default)(i) +} + +fn parse_pattern<'a>(i: &str) -> IResult<&str, Pattern> { + let begin = value(Pattern::Begin, ws(tag("BEGIN"))); + let end = value(Pattern::End, ws(tag("END"))); + let node = map( + tuple((parse_modifier, multispace0, parse_ident)), + |(modifier, _, kind)| Pattern::Node(NodePattern { modifier, kind }), + ); + alt((begin, end, node)).parse(i) +} + +pub fn parse_stanza<'a>(i: &str) -> IResult<&str, Stanza> { + map( + tuple((parse_pattern, parse_block)), + |(pattern, statements)| Stanza { + pattern, + statements, + }, + )(i) +} + +pub fn parse_file(i: &str) -> IResult<&str, Vec> { + many0(parse_stanza).parse(i) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parse_unit() { + assert_eq!(parse_unit("()"), Ok(("", ()))) + } + + #[test] + fn test_parse_int() { + assert_eq!(parse_int("123456"), Ok(("", 123456))); + assert_eq!(parse_int("00123456"), Ok(("", 123456))); + } + + #[test] + fn test_parse_bool() { + assert_eq!(parse_bool("true"), Ok(("", true))); + assert_eq!(parse_bool("false"), Ok(("", false))); + } + + #[test] + fn test_parse_name() { + assert_eq!(parse_name("true"), Ok(("", "true"))); + assert_eq!(parse_name("_abc"), Ok(("", "_abc"))); + } + + #[test] + fn test_parse_literal() { + assert_eq!( + parse_lit(r#""foobarbaz""#), + Ok(("", Literal::Str("foobarbaz".to_owned()))) + ); + assert_eq!(parse_lit("123"), Ok(("", Literal::Int(123)))); + assert_eq!(parse_lit("true"), Ok(("", Literal::Bool(true)))); + } + + #[test] + fn test_parse_expr() { + assert_eq!(parse_expr(" () "), Ok(("", Expr::Unit))); + assert_eq!(parse_expr(" 55 "), Ok(("", Expr::int(55)))); + assert_eq!( + parse_expr(" true || true "), + Ok(( + "", + Expr::Bin( + Expr::true_().boxed(), + BinOp::Logic(LogicOp::Or), + Expr::true_().boxed() + ) + )) + ); + assert_eq!( + parse_expr("true || false && 5 == 5 "), + Ok(( + "", + Expr::Bin( + Expr::true_().boxed(), + BinOp::Logic(LogicOp::Or), + Expr::Bin( + Expr::false_().boxed(), + BinOp::Logic(LogicOp::And), + Expr::Bin( + Expr::int(5).boxed(), + BinOp::Cmp(CmpOp::Eq), + Expr::int(5).boxed(), + ) + .boxed() + ) + .boxed() + ) + )) + ); + assert_eq!( + parse_expr(" foo ( 1, 2,3 , 1 == 1)"), + Ok(( + "", + Expr::Call(Call { + function: "foo".to_owned(), + parameters: vec![ + Expr::int(1), + Expr::int(2), + Expr::int(3), + Expr::Bin( + Expr::int(1).boxed(), + BinOp::Cmp(CmpOp::Eq), + Expr::int(1).boxed() + ) + ], + }) + )) + ); + assert_eq!( + parse_expr("a = b"), + Ok(( + "", + Expr::Bin( + Expr::Ident("a".to_owned()).boxed(), + BinOp::Assign(AssignOp { op: None }), + Expr::Ident("b".to_owned()).boxed(), + ) + )) + ); + assert_eq!( + parse_expr(" a += 4 + 5"), + Ok(( + "", + Expr::Bin( + Expr::Ident("a".to_owned()).boxed(), + BinOp::Assign(AssignOp { + op: Some(ArithOp::Add) + }), + Expr::Bin( + Expr::int(4).boxed(), + BinOp::Arith(ArithOp::Add), + Expr::int(5).boxed(), + ) + .boxed() + ) + )) + ); + } + + #[test] + fn test_parse_statement() { + assert_eq!( + parse_statement("true;"), + Ok(("", Statement::Bare(Expr::true_()))) + ); + assert_eq!( + parse_statement("true ; "), + Ok(("", Statement::Bare(Expr::true_()))) + ); + assert_eq!( + parse_statement("int a ; "), + Ok(( + "", + Statement::Declaration(Declaration { + ty: Type::Integer, + name: "a".to_owned(), + init: None + }) + )) + ); + assert_eq!( + parse_statement("int a =5 ; "), + Ok(( + "", + Statement::Declaration(Declaration { + ty: Type::Integer, + name: "a".to_owned(), + init: Some(Expr::int(5).boxed()) + }) + )) + ); + } + + #[test] + fn test_parse_block() { + assert_eq!( + parse_expr( + r#" + { + true; + 1; + } + "# + ), + Ok(( + "", + Expr::Block(Block { + body: vec![ + Statement::Bare(Expr::true_()), + Statement::Bare(Expr::int(1)), + ] + }) + )) + ); + } + + #[test] + fn test_parse_if() { + assert_eq!( + parse_expr( + r#" + if (1 == true) { + 5; + } else { + 10; + } + "# + ), + Ok(( + "", + Expr::IfExpr(If { + condition: Expr::Bin( + Expr::int(1).boxed(), + BinOp::Cmp(CmpOp::Eq), + Expr::true_().boxed() + ) + .boxed(), + then: Block { + body: vec![Statement::Bare(Expr::int(5)),] + }, + else_: Block { + body: vec![Statement::Bare(Expr::int(10)),] + } + }) + )) + ); + } + + // #[test] + // fn test_skip_query() { + // assert_eq!( + // skip_query( + // r#"(heading + // (paragraph) @foo) {}"# + // ), + // Ok(("{}", ())) + // ); + // } + + #[test] + fn test_parse_pattern() { + assert_eq!( + parse_pattern("enter function_definition"), + Ok(( + "", + Pattern::Node(NodePattern { + modifier: Modifier::Enter, + kind: "function_definition".to_owned() + }) + )) + ); + assert_eq!( + parse_pattern("function_definition"), + Ok(( + "", + Pattern::Node(NodePattern { + modifier: Modifier::Enter, + kind: "function_definition".to_owned() + }) + )) + ); + assert_eq!( + parse_pattern("leave function_definition"), + Ok(( + "", + Pattern::Node(NodePattern { + modifier: Modifier::Leave, + kind: "function_definition".to_owned() + }) + )) + ); + } + + #[test] + fn test_parse_stanza() { + assert_eq!( + parse_stanza("enter function_definition { true; }"), + Ok(( + "", + Stanza { + pattern: Pattern::Node(NodePattern { + modifier: Modifier::Enter, + kind: "function_definition".to_owned() + }), + statements: Block { + body: vec![Statement::Bare(Expr::true_())] + } + } + )) + ); + assert_eq!( + parse_stanza("BEGIN { true; }"), + Ok(( + "", + Stanza { + pattern: Pattern::Begin, + statements: Block { + body: vec![Statement::Bare(Expr::true_())] + } + } + )) + ); + assert_eq!( + parse_block( + " { + true; + }" + ), + Ok(( + "", + Block { + body: vec![Statement::Bare(Expr::true_())] + } + )) + ); + } + + #[test] + fn test_parse_if_statement_regression() { + assert_eq!( + parse_statement("if (true) { true; };"), + Ok(( + "", + Statement::Bare(Expr::IfExpr(If { + condition: Expr::true_().boxed(), + then: Block { + body: vec![Statement::Bare(Expr::true_())] + }, + else_: Block::default(), + })) + )) + ); + assert_eq!( + parse_expr("if (true) { true; } else { true; }"), + Ok(( + "", + Expr::IfExpr(If { + condition: Expr::true_().boxed(), + then: Block { + body: vec![Statement::Bare(Expr::true_())] + }, + else_: Block { + body: vec![Statement::Bare(Expr::true_())] + }, + }) + )) + ); + } +} diff --git a/src/string.rs b/src/string.rs new file mode 100644 index 0000000..820f9ff --- /dev/null +++ b/src/string.rs @@ -0,0 +1,152 @@ +use nom::branch::alt; +use nom::bytes::streaming::{is_not, take_while_m_n}; +use nom::character::streaming::{char, multispace1}; +use nom::combinator::{map, map_opt, map_res, value, verify}; +use nom::error::{FromExternalError, ParseError}; +use nom::multi::fold_many0; +use nom::sequence::{delimited, preceded}; +use nom::{IResult, Parser}; + +// parser combinators are constructed from the bottom up: +// first we write parsers for the smallest elements (escaped characters), +// then combine them into larger parsers. + +/// Parse a unicode sequence, of the form u{XXXX}, where XXXX is 1 to 6 +/// hexadecimal numerals. We will combine this later with parse_escaped_char +/// to parse sequences like \u{00AC}. +fn parse_unicode<'a, E>(input: &'a str) -> IResult<&'a str, char, E> +where + E: ParseError<&'a str> + FromExternalError<&'a str, std::num::ParseIntError>, +{ + // `take_while_m_n` parses between `m` and `n` bytes (inclusive) that match + // a predicate. `parse_hex` here parses between 1 and 6 hexadecimal numerals. + let parse_hex = take_while_m_n(1, 6, |c: char| c.is_ascii_hexdigit()); + + // `preceded` takes a prefix parser, and if it succeeds, returns the result + // of the body parser. In this case, it parses u{XXXX}. + let parse_delimited_hex = preceded( + char('u'), + // `delimited` is like `preceded`, but it parses both a prefix and a suffix. + // It returns the result of the middle parser. In this case, it parses + // {XXXX}, where XXXX is 1 to 6 hex numerals, and returns XXXX + delimited(char('{'), parse_hex, char('}')), + ); + + // `map_res` takes the result of a parser and applies a function that returns + // a Result. In this case we take the hex bytes from parse_hex and attempt to + // convert them to a u32. + let parse_u32 = map_res(parse_delimited_hex, move |hex| u32::from_str_radix(hex, 16)); + + // map_opt is like map_res, but it takes an Option instead of a Result. If + // the function returns None, map_opt returns an error. In this case, because + // not all u32 values are valid unicode code points, we have to fallibly + // convert to char with from_u32. + map_opt(parse_u32, std::char::from_u32).parse(input) +} + +/// Parse an escaped character: \n, \t, \r, \u{00AC}, etc. +fn parse_escaped_char<'a, E>(input: &'a str) -> IResult<&'a str, char, E> +where + E: ParseError<&'a str> + FromExternalError<&'a str, std::num::ParseIntError>, +{ + preceded( + char('\\'), + // `alt` tries each parser in sequence, returning the result of + // the first successful match + alt(( + parse_unicode, + // The `value` parser returns a fixed value (the first argument) if its + // parser (the second argument) succeeds. In these cases, it looks for + // the marker characters (n, r, t, etc) and returns the matching + // character (\n, \r, \t, etc). + value('\n', char('n')), + value('\r', char('r')), + value('\t', char('t')), + value('\u{08}', char('b')), + value('\u{0C}', char('f')), + value('\\', char('\\')), + value('/', char('/')), + value('"', char('"')), + )), + ) + .parse(input) +} + +/// Parse a backslash, followed by any amount of whitespace. This is used later +/// to discard any escaped whitespace. +fn parse_escaped_whitespace<'a, E: ParseError<&'a str>>( + input: &'a str, +) -> IResult<&'a str, &'a str, E> { + preceded(char('\\'), multispace1).parse(input) +} + +/// Parse a non-empty block of text that doesn't include \ or " +fn parse_literal<'a, E: ParseError<&'a str>>(input: &'a str) -> IResult<&'a str, &'a str, E> { + // `is_not` parses a string of 0 or more characters that aren't one of the + // given characters. + let not_quote_slash = is_not("\"\\"); + + // `verify` runs a parser, then runs a verification function on the output of + // the parser. The verification function accepts out output only if it + // returns true. In this case, we want to ensure that the output of is_not + // is non-empty. + verify(not_quote_slash, |s: &str| !s.is_empty()).parse(input) +} + +/// A string fragment contains a fragment of a string being parsed: either +/// a non-empty Literal (a series of non-escaped characters), a single +/// parsed escaped character, or a block of escaped whitespace. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum StringFragment<'a> { + Literal(&'a str), + EscapedChar(char), + EscapedWS, +} + +/// Combine parse_literal, parse_escaped_whitespace, and parse_escaped_char +/// into a StringFragment. +fn parse_fragment<'a, E>(input: &'a str) -> IResult<&'a str, StringFragment<'a>, E> +where + E: ParseError<&'a str> + FromExternalError<&'a str, std::num::ParseIntError>, +{ + alt(( + // The `map` combinator runs a parser, then applies a function to the output + // of that parser. + map(parse_literal, StringFragment::Literal), + map(parse_escaped_char, StringFragment::EscapedChar), + value(StringFragment::EscapedWS, parse_escaped_whitespace), + )) + .parse(input) +} + +/// Parse a string. Use a loop of parse_fragment and push all of the fragments +/// into an output string. +pub fn parse_string<'a, E>(input: &'a str) -> IResult<&'a str, String, E> +where + E: ParseError<&'a str> + FromExternalError<&'a str, std::num::ParseIntError>, +{ + // fold is the equivalent of iterator::fold. It runs a parser in a loop, + // and for each output value, calls a folding function on each output value. + let build_string = fold_many0( + // Our parser function – parses a single string fragment + parse_fragment, + // Our init value, an empty string + String::new, + // Our folding function. For each fragment, append the fragment to the + // string. + |mut string, fragment| { + match fragment { + StringFragment::Literal(s) => string.push_str(s), + StringFragment::EscapedChar(c) => string.push(c), + StringFragment::EscapedWS => {} + } + string + }, + ); + + // Finally, parse the string. Note that, if `build_string` could accept a raw + // " character, the closing delimiter " would never match. When using + // `delimited` with a looping parser (like fold), be sure that the + // loop won't accidentally match your closing delimiter! + delimited(char('"'), build_string, char('"')).parse(input) +} -- cgit v1.2.3