diff options
Diffstat (limited to 'src/eval.rs')
-rw-r--r-- | src/eval.rs | 174 |
1 files changed, 113 insertions, 61 deletions
diff --git a/src/eval.rs b/src/eval.rs index 1a3f8a8..7822419 100644 --- a/src/eval.rs +++ b/src/eval.rs | |||
@@ -1,6 +1,6 @@ | |||
1 | //! tree walking interpreter for tbsp | 1 | //! tree walking interpreter for tbsp |
2 | 2 | ||
3 | use crate::ast; | 3 | use crate::{ast, Wrap}; |
4 | use std::{collections::HashMap, fmt}; | 4 | use std::{collections::HashMap, fmt}; |
5 | 5 | ||
6 | #[derive(Debug, PartialEq, Eq, Clone)] | 6 | #[derive(Debug, PartialEq, Eq, Clone)] |
@@ -38,10 +38,11 @@ pub enum Value { | |||
38 | Integer(i128), | 38 | Integer(i128), |
39 | String(String), | 39 | String(String), |
40 | Boolean(bool), | 40 | Boolean(bool), |
41 | Node, | 41 | Node(NodeId), |
42 | FieldAccess(Vec<String>), | ||
43 | } | 42 | } |
44 | 43 | ||
44 | type NodeId = usize; | ||
45 | |||
45 | impl Value { | 46 | impl Value { |
46 | fn ty(&self) -> ast::Type { | 47 | fn ty(&self) -> ast::Type { |
47 | match self { | 48 | match self { |
@@ -49,8 +50,7 @@ impl Value { | |||
49 | Self::Integer(_) => ast::Type::Integer, | 50 | Self::Integer(_) => ast::Type::Integer, |
50 | Self::String(_) => ast::Type::String, | 51 | Self::String(_) => ast::Type::String, |
51 | Self::Boolean(_) => ast::Type::Boolean, | 52 | Self::Boolean(_) => ast::Type::Boolean, |
52 | Self::Node => ast::Type::Node, | 53 | Self::Node(_) => ast::Type::Node, |
53 | Self::FieldAccess(_) => ast::Type::Node, | ||
54 | } | 54 | } |
55 | } | 55 | } |
56 | 56 | ||
@@ -106,6 +106,16 @@ impl Value { | |||
106 | } | 106 | } |
107 | } | 107 | } |
108 | 108 | ||
109 | fn as_node(&self) -> std::result::Result<NodeId, Error> { | ||
110 | match self { | ||
111 | Self::Node(id) => Ok(*id), | ||
112 | v => Err(Error::TypeMismatch { | ||
113 | expected: ast::Type::Node, | ||
114 | got: v.ty(), | ||
115 | }), | ||
116 | } | ||
117 | } | ||
118 | |||
109 | fn add(&self, other: &Self) -> Result { | 119 | fn add(&self, other: &Self) -> Result { |
110 | match (self, other) { | 120 | match (self, other) { |
111 | (Self::Integer(s), Self::Integer(o)) => Ok(Self::Integer(*s + *o)), | 121 | (Self::Integer(s), Self::Integer(o)) => Ok(Self::Integer(*s + *o)), |
@@ -267,8 +277,7 @@ impl fmt::Display for Value { | |||
267 | Self::Integer(i) => write!(f, "{i}"), | 277 | Self::Integer(i) => write!(f, "{i}"), |
268 | Self::String(s) => write!(f, "{s}"), | 278 | Self::String(s) => write!(f, "{s}"), |
269 | Self::Boolean(b) => write!(f, "{b}"), | 279 | Self::Boolean(b) => write!(f, "{b}"), |
270 | Self::Node => write!(f, "<node>"), | 280 | Self::Node(id) => write!(f, "<node #{id}>"), |
271 | Self::FieldAccess(items) => write!(f, "<node>.{}", items.join(".")), | ||
272 | } | 281 | } |
273 | } | 282 | } |
274 | } | 283 | } |
@@ -373,15 +382,17 @@ pub enum Error { | |||
373 | 382 | ||
374 | pub type Result = std::result::Result<Value, Error>; | 383 | pub type Result = std::result::Result<Value, Error>; |
375 | 384 | ||
376 | pub struct Context<'a> { | 385 | pub struct Context { |
377 | variables: HashMap<ast::Identifier, Variable>, | 386 | variables: HashMap<ast::Identifier, Variable>, |
378 | language: tree_sitter::Language, | 387 | language: tree_sitter::Language, |
379 | visitors: Visitors, | 388 | visitors: Visitors, |
380 | input_src: Option<String>, | 389 | input_src: Option<String>, |
381 | cursor: Option<tree_sitter::TreeCursor<'a>>, | 390 | cursor: Option<tree_sitter::TreeCursor<'static>>, |
391 | tree: Option<&'static tree_sitter::Tree>, | ||
392 | cache: HashMap<NodeId, tree_sitter::Node<'static>>, | ||
382 | } | 393 | } |
383 | 394 | ||
384 | impl<'a> fmt::Debug for Context<'a> { | 395 | impl fmt::Debug for Context { |
385 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | 396 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
386 | f.debug_struct("Context") | 397 | f.debug_struct("Context") |
387 | .field("variables", &self.variables) | 398 | .field("variables", &self.variables) |
@@ -400,7 +411,7 @@ impl<'a> fmt::Debug for Context<'a> { | |||
400 | } | 411 | } |
401 | } | 412 | } |
402 | 413 | ||
403 | impl<'a> Context<'a> { | 414 | impl Context { |
404 | pub fn new(language: tree_sitter::Language) -> Self { | 415 | pub fn new(language: tree_sitter::Language) -> Self { |
405 | Self { | 416 | Self { |
406 | visitors: Default::default(), | 417 | visitors: Default::default(), |
@@ -408,9 +419,44 @@ impl<'a> Context<'a> { | |||
408 | language, | 419 | language, |
409 | input_src: None, | 420 | input_src: None, |
410 | cursor: None, | 421 | cursor: None, |
422 | tree: None, | ||
423 | cache: HashMap::default(), | ||
411 | } | 424 | } |
412 | } | 425 | } |
413 | 426 | ||
427 | pub fn cache_node(&mut self, node: tree_sitter::Node<'static>) { | ||
428 | self.cache.entry(node.id()).or_insert(node); | ||
429 | } | ||
430 | |||
431 | pub fn get_node_by_id(&mut self, id: usize) -> Option<tree_sitter::Node<'static>> { | ||
432 | let root_node = self.tree.as_ref().map(|t| t.root_node())?; | ||
433 | self.get_node_by_id_helper(root_node, id) | ||
434 | } | ||
435 | |||
436 | fn get_node_by_id_helper( | ||
437 | &mut self, | ||
438 | start: tree_sitter::Node<'static>, | ||
439 | id: usize, | ||
440 | ) -> Option<tree_sitter::Node<'static>> { | ||
441 | self.cache_node(start); | ||
442 | |||
443 | if let Some(found) = self.cache.get(&id) { | ||
444 | return Some(*found); | ||
445 | } | ||
446 | |||
447 | if start.id() == id { | ||
448 | return Some(start); | ||
449 | } else { | ||
450 | for child in start.children(&mut start.walk()) { | ||
451 | if let Some(n) = self.get_node_by_id_helper(child, id) { | ||
452 | return Some(n); | ||
453 | }; | ||
454 | } | ||
455 | } | ||
456 | |||
457 | None | ||
458 | } | ||
459 | |||
414 | pub fn with_program(mut self, program: ast::Program) -> std::result::Result<Self, Error> { | 460 | pub fn with_program(mut self, program: ast::Program) -> std::result::Result<Self, Error> { |
415 | for stanza in program.stanzas.into_iter() { | 461 | for stanza in program.stanzas.into_iter() { |
416 | self.visitors.insert(stanza, &self.language)?; | 462 | self.visitors.insert(stanza, &self.language)?; |
@@ -423,8 +469,10 @@ impl<'a> Context<'a> { | |||
423 | self | 469 | self |
424 | } | 470 | } |
425 | 471 | ||
426 | pub fn with_cursor(mut self, cursor: tree_sitter::TreeCursor<'a>) -> Self { | 472 | pub fn with_tree(mut self, tree: tree_sitter::Tree) -> Self { |
427 | self.cursor = Some(cursor); | 473 | let tree = Box::leak(Box::new(tree)); |
474 | self.cursor = Some(tree.walk()); | ||
475 | self.tree = Some(tree); | ||
428 | self | 476 | self |
429 | } | 477 | } |
430 | 478 | ||
@@ -436,10 +484,10 @@ impl<'a> Context<'a> { | |||
436 | ast::Expr::Bin(lhs, op, rhs) => self.eval_bin(&*lhs, *op, &*rhs), | 484 | ast::Expr::Bin(lhs, op, rhs) => self.eval_bin(&*lhs, *op, &*rhs), |
437 | ast::Expr::Unary(expr, op) => self.eval_unary(&*expr, *op), | 485 | ast::Expr::Unary(expr, op) => self.eval_unary(&*expr, *op), |
438 | ast::Expr::Call(call) => self.eval_call(&*call), | 486 | ast::Expr::Call(call) => self.eval_call(&*call), |
439 | ast::Expr::IfExpr(if_expr) => self.eval_if(if_expr), | 487 | ast::Expr::If(if_expr) => self.eval_if(if_expr), |
440 | ast::Expr::Block(block) => self.eval_block(block), | 488 | ast::Expr::Block(block) => self.eval_block(block), |
441 | ast::Expr::Node => Ok(Value::Node), | 489 | ast::Expr::Node => self.eval_node(), |
442 | ast::Expr::FieldAccess(items) => Ok(Value::FieldAccess(items.to_owned())), | 490 | ast::Expr::FieldAccess(expr, items) => self.eval_field_access(expr, items), |
443 | } | 491 | } |
444 | } | 492 | } |
445 | 493 | ||
@@ -451,6 +499,16 @@ impl<'a> Context<'a> { | |||
451 | } | 499 | } |
452 | } | 500 | } |
453 | 501 | ||
502 | fn eval_node(&mut self) -> Result { | ||
503 | self.cursor | ||
504 | .as_ref() | ||
505 | .ok_or(Error::CurrentNodeNotPresent)? | ||
506 | .node() | ||
507 | .id() | ||
508 | .wrap(Value::Node) | ||
509 | .wrap_ok() | ||
510 | } | ||
511 | |||
454 | fn lookup(&mut self, ident: &ast::Identifier) -> std::result::Result<&Variable, Error> { | 512 | fn lookup(&mut self, ident: &ast::Identifier) -> std::result::Result<&Variable, Error> { |
455 | self.variables | 513 | self.variables |
456 | .get(ident) | 514 | .get(ident) |
@@ -469,14 +527,14 @@ impl<'a> Context<'a> { | |||
469 | ty: ast::Type, | 527 | ty: ast::Type, |
470 | ) -> std::result::Result<&mut Variable, Error> { | 528 | ) -> std::result::Result<&mut Variable, Error> { |
471 | if self.lookup(ident).is_err() { | 529 | if self.lookup(ident).is_err() { |
472 | Ok(self | 530 | self.variables |
473 | .variables | ||
474 | .entry(ident.to_owned()) | 531 | .entry(ident.to_owned()) |
475 | .or_insert_with(|| Variable { | 532 | .or_insert_with(|| Variable { |
476 | name: ident.to_owned(), | 533 | name: ident.to_owned(), |
477 | value: Value::default(ty), | 534 | value: Value::default(ty), |
478 | ty, | 535 | ty, |
479 | })) | 536 | }) |
537 | .wrap_ok() | ||
480 | } else { | 538 | } else { |
481 | Err(Error::AlreadyBound(ident.to_owned())) | 539 | Err(Error::AlreadyBound(ident.to_owned())) |
482 | } | 540 | } |
@@ -574,7 +632,7 @@ impl<'a> Context<'a> { | |||
574 | } | 632 | } |
575 | } | 633 | } |
576 | 634 | ||
577 | fn eval_if(&mut self, if_expr: &ast::If) -> Result { | 635 | fn eval_if(&mut self, if_expr: &ast::IfExpr) -> Result { |
578 | let cond = self.eval_expr(&if_expr.condition)?; | 636 | let cond = self.eval_expr(&if_expr.condition)?; |
579 | 637 | ||
580 | if cond.as_boolean()? { | 638 | if cond.as_boolean()? { |
@@ -638,32 +696,9 @@ impl<'a> Context<'a> { | |||
638 | } | 696 | } |
639 | } | 697 | } |
640 | ("text", [arg]) => { | 698 | ("text", [arg]) => { |
641 | let node = match self.eval_expr(arg)? { | 699 | let v = self.eval_expr(arg)?; |
642 | Value::Node => self | 700 | let id = v.as_node()?; |
643 | .cursor | 701 | let node = self.get_node_by_id(id).unwrap(); |
644 | .as_ref() | ||
645 | .ok_or(Error::CurrentNodeNotPresent)? | ||
646 | .node(), | ||
647 | Value::FieldAccess(fields) => { | ||
648 | let mut node = self | ||
649 | .cursor | ||
650 | .as_ref() | ||
651 | .ok_or(Error::CurrentNodeNotPresent)? | ||
652 | .node(); | ||
653 | for field in &fields { | ||
654 | node = node | ||
655 | .child_by_field_name(field.as_bytes()) | ||
656 | .ok_or_else(|| Error::FailedLookup(field.to_owned()))?; | ||
657 | } | ||
658 | node | ||
659 | } | ||
660 | v => { | ||
661 | return Err(Error::TypeMismatch { | ||
662 | expected: ast::Type::Node, | ||
663 | got: v.ty(), | ||
664 | }) | ||
665 | } | ||
666 | }; | ||
667 | let text = node | 702 | let text = node |
668 | .utf8_text(self.input_src.as_ref().unwrap().as_bytes()) | 703 | .utf8_text(self.input_src.as_ref().unwrap().as_bytes()) |
669 | .unwrap(); | 704 | .unwrap(); |
@@ -689,7 +724,7 @@ impl<'a> Context<'a> { | |||
689 | 724 | ||
690 | fn eval_statement(&mut self, stmt: &ast::Statement) -> Result { | 725 | fn eval_statement(&mut self, stmt: &ast::Statement) -> Result { |
691 | match stmt { | 726 | match stmt { |
692 | ast::Statement::Bare(expr) => self.eval_expr(expr).map(|_| Value::Unit), | 727 | ast::Statement::Bare(expr) => self.eval_expr(expr), |
693 | ast::Statement::Declaration(decl) => self.eval_declaration(decl), | 728 | ast::Statement::Declaration(decl) => self.eval_declaration(decl), |
694 | } | 729 | } |
695 | } | 730 | } |
@@ -701,6 +736,24 @@ impl<'a> Context<'a> { | |||
701 | Ok(Value::Unit) | 736 | Ok(Value::Unit) |
702 | } | 737 | } |
703 | 738 | ||
739 | fn eval_field_access(&mut self, expr: &ast::Expr, field: &ast::Identifier) -> Result { | ||
740 | let v = self.eval_expr(expr)?; | ||
741 | let base = v.as_node()?; | ||
742 | let base_node = self.get_node_by_id(base).unwrap(); | ||
743 | base_node | ||
744 | .child_by_field_name(field) | ||
745 | .ok_or(Error::InvalidNodeKind(field.clone())) | ||
746 | .map(|n| n.id()) | ||
747 | .map(Value::Node) | ||
748 | } | ||
749 | |||
750 | fn goto_first_child(&mut self) -> bool { | ||
751 | self.cursor | ||
752 | .as_mut() | ||
753 | .map(|c| c.goto_first_child()) | ||
754 | .unwrap_or_default() | ||
755 | } | ||
756 | |||
704 | pub fn eval(&mut self) -> Result { | 757 | pub fn eval(&mut self) -> Result { |
705 | let visitors = std::mem::take(&mut self.visitors); | 758 | let visitors = std::mem::take(&mut self.visitors); |
706 | let mut has_next = true; | 759 | let mut has_next = true; |
@@ -710,29 +763,29 @@ impl<'a> Context<'a> { | |||
710 | self.eval_block(&visitors.begin)?; | 763 | self.eval_block(&visitors.begin)?; |
711 | 764 | ||
712 | while has_next { | 765 | while has_next { |
713 | let current_node = self.cursor.as_mut().unwrap().node(); | 766 | let current_node = self.cursor.as_ref().unwrap().node(); |
714 | postorder.push(current_node); | 767 | postorder.push(current_node); |
715 | 768 | ||
716 | let visitor = visitors.get_by_node(current_node); | 769 | let visitor = visitors.get_by_node(current_node); |
717 | 770 | ||
718 | visitor.map(|v| self.eval_block(&v.enter)); | 771 | if let Some(v) = visitor { |
772 | self.eval_block(&v.enter)?; | ||
773 | } | ||
719 | 774 | ||
720 | has_next = self.cursor.as_mut().unwrap().goto_first_child(); | 775 | has_next = self.goto_first_child(); |
721 | 776 | ||
722 | if !has_next { | 777 | if !has_next { |
723 | has_next = self.cursor.as_mut().unwrap().goto_next_sibling(); | 778 | has_next = self.cursor.as_mut().unwrap().goto_next_sibling(); |
724 | postorder | 779 | if let Some(v) = postorder.pop().and_then(|n| visitors.get_by_node(n)) { |
725 | .pop() | 780 | self.eval_block(&v.leave)?; |
726 | .and_then(|n| visitors.get_by_node(n)) | 781 | } |
727 | .map(|v| self.eval_block(&v.leave)); | ||
728 | } | 782 | } |
729 | 783 | ||
730 | while !has_next && self.cursor.as_mut().unwrap().goto_parent() { | 784 | while !has_next && self.cursor.as_mut().unwrap().goto_parent() { |
731 | has_next = self.cursor.as_mut().unwrap().goto_next_sibling(); | 785 | has_next = self.cursor.as_mut().unwrap().goto_next_sibling(); |
732 | postorder | 786 | if let Some(v) = postorder.pop().and_then(|n| visitors.get_by_node(n)) { |
733 | .pop() | 787 | self.eval_block(&v.leave)?; |
734 | .and_then(|n| visitors.get_by_node(n)) | 788 | } |
735 | .map(|v| self.eval_block(&v.leave)); | ||
736 | } | 789 | } |
737 | } | 790 | } |
738 | 791 | ||
@@ -748,12 +801,11 @@ pub fn evaluate(file: &str, program: &str, language: tree_sitter::Language) -> R | |||
748 | let _ = parser.set_language(&language); | 801 | let _ = parser.set_language(&language); |
749 | 802 | ||
750 | let tree = parser.parse(file, None).unwrap(); | 803 | let tree = parser.parse(file, None).unwrap(); |
751 | let cursor = tree.walk(); | ||
752 | 804 | ||
753 | let program = ast::Program::new().from_str(program).unwrap(); | 805 | let program = ast::Program::new().from_str(program).unwrap(); |
754 | let mut ctx = Context::new(language) | 806 | let mut ctx = Context::new(language) |
755 | .with_input(file.to_owned()) | 807 | .with_input(file.to_owned()) |
756 | .with_cursor(cursor) | 808 | .with_tree(tree) |
757 | .with_program(program)?; | 809 | .with_program(program)?; |
758 | 810 | ||
759 | ctx.eval() | 811 | ctx.eval() |
@@ -857,7 +909,7 @@ mod test { | |||
857 | name: "a".to_owned(), | 909 | name: "a".to_owned(), |
858 | init: Some(ast::Expr::int(1).boxed()), | 910 | init: Some(ast::Expr::int(1).boxed()), |
859 | }), | 911 | }), |
860 | ast::Statement::Bare(ast::Expr::IfExpr(ast::If { | 912 | ast::Statement::Bare(ast::Expr::If(ast::IfExpr { |
861 | condition: ast::Expr::true_().boxed(), | 913 | condition: ast::Expr::true_().boxed(), |
862 | then: ast::Block { | 914 | then: ast::Block { |
863 | body: vec![ast::Statement::Bare(ast::Expr::Bin( | 915 | body: vec![ast::Statement::Bare(ast::Expr::Bin( |