diff --git a/src/main.rs b/src/main.rs index 985f41a9..fb15a03f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -225,6 +225,10 @@ pub enum TransformPass { EncodeAdts, #[value(alias = "no-adts")] NoEncodeAdts, + #[value(alias = "eta")] + EtaReduce, + #[value(alias = "no-eta")] + NoEtaReduce, } #[derive(Default)] @@ -232,6 +236,7 @@ pub struct TransformPasses { pre_reduce: bool, coalesce_ctrs: bool, encode_adts: bool, + eta_reduce: bool, } impl TransformPasses { @@ -239,6 +244,7 @@ impl TransformPasses { self.pre_reduce = true; self.coalesce_ctrs = true; self.encode_adts = true; + self.eta_reduce = true; } } @@ -256,6 +262,8 @@ impl TransformPass { NoCoalesceCtrs => opts.coalesce_ctrs = false, EncodeAdts => opts.encode_adts = true, NoEncodeAdts => opts.encode_adts = false, + EtaReduce => opts.eta_reduce = true, + NoEtaReduce => opts.eta_reduce = false, } } opts @@ -320,6 +328,9 @@ fn load_book(files: &[String], transform_opts: &TransformOpts) -> Book { ); } for (_, def) in &mut book.nets { + if transform_passes.eta_reduce { + def.eta_reduce(); + } for tree in def.trees_mut() { if transform_passes.coalesce_ctrs { tree.coalesce_constructors(); @@ -432,9 +443,10 @@ fn compile_executable(target: &str, host: &host::Host) -> Result<(), io::Error> stdlib trace transform { - pre_reduce - encode_adts coalesce_ctrs + encode_adts + eta_reduce + pre_reduce } util { apply_tree diff --git a/src/transform.rs b/src/transform.rs index 460dd479..a79e85e9 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -1,3 +1,4 @@ pub mod coalesce_ctrs; pub mod encode_adts; +pub mod eta_reduce; pub mod pre_reduce; diff --git a/src/transform/eta_reduce.rs b/src/transform/eta_reduce.rs new file mode 100644 index 00000000..2f345185 --- /dev/null +++ b/src/transform/eta_reduce.rs @@ -0,0 +1,177 @@ +//! Carries out simple eta-reduction, to reduce the amount of rewrites at +//! runtime. +//! +//! ### Eta-equivalence +//! +//! In interaction combinators, there are some nets that are equivalent and +//! have no observable difference +//! +//! ![Image of eta-equivalence](https://i.postimg.cc/XYVxdMFW/image.png) +//! +//! This module implements the eta-equivalence rule at the top-left of the image +//! above +//! +//! ```txt +//! /|-, ,-|\ eta_reduce +//! ---| | X | |-- ~~~~~~~~~~~~> ------------- +//! \|-' '-|/ +//! ``` +//! +//! In hvm-core's AST representation, this reduction looks like this +//! +//! ```txt +//! {lab x y} ... {lab x y} ~~~~~~~~> x ..... x +//! ``` +//! +//! Essentially, both occurrences of the same constructor are replaced by a +//! variable. +//! +//! ### The algorithm +//! +//! The code uses a two-pass O(n) algorithm, where `n` is the amount of nodes +//! in the AST +//! +//! In the first pass, a node-list is built out of an ordered traversal of the +//! AST. Crucially, the node list stores variable offsets instead of the +//! variable's names Since the AST's order is consistent, the ordering of nodes +//! in the node list can be reproduced with a traversal. +//! +//! This means that each occurrence of a variable is encoded with the offset in +//! the node-list to the _other_ occurrence of the variable. +//! +//! For example, if we start with the net: `[(x y) (x y)]` +//! +//! The resulting node list will look like this: +//! +//! `[Ctr(1), Ctr(0), Var(3), Var(3), Ctr(0), Var(-3), Var(-3)]` +//! +//! The second pass uses the node list to find repeated constructors. If a +//! constructor's children are both variables with the same offset, then we +//! lookup that offset relative to the constructor. If it is equal to the first +//! constructor, it means both of them are equal and they can be replaced with a +//! variable. +//! +//! The pass also reduces subnets such as `(* *) -> *` + +use std::{collections::HashMap, ops::RangeFrom}; + +use crate::ast::{Net, Tree}; + +impl Net { + /// Carries out simple eta-reduction + pub fn eta_reduce(&mut self) { + let mut phase1 = Phase1::default(); + for tree in self.trees() { + phase1.walk_tree(tree); + } + let mut phase2 = Phase2 { nodes: phase1.nodes, index: 0 .. }; + for tree in self.trees_mut() { + phase2.reduce_tree(tree); + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum NodeType { + Ctr(u16), + Var(isize), + Num(u64), + Era, + Other, + Hole, +} + +#[derive(Default)] +struct Phase1<'a> { + vars: HashMap<&'a str, usize>, + nodes: Vec, +} + +impl<'a> Phase1<'a> { + fn walk_tree(&mut self, tree: &'a Tree) { + match tree { + Tree::Ctr { lab, ports } => { + let last_port = ports.len() - 1; + for (idx, i) in ports.iter().enumerate() { + if idx != last_port { + self.nodes.push(NodeType::Ctr(*lab)); + } + self.walk_tree(i); + } + } + Tree::Var { nam } => { + let nam: &str = &*nam; + if let Some(i) = self.vars.get(nam) { + let j = self.nodes.len() as isize; + self.nodes.push(NodeType::Var(*i as isize - j)); + self.nodes[*i] = NodeType::Var(j - *i as isize); + } else { + self.vars.insert(nam, self.nodes.len()); + self.nodes.push(NodeType::Hole); + } + } + Tree::Era => self.nodes.push(NodeType::Era), + Tree::Num { val } => self.nodes.push(NodeType::Num(*val)), + _ => { + self.nodes.push(NodeType::Other); + for i in tree.children() { + self.walk_tree(i); + } + } + } + } +} + +struct Phase2 { + nodes: Vec, + index: RangeFrom, +} + +impl Phase2 { + fn reduce_ctr(&mut self, lab: u16, ports: &mut Vec, skip: usize) -> NodeType { + if skip == ports.len() { + return NodeType::Other; + } + if skip == ports.len() - 1 { + return self.reduce_tree(&mut ports[skip]); + } + let head_index = self.index.next().unwrap(); + let a = self.reduce_tree(&mut ports[skip]); + let b = self.reduce_ctr(lab, ports, skip + 1); + if a == b { + match a { + NodeType::Var(delta) => { + if self.nodes[head_index.wrapping_add_signed(delta)] == NodeType::Ctr(lab) { + ports.pop(); + return NodeType::Var(delta); + } + } + NodeType::Era => { + ports.pop(); + return NodeType::Era; + } + NodeType::Num(val) => { + ports.pop(); + return NodeType::Num(val); + } + _ => {} + } + } + NodeType::Ctr(lab) + } + fn reduce_tree(&mut self, tree: &mut Tree) -> NodeType { + if let Tree::Ctr { lab, ports } = tree { + let ty = self.reduce_ctr(*lab, ports, 0); + if ports.len() == 1 { + *tree = ports.pop().unwrap(); + } + ty + } else { + let index = self.index.next().unwrap(); + for i in tree.children_mut() { + self.reduce_tree(i); + } + self.nodes[index] + } + } +} diff --git a/tests/transform.rs b/tests/transform.rs index 22a2ed98..b201f673 100644 --- a/tests/transform.rs +++ b/tests/transform.rs @@ -43,9 +43,7 @@ pub fn test_adt_encoding() { use std::str::FromStr; pub fn parse_and_encode(net: &str) -> String { let mut net = Net::from_str(net).unwrap(); - println!("{net}"); net.trees_mut().for_each(Tree::coalesce_constructors); - println!("{net}"); net.trees_mut().for_each(Tree::encode_scott_adts); format!("{net}") } @@ -65,3 +63,26 @@ pub fn test_adt_encoding() { assert_display_snapshot!(parse_and_encode("(* x x)"), @"(:1:2)"); assert_display_snapshot!(parse_and_encode("(((((* x x) x) * x) x) * x)"), @"(:0:2 (:0:2 (:1:2)))"); } + +#[test] +pub fn test_eta() { + use hvmc::ast::Net; + use std::str::FromStr; + pub fn parse_and_reduce(net: &str) -> String { + let mut net = Net::from_str(net).unwrap(); + net.eta_reduce(); + format!("{net}") + } + assert_display_snapshot!(parse_and_reduce("((x y) (x y))"), @"(x x)"); + assert_display_snapshot!(parse_and_reduce("((a b c d e f) (a b c d e f))"), @"(a a)"); + assert_display_snapshot!(parse_and_reduce("<+ (a b) (a b)>"), @"<+ a a>"); + assert_display_snapshot!(parse_and_reduce("(a b) & ((a b) (c d)) ~ (c d) "), @r###" + a + & (a c) ~ c + "###); + assert_display_snapshot!(parse_and_reduce("((a b) [a b])"), @"((a b) [a b])"); + assert_display_snapshot!(parse_and_reduce("((a b c) b c)"), @"((a b) b)"); + assert_display_snapshot!(parse_and_reduce("([(a b) (c d)] [(a b) (c d)])"), @"(a a)"); + assert_display_snapshot!(parse_and_reduce("(* *)"), @"*"); + assert_display_snapshot!(parse_and_reduce("([(#0 #0) (#12345 #12345)] [(* *) (a a)])"), @"([#0 #12345] [* (a a)])"); +}