From 1702f955a4546828cd535be6cecad57b90128de8 Mon Sep 17 00:00:00 2001 From: Akshay Date: Fri, 23 Aug 2024 23:00:52 +0100 Subject: add lists and index exprs --- src/ast.rs | 18 ++++++++- src/eval.rs | 106 ++++++++++++++++++++++++++++++++++++++++++++++++++-- src/parser.rs | 118 ++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 229 insertions(+), 13 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 35bf6c3..fe986da 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -59,14 +59,15 @@ pub enum Statement { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Expr { Node, - FieldAccess(Box, Identifier), Unit, Lit(Literal), Ident(Identifier), - // List(Vec), + FieldAccess(Box, Identifier), + Index(Box, Box), Bin(Box, BinOp, Box), Unary(Box, UnaryOp), Call(Call), + List(List), If(IfExpr), Block(Block), } @@ -154,6 +155,18 @@ impl From for Expr { } } +/// A list construction expression +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct List { + pub items: Vec, +} + +impl From for Expr { + fn from(list: List) -> Expr { + Expr::List(list) + } +} + #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum Type { Unit, @@ -161,6 +174,7 @@ pub enum Type { String, Boolean, Node, + List, } #[derive(Debug, PartialEq, Eq, Clone)] diff --git a/src/eval.rs b/src/eval.rs index 7822419..104571b 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -15,17 +15,17 @@ impl Variable { &self.value } - fn ty(&self) -> ast::Type { - self.ty + fn ty(&self) -> &ast::Type { + &self.ty } fn assign(&mut self, value: Value) -> Result { - if self.ty() == value.ty() { + if self.ty() == &value.ty() { self.value = value; Ok(self.value.clone()) } else { Err(Error::TypeMismatch { - expected: self.ty(), + expected: self.ty().clone(), got: value.ty(), }) } @@ -39,6 +39,7 @@ pub enum Value { String(String), Boolean(bool), Node(NodeId), + List(Vec), } type NodeId = usize; @@ -51,6 +52,7 @@ impl Value { Self::String(_) => ast::Type::String, Self::Boolean(_) => ast::Type::Boolean, Self::Node(_) => ast::Type::Node, + Self::List(_) => ast::Type::List, } } @@ -61,6 +63,7 @@ impl Value { ast::Type::String => Self::default_string(), ast::Type::Boolean => Self::default_bool(), ast::Type::Node => unreachable!(), + ast::Type::List => Self::default_list(), } } @@ -76,6 +79,10 @@ impl Value { Self::String(String::default()) } + fn default_list() -> Self { + Self::List(Vec::new()) + } + fn as_boolean(&self) -> std::result::Result { match self { Self::Boolean(b) => Ok(*b), @@ -116,6 +123,16 @@ impl Value { } } + fn as_list(&self) -> std::result::Result, Error> { + match self { + Self::List(values) => Ok(values.clone()), + v => Err(Error::TypeMismatch { + expected: ast::Type::List, + got: v.ty(), + }), + } + } + fn add(&self, other: &Self) -> Result { match (self, other) { (Self::Integer(s), Self::Integer(o)) => Ok(Self::Integer(*s + *o)), @@ -278,6 +295,13 @@ impl fmt::Display for Value { Self::String(s) => write!(f, "{s}"), Self::Boolean(b) => write!(f, "{b}"), Self::Node(id) => write!(f, ""), + Self::List(items) => { + write!(f, "[")?; + for i in items { + write!(f, "{i}")?; + } + write!(f, "]") + } } } } @@ -300,6 +324,12 @@ impl From<&str> for Value { } } +impl From> for Value { + fn from(value: Vec) -> Self { + Self::List(value) + } +} + type NodeKind = u16; #[derive(Debug, Default)] @@ -371,11 +401,16 @@ pub enum Error { AlreadyBound(ast::Identifier), MalformedExpr(String), InvalidNodeKind(String), + NoParentNode(tree_sitter::Node<'static>), InvalidStringSlice { length: usize, start: i128, end: i128, }, + ArrayOutOfBounds { + idx: i128, + len: usize + }, // current node is only set in visitors, not in BEGIN or END blocks CurrentNodeNotPresent, } @@ -484,6 +519,8 @@ impl Context { ast::Expr::Bin(lhs, op, rhs) => self.eval_bin(&*lhs, *op, &*rhs), ast::Expr::Unary(expr, op) => self.eval_unary(&*expr, *op), ast::Expr::Call(call) => self.eval_call(&*call), + ast::Expr::List(list) => self.eval_list(&*list), + ast::Expr::Index(target, idx) => self.eval_index(&*target, &*idx), ast::Expr::If(if_expr) => self.eval_if(if_expr), ast::Expr::Block(block) => self.eval_block(block), ast::Expr::Node => self.eval_node(), @@ -704,10 +741,40 @@ impl Context { .unwrap(); Ok(Value::String(text.to_owned())) } + ("parent", [arg]) => { + let v = self.eval_expr(arg)?; + let id = v.as_node()?; + let node = self.get_node_by_id(id).unwrap(); + let parent = node.parent(); + parent + .map(|n| Value::Node(n.id())) + .ok_or(Error::NoParentNode(node)) + } (s, _) => Err(Error::FailedLookup(s.to_owned())), } } + fn eval_list(&mut self, list: &ast::List) -> Result { + let mut vals = vec![]; + for i in &list.items { + vals.push(self.eval_expr(i)?); + } + Ok(vals.into()) + } + + fn eval_index(&mut self, target: &ast::Expr, idx: &ast::Expr) -> Result { + let mut target = self.eval_expr(target)?.as_list()?; + let idx = self.eval_expr(idx)?.as_int()?; + if idx < 0 || idx >= target.len() as i128 { + Err(Error::ArrayOutOfBounds { + idx, + len: target.len() + }) + } else { + Ok(target.remove(idx as usize)) + } + } + fn eval_declaration(&mut self, decl: &ast::Declaration) -> Result { let initial_value = match decl.init.as_ref() { Some(init) => Some(self.eval_expr(&*init)?), @@ -986,4 +1053,35 @@ mod test { } ); } + + #[test] + fn test_list() { + let language = tree_sitter_python::language(); + let mut ctx = Context::new(language) + .with_program(ast::Program::new()) + .unwrap(); + assert_eq!( + ctx.eval_block(&ast::Block { + body: vec![ast::Statement::Declaration(ast::Declaration { + ty: ast::Type::List, + name: "a".to_owned(), + init: Some( + ast::Expr::List(ast::List { + items: vec![ast::Expr::int(5)] + }) + .boxed() + ), + }),] + }), + Ok(Value::Unit) + ); + assert_eq!( + ctx.lookup(&String::from("a")).unwrap().clone(), + Variable { + ty: ast::Type::List, + name: "a".to_owned(), + value: vec![5.into()].into(), + } + ); + } } diff --git a/src/parser.rs b/src/parser.rs index 4ec8e57..15d03fe 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -169,28 +169,51 @@ fn parse_mul<'a>(i: &'a str) -> IResult<&'a str, Expr> { let div = parse_op("/", BinOp::Arith(ArithOp::Div)); let mod_ = parse_op("%", BinOp::Arith(ArithOp::Mod)); let op = alt((mul, div, mod_)); - let recursive = parse_binary(parse_field_access, op, parse_mul); - let base = parse_field_access; + let recursive = parse_binary(parse_field_or_index, op, parse_mul); + let base = parse_field_or_index; alt((recursive, base)).parse(i) } -fn parse_field_access<'a>(i: &'a str) -> IResult<&'a str, Expr> { - let trailing = map(tuple((ws(char('.')), ws(parse_ident))), |(_, i)| i); +fn parse_field_or_index<'a>(i: &'a str) -> IResult<&'a str, Expr> { + enum FieldOrIndex { + Field(String), + Index(Expr), + } + let (i, base) = parse_atom(i)?; + let field = map(tuple((ws(char('.')), ws(parse_ident))), |(_, i)| { + FieldOrIndex::Field(i) + }); + let index = map( + tuple((ws(char('[')), parse_expr, ws(char(']')))), + |(_, idx, _)| FieldOrIndex::Index(idx), + ); + fold_many0( - trailing, + alt((field, index)), move || base.clone(), - move |acc, new| Expr::FieldAccess(acc.boxed(), new), + move |acc, new| match new { + FieldOrIndex::Field(f) => Expr::FieldAccess(acc.boxed(), f), + FieldOrIndex::Index(idx) => Expr::Index(acc.boxed(), idx.boxed()), + }, )(i) } +fn parse_list<'a>(i: &'a str) -> IResult<&'a str, List> { + let open = ws(char('[')); + let items = separated_list0(char(','), parse_expr); + let close = ws(char(']')); + map(tuple((open, items, close)), |(_, items, _)| List { items }).parse(i) +} + fn parse_atom<'a>(i: &'a str) -> IResult<&'a str, Expr> { let inner = alt(( map(tag("node"), |_| Expr::Node), map(parse_block, Expr::Block), map(parse_if, Expr::If), map(parse_call, Expr::Call), + map(parse_list, Expr::List), map(parse_lit, Expr::Lit), map(parse_ident, Expr::Ident), map(parse_unit, |_| Expr::Unit), @@ -254,7 +277,8 @@ fn parse_type<'a>(i: &'a str) -> IResult<&'a str, Type> { let int = value(Type::Integer, tag("int")); let string = value(Type::String, tag("string")); let bool_ = value(Type::Boolean, tag("bool")); - alt((int, string, bool_)).parse(i) + let list = value(Type::List, tag("list")); + alt((int, string, bool_, list)).parse(i) } fn parse_declaration<'a>(i: &'a str) -> IResult<&'a str, Declaration> { @@ -510,6 +534,27 @@ mod test { ) )) ); + assert_eq!( + parse_expr("a[0]"), + Ok(( + "", + Expr::Index(Expr::Ident("a".to_owned()).boxed(), Expr::int(0).boxed()) + )) + ); + assert_eq!( + parse_expr("children(node)[0]"), + Ok(( + "", + Expr::Index( + Expr::Call(Call { + function: "children".to_owned(), + parameters: vec![Expr::Node] + }) + .boxed(), + Expr::int(0).boxed() + ) + )) + ); } #[test] @@ -544,6 +589,22 @@ mod test { }) )) ); + assert_eq!( + parse_statement(r#"list a =["a", "b", "c"]; "#), + Ok(( + "", + Statement::Declaration(Declaration { + ty: Type::List, + name: "a".to_owned(), + init: Some( + Expr::List(List { + items: vec![Expr::str("a"), Expr::str("b"), Expr::str("c"),] + }) + .boxed() + ) + }) + )) + ); } #[test] @@ -624,6 +685,49 @@ mod test { ); } + #[test] + fn test_parse_index() { + assert_eq!( + parse_expr( + r#" + a[0] + "# + ), + Ok(( + "", + Expr::Index(Expr::Ident("a".to_owned()).boxed(), Expr::int(0).boxed()), + )) + ); + assert_eq!( + parse_expr(r#"node.children[idx]"#), + Ok(( + "", + Expr::Index( + Expr::FieldAccess(Expr::Node.boxed(), Identifier::from("children")).boxed(), + Expr::Ident("idx".to_owned()).boxed() + ) + )) + ); + assert_eq!( + parse_expr(r#"foo[i].bar[j]"#), + Ok(( + "", + Expr::Index( + Expr::FieldAccess( + Expr::Index( + Expr::Ident("foo".to_owned()).boxed(), + Expr::Ident("i".to_owned()).boxed() + ) + .boxed(), + "bar".to_owned() + ) + .boxed(), + Expr::Ident("j".to_owned()).boxed() + ), + )) + ); + } + // #[test] // fn test_skip_query() { // assert_eq!( -- cgit v1.2.3