From e12a4b46d3c1d3a0dc95249af5b32a82f7763367 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 25 Jan 2023 14:03:39 +0100 Subject: [PATCH 01/32] wip --- zokrates_ast/src/common/mod.rs | 2 +- zokrates_ast/src/common/solvers.rs | 34 +++++++++++++++++++++++++++--- zokrates_ast/src/ir/clean.rs | 1 + zokrates_ast/src/ir/folder.rs | 1 + zokrates_ast/src/ir/from_flat.rs | 3 +++ zokrates_ast/src/ir/mod.rs | 17 ++++++++++++--- zokrates_codegen/src/lib.rs | 7 ++++-- 7 files changed, 56 insertions(+), 9 deletions(-) diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index 13d23bfdb..e0337dec2 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; -pub use self::solvers::Solver; +pub use self::solvers::{Solver, ZirSolver}; pub use self::variable::Variable; pub use format_string::FormatString; diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index 9b4f5c900..42318342f 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -2,6 +2,22 @@ use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; +#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] +pub enum ZirSolver<'ast, T> { + #[serde(borrow)] + Function(ZirFunction<'ast, T>), + Indexed(usize, usize), +} + +impl<'ast, T> fmt::Display for ZirSolver<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ZirSolver::Function(_) => write!(f, "Zir(..)"), + ZirSolver::Indexed(index, n) => write!(f, "Zir@{}({})", index, n), + } + } +} + #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] pub enum Solver<'ast, T> { ConditionEq, @@ -13,13 +29,22 @@ pub enum Solver<'ast, T> { ShaCh, EuclideanDiv, #[serde(borrow)] - Zir(ZirFunction<'ast, T>), + Zir(ZirSolver<'ast, T>), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] SnarkVerifyBls12377(usize), } +impl<'ast, T> Solver<'ast, T> { + pub fn zir_function(function: ZirFunction<'ast, T>) -> Self { + Solver::Zir(ZirSolver::Function(function)) + } + pub fn zir_indexed(index: usize, argument_count: usize) -> Self { + Solver::Zir(ZirSolver::Indexed(index, argument_count)) + } +} + impl<'ast, T> fmt::Display for Solver<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -31,7 +56,7 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> { Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"), Solver::ShaCh => write!(f, "ShaCh"), Solver::EuclideanDiv => write!(f, "EuclideanDiv"), - Solver::Zir(_) => write!(f, "Zir(..)"), + Solver::Zir(s) => write!(f, "{}", s), #[cfg(feature = "bellman")] Solver::Sha256Round => write!(f, "Sha256Round"), #[cfg(feature = "ark")] @@ -51,7 +76,10 @@ impl<'ast, T> Solver<'ast, T> { Solver::ShaAndXorAndXorAnd => (3, 1), Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), - Solver::Zir(f) => (f.arguments.len(), 1), + Solver::Zir(s) => match s { + ZirSolver::Function(f) => (f.arguments.len(), 1), + ZirSolver::Indexed(_, n) => (*n, 1), + }, #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] diff --git a/zokrates_ast/src/ir/clean.rs b/zokrates_ast/src/ir/clean.rs index b4fb8f445..998e32987 100644 --- a/zokrates_ast/src/ir/clean.rs +++ b/zokrates_ast/src/ir/clean.rs @@ -14,6 +14,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a .statements .into_iter() .flat_map(|s| Cleaner::default().fold_statement(s)), + solvers: self.solvers, } } } diff --git a/zokrates_ast/src/ir/folder.rs b/zokrates_ast/src/ir/folder.rs index 6e67c15de..99466addc 100644 --- a/zokrates_ast/src/ir/folder.rs +++ b/zokrates_ast/src/ir/folder.rs @@ -50,6 +50,7 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( .flat_map(|s| f.fold_statement(s)) .collect(), return_count: p.return_count, + solvers: p.solvers, } } diff --git a/zokrates_ast/src/ir/from_flat.rs b/zokrates_ast/src/ir/from_flat.rs index fc961cd85..7932f74b3 100644 --- a/zokrates_ast/src/ir/from_flat.rs +++ b/zokrates_ast/src/ir/from_flat.rs @@ -2,6 +2,8 @@ use crate::flat::{FlatDirective, FlatExpression, FlatProgIterator, FlatStatement use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement}; use zokrates_field::Field; +use super::SolverMap; + impl QuadComb { fn from_flat_expression>>(flat_expression: U) -> QuadComb { let flat_expression = flat_expression.into(); @@ -24,6 +26,7 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator>> statements: flat_prog_iterator.statements.into_iter().map(Into::into), arguments: flat_prog_iterator.arguments, return_count: flat_prog_iterator.return_count, + solvers: SolverMap::default(), } } diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index 78b48f808..de73b6834 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -2,7 +2,7 @@ use crate::common::FormatString; use crate::typed::ConcreteType; use derivative::Derivative; use serde::{Deserialize, Serialize}; -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashMap}; use std::fmt; use std::hash::Hash; use zokrates_field::Field; @@ -22,8 +22,8 @@ pub use self::expression::{CanonicalLinComb, LinComb}; pub use self::serialize::ProgEnum; pub use crate::common::Parameter; pub use crate::common::RuntimeError; -pub use crate::common::Solver; pub use crate::common::Variable; +pub use crate::common::{Solver, ZirSolver}; pub use self::witness::Witness; @@ -124,20 +124,29 @@ impl<'ast, T: Field> fmt::Display for Statement<'ast, T> { } pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec>>; +pub type SolverMap<'ast, T> = HashMap>; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct ProgIterator<'ast, T, I: IntoIterator>> { pub arguments: Vec, pub return_count: usize, pub statements: I, + #[serde(borrow)] + pub solvers: SolverMap<'ast, T>, } impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, I> { - pub fn new(arguments: Vec, statements: I, return_count: usize) -> Self { + pub fn new( + arguments: Vec, + statements: I, + return_count: usize, + solvers: SolverMap<'ast, T>, + ) -> Self { Self { arguments, return_count, statements, + solvers, } } @@ -146,6 +155,7 @@ impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, statements: self.statements.into_iter().collect::>(), arguments: self.arguments, return_count: self.return_count, + solvers: self.solvers, } } @@ -192,6 +202,7 @@ impl<'ast, T> Prog<'ast, T> { statements: self.statements.into_iter(), arguments: self.arguments, return_count: self.return_count, + solvers: self.solvers, } } } diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index e149a9688..795fa00ab 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -24,7 +24,7 @@ use zokrates_ast::common::embed::*; use zokrates_ast::common::FlatEmbed; use zokrates_ast::common::{RuntimeError, Variable}; use zokrates_ast::flat::*; -use zokrates_ast::ir::Solver; +use zokrates_ast::ir::{Solver, ZirSolver}; use zokrates_ast::zir::types::{Type, UBitwidth}; use zokrates_ast::zir::{ BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter, @@ -2241,7 +2241,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .map(|assignee| self.use_variable(&assignee)) .collect(); - let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); + + let solver = Solver::Zir(ZirSolver::Function(function)); + let directive = FlatDirective::new(outputs, solver, inputs); + statements_flattened.push_back(FlatStatement::Directive(directive)); } ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { From 49889111836b9b6f7c3d65ebda2d058af8a8552f Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 26 Jan 2023 15:18:31 +0100 Subject: [PATCH 02/32] optimize zir solvers by indexing --- .../src/flatten_complex_types.rs | 4 +- zokrates_ast/src/common/mod.rs | 2 +- zokrates_ast/src/common/solvers.rs | 37 +---- zokrates_ast/src/ir/from_flat.rs | 4 +- zokrates_ast/src/ir/mod.rs | 10 +- zokrates_ast/src/ir/serialize.rs | 142 ++++++++++-------- zokrates_ast/src/ir/solver_indexer.rs | 50 ++++++ zokrates_ast/src/zir/identifier.rs | 12 +- zokrates_ast/src/zir/mod.rs | 1 + zokrates_ast/src/zir/substitution.rs | 26 ++++ zokrates_codegen/src/lib.rs | 32 +++- zokrates_core/src/optimizer/duplicate.rs | 20 ++- zokrates_core/src/optimizer/mod.rs | 1 + zokrates_core/src/optimizer/redefinition.rs | 2 +- zokrates_interpreter/src/lib.rs | 11 +- 15 files changed, 235 insertions(+), 119 deletions(-) create mode 100644 zokrates_ast/src/ir/solver_indexer.rs create mode 100644 zokrates_ast/src/zir/substitution.rs diff --git a/zokrates_analysis/src/flatten_complex_types.rs b/zokrates_analysis/src/flatten_complex_types.rs index f4b81d8e1..4fa123bc1 100644 --- a/zokrates_analysis/src/flatten_complex_types.rs +++ b/zokrates_analysis/src/flatten_complex_types.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth}; @@ -481,7 +481,7 @@ impl<'ast, T: Field> Flattener { // } #[derive(Default)] pub struct ArgumentFinder<'ast, T> { - pub identifiers: HashMap, zir::Type>, + pub identifiers: BTreeMap, zir::Type>, _phantom: PhantomData, } diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index e0337dec2..13d23bfdb 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; -pub use self::solvers::{Solver, ZirSolver}; +pub use self::solvers::Solver; pub use self::variable::Variable; pub use format_string::FormatString; diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index 42318342f..0551bb6d3 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -2,22 +2,6 @@ use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; -#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] -pub enum ZirSolver<'ast, T> { - #[serde(borrow)] - Function(ZirFunction<'ast, T>), - Indexed(usize, usize), -} - -impl<'ast, T> fmt::Display for ZirSolver<'ast, T> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ZirSolver::Function(_) => write!(f, "Zir(..)"), - ZirSolver::Indexed(index, n) => write!(f, "Zir@{}({})", index, n), - } - } -} - #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] pub enum Solver<'ast, T> { ConditionEq, @@ -29,22 +13,14 @@ pub enum Solver<'ast, T> { ShaCh, EuclideanDiv, #[serde(borrow)] - Zir(ZirSolver<'ast, T>), + Zir(ZirFunction<'ast, T>), + IndexedCall(usize, usize), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] SnarkVerifyBls12377(usize), } -impl<'ast, T> Solver<'ast, T> { - pub fn zir_function(function: ZirFunction<'ast, T>) -> Self { - Solver::Zir(ZirSolver::Function(function)) - } - pub fn zir_indexed(index: usize, argument_count: usize) -> Self { - Solver::Zir(ZirSolver::Indexed(index, argument_count)) - } -} - impl<'ast, T> fmt::Display for Solver<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -56,7 +32,8 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> { Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"), Solver::ShaCh => write!(f, "ShaCh"), Solver::EuclideanDiv => write!(f, "EuclideanDiv"), - Solver::Zir(s) => write!(f, "{}", s), + Solver::Zir(_) => write!(f, "Zir(..)"), + Solver::IndexedCall(index, argc) => write!(f, "IndexedCall@{}({})", index, argc), #[cfg(feature = "bellman")] Solver::Sha256Round => write!(f, "Sha256Round"), #[cfg(feature = "ark")] @@ -76,10 +53,8 @@ impl<'ast, T> Solver<'ast, T> { Solver::ShaAndXorAndXorAnd => (3, 1), Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), - Solver::Zir(s) => match s { - ZirSolver::Function(f) => (f.arguments.len(), 1), - ZirSolver::Indexed(_, n) => (*n, 1), - }, + Solver::Zir(f) => (f.arguments.len(), 1), + Solver::IndexedCall(_, n) => (*n, 1), #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] diff --git a/zokrates_ast/src/ir/from_flat.rs b/zokrates_ast/src/ir/from_flat.rs index 7932f74b3..3404c6c49 100644 --- a/zokrates_ast/src/ir/from_flat.rs +++ b/zokrates_ast/src/ir/from_flat.rs @@ -2,8 +2,6 @@ use crate::flat::{FlatDirective, FlatExpression, FlatProgIterator, FlatStatement use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement}; use zokrates_field::Field; -use super::SolverMap; - impl QuadComb { fn from_flat_expression>>(flat_expression: U) -> QuadComb { let flat_expression = flat_expression.into(); @@ -26,7 +24,7 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator>> statements: flat_prog_iterator.statements.into_iter().map(Into::into), arguments: flat_prog_iterator.arguments, return_count: flat_prog_iterator.return_count, - solvers: SolverMap::default(), + solvers: vec![], } } diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index de73b6834..b401b757a 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -2,7 +2,7 @@ use crate::common::FormatString; use crate::typed::ConcreteType; use derivative::Derivative; use serde::{Deserialize, Serialize}; -use std::collections::{BTreeSet, HashMap}; +use std::collections::BTreeSet; use std::fmt; use std::hash::Hash; use zokrates_field::Field; @@ -14,6 +14,7 @@ pub mod folder; pub mod from_flat; mod serialize; pub mod smtlib2; +mod solver_indexer; pub mod visitor; mod witness; @@ -22,8 +23,8 @@ pub use self::expression::{CanonicalLinComb, LinComb}; pub use self::serialize::ProgEnum; pub use crate::common::Parameter; pub use crate::common::RuntimeError; +pub use crate::common::Solver; pub use crate::common::Variable; -pub use crate::common::{Solver, ZirSolver}; pub use self::witness::Witness; @@ -124,7 +125,6 @@ impl<'ast, T: Field> fmt::Display for Statement<'ast, T> { } pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec>>; -pub type SolverMap<'ast, T> = HashMap>; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct ProgIterator<'ast, T, I: IntoIterator>> { @@ -132,7 +132,7 @@ pub struct ProgIterator<'ast, T, I: IntoIterator>> { pub return_count: usize, pub statements: I, #[serde(borrow)] - pub solvers: SolverMap<'ast, T>, + pub solvers: Vec>, } impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, I> { @@ -140,7 +140,7 @@ impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, arguments: Vec, statements: I, return_count: usize, - solvers: SolverMap<'ast, T>, + solvers: Vec>, ) -> Self { Self { arguments, diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 09d003900..3bd403668 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -1,14 +1,18 @@ -use crate::ir::check::UnconstrainedVariableDetector; +use crate::{ + ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer}, + Solver, +}; use super::{ProgIterator, Statement}; +use serde::Deserialize; use serde_cbor::{self, StreamDeserializer}; -use std::io::{Read, Write}; +use std::io::{Read, Seek, Write}; use zokrates_field::*; type DynamicError = Box; const ZOKRATES_MAGIC: &[u8; 4] = &[0x5a, 0x4f, 0x4b, 0]; -const ZOKRATES_VERSION_2: &[u8; 4] = &[0, 0, 0, 2]; +const ZOKRATES_VERSION_3: &[u8; 4] = &[0, 0, 0, 3]; #[derive(PartialEq, Eq, Debug)] pub enum ProgEnum< @@ -61,17 +65,21 @@ impl< impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { /// serialize a program iterator, returning the number of constraints serialized /// Note that we only return constraints, not other statements such as directives - pub fn serialize(self, mut w: W) -> Result { + pub fn serialize(self, mut w: W) -> Result { use super::folder::Folder; w.write_all(ZOKRATES_MAGIC)?; - w.write_all(ZOKRATES_VERSION_2)?; + w.write_all(ZOKRATES_VERSION_3)?; w.write_all(&T::id())?; + let solver_list_ptr_offset = w.stream_position()?; + w.write_all(&[0u8; std::mem::size_of::()])?; // reserve 8 bytes + serde_cbor::to_writer(&mut w, &self.arguments)?; serde_cbor::to_writer(&mut w, &self.return_count)?; let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); + let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default(); let statements = self.statements.into_iter(); @@ -80,12 +88,22 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a if matches!(s, Statement::Constraint(..)) { count += 1; } - let s = unconstrained_variable_detector.fold_statement(s); + let s: Vec> = solver_indexer + .fold_statement(s) + .into_iter() + .flat_map(|s| unconstrained_variable_detector.fold_statement(s)) + .collect(); for s in s { serde_cbor::to_writer(&mut w, &s)?; } } + let solver_list_offset = w.stream_position()?; + serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; + + w.seek(std::io::SeekFrom::Start(solver_list_ptr_offset))?; + w.write_all(&solver_list_offset.to_le_bytes())?; + unconstrained_variable_detector .finalize() .map(|_| count) @@ -103,11 +121,11 @@ impl<'de, R: serde_cbor::de::Read<'de>, T: serde::Deserialize<'de>> Iterator type Item = T; fn next(&mut self) -> Option { - self.s.next().transpose().unwrap() + self.s.next().and_then(|v| v.ok()) } } -impl<'de, R: Read> +impl<'de, R: Read + Seek> ProgEnum< 'de, UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bls12_381Field>>, @@ -128,104 +146,108 @@ impl<'de, R: Read> r.read_exact(&mut version) .map_err(|_| String::from("Cannot read version"))?; - if &version == ZOKRATES_VERSION_2 { + if &version == ZOKRATES_VERSION_3 { // Check the curve identifier, deserializing accordingly let mut curve = [0; 4]; r.read_exact(&mut curve) .map_err(|_| String::from("Cannot read curve identifier"))?; - use serde::de::Deserializer; - let mut p = serde_cbor::Deserializer::from_reader(r); - - struct ArgumentsVisitor; - - impl<'de> serde::de::Visitor<'de> for ArgumentsVisitor { - type Value = Vec; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("seq of flat param") - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut res = vec![]; - while let Some(e) = seq.next_element().unwrap() { - res.push(e); - } - Ok(res) - } - } - - let arguments = p.deserialize_seq(ArgumentsVisitor).unwrap(); + let mut buffer = [0u8; std::mem::size_of::()]; + r.read_exact(&mut buffer) + .map_err(|_| String::from("Cannot read solver list pointer"))?; - struct ReturnCountVisitor; + let solver_list_offset = u64::from_le_bytes(buffer); - impl<'de> serde::de::Visitor<'de> for ReturnCountVisitor { - type Value = usize; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("usize") - } + let (arguments, return_count) = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - fn visit_u32(self, v: u32) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } + let arguments: Vec = Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read parameters"))?; - fn visit_u8(self, v: u8) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } + let return_count = usize::deserialize(&mut p) + .map_err(|_| String::from("Cannot read return count"))?; - fn visit_u16(self, v: u16) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } - } + (arguments, return_count) + }; - let return_count = p.deserialize_u32(ReturnCountVisitor).unwrap(); + let statement_offset = r.stream_position().unwrap(); + r.seek(std::io::SeekFrom::Start(solver_list_offset)) + .unwrap(); match curve { m if m == Bls12_381Field::id() => { - let s = p.into_iter::>(); + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solver map"))? + }; + + r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); Ok(ProgEnum::Bls12_381Program(ProgIterator::new( arguments, UnwrappedStreamDeserializer { s }, return_count, + solvers, ))) } m if m == Bn128Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solver map"))? + }; + + r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); let s = p.into_iter::>(); Ok(ProgEnum::Bn128Program(ProgIterator::new( arguments, UnwrappedStreamDeserializer { s }, return_count, + solvers, ))) } m if m == Bls12_377Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solver map"))? + }; + + r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); let s = p.into_iter::>(); Ok(ProgEnum::Bls12_377Program(ProgIterator::new( arguments, UnwrappedStreamDeserializer { s }, return_count, + solvers, ))) } m if m == Bw6_761Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solver map"))? + }; + + r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); let s = p.into_iter::>(); Ok(ProgEnum::Bw6_761Program(ProgIterator::new( arguments, UnwrappedStreamDeserializer { s }, return_count, + solvers, ))) } _ => Err(String::from("Unknown curve identifier")), diff --git a/zokrates_ast/src/ir/solver_indexer.rs b/zokrates_ast/src/ir/solver_indexer.rs new file mode 100644 index 000000000..59393d25b --- /dev/null +++ b/zokrates_ast/src/ir/solver_indexer.rs @@ -0,0 +1,50 @@ +use crate::ir::folder::Folder; +use crate::ir::Directive; +use crate::ir::Solver; +use crate::zir::ZirFunction; +use std::collections::hash_map::DefaultHasher; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use zokrates_field::Field; + +type Hash = u64; + +fn hash(f: &ZirFunction) -> Hash { + use std::hash::Hash; + use std::hash::Hasher; + let mut hasher = DefaultHasher::new(); + f.hash(&mut hasher); + hasher.finish() +} + +#[derive(Debug, Default)] +pub struct SolverIndexer<'ast, T> { + pub solvers: Vec>, + pub index_map: HashMap, +} + +impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> { + fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { + match d.solver { + Solver::Zir(f) => { + let argc = f.arguments.len(); + let h = hash(&f); + let index = match self.index_map.entry(h) { + Entry::Occupied(v) => *v.get(), + Entry::Vacant(entry) => { + let index = self.solvers.len(); + entry.insert(index); + self.solvers.push(Solver::Zir(f)); + index + } + }; + Directive { + inputs: d.inputs, + outputs: d.outputs, + solver: Solver::IndexedCall(index, argc), + } + } + _ => d, + } + } +} diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index 249b2630e..bc60f5666 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -4,13 +4,14 @@ use std::fmt; use crate::typed::Identifier as CoreIdentifier; -#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum Identifier<'ast> { #[serde(borrow)] Source(SourceIdentifier<'ast>), + Internal(String), } -#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum SourceIdentifier<'ast> { #[serde(borrow)] Basic(CoreIdentifier<'ast>), @@ -34,10 +35,17 @@ impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Identifier::Source(s) => write!(f, "{}", s), + Identifier::Internal(s) => write!(f, "{}", s), } } } +impl<'ast> From for Identifier<'ast> { + fn from(id: String) -> Identifier<'ast> { + Identifier::Internal(id) + } +} + // this is only used in tests but somehow cfg(test) does not work impl<'ast> From<&'ast str> for Identifier<'ast> { fn from(id: &'ast str) -> Identifier<'ast> { diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index 60dc1467f..c2af87b0b 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -4,6 +4,7 @@ mod identifier; pub mod lqc; mod parameter; pub mod result_folder; +pub mod substitution; pub mod types; mod uint; mod variable; diff --git a/zokrates_ast/src/zir/substitution.rs b/zokrates_ast/src/zir/substitution.rs new file mode 100644 index 000000000..89a8f82db --- /dev/null +++ b/zokrates_ast/src/zir/substitution.rs @@ -0,0 +1,26 @@ +use super::{Folder, Identifier}; +use std::{collections::HashMap, marker::PhantomData}; +use zokrates_field::Field; + +pub struct ZirSubstitutor<'a, 'ast, T> { + substitution: &'a HashMap, Identifier<'ast>>, + _phantom: PhantomData, +} + +impl<'a, 'ast, T: Field> ZirSubstitutor<'a, 'ast, T> { + pub fn new(substitution: &'a HashMap, Identifier<'ast>>) -> Self { + ZirSubstitutor { + substitution, + _phantom: PhantomData::default(), + } + } +} + +impl<'a, 'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'a, 'ast, T> { + fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { + match self.substitution.get(&n) { + Some(v) => v.clone(), + None => n, + } + } +} diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 795fa00ab..ea954b89b 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -11,8 +11,8 @@ mod utils; use self::utils::flat_expression_from_bits; use zokrates_ast::zir::{ - ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, - ZirExpressionList, + substitution::ZirSubstitutor, ConditionalExpression, Folder, SelectExpression, ShouldReduce, + UMetadata, ZirAssemblyStatement, ZirExpressionList, }; use zokrates_interpreter::Interpreter; @@ -24,7 +24,7 @@ use zokrates_ast::common::embed::*; use zokrates_ast::common::FlatEmbed; use zokrates_ast::common::{RuntimeError, Variable}; use zokrates_ast::flat::*; -use zokrates_ast::ir::{Solver, ZirSolver}; +use zokrates_ast::ir::Solver; use zokrates_ast::zir::types::{Type, UBitwidth}; use zokrates_ast::zir::{ BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter, @@ -1885,7 +1885,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // constants do not require directives if let Some(FlatExpression::Number(ref x)) = e.field { - let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()]) + let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()], &[]) .unwrap() .into_iter() .map(FlatExpression::Number) @@ -2237,12 +2237,34 @@ impl<'ast, T: Field> Flattener<'ast, T> { .cloned() .map(|p| self.layout.get(&p.id.id).cloned().unwrap().into()) .collect(); + let outputs: Vec = assignees .into_iter() .map(|assignee| self.use_variable(&assignee)) .collect(); - let solver = Solver::Zir(ZirSolver::Function(function)); + let mut substitution_map = HashMap::default(); + for (index, p) in function.arguments.iter().enumerate() { + let new_id = format!("i{}", index).into(); + substitution_map.insert(p.id.id.clone(), new_id); + } + + let mut substitutor = ZirSubstitutor::new(&substitution_map); + let function = ZirFunction { + arguments: function + .arguments + .into_iter() + .map(|p| substitutor.fold_parameter(p)) + .collect(), + statements: function + .statements + .into_iter() + .flat_map(|s| substitutor.fold_statement(s)) + .collect(), + signature: function.signature, + }; + + let solver = Solver::Zir(function); let directive = FlatDirective::new(outputs, solver, inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index 664cfc2db..a08a3f73b 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -39,14 +39,18 @@ impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer { } fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { - let hashed = hash(&s); - let result = match self.seen.get(&hashed) { - Some(_) => vec![], - None => vec![s], - }; - - self.seen.insert(hashed); - result + match s { + Statement::Block(s) => s.into_iter().flat_map(|s| self.fold_statement(s)).collect(), + s => { + let hashed = hash(&s); + let result = match self.seen.get(&hashed) { + Some(_) => vec![], + None => vec![s], + }; + self.seen.insert(hashed); + result + } + } } } diff --git a/zokrates_core/src/optimizer/mod.rs b/zokrates_core/src/optimizer/mod.rs index 1f94740a9..95e9f25cb 100644 --- a/zokrates_core/src/optimizer/mod.rs +++ b/zokrates_core/src/optimizer/mod.rs @@ -54,6 +54,7 @@ pub fn optimize<'ast, T: Field, I: IntoIterator>>( .flat_map(move |s| directive_optimizer.fold_statement(s)) .flat_map(move |s| duplicate_optimizer.fold_statement(s)), return_count: p.return_count, + solvers: p.solvers, }; log::debug!("Done"); diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index b0877fd02..bc7fc279c 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -146,7 +146,7 @@ impl RedefinitionOptimizer { // unwrap inputs to their constant value let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect(); // run the solver - let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap(); + let outputs = Interpreter::execute_solver(&d.solver, &inputs, &[]).unwrap(); assert_eq!(outputs.len(), d.outputs.len()); // insert the results in the substitution diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 9f90e00ca..c8182b202 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -83,7 +83,7 @@ impl Interpreter { inputs.pop().unwrap(), )) } - _ => Self::execute_solver(&d.solver, &inputs), + _ => Self::execute_solver(&d.solver, &inputs, &program.solvers), } .map_err(Error::Solver)?; @@ -164,7 +164,15 @@ impl Interpreter { pub fn execute_solver<'ast, T: Field>( solver: &Solver<'ast, T>, inputs: &[T], + solvers: &[Solver<'ast, T>], ) -> Result, String> { + let solver = match solver { + Solver::IndexedCall(index, _) => solvers + .get(*index) + .ok_or_else(|| format!("Could not resolve solver at index {}", index))?, + s => s, + }; + let (expected_input_count, expected_output_count) = solver.get_signature(); assert_eq!(inputs.len(), expected_input_count); @@ -334,6 +342,7 @@ impl Interpreter { &inputs[*n + 8usize..], ) } + _ => unreachable!("unexpected solver"), }; assert_eq!(res.len(), expected_output_count); From 6a16198bed54afc7dbe7ceb1afdee32ec15044c0 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 26 Jan 2023 15:21:00 +0100 Subject: [PATCH 03/32] fix message --- zokrates_ast/src/ir/serialize.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 3bd403668..246cdbb56 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -154,7 +154,7 @@ impl<'de, R: Read + Seek> let mut buffer = [0u8; std::mem::size_of::()]; r.read_exact(&mut buffer) - .map_err(|_| String::from("Cannot read solver list pointer"))?; + .map_err(|_| String::from("Cannot read solver list offset"))?; let solver_list_offset = u64::from_le_bytes(buffer); @@ -179,7 +179,7 @@ impl<'de, R: Read + Seek> let solvers: Vec> = { let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver map"))? + .map_err(|_| String::from("Cannot read solver list"))? }; r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); @@ -197,7 +197,7 @@ impl<'de, R: Read + Seek> let solvers: Vec> = { let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver map"))? + .map_err(|_| String::from("Cannot read solver list"))? }; r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); @@ -216,7 +216,7 @@ impl<'de, R: Read + Seek> let solvers: Vec> = { let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver map"))? + .map_err(|_| String::from("Cannot read solver list"))? }; r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); @@ -235,7 +235,7 @@ impl<'de, R: Read + Seek> let solvers: Vec> = { let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver map"))? + .map_err(|_| String::from("Cannot read solver list"))? }; r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); From fdd441c37447abafbe13ed740871cc2d50b3006d Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 26 Jan 2023 15:24:24 +0100 Subject: [PATCH 04/32] use cursor in zokrates_js --- zokrates_js/src/lib.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index 07082c3e5..c776f34b5 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -471,7 +471,8 @@ pub fn compute_witness( config: JsValue, log_callback: &js_sys::Function, ) -> Result { - let prog = ir::ProgEnum::deserialize(program) + let cursor = Cursor::new(program); + let prog = ir::ProgEnum::deserialize(cursor) .map_err(|err| JsValue::from_str(&err))? .collect(); match prog { @@ -533,7 +534,8 @@ pub fn setup(program: &[u8], options: JsValue) -> Result { ) .map_err(|e| JsValue::from_str(&e))?; - let prog = ir::ProgEnum::deserialize(program) + let cursor = Cursor::new(program); + let prog = ir::ProgEnum::deserialize(cursor) .map_err(|err| JsValue::from_str(&err))? .collect(); @@ -572,7 +574,8 @@ pub fn setup_with_srs(srs: &[u8], program: &[u8], options: JsValue) -> Result Date: Tue, 28 Feb 2023 02:04:02 +0100 Subject: [PATCH 05/32] refactor, fix tests --- Cargo.lock | 1 + zokrates_ark/src/gm17.rs | 2 + zokrates_ark/src/groth16.rs | 2 + zokrates_ark/src/marlin.rs | 2 + zokrates_ast/Cargo.toml | 1 + zokrates_ast/src/ir/serialize.rs | 401 +++++++++++++------- zokrates_bellman/src/groth16.rs | 1 + zokrates_bellman/src/lib.rs | 6 + zokrates_circom/src/lib.rs | 1 + zokrates_circom/src/r1cs.rs | 2 + zokrates_core/src/optimizer/duplicate.rs | 3 + zokrates_core/src/optimizer/redefinition.rs | 11 + zokrates_interpreter/src/lib.rs | 22 +- zokrates_test/tests/wasm.rs | 1 + 14 files changed, 313 insertions(+), 143 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e21917722..dce729b97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2959,6 +2959,7 @@ name = "zokrates_ast" version = "0.1.4" dependencies = [ "ark-bls12-377", + "byteorder", "cfg-if 0.1.10", "csv", "derivative", diff --git a/zokrates_ark/src/gm17.rs b/zokrates_ark/src/gm17.rs index a8941ed03..337770b78 100644 --- a/zokrates_ark/src/gm17.rs +++ b/zokrates_ark/src/gm17.rs @@ -123,6 +123,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -148,6 +149,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_ark/src/groth16.rs b/zokrates_ark/src/groth16.rs index e9281210a..1a832f1ee 100644 --- a/zokrates_ark/src/groth16.rs +++ b/zokrates_ark/src/groth16.rs @@ -120,6 +120,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -145,6 +146,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_ark/src/marlin.rs b/zokrates_ark/src/marlin.rs index b0e0f75e7..233f9ac58 100644 --- a/zokrates_ark/src/marlin.rs +++ b/zokrates_ark/src/marlin.rs @@ -404,6 +404,7 @@ mod tests { ), Statement::constraint(Variable::new(1), Variable::public(0)), ], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -439,6 +440,7 @@ mod tests { ), Statement::constraint(Variable::new(1), Variable::public(0)), ], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index 6d9b4324b..62a403437 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -9,6 +9,7 @@ bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman"] ark = ["ark-bls12-377", "zokrates_embed/ark"] [dependencies] +byteorder = "1.4.3" zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" } cfg-if = "0.1" zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false } diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 246cdbb56..6266662d4 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -4,6 +4,7 @@ use crate::{ }; use super::{ProgIterator, Statement}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use serde::Deserialize; use serde_cbor::{self, StreamDeserializer}; use std::io::{Read, Seek, Write}; @@ -12,7 +13,7 @@ use zokrates_field::*; type DynamicError = Box; const ZOKRATES_MAGIC: &[u8; 4] = &[0x5a, 0x4f, 0x4b, 0]; -const ZOKRATES_VERSION_3: &[u8; 4] = &[0, 0, 0, 3]; +const FILE_VERSION: &[u8; 4] = &[3, 0, 0, 0]; #[derive(PartialEq, Eq, Debug)] pub enum ProgEnum< @@ -62,47 +63,195 @@ impl< } } +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum SectionType { + Arguments = 1, + Constraints = 2, + Solvers = 3, +} + +impl TryFrom for SectionType { + type Error = String; + + fn try_from(value: u32) -> Result { + match value { + 1 => Ok(SectionType::Arguments), + 2 => Ok(SectionType::Constraints), + 3 => Ok(SectionType::Solvers), + _ => Err("invalid section type".to_string()), + } + } +} + +#[derive(Debug, Clone)] +pub struct Section { + pub ty: SectionType, + pub offset: u64, + pub length: u64, +} + +impl Section { + pub fn new(ty: SectionType) -> Self { + Self { + ty, + offset: 0, + length: 0, + } + } + + pub fn set_offset(&mut self, offset: u64) { + self.offset = offset; + } + + pub fn set_length(&mut self, length: u64) { + self.length = length; + } +} + +#[derive(Debug, Clone)] +pub struct ProgHeader { + pub magic: [u8; 4], + pub version: [u8; 4], + pub curve_id: [u8; 4], + pub constraint_count: u32, + pub return_count: u32, + pub sections: Vec
, +} + +impl ProgHeader { + pub fn write(&self, mut w: W) -> std::io::Result<()> { + w.write_all(&self.magic)?; + w.write_all(&self.version)?; + w.write_all(&self.curve_id)?; + w.write_u32::(self.constraint_count)?; + w.write_u32::(self.return_count)?; + + w.write_u32::(self.sections.len() as u32)?; + for s in &self.sections { + w.write_u32::(s.ty as u32)?; + w.write_u64::(s.offset)?; + w.write_u64::(s.length)?; + } + + Ok(()) + } + + pub fn read(mut r: R) -> std::io::Result { + let mut magic = [0; 4]; + r.read_exact(&mut magic)?; + + let mut version = [0; 4]; + r.read_exact(&mut version)?; + + let mut curve_id = [0; 4]; + r.read_exact(&mut curve_id)?; + + let constraint_count = r.read_u32::()?; + let return_count = r.read_u32::()?; + + let section_count = r.read_u32::()?; + let mut sections = vec![]; + + for _ in 0..section_count { + let id = r.read_u32::()?; + let mut section = Section::new( + SectionType::try_from(id) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?, + ); + section.set_offset(r.read_u64::()?); + section.set_length(r.read_u64::()?); + sections.push(section); + } + + Ok(ProgHeader { + magic, + version, + curve_id, + constraint_count, + return_count, + sections, + }) + } + + fn get_section(&self, ty: SectionType) -> Option<&Section> { + self.sections.iter().find(|s| s.ty == ty) + } +} + impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { /// serialize a program iterator, returning the number of constraints serialized /// Note that we only return constraints, not other statements such as directives pub fn serialize(self, mut w: W) -> Result { use super::folder::Folder; - w.write_all(ZOKRATES_MAGIC)?; - w.write_all(ZOKRATES_VERSION_3)?; - w.write_all(&T::id())?; + const SECTION_COUNT: usize = 3; + const HEADER_SIZE: usize = 24 + SECTION_COUNT * 20; - let solver_list_ptr_offset = w.stream_position()?; - w.write_all(&[0u8; std::mem::size_of::()])?; // reserve 8 bytes + let mut header = ProgHeader { + magic: *ZOKRATES_MAGIC, + version: *FILE_VERSION, + curve_id: T::id(), + constraint_count: 0, + return_count: self.return_count as u32, + sections: Vec::with_capacity(SECTION_COUNT), + }; - serde_cbor::to_writer(&mut w, &self.arguments)?; - serde_cbor::to_writer(&mut w, &self.return_count)?; + w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header + + let arguments = { + let mut section = Section::new(SectionType::Arguments); + section.set_offset(w.stream_position()?); + + serde_cbor::to_writer(&mut w, &self.arguments)?; + + section.set_length(w.stream_position()? - section.offset); + section + }; - let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default(); + let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); + let mut count: usize = 0; - let statements = self.statements.into_iter(); + let constraints = { + let mut section = Section::new(SectionType::Constraints); + section.set_offset(w.stream_position()?); - let mut count = 0; - for s in statements { - if matches!(s, Statement::Constraint(..)) { - count += 1; - } - let s: Vec> = solver_indexer - .fold_statement(s) - .into_iter() - .flat_map(|s| unconstrained_variable_detector.fold_statement(s)) - .collect(); - for s in s { - serde_cbor::to_writer(&mut w, &s)?; + let statements = self.statements.into_iter(); + for s in statements { + if matches!(s, Statement::Constraint(..)) { + count += 1; + } + let s: Vec> = solver_indexer + .fold_statement(s) + .into_iter() + .flat_map(|s| unconstrained_variable_detector.fold_statement(s)) + .collect(); + for s in s { + serde_cbor::to_writer(&mut w, &s)?; + } } - } - let solver_list_offset = w.stream_position()?; - serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; + section.set_length(w.stream_position()? - section.offset); + section + }; - w.seek(std::io::SeekFrom::Start(solver_list_ptr_offset))?; - w.write_all(&solver_list_offset.to_le_bytes())?; + let solvers = { + let mut section = Section::new(SectionType::Solvers); + section.set_offset(w.stream_position()?); + + serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; + + section.set_length(w.stream_position()? - section.offset); + section + }; + + header.constraint_count = count as u32; + header + .sections + .extend_from_slice(&[arguments, constraints, solvers]); + + w.rewind()?; + header.write(&mut w)?; unconstrained_variable_detector .finalize() @@ -135,128 +284,108 @@ impl<'de, R: Read + Seek> > { pub fn deserialize(mut r: R) -> Result { + let header = ProgHeader::read(&mut r).map_err(|_| String::from("Invalid header"))?; + // Check the magic number, `ZOK` - let mut magic = [0; 4]; - r.read_exact(&mut magic) - .map_err(|_| String::from("Cannot read magic number"))?; + if &header.magic != ZOKRATES_MAGIC { + return Err("Invalid magic number".to_string()); + } - if &magic == ZOKRATES_MAGIC { - // Check the version, 2 - let mut version = [0; 4]; - r.read_exact(&mut version) - .map_err(|_| String::from("Cannot read version"))?; + // Check the file version + if &header.version != FILE_VERSION { + return Err("Invalid file version".to_string()); + } - if &version == ZOKRATES_VERSION_3 { - // Check the curve identifier, deserializing accordingly - let mut curve = [0; 4]; - r.read_exact(&mut curve) - .map_err(|_| String::from("Cannot read curve identifier"))?; + let arguments = { + let section = header.get_section(SectionType::Arguments).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - let mut buffer = [0u8; std::mem::size_of::()]; - r.read_exact(&mut buffer) - .map_err(|_| String::from("Cannot read solver list offset"))?; + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read parameters"))? + }; - let solver_list_offset = u64::from_le_bytes(buffer); + let solvers_section = header.get_section(SectionType::Solvers).unwrap(); + r.seek(std::io::SeekFrom::Start(solvers_section.offset)) + .unwrap(); - let (arguments, return_count) = { + match header.curve_id { + m if m == Bls12_381Field::id() => { + let solvers: Vec> = { let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? + }; - let arguments: Vec = Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read parameters"))?; + let section = header.get_section(SectionType::Constraints).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - let return_count = usize::deserialize(&mut p) - .map_err(|_| String::from("Cannot read return count"))?; + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); - (arguments, return_count) + Ok(ProgEnum::Bls12_381Program(ProgIterator::new( + arguments, + UnwrappedStreamDeserializer { s }, + header.return_count as usize, + solvers, + ))) + } + m if m == Bn128Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? }; - let statement_offset = r.stream_position().unwrap(); - r.seek(std::io::SeekFrom::Start(solver_list_offset)) - .unwrap(); - - match curve { - m if m == Bls12_381Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver list"))? - }; - - r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - Ok(ProgEnum::Bls12_381Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - solvers, - ))) - } - m if m == Bn128Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver list"))? - }; - - r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bn128Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - solvers, - ))) - } - m if m == Bls12_377Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver list"))? - }; - - r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bls12_377Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - solvers, - ))) - } - m if m == Bw6_761Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solver list"))? - }; - - r.seek(std::io::SeekFrom::Start(statement_offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bw6_761Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - solvers, - ))) - } - _ => Err(String::from("Unknown curve identifier")), - } - } else { - Err(String::from("Unknown version")) + let section = header.get_section(SectionType::Constraints).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + Ok(ProgEnum::Bn128Program(ProgIterator::new( + arguments, + UnwrappedStreamDeserializer { s }, + header.return_count as usize, + solvers, + ))) + } + m if m == Bls12_377Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? + }; + + let section = header.get_section(SectionType::Constraints).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + Ok(ProgEnum::Bls12_377Program(ProgIterator::new( + arguments, + UnwrappedStreamDeserializer { s }, + header.return_count as usize, + solvers, + ))) + } + m if m == Bw6_761Field::id() => { + let solvers: Vec> = { + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? + }; + + let section = header.get_section(SectionType::Constraints).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + Ok(ProgEnum::Bw6_761Program(ProgIterator::new( + arguments, + UnwrappedStreamDeserializer { s }, + header.return_count as usize, + solvers, + ))) } - } else { - Err(String::from("Wrong magic number")) + _ => Err(String::from("Unknown curve identifier")), } } } diff --git a/zokrates_bellman/src/groth16.rs b/zokrates_bellman/src/groth16.rs index d0e6c4178..1eae3b46a 100644 --- a/zokrates_bellman/src/groth16.rs +++ b/zokrates_bellman/src/groth16.rs @@ -214,6 +214,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); diff --git a/zokrates_bellman/src/lib.rs b/zokrates_bellman/src/lib.rs index 9828b9745..821e3ea2d 100644 --- a/zokrates_bellman/src/lib.rs +++ b/zokrates_bellman/src/lib.rs @@ -276,6 +276,7 @@ mod tests { arguments: vec![Parameter::private(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -297,6 +298,7 @@ mod tests { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -318,6 +320,7 @@ mod tests { arguments: vec![], return_count: 1, statements: vec![Statement::constraint(Variable::one(), Variable::public(0))], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -350,6 +353,7 @@ mod tests { Variable::public(1), ), ], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -373,6 +377,7 @@ mod tests { LinComb::from(Variable::new(42)) + LinComb::one(), Variable::public(0), )], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -400,6 +405,7 @@ mod tests { LinComb::from(Variable::new(42)) + LinComb::from(Variable::new(51)), Variable::public(0), )], + solvers: vec![], }; let interpreter = Interpreter::default(); diff --git a/zokrates_circom/src/lib.rs b/zokrates_circom/src/lib.rs index 9b16742fb..e87c0af38 100644 --- a/zokrates_circom/src/lib.rs +++ b/zokrates_circom/src/lib.rs @@ -44,6 +44,7 @@ mod tests { None, ), ], + solvers: vec![], }; let mut r1cs = vec![]; diff --git a/zokrates_circom/src/r1cs.rs b/zokrates_circom/src/r1cs.rs index 854bc0eab..b2fc8db26 100644 --- a/zokrates_circom/src/r1cs.rs +++ b/zokrates_circom/src/r1cs.rs @@ -296,6 +296,7 @@ mod tests { Variable::public(0).into(), None, )], + solvers: vec![], }; let mut buf = Vec::new(); @@ -365,6 +366,7 @@ mod tests { None, ), ], + solvers: vec![], }; let mut buf = Vec::new(); diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index a08a3f73b..346695a27 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -81,6 +81,7 @@ mod tests { ], return_count: 0, arguments: vec![], + solvers: vec![], }; let expected = p.clone(); @@ -117,6 +118,7 @@ mod tests { ], return_count: 0, arguments: vec![], + solvers: vec![], }; let expected = Prog { @@ -132,6 +134,7 @@ mod tests { ], return_count: 0, arguments: vec![], + solvers: vec![], }; assert_eq!( diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index bc7fc279c..c1f90aa5d 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -256,12 +256,14 @@ mod tests { Statement::definition(out, y), ], return_count: 1, + solvers: vec![], }; let optimized: Prog = Prog { arguments: vec![x], statements: vec![Statement::definition(out, x.id)], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -280,6 +282,7 @@ mod tests { arguments: vec![x], statements: vec![Statement::definition(one, x.id)], return_count: 1, + solvers: vec![], }; let optimized = p.clone(); @@ -316,6 +319,7 @@ mod tests { Statement::definition(out, z), ], return_count: 1, + solvers: vec![], }; let optimized: Prog = Prog { @@ -325,6 +329,7 @@ mod tests { Statement::definition(out, x.id), ], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -365,6 +370,7 @@ mod tests { Statement::definition(out_1, w), ], return_count: 2, + solvers: vec![], }; let optimized: Prog = Prog { @@ -374,6 +380,7 @@ mod tests { Statement::definition(out_1, Bn128Field::from(1)), ], return_count: 2, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -422,6 +429,7 @@ mod tests { Statement::definition(r, LinComb::from(a) + LinComb::from(b) + LinComb::from(c)), ], return_count: 1, + solvers: vec![], }; let expected: Prog = Prog { @@ -442,6 +450,7 @@ mod tests { ), ], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -479,6 +488,7 @@ mod tests { Statement::definition(z, LinComb::from(x.id)), ], return_count: 0, + solvers: vec![], }; let optimized = p.clone(); @@ -507,6 +517,7 @@ mod tests { Statement::constraint(x.id, Bn128Field::from(2)), ], return_count: 1, + solvers: vec![], }; let optimized = p.clone(); diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index c8182b202..14fbade2f 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -443,6 +443,7 @@ mod tests { .iter() .map(|&i| Bn128Field::from(i)) .collect::>(), + &[], ) .unwrap(); let res: Vec = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect(); @@ -459,6 +460,7 @@ mod tests { .iter() .map(|&i| Bn128Field::from(i)) .collect::>(), + &[], ) .unwrap(); let res: Vec = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect(); @@ -469,9 +471,12 @@ mod tests { #[test] fn bits_of_one() { let inputs = vec![Bn128Field::from(1)]; - let res = - Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = Interpreter::execute_solver( + &Solver::Bits(Bn128Field::get_required_bits()), + &inputs, + &[], + ) + .unwrap(); assert_eq!(res[253], Bn128Field::from(1)); for r in &res[0..253] { assert_eq!(*r, Bn128Field::from(0)); @@ -481,9 +486,12 @@ mod tests { #[test] fn bits_of_42() { let inputs = vec![Bn128Field::from(42)]; - let res = - Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = Interpreter::execute_solver( + &Solver::Bits(Bn128Field::get_required_bits()), + &inputs, + &[], + ) + .unwrap(); assert_eq!(res[253], Bn128Field::from(0)); assert_eq!(res[252], Bn128Field::from(1)); assert_eq!(res[251], Bn128Field::from(0)); @@ -496,7 +504,7 @@ mod tests { #[test] fn five_hundred_bits_of_1() { let inputs = vec![Bn128Field::from(1)]; - let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs).unwrap(); + let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs, &[]).unwrap(); let mut expected = vec![Bn128Field::from(0); 500]; expected[499] = Bn128Field::from(1); diff --git a/zokrates_test/tests/wasm.rs b/zokrates_test/tests/wasm.rs index 8b81f3bf0..659f5a237 100644 --- a/zokrates_test/tests/wasm.rs +++ b/zokrates_test/tests/wasm.rs @@ -20,6 +20,7 @@ fn generate_proof() { arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, statements: vec![Statement::constraint(Variable::new(0), Variable::new(0))], + solvers: vec![], }; let interpreter = Interpreter::default(); From 779bb2bc203ac955028743086dc472616aa51bc6 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 28 Feb 2023 02:06:01 +0100 Subject: [PATCH 06/32] add changelog --- changelogs/unreleased/1268-dark64 | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/1268-dark64 diff --git a/changelogs/unreleased/1268-dark64 b/changelogs/unreleased/1268-dark64 new file mode 100644 index 000000000..8b060b168 --- /dev/null +++ b/changelogs/unreleased/1268-dark64 @@ -0,0 +1 @@ +Optimize assembly solver \ No newline at end of file From 3ce57d93a4bd70456bc3b83821a99a55dc4583e3 Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 3 Mar 2023 18:19:24 +0100 Subject: [PATCH 07/32] optimize evaluation of lincomb --- zokrates_ast/src/ir/expression.rs | 7 ++--- zokrates_field/src/dummy_curve.rs | 4 ++- zokrates_field/src/lib.rs | 3 +- zokrates_interpreter/src/lib.rs | 52 ++++++++++++++----------------- 4 files changed, 31 insertions(+), 35 deletions(-) diff --git a/zokrates_ast/src/ir/expression.rs b/zokrates_ast/src/ir/expression.rs index a32a1293c..28ac7993e 100644 --- a/zokrates_ast/src/ir/expression.rs +++ b/zokrates_ast/src/ir/expression.rs @@ -121,9 +121,8 @@ impl LinComb { } pub fn is_assignee(&self, witness: &Witness) -> bool { - self.0.len() == 1 - && self.0.get(0).unwrap().1 == T::from(1) - && !witness.0.contains_key(&self.0.get(0).unwrap().0) + let (var, val) = self.0.get(0).unwrap(); + self.0.len() == 1 && val == &T::from(1) && !witness.0.contains_key(var) } pub fn try_summand(self) -> Result<(Variable, T), Self> { @@ -258,7 +257,7 @@ impl Mul<&T> for LinComb { LinComb( self.0 .into_iter() - .map(|(var, coeff)| (var, coeff * scalar.clone())) + .map(|(var, coeff)| (var, coeff * scalar)) .collect(), ) } diff --git a/zokrates_field/src/dummy_curve.rs b/zokrates_field/src/dummy_curve.rs index 5d3aed4a3..b13265923 100644 --- a/zokrates_field/src/dummy_curve.rs +++ b/zokrates_field/src/dummy_curve.rs @@ -10,7 +10,9 @@ use std::ops::{Add, Div, Mul, Sub}; const _PRIME: u8 = 7; -#[derive(Default, Debug, Hash, Clone, PartialOrd, Ord, Serialize, Deserialize, PartialEq, Eq)] +#[derive( + Default, Debug, Hash, Clone, Copy, PartialOrd, Ord, Serialize, Deserialize, PartialEq, Eq, +)] pub struct FieldPrime { v: u8, } diff --git a/zokrates_field/src/lib.rs b/zokrates_field/src/lib.rs index 38f76905b..81e9915fe 100644 --- a/zokrates_field/src/lib.rs +++ b/zokrates_field/src/lib.rs @@ -70,6 +70,7 @@ pub trait Field: + Zero + One + Clone + + Copy + PartialEq + Eq + Hash @@ -148,7 +149,7 @@ mod prime_field { type Fr = <$v as ark_ec::PairingEngine>::Fr; - #[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash)] + #[derive(PartialEq, PartialOrd, Clone, Copy, Eq, Ord, Hash)] pub struct FieldPrime { v: Fr, } diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 14fbade2f..275624fe5 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -88,7 +88,7 @@ impl Interpreter { .map_err(Error::Solver)?; for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); + witness.insert(*o, res[i]); } } Statement::Log(l, expressions) => { @@ -264,41 +264,38 @@ impl Interpreter { .collect() } Solver::Xor => { - let x = inputs[0].clone(); - let y = inputs[1].clone(); + let x = inputs[0]; + let y = inputs[1]; - vec![x.clone() + y.clone() - T::from(2) * x * y] + vec![x + y - T::from(2) * x * y] } Solver::Or => { - let x = inputs[0].clone(); - let y = inputs[1].clone(); + let x = inputs[0]; + let y = inputs[1]; - vec![x.clone() + y.clone() - x * y] + vec![x + y - x * y] } // res = b * c - (2b * c - b - c) * (a) Solver::ShaAndXorAndXorAnd => { - let a = inputs[0].clone(); - let b = inputs[1].clone(); - let c = inputs[2].clone(); - vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a] + let a = inputs[0]; + let b = inputs[1]; + let c = inputs[2]; + vec![b * c - (T::from(2) * b * c - b - c) * a] } // res = a(b - c) + c Solver::ShaCh => { - let a = inputs[0].clone(); - let b = inputs[1].clone(); - let c = inputs[2].clone(); - vec![a * (b - c.clone()) + c] + let a = inputs[0]; + let b = inputs[1]; + let c = inputs[2]; + vec![a * (b - c) + c] } - Solver::Div => vec![inputs[0] - .clone() - .checked_div(&inputs[1]) - .unwrap_or_else(T::one)], + Solver::Div => vec![inputs[0].checked_div(&inputs[1]).unwrap_or_else(T::one)], Solver::EuclideanDiv => { use num::CheckedDiv; - let n = inputs[0].clone().to_biguint(); - let d = inputs[1].clone().to_biguint(); + let n = inputs[0].to_biguint(); + let d = inputs[1].to_biguint(); let q = n.checked_div(&d).unwrap_or_else(|| 0u32.into()); let r = n - d * &q; @@ -363,14 +360,11 @@ pub enum Error { } fn evaluate_lin(w: &Witness, l: &LinComb) -> Result { - l.0.iter() - .map(|(var, mult)| { - w.0.get(var) - .map(|v| v.clone() * mult) - .ok_or(EvaluationError) - }) // get each term - .collect::, _>>() // fail if any term isn't found - .map(|v| v.iter().fold(T::from(0), |acc, t| acc + t)) // return the sum + l.0.iter().try_fold(T::from(0), |acc, (var, mult)| { + w.0.get(var) + .map(|v| acc + (*v * mult)) + .ok_or(EvaluationError) // fail if any term isn't found + }) } pub fn evaluate_quad(w: &Witness, q: &QuadComb) -> Result { From e8abfb51ea3b12a8196ddfabb6fed202c8437401 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 6 Mar 2023 13:37:12 +0400 Subject: [PATCH 08/32] clippy --- zokrates_analysis/src/flat_propagation.rs | 2 +- zokrates_analysis/src/propagation.rs | 2 +- zokrates_analysis/src/uint_optimizer.rs | 6 +++--- zokrates_ark/src/lib.rs | 2 +- zokrates_ast/src/flat/utils.rs | 2 +- zokrates_ast/src/ir/expression.rs | 4 ++-- zokrates_ast/src/ir/mod.rs | 2 +- zokrates_ast/src/zir/lqc.rs | 16 ++++++++-------- zokrates_bellman/src/lib.rs | 2 +- zokrates_codegen/src/lib.rs | 2 +- zokrates_interpreter/src/lib.rs | 4 ++-- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/zokrates_analysis/src/flat_propagation.rs b/zokrates_analysis/src/flat_propagation.rs index 155d803c8..eb2ca797a 100644 --- a/zokrates_analysis/src/flat_propagation.rs +++ b/zokrates_analysis/src/flat_propagation.rs @@ -32,7 +32,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator { match e { FlatExpression::Number(n) => FlatExpression::Number(n), FlatExpression::Identifier(id) => match self.constants.get(&id) { - Some(c) => FlatExpression::Number(c.clone()), + Some(c) => FlatExpression::Number(*c), None => FlatExpression::Identifier(id), }, FlatExpression::Add(box e1, box e2) => { diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index b7e5c0a17..708ebabbc 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -517,7 +517,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { .unwrap() { FieldElementExpression::Number(num) => { - let mut acc = num.clone(); + let mut acc = num; let mut res = vec![]; for i in (0..bit_width as usize).rev() { diff --git a/zokrates_analysis/src/uint_optimizer.rs b/zokrates_analysis/src/uint_optimizer.rs index ac96dbfe1..6a6b217ec 100644 --- a/zokrates_analysis/src/uint_optimizer.rs +++ b/zokrates_analysis/src/uint_optimizer.rs @@ -170,7 +170,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_add(&range_max.clone()) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() + range_max)) + .unwrap_or_else(|| (true, true, range_max + range_max)) }) }); @@ -223,7 +223,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_add(&target_offset) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() + target_offset)) + .unwrap_or_else(|| (true, true, range_max + target_offset)) } else { left_max .checked_add(&offset) @@ -294,7 +294,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_mul(&range_max.clone()) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() * range_max)) + .unwrap_or_else(|| (true, true, range_max * range_max)) }) }); diff --git a/zokrates_ark/src/lib.rs b/zokrates_ark/src/lib.rs index 425be3a8d..322785c32 100644 --- a/zokrates_ark/src/lib.rs +++ b/zokrates_ark/src/lib.rs @@ -150,7 +150,7 @@ impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator> self.program .public_inputs_values(self.witness.as_ref().unwrap()) .iter() - .map(|v| v.clone().into_ark()) + .map(|v| v.into_ark()) .collect() } } diff --git a/zokrates_ast/src/flat/utils.rs b/zokrates_ast/src/flat/utils.rs index 03239687c..c37d2f68c 100644 --- a/zokrates_ast/src/flat/utils.rs +++ b/zokrates_ast/src/flat/utils.rs @@ -36,7 +36,7 @@ pub fn flat_expression_from_variable_summands(v: &[(T, usize)]) -> Fla match v.len() { 0 => FlatExpression::Number(T::zero()), 1 => { - let (val, var) = v[0].clone(); + let (val, var) = v[0]; FlatExpression::Mult( box FlatExpression::Number(val), box FlatExpression::Identifier(Variable::new(var)), diff --git a/zokrates_ast/src/ir/expression.rs b/zokrates_ast/src/ir/expression.rs index 28ac7993e..ccebfa9a7 100644 --- a/zokrates_ast/src/ir/expression.rs +++ b/zokrates_ast/src/ir/expression.rs @@ -164,8 +164,8 @@ impl LinComb { match acc.entry(val) { Entry::Occupied(o) => { // if the new value is non zero, update, else remove the term entirely - if o.get().clone() + coeff.clone() != T::zero() { - *o.into_mut() = o.get().clone() + coeff; + if *o.get() + coeff != T::zero() { + *o.into_mut() = *o.get() + coeff; } else { o.remove(); } diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index b401b757a..66535d3c1 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -181,7 +181,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a self.arguments .iter() .filter(|p| !p.private) - .map(|p| witness.0.get(&p.id).unwrap().clone()) + .map(|p| *witness.0.get(&p.id).unwrap()) .chain(witness.return_values()) .collect() } diff --git a/zokrates_ast/src/zir/lqc.rs b/zokrates_ast/src/zir/lqc.rs index 121b2a396..834f601f4 100644 --- a/zokrates_ast/src/zir/lqc.rs +++ b/zokrates_ast/src/zir/lqc.rs @@ -43,7 +43,7 @@ impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> { linear: { let mut l = self.linear; other.linear.iter_mut().for_each(|(c, _)| { - *c = T::zero() - &*c; + *c = T::zero() - *c; }); l.append(&mut other.linear); l @@ -51,7 +51,7 @@ impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> { quadratic: { let mut q = self.quadratic; other.quadratic.iter_mut().for_each(|(c, _, _)| { - *c = T::zero() - &*c; + *c = T::zero() - *c; }); q.append(&mut other.quadratic); q @@ -68,18 +68,18 @@ impl<'ast, T: Field> LinQuadComb<'ast, T> { } Ok(Self { - constant: self.constant.clone() * rhs.constant.clone(), + constant: self.constant * rhs.constant, linear: { // lin0 * const1 + lin1 * const0 self.linear .clone() .into_iter() - .map(|(c, i)| (c * rhs.constant.clone(), i)) + .map(|(c, i)| (c * rhs.constant, i)) .chain( rhs.linear .clone() .into_iter() - .map(|(c, i)| (c * self.constant.clone(), i)), + .map(|(c, i)| (c * self.constant, i)), ) .collect() }, @@ -87,16 +87,16 @@ impl<'ast, T: Field> LinQuadComb<'ast, T> { // quad0 * const1 + quad1 * const0 + lin0 * lin1 self.quadratic .into_iter() - .map(|(c, i0, i1)| (c * rhs.constant.clone(), i0, i1)) + .map(|(c, i0, i1)| (c * rhs.constant, i0, i1)) .chain( rhs.quadratic .into_iter() - .map(|(c, i0, i1)| (c * self.constant.clone(), i0, i1)), + .map(|(c, i0, i1)| (c * self.constant, i0, i1)), ) .chain(self.linear.iter().flat_map(|(cl, l)| { rhs.linear .iter() - .map(|(cr, r)| (cl.clone() * cr.clone(), l.clone(), r.clone())) + .map(|(cr, r)| (*cl * *cr, l.clone(), r.clone())) })) .collect() }, diff --git a/zokrates_bellman/src/lib.rs b/zokrates_bellman/src/lib.rs index 821e3ea2d..263cfc211 100644 --- a/zokrates_bellman/src/lib.rs +++ b/zokrates_bellman/src/lib.rs @@ -189,7 +189,7 @@ impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator Flattener<'ast, T> { // constants do not require directives if let Some(FlatExpression::Number(ref x)) = e.field { - let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()], &[]) + let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[*x], &[]) .unwrap() .into_iter() .map(FlatExpression::Number) diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 275624fe5..8a7f88541 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -50,7 +50,7 @@ impl Interpreter { witness.insert(Variable::one(), T::one()); for (arg, value) in program.arguments.iter().zip(inputs.iter()) { - witness.insert(arg.id, value.clone()); + witness.insert(arg.id, *value); } for statement in program.statements.into_iter() { @@ -188,7 +188,7 @@ impl Interpreter { .map(|(a, v)| match &a.id._type { zir::Type::FieldElement => Ok(( a.id.id.clone(), - zokrates_ast::zir::FieldElementExpression::Number(v.clone()).into(), + zokrates_ast::zir::FieldElementExpression::Number(*v).into(), )), zir::Type::Boolean => match v { v if *v == T::from(0) => Ok(( From cfe43e672dd05f8953c3772cd65cd9ced275b200 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 6 Mar 2023 13:53:03 +0400 Subject: [PATCH 09/32] revert is_assignee --- zokrates_ast/src/ir/expression.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/zokrates_ast/src/ir/expression.rs b/zokrates_ast/src/ir/expression.rs index ccebfa9a7..33e3856e7 100644 --- a/zokrates_ast/src/ir/expression.rs +++ b/zokrates_ast/src/ir/expression.rs @@ -121,8 +121,9 @@ impl LinComb { } pub fn is_assignee(&self, witness: &Witness) -> bool { - let (var, val) = self.0.get(0).unwrap(); - self.0.len() == 1 && val == &T::from(1) && !witness.0.contains_key(var) + self.0.len() == 1 + && self.0.get(0).unwrap().1 == T::from(1) + && !witness.0.contains_key(&self.0.get(0).unwrap().0) } pub fn try_summand(self) -> Result<(Variable, T), Self> { From 0f31a1b42ea03c419a7b5600d3ee3a05b7db6fbd Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 14 Mar 2023 19:20:02 +0100 Subject: [PATCH 10/32] apply suggestions (part 1) --- changelogs/unreleased/1268-dark64 | 2 +- zokrates_ast/src/common/mod.rs | 2 +- zokrates_ast/src/common/solvers.rs | 12 +++++++++--- zokrates_ast/src/ir/solver_indexer.rs | 10 +++++++--- zokrates_ast/src/zir/identifier.rs | 12 ++++++------ zokrates_ast/src/zir/substitution.rs | 19 ++++++++----------- zokrates_codegen/src/lib.rs | 18 +++--------------- zokrates_interpreter/src/lib.rs | 6 +++--- 8 files changed, 38 insertions(+), 43 deletions(-) diff --git a/changelogs/unreleased/1268-dark64 b/changelogs/unreleased/1268-dark64 index 8b060b168..7d2f33daa 100644 --- a/changelogs/unreleased/1268-dark64 +++ b/changelogs/unreleased/1268-dark64 @@ -1 +1 @@ -Optimize assembly solver \ No newline at end of file +Reduce compiled program size by deduplicating assembly solvers \ No newline at end of file diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index 13d23bfdb..c8ad09440 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -10,6 +10,6 @@ pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; -pub use self::solvers::Solver; +pub use self::solvers::{RefCall, Solver}; pub use self::variable::Variable; pub use format_string::FormatString; diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index 0551bb6d3..9c6e9bbcc 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -2,6 +2,12 @@ use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; +#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] +pub struct RefCall { + pub index: usize, + pub argument_count: usize, +} + #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] pub enum Solver<'ast, T> { ConditionEq, @@ -14,7 +20,7 @@ pub enum Solver<'ast, T> { EuclideanDiv, #[serde(borrow)] Zir(ZirFunction<'ast, T>), - IndexedCall(usize, usize), + Ref(RefCall), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] @@ -33,7 +39,7 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> { Solver::ShaCh => write!(f, "ShaCh"), Solver::EuclideanDiv => write!(f, "EuclideanDiv"), Solver::Zir(_) => write!(f, "Zir(..)"), - Solver::IndexedCall(index, argc) => write!(f, "IndexedCall@{}({})", index, argc), + Solver::Ref(call) => write!(f, "Ref@{}({})", call.index, call.argument_count), #[cfg(feature = "bellman")] Solver::Sha256Round => write!(f, "Sha256Round"), #[cfg(feature = "ark")] @@ -54,7 +60,7 @@ impl<'ast, T> Solver<'ast, T> { Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), Solver::Zir(f) => (f.arguments.len(), 1), - Solver::IndexedCall(_, n) => (*n, 1), + Solver::Ref(c) => (c.argument_count, 1), #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] diff --git a/zokrates_ast/src/ir/solver_indexer.rs b/zokrates_ast/src/ir/solver_indexer.rs index 59393d25b..cbe1a1511 100644 --- a/zokrates_ast/src/ir/solver_indexer.rs +++ b/zokrates_ast/src/ir/solver_indexer.rs @@ -1,3 +1,4 @@ +use crate::common::RefCall; use crate::ir::folder::Folder; use crate::ir::Directive; use crate::ir::Solver; @@ -20,14 +21,14 @@ fn hash(f: &ZirFunction) -> Hash { #[derive(Debug, Default)] pub struct SolverIndexer<'ast, T> { pub solvers: Vec>, - pub index_map: HashMap, + pub index_map: HashMap, } impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> { fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { match d.solver { Solver::Zir(f) => { - let argc = f.arguments.len(); + let argument_count = f.arguments.len(); let h = hash(&f); let index = match self.index_map.entry(h) { Entry::Occupied(v) => *v.get(), @@ -41,7 +42,10 @@ impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> { Directive { inputs: d.inputs, outputs: d.outputs, - solver: Solver::IndexedCall(index, argc), + solver: Solver::Ref(RefCall { + index, + argument_count, + }), } } _ => d, diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index bc60f5666..b036e42a7 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -31,6 +31,12 @@ impl<'ast> fmt::Display for SourceIdentifier<'ast> { } } +impl<'ast> Identifier<'ast> { + pub fn internal>(name: S) -> Self { + Identifier::Internal(name.into()) + } +} + impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -40,12 +46,6 @@ impl<'ast> fmt::Display for Identifier<'ast> { } } -impl<'ast> From for Identifier<'ast> { - fn from(id: String) -> Identifier<'ast> { - Identifier::Internal(id) - } -} - // this is only used in tests but somehow cfg(test) does not work impl<'ast> From<&'ast str> for Identifier<'ast> { fn from(id: &'ast str) -> Identifier<'ast> { diff --git a/zokrates_ast/src/zir/substitution.rs b/zokrates_ast/src/zir/substitution.rs index 89a8f82db..d559ccffe 100644 --- a/zokrates_ast/src/zir/substitution.rs +++ b/zokrates_ast/src/zir/substitution.rs @@ -1,22 +1,19 @@ use super::{Folder, Identifier}; -use std::{collections::HashMap, marker::PhantomData}; +use std::collections::HashMap; use zokrates_field::Field; -pub struct ZirSubstitutor<'a, 'ast, T> { - substitution: &'a HashMap, Identifier<'ast>>, - _phantom: PhantomData, +#[derive(Default)] +pub struct ZirSubstitutor<'ast> { + substitution: HashMap, Identifier<'ast>>, } -impl<'a, 'ast, T: Field> ZirSubstitutor<'a, 'ast, T> { - pub fn new(substitution: &'a HashMap, Identifier<'ast>>) -> Self { - ZirSubstitutor { - substitution, - _phantom: PhantomData::default(), - } +impl<'ast> ZirSubstitutor<'ast> { + pub fn new(substitution: HashMap, Identifier<'ast>>) -> Self { + Self { substitution } } } -impl<'a, 'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'a, 'ast, T> { +impl<'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'ast> { fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { match self.substitution.get(&n) { Some(v) => v.clone(), diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 77a3944fb..57aa7fcd6 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -2245,24 +2245,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { let mut substitution_map = HashMap::default(); for (index, p) in function.arguments.iter().enumerate() { - let new_id = format!("i{}", index).into(); + let new_id = Identifier::internal(format!("i{}", index)); substitution_map.insert(p.id.id.clone(), new_id); } - let mut substitutor = ZirSubstitutor::new(&substitution_map); - let function = ZirFunction { - arguments: function - .arguments - .into_iter() - .map(|p| substitutor.fold_parameter(p)) - .collect(), - statements: function - .statements - .into_iter() - .flat_map(|s| substitutor.fold_statement(s)) - .collect(), - signature: function.signature, - }; + let mut substitutor = ZirSubstitutor::new(substitution_map); + let function = substitutor.fold_function(function); let solver = Solver::Zir(function); let directive = FlatDirective::new(outputs, solver, inputs); diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 8a7f88541..19a3ef399 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -167,9 +167,9 @@ impl Interpreter { solvers: &[Solver<'ast, T>], ) -> Result, String> { let solver = match solver { - Solver::IndexedCall(index, _) => solvers - .get(*index) - .ok_or_else(|| format!("Could not resolve solver at index {}", index))?, + Solver::Ref(call) => solvers + .get(call.index) + .ok_or_else(|| format!("Could not get solver at index {}", call.index))?, s => s, }; From f84faaae834a9276ad820ed7e9d8561c61ade04d Mon Sep 17 00:00:00 2001 From: Ahmed Castro Date: Sat, 18 Mar 2023 20:09:28 -0600 Subject: [PATCH 11/32] Sudoku example fix **Problem** the logic behind `countDuplicates` is invalid because it's doing this: (incorrect) ``` if(duplicates + e11 == e21) duplicates = 1; else duplicates = 0; ``` Instead of this: (correct) ``` if(e11 == e21) duplicates = duplicates + 1; else duplicates = duplicates; ``` **Solution** I'm using an auxiliary variable to fix this **Alternative solution** We can also fix it with an array and a loop. But I went for the auxiliary variable for no particular reason. ``` def countDuplicates(field e11, field e12, field e21, field e22) -> u32 { u32[6] mut duplicates = [0,0,0,0,0,0]; duplicates[0] = e11 == e12 ? 1 : 0; duplicates[1] = e11 == e21 ? 1 : 0; duplicates[2] = e11 == e22 ? 1 : 0; duplicates[3] = e12 == e21 ? 1 : 0; duplicates[4] = e12 == e22 ? 1 : 0; duplicates[5] = e21 == e22 ? 1 : 0; u32 mut count = 0; for u32 i in 0..5 { count = count + duplicates[i]; } return count; } ``` And btw, I also added an assert to make sure we check for this validation. I've tested this on the playground and on chain on goerli and with this changes it works correctly. --- .../examples/sudoku/sudoku_checker.zok | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/zokrates_cli/examples/sudoku/sudoku_checker.zok b/zokrates_cli/examples/sudoku/sudoku_checker.zok index d6b596f35..2272a42e9 100644 --- a/zokrates_cli/examples/sudoku/sudoku_checker.zok +++ b/zokrates_cli/examples/sudoku/sudoku_checker.zok @@ -11,13 +11,20 @@ // We use a naive encoding of the values as `[1, 2, 3, 4]` and rely on if-else statements to detect duplicates def countDuplicates(field e11, field e12, field e21, field e22) -> field { - field mut duplicates = e11 == e12 ? 1 : 0; - duplicates = duplicates + e11 == e21 ? 1 : 0; - duplicates = duplicates + e11 == e22 ? 1 : 0; - duplicates = duplicates + e12 == e21 ? 1 : 0; - duplicates = duplicates + e12 == e22 ? 1 : 0; - duplicates = duplicates + e21 == e22 ? 1 : 0; - return duplicates; + field mut duplicates = 0; + field mut duplicatesAux = e11 == e12 ? 1 : 0; + duplicates = duplicates + duplicatesAux; + duplicatesAux = e11 == e21 ? 1 : 0; + duplicates = duplicates + duplicatesAux; + duplicatesAux = e11 == e22 ? 1 : 0; + duplicates = duplicates + duplicatesAux; + duplicatesAux = e12 == e21 ? 1 : 0; + duplicates = duplicates + duplicatesAux; + duplicatesAux = e12 == e22 ? 1 : 0; + duplicates = duplicates + duplicatesAux; + duplicatesAux = e21 == e22 ? 1 : 0; + + return countDuplicates; } // returns 0 for x in (1..4) @@ -74,5 +81,6 @@ def main(field a21, field b11, field b22, field c11, field c22, field d21, priva duplicates = duplicates + countDuplicates(b12, b22, d12, d22); // the solution is correct if and only if there are no duplicates + assert(duplicates == 0); return duplicates == 0; } From 475744bf6e0469b714d08f49dd42119fef38a4d9 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 20 Mar 2023 20:21:23 +0100 Subject: [PATCH 12/32] refactor --- zokrates_ast/src/ir/mod.rs | 2 +- zokrates_ast/src/ir/serialize.rs | 168 ++++++++++++------------------- 2 files changed, 67 insertions(+), 103 deletions(-) diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index 66535d3c1..7de47396e 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -20,7 +20,7 @@ mod witness; pub use self::expression::QuadComb; pub use self::expression::{CanonicalLinComb, LinComb}; -pub use self::serialize::ProgEnum; +pub use self::serialize::{ProgEnum, ProgHeader}; pub use crate::common::Parameter; pub use crate::common::RuntimeError; pub use crate::common::Solver; diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 6266662d4..37b28f7df 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -1,7 +1,4 @@ -use crate::{ - ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer}, - Solver, -}; +use crate::ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer}; use super::{ProgIterator, Statement}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -65,7 +62,7 @@ impl< #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum SectionType { - Arguments = 1, + Parameters = 1, Constraints = 2, Solvers = 3, } @@ -75,7 +72,7 @@ impl TryFrom for SectionType { fn try_from(value: u32) -> Result { match value { - 1 => Ok(SectionType::Arguments), + 1 => Ok(SectionType::Parameters), 2 => Ok(SectionType::Constraints), 3 => Ok(SectionType::Solvers), _ => Err("invalid section type".to_string()), @@ -198,21 +195,23 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header - let arguments = { - let mut section = Section::new(SectionType::Arguments); + // write parameters + if self.arguments.len() > 0 { + let mut section = Section::new(SectionType::Parameters); section.set_offset(w.stream_position()?); serde_cbor::to_writer(&mut w, &self.arguments)?; section.set_length(w.stream_position()? - section.offset); - section - }; + header.sections.push(section); + } let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default(); let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); let mut count: usize = 0; - let constraints = { + // write constraints + { let mut section = Section::new(SectionType::Constraints); section.set_offset(w.stream_position()?); @@ -232,24 +231,23 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a } section.set_length(w.stream_position()? - section.offset); - section + header.sections.push(section); }; - let solvers = { + // write solvers + if solver_indexer.solvers.len() > 0 { let mut section = Section::new(SectionType::Solvers); section.set_offset(w.stream_position()?); serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; section.set_length(w.stream_position()? - section.offset); - section - }; + header.sections.push(section); + } header.constraint_count = count as u32; - header - .sections - .extend_from_slice(&[arguments, constraints, solvers]); + // rewind to write the header w.rewind()?; header.write(&mut w)?; @@ -283,6 +281,52 @@ impl<'de, R: Read + Seek> UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bw6_761Field>>, > { + fn read( + mut r: R, + header: &ProgHeader, + ) -> ProgIterator< + 'de, + T, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, T>>, + > { + let parameters = match header.get_section(SectionType::Parameters) { + Some(section) => { + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read parameters")) + .unwrap() + } + None => vec![], + }; + + let solvers = match header.get_section(SectionType::Solvers) { + Some(section) => { + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solvers")) + .unwrap() + } + None => vec![], + }; + + let section = header.get_section(SectionType::Constraints).unwrap(); + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + ProgIterator::new( + parameters, + UnwrappedStreamDeserializer { s }, + header.return_count as usize, + solvers, + ) + } + pub fn deserialize(mut r: R) -> Result { let header = ProgHeader::read(&mut r).map_err(|_| String::from("Invalid header"))?; @@ -296,95 +340,15 @@ impl<'de, R: Read + Seek> return Err("Invalid file version".to_string()); } - let arguments = { - let section = header.get_section(SectionType::Arguments).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read parameters"))? - }; - - let solvers_section = header.get_section(SectionType::Solvers).unwrap(); - r.seek(std::io::SeekFrom::Start(solvers_section.offset)) - .unwrap(); - match header.curve_id { m if m == Bls12_381Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? - }; - - let section = header.get_section(SectionType::Constraints).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bls12_381Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - header.return_count as usize, - solvers, - ))) - } - m if m == Bn128Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? - }; - - let section = header.get_section(SectionType::Constraints).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bn128Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - header.return_count as usize, - solvers, - ))) + Ok(ProgEnum::Bls12_381Program(Self::read(r, &header))) } + m if m == Bn128Field::id() => Ok(ProgEnum::Bn128Program(Self::read(r, &header))), m if m == Bls12_377Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? - }; - - let section = header.get_section(SectionType::Constraints).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bls12_377Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - header.return_count as usize, - solvers, - ))) - } - m if m == Bw6_761Field::id() => { - let solvers: Vec> = { - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p).map_err(|_| String::from("Cannot read solvers"))? - }; - - let section = header.get_section(SectionType::Constraints).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); - - Ok(ProgEnum::Bw6_761Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - header.return_count as usize, - solvers, - ))) + Ok(ProgEnum::Bls12_377Program(Self::read(r, &header))) } + m if m == Bw6_761Field::id() => Ok(ProgEnum::Bw6_761Program(Self::read(r, &header))), _ => Err(String::from("Unknown curve identifier")), } } From 87f356a7c62f118ee5000a6fd27e6e3c923274b5 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 21 Mar 2023 00:09:37 +0100 Subject: [PATCH 13/32] add test --- zokrates_ast/src/ir/serialize.rs | 4 +-- zokrates_interpreter/src/lib.rs | 42 ++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 37b28f7df..2882171f4 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -196,7 +196,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header // write parameters - if self.arguments.len() > 0 { + if !self.arguments.is_empty() { let mut section = Section::new(SectionType::Parameters); section.set_offset(w.stream_position()?); @@ -235,7 +235,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a }; // write solvers - if solver_indexer.solvers.len() > 0 { + if !solver_indexer.solvers.is_empty() { let mut section = Section::new(SectionType::Solvers); section.set_offset(w.stream_position()?); diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 19a3ef399..ccfff8184 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -505,4 +505,46 @@ mod tests { assert_eq!(res, expected); } + + #[test] + fn solver_ref() { + use zir::{ + types::{Signature, Type}, + FieldElementExpression, Identifier, IdentifierExpression, Parameter, Variable, + ZirFunction, ZirStatement, + }; + use zokrates_ast::common::RefCall; + + let id = IdentifierExpression::new(Identifier::internal("i0")); + + // (field i0) -> i0 * i0 + let solvers = vec![Solver::Zir(ZirFunction { + arguments: vec![Parameter { + id: Variable::with_id_and_type(id.id.clone(), Type::FieldElement), + private: true, + }], + statements: vec![ZirStatement::Return(vec![FieldElementExpression::Mult( + Box::new(FieldElementExpression::Identifier(id.clone())), + Box::new(FieldElementExpression::Identifier(id.clone())), + ) + .into()])], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + })]; + + let inputs = vec![Bn128Field::from(2)]; + let res = Interpreter::execute_solver( + &Solver::Ref(RefCall { + index: 0, + argument_count: 1, + }), + &inputs, + &solvers, + ) + .unwrap(); + + let expected = vec![Bn128Field::from(4)]; + assert_eq!(res, expected); + } } From 95bec7be646ca57e2cfc5ae8a4ef46d4732ec531 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 21 Mar 2023 13:16:07 +0100 Subject: [PATCH 14/32] fix zir substitutor --- zokrates_ast/src/ir/serialize.rs | 127 +++++++++++++-------------- zokrates_ast/src/zir/identifier.rs | 8 +- zokrates_ast/src/zir/substitution.rs | 28 +++--- zokrates_codegen/src/lib.rs | 12 +-- zokrates_interpreter/src/lib.rs | 2 +- 5 files changed, 85 insertions(+), 92 deletions(-) diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 2882171f4..f32624299 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -61,6 +61,7 @@ impl< } #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[repr(u32)] pub enum SectionType { Parameters = 1, Constraints = 2, @@ -112,7 +113,7 @@ pub struct ProgHeader { pub curve_id: [u8; 4], pub constraint_count: u32, pub return_count: u32, - pub sections: Vec
, + pub sections: [Section; 3], } impl ProgHeader { @@ -123,7 +124,6 @@ impl ProgHeader { w.write_u32::(self.constraint_count)?; w.write_u32::(self.return_count)?; - w.write_u32::(self.sections.len() as u32)?; for s in &self.sections { w.write_u32::(s.ty as u32)?; w.write_u64::(s.offset)?; @@ -146,19 +146,9 @@ impl ProgHeader { let constraint_count = r.read_u32::()?; let return_count = r.read_u32::()?; - let section_count = r.read_u32::()?; - let mut sections = vec![]; - - for _ in 0..section_count { - let id = r.read_u32::()?; - let mut section = Section::new( - SectionType::try_from(id) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?, - ); - section.set_offset(r.read_u64::()?); - section.set_length(r.read_u64::()?); - sections.push(section); - } + let parameters = Self::read_section(r.by_ref())?; + let constraints = Self::read_section(r.by_ref())?; + let solvers = Self::read_section(r.by_ref())?; Ok(ProgHeader { magic, @@ -166,12 +156,19 @@ impl ProgHeader { curve_id, constraint_count, return_count, - sections, + sections: [parameters, constraints, solvers], }) } - fn get_section(&self, ty: SectionType) -> Option<&Section> { - self.sections.iter().find(|s| s.ty == ty) + fn read_section(mut r: R) -> std::io::Result
{ + let id = r.read_u32::()?; + let mut section = Section::new( + SectionType::try_from(id) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?, + ); + section.set_offset(r.read_u64::()?); + section.set_length(r.read_u64::()?); + Ok(section) } } @@ -181,37 +178,26 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a pub fn serialize(self, mut w: W) -> Result { use super::folder::Folder; - const SECTION_COUNT: usize = 3; - const HEADER_SIZE: usize = 24 + SECTION_COUNT * 20; + // reserve bytes for the header + w.write_all(&[0u8; std::mem::size_of::()])?; - let mut header = ProgHeader { - magic: *ZOKRATES_MAGIC, - version: *FILE_VERSION, - curve_id: T::id(), - constraint_count: 0, - return_count: self.return_count as u32, - sections: Vec::with_capacity(SECTION_COUNT), - }; - - w.write_all(&[0u8; HEADER_SIZE])?; // reserve bytes for the header - - // write parameters - if !self.arguments.is_empty() { + // write parameters section + let parameters = { let mut section = Section::new(SectionType::Parameters); section.set_offset(w.stream_position()?); serde_cbor::to_writer(&mut w, &self.arguments)?; section.set_length(w.stream_position()? - section.offset); - header.sections.push(section); - } + section + }; let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default(); let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); let mut count: usize = 0; - // write constraints - { + // write constraints section + let constraints = { let mut section = Section::new(SectionType::Constraints); section.set_offset(w.stream_position()?); @@ -231,21 +217,28 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a } section.set_length(w.stream_position()? - section.offset); - header.sections.push(section); + section }; - // write solvers - if !solver_indexer.solvers.is_empty() { + // write solvers section + let solvers = { let mut section = Section::new(SectionType::Solvers); section.set_offset(w.stream_position()?); serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; section.set_length(w.stream_position()? - section.offset); - header.sections.push(section); - } + section + }; - header.constraint_count = count as u32; + let header = ProgHeader { + magic: *ZOKRATES_MAGIC, + version: *FILE_VERSION, + curve_id: T::id(), + constraint_count: count as u32, + return_count: self.return_count as u32, + sections: [parameters, constraints, solvers], + }; // rewind to write the header w.rewind()?; @@ -289,39 +282,39 @@ impl<'de, R: Read + Seek> T, UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, T>>, > { - let parameters = match header.get_section(SectionType::Parameters) { - Some(section) => { - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read parameters")) - .unwrap() - } - None => vec![], + let parameters = { + let section = &header.sections[0]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read parameters")) + .unwrap() }; - let solvers = match header.get_section(SectionType::Solvers) { - Some(section) => { - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + let solvers = { + let section = &header.sections[2]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); - Vec::deserialize(&mut p) - .map_err(|_| String::from("Cannot read solvers")) - .unwrap() - } - None => vec![], + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solvers")) + .unwrap() }; - let section = header.get_section(SectionType::Constraints).unwrap(); - r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + let statements_deserializer = { + let section = &header.sections[1]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); - let p = serde_cbor::Deserializer::from_reader(r); - let s = p.into_iter::>(); + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + UnwrappedStreamDeserializer { s } + }; ProgIterator::new( parameters, - UnwrappedStreamDeserializer { s }, + statements_deserializer, header.return_count as usize, solvers, ) diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index b036e42a7..f0c387df3 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -8,7 +8,7 @@ use crate::typed::Identifier as CoreIdentifier; pub enum Identifier<'ast> { #[serde(borrow)] Source(SourceIdentifier<'ast>), - Internal(String), + Internal(usize), } #[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] @@ -32,8 +32,8 @@ impl<'ast> fmt::Display for SourceIdentifier<'ast> { } impl<'ast> Identifier<'ast> { - pub fn internal>(name: S) -> Self { - Identifier::Internal(name.into()) + pub fn internal>(id: T) -> Self { + Identifier::Internal(id.into()) } } @@ -41,7 +41,7 @@ impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Identifier::Source(s) => write!(f, "{}", s), - Identifier::Internal(s) => write!(f, "{}", s), + Identifier::Internal(i) => write!(f, "i{}", i), } } } diff --git a/zokrates_ast/src/zir/substitution.rs b/zokrates_ast/src/zir/substitution.rs index d559ccffe..d5709c5d0 100644 --- a/zokrates_ast/src/zir/substitution.rs +++ b/zokrates_ast/src/zir/substitution.rs @@ -1,23 +1,31 @@ -use super::{Folder, Identifier}; +use super::{Folder, Identifier, Parameter, Variable, ZirAssignee}; use std::collections::HashMap; use zokrates_field::Field; #[derive(Default)] pub struct ZirSubstitutor<'ast> { - substitution: HashMap, Identifier<'ast>>, -} - -impl<'ast> ZirSubstitutor<'ast> { - pub fn new(substitution: HashMap, Identifier<'ast>>) -> Self { - Self { substitution } - } + substitution: HashMap, usize>, } impl<'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'ast> { + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { + let new_id = self.substitution.len(); + self.substitution.insert(p.id.id.clone(), new_id); + + Parameter { + id: Variable::with_id_and_type(Identifier::internal(new_id), p.id._type), + ..p + } + } + fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { + let new_id = self.substitution.len(); + self.substitution.insert(a.id.clone(), new_id); + ZirAssignee::with_id_and_type(Identifier::internal(new_id), a._type) + } fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { match self.substitution.get(&n) { - Some(v) => v.clone(), - None => n, + Some(v) => Identifier::internal(*v), + None => unreachable!(), } } } diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 57aa7fcd6..9ec1f4268 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -2243,18 +2243,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { .map(|assignee| self.use_variable(&assignee)) .collect(); - let mut substitution_map = HashMap::default(); - for (index, p) in function.arguments.iter().enumerate() { - let new_id = Identifier::internal(format!("i{}", index)); - substitution_map.insert(p.id.id.clone(), new_id); - } - - let mut substitutor = ZirSubstitutor::new(substitution_map); + let mut substitutor = ZirSubstitutor::default(); let function = substitutor.fold_function(function); - let solver = Solver::Zir(function); - let directive = FlatDirective::new(outputs, solver, inputs); - + let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); } ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index ccfff8184..73709addd 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -515,7 +515,7 @@ mod tests { }; use zokrates_ast::common::RefCall; - let id = IdentifierExpression::new(Identifier::internal("i0")); + let id = IdentifierExpression::new(Identifier::internal(0usize)); // (field i0) -> i0 * i0 let solvers = vec![Solver::Zir(ZirFunction { From 2e551b3d161c4f4cc81ee99472ad3a2d2dcad6f2 Mon Sep 17 00:00:00 2001 From: Ahmed Castro Date: Tue, 21 Mar 2023 11:53:10 -0600 Subject: [PATCH 15/32] Update sudoku_checker.zok --- .../examples/sudoku/sudoku_checker.zok | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/zokrates_cli/examples/sudoku/sudoku_checker.zok b/zokrates_cli/examples/sudoku/sudoku_checker.zok index 2272a42e9..7fc6e2252 100644 --- a/zokrates_cli/examples/sudoku/sudoku_checker.zok +++ b/zokrates_cli/examples/sudoku/sudoku_checker.zok @@ -11,20 +11,13 @@ // We use a naive encoding of the values as `[1, 2, 3, 4]` and rely on if-else statements to detect duplicates def countDuplicates(field e11, field e12, field e21, field e22) -> field { - field mut duplicates = 0; - field mut duplicatesAux = e11 == e12 ? 1 : 0; - duplicates = duplicates + duplicatesAux; - duplicatesAux = e11 == e21 ? 1 : 0; - duplicates = duplicates + duplicatesAux; - duplicatesAux = e11 == e22 ? 1 : 0; - duplicates = duplicates + duplicatesAux; - duplicatesAux = e12 == e21 ? 1 : 0; - duplicates = duplicates + duplicatesAux; - duplicatesAux = e12 == e22 ? 1 : 0; - duplicates = duplicates + duplicatesAux; - duplicatesAux = e21 == e22 ? 1 : 0; - - return countDuplicates; + field mut duplicates = e11 == e12 ? 1 : 0; + duplicates = duplicates + (e11 == e21 ? 1 : 0); + duplicates = duplicates + (e11 == e22 ? 1 : 0); + duplicates = duplicates + (e12 == e21 ? 1 : 0); + duplicates = duplicates + (e12 == e22 ? 1 : 0); + duplicates = duplicates + (e21 == e22 ? 1 : 0); + return duplicates; } // returns 0 for x in (1..4) @@ -33,7 +26,7 @@ def validateInput(field x) -> bool { } // variables naming: box'row''column' -def main(field a21, field b11, field b22, field c11, field c22, field d21, private field a11, private field a12, private field a22, private field b12, private field b21, private field c12, private field c21, private field d11, private field d12, private field d22) -> bool { +def main(field a21, field b11, field b22, field c11, field c22, field d21, private field a11, private field a12, private field a22, private field b12, private field b21, private field c12, private field c21, private field d11, private field d12, private field d22) { // validate inputs assert(validateInput(a11)); @@ -82,5 +75,6 @@ def main(field a21, field b11, field b22, field c11, field c22, field d21, priva // the solution is correct if and only if there are no duplicates assert(duplicates == 0); - return duplicates == 0; + + return; } From e210f469ca1a6cbfaee7c13a7754fa450dce1e70 Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 24 Mar 2023 14:13:15 +0100 Subject: [PATCH 16/32] change into_bellman impl on field, use lincomb in bellman circuit construction instead of canonical --- zokrates_bellman/src/lib.rs | 20 +++++--------------- zokrates_field/src/lib.rs | 9 ++++++--- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/zokrates_bellman/src/lib.rs b/zokrates_bellman/src/lib.rs index 9828b9745..040683bf9 100644 --- a/zokrates_bellman/src/lib.rs +++ b/zokrates_bellman/src/lib.rs @@ -11,7 +11,7 @@ use bellman::{ }; use std::collections::BTreeMap; use zokrates_ast::common::Variable; -use zokrates_ast::ir::{CanonicalLinComb, ProgIterator, Statement, Witness}; +use zokrates_ast::ir::{LinComb, ProgIterator, Statement, Witness}; use zokrates_field::BellmanFieldExtensions; use zokrates_field::Field; @@ -45,7 +45,7 @@ impl<'a, T: Field, I: IntoIterator>> Computation<'a, T, } fn bellman_combination>( - l: CanonicalLinComb, + l: LinComb, cs: &mut CS, symbols: &mut BTreeMap, witness: &mut Witness, @@ -127,19 +127,9 @@ impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator ::Fr { - use bellman_ce::pairing::ff::PrimeField; - let s = self.to_dec_string(); - ::Fr::from_str(&s).unwrap() + use bellman_ce::pairing::ff::{PrimeField, PrimeFieldRepr}; + let bytes = self.to_byte_vector(); + let mut repr = + <::Fr as PrimeField>::Repr::default(); + repr.read_le(bytes.as_slice()).unwrap(); + ::Fr::from_repr(repr).unwrap() } fn new_fq2( From e0b029959d8c34eca7eebfba0ca72ac7096fabe3 Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 24 Mar 2023 14:39:08 +0100 Subject: [PATCH 17/32] rename substitutor to canonicalizer, add tests --- zokrates_ast/src/zir/canonicalizer.rs | 94 +++++++++++++++++++++++++++ zokrates_ast/src/zir/mod.rs | 2 +- zokrates_ast/src/zir/substitution.rs | 31 --------- zokrates_codegen/src/lib.rs | 6 +- zokrates_interpreter/src/lib.rs | 2 +- 5 files changed, 99 insertions(+), 36 deletions(-) create mode 100644 zokrates_ast/src/zir/canonicalizer.rs delete mode 100644 zokrates_ast/src/zir/substitution.rs diff --git a/zokrates_ast/src/zir/canonicalizer.rs b/zokrates_ast/src/zir/canonicalizer.rs new file mode 100644 index 000000000..ca67d8b51 --- /dev/null +++ b/zokrates_ast/src/zir/canonicalizer.rs @@ -0,0 +1,94 @@ +use super::{Folder, Identifier, Parameter, Variable, ZirAssignee}; +use std::collections::HashMap; +use zokrates_field::Field; + +#[derive(Default)] +pub struct ZirCanonicalizer<'ast> { + identifier_map: HashMap, usize>, +} + +impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> { + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { + let new_id = self.identifier_map.len(); + self.identifier_map.insert(p.id.id.clone(), new_id); + + Parameter { + id: Variable::with_id_and_type(Identifier::internal(new_id), p.id._type), + ..p + } + } + fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { + let new_id = self.identifier_map.len(); + self.identifier_map.insert(a.id.clone(), new_id); + ZirAssignee::with_id_and_type(Identifier::internal(new_id), a._type) + } + fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { + match self.identifier_map.get(&n) { + Some(v) => Identifier::internal(*v), + None => unreachable!(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::zir::{ + FieldElementExpression, IdentifierExpression, Signature, Type, ZirAssignee, ZirFunction, + ZirStatement, + }; + + use super::*; + use zokrates_field::Bn128Field; + + #[test] + fn canonicalize() { + let func = ZirFunction:: { + arguments: vec![Parameter { + id: Variable::field_element("a"), + private: true, + }], + statements: vec![ + ZirStatement::Definition( + ZirAssignee::field_element("b"), + FieldElementExpression::Identifier(IdentifierExpression::new("a".into())) + .into(), + ), + ZirStatement::Return(vec![FieldElementExpression::Identifier( + IdentifierExpression::new("b".into()), + ) + .into()]), + ], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + }; + + let mut canonicalizer = ZirCanonicalizer::default(); + let result = canonicalizer.fold_function(func); + + let expected = ZirFunction:: { + arguments: vec![Parameter { + id: Variable::field_element(Identifier::internal(0usize)), + private: true, + }], + statements: vec![ + ZirStatement::Definition( + ZirAssignee::field_element(Identifier::internal(1usize)), + FieldElementExpression::Identifier(IdentifierExpression::new( + Identifier::internal(0usize), + )) + .into(), + ), + ZirStatement::Return(vec![FieldElementExpression::Identifier( + IdentifierExpression::new(Identifier::internal(1usize)), + ) + .into()]), + ], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + }; + + assert_eq!(result, expected); + } +} diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index c2af87b0b..081109666 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -1,10 +1,10 @@ +pub mod canonicalizer; pub mod folder; mod from_typed; mod identifier; pub mod lqc; mod parameter; pub mod result_folder; -pub mod substitution; pub mod types; mod uint; mod variable; diff --git a/zokrates_ast/src/zir/substitution.rs b/zokrates_ast/src/zir/substitution.rs deleted file mode 100644 index d5709c5d0..000000000 --- a/zokrates_ast/src/zir/substitution.rs +++ /dev/null @@ -1,31 +0,0 @@ -use super::{Folder, Identifier, Parameter, Variable, ZirAssignee}; -use std::collections::HashMap; -use zokrates_field::Field; - -#[derive(Default)] -pub struct ZirSubstitutor<'ast> { - substitution: HashMap, usize>, -} - -impl<'ast, T: Field> Folder<'ast, T> for ZirSubstitutor<'ast> { - fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { - let new_id = self.substitution.len(); - self.substitution.insert(p.id.id.clone(), new_id); - - Parameter { - id: Variable::with_id_and_type(Identifier::internal(new_id), p.id._type), - ..p - } - } - fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { - let new_id = self.substitution.len(); - self.substitution.insert(a.id.clone(), new_id); - ZirAssignee::with_id_and_type(Identifier::internal(new_id), a._type) - } - fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { - match self.substitution.get(&n) { - Some(v) => Identifier::internal(*v), - None => unreachable!(), - } - } -} diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 9ec1f4268..a09eb61e9 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -11,7 +11,7 @@ mod utils; use self::utils::flat_expression_from_bits; use zokrates_ast::zir::{ - substitution::ZirSubstitutor, ConditionalExpression, Folder, SelectExpression, ShouldReduce, + canonicalizer::ZirCanonicalizer, ConditionalExpression, Folder, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, ZirExpressionList, }; use zokrates_interpreter::Interpreter; @@ -2243,8 +2243,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { .map(|assignee| self.use_variable(&assignee)) .collect(); - let mut substitutor = ZirSubstitutor::default(); - let function = substitutor.fold_function(function); + let mut canonicalizer = ZirCanonicalizer::default(); + let function = canonicalizer.fold_function(function); let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 73709addd..db6ed4a59 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -520,7 +520,7 @@ mod tests { // (field i0) -> i0 * i0 let solvers = vec![Solver::Zir(ZirFunction { arguments: vec![Parameter { - id: Variable::with_id_and_type(id.id.clone(), Type::FieldElement), + id: Variable::field_element(id.id.clone()), private: true, }], statements: vec![ZirStatement::Return(vec![FieldElementExpression::Mult( From 30c9d8b8b187a019ef7c735f681b221dd3ff281d Mon Sep 17 00:00:00 2001 From: Ahmed Castro Date: Fri, 24 Mar 2023 16:36:49 -0600 Subject: [PATCH 18/32] Update sudoku_checker.zok --- zokrates_cli/examples/sudoku/sudoku_checker.zok | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/zokrates_cli/examples/sudoku/sudoku_checker.zok b/zokrates_cli/examples/sudoku/sudoku_checker.zok index 7fc6e2252..b77748f42 100644 --- a/zokrates_cli/examples/sudoku/sudoku_checker.zok +++ b/zokrates_cli/examples/sudoku/sudoku_checker.zok @@ -74,7 +74,5 @@ def main(field a21, field b11, field b22, field c11, field c22, field d21, priva duplicates = duplicates + countDuplicates(b12, b22, d12, d22); // the solution is correct if and only if there are no duplicates - assert(duplicates == 0); - - return; + return duplicates == 0; } From 48885d64a4ecced90cda65392eded9dcd3234a5b Mon Sep 17 00:00:00 2001 From: Ahmed Castro Date: Fri, 24 Mar 2023 16:37:58 -0600 Subject: [PATCH 19/32] Update sudoku_checker.zok --- zokrates_cli/examples/sudoku/sudoku_checker.zok | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_cli/examples/sudoku/sudoku_checker.zok b/zokrates_cli/examples/sudoku/sudoku_checker.zok index b77748f42..4c01df6ec 100644 --- a/zokrates_cli/examples/sudoku/sudoku_checker.zok +++ b/zokrates_cli/examples/sudoku/sudoku_checker.zok @@ -26,7 +26,7 @@ def validateInput(field x) -> bool { } // variables naming: box'row''column' -def main(field a21, field b11, field b22, field c11, field c22, field d21, private field a11, private field a12, private field a22, private field b12, private field b21, private field c12, private field c21, private field d11, private field d12, private field d22) { +def main(field a21, field b11, field b22, field c11, field c22, field d21, private field a11, private field a12, private field a22, private field b12, private field b21, private field c12, private field c21, private field d11, private field d12, private field d22) -> bool { // validate inputs assert(validateInput(a11)); From aea19eb4179e82bb742249581ae9f07675844e42 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 27 Mar 2023 21:18:54 +0200 Subject: [PATCH 20/32] parallelize witness deserialization --- Cargo.lock | 9 +-- zokrates_ark/Cargo.toml | 2 +- zokrates_ark/src/gm17.rs | 31 ++++++--- zokrates_ark/src/groth16.rs | 32 ++++++--- zokrates_ark/src/marlin.rs | 24 +++++-- zokrates_ast/Cargo.toml | 4 +- zokrates_ast/src/common/embed.rs | 3 +- zokrates_ast/src/ir/witness.rs | 90 ++++++++++++++++++-------- zokrates_bellman/Cargo.toml | 2 +- zokrates_bellman/src/groth16.rs | 26 +++++--- zokrates_cli/Cargo.toml | 2 +- zokrates_cli/src/ops/generate_proof.rs | 14 ++-- zokrates_cli/src/ops/mpc/verify.rs | 4 +- zokrates_js/src/lib.rs | 2 +- zokrates_proof_systems/src/lib.rs | 19 ++++-- 15 files changed, 174 insertions(+), 90 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e21917722..28c3264b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2036,9 +2036,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db3a213adf02b3bcfd2d3846bb41cb22857d131789e01df434fb7e7bc0759b7" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" dependencies = [ "either", "rayon-core", @@ -2046,9 +2046,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cac410af5d00ab6884528b4ab69d1e8e146e8d471201800fa1b4524126de6ad3" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" dependencies = [ "crossbeam-channel 0.5.6", "crossbeam-deque 0.8.2", @@ -2964,6 +2964,7 @@ dependencies = [ "derivative", "num-bigint 0.2.6", "pairing_ce", + "rayon", "serde", "serde_cbor", "serde_json", diff --git a/zokrates_ark/Cargo.toml b/zokrates_ark/Cargo.toml index c31dc8d3e..97bd86920 100644 --- a/zokrates_ark/Cargo.toml +++ b/zokrates_ark/Cargo.toml @@ -28,7 +28,7 @@ ark-bls12-377 = { version = "^0.3.0", features = ["curve"], default-features = f ark-bw6-761 = { version = "^0.3.0", default-features = false } ark-gm17 = { version = "^0.3.0", default-features = false } ark-groth16 = { version = "^0.3.0", default-features = false } -ark-serialize = { version = "^0.3.0", default-features = false } +ark-serialize = { version = "^0.3.0", default-features = false, features = ["std"] } ark-relations = { version = "^0.3.0", default-features = false } ark-marlin = { git = "https://github.com/arkworks-rs/marlin", rev = "63cfd82", default-features = false } ark-poly = { version = "^0.3.0", default-features = false } diff --git a/zokrates_ark/src/gm17.rs b/zokrates_ark/src/gm17.rs index a8941ed03..d1347c1aa 100644 --- a/zokrates_ark/src/gm17.rs +++ b/zokrates_ark/src/gm17.rs @@ -40,11 +40,16 @@ impl NonUniversalBackend for Ark { } impl Backend for Ark { - fn generate_proof<'a, I: IntoIterator>, R: RngCore + CryptoRng>( + fn generate_proof< + 'a, + I: IntoIterator>, + R: std::io::Read, + G: RngCore + CryptoRng, + >( program: ProgIterator<'a, T, I>, witness: Witness, - proving_key: Vec, - rng: &mut R, + proving_key: R, + rng: &mut G, ) -> Proof { let computation = Computation::with_witness(program, witness); @@ -54,10 +59,9 @@ impl Backend for Ark { .map(parse_fr::) .collect::>(); - let pk = ProvingKey::<::ArkEngine>::deserialize_unchecked( - &mut proving_key.as_slice(), - ) - .unwrap(); + let pk = + ProvingKey::<::ArkEngine>::deserialize_unchecked(proving_key) + .unwrap(); let proof = ArkGM17::::prove(&pk, computation, rng).unwrap(); let proof_points = ProofPoints { @@ -135,7 +139,10 @@ mod tests { .unwrap(); let proof = >::generate_proof( - program, witness, keypair.pk, rng, + program, + witness, + keypair.pk.as_slice(), + rng, ); let ans = >::verify(keypair.vk, proof); @@ -158,8 +165,12 @@ mod tests { .execute(program.clone(), &[Bw6_761Field::from(42)]) .unwrap(); - let proof = - >::generate_proof(program, witness, keypair.pk, rng); + let proof = >::generate_proof( + program, + witness, + keypair.pk.as_slice(), + rng, + ); let ans = >::verify(keypair.vk, proof); assert!(ans); diff --git a/zokrates_ark/src/groth16.rs b/zokrates_ark/src/groth16.rs index e9281210a..250fd7e64 100644 --- a/zokrates_ark/src/groth16.rs +++ b/zokrates_ark/src/groth16.rs @@ -4,6 +4,7 @@ use ark_groth16::{ ProvingKey, VerifyingKey, }; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use std::io::Read; use zokrates_field::ArkFieldExtensions; use zokrates_field::Field; use zokrates_proof_systems::{Backend, NonUniversalBackend, Proof, SetupKeypair}; @@ -17,11 +18,16 @@ use zokrates_proof_systems::groth16::{ProofPoints, VerificationKey, G16}; use zokrates_proof_systems::Scheme; impl Backend for Ark { - fn generate_proof<'a, I: IntoIterator>, R: RngCore + CryptoRng>( + fn generate_proof< + 'a, + I: IntoIterator>, + R: Read, + G: RngCore + CryptoRng, + >( program: ProgIterator<'a, T, I>, witness: Witness, - proving_key: Vec, - rng: &mut R, + proving_key: R, + rng: &mut G, ) -> Proof { let computation = Computation::with_witness(program, witness); @@ -31,10 +37,9 @@ impl Backend for Ark { .map(parse_fr::) .collect::>(); - let pk = ProvingKey::<::ArkEngine>::deserialize_unchecked( - &mut proving_key.as_slice(), - ) - .unwrap(); + let pk = + ProvingKey::<::ArkEngine>::deserialize_unchecked(proving_key) + .unwrap(); let proof = Groth16::::prove(&pk, computation, rng).unwrap(); let proof_points = ProofPoints { @@ -132,7 +137,10 @@ mod tests { .unwrap(); let proof = >::generate_proof( - program, witness, keypair.pk, rng, + program, + witness, + keypair.pk.as_slice(), + rng, ); let ans = >::verify(keypair.vk, proof); @@ -155,8 +163,12 @@ mod tests { .execute(program.clone(), &[Bw6_761Field::from(42)]) .unwrap(); - let proof = - >::generate_proof(program, witness, keypair.pk, rng); + let proof = >::generate_proof( + program, + witness, + keypair.pk.as_slice(), + rng, + ); let ans = >::verify(keypair.vk, proof); assert!(ans); diff --git a/zokrates_ark/src/marlin.rs b/zokrates_ark/src/marlin.rs index b0e0f75e7..d4eeb1c42 100644 --- a/zokrates_ark/src/marlin.rs +++ b/zokrates_ark/src/marlin.rs @@ -19,6 +19,7 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use digest::Digest; use rand_0_8::{CryptoRng, Error, RngCore, SeedableRng}; use sha3::Keccak256; +use std::io::Read; use std::marker::PhantomData; use zokrates_field::{ArkFieldExtensions, Field}; @@ -206,11 +207,16 @@ impl UniversalBackend for Ark } impl Backend for Ark { - fn generate_proof<'a, I: IntoIterator>, R: RngCore + CryptoRng>( + fn generate_proof< + 'a, + I: IntoIterator>, + R: Read, + G: RngCore + CryptoRng, + >( program: ProgIterator<'a, T, I>, witness: Witness, - proving_key: Vec, - rng: &mut R, + proving_key: R, + rng: &mut G, ) -> Proof { let computation = Computation::with_witness(program, witness); @@ -220,7 +226,7 @@ impl Backend for Ark { T::ArkEngine, DensePolynomial<<::ArkEngine as PairingEngine>::Fr>, >, - >::deserialize_unchecked(&mut proving_key.as_slice()) + >::deserialize_unchecked(proving_key) .unwrap(); let public_inputs = computation.public_inputs_values(); @@ -417,7 +423,10 @@ mod tests { .unwrap(); let proof = >::generate_proof( - program, witness, keypair.pk, rng, + program, + witness, + keypair.pk.as_slice(), + rng, ); let ans = >::verify(keypair.vk, proof); @@ -452,7 +461,10 @@ mod tests { .unwrap(); let proof = >::generate_proof( - program, witness, keypair.pk, rng, + program, + witness, + keypair.pk.as_slice(), + rng, ); let ans = >::verify(keypair.vk, proof); diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index 6d9b4324b..b85e71674 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" default = ["bellman", "ark"] bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman"] ark = ["ark-bls12-377", "zokrates_embed/ark"] +parallel = ["rayon"] [dependencies] zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" } @@ -20,4 +21,5 @@ serde_json = { version = "1.0", features = ["preserve_order"] } zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false } pairing_ce = { version = "^0.21", optional = true } ark-bls12-377 = { version = "^0.3.0", features = ["curve"], default-features = false, optional = true } -derivative = "2.2.0" \ No newline at end of file +derivative = "2.2.0" +rayon = { version = "1.7", optional = true } \ No newline at end of file diff --git a/zokrates_ast/src/common/embed.rs b/zokrates_ast/src/common/embed.rs index 58294142f..3eeed58ce 100644 --- a/zokrates_ast/src/common/embed.rs +++ b/zokrates_ast/src/common/embed.rs @@ -1,5 +1,5 @@ use crate::common::{Parameter, RuntimeError, Solver, Variable}; -use crate::flat::{flat_expression_from_bits, flat_expression_from_variable_summands}; +use crate::flat::flat_expression_from_bits; use crate::flat::{FlatDirective, FlatExpression, FlatFunctionIterator, FlatStatement}; use crate::typed::types::{ ConcreteGenericsAssignment, DeclarationConstant, DeclarationSignature, DeclarationType, @@ -17,6 +17,7 @@ cfg_if::cfg_if! { if #[cfg(feature = "bellman")] { use pairing_ce::bn256::Bn256; use zokrates_embed::{bellman::{from_bellman, generate_sha256_round_constraints}}; + use crate::flat::flat_expression_from_variable_summands; } } diff --git a/zokrates_ast/src/ir/witness.rs b/zokrates_ast/src/ir/witness.rs index 366c62d8d..b8c0b22db 100644 --- a/zokrates_ast/src/ir/witness.rs +++ b/zokrates_ast/src/ir/witness.rs @@ -44,7 +44,6 @@ impl Witness { pub fn write(&self, writer: W) -> io::Result<()> { let mut wtr = csv::WriterBuilder::new() .delimiter(b' ') - .flexible(true) .has_headers(false) .from_writer(writer); @@ -59,36 +58,71 @@ impl Witness { pub fn read(mut reader: R) -> io::Result { let mut rdr = csv::ReaderBuilder::new() .delimiter(b' ') - .flexible(true) .has_headers(false) .from_reader(&mut reader); - let map = rdr - .deserialize::<(String, String)>() - .map(|r| { - r.map(|(variable, value)| { - let variable = Variable::try_from_human_readable(&variable).map_err(|why| { - io::Error::new( - io::ErrorKind::Other, - format!("Invalid variable in witness: {}", why), - ) - })?; - let value = T::try_from_dec_str(&value).map_err(|_| { - io::Error::new( - io::ErrorKind::Other, - format!("Invalid value in witness: {}", value), - ) - })?; - Ok((variable, value)) - }) - .map_err(|e| match e.into_kind() { - csv::ErrorKind::Io(e) => e, - e => io::Error::new(io::ErrorKind::Other, format!("{:?}", e)), - })? - }) - .collect::>>()?; - - Ok(Witness(map)) + cfg_if::cfg_if! { + if #[cfg(feature = "parallel")] { + use rayon::prelude::{IntoParallelIterator, ParallelIterator}; + + let map = rdr + .deserialize::<(String, String)>() + .collect::>>() + .into_par_iter() + .map(|r| { + let (var, val) = r.map_err(|e| match e.into_kind() { + csv::ErrorKind::Io(e) => e, + e => io::Error::new(io::ErrorKind::Other, format!("{:?}", e)), + })?; + + let variable = Variable::try_from_human_readable(&var).map_err(|why| { + io::Error::new( + io::ErrorKind::Other, + format!("Invalid variable in witness: {}", why), + ) + })?; + + let value = T::try_from_dec_str(&val).map_err(|_| { + io::Error::new( + io::ErrorKind::Other, + format!("Invalid value in witness: {}", &val), + ) + })?; + + Ok((variable, value)) + }) + .collect::>>()?; + + Ok(Witness(map)) + } else { + let map = rdr + .deserialize::<(String, String)>() + .map(|r| { + r.map(|(variable, value)| { + let variable = Variable::try_from_human_readable(&variable).map_err(|why| { + io::Error::new( + io::ErrorKind::Other, + format!("Invalid variable in witness: {}", why), + ) + })?; + let value = T::try_from_dec_str(&value).map_err(|_| { + io::Error::new( + io::ErrorKind::Other, + format!("Invalid value in witness: {}", value), + ) + })?; + Ok((variable, value)) + }) + .map_err(|e| match e.into_kind() { + csv::ErrorKind::Io(e) => e, + e => io::Error::new(io::ErrorKind::Other, format!("{:?}", e)), + })? + }) + .collect::>>()?; + + Ok(Witness(map)) + } + } } } diff --git a/zokrates_bellman/Cargo.toml b/zokrates_bellman/Cargo.toml index 9d8fff446..fb4c70fee 100644 --- a/zokrates_bellman/Cargo.toml +++ b/zokrates_bellman/Cargo.toml @@ -9,7 +9,7 @@ wasm = ["bellman/nolog", "bellman/wasm"] multicore = ["bellman/multicore", "phase2/multicore"] [dependencies] -zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false } +zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false, features = ["bellman"] } zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } zokrates_proof_systems = { version = "0.1", path = "../zokrates_proof_systems", default-features = false } diff --git a/zokrates_bellman/src/groth16.rs b/zokrates_bellman/src/groth16.rs index d0e6c4178..457c7ed7b 100644 --- a/zokrates_bellman/src/groth16.rs +++ b/zokrates_bellman/src/groth16.rs @@ -20,14 +20,19 @@ use zokrates_proof_systems::groth16::{ProofPoints, VerificationKey, G16}; use zokrates_proof_systems::Scheme; impl Backend for Bellman { - fn generate_proof<'a, I: IntoIterator>, R: RngCore + CryptoRng>( + fn generate_proof< + 'a, + I: IntoIterator>, + R: Read, + G: RngCore + CryptoRng, + >( program: ProgIterator<'a, T, I>, witness: Witness, - proving_key: Vec, - rng: &mut R, + proving_key: R, + rng: &mut G, ) -> Proof { let computation = Computation::with_witness(program, witness); - let params = Parameters::read(proving_key.as_slice(), true).unwrap(); + let params = Parameters::read(proving_key, true).unwrap(); let public_inputs: Vec = computation .public_inputs_values() @@ -108,7 +113,7 @@ impl MpcBackend for Bellman { } fn contribute( - params: &mut R, + params: R, rng: &mut G, output: &mut W, ) -> Result<[u8; 64], String> { @@ -124,8 +129,8 @@ impl MpcBackend for Bellman { Ok(hash) } - fn verify<'a, P: Read, R: Read, I: IntoIterator>>( - params: &mut P, + fn verify<'a, R: Read, I: IntoIterator>>( + params: R, program: ProgIterator<'a, T, I>, phase1_radix: &mut R, ) -> Result, String> { @@ -140,7 +145,7 @@ impl MpcBackend for Bellman { Ok(hashes) } - fn export_keypair(params: &mut R) -> Result, String> { + fn export_keypair(params: R) -> Result, String> { let params = MPCParameters::::read(params, true).map_err(|e| e.to_string())?; @@ -226,7 +231,10 @@ mod tests { .unwrap(); let proof = >::generate_proof( - program, witness, keypair.pk, rng, + program, + witness, + keypair.pk.as_slice(), + rng, ); let ans = >::verify(keypair.vk, proof); diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index 3fa5b5feb..5b3b888ac 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -20,7 +20,7 @@ regex = "0.2" zokrates_field = { version = "0.5", path = "../zokrates_field", features = ["multicore"] } zokrates_abi = { version = "0.1", path = "../zokrates_abi" } zokrates_core = { version = "0.7", path = "../zokrates_core", default-features = false } -zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } +zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false, features = ["parallel"] } zokrates_interpreter = { version = "0.1", path = "../zokrates_interpreter", default-features = false } zokrates_circom = { version = "0.1", path = "../zokrates_circom", default-features = false } zokrates_embed = { version = "0.1", path = "../zokrates_embed", features = ["multicore"] } diff --git a/zokrates_cli/src/ops/generate_proof.rs b/zokrates_cli/src/ops/generate_proof.rs index b2ee77d44..99857f0aa 100644 --- a/zokrates_cli/src/ops/generate_proof.rs +++ b/zokrates_cli/src/ops/generate_proof.rs @@ -4,7 +4,7 @@ use rand_0_8::rngs::StdRng; use rand_0_8::SeedableRng; use std::convert::TryFrom; use std::fs::File; -use std::io::{BufReader, Read, Write}; +use std::io::{BufReader, Write}; use std::path::Path; #[cfg(feature = "ark")] use zokrates_ark::Ark; @@ -163,7 +163,9 @@ fn cli_generate_proof< let witness_file = File::open(&witness_path) .map_err(|why| format!("Could not open {}: {}", witness_path.display(), why))?; - let witness = ir::Witness::read(witness_file) + let witness_reader = BufReader::new(witness_file); + + let witness = ir::Witness::read(witness_reader) .map_err(|why| format!("Could not load witness: {:?}", why))?; let pk_path = Path::new(sub_matches.value_of("proving-key-path").unwrap()); @@ -172,18 +174,14 @@ fn cli_generate_proof< let pk_file = File::open(&pk_path) .map_err(|why| format!("Could not open {}: {}", pk_path.display(), why))?; - let mut pk: Vec = Vec::new(); - let mut pk_reader = BufReader::new(pk_file); - pk_reader - .read_to_end(&mut pk) - .map_err(|why| format!("Could not read {}: {}", pk_path.display(), why))?; + let pk_reader = BufReader::new(pk_file); let mut rng = sub_matches .value_of("entropy") .map(get_rng_from_entropy) .unwrap_or_else(StdRng::from_entropy); - let proof = B::generate_proof(program, witness, pk, &mut rng); + let proof = B::generate_proof(program, witness, pk_reader, &mut rng); let mut proof_file = File::create(proof_path).unwrap(); let proof = diff --git a/zokrates_cli/src/ops/mpc/verify.rs b/zokrates_cli/src/ops/mpc/verify.rs index fa014bd06..d0520e64c 100644 --- a/zokrates_cli/src/ops/mpc/verify.rs +++ b/zokrates_cli/src/ops/mpc/verify.rs @@ -73,7 +73,7 @@ fn cli_mpc_verify< let file = File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; - let mut reader = BufReader::new(file); + let reader = BufReader::new(file); let radix_path = Path::new(sub_matches.value_of("radix-path").unwrap()); let radix_file = File::open(radix_path) @@ -81,7 +81,7 @@ fn cli_mpc_verify< let mut radix_reader = BufReader::new(radix_file); - let result = B::verify(&mut reader, program, &mut radix_reader) + let result = B::verify(reader, program, &mut radix_reader) .map_err(|e| format!("Verification failed: {}", e))?; let contribution_count = result.len(); diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index 2b4dccebc..6838e4be9 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -424,7 +424,7 @@ mod internal { let ir_witness: ir::Witness = ir::Witness::read(str_witness.as_bytes()) .map_err(|err| JsValue::from_str(&format!("Could not read witness: {}", err)))?; - let proof = B::generate_proof(prog, ir_witness, pk.to_vec(), rng); + let proof = B::generate_proof(prog, ir_witness, pk, rng); Ok(JsValue::from_serde(&TaggedProof::::new(proof.proof, proof.inputs)).unwrap()) } diff --git a/zokrates_proof_systems/src/lib.rs b/zokrates_proof_systems/src/lib.rs index 69a9740df..80432e3f6 100644 --- a/zokrates_proof_systems/src/lib.rs +++ b/zokrates_proof_systems/src/lib.rs @@ -96,11 +96,16 @@ impl ToString for G2AffineFq2 { } pub trait Backend> { - fn generate_proof<'a, I: IntoIterator>, R: RngCore + CryptoRng>( + fn generate_proof< + 'a, + I: IntoIterator>, + R: Read, + G: RngCore + CryptoRng, + >( program: ir::ProgIterator<'a, T, I>, witness: ir::Witness, - proving_key: Vec, - rng: &mut R, + proving_key: R, + rng: &mut G, ) -> Proof; fn verify(vk: S::VerificationKey, proof: Proof) -> bool; @@ -129,16 +134,16 @@ pub trait MpcBackend> { ) -> Result<(), String>; fn contribute( - params: &mut R, + params: R, rng: &mut G, output: &mut W, ) -> Result<[u8; 64], String>; - fn verify<'a, P: Read, R: Read, I: IntoIterator>>( - params: &mut P, + fn verify<'a, R: Read, I: IntoIterator>>( + params: R, program: ir::ProgIterator<'a, T, I>, phase1_radix: &mut R, ) -> Result, String>; - fn export_keypair(params: &mut R) -> Result, String>; + fn export_keypair(params: R) -> Result, String>; } From f83c26cd4588ce5e5aa2981ba0fb19a1bdf22045 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 28 Mar 2023 14:31:10 +0200 Subject: [PATCH 21/32] fix test --- zokrates_test/tests/wasm.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/zokrates_test/tests/wasm.rs b/zokrates_test/tests/wasm.rs index 8b81f3bf0..1278e5bbc 100644 --- a/zokrates_test/tests/wasm.rs +++ b/zokrates_test/tests/wasm.rs @@ -29,6 +29,10 @@ fn generate_proof() { let rng = &mut StdRng::from_entropy(); let keypair = >::setup(program.clone(), rng); - let _proof = - >::generate_proof(program, witness, keypair.pk, rng); + let _proof = >::generate_proof( + program, + witness, + keypair.pk.as_slice(), + rng, + ); } From bd873daf1a4d911a5601573f83b9a10ebed6e6cf Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 28 Mar 2023 14:49:01 +0200 Subject: [PATCH 22/32] remove into_canonical call --- zokrates_ark/src/lib.rs | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/zokrates_ark/src/lib.rs b/zokrates_ark/src/lib.rs index 425be3a8d..586e09b28 100644 --- a/zokrates_ark/src/lib.rs +++ b/zokrates_ark/src/lib.rs @@ -9,7 +9,7 @@ use ark_relations::r1cs::{ }; use std::collections::BTreeMap; use zokrates_ast::common::Variable; -use zokrates_ast::ir::{CanonicalLinComb, ProgIterator, Statement, Witness}; +use zokrates_ast::ir::{LinComb, ProgIterator, Statement, Witness}; use zokrates_field::{ArkFieldExtensions, Field}; pub use self::parse::*; @@ -39,7 +39,7 @@ impl<'a, T, I: IntoIterator>> Computation<'a, T, I> { } fn ark_combination( - l: CanonicalLinComb, + l: LinComb, cs: &mut ConstraintSystem<<::ArkEngine as PairingEngine>::Fr>, symbols: &mut BTreeMap, witness: &mut Witness, @@ -113,24 +113,9 @@ impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator> for statement in self.program.statements { if let Statement::Constraint(quad, lin, _) = statement { - let a = ark_combination( - quad.left.clone().into_canonical(), - &mut cs, - &mut symbols, - &mut witness, - ); - let b = ark_combination( - quad.right.clone().into_canonical(), - &mut cs, - &mut symbols, - &mut witness, - ); - let c = ark_combination( - lin.into_canonical(), - &mut cs, - &mut symbols, - &mut witness, - ); + let a = ark_combination(quad.left, &mut cs, &mut symbols, &mut witness); + let b = ark_combination(quad.right, &mut cs, &mut symbols, &mut witness); + let c = ark_combination(lin, &mut cs, &mut symbols, &mut witness); cs.enforce_constraint(a, b, c)?; } From 3ba789e7f3a322ad1282a679f5ea3ac45ad01737 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 29 Mar 2023 20:36:34 +0200 Subject: [PATCH 23/32] use binary format for witness --- .gitignore | 1 + Cargo.lock | 1 - zokrates_ark/src/groth16.rs | 1 + zokrates_ast/Cargo.toml | 4 +- zokrates_ast/src/common/variable.rs | 15 +++ zokrates_ast/src/ir/witness.rs | 133 ++++++------------------ zokrates_cli/Cargo.toml | 2 +- zokrates_cli/src/ops/compute_witness.rs | 17 +++ zokrates_field/src/dummy_curve.rs | 8 ++ zokrates_field/src/lib.rs | 21 +++- zokrates_js/index.d.ts | 4 +- zokrates_js/src/lib.rs | 23 ++-- 12 files changed, 110 insertions(+), 120 deletions(-) diff --git a/.gitignore b/.gitignore index a11fa3038..3e0371edf 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ verifier.sol proof.json universal_setup.dat witness +witness.json # ZoKrates source files at the root of the repository /*.zok diff --git a/Cargo.lock b/Cargo.lock index 28c3264b7..94becac14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2964,7 +2964,6 @@ dependencies = [ "derivative", "num-bigint 0.2.6", "pairing_ce", - "rayon", "serde", "serde_cbor", "serde_json", diff --git a/zokrates_ark/src/groth16.rs b/zokrates_ark/src/groth16.rs index 250fd7e64..f24d82093 100644 --- a/zokrates_ark/src/groth16.rs +++ b/zokrates_ark/src/groth16.rs @@ -42,6 +42,7 @@ impl Backend for Ark { .unwrap(); let proof = Groth16::::prove(&pk, computation, rng).unwrap(); + let proof_points = ProofPoints { a: parse_g1::(&proof.a), b: parse_g2::(&proof.b), diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index b85e71674..6d9b4324b 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" default = ["bellman", "ark"] bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman"] ark = ["ark-bls12-377", "zokrates_embed/ark"] -parallel = ["rayon"] [dependencies] zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" } @@ -21,5 +20,4 @@ serde_json = { version = "1.0", features = ["preserve_order"] } zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false } pairing_ce = { version = "^0.21", optional = true } ark-bls12-377 = { version = "^0.3.0", features = ["curve"], default-features = false, optional = true } -derivative = "2.2.0" -rayon = { version = "1.7", optional = true } \ No newline at end of file +derivative = "2.2.0" \ No newline at end of file diff --git a/zokrates_ast/src/common/variable.rs b/zokrates_ast/src/common/variable.rs index 983e9e178..967d47a2c 100644 --- a/zokrates_ast/src/common/variable.rs +++ b/zokrates_ast/src/common/variable.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt; +use std::io::{Read, Write}; // A variable in a constraint system // id > 0 for intermediate variables @@ -33,6 +34,20 @@ impl Variable { (self.id as usize) - 1 } + pub fn write(&self, mut writer: W) -> std::io::Result<()> { + writer.write_all(&self.id.to_le_bytes())?; + Ok(()) + } + + pub fn read(mut reader: R) -> std::io::Result { + let mut buf = [0; std::mem::size_of::()]; + reader.read_exact(&mut buf)?; + + Ok(Variable { + id: isize::from_le_bytes(buf), + }) + } + pub fn try_from_human_readable(s: &str) -> Result { if s == "~one" { return Ok(Variable::one()); diff --git a/zokrates_ast/src/ir/witness.rs b/zokrates_ast/src/ir/witness.rs index b8c0b22db..ff58667f0 100644 --- a/zokrates_ast/src/ir/witness.rs +++ b/zokrates_ast/src/ir/witness.rs @@ -41,88 +41,44 @@ impl Witness { Witness(BTreeMap::new()) } - pub fn write(&self, writer: W) -> io::Result<()> { - let mut wtr = csv::WriterBuilder::new() - .delimiter(b' ') - .has_headers(false) - .from_writer(writer); + pub fn write(&self, mut writer: W) -> io::Result<()> { + let length = self.0.len(); + writer.write_all(&length.to_le_bytes())?; - // Write each line of the witness to the file for (variable, value) in &self.0 { - wtr.serialize((variable.to_string(), value.to_dec_string()))?; + variable.write(&mut writer)?; + value.write(&mut writer)?; } - Ok(()) } pub fn read(mut reader: R) -> io::Result { - let mut rdr = csv::ReaderBuilder::new() - .delimiter(b' ') - .has_headers(false) - .from_reader(&mut reader); - - cfg_if::cfg_if! { - if #[cfg(feature = "parallel")] { - use rayon::prelude::{IntoParallelIterator, ParallelIterator}; - - let map = rdr - .deserialize::<(String, String)>() - .collect::>>() - .into_par_iter() - .map(|r| { - let (var, val) = r.map_err(|e| match e.into_kind() { - csv::ErrorKind::Io(e) => e, - e => io::Error::new(io::ErrorKind::Other, format!("{:?}", e)), - })?; - - let variable = Variable::try_from_human_readable(&var).map_err(|why| { - io::Error::new( - io::ErrorKind::Other, - format!("Invalid variable in witness: {}", why), - ) - })?; - - let value = T::try_from_dec_str(&val).map_err(|_| { - io::Error::new( - io::ErrorKind::Other, - format!("Invalid value in witness: {}", &val), - ) - })?; - - Ok((variable, value)) - }) - .collect::>>()?; - - Ok(Witness(map)) - } else { - let map = rdr - .deserialize::<(String, String)>() - .map(|r| { - r.map(|(variable, value)| { - let variable = Variable::try_from_human_readable(&variable).map_err(|why| { - io::Error::new( - io::ErrorKind::Other, - format!("Invalid variable in witness: {}", why), - ) - })?; - let value = T::try_from_dec_str(&value).map_err(|_| { - io::Error::new( - io::ErrorKind::Other, - format!("Invalid value in witness: {}", value), - ) - })?; - Ok((variable, value)) - }) - .map_err(|e| match e.into_kind() { - csv::ErrorKind::Io(e) => e, - e => io::Error::new(io::ErrorKind::Other, format!("{:?}", e)), - })? - }) - .collect::>>()?; - - Ok(Witness(map)) - } + let mut witness = Self::empty(); + + let mut buf = [0; std::mem::size_of::()]; + reader.read_exact(&mut buf)?; + + let length: usize = usize::from_le_bytes(buf); + + for _ in 0..length { + let var = Variable::read(&mut reader)?; + let val = T::read(&mut reader)?; + + assert_eq!(witness.insert(var, val), None); } + + Ok(witness) + } + + pub fn write_json(&self, writer: W) -> io::Result<()> { + let map = self + .0 + .iter() + .map(|(k, v)| (k.to_string(), serde_json::json!(v.to_dec_string()))) + .collect::>(); + + serde_json::to_writer_pretty(writer, &map)?; + Ok(()) } } @@ -170,34 +126,5 @@ mod tests { assert_eq!(w, r); } - - #[test] - fn wrong_value() { - let mut buff = Cursor::new(vec![]); - - buff.write_all("_1 123bug".as_ref()).unwrap(); - buff.set_position(0); - - assert!(Witness::::read(buff).is_err()); - } - - #[test] - fn wrong_variable() { - let mut buff = Cursor::new(vec![]); - - buff.write_all("_1bug 123".as_ref()).unwrap(); - buff.set_position(0); - - assert!(Witness::::read(buff).is_err()); - } - - #[test] - fn not_csv() { - let mut buff = Cursor::new(vec![]); - buff.write_all("whatwhat".as_ref()).unwrap(); - buff.set_position(0); - - assert!(Witness::::read(buff).is_err()); - } } } diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index 5b3b888ac..3fa5b5feb 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -20,7 +20,7 @@ regex = "0.2" zokrates_field = { version = "0.5", path = "../zokrates_field", features = ["multicore"] } zokrates_abi = { version = "0.1", path = "../zokrates_abi" } zokrates_core = { version = "0.7", path = "../zokrates_core", default-features = false } -zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false, features = ["parallel"] } +zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } zokrates_interpreter = { version = "0.1", path = "../zokrates_interpreter", default-features = false } zokrates_circom = { version = "0.1", path = "../zokrates_circom", default-features = false } zokrates_embed = { version = "0.1", path = "../zokrates_embed", features = ["multicore"] } diff --git a/zokrates_cli/src/ops/compute_witness.rs b/zokrates_cli/src/ops/compute_witness.rs index ab2a959bd..3c8e42361 100644 --- a/zokrates_cli/src/ops/compute_witness.rs +++ b/zokrates_cli/src/ops/compute_witness.rs @@ -66,6 +66,10 @@ pub fn subcommand() -> App<'static, 'static> { .help("Read arguments from stdin") .conflicts_with("arguments") .required(false) + ).arg(Arg::with_name("json") + .long("json") + .help("Write witness in a json format") + .required(false) ) } @@ -195,6 +199,19 @@ fn cli_compute<'a, T: Field, I: Iterator>>( .write(writer) .map_err(|why| format!("Could not save witness: {:?}", why))?; + // write witness in the json format + if sub_matches.is_present("json") { + let json_path = Path::new(sub_matches.value_of("output").unwrap()).with_extension("json"); + let json_file = File::create(&json_path) + .map_err(|why| format!("Could not create {}: {}", json_path.display(), why))?; + + let writer = BufWriter::new(json_file); + + witness + .write_json(writer) + .map_err(|why| format!("Could not save {}: {:?}", json_path.display(), why))?; + } + // write circom witness to file let wtns_path = Path::new(sub_matches.value_of("circom-witness").unwrap()); let wtns_file = File::create(&wtns_path) diff --git a/zokrates_field/src/dummy_curve.rs b/zokrates_field/src/dummy_curve.rs index 5d3aed4a3..cb3af053e 100644 --- a/zokrates_field/src/dummy_curve.rs +++ b/zokrates_field/src/dummy_curve.rs @@ -250,4 +250,12 @@ impl Field for FieldPrime { fn to_biguint(&self) -> num_bigint::BigUint { unimplemented!() } + + fn read(_: R) -> std::io::Result { + unimplemented!() + } + + fn write(&self, _: W) -> std::io::Result<()> { + unimplemented!() + } } diff --git a/zokrates_field/src/lib.rs b/zokrates_field/src/lib.rs index 9b0eb7331..a33bf31a9 100644 --- a/zokrates_field/src/lib.rs +++ b/zokrates_field/src/lib.rs @@ -8,7 +8,6 @@ extern crate num_bigint; #[cfg(feature = "bellman")] use bellman_ce::pairing::{ff::ScalarEngine, Engine}; - use num_bigint::BigUint; use num_traits::{CheckedDiv, One, Zero}; use serde::{Deserialize, Serialize}; @@ -16,6 +15,7 @@ use std::convert::{From, TryFrom}; use std::fmt; use std::fmt::{Debug, Display}; use std::hash::Hash; +use std::io::{Read, Write}; use std::ops::{Add, Div, Mul, Sub}; pub trait Pow { @@ -95,6 +95,10 @@ pub trait Field: + num_traits::CheckedMul { const G2_TYPE: G2Type = G2Type::Fq2; + // Read field from the reader + fn read(reader: R) -> std::io::Result; + // Write field to the writer + fn write(&self, writer: W) -> std::io::Result<()>; /// Returns this `Field`'s contents as little-endian byte vector fn to_byte_vector(&self) -> Vec; /// Returns an element of this `Field` from a little-endian byte vector @@ -144,6 +148,7 @@ mod prime_field { use std::convert::TryFrom; use std::fmt; use std::fmt::{Debug, Display}; + use std::io::{Read, Write}; use std::ops::{Add, Div, Mul, Sub}; type Fr = <$v as ark_ec::PairingEngine>::Fr; @@ -186,9 +191,21 @@ mod prime_field { self.v.into_repr().to_bytes_le() } - fn from_byte_vector(bytes: Vec) -> Self { + fn read(reader: R) -> std::io::Result { use ark_ff::FromBytes; + Ok(FieldPrime { + v: Fr::read(reader)?, + }) + } + fn write(&self, mut writer: W) -> std::io::Result<()> { + use ark_ff::ToBytes; + self.v.write(&mut writer)?; + Ok(()) + } + + fn from_byte_vector(bytes: Vec) -> Self { + use ark_ff::FromBytes; FieldPrime { v: Fr::from(::BigInt::read(&bytes[..]).unwrap()), } diff --git a/zokrates_js/index.d.ts b/zokrates_js/index.d.ts index d2755a2c2..ff7cba246 100644 --- a/zokrates_js/index.d.ts +++ b/zokrates_js/index.d.ts @@ -42,7 +42,7 @@ declare module "zokrates-js" { } export interface ComputationResult { - witness: string; + witness: Uint8Array; output: string; snarkjs?: { witness: Uint8Array; @@ -90,7 +90,7 @@ declare module "zokrates-js" { setupWithSrs(srs: Uint8Array, program: Uint8Array): SetupKeypair; generateProof( program: Uint8Array, - witness: string, + witness: Uint8Array, provingKey: Uint8Array, entropy?: string ): Proof; diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index 6838e4be9..2f3ac09e8 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -79,15 +79,17 @@ pub struct ResolverResult { #[wasm_bindgen] pub struct ComputationResult { - witness: String, + witness: Vec, output: String, snarkjs_witness: Option>, } #[wasm_bindgen] impl ComputationResult { - pub fn witness(&self) -> JsValue { - JsValue::from_str(&self.witness) + pub fn witness(&self) -> js_sys::Uint8Array { + let arr = js_sys::Uint8Array::new_with_length(self.witness.len() as u32); + arr.copy_from(&self.witness); + arr } pub fn output(&self) -> JsValue { @@ -360,8 +362,14 @@ mod internal { buffer.into_inner() }); + let witness = { + let mut buffer = Cursor::new(vec![]); + witness.write(&mut buffer).unwrap(); + buffer.into_inner() + }; + Ok(ComputationResult { - witness: format!("{}", witness), + witness, output: to_string_pretty(&return_values).unwrap(), snarkjs_witness, }) @@ -416,12 +424,11 @@ mod internal { pub fn generate_proof, B: Backend, R: RngCore + CryptoRng>( prog: ir::Prog, - witness: JsValue, + witness: &[u8], pk: &[u8], rng: &mut R, ) -> Result { - let str_witness = witness.as_string().unwrap(); - let ir_witness: ir::Witness = ir::Witness::read(str_witness.as_bytes()) + let ir_witness: ir::Witness = ir::Witness::read(witness) .map_err(|err| JsValue::from_str(&format!("Could not read witness: {}", err)))?; let proof = B::generate_proof(prog, ir_witness, pk, rng); @@ -693,7 +700,7 @@ pub fn universal_setup(curve: JsValue, size: u32, entropy: JsValue) -> Result Date: Wed, 29 Mar 2023 20:40:26 +0200 Subject: [PATCH 24/32] remove csv dependency --- Cargo.lock | 1 - zokrates_ast/Cargo.toml | 1 - zokrates_ast/src/ir/witness.rs | 2 +- zokrates_field/Cargo.toml | 2 +- 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 94becac14..1be2fe836 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2960,7 +2960,6 @@ version = "0.1.4" dependencies = [ "ark-bls12-377", "cfg-if 0.1.10", - "csv", "derivative", "num-bigint 0.2.6", "pairing_ce", diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index 6d9b4324b..c19392aa9 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -13,7 +13,6 @@ zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" } cfg-if = "0.1" zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false } serde = { version = "1.0", features = ["derive"] } -csv = "1" serde_cbor = "0.11.2" num-bigint = { version = "0.2", default-features = false } serde_json = { version = "1.0", features = ["preserve_order"] } diff --git a/zokrates_ast/src/ir/witness.rs b/zokrates_ast/src/ir/witness.rs index ff58667f0..b2d6e508d 100644 --- a/zokrates_ast/src/ir/witness.rs +++ b/zokrates_ast/src/ir/witness.rs @@ -64,7 +64,7 @@ impl Witness { let var = Variable::read(&mut reader)?; let val = T::read(&mut reader)?; - assert_eq!(witness.insert(var, val), None); + witness.insert(var, val); } Ok(witness) diff --git a/zokrates_field/Cargo.toml b/zokrates_field/Cargo.toml index bb420c7f9..8bcde430d 100644 --- a/zokrates_field/Cargo.toml +++ b/zokrates_field/Cargo.toml @@ -29,7 +29,7 @@ ark-bn254 = { version = "^0.3.0", features = ["curve"], default-features = false ark-bls12-377 = { version = "^0.3.0", features = ["curve"], default-features = false } ark-bls12-381 = { version = "^0.3.0", features = ["curve"] } ark-bw6-761 = { version = "^0.3.0", default-features = false } -ark-serialize = { version = "^0.3.0", default-features = false } +ark-serialize = { version = "^0.3.0", default-features = false, features = ["std"] } [dev-dependencies] rand = "0.4" From 2e9c66a5c1ce9c6cfd5be46ba4cca9f5a9f921eb Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 4 Apr 2023 18:46:08 +0200 Subject: [PATCH 25/32] fix integration test, update book --- zokrates_ast/src/ir/witness.rs | 34 +++++++++++++++++ zokrates_book/src/examples/sha256example.md | 22 ++++------- .../examples/book/sha256_tutorial/test.sh | 4 +- .../tests/code/arithmetics.expected.witness | 1 - .../code/arithmetics.expected.witness.json | 3 ++ .../code/conditional_false.expected.witness | 1 - .../conditional_false.expected.witness.json | 3 ++ .../code/conditional_true.expected.witness | 1 - .../conditional_true.expected.witness.json | 3 ++ .../code/multidim_update.expected.witness | 4 -- .../multidim_update.expected.witness.json | 6 +++ .../tests/code/n_choose_k.expected.witness | 1 - .../code/n_choose_k.expected.witness.json | 3 ++ .../tests/code/no_return.expected.witness | 0 .../code/no_return.expected.witness.json | 1 + .../tests/code/return_array.expected.witness | 8 ---- .../code/return_array.expected.witness.json | 10 +++++ .../tests/code/simple_add.expected.witness | 1 - .../code/simple_add.expected.witness.json | 3 ++ .../tests/code/simple_mul.expected.witness | 1 - .../code/simple_mul.expected.witness.json | 3 ++ .../tests/code/taxation.expected.witness | 1 - .../tests/code/taxation.expected.witness.json | 3 ++ zokrates_cli/tests/integration.rs | 37 ++++++++----------- 24 files changed, 96 insertions(+), 58 deletions(-) delete mode 100644 zokrates_cli/tests/code/arithmetics.expected.witness create mode 100644 zokrates_cli/tests/code/arithmetics.expected.witness.json delete mode 100644 zokrates_cli/tests/code/conditional_false.expected.witness create mode 100644 zokrates_cli/tests/code/conditional_false.expected.witness.json delete mode 100644 zokrates_cli/tests/code/conditional_true.expected.witness create mode 100644 zokrates_cli/tests/code/conditional_true.expected.witness.json delete mode 100644 zokrates_cli/tests/code/multidim_update.expected.witness create mode 100644 zokrates_cli/tests/code/multidim_update.expected.witness.json delete mode 100644 zokrates_cli/tests/code/n_choose_k.expected.witness create mode 100644 zokrates_cli/tests/code/n_choose_k.expected.witness.json delete mode 100644 zokrates_cli/tests/code/no_return.expected.witness create mode 100644 zokrates_cli/tests/code/no_return.expected.witness.json delete mode 100644 zokrates_cli/tests/code/return_array.expected.witness create mode 100644 zokrates_cli/tests/code/return_array.expected.witness.json delete mode 100644 zokrates_cli/tests/code/simple_add.expected.witness create mode 100644 zokrates_cli/tests/code/simple_add.expected.witness.json delete mode 100644 zokrates_cli/tests/code/simple_mul.expected.witness create mode 100644 zokrates_cli/tests/code/simple_mul.expected.witness.json delete mode 100644 zokrates_cli/tests/code/taxation.expected.witness create mode 100644 zokrates_cli/tests/code/taxation.expected.witness.json diff --git a/zokrates_ast/src/ir/witness.rs b/zokrates_ast/src/ir/witness.rs index b2d6e508d..8edec8549 100644 --- a/zokrates_ast/src/ir/witness.rs +++ b/zokrates_ast/src/ir/witness.rs @@ -80,6 +80,40 @@ impl Witness { serde_json::to_writer_pretty(writer, &map)?; Ok(()) } + + pub fn read_json(reader: R) -> io::Result { + let json: serde_json::Value = serde_json::from_reader(reader)?; + let object = json + .as_object() + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Witness must be an object"))?; + + let mut witness = Witness::empty(); + for (k, v) in object { + let variable = Variable::try_from_human_readable(k).map_err(|why| { + io::Error::new( + io::ErrorKind::Other, + format!("Invalid variable in witness: {}", why), + ) + })?; + + let value = v + .as_str() + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Witness value must be a string") + }) + .and_then(|v| { + T::try_from_dec_str(v).map_err(|_| { + io::Error::new( + io::ErrorKind::Other, + format!("Invalid value in witness: {}", v), + ) + }) + })?; + + witness.insert(variable, value); + } + Ok(witness) + } } impl fmt::Display for Witness { diff --git a/zokrates_book/src/examples/sha256example.md b/zokrates_book/src/examples/sha256example.md index e970d7dc2..a933b8cb8 100644 --- a/zokrates_book/src/examples/sha256example.md +++ b/zokrates_book/src/examples/sha256example.md @@ -43,20 +43,14 @@ As a next step we can create a witness file using the following command: Using the flag `-a` we pass arguments to the program. Recall that our goal is to compute the hash for the number `5`. Consequently we set `a`, `b` and `c` to `0` and `d` to `5`. -Still here? Great! At this point, we can check the `witness` file for the return values: +Still here? Great! At this point we can check the return values. We should see the following output: ``` -{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:13}} +Witness: +["263561599766550617289250058199814760685","65303172752238645975888084098459749904"] ``` -which should lead to the following output: - -```sh -~out_0 263561599766550617289250058199814760685 -~out_1 65303172752238645975888084098459749904 -``` - -Hence, by concatenating the outputs as 128 bit numbers, we arrive at the following value as the hash for our selected pre-image : +By concatenating the outputs as 128 bit numbers, we arrive at the following value as the hash for our selected pre-image : `0xc6481e22c5ff4164af680b8cfaa5e8ed3120eeff89c4f307c4a6faaae059ce10` ## Prove knowledge of pre-image @@ -78,13 +72,13 @@ Note that we now compare the result of `sha256packed` with the hard-coded correc So, having defined the program, Victor is now ready to compile the code: ``` -{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:17}} +{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:15}} ``` Based on that Victor can run the setup phase and export a verifier smart contract as a Solidity file: ``` -{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:18:19}} +{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:16:17}} ``` `setup` creates a `verification.key` file and a `proving.key` file. Victor gives the proving key to Peggy. @@ -94,13 +88,13 @@ Based on that Victor can run the setup phase and export a verifier smart contrac Peggy provides the correct pre-image as an argument to the program. ``` -{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:20}} +{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:18}} ``` Finally, Peggy can run the command to construct the proof: ``` -{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:21}} +{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:19}} ``` As the inputs were declared as private in the program, they do not appear in the proof thanks to the zero-knowledge property of the protocol. diff --git a/zokrates_cli/examples/book/sha256_tutorial/test.sh b/zokrates_cli/examples/book/sha256_tutorial/test.sh index 664749ccf..07756593e 100755 --- a/zokrates_cli/examples/book/sha256_tutorial/test.sh +++ b/zokrates_cli/examples/book/sha256_tutorial/test.sh @@ -8,9 +8,7 @@ function zokrates() { } zokrates compile -i hashexample.zok -zokrates compute-witness -a 0 0 0 5 - -grep '~out' witness +zokrates compute-witness -a 0 0 0 5 --verbose cp -f hashexample_updated.zok hashexample.zok diff --git a/zokrates_cli/tests/code/arithmetics.expected.witness b/zokrates_cli/tests/code/arithmetics.expected.witness deleted file mode 100644 index caaf539db..000000000 --- a/zokrates_cli/tests/code/arithmetics.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 12 diff --git a/zokrates_cli/tests/code/arithmetics.expected.witness.json b/zokrates_cli/tests/code/arithmetics.expected.witness.json new file mode 100644 index 000000000..4e1ae3537 --- /dev/null +++ b/zokrates_cli/tests/code/arithmetics.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "12" +} diff --git a/zokrates_cli/tests/code/conditional_false.expected.witness b/zokrates_cli/tests/code/conditional_false.expected.witness deleted file mode 100644 index 1b8f13fa2..000000000 --- a/zokrates_cli/tests/code/conditional_false.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 0 \ No newline at end of file diff --git a/zokrates_cli/tests/code/conditional_false.expected.witness.json b/zokrates_cli/tests/code/conditional_false.expected.witness.json new file mode 100644 index 000000000..50655725f --- /dev/null +++ b/zokrates_cli/tests/code/conditional_false.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "0" +} diff --git a/zokrates_cli/tests/code/conditional_true.expected.witness b/zokrates_cli/tests/code/conditional_true.expected.witness deleted file mode 100644 index 1e61044c7..000000000 --- a/zokrates_cli/tests/code/conditional_true.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 1 \ No newline at end of file diff --git a/zokrates_cli/tests/code/conditional_true.expected.witness.json b/zokrates_cli/tests/code/conditional_true.expected.witness.json new file mode 100644 index 000000000..cd003d107 --- /dev/null +++ b/zokrates_cli/tests/code/conditional_true.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "1" +} diff --git a/zokrates_cli/tests/code/multidim_update.expected.witness b/zokrates_cli/tests/code/multidim_update.expected.witness deleted file mode 100644 index 0b054f2fe..000000000 --- a/zokrates_cli/tests/code/multidim_update.expected.witness +++ /dev/null @@ -1,4 +0,0 @@ -~out_0 0 -~out_1 0 -~out_2 0 -~out_3 42 \ No newline at end of file diff --git a/zokrates_cli/tests/code/multidim_update.expected.witness.json b/zokrates_cli/tests/code/multidim_update.expected.witness.json new file mode 100644 index 000000000..98c0fbd76 --- /dev/null +++ b/zokrates_cli/tests/code/multidim_update.expected.witness.json @@ -0,0 +1,6 @@ +{ + "~out_0": "0", + "~out_1": "0", + "~out_2": "0", + "~out_3": "42" +} diff --git a/zokrates_cli/tests/code/n_choose_k.expected.witness b/zokrates_cli/tests/code/n_choose_k.expected.witness deleted file mode 100644 index b51f241b4..000000000 --- a/zokrates_cli/tests/code/n_choose_k.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 5 \ No newline at end of file diff --git a/zokrates_cli/tests/code/n_choose_k.expected.witness.json b/zokrates_cli/tests/code/n_choose_k.expected.witness.json new file mode 100644 index 000000000..839b53c97 --- /dev/null +++ b/zokrates_cli/tests/code/n_choose_k.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "5" +} diff --git a/zokrates_cli/tests/code/no_return.expected.witness b/zokrates_cli/tests/code/no_return.expected.witness deleted file mode 100644 index e69de29bb..000000000 diff --git a/zokrates_cli/tests/code/no_return.expected.witness.json b/zokrates_cli/tests/code/no_return.expected.witness.json new file mode 100644 index 000000000..0967ef424 --- /dev/null +++ b/zokrates_cli/tests/code/no_return.expected.witness.json @@ -0,0 +1 @@ +{} diff --git a/zokrates_cli/tests/code/return_array.expected.witness b/zokrates_cli/tests/code/return_array.expected.witness deleted file mode 100644 index f9f9cf798..000000000 --- a/zokrates_cli/tests/code/return_array.expected.witness +++ /dev/null @@ -1,8 +0,0 @@ -~out_0 2 -~out_1 1 -~out_2 1 -~out_3 1 -~out_4 3 -~out_5 3 -~out_6 3 -~out_7 3 \ No newline at end of file diff --git a/zokrates_cli/tests/code/return_array.expected.witness.json b/zokrates_cli/tests/code/return_array.expected.witness.json new file mode 100644 index 000000000..d1b0a1d25 --- /dev/null +++ b/zokrates_cli/tests/code/return_array.expected.witness.json @@ -0,0 +1,10 @@ +{ + "~out_0": "2", + "~out_1": "1", + "~out_2": "1", + "~out_3": "1", + "~out_4": "3", + "~out_5": "3", + "~out_6": "3", + "~out_7": "3" +} diff --git a/zokrates_cli/tests/code/simple_add.expected.witness b/zokrates_cli/tests/code/simple_add.expected.witness deleted file mode 100644 index 23b7f950e..000000000 --- a/zokrates_cli/tests/code/simple_add.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 3 diff --git a/zokrates_cli/tests/code/simple_add.expected.witness.json b/zokrates_cli/tests/code/simple_add.expected.witness.json new file mode 100644 index 000000000..ca9fd972e --- /dev/null +++ b/zokrates_cli/tests/code/simple_add.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "3" +} diff --git a/zokrates_cli/tests/code/simple_mul.expected.witness b/zokrates_cli/tests/code/simple_mul.expected.witness deleted file mode 100644 index 8eb3a8d7a..000000000 --- a/zokrates_cli/tests/code/simple_mul.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 24 \ No newline at end of file diff --git a/zokrates_cli/tests/code/simple_mul.expected.witness.json b/zokrates_cli/tests/code/simple_mul.expected.witness.json new file mode 100644 index 000000000..7143416e9 --- /dev/null +++ b/zokrates_cli/tests/code/simple_mul.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "24" +} diff --git a/zokrates_cli/tests/code/taxation.expected.witness b/zokrates_cli/tests/code/taxation.expected.witness deleted file mode 100644 index 1b8f13fa2..000000000 --- a/zokrates_cli/tests/code/taxation.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 0 \ No newline at end of file diff --git a/zokrates_cli/tests/code/taxation.expected.witness.json b/zokrates_cli/tests/code/taxation.expected.witness.json new file mode 100644 index 000000000..50655725f --- /dev/null +++ b/zokrates_cli/tests/code/taxation.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "0" +} diff --git a/zokrates_cli/tests/integration.rs b/zokrates_cli/tests/integration.rs index 35aff1deb..b909082c2 100644 --- a/zokrates_cli/tests/integration.rs +++ b/zokrates_cli/tests/integration.rs @@ -16,11 +16,11 @@ mod integration { use std::fs; use std::fs::File; use std::io::{BufReader, Read, Write}; - use std::panic; use std::path::Path; use std::process::Command; use tempdir::TempDir; use zokrates_abi::{parse_strict, Encode}; + use zokrates_ast::ir::Witness; use zokrates_ast::typed::abi::Abi; use zokrates_field::Bn128Field; use zokrates_proof_systems::{ @@ -101,7 +101,9 @@ mod integration { let program_name = Path::new(Path::new(path.file_stem().unwrap()).file_stem().unwrap()); let prog = dir.join(program_name).with_extension("zok"); - let witness = dir.join(program_name).with_extension("expected.witness"); + let witness = dir + .join(program_name) + .with_extension("expected.witness.json"); let json_input = dir.join(program_name).with_extension("arguments.json"); test_compile_and_witness( @@ -250,33 +252,24 @@ mod integration { .unwrap(); // load the expected witness - let mut expected_witness_file = File::open(&expected_witness_path).unwrap(); - let mut expected_witness = String::new(); - expected_witness_file - .read_to_string(&mut expected_witness) - .unwrap(); + let expected_witness_file = File::open(&expected_witness_path).unwrap(); + let expected_witness = + Witness::::read_json(expected_witness_file).unwrap(); // load the actual witness - let mut witness_file = File::open(&witness_path).unwrap(); - let mut witness = String::new(); - witness_file.read_to_string(&mut witness).unwrap(); + let witness_file = File::open(&witness_path).unwrap(); + let witness = Witness::::read(witness_file).unwrap(); // load the actual inline witness - let mut inline_witness_file = File::open(&inline_witness_path).unwrap(); - let mut inline_witness = String::new(); - inline_witness_file - .read_to_string(&mut inline_witness) - .unwrap(); + let inline_witness_file = File::open(&inline_witness_path).unwrap(); + let inline_witness = + Witness::::read(inline_witness_file).unwrap(); assert_eq!(inline_witness, witness); - for line in expected_witness.as_str().split('\n') { - assert!( - witness.contains(line), - "Witness generation failed for {}\n\nLine \"{}\" not found in witness", - program_path.to_str().unwrap(), - line - ); + for (k, v) in expected_witness.0 { + let value = witness.0.get(&k).expect("should contain key"); + assert!(v.eq(value)); } let backends = map! { From 83835a72494034065b681469dcb3731bd86436d2 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 4 Apr 2023 18:53:46 +0200 Subject: [PATCH 26/32] add changelog --- changelogs/unreleased/1289-dark64 | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/1289-dark64 diff --git a/changelogs/unreleased/1289-dark64 b/changelogs/unreleased/1289-dark64 new file mode 100644 index 000000000..950a2f588 --- /dev/null +++ b/changelogs/unreleased/1289-dark64 @@ -0,0 +1 @@ +Change witness format to binary, optimize backend integration code to improve proving time \ No newline at end of file From 405b24da81ad84b5c33a5aff8169b492b5da7a17 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 4 Apr 2023 19:16:10 +0200 Subject: [PATCH 27/32] change help message --- zokrates_cli/src/ops/compute_witness.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_cli/src/ops/compute_witness.rs b/zokrates_cli/src/ops/compute_witness.rs index 3c8e42361..0a886523b 100644 --- a/zokrates_cli/src/ops/compute_witness.rs +++ b/zokrates_cli/src/ops/compute_witness.rs @@ -68,7 +68,7 @@ pub fn subcommand() -> App<'static, 'static> { .required(false) ).arg(Arg::with_name("json") .long("json") - .help("Write witness in a json format") + .help("Write witness in a json format for debugging purposes") .required(false) ) } From 2ba870a9756824562db9e539360ccffbb9091fc1 Mon Sep 17 00:00:00 2001 From: Ahmed Castro Date: Mon, 10 Apr 2023 20:09:01 -0600 Subject: [PATCH 28/32] Create 1287-Turupawn --- changelogs/unreleased/1287-Turupawn | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/1287-Turupawn diff --git a/changelogs/unreleased/1287-Turupawn b/changelogs/unreleased/1287-Turupawn new file mode 100644 index 000000000..7cb706c0d --- /dev/null +++ b/changelogs/unreleased/1287-Turupawn @@ -0,0 +1 @@ +Fixed precedence issue on Sudoku example. From 36d00668e4f77b90697b0ba32ab09260f443eff5 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 11 Apr 2023 15:05:17 +0200 Subject: [PATCH 29/32] refactor --- zokrates_ast/src/common/variable.rs | 24 ----------- zokrates_ast/src/ir/witness.rs | 60 ++++++++++++-------------- zokrates_cli/tests/integration.rs | 65 ++++++++++++++++++++++++++++- 3 files changed, 89 insertions(+), 60 deletions(-) diff --git a/zokrates_ast/src/common/variable.rs b/zokrates_ast/src/common/variable.rs index 967d47a2c..ab460623c 100644 --- a/zokrates_ast/src/common/variable.rs +++ b/zokrates_ast/src/common/variable.rs @@ -47,30 +47,6 @@ impl Variable { id: isize::from_le_bytes(buf), }) } - - pub fn try_from_human_readable(s: &str) -> Result { - if s == "~one" { - return Ok(Variable::one()); - } - - let mut public = s.split("~out_"); - match public.nth(1) { - Some(v) => { - let v = v.parse().map_err(|_| s)?; - Ok(Variable::public(v)) - } - None => { - let mut private = s.split('_'); - match private.nth(1) { - Some(v) => { - let v = v.parse().map_err(|_| s)?; - Ok(Variable::new(v)) - } - None => Err(s), - } - } - } - } } impl fmt::Display for Variable { diff --git a/zokrates_ast/src/ir/witness.rs b/zokrates_ast/src/ir/witness.rs index 8edec8549..70da540c1 100644 --- a/zokrates_ast/src/ir/witness.rs +++ b/zokrates_ast/src/ir/witness.rs @@ -80,40 +80,6 @@ impl Witness { serde_json::to_writer_pretty(writer, &map)?; Ok(()) } - - pub fn read_json(reader: R) -> io::Result { - let json: serde_json::Value = serde_json::from_reader(reader)?; - let object = json - .as_object() - .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Witness must be an object"))?; - - let mut witness = Witness::empty(); - for (k, v) in object { - let variable = Variable::try_from_human_readable(k).map_err(|why| { - io::Error::new( - io::ErrorKind::Other, - format!("Invalid variable in witness: {}", why), - ) - })?; - - let value = v - .as_str() - .ok_or_else(|| { - io::Error::new(io::ErrorKind::Other, "Witness value must be a string") - }) - .and_then(|v| { - T::try_from_dec_str(v).map_err(|_| { - io::Error::new( - io::ErrorKind::Other, - format!("Invalid value in witness: {}", v), - ) - }) - })?; - - witness.insert(variable, value); - } - Ok(witness) - } } impl fmt::Display for Witness { @@ -160,5 +126,31 @@ mod tests { assert_eq!(w, r); } + + #[test] + fn serialize_json() { + let w = Witness( + vec![ + (Variable::new(42), Bn128Field::from(42)), + (Variable::public(8), Bn128Field::from(8)), + (Variable::one(), Bn128Field::from(1)), + ] + .into_iter() + .collect(), + ); + + let mut buf = Cursor::new(vec![]); + w.write_json(&mut buf).unwrap(); + + let output = String::from_utf8(buf.into_inner()).unwrap(); + assert_eq!( + output.as_str(), + r#"{ + "~out_8": "8", + "~one": "1", + "_42": "42" +}"# + ) + } } } diff --git a/zokrates_cli/tests/integration.rs b/zokrates_cli/tests/integration.rs index b909082c2..cd988b104 100644 --- a/zokrates_cli/tests/integration.rs +++ b/zokrates_cli/tests/integration.rs @@ -27,6 +27,67 @@ mod integration { to_token::ToToken, Marlin, Proof, SolidityCompatibleScheme, G16, GM17, }; + mod helpers { + use super::*; + use zokrates_ast::common::Variable; + use zokrates_field::Field; + + pub fn parse_variable(s: &str) -> Result { + if s == "~one" { + return Ok(Variable::one()); + } + + let mut public = s.split("~out_"); + match public.nth(1) { + Some(v) => { + let v = v.parse().map_err(|_| s)?; + Ok(Variable::public(v)) + } + None => { + let mut private = s.split('_'); + match private.nth(1) { + Some(v) => { + let v = v.parse().map_err(|_| s)?; + Ok(Variable::new(v)) + } + None => Err(s), + } + } + } + } + + pub fn parse_witness_json(reader: R) -> std::io::Result> { + use std::io::{Error, ErrorKind}; + + let json: serde_json::Value = serde_json::from_reader(reader)?; + let object = json + .as_object() + .ok_or_else(|| Error::new(ErrorKind::Other, "Witness must be an object"))?; + + let mut witness = Witness::empty(); + for (k, v) in object { + let variable = parse_variable(k).map_err(|why| { + Error::new( + ErrorKind::Other, + format!("Invalid variable in witness: {}", why), + ) + })?; + + let value = v + .as_str() + .ok_or_else(|| Error::new(ErrorKind::Other, "Witness value must be a string")) + .and_then(|v| { + T::try_from_dec_str(v).map_err(|_| { + Error::new(ErrorKind::Other, format!("Invalid value in witness: {}", v)) + }) + })?; + + witness.insert(variable, value); + } + Ok(witness) + } + } + macro_rules! map( { $($key:expr => $value:expr),+ } => { @@ -253,8 +314,8 @@ mod integration { // load the expected witness let expected_witness_file = File::open(&expected_witness_path).unwrap(); - let expected_witness = - Witness::::read_json(expected_witness_file).unwrap(); + let expected_witness: Witness = + helpers::parse_witness_json(expected_witness_file).unwrap(); // load the actual witness let witness_file = File::open(&witness_path).unwrap(); From 52e45a396c857c0dc3eaff3918fc0abc0846e44c Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Thu, 13 Apr 2023 15:33:37 +0200 Subject: [PATCH 30/32] Add sourcemaps (#1285) add sourcemaps --- .circleci/config.yml | 4 +- Cargo.lock | 799 ++++--- Cargo.toml | 6 +- changelogs/unreleased/1285-schaeff | 1 + zokrates_analysis/src/assembly_transformer.rs | 453 ++-- .../src/boolean_array_comparator.rs | 111 +- zokrates_analysis/src/branch_isolator.rs | 23 +- zokrates_analysis/src/condition_redefiner.rs | 143 +- .../src/constant_argument_checker.rs | 111 +- zokrates_analysis/src/constant_resolver.rs | 109 +- zokrates_analysis/src/dead_code.rs | 76 +- zokrates_analysis/src/expression_validator.rs | 77 +- zokrates_analysis/src/flat_propagation.rs | 100 +- .../src/flatten_complex_types.rs | 596 ++--- zokrates_analysis/src/log_ignorer.rs | 9 +- zokrates_analysis/src/out_of_bounds.rs | 4 +- zokrates_analysis/src/panic_extractor.rs | 239 +- zokrates_analysis/src/propagation.rs | 1561 +++++++------ .../src/reducer/constants_reader.rs | 194 +- .../src/reducer/constants_writer.rs | 33 +- zokrates_analysis/src/reducer/inline.rs | 10 +- zokrates_analysis/src/reducer/mod.rs | 261 +-- zokrates_analysis/src/reducer/shallow_ssa.rs | 201 +- zokrates_analysis/src/uint_optimizer.rs | 346 +-- .../src/variable_write_remover.rs | 251 ++- zokrates_analysis/src/zir_propagation.rs | 1616 +++++++------- zokrates_ark/src/gm17.rs | 14 +- zokrates_ark/src/groth16.rs | 14 +- zokrates_ark/src/lib.rs | 13 +- zokrates_ark/src/marlin.rs | 18 +- zokrates_ast/Cargo.toml | 2 +- zokrates_ast/src/common/embed.rs | 109 +- zokrates_ast/src/common/expressions.rs | 280 +++ zokrates_ast/src/common/flat/mod.rs | 5 + zokrates_ast/src/common/flat/parameter.rs | 63 + zokrates_ast/src/common/flat/variable.rs | 102 + zokrates_ast/src/common/fold.rs | 7 + zokrates_ast/src/common/mod.rs | 13 + zokrates_ast/src/common/operators.rs | 143 ++ zokrates_ast/src/common/parameter.rs | 62 +- zokrates_ast/src/common/position.rs | 242 +++ zokrates_ast/src/common/statements.rs | 278 +++ zokrates_ast/src/common/value.rs | 3 + zokrates_ast/src/common/variable.rs | 110 +- zokrates_ast/src/flat/folder.rs | 102 +- zokrates_ast/src/flat/mod.rs | 371 +++- zokrates_ast/src/flat/utils.rs | 25 +- zokrates_ast/src/ir/check.rs | 9 +- zokrates_ast/src/ir/clean.rs | 10 +- zokrates_ast/src/ir/expression.rs | 262 ++- zokrates_ast/src/ir/folder.rs | 144 +- zokrates_ast/src/ir/from_flat.rs | 132 +- zokrates_ast/src/ir/mod.rs | 217 +- zokrates_ast/src/ir/serialize.rs | 32 +- zokrates_ast/src/ir/smtlib2.rs | 19 +- zokrates_ast/src/ir/solver_indexer.rs | 24 +- zokrates_ast/src/ir/visitor.rs | 28 +- zokrates_ast/src/ir/witness.rs | 2 +- zokrates_ast/src/typed/abi.rs | 17 +- zokrates_ast/src/typed/folder.rs | 1075 ++++++--- zokrates_ast/src/typed/integer.rs | 367 ++-- zokrates_ast/src/typed/mod.rs | 1551 ++++++++++--- zokrates_ast/src/typed/parameter.rs | 30 +- zokrates_ast/src/typed/result_folder.rs | 1129 ++++++---- zokrates_ast/src/typed/types.rs | 27 +- zokrates_ast/src/typed/uint.rs | 87 +- zokrates_ast/src/typed/utils/mod.rs | 18 +- zokrates_ast/src/typed/variable.rs | 74 +- zokrates_ast/src/untyped/mod.rs | 10 +- zokrates_ast/src/untyped/node.rs | 22 +- zokrates_ast/src/untyped/variable.rs | 10 +- zokrates_ast/src/zir/canonicalizer.rs | 25 +- zokrates_ast/src/zir/folder.rs | 701 ++++-- zokrates_ast/src/zir/identifier.rs | 22 +- zokrates_ast/src/zir/lqc.rs | 103 +- zokrates_ast/src/zir/mod.rs | 937 ++++++-- zokrates_ast/src/zir/parameter.rs | 34 +- zokrates_ast/src/zir/result_folder.rs | 754 ++++--- zokrates_ast/src/zir/uint.rs | 72 +- zokrates_ast/src/zir/variable.rs | 30 +- zokrates_bellman/src/groth16.rs | 11 +- zokrates_bellman/src/lib.rs | 41 +- zokrates_circom/src/lib.rs | 11 +- zokrates_circom/src/r1cs.rs | 36 +- zokrates_cli/Cargo.toml | 1 + zokrates_cli/src/bin.rs | 5 +- zokrates_cli/src/ops/compile.rs | 14 +- zokrates_cli/src/ops/mod.rs | 1 + zokrates_cli/src/ops/profile.rs | 52 + zokrates_cli/tests/integration.rs | 2 +- zokrates_codegen/src/lib.rs | 1923 +++++++++-------- zokrates_codegen/src/utils.rs | 11 +- zokrates_core/src/compile.rs | 10 +- zokrates_core/src/imports.rs | 44 +- zokrates_core/src/optimizer/directive.rs | 42 +- zokrates_core/src/optimizer/duplicate.rs | 24 +- zokrates_core/src/optimizer/mod.rs | 1 + zokrates_core/src/optimizer/redefinition.rs | 88 +- zokrates_core/src/optimizer/tautology.rs | 22 +- zokrates_core/src/semantics.rs | 619 +++--- zokrates_core_test/tests/tests/uint/ch.json | 6 +- zokrates_field/src/bn128.rs | 24 +- zokrates_interpreter/src/lib.rs | 51 +- zokrates_js/Cargo.toml | 2 +- zokrates_parser/Cargo.toml | 2 +- zokrates_pest_ast/Cargo.toml | 2 +- zokrates_profiler/Cargo.toml | 9 + zokrates_profiler/src/lib.rs | 52 + zokrates_test/tests/wasm.rs | 7 +- 109 files changed, 12715 insertions(+), 7682 deletions(-) create mode 100644 changelogs/unreleased/1285-schaeff create mode 100644 zokrates_ast/src/common/expressions.rs create mode 100644 zokrates_ast/src/common/flat/mod.rs create mode 100644 zokrates_ast/src/common/flat/parameter.rs create mode 100644 zokrates_ast/src/common/flat/variable.rs create mode 100644 zokrates_ast/src/common/fold.rs create mode 100644 zokrates_ast/src/common/operators.rs create mode 100644 zokrates_ast/src/common/position.rs create mode 100644 zokrates_ast/src/common/statements.rs create mode 100644 zokrates_ast/src/common/value.rs create mode 100644 zokrates_cli/src/ops/profile.rs create mode 100644 zokrates_profiler/Cargo.toml create mode 100644 zokrates_profiler/src/lib.rs diff --git a/.circleci/config.yml b/.circleci/config.yml index c4d216cac..7e6cd0a25 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -12,7 +12,7 @@ jobs: build: docker: - image: zokrates/env:latest - resource_class: large + resource_class: xlarge steps: - checkout - run: @@ -68,7 +68,7 @@ jobs: docker: - image: zokrates/env:latest - image: trufflesuite/ganache-cli:next - resource_class: large + resource_class: xlarge steps: - checkout - run: diff --git a/Cargo.lock b/Cargo.lock index b79522728..5073a553b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.17.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b" +checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97" dependencies = [ "gimli", ] @@ -39,9 +39,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "0.7.18" +version = "0.7.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" +checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" dependencies = [ "memchr", ] @@ -171,8 +171,8 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db02d390bf6643fb404d3d22d31aee1c4bc4459600aef9113833d17e786c6e44" dependencies = [ - "quote 1.0.20", - "syn 1.0.107", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -183,8 +183,8 @@ checksum = "db2fd794a08ccb318058009eefdf15bcaaaaf6f8161eb3345f907222bac38b20" dependencies = [ "num-bigint 0.4.3", "num-traits 0.2.15", - "quote 1.0.20", - "syn 1.0.107", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -334,9 +334,9 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8dd4e5f0bf8285d5ed538d27fab7411f3e297908fd93c62195de8bee3f199e82" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -409,9 +409,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.65" +version = "0.3.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11a17d453482a265fd5f8479f2a3f405566e6ca627837aaddb85af8b1ab8ef61" +checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca" dependencies = [ "addr2line", "cc", @@ -464,9 +464,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitvec" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1489fcb93a5bb47da0462ca93ad252ad6af2145cce58d10d46a83931ba9f016b" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" dependencies = [ "funty", "radium", @@ -537,16 +537,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" dependencies = [ "block-padding 0.2.1", - "generic-array 0.14.6", + "generic-array 0.14.7", ] [[package]] name = "block-buffer" -version = "0.10.2" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ - "generic-array 0.14.6", + "generic-array 0.14.7", ] [[package]] @@ -564,29 +564,17 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" -[[package]] -name = "bstr" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" -dependencies = [ - "lazy_static", - "memchr", - "regex-automata", - "serde", -] - [[package]] name = "bumpalo" -version = "3.10.0" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37ccbd214614c6783386c1af30caf03192f17891059cecc394b4fb119e363de3" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" [[package]] name = "byte-slice-cast" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87c5fdd0166095e1d463fc6cc01aa8ce547ad77a4e84d42eb6762b084e28067e" +checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" [[package]] name = "byte-tools" @@ -608,15 +596,15 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfb24e866b15a1af2a1b663f10c6b6b8f397a84aadb828f12e5b289ec23a3a3c" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" [[package]] name = "camino" -version = "1.0.9" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "869119e97797867fd90f5e22af7d0bd274bd4635ebb9eb68c04f3f513ae6c412" +checksum = "c530edf18f37068ac2d977409ed5cd50d53d73bc653c7647b48eb78976ac9ae2" dependencies = [ "serde", ] @@ -638,16 +626,16 @@ checksum = "4acbb09d9ee8e23699b9634375c72795d095bf268439da88562cf9b501f181fa" dependencies = [ "camino", "cargo-platform", - "semver 1.0.16", + "semver 1.0.17", "serde", "serde_json", ] [[package]] name = "cc" -version = "1.0.73" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" [[package]] name = "cfg-if" @@ -702,9 +690,9 @@ checksum = "c6dd675567eb3e35787bd2583d129e85fabc7503b0a093d08c51198a307e2091" dependencies = [ "heck", "proc-macro-error", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -736,9 +724,9 @@ checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" [[package]] name = "cpufeatures" -version = "0.2.2" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59a6001667ab124aebae2a495118e11d30984c3a653e99d86d58971708cf5e4b" +checksum = "280a9f2d8b3a38871a3c8a46fb80db65e5e5ed97da80c4d08bf27fb63e35e181" dependencies = [ "libc", ] @@ -769,12 +757,12 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.6" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" dependencies = [ "cfg-if 1.0.0", - "crossbeam-utils 0.8.14", + "crossbeam-utils 0.8.15", ] [[package]] @@ -790,13 +778,13 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" dependencies = [ "cfg-if 1.0.0", - "crossbeam-epoch 0.9.13", - "crossbeam-utils 0.8.14", + "crossbeam-epoch 0.9.14", + "crossbeam-utils 0.8.15", ] [[package]] @@ -816,14 +804,14 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.13" +version = "0.9.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a" +checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695" dependencies = [ "autocfg", "cfg-if 1.0.0", - "crossbeam-utils 0.8.14", - "memoffset 0.7.1", + "crossbeam-utils 0.8.15", + "memoffset 0.8.0", "scopeguard", ] @@ -851,9 +839,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.14" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" +checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b" dependencies = [ "cfg-if 1.0.0", ] @@ -866,11 +854,11 @@ checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" [[package]] name = "crypto-common" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5999502d32b9c48d492abe66392408144895020ec4709e549e840799f3bb74c0" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ - "generic-array 0.14.6", + "generic-array 0.14.7", "typenum", ] @@ -890,19 +878,18 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" dependencies = [ - "generic-array 0.14.6", - "subtle 2.4.1", + "generic-array 0.14.7", + "subtle 2.5.0", ] [[package]] name = "csv" -version = "1.1.6" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" +checksum = "0b015497079b9a9d69c02ad25de6c0a6edef051ea6360a327d0bd05802ef64ad" dependencies = [ - "bstr", "csv-core", - "itoa 0.4.8", + "itoa", "ryu", "serde", ] @@ -918,12 +905,12 @@ dependencies = [ [[package]] name = "ctor" -version = "0.1.22" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f877be4f7c9f246b183111634f75baa039715e3f46ce860677d3b19a69fb229c" +checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" dependencies = [ - "quote 1.0.20", - "syn 1.0.107", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -932,9 +919,9 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -970,7 +957,7 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" dependencies = [ - "generic-array 0.14.6", + "generic-array 0.14.7", ] [[package]] @@ -979,7 +966,7 @@ version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ - "block-buffer 0.10.2", + "block-buffer 0.10.4", "crypto-common", ] @@ -1005,20 +992,20 @@ dependencies = [ [[package]] name = "either" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f107b87b6afc2a64fd13cac55fe06d6c8859f12d4b14cbcdd2c67d0976781be" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" [[package]] name = "env_logger" -version = "0.9.0" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b2cf0344971ee6c64c31be0d530793fba457d322dfec2810c453d0ef228f9c3" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" dependencies = [ "atty", "humantime", "log", - "regex 1.7.1", + "regex 1.7.3", "termcolor", ] @@ -1028,6 +1015,27 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f4b14e20978669064c33b4c1e0fb4083412e40fe56cbea2eae80fd7591503ee" +[[package]] +name = "errno" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +dependencies = [ + "errno-dragonfly", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "error-chain" version = "0.11.0" @@ -1048,17 +1056,17 @@ dependencies = [ [[package]] name = "ethabi" -version = "17.1.0" +version = "17.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f186de076b3e77b8e6d73c99d1b52edc2a229e604f4b5eb6992c06c11d79d537" +checksum = "e4966fba78396ff92db3b817ee71143eccd98acf0f876b8d600e585a670c5d1b" dependencies = [ "ethereum-types", "hex 0.4.3", "once_cell", - "regex 1.7.1", + "regex 1.7.3", "serde", "serde_json", - "sha3 0.10.1", + "sha3 0.10.7", "thiserror", "uint", ] @@ -1112,9 +1120,9 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa4da3c766cd7a0db8242e326e9e4e081edd567072893ed320008189715366a4" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "synstructure", ] @@ -1126,9 +1134,9 @@ checksum = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" [[package]] name = "fastrand" -version = "1.7.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3fcf0cee53519c866c09b5de1f6c56ff9d647101f81c1964fa632e148896cdf" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" dependencies = [ "instant", ] @@ -1154,9 +1162,9 @@ dependencies = [ "num-bigint 0.2.6", "num-integer", "num-traits 0.2.15", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -1183,9 +1191,9 @@ dependencies = [ [[package]] name = "fs_extra" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2022715d62ab30faffd124d40b76f4134a550a87792276512b18d63272333394" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fuchsia-cprng" @@ -1201,9 +1209,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "futures" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" +checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" dependencies = [ "futures-channel", "futures-core", @@ -1216,9 +1224,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" dependencies = [ "futures-core", "futures-sink", @@ -1226,15 +1234,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" [[package]] name = "futures-executor" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9420b90cfa29e327d0429f19be13e7ddb68fa1cccb09d65e5706b8c7a749b8a6" +checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" dependencies = [ "futures-core", "futures-task", @@ -1244,27 +1252,27 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" [[package]] name = "futures-sink" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" [[package]] name = "futures-task" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a" +checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" [[package]] name = "futures-util" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" +checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ "futures-channel", "futures-core", @@ -1288,9 +1296,9 @@ dependencies = [ [[package]] name = "generic-array" -version = "0.14.6" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", @@ -1298,9 +1306,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if 1.0.0", "js-sys", @@ -1311,9 +1319,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.26.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4" +checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4" [[package]] name = "glob" @@ -1323,9 +1331,9 @@ checksum = "8be18de09a56b60ed0edf84bc9df007e30040691af7acd1c41874faac5895bfb" [[package]] name = "glob" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "half" @@ -1375,6 +1383,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + [[package]] name = "hex" version = "0.3.2" @@ -1445,9 +1459,9 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11d7a9f6330b71fea57921c9b61c47ee6e84f72d394754eff6163ae67e7395eb" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -1469,6 +1483,17 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "io-lifetimes" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" +dependencies = [ + "hermit-abi 0.3.1", + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "itertools" version = "0.7.11" @@ -1489,15 +1514,9 @@ dependencies = [ [[package]] name = "itoa" -version = "0.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" - -[[package]] -name = "itoa" -version = "1.0.2" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "112c678d4050afce233f4f2852bb2eb519230b3cf12f33585275537d7e41578d" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" [[package]] name = "js-sys" @@ -1516,9 +1535,12 @@ checksum = "078e285eafdfb6c4b434e0d31e8cfcb5115b651496faca5749b88fafd4f23bfd" [[package]] name = "keccak" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9b7d56ba4a8344d6be9729995e6b06f928af29998cdf79fe390cbf6b1fee838" +checksum = "3afef3b6eff9ce9d8ff9b3601125eec7f0c8cbac7abd14f355d053fa56c98768" +dependencies = [ + "cpufeatures", +] [[package]] name = "lazy_static" @@ -1528,9 +1550,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.126" +version = "0.2.141" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" + +[[package]] +name = "linux-raw-sys" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" +checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" [[package]] name = "log" @@ -1570,18 +1598,18 @@ dependencies = [ [[package]] name = "memoffset" -version = "0.7.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" dependencies = [ "autocfg", ] [[package]] name = "miniz_oxide" -version = "0.5.3" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f5c75688da582b8ffc1f1799e9db273f32133c49e048f614d22ec3256773ccc" +checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" dependencies = [ "adler", ] @@ -1677,18 +1705,18 @@ dependencies = [ [[package]] name = "object" -version = "0.28.4" +version = "0.30.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e42c982f2d955fac81dd7e1d0e1426a7d702acd9c98d19ab01083a6a0328c424" +checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439" dependencies = [ "memchr", ] [[package]] name = "once_cell" -version = "1.17.0" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" +checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" [[package]] name = "opaque-debug" @@ -1724,9 +1752,9 @@ dependencies = [ [[package]] name = "parity-scale-codec" -version = "3.1.5" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9182e4a71cae089267ab03e67c99368db7cd877baf50f931e5d6d4b71e195ac0" +checksum = "637935964ff85a605d114591d4d2c13c5d1ba2806dae97cea6bf180238a749ac" dependencies = [ "arrayvec 0.7.2", "bitvec", @@ -1738,21 +1766,21 @@ dependencies = [ [[package]] name = "parity-scale-codec-derive" -version = "3.1.3" +version = "3.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9299338969a3d2f491d65f140b00ddec470858402f888af98e8642fb5e8965cd" +checksum = "86b26a931f824dd4eca30b3e43bb4f31cd5f0d3a403c5f5ff27106b805bfde7b" dependencies = [ "proc-macro-crate", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] name = "paste" -version = "1.0.7" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c520e05135d6e763148b6426a837e239041653ba7becd2e538c076c738025fc" +checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" [[package]] name = "pest" @@ -1794,9 +1822,9 @@ checksum = "99b8db626e31e5b81787b9783425769681b347011cc59471e33ea46d2ea0cf55" dependencies = [ "pest", "pest_meta", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -1840,9 +1868,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "ppv-lite86" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "pretty_assertions" @@ -1858,14 +1886,14 @@ dependencies = [ [[package]] name = "pretty_assertions" -version = "1.2.1" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89f989ac94207d048d92db058e4f6ec7342b0971fc58d1271ca148b799b3563" +checksum = "a25e9bcb20aa780fd0bb16b72403a9064d6b3f22f026946029acb941a50af755" dependencies = [ - "ansi_term 0.12.1", "ctor", "diff", "output_vt100", + "yansi", ] [[package]] @@ -1883,10 +1911,11 @@ dependencies = [ [[package]] name = "proc-macro-crate" -version = "1.1.3" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e17d47ce914bf4de440332250b0edd23ce48c005f59fab39d3335866b114f11a" +checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9" dependencies = [ + "once_cell", "thiserror", "toml", ] @@ -1898,9 +1927,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7" dependencies = [ "proc-macro-error-attr", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "version_check", ] @@ -1910,18 +1939,18 @@ version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "syn-mid", "version_check", ] [[package]] name = "proc-macro-hack" -version = "0.5.19" +version = "0.5.20+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" +checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" @@ -1934,18 +1963,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.50" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ef7d57beacfaf2d8aee5937dab7b7f28de3cb8b1828479bb5de2a7106f2bae2" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" dependencies = [ "unicode-ident", ] [[package]] name = "pulldown-cmark" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34f197a544b0c9ab3ae46c359a7ec9cbbb5c7bf97054266fecb7ead794a181d6" +checksum = "2d9cc634bc78768157b5cbfe988ffcd1dcba95cd2b2f03a88316c08c6d00ed63" dependencies = [ "bitflags", "memchr", @@ -1963,11 +1992,11 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.20" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bcdf212e9776fbcb2d23ab029360416bb1706b1aea2d1a5ba002727cbcab804" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ - "proc-macro2 1.0.50", + "proc-macro2 1.0.56", ] [[package]] @@ -1997,7 +2026,7 @@ checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", "rand_chacha", - "rand_core 0.6.3", + "rand_core 0.6.4", ] [[package]] @@ -2007,7 +2036,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core 0.6.3", + "rand_core 0.6.4", ] [[package]] @@ -2027,9 +2056,9 @@ checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc" [[package]] name = "rand_core" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ "getrandom", ] @@ -2050,9 +2079,9 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" dependencies = [ - "crossbeam-channel 0.5.6", - "crossbeam-deque 0.8.2", - "crossbeam-utils 0.8.14", + "crossbeam-channel 0.5.8", + "crossbeam-deque 0.8.3", + "crossbeam-utils 0.8.15", "num_cpus", ] @@ -2067,9 +2096,18 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.13" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f25bc4c7e55e0b0b7a1d43fb893f4fa1361d0abe38b9ce4f323c2adfe6ef42" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ "bitflags", ] @@ -2081,15 +2119,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ "getrandom", - "redox_syscall", + "redox_syscall 0.2.16", "thiserror", ] [[package]] name = "reduce" -version = "0.1.4" +version = "0.1.5+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16d2dc47b68ac15ea328cd7ebe01d7d512ed29787f7d534ad2a3c341328b35d7" +checksum = "feff7c275fbc4a96ccdc240a5a180487a61a31baffaff6cdd4fb2c8e9e0a2ecd" [[package]] name = "regex" @@ -2106,21 +2144,15 @@ dependencies = [ [[package]] name = "regex" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" +checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" dependencies = [ - "aho-corasick 0.7.18", + "aho-corasick 0.7.20", "memchr", - "regex-syntax 0.6.27", + "regex-syntax 0.6.29", ] -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" - [[package]] name = "regex-syntax" version = "0.5.6" @@ -2132,9 +2164,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.27" +version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "remove_dir_all" @@ -2157,9 +2189,9 @@ dependencies = [ [[package]] name = "rustc-demangle" -version = "0.1.21" +version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" +checksum = "d4a36c42d1873f9a77c53bde094f9664d9891bc604a45b4798fd2c389ed12e5b" [[package]] name = "rustc-hex" @@ -2176,11 +2208,25 @@ dependencies = [ "semver 0.11.0", ] +[[package]] +name = "rustix" +version = "0.37.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85597d61f83914ddeba6a47b3b8ffe7365107221c2e557ed94426489fefb5f77" +dependencies = [ + "bitflags", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys 0.48.0", +] + [[package]] name = "ryu" -version = "1.0.10" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3f6f92acf49d1b98f7a81226834412ada05458b7364277387724a237f062695" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] name = "same-file" @@ -2210,9 +2256,9 @@ dependencies = [ [[package]] name = "scoped-tls" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea6a9290e3c9cf0f18145ef7ffa62d68ee0bf5fcd651017e586dc7fd5da448c2" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" [[package]] name = "scopeguard" @@ -2231,9 +2277,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" +checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" dependencies = [ "serde", ] @@ -2249,9 +2295,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.138" +version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1578c6245786b9d168c5447eeacfb96856573ca56c9d68fdcf394be134882a47" +checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" dependencies = [ "serde_derive", ] @@ -2268,23 +2314,23 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.138" +version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "023e9b1467aef8a10fb88f25611870ada9800ef7e22afce356bb0d2387b6f27c" +checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 2.0.14", ] [[package]] name = "serde_json" -version = "1.0.82" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82c2c1fdcd807d1098552c5b9a36e425e42e9fbd7c6a37a8425f390f781f7fa7" +checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" dependencies = [ "indexmap", - "itoa 1.0.2", + "itoa", "ryu", "serde", ] @@ -2338,9 +2384,9 @@ dependencies = [ [[package]] name = "sha3" -version = "0.10.1" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "881bf8156c87b6301fc5ca6b27f11eeb2761224c7081e69b409d5a1951a70c86" +checksum = "54c2bb1a323307527314a36bfb73f24febb08ce2b8a554bf4ffd6f51ad15198c" dependencies = [ "digest 0.10.6", "keccak", @@ -2348,9 +2394,9 @@ dependencies = [ [[package]] name = "single" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd5add732a1ab689845591a1b50339cf5310b563e08dc5813c65991f30369ea2" +checksum = "9db45bb685b486eec37e0271dcc0dac76eae5e893125f8a4f0511d0a1d29543b" dependencies = [ "failure", ] @@ -2364,7 +2410,7 @@ dependencies = [ "bytecount", "cargo_metadata", "error-chain 0.12.4", - "glob 0.3.0", + "glob 0.3.1", "pulldown-cmark", "tempfile", "walkdir", @@ -2372,9 +2418,12 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.6" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb703cfe953bccee95685111adeedb76fabe4e97549a58d16f03ea7b9367bb32" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +dependencies = [ + "autocfg", +] [[package]] name = "static_assertions" @@ -2402,9 +2451,9 @@ checksum = "2d67a5a62ba6e01cb2192ff309324cb4875d0c451d55fe2319433abe7a05a8ee" [[package]] name = "subtle" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" @@ -2419,12 +2468,23 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.107" +version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", + "proc-macro2 1.0.56", + "quote 1.0.26", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf316d5356ed6847742d036f8a39c3b8435cac10bd528a4bd461928a6ab34d5" +dependencies = [ + "proc-macro2 1.0.56", + "quote 1.0.26", "unicode-ident", ] @@ -2434,9 +2494,9 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baa8e7560a164edb1621a55d18a0c59abf49d360f47aa7b821061dd7eea7fac9" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -2445,9 +2505,9 @@ version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "unicode-xid 0.2.4", ] @@ -2469,23 +2529,22 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.3.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" +checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" dependencies = [ "cfg-if 1.0.0", "fastrand", - "libc", - "redox_syscall", - "remove_dir_all", - "winapi", + "redox_syscall 0.3.5", + "rustix", + "windows-sys 0.45.0", ] [[package]] name = "termcolor" -version = "1.1.3" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" dependencies = [ "winapi-util", ] @@ -2501,22 +2560,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.38" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.38" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 2.0.14", ] [[package]] @@ -2539,9 +2598,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" dependencies = [ "serde", ] @@ -2564,9 +2623,9 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -2596,27 +2655,27 @@ checksum = "a9b2228007eba4120145f785df0f6c92ea538f5a3635a612ecf4e334c8c1446d" [[package]] name = "typenum" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "ucd-trie" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89570599c4fe5585de2b388aab47e99f7fa4e9238a1399f707a02e356058141c" +checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" [[package]] name = "ucd-util" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bfcbf611b122f2c10eb1bb6172fbc4c2e25df9970330e4d75ce2b5201c9bfc" +checksum = "abd2fc5d32b590614af8b0a20d837f32eca055edd0bbead59a9cfe80858be003" [[package]] name = "uint" -version = "0.9.3" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12f03af7ccf01dd611cc450a0d10dbc9b745770d096473e2faf0ca6e2d66d1e0" +checksum = "76f64bba2c53b04fcab63c01a7d7427eadc821e3bc48c34dc9ba29c501164b52" dependencies = [ "byteorder", "crunchy", @@ -2635,21 +2694,21 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.1" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bd2fe26506023ed7b5e1e315add59d6f584c621d037f9368fea9cfb988f368c" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" [[package]] name = "unicode-segmentation" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e8820f5d777f6224dc4be3632222971ac30164d4a258d595640799554ebfd99" +checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" [[package]] name = "unicode-width" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" +checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" [[package]] name = "unicode-xid" @@ -2695,12 +2754,11 @@ checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" [[package]] name = "walkdir" -version = "2.3.2" +version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" dependencies = [ "same-file", - "winapi", "winapi-util", ] @@ -2731,9 +2789,9 @@ dependencies = [ "bumpalo", "lazy_static", "log", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "wasm-bindgen-shared", ] @@ -2755,7 +2813,7 @@ version = "0.2.81" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c441e177922bc58f1e12c022624b6216378e5febc2f0533e41ba443d505b80aa" dependencies = [ - "quote 1.0.20", + "quote 1.0.26", "wasm-bindgen-macro-support", ] @@ -2765,9 +2823,9 @@ version = "0.2.81" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d94ac45fcf608c1f45ef53e748d35660f168490c10b23704c7779ab8f5c3048" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2798,8 +2856,8 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88ad594bf33e73cafcac2ae9062fc119d4f75f9c77e25022f91c9a64bd5b6463" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", + "proc-macro2 1.0.56", + "quote 1.0.26", ] [[package]] @@ -2843,34 +2901,171 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.0", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + [[package]] name = "wyz" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30b31594f29d27036c383b53b59ed3476874d518f0efb151b27a4c275141390e" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" dependencies = [ "tap", ] +[[package]] +name = "yansi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" + [[package]] name = "zeroize" -version = "1.5.6" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20b578acffd8516a6c3f2a1bdefc1ec37e547bb4e0fb8b6b01a4cafc886b4442" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" dependencies = [ "zeroize_derive", ] [[package]] name = "zeroize_derive" -version = "1.3.2" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f8f187641dad4f680d25c4bfc4225b418165984179f26ca76ec4fb6441d3a17" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", - "synstructure", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 2.0.14", ] [[package]] @@ -2995,7 +3190,7 @@ version = "0.1.2" dependencies = [ "bellman_ce", "byteorder", - "pretty_assertions 1.2.1", + "pretty_assertions 1.3.0", "zkutil", "zokrates_ast", "zokrates_core", @@ -3019,7 +3214,7 @@ dependencies = [ "hex 0.3.2", "lazy_static", "log", - "pretty_assertions 1.2.1", + "pretty_assertions 1.3.0", "primitive-types", "rand 0.4.6", "rand 0.8.5", @@ -3041,6 +3236,7 @@ dependencies = [ "zokrates_field", "zokrates_fs_resolver", "zokrates_interpreter", + "zokrates_profiler", "zokrates_proof_systems", ] @@ -3212,6 +3408,13 @@ dependencies = [ "zokrates_parser", ] +[[package]] +name = "zokrates_profiler" +version = "0.1.0" +dependencies = [ + "zokrates_ast", +] + [[package]] name = "zokrates_proof_systems" version = "0.1.1" @@ -3264,5 +3467,5 @@ dependencies = [ name = "zokrates_test_derive" version = "0.0.1" dependencies = [ - "glob 0.3.0", + "glob 0.3.1", ] diff --git a/Cargo.toml b/Cargo.toml index d521322d6..fd10dd820 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,9 @@ members = [ "zokrates_bellman", "zokrates_proof_systems", "zokrates_js", - "zokrates_circom" + "zokrates_circom", + "zokrates_profiler", ] -exclude = [] \ No newline at end of file +[profile.dev] +opt-level = 1 \ No newline at end of file diff --git a/changelogs/unreleased/1285-schaeff b/changelogs/unreleased/1285-schaeff new file mode 100644 index 000000000..aa8478a07 --- /dev/null +++ b/changelogs/unreleased/1285-schaeff @@ -0,0 +1 @@ +Introduce sourcemaps, introduce `inspect` command to identify costly parts of the source \ No newline at end of file diff --git a/zokrates_analysis/src/assembly_transformer.rs b/zokrates_analysis/src/assembly_transformer.rs index 7c8e856e0..446f4c8f6 100644 --- a/zokrates_analysis/src/assembly_transformer.rs +++ b/zokrates_analysis/src/assembly_transformer.rs @@ -3,9 +3,13 @@ use crate::ZirPropagator; use std::fmt; +use std::ops::*; use zokrates_ast::zir::lqc::LinQuadComb; use zokrates_ast::zir::result_folder::ResultFolder; -use zokrates_ast::zir::{FieldElementExpression, Id, Identifier, ZirAssemblyStatement, ZirProgram}; +use zokrates_ast::zir::AssemblyConstraint; +use zokrates_ast::zir::{ + Expr, FieldElementExpression, Id, Identifier, ZirAssemblyStatement, ZirProgram, +}; use zokrates_field::Field; #[derive(Debug)] @@ -28,143 +32,132 @@ impl AssemblyTransformer { impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer { type Error = Error; - fn fold_assembly_statement( + fn fold_assembly_constraint( &mut self, - s: ZirAssemblyStatement<'ast, T>, + s: AssemblyConstraint<'ast, T>, ) -> Result>, Self::Error> { - match s { - ZirAssemblyStatement::Assignment(_, _) => Ok(vec![s]), - ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { - let lhs = self.fold_field_expression(lhs)?; - let rhs = self.fold_field_expression(rhs)?; - - let (is_quadratic, lhs, rhs) = match (lhs, rhs) { - ( - lhs @ FieldElementExpression::Identifier(..), - rhs @ FieldElementExpression::Identifier(..), - ) => (true, lhs, rhs), - (FieldElementExpression::Mult(x, y), other) - | (other, FieldElementExpression::Mult(x, y)) - if other.is_linear() => - { - ( - x.is_linear() && y.is_linear(), - other, - FieldElementExpression::Mult(x, y), + let lhs = self.fold_field_expression(s.left)?; + let rhs = self.fold_field_expression(s.right)?; + + let (is_quadratic, lhs, rhs) = match (lhs, rhs) { + ( + lhs @ FieldElementExpression::Identifier(..), + rhs @ FieldElementExpression::Identifier(..), + ) => (true, lhs, rhs), + (FieldElementExpression::Mult(e), other) | (other, FieldElementExpression::Mult(e)) + if other.is_linear() => + { + ( + e.left.is_linear() && e.right.is_linear(), + other, + FieldElementExpression::Mult(e), + ) + } + (lhs, rhs) => (false, lhs, rhs), + }; + + match is_quadratic { + true => Ok(vec![ZirAssemblyStatement::constraint(lhs, rhs, s.metadata)]), + false => { + let sub = FieldElementExpression::sub(lhs, rhs); + let mut lqc = LinQuadComb::try_from(sub.clone()) + .map_err(|_| Error("Non-quadratic constraints are not allowed".to_string()))?; + + let linear = lqc + .linear + .into_iter() + .map(|(c, i)| { + FieldElementExpression::mul( + FieldElementExpression::value(c), + FieldElementExpression::identifier(i), ) - } - (lhs, rhs) => (false, lhs, rhs), - }; - - match is_quadratic { - true => Ok(vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)]), - false => { - let sub = FieldElementExpression::Sub(box lhs, box rhs); - let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| { - Error("Non-quadratic constraints are not allowed".to_string()) - })?; - - let linear = lqc - .linear - .into_iter() - .map(|(c, i)| { - FieldElementExpression::Mult( - box FieldElementExpression::Number(c), - box FieldElementExpression::identifier(i), - ) - }) - .fold(FieldElementExpression::Number(T::from(0)), |acc, e| { - FieldElementExpression::Add(box acc, box e) - }); - - let lhs = FieldElementExpression::Add( - box FieldElementExpression::Number(lqc.constant), - box linear, - ); - - let rhs: FieldElementExpression<'ast, T> = if lqc.quadratic.len() > 1 { - let common_factor = lqc - .quadratic - .iter() - .scan(None, |state: &mut Option>, (_, a, b)| { - // short circuit if we do not have any common factors anymore - if *state == Some(vec![]) { - None - } else { - match state { - // only keep factors found in this term - Some(factors) => { - factors.retain(|&x| x == a || x == b); - } - // initialisation step, start with the two factors in the first term - None => { - *state = Some(vec![a, b]); - } - }; - state.clone() + }) + .fold(FieldElementExpression::value(T::from(0)), |acc, e| { + FieldElementExpression::add(acc, e) + }); + + let lhs = FieldElementExpression::add( + FieldElementExpression::value(lqc.constant), + linear, + ); + + let rhs: FieldElementExpression<'ast, T> = if lqc.quadratic.len() > 1 { + let common_factor = lqc + .quadratic + .iter() + .scan(None, |state: &mut Option>, (_, a, b)| { + // short circuit if we do not have any common factors anymore + if *state == Some(vec![]) { + None + } else { + match state { + // only keep factors found in this term + Some(factors) => { + factors.retain(|&x| x == a || x == b); } - }) - .last() - .and_then(|mut v| v.pop().cloned()); - - match common_factor { - Some(factor) => Ok(FieldElementExpression::Mult( - box lqc - .quadratic - .into_iter() - .map(|(c, i0, i1)| { - let c = T::zero() - c; - let e = match (i0, i1) { - (i0, i1) if factor.eq(&i0) => { - FieldElementExpression::identifier(i1) - } - (i0, i1) if factor.eq(&i1) => { - FieldElementExpression::identifier(i0) - } - _ => unreachable!(), - }; - FieldElementExpression::Mult( - box FieldElementExpression::Number(c), - box e, - ) - }) - .fold( - FieldElementExpression::Number(T::from(0)), - |acc, e| FieldElementExpression::Add(box acc, box e), - ), - box FieldElementExpression::identifier(factor), - )), - None => Err(Error( - "Non-quadratic constraints are not allowed".to_string(), - )), - }? - } else { + // initialisation step, start with the two factors in the first term + None => { + *state = Some(vec![a, b]); + } + }; + state.clone() + } + }) + .last() + .and_then(|mut v| v.pop().cloned()); + + match common_factor { + Some(factor) => Ok(FieldElementExpression::mul( lqc.quadratic - .pop() + .into_iter() .map(|(c, i0, i1)| { - FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(T::zero() - c), - box FieldElementExpression::identifier(i0), - ), - box FieldElementExpression::identifier(i1), - ) + let c = T::zero() - c; + let e = match (i0, i1) { + (i0, i1) if factor.eq(&i0) => { + FieldElementExpression::identifier(i1) + } + (i0, i1) if factor.eq(&i1) => { + FieldElementExpression::identifier(i0) + } + _ => unreachable!(), + }; + FieldElementExpression::mul(FieldElementExpression::value(c), e) }) - .unwrap_or_else(|| FieldElementExpression::Number(T::from(0))) - }; + .fold( + FieldElementExpression::value(T::from(0)), + FieldElementExpression::add, + ), + FieldElementExpression::identifier(factor), + )), + None => Err(Error( + "Non-quadratic constraints are not allowed".to_string(), + )), + }? + } else { + lqc.quadratic + .pop() + .map(|(c, i0, i1)| { + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(T::zero() - c), + FieldElementExpression::identifier(i0), + ), + FieldElementExpression::identifier(i1), + ) + }) + .unwrap_or_else(|| FieldElementExpression::value(T::from(0))) + }; - let mut propagator = ZirPropagator::default(); - let lhs = propagator - .fold_field_expression(lhs) - .map_err(|e| Error(e.to_string()))?; + let mut propagator = ZirPropagator::default(); + let lhs = propagator + .fold_field_expression(lhs) + .map_err(|e| Error(e.to_string()))?; - let rhs = propagator - .fold_field_expression(rhs) - .map_err(|e| Error(e.to_string()))?; + let rhs = propagator + .fold_field_expression(rhs) + .map_err(|e| Error(e.to_string()))?; - Ok(vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)]) - } - } + Ok(vec![ZirAssemblyStatement::constraint(lhs, rhs, s.metadata)]) } } } @@ -180,21 +173,21 @@ mod tests { fn quadratic() { // x === a * b; let lhs = FieldElementExpression::::identifier("x".into()); - let rhs = FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let rhs = FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ); - let expected = vec![ZirAssemblyStatement::Constraint( + let expected = vec![ZirAssemblyStatement::constraint( FieldElementExpression::identifier("x".into()), - FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), SourceMetadata::default(), )]; let result = AssemblyTransformer - .fold_assembly_statement(ZirAssemblyStatement::Constraint( + .fold_assembly_statement(ZirAssemblyStatement::constraint( lhs, rhs, SourceMetadata::default(), @@ -208,15 +201,15 @@ mod tests { fn non_quadratic() { // x === ((a * b) * c); let lhs = FieldElementExpression::::identifier("x".into()); - let rhs = FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let rhs = FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::identifier("c".into()), ); - let result = AssemblyTransformer.fold_assembly_statement(ZirAssemblyStatement::Constraint( + let result = AssemblyTransformer.fold_assembly_statement(ZirAssemblyStatement::constraint( lhs, rhs, SourceMetadata::default(), @@ -229,31 +222,31 @@ mod tests { fn transform() { // x === 1 - a * b; --> (-1) + x === (((-1) * a) * b); let lhs = FieldElementExpression::identifier("x".into()); - let rhs = FieldElementExpression::Sub( - box FieldElementExpression::Number(Bn128Field::from(1)), - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let rhs = FieldElementExpression::sub( + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), ); - let expected = vec![ZirAssemblyStatement::Constraint( - FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(-1)), - box FieldElementExpression::identifier("x".into()), + let expected = vec![ZirAssemblyStatement::constraint( + FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(-1)), + FieldElementExpression::identifier("x".into()), ), - FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-1)), - box FieldElementExpression::identifier("a".into()), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-1)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ), SourceMetadata::default(), )]; let result = AssemblyTransformer - .fold_assembly_statement(ZirAssemblyStatement::Constraint( + .fold_assembly_statement(ZirAssemblyStatement::constraint( lhs, rhs, SourceMetadata::default(), @@ -267,30 +260,30 @@ mod tests { fn factorize() { // x === (a * b) + (b * c); --> x === ((a + c) * b); let lhs = FieldElementExpression::::identifier("x".into()); - let rhs = FieldElementExpression::Add( - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let rhs = FieldElementExpression::add( + FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("b".into()), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::mul( + FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("c".into()), ), ); - let expected = vec![ZirAssemblyStatement::Constraint( + let expected = vec![ZirAssemblyStatement::constraint( FieldElementExpression::identifier("x".into()), - FieldElementExpression::Mult( - box FieldElementExpression::Add( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::mul( + FieldElementExpression::add( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("c".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ), SourceMetadata::default(), )]; let result = AssemblyTransformer - .fold_assembly_statement(ZirAssemblyStatement::Constraint( + .fold_assembly_statement(ZirAssemblyStatement::constraint( lhs, rhs, SourceMetadata::default(), @@ -307,100 +300,100 @@ mod tests { // --> // ((((x + ((-1)*a)) + ((-1)*b)) + ((-1)*c)) + (2*mid)) === (((((-2)*b) + ((-2)*c)) + (4*mid)) * a); let lhs = FieldElementExpression::::identifier("x".into()); - let rhs = FieldElementExpression::Add( - box FieldElementExpression::Sub( - box FieldElementExpression::Sub( - box FieldElementExpression::Sub( - box FieldElementExpression::Add( - box FieldElementExpression::Add( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let rhs = FieldElementExpression::add( + FieldElementExpression::sub( + FieldElementExpression::sub( + FieldElementExpression::sub( + FieldElementExpression::add( + FieldElementExpression::add( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::identifier("c".into()), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::identifier("c".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("mid".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("mid".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::identifier("a".into()), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("mid".into()), + FieldElementExpression::identifier("mid".into()), ), ); - let lhs_expected = FieldElementExpression::Add( - box FieldElementExpression::Add( - box FieldElementExpression::Add( - box FieldElementExpression::Add( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-1)), - box FieldElementExpression::identifier("a".into()), + let lhs_expected = FieldElementExpression::add( + FieldElementExpression::add( + FieldElementExpression::add( + FieldElementExpression::add( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-1)), + FieldElementExpression::identifier("a".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-1)), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-1)), + FieldElementExpression::identifier("b".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-1)), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-1)), + FieldElementExpression::identifier("c".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("mid".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("mid".into()), ), ); - let rhs_expected = FieldElementExpression::Mult( - box FieldElementExpression::Add( - box FieldElementExpression::Add( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-2)), - box FieldElementExpression::identifier("b".into()), + let rhs_expected = FieldElementExpression::mul( + FieldElementExpression::add( + FieldElementExpression::add( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-2)), + FieldElementExpression::identifier("b".into()), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-2)), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-2)), + FieldElementExpression::identifier("c".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::identifier("mid".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::identifier("mid".into()), ), ), - box FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("a".into()), ); - let expected = vec![ZirAssemblyStatement::Constraint( + let expected = vec![ZirAssemblyStatement::constraint( lhs_expected, rhs_expected, SourceMetadata::default(), )]; let result = AssemblyTransformer - .fold_assembly_statement(ZirAssemblyStatement::Constraint( + .fold_assembly_statement(ZirAssemblyStatement::constraint( lhs, rhs, SourceMetadata::default(), diff --git a/zokrates_analysis/src/boolean_array_comparator.rs b/zokrates_analysis/src/boolean_array_comparator.rs index cb016c034..f9696d276 100644 --- a/zokrates_analysis/src/boolean_array_comparator.rs +++ b/zokrates_analysis/src/boolean_array_comparator.rs @@ -1,10 +1,22 @@ -use zokrates_ast::typed::{ - folder::*, ArrayExpressionInner, ArrayValue, BooleanExpression, ConditionalExpression, - ConditionalKind, EqExpression, FieldElementExpression, SelectExpression, Type, TypedExpression, - TypedProgram, UExpressionInner, +use zokrates_ast::{ + common::WithSpan, + typed::{ + folder::*, ArrayExpression, BooleanExpression, Conditional, ConditionalKind, Expr, + FieldElementExpression, Select, Type, TypedExpression, TypedProgram, UExpression, + UExpressionInner, + }, }; + use zokrates_field::Field; +fn sum_rec + Clone>(a: &[T], default: &T) -> T { + match a.len() { + 0 => default.clone(), + 1 => a[0].clone(), + n => sum_rec(&a[..n / 2], default) + sum_rec(&a[n / 2..], default), + } +} + #[derive(Default)] pub struct BooleanArrayComparator; @@ -15,73 +27,57 @@ impl BooleanArrayComparator { } impl<'ast, T: Field> Folder<'ast, T> for BooleanArrayComparator { - fn fold_boolean_expression( + fn fold_boolean_expression_cases( &mut self, e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { match e { - BooleanExpression::ArrayEq(e) => match e.left.inner_type() { + BooleanExpression::ArrayEq(e) => match *e.left.inner_type() { Type::Boolean => { + let span = e.get_span(); + let len = e.left.size(); let len = match len.as_inner() { - UExpressionInner::Value(v) => *v as usize, + UExpressionInner::Value(v) => v.value as usize, _ => unreachable!("array size should be known"), }; let chunk_size = T::get_required_bits() as usize - 1; let left_elements: Vec<_> = (0..len) - .map(|i| { - BooleanExpression::Select(SelectExpression::new( - *e.left.clone(), - (i as u32).into(), - )) - }) + .map(|i| BooleanExpression::select(*e.left.clone(), i as u32).span(span)) .collect(); let right_elements: Vec<_> = (0..len) - .map(|i| { - BooleanExpression::Select(SelectExpression::new( - *e.right.clone(), - (i as u32).into(), - )) - }) + .map(|i| BooleanExpression::select(*e.right.clone(), i as u32).span(span)) .collect(); let process = |elements: &[BooleanExpression<'ast, T>]| { elements .chunks(chunk_size) .map(|chunk| { - TypedExpression::from( - chunk + TypedExpression::from(sum_rec( + &chunk .iter() .rev() .enumerate() .rev() .map(|(index, c)| { - FieldElementExpression::Conditional( - ConditionalExpression::new( - c.clone(), - FieldElementExpression::Pow( - box FieldElementExpression::Number( - T::from(2), - ), - box (index as u32).into(), - ), - T::zero().into(), - ConditionalKind::Ternary, + FieldElementExpression::conditional( + c.clone().span(span), + FieldElementExpression::pow( + FieldElementExpression::value(T::from(2)) + .span(span), + UExpression::from(index as u32).span(span), ), + FieldElementExpression::from(T::zero()).span(span), + ConditionalKind::Ternary, ) + .span(span) }) - .fold(None, |acc, e| match acc { - Some(acc) => { - Some(FieldElementExpression::Add(box acc, box e)) - } - None => Some(e), - }) - .unwrap_or_else(|| { - FieldElementExpression::Number(T::zero()) - }), - ) + .collect::>(), + &FieldElementExpression::value(T::from(0)).span(span), + )) + .span(span) .into() }) .collect() @@ -93,23 +89,28 @@ impl<'ast, T: Field> Folder<'ast, T> for BooleanArrayComparator { let chunk_count = left.len(); - BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpressionInner::Value(ArrayValue(left)) - .annotate(Type::FieldElement, chunk_count as u32), - ArrayExpressionInner::Value(ArrayValue(right)) - .annotate(Type::FieldElement, chunk_count as u32), - )) + BooleanExpression::array_eq( + ArrayExpression::value(left) + .annotate(Type::FieldElement, chunk_count as u32) + .span(span), + ArrayExpression::value(right) + .annotate(Type::FieldElement, chunk_count as u32) + .span(span), + ) } - _ => fold_boolean_expression(self, BooleanExpression::ArrayEq(e)), + _ => fold_boolean_expression_cases(self, BooleanExpression::ArrayEq(e)), }, - e => fold_boolean_expression(self, e), + e => fold_boolean_expression_cases(self, e), } } } #[cfg(test)] mod tests { - use zokrates_ast::typed::{BooleanExpression, EqExpression, Type}; + use zokrates_ast::{ + common::expressions::BinaryExpression, + typed::{BooleanExpression, Type}, + }; use zokrates_field::DummyCurveField; use zokrates_ast::typed::utils::{a, a_id, conditional, f, select, u_32}; @@ -127,9 +128,9 @@ mod tests { let y = a_id("y").annotate(Type::Boolean, 2u32); let e: BooleanExpression = - BooleanExpression::ArrayEq(EqExpression::new(x.clone(), y.clone())); + BooleanExpression::ArrayEq(BinaryExpression::new(x.clone(), y.clone())); - let expected = BooleanExpression::ArrayEq(EqExpression::new( + let expected = BooleanExpression::ArrayEq(BinaryExpression::new( a([ conditional(select(x.clone(), 0u32), f(2).pow(u_32(1)), f(0)) + conditional(select(x.clone(), 1u32), f(2).pow(u_32(0)), f(0)), @@ -155,9 +156,9 @@ mod tests { let y = a_id("y").annotate(Type::Boolean, 3u32); let e: BooleanExpression = - BooleanExpression::ArrayEq(EqExpression::new(x.clone(), y.clone())); + BooleanExpression::ArrayEq(BinaryExpression::new(x.clone(), y.clone())); - let expected = BooleanExpression::ArrayEq(EqExpression::new( + let expected = BooleanExpression::ArrayEq(BinaryExpression::new( a([ conditional(select(x.clone(), 0u32), f(2).pow(u_32(1)), f(0)) + conditional(select(x.clone(), 1u32), f(2).pow(u_32(0)), f(0)), diff --git a/zokrates_analysis/src/branch_isolator.rs b/zokrates_analysis/src/branch_isolator.rs index ee41233a4..6b6789ec6 100644 --- a/zokrates_analysis/src/branch_isolator.rs +++ b/zokrates_analysis/src/branch_isolator.rs @@ -3,6 +3,7 @@ // `if c then a else b fi` becomes `if c then { a } else { b } fi`, and down the line any statements resulting from trating `a` and `b` can be safely kept inside the respective blocks. +use zokrates_ast::common::{Fold, WithSpan}; use zokrates_ast::typed::folder::*; use zokrates_ast::typed::*; use zokrates_field::Field; @@ -18,17 +19,25 @@ impl Isolator { impl<'ast, T: Field> Folder<'ast, T> for Isolator { fn fold_conditional_expression< - E: Expr<'ast, T> + Block<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>, + E: Expr<'ast, T> + Block<'ast, T> + Fold + Conditional<'ast, T>, >( &mut self, _: &E::Ty, e: ConditionalExpression<'ast, T, E>, ) -> ConditionalOrExpression<'ast, T, E> { - ConditionalOrExpression::Conditional(ConditionalExpression::new( - self.fold_boolean_expression(*e.condition), - E::block(vec![], e.consequence.fold(self)), - E::block(vec![], e.alternative.fold(self)), - e.kind, - )) + let span = e.get_span(); + + let consequence_span = e.consequence.get_span(); + let alternative_span = e.alternative.get_span(); + + ConditionalOrExpression::Conditional( + ConditionalExpression::new( + self.fold_boolean_expression(*e.condition), + E::block(vec![], e.consequence.fold(self)).span(consequence_span), + E::block(vec![], e.alternative.fold(self)).span(alternative_span), + e.kind, + ) + .span(span), + ) } } diff --git a/zokrates_analysis/src/condition_redefiner.rs b/zokrates_analysis/src/condition_redefiner.rs index 775271b1d..535a51375 100644 --- a/zokrates_analysis/src/condition_redefiner.rs +++ b/zokrates_analysis/src/condition_redefiner.rs @@ -1,7 +1,10 @@ -use zokrates_ast::typed::{ - folder::*, BlockExpression, BooleanExpression, Conditional, ConditionalExpression, - ConditionalOrExpression, CoreIdentifier, Expr, Id, Identifier, Type, TypedExpression, - TypedProgram, TypedStatement, Variable, +use zokrates_ast::{ + common::{Fold, WithSpan}, + typed::{ + folder::*, BlockExpression, BooleanExpression, Conditional, ConditionalExpression, + ConditionalOrExpression, CoreIdentifier, Expr, Id, Identifier, Type, TypedExpression, + TypedProgram, TypedStatement, Variable, + }, }; use zokrates_field::Field; @@ -18,14 +21,14 @@ impl<'ast, T: Field> ConditionRedefiner<'ast, T> { } impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> { - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { + fn fold_statement_cases(&mut self, s: TypedStatement<'ast, T>) -> Vec> { assert!(self.buffer.is_empty()); - let s = fold_statement(self, s); + let s = fold_statement_cases(self, s); let buffer = std::mem::take(&mut self.buffer); buffer.into_iter().chain(s).collect() } - fn fold_block_expression>( + fn fold_block_expression>( &mut self, b: BlockExpression<'ast, T, E>, ) -> BlockExpression<'ast, T, E> { @@ -54,25 +57,32 @@ impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> { b } - fn fold_conditional_expression + Conditional<'ast, T> + Fold<'ast, T>>( + fn fold_conditional_expression + Conditional<'ast, T> + Fold>( &mut self, _: &E::Ty, e: ConditionalExpression<'ast, T, E>, ) -> ConditionalOrExpression<'ast, T, E> { let condition = self.fold_boolean_expression(*e.condition); + let condition_span = condition.get_span(); let condition = match condition { condition @ BooleanExpression::Value(_) | condition @ BooleanExpression::Identifier(_) => condition, condition => { let condition_id = Identifier::from(CoreIdentifier::Condition(self.index)); - self.buffer.push(TypedStatement::definition( - Variable::immutable(condition_id.clone(), Type::Boolean).into(), - TypedExpression::from(condition), - )); + self.buffer.push( + TypedStatement::definition( + Variable::new(condition_id.clone(), Type::Boolean) + .span(condition_span) + .into(), + TypedExpression::from(condition).span(condition_span), + ) + .span(condition_span), + ); self.index += 1; - BooleanExpression::identifier(condition_id) + BooleanExpression::identifier(condition_id).span(condition_span) } - }; + } + .span(condition_span); let consequence = e.consequence.fold(self); let alternative = e.alternative.fold(self); @@ -89,6 +99,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> { #[cfg(test)] mod tests { use super::*; + use std::ops::*; use zokrates_ast::typed::{ Block, BooleanExpression, Conditional, ConditionalKind, FieldElementExpression, Type, }; @@ -102,9 +113,9 @@ mod tests { let s = TypedStatement::definition( Variable::field_element("foo").into(), FieldElementExpression::conditional( - BooleanExpression::Value(true), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + BooleanExpression::value(true), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ) .into(), @@ -124,8 +135,8 @@ mod tests { Variable::field_element("foo").into(), FieldElementExpression::conditional( BooleanExpression::identifier("c".into()), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ) .into(), @@ -143,17 +154,17 @@ mod tests { // bool #CONDITION_0 = c && d; // field foo = if #CONDITION_0 { 1 } else { 2 }; - let condition = BooleanExpression::And( - box BooleanExpression::identifier("c".into()), - box BooleanExpression::identifier("d".into()), + let condition = BooleanExpression::bitand( + BooleanExpression::identifier("c".into()), + BooleanExpression::identifier("d".into()), ); let s = TypedStatement::definition( Variable::field_element("foo").into(), FieldElementExpression::conditional( condition.clone(), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ) .into(), @@ -164,7 +175,7 @@ mod tests { let expected = vec![ // define condition TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(), + Variable::new(CoreIdentifier::Condition(0), Type::Boolean).into(), condition.into(), ), // rewrite statement @@ -172,8 +183,8 @@ mod tests { Variable::field_element("foo").into(), FieldElementExpression::conditional( BooleanExpression::identifier(CoreIdentifier::Condition(0).into()), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ) .into(), @@ -202,14 +213,14 @@ mod tests { // 3 // }; - let condition_0 = BooleanExpression::And( - box BooleanExpression::identifier("c".into()), - box BooleanExpression::identifier("d".into()), + let condition_0 = BooleanExpression::bitand( + BooleanExpression::identifier("c".into()), + BooleanExpression::identifier("d".into()), ); - let condition_1 = BooleanExpression::And( - box BooleanExpression::identifier("e".into()), - box BooleanExpression::identifier("f".into()), + let condition_1 = BooleanExpression::bitand( + BooleanExpression::identifier("e".into()), + BooleanExpression::identifier("f".into()), ); let s = TypedStatement::definition( @@ -218,11 +229,11 @@ mod tests { condition_0.clone(), FieldElementExpression::conditional( condition_1.clone(), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ) .into(), @@ -233,11 +244,11 @@ mod tests { let expected = vec![ // define conditions TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(), + Variable::new(CoreIdentifier::Condition(0), Type::Boolean).into(), condition_0.into(), ), TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(1), Type::Boolean).into(), + Variable::new(CoreIdentifier::Condition(1), Type::Boolean).into(), condition_1.into(), ), // rewrite statement @@ -247,11 +258,11 @@ mod tests { BooleanExpression::identifier(CoreIdentifier::Condition(0).into()), FieldElementExpression::conditional( BooleanExpression::identifier(CoreIdentifier::Condition(1).into()), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ) .into(), @@ -284,19 +295,19 @@ mod tests { // if #CONDITION_2 { 2 } : { 3 } // }; - let condition_0 = BooleanExpression::And( - box BooleanExpression::identifier("c".into()), - box BooleanExpression::identifier("d".into()), + let condition_0 = BooleanExpression::bitand( + BooleanExpression::identifier("c".into()), + BooleanExpression::identifier("d".into()), ); - let condition_1 = BooleanExpression::And( - box BooleanExpression::identifier("e".into()), - box BooleanExpression::identifier("f".into()), + let condition_1 = BooleanExpression::bitand( + BooleanExpression::identifier("e".into()), + BooleanExpression::identifier("f".into()), ); - let condition_2 = BooleanExpression::And( - box BooleanExpression::identifier("e".into()), - box BooleanExpression::identifier("f".into()), + let condition_2 = BooleanExpression::bitand( + BooleanExpression::identifier("e".into()), + BooleanExpression::identifier("f".into()), ); let condition_id_0 = BooleanExpression::identifier(CoreIdentifier::Condition(0).into()); @@ -310,24 +321,24 @@ mod tests { FieldElementExpression::block( vec![TypedStatement::definition( Variable::field_element("a").into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), )], FieldElementExpression::conditional( condition_1.clone(), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ), ), FieldElementExpression::block( vec![TypedStatement::definition( Variable::field_element("b").into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), )], FieldElementExpression::conditional( condition_2.clone(), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ), ), @@ -341,7 +352,7 @@ mod tests { let expected = vec![ // define conditions TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(), + Variable::new(CoreIdentifier::Condition(0), Type::Boolean).into(), condition_0.into(), ), // rewrite statement @@ -353,18 +364,17 @@ mod tests { vec![ TypedStatement::definition( Variable::field_element("a").into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), ), TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(1), Type::Boolean) - .into(), + Variable::new(CoreIdentifier::Condition(1), Type::Boolean).into(), condition_1.into(), ), ], FieldElementExpression::conditional( condition_id_1, - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ), ), @@ -372,18 +382,17 @@ mod tests { vec![ TypedStatement::definition( Variable::field_element("b").into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(2), Type::Boolean) - .into(), + Variable::new(CoreIdentifier::Condition(2), Type::Boolean).into(), condition_2.into(), ), ], FieldElementExpression::conditional( condition_id_2, - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ), ), diff --git a/zokrates_analysis/src/constant_argument_checker.rs b/zokrates_analysis/src/constant_argument_checker.rs index 485cf9181..a8b772f51 100644 --- a/zokrates_analysis/src/constant_argument_checker.rs +++ b/zokrates_analysis/src/constant_argument_checker.rs @@ -1,9 +1,10 @@ use std::fmt; use zokrates_ast::common::FlatEmbed; +use zokrates_ast::typed::result_folder::*; +use zokrates_ast::typed::{result_folder::ResultFolder, Constant, EmbedCall, TypedStatement}; use zokrates_ast::typed::{ - result_folder::fold_statement, result_folder::ResultFolder, Constant, EmbedCall, TypedStatement, + DefinitionRhs, DefinitionStatement, TypedProgram, UBitwidth, UExpression, UExpressionInner, }; -use zokrates_ast::typed::{DefinitionRhs, TypedProgram}; use zokrates_field::Field; pub struct ConstantArgumentChecker; @@ -26,47 +27,83 @@ impl fmt::Display for Error { impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { type Error = Error; - fn fold_statement( + fn fold_definition_statement( &mut self, - s: TypedStatement<'ast, T>, + s: DefinitionStatement<'ast, T>, ) -> Result>, Self::Error> { - match s { - TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => { - match embed_call { - EmbedCall { - embed: FlatEmbed::BitArrayLe, - .. - } => { - let arguments = embed_call - .arguments - .into_iter() - .map(|a| self.fold_expression(a)) - .collect::, _>>()?; + match s.rhs { + DefinitionRhs::EmbedCall(embed_call) => match embed_call { + EmbedCall { + embed: FlatEmbed::BitArrayLe, + .. + } => { + let arguments = embed_call + .arguments + .into_iter() + .map(|a| self.fold_expression(a)) + .collect::, _>>()?; - if arguments[1].is_constant() { - Ok(vec![TypedStatement::Definition( - assignee, - EmbedCall { - embed: FlatEmbed::BitArrayLe, - generics: embed_call.generics, - arguments, - } - .into(), - )]) - } else { - Err(Error(format!( - "Cannot compare to a variable value, found `{}`", - arguments[1] - ))) - } + if arguments[1].is_constant() { + Ok(vec![TypedStatement::embed_call_definition( + s.assignee, + EmbedCall { + embed: FlatEmbed::BitArrayLe, + generics: embed_call.generics, + arguments, + }, + )]) + } else { + Err(Error(format!( + "Cannot compare to a variable value, found `{}`", + arguments[1] + ))) } - embed_call => Ok(vec![TypedStatement::Definition( - assignee, - embed_call.into(), - )]), + } + embed_call => Ok(vec![TypedStatement::embed_call_definition( + s.assignee, embed_call, + )]), + }, + _ => fold_definition_statement(self, s), + } + } + + fn fold_uint_expression_cases( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Error> { + match e { + UExpressionInner::LeftShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + + match right.as_inner() { + UExpressionInner::Value(_) => { + Ok(UExpression::left_shift(left, right).into_inner()) + } + by => Err(Error(format!( + "Cannot shift by a variable value, found `{} << {}`", + left, + by.clone().annotate(UBitwidth::B32) + ))), + } + } + UExpressionInner::RightShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + + match right.as_inner() { + UExpressionInner::Value(_) => { + Ok(UExpression::right_shift(left, right).into_inner()) + } + by => Err(Error(format!( + "Cannot shift by a variable value, found `{} >> {}`", + left, + by.clone().annotate(UBitwidth::B32) + ))), } } - s => fold_statement(self, s), + e => fold_uint_expression_cases(self, bitwidth, e), } } } diff --git a/zokrates_analysis/src/constant_resolver.rs b/zokrates_analysis/src/constant_resolver.rs index bf816248f..f39167d39 100644 --- a/zokrates_analysis/src/constant_resolver.rs +++ b/zokrates_analysis/src/constant_resolver.rs @@ -9,18 +9,18 @@ use zokrates_field::Field; // a map of the canonical constants in this program. with all imported constants reduced to their canonical value type ProgramConstants<'ast, T> = - HashMap, TypedConstant<'ast, T>>>; + HashMap, TypedConstant<'ast, T>>>; pub struct ConstantResolver<'ast, T> { modules: TypedModules<'ast, T>, - location: OwnedTypedModuleId, + location: OwnedModuleId, constants: ProgramConstants<'ast, T>, } impl<'ast, T: Field> ConstantResolver<'ast, T> { pub fn new( modules: TypedModules<'ast, T>, - location: OwnedTypedModuleId, + location: OwnedModuleId, constants: ProgramConstants<'ast, T>, ) -> Self { ConstantResolver { @@ -35,14 +35,14 @@ impl<'ast, T: Field> ConstantResolver<'ast, T> { inliner.fold_program(p) } - fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { + fn change_location(&mut self, location: OwnedModuleId) -> OwnedModuleId { let prev = self.location.clone(); self.location = location; self.constants.entry(self.location.clone()).or_default(); prev } - fn treated(&self, id: &TypedModuleId) -> bool { + fn treated(&self, id: &ModuleId) -> bool { self.constants.contains_key(id) } @@ -67,7 +67,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantResolver<'ast, T> { } } - fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId { + fn fold_module_id(&mut self, id: OwnedModuleId) -> OwnedModuleId { // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet if !self.treated(&id) { let current_m_id = self.change_location(id.clone()); @@ -109,6 +109,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantResolver<'ast, T> { #[cfg(test)] mod tests { use super::*; + use std::ops::*; use zokrates_ast::typed::types::{DeclarationSignature, GTupleType}; use zokrates_ast::typed::{ DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression, @@ -130,7 +131,7 @@ mod tests { let const_id = "a"; let main: TypedFunction = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from(const_id)).into(), )], signature: DeclarationSignature::new() @@ -139,6 +140,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -147,7 +149,7 @@ mod tests { TypedConstantSymbolDeclaration::new( CanonicalConstantIdentifier::new(const_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( + TypedExpression::FieldElement(FieldElementExpression::value( Bn128Field::from(1), )), DeclarationType::FieldElement, @@ -190,7 +192,7 @@ mod tests { let const_id = CanonicalConstantIdentifier::new("a", "main".into()); let main: TypedFunction = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( BooleanExpression::identifier(Identifier::from(const_id.clone())).into(), )], signature: DeclarationSignature::new() @@ -199,6 +201,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -207,7 +210,7 @@ mod tests { TypedConstantSymbolDeclaration::new( const_id, TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::Boolean(BooleanExpression::Value(true)), + TypedExpression::Boolean(BooleanExpression::value(true)), DeclarationType::Boolean, )), ) @@ -248,7 +251,7 @@ mod tests { let const_id = CanonicalConstantIdentifier::new("a", "main".into()); let main: TypedFunction = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( UExpression::identifier(Identifier::from(const_id.clone())) .annotate(UBitwidth::B32) .into(), @@ -259,6 +262,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -267,9 +271,7 @@ mod tests { TypedConstantSymbolDeclaration::new( const_id, TypedConstantSymbol::Here(TypedConstant::new( - UExpressionInner::Value(1u128) - .annotate(UBitwidth::B32) - .into(), + UExpression::value(1u128).annotate(UBitwidth::B32).into(), DeclarationType::Uint(UBitwidth::B32), )), ) @@ -310,20 +312,18 @@ mod tests { let const_id = CanonicalConstantIdentifier::new("a", "main".into()); let main: TypedFunction = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( - FieldElementExpression::Add( + statements: vec![TypedStatement::ret( + FieldElementExpression::add( FieldElementExpression::select( ArrayExpression::identifier(Identifier::from(const_id.clone())) .annotate(GType::FieldElement, 2u32), - UExpressionInner::Value(0u128).annotate(UBitwidth::B32), - ) - .into(), + UExpression::value(0u128).annotate(UBitwidth::B32), + ), FieldElementExpression::select( ArrayExpression::identifier(Identifier::from(const_id.clone())) .annotate(GType::FieldElement, 2u32), - UExpressionInner::Value(1u128).annotate(UBitwidth::B32), - ) - .into(), + UExpression::value(1u128).annotate(UBitwidth::B32), + ), ) .into(), )], @@ -333,6 +333,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -342,15 +343,10 @@ mod tests { const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::Array( - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(2)) - .into(), - FieldElementExpression::Number(Bn128Field::from(2)) - .into(), - ] - .into(), - ) + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), + ]) .annotate(GType::FieldElement, 2u32), ), DeclarationType::Array(DeclarationArrayType::new( @@ -397,7 +393,7 @@ mod tests { let main: TypedFunction = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from(const_b_id.clone())).into(), )], signature: DeclarationSignature::new() @@ -406,6 +402,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -414,7 +411,7 @@ mod tests { TypedConstantSymbolDeclaration::new( const_a_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( + TypedExpression::FieldElement(FieldElementExpression::value( Bn128Field::from(1), )), DeclarationType::FieldElement, @@ -424,11 +421,11 @@ mod tests { TypedConstantSymbolDeclaration::new( const_b_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Add( - box FieldElementExpression::identifier(Identifier::from( + TypedExpression::FieldElement(FieldElementExpression::add( + FieldElementExpression::identifier(Identifier::from( const_a_id.clone(), )), - box FieldElementExpression::Number(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(1)), )), DeclarationType::FieldElement, )), @@ -505,7 +502,7 @@ mod tests { TypedConstantSymbolDeclaration::new( foo_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( + TypedExpression::FieldElement(FieldElementExpression::value( Bn128Field::from(42), )), DeclarationType::FieldElement, @@ -556,7 +553,7 @@ mod tests { ), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from( main_const_id.clone(), )) @@ -572,6 +569,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![ ("main".into(), main_module), @@ -602,7 +600,7 @@ mod tests { ), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from( main_const_id.clone(), )) @@ -618,6 +616,7 @@ mod tests { }; let expected_program: TypedProgram = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![ ("main".into(), expected_main_module), @@ -683,7 +682,7 @@ mod tests { TypedConstantSymbolDeclaration::new( foo_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( + TypedExpression::FieldElement(FieldElementExpression::value( Bn128Field::from(2), )), DeclarationType::FieldElement, @@ -693,13 +692,10 @@ mod tests { TypedConstantSymbolDeclaration::new( bar_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::Array( - ArrayExpressionInner::Repeat( - box FieldElementExpression::Number(Bn128Field::from(1)).into(), - box UExpression::from(foo_const_id.clone()), - ) - .annotate(Type::FieldElement, foo_const_id.clone()), - ), + TypedExpression::Array(ArrayExpression::repeat( + FieldElementExpression::value(Bn128Field::from(1)).into(), + UExpression::from(foo_const_id.clone()), + )), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, DeclarationConstant::Constant(foo_const_id.clone()), @@ -763,7 +759,7 @@ mod tests { ), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from( main_foo_const_id.clone(), )) @@ -779,6 +775,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![ ("main".into(), main_module), @@ -794,7 +791,7 @@ mod tests { TypedConstantSymbolDeclaration::new( main_foo_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), DeclarationType::FieldElement, )), ) @@ -802,13 +799,10 @@ mod tests { TypedConstantSymbolDeclaration::new( main_bar_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::Array( - ArrayExpressionInner::Repeat( - box FieldElementExpression::Number(Bn128Field::from(1)).into(), - box UExpression::from(foo_const_id.clone()), - ) - .annotate(Type::FieldElement, foo_const_id.clone()), - ), + TypedExpression::Array(ArrayExpression::repeat( + FieldElementExpression::value(Bn128Field::from(1)).into(), + UExpression::from(foo_const_id.clone()), + )), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, DeclarationConstant::Constant(foo_const_id.clone()), @@ -838,7 +832,7 @@ mod tests { ), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from( main_foo_const_id.clone(), )) @@ -854,6 +848,7 @@ mod tests { }; let expected_program: TypedProgram = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![ ("main".into(), expected_main_module), diff --git a/zokrates_analysis/src/dead_code.rs b/zokrates_analysis/src/dead_code.rs index 4b67b08b2..d1baba976 100644 --- a/zokrates_analysis/src/dead_code.rs +++ b/zokrates_analysis/src/dead_code.rs @@ -26,38 +26,56 @@ impl<'ast, T: Field> Folder<'ast, T> for DeadCodeEliminator<'ast> { ZirFunction { statements, ..f } } - fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { - match s { - ZirStatement::Definition(v, e) => { - // if the lhs is used later in the program - if self.used.remove(&v.id) { - // include this statement - fold_statement(self, ZirStatement::Definition(v, e)) - } else { - // otherwise remove it - vec![] - } - } - ZirStatement::IfElse(condition, consequence, alternative) => { - let condition = self.fold_boolean_expression(condition); + fn fold_if_else_statement( + &mut self, + s: zokrates_ast::zir::IfElseStatement<'ast, T>, + ) -> Vec> { + let condition = self.fold_boolean_expression(s.condition); - let mut consequence: Vec<_> = consequence - .into_iter() - .rev() - .flat_map(|e| self.fold_statement(e)) - .collect(); - consequence.reverse(); + let mut consequence: Vec<_> = s + .consequence + .into_iter() + .rev() + .flat_map(|e| self.fold_statement(e)) + .collect(); + consequence.reverse(); - let mut alternative: Vec<_> = alternative - .into_iter() - .rev() - .flat_map(|e| self.fold_statement(e)) - .collect(); - alternative.reverse(); + let mut alternative: Vec<_> = s + .alternative + .into_iter() + .rev() + .flat_map(|e| self.fold_statement(e)) + .collect(); + alternative.reverse(); + + vec![ZirStatement::if_else(condition, consequence, alternative)] + } + + fn fold_multiple_definition_statement( + &mut self, + s: zokrates_ast::zir::MultipleDefinitionStatement<'ast, T>, + ) -> Vec> { + // if the lhs is used later in the program + if s.assignees.iter().any(|a| self.used.remove(&a.id)) { + // include this statement + fold_multiple_definition_statement(self, s) + } else { + // otherwise remove it + vec![] + } + } - vec![ZirStatement::IfElse(condition, consequence, alternative)] - } - s => fold_statement(self, s), + fn fold_definition_statement( + &mut self, + s: zokrates_ast::zir::DefinitionStatement<'ast, T>, + ) -> Vec> { + // if the lhs is used later in the program + if self.used.remove(&s.assignee.id) { + // include this statement + fold_definition_statement(self, s) + } else { + // otherwise remove it + vec![] } } diff --git a/zokrates_analysis/src/expression_validator.rs b/zokrates_analysis/src/expression_validator.rs index d3a0a96ef..9dcd63413 100644 --- a/zokrates_analysis/src/expression_validator.rs +++ b/zokrates_analysis/src/expression_validator.rs @@ -1,7 +1,5 @@ use std::fmt; -use zokrates_ast::typed::result_folder::{ - fold_assembly_statement, fold_field_expression, fold_uint_expression_inner, ResultFolder, -}; +use zokrates_ast::typed::{result_folder::*, AssemblyAssignment, UExpression}; use zokrates_ast::typed::{ FieldElementExpression, TypedAssemblyStatement, TypedProgram, UBitwidth, UExpressionInner, }; @@ -27,81 +25,80 @@ impl ExpressionValidator { impl<'ast, T: Field> ResultFolder<'ast, T> for ExpressionValidator { type Error = Error; - fn fold_assembly_statement( + // we allow more dynamic expressions in witness generation + fn fold_assembly_assignment( &mut self, - s: TypedAssemblyStatement<'ast, T>, + s: AssemblyAssignment<'ast, T>, ) -> Result>, Self::Error> { - match s { - // we allow more dynamic expressions in witness generation - TypedAssemblyStatement::Assignment(_, _) => Ok(vec![s]), - s => fold_assembly_statement(self, s), - } + Ok(vec![TypedAssemblyStatement::Assignment(s)]) } - fn fold_field_expression( + fn fold_field_expression_cases( &mut self, e: FieldElementExpression<'ast, T>, ) -> Result, Self::Error> { match e { // these should have been propagated away - FieldElementExpression::And(_, _) - | FieldElementExpression::Or(_, _) - | FieldElementExpression::Xor(_, _) - | FieldElementExpression::LeftShift(_, _) - | FieldElementExpression::RightShift(_, _) => Err(Error(format!( + FieldElementExpression::And(_) + | FieldElementExpression::Or(_) + | FieldElementExpression::Xor(_) + | FieldElementExpression::LeftShift(_) + | FieldElementExpression::RightShift(_) => Err(Error(format!( "Found non-constant bitwise operation in field element expression `{}`", e ))), - FieldElementExpression::Pow(box e, box exp) => { - let e = self.fold_field_expression(e)?; - let exp = self.fold_uint_expression(exp)?; + FieldElementExpression::Pow(e) => { + let base = self.fold_field_expression(*e.left)?; + let exp = self.fold_uint_expression(*e.right)?; match exp.as_inner() { - UExpressionInner::Value(_) => Ok(FieldElementExpression::Pow(box e, box exp)), + UExpressionInner::Value(_) => Ok(FieldElementExpression::pow(base, exp)), exp => Err(Error(format!( "Found non-constant exponent in power expression `{}**{}`", - e, + base, exp.clone().annotate(UBitwidth::B32) ))), } } - e => fold_field_expression(self, e), + e => fold_field_expression_cases(self, e), } } - fn fold_uint_expression_inner( + fn fold_uint_expression_cases( &mut self, bitwidth: UBitwidth, e: UExpressionInner<'ast, T>, ) -> Result, Error> { match e { - UExpressionInner::LeftShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; + UExpressionInner::LeftShift(e) => { + let expr = self.fold_uint_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; match by.as_inner() { - UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)), - by => Err(Error(format!( - "Cannot shift by a variable value, found `{} << {}`", - e, - by.clone().annotate(UBitwidth::B32) + UExpressionInner::Value(_) => { + Ok(UExpression::left_shift(expr, by).into_inner()) + } + _ => Err(Error(format!( + "Cannot shift by a variable value, found `{}`", + UExpression::left_shift(expr, by) ))), } } - UExpressionInner::RightShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; + UExpressionInner::RightShift(e) => { + let expr = self.fold_uint_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; match by.as_inner() { - UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)), - by => Err(Error(format!( - "Cannot shift by a variable value, found `{} >> {}`", - e, - by.clone().annotate(UBitwidth::B32) + UExpressionInner::Value(_) => { + Ok(UExpression::right_shift(expr, by).into_inner()) + } + _ => Err(Error(format!( + "Cannot shift by a variable value, found `{}`", + UExpression::right_shift(expr, by) ))), } } - e => fold_uint_expression_inner(self, bitwidth, e), + e => fold_uint_expression_cases(self, bitwidth, e), } } } diff --git a/zokrates_analysis/src/flat_propagation.rs b/zokrates_analysis/src/flat_propagation.rs index eb2ca797a..547712ff3 100644 --- a/zokrates_analysis/src/flat_propagation.rs +++ b/zokrates_analysis/src/flat_propagation.rs @@ -5,6 +5,9 @@ //! @date 2018 use std::collections::HashMap; +use std::ops::*; +use zokrates_ast::common::expressions::IdentifierOrExpression; +use zokrates_ast::common::WithSpan; use zokrates_ast::flat::folder::*; use zokrates_ast::flat::*; use zokrates_field::Field; @@ -17,49 +20,62 @@ struct Propagator { impl<'ast, T: Field> Folder<'ast, T> for Propagator { fn fold_statement(&mut self, s: FlatStatement<'ast, T>) -> Vec> { match s { - FlatStatement::Definition(var, expr) => match self.fold_expression(expr) { - FlatExpression::Number(n) => { - self.constants.insert(var, n); + FlatStatement::Definition(s) => match self.fold_expression(s.rhs) { + FlatExpression::Value(n) => { + self.constants.insert(s.assignee, n.value); vec![] } - e => vec![FlatStatement::Definition(var, e)], + e => vec![FlatStatement::definition(s.assignee, e)], }, s => fold_statement(self, s), } } + fn fold_identifier_expression( + &mut self, + e: zokrates_ast::common::expressions::IdentifierExpression>, + ) -> IdentifierOrExpression, FlatExpression> { + match self.constants.get(&e.id) { + Some(c) => IdentifierOrExpression::Expression(FlatExpression::value(*c)), + None => IdentifierOrExpression::Identifier(e), + } + } + fn fold_expression(&mut self, e: FlatExpression) -> FlatExpression { + let span = e.get_span(); + match e { - FlatExpression::Number(n) => FlatExpression::Number(n), - FlatExpression::Identifier(id) => match self.constants.get(&id) { - Some(c) => FlatExpression::Number(*c), - None => FlatExpression::Identifier(id), - }, - FlatExpression::Add(box e1, box e2) => { - match (self.fold_expression(e1), self.fold_expression(e2)) { - (FlatExpression::Number(n1), FlatExpression::Number(n2)) => { - FlatExpression::Number(n1 + n2) - } - (e1, e2) => FlatExpression::Add(box e1, box e2), + FlatExpression::Value(n) => FlatExpression::Value(n), + FlatExpression::Add(e) => match ( + self.fold_expression(*e.left), + self.fold_expression(*e.right), + ) { + (FlatExpression::Value(n1), FlatExpression::Value(n2)) => { + FlatExpression::value(n1.value + n2.value) } - } - FlatExpression::Sub(box e1, box e2) => { - match (self.fold_expression(e1), self.fold_expression(e2)) { - (FlatExpression::Number(n1), FlatExpression::Number(n2)) => { - FlatExpression::Number(n1 - n2) - } - (e1, e2) => FlatExpression::Sub(box e1, box e2), + (e1, e2) => FlatExpression::add(e1, e2), + }, + FlatExpression::Sub(e) => match ( + self.fold_expression(*e.left), + self.fold_expression(*e.right), + ) { + (FlatExpression::Value(n1), FlatExpression::Value(n2)) => { + FlatExpression::value(n1.value - n2.value) } - } - FlatExpression::Mult(box e1, box e2) => { - match (self.fold_expression(e1), self.fold_expression(e2)) { - (FlatExpression::Number(n1), FlatExpression::Number(n2)) => { - FlatExpression::Number(n1 * n2) - } - (e1, e2) => FlatExpression::Mult(box e1, box e2), + (e1, e2) => FlatExpression::sub(e1, e2), + }, + FlatExpression::Mult(e) => match ( + self.fold_expression(*e.left), + self.fold_expression(*e.right), + ) { + (FlatExpression::Value(n1), FlatExpression::Value(n2)) => { + FlatExpression::value(n1.value * n2.value) } - } + (e1, e2) => FlatExpression::mul(e1, e2), + }, + e => fold_expression(self, e), } + .span(span) } } @@ -80,14 +96,14 @@ mod tests { fn add() { let mut propagator = Propagator::default(); - let e = FlatExpression::Add( - box FlatExpression::Number(Bn128Field::from(2)), - box FlatExpression::Number(Bn128Field::from(3)), + let e = FlatExpression::add( + FlatExpression::value(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(3)), ); assert_eq!( propagator.fold_expression(e), - FlatExpression::Number(Bn128Field::from(5)) + FlatExpression::value(Bn128Field::from(5)) ); } @@ -95,14 +111,14 @@ mod tests { fn sub() { let mut propagator = Propagator::default(); - let e = FlatExpression::Sub( - box FlatExpression::Number(Bn128Field::from(3)), - box FlatExpression::Number(Bn128Field::from(2)), + let e = FlatExpression::sub( + FlatExpression::value(Bn128Field::from(3)), + FlatExpression::value(Bn128Field::from(2)), ); assert_eq!( propagator.fold_expression(e), - FlatExpression::Number(Bn128Field::from(1)) + FlatExpression::value(Bn128Field::from(1)) ); } @@ -110,14 +126,14 @@ mod tests { fn mult() { let mut propagator = Propagator::default(); - let e = FlatExpression::Mult( - box FlatExpression::Number(Bn128Field::from(3)), - box FlatExpression::Number(Bn128Field::from(2)), + let e = FlatExpression::mul( + FlatExpression::value(Bn128Field::from(3)), + FlatExpression::value(Bn128Field::from(2)), ); assert_eq!( propagator.fold_expression(e), - FlatExpression::Number(Bn128Field::from(6)) + FlatExpression::value(Bn128Field::from(6)) ); } } diff --git a/zokrates_analysis/src/flatten_complex_types.rs b/zokrates_analysis/src/flatten_complex_types.rs index 578cd8ee2..20cfa9436 100644 --- a/zokrates_analysis/src/flatten_complex_types.rs +++ b/zokrates_analysis/src/flatten_complex_types.rs @@ -1,10 +1,15 @@ use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; +use std::ops::*; +use zokrates_ast::common::expressions::{BinaryExpression, ValueExpression}; +use zokrates_ast::common::operators::OpEq; +use zokrates_ast::common::statements::LogStatement; +use zokrates_ast::common::{Span, WithSpan}; use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth}; -use zokrates_ast::typed::{self, Expr, Typed}; -use zokrates_ast::zir::IntoType as ZirIntoType; -use zokrates_ast::zir::{self, Folder, Id, Select}; +use zokrates_ast::typed::{self, Basic, Expr, Typed}; +use zokrates_ast::zir::{self, Expr as ZirExpr, Folder, Id, MultipleDefinitionStatement, Select}; +use zokrates_ast::zir::{IntoType as ZirIntoType, SourceIdentifier}; use zokrates_field::Field; #[derive(Default)] @@ -18,22 +23,24 @@ fn flatten_identifier_rec<'ast>( ) -> Vec> { match ty { typed::ConcreteType::Int => unreachable!(), - typed::ConcreteType::FieldElement => vec![zir::Variable { - id: zir::Identifier::Source(id), - _type: zir::Type::FieldElement, - }], - typed::types::ConcreteType::Boolean => vec![zir::Variable { - id: zir::Identifier::Source(id), - _type: zir::Type::Boolean, - }], - typed::types::ConcreteType::Uint(bitwidth) => vec![zir::Variable { - id: zir::Identifier::Source(id), - _type: zir::Type::uint(bitwidth.to_usize()), - }], + typed::ConcreteType::FieldElement => vec![zir::Variable::new( + zir::Identifier::Source(id), + zir::Type::FieldElement, + )], + typed::types::ConcreteType::Boolean => vec![zir::Variable::new( + zir::Identifier::Source(id), + zir::Type::Boolean, + )], + typed::types::ConcreteType::Uint(bitwidth) => { + vec![zir::Variable::new( + zir::Identifier::Source(id), + zir::Type::uint(bitwidth.to_usize()), + )] + } typed::types::ConcreteType::Array(array_type) => (0..*array_type.size) .flat_map(|i| { flatten_identifier_rec( - zir::SourceIdentifier::Select(box id.clone(), i), + SourceIdentifier::Select(Box::new(id.clone()), i), &array_type.ty, ) }) @@ -42,7 +49,7 @@ fn flatten_identifier_rec<'ast>( .iter() .flat_map(|struct_member| { flatten_identifier_rec( - zir::SourceIdentifier::Member(box id.clone(), struct_member.id.clone()), + SourceIdentifier::Member(Box::new(id.clone()), struct_member.id.clone()), &struct_member.ty, ) }) @@ -52,7 +59,10 @@ fn flatten_identifier_rec<'ast>( .iter() .enumerate() .flat_map(|(i, ty)| { - flatten_identifier_rec(zir::SourceIdentifier::Element(box id.clone(), i as u32), ty) + flatten_identifier_rec( + SourceIdentifier::Element(Box::new(id.clone()), i as u32), + ty, + ) }) .collect(), } @@ -78,7 +88,7 @@ fn flatten_identifier_to_expression_rec<'ast, T: Field>( typed::ConcreteType::Array(array_type) => (0..*array_type.size) .flat_map(|i| { flatten_identifier_to_expression_rec( - zir::SourceIdentifier::Select(box id.clone(), i), + SourceIdentifier::Select(box id.clone(), i), &array_type.ty, ) }) @@ -87,7 +97,7 @@ fn flatten_identifier_to_expression_rec<'ast, T: Field>( .iter() .flat_map(|struct_member| { flatten_identifier_to_expression_rec( - zir::SourceIdentifier::Member(box id.clone(), struct_member.id.clone()), + SourceIdentifier::Member(box id.clone(), struct_member.id.clone()), &struct_member.ty, ) }) @@ -98,7 +108,7 @@ fn flatten_identifier_to_expression_rec<'ast, T: Field>( .enumerate() .flat_map(|(i, ty)| { flatten_identifier_to_expression_rec( - zir::SourceIdentifier::Element(box id.clone(), i as u32), + SourceIdentifier::Element(box id.clone(), i as u32), ty, ) }) @@ -192,24 +202,30 @@ impl<'ast, T: Field> Flattener { &mut self, p: typed::DeclarationParameter<'ast, T>, ) -> Vec> { + let span = p.get_span(); + let private = p.private; self.fold_variable(zokrates_ast::typed::variable::try_from_g_variable(p.id).unwrap()) .into_iter() - .map(|v| zir::Parameter { id: v, private }) + .map(|v| zir::Parameter::new(v, private).span(span)) .collect() } fn fold_name(&mut self, n: typed::Identifier<'ast>) -> zir::SourceIdentifier<'ast> { - zir::SourceIdentifier::Basic(n) + SourceIdentifier::Basic(n) } fn fold_variable(&mut self, v: typed::Variable<'ast, T>) -> Vec> { + let span = v.get_span(); let ty = v.get_type(); let id = self.fold_name(v.id); let ty = typed::types::ConcreteType::try_from(ty).unwrap(); flatten_identifier_rec(id, &ty) + .into_iter() + .map(|v| v.span(span)) + .collect() } fn fold_assignee(&mut self, a: typed::TypedAssignee<'ast, T>) -> Vec> { @@ -224,7 +240,7 @@ impl<'ast, T: Field> Flattener { match i.as_inner() { typed::UExpressionInner::Value(index) => { - a[*index as usize * count..(*index as usize + 1) * count].to_vec() + a[index.value as usize * count..(index.value as usize + 1) * count].to_vec() } i => unreachable!("index {:?} not allowed, should be a constant", i), } @@ -366,6 +382,20 @@ impl<'ast, T: Field> Flattener { fold_conditional_expression(self, statements_buffer, c) } + fn fold_binary_expression< + L: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + R: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + E: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + Op, + >( + &mut self, + statements_buffer: &mut Vec>, + e: BinaryExpression, + ) -> BinaryExpression + { + fold_binary_expression(self, statements_buffer, e) + } + fn fold_member_expression( &mut self, statements_buffer: &mut Vec>, @@ -393,7 +423,7 @@ impl<'ast, T: Field> Flattener { fn fold_eq_expression>( &mut self, statements_buffer: &mut Vec>, - eq: typed::EqExpression, + eq: BinaryExpression>, ) -> zir::BooleanExpression<'ast, T> { fold_eq_expression(self, statements_buffer, eq) } @@ -488,22 +518,23 @@ pub struct ArgumentFinder<'ast, T> { impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> { fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec> { match s { - zir::ZirStatement::Definition(assignee, expr) => { - let assignee = self.fold_assignee(assignee); - let expr = self.fold_expression(expr); + zir::ZirStatement::Definition(s) => { + let assignee = self.fold_assignee(s.assignee); + let expr = self.fold_expression(s.rhs); self.identifiers.remove(&assignee.id); - vec![zir::ZirStatement::Definition(assignee, expr)] + vec![zir::ZirStatement::definition(assignee, expr)] } - zir::ZirStatement::MultipleDefinition(assignees, list) => { - let assignees: Vec> = assignees + zir::ZirStatement::MultipleDefinition(s) => { + let assignees: Vec> = s + .assignees .into_iter() .map(|v| self.fold_assignee(v)) .collect(); - let list = self.fold_expression_list(list); + let list = self.fold_expression_list(s.rhs); for a in &assignees { self.identifiers.remove(&a.id); } - vec![zir::ZirStatement::MultipleDefinition(assignees, list)] + vec![zir::ZirStatement::multiple_definition(assignees, list)] } s => zir::folder::fold_statement(self, s), } @@ -525,12 +556,14 @@ fn fold_assembly_statement<'ast, T: Field>( statements_buffer: &mut Vec>, s: typed::TypedAssemblyStatement<'ast, T>, ) -> zir::ZirAssemblyStatement<'ast, T> { + let span = s.get_span(); + match s { - typed::TypedAssemblyStatement::Assignment(a, e) => { + typed::TypedAssemblyStatement::Assignment(s) => { let mut statements_buffer: Vec> = vec![]; - let a = f.fold_assignee(a); - let e = f.fold_expression(&mut statements_buffer, e); - statements_buffer.push(zir::ZirStatement::Return(e)); + let a = f.fold_assignee(s.assignee); + let e = f.fold_expression(&mut statements_buffer, s.expression); + statements_buffer.push(zir::ZirStatement::ret(e)); let mut finder = ArgumentFinder::default(); let mut statements_buffer: Vec> = statements_buffer @@ -547,22 +580,22 @@ fn fold_assembly_statement<'ast, T: Field>( arguments: finder .identifiers .into_iter() - .map(|(id, ty)| zir::Parameter { - id: zir::Variable::with_id_and_type(id, ty), - private: true, + .map(|(id, ty)| { + zir::Parameter::private(zir::Variable::with_id_and_type(id, ty)) }) .collect(), statements: statements_buffer, }; - zir::ZirAssemblyStatement::Assignment(a, function) + zir::ZirAssemblyStatement::assignment(a, function) } - typed::TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { - let lhs = f.fold_field_expression(statements_buffer, lhs); - let rhs = f.fold_field_expression(statements_buffer, rhs); - zir::ZirAssemblyStatement::Constraint(lhs, rhs, metadata) + typed::TypedAssemblyStatement::Constraint(s) => { + let lhs = f.fold_field_expression(statements_buffer, s.left); + let rhs = f.fold_field_expression(statements_buffer, s.right); + zir::ZirAssemblyStatement::constraint(lhs, rhs, s.metadata) } } + .span(span) } fn fold_statement<'ast, T: Field>( @@ -570,57 +603,68 @@ fn fold_statement<'ast, T: Field>( statements_buffer: &mut Vec>, s: typed::TypedStatement<'ast, T>, ) { + let span = s.get_span(); + let res = match s { - typed::TypedStatement::Assembly(statements) => { - let statements = statements + typed::TypedStatement::Return(s) => vec![zir::ZirStatement::ret( + f.fold_expression(statements_buffer, s.inner), + )], + typed::TypedStatement::Assembly(s) => { + let statements = s + .inner .into_iter() .map(|s| f.fold_assembly_statement(statements_buffer, s)) .collect(); - vec![zir::ZirStatement::Assembly(statements)] + vec![zir::ZirStatement::assembly(statements)] } - typed::TypedStatement::Return(expression) => vec![zir::ZirStatement::Return( - f.fold_expression(statements_buffer, expression), - )], - typed::TypedStatement::Definition(a, typed::DefinitionRhs::Expression(e)) => { + typed::TypedStatement::Definition(typed::DefinitionStatement { + assignee: a, + rhs: typed::DefinitionRhs::Expression(e), + .. + }) => { let a = f.fold_assignee(a); let e = f.fold_expression(statements_buffer, e); assert_eq!(a.len(), e.len()); a.into_iter() .zip(e.into_iter()) - .map(|(a, e)| zir::ZirStatement::Definition(a, e)) + .map(|(a, e)| zir::ZirStatement::definition(a, e)) .collect() } - typed::TypedStatement::Assertion(e, error) => { - let e = f.fold_boolean_expression(statements_buffer, e); - let error = match error { + typed::TypedStatement::Assertion(s) => { + let e = f.fold_boolean_expression(statements_buffer, s.expression); + let error = match s.error { typed::RuntimeError::SourceAssertion(metadata) => { zir::RuntimeError::SourceAssertion(metadata) } typed::RuntimeError::SelectRangeCheck => zir::RuntimeError::SelectRangeCheck, typed::RuntimeError::DivisionByZero => zir::RuntimeError::DivisionByZero, }; - vec![zir::ZirStatement::Assertion(e, error)] + vec![zir::ZirStatement::assertion(e, error)] } - typed::TypedStatement::Definition( - assignee, - typed::DefinitionRhs::EmbedCall(embed_call), - ) => { + typed::TypedStatement::Definition(typed::DefinitionStatement { + assignee: a, + rhs: typed::DefinitionRhs::EmbedCall(embed_call), + .. + }) => { vec![zir::ZirStatement::MultipleDefinition( - f.fold_assignee(assignee), - zir::ZirExpressionList::EmbedCall( - embed_call.embed, - embed_call.generics, - embed_call - .arguments - .into_iter() - .flat_map(|a| f.fold_expression(statements_buffer, a)) - .collect(), + MultipleDefinitionStatement::new( + f.fold_assignee(a), + zir::ZirExpressionList::EmbedCall( + embed_call.embed, + embed_call.generics, + embed_call + .arguments + .into_iter() + .flat_map(|a| f.fold_expression(statements_buffer, a)) + .collect(), + ), ), )] } - typed::TypedStatement::Log(l, e) => vec![zir::ZirStatement::Log( - l, - e.into_iter() + typed::TypedStatement::Log(e) => vec![zir::ZirStatement::Log(LogStatement::new( + e.format_string, + e.expressions + .into_iter() .map(|e| { ( e.get_type().try_into().unwrap(), @@ -628,11 +672,11 @@ fn fold_statement<'ast, T: Field>( ) }) .collect(), - )], + ))], typed::TypedStatement::For(..) => unreachable!(), }; - statements_buffer.extend(res); + statements_buffer.extend(res.into_iter().map(|s| s.span(span))); } fn fold_array_expression_inner<'ast, T: Field>( @@ -671,31 +715,32 @@ fn fold_array_expression_inner<'ast, T: Field>( typed::ArrayExpressionInner::Select(select) => { f.fold_select_expression(statements_buffer, select) } - typed::ArrayExpressionInner::Slice(box array, box from, box to) => { - let array = f.fold_array_expression(statements_buffer, array); - let from = f.fold_uint_expression(statements_buffer, from); - let to = f.fold_uint_expression(statements_buffer, to); + typed::ArrayExpressionInner::Slice(e) => { + let array = f.fold_array_expression(statements_buffer, *e.array); + let from = f.fold_uint_expression(statements_buffer, *e.from); + let to = f.fold_uint_expression(statements_buffer, *e.to); match (from.into_inner(), to.into_inner()) { (zir::UExpressionInner::Value(from), zir::UExpressionInner::Value(to)) => { - assert_eq!(size, to.saturating_sub(from) as u32); + assert_eq!(size, to.value.saturating_sub(from.value) as u32); let element_size = ty.get_primitive_count(); - let start = from as usize * element_size; - let end = to as usize * element_size; + let start = from.value as usize * element_size; + let end = to.value as usize * element_size; array[start..end].to_vec() } _ => unreachable!(), } } - typed::ArrayExpressionInner::Repeat(box e, box count) => { - let e = f.fold_expression(statements_buffer, e); - let count = f.fold_uint_expression(statements_buffer, count); + typed::ArrayExpressionInner::Repeat(r) => { + let e = f.fold_expression(statements_buffer, *r.e); + let count = f.fold_uint_expression(statements_buffer, *r.count); match count.into_inner() { - zir::UExpressionInner::Value(count) => { - vec![e; count as usize].into_iter().flatten().collect() - } + zir::UExpressionInner::Value(count) => vec![e; count.value as usize] + .into_iter() + .flatten() + .collect(), _ => unreachable!(), } } @@ -860,7 +905,7 @@ fn fold_select_expression<'ast, T: Field, E>( match index.as_inner() { zir::UExpressionInner::Value(v) => { - let v = *v as usize; + let v = v.value as usize; array[v * size..(v + 1) * size].to_vec() } @@ -925,6 +970,9 @@ fn fold_conditional_expression<'ast, T: Field, E: Flatten<'ast, T>>( statements_buffer: &mut Vec>, c: typed::ConditionalExpression<'ast, T, E>, ) -> Vec> { + let span = c.get_span(); + let condition_span = c.condition.get_span(); + let mut consequence_statements = vec![]; let mut alternative_statements = vec![]; @@ -935,11 +983,14 @@ fn fold_conditional_expression<'ast, T: Field, E: Flatten<'ast, T>>( assert_eq!(consequence.len(), alternative.len()); if !consequence_statements.is_empty() || !alternative_statements.is_empty() { - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); + statements_buffer.push( + zir::ZirStatement::if_else( + condition.clone().span(condition_span), + consequence_statements, + alternative_statements, + ) + .span(span), + ); } use zokrates_ast::zir::Conditional; @@ -949,19 +1000,46 @@ fn fold_conditional_expression<'ast, T: Field, E: Flatten<'ast, T>>( .zip(alternative.into_iter()) .map(|(c, a)| match (c, a) { (zir::ZirExpression::FieldElement(c), zir::ZirExpression::FieldElement(a)) => { - zir::FieldElementExpression::conditional(condition.clone(), c, a).into() + zir::FieldElementExpression::conditional(condition.clone(), c, a) + .span(span) + .into() } (zir::ZirExpression::Boolean(c), zir::ZirExpression::Boolean(a)) => { - zir::BooleanExpression::conditional(condition.clone(), c, a).into() + zir::BooleanExpression::conditional(condition.clone(), c, a) + .span(span) + .into() } (zir::ZirExpression::Uint(c), zir::ZirExpression::Uint(a)) => { - zir::UExpression::conditional(condition.clone(), c, a).into() + zir::UExpression::conditional(condition.clone(), c, a) + .span(span) + .into() } _ => unreachable!(), }) .collect() } +fn fold_binary_expression< + 'ast, + T: Field, + L: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + R: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + E: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + Op, +>( + f: &mut Flattener, + statements_buffer: &mut Vec>, + e: BinaryExpression, +) -> BinaryExpression { + let left_span = e.left.get_span(); + let right_span = e.left.get_span(); + + let left: L::ZirExpressionType = e.left.flatten(f, statements_buffer).pop().unwrap().into(); + let right: R::ZirExpressionType = e.right.flatten(f, statements_buffer).pop().unwrap().into(); + + BinaryExpression::new(left.span(left_span), right.span(right_span)).span(e.span) +} + fn fold_identifier_expression<'ast, T: Field, E: Expr<'ast, T>>( f: &mut Flattener, ty: E::ConcreteTy, @@ -975,77 +1053,56 @@ fn fold_field_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::FieldElementExpression<'ast, T>, ) -> zir::FieldElementExpression<'ast, T> { + let span = e.get_span(); + match e { - typed::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n), + typed::FieldElementExpression::Value(n) => zir::FieldElementExpression::Value(n), typed::FieldElementExpression::Identifier(id) => f .fold_identifier_expression(typed::ConcreteType::FieldElement, id) .pop() .unwrap() .try_into() .unwrap(), - typed::FieldElementExpression::Add(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::FieldElementExpression::Add(box e1, box e2) - } - typed::FieldElementExpression::Sub(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::FieldElementExpression::Sub(box e1, box e2) - } - typed::FieldElementExpression::Mult(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::FieldElementExpression::Mult(box e1, box e2) - } - typed::FieldElementExpression::Div(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::FieldElementExpression::Div(box e1, box e2) - } - typed::FieldElementExpression::Pow(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_uint_expression(statements_buffer, e2); - zir::FieldElementExpression::Pow(box e1, box e2) - } - typed::FieldElementExpression::Neg(box e) => { - let e = f.fold_field_expression(statements_buffer, e); - - zir::FieldElementExpression::Sub( - box zir::FieldElementExpression::Number(T::zero()), - box e, - ) + typed::FieldElementExpression::Add(e) => { + zir::FieldElementExpression::Add(f.fold_binary_expression(statements_buffer, e)) } - typed::FieldElementExpression::Pos(box e) => f.fold_field_expression(statements_buffer, e), - typed::FieldElementExpression::Xor(box left, box right) => { - let left = f.fold_field_expression(statements_buffer, left); - let right = f.fold_field_expression(statements_buffer, right); - - zir::FieldElementExpression::Xor(box left, box right) + typed::FieldElementExpression::Sub(e) => { + zir::FieldElementExpression::Sub(f.fold_binary_expression(statements_buffer, e)) } - typed::FieldElementExpression::And(box left, box right) => { - let left = f.fold_field_expression(statements_buffer, left); - let right = f.fold_field_expression(statements_buffer, right); - - zir::FieldElementExpression::And(box left, box right) + typed::FieldElementExpression::Mult(e) => { + zir::FieldElementExpression::Mult(f.fold_binary_expression(statements_buffer, e)) } - typed::FieldElementExpression::Or(box left, box right) => { - let left = f.fold_field_expression(statements_buffer, left); - let right = f.fold_field_expression(statements_buffer, right); - - zir::FieldElementExpression::Or(box left, box right) + typed::FieldElementExpression::Div(e) => { + zir::FieldElementExpression::Div(f.fold_binary_expression(statements_buffer, e)) } - typed::FieldElementExpression::LeftShift(box e, box by) => { - let e = f.fold_field_expression(statements_buffer, e); - let by = f.fold_uint_expression(statements_buffer, by); - - zir::FieldElementExpression::LeftShift(box e, box by) + typed::FieldElementExpression::Pow(e) => { + zir::FieldElementExpression::Pow(f.fold_binary_expression(statements_buffer, e)) } - typed::FieldElementExpression::RightShift(box e, box by) => { - let e = f.fold_field_expression(statements_buffer, e); - let by = f.fold_uint_expression(statements_buffer, by); + typed::FieldElementExpression::And(e) => { + zir::FieldElementExpression::And(f.fold_binary_expression(statements_buffer, e)) + } + typed::FieldElementExpression::Or(e) => { + zir::FieldElementExpression::Or(f.fold_binary_expression(statements_buffer, e)) + } + typed::FieldElementExpression::Xor(e) => { + zir::FieldElementExpression::Xor(f.fold_binary_expression(statements_buffer, e)) + } + typed::FieldElementExpression::LeftShift(e) => { + zir::FieldElementExpression::LeftShift(f.fold_binary_expression(statements_buffer, e)) + } + typed::FieldElementExpression::RightShift(e) => { + zir::FieldElementExpression::RightShift(f.fold_binary_expression(statements_buffer, e)) + } + typed::FieldElementExpression::Neg(e) => { + let e = f.fold_field_expression(statements_buffer, *e.inner); - zir::FieldElementExpression::RightShift(box e, box by) + zir::FieldElementExpression::sub( + zir::FieldElementExpression::Value(ValueExpression::new(T::zero())), + e, + ) + } + typed::FieldElementExpression::Pos(e) => { + f.fold_field_expression(statements_buffer, *e.inner) } typed::FieldElementExpression::Conditional(c) => f .fold_conditional_expression(statements_buffer, c) @@ -1080,6 +1137,7 @@ fn fold_field_expression<'ast, T: Field>( f.fold_field_expression(statements_buffer, *block.value) } } + .span(span) } // util function to output a boolean expression representing the equality of two lists of ZirExpression. @@ -1088,27 +1146,32 @@ fn fold_field_expression<'ast, T: Field>( fn conjunction_tree<'ast, T: Field>( v: &[zir::ZirExpression<'ast, T>], w: &[zir::ZirExpression<'ast, T>], + span: Option, ) -> zir::BooleanExpression<'ast, T> { assert_eq!(v.len(), w.len()); match v.len() { - 0 => zir::BooleanExpression::Value(true), + 0 => zir::BooleanExpression::value(true), 1 => match (v[0].clone(), w[0].clone()) { (zir::ZirExpression::Boolean(v), zir::ZirExpression::Boolean(w)) => { - zir::BooleanExpression::BoolEq(box v, box w) + zir::BooleanExpression::bool_eq(v, w).span(span) } (zir::ZirExpression::FieldElement(v), zir::ZirExpression::FieldElement(w)) => { - zir::BooleanExpression::FieldEq(box v, box w) + zir::BooleanExpression::field_eq(v, w).span(span) } (zir::ZirExpression::Uint(v), zir::ZirExpression::Uint(w)) => { - zir::BooleanExpression::UintEq(box v, box w) + zir::BooleanExpression::uint_eq(v, w).span(span) } _ => unreachable!(), }, n => { let (x0, y0) = v.split_at(n / 2); let (x1, y1) = w.split_at(n / 2); - zir::BooleanExpression::And(box conjunction_tree(x0, x1), box conjunction_tree(y0, y1)) + zir::BooleanExpression::bitand( + conjunction_tree(x0, x1, span), + conjunction_tree(y0, y1, span), + ) + .span(span) } } } @@ -1116,11 +1179,18 @@ fn conjunction_tree<'ast, T: Field>( fn fold_eq_expression<'ast, T: Field, E: Flatten<'ast, T>>( f: &mut Flattener, statements_buffer: &mut Vec>, - e: typed::EqExpression, + e: zokrates_ast::common::expressions::BinaryExpression< + OpEq, + E, + E, + typed::BooleanExpression<'ast, T>, + >, ) -> zir::BooleanExpression<'ast, T> { + let span = e.get_span(); + let left = e.left.flatten(f, statements_buffer); let right = e.right.flatten(f, statements_buffer); - conjunction_tree(&left, &right) + conjunction_tree(&left, &right, span) } fn fold_boolean_expression<'ast, T: Field>( @@ -1128,6 +1198,8 @@ fn fold_boolean_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::BooleanExpression<'ast, T>, ) -> zir::BooleanExpression<'ast, T> { + let span = e.get_span(); + match e { typed::BooleanExpression::Block(block) => { block @@ -1149,59 +1221,26 @@ fn fold_boolean_expression<'ast, T: Field>( typed::BooleanExpression::StructEq(e) => f.fold_eq_expression(statements_buffer, e), typed::BooleanExpression::TupleEq(e) => f.fold_eq_expression(statements_buffer, e), typed::BooleanExpression::UintEq(e) => f.fold_eq_expression(statements_buffer, e), - typed::BooleanExpression::FieldLt(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::BooleanExpression::FieldLt(box e1, box e2) - } - typed::BooleanExpression::FieldLe(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::BooleanExpression::FieldLe(box e1, box e2) - } - typed::BooleanExpression::FieldGt(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::BooleanExpression::FieldLt(box e2, box e1) - } - typed::BooleanExpression::FieldGe(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::BooleanExpression::FieldLe(box e2, box e1) - } - typed::BooleanExpression::UintLt(box e1, box e2) => { - let e1 = f.fold_uint_expression(statements_buffer, e1); - let e2 = f.fold_uint_expression(statements_buffer, e2); - zir::BooleanExpression::UintLt(box e1, box e2) - } - typed::BooleanExpression::UintLe(box e1, box e2) => { - let e1 = f.fold_uint_expression(statements_buffer, e1); - let e2 = f.fold_uint_expression(statements_buffer, e2); - zir::BooleanExpression::UintLe(box e1, box e2) - } - typed::BooleanExpression::UintGt(box e1, box e2) => { - let e1 = f.fold_uint_expression(statements_buffer, e1); - let e2 = f.fold_uint_expression(statements_buffer, e2); - zir::BooleanExpression::UintLt(box e2, box e1) - } - typed::BooleanExpression::UintGe(box e1, box e2) => { - let e1 = f.fold_uint_expression(statements_buffer, e1); - let e2 = f.fold_uint_expression(statements_buffer, e2); - zir::BooleanExpression::UintLe(box e2, box e1) - } - typed::BooleanExpression::Or(box e1, box e2) => { - let e1 = f.fold_boolean_expression(statements_buffer, e1); - let e2 = f.fold_boolean_expression(statements_buffer, e2); - zir::BooleanExpression::Or(box e1, box e2) - } - typed::BooleanExpression::And(box e1, box e2) => { - let e1 = f.fold_boolean_expression(statements_buffer, e1); - let e2 = f.fold_boolean_expression(statements_buffer, e2); - zir::BooleanExpression::And(box e1, box e2) - } - typed::BooleanExpression::Not(box e) => { - let e = f.fold_boolean_expression(statements_buffer, e); - zir::BooleanExpression::Not(box e) + typed::BooleanExpression::FieldLt(e) => { + zir::BooleanExpression::FieldLt(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::FieldLe(e) => { + zir::BooleanExpression::FieldLe(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::UintLt(e) => { + zir::BooleanExpression::UintLt(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::UintLe(e) => { + zir::BooleanExpression::UintLe(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::Or(e) => { + zir::BooleanExpression::Or(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::And(e) => { + zir::BooleanExpression::And(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::Not(e) => { + zir::BooleanExpression::not(f.fold_boolean_expression(statements_buffer, *e.inner)) } typed::BooleanExpression::Conditional(c) => f .fold_conditional_expression(statements_buffer, c) @@ -1229,6 +1268,7 @@ fn fold_boolean_expression<'ast, T: Field>( .try_into() .unwrap(), } + .span(span) } fn fold_uint_expression<'ast, T: Field>( @@ -1236,8 +1276,10 @@ fn fold_uint_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::UExpression<'ast, T>, ) -> zir::UExpression<'ast, T> { + let span = e.get_span(); f.fold_uint_expression_inner(statements_buffer, e.bitwidth, e.inner) .annotate(e.bitwidth.to_usize()) + .span(span) } fn fold_uint_expression_inner<'ast, T: Field>( @@ -1246,6 +1288,8 @@ fn fold_uint_expression_inner<'ast, T: Field>( bitwidth: UBitwidth, e: typed::UExpressionInner<'ast, T>, ) -> zir::UExpressionInner<'ast, T> { + let span = e.get_span(); + match e { typed::UExpressionInner::Block(block) => { block @@ -1261,95 +1305,83 @@ fn fold_uint_expression_inner<'ast, T: Field>( .unwrap() .into_inner() } - typed::UExpressionInner::Add(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Add(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Add(box left, box right) + zir::UExpression::add(left, right).into_inner() } - typed::UExpressionInner::Sub(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Sub(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Sub(box left, box right) + zir::UExpression::sub(left, right).into_inner() } typed::UExpressionInner::FloorSub(..) => unreachable!(), - typed::UExpressionInner::Mult(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Mult(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Mult(box left, box right) + zir::UExpression::mult(left, right).into_inner() } - typed::UExpressionInner::Div(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Div(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Div(box left, box right) + zir::UExpression::div(left, right).into_inner() } - typed::UExpressionInner::Rem(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Rem(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Rem(box left, box right) + zir::UExpression::rem(left, right).into_inner() } - typed::UExpressionInner::Xor(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Xor(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Xor(box left, box right) + zir::UExpression::xor(left, right).into_inner() } - typed::UExpressionInner::And(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::And(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::And(box left, box right) + zir::UExpression::and(left, right).into_inner() } - typed::UExpressionInner::Or(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Or(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Or(box left, box right) + zir::UExpression::or(left, right).into_inner() } - typed::UExpressionInner::LeftShift(box e, box by) => { - let e = f.fold_uint_expression(statements_buffer, e); - - let by = match by.as_inner() { - typed::UExpressionInner::Value(by) => by, - _ => unreachable!("static analysis should have made sure that this is constant"), - }; + typed::UExpressionInner::LeftShift(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::LeftShift(box e, *by as u32) + zir::UExpression::left_shift(left, right).into_inner() } - typed::UExpressionInner::RightShift(box e, box by) => { - let e = f.fold_uint_expression(statements_buffer, e); - - let by = match by.as_inner() { - typed::UExpressionInner::Value(by) => by, - _ => unreachable!("static analysis should have made sure that this is constant"), - }; + typed::UExpressionInner::RightShift(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::RightShift(box e, *by as u32) + zir::UExpression::right_shift(left, right).into_inner() } - typed::UExpressionInner::Not(box e) => { - let e = f.fold_uint_expression(statements_buffer, e); - - zir::UExpressionInner::Not(box e) + typed::UExpressionInner::Not(e) => { + zir::UExpression::not(f.fold_uint_expression(statements_buffer, *e.inner)).into_inner() } - typed::UExpressionInner::Neg(box e) => { - let bitwidth = e.bitwidth(); + typed::UExpressionInner::Neg(e) => { + let bitwidth = e.inner.bitwidth(); f.fold_uint_expression( statements_buffer, - typed::UExpressionInner::Value(0).annotate(bitwidth) - e, + typed::UExpression::value(0).annotate(bitwidth) - *e.inner, ) .into_inner() } - typed::UExpressionInner::Pos(box e) => { - let e = f.fold_uint_expression(statements_buffer, e); - - e.into_inner() - } + typed::UExpressionInner::Pos(e) => f + .fold_uint_expression(statements_buffer, *e.inner) + .into_inner(), typed::UExpressionInner::FunctionCall(..) => { unreachable!("function calls should have been removed") } @@ -1382,6 +1414,7 @@ fn fold_uint_expression_inner<'ast, T: Field>( .unwrap() .into_inner(), } + .span(span) } fn fold_function<'ast, T: Field>( @@ -1418,6 +1451,7 @@ fn fold_array_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::ArrayExpression<'ast, T>, ) -> Vec> { + let span = e.get_span(); let size: u32 = e.size().try_into().unwrap(); f.fold_array_expression_inner( statements_buffer, @@ -1425,6 +1459,9 @@ fn fold_array_expression<'ast, T: Field>( size, e.into_inner(), ) + .into_iter() + .map(|e| e.span(span)) + .collect() } fn fold_struct_expression<'ast, T: Field>( @@ -1432,11 +1469,15 @@ fn fold_struct_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::StructExpression<'ast, T>, ) -> Vec> { + let span = e.get_span(); f.fold_struct_expression_inner( statements_buffer, typed::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(), e.into_inner(), ) + .into_iter() + .map(|e| e.span(span)) + .collect() } fn fold_tuple_expression<'ast, T: Field>( @@ -1444,11 +1485,15 @@ fn fold_tuple_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::TupleExpression<'ast, T>, ) -> Vec> { + let span = e.get_span(); f.fold_tuple_expression_inner( statements_buffer, typed::types::ConcreteTupleType::try_from(e.ty().clone()).unwrap(), e.into_inner(), ) + .into_iter() + .map(|e| e.span(span)) + .collect() } fn fold_program<'ast, T: Field>( @@ -1469,5 +1514,6 @@ fn fold_program<'ast, T: Field>( zir::ZirProgram { main: f.fold_function(main_function), + module_map: p.module_map, } } diff --git a/zokrates_analysis/src/log_ignorer.rs b/zokrates_analysis/src/log_ignorer.rs index a5a8cd54e..2c8aa5a8c 100644 --- a/zokrates_analysis/src/log_ignorer.rs +++ b/zokrates_analysis/src/log_ignorer.rs @@ -1,4 +1,4 @@ -use zokrates_ast::typed::{folder::*, TypedProgram, TypedStatement}; +use zokrates_ast::typed::{folder::*, LogStatement, TypedProgram, TypedStatement}; use zokrates_field::Field; #[derive(Default)] @@ -11,10 +11,7 @@ impl LogIgnorer { } impl<'ast, T: Field> Folder<'ast, T> for LogIgnorer { - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { - match s { - TypedStatement::Log(..) => vec![], - s => fold_statement(self, s), - } + fn fold_log_statement(&mut self, _: LogStatement<'ast, T>) -> Vec> { + vec![] } } diff --git a/zokrates_analysis/src/out_of_bounds.rs b/zokrates_analysis/src/out_of_bounds.rs index 7cdc88b6c..5be99af50 100644 --- a/zokrates_analysis/src/out_of_bounds.rs +++ b/zokrates_analysis/src/out_of_bounds.rs @@ -53,14 +53,14 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for OutOfBoundsChecker { let size = match array.get_type() { Type::Array(array_ty) => match array_ty.size.as_inner() { - UExpressionInner::Value(size) => *size, + UExpressionInner::Value(size) => size.value, _ => unreachable!(), }, _ => unreachable!(), }; match index.as_inner() { - UExpressionInner::Value(i) if i >= &size => Err(Error(format!( + UExpressionInner::Value(i) if i.value >= size => Err(Error(format!( "Out of bounds write to `{}` because `{}` has size {}", TypedAssignee::Select(box array.clone(), box index), array, diff --git a/zokrates_analysis/src/panic_extractor.rs b/zokrates_analysis/src/panic_extractor.rs index d17675a83..312abab05 100644 --- a/zokrates_analysis/src/panic_extractor.rs +++ b/zokrates_analysis/src/panic_extractor.rs @@ -1,6 +1,11 @@ -use zokrates_ast::zir::{ - folder::*, BooleanExpression, Conditional, ConditionalExpression, ConditionalOrExpression, - FieldElementExpression, RuntimeError, UBitwidth, UExpressionInner, ZirProgram, ZirStatement, +use std::ops::*; +use zokrates_ast::{ + common::{expressions::BinaryExpression, Fold, WithSpan}, + zir::{ + folder::*, BooleanExpression, Conditional, ConditionalExpression, ConditionalOrExpression, + Expr, FieldElementExpression, IfElseStatement, RuntimeError, UBitwidth, UExpression, + UExpressionInner, ZirProgram, ZirStatement, + }, }; use zokrates_field::Field; @@ -18,67 +23,85 @@ impl<'ast, T: Field> PanicExtractor<'ast, T> { } impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> { - fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { + fn fold_if_else_statement( + &mut self, + s: IfElseStatement<'ast, T>, + ) -> Vec> { + let condition = self.fold_boolean_expression(s.condition); + let mut consequence_extractor = Self::default(); + let consequence = s + .consequence + .into_iter() + .flat_map(|s| consequence_extractor.fold_statement(s)) + .collect(); + assert!(consequence_extractor.panic_buffer.is_empty()); + let mut alternative_extractor = Self::default(); + let alternative = s + .alternative + .into_iter() + .flat_map(|s| alternative_extractor.fold_statement(s)) + .collect(); + assert!(alternative_extractor.panic_buffer.is_empty()); + + self.panic_buffer + .drain(..) + .chain(std::iter::once(ZirStatement::if_else( + condition, + consequence, + alternative, + ))) + .collect() + } + + fn fold_statement_cases(&mut self, s: ZirStatement<'ast, T>) -> Vec> { match s { - ZirStatement::IfElse(condition, consequence, alternative) => { - let condition = self.fold_boolean_expression(condition); - let mut consequence_extractor = Self::default(); - let consequence = consequence - .into_iter() - .flat_map(|s| consequence_extractor.fold_statement(s)) - .collect(); - assert!(consequence_extractor.panic_buffer.is_empty()); - let mut alternative_extractor = Self::default(); - let alternative = alternative - .into_iter() - .flat_map(|s| alternative_extractor.fold_statement(s)) - .collect(); - assert!(alternative_extractor.panic_buffer.is_empty()); - - self.panic_buffer - .drain(..) - .chain(std::iter::once(ZirStatement::IfElse( - condition, - consequence, - alternative, - ))) - .collect() - } + ZirStatement::IfElse(s) => self.fold_if_else_statement(s), s => { - let s = fold_statement(self, s); + let s = fold_statement_cases(self, s); self.panic_buffer.drain(..).chain(s).collect() } } } - fn fold_field_expression( + fn fold_field_expression_cases( &mut self, e: FieldElementExpression<'ast, T>, ) -> FieldElementExpression<'ast, T> { + let span = e.get_span(); + match e { - FieldElementExpression::Div(box n, box d) => { - let n = self.fold_field_expression(n); - let d = self.fold_field_expression(d); - self.panic_buffer.push(ZirStatement::Assertion( - BooleanExpression::Not(box BooleanExpression::FieldEq( - box d.clone(), - box FieldElementExpression::Number(T::zero()), - )), - RuntimeError::DivisionByZero, - )); - FieldElementExpression::Div(box n, box d) + FieldElementExpression::Div(e) => { + let n = self.fold_field_expression(*e.left); + let d = self.fold_field_expression(*e.right); + self.panic_buffer.push( + ZirStatement::assertion( + BooleanExpression::not( + BooleanExpression::field_eq( + d.clone().span(span), + FieldElementExpression::value(T::zero()).span(span), + ) + .span(span), + ) + .span(span), + RuntimeError::DivisionByZero, + ) + .span(span), + ); + FieldElementExpression::div(n, d) } - e => fold_field_expression(self, e), + e => fold_field_expression_cases(self, e), } } fn fold_conditional_expression< - E: zokrates_ast::zir::Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>, + E: zokrates_ast::zir::Expr<'ast, T> + Fold + Conditional<'ast, T>, >( &mut self, _: &E::Ty, e: ConditionalExpression<'ast, T, E>, ) -> ConditionalOrExpression<'ast, T, E> { + let span = e.get_span(); + let condition = self.fold_boolean_expression(*e.condition); let mut consequence_extractor = Self::default(); let consequence = e.consequence.fold(&mut consequence_extractor); @@ -89,62 +112,74 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> { let alternative_panics: Vec<_> = alternative_extractor.panic_buffer.drain(..).collect(); if !(consequence_panics.is_empty() && alternative_panics.is_empty()) { - self.panic_buffer.push(ZirStatement::IfElse( - condition.clone(), - consequence_panics, - alternative_panics, - )); + self.panic_buffer.push( + ZirStatement::if_else(condition.clone(), consequence_panics, alternative_panics) + .span(span), + ); } - ConditionalOrExpression::Conditional(ConditionalExpression::new( - condition, - consequence, - alternative, - )) + ConditionalOrExpression::Conditional( + ConditionalExpression::new(condition, consequence, alternative).span(span), + ) } - fn fold_uint_expression_inner( + fn fold_uint_expression_cases( &mut self, b: UBitwidth, e: UExpressionInner<'ast, T>, ) -> UExpressionInner<'ast, T> { + let span = e.get_span(); + match e { - UExpressionInner::Div(box n, box d) => { - let n = self.fold_uint_expression(n); - let d = self.fold_uint_expression(d); - self.panic_buffer.push(ZirStatement::Assertion( - BooleanExpression::Not(box BooleanExpression::UintEq( - box d.clone(), - box UExpressionInner::Value(0).annotate(b), - )), - RuntimeError::DivisionByZero, - )); - UExpressionInner::Div(box n, box d) + UExpressionInner::Div(e) => { + let n = self.fold_uint_expression(*e.left); + let d = self.fold_uint_expression(*e.right); + self.panic_buffer.push( + ZirStatement::assertion( + BooleanExpression::not( + BooleanExpression::uint_eq( + d.clone().span(span), + UExpression::value(0).annotate(b).span(span), + ) + .span(span), + ) + .span(span), + RuntimeError::DivisionByZero, + ) + .span(span), + ); + UExpression::div(n, d).into_inner() } - e => fold_uint_expression_inner(self, b, e), + e => fold_uint_expression_cases(self, b, e), } } - fn fold_boolean_expression( + fn fold_boolean_expression_cases( &mut self, e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { match e { // constant range checks are complete, so no panic needs to be extracted - e @ BooleanExpression::FieldLt(box FieldElementExpression::Number(_), _) - | e @ BooleanExpression::FieldLt(_, box FieldElementExpression::Number(_)) => { - fold_boolean_expression(self, e) - } - BooleanExpression::FieldLt(box left, box right) => { - let left = self.fold_field_expression(left); - let right = self.fold_field_expression(right); + e @ BooleanExpression::FieldLt(BinaryExpression { + left: box FieldElementExpression::Value(_), + .. + }) + | e @ BooleanExpression::FieldLt(BinaryExpression { + right: box FieldElementExpression::Value(_), + .. + }) => fold_boolean_expression_cases(self, e), + BooleanExpression::FieldLt(e) => { + let span = e.get_span(); + + let left = self.fold_field_expression(*e.left); + let right = self.fold_field_expression(*e.right); let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to field elements whose difference is strictly smaller than 2**(bitwidth - 2) - let offset = FieldElementExpression::Number(T::from(2).pow(safe_width)); - let max = FieldElementExpression::Number(T::from(2).pow(safe_width + 1)); + let offset = FieldElementExpression::number(T::from(2).pow(safe_width)); + let max = FieldElementExpression::number(T::from(2).pow(safe_width + 1)); // `|left - right|` must be of bitwidth at most `safe_bitwidth` // this means we need to guarantee the following: `-2**(safe_width) < left - right < 2**(safe_width)` @@ -153,26 +188,40 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> { // we split this check in two: // `2**(safe_width) + left - right < 2**(safe_width + 1)` - self.panic_buffer.push(ZirStatement::Assertion( - BooleanExpression::FieldLt( - box FieldElementExpression::Add( - box offset.clone(), - box FieldElementExpression::Sub(box left.clone(), box right.clone()), - ), - box max, - ), - RuntimeError::IncompleteDynamicRange, - )); + self.panic_buffer.push( + ZirStatement::assertion( + BooleanExpression::field_lt( + offset.clone().span(span) + + FieldElementExpression::sub(left.clone(), right.clone()) + .span(span), + max, + ) + .span(span), + RuntimeError::IncompleteDynamicRange, + ) + .span(span), + ); // and // `2**(safe_width) + left - right != 0` - self.panic_buffer.push(ZirStatement::Assertion( - BooleanExpression::Not(box BooleanExpression::FieldEq( - box FieldElementExpression::Sub(box right.clone(), box left.clone()), - box offset, - )), - RuntimeError::IncompleteDynamicRange, - )); + self.panic_buffer.push( + ZirStatement::assertion( + BooleanExpression::not( + BooleanExpression::field_eq( + FieldElementExpression::sub( + right.clone().span(span), + left.clone().span(span), + ) + .span(span), + offset.span(span), + ) + .span(span), + ) + .span(span), + RuntimeError::IncompleteDynamicRange, + ) + .span(span), + ); // NOTE: // instead of splitting the check in two, we could have used a single `Lt` here, by simply subtracting 1 from all sides: @@ -182,9 +231,9 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> { // if we use `x - 1` here, we end up having to calculate the bits of both `x` and `x - 1`, which is expensive // by splitting, we can reuse the bits of `x` needed for this completeness check when computing the result - BooleanExpression::FieldLt(box left, box right) + BooleanExpression::field_lt(left, right) } - e => fold_boolean_expression(self, e), + e => fold_boolean_expression_cases(self, e), } } } diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index 1385a243f..06e8c5b43 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -12,8 +12,13 @@ use num_bigint::BigUint; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::fmt; +use std::ops::*; use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub}; -use zokrates_ast::common::FlatEmbed; +use zokrates_ast::common::expressions::{ + BinaryExpression, BinaryOrExpression, EqExpression, ValueExpression, +}; +use zokrates_ast::common::operators::OpEq; +use zokrates_ast::common::{FlatEmbed, ResultFold, WithSpan}; use zokrates_ast::typed::result_folder::*; use zokrates_ast::typed::types::Type; use zokrates_ast::typed::*; @@ -74,7 +79,7 @@ impl<'ast, T: Field> Propagator<'ast, T> { UExpressionInner::Value(n) => match constant { TypedExpression::Array(a) => match a.as_inner_mut() { ArrayExpressionInner::Value(value) => { - match value.0.get_mut(*n as usize) { + match value.value.get_mut(n.value as usize) { Some(TypedExpressionOrSpread::Expression(ref mut e)) => { Ok((variable, e)) } @@ -106,7 +111,7 @@ impl<'ast, T: Field> Propagator<'ast, T> { match c { TypedExpression::Struct(a) => match a.as_inner_mut() { - StructExpressionInner::Value(value) => Ok((v, &mut value[index])), + StructExpressionInner::Value(value) => Ok((v, &mut value.value[index])), _ => unreachable!("should be a struct value"), }, _ => unreachable!("should be a struct expression"), @@ -119,7 +124,7 @@ impl<'ast, T: Field> Propagator<'ast, T> { Ok((v, c)) => match c { TypedExpression::Tuple(a) => match a.as_inner_mut() { TupleExpressionInner::Value(value) => { - Ok((v, &mut value[*index as usize])) + Ok((v, &mut value.value[*index as usize])) } _ => unreachable!("should be a tuple value"), }, @@ -150,7 +155,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } }) .collect::>()?, - main: p.main, + ..p }) } @@ -168,7 +173,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } fn fold_conditional_expression< - E: Expr<'ast, T> + Conditional<'ast, T> + PartialEq + ResultFold<'ast, T>, + E: Expr<'ast, T> + Conditional<'ast, T> + PartialEq + ResultFold, >( &mut self, _: &E::Ty, @@ -180,10 +185,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { e.consequence.fold(self)?, e.alternative.fold(self)?, ) { - (BooleanExpression::Value(true), consequence, _) => { + (BooleanExpression::Value(v), consequence, _) if v.value => { ConditionalOrExpression::Expression(consequence.into_inner()) } - (BooleanExpression::Value(false), _, alternative) => { + (BooleanExpression::Value(v), _, alternative) if !v.value => { ConditionalOrExpression::Expression(alternative.into_inner()) } (_, consequence, alternative) if consequence == alternative => { @@ -196,105 +201,93 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { ) } - fn fold_assembly_statement( + fn fold_assembly_assignment( &mut self, - s: TypedAssemblyStatement<'ast, T>, + s: AssemblyAssignment<'ast, T>, ) -> Result>, Self::Error> { - match s { - TypedAssemblyStatement::Assignment(assignee, expr) => { - let assignee = self.fold_assignee(assignee)?; - let expr = self.fold_expression(expr)?; + let assignee = self.fold_assignee(s.assignee)?; + let expr = self.fold_expression(s.expression)?; - if expr.is_constant() { - match assignee { - TypedAssignee::Identifier(var) => { - let expr = expr.into_canonical_constant(); + if expr.is_constant() { + match assignee { + TypedAssignee::Identifier(var) => { + let expr = expr.into_canonical_constant(); - assert!(self.constants.insert(var.id, expr).is_none()); + assert!(self.constants.insert(var.id, expr).is_none()); - Ok(vec![]) - } - assignee => match self.try_get_constant_mut(&assignee) { - Ok((_, c)) => { - *c = expr.into_canonical_constant(); - Ok(vec![]) - } - Err(v) => match self.constants.remove(&v.id) { - // invalidate the cache for this identifier, and define the latest - // version of the constant in the program, if any - Some(c) => Ok(vec![ - TypedAssemblyStatement::Assignment(v.clone().into(), c), - TypedAssemblyStatement::Assignment(assignee, expr), - ]), - None => { - Ok(vec![TypedAssemblyStatement::Assignment(assignee, expr)]) - } - }, - }, + Ok(vec![]) + } + assignee => match self.try_get_constant_mut(&assignee) { + Ok((_, c)) => { + *c = expr.into_canonical_constant(); + Ok(vec![]) } - } else { - // the expression being assigned is not constant, invalidate the cache - let v = self - .try_get_constant_mut(&assignee) - .map(|(v, _)| v) - .unwrap_or_else(|v| v); - - match self.constants.remove(&v.id) { + Err(v) => match self.constants.remove(&v.id) { + // invalidate the cache for this identifier, and define the latest + // version of the constant in the program, if any Some(c) => Ok(vec![ - TypedAssemblyStatement::Assignment(v.clone().into(), c), - TypedAssemblyStatement::Assignment(assignee, expr), + TypedAssemblyStatement::assignment(v.clone().into(), c), + TypedAssemblyStatement::assignment(assignee, expr), ]), - None => Ok(vec![TypedAssemblyStatement::Assignment(assignee, expr)]), - } - } + None => Ok(vec![TypedAssemblyStatement::assignment(assignee, expr)]), + }, + }, } - TypedAssemblyStatement::Constraint(left, right, metadata) => { - let left = self.fold_field_expression(left)?; - let right = self.fold_field_expression(right)?; - - // a bit hacky, but we use a fake boolean expression to check this - let is_equal = - BooleanExpression::FieldEq(EqExpression::new(left.clone(), right.clone())); - let is_equal = self.fold_boolean_expression(is_equal)?; - - match is_equal { - BooleanExpression::Value(true) => Ok(vec![]), - BooleanExpression::Value(false) => { - Err(Error::AssertionFailed(RuntimeError::SourceAssertion( - metadata - .message(Some(format!("In asm block: `{} !== {}`", left, right))), - ))) - } - _ => Ok(vec![TypedAssemblyStatement::Constraint( - left, right, metadata, - )]), - } + } else { + // the expression being assigned is not constant, invalidate the cache + let v = self + .try_get_constant_mut(&assignee) + .map(|(v, _)| v) + .unwrap_or_else(|v| v); + + match self.constants.remove(&v.id) { + Some(c) => Ok(vec![ + TypedAssemblyStatement::assignment(v.clone().into(), c), + TypedAssemblyStatement::assignment(assignee, expr), + ]), + None => Ok(vec![TypedAssemblyStatement::assignment(assignee, expr)]), } } } - fn fold_statement( + fn fold_assembly_constraint( &mut self, - s: TypedStatement<'ast, T>, - ) -> Result>, Error> { - match s { - TypedStatement::Assembly(statements) => { - let statements: Vec<_> = statements - .into_iter() - .map(|s| self.fold_assembly_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(); - match statements.len() { - 0 => Ok(vec![]), - _ => Ok(vec![TypedStatement::Assembly(statements)]), - } + s: AssemblyConstraint<'ast, T>, + ) -> Result>, Self::Error> { + let span = s.get_span(); + + let left = self.fold_field_expression(s.left)?; + let right = self.fold_field_expression(s.right)?; + + // a bit hacky, but we use a fake boolean expression to check this + let is_equal = BooleanExpression::field_eq(left.clone(), right.clone()).span(span); + let is_equal = self.fold_boolean_expression(is_equal)?; + + match is_equal { + BooleanExpression::Value(v) if v.value => Ok(vec![]), + BooleanExpression::Value(v) if !v.value => { + Err(Error::AssertionFailed(RuntimeError::SourceAssertion( + s.metadata + .message(Some(format!("In asm block: `{} !== {}`", left, right))), + ))) } + _ => Ok(vec![TypedAssemblyStatement::constraint( + left, right, s.metadata, + )]), + } + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + let span = s.get_span(); + + match s.rhs { // propagation to the defined variable if rhs is a constant - TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => { - let assignee = self.fold_assignee(assignee)?; - let expr = self.fold_expression(expr)?; + DefinitionRhs::Expression(e) => { + let assignee = self.fold_assignee(s.assignee)?; + let expr = self.fold_expression(e)?; if let (Ok(a), Ok(e)) = ( ConcreteType::try_from(assignee.get_type()), @@ -326,10 +319,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { // invalidate the cache for this identifier, and define the latest // version of the constant in the program, if any Some(c) => Ok(vec![ - TypedStatement::Definition(v.clone().into(), c.into()), - TypedStatement::Definition(assignee, expr.into()), + TypedStatement::definition(v.clone().into(), c), + TypedStatement::definition(assignee, expr), ]), - None => Ok(vec![TypedStatement::Definition(assignee, expr.into())]), + None => Ok(vec![TypedStatement::definition(assignee, expr)]), }, }, } @@ -342,23 +335,16 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { match self.constants.remove(&v.id) { Some(c) => Ok(vec![ - TypedStatement::Definition(v.clone().into(), c.into()), - TypedStatement::Definition(assignee, expr.into()), + TypedStatement::definition(v.clone().into(), c), + TypedStatement::definition(assignee, expr), ]), - None => Ok(vec![TypedStatement::Definition(assignee, expr.into())]), + None => Ok(vec![TypedStatement::definition(assignee, expr)]), } } } - // we do not visit the for-loop statements - TypedStatement::For(v, from, to, statements) => { - let from = self.fold_uint_expression(from)?; - let to = self.fold_uint_expression(to)?; - - Ok(vec![TypedStatement::For(v, from, to, statements)]) - } - TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => { - let assignee = self.fold_assignee(assignee)?; - let embed_call = self.fold_embed_call(embed_call)?; + DefinitionRhs::EmbedCall(e) => { + let assignee = self.fold_assignee(s.assignee)?; + let embed_call = self.fold_embed_call(e)?; fn process_u_from_bits<'ast, T: Field>( arguments: &[TypedExpression<'ast, T>], @@ -374,7 +360,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { .into_inner() { ArrayExpressionInner::Value(v) => - UExpressionInner::Value( + UExpression::value( v.into_iter() .map(|v| match v { TypedExpressionOrSpread::Expression( @@ -386,7 +372,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { }) .enumerate() .fold(0, |acc, (i, v)| { - if v { + if v.value { acc + 2u128.pow( (bitwidth.to_usize() - i - 1) .try_into() @@ -414,7 +400,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { .into_inner() { UExpressionInner::Value(v) => { - let mut num = v; + let mut num = v.value; let mut res = vec![]; for i in (0..bitwidth as u32).rev() { @@ -427,11 +413,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } assert_eq!(num, 0); - ArrayExpressionInner::Value( + ArrayExpression::value( res.into_iter() - .map(|v| BooleanExpression::Value(v).into()) - .collect::>() - .into(), + .map(|v| BooleanExpression::value(v).into()) + .collect::>(), ) .annotate(Type::Boolean, bitwidth.to_usize() as u32) .into() @@ -448,13 +433,17 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { match FieldElementExpression::try_from_typed( embed_call.arguments[0].clone(), ) { - Ok(FieldElementExpression::Number(n)) if n == T::from(0) => { - Ok(Some(BooleanExpression::Value(false).into())) + Ok(FieldElementExpression::Value(n)) + if n.value == T::from(0) => + { + Ok(Some(BooleanExpression::value(false).span(span).into())) } - Ok(FieldElementExpression::Number(n)) if n == T::from(1) => { - Ok(Some(BooleanExpression::Value(true).into())) + Ok(FieldElementExpression::Value(n)) + if n.value == T::from(1) => + { + Ok(Some(BooleanExpression::value(true).span(span).into())) } - Ok(FieldElementExpression::Number(n)) => { + Ok(FieldElementExpression::Value(n)) => { Err(Error::InvalidValue(format!( "Cannot call `{}` with value `{}`: should be 0 or 1", embed_call.embed.id(), @@ -507,8 +496,8 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { ) .unwrap() { - FieldElementExpression::Number(num) => { - let mut acc = num; + FieldElementExpression::Value(num) => { + let mut acc = num.value; let mut res = vec![]; for i in (0..bit_width as usize).rev() { @@ -528,13 +517,17 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { ))) } else { Ok(Some( - ArrayExpressionInner::Value( + ArrayExpression::value( res.into_iter() - .map(|v| BooleanExpression::Value(v).into()) - .collect::>() - .into(), + .map(|v| { + BooleanExpression::value(v) + .span(span) + .into() + }) + .collect::>(), ) .annotate(Type::Boolean, bit_width) + .span(span) .into(), )) } @@ -562,11 +555,11 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } Err(v) => match self.constants.remove(&v.id) { Some(c) => vec![ - TypedStatement::Definition(v.clone().into(), c.into()), - TypedStatement::Definition(assignee, expr.into()), + TypedStatement::definition(v.clone().into(), c), + TypedStatement::definition(assignee, expr), ], None => { - vec![TypedStatement::Definition(assignee, expr.into())] + vec![TypedStatement::definition(assignee, expr)] } }, }, @@ -582,12 +575,11 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { match self.constants.remove(&v.id) { Some(c) => vec![ - TypedStatement::Definition(v.clone().into(), c.into()), - TypedStatement::Definition(assignee, embed_call.into()), + TypedStatement::definition(v.clone().into(), c), + TypedStatement::embed_call_definition(assignee, embed_call), ], - None => vec![TypedStatement::Definition( - assignee, - embed_call.into(), + None => vec![TypedStatement::embed_call_definition( + assignee, embed_call, )], } } @@ -596,7 +588,8 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { false => { // if the function arguments are not constant, invalidate the cache // for the return assignees - let def = TypedStatement::Definition(assignee.clone(), embed_call.into()); + let def = + TypedStatement::embed_call_definition(assignee.clone(), embed_call); let v = self .try_get_constant_mut(&assignee) @@ -605,401 +598,447 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { Ok(match self.constants.remove(&v.id) { Some(c) => { - vec![TypedStatement::Definition(v.clone().into(), c.into()), def] + vec![TypedStatement::definition(v.clone().into(), c), def] } None => vec![def], }) } } } - TypedStatement::Assertion(e, err) => { - let expr = self.fold_boolean_expression(e)?; - match expr { - BooleanExpression::Value(false) => Err(Error::AssertionFailed(err)), - BooleanExpression::Value(true) => Ok(vec![]), - _ => Ok(vec![TypedStatement::Assertion(expr, err)]), - } - } - s => fold_statement(self, s), } } - fn fold_uint_expression_inner( + fn fold_assembly_block( + &mut self, + s: AssemblyBlockStatement<'ast, T>, + ) -> Result>, Self::Error> { + Ok(fold_assembly_block(self, s)? + .into_iter() + .filter(|s| match s { + TypedStatement::Assembly(s) => !s.inner.is_empty(), + _ => true, + }) + .collect()) + } + + fn fold_for_statement( + &mut self, + s: ForStatement<'ast, T>, + ) -> Result>, Self::Error> { + // we do not visit the for-loop statements + let from = self.fold_uint_expression(s.from)?; + let to = self.fold_uint_expression(s.to)?; + + Ok(vec![TypedStatement::for_(s.var, from, to, s.statements)]) + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, + ) -> Result>, Self::Error> { + let _e_str = s.expression.to_string(); + let expr = self.fold_boolean_expression(s.expression)?; + + match expr { + BooleanExpression::Value(v) if !v.value => Err(Error::AssertionFailed(s.error)), + BooleanExpression::Value(v) if v.value => Ok(vec![]), + _ => Ok(vec![TypedStatement::assertion(expr, s.error)]), + } + } + + fn fold_uint_expression_cases( &mut self, bitwidth: UBitwidth, e: UExpressionInner<'ast, T>, ) -> Result, Error> { match e { - UExpressionInner::Add(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Add(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - (v1 + v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + (v1.value + v2.value) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => match v { - 0 => Ok(e), - _ => Ok(UExpressionInner::Add( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), - }, - (e1, e2) => Ok(UExpressionInner::Add( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => { + match v.value { + 0 => Ok(e), + _ => Ok(UExpression::add( + e.annotate(bitwidth), + UExpression::value(v.value).annotate(bitwidth), + ) + .into_inner()), + } + } + (e1, e2) => { + Ok(UExpression::add(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::Sub(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Sub(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - (v1.wrapping_sub(v2)) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + (v1.value.wrapping_sub(v2.value)) + % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) => match v { + (e, UExpressionInner::Value(v)) => match v.value { 0 => Ok(e), - _ => Ok(UExpressionInner::Sub( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), + _ => Ok(UExpression::sub( + e.annotate(bitwidth), + UExpressionInner::Value(v).annotate(bitwidth), + ) + .into_inner()), }, - (e1, e2) => Ok(UExpressionInner::Sub( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => { + Ok(UExpression::sub(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::FloorSub(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::FloorSub(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - v1.saturating_sub(v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + v1.value.saturating_sub(v2.value) + % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) => match v { + (e, UExpressionInner::Value(v)) => match v.value { 0 => Ok(e), - _ => Ok(UExpressionInner::FloorSub( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), + _ => Ok(UExpression::floor_sub( + e.annotate(bitwidth), + UExpressionInner::Value(v).annotate(bitwidth), + ) + .into_inner()), }, - (e1, e2) => Ok(UExpressionInner::Sub( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => { + Ok(UExpression::sub(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::Mult(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Mult(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - (v1 * v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + (v1.value * v2.value) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => match v { - 0 => Ok(UExpressionInner::Value(0)), - 1 => Ok(e), - _ => Ok(UExpressionInner::Mult( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), - }, - (e1, e2) => Ok(UExpressionInner::Mult( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => { + match v.value { + 0 => Ok(UExpression::value(0)), + 1 => Ok(e), + _ => Ok(UExpression::mul( + e.annotate(bitwidth), + UExpressionInner::Value(v).annotate(bitwidth), + ) + .into_inner()), + } + } + (e1, e2) => { + Ok(UExpression::mul(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::Div(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Div(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - (v1 / v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + (v1.value / v2.value) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) => match v { + (e, UExpressionInner::Value(v)) => match v.value { 1 => Ok(e), - _ => Ok(UExpressionInner::Div( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), + _ => Ok(UExpression::div( + e.annotate(bitwidth), + UExpressionInner::Value(v).annotate(bitwidth), + ) + .into_inner()), }, - (e1, e2) => Ok(UExpressionInner::Div( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => { + Ok(UExpression::div(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::Rem(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Rem(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - (v1 % v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + (v1.value % v2.value) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) => match v { - 1 => Ok(UExpressionInner::Value(0)), - _ => Ok(UExpressionInner::Rem( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), + (e, UExpressionInner::Value(v)) => match v.value { + 1 => Ok(UExpression::value(0)), + _ => Ok(UExpression::rem( + e.annotate(bitwidth), + UExpressionInner::Value(v).annotate(bitwidth), + ) + .into_inner()), }, - (e1, e2) => Ok(UExpressionInner::Rem( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => { + Ok(UExpression::rem(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::RightShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e.into_inner(), by.into_inner()) { + UExpressionInner::RightShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + match (left.into_inner(), right.into_inner()) { (UExpressionInner::Value(v), UExpressionInner::Value(by)) => { - Ok(UExpressionInner::Value(v >> by)) + Ok(UExpression::value(v.value >> by.value)) } - (e, by) => Ok(UExpressionInner::RightShift( - box e.annotate(bitwidth), - box by.annotate(UBitwidth::B32), - )), + (e, by) => Ok(UExpression::right_shift( + e.annotate(bitwidth), + by.annotate(UBitwidth::B32), + ) + .into_inner()), } } - UExpressionInner::LeftShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e.into_inner(), by.into_inner()) { - (UExpressionInner::Value(v), UExpressionInner::Value(by)) => Ok( - UExpressionInner::Value((v << by) & (2_u128.pow(bitwidth as u32) - 1)), - ), - (e, by) => Ok(UExpressionInner::LeftShift( - box e.annotate(bitwidth), - box by.annotate(UBitwidth::B32), - )), + UExpressionInner::LeftShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + match (left.into_inner(), right.into_inner()) { + (UExpressionInner::Value(v), UExpressionInner::Value(by)) => { + Ok(UExpression::value( + (v.value << by.value) & (2_u128.pow(bitwidth as u32) - 1), + )) + } + (e, by) => Ok(UExpression::left_shift( + e.annotate(bitwidth), + by.annotate(UBitwidth::B32), + ) + .into_inner()), } } - UExpressionInner::Xor(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Xor(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value(v1 ^ v2)) + Ok(UExpression::value(v1.value ^ v2.value)) + } + (UExpressionInner::Value(v), e2) | (e2, UExpressionInner::Value(v)) + if v.value == 0 => + { + Ok(e2) } - (UExpressionInner::Value(0), e2) => Ok(e2), - (e1, UExpressionInner::Value(0)) => Ok(e1), (e1, e2) => { if e1 == e2 { - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) } else { - Ok(UExpressionInner::Xor( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )) + Ok( + UExpression::xor(e1.annotate(bitwidth), e2.annotate(bitwidth)) + .into_inner(), + ) } } }, - UExpressionInner::And(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::And(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value(v1 & v2)) + Ok(UExpression::value(v1.value & v2.value)) } - (UExpressionInner::Value(0), _) | (_, UExpressionInner::Value(0)) => { - Ok(UExpressionInner::Value(0)) + (UExpressionInner::Value(v), _) | (_, UExpressionInner::Value(v)) + if v.value == 0 => + { + Ok(UExpression::value(0)) + } + (e1, e2) => { + Ok(UExpression::and(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) } - (e1, e2) => Ok(UExpressionInner::And( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), }, - UExpressionInner::Not(box e) => { - let e = self.fold_uint_expression(e)?.into_inner(); + UExpressionInner::Not(e) => { + let e = self.fold_uint_expression(*e.inner)?.into_inner(); match e { - UExpressionInner::Value(v) => Ok(UExpressionInner::Value( - (!v) & (2_u128.pow(bitwidth as u32) - 1), + UExpressionInner::Value(v) => Ok(UExpression::value( + (!v.value) & (2_u128.pow(bitwidth as u32) - 1), )), - e => Ok(UExpressionInner::Not(box e.annotate(bitwidth))), + e => Ok(UExpression::not(e.annotate(bitwidth)).into_inner()), } } - UExpressionInner::Neg(box e) => { - let e = self.fold_uint_expression(e)?.into_inner(); + UExpressionInner::Neg(e) => { + let e = self.fold_uint_expression(*e.inner)?.into_inner(); match e { - UExpressionInner::Value(v) => Ok(UExpressionInner::Value( - (0u128.wrapping_sub(v)) + UExpressionInner::Value(v) => Ok(UExpression::value( + (0u128.wrapping_sub(v.value)) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )), - e => Ok(UExpressionInner::Neg(box e.annotate(bitwidth))), + e => Ok(UExpression::neg(e.annotate(bitwidth)).into_inner()), } } - UExpressionInner::Pos(box e) => { - let e = self.fold_uint_expression(e)?.into_inner(); + UExpressionInner::Pos(e) => { + let e = self.fold_uint_expression(*e.inner)?.into_inner(); match e { - UExpressionInner::Value(v) => Ok(UExpressionInner::Value(v)), - e => Ok(UExpressionInner::Pos(box e.annotate(bitwidth))), + UExpressionInner::Value(v) => Ok(UExpression::value(v.value)), + e => Ok(UExpression::pos(e.annotate(bitwidth)).into_inner()), } } - e => fold_uint_expression_inner(self, bitwidth, e), + e => fold_uint_expression_cases(self, bitwidth, e), } } - fn fold_field_expression( + fn fold_field_expression_cases( &mut self, e: FieldElementExpression<'ast, T>, ) -> Result, Error> { match e { - FieldElementExpression::Add(box e1, box e2) => match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 + n2)) - } - (e1, e2) => Ok(FieldElementExpression::Add(box e1, box e2)), - }, - FieldElementExpression::Sub(box e1, box e2) => match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 - n2)) - } - (e1, e2) => Ok(FieldElementExpression::Sub(box e1, box e2)), - }, - FieldElementExpression::Mult(box e1, box e2) => match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 * n2)) - } - (e1, e2) => Ok(FieldElementExpression::Mult(box e1, box e2)), - }, - FieldElementExpression::Div(box e1, box e2) => match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 / n2)) - } - (e1, e2) => Ok(FieldElementExpression::Div(box e1, box e2)), - }, - FieldElementExpression::Neg(box e) => match self.fold_field_expression(e)? { - FieldElementExpression::Number(n) => { - Ok(FieldElementExpression::Number(T::zero() - n)) + FieldElementExpression::Add(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::Value(ValueExpression::new(n1.value + n2.value)) + } + (e1, e2) => e1 + e2, + }) + } + FieldElementExpression::Sub(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::Value(ValueExpression::new(n1.value - n2.value)) + } + (e1, e2) => e1 - e2, + }) + } + FieldElementExpression::Mult(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::Value(ValueExpression::new(n1.value * n2.value)) + } + (e1, e2) => e1 * e2, + }) + } + FieldElementExpression::Div(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::Value(ValueExpression::new(n1.value / n2.value)) + } + (e1, e2) => e1 / e2, + }) + } + FieldElementExpression::Neg(e) => match self.fold_field_expression(*e.inner)? { + FieldElementExpression::Value(n) => { + Ok(FieldElementExpression::value(T::zero() - n.value)) } - e => Ok(FieldElementExpression::Neg(box e)), + e => Ok(FieldElementExpression::neg(e)), }, - FieldElementExpression::Pos(box e) => match self.fold_field_expression(e)? { - FieldElementExpression::Number(n) => Ok(FieldElementExpression::Number(n)), - e => Ok(FieldElementExpression::Pos(box e)), + FieldElementExpression::Pos(e) => match self.fold_field_expression(*e.inner)? { + FieldElementExpression::Value(n) => Ok(FieldElementExpression::Value(n)), + e => Ok(FieldElementExpression::pos(e)), }, - FieldElementExpression::Pow(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + FieldElementExpression::Pow(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1, e2.into_inner()) { - (_, UExpressionInner::Value(ref n2)) if *n2 == 0 => { - Ok(FieldElementExpression::Number(T::from(1))) + (_, UExpressionInner::Value(ref n2)) if n2.value == 0 => { + Ok(FieldElementExpression::value(T::from(1))) } - (FieldElementExpression::Number(n1), UExpressionInner::Value(n2)) => { - Ok(FieldElementExpression::Number(n1.pow(n2 as usize))) - } - (e1, UExpressionInner::Value(n2)) => Ok(FieldElementExpression::Pow( - box e1, - box UExpressionInner::Value(n2).annotate(UBitwidth::B32), - )), - (e1, e2) => Ok(FieldElementExpression::Pow( - box e1, - box e2.annotate(UBitwidth::B32), - )), + (FieldElementExpression::Value(n1), UExpressionInner::Value(n2)) => Ok( + FieldElementExpression::value(n1.value.pow(n2.value as usize)), + ), + (e1, e2) => Ok(FieldElementExpression::pow(e1, e2.annotate(UBitwidth::B32))), } } - FieldElementExpression::Xor(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::Xor(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitxor(n2.value.to_biguint())) + .unwrap(), )) } - (FieldElementExpression::Number(n), e) - | (e, FieldElementExpression::Number(n)) - if n == T::from(0) => + (FieldElementExpression::Value(n), e) + | (e, FieldElementExpression::Value(n)) + if n.value == T::from(0) => { Ok(e) } - (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))), - (e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)), + (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::value(T::from(0))), + (e1, e2) => Ok(FieldElementExpression::bitxor(e1, e2)), } } - - FieldElementExpression::And(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::And(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (_, FieldElementExpression::Number(n)) - | (FieldElementExpression::Number(n), _) - if n == T::from(0) => + (_, FieldElementExpression::Value(n)) + | (FieldElementExpression::Value(n), _) + if n.value == T::from(0) => { - Ok(FieldElementExpression::Number(n)) + Ok(FieldElementExpression::Value(n)) } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitand(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitand(n2.value.to_biguint())) + .unwrap(), )) } - (e1, e2) => Ok(FieldElementExpression::And(box e1, box e2)), + (e1, e2) => Ok(FieldElementExpression::bitand(e1, e2)), } } - FieldElementExpression::Or(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::Or(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (e, FieldElementExpression::Number(n)) - | (FieldElementExpression::Number(n), e) - if n == T::from(0) => + (e, FieldElementExpression::Value(n)) + | (FieldElementExpression::Value(n), e) + if n.value == T::from(0) => { Ok(e) } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitor(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitor(n2.value.to_biguint())) + .unwrap(), )) } - (e1, e2) => Ok(FieldElementExpression::Or(box e1, box e2)), + (e1, e2) => Ok(FieldElementExpression::bitor(e1, e2)), } } - FieldElementExpression::LeftShift(box e, box by) => { - let e = self.fold_field_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e, by) { + FieldElementExpression::LeftShift(e) => { + let expr = self.fold_field_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; + match (expr, by) { ( e, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by == 0 => Ok(e), + ) if by.value == 0 => Ok(e), ( _, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by as usize >= T::get_required_bits() => { - Ok(FieldElementExpression::Number(T::from(0))) + ) if by.value as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::value(T::from(0))) } ( - FieldElementExpression::Number(n), + FieldElementExpression::Value(n), UExpression { inner: UExpressionInner::Value(by), .. @@ -1008,46 +1047,47 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { let two = BigUint::from(2usize); let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize); - Ok(FieldElementExpression::Number( - T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(), + Ok(FieldElementExpression::value( + T::try_from(n.value.to_biguint().shl(by.value as usize).bitand(mask)) + .unwrap(), )) } - (e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)), + (e, by) => Ok(FieldElementExpression::left_shift(e, by)), } } - FieldElementExpression::RightShift(box e, box by) => { - let e = self.fold_field_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e, by) { + FieldElementExpression::RightShift(e) => { + let expr = self.fold_field_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; + match (expr, by) { ( e, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by == 0 => Ok(e), + ) if by.value == 0 => Ok(e), ( _, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by as usize >= T::get_required_bits() => { - Ok(FieldElementExpression::Number(T::from(0))) + ) if by.value as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::value(T::from(0))) } ( - FieldElementExpression::Number(n), + FieldElementExpression::Value(n), UExpression { inner: UExpressionInner::Value(by), .. }, - ) => Ok(FieldElementExpression::Number( - T::try_from(n.to_biguint().shr(by as usize)).unwrap(), + ) => Ok(FieldElementExpression::value( + T::try_from(n.value.to_biguint().shr(by.value as usize)).unwrap(), )), - (e, by) => Ok(FieldElementExpression::RightShift(box e, box by)), + (e, by) => Ok(FieldElementExpression::right_shift(e, by)), } } - e => fold_field_expression(self, e), + e => fold_field_expression_cases(self, e), } } @@ -1128,10 +1168,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { if n < size { Ok(SelectOrExpression::Expression( - v.expression_at::(n as usize).unwrap().into_inner(), + v.expression_at::(n.value as usize).unwrap().into_inner(), )) } else { - Err(Error::OutOfBounds(n, size)) + Err(Error::OutOfBounds(n.value, size.value)) } } (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { @@ -1140,7 +1180,9 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { TypedExpression::Array(a) => match a.as_inner() { ArrayExpressionInner::Value(v) => { Ok(SelectOrExpression::Expression( - v.expression_at::(n as usize).unwrap().into_inner(), + v.expression_at::(n.value as usize) + .unwrap() + .into_inner(), )) } _ => unreachable!("should be an array value"), @@ -1150,7 +1192,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { None => Ok(SelectOrExpression::Expression( E::select( ArrayExpressionInner::Identifier(id) - .annotate(inner_type, size as u32), + .annotate(inner_type, size.value as u32), UExpressionInner::Value(n).annotate(UBitwidth::B32), ) .into_inner(), @@ -1158,7 +1200,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } } (a, i) => Ok(SelectOrExpression::Select(SelectExpression::new( - a.annotate(inner_type, size as u32), + a.annotate(inner_type, size.value as u32), i.annotate(UBitwidth::B32), ))), }, @@ -1168,7 +1210,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } } - fn fold_array_expression_inner( + fn fold_array_expression_cases( &mut self, ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, @@ -1190,7 +1232,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { inner: ArrayExpressionInner::Value(v), .. }, - }) => v.0, + }) => v.value, e => vec![e], } }) @@ -1214,11 +1256,11 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { .collect(), )) } - e => fold_array_expression_inner(self, ty, e), + e => fold_array_expression_cases(self, ty, e), } } - fn fold_struct_expression_inner( + fn fold_struct_expression_cases( &mut self, ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, @@ -1240,11 +1282,13 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { Ok(StructExpressionInner::Value(v)) } - e => fold_struct_expression_inner(self, ty, e), + e => fold_struct_expression_cases(self, ty, e), } } - fn fold_identifier_expression + Id<'ast, T> + ResultFold<'ast, T>>( + fn fold_identifier_expression< + E: Expr<'ast, T> + Id<'ast, T> + ResultFold, + >( &mut self, _: &E::Ty, id: IdentifierExpression<'ast, E>, @@ -1255,7 +1299,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } } - fn fold_tuple_expression_inner( + fn fold_tuple_expression_cases( &mut self, ty: &TupleType<'ast, T>, e: TupleExpressionInner<'ast, T>, @@ -1277,16 +1321,19 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { Ok(TupleExpressionInner::Value(v)) } - e => fold_tuple_expression_inner(self, ty, e), + e => fold_tuple_expression_cases(self, ty, e), } } fn fold_eq_expression< - E: Expr<'ast, T> + PartialEq + Constant + Typed<'ast, T> + ResultFold<'ast, T>, + E: Expr<'ast, T> + PartialEq + Constant + Typed<'ast, T> + ResultFold, >( &mut self, - e: EqExpression, - ) -> Result, Self::Error> { + e: EqExpression>, + ) -> Result< + BinaryOrExpression, BooleanExpression<'ast, T>>, + Self::Error, + > { let left = e.left.fold(self)?; let right = e.right.fold(self)?; @@ -1305,22 +1352,26 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { // if the two expressions are the same, we can reduce to `true`. // Note that if they are different we cannot reduce to `false`: `a == 1` may still be `true` even though `a` and `1` are different expressions if left == right { - return Ok(EqOrBoolean::Boolean(BooleanExpression::Value(true))); + return Ok(BinaryOrExpression::Expression(BooleanExpression::value( + true, + ))); } // if both expressions are constant, we can reduce the equality check after we put them in canonical form if left.is_constant() && right.is_constant() { let left = left.into_canonical_constant(); let right = right.into_canonical_constant(); - Ok(EqOrBoolean::Boolean(BooleanExpression::Value( + Ok(BinaryOrExpression::Expression(BooleanExpression::value( left == right, ))) } else { - Ok(EqOrBoolean::Eq(EqExpression::new(left, right))) + Ok(BinaryOrExpression::Binary(BinaryExpression::new( + left, right, + ))) } } - fn fold_boolean_expression( + fn fold_boolean_expression_cases( &mut self, e: BooleanExpression<'ast, T>, ) -> Result, Error> { @@ -1330,142 +1381,106 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { // These kind of reduction rules are easier to apply later in the process, when we have canonical representations // of expressions, ie `a + a` would always be written `2 * a` match e { - BooleanExpression::FieldLt(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + BooleanExpression::FieldLt(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 < n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(BooleanExpression::value(n1.value < n2.value)) } - (e1, e2) => Ok(BooleanExpression::FieldLt(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::field_lt(e1, e2)), } } - BooleanExpression::FieldLe(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + BooleanExpression::FieldLe(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 <= n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(BooleanExpression::value(n1.value <= n2.value)) } - (e1, e2) => Ok(BooleanExpression::FieldLe(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::field_le(e1, e2)), } } - BooleanExpression::FieldGt(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; - - match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 > n2)) - } - (e1, e2) => Ok(BooleanExpression::FieldGt(box e1, box e2)), - } - } - BooleanExpression::FieldGe(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; - - match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 >= n2)) - } - (e1, e2) => Ok(BooleanExpression::FieldGe(box e1, box e2)), - } - } - BooleanExpression::UintLt(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + BooleanExpression::UintLt(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(BooleanExpression::Value(n1 < n2)) + Ok(BooleanExpression::value(n1.value < n2.value)) } - _ => Ok(BooleanExpression::UintLt(box e1, box e2)), + _ => Ok(BooleanExpression::uint_lt(e1, e2)), } } - BooleanExpression::UintLe(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + BooleanExpression::UintLe(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(BooleanExpression::Value(n1 <= n2)) + Ok(BooleanExpression::value(n1.value <= n2.value)) } - _ => Ok(BooleanExpression::UintLe(box e1, box e2)), + _ => Ok(BooleanExpression::uint_le(e1, e2)), } } - BooleanExpression::UintGt(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; - - match (e1.as_inner(), e2.as_inner()) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(BooleanExpression::Value(n1 > n2)) - } - _ => Ok(BooleanExpression::UintGt(box e1, box e2)), - } - } - BooleanExpression::UintGe(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; - - match (e1.as_inner(), e2.as_inner()) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(BooleanExpression::Value(n1 >= n2)) - } - _ => Ok(BooleanExpression::UintGe(box e1, box e2)), - } - } - BooleanExpression::Or(box e1, box e2) => { - let e1 = self.fold_boolean_expression(e1)?; - let e2 = self.fold_boolean_expression(e2)?; + BooleanExpression::Or(e) => { + let e1 = self.fold_boolean_expression(*e.left)?; + let e2 = self.fold_boolean_expression(*e.right)?; match (e1, e2) { // reduction of constants (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { - Ok(BooleanExpression::Value(v1 || v2)) + Ok(BooleanExpression::value(v1.value || v2.value)) } // x || true == true - (_, BooleanExpression::Value(true)) | (BooleanExpression::Value(true), _) => { - Ok(BooleanExpression::Value(true)) + (_, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), _) + if v.value => + { + Ok(BooleanExpression::value(true)) } // x || false == x - (e, BooleanExpression::Value(false)) | (BooleanExpression::Value(false), e) => { + (e, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), e) + if !v.value => + { Ok(e) } - (e1, e2) => Ok(BooleanExpression::Or(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::bitor(e1, e2)), } } - BooleanExpression::And(box e1, box e2) => { - let e1 = self.fold_boolean_expression(e1)?; - let e2 = self.fold_boolean_expression(e2)?; + BooleanExpression::And(e) => { + let e1 = self.fold_boolean_expression(*e.left)?; + let e2 = self.fold_boolean_expression(*e.right)?; match (e1, e2) { // reduction of constants (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { - Ok(BooleanExpression::Value(v1 && v2)) + Ok(BooleanExpression::value(v1.value && v2.value)) } // x && true == x - (e, BooleanExpression::Value(true)) | (BooleanExpression::Value(true), e) => { + (e, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), e) + if v.value => + { Ok(e) } // x && false == false - (_, BooleanExpression::Value(false)) | (BooleanExpression::Value(false), _) => { - Ok(BooleanExpression::Value(false)) + (_, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), _) + if !v.value => + { + Ok(BooleanExpression::value(false)) } - (e1, e2) => Ok(BooleanExpression::And(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::bitand(e1, e2)), } } - BooleanExpression::Not(box e) => { - let e = self.fold_boolean_expression(e)?; + BooleanExpression::Not(e) => { + let e = self.fold_boolean_expression(*e.inner)?; match e { - BooleanExpression::Value(v) => Ok(BooleanExpression::Value(!v)), - e => Ok(BooleanExpression::Not(box e)), + BooleanExpression::Value(v) => Ok(BooleanExpression::value(!v.value)), + e => Ok(BooleanExpression::not(e)), } } - e => fold_boolean_expression(self, e), + e => fold_boolean_expression_cases(self, e), } } } @@ -1485,66 +1500,66 @@ mod tests { #[test] fn add() { - let e = FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + let e = FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(5))) + Ok(FieldElementExpression::value(Bn128Field::from(5))) ); } #[test] fn sub() { - let e = FieldElementExpression::Sub( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e = FieldElementExpression::sub( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); } #[test] fn mult() { - let e = FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e = FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(6))) + Ok(FieldElementExpression::value(Bn128Field::from(6))) ); } #[test] fn div() { - let e = FieldElementExpression::Div( - box FieldElementExpression::Number(Bn128Field::from(6)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e = FieldElementExpression::div( + FieldElementExpression::value(Bn128Field::from(6)), + FieldElementExpression::value(Bn128Field::from(2)), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(3))) + Ok(FieldElementExpression::value(Bn128Field::from(3))) ); } #[test] fn pow() { - let e = FieldElementExpression::Pow( - box FieldElementExpression::Number(Bn128Field::from(2)), - box 3u32.into(), + let e = FieldElementExpression::pow( + FieldElementExpression::value(Bn128Field::from(2)), + 3u32.into(), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(8))) + Ok(FieldElementExpression::value(Bn128Field::from(8))) ); } @@ -1553,43 +1568,43 @@ mod tests { let mut propagator = Propagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::identifier("a".into()), - box 0u32.into(), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::identifier("a".into()), + 0u32.into(), )), Ok(FieldElementExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box 2u32.into(), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(2)), + 2u32.into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(8))) + Ok(FieldElementExpression::value(Bn128Field::from(8))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(1)), - box ((Bn128Field::get_required_bits() - 1) as u32).into(), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(1)), + ((Bn128Field::get_required_bits() - 1) as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) + Ok(FieldElementExpression::value(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(3)), - box ((Bn128Field::get_required_bits() - 3) as u32).into(), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(3)), + ((Bn128Field::get_required_bits() - 3) as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) + Ok(FieldElementExpression::value(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(1)), - box (Bn128Field::get_required_bits() as u32).into(), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(1)), + (Bn128Field::get_required_bits() as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); } @@ -1598,110 +1613,106 @@ mod tests { let mut propagator = Propagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::identifier("a".into()), - box 0u32.into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::identifier("a".into()), + 0u32.into(), )), Ok(FieldElementExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::identifier("a".into()), - box (Bn128Field::get_required_bits() as u32).into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::identifier("a".into()), + (Bn128Field::get_required_bits() as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(3)), - box 1u32.into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(3)), + 1u32.into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box 2u32.into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(2)), + 2u32.into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box 4u32.into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(2)), + 4u32.into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::max_value()), - box ((Bn128Field::get_required_bits() - 1) as u32).into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::max_value()), + ((Bn128Field::get_required_bits() - 1) as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::max_value()), - box (Bn128Field::get_required_bits() as u32).into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::max_value()), + (Bn128Field::get_required_bits() as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); } #[test] fn if_else_true() { let e = FieldElementExpression::conditional( - BooleanExpression::Value(true), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + BooleanExpression::value(true), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(2))) + Ok(FieldElementExpression::value(Bn128Field::from(2))) ); } #[test] fn if_else_false() { let e = FieldElementExpression::conditional( - BooleanExpression::Value(false), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + BooleanExpression::value(false), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(3))) + Ok(FieldElementExpression::value(Bn128Field::from(3))) ); } #[test] fn select() { let e = FieldElementExpression::select( - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(1)).into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(3)).into(), - ] - .into(), - ) + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(3)).into(), + ]) .annotate(Type::FieldElement, 3u32), - UExpressionInner::Add(box 1u32.into(), box 1u32.into()) - .annotate(UBitwidth::B32), + UExpression::add(1u32.into(), 1u32.into()), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(3))) + Ok(FieldElementExpression::value(Bn128Field::from(3))) ); } } @@ -1713,21 +1724,21 @@ mod tests { #[test] fn not() { let e_true: BooleanExpression = - BooleanExpression::Not(box BooleanExpression::Value(false)); + BooleanExpression::not(BooleanExpression::value(false)); let e_false: BooleanExpression = - BooleanExpression::Not(box BooleanExpression::Value(true)); + BooleanExpression::not(BooleanExpression::value(true)); let e_default: BooleanExpression = - BooleanExpression::Not(box BooleanExpression::identifier("a".into())); + BooleanExpression::not(BooleanExpression::identifier("a".into())); assert_eq!( Propagator::default().fold_boolean_expression(e_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_default.clone()), @@ -1737,39 +1748,39 @@ mod tests { #[test] fn field_eq() { - let e_constant_true = BooleanExpression::FieldEq(EqExpression::new( - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(2)), + let e_constant_true = BooleanExpression::FieldEq(BinaryExpression::new( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(2)), )); - let e_constant_false = BooleanExpression::FieldEq(EqExpression::new( - FieldElementExpression::Number(Bn128Field::from(4)), - FieldElementExpression::Number(Bn128Field::from(2)), + let e_constant_false = BooleanExpression::FieldEq(BinaryExpression::new( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(2)), )); let e_identifier_true: BooleanExpression = - BooleanExpression::FieldEq(EqExpression::new( + BooleanExpression::FieldEq(BinaryExpression::new( FieldElementExpression::identifier("a".into()), FieldElementExpression::identifier("a".into()), )); let e_identifier_unchanged: BooleanExpression = - BooleanExpression::FieldEq(EqExpression::new( + BooleanExpression::FieldEq(BinaryExpression::new( FieldElementExpression::identifier("a".into()), FieldElementExpression::identifier("b".into()), )); assert_eq!( Propagator::default().fold_boolean_expression(e_constant_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_constant_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_identifier_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()), @@ -1782,151 +1793,121 @@ mod tests { assert_eq!( Propagator::::default().fold_boolean_expression( BooleanExpression::BoolEq(EqExpression::new( - BooleanExpression::Value(false), - BooleanExpression::Value(false) + BooleanExpression::value(false), + BooleanExpression::value(false) )) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( BooleanExpression::BoolEq(EqExpression::new( - BooleanExpression::Value(true), - BooleanExpression::Value(true) + BooleanExpression::value(true), + BooleanExpression::value(true) )) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( BooleanExpression::BoolEq(EqExpression::new( - BooleanExpression::Value(true), - BooleanExpression::Value(false) + BooleanExpression::value(true), + BooleanExpression::value(false) )) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::::default().fold_boolean_expression( BooleanExpression::BoolEq(EqExpression::new( - BooleanExpression::Value(false), - BooleanExpression::Value(true) + BooleanExpression::value(false), + BooleanExpression::value(true) )) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } #[test] fn array_eq() { - let e_constant_true = BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) + let e_constant_true = BooleanExpression::ArrayEq(BinaryExpression::new( + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) .annotate(Type::FieldElement, 1u32), - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) .annotate(Type::FieldElement, 1u32), )); - let e_constant_false = BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) + let e_constant_false = BooleanExpression::ArrayEq(BinaryExpression::new( + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) .annotate(Type::FieldElement, 1u32), - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(4usize)).into(), - )] - .into(), - ) + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(4usize)).into(), + )]) .annotate(Type::FieldElement, 1u32), )); let e_identifier_true: BooleanExpression = - BooleanExpression::ArrayEq(EqExpression::new( + BooleanExpression::ArrayEq(BinaryExpression::new( ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32), ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32), )); let e_identifier_unchanged: BooleanExpression = - BooleanExpression::ArrayEq(EqExpression::new( + BooleanExpression::ArrayEq(BinaryExpression::new( ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32), ArrayExpression::identifier("b".into()).annotate(Type::FieldElement, 1u32), )); - let e_non_canonical_true = BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Spread( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32) - .into(), - )] + let e_non_canonical_true = BooleanExpression::ArrayEq(BinaryExpression::new( + ArrayExpression::value(vec![TypedExpressionOrSpread::Spread( + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) + .annotate(Type::FieldElement, 1u32) .into(), - ) + )]) .annotate(Type::FieldElement, 1u32), - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) .annotate(Type::FieldElement, 1u32), )); - let e_non_canonical_false = BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Spread( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32) - .into(), - )] + let e_non_canonical_false = BooleanExpression::ArrayEq(BinaryExpression::new( + ArrayExpression::value(vec![TypedExpressionOrSpread::Spread( + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) + .annotate(Type::FieldElement, 1u32) .into(), - ) + )]) .annotate(Type::FieldElement, 1u32), - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(4usize)).into(), - )] - .into(), - ) + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(4usize)).into(), + )]) .annotate(Type::FieldElement, 1u32), )); assert_eq!( Propagator::default().fold_boolean_expression(e_constant_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_constant_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_identifier_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()), @@ -1934,99 +1915,99 @@ mod tests { ); assert_eq!( Propagator::default().fold_boolean_expression(e_non_canonical_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_non_canonical_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } #[test] fn lt() { - let e_true = BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(4)), + let e_true = BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(4)), ); - let e_false = BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e_false = BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(2)), ); assert_eq!( Propagator::default().fold_boolean_expression(e_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } #[test] fn le() { - let e_true = BooleanExpression::FieldLe( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e_true = BooleanExpression::field_le( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(2)), ); - let e_false = BooleanExpression::FieldLe( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e_false = BooleanExpression::field_le( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(2)), ); assert_eq!( Propagator::default().fold_boolean_expression(e_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } #[test] fn gt() { - let e_true = BooleanExpression::FieldGt( - box FieldElementExpression::Number(Bn128Field::from(5)), - box FieldElementExpression::Number(Bn128Field::from(4)), + let e_true = BooleanExpression::field_gt( + FieldElementExpression::value(Bn128Field::from(5)), + FieldElementExpression::value(Bn128Field::from(4)), ); - let e_false = BooleanExpression::FieldGt( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(5)), + let e_false = BooleanExpression::field_gt( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(5)), ); assert_eq!( Propagator::default().fold_boolean_expression(e_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } #[test] fn ge() { - let e_true = BooleanExpression::FieldGe( - box FieldElementExpression::Number(Bn128Field::from(5)), - box FieldElementExpression::Number(Bn128Field::from(5)), + let e_true = BooleanExpression::field_ge( + FieldElementExpression::value(Bn128Field::from(5)), + FieldElementExpression::value(Bn128Field::from(5)), ); - let e_false = BooleanExpression::FieldGe( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(5)), + let e_false = BooleanExpression::field_ge( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(5)), ); assert_eq!( Propagator::default().fold_boolean_expression(e_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -2036,75 +2017,75 @@ mod tests { assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(true), - box BooleanExpression::identifier(a_bool.clone()) + BooleanExpression::bitand( + BooleanExpression::value(true), + BooleanExpression::identifier(a_bool.clone()) ) ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::identifier(a_bool.clone()), - box BooleanExpression::Value(true), + BooleanExpression::bitand( + BooleanExpression::identifier(a_bool.clone()), + BooleanExpression::value(true), ) ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(false), - box BooleanExpression::identifier(a_bool.clone()) + BooleanExpression::bitand( + BooleanExpression::value(false), + BooleanExpression::identifier(a_bool.clone()) ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::identifier(a_bool.clone()), - box BooleanExpression::Value(false), + BooleanExpression::bitand( + BooleanExpression::identifier(a_bool.clone()), + BooleanExpression::value(false), ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), + BooleanExpression::bitand( + BooleanExpression::value(true), + BooleanExpression::value(false), ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(false), - box BooleanExpression::Value(true), + BooleanExpression::bitand( + BooleanExpression::value(false), + BooleanExpression::value(true), ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), + BooleanExpression::bitand( + BooleanExpression::value(true), + BooleanExpression::value(true), ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(false), - box BooleanExpression::Value(false), + BooleanExpression::bitand( + BooleanExpression::value(false), + BooleanExpression::value(false), ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -2114,75 +2095,75 @@ mod tests { assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::identifier(a_bool.clone()) + BooleanExpression::bitor( + BooleanExpression::value(true), + BooleanExpression::identifier(a_bool.clone()) ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::identifier(a_bool.clone()), - box BooleanExpression::Value(true), + BooleanExpression::bitor( + BooleanExpression::identifier(a_bool.clone()), + BooleanExpression::value(true), ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(false), - box BooleanExpression::identifier(a_bool.clone()) + BooleanExpression::bitor( + BooleanExpression::value(false), + BooleanExpression::identifier(a_bool.clone()) ) ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::identifier(a_bool.clone()), - box BooleanExpression::Value(false), + BooleanExpression::bitor( + BooleanExpression::identifier(a_bool.clone()), + BooleanExpression::value(false), ) ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), + BooleanExpression::bitor( + BooleanExpression::value(true), + BooleanExpression::value(false), ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(false), - box BooleanExpression::Value(true), + BooleanExpression::bitor( + BooleanExpression::value(false), + BooleanExpression::value(true), ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), + BooleanExpression::bitor( + BooleanExpression::value(true), + BooleanExpression::value(true), ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(false), - box BooleanExpression::Value(false), + BooleanExpression::bitor( + BooleanExpression::value(false), + BooleanExpression::value(false), ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } } diff --git a/zokrates_analysis/src/reducer/constants_reader.rs b/zokrates_analysis/src/reducer/constants_reader.rs index f991f77fd..ca882dfdd 100644 --- a/zokrates_analysis/src/reducer/constants_reader.rs +++ b/zokrates_analysis/src/reducer/constants_reader.rs @@ -2,15 +2,13 @@ use crate::reducer::ConstantDefinitions; use zokrates_ast::typed::{ - folder::*, identifier::FrameIdentifier, ArrayExpression, ArrayExpressionInner, ArrayType, - BooleanExpression, CoreIdentifier, DeclarationConstant, Expr, FieldElementExpression, Id, - Identifier, IdentifierExpression, StructExpression, StructExpressionInner, StructType, - TupleExpression, TupleExpressionInner, TupleType, TypedProgram, TypedSymbolDeclaration, - UBitwidth, UExpression, UExpressionInner, + folder::*, identifier::FrameIdentifier, CoreIdentifier, DeclarationConstant, Expr, Id, + Identifier, IdentifierExpression, IdentifierOrExpression, TypedExpression, TypedProgram, + TypedSymbolDeclaration, UExpression, UExpressionInner, }; use zokrates_field::Field; -use std::convert::{TryFrom, TryInto}; +use std::convert::TryFrom; pub struct ConstantsReader<'a, 'ast, T> { constants: &'a ConstantDefinitions<'ast, T>, @@ -44,7 +42,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { match self.constants.get(&c).cloned() { Some(e) => match UExpression::try_from(e).unwrap().into_inner() { - UExpressionInner::Value(v) => DeclarationConstant::Concrete(v as u32), + UExpressionInner::Value(v) => DeclarationConstant::Concrete(v.value as u32), _ => unreachable!(), }, None => DeclarationConstant::Constant(c), @@ -54,179 +52,31 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { } } - fn fold_field_expression( + fn fold_identifier_expression< + E: Expr<'ast, T> + Id<'ast, T> + From>, + >( &mut self, - e: FieldElementExpression<'ast, T>, - ) -> FieldElementExpression<'ast, T> { - match e { - FieldElementExpression::Identifier(IdentifierExpression { + ty: &E::Ty, + e: IdentifierExpression<'ast, E>, + ) -> IdentifierOrExpression<'ast, T, E> { + match e.id { + Identifier { id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame: _, }, - .. - }) => { + version, + } => { assert_eq!(version, 0); match self.constants.get(&c).cloned() { - Some(v) => v.try_into().unwrap(), - None => FieldElementExpression::identifier(Identifier::from( - CoreIdentifier::Constant(c), + Some(v) => IdentifierOrExpression::Expression(E::from(v).into_inner()), + None => IdentifierOrExpression::Identifier(IdentifierExpression::new( + CoreIdentifier::Constant(c).into(), )), } } - e => fold_field_expression(self, e), - } - } - - fn fold_boolean_expression( - &mut self, - e: BooleanExpression<'ast, T>, - ) -> BooleanExpression<'ast, T> { - match e { - BooleanExpression::Identifier(IdentifierExpression { - id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, - }, - .. - }) => { - assert_eq!(version, 0); - match self.constants.get(&c).cloned() { - Some(v) => v.try_into().unwrap(), - None => { - BooleanExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))) - } - } - } - e => fold_boolean_expression(self, e), - } - } - - fn fold_uint_expression_inner( - &mut self, - ty: UBitwidth, - e: UExpressionInner<'ast, T>, - ) -> UExpressionInner<'ast, T> { - match e { - UExpressionInner::Identifier(IdentifierExpression { - id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, - }, - .. - }) => { - assert_eq!(version, 0); - match self.constants.get(&c).cloned() { - Some(v) => UExpression::try_from(v).unwrap().into_inner(), - None => UExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))), - } - } - e => fold_uint_expression_inner(self, ty, e), - } - } - - fn fold_array_expression_inner( - &mut self, - ty: &ArrayType<'ast, T>, - e: ArrayExpressionInner<'ast, T>, - ) -> ArrayExpressionInner<'ast, T> { - match e { - ArrayExpressionInner::Identifier(IdentifierExpression { - id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, - }, - .. - }) => { - assert_eq!(version, 0); - match self.constants.get(&c).cloned() { - Some(v) => ArrayExpression::try_from(v).unwrap().into_inner(), - None => { - ArrayExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))) - } - } - } - e => fold_array_expression_inner(self, ty, e), - } - } - - fn fold_tuple_expression_inner( - &mut self, - ty: &TupleType<'ast, T>, - e: TupleExpressionInner<'ast, T>, - ) -> TupleExpressionInner<'ast, T> { - match e { - TupleExpressionInner::Identifier(IdentifierExpression { - id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, - }, - .. - }) => { - assert_eq!(version, 0); - match self.constants.get(&c).cloned() { - Some(v) => TupleExpression::try_from(v).unwrap().into_inner(), - None => { - TupleExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))) - } - } - } - e => fold_tuple_expression_inner(self, ty, e), - } - } - - fn fold_struct_expression_inner( - &mut self, - ty: &StructType<'ast, T>, - e: StructExpressionInner<'ast, T>, - ) -> StructExpressionInner<'ast, T> { - match e { - StructExpressionInner::Identifier(IdentifierExpression { - id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, - }, - .. - }) => { - assert_eq!(version, 0); - match self.constants.get(&c).cloned() { - Some(v) => StructExpression::try_from(v).unwrap().into_inner(), - None => { - StructExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))) - } - } - } - e => fold_struct_expression_inner(self, ty, e), + _ => fold_identifier_expression(self, ty, e), } } } diff --git a/zokrates_analysis/src/reducer/constants_writer.rs b/zokrates_analysis/src/reducer/constants_writer.rs index 50a6d25ad..a96d0c8d5 100644 --- a/zokrates_analysis/src/reducer/constants_writer.rs +++ b/zokrates_analysis/src/reducer/constants_writer.rs @@ -3,18 +3,17 @@ use crate::reducer::{ constants_reader::ConstantsReader, reduce_function, ConstantDefinitions, Error, }; -use std::collections::{BTreeMap, HashSet}; +use std::collections::HashSet; use zokrates_ast::typed::{ - result_folder::*, Constant, OwnedTypedModuleId, Typed, TypedConstant, TypedConstantSymbol, - TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram, TypedSymbolDeclaration, - UExpression, + result_folder::*, Constant, ModuleId, OwnedModuleId, Typed, TypedConstant, TypedConstantSymbol, + TypedConstantSymbolDeclaration, TypedProgram, TypedSymbolDeclaration, UExpression, }; use zokrates_field::Field; pub struct ConstantsWriter<'ast, T> { - treated: HashSet, + treated: HashSet, constants: ConstantDefinitions<'ast, T>, - location: OwnedTypedModuleId, + location: OwnedModuleId, program: TypedProgram<'ast, T>, } @@ -28,22 +27,19 @@ impl<'ast, T: Field> ConstantsWriter<'ast, T> { } } - fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { + fn change_location(&mut self, location: OwnedModuleId) -> OwnedModuleId { let prev = self.location.clone(); self.location = location; self.treated.insert(self.location.clone()); prev } - fn treated(&self, id: &TypedModuleId) -> bool { + fn treated(&self, id: &ModuleId) -> bool { self.treated.contains(id) } fn update_program(&mut self) { - let mut p = TypedProgram { - main: "".into(), - modules: BTreeMap::default(), - }; + let mut p = TypedProgram::default(); std::mem::swap(&mut self.program, &mut p); self.program = ConstantsReader::with_constants(&self.constants).read_into_program(p); } @@ -59,10 +55,7 @@ impl<'ast, T: Field> ConstantsWriter<'ast, T> { impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> { type Error = Error; - fn fold_module_id( - &mut self, - id: OwnedTypedModuleId, - ) -> Result { + fn fold_module_id(&mut self, id: OwnedModuleId) -> Result { // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet if !self.treated(&id) { let current_m_id = self.change_location(id.clone()); @@ -114,15 +107,15 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> { // wrap this expression in a function let wrapper = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return(c.expression)], + statements: vec![TypedStatement::ret(c.expression)], signature: DeclarationSignature::new().output(c.ty.clone()), }; let mut inlined_wrapper = reduce_function(wrapper, &self.program)?; - if let TypedStatement::Return(expression) = - inlined_wrapper.statements.pop().unwrap() - { + if let TypedStatement::Return(ret) = inlined_wrapper.statements.pop().unwrap() { + let expression = ret.inner; + if !expression.is_constant() { return Err(Error::ConstantReduction(id.id.to_string(), id.module)); }; diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index 8d3727d3d..254952498 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -91,7 +91,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .map(|g| { g.as_ref() .map(|g| match g.as_inner() { - UExpressionInner::Value(v) => Ok(*v as u32), + UExpressionInner::Value(v) => Ok(v.value as u32), _ => Err(()), }) .transpose() @@ -143,12 +143,12 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( assert_eq!(f.arguments.len(), arguments.len()); let generic_bindings = assignment.0.into_iter().map(|(identifier, value)| { - TypedStatement::Definition( + TypedStatement::definition( TypedAssignee::Identifier(Variable::uint( CoreIdentifier::from(identifier), UBitwidth::B32, )), - TypedExpression::from(UExpression::from(value)).into(), + TypedExpression::from(UExpression::from(value)), ) }); @@ -156,7 +156,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .arguments .into_iter() .zip(inferred_signature.inputs.clone()) - .map(|(p, t)| ConcreteVariable::new(p.id.id, t, false)) + .map(|(p, t)| ConcreteVariable::new(p.id.id, t)) .map(Variable::from) .collect(); @@ -167,7 +167,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( assert_eq!(returns.len(), 1); let return_value = match returns.pop().unwrap() { - TypedStatement::Return(e) => e, + TypedStatement::Return(e) => e.inner, _ => unreachable!(), }; diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index 826cd6663..3328bf5cb 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -20,18 +20,24 @@ mod shallow_ssa; use self::inline::InlineValue; use self::inline::{inline_call, InlineError}; use std::collections::HashMap; -use zokrates_ast::typed::result_folder::*; +use zokrates_ast::common::ResultFold; +use zokrates_ast::common::WithSpan; use zokrates_ast::typed::DeclarationParameter; use zokrates_ast::typed::Folder; +use zokrates_ast::typed::SliceExpression; +use zokrates_ast::typed::SliceOrExpression; +use zokrates_ast::typed::{result_folder::*, ArrayExpression}; +use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; + use zokrates_ast::typed::TypedAssignee; use zokrates_ast::typed::{ - ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall, - FunctionCallExpression, FunctionCallOrExpression, Id, OwnedTypedModuleId, TypedExpression, - TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram, - TypedStatement, UExpression, UExpressionInner, + BlockExpression, CoreIdentifier, Expr, FunctionCall, FunctionCallExpression, + FunctionCallOrExpression, Id, TypedExpression, TypedFunction, TypedFunctionSymbol, + TypedFunctionSymbolDeclaration, TypedModule, TypedProgram, TypedStatement, UExpression, + UExpressionInner, }; -use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; +use zokrates_ast::untyped::OwnedModuleId; use zokrates_field::Field; use self::constants_writer::ConstantsWriter; @@ -53,7 +59,7 @@ pub enum Error { Incompatible(String), GenericsInMain, LoopTooLarge(u128), - ConstantReduction(String, OwnedTypedModuleId), + ConstantReduction(String, OwnedModuleId), NonConstant(String), Type(String), Propagation(propagation::Error), @@ -119,8 +125,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ty: &E::Ty, e: FunctionCallExpression<'ast, T, E>, ) -> Result, Self::Error> { - // generics are already in ssa form + let span = e.get_span(); + // generics are already in ssa form let generics: Vec<_> = e .generics .into_iter() @@ -166,7 +173,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { let input_bindings: Vec<_> = input_variables .into_iter() .zip(arguments) - .map(|(v, a)| TypedStatement::definition(self.ssa.fold_assignee(v.into()), a)) + .map(|(v, a)| { + TypedStatement::definition(self.ssa.fold_assignee(v.into()), a).span(span) + }) .collect(); let input_bindings = input_bindings @@ -174,7 +183,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .map(|s| self.propagator.fold_statement(s)) .collect::, _>>()? .into_iter() - .flatten(); + .flatten() + .collect::>(); self.statement_buffer.extend(input_bindings); @@ -183,7 +193,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .map(|s| self.fold_statement(s)) .collect::, _>>()? .into_iter() - .flatten(); + .flatten() + .collect::>(); self.statement_buffer.extend(statements); @@ -208,18 +219,20 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { Err(InlineError::Flat(embed, generics, output_type)) => { let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0)); - let var = Variable::immutable(identifier.clone(), output_type); + let var = Variable::new(identifier.clone(), output_type); let v: TypedAssignee<'ast, T> = var.clone().into(); - self.statement_buffer - .push(TypedStatement::embed_call_definition( + self.statement_buffer.push( + TypedStatement::embed_call_definition( v, EmbedCall::new(embed, generics, arguments), - )); - Ok(FunctionCallOrExpression::Expression(E::identifier( - identifier, - ))) + ) + .span(span), + ); + Ok(FunctionCallOrExpression::Expression( + E::identifier(identifier).span(span), + )) } }; @@ -228,7 +241,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { res } - fn fold_block_expression>( + fn fold_block_expression>( &mut self, b: BlockExpression<'ast, T, E>, ) -> Result, Self::Error> { @@ -258,36 +271,39 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { unreachable!("canonical constant identifiers should not be folded, they should be inlined") } + // here we implement fold_statement and not fold_statement_cases because we do not want the span of the input + // to be applied to all the outputs: a statement which contains a call which gets inline should not hide the + // inlined statements behind its own span fn fold_statement( &mut self, s: TypedStatement<'ast, T>, ) -> Result>, Self::Error> { let res = match s { - TypedStatement::For(v, from, to, statements) => { - let from = self.ssa.fold_uint_expression(from); + TypedStatement::For(s) => { + let from = self.ssa.fold_uint_expression(s.from); let from = self.propagator.fold_uint_expression(from)?; let from = self.fold_uint_expression(from)?; let from = self.propagator.fold_uint_expression(from)?; - let to = self.ssa.fold_uint_expression(to); + let to = self.ssa.fold_uint_expression(s.to); let to = self.propagator.fold_uint_expression(to)?; let to = self.fold_uint_expression(to)?; let to = self.propagator.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { (UExpressionInner::Value(from), UExpressionInner::Value(to)) - if to - from > MAX_FOR_LOOP_SIZE => + if to.value - from.value > MAX_FOR_LOOP_SIZE => { - Err(Error::LoopTooLarge(to.saturating_sub(*from))) + Err(Error::LoopTooLarge(to.value.saturating_sub(from.value))) } - (UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from - ..*to) + (UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((from.value + ..to.value) .flat_map(|index| { std::iter::once(TypedStatement::definition( - v.clone().into(), + s.var.clone().into(), UExpression::from(index as u32).into(), )) - .chain(statements.clone()) + .chain(s.statements.clone()) .map(|s| self.fold_statement(s)) .collect::>() }) @@ -330,39 +346,33 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { Ok(self.statement_buffer.drain(..).chain(res).collect()) } - fn fold_array_expression_inner( + fn fold_slice_expression( &mut self, - array_ty: &ArrayType<'ast, T>, - e: ArrayExpressionInner<'ast, T>, - ) -> Result, Self::Error> { - match e { - ArrayExpressionInner::Slice(box array, box from, box to) => { - let array = self.ssa.fold_array_expression(array); - let array = self.propagator.fold_array_expression(array)?; - let array = self.fold_array_expression(array)?; - let array = self.propagator.fold_array_expression(array)?; - - let from = self.ssa.fold_uint_expression(from); - let from = self.propagator.fold_uint_expression(from)?; - let from = self.fold_uint_expression(from)?; - let from = self.propagator.fold_uint_expression(from)?; - - let to = self.ssa.fold_uint_expression(to); - let to = self.propagator.fold_uint_expression(to)?; - let to = self.fold_uint_expression(to)?; - let to = self.propagator.fold_uint_expression(to)?; - - match (from.as_inner(), to.as_inner()) { - (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { - Ok(ArrayExpressionInner::Slice(box array, box from, box to)) - } - _ => Err(Error::NonConstant(format!( - "Slice bounds must be compile time constants, found {}", - ArrayExpressionInner::Slice(box array, box from, box to) - ))), - } - } - _ => fold_array_expression_inner(self, array_ty, e), + e: zokrates_ast::typed::SliceExpression<'ast, T>, + ) -> Result, Self::Error> { + let array = self.ssa.fold_array_expression(*e.array); + let array = self.propagator.fold_array_expression(array)?; + let array = self.fold_array_expression(array)?; + let array = self.propagator.fold_array_expression(array)?; + + let from = self.ssa.fold_uint_expression(*e.from); + let from = self.propagator.fold_uint_expression(from)?; + let from = self.fold_uint_expression(from)?; + let from = self.propagator.fold_uint_expression(from)?; + + let to = self.ssa.fold_uint_expression(*e.to); + let to = self.propagator.fold_uint_expression(to)?; + let to = self.fold_uint_expression(to)?; + let to = self.propagator.fold_uint_expression(to)?; + + match (from.as_inner(), to.as_inner()) { + (UExpressionInner::Value(..), UExpressionInner::Value(..)) => Ok( + SliceOrExpression::Slice(SliceExpression::new(array, from, to)), + ), + _ => Err(Error::NonConstant(format!( + "Slice bounds must be compile time constants, found {}", + ArrayExpression::slice(array, from, to) + ))), } } } @@ -405,6 +415,7 @@ pub fn reduce_program(p: TypedProgram) -> Result, E )] .into_iter() .collect(), + ..p }) } _ => Err(Error::GenericsInMain), @@ -426,17 +437,16 @@ mod tests { use zokrates_ast::typed::types::DeclarationSignature; use zokrates_ast::typed::types::{DeclarationConstant, GTupleType}; use zokrates_ast::typed::{ - ArrayExpression, ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, - DeclarationVariable, FieldElementExpression, GenericIdentifier, Identifier, - OwnedTypedModuleId, Select, TupleExpressionInner, TupleType, Type, TypedExpression, - TypedExpressionOrSpread, UBitwidth, UExpressionInner, Variable, + ArrayExpression, DeclarationFunctionKey, DeclarationType, DeclarationVariable, + FieldElementExpression, GenericIdentifier, Identifier, Select, TupleExpression, TupleType, + Type, TypedExpression, TypedExpressionOrSpread, UBitwidth, Variable, }; use zokrates_field::Bn128Field; use lazy_static::lazy_static; lazy_static! { - static ref MAIN_MODULE_ID: OwnedTypedModuleId = OwnedTypedModuleId::from("main"); + static ref MAIN_MODULE_ID: OwnedModuleId = OwnedModuleId::from("main"); } #[test] @@ -463,7 +473,7 @@ mod tests { let foo: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier("a".into()).into(), )], signature: DeclarationSignature::new() @@ -507,7 +517,7 @@ mod tests { .annotate(UBitwidth::B32) .into(), ), - TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()), + TypedStatement::ret(FieldElementExpression::identifier("a".into()).into()), ], signature: DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) @@ -515,6 +525,7 @@ mod tests { }; let p = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -562,7 +573,7 @@ mod tests { Variable::field_element(Identifier::from("a").version(2)).into(), FieldElementExpression::identifier(Identifier::from("a").in_frame(1)).into(), ), - TypedStatement::Return( + TypedStatement::ret( FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), ), ], @@ -572,6 +583,7 @@ mod tests { }; let expected: TypedProgram = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -636,7 +648,7 @@ mod tests { GenericIdentifier::with_name("K").with_index(0), ) .into()], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) .annotate(Type::FieldElement, 1u32) .into(), @@ -659,9 +671,9 @@ mod tests { ), TypedStatement::definition( Variable::array("b", Type::FieldElement, 1u32).into(), - ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier("a".into()).into()].into(), - ) + ArrayExpression::value(vec![ + FieldElementExpression::identifier("a".into()).into() + ]) .annotate(Type::FieldElement, 1u32) .into(), ), @@ -684,7 +696,7 @@ mod tests { .annotate(UBitwidth::B32) .into(), ), - TypedStatement::Return( + TypedStatement::ret( (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier("b".into()) @@ -700,6 +712,7 @@ mod tests { }; let p = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -734,9 +747,9 @@ mod tests { statements: vec![ TypedStatement::definition( Variable::array("b", Type::FieldElement, 1u32).into(), - ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier("a".into()).into()].into(), - ) + ArrayExpression::value(vec![ + FieldElementExpression::identifier("a".into()).into() + ]) .annotate(Type::FieldElement, 1u32) .into(), ), @@ -754,7 +767,7 @@ mod tests { .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::Return( + TypedStatement::ret( (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier(Identifier::from("b").version(1)) @@ -770,6 +783,7 @@ mod tests { }; let expected = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -834,7 +848,7 @@ mod tests { GenericIdentifier::with_name("K").with_index(0), ) .into()], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) .annotate(Type::FieldElement, 1u32) .into(), @@ -842,6 +856,8 @@ mod tests { signature: foo_signature.clone(), }; + use std::ops::Sub; + let main: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], statements: vec![ @@ -859,16 +875,15 @@ mod tests { Variable::array( "b", Type::FieldElement, - UExpressionInner::Sub( - box UExpression::identifier("n".into()).annotate(UBitwidth::B32), - box 1u32.into(), - ) - .annotate(UBitwidth::B32), + UExpression::sub( + UExpression::identifier("n".into()).annotate(UBitwidth::B32), + 1u32.into(), + ), ) .into(), - ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier("a".into()).into()].into(), - ) + ArrayExpression::value(vec![ + FieldElementExpression::identifier("a".into()).into() + ]) .annotate(Type::FieldElement, 1u32) .into(), ), @@ -891,7 +906,7 @@ mod tests { .annotate(UBitwidth::B32) .into(), ), - TypedStatement::Return( + TypedStatement::ret( (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier("b".into()) @@ -907,6 +922,7 @@ mod tests { }; let p = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -941,10 +957,9 @@ mod tests { statements: vec![ TypedStatement::definition( Variable::array("b", Type::FieldElement, 1u32).into(), - ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier(Identifier::from("a")).into()] - .into(), - ) + ArrayExpression::value(vec![ + FieldElementExpression::identifier("a".into()).into() + ]) .annotate(Type::FieldElement, 1u32) .into(), ), @@ -962,7 +977,7 @@ mod tests { .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::Return( + TypedStatement::ret( (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier(Identifier::from("b").version(1)) @@ -978,6 +993,7 @@ mod tests { }; let expected = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -1021,10 +1037,6 @@ mod tests { // expected: // def main() { - // # PUSH CALL to foo::<1> - // # PUSH CALL to bar::<2> - // # POP CALL - // # POP CALL // return; // } @@ -1056,22 +1068,19 @@ mod tests { UExpression::identifier("K".into()).annotate(UBitwidth::B32), ) .into(), - ArrayExpressionInner::Slice( - box ArrayExpression::function_call( + ArrayExpression::slice( + ArrayExpression::function_call( DeclarationFunctionKey::with_location("main", "bar") .signature(foo_signature.clone()), vec![None], - vec![ArrayExpressionInner::Value( - vec![ - TypedExpressionOrSpread::Spread( - ArrayExpression::identifier("a".into()) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - FieldElementExpression::Number(Bn128Field::from(0)).into(), - ] - .into(), - ) + vec![ArrayExpression::value(vec![ + TypedExpressionOrSpread::Spread( + ArrayExpression::identifier("a".into()) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + FieldElementExpression::value(Bn128Field::from(0)).into(), + ]) .annotate( Type::FieldElement, UExpression::identifier("K".into()).annotate(UBitwidth::B32) @@ -1084,16 +1093,12 @@ mod tests { UExpression::identifier("K".into()).annotate(UBitwidth::B32) + 1u32.into(), ), - box 0u32.into(), - box UExpression::identifier("K".into()).annotate(UBitwidth::B32), - ) - .annotate( - Type::FieldElement, + 0u32.into(), UExpression::identifier("K".into()).annotate(UBitwidth::B32), ) .into(), ), - TypedStatement::Return( + TypedStatement::ret( ArrayExpression::identifier("ret".into()) .annotate( Type::FieldElement, @@ -1114,7 +1119,7 @@ mod tests { DeclarationConstant::Generic(GenericIdentifier::with_name("K").with_index(0)), ) .into()], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) .annotate( Type::FieldElement, @@ -1134,17 +1139,18 @@ mod tests { DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), vec![None], - vec![ArrayExpressionInner::Value( - vec![FieldElementExpression::Number(Bn128Field::from(1)).into()].into(), + vec![ArrayExpression::value(vec![FieldElementExpression::value( + Bn128Field::from(1), ) + .into()]) .annotate(Type::FieldElement, 1u32) .into()], ) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), ), @@ -1180,14 +1186,15 @@ mod tests { )] .into_iter() .collect(), + module_map: Default::default(), }; let reduced = reduce_program(p); let expected_main = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + statements: vec![TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), )], @@ -1195,6 +1202,7 @@ mod tests { }; let expected = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -1247,7 +1255,7 @@ mod tests { GenericIdentifier::with_name("K").with_index(0), ) .into()], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) .annotate(Type::FieldElement, 1u32) .into(), @@ -1264,15 +1272,15 @@ mod tests { DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), vec![None], - vec![ArrayExpressionInner::Value(vec![].into()) + vec![ArrayExpression::value(vec![]) .annotate(Type::FieldElement, 0u32) .into()], ) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), ), @@ -1308,6 +1316,7 @@ mod tests { )] .into_iter() .collect(), + module_map: Default::default(), }; let reduced = reduce_program(p); diff --git a/zokrates_analysis/src/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs index aaec4fd56..8bb244a73 100644 --- a/zokrates_analysis/src/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -155,16 +155,11 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { } } - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { - match s { - // only fold bounds of for loop statements - TypedStatement::For(v, from, to, stats) => { - let from = self.fold_uint_expression(from); - let to = self.fold_uint_expression(to); - vec![TypedStatement::For(v, from, to, stats)] - } - s => fold_statement(self, s), - } + // only fold bounds of for loop statements + fn fold_for_statement(&mut self, s: ForStatement<'ast, T>) -> Vec> { + let from = self.fold_uint_expression(s.from); + let to = self.fold_uint_expression(s.to); + vec![TypedStatement::for_(s.var, from, to, s.statements)] } // retrieve the latest version @@ -185,6 +180,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { #[cfg(test)] mod tests { use super::*; + use std::ops::*; use zokrates_ast::typed::types::DeclarationSignature; use zokrates_field::Bn128Field; mod normal { @@ -205,7 +201,7 @@ mod tests { statements: vec![ TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element(Identifier::from("foo"))), - FieldElementExpression::Number(Bn128Field::from(4)).into(), + FieldElementExpression::value(Bn128Field::from(4)).into(), ), TypedStatement::definition( TypedAssignee::Identifier(Variable::uint( @@ -214,19 +210,19 @@ mod tests { )), UExpression::from(0u32).into(), ), - TypedStatement::For( - Variable::new("i", Type::Uint(UBitwidth::B32), false), + TypedStatement::for_( + Variable::new("i", Type::Uint(UBitwidth::B32)), UExpression::identifier("i".into()).annotate(UBitwidth::B32), 2u32.into(), vec![TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element(Identifier::from( "foo", ))), - FieldElementExpression::Number(Bn128Field::from(5)).into(), + FieldElementExpression::value(Bn128Field::from(5)).into(), )], ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), ), @@ -254,7 +250,7 @@ mod tests { let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(5)).into(), + FieldElementExpression::value(Bn128Field::from(5)).into(), ); assert_eq!( u.fold_statement(s), @@ -262,13 +258,13 @@ mod tests { TypedAssignee::Identifier(Variable::field_element( Identifier::from("a").version(0) )), - FieldElementExpression::Number(Bn128Field::from(5)).into() + FieldElementExpression::value(Bn128Field::from(5)).into() )] ); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(6)).into(), + FieldElementExpression::value(Bn128Field::from(6)).into(), ); assert_eq!( u.fold_statement(s), @@ -276,7 +272,7 @@ mod tests { TypedAssignee::Identifier(Variable::field_element( Identifier::from("a").version(1) )), - FieldElementExpression::Number(Bn128Field::from(6)).into() + FieldElementExpression::value(Bn128Field::from(6)).into() )] ); @@ -301,7 +297,7 @@ mod tests { let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(5)).into(), + FieldElementExpression::value(Bn128Field::from(5)).into(), ); assert_eq!( u.fold_statement(s), @@ -309,15 +305,15 @@ mod tests { TypedAssignee::Identifier(Variable::field_element( Identifier::from("a").version(0) )), - FieldElementExpression::Number(Bn128Field::from(5)).into() + FieldElementExpression::value(Bn128Field::from(5)).into() )] ); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Add( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(1)), + FieldElementExpression::add( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(1)), ) .into(), ); @@ -327,9 +323,9 @@ mod tests { TypedAssignee::Identifier(Variable::field_element( Identifier::from("a").version(1) )), - FieldElementExpression::Add( - box FieldElementExpression::identifier(Identifier::from("a").version(0)), - box FieldElementExpression::Number(Bn128Field::from(1)) + FieldElementExpression::add( + FieldElementExpression::identifier(Identifier::from("a").version(0)), + FieldElementExpression::value(Bn128Field::from(1)) ) .into() )] @@ -350,7 +346,7 @@ mod tests { let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ); assert_eq!( u.fold_statement(s), @@ -358,7 +354,7 @@ mod tests { TypedAssignee::Identifier(Variable::field_element( Identifier::from("a").version(0) )), - FieldElementExpression::Number(Bn128Field::from(2)).into() + FieldElementExpression::value(Bn128Field::from(2)).into() )] ); @@ -409,13 +405,10 @@ mod tests { let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)), - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(1)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ] - .into(), - ) + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), + ]) .annotate(Type::FieldElement, 2u32) .into(), ); @@ -428,13 +421,10 @@ mod tests { Type::FieldElement, 2u32 )), - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(1)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into() - ] - .into() - ) + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into() + ]) .annotate(Type::FieldElement, 2u32) .into() )] @@ -445,7 +435,7 @@ mod tests { box TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)), box UExpression::from(1u32), ), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ); assert_eq!(u.fold_statement(s.clone()), vec![s]); @@ -465,30 +455,21 @@ mod tests { let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32)); let s = TypedStatement::definition( - TypedAssignee::Identifier(Variable::new("a", array_of_array_ty.clone(), true)), - ArrayExpressionInner::Value( - vec![ - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(0)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ] - .into(), - ) - .annotate(Type::FieldElement, 2u32) - .into(), - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(3)).into(), - ] - .into(), - ) - .annotate(Type::FieldElement, 2u32) - .into(), - ] + TypedAssignee::Identifier(Variable::new("a", array_of_array_ty.clone())), + ArrayExpression::value(vec![ + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(0)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), + ]) + .annotate(Type::FieldElement, 2u32) .into(), - ) + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(3)).into(), + ]) + .annotate(Type::FieldElement, 2u32) + .into(), + ]) .annotate(Type::array((Type::FieldElement, 2u32)), 2u32) .into(), ); @@ -499,31 +480,21 @@ mod tests { TypedAssignee::Identifier(Variable::new( Identifier::from("a").version(0), array_of_array_ty.clone(), - true, )), - ArrayExpressionInner::Value( - vec![ - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(0)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ] - .into() - ) - .annotate(Type::FieldElement, 2u32) - .into(), - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(3)).into(), - ] - .into() - ) - .annotate(Type::FieldElement, 2u32) - .into(), - ] - .into() - ) + ArrayExpression::value(vec![ + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(0)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), + ]) + .annotate(Type::FieldElement, 2u32) + .into(), + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(3)).into(), + ]) + .annotate(Type::FieldElement, 2u32) + .into(), + ]) .annotate(Type::array((Type::FieldElement, 2u32)), 2u32) .into(), )] @@ -531,20 +502,13 @@ mod tests { let s: TypedStatement = TypedStatement::definition( TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::new( - "a", - array_of_array_ty.clone(), - true, - )), + box TypedAssignee::Identifier(Variable::new("a", array_of_array_ty.clone())), box UExpression::from(1u32), ), - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(4)).into(), - FieldElementExpression::Number(Bn128Field::from(5)).into(), - ] - .into(), - ) + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(4)).into(), + FieldElementExpression::value(Bn128Field::from(5)).into(), + ]) .annotate(Type::FieldElement, 2u32) .into(), ); @@ -606,7 +570,7 @@ mod tests { Variable::field_element("a").into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::For( + TypedStatement::for_( Variable::uint("i", UBitwidth::B32), UExpression::identifier("n".into()).annotate(UBitwidth::B32), UExpression::identifier("n".into()).annotate(UBitwidth::B32) @@ -620,7 +584,7 @@ mod tests { Variable::field_element("a").into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::For( + TypedStatement::for_( Variable::uint("i", UBitwidth::B32), UExpression::identifier("n".into()).annotate(UBitwidth::B32), UExpression::identifier("n".into()).annotate(UBitwidth::B32) @@ -634,7 +598,7 @@ mod tests { Variable::field_element("a").into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()), + TypedStatement::ret(FieldElementExpression::identifier("a".into()).into()), ], signature: DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) @@ -660,7 +624,7 @@ mod tests { Variable::field_element(Identifier::from("a").version(1)).into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::For( + TypedStatement::for_( Variable::uint("i", UBitwidth::B32), UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32), @@ -677,7 +641,7 @@ mod tests { Variable::field_element(Identifier::from("a").version(2)).into(), FieldElementExpression::identifier(Identifier::from("a").version(1)).into(), ), - TypedStatement::For( + TypedStatement::for_( Variable::uint("i", UBitwidth::B32), UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32), @@ -694,7 +658,7 @@ mod tests { Variable::field_element(Identifier::from("a").version(3)).into(), FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), ), - TypedStatement::Return( + TypedStatement::ret( FieldElementExpression::identifier(Identifier::from("a").version(3)).into(), ), ], @@ -720,6 +684,7 @@ mod tests { } mod shadowing { + use super::*; #[test] @@ -751,10 +716,10 @@ mod tests { 1, ))) .into(), - BooleanExpression::Value(true).into(), + BooleanExpression::value(true).into(), ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), ), @@ -775,10 +740,10 @@ mod tests { 1, ))) .into(), - BooleanExpression::Value(true).into(), + BooleanExpression::value(true).into(), ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), ), @@ -860,7 +825,7 @@ mod tests { )) .into(), ), - TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()), + TypedStatement::ret(FieldElementExpression::identifier("a".into()).into()), ], signature: DeclarationSignature::new() .generics(vec![Some( @@ -924,7 +889,7 @@ mod tests { )) .into(), ), - TypedStatement::Return( + TypedStatement::ret( FieldElementExpression::identifier(Identifier::from("a").version(3)).into(), ), ], diff --git a/zokrates_analysis/src/uint_optimizer.rs b/zokrates_analysis/src/uint_optimizer.rs index 6a6b217ec..c006cdd67 100644 --- a/zokrates_analysis/src/uint_optimizer.rs +++ b/zokrates_analysis/src/uint_optimizer.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::ops::{BitAnd, Shl, Shr}; -use zokrates_ast::common::FlatEmbed; +use zokrates_ast::common::{FlatEmbed, Fold, WithSpan}; use zokrates_ast::zir::folder::*; use zokrates_ast::zir::*; use zokrates_field::Field; @@ -55,7 +55,7 @@ fn force_no_reduce(e: UExpression) -> UExpression { } impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { - fn fold_select_expression + Fold<'ast, T> + Select<'ast, T>>( + fn fold_select_expression + Fold + Select<'ast, T>>( &mut self, _: &E::Ty, e: SelectExpression<'ast, T, E>, @@ -66,43 +66,45 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { SelectOrExpression::Select(SelectExpression::new(array, force_reduce(index))) } - fn fold_boolean_expression( + fn fold_boolean_expression_cases( &mut self, e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { match e { - BooleanExpression::UintEq(box left, box right) => { - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + BooleanExpression::UintEq(e) => { + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left = force_reduce(left); let right = force_reduce(right); - BooleanExpression::UintEq(box left, box right) + BooleanExpression::uint_eq(left, right) } - BooleanExpression::UintLt(box left, box right) => { - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + BooleanExpression::UintLt(e) => { + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left = force_reduce(left); let right = force_reduce(right); - BooleanExpression::UintLt(box left, box right) + BooleanExpression::uint_lt(left, right) } - BooleanExpression::UintLe(box left, box right) => { - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + BooleanExpression::UintLe(e) => { + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left = force_reduce(left); let right = force_reduce(right); - BooleanExpression::UintLe(box left, box right) + BooleanExpression::uint_le(left, right) } - e => fold_boolean_expression(self, e), + e => fold_boolean_expression_cases(self, e), } } fn fold_uint_expression(&mut self, e: UExpression<'ast, T>) -> UExpression<'ast, T> { + let span = e.get_span(); + if e.metadata.is_some() { return e; } @@ -120,7 +122,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { use self::UExpressionInner::*; let res = match inner { - Value(v) => Value(v).annotate(range).with_max(v), + Value(v) => Value(v.clone()).annotate(range).with_max(v.value), Identifier(id) => Identifier(id.clone()).annotate(range).metadata( self.ids .get(&Variable::uint(id.id.clone(), range)) @@ -151,10 +153,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { UExpression::select(values, index).with_max(max_value) } - Add(box left, box right) => { + Add(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left_max = left.metadata.clone().unwrap().max; let right_max = right.metadata.clone().unwrap().max; @@ -187,7 +189,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { UExpression::add(left, right).with_max(max) } - Sub(box left, box right) => { + Sub(e) => { // let `target` the target bitwidth of `left` and `right` // `0 <= left <= max_left` // `0 <= right <= max_right` @@ -205,8 +207,8 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { // smaller or equal to N for target in {8, 16, 32} // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left_max = left.metadata.clone().unwrap().max; let right_bitwidth = right.metadata.clone().unwrap().bitwidth(); @@ -254,31 +256,31 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { UExpression::sub(left, right).with_max(max) } - Xor(box left, box right) => { + Xor(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); UExpression::xor(force_reduce(left), force_reduce(right)).with_max(range_max) } - And(box left, box right) => { + And(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); UExpression::and(force_reduce(left), force_reduce(right)).with_max(range_max) } - Or(box left, box right) => { + Or(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); UExpression::or(force_reduce(left), force_reduce(right)).with_max(range_max) } - Mult(box left, box right) => { + Mult(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left_max = left.metadata.clone().unwrap().max; let right_max = right.metadata.clone().unwrap().max; @@ -311,52 +313,72 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { UExpression::mult(left, right).with_max(max) } - Div(box left, box right) => { + Div(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); UExpression::div(force_reduce(left), force_reduce(right)).with_max(range_max) } - Rem(box left, box right) => { + Rem(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); UExpression::rem(force_reduce(left), force_reduce(right)).with_max(range_max) } - Not(box e) => { - let e = self.fold_uint_expression(e); + Not(e) => { + let inner = self.fold_uint_expression(*e.inner); - UExpressionInner::Not(box force_reduce(e)) - .annotate(range) - .with_max(range_max) + UExpression::not(force_reduce(inner)).with_max(range_max) } - LeftShift(box e, by) => { + LeftShift(e) => { // reduce both terms - let e = self.fold_uint_expression(e); - - let e_max: num_bigint::BigUint = e.metadata.as_ref().unwrap().max.to_biguint(); - let max = e_max - .shl(by as usize) - .bitand(&(2_u128.pow(range as u32) - 1).into()); - - let max = T::try_from(max).unwrap(); - - UExpression::left_shift(force_reduce(e), by).with_max(max) + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); + + match right.into_inner() { + UExpressionInner::Value(by) => { + let e_max: num_bigint::BigUint = + left.metadata.as_ref().unwrap().max.to_biguint(); + let max = e_max + .shl(by.value as usize) + .bitand(&(2_u128.pow(range as u32) - 1).into()); + + let max = T::try_from(max).unwrap(); + + UExpression::left_shift( + force_reduce(left), + UExpression::value(by.value).annotate(UBitwidth::B32), + ) + .with_max(max) + } + _ => unreachable!(), + } } - RightShift(box e, by) => { + RightShift(e) => { // reduce both terms - let e = self.fold_uint_expression(e); - - let e_max: num_bigint::BigUint = e.metadata.as_ref().unwrap().max.to_biguint(); - let max = e_max - .bitand(&(2_u128.pow(range as u32) - 1).into()) - .shr(by as usize); - - let max = T::try_from(max).unwrap(); - - UExpression::right_shift(force_reduce(e), by).with_max(max) + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); + + match right.into_inner() { + UExpressionInner::Value(by) => { + let e_max: num_bigint::BigUint = + left.metadata.as_ref().unwrap().max.to_biguint(); + let max = e_max + .bitand(&(2_u128.pow(range as u32) - 1).into()) + .shr(by.value as usize); + + let max = T::try_from(max).unwrap(); + + UExpression::right_shift( + force_reduce(left), + UExpression::value(by.value).annotate(UBitwidth::B32), + ) + .with_max(max) + } + _ => unreachable!(), + } } Conditional(e) => { let condition = self.fold_boolean_expression(*e.condition); @@ -379,44 +401,52 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { assert!(res.metadata.is_some()); - res + res.span(span) } - fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { - match s { - ZirStatement::Definition(a, e) => { - let e = self.fold_expression(e); + fn fold_return_statement(&mut self, s: ReturnStatement<'ast, T>) -> Vec> { + // we need to put back in range to return + vec![ZirStatement::ret( + s.inner + .into_iter() + .map(|e| match e { + ZirExpression::Uint(e) => { + let e = self.fold_uint_expression(e); + + let e = force_reduce(e); - let e = match e { - ZirExpression::Uint(i) => { - let i = force_no_reduce(i); - self.register(a.clone(), i.metadata.clone().unwrap()); - ZirExpression::Uint(i) + ZirExpression::Uint(e) } - e => e, - }; - vec![ZirStatement::Definition(a, e)] + e => self.fold_expression(e), + }) + .collect(), + )] + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Vec> { + let e = self.fold_expression(s.rhs); + + let e = match e { + ZirExpression::Uint(i) => { + let i = force_no_reduce(i); + self.register(s.assignee.clone(), i.metadata.clone().unwrap()); + ZirExpression::Uint(i) } - // we need to put back in range to return - ZirStatement::Return(expressions) => vec![ZirStatement::Return( - expressions - .into_iter() - .map(|e| match e { - ZirExpression::Uint(e) => { - let e = self.fold_uint_expression(e); - - let e = force_reduce(e); - - ZirExpression::Uint(e) - } - e => self.fold_expression(e), - }) - .collect(), - )], - ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall(embed, generics, arguments), - ) => { + e => e, + }; + vec![ZirStatement::definition(s.assignee, e)] + } + + fn fold_multiple_definition_statement( + &mut self, + s: MultipleDefinitionStatement<'ast, T>, + ) -> Vec> { + let lhs = s.assignees; + match s.rhs { + ZirExpressionList::EmbedCall(embed, generics, arguments) => { match embed { FlatEmbed::U64FromBits => { assert_eq!(lhs.len(), 1); @@ -467,74 +497,68 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { | FlatEmbed::U32ToBits | FlatEmbed::U64ToBits => { vec![ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall( - embed, - generics, - arguments - .into_iter() - .map(|e| match e { - ZirExpression::Uint(e) => { - let e = self.fold_uint_expression(e); - let e = force_reduce(e); - ZirExpression::Uint(e) - } - e => self.fold_expression(e), - }) - .collect(), + MultipleDefinitionStatement::new( + lhs, + ZirExpressionList::EmbedCall( + embed, + generics, + arguments + .into_iter() + .map(|e| match e { + ZirExpression::Uint(e) => { + let e = self.fold_uint_expression(e); + let e = force_reduce(e); + ZirExpression::Uint(e) + } + e => self.fold_expression(e), + }) + .collect(), + ), ), )] } _ => { vec![ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall( - embed, - generics, - arguments - .into_iter() - .map(|e| self.fold_expression(e)) - .collect(), + MultipleDefinitionStatement::new( + lhs, + ZirExpressionList::EmbedCall( + embed, + generics, + arguments + .into_iter() + .map(|e| self.fold_expression(e)) + .collect(), + ), ), )] } } } - ZirStatement::Assertion(BooleanExpression::UintEq(box left, box right), metadata) => { - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); - - // we can only compare two unsigned integers if they are in range - let left = force_reduce(left); - let right = force_reduce(right); - - vec![ZirStatement::Assertion( - BooleanExpression::UintEq(box left, box right), - metadata, - )] - } - ZirStatement::Log(l, e) => vec![ZirStatement::Log( - l, - e.into_iter() - .map(|(t, e)| { - ( - t, - e.into_iter() - .map(|e| match e { - ZirExpression::Uint(e) => { - force_reduce(self.fold_uint_expression(e)).into() - } - e => self.fold_expression(e), - }) - .collect(), - ) - }) - .collect(), - )], - s => fold_statement(self, s), } } + fn fold_log_statement(&mut self, s: LogStatement<'ast, T>) -> Vec> { + vec![ZirStatement::Log(LogStatement::new( + s.format_string, + s.expressions + .into_iter() + .map(|(t, e)| { + ( + t, + e.into_iter() + .map(|e| match e { + ZirExpression::Uint(e) => { + force_reduce(self.fold_uint_expression(e)).into() + } + e => self.fold_expression(e), + }) + .collect(), + ) + }) + .collect(), + ))] + } + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { let id = match p.id.get_type() { Type::Uint(bitwidth) => { @@ -730,8 +754,8 @@ mod tests { assert_eq!( UintOptimizer::new() - .fold_uint_expression(UExpression::right_shift(left.clone(), right)), - UExpression::right_shift(left_expected, right_expected).with_max(output_max) + .fold_uint_expression(UExpression::right_shift(left.clone(), right.into())), + UExpression::right_shift(left_expected, right_expected.into()).with_max(output_max) ); } @@ -753,8 +777,8 @@ mod tests { assert_eq!( UintOptimizer::new() - .fold_uint_expression(UExpression::left_shift(left.clone(), right)), - UExpression::left_shift(left_expected, right_expected).with_max(output_max) + .fold_uint_expression(UExpression::left_shift(left.clone(), right.into())), + UExpression::left_shift(left_expected, right_expected.into()).with_max(output_max) ); } @@ -777,7 +801,7 @@ mod tests { assert_eq!( UintOptimizer::new() .fold_uint_expression(UExpression::conditional( - BooleanExpression::Value(true), + BooleanExpression::value(true), consequence, alternative )) diff --git a/zokrates_analysis/src/variable_write_remover.rs b/zokrates_analysis/src/variable_write_remover.rs index b218acc33..9defeeaf5 100644 --- a/zokrates_analysis/src/variable_write_remover.rs +++ b/zokrates_analysis/src/variable_write_remover.rs @@ -6,6 +6,7 @@ use std::collections::HashSet; use std::fmt; +use zokrates_ast::common::{Span, WithSpan}; use zokrates_ast::typed::result_folder::ResultFolder; use zokrates_ast::typed::result_folder::*; use zokrates_ast::typed::types::{MemberId, Type}; @@ -38,6 +39,7 @@ impl<'ast> VariableWriteRemover { indices: Vec>, new_expression: TypedExpression<'ast, T>, statements: &mut HashSet>, + span: Option, ) -> TypedExpression<'ast, T> { let mut indices = indices; @@ -55,25 +57,36 @@ impl<'ast> VariableWriteRemover { match head { Access::Select(box head) => { - statements.insert(TypedStatement::Assertion( - BooleanExpression::UintLt(box head.clone(), box size.into()), + statements.insert(TypedStatement::assertion( + BooleanExpression::uint_lt( + head.clone(), + UExpression::from(size).span(span), + ) + .span(span), RuntimeError::SelectRangeCheck, )); - ArrayExpressionInner::Value( + ArrayExpression::value( (0..size) .map(|i| match inner_ty { Type::Int => unreachable!(), Type::Array(..) => ArrayExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - ArrayExpression::select(base.clone(), i).into(), + ArrayExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::Array(e) => e, e => unreachable!( @@ -81,20 +94,32 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - ArrayExpression::select(base.clone(), i), + ArrayExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), Type::Struct(..) => StructExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - StructExpression::select(base.clone(), i).into(), + StructExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::Struct(e) => e, e => unreachable!( @@ -102,20 +127,31 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - StructExpression::select(base.clone(), i), + StructExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), Type::Tuple(..) => TupleExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - TupleExpression::select(base.clone(), i).into(), + TupleExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::Tuple(e) => e, e => unreachable!( @@ -123,21 +159,31 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - TupleExpression::select(base.clone(), i), + TupleExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), Type::FieldElement => FieldElementExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - FieldElementExpression::select(base.clone(), i) - .into(), + FieldElementExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::FieldElement(e) => e, e => unreachable!( @@ -145,20 +191,31 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - FieldElementExpression::select(base.clone(), i), + FieldElementExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), Type::Boolean => BooleanExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - BooleanExpression::select(base.clone(), i).into(), + BooleanExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::Boolean(e) => e, e => unreachable!( @@ -166,20 +223,31 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - BooleanExpression::select(base.clone(), i), + BooleanExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), Type::Uint(..) => UExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - UExpression::select(base.clone(), i).into(), + UExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::Uint(e) => e, e => unreachable!( @@ -187,13 +255,16 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - UExpression::select(base.clone(), i), + UExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), }) - .collect::>() - .into(), + .collect::>(), ) .annotate(inner_ty.clone(), size) .into() @@ -228,6 +299,7 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { FieldElementExpression::member(base.clone(), member.id) @@ -242,9 +314,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - UExpression::member(base.clone(), member.id).into() + UExpression::member(base.clone(), member.id) + .span(span) + .into() } } Type::Boolean => { @@ -258,6 +333,7 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { BooleanExpression::member(base.clone(), member.id) @@ -272,9 +348,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - ArrayExpression::member(base.clone(), member.id).into() + ArrayExpression::member(base.clone(), member.id) + .span(span) + .into() } } Type::Struct(..) => { @@ -288,9 +367,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - StructExpression::member(base.clone(), member.id).into() + StructExpression::member(base.clone(), member.id) + .span(span) + .into() } } Type::Tuple(..) => { @@ -301,9 +383,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - TupleExpression::member(base.clone(), member.id).into() + TupleExpression::member(base.clone(), member.id) + .span(span) + .into() } } }) @@ -341,9 +426,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - FieldElementExpression::element(base.clone(), i).into() + FieldElementExpression::element(base.clone(), i) + .span(span) + .into() } } Type::Uint(..) => { @@ -353,9 +441,10 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - UExpression::element(base.clone(), i).into() + UExpression::element(base.clone(), i).span(span).into() } } Type::Boolean => { @@ -366,9 +455,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - BooleanExpression::element(base.clone(), i).into() + BooleanExpression::element(base.clone(), i) + .span(span) + .into() } } Type::Array(..) => { @@ -378,9 +470,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - ArrayExpression::element(base.clone(), i).into() + ArrayExpression::element(base.clone(), i) + .span(span) + .into() } } Type::Struct(..) => { @@ -391,9 +486,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - StructExpression::element(base.clone(), i).into() + StructExpression::element(base.clone(), i) + .span(span) + .into() } } Type::Tuple(..) => { @@ -403,9 +501,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - TupleExpression::element(base.clone(), i).into() + TupleExpression::element(base.clone(), i) + .span(span) + .into() } } }) @@ -466,39 +567,46 @@ fn is_constant(assignee: &TypedAssignee) -> bool { impl<'ast, T: Field> ResultFolder<'ast, T> for VariableWriteRemover { type Error = Error; - fn fold_assembly_statement( + fn fold_assembly_assignment( &mut self, - s: TypedAssemblyStatement<'ast, T>, + s: AssemblyAssignment<'ast, T>, ) -> Result>, Self::Error> { - match s { - TypedAssemblyStatement::Assignment(a, e) if is_constant(&a) => { - Ok(vec![TypedAssemblyStatement::Assignment(a, e)]) - } - TypedAssemblyStatement::Assignment(a, _) => Err(Error(format!( + match is_constant(&s.assignee) { + true => Ok(vec![TypedAssemblyStatement::Assignment(s)]), + false => Err(Error(format!( "Cannot assign to an assignee with a variable index `{}`", - a + s.assignee ))), - s => Ok(vec![s]), } } - fn fold_statement( + fn fold_assembly_constraint( + &mut self, + s: AssemblyConstraint<'ast, T>, + ) -> Result>, Self::Error> { + Ok(vec![TypedAssemblyStatement::Constraint(s)]) + } + + fn fold_definition_statement( &mut self, - s: TypedStatement<'ast, T>, + s: DefinitionStatement<'ast, T>, ) -> Result>, Self::Error> { - match s { - TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => { + let span = s.get_span(); + + match s.rhs { + DefinitionRhs::Expression(expr) => { + let a = s.assignee; let expr = self.fold_expression(expr)?; - if is_constant(&assignee) { - Ok(vec![TypedStatement::definition(assignee, expr)]) + if is_constant(&a) { + Ok(vec![TypedStatement::definition(a, expr).span(span)]) } else { // Note: here we redefine the whole object, ideally we would only redefine some of it // Example: `a[0][i] = 42` we redefine `a` but we could redefine just `a[0]` - let (variable, indices) = linear(assignee); + let (variable, indices) = linear(a); - let base = match variable.get_type() { + let base: TypedExpression<'ast, T> = match variable.get_type() { Type::Int => unreachable!(), Type::FieldElement => { FieldElementExpression::identifier(variable.id.clone()).into() @@ -518,6 +626,8 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for VariableWriteRemover { .into(), }; + let base = base.span(span); + let base = self.fold_expression(base)?; let indices = indices @@ -531,18 +641,21 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for VariableWriteRemover { .collect::>()?; let mut range_checks = HashSet::new(); - let e = Self::choose_many(base, indices, expr, &mut range_checks); + let e = Self::choose_many(base, indices, expr, &mut range_checks, span); Ok(range_checks .into_iter() - .chain(std::iter::once(TypedStatement::definition( - TypedAssignee::Identifier(variable), - e, - ))) + .chain(std::iter::once( + TypedStatement::definition( + TypedAssignee::Identifier(variable.span(span)), + e, + ) + .span(span), + )) .collect()) } } - s => fold_statement(self, s), + _ => fold_definition_statement(self, s), } } } diff --git a/zokrates_analysis/src/zir_propagation.rs b/zokrates_analysis/src/zir_propagation.rs index 8fdaaf73c..0d3fb1c73 100644 --- a/zokrates_analysis/src/zir_propagation.rs +++ b/zokrates_analysis/src/zir_propagation.rs @@ -2,8 +2,11 @@ use num::traits::Pow; use num_bigint::BigUint; use std::collections::HashMap; use std::fmt; -use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub}; +use std::ops::*; +use zokrates_ast::common::{ResultFold, WithSpan}; use zokrates_ast::zir::types::UBitwidth; +use zokrates_ast::zir::AssertionStatement; +use zokrates_ast::zir::IfElseStatement; use zokrates_ast::zir::{ result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Constant, Expr, Id, IdentifierExpression, IdentifierOrExpression, SelectExpression, SelectOrExpression, @@ -79,158 +82,158 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { }) } - fn fold_assembly_statement( + fn fold_assembly_assignment( &mut self, - s: ZirAssemblyStatement<'ast, T>, + s: zokrates_ast::zir::AssemblyAssignment<'ast, T>, ) -> Result>, Self::Error> { - match s { - ZirAssemblyStatement::Assignment(assignees, function) => { - let assignees: Vec<_> = assignees - .into_iter() - .map(|a| self.fold_assignee(a)) - .collect::>()?; - - let function = self.fold_function(function)?; - - match &function.statements.last().unwrap() { - ZirStatement::Return(values) => { - if values.iter().all(|v| v.is_constant()) { - self.constants.extend( - assignees - .into_iter() - .zip(values.iter()) - .map(|(a, v)| (a.id, v.clone())), - ); - Ok(vec![]) - } else { - assignees.iter().for_each(|a| { - self.constants.remove(&a.id); - }); - Ok(vec![ZirAssemblyStatement::Assignment(assignees, function)]) - } - } - _ => { - assignees.iter().for_each(|a| { - self.constants.remove(&a.id); - }); - Ok(vec![ZirAssemblyStatement::Assignment(assignees, function)]) - } + let assignees: Vec<_> = s + .assignee + .into_iter() + .map(|a| self.fold_assignee(a)) + .collect::>()?; + + let function = self.fold_function(s.expression)?; + + match &function.statements.last().unwrap() { + ZirStatement::Return(s) => { + if s.inner.iter().all(|v| v.is_constant()) { + self.constants.extend( + assignees + .into_iter() + .zip(s.inner.iter()) + .map(|(a, v)| (a.id, v.clone())), + ); + Ok(vec![]) + } else { + assignees.iter().for_each(|a| { + self.constants.remove(&a.id); + }); + Ok(vec![ZirAssemblyStatement::assignment(assignees, function)]) } } - ZirAssemblyStatement::Constraint(left, right, metadata) => { - let left = self.fold_field_expression(left)?; - let right = self.fold_field_expression(right)?; - - // a bit hacky, but we use a fake boolean expression to check this - let is_equal = BooleanExpression::FieldEq(box left.clone(), box right.clone()); - let is_equal = self.fold_boolean_expression(is_equal)?; - - match is_equal { - BooleanExpression::Value(true) => Ok(vec![]), - BooleanExpression::Value(false) => { - Err(Error::AssertionFailed(RuntimeError::SourceAssertion( - metadata - .message(Some(format!("In asm block: `{} !== {}`", left, right))), - ))) - } - _ => Ok(vec![ZirAssemblyStatement::Constraint( - left, right, metadata, - )]), - } + _ => { + assignees.iter().for_each(|a| { + self.constants.remove(&a.id); + }); + Ok(vec![ZirAssemblyStatement::assignment(assignees, function)]) } } } - fn fold_statement( + fn fold_assembly_constraint( &mut self, - s: ZirStatement<'ast, T>, + s: zokrates_ast::zir::AssemblyConstraint<'ast, T>, + ) -> Result>, Self::Error> { + let left = self.fold_field_expression(s.left)?; + let right = self.fold_field_expression(s.right)?; + + // a bit hacky, but we use a fake boolean expression to check this + let is_equal = BooleanExpression::field_eq(left.clone(), right.clone()); + let is_equal = self.fold_boolean_expression(is_equal)?; + + match is_equal { + BooleanExpression::Value(v) if v.value => Ok(vec![]), + BooleanExpression::Value(v) if !v.value => { + Err(Error::AssertionFailed(RuntimeError::SourceAssertion( + s.metadata + .message(Some(format!("In asm block: `{} !== {}`", left, right))), + ))) + } + _ => Ok(vec![ZirAssemblyStatement::constraint( + left, right, s.metadata, + )]), + } + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, ) -> Result>, Self::Error> { - match s { - ZirStatement::Assertion(e, error) => match self.fold_boolean_expression(e)? { - BooleanExpression::Value(true) => Ok(vec![]), - BooleanExpression::Value(false) => Err(Error::AssertionFailed(error)), - e => Ok(vec![ZirStatement::Assertion(e, error)]), - }, - ZirStatement::Definition(a, e) => { - let e = self.fold_expression(e)?; - match e { - ZirExpression::FieldElement(FieldElementExpression::Number(..)) - | ZirExpression::Boolean(BooleanExpression::Value(..)) - | ZirExpression::Uint(UExpression { - inner: UExpressionInner::Value(..), - .. - }) => { - self.constants.insert(a.id, e); - Ok(vec![]) - } - _ => { - self.constants.remove(&a.id); - Ok(vec![ZirStatement::Definition(a, e)]) - } - } + match self.fold_boolean_expression(s.expression)? { + BooleanExpression::Value(v) if v.value => Ok(vec![]), + BooleanExpression::Value(v) if !v.value => Err(Error::AssertionFailed(s.error)), + e => Ok(vec![ZirStatement::assertion(e, s.error)]), + } + } + + fn fold_definition_statement( + &mut self, + s: zokrates_ast::zir::DefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + let e = self.fold_expression(s.rhs)?; + match e { + ZirExpression::FieldElement(FieldElementExpression::Value(..)) + | ZirExpression::Boolean(BooleanExpression::Value(..)) + | ZirExpression::Uint(UExpression { + inner: UExpressionInner::Value(..), + .. + }) => { + self.constants.insert(s.assignee.id, e); + Ok(vec![]) } - ZirStatement::IfElse(e, consequence, alternative) => { - match self.fold_boolean_expression(e)? { - BooleanExpression::Value(true) => Ok(consequence + _ => { + self.constants.remove(&s.assignee.id); + Ok(vec![ZirStatement::definition(s.assignee, e)]) + } + } + } + + fn fold_if_else_statement( + &mut self, + s: zokrates_ast::zir::IfElseStatement<'ast, T>, + ) -> Result>, Self::Error> { + { + match self.fold_boolean_expression(s.condition)? { + BooleanExpression::Value(v) if v.value => Ok(s + .consequence + .into_iter() + .map(|s| self.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect()), + BooleanExpression::Value(v) if !v.value => Ok(s + .alternative + .into_iter() + .map(|s| self.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect()), + e => Ok(vec![ZirStatement::IfElse(IfElseStatement::new( + e, + s.consequence .into_iter() .map(|s| self.fold_statement(s)) .collect::, _>>()? .into_iter() .flatten() - .collect()), - BooleanExpression::Value(false) => Ok(alternative + .collect(), + s.alternative .into_iter() .map(|s| self.fold_statement(s)) .collect::, _>>()? .into_iter() .flatten() - .collect()), - e => Ok(vec![ZirStatement::IfElse( - e, - consequence - .into_iter() - .map(|s| self.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), - alternative - .into_iter() - .map(|s| self.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), - )]), - } - } - ZirStatement::MultipleDefinition(assignees, list) => { - for a in &assignees { - self.constants.remove(&a.id); - } - Ok(vec![ZirStatement::MultipleDefinition( - assignees, - self.fold_expression_list(list)?, - )]) + .collect(), + ))]), } - ZirStatement::Assembly(statements) => { - let statements: Vec<_> = statements - .into_iter() - .map(|s| self.fold_assembly_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(); - match statements.len() { - 0 => Ok(vec![]), - _ => Ok(vec![ZirStatement::Assembly(statements)]), - } - } - _ => fold_statement(self, s), } } - fn fold_identifier_expression + Id<'ast, T> + ResultFold<'ast, T>>( + fn fold_multiple_definition_statement( + &mut self, + s: zokrates_ast::zir::MultipleDefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + for a in &s.assignees { + self.constants.remove(&a.id); + } + fold_multiple_definition_statement(self, s) + } + + fn fold_identifier_expression< + E: Expr<'ast, T> + Id<'ast, T> + ResultFold, + >( &mut self, _: &E::Ty, id: IdentifierExpression<'ast, E>, @@ -241,169 +244,175 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { } } - fn fold_field_expression( + fn fold_field_expression_cases( &mut self, e: FieldElementExpression<'ast, T>, ) -> Result, Self::Error> { match e { - FieldElementExpression::Number(n) => Ok(FieldElementExpression::Number(n)), - FieldElementExpression::Add(box e1, box e2) => { - match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n), e) - | (e, FieldElementExpression::Number(n)) - if n == T::from(0) => + FieldElementExpression::Value(n) => Ok(FieldElementExpression::Value(n)), + FieldElementExpression::Add(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n), e) + | (e, FieldElementExpression::Value(n)) + if n.value == T::from(0) => { - Ok(e) + e } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 + n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::number(n1.value + n2.value) } - (e1, e2) => Ok(FieldElementExpression::Add(box e1, box e2)), + (e1, e2) => FieldElementExpression::add(e1, e2), } + .span(e.span)) } - FieldElementExpression::Sub(box e1, box e2) => { - match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (e, FieldElementExpression::Number(n)) if n == T::from(0) => Ok(e), - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 - n2)) + FieldElementExpression::Sub(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (e, FieldElementExpression::Value(n)) if n.value == T::from(0) => e, + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::number(n1.value - n2.value) } - (e1, e2) => Ok(FieldElementExpression::Sub(box e1, box e2)), + (e1, e2) => FieldElementExpression::sub(e1, e2), } + .span(e.span)) } - FieldElementExpression::Mult(box e1, box e2) => { - match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n), _) - | (_, FieldElementExpression::Number(n)) - if n == T::from(0) => + FieldElementExpression::Mult(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n), _) + | (_, FieldElementExpression::Value(n)) + if n.value == T::from(0) => { - Ok(FieldElementExpression::Number(T::from(0))) + FieldElementExpression::number(T::from(0)) } - (FieldElementExpression::Number(n), e) - | (e, FieldElementExpression::Number(n)) - if n == T::from(1) => + (FieldElementExpression::Value(n), e) + | (e, FieldElementExpression::Value(n)) + if n.value == T::from(1) => { - Ok(e) + e } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 * n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::number(n1.value * n2.value) } - (e1, e2) => Ok(FieldElementExpression::Mult(box e1, box e2)), + (e1, e2) => FieldElementExpression::mul(e1, e2), } + .span(e.span)) } - FieldElementExpression::Div(box e1, box e2) => { - match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (_, FieldElementExpression::Number(n)) if n == T::from(0) => { + FieldElementExpression::Div(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + match (left, right) { + (_, FieldElementExpression::Value(n)) if n.value == T::from(0) => { Err(Error::DivisionByZero) } - (e, FieldElementExpression::Number(n)) if n == T::from(1) => Ok(e), - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 / n2)) + (e, FieldElementExpression::Value(n)) if n.value == T::from(1) => Ok(e), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::number(n1.value / n2.value).span(e.span)) } - (e1, e2) => Ok(FieldElementExpression::Div(box e1, box e2)), + (e1, e2) => Ok(FieldElementExpression::div(e1, e2).span(e.span)), } } - FieldElementExpression::Pow(box e, box exponent) => { - let exponent = self.fold_uint_expression(exponent)?; - match (self.fold_field_expression(e)?, exponent.into_inner()) { - (_, UExpressionInner::Value(n2)) if n2 == 0 => { - Ok(FieldElementExpression::Number(T::from(1))) + FieldElementExpression::Pow(e) => { + let exponent = self.fold_uint_expression(*e.right)?; + match (self.fold_field_expression(*e.left)?, exponent.into_inner()) { + (_, UExpressionInner::Value(n2)) if n2.value == 0 => { + Ok(FieldElementExpression::number(T::from(1))) } - (e, UExpressionInner::Value(n2)) if n2 == 1 => Ok(e), - (FieldElementExpression::Number(n), UExpressionInner::Value(e)) => { - Ok(FieldElementExpression::Number(n.pow(e as usize))) + (e, UExpressionInner::Value(n2)) if n2.value == 1 => Ok(e), + (FieldElementExpression::Value(n), UExpressionInner::Value(e)) => Ok( + FieldElementExpression::number(n.value.pow(e.value as usize)), + ), + (e, exp) => { + Ok(FieldElementExpression::pow(e, exp.annotate(UBitwidth::B32)) + .into_inner()) } - (e, exp) => Ok(FieldElementExpression::Pow( - box e, - box exp.annotate(UBitwidth::B32), - )), } } - FieldElementExpression::Xor(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::Xor(e) => { + let e1 = self.fold_field_expression(*e.right)?; + let e2 = self.fold_field_expression(*e.left)?; match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitxor(n2.value.to_biguint())) + .unwrap(), )) } - (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))), - (e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)), + (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::value(T::from(0))), + (e1, e2) => Ok(FieldElementExpression::bitxor(e1, e2)), } } - FieldElementExpression::And(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::And(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (_, FieldElementExpression::Number(n)) - | (FieldElementExpression::Number(n), _) - if n == T::from(0) => + (_, FieldElementExpression::Value(n)) + | (FieldElementExpression::Value(n), _) + if n.value == T::from(0) => { - Ok(FieldElementExpression::Number(n)) + Ok(FieldElementExpression::Value(n)) } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitand(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitand(n2.value.to_biguint())) + .unwrap(), )) } - (e1, e2) => Ok(FieldElementExpression::And(box e1, box e2)), + (e1, e2) => Ok(FieldElementExpression::bitand(e1, e2)), } } - FieldElementExpression::Or(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::Or(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (e, FieldElementExpression::Number(n)) - | (FieldElementExpression::Number(n), e) - if n == T::from(0) => + (e, FieldElementExpression::Value(n)) + | (FieldElementExpression::Value(n), e) + if n.value == T::from(0) => { Ok(e) } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitor(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitor(n2.value.to_biguint())) + .unwrap(), )) } - (e1, e2) => Ok(FieldElementExpression::Or(box e1, box e2)), + (e1, e2) => Ok(FieldElementExpression::bitor(e1, e2)), } } - FieldElementExpression::LeftShift(box e, box by) => { - let e = self.fold_field_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e, by) { + FieldElementExpression::LeftShift(e) => { + let expr = self.fold_field_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; + match (expr, by) { ( e, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by == 0 => Ok(e), + ) if by.value == 0 => Ok(e), ( _, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by as usize >= T::get_required_bits() => { - Ok(FieldElementExpression::Number(T::from(0))) + ) if by.value as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::value(T::from(0))) } ( - FieldElementExpression::Number(n), + FieldElementExpression::Value(n), UExpression { inner: UExpressionInner::Value(by), .. @@ -412,197 +421,206 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { let two = BigUint::from(2usize); let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize); - Ok(FieldElementExpression::Number( - T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(), + Ok(FieldElementExpression::value( + T::try_from(n.value.to_biguint().shl(by.value as usize).bitand(mask)) + .unwrap(), )) } - (e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)), + (expr, by) => Ok(FieldElementExpression::left_shift(expr, by)), } } - FieldElementExpression::RightShift(box e, box by) => { - let e = self.fold_field_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e, by) { + FieldElementExpression::RightShift(e) => { + let expr = self.fold_field_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; + match (expr, by) { ( e, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by == 0 => Ok(e), + ) if by.value == 0 => Ok(e), ( _, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by as usize >= T::get_required_bits() => { - Ok(FieldElementExpression::Number(T::from(0))) + ) if by.value as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::value(T::from(0))) } ( - FieldElementExpression::Number(n), + FieldElementExpression::Value(n), UExpression { inner: UExpressionInner::Value(by), .. }, - ) => Ok(FieldElementExpression::Number( - T::try_from(n.to_biguint().shr(by as usize)).unwrap(), + ) => Ok(FieldElementExpression::value( + T::try_from(n.value.to_biguint().shr(by.value as usize)).unwrap(), )), - (e, by) => Ok(FieldElementExpression::RightShift(box e, box by)), + (expr, by) => Ok(FieldElementExpression::right_shift(expr, by)), } } - e => fold_field_expression(self, e), + e => fold_field_expression_cases(self, e), } } - fn fold_boolean_expression( + fn fold_boolean_expression_cases( &mut self, e: BooleanExpression<'ast, T>, ) -> Result, Error> { match e { BooleanExpression::Value(v) => Ok(BooleanExpression::Value(v)), - BooleanExpression::FieldLt(box e1, box e2) => { + BooleanExpression::FieldLt(e) => { match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, + self.fold_field_expression(*e.left)?, + self.fold_field_expression(*e.right)?, ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 < n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(BooleanExpression::value(n1 < n2)) } - (_, FieldElementExpression::Number(c)) if c == T::zero() => { - Ok(BooleanExpression::Value(false)) + (_, FieldElementExpression::Value(c)) if c.value == T::zero() => { + Ok(BooleanExpression::value(false)) } - (FieldElementExpression::Number(c), _) if c == T::max_value() => { - Ok(BooleanExpression::Value(false)) + (FieldElementExpression::Value(c), _) if c.value == T::max_value() => { + Ok(BooleanExpression::value(false)) } - (e1, e2) => Ok(BooleanExpression::FieldLt(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::field_lt(e1, e2)), } } - BooleanExpression::FieldLe(box e1, box e2) => { + BooleanExpression::FieldLe(e) => { match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, + self.fold_field_expression(*e.left)?, + self.fold_field_expression(*e.right)?, ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 <= n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(BooleanExpression::value(n1 <= n2)) } - (e1, e2) => Ok(BooleanExpression::FieldLe(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::field_le(e1, e2)), } } - BooleanExpression::FieldEq(box e1, box e2) => { + BooleanExpression::FieldEq(e) => { match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, + self.fold_field_expression(*e.left)?, + self.fold_field_expression(*e.right)?, ) { - (FieldElementExpression::Number(v1), FieldElementExpression::Number(v2)) => { - Ok(BooleanExpression::Value(v1.eq(&v2))) + (FieldElementExpression::Value(v1), FieldElementExpression::Value(v2)) => { + Ok(BooleanExpression::value(v1.eq(&v2))) } (e1, e2) => { if e1.eq(&e2) { - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) } else { - Ok(BooleanExpression::FieldEq(box e1, box e2)) + Ok(BooleanExpression::field_eq(e1, e2)) } } } } - BooleanExpression::UintLt(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + BooleanExpression::UintLt(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(BooleanExpression::Value(v1 < v2)) + Ok(BooleanExpression::value(v1 < v2)) } - _ => Ok(BooleanExpression::UintLt(box e1, box e2)), + _ => Ok(BooleanExpression::uint_lt(e1, e2)), } } - BooleanExpression::UintLe(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + BooleanExpression::UintLe(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(BooleanExpression::Value(v1 <= v2)) + Ok(BooleanExpression::value(v1 <= v2)) } - _ => Ok(BooleanExpression::UintLe(box e1, box e2)), + _ => Ok(BooleanExpression::uint_le(e1, e2)), } } - BooleanExpression::UintEq(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + BooleanExpression::UintEq(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(BooleanExpression::Value(v1 == v2)) + Ok(BooleanExpression::value(v1 == v2)) } _ => { if e1.eq(&e2) { - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) } else { - Ok(BooleanExpression::UintEq(box e1, box e2)) + Ok(BooleanExpression::uint_eq(e1, e2)) } } } } - BooleanExpression::BoolEq(box e1, box e2) => { + BooleanExpression::BoolEq(e) => { match ( - self.fold_boolean_expression(e1)?, - self.fold_boolean_expression(e2)?, + self.fold_boolean_expression(*e.left)?, + self.fold_boolean_expression(*e.right)?, ) { (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { - Ok(BooleanExpression::Value(v1 == v2)) + Ok(BooleanExpression::value(v1 == v2)) } (e1, e2) => { if e1.eq(&e2) { - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) } else { - Ok(BooleanExpression::BoolEq(box e1, box e2)) + Ok(BooleanExpression::bool_eq(e1, e2)) } } } } - BooleanExpression::Or(box e1, box e2) => { + BooleanExpression::Or(e) => { match ( - self.fold_boolean_expression(e1)?, - self.fold_boolean_expression(e2)?, + self.fold_boolean_expression(*e.left)?, + self.fold_boolean_expression(*e.right)?, ) { (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { - Ok(BooleanExpression::Value(v1 || v2)) + Ok(BooleanExpression::value(v1.value || v2.value)) } - (_, BooleanExpression::Value(true)) | (BooleanExpression::Value(true), _) => { - Ok(BooleanExpression::Value(true)) + (_, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), _) + if v.value => + { + Ok(BooleanExpression::value(true)) } - (e, BooleanExpression::Value(false)) | (BooleanExpression::Value(false), e) => { + (e, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), e) + if !v.value => + { Ok(e) } - (e1, e2) => Ok(BooleanExpression::Or(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::bitor(e1, e2)), } } - BooleanExpression::And(box e1, box e2) => { + BooleanExpression::And(e) => { match ( - self.fold_boolean_expression(e1)?, - self.fold_boolean_expression(e2)?, + self.fold_boolean_expression(*e.left)?, + self.fold_boolean_expression(*e.right)?, ) { - (BooleanExpression::Value(true), e) | (e, BooleanExpression::Value(true)) => { + (BooleanExpression::Value(v), e) | (e, BooleanExpression::Value(v)) + if v.value => + { Ok(e) } - (BooleanExpression::Value(false), _) | (_, BooleanExpression::Value(false)) => { - Ok(BooleanExpression::Value(false)) + (BooleanExpression::Value(v), _) | (_, BooleanExpression::Value(v)) + if !v.value => + { + Ok(BooleanExpression::value(false)) } - (e1, e2) => Ok(BooleanExpression::And(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::bitand(e1, e2)), } } - BooleanExpression::Not(box e) => match self.fold_boolean_expression(e)? { - BooleanExpression::Value(v) => Ok(BooleanExpression::Value(!v)), - e => Ok(BooleanExpression::Not(box e)), + BooleanExpression::Not(e) => match self.fold_boolean_expression(*e.inner)? { + BooleanExpression::Value(v) => Ok(BooleanExpression::value(!v.value)), + e => Ok(BooleanExpression::not(e)), }, - e => fold_boolean_expression(self, e), + e => fold_boolean_expression_cases(self, e), } } fn fold_select_expression< - E: Clone + Expr<'ast, T> + ResultFold<'ast, T> + zokrates_ast::zir::Select<'ast, T>, + E: Clone + Expr<'ast, T> + ResultFold + zokrates_ast::zir::Select<'ast, T>, >( &mut self, _: &E::Ty, @@ -617,9 +635,9 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { match index.as_inner() { UExpressionInner::Value(v) => array - .get(*v as usize) + .get(v.value as usize) .cloned() - .ok_or(Error::OutOfBounds(*v as usize, array.len())) + .ok_or(Error::OutOfBounds(v.value as usize, array.len())) .map(|e| SelectOrExpression::Expression(e.into_inner())), _ => Ok(SelectOrExpression::Expression( E::select(array, index).into_inner(), @@ -627,185 +645,221 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { } } - fn fold_uint_expression_inner( + fn fold_uint_expression_cases( &mut self, bitwidth: UBitwidth, e: UExpressionInner<'ast, T>, ) -> Result, Self::Error> { match e { UExpressionInner::Value(v) => Ok(UExpressionInner::Value(v)), - UExpressionInner::Add(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Add(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (UExpressionInner::Value(0), e) | (e, UExpressionInner::Value(0)) => Ok(e), - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( - UExpressionInner::Value((n1 + n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + (UExpressionInner::Value(v), e) | (e, UExpressionInner::Value(v)) + if v.value == 0 => + { + Ok(e) + } + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + Ok(UExpression::value( + (n1.value + n2.value) % 2_u128.pow(bitwidth.to_usize() as u32), + )) + } + (e1, e2) => Ok( + UExpression::add(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), ), - (e1, e2) => Ok(UExpressionInner::Add( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), } } - UExpressionInner::Sub(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Sub(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (e, UExpressionInner::Value(0)) => Ok(e), + (e, UExpressionInner::Value(v)) if v.value == 0 => Ok(e), (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(UExpressionInner::Value( - n1.wrapping_sub(n2) % 2_u128.pow(bitwidth.to_usize() as u32), + Ok(UExpression::value( + n1.value.wrapping_sub(n2.value) + % 2_u128.pow(bitwidth.to_usize() as u32), )) } - (e1, e2) => Ok(UExpressionInner::Sub( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => Ok( + UExpression::sub(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), + ), } } - UExpressionInner::Mult(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Mult(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (_, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), _) => { - Ok(UExpressionInner::Value(0)) + (_, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), _) + if v.value == 0 => + { + Ok(UExpression::value(0)) + } + (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) + if v.value == 1 => + { + Ok(e) } - (e, UExpressionInner::Value(1)) | (UExpressionInner::Value(1), e) => Ok(e), - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( - UExpressionInner::Value((n1 * n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + Ok(UExpression::value( + (n1.value * n2.value) % 2_u128.pow(bitwidth.to_usize() as u32), + )) + } + (e1, e2) => Ok( + UExpression::mult(e1.annotate(bitwidth), e2.annotate(bitwidth)) + .into_inner(), ), - (e1, e2) => Ok(UExpressionInner::Mult( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), } } - UExpressionInner::Div(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Div(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (_, UExpressionInner::Value(n)) if n == 0 => Err(Error::DivisionByZero), - (e, UExpressionInner::Value(n)) if n == 1 => Ok(e), - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( - UExpressionInner::Value((n1 / n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + (_, UExpressionInner::Value(n)) if n.value == 0 => Err(Error::DivisionByZero), + (e, UExpressionInner::Value(n)) if n.value == 1 => Ok(e), + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + Ok(UExpression::value( + (n1.value / n2.value) % 2_u128.pow(bitwidth.to_usize() as u32), + )) + } + (e1, e2) => Ok( + UExpression::div(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), ), - (e1, e2) => Ok(UExpressionInner::Div( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), } } - UExpressionInner::Rem(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Rem(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( - UExpressionInner::Value((n1 % n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + Ok(UExpression::value( + (n1.value % n2.value) % 2_u128.pow(bitwidth.to_usize() as u32), + )) + } + (e1, e2) => Ok( + UExpression::rem(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), ), - (e1, e2) => Ok(UExpressionInner::Rem( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), } } - UExpressionInner::Xor(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Xor(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(UExpressionInner::Value(n1 ^ n2)) + Ok(UExpression::value(n1.value ^ n2.value)) } - (e1, e2) if e1.eq(&e2) => Ok(UExpressionInner::Value(0)), - (e1, e2) => Ok(UExpressionInner::Xor( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) if e1.eq(&e2) => Ok(UExpression::value(0)), + (e1, e2) => Ok( + UExpression::xor(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), + ), } } - UExpressionInner::And(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::And(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { (e, UExpressionInner::Value(n)) | (UExpressionInner::Value(n), e) - if n == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => + if n.value == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => { Ok(e) } - (_, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), _) => { - Ok(UExpressionInner::Value(0)) + (_, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), _) + if v.value == 0 => + { + Ok(UExpression::value(0)) } (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(UExpressionInner::Value(n1 & n2)) + Ok(UExpression::value(n1.value & n2.value)) } - (e1, e2) => Ok(UExpressionInner::And( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => Ok( + UExpression::and(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), + ), } } - UExpressionInner::Or(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Or(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (e, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), e) => Ok(e), + (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) + if v.value == 0 => + { + Ok(e) + } (_, UExpressionInner::Value(n)) | (UExpressionInner::Value(n), _) - if n == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => + if n.value == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => { - Ok(UExpressionInner::Value(n)) + Ok(UExpression::value(n.value)) } (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(UExpressionInner::Value(n1 | n2)) + Ok(UExpression::value(n1.value | n2.value)) } - (e1, e2) => Ok(UExpressionInner::Or( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => Ok( + UExpression::or(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), + ), } } - UExpressionInner::LeftShift(box e, by) => { - let e = self.fold_uint_expression(e)?; - match (e.into_inner(), by) { - (e, 0) => Ok(e), - (_, by) if by >= bitwidth as u32 => Ok(UExpressionInner::Value(0)), - (UExpressionInner::Value(n), by) => Ok(UExpressionInner::Value( - (n << by) & (2_u128.pow(bitwidth as u32) - 1), - )), - (e, by) => Ok(UExpressionInner::LeftShift(box e.annotate(bitwidth), by)), + UExpressionInner::LeftShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + match (left.into_inner(), right.into_inner()) { + (e, UExpressionInner::Value(by)) if by.value == 0 => Ok(e), + (_, UExpressionInner::Value(by)) if by.value as u32 >= bitwidth as u32 => { + Ok(UExpression::value(0)) + } + (UExpressionInner::Value(n), UExpressionInner::Value(by)) => { + Ok(UExpression::value( + (n.value << by.value) & (2_u128.pow(bitwidth as u32) - 1), + )) + } + (e, by) => Ok(UExpression::left_shift( + e.annotate(bitwidth), + by.annotate(UBitwidth::B32), + ) + .into_inner()), } } - UExpressionInner::RightShift(box e, by) => { - let e = self.fold_uint_expression(e)?; - match (e.into_inner(), by) { - (e, 0) => Ok(e), - (_, by) if by >= bitwidth as u32 => Ok(UExpressionInner::Value(0)), - (UExpressionInner::Value(n), by) => Ok(UExpressionInner::Value(n >> by)), - (e, by) => Ok(UExpressionInner::RightShift(box e.annotate(bitwidth), by)), + UExpressionInner::RightShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + match (left.into_inner(), right.into_inner()) { + (e, UExpressionInner::Value(by)) if by.value == 0 => Ok(e), + (_, UExpressionInner::Value(by)) if by.value as u32 >= bitwidth as u32 => { + Ok(UExpression::value(0)) + } + (UExpressionInner::Value(n), UExpressionInner::Value(by)) => { + Ok(UExpression::value(n.value >> by.value)) + } + (e, by) => Ok(UExpression::right_shift( + e.annotate(bitwidth), + by.annotate(UBitwidth::B32), + ) + .into_inner()), } } - UExpressionInner::Not(box e) => { - let e = self.fold_uint_expression(e)?; + UExpressionInner::Not(e) => { + let e = self.fold_uint_expression(*e.inner)?; match e.into_inner() { - UExpressionInner::Value(n) => Ok(UExpressionInner::Value( - !n & (2_u128.pow(bitwidth as u32) - 1), + UExpressionInner::Value(n) => Ok(UExpression::value( + !n.value & (2_u128.pow(bitwidth as u32) - 1), )), - e => Ok(UExpressionInner::Not(box e.annotate(bitwidth))), + e => Ok(UExpression::not(e.annotate(bitwidth)).into_inner()), } } - e => fold_uint_expression_inner(self, bitwidth, e), + e => fold_uint_expression_cases(self, bitwidth, e), } } fn fold_conditional_expression< - E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>, + E: Expr<'ast, T> + ResultFold + Conditional<'ast, T>, >( &mut self, _: &E::Ty, @@ -814,10 +868,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { let condition = self.fold_boolean_expression(*e.condition)?; match condition { - BooleanExpression::Value(true) => Ok(ConditionalOrExpression::Expression( + BooleanExpression::Value(v) if v.value => Ok(ConditionalOrExpression::Expression( e.consequence.fold(self)?.into_inner(), )), - BooleanExpression::Value(false) => Ok(ConditionalOrExpression::Expression( + BooleanExpression::Value(v) if !v.value => Ok(ConditionalOrExpression::Expression( e.alternative.fold(self)?.into_inner(), )), condition => { @@ -847,15 +901,15 @@ mod tests { #[test] fn propagation() { // assert([x, 1] == [y, 1]) - let statements = vec![ZirStatement::Assertion( - BooleanExpression::And( - box BooleanExpression::FieldEq( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + let statements = vec![ZirStatement::assertion( + BooleanExpression::bitand( + BooleanExpression::field_eq( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), - box BooleanExpression::FieldEq( - box FieldElementExpression::Number(Bn128Field::from(1)), - box FieldElementExpression::Number(Bn128Field::from(1)), + BooleanExpression::field_eq( + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(1)), ), ), RuntimeError::mock(), @@ -873,10 +927,10 @@ mod tests { assert_eq!( statements, - vec![ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + vec![ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), RuntimeError::mock() )] @@ -897,21 +951,21 @@ mod tests { assert_eq!( propagator.fold_field_expression(FieldElementExpression::select( vec![ - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ], - UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(2))) + Ok(FieldElementExpression::value(Bn128Field::from(2))) ); assert_eq!( propagator.fold_field_expression(FieldElementExpression::select( vec![ - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ], - UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), Err(Error::OutOfBounds(3, 2)) ); @@ -922,18 +976,18 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + propagator.fold_field_expression(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(5))) + Ok(FieldElementExpression::value(Bn128Field::from(5))) ); // a + 0 = a assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Add( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + propagator.fold_field_expression(FieldElementExpression::add( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(0)), )), Ok(FieldElementExpression::identifier("a".into())) ); @@ -944,18 +998,18 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Sub( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(2)), + propagator.fold_field_expression(FieldElementExpression::sub( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); // a - 0 = a assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Sub( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + propagator.fold_field_expression(FieldElementExpression::sub( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(0)), )), Ok(FieldElementExpression::identifier("a".into())) ); @@ -966,27 +1020,27 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(2)), + propagator.fold_field_expression(FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(6))) + Ok(FieldElementExpression::value(Bn128Field::from(6))) ); // a * 0 = 0 assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + propagator.fold_field_expression(FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(0)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); // a * 1 = a assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(1)), + propagator.fold_field_expression(FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(1)), )), Ok(FieldElementExpression::identifier("a".into())) ); @@ -997,25 +1051,25 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Div( - box FieldElementExpression::Number(Bn128Field::from(6)), - box FieldElementExpression::Number(Bn128Field::from(2)), + propagator.fold_field_expression(FieldElementExpression::div( + FieldElementExpression::value(Bn128Field::from(6)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(3))) + Ok(FieldElementExpression::value(Bn128Field::from(3))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Div( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(1)), + propagator.fold_field_expression(FieldElementExpression::div( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(1)), )), Ok(FieldElementExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Div( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + propagator.fold_field_expression(FieldElementExpression::div( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(0)), )), Err(Error::DivisionByZero) ); @@ -1026,27 +1080,27 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Pow( - box FieldElementExpression::Number(Bn128Field::from(3)), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::pow( + FieldElementExpression::value(Bn128Field::from(3)), + UExpression::value(2).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(9))) + Ok(FieldElementExpression::value(Bn128Field::from(9))) ); // a ** 0 = 1 assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Pow( - box FieldElementExpression::identifier("a".into()), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::pow( + FieldElementExpression::identifier("a".into()), + UExpression::value(0).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); // a ** 1 = a assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Pow( - box FieldElementExpression::identifier("a".into()), - box UExpressionInner::Value(1).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::pow( + FieldElementExpression::identifier("a".into()), + UExpression::value(1).annotate(UBitwidth::B32), )), Ok(FieldElementExpression::identifier("a".into())) ); @@ -1057,44 +1111,44 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::identifier("a".into()), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::identifier("a".into()), + UExpression::value(0).annotate(UBitwidth::B32), )), Ok(FieldElementExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(2)), + UExpression::value(2_u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(8))) + Ok(FieldElementExpression::value(Bn128Field::from(8))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(1)), - box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(1)), + UExpression::value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) + Ok(FieldElementExpression::value(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(3)), - box UExpressionInner::Value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(3)), + UExpression::value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) + Ok(FieldElementExpression::value(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(1)), - box UExpressionInner::Value((Bn128Field::get_required_bits()) as u128) + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(1)), + UExpression::value((Bn128Field::get_required_bits()) as u128) .annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); } @@ -1103,61 +1157,61 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::identifier("a".into()), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::identifier("a".into()), + UExpression::value(0).annotate(UBitwidth::B32), )), Ok(FieldElementExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::identifier("a".into()), - box UExpressionInner::Value(Bn128Field::get_required_bits() as u128) + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::identifier("a".into()), + UExpression::value(Bn128Field::get_required_bits() as u128) .annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(3)), - box UExpressionInner::Value(1 as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(3)), + UExpression::value(1_u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(2)), + UExpression::value(2_u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box UExpressionInner::Value(4 as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(2)), + UExpression::value(4_u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::max_value()), - box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128) + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::max_value()), + UExpression::value((Bn128Field::get_required_bits() - 1) as u128) .annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::max_value()), - box UExpressionInner::Value(Bn128Field::get_required_bits() as u128) + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::max_value()), + UExpression::value(Bn128Field::get_required_bits() as u128) .annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); } @@ -1167,29 +1221,29 @@ mod tests { assert_eq!( propagator.fold_field_expression(FieldElementExpression::conditional( - BooleanExpression::Value(true), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + BooleanExpression::value(true), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); assert_eq!( propagator.fold_field_expression(FieldElementExpression::conditional( - BooleanExpression::Value(false), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + BooleanExpression::value(false), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(2))) + Ok(FieldElementExpression::value(Bn128Field::from(2))) ); assert_eq!( propagator.fold_field_expression(FieldElementExpression::conditional( BooleanExpression::identifier("a".into()), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(2))) + Ok(FieldElementExpression::value(Bn128Field::from(2))) ); } } @@ -1208,21 +1262,21 @@ mod tests { assert_eq!( propagator.fold_boolean_expression(BooleanExpression::select( vec![ - BooleanExpression::Value(false), - BooleanExpression::Value(true), + BooleanExpression::value(false), + BooleanExpression::value(true) ], - UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::select( vec![ - BooleanExpression::Value(false), - BooleanExpression::Value(true), + BooleanExpression::value(false), + BooleanExpression::value(true) ], - UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), Err(Error::OutOfBounds(3, 2)) ); @@ -1233,35 +1287,35 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + propagator.fold_boolean_expression(BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(3)), + propagator.fold_boolean_expression(BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(3)), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLt( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + propagator.fold_boolean_expression(BooleanExpression::field_lt( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(0)), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::max_value()), - box FieldElementExpression::identifier("a".into()), + propagator.fold_boolean_expression(BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::max_value()), + FieldElementExpression::identifier("a".into()), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1270,19 +1324,19 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLe( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + propagator.fold_boolean_expression(BooleanExpression::field_le( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLe( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(3)), + propagator.fold_boolean_expression(BooleanExpression::field_le( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(3)), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); } @@ -1291,19 +1345,19 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldEq( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(2)), + propagator.fold_boolean_expression(BooleanExpression::field_eq( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldEq( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(2)), + propagator.fold_boolean_expression(BooleanExpression::field_eq( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1312,19 +1366,19 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintLt( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_lt( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintLt( - box UExpressionInner::Value(3).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_lt( + UExpression::value(3).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1333,19 +1387,19 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintLe( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_le( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintLe( - box UExpressionInner::Value(3).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_le( + UExpression::value(3).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); } @@ -1354,19 +1408,19 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintEq( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_eq( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintEq( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_eq( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1375,19 +1429,19 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::BoolEq( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), + propagator.fold_boolean_expression(BooleanExpression::bool_eq( + BooleanExpression::value(true), + BooleanExpression::value(true), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::BoolEq( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), + propagator.fold_boolean_expression(BooleanExpression::bool_eq( + BooleanExpression::value(true), + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1396,35 +1450,35 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::And( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), + propagator.fold_boolean_expression(BooleanExpression::bitand( + BooleanExpression::value(true), + BooleanExpression::value(true), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::And( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), + propagator.fold_boolean_expression(BooleanExpression::bitand( + BooleanExpression::value(true), + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::And( - box BooleanExpression::identifier("a".into()), - box BooleanExpression::Value(true), + propagator.fold_boolean_expression(BooleanExpression::bitand( + BooleanExpression::identifier("a".into()), + BooleanExpression::value(true) )), Ok(BooleanExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::And( - box BooleanExpression::identifier("a".into()), - box BooleanExpression::Value(false), + propagator.fold_boolean_expression(BooleanExpression::bitand( + BooleanExpression::identifier("a".into()), + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1433,19 +1487,19 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), + propagator.fold_boolean_expression(BooleanExpression::bitor( + BooleanExpression::value(true), + BooleanExpression::value(true), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), + propagator.fold_boolean_expression(BooleanExpression::bitor( + BooleanExpression::value(true), + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); } @@ -1454,17 +1508,17 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::Not( - box BooleanExpression::Value(true), + propagator.fold_boolean_expression(BooleanExpression::not( + BooleanExpression::value(true) )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::Not( - box BooleanExpression::Value(false), + propagator.fold_boolean_expression(BooleanExpression::not( + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); } @@ -1474,29 +1528,29 @@ mod tests { assert_eq!( propagator.fold_boolean_expression(BooleanExpression::conditional( - BooleanExpression::Value(true), - BooleanExpression::Value(true), - BooleanExpression::Value(false) + BooleanExpression::value(true), + BooleanExpression::value(true), + BooleanExpression::value(false) )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::conditional( - BooleanExpression::Value(false), - BooleanExpression::Value(true), - BooleanExpression::Value(false) + BooleanExpression::value(false), + BooleanExpression::value(true), + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::conditional( BooleanExpression::identifier("a".into()), - BooleanExpression::Value(true), - BooleanExpression::Value(true) + BooleanExpression::value(true), + BooleanExpression::value(true), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); } } @@ -1516,14 +1570,14 @@ mod tests { UBitwidth::B32, UExpression::select( vec![ - UExpressionInner::Value(1).annotate(UBitwidth::B32), - UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ], - UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), ) .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); assert_eq!( @@ -1531,10 +1585,10 @@ mod tests { UBitwidth::B32, UExpression::select( vec![ - UExpressionInner::Value(1).annotate(UBitwidth::B32), - UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ], - UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) .into_inner() ), @@ -1549,22 +1603,24 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Add( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::add( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(5)) + Ok(UExpression::value(5)) ); // a + 0 = a assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Add( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::add( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1577,22 +1633,24 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Sub( - box UExpressionInner::Value(3).annotate(UBitwidth::B32), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::sub( + UExpression::value(3).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(1)) + Ok(UExpression::value(1)) ); // a - 0 = a assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Sub( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::sub( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1605,22 +1663,24 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Mult( - box UExpressionInner::Value(3).annotate(UBitwidth::B32), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::mult( + UExpression::value(3).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(6)) + Ok(UExpression::value(6)) ); // a * 1 = a assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Mult( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpression::mult( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1629,12 +1689,13 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Mult( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::mult( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) ); } @@ -1645,21 +1706,23 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Div( - box UExpressionInner::Value(6).annotate(UBitwidth::B32), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::div( + UExpression::value(6).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(3)) + Ok(UExpression::value(3)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Div( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpression::div( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1667,10 +1730,11 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Div( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::div( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), Err(Error::DivisionByZero) ); @@ -1683,23 +1747,25 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Rem( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::rem( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Rem( - box UExpressionInner::Value(3).annotate(UBitwidth::B32), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::rem( + UExpression::value(3).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(1)) + Ok(UExpression::value(1)) ); } @@ -1710,23 +1776,25 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Xor( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::xor( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(1)) + Ok(UExpression::value(1)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Xor( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::xor( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::identifier("a".into()).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) ); } @@ -1737,32 +1805,35 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::And( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::and( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::And( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::and( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::And( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(u32::MAX as u128).annotate(UBitwidth::B32), + UExpression::and( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(u32::MAX as u128).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1775,21 +1846,23 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Or( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::or( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(3)) + Ok(UExpression::value(3)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Or( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::or( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1797,12 +1870,13 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Or( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(u32::MAX as u128).annotate(UBitwidth::B32), + UExpression::or( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(u32::MAX as u128).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(u32::MAX as u128)) + Ok(UExpression::value(u32::MAX as u128)) ); } @@ -1813,34 +1887,37 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::LeftShift( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - 3, + UExpression::left_shift( + UExpression::value(2).annotate(UBitwidth::B32), + 3.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(16)) + Ok(UExpression::value(16)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::LeftShift( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - 0, + UExpression::left_shift( + UExpression::value(2).annotate(UBitwidth::B32), + 0.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::LeftShift( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - 32, + UExpression::left_shift( + UExpression::value(2).annotate(UBitwidth::B32), + 32.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) ); } @@ -1851,34 +1928,37 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::RightShift( - box UExpressionInner::Value(4).annotate(UBitwidth::B32), - 2, + UExpression::right_shift( + UExpression::value(4).annotate(UBitwidth::B32), + 2.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(1)) + Ok(UExpression::value(1)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::RightShift( - box UExpressionInner::Value(4).annotate(UBitwidth::B32), - 0, + UExpression::right_shift( + UExpression::value(4).annotate(UBitwidth::B32), + 0.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(4)) + Ok(UExpression::value(4)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::RightShift( - box UExpressionInner::Value(4).annotate(UBitwidth::B32), - 32, + UExpression::right_shift( + UExpression::value(4).annotate(UBitwidth::B32), + 32.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) ); } @@ -1889,9 +1969,9 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Not(box UExpressionInner::Value(2).annotate(UBitwidth::B32),) + UExpression::not(UExpression::value(2).annotate(UBitwidth::B32),).into_inner() ), - Ok(UExpressionInner::Value(4294967293)) + Ok(UExpression::value(4294967293)) ); } @@ -1903,26 +1983,26 @@ mod tests { propagator.fold_uint_expression_inner( UBitwidth::B32, UExpression::conditional( - BooleanExpression::Value(true), - UExpressionInner::Value(1).annotate(UBitwidth::B32), - UExpressionInner::Value(2).annotate(UBitwidth::B32), + BooleanExpression::value(true), + UExpression::value(1).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) .into_inner() ), - Ok(UExpressionInner::Value(1)) + Ok(UExpression::value(1)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, UExpression::conditional( - BooleanExpression::Value(false), - UExpressionInner::Value(1).annotate(UBitwidth::B32), - UExpressionInner::Value(2).annotate(UBitwidth::B32), + BooleanExpression::value(false), + UExpression::value(1).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); assert_eq!( @@ -1930,12 +2010,12 @@ mod tests { UBitwidth::B32, UExpression::conditional( BooleanExpression::identifier("a".into()), - UExpressionInner::Value(2).annotate(UBitwidth::B32), - UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); } } diff --git a/zokrates_ark/src/gm17.rs b/zokrates_ark/src/gm17.rs index b756913a5..4148c1361 100644 --- a/zokrates_ark/src/gm17.rs +++ b/zokrates_ark/src/gm17.rs @@ -124,9 +124,14 @@ mod tests { #[test] fn verify_bls12_377_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], solvers: vec![], }; @@ -153,9 +158,14 @@ mod tests { #[test] fn verify_bw6_761_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], solvers: vec![], }; diff --git a/zokrates_ark/src/groth16.rs b/zokrates_ark/src/groth16.rs index 3c7b4b806..c49c8cab1 100644 --- a/zokrates_ark/src/groth16.rs +++ b/zokrates_ark/src/groth16.rs @@ -123,9 +123,14 @@ mod tests { #[test] fn verify_bls12_377_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], solvers: vec![], }; @@ -152,9 +157,14 @@ mod tests { #[test] fn verify_bw6_761_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], solvers: vec![], }; diff --git a/zokrates_ark/src/lib.rs b/zokrates_ark/src/lib.rs index e466dd1f6..dba49e9a9 100644 --- a/zokrates_ark/src/lib.rs +++ b/zokrates_ark/src/lib.rs @@ -8,7 +8,7 @@ use ark_relations::r1cs::{ SynthesisError, Variable as ArkVariable, }; use std::collections::BTreeMap; -use zokrates_ast::common::Variable; +use zokrates_ast::common::flat::Variable; use zokrates_ast::ir::{LinComb, ProgIterator, Statement, Witness}; use zokrates_field::{ArkFieldExtensions, Field}; @@ -44,7 +44,8 @@ fn ark_combination( symbols: &mut BTreeMap, witness: &mut Witness, ) -> LinearCombination<<::ArkEngine as PairingEngine>::Fr> { - l.0.into_iter() + l.value + .into_iter() .map(|(k, v)| { ( v.into_ark(), @@ -112,10 +113,10 @@ impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator> })); for statement in self.program.statements { - if let Statement::Constraint(quad, lin, _) = statement { - let a = ark_combination(quad.left, &mut cs, &mut symbols, &mut witness); - let b = ark_combination(quad.right, &mut cs, &mut symbols, &mut witness); - let c = ark_combination(lin, &mut cs, &mut symbols, &mut witness); + if let Statement::Constraint(s) = statement { + let a = ark_combination(s.quad.left, &mut cs, &mut symbols, &mut witness); + let b = ark_combination(s.quad.right, &mut cs, &mut symbols, &mut witness); + let c = ark_combination(s.lin, &mut cs, &mut symbols, &mut witness); cs.enforce_constraint(a, b, c)?; } diff --git a/zokrates_ark/src/marlin.rs b/zokrates_ark/src/marlin.rs index 0750c2b84..5292fba94 100644 --- a/zokrates_ark/src/marlin.rs +++ b/zokrates_ark/src/marlin.rs @@ -398,17 +398,16 @@ mod tests { #[test] fn verify_bls12_377_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::private(Variable::new(0))], return_count: 1, statements: vec![ Statement::constraint( - QuadComb::from_linear_combinations( - Variable::new(0).into(), - Variable::new(0).into(), - ), + QuadComb::new(Variable::new(0).into(), Variable::new(0).into()), Variable::new(1), + None, ), - Statement::constraint(Variable::new(1), Variable::public(0)), + Statement::constraint(Variable::new(1), Variable::public(0), None), ], solvers: vec![], }; @@ -437,17 +436,16 @@ mod tests { #[test] fn verify_bw6_761_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::private(Variable::new(0))], return_count: 1, statements: vec![ Statement::constraint( - QuadComb::from_linear_combinations( - Variable::new(0).into(), - Variable::new(0).into(), - ), + QuadComb::new(Variable::new(0).into(), Variable::new(0).into()), Variable::new(1), + None, ), - Statement::constraint(Variable::new(1), Variable::public(0)), + Statement::constraint(Variable::new(1), Variable::public(0), None), ], solvers: vec![], }; diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index b135b06f5..39ba46946 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -20,4 +20,4 @@ serde_json = { version = "1.0", features = ["preserve_order"] } zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false } pairing_ce = { version = "^0.21", optional = true } ark-bls12-377 = { version = "^0.3.0", features = ["curve"], default-features = false, optional = true } -derivative = "2.2.0" \ No newline at end of file +derivative = "2.2.0" diff --git a/zokrates_ast/src/common/embed.rs b/zokrates_ast/src/common/embed.rs index 3eeed58ce..18e902716 100644 --- a/zokrates_ast/src/common/embed.rs +++ b/zokrates_ast/src/common/embed.rs @@ -1,4 +1,7 @@ -use crate::common::{Parameter, RuntimeError, Solver, Variable}; +use crate::common::{ + flat::{Parameter, Variable}, + RuntimeError, Solver, +}; use crate::flat::flat_expression_from_bits; use crate::flat::{FlatDirective, FlatExpression, FlatFunctionIterator, FlatStatement}; use crate::typed::types::{ @@ -11,8 +14,11 @@ use crate::untyped::{ }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::ops::*; use zokrates_field::Field; +use super::ModuleMap; + cfg_if::cfg_if! { if #[cfg(feature = "bellman")] { use pairing_ce::bn256::Bn256; @@ -365,21 +371,18 @@ pub fn sha256_round<'ast, T: Field>( .clone() .into_iter() .chain(current_hash_argument_indices.clone()) - .map(|i| Parameter { - id: Variable::new(i), - private: true, - }) + .map(|i| Parameter::private(Variable::new(i))) .collect(); // define a binding of the first variable in the constraint system to one - let one_binding_statement = FlatStatement::Condition( + let one_binding_statement = FlatStatement::condition( Variable::new(0).into(), - FlatExpression::Number(T::from(1)), + FlatExpression::value(T::from(1)), RuntimeError::BellmanOneBinding, ); let input_binding_statements = // bind input and current_hash to inputs input_indices.chain(current_hash_indices).zip(input_argument_indices.clone().into_iter().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| { - FlatStatement::Condition( + FlatStatement::condition( Variable::new(cs_index).into(), Variable::new(argument_index).into(), RuntimeError::BellmanInputBinding @@ -392,30 +395,30 @@ pub fn sha256_round<'ast, T: Field>( let rhs_b = flat_expression_from_variable_summands::(c.b.as_slice()); let lhs = flat_expression_from_variable_summands::(c.c.as_slice()); - FlatStatement::Condition( + FlatStatement::condition( lhs, - FlatExpression::Mult(box rhs_a, box rhs_b), + FlatExpression::mul(rhs_a, rhs_b), RuntimeError::BellmanConstraint, ) }); // define which subset of the witness is returned - let outputs = output_indices.map(|o| FlatExpression::Identifier(Variable::new(o))); + let outputs = output_indices.map(|o| FlatExpression::identifier(Variable::new(o))); // insert a directive to set the witness based on the bellman gadget and inputs - let directive_statement = FlatStatement::Directive(FlatDirective { - outputs: cs_indices.map(Variable::new).collect(), - inputs: input_argument_indices + let directive_statement = FlatStatement::Directive(FlatDirective::new( + cs_indices.map(Variable::new).collect(), + Solver::Sha256Round, + input_argument_indices .into_iter() .chain(current_hash_argument_indices) .map(|i| Variable::new(i).into()) .collect(), - solver: Solver::Sha256Round, - }); + )); // insert a statement to return the subset of the witness let return_statements = outputs .into_iter() .enumerate() - .map(|(index, e)| FlatStatement::Definition(Variable::public(index), e)); + .map(|(index, e)| FlatStatement::definition(Variable::public(index), e)); let statements = std::iter::once(directive_statement) .chain(std::iter::once(one_binding_statement)) .chain(input_binding_statements) @@ -426,6 +429,7 @@ pub fn sha256_round<'ast, T: Field>( arguments, statements, return_count, + module_map: ModuleMap::default(), } } @@ -468,9 +472,9 @@ pub fn snark_verify_bls12_377<'ast, T: Field>( .chain(vk_arguments) .collect(); - let one_binding_statement = FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(0)), - FlatExpression::Number(T::from(1)), + let one_binding_statement = FlatStatement::condition( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(T::from(1)), RuntimeError::ArkOneBinding, ); @@ -484,7 +488,7 @@ pub fn snark_verify_bls12_377<'ast, T: Field>( .chain(vk_argument_indices.clone()), ) .map(|(cs_index, argument_index)| { - FlatStatement::Condition( + FlatStatement::condition( Variable::new(cs_index).into(), Variable::new(argument_index).into(), RuntimeError::ArkInputBinding, @@ -500,29 +504,29 @@ pub fn snark_verify_bls12_377<'ast, T: Field>( let rhs_b = flat_expression_from_variable_summands::(c.b.as_slice()); let lhs = flat_expression_from_variable_summands::(c.c.as_slice()); - FlatStatement::Condition( + FlatStatement::condition( lhs, - FlatExpression::Mult(box rhs_a, box rhs_b), + FlatExpression::mul(rhs_a, rhs_b), RuntimeError::ArkConstraint, ) }) .collect(); - let return_statement = FlatStatement::Definition( + let return_statement = FlatStatement::definition( Variable::public(0), - FlatExpression::Identifier(Variable::new(out_index)), + FlatExpression::identifier(Variable::new(out_index)), ); // insert a directive to set the witness - let directive_statement = FlatStatement::Directive(FlatDirective { - outputs: cs_indices.map(Variable::new).collect(), - inputs: input_argument_indices + let directive_statement = FlatStatement::Directive(FlatDirective::new( + cs_indices.map(Variable::new).collect(), + Solver::SnarkVerifyBls12377(n), + input_argument_indices .chain(proof_argument_indices) .chain(vk_argument_indices) .map(|i| Variable::new(i).into()) .collect(), - solver: Solver::SnarkVerifyBls12377(n), - }); + )); let statements = std::iter::once(directive_statement) .chain(std::iter::once(one_binding_statement)) @@ -534,6 +538,7 @@ pub fn snark_verify_bls12_377<'ast, T: Field>( arguments, statements, return_count: 1, + module_map: ModuleMap::default(), } } @@ -563,14 +568,11 @@ pub fn unpack_to_bitwidth<'ast, T: Field>( let mut layout = HashMap::new(); - let arguments = vec![Parameter { - id: Variable::new(0), - private: true, - }]; + let arguments = vec![Parameter::private(Variable::new(0))]; // o0, ..., o253 = ToBits(i0) - let directive_inputs = vec![FlatExpression::Identifier(use_variable( + let directive_inputs = vec![FlatExpression::identifier(use_variable( &mut layout, "i0".into(), &mut counter, @@ -586,16 +588,16 @@ pub fn unpack_to_bitwidth<'ast, T: Field>( let outputs: Vec<_> = directive_outputs .iter() .enumerate() - .map(|(_, o)| FlatExpression::Identifier(*o)) + .map(|(_, o)| FlatExpression::identifier(*o)) .collect(); // o253, o252, ... o{253 - (bit_width - 1)} are bits let mut statements: Vec> = (0..bit_width) .map(|index| { - let bit = FlatExpression::Identifier(Variable::new(bit_width - index)); - FlatStatement::Condition( + let bit = FlatExpression::identifier(Variable::new(bit_width - index)); + FlatStatement::condition( bit.clone(), - FlatExpression::Mult(box bit.clone(), box bit.clone()), + FlatExpression::mul(bit.clone(), bit), RuntimeError::Bitness, ) }) @@ -604,39 +606,40 @@ pub fn unpack_to_bitwidth<'ast, T: Field>( // sum check: o253 + o252 * 2 + ... + o{253 - (bit_width - 1)} * 2**(bit_width - 1) let lhs_sum = flat_expression_from_bits( (0..bit_width) - .map(|i| FlatExpression::Identifier(Variable::new(i + 1))) + .map(|i| FlatExpression::identifier(Variable::new(i + 1))) .collect(), ); - statements.push(FlatStatement::Condition( + statements.push(FlatStatement::condition( lhs_sum, - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Number(T::from(1)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(T::from(1)), ), RuntimeError::Sum, )); statements.insert( 0, - FlatStatement::Directive(FlatDirective { - inputs: directive_inputs, - outputs: directive_outputs, + FlatStatement::Directive(FlatDirective::new( + directive_outputs, solver, - }), + directive_inputs, + )), ); statements.extend( outputs .into_iter() .enumerate() - .map(|(index, e)| FlatStatement::Definition(Variable::public(index), e)), + .map(|(index, e)| FlatStatement::definition(Variable::public(index), e)), ); FlatFunctionIterator { arguments, statements: statements.into_iter(), return_count: bit_width, + module_map: ModuleMap::default(), } } @@ -662,7 +665,7 @@ mod tests { .map(|i| Variable::new(i + 1)) .collect(), Solver::bits(Bn128Field::get_required_bits()), - vec![Variable::new(0)] + vec![Variable::new(0).into()] )) ); assert_eq!( @@ -710,9 +713,9 @@ mod tests { // bellman variable #0: index 0 should equal 1 assert_eq!( compiled.statements[1], - FlatStatement::Condition( + FlatStatement::condition( Variable::new(0).into(), - FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), RuntimeError::BellmanOneBinding ) ); @@ -720,7 +723,7 @@ mod tests { // bellman input #0: index 1 should equal zokrates input #0: index v_count assert_eq!( compiled.statements[2], - FlatStatement::Condition( + FlatStatement::condition( Variable::new(1).into(), Variable::new(26936).into(), RuntimeError::BellmanInputBinding diff --git a/zokrates_ast/src/common/expressions.rs b/zokrates_ast/src/common/expressions.rs new file mode 100644 index 000000000..4b61b1f82 --- /dev/null +++ b/zokrates_ast/src/common/expressions.rs @@ -0,0 +1,280 @@ +use derivative::Derivative; +use num_bigint::BigUint; +use serde::{Deserialize, Serialize}; + +use super::operators::OpEq; +use super::{operators::OperatorStr, Span, WithSpan}; +use std::fmt; +use std::marker::PhantomData; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct BinaryExpression { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub left: Box, + pub right: Box, + operator: PhantomData, + output: PhantomData, +} + +impl BinaryExpression { + pub fn new(left: L, right: R) -> Self { + Self { + span: None, + left: box left, + right: box right, + operator: PhantomData, + output: PhantomData, + } + } +} + +impl fmt::Display + for BinaryExpression +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "({} {} {})", self.left, Op::STR, self.right,) + } +} + +impl WithSpan for BinaryExpression { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +pub enum BinaryOrExpression { + Binary(BinaryExpression), + Expression(I), +} + +pub type EqExpression = BinaryExpression; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnaryExpression { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub inner: Box, + operator: PhantomData, + output: PhantomData, +} + +impl UnaryExpression { + pub fn new(inner: In) -> Self { + Self { + span: None, + inner: box inner, + operator: PhantomData, + output: PhantomData, + } + } +} + +impl fmt::Display + for UnaryExpression +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "({}{})", Op::STR, self.inner,) + } +} + +impl WithSpan for UnaryExpression { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +pub enum UnaryOrExpression { + Unary(UnaryExpression), + Expression(I), +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ValueExpression { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub value: V, +} + +impl ValueExpression { + pub fn new(value: V) -> Self { + Self { span: None, value } + } +} + +impl fmt::Display for ValueExpression { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.value,) + } +} + +pub type FieldValueExpression = ValueExpression; + +pub type BooleanValueExpression = ValueExpression; + +pub type UValueExpression = ValueExpression; + +pub type IntValueExpression = ValueExpression; + +pub enum ValueOrExpression { + Value(V), + Expression(I), +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct IdentifierExpression { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub id: I, + pub ty: PhantomData, +} + +impl IdentifierExpression { + pub fn new(id: I) -> Self { + IdentifierExpression { + span: None, + id, + ty: PhantomData, + } + } +} + +impl fmt::Display for IdentifierExpression { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.id,) + } +} + +impl WithSpan for IdentifierExpression { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +pub enum IdentifierOrExpression { + Identifier(IdentifierExpression), + Expression(I), +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct DefinitionStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub assignee: A, + pub rhs: E, +} + +impl DefinitionStatement { + pub fn new(assignee: A, rhs: E) -> Self { + DefinitionStatement { + span: None, + assignee, + rhs, + } + } +} + +impl WithSpan for DefinitionStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for DefinitionStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} = {}", self.assignee, self.rhs,) + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AssertionStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub expression: B, + pub error: E, +} + +impl AssertionStatement { + pub fn new(expression: B, error: E) -> Self { + AssertionStatement { + span: None, + expression, + error, + } + } +} + +impl WithSpan for AssertionStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ReturnStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub inner: E, +} + +impl ReturnStatement { + pub fn new(e: E) -> Self { + ReturnStatement { + span: None, + inner: e, + } + } +} + +impl WithSpan for ReturnStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for ReturnStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "return {};", self.inner) + } +} diff --git a/zokrates_ast/src/common/flat/mod.rs b/zokrates_ast/src/common/flat/mod.rs new file mode 100644 index 000000000..d1892a6bb --- /dev/null +++ b/zokrates_ast/src/common/flat/mod.rs @@ -0,0 +1,5 @@ +pub mod parameter; +pub mod variable; + +pub use parameter::Parameter; +pub use variable::Variable; diff --git a/zokrates_ast/src/common/flat/parameter.rs b/zokrates_ast/src/common/flat/parameter.rs new file mode 100644 index 000000000..c4a86c962 --- /dev/null +++ b/zokrates_ast/src/common/flat/parameter.rs @@ -0,0 +1,63 @@ +use crate::common::{Span, WithSpan}; + +use super::variable::Variable; +use derivative::Derivative; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct Parameter { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub id: Variable, + pub private: bool, +} + +impl WithSpan for Parameter { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl Parameter { + pub fn new(id: Variable, private: bool) -> Self { + Parameter { + id, + private, + span: None, + } + } + + pub fn public(v: Variable) -> Self { + Self::new(v, false) + } + + pub fn private(v: Variable) -> Self { + Self::new(v, true) + } +} + +impl fmt::Display for Parameter { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let visibility = if self.private { "private " } else { "" }; + write!(f, "{}{}", visibility, self.id) + } +} + +impl Parameter { + pub fn apply_substitution(self, substitution: &HashMap) -> Parameter { + Parameter { + id: *substitution.get(&self.id).unwrap(), + private: self.private, + ..self + } + } +} diff --git a/zokrates_ast/src/common/flat/variable.rs b/zokrates_ast/src/common/flat/variable.rs new file mode 100644 index 000000000..ab460623c --- /dev/null +++ b/zokrates_ast/src/common/flat/variable.rs @@ -0,0 +1,102 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt; +use std::io::{Read, Write}; + +// A variable in a constraint system +// id > 0 for intermediate variables +// id == 0 for ~one +// id < 0 for public outputs +#[derive(Serialize, Deserialize, Clone, PartialEq, Hash, Eq, Ord, PartialOrd, Copy)] +pub struct Variable { + pub id: isize, +} + +impl Variable { + pub fn new(id: usize) -> Self { + Variable { + id: 1 + id as isize, + } + } + + pub fn one() -> Self { + Variable { id: 0 } + } + + pub fn public(id: usize) -> Self { + Variable { + id: -(id as isize) - 1, + } + } + + pub fn id(&self) -> usize { + assert!(self.id > 0); + (self.id as usize) - 1 + } + + pub fn write(&self, mut writer: W) -> std::io::Result<()> { + writer.write_all(&self.id.to_le_bytes())?; + Ok(()) + } + + pub fn read(mut reader: R) -> std::io::Result { + let mut buf = [0; std::mem::size_of::()]; + reader.read_exact(&mut buf)?; + + Ok(Variable { + id: isize::from_le_bytes(buf), + }) + } +} + +impl fmt::Display for Variable { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.id { + 0 => write!(f, "~one"), + i if i > 0 => write!(f, "_{}", i - 1), + i => write!(f, "~out_{}", -(i + 1)), + } + } +} + +impl fmt::Debug for Variable { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.id { + 0 => write!(f, "~one"), + i if i > 0 => write!(f, "_{}", i - 1), + i => write!(f, "~out_{}", -(i + 1)), + } + } +} + +impl Variable { + pub fn apply_substitution(self, substitution: &HashMap) -> &Self { + substitution.get(&self).unwrap() + } + + pub fn is_output(&self) -> bool { + self.id < 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn one() { + assert_eq!(format!("{}", Variable::one()), "~one"); + } + + #[test] + fn public() { + assert_eq!(format!("{}", Variable::public(0)), "~out_0"); + assert_eq!(format!("{}", Variable::public(42)), "~out_42"); + } + + #[test] + fn private() { + assert_eq!(format!("{}", Variable::new(0)), "_0"); + assert_eq!(format!("{}", Variable::new(42)), "_42"); + } +} diff --git a/zokrates_ast/src/common/fold.rs b/zokrates_ast/src/common/fold.rs new file mode 100644 index 000000000..98369bef1 --- /dev/null +++ b/zokrates_ast/src/common/fold.rs @@ -0,0 +1,7 @@ +pub trait Fold: Sized { + fn fold(self, f: &mut F) -> Self; +} + +pub trait ResultFold: Sized { + fn fold(self, f: &mut F) -> Result; +} diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index c8ad09440..186ea3caa 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -1,15 +1,28 @@ pub mod embed; mod error; +pub mod expressions; +pub mod flat; +mod fold; mod format_string; mod metadata; +pub mod operators; mod parameter; +mod position; mod solvers; +pub mod statements; +mod value; mod variable; pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; +pub use self::fold::{Fold, ResultFold}; pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; +pub use self::position::{ + LocalSourceSpan, ModuleId, ModuleIdHash, ModuleMap, OwnedModuleId, Position, SourceSpan, Span, + WithSpan, +}; pub use self::solvers::{RefCall, Solver}; +pub use self::value::Value; pub use self::variable::Variable; pub use format_string::FormatString; diff --git a/zokrates_ast/src/common/operators.rs b/zokrates_ast/src/common/operators.rs new file mode 100644 index 000000000..18a26cec2 --- /dev/null +++ b/zokrates_ast/src/common/operators.rs @@ -0,0 +1,143 @@ +pub trait OperatorStr { + const STR: &'static str; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpAdd; + +impl OperatorStr for OpAdd { + const STR: &'static str = "+"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpSub; + +impl OperatorStr for OpSub { + const STR: &'static str = "-"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpFloorSub; + +impl OperatorStr for OpFloorSub { + const STR: &'static str = "- /* FLOOR_SUB */"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpMul; + +impl OperatorStr for OpMul { + const STR: &'static str = "*"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpDiv; + +impl OperatorStr for OpDiv { + const STR: &'static str = "/"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpRem; + +impl OperatorStr for OpRem { + const STR: &'static str = "%"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpPow; + +impl OperatorStr for OpPow { + const STR: &'static str = "**"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpEq; + +impl OperatorStr for OpEq { + const STR: &'static str = "=="; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpLt; + +impl OperatorStr for OpLt { + const STR: &'static str = "<"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpLe; + +impl OperatorStr for OpLe { + const STR: &'static str = "<="; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpGt; + +impl OperatorStr for OpGt { + const STR: &'static str = ">"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpGe; + +impl OperatorStr for OpGe { + const STR: &'static str = ">="; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpXor; + +impl OperatorStr for OpXor { + const STR: &'static str = "^"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpOr; + +impl OperatorStr for OpOr { + const STR: &'static str = "|"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpAnd; + +impl OperatorStr for OpAnd { + const STR: &'static str = "&"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpLsh; + +impl OperatorStr for OpLsh { + const STR: &'static str = "<<"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpRsh; + +impl OperatorStr for OpRsh { + const STR: &'static str = ">>"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpNot; + +impl OperatorStr for OpNot { + const STR: &'static str = "!"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpNeg; + +impl OperatorStr for OpNeg { + const STR: &'static str = "-"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpPos; + +impl OperatorStr for OpPos { + const STR: &'static str = "+"; +} diff --git a/zokrates_ast/src/common/parameter.rs b/zokrates_ast/src/common/parameter.rs index 4b17395f4..e63b8be5c 100644 --- a/zokrates_ast/src/common/parameter.rs +++ b/zokrates_ast/src/common/parameter.rs @@ -1,46 +1,58 @@ -use super::variable::Variable; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::fmt; -#[derive(Serialize, Deserialize, Hash, Eq, PartialEq, Clone, Copy)] -pub struct Parameter { - pub id: Variable, +use derivative::Derivative; +use serde::{Deserialize, Serialize}; + +use super::{Span, WithSpan}; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Hash, Eq)] +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Parameter { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub id: V, pub private: bool, } -impl Parameter { - fn new(id: Variable, private: bool) -> Self { - Parameter { id, private } +impl From for Parameter { + fn from(v: V) -> Self { + Self::private(v) + } +} + +impl Parameter { + pub fn new(v: V, private: bool) -> Self { + Parameter { + span: None, + id: v, + private, + } } - pub fn public(v: Variable) -> Self { + pub fn public(v: V) -> Self { Self::new(v, false) } - pub fn private(v: Variable) -> Self { + pub fn private(v: V) -> Self { Self::new(v, true) } } -impl fmt::Display for Parameter { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let visibility = if self.private { "private " } else { "" }; - write!(f, "{}{}", visibility, self.id) +impl WithSpan for Parameter { + fn span(mut self, span: Option) -> Self { + self.span = span; + self } -} -impl fmt::Debug for Parameter { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Parameter(id: {:?})", self.id) + fn get_span(&self) -> Option { + self.span } } -impl Parameter { - pub fn apply_substitution(self, substitution: &HashMap) -> Parameter { - Parameter { - id: *substitution.get(&self.id).unwrap(), - private: self.private, - } +impl fmt::Display for Parameter { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let visibility = if self.private { "private " } else { "" }; + write!(f, "{}{}", visibility, self.id) } } diff --git a/zokrates_ast/src/common/position.rs b/zokrates_ast/src/common/position.rs new file mode 100644 index 000000000..a817f0749 --- /dev/null +++ b/zokrates_ast/src/common/position.rs @@ -0,0 +1,242 @@ +use std::{ + collections::BTreeMap, + fmt, + path::{Path, PathBuf}, +}; + +use serde::{Deserialize, Serialize}; + +use super::FlatEmbed; + +#[derive(Clone, PartialEq, Eq, Copy, Hash, Default, PartialOrd, Ord, Deserialize, Serialize)] +pub struct LocalSourceSpan { + pub from: Position, + pub to: Position, +} + +pub type ModuleIdHash = u64; + +pub type ModuleId = Path; + +pub type OwnedModuleId = PathBuf; + +#[derive(Clone, PartialEq, Debug, Eq, Hash, Default, PartialOrd, Ord, Deserialize, Serialize)] +pub struct ModuleMap { + modules: BTreeMap, +} + +impl ModuleMap { + pub fn new>(i: I) -> Self { + Self { + modules: i.into_iter().map(|id| (hash(&id), id)).collect(), + } + } + + pub fn remap_prefix(self, prefix: &Path, to: &Path) -> Self { + Self { + modules: self + .modules + .into_iter() + .map(|(id, path)| { + ( + id, + path.strip_prefix(prefix) + .map(|path| to.join(path)) + .unwrap_or(path), + ) + }) + .collect(), + } + } +} + +#[derive(Clone, PartialEq, Eq, Copy, Hash, Default, PartialOrd, Ord, Deserialize, Serialize)] +pub struct Position { + pub line: usize, + pub col: usize, +} + +#[derive(Clone, PartialEq, Eq, Copy, Hash, PartialOrd, Ord, Deserialize, Serialize, Debug)] +pub enum Span { + Source(SourceSpan), + Embed(FlatEmbed), +} + +impl fmt::Display for Span { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Span::Source(s) => write!(f, "{}", s), + Span::Embed(e) => write!(f, "{:?}", e), + } + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum ResolvedSpan { + Source(ResolvedSourceSpan), + Embed(FlatEmbed), +} + +impl fmt::Display for ResolvedSpan { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ResolvedSpan::Source(s) => write!(f, "{}", s), + ResolvedSpan::Embed(e) => write!(f, "{:?}", e), + } + } +} + +impl Span { + pub fn resolve(self, map: &ModuleMap) -> ResolvedSpan { + match self { + Span::Source(s) => ResolvedSpan::Source(ResolvedSourceSpan { + module: map.modules.get(&s.module).cloned().unwrap(), + from: s.from, + to: s.to, + }), + Span::Embed(s) => ResolvedSpan::Embed(s), + } + } +} + +#[derive(Clone, PartialEq, Eq, Copy, Hash, Default, PartialOrd, Ord, Deserialize, Serialize)] +pub struct SourceSpan { + pub module: ModuleIdHash, + pub from: Position, + pub to: Position, +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct ResolvedSourceSpan { + pub module: OwnedModuleId, + pub from: Position, + pub to: Position, +} + +impl From for Span { + fn from(span: SourceSpan) -> Self { + Self::Source(span) + } +} + +impl From for Span { + fn from(embed: FlatEmbed) -> Self { + Self::Embed(embed) + } +} + +impl SourceSpan { + pub fn mock() -> Self { + Self { + module: hash(&OwnedModuleId::default()), + from: Position::mock(), + to: Position::mock(), + } + } +} + +pub trait WithSpan: Sized { + fn span(self, _: Option) -> Self; + + fn with_span>(self, span: S) -> Self { + self.span(Some(span.into())) + } + + fn get_span(&self) -> Option; +} + +fn hash(id: &ModuleId) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + id.hash(&mut hasher); + hasher.finish() +} + +impl LocalSourceSpan { + pub fn in_module(self, module_id: &ModuleId) -> SourceSpan { + SourceSpan { + module: hash(module_id), + from: self.from, + to: self.to, + } + } + + pub fn mock() -> Self { + Self { + from: Position::mock(), + to: Position::mock(), + } + } +} + +impl Position { + pub fn col(&self, delta: isize) -> Position { + assert!(self.col <= isize::max_value() as usize); + assert!(self.col as isize >= delta); + Position { + line: self.line, + col: (self.col as isize + delta) as usize, + } + } + + pub fn mock() -> Self { + Position { line: 42, col: 42 } + } +} +impl fmt::Display for Position { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}:{}", self.line, self.col) + } +} +impl fmt::Debug for Position { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl fmt::Display for SourceSpan { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.from) + } +} +impl fmt::Debug for SourceSpan { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl fmt::Display for ResolvedSourceSpan { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}:{} (until {})", + self.module.display(), + self.from, + self.to + ) + } +} + +#[test] +fn position_col() { + let pos = Position { + line: 100, + col: 258, + }; + assert_eq!( + pos.col(26), + Position { + line: 100, + col: 284, + } + ); + assert_eq!( + pos.col(-23), + Position { + line: 100, + col: 235, + } + ); +} diff --git a/zokrates_ast/src/common/statements.rs b/zokrates_ast/src/common/statements.rs new file mode 100644 index 000000000..ba67b8283 --- /dev/null +++ b/zokrates_ast/src/common/statements.rs @@ -0,0 +1,278 @@ +use derivative::Derivative; +use serde::{Deserialize, Serialize}; + +use crate::Solver; + +use super::{FormatString, SourceMetadata}; +use super::{Span, WithSpan}; +use std::fmt; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct DefinitionStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub assignee: A, + pub rhs: E, +} + +impl DefinitionStatement { + pub fn new(assignee: A, rhs: E) -> Self { + DefinitionStatement { + span: None, + assignee, + rhs, + } + } +} + +impl WithSpan for DefinitionStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for DefinitionStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} = {};", self.assignee, self.rhs,) + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct AssertionStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub expression: B, + pub error: E, +} + +impl AssertionStatement { + pub fn new(expression: B, error: E) -> Self { + AssertionStatement { + span: None, + expression, + error, + } + } +} + +impl WithSpan for AssertionStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct ReturnStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub inner: E, +} + +impl ReturnStatement { + pub fn new(e: E) -> Self { + ReturnStatement { + span: None, + inner: e, + } + } +} + +impl WithSpan for ReturnStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for ReturnStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "return {};", self.inner) + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct LogStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub format_string: FormatString, + pub expressions: Vec, +} + +impl LogStatement { + pub fn new(format_string: FormatString, expressions: Vec) -> Self { + LogStatement { + span: None, + format_string, + expressions, + } + } +} + +impl WithSpan for LogStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for LogStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "log({}, {});", + self.format_string, + self.expressions + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", ") + ) + } +} + +#[derive(Derivative, Clone, Debug, Serialize, Deserialize)] +#[derivative(Hash, PartialEq, Eq)] +pub struct DirectiveStatement { + #[derivative(Hash = "ignore", PartialEq = "ignore")] + pub span: Option, + pub inputs: Vec, + pub outputs: Vec, + pub solver: S, +} + +impl<'ast, T, I, O> DirectiveStatement> { + pub fn new(outputs: Vec, solver: Solver<'ast, T>, inputs: Vec) -> Self { + let (in_len, out_len) = solver.get_signature(); + assert_eq!(in_len, inputs.len()); + assert_eq!(out_len, outputs.len()); + Self { + span: None, + inputs, + outputs, + solver, + } + } +} + +impl WithSpan for DirectiveStatement { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display + for DirectiveStatement +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "# {} = {}({})", + self.outputs + .iter() + .map(|o| format!("{}", o)) + .collect::>() + .join(", "), + self.solver, + self.inputs + .iter() + .map(|i| format!("{}", i)) + .collect::>() + .join(", ") + ) + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AssemblyAssignment { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub assignee: A, + pub expression: E, +} + +impl AssemblyAssignment { + pub fn new(assignee: A, expression: E) -> Self { + Self { + span: None, + assignee, + expression, + } + } +} + +impl WithSpan for AssemblyAssignment { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AssemblyConstraint { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub left: E, + pub right: E, + pub metadata: SourceMetadata, +} + +impl AssemblyConstraint { + pub fn new(left: E, right: E, metadata: SourceMetadata) -> Self { + Self { + span: None, + left, + right, + metadata, + } + } +} + +impl WithSpan for AssemblyConstraint { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for AssemblyConstraint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} === {};", self.left, self.right) + } +} diff --git a/zokrates_ast/src/common/value.rs b/zokrates_ast/src/common/value.rs new file mode 100644 index 000000000..af204a813 --- /dev/null +++ b/zokrates_ast/src/common/value.rs @@ -0,0 +1,3 @@ +pub trait Value { + type Value: Clone; +} diff --git a/zokrates_ast/src/common/variable.rs b/zokrates_ast/src/common/variable.rs index ab460623c..c5556805a 100644 --- a/zokrates_ast/src/common/variable.rs +++ b/zokrates_ast/src/common/variable.rs @@ -1,102 +1,42 @@ +use derivative::Derivative; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::fmt; -use std::io::{Read, Write}; -// A variable in a constraint system -// id > 0 for intermediate variables -// id == 0 for ~one -// id < 0 for public outputs -#[derive(Serialize, Deserialize, Clone, PartialEq, Hash, Eq, Ord, PartialOrd, Copy)] -pub struct Variable { - pub id: isize, -} - -impl Variable { - pub fn new(id: usize) -> Self { - Variable { - id: 1 + id as isize, - } - } - - pub fn one() -> Self { - Variable { id: 0 } - } - - pub fn public(id: usize) -> Self { - Variable { - id: -(id as isize) - 1, - } - } +use super::{Span, WithSpan}; - pub fn id(&self) -> usize { - assert!(self.id > 0); - (self.id as usize) - 1 - } +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Hash, Eq, Ord)] +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Variable { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub id: I, + pub ty: T, +} - pub fn write(&self, mut writer: W) -> std::io::Result<()> { - writer.write_all(&self.id.to_le_bytes())?; - Ok(()) +impl WithSpan for Variable { + fn span(mut self, span: Option) -> Self { + self.span = span; + self } - pub fn read(mut reader: R) -> std::io::Result { - let mut buf = [0; std::mem::size_of::()]; - reader.read_exact(&mut buf)?; - - Ok(Variable { - id: isize::from_le_bytes(buf), - }) + fn get_span(&self) -> Option { + self.span } } -impl fmt::Display for Variable { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.id { - 0 => write!(f, "~one"), - i if i > 0 => write!(f, "_{}", i - 1), - i => write!(f, "~out_{}", -(i + 1)), +impl Variable { + pub fn new>(id: J, ty: T) -> Self { + Self { + span: None, + id: id.into(), + ty, } } } -impl fmt::Debug for Variable { +impl fmt::Display for Variable { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.id { - 0 => write!(f, "~one"), - i if i > 0 => write!(f, "_{}", i - 1), - i => write!(f, "~out_{}", -(i + 1)), - } - } -} - -impl Variable { - pub fn apply_substitution(self, substitution: &HashMap) -> &Self { - substitution.get(&self).unwrap() - } - - pub fn is_output(&self) -> bool { - self.id < 0 - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn one() { - assert_eq!(format!("{}", Variable::one()), "~one"); - } - - #[test] - fn public() { - assert_eq!(format!("{}", Variable::public(0)), "~out_0"); - assert_eq!(format!("{}", Variable::public(42)), "~out_42"); - } - - #[test] - fn private() { - assert_eq!(format!("{}", Variable::new(0)), "_0"); - assert_eq!(format!("{}", Variable::new(42)), "_42"); + write!(f, "{} {}", self.ty, self.id,) } } diff --git a/zokrates_ast/src/flat/folder.rs b/zokrates_ast/src/flat/folder.rs index ce50d7dac..539e02ebe 100644 --- a/zokrates_ast/src/flat/folder.rs +++ b/zokrates_ast/src/flat/folder.rs @@ -1,9 +1,18 @@ // Generic walk through an IR AST. Not mutating in place use super::*; -use crate::common::Variable; +use crate::common::{ + expressions::{BinaryOrExpression, IdentifierOrExpression}, + flat::Variable, + Fold, WithSpan, +}; use zokrates_field::Field; +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for FlatExpression { + fn fold(self, f: &mut F) -> Self { + f.fold_expression(self) + } +} pub trait Folder<'ast, T: Field>: Sized { fn fold_program(&mut self, p: FlatProg<'ast, T>) -> FlatProg<'ast, T> { fold_program(self, p) @@ -25,6 +34,20 @@ pub trait Folder<'ast, T: Field>: Sized { fold_expression(self, e) } + fn fold_binary_expression, R: Fold, E>( + &mut self, + e: BinaryExpression, + ) -> BinaryOrExpression> { + fold_binary_expression(self, e) + } + + fn fold_identifier_expression( + &mut self, + e: IdentifierExpression>, + ) -> IdentifierOrExpression, FlatExpression> { + fold_identifier_expression(self, e) + } + fn fold_directive(&mut self, d: FlatDirective<'ast, T>) -> FlatDirective<'ast, T> { fold_directive(self, d) } @@ -45,7 +68,7 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), - return_count: p.return_count, + ..p } } @@ -54,25 +77,27 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( s: FlatStatement<'ast, T>, ) -> Vec> { match s { - FlatStatement::Block(statements) => vec![FlatStatement::Block( + FlatStatement::Condition(s) => vec![FlatStatement::condition( + f.fold_expression(s.quad), + f.fold_expression(s.lin), + s.error, + )], + FlatStatement::Block(statements) => vec![FlatStatement::block( statements + .inner .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), )], - FlatStatement::Condition(left, right, error) => vec![FlatStatement::Condition( - f.fold_expression(left), - f.fold_expression(right), - error, - )], - FlatStatement::Definition(v, e) => vec![FlatStatement::Definition( - f.fold_variable(v), - f.fold_expression(e), + FlatStatement::Definition(s) => vec![FlatStatement::definition( + f.fold_variable(s.assignee), + f.fold_expression(s.rhs), )], FlatStatement::Directive(d) => vec![FlatStatement::Directive(f.fold_directive(d))], - FlatStatement::Log(s, e) => vec![FlatStatement::Log( - s, - e.into_iter() + FlatStatement::Log(s) => vec![FlatStatement::log( + s.format_string, + s.expressions + .into_iter() .map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect())) .collect(), )], @@ -84,20 +109,45 @@ pub fn fold_expression<'ast, T: Field, F: Folder<'ast, T>>( e: FlatExpression, ) -> FlatExpression { match e { - FlatExpression::Number(n) => FlatExpression::Number(n), - FlatExpression::Identifier(id) => FlatExpression::Identifier(f.fold_variable(id)), - FlatExpression::Add(box left, box right) => { - FlatExpression::Add(box f.fold_expression(left), box f.fold_expression(right)) - } - FlatExpression::Sub(box left, box right) => { - FlatExpression::Sub(box f.fold_expression(left), box f.fold_expression(right)) - } - FlatExpression::Mult(box left, box right) => { - FlatExpression::Mult(box f.fold_expression(left), box f.fold_expression(right)) - } + FlatExpression::Value(n) => FlatExpression::Value(n), + FlatExpression::Identifier(id) => match f.fold_identifier_expression(id) { + IdentifierOrExpression::Identifier(e) => FlatExpression::Identifier(e), + IdentifierOrExpression::Expression(e) => e, + }, + FlatExpression::Add(e) => match f.fold_binary_expression(e) { + BinaryOrExpression::Binary(e) => FlatExpression::Add(e), + BinaryOrExpression::Expression(e) => e, + }, + FlatExpression::Sub(e) => match f.fold_binary_expression(e) { + BinaryOrExpression::Binary(e) => FlatExpression::Sub(e), + BinaryOrExpression::Expression(e) => e, + }, + FlatExpression::Mult(e) => match f.fold_binary_expression(e) { + BinaryOrExpression::Binary(e) => FlatExpression::Mult(e), + BinaryOrExpression::Expression(e) => e, + }, } } +fn fold_identifier_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: IdentifierExpression>, +) -> IdentifierOrExpression, FlatExpression> { + let id = f.fold_variable(e.id); + + IdentifierOrExpression::Identifier(IdentifierExpression { id, ..e }) +} + +fn fold_binary_expression<'ast, T: Field, F: Folder<'ast, T>, Op, L: Fold, R: Fold, E>( + f: &mut F, + e: BinaryExpression, +) -> BinaryOrExpression> { + let left = e.left.fold(f); + let right = e.right.fold(f); + + BinaryOrExpression::Binary(BinaryExpression::new(left, right).span(e.span)) +} + pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ds: FlatDirective<'ast, T>, @@ -116,7 +166,7 @@ pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter { Parameter { id: f.fold_variable(a.id), - private: a.private, + ..a } } diff --git a/zokrates_ast/src/flat/mod.rs b/zokrates_ast/src/flat/mod.rs index 015e670a5..e1a86e7a9 100644 --- a/zokrates_ast/src/flat/mod.rs +++ b/zokrates_ast/src/flat/mod.rs @@ -8,11 +8,20 @@ pub mod folder; pub mod utils; +use crate::common; +pub use crate::common::flat::Parameter; +pub use crate::common::flat::Variable; +use crate::common::statements::DirectiveStatement; use crate::common::FormatString; -pub use crate::common::Parameter; +use crate::common::ModuleMap; pub use crate::common::RuntimeError; -pub use crate::common::Variable; +use crate::common::{ + expressions::{BinaryExpression, IdentifierExpression, ValueExpression}, + operators::*, +}; +use crate::common::{Span, WithSpan}; +use derivative::Derivative; pub use utils::{ flat_expression_from_bits, flat_expression_from_expression_summands, flat_expression_from_variable_summands, @@ -32,6 +41,8 @@ pub type FlatProgIterator<'ast, T, I> = FlatFunctionIterator<'ast, T, I>; #[derive(Clone, PartialEq, Eq, Debug)] pub struct FlatFunctionIterator<'ast, T, I: IntoIterator>> { + /// The map of the modules for sourcemaps + pub module_map: ModuleMap, /// Arguments of the function pub arguments: Vec, /// Vector of statements that are executed when running the function @@ -46,6 +57,7 @@ impl<'ast, T, I: IntoIterator>> FlatFunctionIterat statements: self.statements.into_iter().collect(), arguments: self.arguments, return_count: self.return_count, + module_map: self.module_map, } } } @@ -70,45 +82,155 @@ impl<'ast, T: Field> fmt::Display for FlatFunction<'ast, T> { } } -/// Calculates a flattened function based on a R1CS (A, B, C) and returns that flattened function: -/// * The Rank 1 Constraint System (R1CS) is defined as: -/// * `* = ` for a witness `x` -/// * Since the matrices in R1CS are usually sparse, the following encoding is used: -/// * For each constraint (i.e., row in the R1CS), only non-zero values are supplied and encoded as a tuple (index, value). -/// -/// # Arguments -/// -/// * r1cs - R1CS in standard JSON data format +pub type DefinitionStatement = + common::expressions::DefinitionStatement>; +pub type LogStatement = common::statements::LogStatement<(ConcreteType, Vec>)>; +pub type FlatDirective<'ast, T> = + common::statements::DirectiveStatement, Variable, Solver<'ast, T>>; -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug)] +pub struct BlockStatement<'ast, T> { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + pub span: Option, + pub inner: Vec>, +} + +impl<'ast, T> BlockStatement<'ast, T> { + pub fn new(inner: Vec>) -> Self { + BlockStatement { span: None, inner } + } +} + +impl WithSpan for AssertionStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug)] +pub struct AssertionStatement { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + pub span: Option, + pub quad: FlatExpression, + pub lin: FlatExpression, + pub error: RuntimeError, +} + +impl AssertionStatement { + pub fn new(lin: FlatExpression, quad: FlatExpression, error: RuntimeError) -> Self { + AssertionStatement { + span: None, + quad, + lin, + error, + } + } +} + +impl<'ast, T> WithSpan for BlockStatement<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug)] pub enum FlatStatement<'ast, T> { - Block(Vec>), - Condition(FlatExpression, FlatExpression, RuntimeError), - Definition(Variable, FlatExpression), + Condition(AssertionStatement), + Definition(DefinitionStatement), Directive(FlatDirective<'ast, T>), - Log(FormatString, Vec<(ConcreteType, Vec>)>), + Log(LogStatement), + Block(BlockStatement<'ast, T>), +} + +impl<'ast, T> FlatStatement<'ast, T> { + pub fn definition(assignee: Variable, rhs: FlatExpression) -> Self { + Self::Definition(DefinitionStatement::new(assignee, rhs)) + } + + pub fn condition(lin: FlatExpression, quad: FlatExpression, error: RuntimeError) -> Self { + Self::Condition(AssertionStatement::new(lin, quad, error)) + } + + pub fn log( + format_string: FormatString, + expressions: Vec<(ConcreteType, Vec>)>, + ) -> Self { + Self::Log(LogStatement::new(format_string, expressions)) + } + + pub fn directive( + outputs: Vec, + solver: Solver<'ast, T>, + inputs: Vec>, + ) -> Self { + Self::Directive(DirectiveStatement::new(outputs, solver, inputs)) + } + + pub fn block(inner: Vec>) -> Self { + Self::Block(BlockStatement::new(inner)) + } +} + +impl<'ast, T> WithSpan for FlatStatement<'ast, T> { + fn span(self, span: Option) -> Self { + use FlatStatement::*; + + match self { + Condition(e) => Condition(e.span(span)), + Definition(e) => Definition(e.span(span)), + Directive(e) => Directive(e.span(span)), + Log(e) => Log(e.span(span)), + Block(e) => Block(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use FlatStatement::*; + + match self { + Condition(e) => e.get_span(), + Definition(e) => e.get_span(), + Directive(e) => e.get_span(), + Log(e) => e.get_span(), + Block(e) => e.get_span(), + } + } } impl<'ast, T: Field> fmt::Display for FlatStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FlatStatement::Block(ref statements) => { + FlatStatement::Definition(ref e) => write!(f, "{}", e), + FlatStatement::Condition(ref s) => { + write!(f, "{} == {} // {}", s.lin, s.quad, s.error) + } + FlatStatement::Block(ref s) => { writeln!(f, "{{")?; - for s in statements { + for s in &s.inner { writeln!(f, "{}", s)?; } writeln!(f, "}}") } - FlatStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs), - FlatStatement::Condition(ref lhs, ref rhs, ref message) => { - write!(f, "{} == {} // {}", lhs, rhs, message) - } FlatStatement::Directive(ref d) => write!(f, "{}", d), - FlatStatement::Log(ref l, ref expressions) => write!( + FlatStatement::Log(ref s) => write!( f, "log(\"{}\"), {})", - l, - expressions + s.format_string, + s.expressions .iter() .map(|(_, e)| format!( "[{}]", @@ -130,20 +252,20 @@ impl<'ast, T: Field> FlatStatement<'ast, T> { substitution: &'ast HashMap, ) -> FlatStatement { match self { - FlatStatement::Block(statements) => FlatStatement::Block( - statements + FlatStatement::Definition(s) => FlatStatement::definition( + *s.assignee.apply_substitution(substitution), + s.rhs.apply_substitution(substitution), + ), + FlatStatement::Block(s) => FlatStatement::block( + s.inner .into_iter() .map(|s| s.apply_substitution(substitution)) .collect(), ), - FlatStatement::Definition(id, x) => FlatStatement::Definition( - *id.apply_substitution(substitution), - x.apply_substitution(substitution), - ), - FlatStatement::Condition(x, y, message) => FlatStatement::Condition( - x.apply_substitution(substitution), - y.apply_substitution(substitution), - message, + FlatStatement::Condition(s) => FlatStatement::condition( + s.quad.apply_substitution(substitution), + s.lin.apply_substitution(substitution), + s.error, ), FlatStatement::Directive(d) => { let outputs = d @@ -163,9 +285,10 @@ impl<'ast, T: Field> FlatStatement<'ast, T> { ..d }) } - FlatStatement::Log(l, e) => FlatStatement::Log( - l, - e.into_iter() + FlatStatement::Log(s) => FlatStatement::Log(LogStatement::new( + s.format_string, + s.expressions + .into_iter() .map(|(t, e)| { ( t, @@ -175,111 +298,103 @@ impl<'ast, T: Field> FlatStatement<'ast, T> { ) }) .collect(), - ), + )), } } } -#[derive(Clone, Hash, Debug, PartialEq, Eq)] -pub struct FlatDirective<'ast, T> { - pub inputs: Vec>, - pub outputs: Vec, - pub solver: Solver<'ast, T>, +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum FlatExpression { + Value(ValueExpression), + Identifier(IdentifierExpression), + Add(BinaryExpression), + Sub(BinaryExpression), + Mult(BinaryExpression), } -impl<'ast, T> FlatDirective<'ast, T> { - pub fn new>>( - outputs: Vec, - solver: Solver<'ast, T>, - inputs: Vec, - ) -> Self { - let (in_len, out_len) = solver.get_signature(); - assert_eq!(in_len, inputs.len()); - assert_eq!(out_len, outputs.len()); - FlatDirective { - solver, - inputs: inputs.into_iter().map(|i| i.into()).collect(), - outputs, - } +impl std::ops::Add for FlatExpression { + type Output = Self; + + fn add(self, other: Self) -> Self::Output { + FlatExpression::Add(BinaryExpression::new(self, other)) } } -impl<'ast, T: Field> fmt::Display for FlatDirective<'ast, T> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "# {} = {}({})", - self.outputs - .iter() - .map(|o| o.to_string()) - .collect::>() - .join(", "), - self.solver, - self.inputs - .iter() - .map(|i| i.to_string()) - .collect::>() - .join(", ") - ) +impl std::ops::Sub for FlatExpression { + type Output = Self; + + fn sub(self, other: Self) -> Self::Output { + FlatExpression::Sub(BinaryExpression::new(self, other)) } } -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub enum FlatExpression { - Number(T), - Identifier(Variable), - Add(Box>, Box>), - Sub(Box>, Box>), - Mult(Box>, Box>), +impl std::ops::Mul for FlatExpression { + type Output = Self; + + fn mul(self, other: Self) -> Self::Output { + FlatExpression::Mult(BinaryExpression::new(self, other)) + } } impl From for FlatExpression { fn from(other: T) -> Self { - Self::Number(other) + Self::value(other) } } -impl FlatExpression { - pub fn apply_substitution( - self, - substitution: &HashMap, - ) -> FlatExpression { +impl BinaryExpression, FlatExpression, FlatExpression> { + fn apply_substitution(self, substitution: &HashMap) -> Self { + let left = self.left.apply_substitution(substitution); + let right = self.right.apply_substitution(substitution); + + Self::new(left, right).span(self.span) + } +} + +impl IdentifierExpression> { + fn apply_substitution(self, substitution: &HashMap) -> Self { + let id = *self.id.apply_substitution(substitution); + + IdentifierExpression { id, ..self } + } +} + +impl FlatExpression { + pub fn identifier(v: Variable) -> Self { + Self::Identifier(IdentifierExpression::new(v)) + } + + pub fn value(t: T) -> Self { + Self::Value(ValueExpression::new(t)) + } + + pub fn apply_substitution(self, substitution: &HashMap) -> Self { match self { - e @ FlatExpression::Number(_) => e, + e @ FlatExpression::Value(_) => e, FlatExpression::Identifier(id) => { - FlatExpression::Identifier(*id.apply_substitution(substitution)) + FlatExpression::Identifier(id.apply_substitution(substitution)) } - FlatExpression::Add(e1, e2) => FlatExpression::Add( - box e1.apply_substitution(substitution), - box e2.apply_substitution(substitution), - ), - FlatExpression::Sub(e1, e2) => FlatExpression::Sub( - box e1.apply_substitution(substitution), - box e2.apply_substitution(substitution), - ), - FlatExpression::Mult(e1, e2) => FlatExpression::Mult( - box e1.apply_substitution(substitution), - box e2.apply_substitution(substitution), - ), + FlatExpression::Add(e) => FlatExpression::Add(e.apply_substitution(substitution)), + FlatExpression::Sub(e) => FlatExpression::Sub(e.apply_substitution(substitution)), + FlatExpression::Mult(e) => FlatExpression::Mult(e.apply_substitution(substitution)), } } pub fn is_linear(&self) -> bool { match *self { - FlatExpression::Number(_) | FlatExpression::Identifier(_) => true, - FlatExpression::Add(ref x, ref y) | FlatExpression::Sub(ref x, ref y) => { - x.is_linear() && y.is_linear() - } - FlatExpression::Mult(ref x, ref y) => matches!( - (x.clone(), y.clone()), - (box FlatExpression::Number(_), box FlatExpression::Number(_)) + FlatExpression::Value(_) | FlatExpression::Identifier(_) => true, + FlatExpression::Add(ref e) => e.left.is_linear() && e.right.is_linear(), + FlatExpression::Sub(ref e) => e.left.is_linear() && e.right.is_linear(), + FlatExpression::Mult(ref e) => matches!( + (&e.left, &e.right), + (box FlatExpression::Value(_), box FlatExpression::Value(_)) | ( - box FlatExpression::Number(_), + box FlatExpression::Value(_), box FlatExpression::Identifier(_) ) | ( box FlatExpression::Identifier(_), - box FlatExpression::Number(_) + box FlatExpression::Value(_) ) ), } @@ -289,18 +404,18 @@ impl FlatExpression { impl fmt::Display for FlatExpression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FlatExpression::Number(ref i) => write!(f, "{}", i), + FlatExpression::Value(ref i) => write!(f, "{}", i), FlatExpression::Identifier(ref var) => write!(f, "{}", var), - FlatExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - FlatExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - FlatExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), + FlatExpression::Add(ref e) => write!(f, "{}", e), + FlatExpression::Sub(ref e) => write!(f, "{}", e), + FlatExpression::Mult(ref e) => write!(f, "{}", e), } } } impl From for FlatExpression { fn from(v: Variable) -> FlatExpression { - FlatExpression::Identifier(v) + FlatExpression::identifier(v) } } @@ -314,3 +429,27 @@ impl fmt::Display for Error { write!(f, "{}", self.message) } } + +impl WithSpan for FlatExpression { + fn span(self, span: Option) -> Self { + use FlatExpression::*; + match self { + Add(e) => Add(e.span(span)), + Sub(e) => Sub(e.span(span)), + Mult(e) => Mult(e.span(span)), + Value(e) => Value(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use FlatExpression::*; + match self { + Add(e) => e.get_span(), + Sub(e) => e.get_span(), + Mult(e) => e.get_span(), + Value(e) => e.get_span(), + Identifier(e) => e.get_span(), + } + } +} diff --git a/zokrates_ast/src/flat/utils.rs b/zokrates_ast/src/flat/utils.rs index c37d2f68c..7d1f1ff57 100644 --- a/zokrates_ast/src/flat/utils.rs +++ b/zokrates_ast/src/flat/utils.rs @@ -1,4 +1,5 @@ use crate::flat::{FlatExpression, Variable}; +use std::ops::*; use zokrates_field::Field; // util to convert a vector of `(coefficient, expression)` to a flat_expression @@ -7,16 +8,16 @@ pub fn flat_expression_from_expression_summands FlatExpression { match v.len() { - 0 => FlatExpression::Number(T::zero()), + 0 => FlatExpression::value(T::zero()), 1 => { let (val, var) = v[0].clone(); - FlatExpression::Mult(box FlatExpression::Number(val), box var.into()) + FlatExpression::mul(FlatExpression::value(val), var.into()) } n => { let (u, v) = v.split_at(n / 2); - FlatExpression::Add( - box flat_expression_from_expression_summands(u), - box flat_expression_from_expression_summands(v), + FlatExpression::add( + flat_expression_from_expression_summands(u), + flat_expression_from_expression_summands(v), ) } } @@ -34,19 +35,19 @@ pub fn flat_expression_from_bits(v: Vec>) -> FlatExp pub fn flat_expression_from_variable_summands(v: &[(T, usize)]) -> FlatExpression { match v.len() { - 0 => FlatExpression::Number(T::zero()), + 0 => FlatExpression::value(T::zero()), 1 => { let (val, var) = v[0]; - FlatExpression::Mult( - box FlatExpression::Number(val), - box FlatExpression::Identifier(Variable::new(var)), + FlatExpression::mul( + FlatExpression::value(val), + FlatExpression::identifier(Variable::new(var)), ) } n => { let (u, v) = v.split_at(n / 2); - FlatExpression::Add( - box flat_expression_from_variable_summands(u), - box flat_expression_from_variable_summands(v), + FlatExpression::add( + flat_expression_from_variable_summands(u), + flat_expression_from_variable_summands(v), ) } } diff --git a/zokrates_ast/src/ir/check.rs b/zokrates_ast/src/ir/check.rs index 41cac7b0d..e4c1652ae 100644 --- a/zokrates_ast/src/ir/check.rs +++ b/zokrates_ast/src/ir/check.rs @@ -1,5 +1,5 @@ use crate::ir::folder::Folder; -use crate::ir::Directive; +use crate::ir::DirectiveStatement; use crate::ir::Parameter; use crate::ir::ProgIterator; use crate::ir::Statement; @@ -42,8 +42,11 @@ impl<'ast, T: Field> Folder<'ast, T> for UnconstrainedVariableDetector { self.variables.remove(&v); v } - fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { + fn fold_directive_statement( + &mut self, + d: DirectiveStatement<'ast, T>, + ) -> Vec> { self.variables.extend(d.outputs.iter()); - d + vec![Statement::Directive(d)] } } diff --git a/zokrates_ast/src/ir/clean.rs b/zokrates_ast/src/ir/clean.rs index 998e32987..1c58faff6 100644 --- a/zokrates_ast/src/ir/clean.rs +++ b/zokrates_ast/src/ir/clean.rs @@ -1,3 +1,5 @@ +use crate::common::WithSpan; + use super::folder::Folder; use super::{ProgIterator, Statement}; use zokrates_field::Field; @@ -8,6 +10,7 @@ pub struct Cleaner; impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { pub fn clean(self) -> ProgIterator<'ast, T, impl IntoIterator>> { ProgIterator { + module_map: self.module_map, arguments: self.arguments, return_count: self.return_count, statements: self @@ -21,8 +24,13 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a impl<'ast, T: Field> Folder<'ast, T> for Cleaner { fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { + if s.get_span().is_none() { + eprintln!("Internal compiler warning: found a statement without source information. Please open an issue https://github.com/Zokrates/ZoKrates/issues/new?template=bug_report.md"); + } + match s { - Statement::Block(statements) => statements + Statement::Block(s) => s + .inner .into_iter() .flat_map(|s| self.fold_statement(s)) .collect(), diff --git a/zokrates_ast/src/ir/expression.rs b/zokrates_ast/src/ir/expression.rs index 33e3856e7..5394b6eac 100644 --- a/zokrates_ast/src/ir/expression.rs +++ b/zokrates_ast/src/ir/expression.rs @@ -1,23 +1,34 @@ use super::Witness; -use crate::common::Variable; +use crate::common::{flat::Variable, Span, WithSpan}; +use derivative::Derivative; use serde::{Deserialize, Serialize}; use std::collections::btree_map::{BTreeMap, Entry}; use std::fmt; -use std::hash::Hash; use std::ops::{Add, Div, Mul, Sub}; use zokrates_field::Field; -#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Debug, Clone, Serialize, Deserialize, Eq)] pub struct QuadComb { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + pub span: Option, pub left: LinComb, pub right: LinComb, } -impl QuadComb { - pub fn from_linear_combinations(left: LinComb, right: LinComb) -> Self { - QuadComb { left, right } +impl WithSpan for QuadComb { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span } +} +impl QuadComb { pub fn try_linear(self) -> Result, Self> { // identify `(k * ~ONE) * (lincomb)` and `(lincomb) * (k * ~ONE)` and return (k * lincomb) // if not, error out with the input @@ -30,7 +41,7 @@ impl QuadComb { Ok(coefficient) => Ok(self.right * &coefficient), Err(left) => match self.right.try_constant() { Ok(coefficient) => Ok(left * &coefficient), - Err(right) => Err(QuadComb::from_linear_combinations(left, right)), + Err(right) => Err(QuadComb::new(left, right)), }, } } @@ -44,7 +55,7 @@ impl From for LinComb { impl>> From for QuadComb { fn from(x: U) -> QuadComb { - QuadComb::from_linear_combinations(LinComb::one(), x.into()) + QuadComb::new(LinComb::one(), x.into()) } } @@ -54,21 +65,80 @@ impl fmt::Display for QuadComb { } } -#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] -pub struct LinComb(pub Vec<(Variable, T)>); +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq)] +pub struct LinComb { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + pub span: Option, + pub value: Vec<(Variable, T)>, +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq)] +pub struct CanonicalLinComb { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + pub span: Option, + pub value: BTreeMap, +} + +impl WithSpan for CanonicalLinComb { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } -#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)] -pub struct CanonicalLinComb(pub BTreeMap); + fn get_span(&self) -> Option { + self.span + } +} -#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq)] pub struct CanonicalQuadComb { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + span: Option, left: CanonicalLinComb, right: CanonicalLinComb, } +impl WithSpan for CanonicalQuadComb { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl CanonicalQuadComb { + pub fn new(left: CanonicalLinComb, right: CanonicalLinComb) -> Self { + Self { + span: None, + left, + right, + } + } +} + +impl QuadComb { + pub fn new(left: LinComb, right: LinComb) -> Self { + Self { + span: None, + left, + right, + } + } +} + impl From> for QuadComb { fn from(q: CanonicalQuadComb) -> Self { QuadComb { + span: q.span, left: q.left.into(), right: q.right.into(), } @@ -77,42 +147,66 @@ impl From> for QuadComb { impl From> for LinComb { fn from(l: CanonicalLinComb) -> Self { - LinComb(l.0.into_iter().collect()) + LinComb { + span: l.span, + value: l.value.into_iter().collect(), + } + } +} + +impl CanonicalLinComb { + pub fn new(value: BTreeMap) -> Self { + Self { span: None, value } } } impl LinComb { + pub fn new(value: Vec<(Variable, T)>) -> Self { + Self { span: None, value } + } + pub fn summand>(mult: U, var: Variable) -> LinComb { let res = vec![(var, mult.into())]; - LinComb(res) + LinComb::new(res) } pub fn zero() -> LinComb { - LinComb(Vec::new()) + LinComb::new(Vec::new()) } pub fn is_zero(&self) -> bool { - self.0.is_empty() + self.value.is_empty() + } +} + +impl WithSpan for LinComb { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span } } impl LinComb { pub fn try_constant(self) -> Result { - match self.0.len() { + match self.value.len() { // if the lincomb is empty, it is reduceable to 0 0 => Ok(T::zero()), _ => { // take the first variable in the lincomb - let first = &self.0[0].0; + let first = &self.value[0].0; if first != &Variable::one() { return Err(self); } // all terms must contain the same variable - if self.0.iter().all(|element| element.0 == *first) { - Ok(self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1)) + if self.value.iter().all(|element| element.0 == *first) { + Ok(self.value.into_iter().fold(T::zero(), |acc, e| acc + e.1)) } else { Err(self) } @@ -121,26 +215,26 @@ impl LinComb { } pub fn is_assignee(&self, witness: &Witness) -> bool { - self.0.len() == 1 - && self.0.get(0).unwrap().1 == T::from(1) - && !witness.0.contains_key(&self.0.get(0).unwrap().0) + self.value.len() == 1 + && self.value.get(0).unwrap().1 == T::from(1) + && !witness.0.contains_key(&self.value.get(0).unwrap().0) } pub fn try_summand(self) -> Result<(Variable, T), Self> { - match self.0.len() { + match self.value.len() { // if the lincomb is empty, it is not reduceable to a summand 0 => Err(self), _ => { // take the first variable in the lincomb - let first = &self.0[0].0; + let first = &self.value[0].0; - if self.0.iter().all(|element| + if self.value.iter().all(|element| // all terms must contain the same variable element.0 == *first) { Ok(( *first, - self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1), + self.value.into_iter().fold(T::zero(), |acc, e| acc + e.1), )) } else { Err(self) @@ -156,32 +250,34 @@ impl LinComb { impl LinComb { pub fn into_canonical(self) -> CanonicalLinComb { - CanonicalLinComb( - self.0 - .into_iter() - .fold(BTreeMap::new(), |mut acc, (val, coeff)| { - // if we're adding 0 times some variable, we can ignore this term - if coeff != T::zero() { - match acc.entry(val) { - Entry::Occupied(o) => { - // if the new value is non zero, update, else remove the term entirely - if *o.get() + coeff != T::zero() { - *o.into_mut() = *o.get() + coeff; - } else { - o.remove(); - } - } - Entry::Vacant(v) => { - // We checked earlier but let's make sure we're not creating zero-coeff terms - assert!(coeff != T::zero()); - v.insert(coeff); + let span = self.get_span(); + + CanonicalLinComb::new(self.value.into_iter().fold( + BTreeMap::::new(), + |mut acc, (val, coeff)| { + // if we're adding 0 times some variable, we can ignore this term + if coeff != T::zero() { + match acc.entry(val) { + Entry::Occupied(o) => { + // if the new value is non zero, update, else remove the term entirely + if *o.get() + coeff != T::zero() { + *o.into_mut() = *o.get() + coeff; + } else { + o.remove(); } } + Entry::Vacant(v) => { + // We checked earlier but let's make sure we're not creating zero-coeff terms + assert!(coeff != T::zero()); + v.insert(coeff); + } } + } - acc - }), - ) + acc + }, + )) + .span(span) } pub fn reduce(self) -> Self { @@ -191,10 +287,8 @@ impl LinComb { impl QuadComb { pub fn into_canonical(self) -> CanonicalQuadComb { - CanonicalQuadComb { - left: self.left.into_canonical(), - right: self.right.into_canonical(), - } + CanonicalQuadComb::new(self.left.into_canonical(), self.right.into_canonical()) + .span(self.span) } pub fn reduce(self) -> Self { @@ -209,7 +303,7 @@ impl fmt::Display for LinComb { false => write!( f, "{}", - self.0 + self.value .iter() .map(|(k, v)| format!("{} * {}", v.to_compact_dec_string(), k)) .collect::>() @@ -222,7 +316,7 @@ impl fmt::Display for LinComb { impl From for LinComb { fn from(v: Variable) -> LinComb { let r = vec![(v, T::one())]; - LinComb(r) + LinComb::new(r) } } @@ -230,9 +324,9 @@ impl Add> for LinComb { type Output = LinComb; fn add(self, other: LinComb) -> LinComb { - let mut res = self.0; - res.extend(other.0); - LinComb(res) + let mut res = self.value; + res.extend(other.value); + LinComb::new(res) } } @@ -241,9 +335,14 @@ impl Sub> for LinComb { fn sub(self, other: LinComb) -> LinComb { // Concatenate with second vector that have negative coeffs - let mut res = self.0; - res.extend(other.0.into_iter().map(|(var, val)| (var, T::zero() - val))); - LinComb(res) + let mut res = self.value; + res.extend( + other + .value + .into_iter() + .map(|(var, val)| (var, T::zero() - val)), + ); + LinComb::new(res) } } @@ -255,8 +354,8 @@ impl Mul<&T> for LinComb { return self; } - LinComb( - self.0 + LinComb::new( + self.value .into_iter() .map(|(var, coeff)| (var, coeff * scalar)) .collect(), @@ -299,7 +398,7 @@ mod tests { (Variable::new(42), Bn128Field::from(1)), ]; - assert_eq!(c, LinComb(expected_vec)); + assert_eq!(c, LinComb::new(expected_vec)); } #[test] fn sub() { @@ -312,7 +411,7 @@ mod tests { (Variable::new(42), Bn128Field::from(-1)), ]; - assert_eq!(c, LinComb(expected_vec)); + assert_eq!(c, LinComb::new(expected_vec)); } #[test] @@ -331,35 +430,26 @@ mod tests { fn from_linear() { let a: LinComb = LinComb::summand(3, Variable::new(42)) + LinComb::summand(4, Variable::new(33)); - let expected = QuadComb { - left: LinComb::one(), - right: a.clone(), - }; + let expected = QuadComb::new(LinComb::one(), a.clone()); assert_eq!(QuadComb::from(a), expected); } #[test] fn zero() { let a: LinComb = LinComb::zero(); - let expected: QuadComb = QuadComb { - left: LinComb::one(), - right: LinComb::zero(), - }; + let expected: QuadComb = QuadComb::new(LinComb::one(), LinComb::zero()); assert_eq!(QuadComb::from(a), expected); } #[test] fn display() { - let a: QuadComb = QuadComb { - left: LinComb::summand(3, Variable::new(42)) - + LinComb::summand(4, Variable::new(33)), - right: LinComb::summand(1, Variable::new(21)), - }; + let a: QuadComb = QuadComb::new( + LinComb::summand(3, Variable::new(42)) + LinComb::summand(4, Variable::new(33)), + LinComb::summand(1, Variable::new(21)), + ); assert_eq!(&a.to_string(), "(3 * _42 + 4 * _33) * (1 * _21)"); - let a: QuadComb = QuadComb { - left: LinComb::zero(), - right: LinComb::summand(1, Variable::new(21)), - }; + let a: QuadComb = + QuadComb::new(LinComb::zero(), LinComb::summand(1, Variable::new(21))); assert_eq!(&a.to_string(), "(0) * (1 * _21)"); } } @@ -369,7 +459,7 @@ mod tests { #[test] fn try_summand() { - let summand = LinComb(vec![ + let summand = LinComb::new(vec![ (Variable::new(42), Bn128Field::from(1)), (Variable::new(42), Bn128Field::from(2)), (Variable::new(42), Bn128Field::from(3)), @@ -379,14 +469,14 @@ mod tests { Ok((Variable::new(42), Bn128Field::from(6))) ); - let not_summand = LinComb(vec![ + let not_summand = LinComb::new(vec![ (Variable::new(41), Bn128Field::from(1)), (Variable::new(42), Bn128Field::from(2)), (Variable::new(42), Bn128Field::from(3)), ]); assert!(not_summand.try_summand().is_err()); - let empty: LinComb = LinComb(vec![]); + let empty: LinComb = LinComb::new(vec![]); assert!(empty.try_summand().is_err()); } } diff --git a/zokrates_ast/src/ir/folder.rs b/zokrates_ast/src/ir/folder.rs index 99466addc..3b17c722f 100644 --- a/zokrates_ast/src/ir/folder.rs +++ b/zokrates_ast/src/ir/folder.rs @@ -1,7 +1,7 @@ // Generic walk through an IR AST. Not mutating in place use super::*; -use crate::common::Variable; +use crate::common::{flat::Variable, WithSpan}; use zokrates_field::Field; pub trait Folder<'ast, T: Field>: Sized { @@ -21,6 +21,29 @@ pub trait Folder<'ast, T: Field>: Sized { fold_statement(self, s) } + fn fold_statement_cases(&mut self, s: Statement<'ast, T>) -> Vec> { + fold_statement_cases(self, s) + } + + fn fold_constraint_statement(&mut self, s: ConstraintStatement) -> Vec> { + fold_constraint_statement(self, s) + } + + fn fold_directive_statement( + &mut self, + s: DirectiveStatement<'ast, T>, + ) -> Vec> { + fold_directive_statement(self, s) + } + + fn fold_log_statement(&mut self, s: LogStatement) -> Vec> { + fold_log_statement(self, s) + } + + fn fold_block_statement(&mut self, s: BlockStatement<'ast, T>) -> Vec> { + fold_block_statement(self, s) + } + fn fold_linear_combination(&mut self, e: LinComb) -> LinComb { fold_linear_combination(self, e) } @@ -28,10 +51,6 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_quadratic_combination(&mut self, es: QuadComb) -> QuadComb { fold_quadratic_combination(self, es) } - - fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { - fold_directive(self, d) - } } pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( @@ -49,41 +68,73 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), - return_count: p.return_count, - solvers: p.solvers, + ..p } } -pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( +pub fn fold_constraint_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ConstraintStatement, +) -> Vec> { + vec![Statement::constraint( + f.fold_quadratic_combination(s.quad), + f.fold_linear_combination(s.lin), + s.error, + )] +} + +pub fn fold_log_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: LogStatement, +) -> Vec> { + vec![Statement::log( + s.format_string, + s.expressions + .into_iter() + .map(|(t, e)| { + ( + t, + e.into_iter() + .map(|e| f.fold_linear_combination(e)) + .collect(), + ) + }) + .collect(), + )] +} + +pub fn fold_block_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: BlockStatement<'ast, T>, +) -> Vec> { + vec![Statement::block( + s.inner + .into_iter() + .flat_map(|s| f.fold_statement(s)) + .collect(), + )] +} + +fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: Statement<'ast, T>, +) -> Vec> { + let span = s.get_span(); + f.fold_statement_cases(s) + .into_iter() + .map(|s| s.span(span)) + .collect() +} + +pub fn fold_statement_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: Statement<'ast, T>, ) -> Vec> { match s { - Statement::Block(statements) => vec![Statement::Block( - statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - )], - Statement::Constraint(quad, lin, message) => vec![Statement::Constraint( - f.fold_quadratic_combination(quad), - f.fold_linear_combination(lin), - message, - )], - Statement::Directive(dir) => vec![Statement::Directive(f.fold_directive(dir))], - Statement::Log(l, e) => vec![Statement::Log( - l, - e.into_iter() - .map(|(t, e)| { - ( - t, - e.into_iter() - .map(|e| f.fold_linear_combination(e)) - .collect(), - ) - }) - .collect(), - )], + Statement::Constraint(s) => f.fold_constraint_statement(s), + Statement::Directive(s) => f.fold_directive_statement(s), + Statement::Log(s) => f.fold_log_statement(s), + Statement::Block(s) => f.fold_block_statement(s), } } @@ -91,28 +142,31 @@ pub fn fold_linear_combination<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: LinComb, ) -> LinComb { - LinComb( - e.0.into_iter() + LinComb::new( + e.value + .into_iter() .map(|(variable, coefficient)| (f.fold_variable(variable), coefficient)) .collect(), ) + .span(e.span) } pub fn fold_quadratic_combination<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: QuadComb, ) -> QuadComb { - QuadComb { - left: f.fold_linear_combination(e.left), - right: f.fold_linear_combination(e.right), - } + QuadComb::new( + f.fold_linear_combination(e.left), + f.fold_linear_combination(e.right), + ) + .span(e.span) } -pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( +pub fn fold_directive_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - ds: Directive<'ast, T>, -) -> Directive<'ast, T> { - Directive { + ds: DirectiveStatement<'ast, T>, +) -> Vec> { + vec![Statement::Directive(DirectiveStatement { inputs: ds .inputs .into_iter() @@ -120,13 +174,13 @@ pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( .collect(), outputs: ds.outputs.into_iter().map(|o| f.fold_variable(o)).collect(), ..ds - } + })] } pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter { Parameter { id: f.fold_variable(a.id), - private: a.private, + ..a } } diff --git a/zokrates_ast/src/ir/from_flat.rs b/zokrates_ast/src/ir/from_flat.rs index 3404c6c49..f7f8c42f8 100644 --- a/zokrates_ast/src/ir/from_flat.rs +++ b/zokrates_ast/src/ir/from_flat.rs @@ -1,5 +1,8 @@ +use crate::common::statements::LogStatement; +use crate::common::WithSpan; use crate::flat::{FlatDirective, FlatExpression, FlatProgIterator, FlatStatement, Variable}; -use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement}; +use crate::ir::{DirectiveStatement, LinComb, ProgIterator, QuadComb, Statement}; +use std::ops::*; use zokrates_field::Field; impl QuadComb { @@ -8,9 +11,7 @@ impl QuadComb { match flat_expression.is_linear() { true => LinComb::from(flat_expression).into(), false => match flat_expression { - FlatExpression::Mult(box e1, box e2) => { - QuadComb::from_linear_combinations(e1.into(), e2.into()) - } + FlatExpression::Mult(e) => QuadComb::new((*e.left).into(), (*e.right).into()), e => unimplemented!("{}", e), }, } @@ -24,72 +25,76 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator>> statements: flat_prog_iterator.statements.into_iter().map(Into::into), arguments: flat_prog_iterator.arguments, return_count: flat_prog_iterator.return_count, + module_map: flat_prog_iterator.module_map, solvers: vec![], } } impl From> for LinComb { fn from(flat_expression: FlatExpression) -> LinComb { + let span = flat_expression.get_span(); + match flat_expression { - FlatExpression::Number(ref n) if *n == T::from(0) => LinComb::zero(), - FlatExpression::Number(n) => LinComb::summand(n, Variable::one()), - FlatExpression::Identifier(id) => LinComb::from(id), - FlatExpression::Add(box e1, box e2) => LinComb::from(e1) + LinComb::from(e2), - FlatExpression::Sub(box e1, box e2) => LinComb::from(e1) - LinComb::from(e2), - FlatExpression::Mult( - box FlatExpression::Number(n1), - box FlatExpression::Identifier(v1), - ) - | FlatExpression::Mult( - box FlatExpression::Identifier(v1), - box FlatExpression::Number(n1), - ) => LinComb::summand(n1, v1), - FlatExpression::Mult( - box FlatExpression::Number(n1), - box FlatExpression::Number(n2), - ) => LinComb::summand(n1 * n2, Variable::one()), - e => unreachable!("{}", e), + FlatExpression::Value(ref n) if n.value == T::from(0) => LinComb::zero(), + FlatExpression::Value(n) => LinComb::summand(n.value, Variable::one()), + FlatExpression::Identifier(id) => LinComb::from(id.id), + FlatExpression::Add(e) => LinComb::from(*e.left) + LinComb::from(*e.right), + FlatExpression::Sub(e) => LinComb::from(*e.left) - LinComb::from(*e.right), + FlatExpression::Mult(e) => match (*e.left, *e.right) { + (FlatExpression::Value(n1), FlatExpression::Identifier(v1)) + | (FlatExpression::Identifier(v1), FlatExpression::Value(n1)) => { + LinComb::summand(n1.value, v1.id) + } + (FlatExpression::Value(n1), FlatExpression::Value(n2)) => { + LinComb::summand(n1.value * n2.value, Variable::one()) + } + (left, right) => unreachable!("{}", FlatExpression::mul(left, right).span(e.span)), + }, } + .span(span) } } impl<'ast, T: Field> From> for Statement<'ast, T> { fn from(flat_statement: FlatStatement<'ast, T>) -> Statement<'ast, T> { + let span = flat_statement.get_span(); + match flat_statement { - FlatStatement::Block(statements) => { - Statement::Block(statements.into_iter().map(Statement::from).collect()) - } - FlatStatement::Condition(linear, quadratic, message) => match quadratic { - FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint( - QuadComb::from_linear_combinations(lhs.into(), rhs.into()), - linear.into(), - Some(message), + FlatStatement::Condition(s) => match s.quad { + FlatExpression::Mult(e) => Statement::constraint( + QuadComb::new((*e.left).into(), (*e.right).into()).span(e.span), + LinComb::from(s.lin), + Some(s.error), ), - e => Statement::Constraint(LinComb::from(e).into(), linear.into(), Some(message)), + e => Statement::constraint(LinComb::from(e), s.lin, Some(s.error)), }, - FlatStatement::Definition(var, quadratic) => match quadratic { - FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint( - QuadComb::from_linear_combinations(lhs.into(), rhs.into()), - var.into(), + FlatStatement::Block(statements) => { + Statement::block(statements.inner.into_iter().map(Statement::from).collect()) + } + FlatStatement::Definition(s) => match s.rhs { + FlatExpression::Mult(e) => Statement::constraint( + QuadComb::new((*e.left).into(), (*e.right).into()).span(e.span), + s.assignee, None, ), - e => Statement::Constraint(LinComb::from(e).into(), var.into(), None), + e => Statement::constraint(LinComb::from(e), s.assignee, None), }, FlatStatement::Directive(ds) => Statement::Directive(ds.into()), - FlatStatement::Log(l, expressions) => Statement::Log( - l, - expressions + FlatStatement::Log(s) => Statement::Log(LogStatement::new( + s.format_string, + s.expressions .into_iter() .map(|(t, e)| (t, e.into_iter().map(LinComb::from).collect())) .collect(), - ), + )), } + .span(span) } } -impl<'ast, T: Field> From> for Directive<'ast, T> { - fn from(ds: FlatDirective<'ast, T>) -> Directive { - Directive { +impl<'ast, T: Field> From> for DirectiveStatement<'ast, T> { + fn from(ds: FlatDirective) -> DirectiveStatement { + DirectiveStatement { inputs: ds .inputs .into_iter() @@ -97,6 +102,7 @@ impl<'ast, T: Field> From> for Directive<'ast, T> { .collect(), solver: ds.solver, outputs: ds.outputs, + span: ds.span, } } } @@ -109,7 +115,7 @@ mod tests { #[test] fn zero() { // 0 - let zero = FlatExpression::Number(Bn128Field::from(0)); + let zero = FlatExpression::value(Bn128Field::from(0)); let expected: LinComb = LinComb::zero(); assert_eq!(LinComb::from(zero), expected); } @@ -117,7 +123,7 @@ mod tests { #[test] fn one() { // 1 - let one = FlatExpression::Number(Bn128Field::from(1)); + let one = FlatExpression::value(Bn128Field::from(1)); let expected: LinComb = Variable::one().into(); assert_eq!(LinComb::from(one), expected); } @@ -125,7 +131,7 @@ mod tests { #[test] fn forty_two() { // 42 - let one = FlatExpression::Number(Bn128Field::from(42)); + let one = FlatExpression::value(Bn128Field::from(42)); let expected: LinComb = LinComb::summand(42, Variable::one()); assert_eq!(LinComb::from(one), expected); } @@ -133,9 +139,9 @@ mod tests { #[test] fn add() { // x + y - let add = FlatExpression::Add( - box FlatExpression::Identifier(Variable::new(42)), - box FlatExpression::Identifier(Variable::new(21)), + let add = FlatExpression::add( + FlatExpression::identifier(Variable::new(42)), + FlatExpression::identifier(Variable::new(21)), ); let expected: LinComb = LinComb::summand(1, Variable::new(42)) + LinComb::summand(1, Variable::new(21)); @@ -145,14 +151,14 @@ mod tests { #[test] fn linear_combination() { // 42*x + 21*y - let add = FlatExpression::Add( - box FlatExpression::Mult( - box FlatExpression::Number(Bn128Field::from(42)), - box FlatExpression::Identifier(Variable::new(42)), + let add = FlatExpression::add( + FlatExpression::mul( + FlatExpression::value(Bn128Field::from(42)), + FlatExpression::identifier(Variable::new(42)), ), - box FlatExpression::Mult( - box FlatExpression::Number(Bn128Field::from(21)), - box FlatExpression::Identifier(Variable::new(21)), + FlatExpression::mul( + FlatExpression::value(Bn128Field::from(21)), + FlatExpression::identifier(Variable::new(21)), ), ); let expected: LinComb = @@ -163,14 +169,14 @@ mod tests { #[test] fn linear_combination_inverted() { // x*42 + y*21 - let add = FlatExpression::Add( - box FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(42)), - box FlatExpression::Number(Bn128Field::from(42)), + let add = FlatExpression::add( + FlatExpression::mul( + FlatExpression::identifier(Variable::new(42)), + FlatExpression::value(Bn128Field::from(42)), ), - box FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(21)), - box FlatExpression::Number(Bn128Field::from(21)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(21)), + FlatExpression::value(Bn128Field::from(21)), ), ); let expected: LinComb = diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index 7de47396e..2adbd802c 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -1,10 +1,9 @@ -use crate::common::FormatString; +use crate::common::{FormatString, ModuleMap, Span, WithSpan}; use crate::typed::ConcreteType; use derivative::Derivative; use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt; -use std::hash::Hash; use zokrates_field::Field; mod check; @@ -21,94 +20,178 @@ mod witness; pub use self::expression::QuadComb; pub use self::expression::{CanonicalLinComb, LinComb}; pub use self::serialize::{ProgEnum, ProgHeader}; -pub use crate::common::Parameter; +pub use crate::common::flat::Parameter; +pub use crate::common::flat::Variable; pub use crate::common::RuntimeError; pub use crate::common::Solver; -pub use crate::common::Variable; pub use self::witness::Witness; +pub type LogStatement = crate::common::statements::LogStatement<(ConcreteType, Vec>)>; +pub type DirectiveStatement<'ast, T> = + crate::common::statements::DirectiveStatement, Variable, Solver<'ast, T>>; + +#[derive(Derivative, Clone, Debug, Serialize, Deserialize)] +#[derivative(Hash, PartialEq, Eq)] +pub struct ConstraintStatement { + #[derivative(Hash = "ignore", PartialEq = "ignore")] + pub span: Option, + pub quad: QuadComb, + pub lin: LinComb, + #[derivative(Hash = "ignore", PartialEq = "ignore")] + pub error: Option, +} + +impl ConstraintStatement { + pub fn new(quad: QuadComb, lin: LinComb, error: Option) -> Self { + Self { + span: None, + quad, + lin, + error, + } + } +} + +impl WithSpan for ConstraintStatement { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for ConstraintStatement { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} == {}{}", + self.quad, + self.lin, + self.error + .as_ref() + .map(|e| format!(" // {}", e)) + .unwrap_or_else(|| "".to_string()) + ) + } +} + +#[derive(Derivative, Clone, Debug, Serialize, Deserialize)] +#[derivative(Hash, PartialEq, Eq)] +pub struct BlockStatement<'ast, T> { + #[derivative(Hash = "ignore", PartialEq = "ignore")] + pub span: Option, + #[serde(borrow)] + pub inner: Vec>, +} + +impl<'ast, T> BlockStatement<'ast, T> { + pub fn new(inner: Vec>) -> Self { + Self { span: None, inner } + } +} + +impl<'ast, T> WithSpan for BlockStatement<'ast, T> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T: Field> fmt::Display for BlockStatement<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "{{")?; + for s in &self.inner { + writeln!(f, "{}", s)?; + } + write!(f, "}}") + } +} + #[derive(Debug, Serialize, Deserialize, Clone, Derivative)] #[derivative(Hash, PartialEq, Eq)] pub enum Statement<'ast, T> { #[serde(skip)] - Block(Vec>), - Constraint( - QuadComb, - LinComb, - #[derivative(Hash = "ignore")] Option, - ), + Block(BlockStatement<'ast, T>), + Constraint(ConstraintStatement), #[serde(borrow)] - Directive(Directive<'ast, T>), - Log(FormatString, Vec<(ConcreteType, Vec>)>), + Directive(DirectiveStatement<'ast, T>), + Log(LogStatement), } pub type PublicInputs = BTreeSet; +impl<'ast, T> WithSpan for Statement<'ast, T> { + fn span(self, span: Option) -> Self { + match self { + Statement::Constraint(c) => Statement::Constraint(c.span(span)), + Statement::Directive(c) => Statement::Directive(c.span(span)), + Statement::Log(c) => Statement::Log(c.span(span)), + Statement::Block(c) => Statement::Block(c.span(span)), + } + } + + fn get_span(&self) -> Option { + match self { + Statement::Constraint(c) => c.get_span(), + Statement::Directive(c) => c.get_span(), + Statement::Log(c) => c.get_span(), + Statement::Block(c) => c.get_span(), + } + } +} + impl<'ast, T: Field> Statement<'ast, T> { pub fn definition>>(v: Variable, e: U) -> Self { - Statement::Constraint(e.into(), v.into(), None) + Statement::constraint(e, v, None) } - pub fn constraint>, V: Into>>(quad: U, lin: V) -> Self { - Statement::Constraint(quad.into(), lin.into(), None) + pub fn constraint>, V: Into>>( + quad: U, + lin: V, + error: Option, + ) -> Self { + Statement::Constraint(ConstraintStatement::new(quad.into(), lin.into(), error)) } -} -#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] -pub struct Directive<'ast, T> { - pub inputs: Vec>, - pub outputs: Vec, - #[serde(borrow)] - pub solver: Solver<'ast, T>, -} + pub fn log( + format_string: FormatString, + expressions: Vec<(ConcreteType, Vec>)>, + ) -> Self { + Statement::Log(LogStatement::new(format_string, expressions)) + } -impl<'ast, T: Field> fmt::Display for Directive<'ast, T> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "# {} = {}({})", - self.outputs - .iter() - .map(|o| format!("{}", o)) - .collect::>() - .join(", "), - self.solver, - self.inputs - .iter() - .map(|i| format!("{}", i)) - .collect::>() - .join(", ") - ) + pub fn block(inner: Vec>) -> Self { + Statement::Block(BlockStatement::new(inner)) + } + + pub fn directive( + outputs: Vec, + solver: Solver<'ast, T>, + inputs: Vec>, + ) -> Self { + Statement::Directive(DirectiveStatement::new(outputs, solver, inputs)) } } impl<'ast, T: Field> fmt::Display for Statement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Statement::Block(ref statements) => { - writeln!(f, "{{")?; - for s in statements { - writeln!(f, "{}", s)?; - } - write!(f, "}}") - } - Statement::Constraint(ref quad, ref lin, ref error) => write!( - f, - "{} == {}{}", - quad, - lin, - error - .as_ref() - .map(|e| format!(" // {}", e)) - .unwrap_or_else(|| "".to_string()) - ), + Statement::Constraint(ref s) => write!(f, "{}", s), + Statement::Block(ref s) => write!(f, "{}", s), Statement::Directive(ref s) => write!(f, "{}", s), - Statement::Log(ref s, ref expressions) => write!( + Statement::Log(ref s) => write!( f, "log(\"{}\", {})", - s, - expressions + s.format_string, + s.expressions .iter() .map(|(_, l)| format!( "[{}]", @@ -128,6 +211,7 @@ pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec>>; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct ProgIterator<'ast, T, I: IntoIterator>> { + pub module_map: ModuleMap, pub arguments: Vec, pub return_count: usize, pub statements: I, @@ -140,12 +224,14 @@ impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, arguments: Vec, statements: I, return_count: usize, + module_map: ModuleMap, solvers: Vec>, ) -> Self { Self { arguments, return_count, statements, + module_map, solvers, } } @@ -155,6 +241,7 @@ impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, statements: self.statements.into_iter().collect::>(), arguments: self.arguments, return_count: self.return_count, + module_map: self.module_map, solvers: self.solvers, } } @@ -202,6 +289,7 @@ impl<'ast, T> Prog<'ast, T> { statements: self.statements.into_iter(), arguments: self.arguments, return_count: self.return_count, + module_map: self.module_map, solvers: self.solvers, } } @@ -244,12 +332,9 @@ mod tests { #[test] fn print_constraint() { - let c: Statement = Statement::Constraint( - QuadComb::from_linear_combinations( - Variable::new(42).into(), - Variable::new(42).into(), - ), - Variable::new(42).into(), + let c: Statement = Statement::constraint( + QuadComb::new(Variable::new(42).into(), Variable::new(42).into()), + Variable::new(42), None, ); assert_eq!(format!("{}", c), "(1 * _42) * (1 * _42) == 1 * _42") diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index f32624299..02b2d420a 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -1,6 +1,7 @@ use crate::ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer}; use super::{ProgIterator, Statement}; +use crate::ir::ModuleMap; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use serde::Deserialize; use serde_cbor::{self, StreamDeserializer}; @@ -66,6 +67,7 @@ pub enum SectionType { Parameters = 1, Constraints = 2, Solvers = 3, + Modules = 4, } impl TryFrom for SectionType { @@ -76,6 +78,7 @@ impl TryFrom for SectionType { 1 => Ok(SectionType::Parameters), 2 => Ok(SectionType::Constraints), 3 => Ok(SectionType::Solvers), + 4 => Ok(SectionType::Modules), _ => Err("invalid section type".to_string()), } } @@ -113,7 +116,7 @@ pub struct ProgHeader { pub curve_id: [u8; 4], pub constraint_count: u32, pub return_count: u32, - pub sections: [Section; 3], + pub sections: [Section; 4], } impl ProgHeader { @@ -149,6 +152,7 @@ impl ProgHeader { let parameters = Self::read_section(r.by_ref())?; let constraints = Self::read_section(r.by_ref())?; let solvers = Self::read_section(r.by_ref())?; + let module_map = Self::read_section(r.by_ref())?; Ok(ProgHeader { magic, @@ -156,7 +160,7 @@ impl ProgHeader { curve_id, constraint_count, return_count, - sections: [parameters, constraints, solvers], + sections: [parameters, constraints, solvers, module_map], }) } @@ -231,13 +235,24 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a section }; + // write module map section + let module_map = { + let mut section = Section::new(SectionType::Solvers); + section.set_offset(w.stream_position()?); + + serde_cbor::to_writer(&mut w, &self.module_map)?; + + section.set_length(w.stream_position()? - section.offset); + section + }; + let header = ProgHeader { magic: *ZOKRATES_MAGIC, version: *FILE_VERSION, curve_id: T::id(), constraint_count: count as u32, return_count: self.return_count as u32, - sections: [parameters, constraints, solvers], + sections: [parameters, constraints, solvers, module_map], }; // rewind to write the header @@ -302,6 +317,16 @@ impl<'de, R: Read + Seek> .unwrap() }; + let module_map = { + let section = &header.sections[3]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + ModuleMap::deserialize(&mut p) + .map_err(|_| String::from("Cannot read module map")) + .unwrap() + }; + let statements_deserializer = { let section = &header.sections[1]; r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); @@ -316,6 +341,7 @@ impl<'de, R: Read + Seek> parameters, statements_deserializer, header.return_count as usize, + module_map, solvers, ) } diff --git a/zokrates_ast/src/ir/smtlib2.rs b/zokrates_ast/src/ir/smtlib2.rs index bc1188518..4d43d0d28 100644 --- a/zokrates_ast/src/ir/smtlib2.rs +++ b/zokrates_ast/src/ir/smtlib2.rs @@ -79,11 +79,11 @@ impl<'ast, T: Field> SMTLib2 for Statement<'ast, T> { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Statement::Block(..) => unreachable!(), - Statement::Constraint(ref quad, ref lin, _) => { + Statement::Constraint(ref s) => { write!(f, "(= (mod ")?; - quad.to_smtlib2(f)?; + s.quad.to_smtlib2(f)?; write!(f, " |~prime|) (mod ")?; - lin.to_smtlib2(f)?; + s.lin.to_smtlib2(f)?; write!(f, " |~prime|))") } Statement::Directive(ref s) => s.to_smtlib2(f), @@ -92,7 +92,7 @@ impl<'ast, T: Field> SMTLib2 for Statement<'ast, T> { } } -impl<'ast, T: Field> SMTLib2 for Directive<'ast, T> { +impl<'ast, T: Field> SMTLib2 for DirectiveStatement<'ast, T> { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "") } @@ -109,15 +109,20 @@ impl SMTLib2 for LinComb { match self.is_zero() { true => write!(f, "0"), false => { - if self.0.len() > 1 { + if self.value.len() > 1 { write!(f, "(+")?; - for expr in self.0.iter() { + for expr in self.value.iter() { write!(f, " ")?; format_prefix_op_smtlib2(f, "*", &expr.0, &expr.1.to_biguint())?; } write!(f, ")") } else { - format_prefix_op_smtlib2(f, "*", &self.0[0].0, &self.0[0].1.to_biguint()) + format_prefix_op_smtlib2( + f, + "*", + &self.value[0].0, + &self.value[0].1.to_biguint(), + ) } } } diff --git a/zokrates_ast/src/ir/solver_indexer.rs b/zokrates_ast/src/ir/solver_indexer.rs index cbe1a1511..334da750b 100644 --- a/zokrates_ast/src/ir/solver_indexer.rs +++ b/zokrates_ast/src/ir/solver_indexer.rs @@ -1,6 +1,5 @@ use crate::common::RefCall; use crate::ir::folder::Folder; -use crate::ir::Directive; use crate::ir::Solver; use crate::zir::ZirFunction; use std::collections::hash_map::DefaultHasher; @@ -8,6 +7,9 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use zokrates_field::Field; +use super::DirectiveStatement; +use super::Statement; + type Hash = u64; fn hash(f: &ZirFunction) -> Hash { @@ -25,8 +27,11 @@ pub struct SolverIndexer<'ast, T> { } impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> { - fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { - match d.solver { + fn fold_directive_statement( + &mut self, + d: DirectiveStatement<'ast, T>, + ) -> Vec> { + let res = match d.solver { Solver::Zir(f) => { let argument_count = f.arguments.len(); let h = hash(&f); @@ -39,16 +44,17 @@ impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> { index } }; - Directive { - inputs: d.inputs, - outputs: d.outputs, - solver: Solver::Ref(RefCall { + DirectiveStatement::new( + d.outputs, + Solver::Ref(RefCall { index, argument_count, }), - } + d.inputs, + ) } _ => d, - } + }; + vec![Statement::Directive(res)] } } diff --git a/zokrates_ast/src/ir/visitor.rs b/zokrates_ast/src/ir/visitor.rs index d3894ca6b..175c48bfe 100644 --- a/zokrates_ast/src/ir/visitor.rs +++ b/zokrates_ast/src/ir/visitor.rs @@ -1,7 +1,7 @@ // Generic walk through an IR AST. Not mutating in place use super::*; -use crate::common::Variable; +use crate::common::flat::Variable; use zokrates_field::Field; pub trait Visitor: Sized { @@ -33,8 +33,8 @@ pub trait Visitor: Sized { visit_quadratic_combination(self, es) } - fn visit_directive(&mut self, d: &Directive) { - visit_directive(self, d) + fn visit_directive_statement(&mut self, d: &DirectiveStatement) { + visit_directive_statement(self, d) } fn visit_runtime_error(&mut self, e: &RuntimeError) { @@ -53,21 +53,21 @@ pub fn visit_module>(f: &mut F, p: &Prog) { pub fn visit_statement>(f: &mut F, s: &Statement) { match s { - Statement::Block(statements) => { - for s in statements { + Statement::Block(s) => { + for s in &s.inner { f.visit_statement(s); } } - Statement::Constraint(quad, lin, error) => { - f.visit_quadratic_combination(quad); - f.visit_linear_combination(lin); - if let Some(error) = error.as_ref() { + Statement::Constraint(s) => { + f.visit_quadratic_combination(&s.quad); + f.visit_linear_combination(&s.lin); + if let Some(error) = s.error.as_ref() { f.visit_runtime_error(error); } } - Statement::Directive(dir) => f.visit_directive(dir), - Statement::Log(_, expressions) => { - for (_, e) in expressions { + Statement::Directive(dir) => f.visit_directive_statement(dir), + Statement::Log(s) => { + for (_, e) in &s.expressions { for e in e { f.visit_linear_combination(e); } @@ -77,7 +77,7 @@ pub fn visit_statement>(f: &mut F, s: &Statement) { } pub fn visit_linear_combination>(f: &mut F, e: &LinComb) { - for expr in e.0.iter() { + for expr in e.value.iter() { f.visit_variable(&expr.0); f.visit_value(&expr.1); } @@ -88,7 +88,7 @@ pub fn visit_quadratic_combination>(f: &mut F, e: &QuadC f.visit_linear_combination(&e.right); } -pub fn visit_directive>(f: &mut F, ds: &Directive) { +pub fn visit_directive_statement>(f: &mut F, ds: &DirectiveStatement) { for expr in ds.inputs.iter() { f.visit_quadratic_combination(expr); } diff --git a/zokrates_ast/src/ir/witness.rs b/zokrates_ast/src/ir/witness.rs index 70da540c1..b8117279f 100644 --- a/zokrates_ast/src/ir/witness.rs +++ b/zokrates_ast/src/ir/witness.rs @@ -1,4 +1,4 @@ -use crate::common::Variable; +use crate::common::flat::Variable; use std::collections::{BTreeMap, HashMap}; use std::fmt; use std::io; diff --git a/zokrates_ast/src/typed/abi.rs b/zokrates_ast/src/typed/abi.rs index 253d27e28..7c9d995fa 100644 --- a/zokrates_ast/src/typed/abi.rs +++ b/zokrates_ast/src/typed/abi.rs @@ -49,14 +49,14 @@ mod tests { ConcreteFunctionKey::with_location("main", "main").into(), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![ - DeclarationParameter { - id: DeclarationVariable::new("a", DeclarationType::FieldElement, true), - private: true, - }, - DeclarationParameter { - id: DeclarationVariable::new("b", DeclarationType::Boolean, false), - private: false, - }, + DeclarationParameter::private(DeclarationVariable::new( + "a", + DeclarationType::FieldElement, + )), + DeclarationParameter::public(DeclarationVariable::new( + "b", + DeclarationType::Boolean, + )), ], statements: vec![], signature: ConcreteSignature::new() @@ -73,6 +73,7 @@ mod tests { let typed_ast: TypedProgram = TypedProgram { main: "main".into(), modules, + module_map: Default::default(), }; let abi: Abi = typed_ast.abi(); diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index 722dcbf87..bc05d026f 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -1,47 +1,47 @@ +use crate::common::expressions::{ + BooleanValueExpression, EqExpression, FieldValueExpression, IntValueExpression, + UnaryOrExpression, ValueOrExpression, +}; // Generic walk through a typed AST. Not mutating in place - -use crate::typed::types::*; +use crate::common::{expressions::BinaryOrExpression, Fold}; use crate::typed::*; use zokrates_field::Field; use super::identifier::FrameIdentifier; +use super::types::{DeclarationStructMember, DeclarationTupleType, StructMember}; -pub trait Fold<'ast, T: Field>: Sized { - fn fold>(self, f: &mut F) -> Self; -} - -impl<'ast, T: Field> Fold<'ast, T> for FieldElementExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for FieldElementExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_field_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for BooleanExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for BooleanExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_boolean_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for UExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for UExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_uint_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for StructExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for StructExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_struct_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for ArrayExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for ArrayExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_array_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for TupleExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for TupleExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_tuple_expression(self) } } @@ -142,22 +142,16 @@ pub trait Folder<'ast, T: Field>: Sized { } fn fold_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { - Variable { - id: self.fold_name(v.id), - _type: self.fold_type(v._type), - is_mutable: v.is_mutable, - } + let span = v.get_span(); + Variable::new(self.fold_name(v.id), self.fold_type(v.ty)).span(span) } fn fold_declaration_variable( &mut self, v: DeclarationVariable<'ast, T>, ) -> DeclarationVariable<'ast, T> { - DeclarationVariable { - id: self.fold_name(v.id), - _type: self.fold_declaration_type(v._type), - is_mutable: v.is_mutable, - } + let span = v.get_span(); + DeclarationVariable::new(self.fold_name(v.id), self.fold_declaration_type(v.ty)).span(span) } fn fold_type(&mut self, t: Type<'ast, T>) -> Type<'ast, T> { @@ -263,6 +257,27 @@ pub trait Folder<'ast, T: Field>: Sized { fold_assignee(self, a) } + fn fold_assembly_block( + &mut self, + s: AssemblyBlockStatement<'ast, T>, + ) -> Vec> { + fold_assembly_block(self, s) + } + + fn fold_assembly_assignment( + &mut self, + s: AssemblyAssignment<'ast, T>, + ) -> Vec> { + fold_assembly_assignment(self, s) + } + + fn fold_assembly_constraint( + &mut self, + s: AssemblyConstraint<'ast, T>, + ) -> Vec> { + fold_assembly_constraint(self, s) + } + fn fold_assembly_statement( &mut self, s: TypedAssemblyStatement<'ast, T>, @@ -270,10 +285,50 @@ pub trait Folder<'ast, T: Field>: Sized { fold_assembly_statement(self, s) } + fn fold_assembly_statement_cases( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Vec> { + fold_assembly_statement_cases(self, s) + } + fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { fold_statement(self, s) } + fn fold_statement_cases(&mut self, s: TypedStatement<'ast, T>) -> Vec> { + fold_statement_cases(self, s) + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Vec> { + fold_definition_statement(self, s) + } + + fn fold_return_statement( + &mut self, + s: ReturnStatement<'ast, T>, + ) -> Vec> { + fold_return_statement(self, s) + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, + ) -> Vec> { + fold_assertion_statement(self, s) + } + + fn fold_log_statement(&mut self, s: LogStatement<'ast, T>) -> Vec> { + fold_log_statement(self, s) + } + + fn fold_for_statement(&mut self, s: ForStatement<'ast, T>) -> Vec> { + fold_for_statement(self, s) + } + fn fold_definition_rhs(&mut self, rhs: DefinitionRhs<'ast, T>) -> DefinitionRhs<'ast, T> { fold_definition_rhs(self, rhs) } @@ -312,7 +367,7 @@ pub trait Folder<'ast, T: Field>: Sized { } } - fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> OwnedTypedModuleId { + fn fold_module_id(&mut self, i: OwnedModuleId) -> OwnedModuleId { i } @@ -320,7 +375,7 @@ pub trait Folder<'ast, T: Field>: Sized { fold_expression(self, e) } - fn fold_block_expression>( + fn fold_block_expression>( &mut self, block: BlockExpression<'ast, T, E>, ) -> BlockExpression<'ast, T, E> { @@ -329,7 +384,7 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_conditional_expression< E: Expr<'ast, T> - + Fold<'ast, T> + + Fold + Block<'ast, T> + Conditional<'ast, T> + From>, @@ -341,6 +396,35 @@ pub trait Folder<'ast, T: Field>: Sized { fold_conditional_expression(self, ty, e) } + fn fold_binary_expression< + L: Expr<'ast, T> + Fold, + R: Expr<'ast, T> + Fold, + E: Expr<'ast, T> + Fold, + Op, + >( + &mut self, + ty: &E::Ty, + e: BinaryExpression, + ) -> BinaryOrExpression { + fold_binary_expression(self, ty, e) + } + + fn fold_eq_expression + Fold>( + &mut self, + e: EqExpression>, + ) -> BinaryOrExpression, BooleanExpression<'ast, T>> + { + fold_binary_expression(self, &Type::Boolean, e) + } + + fn fold_unary_expression + Fold, E: Expr<'ast, T> + Fold, Op>( + &mut self, + ty: &E::Ty, + e: UnaryExpression, + ) -> UnaryOrExpression { + fold_unary_expression(self, ty, e) + } + fn fold_member_expression< E: Expr<'ast, T> + Member<'ast, T> + From>, >( @@ -351,6 +435,17 @@ pub trait Folder<'ast, T: Field>: Sized { fold_member_expression(self, ty, e) } + fn fold_slice_expression(&mut self, e: SliceExpression<'ast, T>) -> SliceOrExpression<'ast, T> { + fold_slice_expression(self, e) + } + + fn fold_repeat_expression( + &mut self, + e: RepeatExpression<'ast, T>, + ) -> RepeatOrExpression<'ast, T> { + fold_repeat_expression(self, e) + } + fn fold_identifier_expression< E: Expr<'ast, T> + Id<'ast, T> + From>, >( @@ -371,13 +466,6 @@ pub trait Folder<'ast, T: Field>: Sized { fold_element_expression(self, ty, e) } - fn fold_eq_expression + PartialEq + Constant + Fold<'ast, T>>( - &mut self, - e: EqExpression, - ) -> EqOrBoolean<'ast, T, E> { - fold_eq_expression(self, e) - } - fn fold_function_call_expression< E: Id<'ast, T> + From> + Expr<'ast, T> + FunctionCall<'ast, T>, >( @@ -417,6 +505,13 @@ pub trait Folder<'ast, T: Field>: Sized { fold_int_expression(self, e) } + fn fold_field_expression_cases( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> FieldElementExpression<'ast, T> { + fold_field_expression_cases(self, e) + } + fn fold_field_expression( &mut self, e: FieldElementExpression<'ast, T>, @@ -424,6 +519,13 @@ pub trait Folder<'ast, T: Field>: Sized { fold_field_expression(self, e) } + fn fold_boolean_expression_cases( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> BooleanExpression<'ast, T> { + fold_boolean_expression_cases(self, e) + } + fn fold_boolean_expression( &mut self, e: BooleanExpression<'ast, T>, @@ -435,6 +537,14 @@ pub trait Folder<'ast, T: Field>: Sized { fold_uint_expression(self, e) } + fn fold_uint_expression_cases( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> UExpressionInner<'ast, T> { + fold_uint_expression_cases(self, bitwidth, e) + } + fn fold_uint_expression_inner( &mut self, bitwidth: UBitwidth, @@ -443,6 +553,14 @@ pub trait Folder<'ast, T: Field>: Sized { fold_uint_expression_inner(self, bitwidth, e) } + fn fold_array_expression_cases( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> ArrayExpressionInner<'ast, T> { + fold_array_expression_cases(self, ty, e) + } + fn fold_array_expression_inner( &mut self, ty: &ArrayType<'ast, T>, @@ -451,6 +569,14 @@ pub trait Folder<'ast, T: Field>: Sized { fold_array_expression_inner(self, ty, e) } + fn fold_tuple_expression_cases( + &mut self, + ty: &TupleType<'ast, T>, + e: TupleExpressionInner<'ast, T>, + ) -> TupleExpressionInner<'ast, T> { + fold_tuple_expression_cases(self, ty, e) + } + fn fold_tuple_expression_inner( &mut self, ty: &TupleType<'ast, T>, @@ -459,6 +585,14 @@ pub trait Folder<'ast, T: Field>: Sized { fold_tuple_expression_inner(self, ty, e) } + fn fold_struct_expression_cases( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> StructExpressionInner<'ast, T> { + fold_struct_expression_cases(self, ty, e) + } + fn fold_struct_expression_inner( &mut self, ty: &StructType<'ast, T>, @@ -466,6 +600,51 @@ pub trait Folder<'ast, T: Field>: Sized { ) -> StructExpressionInner<'ast, T> { fold_struct_expression_inner(self, ty, e) } + + fn fold_field_value_expression( + &mut self, + v: FieldValueExpression, + ) -> ValueOrExpression, FieldElementExpression<'ast, T>> { + fold_field_value_expression(self, v) + } + + fn fold_boolean_value_expression( + &mut self, + v: BooleanValueExpression, + ) -> ValueOrExpression> { + fold_boolean_value_expression(self, v) + } + + fn fold_integer_value_expression( + &mut self, + v: IntValueExpression, + ) -> ValueOrExpression> { + fold_integer_value_expression(self, v) + } + + fn fold_struct_value_expression( + &mut self, + ty: &StructType<'ast, T>, + v: StructValueExpression<'ast, T>, + ) -> ValueOrExpression, StructExpressionInner<'ast, T>> { + fold_struct_value_expression(self, ty, v) + } + + fn fold_array_value_expression( + &mut self, + ty: &ArrayType<'ast, T>, + v: ArrayValueExpression<'ast, T>, + ) -> ValueOrExpression, ArrayExpressionInner<'ast, T>> { + fold_array_value_expression(self, ty, v) + } + + fn fold_tuple_value_expression( + &mut self, + ty: &TupleType<'ast, T>, + v: TupleValueExpression<'ast, T>, + ) -> ValueOrExpression, TupleExpressionInner<'ast, T>> { + fold_tuple_value_expression(self, ty, v) + } } pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( @@ -525,58 +704,136 @@ pub fn fold_definition_rhs<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( +pub fn fold_return_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ReturnStatement<'ast, T>, +) -> Vec> { + vec![TypedStatement::Return(ReturnStatement::new( + f.fold_expression(s.inner), + ))] +} + +pub fn fold_definition_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: DefinitionStatement<'ast, T>, +) -> Vec> { + let rhs = f.fold_definition_rhs(s.rhs); + vec![TypedStatement::Definition( + DefinitionStatement::new(f.fold_assignee(s.assignee), rhs).span(s.span), + )] +} + +pub fn fold_assertion_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssertionStatement<'ast, T>, +) -> Vec> { + vec![TypedStatement::Assertion( + AssertionStatement::new(f.fold_boolean_expression(s.expression), s.error).span(s.span), + )] +} + +pub fn fold_for_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ForStatement<'ast, T>, +) -> Vec> { + vec![TypedStatement::For(ForStatement::new( + f.fold_variable(s.var), + f.fold_uint_expression(s.from), + f.fold_uint_expression(s.to), + s.statements + .into_iter() + .flat_map(|s| f.fold_statement(s)) + .collect(), + ))] +} + +pub fn fold_log_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: LogStatement<'ast, T>, +) -> Vec> { + vec![TypedStatement::Log(LogStatement::new( + s.format_string, + s.expressions + .into_iter() + .map(|e| f.fold_expression(e)) + .collect(), + ))] +} + +pub fn fold_statement_cases<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: TypedStatement<'ast, T>, +) -> Vec> { + match s { + TypedStatement::Return(s) => f.fold_return_statement(s), + TypedStatement::Definition(s) => f.fold_definition_statement(s), + TypedStatement::Assertion(s) => f.fold_assertion_statement(s), + TypedStatement::For(s) => f.fold_for_statement(s), + TypedStatement::Log(s) => f.fold_log_statement(s), + TypedStatement::Assembly(s) => f.fold_assembly_block(s), + } +} + +pub fn fold_assembly_block<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyBlockStatement<'ast, T>, +) -> Vec> { + vec![TypedStatement::Assembly(AssemblyBlockStatement::new( + s.inner + .into_iter() + .flat_map(|s| f.fold_assembly_statement(s)) + .collect(), + ))] +} + +pub fn fold_assembly_assignment<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyAssignment<'ast, T>, +) -> Vec> { + let assignee = f.fold_assignee(s.assignee); + let expression = f.fold_expression(s.expression); + vec![TypedAssemblyStatement::assignment(assignee, expression)] +} + +pub fn fold_assembly_constraint<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyConstraint<'ast, T>, +) -> Vec> { + let left = f.fold_field_expression(s.left); + let right = f.fold_field_expression(s.right); + vec![TypedAssemblyStatement::constraint(left, right, s.metadata)] +} + +fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: TypedAssemblyStatement<'ast, T>, +) -> Vec> { + let span = s.get_span(); + f.fold_assembly_statement_cases(s) + .into_iter() + .map(|s| s.span(span)) + .collect() +} + +pub fn fold_assembly_statement_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: TypedAssemblyStatement<'ast, T>, ) -> Vec> { match s { - TypedAssemblyStatement::Assignment(a, e) => { - let e = f.fold_expression(e); - vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a), e)] - } - TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { - vec![TypedAssemblyStatement::Constraint( - f.fold_field_expression(lhs), - f.fold_field_expression(rhs), - metadata, - )] - } + TypedAssemblyStatement::Assignment(s) => f.fold_assembly_assignment(s), + TypedAssemblyStatement::Constraint(s) => f.fold_assembly_constraint(s), } } -pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: TypedStatement<'ast, T>, ) -> Vec> { - let res = match s { - TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)), - TypedStatement::Definition(a, rhs) => { - let rhs = f.fold_definition_rhs(rhs); - TypedStatement::Definition(f.fold_assignee(a), rhs) - } - TypedStatement::Assertion(e, error) => { - TypedStatement::Assertion(f.fold_boolean_expression(e), error) - } - TypedStatement::For(v, from, to, statements) => TypedStatement::For( - f.fold_variable(v), - f.fold_uint_expression(from), - f.fold_uint_expression(to), - statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - ), - TypedStatement::Log(s, e) => { - TypedStatement::Log(s, e.into_iter().map(|e| f.fold_expression(e)).collect()) - } - TypedStatement::Assembly(statements) => TypedStatement::Assembly( - statements - .into_iter() - .flat_map(|s| f.fold_assembly_statement(s)) - .collect(), - ), - }; - vec![res] + let span = s.get_span(); + f.fold_statement_cases(s) + .into_iter() + .map(|s| s.span(span)) + .collect() } pub fn fold_identifier_expression< @@ -621,7 +878,16 @@ pub fn fold_expression<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, +) -> ArrayExpressionInner<'ast, T> { + let span = e.get_span(); + f.fold_array_expression_cases(ty, e).span(span) +} + +pub fn fold_array_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, @@ -634,16 +900,14 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => Block(f.fold_block_expression(block)), - Value(exprs) => Value( - exprs - .into_iter() - .map(|e| f.fold_expression_or_spread(e)) - .collect(), - ), FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call) { FunctionCallOrExpression::FunctionCall(function_call) => FunctionCall(function_call), FunctionCallOrExpression::Expression(u) => u, }, + Value(function_call) => match f.fold_array_value_expression(ty, function_call) { + ValueOrExpression::Value(value) => Value(value), + ValueOrExpression::Expression(e) => e, + }, Conditional(c) => match f.fold_conditional_expression(ty, c) { ConditionalOrExpression::Conditional(s) => Conditional(s), ConditionalOrExpression::Expression(u) => u, @@ -656,17 +920,14 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( MemberOrExpression::Member(m) => Member(m), MemberOrExpression::Expression(u) => u, }, - Slice(box array, box from, box to) => { - let array = f.fold_array_expression(array); - let from = f.fold_uint_expression(from); - let to = f.fold_uint_expression(to); - Slice(box array, box from, box to) - } - Repeat(box e, box count) => { - let e = f.fold_expression(e); - let count = f.fold_uint_expression(count); - Repeat(box e, box count) - } + Slice(s) => match f.fold_slice_expression(s) { + SliceOrExpression::Slice(m) => Slice(m), + SliceOrExpression::Expression(u) => u, + }, + Repeat(s) => match f.fold_repeat_expression(s) { + RepeatOrExpression::Repeat(m) => Repeat(m), + RepeatOrExpression::Expression(u) => u, + }, Element(m) => match f.fold_element_expression(ty, m) { ElementOrExpression::Element(m) => Element(m), ElementOrExpression::Expression(u) => u, @@ -674,7 +935,16 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, +) -> StructExpressionInner<'ast, T> { + let span = e.get_span(); + f.fold_struct_expression_cases(ty, e).span(span) +} + +pub fn fold_struct_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, @@ -687,7 +957,10 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => Block(f.fold_block_expression(block)), - Value(exprs) => Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect()), + Value(function_call) => match f.fold_struct_value_expression(ty, function_call) { + ValueOrExpression::Value(value) => Value(value), + ValueOrExpression::Expression(e) => e, + }, FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call) { FunctionCallOrExpression::FunctionCall(function_call) => FunctionCall(function_call), FunctionCallOrExpression::Expression(u) => u, @@ -711,7 +984,16 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_tuple_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_tuple_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ty: &TupleType<'ast, T>, + e: TupleExpressionInner<'ast, T>, +) -> TupleExpressionInner<'ast, T> { + let span = e.get_span(); + f.fold_tuple_expression_cases(ty, e).span(span) +} + +pub fn fold_tuple_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ty: &TupleType<'ast, T>, e: TupleExpressionInner<'ast, T>, @@ -724,7 +1006,10 @@ pub fn fold_tuple_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( IdentifierOrExpression::Identifier(i) => Identifier(i), IdentifierOrExpression::Expression(u) => u, }, - Value(exprs) => Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect()), + Value(function_call) => match f.fold_tuple_value_expression(ty, function_call) { + ValueOrExpression::Value(value) => Value(value), + ValueOrExpression::Expression(e) => e, + }, FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call) { FunctionCallOrExpression::FunctionCall(function_call) => FunctionCall(function_call), FunctionCallOrExpression::Expression(u) => u, @@ -748,7 +1033,15 @@ pub fn fold_tuple_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> FieldElementExpression<'ast, T> { + let span = e.get_span(); + f.fold_field_expression_cases(e).span(span) +} + +pub fn fold_field_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: FieldElementExpression<'ast, T>, ) -> FieldElementExpression<'ast, T> { @@ -760,72 +1053,58 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => Block(f.fold_block_expression(block)), - Number(n) => Number(n), - Add(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - Add(box e1, box e2) - } - Sub(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - Sub(box e1, box e2) - } - Mult(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - Mult(box e1, box e2) - } - Div(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - Div(box e1, box e2) - } - Pow(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_uint_expression(e2); - Pow(box e1, box e2) - } - Neg(box e) => { - let e = f.fold_field_expression(e); - - Neg(box e) - } - Pos(box e) => { - let e = f.fold_field_expression(e); - - Pos(box e) - } - And(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - And(box left, box right) - } - Or(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - Or(box left, box right) - } - Xor(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - Xor(box left, box right) - } - LeftShift(box e, box by) => { - let e = f.fold_field_expression(e); - let by = f.fold_uint_expression(by); - - LeftShift(box e, box by) - } - RightShift(box e, box by) => { - let e = f.fold_field_expression(e); - let by = f.fold_uint_expression(by); - - RightShift(box e, box by) - } + Value(value) => match f.fold_field_value_expression(value) { + ValueOrExpression::Value(value) => Value(value), + ValueOrExpression::Expression(e) => e, + }, + Add(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Pow(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Pow(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Neg(e) => match f.fold_unary_expression(&Type::FieldElement, e) { + UnaryOrExpression::Unary(e) => Neg(e), + UnaryOrExpression::Expression(u) => u, + }, + Pos(e) => match f.fold_unary_expression(&Type::FieldElement, e) { + UnaryOrExpression::Unary(e) => Pos(e), + UnaryOrExpression::Expression(u) => u, + }, Conditional(c) => match f.fold_conditional_expression(&Type::FieldElement, c) { ConditionalOrExpression::Conditional(s) => Conditional(s), ConditionalOrExpression::Expression(u) => u, @@ -856,19 +1135,53 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( pub fn fold_conditional_expression< 'ast, T: Field, - E: Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T> + From>, + E: Expr<'ast, T> + Fold + Conditional<'ast, T> + From>, F: Folder<'ast, T>, >( f: &mut F, _: &E::Ty, e: ConditionalExpression<'ast, T, E>, ) -> ConditionalOrExpression<'ast, T, E> { - ConditionalOrExpression::Conditional(ConditionalExpression::new( - f.fold_boolean_expression(*e.condition), - e.consequence.fold(f), - e.alternative.fold(f), - e.kind, - )) + ConditionalOrExpression::Conditional( + ConditionalExpression::new( + f.fold_boolean_expression(*e.condition), + e.consequence.fold(f), + e.alternative.fold(f), + e.kind, + ) + .span(e.span), + ) +} + +pub fn fold_binary_expression< + 'ast, + T: Field, + L: Expr<'ast, T> + Fold + From>, + R: Expr<'ast, T> + Fold + From>, + E: Expr<'ast, T> + Fold + From>, + F: Folder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: BinaryExpression, +) -> BinaryOrExpression { + BinaryOrExpression::Binary(BinaryExpression::new(e.left.fold(f), e.right.fold(f))) +} + +pub fn fold_unary_expression< + 'ast, + T: Field, + In: Expr<'ast, T> + Fold + From>, + E: Expr<'ast, T> + Fold + From>, + F: Folder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: UnaryExpression, +) -> UnaryOrExpression { + UnaryOrExpression::Unary(UnaryExpression::new(e.inner.fold(f))) } pub fn fold_member_expression< @@ -887,6 +1200,27 @@ pub fn fold_member_expression< )) } +pub fn fold_slice_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: SliceExpression<'ast, T>, +) -> SliceOrExpression<'ast, T> { + SliceOrExpression::Slice(SliceExpression::new( + f.fold_array_expression(*e.array), + f.fold_uint_expression(*e.from), + f.fold_uint_expression(*e.to), + )) +} + +pub fn fold_repeat_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: RepeatExpression<'ast, T>, +) -> RepeatOrExpression<'ast, T> { + RepeatOrExpression::Repeat(RepeatExpression::new( + f.fold_expression(*e.e), + f.fold_uint_expression(*e.count), + )) +} + pub fn fold_element_expression< 'ast, T: Field, @@ -897,17 +1231,9 @@ pub fn fold_element_expression< _: &E::Ty, e: ElementExpression<'ast, T, E>, ) -> ElementOrExpression<'ast, T, E> { - ElementOrExpression::Element(ElementExpression::new( - f.fold_tuple_expression(*e.tuple), - e.index, - )) -} - -pub fn fold_eq_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T>>( - f: &mut F, - e: EqExpression, -) -> EqOrBoolean<'ast, T, E> { - EqOrBoolean::Eq(EqExpression::new(e.left.fold(f), e.right.fold(f))) + ElementOrExpression::Element( + ElementExpression::new(f.fold_tuple_expression(*e.tuple), e.index).span(e.span), + ) } pub fn fold_select_expression< @@ -920,10 +1246,13 @@ pub fn fold_select_expression< _: &E::Ty, e: SelectExpression<'ast, T, E>, ) -> SelectOrExpression<'ast, T, E> { - SelectOrExpression::Select(SelectExpression::new( - f.fold_array_expression(*e.array), - f.fold_uint_expression(*e.index), - )) + SelectOrExpression::Select( + SelectExpression::new( + f.fold_array_expression(*e.array), + f.fold_uint_expression(*e.index), + ) + .span(e.span), + ) } pub fn fold_int_expression<'ast, T: Field, F: Folder<'ast, T>>( @@ -933,7 +1262,15 @@ pub fn fold_int_expression<'ast, T: Field, F: Folder<'ast, T>>( unreachable!() } -pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> BooleanExpression<'ast, T> { + let span = e.get_span(); + f.fold_boolean_expression_cases(e).span(span) +} + +pub fn fold_boolean_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { @@ -945,85 +1282,62 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => BooleanExpression::Block(f.fold_block_expression(block)), - Value(v) => BooleanExpression::Value(v), + Value(e) => match f.fold_boolean_value_expression(e) { + ValueOrExpression::Value(e) => Value(e), + ValueOrExpression::Expression(u) => u, + }, FieldEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::FieldEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => FieldEq(e), + BinaryOrExpression::Expression(u) => u, }, BoolEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::BoolEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => BoolEq(e), + BinaryOrExpression::Expression(u) => u, }, ArrayEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::ArrayEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => ArrayEq(e), + BinaryOrExpression::Expression(u) => u, }, StructEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::StructEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => StructEq(e), + BinaryOrExpression::Expression(u) => u, }, TupleEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::TupleEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => TupleEq(e), + BinaryOrExpression::Expression(u) => u, }, UintEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::UintEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => UintEq(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLt(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => FieldLt(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLe(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => FieldLe(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLt(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => UintLt(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLe(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => UintLe(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&Type::Boolean, e) { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, }, - FieldLt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldLt(box e1, box e2) - } - FieldLe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldLe(box e1, box e2) - } - FieldGt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldGt(box e1, box e2) - } - FieldGe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldGe(box e1, box e2) - } - UintLt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - UintLt(box e1, box e2) - } - UintLe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - UintLe(box e1, box e2) - } - UintGt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - UintGt(box e1, box e2) - } - UintGe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - UintGe(box e1, box e2) - } - Or(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1); - let e2 = f.fold_boolean_expression(e2); - Or(box e1, box e2) - } - And(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1); - let e2 = f.fold_boolean_expression(e2); - And(box e1, box e2) - } - Not(box e) => { - let e = f.fold_boolean_expression(e); - Not(box e) - } FunctionCall(function_call) => { match f.fold_function_call_expression(&Type::Boolean, function_call) { FunctionCallOrExpression::FunctionCall(function_call) => { @@ -1061,7 +1375,16 @@ pub fn fold_uint_expression<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ty: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> UExpressionInner<'ast, T> { + let span = e.get_span(); + f.fold_uint_expression_cases(ty, e).span(span) +} + +pub fn fold_uint_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ty: UBitwidth, e: UExpressionInner<'ast, T>, @@ -1075,87 +1398,62 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( }, Block(block) => Block(f.fold_block_expression(block)), Value(v) => Value(v), - Add(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Add(box left, box right) - } - Sub(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Sub(box left, box right) - } - FloorSub(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - FloorSub(box left, box right) - } - Mult(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Mult(box left, box right) - } - Div(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Div(box left, box right) - } - Rem(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Rem(box left, box right) - } - Xor(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Xor(box left, box right) - } - And(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - And(box left, box right) - } - Or(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Or(box left, box right) - } - LeftShift(box e, box by) => { - let e = f.fold_uint_expression(e); - let by = f.fold_uint_expression(by); - - LeftShift(box e, box by) - } - RightShift(box e, box by) => { - let e = f.fold_uint_expression(e); - let by = f.fold_uint_expression(by); - - RightShift(box e, box by) - } - Not(box e) => { - let e = f.fold_uint_expression(e); - - Not(box e) - } - Neg(box e) => { - let e = f.fold_uint_expression(e); - - Neg(box e) - } - Pos(box e) => { - let e = f.fold_uint_expression(e); - - Pos(box e) - } + Add(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + FloorSub(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => FloorSub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Rem(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Rem(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&ty, e) { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, + Neg(e) => match f.fold_unary_expression(&ty, e) { + UnaryOrExpression::Unary(e) => Neg(e), + UnaryOrExpression::Expression(u) => u, + }, + Pos(e) => match f.fold_unary_expression(&ty, e) { + UnaryOrExpression::Unary(e) => Pos(e), + UnaryOrExpression::Expression(u) => u, + }, FunctionCall(function_call) => match f.fold_function_call_expression(&ty, function_call) { FunctionCallOrExpression::FunctionCall(function_call) => FunctionCall(function_call), FunctionCallOrExpression::Expression(u) => u, @@ -1179,18 +1477,19 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_block_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T>>( +pub fn fold_block_expression<'ast, T: Field, E: Fold, F: Folder<'ast, T>>( f: &mut F, block: BlockExpression<'ast, T, E>, ) -> BlockExpression<'ast, T, E> { - BlockExpression { - statements: block + BlockExpression::new( + block .statements .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), - value: box block.value.fold(f), - } + block.value.fold(f), + ) + .span(block.span) } pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>( @@ -1214,17 +1513,20 @@ pub fn fold_function_call_expression< _: &E::Ty, e: FunctionCallExpression<'ast, T, E>, ) -> FunctionCallOrExpression<'ast, T, E> { - FunctionCallOrExpression::FunctionCall(FunctionCallExpression::new( - f.fold_declaration_function_key(e.function_key), - e.generics - .into_iter() - .map(|g| g.map(|g| f.fold_uint_expression(g))) - .collect(), - e.arguments - .into_iter() - .map(|e| f.fold_expression(e)) - .collect(), - )) + FunctionCallOrExpression::FunctionCall( + FunctionCallExpression::new( + f.fold_declaration_function_key(e.function_key), + e.generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g))) + .collect(), + e.arguments + .into_iter() + .map(|e| f.fold_expression(e)) + .collect(), + ) + .span(e.span), + ) } pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>( @@ -1308,6 +1610,64 @@ pub fn fold_tuple_expression<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_integer_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + _f: &mut F, + a: IntValueExpression, +) -> ValueOrExpression> { + ValueOrExpression::Value(a) +} + +pub fn fold_boolean_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + _f: &mut F, + a: BooleanValueExpression, +) -> ValueOrExpression> { + ValueOrExpression::Value(a) +} + +pub fn fold_field_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + _f: &mut F, + a: FieldValueExpression, +) -> ValueOrExpression, FieldElementExpression<'ast, T>> { + ValueOrExpression::Value(a) +} + +pub fn fold_struct_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + _ty: &StructType<'ast, T>, + a: StructValueExpression<'ast, T>, +) -> ValueOrExpression, StructExpressionInner<'ast, T>> { + ValueOrExpression::Value(StructValueExpression { + value: a.value.into_iter().map(|v| f.fold_expression(v)).collect(), + ..a + }) +} + +pub fn fold_array_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + _ty: &ArrayType<'ast, T>, + a: ArrayValueExpression<'ast, T>, +) -> ValueOrExpression, ArrayExpressionInner<'ast, T>> { + ValueOrExpression::Value(ArrayValueExpression { + value: a + .value + .into_iter() + .map(|v| f.fold_expression_or_spread(v)) + .collect(), + ..a + }) +} + +pub fn fold_tuple_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + _ty: &TupleType<'ast, T>, + a: TupleValueExpression<'ast, T>, +) -> ValueOrExpression, TupleExpressionInner<'ast, T>> { + ValueOrExpression::Value(TupleValueExpression { + value: a.value.into_iter().map(|v| f.fold_expression(v)).collect(), + ..a + }) +} + pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, c: TypedConstant<'ast, T>, @@ -1370,5 +1730,6 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( .map(|(module_id, module)| (f.fold_module_id(module_id), f.fold_module(module))) .collect(), main: f.fold_module_id(p.main), + ..p } } diff --git a/zokrates_ast/src/typed/integer.rs b/zokrates_ast/src/typed/integer.rs index 507ce9174..8f8b11915 100644 --- a/zokrates_ast/src/typed/integer.rs +++ b/zokrates_ast/src/typed/integer.rs @@ -8,14 +8,21 @@ use crate::typed::{ ArrayExpression, ArrayExpressionInner, BooleanExpression, Conditional, ConditionalExpression, Expr, FieldElementExpression, Select, SelectExpression, StructExpression, StructExpressionInner, TupleExpression, TupleExpressionInner, Typed, TypedExpression, - TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner, + TypedExpressionOrSpread, TypedSpread, UExpression, }; + +use crate::common::{operators::*, WithSpan}; + use num_bigint::BigUint; use std::convert::TryFrom; use std::fmt; -use std::ops::{Add, Div, Mul, Neg, Not, Rem, Sub}; +use std::ops::*; use zokrates_field::Field; +use crate::common::expressions::*; + +use super::{ArrayValueExpression, RepeatExpression}; + type TypedExpressionPair<'ast, T> = (TypedExpression<'ast, T>, TypedExpression<'ast, T>); impl<'ast, T: Field> TypedExpressionOrSpread<'ast, T> { @@ -274,30 +281,107 @@ impl<'ast, T: Field> TypedExpression<'ast, T> { #[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)] pub enum IntExpression<'ast, T> { - Value(BigUint), - Pos(Box>), - Neg(Box>), - Add(Box>, Box>), - Sub(Box>, Box>), - Mult(Box>, Box>), - Div(Box>, Box>), - Rem(Box>, Box>), - Pow(Box>, Box>), + Value(IntValueExpression), + Pos(UnaryExpression, IntExpression<'ast, T>>), + Neg(UnaryExpression, IntExpression<'ast, T>>), + Not(UnaryExpression, IntExpression<'ast, T>>), + Add( + BinaryExpression< + OpAdd, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Sub( + BinaryExpression< + OpSub, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Mult( + BinaryExpression< + OpMul, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Div( + BinaryExpression< + OpDiv, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Rem( + BinaryExpression< + OpRem, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Pow( + BinaryExpression< + OpPow, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Xor( + BinaryExpression< + OpXor, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + And( + BinaryExpression< + OpAnd, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Or( + BinaryExpression< + OpOr, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + LeftShift( + BinaryExpression< + OpLsh, + IntExpression<'ast, T>, + UExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + RightShift( + BinaryExpression< + OpRsh, + IntExpression<'ast, T>, + UExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), Conditional(ConditionalExpression<'ast, T, IntExpression<'ast, T>>), Select(SelectExpression<'ast, T, IntExpression<'ast, T>>), - Xor(Box>, Box>), - And(Box>, Box>), - Or(Box>, Box>), - Not(Box>), - LeftShift(Box>, Box>), - RightShift(Box>, Box>), } impl<'ast, T> Add for IntExpression<'ast, T> { type Output = Self; fn add(self, other: Self) -> Self { - IntExpression::Add(box self, box other) + IntExpression::Add(BinaryExpression::new(self, other)) } } @@ -305,7 +389,7 @@ impl<'ast, T> Sub for IntExpression<'ast, T> { type Output = Self; fn sub(self, other: Self) -> Self { - IntExpression::Sub(box self, box other) + IntExpression::Sub(BinaryExpression::new(self, other)) } } @@ -313,7 +397,7 @@ impl<'ast, T> Mul for IntExpression<'ast, T> { type Output = Self; fn mul(self, other: Self) -> Self { - IntExpression::Mult(box self, box other) + IntExpression::Mult(BinaryExpression::new(self, other)) } } @@ -321,7 +405,7 @@ impl<'ast, T> Div for IntExpression<'ast, T> { type Output = Self; fn div(self, other: Self) -> Self { - IntExpression::Div(box self, box other) + IntExpression::Div(BinaryExpression::new(self, other)) } } @@ -329,7 +413,7 @@ impl<'ast, T> Rem for IntExpression<'ast, T> { type Output = Self; fn rem(self, other: Self) -> Self { - IntExpression::Rem(box self, box other) + IntExpression::Rem(BinaryExpression::new(self, other)) } } @@ -337,7 +421,7 @@ impl<'ast, T> Not for IntExpression<'ast, T> { type Output = Self; fn not(self) -> Self { - IntExpression::Not(box self) + IntExpression::Not(UnaryExpression::new(self)) } } @@ -345,37 +429,37 @@ impl<'ast, T> Neg for IntExpression<'ast, T> { type Output = Self; fn neg(self) -> Self { - IntExpression::Neg(box self) + IntExpression::Neg(UnaryExpression::new(self)) } } impl<'ast, T> IntExpression<'ast, T> { pub fn pow(self, other: Self) -> Self { - IntExpression::Pow(box self, box other) + IntExpression::Pow(BinaryExpression::new(self, other)) } pub fn and(self, other: Self) -> Self { - IntExpression::And(box self, box other) + IntExpression::And(BinaryExpression::new(self, other)) } pub fn xor(self, other: Self) -> Self { - IntExpression::Xor(box self, box other) + IntExpression::Xor(BinaryExpression::new(self, other)) } pub fn or(self, other: Self) -> Self { - IntExpression::Or(box self, box other) + IntExpression::Or(BinaryExpression::new(self, other)) } pub fn left_shift(self, by: UExpression<'ast, T>) -> Self { - IntExpression::LeftShift(box self, box by) + IntExpression::LeftShift(BinaryExpression::new(self, by)) } pub fn right_shift(self, by: UExpression<'ast, T>) -> Self { - IntExpression::RightShift(box self, box by) + IntExpression::RightShift(BinaryExpression::new(self, by)) } pub fn pos(self) -> Self { - IntExpression::Pos(box self) + IntExpression::Pos(UnaryExpression::new(self)) } } @@ -383,21 +467,21 @@ impl<'ast, T: fmt::Display> fmt::Display for IntExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { IntExpression::Value(ref v) => write!(f, "{}", v), - IntExpression::Pos(ref e) => write!(f, "(+{})", e), - IntExpression::Neg(ref e) => write!(f, "(-{})", e), - IntExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), - IntExpression::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs), - IntExpression::Pow(ref lhs, ref rhs) => write!(f, "({} ** {})", lhs, rhs), + IntExpression::Pos(ref e) => write!(f, "{}", e), + IntExpression::Neg(ref e) => write!(f, "{}", e), + IntExpression::Div(ref e) => write!(f, "{}", e), + IntExpression::Rem(ref e) => write!(f, "{}", e), + IntExpression::Pow(ref e) => write!(f, "{}", e), IntExpression::Select(ref select) => write!(f, "{}", select), - IntExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - IntExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), - IntExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), - IntExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), - IntExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - IntExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - IntExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), - IntExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), - IntExpression::Not(ref e) => write!(f, "!{}", e), + IntExpression::Add(ref e) => write!(f, "{}", e), + IntExpression::And(ref e) => write!(f, "{}", e), + IntExpression::Or(ref e) => write!(f, "{}", e), + IntExpression::Xor(ref e) => write!(f, "{}", e), + IntExpression::Sub(ref e) => write!(f, "{}", e), + IntExpression::Mult(ref e) => write!(f, "{}", e), + IntExpression::RightShift(ref e) => write!(f, "{}", e), + IntExpression::LeftShift(ref e) => write!(f, "{}", e), + IntExpression::Not(ref e) => write!(f, "{}", e), IntExpression::Conditional(ref c) => write!(f, "{}", c), } } @@ -424,48 +508,52 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { } pub fn try_from_int(i: IntExpression<'ast, T>) -> Result> { + let span = i.get_span(); + match i { - IntExpression::Value(i) => Ok(Self::Number(T::try_from(i.clone()).map_err(|_| i)?)), - IntExpression::Add(box e1, box e2) => Ok(Self::Add( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Value(i) => Ok(Self::Value(ValueExpression::new( + T::try_from(i.value.clone()).map_err(|_| IntExpression::Value(i))?, + ))), + IntExpression::Add(e) => Ok(Self::add( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::Sub(box e1, box e2) => Ok(Self::Sub( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Sub(e) => Ok(Self::sub( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::Mult(box e1, box e2) => Ok(Self::Mult( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Mult(e) => Ok(Self::mul( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::Pow(box e1, box e2) => Ok(Self::Pow( - box Self::try_from_int(e1)?, - box UExpression::try_from_int(e2, &UBitwidth::B32)?, + IntExpression::Pow(e) => Ok(Self::pow( + Self::try_from_int(*e.left)?, + UExpression::try_from_int(*e.right, &UBitwidth::B32)?, )), - IntExpression::Div(box e1, box e2) => Ok(Self::Div( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Div(e) => Ok(Self::div( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::And(box e1, box e2) => Ok(Self::And( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::And(e) => Ok(Self::bitand( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::Or(box e1, box e2) => Ok(Self::Or( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Or(e) => Ok(Self::bitor( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::Xor(box e1, box e2) => Ok(Self::Xor( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Xor(e) => Ok(Self::bitxor( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::LeftShift(box e1, box e2) => { - Ok(Self::LeftShift(box Self::try_from_int(e1)?, box e2)) + IntExpression::LeftShift(e) => { + Ok(Self::left_shift(Self::try_from_int(*e.left)?, *e.right)) } - IntExpression::RightShift(box e1, box e2) => { - Ok(Self::RightShift(box Self::try_from_int(e1)?, box e2)) + IntExpression::RightShift(e) => { + Ok(Self::right_shift(Self::try_from_int(*e.left)?, *e.right)) } - IntExpression::Pos(box e) => Ok(Self::Pos(box Self::try_from_int(e)?)), - IntExpression::Neg(box e) => Ok(Self::Neg(box Self::try_from_int(e)?)), + IntExpression::Pos(e) => Ok(Self::pos(Self::try_from_int(*e.inner)?)), + IntExpression::Neg(e) => Ok(Self::neg(Self::try_from_int(*e.inner)?)), IntExpression::Conditional(c) => Ok(Self::Conditional(ConditionalExpression::new( *c.condition, Self::try_from_int(*c.consequence)?, @@ -481,6 +569,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { match array.into_inner() { ArrayExpressionInner::Value(values) => { let values = values + .value .into_iter() .map(|v| { TypedExpressionOrSpread::align_to_type( @@ -501,8 +590,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { }) .collect::, _>>()?; Ok(FieldElementExpression::select( - ArrayExpressionInner::Value(values.into()) - .annotate(Type::FieldElement, size), + ArrayExpression::value(values).annotate(Type::FieldElement, size), index, )) } @@ -511,6 +599,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { } i => Err(i), } + .map(|e| e.span(span)) } } @@ -537,54 +626,56 @@ impl<'ast, T: Field> UExpression<'ast, T> { ) -> Result> { use self::IntExpression::*; + let span = i.get_span(); + match i { Value(i) => { - if i <= BigUint::from(2u128.pow(bitwidth.to_usize() as u32) - 1) { - Ok(UExpressionInner::Value( - u128::from_str_radix(&i.to_str_radix(16), 16).unwrap(), + if i.value <= BigUint::from(2u128.pow(bitwidth.to_usize() as u32) - 1) { + Ok(UExpression::value( + u128::from_str_radix(&i.value.to_str_radix(16), 16).unwrap(), ) .annotate(*bitwidth)) } else { Err(Value(i)) } } - Add(box e1, box e2) => { - Ok(Self::try_from_int(e1, bitwidth)? + Self::try_from_int(e2, bitwidth)?) - } - Pos(box e) => Ok(Self::pos(Self::try_from_int(e, bitwidth)?)), - Neg(box e) => Ok(Self::neg(Self::try_from_int(e, bitwidth)?)), - Sub(box e1, box e2) => { - Ok(Self::try_from_int(e1, bitwidth)? - Self::try_from_int(e2, bitwidth)?) - } - Mult(box e1, box e2) => { - Ok(Self::try_from_int(e1, bitwidth)? * Self::try_from_int(e2, bitwidth)?) - } - Div(box e1, box e2) => { - Ok(Self::try_from_int(e1, bitwidth)? / Self::try_from_int(e2, bitwidth)?) - } - Rem(box e1, box e2) => { - Ok(Self::try_from_int(e1, bitwidth)? % Self::try_from_int(e2, bitwidth)?) - } - And(box e1, box e2) => Ok(UExpression::and( - Self::try_from_int(e1, bitwidth)?, - Self::try_from_int(e2, bitwidth)?, + Add(e) => Ok( + Self::try_from_int(*e.left, bitwidth)? + Self::try_from_int(*e.right, bitwidth)? + ), + Pos(e) => Ok(Self::pos(Self::try_from_int(*e.inner, bitwidth)?)), + Neg(e) => Ok(Self::neg(Self::try_from_int(*e.inner, bitwidth)?)), + Sub(e) => Ok( + Self::try_from_int(*e.left, bitwidth)? - Self::try_from_int(*e.right, bitwidth)? + ), + Mult(e) => Ok( + Self::try_from_int(*e.left, bitwidth)? * Self::try_from_int(*e.right, bitwidth)? + ), + Div(e) => Ok( + Self::try_from_int(*e.left, bitwidth)? / Self::try_from_int(*e.right, bitwidth)? + ), + Rem(e) => Ok( + Self::try_from_int(*e.left, bitwidth)? % Self::try_from_int(*e.right, bitwidth)? + ), + And(e) => Ok(UExpression::and( + Self::try_from_int(*e.left, bitwidth)?, + Self::try_from_int(*e.right, bitwidth)?, )), - Or(box e1, box e2) => Ok(UExpression::or( - Self::try_from_int(e1, bitwidth)?, - Self::try_from_int(e2, bitwidth)?, + Or(e) => Ok(UExpression::or( + Self::try_from_int(*e.left, bitwidth)?, + Self::try_from_int(*e.right, bitwidth)?, )), - Not(box e) => Ok(!Self::try_from_int(e, bitwidth)?), - Xor(box e1, box e2) => Ok(UExpression::xor( - Self::try_from_int(e1, bitwidth)?, - Self::try_from_int(e2, bitwidth)?, + Not(e) => Ok(!Self::try_from_int(*e.inner, bitwidth)?), + Xor(e) => Ok(UExpression::xor( + Self::try_from_int(*e.left, bitwidth)?, + Self::try_from_int(*e.right, bitwidth)?, )), - RightShift(box e1, box e2) => Ok(UExpression::right_shift( - Self::try_from_int(e1, bitwidth)?, - e2, + RightShift(e) => Ok(UExpression::right_shift( + Self::try_from_int(*e.left, bitwidth)?, + *e.right, )), - LeftShift(box e1, box e2) => Ok(UExpression::left_shift( - Self::try_from_int(e1, bitwidth)?, - e2, + LeftShift(e) => Ok(UExpression::left_shift( + Self::try_from_int(*e.left, bitwidth)?, + *e.right, )), Conditional(c) => Ok(UExpression::conditional( *c.condition, @@ -600,6 +691,7 @@ impl<'ast, T: Field> UExpression<'ast, T> { match array.into_inner() { ArrayExpressionInner::Value(values) => { let values = values + .value .into_iter() .map(|v| { TypedExpressionOrSpread::align_to_type( @@ -620,8 +712,7 @@ impl<'ast, T: Field> UExpression<'ast, T> { }) .collect::, _>>()?; Ok(UExpression::select( - ArrayExpressionInner::Value(values.into()) - .annotate(Type::Uint(*bitwidth), size), + ArrayExpression::value(values).annotate(Type::Uint(*bitwidth), size), index, )) } @@ -630,6 +721,7 @@ impl<'ast, T: Field> UExpression<'ast, T> { } i => Err(i), } + .map(|e| e.span(span)) } } @@ -649,6 +741,8 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { array: Self, target_array_ty: &GArrayType, ) -> Result> { + let span = array.get_span(); + let array_ty = array.ty().clone(); // elements must fit in the target type @@ -659,6 +753,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { _ => { // try to convert all elements to the target type inline_array + .value .into_iter() .map(|v| { TypedExpressionOrSpread::align_to_type(v, target_array_ty).map_err( @@ -671,24 +766,25 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { ) }) .collect::, _>>() - .map(|v| v.into()) + .map(ArrayValueExpression::new) } }?; - let inner_ty = res.0[0].get_type().0; + let inner_ty = res.value[0].get_type().0; Ok(ArrayExpressionInner::Value(res).annotate(inner_ty, *array_ty.size)) } - ArrayExpressionInner::Repeat(box e, box count) => { + ArrayExpressionInner::Repeat(r) => { match &*target_array_ty.ty { - GType::Int => Ok(ArrayExpressionInner::Repeat(box e, box count) - .annotate(Type::Int, *array_ty.size)), + GType::Int => { + Ok(ArrayExpressionInner::Repeat(r).annotate(Type::Int, *array_ty.size)) + } // try to align the repeated element to the target type - t => TypedExpression::align_to_type(e, t) + t => TypedExpression::align_to_type(*r.e, t) .map(|e| { let ty = e.get_type().clone(); - ArrayExpressionInner::Repeat(box e, box count) + ArrayExpressionInner::Repeat(RepeatExpression::new(e, *r.count)) .annotate(ty, *array_ty.size) }) .map_err(|(e, _)| e), @@ -702,6 +798,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { } } } + .map(|e| e.span(span)) } } @@ -710,6 +807,8 @@ impl<'ast, T: Field> StructExpression<'ast, T> { struc: Self, target_struct_ty: &GStructType, ) -> Result> { + let span = struc.get_span(); + let struct_ty = struc.ty().clone(); if struct_ty.members.len() != target_struct_ty.members.len() { @@ -724,7 +823,7 @@ impl<'ast, T: Field> StructExpression<'ast, T> { TypedExpression::align_to_type(value, &*target_member.ty) }) .collect::, _>>() - .map(|v| StructExpressionInner::Value(v).annotate(struct_ty.clone())) + .map(|v| StructExpression::value(v).annotate(struct_ty.clone())) .map_err(|(v, _)| v), s => { if struct_ty @@ -739,6 +838,7 @@ impl<'ast, T: Field> StructExpression<'ast, T> { } } } + .map(|e| e.span(span)) } pub fn try_from_typed>>( @@ -757,6 +857,8 @@ impl<'ast, T: Field> TupleExpression<'ast, T> { tuple: Self, target_tuple_ty: >upleType, ) -> Result> { + let span = tuple.get_span(); + let tuple_ty = tuple.ty().clone(); if tuple_ty.elements.len() != target_tuple_ty.elements.len() { @@ -771,7 +873,7 @@ impl<'ast, T: Field> TupleExpression<'ast, T> { .collect::, _>>() .map(|v| { let ty = TupleType::new(v.iter().map(|e| e.get_type()).collect()); - TupleExpressionInner::Value(v).annotate(ty) + TupleExpression::value(v).annotate(ty) }) .map_err(|(v, _)| v), s => { @@ -787,6 +889,7 @@ impl<'ast, T: Field> TupleExpression<'ast, T> { } } } + .map(|e| e.span(span)) } pub fn try_from_typed>>( @@ -800,9 +903,9 @@ impl<'ast, T: Field> TupleExpression<'ast, T> { } } -impl<'ast, T> From for IntExpression<'ast, T> { +impl<'ast, T: Field> From for IntExpression<'ast, T> { fn from(v: BigUint) -> Self { - IntExpression::Value(v) + IntExpression::value(v) } } @@ -816,10 +919,11 @@ mod tests { fn field_from_int() { let n: IntExpression = BigUint::from(42usize).into(); let n_a: ArrayExpression = - ArrayExpressionInner::Value(vec![n.clone().into()].into()).annotate(Type::Int, 1u32); + ArrayExpressionInner::Value(ValueExpression::new(vec![n.clone().into()])) + .annotate(Type::Int, 1u32); let t: FieldElementExpression = Bn128Field::from(42).into(); let t_a: ArrayExpression = - ArrayExpressionInner::Value(vec![t.clone().into()].into()) + ArrayExpressionInner::Value(ValueExpression::new(vec![t.clone().into()])) .annotate(Type::FieldElement, 1u32); let i: UExpression = 42u32.into(); let c: BooleanExpression = true.into(); @@ -876,10 +980,11 @@ mod tests { fn uint_from_int() { let n: IntExpression = BigUint::from(42usize).into(); let n_a: ArrayExpression = - ArrayExpressionInner::Value(vec![n.clone().into()].into()).annotate(Type::Int, 1u32); + ArrayExpressionInner::Value(ValueExpression::new(vec![n.clone().into()])) + .annotate(Type::Int, 1u32); let t: UExpression = 42u32.into(); let t_a: ArrayExpression = - ArrayExpressionInner::Value(vec![t.clone().into()].into()) + ArrayExpressionInner::Value(ValueExpression::new(vec![t.clone().into()])) .annotate(Type::Uint(UBitwidth::B32), 1u32); let i: UExpression = 0u32.into(); let c: BooleanExpression = true.into(); diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index bd000d12a..53e2096c5 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -27,16 +27,23 @@ pub use self::types::{ UBitwidth, }; use self::types::{ConcreteArrayType, ConcreteStructType}; + +use crate::common::expressions::{ + self, BinaryExpression, BooleanValueExpression, FieldValueExpression, UnaryExpression, + ValueExpression, +}; +use crate::common::{self, ModuleMap, Span, Value, WithSpan}; +pub use crate::common::{ModuleId, OwnedModuleId}; use crate::typed::types::IntoType; pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable}; use std::marker::PhantomData; -use std::path::{Path, PathBuf}; +use std::ops::Deref; pub use crate::typed::integer::IntExpression; pub use crate::typed::uint::{bitwidth, UExpression, UExpressionInner, UMetadata}; -use crate::common::{FlatEmbed, FormatString, SourceMetadata}; +use crate::common::{operators::*, FlatEmbed, FormatString, SourceMetadata}; use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; @@ -44,20 +51,17 @@ use std::fmt; pub use crate::typed::types::{ArrayType, FunctionKey, MemberId}; +use derivative::Derivative; +use num_bigint::BigUint; use zokrates_field::Field; pub use self::folder::Folder; use crate::typed::abi::{Abi, AbiInput}; -use std::ops::{Add, Div, Mul, Sub}; pub use self::identifier::Identifier; -/// An identifier for a `TypedModule`. Typically a path or uri. -pub type OwnedTypedModuleId = PathBuf; -pub type TypedModuleId = Path; - /// A collection of `TypedModule`s -pub type TypedModules<'ast, T> = BTreeMap>; +pub type TypedModules<'ast, T> = BTreeMap>; /// A collection of `TypedFunctionSymbol`s /// # Remarks @@ -80,12 +84,16 @@ pub type TypedConstantSymbols<'ast, T> = Vec<( )>; /// A typed program as a collection of modules, one of them being the main -#[derive(PartialEq, Eq, Debug, Clone)] +#[derive(PartialEq, Eq, Debug, Clone, Default)] pub struct TypedProgram<'ast, T> { + pub module_map: ModuleMap, pub modules: TypedModules<'ast, T>, - pub main: OwnedTypedModuleId, + pub main: OwnedModuleId, } +pub type IdentifierOrExpression<'ast, T, E> = + expressions::IdentifierOrExpression, E, >::Inner>; + impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn abi(&self) -> Abi { let main = &self.modules[&self.main] @@ -107,7 +115,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { crate::typed::types::try_from_g_type::< DeclarationConstant<'ast, T>, UExpression<'ast, T>, - >(p.id._type.clone()) + >(p.id.ty.clone()) .unwrap(), ) .map(|ty| AbiInput { @@ -499,7 +507,7 @@ impl<'ast, T: Clone> TypedExpressionOrSpread<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionOrSpread<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - TypedExpressionOrSpread::Expression(e) => write!(f, "{}", e), + TypedExpressionOrSpread::Expression(ref e) => write!(f, "{}", e), TypedExpressionOrSpread::Spread(s) => write!(f, "{}", s), } } @@ -642,98 +650,242 @@ impl<'ast, T: fmt::Display> fmt::Display for DefinitionRhs<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { DefinitionRhs::EmbedCall(c) => write!(f, "{}", c), - DefinitionRhs::Expression(e) => write!(f, "{}", e), + DefinitionRhs::Expression(ref e) => write!(f, "{}", e), + } + } +} + +pub type DefinitionStatement<'ast, T> = + common::statements::DefinitionStatement, DefinitionRhs<'ast, T>>; +pub type AssertionStatement<'ast, T> = + common::statements::AssertionStatement, RuntimeError>; +pub type ReturnStatement<'ast, T> = common::statements::ReturnStatement>; +pub type LogStatement<'ast, T> = common::statements::LogStatement>; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct ForStatement<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub var: Variable<'ast, T>, + pub from: UExpression<'ast, T>, + pub to: UExpression<'ast, T>, + pub statements: Vec>, +} + +impl<'ast, T> ForStatement<'ast, T> { + fn new( + var: Variable<'ast, T>, + from: UExpression<'ast, T>, + to: UExpression<'ast, T>, + statements: Vec>, + ) -> Self { + Self { + span: None, + var, + from, + to, + statements, } } } +impl<'ast, T> WithSpan for ForStatement<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct AssemblyBlockStatement<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub inner: Vec>, +} + +impl<'ast, T> AssemblyBlockStatement<'ast, T> { + pub fn new(inner: Vec>) -> Self { + Self { span: None, inner } + } +} + +pub type AssemblyConstraint<'ast, T> = + crate::common::statements::AssemblyConstraint>; +pub type AssemblyAssignment<'ast, T> = + crate::common::statements::AssemblyAssignment, TypedExpression<'ast, T>>; + #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TypedAssemblyStatement<'ast, T> { - Assignment(TypedAssignee<'ast, T>, TypedExpression<'ast, T>), - Constraint( - FieldElementExpression<'ast, T>, - FieldElementExpression<'ast, T>, - SourceMetadata, - ), + Assignment(AssemblyAssignment<'ast, T>), + Constraint(AssemblyConstraint<'ast, T>), +} + +impl<'ast, T> WithSpan for TypedAssemblyStatement<'ast, T> { + fn span(self, span: Option) -> Self { + match self { + TypedAssemblyStatement::Assignment(s) => { + TypedAssemblyStatement::Assignment(s.span(span)) + } + TypedAssemblyStatement::Constraint(s) => { + TypedAssemblyStatement::Constraint(s.span(span)) + } + } + } + + fn get_span(&self) -> Option { + match self { + TypedAssemblyStatement::Assignment(s) => s.get_span(), + TypedAssemblyStatement::Constraint(s) => s.get_span(), + } + } +} + +impl<'ast, T> TypedAssemblyStatement<'ast, T> { + pub fn assignment( + assignee: TypedAssignee<'ast, T>, + expression: TypedExpression<'ast, T>, + ) -> Self { + TypedAssemblyStatement::Assignment(AssemblyAssignment::new(assignee, expression)) + } + + pub fn constraint( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + metadata: SourceMetadata, + ) -> Self { + TypedAssemblyStatement::Constraint(AssemblyConstraint::new(left, right, metadata)) + } } impl<'ast, T: fmt::Display> fmt::Display for TypedAssemblyStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - TypedAssemblyStatement::Assignment(ref lhs, ref rhs) => { - write!(f, "{} <-- {};", lhs, rhs) + TypedAssemblyStatement::Assignment(ref s) => { + write!(f, "{} <-- {};", s.assignee, s.expression) } - TypedAssemblyStatement::Constraint(ref lhs, ref rhs, _) => { - write!(f, "{} === {};", lhs, rhs) + TypedAssemblyStatement::Constraint(ref s) => { + write!(f, "{}", s) } } } } +impl<'ast, T> WithSpan for AssemblyBlockStatement<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + /// A statement in a `TypedFunction` #[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TypedStatement<'ast, T> { - Return(TypedExpression<'ast, T>), - Definition(TypedAssignee<'ast, T>, DefinitionRhs<'ast, T>), - Assertion(BooleanExpression<'ast, T>, RuntimeError), - For( - Variable<'ast, T>, - UExpression<'ast, T>, - UExpression<'ast, T>, - Vec>, - ), - Log(FormatString, Vec>), - Assembly(Vec>), + Return(ReturnStatement<'ast, T>), + Definition(DefinitionStatement<'ast, T>), + Assertion(AssertionStatement<'ast, T>), + For(ForStatement<'ast, T>), + Log(LogStatement<'ast, T>), + Assembly(AssemblyBlockStatement<'ast, T>), +} + +impl<'ast, T> WithSpan for TypedStatement<'ast, T> { + fn span(self, span: Option) -> Self { + use TypedStatement::*; + + match self { + Return(e) => Return(e.span(span)), + Definition(e) => Definition(e.span(span)), + Assertion(e) => Assertion(e.span(span)), + For(e) => For(e.span(span)), + Log(e) => Log(e.span(span)), + Assembly(e) => Assembly(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use TypedStatement::*; + + match self { + Return(e) => e.get_span(), + Definition(e) => e.get_span(), + Assertion(e) => e.get_span(), + For(e) => e.get_span(), + Log(e) => e.get_span(), + Assembly(e) => e.get_span(), + } + } } impl<'ast, T> TypedStatement<'ast, T> { pub fn definition(a: TypedAssignee<'ast, T>, e: TypedExpression<'ast, T>) -> Self { - Self::Definition(a, e.into()) + Self::Definition(DefinitionStatement::new(a, e.into())) + } + + pub fn for_( + var: Variable<'ast, T>, + from: UExpression<'ast, T>, + to: UExpression<'ast, T>, + statements: Vec>, + ) -> Self { + Self::For(ForStatement::new(var, from, to, statements)) + } + + pub fn assertion(e: BooleanExpression<'ast, T>, error: RuntimeError) -> Self { + Self::Assertion(AssertionStatement::new(e, error)) + } + + pub fn ret(e: TypedExpression<'ast, T>) -> Self { + Self::Return(ReturnStatement::new(e)) } pub fn embed_call_definition(a: TypedAssignee<'ast, T>, c: EmbedCall<'ast, T>) -> Self { - Self::Definition(a, c.into()) + Self::Definition(DefinitionStatement::new(a, c.into())) + } + + pub fn log(format_string: FormatString, expressions: Vec>) -> Self { + Self::Log(LogStatement::new(format_string, expressions)) } } impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - TypedStatement::Return(ref e) => { - write!(f, "return {};", e) + TypedStatement::Return(ref s) => { + write!(f, "{}", s) } - TypedStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {};", lhs, rhs), - TypedStatement::Assertion(ref e, ref error) => { - write!(f, "assert({}", e)?; - match error { - RuntimeError::SourceAssertion(metadata) => match &metadata.message { + TypedStatement::Definition(ref s) => write!(f, "{}", s), + TypedStatement::Assertion(ref s) => { + write!(f, "assert({}", s.expression)?; + match s.error { + RuntimeError::SourceAssertion(ref metadata) => match &metadata.message { Some(m) => write!(f, ", \"{}\");", m), None => write!(f, ");"), }, - error => write!(f, "); // {}", error), + ref error => write!(f, "); // {}", error), } } - TypedStatement::For(ref var, ref start, ref stop, ref list) => { - writeln!(f, "for {} in {}..{} {{", var, start, stop)?; - for l in list { + TypedStatement::For(ref s) => { + writeln!(f, "for {} in {}..{} {{", s.var, s.from, s.to)?; + for l in &s.statements { writeln!(f, "\t\t{}", l)?; } write!(f, "\t}}") } - TypedStatement::Log(ref l, ref expressions) => write!( - f, - "log({}, {})", - l, - expressions - .iter() - .map(|e| e.to_string()) - .collect::>() - .join(", ") - ), - TypedStatement::Assembly(ref statements) => { + TypedStatement::Log(ref s) => write!(f, "{}", s), + TypedStatement::Assembly(ref s) => { writeln!(f, "asm {{")?; - for s in statements { + for s in &s.inner { writeln!(f, "\t\t{}", s)?; } write!(f, "\t}}") @@ -759,9 +911,9 @@ pub enum TypedExpression<'ast, T> { Int(IntExpression<'ast, T>), } -impl<'ast, T> TypedExpression<'ast, T> { +impl<'ast, T: Field> TypedExpression<'ast, T> { pub fn empty_tuple() -> TypedExpression<'ast, T> { - TypedExpression::Tuple(TupleExpressionInner::Value(vec![]).annotate(TupleType::new(vec![]))) + TypedExpression::Tuple(TupleExpression::value(vec![]).annotate(TupleType::new(vec![]))) } } @@ -905,29 +1057,12 @@ impl<'ast, T: Clone> Typed<'ast, T> for BooleanExpression<'ast, T> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] -pub struct EqExpression { - pub left: Box, - pub right: Box, -} - -impl EqExpression { - pub fn new(left: E, right: E) -> Self { - EqExpression { - left: box left, - right: box right, - } - } -} - -impl fmt::Display for EqExpression { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "({} == {})", self.left, self.right) - } -} - -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct BlockExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub statements: Vec>, pub value: Box, } @@ -935,35 +1070,99 @@ pub struct BlockExpression<'ast, T, E> { impl<'ast, T, E> BlockExpression<'ast, T, E> { pub fn new(statements: Vec>, value: E) -> Self { BlockExpression { + span: None, statements, value: box value, } } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] -pub struct IdentifierExpression<'ast, E> { - pub id: Identifier<'ast>, - ty: PhantomData, +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct SliceExpression<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub array: Box>, + pub from: Box>, + pub to: Box>, } -impl<'ast, E> IdentifierExpression<'ast, E> { - pub fn new(id: Identifier<'ast>) -> Self { - IdentifierExpression { - id, - ty: PhantomData, +impl<'ast, T> SliceExpression<'ast, T> { + pub fn new( + array: ArrayExpression<'ast, T>, + from: UExpression<'ast, T>, + to: UExpression<'ast, T>, + ) -> Self { + SliceExpression { + span: None, + array: Box::new(array), + from: Box::new(from), + to: Box::new(to), } } } -impl<'ast, E> fmt::Display for IdentifierExpression<'ast, E> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.id) +impl<'ast, T> WithSpan for SliceExpression<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +impl<'ast, T: fmt::Display> fmt::Display for SliceExpression<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}[{}..{}]", self.array, self.from, self.to) + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct RepeatExpression<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub count: Box>, + pub e: Box>, +} + +impl<'ast, T> RepeatExpression<'ast, T> { + pub fn new(e: TypedExpression<'ast, T>, count: UExpression<'ast, T>) -> Self { + RepeatExpression { + span: None, + e: Box::new(e), + count: Box::new(count), + } + } +} + +impl<'ast, T> WithSpan for RepeatExpression<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T: fmt::Display> fmt::Display for RepeatExpression<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[{}; {}]", self.e, self.count) + } +} + +pub type IdentifierExpression<'ast, E> = expressions::IdentifierExpression, E>; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct MemberExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub struc: Box>, pub id: MemberId, ty: PhantomData, @@ -972,6 +1171,7 @@ pub struct MemberExpression<'ast, T, E> { impl<'ast, T, E> MemberExpression<'ast, T, E> { pub fn new(struc: StructExpression<'ast, T>, id: MemberId) -> Self { MemberExpression { + span: None, struc: box struc, id, ty: PhantomData, @@ -985,8 +1185,12 @@ impl<'ast, T: fmt::Display, E> fmt::Display for MemberExpression<'ast, T, E> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct SelectExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub array: Box>, pub index: Box>, ty: PhantomData, @@ -995,6 +1199,7 @@ pub struct SelectExpression<'ast, T, E> { impl<'ast, T, E> SelectExpression<'ast, T, E> { pub fn new(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { SelectExpression { + span: None, array: box array, index: box index, ty: PhantomData, @@ -1008,8 +1213,12 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct ElementExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub tuple: Box>, pub index: u32, ty: PhantomData, @@ -1018,6 +1227,7 @@ pub struct ElementExpression<'ast, T, E> { impl<'ast, T, E> ElementExpression<'ast, T, E> { pub fn new(tuple: TupleExpression<'ast, T>, index: u32) -> Self { ElementExpression { + span: None, tuple: box tuple, index, ty: PhantomData, @@ -1037,8 +1247,12 @@ pub enum ConditionalKind { Ternary, } -#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct ConditionalExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub condition: Box>, pub consequence: Box, pub alternative: Box, @@ -1053,6 +1267,7 @@ impl<'ast, T, E> ConditionalExpression<'ast, T, E> { kind: ConditionalKind, ) -> Self { ConditionalExpression { + span: None, condition: box condition, consequence: box consequence, alternative: box alternative, @@ -1080,8 +1295,12 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpress } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct FunctionCallExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub function_key: DeclarationFunctionKey<'ast, T>, pub generics: Vec>>, pub arguments: Vec>, @@ -1095,6 +1314,7 @@ impl<'ast, T, E> FunctionCallExpression<'ast, T, E> { arguments: Vec>, ) -> Self { FunctionCallExpression { + span: None, function_key, generics, arguments, @@ -1136,51 +1356,21 @@ impl<'ast, T: fmt::Display, E> fmt::Display for FunctionCallExpression<'ast, T, #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum FieldElementExpression<'ast, T> { Block(BlockExpression<'ast, T, Self>), - Number(T), + Value(FieldValueExpression), Identifier(IdentifierExpression<'ast, Self>), - Add( - Box>, - Box>, - ), - Sub( - Box>, - Box>, - ), - Mult( - Box>, - Box>, - ), - Div( - Box>, - Box>, - ), - Pow( - Box>, - Box>, - ), - And( - Box>, - Box>, - ), - Or( - Box>, - Box>, - ), - Xor( - Box>, - Box>, - ), - LeftShift( - Box>, - Box>, - ), - RightShift( - Box>, - Box>, - ), + Add(BinaryExpression), + Sub(BinaryExpression), + Mult(BinaryExpression), + Div(BinaryExpression), + Pow(BinaryExpression, Self>), + And(BinaryExpression), + Or(BinaryExpression), + Xor(BinaryExpression), + LeftShift(BinaryExpression, Self>), + RightShift(BinaryExpression, Self>), Conditional(ConditionalExpression<'ast, T, Self>), - Neg(Box>), - Pos(Box>), + Neg(UnaryExpression), + Pos(UnaryExpression), FunctionCall(FunctionCallExpression<'ast, T, Self>), Member(MemberExpression<'ast, T, Self>), Select(SelectExpression<'ast, T, Self>), @@ -1192,7 +1382,7 @@ impl<'ast, T: Field> From> for TupleExpression<'ast, T> { match assignee { TypedAssignee::Identifier(v) => { let inner = TupleExpression::identifier(v.id); - match v._type { + match v.ty { GType::Tuple(tuple_ty) => inner.annotate(tuple_ty), _ => unreachable!(), } @@ -1209,7 +1399,7 @@ impl<'ast, T: Field> From> for StructExpression<'ast, T> match assignee { TypedAssignee::Identifier(v) => { let inner = StructExpression::identifier(v.id); - match v._type { + match v.ty { GType::Struct(struct_ty) => inner.annotate(struct_ty), _ => unreachable!(), } @@ -1226,7 +1416,7 @@ impl<'ast, T: Field> From> for ArrayExpression<'ast, T> { match assignee { TypedAssignee::Identifier(v) => { let inner = ArrayExpression::identifier(v.id); - match v._type { + match v.ty { GType::Array(array_ty) => inner.annotate(*array_ty.ty, *array_ty.size), _ => unreachable!(), } @@ -1253,47 +1443,91 @@ impl<'ast, T: Field> From> for FieldElementExpression<'as } } -impl<'ast, T> Add for FieldElementExpression<'ast, T> { +impl<'ast, T> std::ops::Add for FieldElementExpression<'ast, T> { type Output = Self; fn add(self, other: Self) -> Self { - FieldElementExpression::Add(box self, box other) + FieldElementExpression::Add(BinaryExpression::new(self, other)) } } -impl<'ast, T> Sub for FieldElementExpression<'ast, T> { +impl<'ast, T: Field> std::ops::Sub for FieldElementExpression<'ast, T> { type Output = Self; fn sub(self, other: Self) -> Self { - FieldElementExpression::Sub(box self, box other) + FieldElementExpression::Sub(BinaryExpression::new(self, other)) } } -impl<'ast, T> Mul for FieldElementExpression<'ast, T> { +impl<'ast, T: Field> std::ops::Mul for FieldElementExpression<'ast, T> { type Output = Self; fn mul(self, other: Self) -> Self { - FieldElementExpression::Mult(box self, box other) + FieldElementExpression::Mult(BinaryExpression::new(self, other)) } } -impl<'ast, T> Div for FieldElementExpression<'ast, T> { +impl<'ast, T: Field> std::ops::Div for FieldElementExpression<'ast, T> { type Output = Self; fn div(self, other: Self) -> Self { - FieldElementExpression::Div(box self, box other) + FieldElementExpression::Div(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::BitAnd for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + FieldElementExpression::And(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::BitOr for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + FieldElementExpression::Or(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::BitXor for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitxor(self, other: Self) -> Self { + FieldElementExpression::Xor(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T> std::ops::Neg for FieldElementExpression<'ast, T> { + type Output = Self; + + fn neg(self) -> Self { + FieldElementExpression::Neg(UnaryExpression::new(self)) } } -impl<'ast, T> FieldElementExpression<'ast, T> { +impl<'ast, T: Field> FieldElementExpression<'ast, T> { pub fn pow(self, other: UExpression<'ast, T>) -> Self { - FieldElementExpression::Pow(box self, box other) + FieldElementExpression::Pow(BinaryExpression::new(self, other)) + } + + pub fn pos(self) -> Self { + FieldElementExpression::Pos(UnaryExpression::new(self)) + } + + pub fn left_shift(self, by: UExpression<'ast, T>) -> Self { + FieldElementExpression::LeftShift(BinaryExpression::new(self, by)) + } + + pub fn right_shift(self, by: UExpression<'ast, T>) -> Self { + FieldElementExpression::RightShift(BinaryExpression::new(self, by)) } } -impl<'ast, T> From for FieldElementExpression<'ast, T> { +impl<'ast, T: Clone> From for FieldElementExpression<'ast, T> { fn from(n: T) -> Self { - FieldElementExpression::Number(n) + FieldElementExpression::Value(ValueExpression::new(n)) } } @@ -1302,42 +1536,41 @@ impl<'ast, T> From for FieldElementExpression<'ast, T> { pub enum BooleanExpression<'ast, T> { Block(BlockExpression<'ast, T, Self>), Identifier(IdentifierExpression<'ast, Self>), - Value(bool), + Value(BooleanValueExpression), FieldLt( - Box>, - Box>, + BinaryExpression< + OpLt, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), FieldLe( - Box>, - Box>, - ), - FieldGe( - Box>, - Box>, - ), - FieldGt( - Box>, - Box>, - ), - UintLt(Box>, Box>), - UintLe(Box>, Box>), - UintGe(Box>, Box>), - UintGt(Box>, Box>), - FieldEq(EqExpression>), - BoolEq(EqExpression>), - ArrayEq(EqExpression>), - StructEq(EqExpression>), - TupleEq(EqExpression>), - UintEq(EqExpression>), - Or( - Box>, - Box>, + BinaryExpression< + OpLe, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), - And( - Box>, - Box>, + UintLt(BinaryExpression, UExpression<'ast, T>, Self>), + UintLe(BinaryExpression, UExpression<'ast, T>, Self>), + FieldEq( + BinaryExpression< + OpEq, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), - Not(Box>), + BoolEq(BinaryExpression, BooleanExpression<'ast, T>, Self>), + ArrayEq(BinaryExpression, ArrayExpression<'ast, T>, Self>), + StructEq(BinaryExpression, StructExpression<'ast, T>, Self>), + TupleEq(BinaryExpression, TupleExpression<'ast, T>, Self>), + UintEq(BinaryExpression, UExpression<'ast, T>, Self>), + Or(BinaryExpression, BooleanExpression<'ast, T>, Self>), + And(BinaryExpression, BooleanExpression<'ast, T>, Self>), + Not(UnaryExpression), Conditional(ConditionalExpression<'ast, T, Self>), Member(MemberExpression<'ast, T, Self>), FunctionCall(FunctionCallExpression<'ast, T, Self>), @@ -1347,7 +1580,104 @@ pub enum BooleanExpression<'ast, T> { impl<'ast, T> From for BooleanExpression<'ast, T> { fn from(b: bool) -> Self { - BooleanExpression::Value(b) + BooleanExpression::Value(ValueExpression::new(b)) + } +} + +impl<'ast, T> std::ops::Not for BooleanExpression<'ast, T> { + type Output = Self; + + fn not(self) -> Self { + Self::Not(UnaryExpression::new(self)) + } +} + +impl<'ast, T> std::ops::BitAnd for BooleanExpression<'ast, T> { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + Self::And(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T> std::ops::BitOr for BooleanExpression<'ast, T> { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + Self::Or(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T> BooleanExpression<'ast, T> { + pub fn uint_eq(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintEq(BinaryExpression::new(left, right)) + } + + pub fn bool_eq(left: BooleanExpression<'ast, T>, right: BooleanExpression<'ast, T>) -> Self { + Self::BoolEq(BinaryExpression::new(left, right)) + } + + pub fn field_eq( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldEq(BinaryExpression::new(left, right)) + } + + pub fn struct_eq(left: StructExpression<'ast, T>, right: StructExpression<'ast, T>) -> Self { + Self::StructEq(BinaryExpression::new(left, right)) + } + + pub fn array_eq(left: ArrayExpression<'ast, T>, right: ArrayExpression<'ast, T>) -> Self { + Self::ArrayEq(BinaryExpression::new(left, right)) + } + + pub fn tuple_eq(left: TupleExpression<'ast, T>, right: TupleExpression<'ast, T>) -> Self { + Self::TupleEq(BinaryExpression::new(left, right)) + } + + pub fn uint_lt(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLt(BinaryExpression::new(left, right)) + } + + pub fn uint_le(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLe(BinaryExpression::new(left, right)) + } + + pub fn uint_gt(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLt(BinaryExpression::new(right, left)) + } + + pub fn uint_ge(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLe(BinaryExpression::new(right, left)) + } + + pub fn field_lt( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLt(BinaryExpression::new(left, right)) + } + + pub fn field_le( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLe(BinaryExpression::new(left, right)) + } + + pub fn field_gt( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLt(BinaryExpression::new(right, left)) + } + + pub fn field_ge( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLe(BinaryExpression::new(right, left)) } } @@ -1362,25 +1692,34 @@ pub struct ArrayExpression<'ast, T> { pub inner: ArrayExpressionInner<'ast, T>, } -#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)] -pub struct ArrayValue<'ast, T>(pub Vec>); - -impl<'ast, T> From>> for ArrayValue<'ast, T> { - fn from(array: Vec>) -> Self { - Self(array) - } -} +type ArrayValueExpression<'ast, T> = ValueExpression>>; -impl<'ast, T> IntoIterator for ArrayValue<'ast, T> { +impl<'ast, T> IntoIterator for ArrayValueExpression<'ast, T> { type Item = TypedExpressionOrSpread<'ast, T>; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() + self.value.into_iter() + } +} + +impl<'ast, T> Deref for ArrayValueExpression<'ast, T> { + type Target = [TypedExpressionOrSpread<'ast, T>]; + + fn deref(&self) -> &Self::Target { + &self.value[..] + } +} + +impl<'ast, T> std::iter::FromIterator> + for ArrayValueExpression<'ast, T> +{ + fn from_iter>>(iter: I) -> Self { + Self::new(iter.into_iter().collect()) } } -impl<'ast, T: Field> ArrayValue<'ast, T> { +impl<'ast, T: Field> ArrayValueExpression<'ast, T> { fn expression_at_aux< U: Select<'ast, T> + From> + Into>, >( @@ -1396,7 +1735,7 @@ impl<'ast, T: Field> ArrayValue<'ast, T> { ArrayExpressionInner::Value(v) => { v.into_iter().flat_map(Self::expression_at_aux).collect() } - a => (0..size) + a => (0..size.value) .map(|i| { Some(U::select( a.clone() @@ -1418,8 +1757,7 @@ impl<'ast, T: Field> ArrayValue<'ast, T> { &self, index: usize, ) -> Option { - self.0 - .iter() + self.iter() .flat_map(|v| Self::expression_at_aux(v.clone())) .take_while(|e| e.is_some()) .map(|e| e.unwrap()) @@ -1427,34 +1765,18 @@ impl<'ast, T: Field> ArrayValue<'ast, T> { } } -impl<'ast, T> ArrayValue<'ast, T> { - fn iter(&self) -> std::slice::Iter> { - self.0.iter() - } -} - -impl<'ast, T> std::iter::FromIterator> for ArrayValue<'ast, T> { - fn from_iter>>(iter: I) -> Self { - Self(iter.into_iter().collect()) - } -} - #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum ArrayExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, ArrayExpression<'ast, T>>), Identifier(IdentifierExpression<'ast, ArrayExpression<'ast, T>>), - Value(ArrayValue<'ast, T>), + Value(ArrayValueExpression<'ast, T>), FunctionCall(FunctionCallExpression<'ast, T, ArrayExpression<'ast, T>>), Conditional(ConditionalExpression<'ast, T, ArrayExpression<'ast, T>>), Member(MemberExpression<'ast, T, ArrayExpression<'ast, T>>), Select(SelectExpression<'ast, T, ArrayExpression<'ast, T>>), Element(ElementExpression<'ast, T, ArrayExpression<'ast, T>>), - Slice( - Box>, - Box>, - Box>, - ), - Repeat(Box>, Box>), + Slice(SliceExpression<'ast, T>), + Repeat(RepeatExpression<'ast, T>), } impl<'ast, T> ArrayExpressionInner<'ast, T> { @@ -1478,6 +1800,22 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> { pub fn size(&self) -> UExpression<'ast, T> { *self.ty.size.clone() } + + pub fn slice( + array: ArrayExpression<'ast, T>, + from: UExpression<'ast, T>, + to: UExpression<'ast, T>, + ) -> Self { + let inner = array.inner_type().clone(); + let size = to.clone() - from.clone(); + ArrayExpressionInner::Slice(SliceExpression::new(array, from, to)).annotate(inner, size) + } + + pub fn repeat(e: TypedExpression<'ast, T>, count: UExpression<'ast, T>) -> Self { + let inner = e.get_type().clone(); + let size = count.clone(); + ArrayExpressionInner::Repeat(RepeatExpression::new(e, count)).annotate(inner, size) + } } #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] @@ -1486,6 +1824,31 @@ pub struct StructExpression<'ast, T> { inner: StructExpressionInner<'ast, T>, } +type StructValueExpression<'ast, T> = ValueExpression>>; + +impl<'ast, T> IntoIterator for StructValueExpression<'ast, T> { + type Item = TypedExpression<'ast, T>; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.value.into_iter() + } +} + +impl<'ast, T> Deref for StructValueExpression<'ast, T> { + type Target = [TypedExpression<'ast, T>]; + + fn deref(&self) -> &Self::Target { + &self.value[..] + } +} + +impl<'ast, T> std::iter::FromIterator> for StructValueExpression<'ast, T> { + fn from_iter>>(iter: I) -> Self { + Self::new(iter.into_iter().collect()) + } +} + impl<'ast, T> StructExpression<'ast, T> { pub fn ty(&self) -> &StructType<'ast, T> { &self.ty @@ -1508,7 +1871,7 @@ impl<'ast, T> StructExpression<'ast, T> { pub enum StructExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, StructExpression<'ast, T>>), Identifier(IdentifierExpression<'ast, StructExpression<'ast, T>>), - Value(Vec>), + Value(TupleValueExpression<'ast, T>), FunctionCall(FunctionCallExpression<'ast, T, StructExpression<'ast, T>>), Conditional(ConditionalExpression<'ast, T, StructExpression<'ast, T>>), Member(MemberExpression<'ast, T, StructExpression<'ast, T>>), @@ -1546,11 +1909,13 @@ impl<'ast, T> TupleExpression<'ast, T> { } } +type TupleValueExpression<'ast, T> = ValueExpression>>; + #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TupleExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, TupleExpression<'ast, T>>), Identifier(IdentifierExpression<'ast, TupleExpression<'ast, T>>), - Value(Vec>), + Value(TupleValueExpression<'ast, T>), FunctionCall(FunctionCallExpression<'ast, T, TupleExpression<'ast, T>>), Conditional(ConditionalExpression<'ast, T, TupleExpression<'ast, T>>), Member(MemberExpression<'ast, T, TupleExpression<'ast, T>>), @@ -1715,7 +2080,7 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for BlockExpression<'a .map(|s| s.to_string()) .chain(std::iter::once(self.value.to_string())) .collect::>() - .join("\n") + .join("\n"), ) } } @@ -1724,21 +2089,21 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { FieldElementExpression::Block(ref block) => write!(f, "{}", block), - FieldElementExpression::Number(ref i) => write!(f, "{}f", i), + FieldElementExpression::Value(ref i) => write!(f, "{}f", i), FieldElementExpression::Identifier(ref var) => write!(f, "{}", var), - FieldElementExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - FieldElementExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), - FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs), - FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e), - FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e), - FieldElementExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), - FieldElementExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), - FieldElementExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), - FieldElementExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), - FieldElementExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), + FieldElementExpression::Add(ref e) => write!(f, "{}", e), + FieldElementExpression::Sub(ref e) => write!(f, "{}", e), + FieldElementExpression::Mult(ref e) => write!(f, "{}", e), + FieldElementExpression::Div(ref e) => write!(f, "{}", e), + FieldElementExpression::Pow(ref e) => write!(f, "{}", e), + FieldElementExpression::Neg(ref e) => write!(f, "{}", e), + FieldElementExpression::Pos(ref e) => write!(f, "{}", e), FieldElementExpression::Conditional(ref c) => write!(f, "{}", c), + FieldElementExpression::And(ref e) => write!(f, "{}", e), + FieldElementExpression::Or(ref e) => write!(f, "{}", e), + FieldElementExpression::Xor(ref e) => write!(f, "{}", e), + FieldElementExpression::LeftShift(ref e) => write!(f, "{}", e), + FieldElementExpression::RightShift(ref e) => write!(f, "{}", e), FieldElementExpression::FunctionCall(ref function_call) => { write!(f, "{}", function_call) } @@ -1755,22 +2120,22 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { UExpressionInner::Block(ref block) => write!(f, "{}", block,), UExpressionInner::Value(ref v) => write!(f, "{}", v), UExpressionInner::Identifier(ref var) => write!(f, "{}", var), - UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - UExpressionInner::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), - UExpressionInner::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), - UExpressionInner::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), - UExpressionInner::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - UExpressionInner::FloorSub(ref lhs, ref rhs) => { - write!(f, "(FLOOR_SUB({}, {}))", lhs, rhs) + UExpressionInner::Add(ref e) => write!(f, "{}", e), + UExpressionInner::And(ref e) => write!(f, "{}", e), + UExpressionInner::Or(ref e) => write!(f, "{}", e), + UExpressionInner::Xor(ref e) => write!(f, "{}", e), + UExpressionInner::Sub(ref e) => write!(f, "{}", e), + UExpressionInner::Mult(ref e) => write!(f, "{}", e), + UExpressionInner::FloorSub(ref e) => { + write!(f, "{}", e) } - UExpressionInner::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), - UExpressionInner::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs), - UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), - UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), - UExpressionInner::Not(ref e) => write!(f, "!{}", e), - UExpressionInner::Neg(ref e) => write!(f, "(-{})", e), - UExpressionInner::Pos(ref e) => write!(f, "(+{})", e), + UExpressionInner::Div(ref e) => write!(f, "{}", e), + UExpressionInner::Rem(ref e) => write!(f, "{}", e), + UExpressionInner::RightShift(ref e) => write!(f, "{}", e), + UExpressionInner::LeftShift(ref e) => write!(f, "{}", e), + UExpressionInner::Not(ref e) => write!(f, "{}", e), + UExpressionInner::Neg(ref e) => write!(f, "{}", e), + UExpressionInner::Pos(ref e) => write!(f, "{}", e), UExpressionInner::Select(ref select) => write!(f, "{}", select), UExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call), UExpressionInner::Conditional(ref c) => write!(f, "{}", c), @@ -1782,27 +2147,23 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { + match &self { BooleanExpression::Block(ref block) => write!(f, "{}", block,), BooleanExpression::Identifier(ref var) => write!(f, "{}", var), - BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs), - BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs), - BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "({} >= {})", lhs, rhs), - BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "({} > {})", lhs, rhs), - BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs), - BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs), - BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "({} >= {})", lhs, rhs), - BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "({} > {})", lhs, rhs), + BooleanExpression::FieldLt(ref e) => write!(f, "{}", e), + BooleanExpression::FieldLe(ref e) => write!(f, "{}", e), + BooleanExpression::UintLt(ref e) => write!(f, "{}", e), + BooleanExpression::UintLe(ref e) => write!(f, "{}", e), BooleanExpression::FieldEq(ref e) => write!(f, "{}", e), BooleanExpression::BoolEq(ref e) => write!(f, "{}", e), BooleanExpression::ArrayEq(ref e) => write!(f, "{}", e), BooleanExpression::StructEq(ref e) => write!(f, "{}", e), BooleanExpression::TupleEq(ref e) => write!(f, "{}", e), BooleanExpression::UintEq(ref e) => write!(f, "{}", e), - BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "({} || {})", lhs, rhs), - BooleanExpression::And(ref lhs, ref rhs) => write!(f, "({} && {})", lhs, rhs), - BooleanExpression::Not(ref exp) => write!(f, "!{}", exp), - BooleanExpression::Value(b) => write!(f, "{}", b), + BooleanExpression::Or(ref e) => write!(f, "{}", e), + BooleanExpression::And(ref e) => write!(f, "{}", e), + BooleanExpression::Not(ref exp) => write!(f, "{}", exp), + BooleanExpression::Value(ref b) => write!(f, "{}", b), BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call), BooleanExpression::Conditional(ref c) => write!(f, "{}", c), BooleanExpression::Member(ref m) => write!(f, "{}", m), @@ -1830,12 +2191,8 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> { ArrayExpressionInner::Conditional(ref c) => write!(f, "{}", c), ArrayExpressionInner::Member(ref m) => write!(f, "{}", m), ArrayExpressionInner::Select(ref select) => write!(f, "{}", select), - ArrayExpressionInner::Slice(ref a, ref from, ref to) => { - write!(f, "{}[{}..{}]", a, from, to) - } - ArrayExpressionInner::Repeat(ref e, ref count) => { - write!(f, "[{}; {}]", e, count) - } + ArrayExpressionInner::Slice(ref e) => write!(f, "{}", e), + ArrayExpressionInner::Repeat(ref e) => write!(f, "{}", e), ArrayExpressionInner::Element(ref element) => write!(f, "{}", element), } } @@ -1859,10 +2216,507 @@ impl<'ast, T: Field> From> for TypedExpression<'ast, T> { } } +// TODO: MACROS + +impl<'ast, T> WithSpan for TypedExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use TypedExpression::*; + match self { + Boolean(e) => Boolean(e.span(span)), + FieldElement(e) => FieldElement(e.span(span)), + Uint(e) => Uint(e.span(span)), + Array(e) => Array(e.span(span)), + Struct(e) => Struct(e.span(span)), + Tuple(e) => Tuple(e.span(span)), + Int(e) => Int(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use TypedExpression::*; + match self { + Boolean(e) => e.get_span(), + FieldElement(e) => e.get_span(), + Uint(e) => e.get_span(), + Array(e) => e.get_span(), + Struct(e) => e.get_span(), + Tuple(e) => e.get_span(), + Int(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for FieldElementExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use FieldElementExpression::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Add(e) => Add(e.span(span)), + Value(e) => Value(e.span(span)), + Mult(e) => Mult(e.span(span)), + Sub(e) => Sub(e.span(span)), + Pow(e) => Pow(e.span(span)), + Div(e) => Div(e.span(span)), + Pos(e) => Pos(e.span(span)), + Neg(e) => Neg(e.span(span)), + And(e) => And(e.span(span)), + Or(e) => Or(e.span(span)), + Xor(e) => Xor(e.span(span)), + LeftShift(e) => LeftShift(e.span(span)), + RightShift(e) => RightShift(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use FieldElementExpression::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Add(e) => e.get_span(), + Value(e) => e.get_span(), + Sub(e) => e.get_span(), + Mult(e) => e.get_span(), + Div(e) => e.get_span(), + Pow(e) => e.get_span(), + Neg(e) => e.get_span(), + Pos(e) => e.get_span(), + And(e) => e.get_span(), + Or(e) => e.get_span(), + Xor(e) => e.get_span(), + LeftShift(e) => e.get_span(), + RightShift(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for BooleanExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use BooleanExpression::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Value(e) => Value(e.span(span)), + FieldLt(e) => FieldLt(e.span(span)), + FieldLe(e) => FieldLe(e.span(span)), + UintLt(e) => UintLt(e.span(span)), + UintLe(e) => UintLe(e.span(span)), + FieldEq(e) => FieldEq(e.span(span)), + BoolEq(e) => BoolEq(e.span(span)), + ArrayEq(e) => ArrayEq(e.span(span)), + StructEq(e) => StructEq(e.span(span)), + TupleEq(e) => TupleEq(e.span(span)), + UintEq(e) => UintEq(e.span(span)), + Or(e) => Or(e.span(span)), + And(e) => And(e.span(span)), + Not(e) => Not(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use BooleanExpression::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Value(e) => e.get_span(), + FieldLt(e) => e.get_span(), + FieldLe(e) => e.get_span(), + UintLt(e) => e.get_span(), + UintLe(e) => e.get_span(), + FieldEq(e) => e.get_span(), + BoolEq(e) => e.get_span(), + ArrayEq(e) => e.get_span(), + StructEq(e) => e.get_span(), + TupleEq(e) => e.get_span(), + UintEq(e) => e.get_span(), + Or(e) => e.get_span(), + And(e) => e.get_span(), + Not(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for UExpressionInner<'ast, T> { + fn span(self, span: Option) -> Self { + use UExpressionInner::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Value(e) => Value(e.span(span)), + Add(e) => Add(e.span(span)), + Sub(e) => Sub(e.span(span)), + FloorSub(e) => FloorSub(e.span(span)), + Mult(e) => Mult(e.span(span)), + Div(e) => Div(e.span(span)), + Rem(e) => Rem(e.span(span)), + Xor(e) => Xor(e.span(span)), + And(e) => And(e.span(span)), + Or(e) => Or(e.span(span)), + Not(e) => Not(e.span(span)), + Neg(e) => Neg(e.span(span)), + Pos(e) => Pos(e.span(span)), + LeftShift(e) => LeftShift(e.span(span)), + RightShift(e) => RightShift(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use UExpressionInner::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Value(e) => e.get_span(), + Add(e) => e.get_span(), + Sub(e) => e.get_span(), + FloorSub(e) => e.get_span(), + Mult(e) => e.get_span(), + Div(e) => e.get_span(), + Rem(e) => e.get_span(), + Xor(e) => e.get_span(), + And(e) => e.get_span(), + Or(e) => e.get_span(), + Not(e) => e.get_span(), + Neg(e) => e.get_span(), + Pos(e) => e.get_span(), + LeftShift(e) => e.get_span(), + RightShift(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for ArrayExpressionInner<'ast, T> { + fn span(self, span: Option) -> Self { + use ArrayExpressionInner::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Value(e) => Value(e.span(span)), + Slice(e) => Slice(e.span(span)), + Repeat(e) => Repeat(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use ArrayExpressionInner::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Value(e) => e.get_span(), + Slice(e) => e.get_span(), + Repeat(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for StructExpressionInner<'ast, T> { + fn span(self, span: Option) -> Self { + use StructExpressionInner::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Value(e) => Value(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use StructExpressionInner::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Value(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for TupleExpressionInner<'ast, T> { + fn span(self, span: Option) -> Self { + use TupleExpressionInner::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Value(e) => Value(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use TupleExpressionInner::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Value(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for IntExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use IntExpression::*; + match self { + Conditional(e) => Conditional(e.span(span)), + Select(e) => Select(e.span(span)), + Value(e) => Value(e.span(span)), + Pos(e) => Pos(e.span(span)), + Neg(e) => Neg(e.span(span)), + Not(e) => Not(e.span(span)), + Add(e) => Add(e.span(span)), + Sub(e) => Sub(e.span(span)), + Mult(e) => Mult(e.span(span)), + Div(e) => Div(e.span(span)), + Rem(e) => Rem(e.span(span)), + Pow(e) => Pow(e.span(span)), + Xor(e) => Xor(e.span(span)), + And(e) => And(e.span(span)), + Or(e) => Or(e.span(span)), + LeftShift(e) => LeftShift(e.span(span)), + RightShift(e) => RightShift(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use IntExpression::*; + match self { + Conditional(e) => e.get_span(), + Select(e) => e.get_span(), + Value(e) => e.get_span(), + Pos(e) => e.get_span(), + Neg(e) => e.get_span(), + Not(e) => e.get_span(), + Add(e) => e.get_span(), + Sub(e) => e.get_span(), + Mult(e) => e.get_span(), + Div(e) => e.get_span(), + Rem(e) => e.get_span(), + Pow(e) => e.get_span(), + Xor(e) => e.get_span(), + And(e) => e.get_span(), + Or(e) => e.get_span(), + LeftShift(e) => e.get_span(), + RightShift(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for TupleExpression<'ast, T> { + fn span(self, span: Option) -> Self { + TupleExpression { + inner: self.inner.span(span), + ..self + } + } + + fn get_span(&self) -> Option { + self.inner.get_span() + } +} + +impl<'ast, T> WithSpan for StructExpression<'ast, T> { + fn span(self, span: Option) -> Self { + StructExpression { + inner: self.inner.span(span), + ..self + } + } + + fn get_span(&self) -> Option { + self.inner.get_span() + } +} + +impl<'ast, T> WithSpan for ArrayExpression<'ast, T> { + fn span(self, span: Option) -> Self { + ArrayExpression { + inner: self.inner.span(span), + ..self + } + } + + fn get_span(&self) -> Option { + self.inner.get_span() + } +} + +impl<'ast, T> WithSpan for UExpression<'ast, T> { + fn span(self, span: Option) -> Self { + UExpression { + inner: self.inner.span(span), + ..self + } + } + + fn get_span(&self) -> Option { + self.inner.get_span() + } +} + +impl<'ast, T, E> WithSpan for ConditionalExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl WithSpan for ValueExpression { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T, E> WithSpan for SelectExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T, E> WithSpan for ElementExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T, E> WithSpan for MemberExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T, E> WithSpan for FunctionCallExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T, E> WithSpan for BlockExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T: Clone> Value for FieldElementExpression<'ast, T> { + type Value = T; +} + +impl<'ast, T> Value for BooleanExpression<'ast, T> { + type Value = bool; +} + +impl<'ast, T> Value for UExpression<'ast, T> { + type Value = u128; +} + +impl<'ast, T: Clone> Value for ArrayExpression<'ast, T> { + type Value = Vec>; +} + +impl<'ast, T: Clone> Value for StructExpression<'ast, T> { + type Value = Vec>; +} + +impl<'ast, T: Clone> Value for TupleExpression<'ast, T> { + type Value = Vec>; +} + +impl<'ast, T> Value for IntExpression<'ast, T> { + type Value = BigUint; +} + // Common behaviour across expressions -pub trait Expr<'ast, T>: fmt::Display + From> { - type Inner; +pub trait Expr<'ast, T>: Value + WithSpan + fmt::Display + From> { + type Inner: WithSpan; type Ty: Clone + IntoType>; type ConcreteTy: Clone + IntoType; @@ -1873,9 +2727,11 @@ pub trait Expr<'ast, T>: fmt::Display + From> { fn as_inner(&self) -> &Self::Inner; fn as_inner_mut(&mut self) -> &mut Self::Inner; + + fn value(_: Self::Value) -> Self::Inner; } -impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; type ConcreteTy = ConcreteType; @@ -1895,9 +2751,13 @@ impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { self } + + fn value(v: ::Value) -> Self::Inner { + Self::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for BooleanExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; type ConcreteTy = ConcreteType; @@ -1917,9 +2777,13 @@ impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { self } + + fn value(v: ::Value) -> Self::Inner { + Self::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for UExpression<'ast, T> { type Inner = UExpressionInner<'ast, T>; type Ty = UBitwidth; type ConcreteTy = UBitwidth; @@ -1939,9 +2803,13 @@ impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { &mut self.inner } + + fn value(v: Self::Value) -> Self::Inner { + UExpressionInner::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for StructExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for StructExpression<'ast, T> { type Inner = StructExpressionInner<'ast, T>; type Ty = StructType<'ast, T>; type ConcreteTy = ConcreteStructType; @@ -1961,9 +2829,13 @@ impl<'ast, T: Field> Expr<'ast, T> for StructExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { &mut self.inner } + + fn value(v: Self::Value) -> Self::Inner { + StructExpressionInner::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for ArrayExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for ArrayExpression<'ast, T> { type Inner = ArrayExpressionInner<'ast, T>; type Ty = ArrayType<'ast, T>; type ConcreteTy = ConcreteArrayType; @@ -1983,9 +2855,13 @@ impl<'ast, T: Field> Expr<'ast, T> for ArrayExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { &mut self.inner } + + fn value(v: Self::Value) -> Self::Inner { + ArrayExpressionInner::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for TupleExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for TupleExpression<'ast, T> { type Inner = TupleExpressionInner<'ast, T>; type Ty = TupleType<'ast, T>; type ConcreteTy = ConcreteTupleType; @@ -2005,9 +2881,13 @@ impl<'ast, T: Field> Expr<'ast, T> for TupleExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { &mut self.inner } + + fn value(v: Self::Value) -> Self::Inner { + TupleExpressionInner::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for IntExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for IntExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; type ConcreteTy = ConcreteType; @@ -2027,6 +2907,10 @@ impl<'ast, T: Field> Expr<'ast, T> for IntExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { self } + + fn value(v: as Value>::Value) -> Self { + IntExpression::Value(ValueExpression::new(v)) + } } // Enums types to enable returning e.g a member expression OR another type of expression of this type @@ -2039,20 +2923,19 @@ pub enum SelectOrExpression<'ast, T, E: Expr<'ast, T>> { Select(SelectExpression<'ast, T, E>), Expression(E::Inner), } - -pub enum EqOrBoolean<'ast, T, E> { - Eq(EqExpression), - Boolean(BooleanExpression<'ast, T>), -} - pub enum MemberOrExpression<'ast, T, E: Expr<'ast, T>> { Member(MemberExpression<'ast, T, E>), Expression(E::Inner), } -pub enum IdentifierOrExpression<'ast, T, E: Expr<'ast, T>> { - Identifier(IdentifierExpression<'ast, E>), - Expression(E::Inner), +pub enum RepeatOrExpression<'ast, T> { + Repeat(RepeatExpression<'ast, T>), + Expression(ArrayExpressionInner<'ast, T>), +} + +pub enum SliceOrExpression<'ast, T> { + Slice(SliceExpression<'ast, T>), + Expression(ArrayExpressionInner<'ast, T>), } pub enum ElementOrExpression<'ast, T, E: Expr<'ast, T>> { @@ -2065,6 +2948,22 @@ pub enum ConditionalOrExpression<'ast, T, E: Expr<'ast, T>> { Expression(E::Inner), } +pub trait Basic<'ast, T> { + type ZirExpressionType: WithSpan + Value + From>; +} + +impl<'ast, T: Clone> Basic<'ast, T> for FieldElementExpression<'ast, T> { + type ZirExpressionType = crate::zir::FieldElementExpression<'ast, T>; +} + +impl<'ast, T> Basic<'ast, T> for BooleanExpression<'ast, T> { + type ZirExpressionType = crate::zir::BooleanExpression<'ast, T>; +} + +impl<'ast, T> Basic<'ast, T> for UExpression<'ast, T> { + type ZirExpressionType = crate::zir::UExpression<'ast, T>; +} + pub trait Conditional<'ast, T> { fn conditional( condition: BooleanExpression<'ast, T>, @@ -2576,7 +3475,7 @@ pub trait Constant: Sized { impl<'ast, T: Field> Constant for FieldElementExpression<'ast, T> { fn is_constant(&self) -> bool { - matches!(self, FieldElementExpression::Number(..)) + matches!(self, FieldElementExpression::Value(..)) } } @@ -2595,16 +3494,14 @@ impl<'ast, T: Field> Constant for UExpression<'ast, T> { impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { fn is_constant(&self) -> bool { match self.as_inner() { - ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e { + ArrayExpressionInner::Value(v) => v.iter().all(|e| match e { TypedExpressionOrSpread::Expression(e) => e.is_constant(), TypedExpressionOrSpread::Spread(s) => s.array.is_constant(), }), - ArrayExpressionInner::Slice(box a, box from, box to) => { - from.is_constant() && to.is_constant() && a.is_constant() - } - ArrayExpressionInner::Repeat(box e, box count) => { - count.is_constant() && e.is_constant() + ArrayExpressionInner::Slice(e) => { + e.from.is_constant() && e.to.is_constant() && e.array.is_constant() } + ArrayExpressionInner::Repeat(e) => e.count.is_constant() && e.e.is_constant(), _ => false, } } @@ -2620,35 +3517,35 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { .into_iter() .flat_map(into_canonical_constant_aux) .collect(), - ArrayExpressionInner::Slice(box v, box from, box to) => { - let from = match from.into_inner() { + ArrayExpressionInner::Slice(e) => { + let from = match e.from.into_inner() { UExpressionInner::Value(v) => v, _ => unreachable!(), }; - let to = match to.into_inner() { + let to = match e.to.into_inner() { UExpressionInner::Value(v) => v, _ => unreachable!(), }; - let v = match v.into_inner() { + let v = match e.array.into_inner() { ArrayExpressionInner::Value(v) => v, _ => unreachable!(), }; v.into_iter() .flat_map(into_canonical_constant_aux) - .skip(from as usize) - .take(to as usize - from as usize) + .skip(from.value as usize) + .take(to.value as usize - from.value as usize) .collect() } - ArrayExpressionInner::Repeat(box e, box count) => { - let count = match count.into_inner() { + ArrayExpressionInner::Repeat(e) => { + let count = match e.count.into_inner() { UExpressionInner::Value(count) => count, _ => unreachable!(), }; - vec![e.into_canonical_constant(); count as usize] + vec![e.e.into_canonical_constant(); count.value as usize] } a => unreachable!("{}", a), }, @@ -2658,53 +3555,49 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { let array_ty = self.ty().clone(); match self.into_inner() { - ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value( + ArrayExpressionInner::Value(v) => ArrayExpression::value( v.into_iter() .flat_map(into_canonical_constant_aux) .map(|e| e.into()) - .collect::>() - .into(), + .collect::>(), ) .annotate(*array_ty.ty, *array_ty.size), - ArrayExpressionInner::Slice(box a, box from, box to) => { - let from = match from.into_inner() { - UExpressionInner::Value(from) => from as usize, + ArrayExpressionInner::Slice(e) => { + let from = match e.from.into_inner() { + UExpressionInner::Value(from) => from.value as usize, _ => unreachable!("should be a uint value"), }; - let to = match to.into_inner() { - UExpressionInner::Value(to) => to as usize, + let to = match e.to.into_inner() { + UExpressionInner::Value(to) => to.value as usize, _ => unreachable!("should be a uint value"), }; - let v = match a.into_inner() { + let v = match e.array.into_inner() { ArrayExpressionInner::Value(v) => v, _ => unreachable!("should be an array value"), }; - ArrayExpressionInner::Value( + ArrayExpression::value( v.into_iter() .flat_map(into_canonical_constant_aux) .map(|e| e.into()) .skip(from) .take(to - from) - .collect::>() - .into(), + .collect::>(), ) .annotate(*array_ty.ty, *array_ty.size) } - ArrayExpressionInner::Repeat(box e, box count) => { - let count = match count.into_inner() { - UExpressionInner::Value(from) => from as usize, + ArrayExpressionInner::Repeat(e) => { + let count = match e.count.into_inner() { + UExpressionInner::Value(from) => from.value as usize, _ => unreachable!("should be a uint value"), }; - let e = e.into_canonical_constant(); + let e = e.e.into_canonical_constant(); - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression(e); count].into(), - ) - .annotate(*array_ty.ty, *array_ty.size) + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression(e); count]) + .annotate(*array_ty.ty, *array_ty.size) } _ => unreachable!(), } @@ -2723,7 +3616,7 @@ impl<'ast, T: Field> Constant for StructExpression<'ast, T> { let struct_ty = self.ty().clone(); match self.into_inner() { - StructExpressionInner::Value(expressions) => StructExpressionInner::Value( + StructExpressionInner::Value(expressions) => StructExpression::value( expressions .into_iter() .map(|e| e.into_canonical_constant()) @@ -2747,7 +3640,7 @@ impl<'ast, T: Field> Constant for TupleExpression<'ast, T> { let tuple_ty = self.ty().clone(); match self.into_inner() { - TupleExpressionInner::Value(expressions) => TupleExpressionInner::Value( + TupleExpressionInner::Value(expressions) => TupleExpression::value( expressions .into_iter() .map(|e| e.into_canonical_constant()) diff --git a/zokrates_ast/src/typed/parameter.rs b/zokrates_ast/src/typed/parameter.rs index 45b0dcae4..4b8332eb1 100644 --- a/zokrates_ast/src/typed/parameter.rs +++ b/zokrates_ast/src/typed/parameter.rs @@ -1,33 +1,5 @@ use crate::typed::types::DeclarationConstant; use crate::typed::GVariable; -use std::fmt; - -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct GParameter<'ast, S> { - pub id: GVariable<'ast, S>, - pub private: bool, -} - -impl<'ast, S> From> for GParameter<'ast, S> { - fn from(v: GVariable<'ast, S>) -> Self { - GParameter { - id: v, - private: true, - } - } -} +pub type GParameter<'ast, S> = crate::common::Parameter>; pub type DeclarationParameter<'ast, T> = GParameter<'ast, DeclarationConstant<'ast, T>>; - -impl<'ast, S: fmt::Display> fmt::Display for GParameter<'ast, S> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let visibility = if self.private { "private " } else { "" }; - write!(f, "{}{} {}", visibility, self.id._type, self.id.id) - } -} - -impl<'ast, S: fmt::Debug> fmt::Debug for GParameter<'ast, S> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Parameter(variable: {:?})", self.id) - } -} diff --git a/zokrates_ast/src/typed/result_folder.rs b/zokrates_ast/src/typed/result_folder.rs index 8ed911314..531d45275 100644 --- a/zokrates_ast/src/typed/result_folder.rs +++ b/zokrates_ast/src/typed/result_folder.rs @@ -1,47 +1,57 @@ // Generic walk through a typed AST. Not mutating in place +use crate::common::expressions::{ + BinaryOrExpression, EqExpression, UnaryOrExpression, ValueOrExpression, +}; +use crate::common::ResultFold; use crate::typed::types::*; use crate::typed::*; use zokrates_field::Field; use super::identifier::FrameIdentifier; -pub trait ResultFold<'ast, T: Field>: Sized { - fn fold>(self, f: &mut F) -> Result; -} - -impl<'ast, T: Field> ResultFold<'ast, T> for FieldElementExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for FieldElementExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { f.fold_field_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for BooleanExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for BooleanExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { f.fold_boolean_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for UExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold for UExpression<'ast, T> { + fn fold(self, f: &mut F) -> Result { f.fold_uint_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for ArrayExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { - f.fold_array_expression(self) +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for StructExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { + f.fold_struct_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for StructExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { - f.fold_struct_expression(self) +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for ArrayExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { + f.fold_array_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for TupleExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for TupleExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { f.fold_tuple_expression(self) } } @@ -153,7 +163,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized { }) } - fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> Result { + fn fold_module_id(&mut self, i: OwnedModuleId) -> Result { Ok(i) } @@ -170,22 +180,19 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } fn fold_variable(&mut self, v: Variable<'ast, T>) -> Result, Self::Error> { - Ok(Variable { - id: self.fold_name(v.id)?, - _type: self.fold_type(v._type)?, - is_mutable: v.is_mutable, - }) + let span = v.get_span(); + Ok(Variable::new(self.fold_name(v.id)?, self.fold_type(v.ty)?).span(span)) } fn fold_declaration_variable( &mut self, v: DeclarationVariable<'ast, T>, ) -> Result, Self::Error> { - Ok(DeclarationVariable { - id: self.fold_name(v.id)?, - _type: self.fold_declaration_type(v._type)?, - is_mutable: v.is_mutable, - }) + let span = v.get_span(); + Ok( + DeclarationVariable::new(self.fold_name(v.id)?, self.fold_declaration_type(v.ty)?) + .span(span), + ) } fn fold_type(&mut self, t: Type<'ast, T>) -> Result, Self::Error> { @@ -200,7 +207,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } fn fold_conditional_expression< - E: Expr<'ast, T> + PartialEq + Conditional<'ast, T> + ResultFold<'ast, T>, + E: Expr<'ast, T> + PartialEq + Conditional<'ast, T> + ResultFold, >( &mut self, ty: &E::Ty, @@ -209,14 +216,55 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_conditional_expression(self, ty, e) } - fn fold_block_expression>( + #[allow(clippy::type_complexity)] + fn fold_binary_expression< + L: Expr<'ast, T> + PartialEq + ResultFold, + R: Expr<'ast, T> + PartialEq + ResultFold, + E: Expr<'ast, T> + PartialEq + ResultFold, + Op: OperatorStr, + >( + &mut self, + ty: &E::Ty, + e: BinaryExpression, + ) -> Result, Self::Error> { + fold_binary_expression(self, ty, e) + } + + #[allow(clippy::type_complexity)] + fn fold_eq_expression< + E: Expr<'ast, T> + Constant + Typed<'ast, T> + PartialEq + ResultFold, + >( + &mut self, + e: EqExpression>, + ) -> Result< + BinaryOrExpression, BooleanExpression<'ast, T>>, + Self::Error, + > { + fold_binary_expression(self, &Type::Boolean, e) + } + + fn fold_unary_expression< + In: Expr<'ast, T> + PartialEq + ResultFold, + E: Expr<'ast, T> + PartialEq + ResultFold, + Op, + >( + &mut self, + ty: &E::Ty, + e: UnaryExpression, + ) -> Result, Self::Error> { + fold_unary_expression(self, ty, e) + } + + fn fold_block_expression>( &mut self, block: BlockExpression<'ast, T, E>, ) -> Result, Self::Error> { fold_block_expression(self, block) } - fn fold_identifier_expression + Id<'ast, T> + ResultFold<'ast, T>>( + fn fold_identifier_expression< + E: Expr<'ast, T> + Id<'ast, T> + ResultFold, + >( &mut self, ty: &E::Ty, id: IdentifierExpression<'ast, E>, @@ -234,6 +282,20 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_member_expression(self, ty, e) } + fn fold_slice_expression( + &mut self, + e: SliceExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_slice_expression(self, e) + } + + fn fold_repeat_expression( + &mut self, + e: RepeatExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_repeat_expression(self, e) + } + fn fold_element_expression< E: Expr<'ast, T> + Element<'ast, T> + From>, >( @@ -244,15 +306,6 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_element_expression(self, ty, e) } - fn fold_eq_expression< - E: Expr<'ast, T> + Typed<'ast, T> + PartialEq + Constant + ResultFold<'ast, T>, - >( - &mut self, - e: EqExpression, - ) -> Result, Self::Error> { - fold_eq_expression(self, e) - } - fn fold_select_expression< E: Expr<'ast, T> + Select<'ast, T> @@ -389,6 +442,27 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_assignee(self, a) } + fn fold_assembly_block( + &mut self, + s: AssemblyBlockStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_block(self, s) + } + + fn fold_assembly_assignment( + &mut self, + s: AssemblyAssignment<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_assignment(self, s) + } + + fn fold_assembly_constraint( + &mut self, + s: AssemblyConstraint<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_constraint(self, s) + } + fn fold_assembly_statement( &mut self, s: TypedAssemblyStatement<'ast, T>, @@ -396,6 +470,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_assembly_statement(self, s) } + fn fold_assembly_statement_cases( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_statement_cases(self, s) + } + fn fold_statement( &mut self, s: TypedStatement<'ast, T>, @@ -403,6 +484,48 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_statement(self, s) } + fn fold_statement_cases( + &mut self, + s: TypedStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_statement_cases(self, s) + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_definition_statement(self, s) + } + + fn fold_return_statement( + &mut self, + s: ReturnStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_return_statement(self, s) + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assertion_statement(self, s) + } + + fn fold_log_statement( + &mut self, + s: LogStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_log_statement(self, s) + } + + fn fold_for_statement( + &mut self, + s: ForStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_for_statement(self, s) + } + fn fold_definition_rhs( &mut self, rhs: DefinitionRhs<'ast, T>, @@ -481,12 +604,28 @@ pub trait ResultFolder<'ast, T: Field>: Sized { ) -> Result, Self::Error> { fold_field_expression(self, e) } + + fn fold_field_expression_cases( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_field_expression_cases(self, e) + } + fn fold_boolean_expression( &mut self, e: BooleanExpression<'ast, T>, ) -> Result, Self::Error> { fold_boolean_expression(self, e) } + + fn fold_boolean_expression_cases( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_boolean_expression_cases(self, e) + } + fn fold_uint_expression( &mut self, e: UExpression<'ast, T>, @@ -502,6 +641,14 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_uint_expression_inner(self, bitwidth, e) } + fn fold_uint_expression_cases( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_uint_expression_cases(self, bitwidth, e) + } + fn fold_array_expression_inner( &mut self, ty: &ArrayType<'ast, T>, @@ -509,6 +656,15 @@ pub trait ResultFolder<'ast, T: Field>: Sized { ) -> Result, Self::Error> { fold_array_expression_inner(self, ty, e) } + + fn fold_array_expression_cases( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_array_expression_cases(self, ty, e) + } + fn fold_struct_expression_inner( &mut self, ty: &StructType<'ast, T>, @@ -517,6 +673,14 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_struct_expression_inner(self, ty, e) } + fn fold_struct_expression_cases( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_struct_expression_cases(self, ty, e) + } + fn fold_tuple_expression_inner( &mut self, ty: &TupleType<'ast, T>, @@ -524,69 +688,186 @@ pub trait ResultFolder<'ast, T: Field>: Sized { ) -> Result, Self::Error> { fold_tuple_expression_inner(self, ty, e) } + + fn fold_tuple_expression_cases( + &mut self, + ty: &TupleType<'ast, T>, + e: TupleExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_tuple_expression_cases(self, ty, e) + } + + fn fold_struct_value_expression( + &mut self, + ty: &StructType<'ast, T>, + v: StructValueExpression<'ast, T>, + ) -> Result< + ValueOrExpression, StructExpressionInner<'ast, T>>, + Self::Error, + > { + fold_struct_value_expression(self, ty, v) + } + + fn fold_array_value_expression( + &mut self, + ty: &ArrayType<'ast, T>, + v: ArrayValueExpression<'ast, T>, + ) -> Result< + ValueOrExpression, ArrayExpressionInner<'ast, T>>, + Self::Error, + > { + fold_array_value_expression(self, ty, v) + } + + fn fold_tuple_value_expression( + &mut self, + ty: &TupleType<'ast, T>, + v: TupleValueExpression<'ast, T>, + ) -> Result< + ValueOrExpression, TupleExpressionInner<'ast, T>>, + Self::Error, + > { + fold_tuple_value_expression(self, ty, v) + } } -pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( +pub fn fold_statement_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedStatement<'ast, T>, +) -> Result>, F::Error> { + match s { + TypedStatement::Return(s) => f.fold_return_statement(s), + TypedStatement::Definition(s) => f.fold_definition_statement(s), + TypedStatement::Assertion(s) => f.fold_assertion_statement(s), + TypedStatement::For(s) => f.fold_for_statement(s), + TypedStatement::Log(s) => f.fold_log_statement(s), + TypedStatement::Assembly(s) => f.fold_assembly_block(s), + } +} + +pub fn fold_assembly_block<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyBlockStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![TypedStatement::Assembly(AssemblyBlockStatement::new( + s.inner + .into_iter() + .map(|s| f.fold_assembly_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ))]) +} + +pub fn fold_assembly_assignment<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyAssignment<'ast, T>, +) -> Result>, F::Error> { + let assignee = f.fold_assignee(s.assignee)?; + let expression = f.fold_expression(s.expression)?; + Ok(vec![TypedAssemblyStatement::assignment( + assignee, expression, + )]) +} + +pub fn fold_assembly_constraint<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyConstraint<'ast, T>, +) -> Result>, F::Error> { + let left = f.fold_field_expression(s.left)?; + let right = f.fold_field_expression(s.right)?; + Ok(vec![TypedAssemblyStatement::constraint( + left, right, s.metadata, + )]) +} + +fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: TypedAssemblyStatement<'ast, T>, ) -> Result>, F::Error> { - Ok(match s { - TypedAssemblyStatement::Assignment(a, e) => { - let e = f.fold_expression(e)?; - vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, e)] - } - TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { - vec![TypedAssemblyStatement::Constraint( - f.fold_field_expression(lhs)?, - f.fold_field_expression(rhs)?, - metadata, - )] - } - }) + let span = s.get_span(); + f.fold_assembly_statement_cases(s) + .map(|s| s.into_iter().map(|s| s.span(span)).collect()) +} + +pub fn fold_assembly_statement_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedAssemblyStatement<'ast, T>, +) -> Result>, F::Error> { + match s { + TypedAssemblyStatement::Assignment(s) => f.fold_assembly_assignment(s), + TypedAssemblyStatement::Constraint(s) => f.fold_assembly_constraint(s), + } } pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: TypedStatement<'ast, T>, ) -> Result>, F::Error> { - let res = match s { - TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?), - TypedStatement::Definition(a, e) => { - let rhs = f.fold_definition_rhs(e)?; - TypedStatement::Definition(f.fold_assignee(a)?, rhs) - } - TypedStatement::Assertion(e, error) => { - TypedStatement::Assertion(f.fold_boolean_expression(e)?, error) - } - TypedStatement::For(v, from, to, statements) => TypedStatement::For( - f.fold_variable(v)?, - f.fold_uint_expression(from)?, - f.fold_uint_expression(to)?, - statements - .into_iter() - .map(|s| f.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), - ), - TypedStatement::Log(s, e) => TypedStatement::Log( - s, - e.into_iter() - .map(|e| f.fold_expression(e)) - .collect::, _>>()?, - ), - TypedStatement::Assembly(statements) => TypedStatement::Assembly( - statements - .into_iter() - .map(|s| f.fold_assembly_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), - ), - }; - Ok(vec![res]) + let span = s.get_span(); + f.fold_statement_cases(s) + .map(|s| s.into_iter().map(|s| s.span(span)).collect()) +} + +pub fn fold_return_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ReturnStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![TypedStatement::Return(ReturnStatement::new( + f.fold_expression(s.inner)?, + ))]) +} + +pub fn fold_definition_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: DefinitionStatement<'ast, T>, +) -> Result>, F::Error> { + let rhs = f.fold_definition_rhs(s.rhs)?; + Ok(vec![TypedStatement::Definition(DefinitionStatement::new( + f.fold_assignee(s.assignee)?, + rhs, + ))]) +} + +pub fn fold_assertion_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssertionStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![TypedStatement::Assertion( + AssertionStatement::new(f.fold_boolean_expression(s.expression)?, s.error).span(s.span), + )]) +} + +pub fn fold_for_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ForStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![TypedStatement::For(ForStatement::new( + f.fold_variable(s.var)?, + f.fold_uint_expression(s.from)?, + f.fold_uint_expression(s.to)?, + s.statements + .into_iter() + .map(|s| f.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ))]) +} + +pub fn fold_log_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: LogStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![TypedStatement::Log(LogStatement::new( + s.format_string, + s.expressions + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?, + ))]) } pub fn fold_definition_rhs<'ast, T: Field, F: ResultFolder<'ast, T>>( @@ -613,7 +894,16 @@ pub fn fold_embed_call<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } -pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_array_expression_cases(ty, e).map(|e| e.span(span)) +} + +pub fn fold_array_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, @@ -626,16 +916,14 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( IdentifierOrExpression::Identifier(i) => ArrayExpressionInner::Identifier(i), IdentifierOrExpression::Expression(u) => u, }, - Value(exprs) => Value( - exprs - .into_iter() - .map(|e| f.fold_expression_or_spread(e)) - .collect::>()?, - ), FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call)? { FunctionCallOrExpression::FunctionCall(c) => FunctionCall(c), FunctionCallOrExpression::Expression(u) => u, }, + Value(value) => match f.fold_array_value_expression(ty, value)? { + ValueOrExpression::Value(c) => Value(c), + ValueOrExpression::Expression(u) => u, + }, Conditional(c) => match f.fold_conditional_expression(ty, c)? { ConditionalOrExpression::Conditional(c) => Conditional(c), ConditionalOrExpression::Expression(u) => u, @@ -648,17 +936,14 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( SelectOrExpression::Select(m) => Select(m), SelectOrExpression::Expression(u) => u, }, - Slice(box array, box from, box to) => { - let array = f.fold_array_expression(array)?; - let from = f.fold_uint_expression(from)?; - let to = f.fold_uint_expression(to)?; - Slice(box array, box from, box to) - } - Repeat(box e, box count) => { - let e = f.fold_expression(e)?; - let count = f.fold_uint_expression(count)?; - Repeat(box e, box count) - } + Slice(s) => match f.fold_slice_expression(s)? { + SliceOrExpression::Slice(m) => Slice(m), + SliceOrExpression::Expression(u) => u, + }, + Repeat(s) => match f.fold_repeat_expression(s)? { + RepeatOrExpression::Repeat(m) => Repeat(m), + RepeatOrExpression::Expression(u) => u, + }, Element(element) => match f.fold_element_expression(ty, element)? { ElementOrExpression::Element(m) => Element(m), ElementOrExpression::Expression(u) => u, @@ -684,7 +969,66 @@ pub fn fold_assignee<'ast, T: Field, F: ResultFolder<'ast, T>>( } } -pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( +pub fn fold_struct_value_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: &StructType<'ast, T>, + a: StructValueExpression<'ast, T>, +) -> Result< + ValueOrExpression, StructExpressionInner<'ast, T>>, + F::Error, +> { + Ok(ValueOrExpression::Value(StructValueExpression { + value: a + .value + .into_iter() + .map(|v| f.fold_expression(v)) + .collect::, _>>()?, + ..a + })) +} + +pub fn fold_array_value_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: &ArrayType<'ast, T>, + a: ArrayValueExpression<'ast, T>, +) -> Result, ArrayExpressionInner<'ast, T>>, F::Error> +{ + Ok(ValueOrExpression::Value(ArrayValueExpression { + value: a + .value + .into_iter() + .map(|v| f.fold_expression_or_spread(v)) + .collect::, _>>()?, + ..a + })) +} + +pub fn fold_tuple_value_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: &TupleType<'ast, T>, + a: TupleValueExpression<'ast, T>, +) -> Result, TupleExpressionInner<'ast, T>>, F::Error> +{ + Ok(ValueOrExpression::Value(TupleValueExpression { + value: a + .value + .into_iter() + .map(|v| f.fold_expression(v)) + .collect::, _>>()?, + ..a + })) +} + +fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_struct_expression_cases(ty, e).map(|e| e.span(span)) +} + +pub fn fold_struct_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, @@ -697,12 +1041,10 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => Block(f.fold_block_expression(block)?), - Value(exprs) => Value( - exprs - .into_iter() - .map(|e| f.fold_expression(e)) - .collect::>()?, - ), + Value(value) => match f.fold_struct_value_expression(ty, value)? { + ValueOrExpression::Value(c) => Value(c), + ValueOrExpression::Expression(u) => u, + }, FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call)? { FunctionCallOrExpression::FunctionCall(c) => FunctionCall(c), FunctionCallOrExpression::Expression(u) => u, @@ -727,7 +1069,16 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( Ok(e) } -pub fn fold_tuple_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_tuple_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + ty: &TupleType<'ast, T>, + e: TupleExpressionInner<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_tuple_expression_cases(ty, e).map(|e| e.span(span)) +} + +pub fn fold_tuple_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: &TupleType<'ast, T>, e: TupleExpressionInner<'ast, T>, @@ -740,12 +1091,10 @@ pub fn fold_tuple_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( IdentifierOrExpression::Identifier(i) => Identifier(i), IdentifierOrExpression::Expression(u) => u, }, - Value(exprs) => Value( - exprs - .into_iter() - .map(|e| f.fold_expression(e)) - .collect::>()?, - ), + Value(value) => match f.fold_tuple_value_expression(ty, value)? { + ValueOrExpression::Value(c) => Value(c), + ValueOrExpression::Expression(u) => u, + }, FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call)? { FunctionCallOrExpression::FunctionCall(c) => FunctionCall(c), FunctionCallOrExpression::Expression(u) => u, @@ -770,7 +1119,15 @@ pub fn fold_tuple_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( Ok(e) } -pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_field_expression_cases(e).map(|e| e.span(span)) +} + +pub fn fold_field_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: FieldElementExpression<'ast, T>, ) -> Result, F::Error> { @@ -782,72 +1139,55 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => Block(f.fold_block_expression(block)?), - Number(n) => Number(n), - Add(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - Add(box e1, box e2) - } - Sub(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - Sub(box e1, box e2) - } - Mult(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - Mult(box e1, box e2) - } - Div(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - Div(box e1, box e2) - } - Pow(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - Pow(box e1, box e2) - } - Neg(box e) => { - let e = f.fold_field_expression(e)?; - - Neg(box e) - } - Pos(box e) => { - let e = f.fold_field_expression(e)?; - - Pos(box e) - } - And(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - And(box left, box right) - } - Or(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - Or(box left, box right) - } - Xor(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - Xor(box left, box right) - } - LeftShift(box e, box by) => { - let e = f.fold_field_expression(e)?; - let by = f.fold_uint_expression(by)?; - - LeftShift(box e, box by) - } - RightShift(box e, box by) => { - let e = f.fold_field_expression(e)?; - let by = f.fold_uint_expression(by)?; - - RightShift(box e, box by) - } + Value(n) => Value(n), + Add(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Pow(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Pow(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Neg(e) => match f.fold_unary_expression(&Type::FieldElement, e)? { + UnaryOrExpression::Unary(e) => Neg(e), + UnaryOrExpression::Expression(u) => u, + }, + Pos(e) => match f.fold_unary_expression(&Type::FieldElement, e)? { + UnaryOrExpression::Unary(e) => Pos(e), + UnaryOrExpression::Expression(u) => u, + }, Conditional(c) => match f.fold_conditional_expression(&Type::FieldElement, c)? { ConditionalOrExpression::Conditional(c) => Conditional(c), ConditionalOrExpression::Expression(u) => u, @@ -881,12 +1221,17 @@ pub fn fold_int_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( unreachable!() } -pub fn fold_block_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFolder<'ast, T>>( +pub fn fold_block_expression< + 'ast, + T: Field, + E: ResultFold, + F: ResultFolder<'ast, T>, +>( f: &mut F, block: BlockExpression<'ast, T, E>, ) -> Result, F::Error> { - Ok(BlockExpression { - statements: block + Ok(BlockExpression::new( + block .statements .into_iter() .map(|s| f.fold_statement(s)) @@ -894,8 +1239,9 @@ pub fn fold_block_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFo .into_iter() .flatten() .collect(), - value: box block.value.fold(f)?, - }) + block.value.fold(f)?, + ) + .span(block.span)) } pub fn fold_conditional_expression< @@ -904,7 +1250,7 @@ pub fn fold_conditional_expression< E: Expr<'ast, T> + Conditional<'ast, T> + PartialEq - + ResultFold<'ast, T> + + ResultFold + From>, F: ResultFolder<'ast, T>, >( @@ -918,7 +1264,44 @@ pub fn fold_conditional_expression< e.consequence.fold(f)?, e.alternative.fold(f)?, e.kind, - ), + ) + .span(e.span), + )) +} + +#[allow(clippy::type_complexity)] +pub fn fold_binary_expression< + 'ast, + T: Field, + L: Expr<'ast, T> + PartialEq + ResultFold + From>, + R: Expr<'ast, T> + PartialEq + ResultFold + From>, + E: Expr<'ast, T> + PartialEq + ResultFold + From>, + F: ResultFolder<'ast, T>, + Op: OperatorStr, +>( + f: &mut F, + _: &E::Ty, + e: BinaryExpression, +) -> Result, F::Error> { + Ok(BinaryOrExpression::Binary( + BinaryExpression::new(e.left.fold(f)?, e.right.fold(f)?).span(e.span), + )) +} + +pub fn fold_unary_expression< + 'ast, + T: Field, + In: Expr<'ast, T> + PartialEq + ResultFold + From>, + E: Expr<'ast, T> + PartialEq + ResultFold + From>, + F: ResultFolder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: UnaryExpression, +) -> Result, F::Error> { + Ok(UnaryOrExpression::Unary( + UnaryExpression::new(e.inner.fold(f)?).span(e.span), )) } @@ -932,9 +1315,29 @@ pub fn fold_member_expression< _: &E::Ty, e: MemberExpression<'ast, T, E>, ) -> Result, F::Error> { - Ok(MemberOrExpression::Member(MemberExpression::new( - f.fold_struct_expression(*e.struc)?, - e.id, + Ok(MemberOrExpression::Member( + MemberExpression::new(f.fold_struct_expression(*e.struc)?, e.id).span(e.span), + )) +} + +pub fn fold_slice_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: SliceExpression<'ast, T>, +) -> Result, F::Error> { + Ok(SliceOrExpression::Slice(SliceExpression::new( + f.fold_array_expression(*e.array)?, + f.fold_uint_expression(*e.from)?, + f.fold_uint_expression(*e.to)?, + ))) +} + +pub fn fold_repeat_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: RepeatExpression<'ast, T>, +) -> Result, F::Error> { + Ok(RepeatOrExpression::Repeat(RepeatExpression::new( + f.fold_expression(*e.e)?, + f.fold_uint_expression(*e.count)?, ))) } @@ -949,20 +1352,10 @@ pub fn fold_identifier_expression< e: IdentifierExpression<'ast, E>, ) -> Result, F::Error> { Ok(IdentifierOrExpression::Identifier( - IdentifierExpression::new(f.fold_name(e.id)?), + IdentifierExpression::new(f.fold_name(e.id)?).span(e.span), )) } -pub fn fold_eq_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFolder<'ast, T>>( - f: &mut F, - e: EqExpression, -) -> Result, F::Error> { - Ok(EqOrBoolean::Eq(EqExpression::new( - e.left.fold(f)?, - e.right.fold(f)?, - ))) -} - pub fn fold_select_expression< 'ast, T: Field, @@ -976,10 +1369,13 @@ pub fn fold_select_expression< _: &E::Ty, e: SelectExpression<'ast, T, E>, ) -> Result, F::Error> { - Ok(SelectOrExpression::Select(SelectExpression::new( - f.fold_array_expression(*e.array)?, - f.fold_uint_expression(*e.index)?, - ))) + Ok(SelectOrExpression::Select( + SelectExpression::new( + f.fold_array_expression(*e.array)?, + f.fold_uint_expression(*e.index)?, + ) + .span(e.span), + )) } pub fn fold_element_expression< @@ -992,10 +1388,9 @@ pub fn fold_element_expression< _: &E::Ty, e: ElementExpression<'ast, T, E>, ) -> Result, F::Error> { - Ok(ElementOrExpression::Element(ElementExpression::new( - f.fold_tuple_expression(*e.tuple)?, - e.index, - ))) + Ok(ElementOrExpression::Element( + ElementExpression::new(f.fold_tuple_expression(*e.tuple)?, e.index).span(e.span), + )) } pub fn fold_function_call_expression< @@ -1008,20 +1403,31 @@ pub fn fold_function_call_expression< _: &E::Ty, e: FunctionCallExpression<'ast, T, E>, ) -> Result, F::Error> { - Ok(FunctionCallOrExpression::Expression(E::function_call( - f.fold_declaration_function_key(e.function_key)?, - e.generics - .into_iter() - .map(|g| g.map(|g| f.fold_uint_expression(g)).transpose()) - .collect::>()?, - e.arguments - .into_iter() - .map(|e| f.fold_expression(e)) - .collect::>()?, - ))) + Ok(FunctionCallOrExpression::Expression( + E::function_call( + f.fold_declaration_function_key(e.function_key)?, + e.generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g)).transpose()) + .collect::>()?, + e.arguments + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?, + ) + .span(e.span), + )) +} + +fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_boolean_expression_cases(e).map(|e| e.span(span)) } -pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( +pub fn fold_boolean_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: BooleanExpression<'ast, T>, ) -> Result, F::Error> { @@ -1035,83 +1441,57 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( Block(block) => Block(f.fold_block_expression(block)?), Value(v) => Value(v), FieldEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => FieldEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => FieldEq(e), + BinaryOrExpression::Expression(u) => u, }, BoolEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => BoolEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => BoolEq(e), + BinaryOrExpression::Expression(u) => u, }, ArrayEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => ArrayEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => ArrayEq(e), + BinaryOrExpression::Expression(u) => u, }, StructEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => StructEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => StructEq(e), + BinaryOrExpression::Expression(u) => u, }, TupleEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => TupleEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => TupleEq(e), + BinaryOrExpression::Expression(u) => u, }, UintEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => UintEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => UintEq(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLt(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => FieldLt(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLe(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => FieldLe(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLt(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => UintLt(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLe(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => UintLe(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&Type::Boolean, e)? { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, }, - FieldLt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldLt(box e1, box e2) - } - FieldLe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldLe(box e1, box e2) - } - FieldGt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldGt(box e1, box e2) - } - FieldGe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldGe(box e1, box e2) - } - UintLt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - UintLt(box e1, box e2) - } - UintLe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - UintLe(box e1, box e2) - } - UintGt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - UintGt(box e1, box e2) - } - UintGe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - UintGe(box e1, box e2) - } - Or(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1)?; - let e2 = f.fold_boolean_expression(e2)?; - Or(box e1, box e2) - } - And(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1)?; - let e2 = f.fold_boolean_expression(e2)?; - And(box e1, box e2) - } - Not(box e) => { - let e = f.fold_boolean_expression(e)?; - Not(box e) - } FunctionCall(function_call) => { match f.fold_function_call_expression(&Type::Boolean, function_call)? { FunctionCallOrExpression::FunctionCall(c) => FunctionCall(c), @@ -1135,6 +1515,7 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( ElementOrExpression::Expression(u) => u, }, }; + Ok(e) } @@ -1148,7 +1529,16 @@ pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } -pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + ty: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_uint_expression_cases(ty, e).map(|e| e.span(span)) +} + +pub fn fold_uint_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: UBitwidth, e: UExpressionInner<'ast, T>, @@ -1162,87 +1552,62 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( }, Block(block) => Block(f.fold_block_expression(block)?), Value(v) => Value(v), - Add(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Add(box left, box right) - } - Sub(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Sub(box left, box right) - } - FloorSub(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - FloorSub(box left, box right) - } - Mult(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Mult(box left, box right) - } - Div(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Div(box left, box right) - } - Rem(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Rem(box left, box right) - } - Xor(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Xor(box left, box right) - } - And(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - And(box left, box right) - } - Or(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Or(box left, box right) - } - LeftShift(box e, box by) => { - let e = f.fold_uint_expression(e)?; - let by = f.fold_uint_expression(by)?; - - LeftShift(box e, box by) - } - RightShift(box e, box by) => { - let e = f.fold_uint_expression(e)?; - let by = f.fold_uint_expression(by)?; - - RightShift(box e, box by) - } - Not(box e) => { - let e = f.fold_uint_expression(e)?; - - Not(box e) - } - Neg(box e) => { - let e = f.fold_uint_expression(e)?; - - Neg(box e) - } - Pos(box e) => { - let e = f.fold_uint_expression(e)?; - - Pos(box e) - } + Add(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + FloorSub(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => FloorSub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Rem(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Rem(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Pos(e) => match f.fold_unary_expression(&ty, e)? { + UnaryOrExpression::Unary(e) => Pos(e), + UnaryOrExpression::Expression(u) => u, + }, + Neg(e) => match f.fold_unary_expression(&ty, e)? { + UnaryOrExpression::Unary(e) => Neg(e), + UnaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&ty, e)? { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, FunctionCall(function_call) => { match f.fold_function_call_expression(&ty, function_call)? { FunctionCallOrExpression::FunctionCall(c) => FunctionCall(c), @@ -1266,6 +1631,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( ElementOrExpression::Expression(u) => u, }, }; + Ok(e) } @@ -1481,5 +1847,6 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>( f.fold_module(module).map(|m| (module_id, m)) }) .collect::>()?, + ..p }) } diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index f2bf23d5b..8ae15de8b 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -1,5 +1,6 @@ +use crate::common::expressions::ValueExpression; use crate::typed::{ - CoreIdentifier, Identifier, OwnedTypedModuleId, TypedExpression, UExpression, UExpressionInner, + CoreIdentifier, Identifier, OwnedModuleId, TypedExpression, UExpression, UExpressionInner, }; use crate::typed::{TryFrom, TryInto}; use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer}; @@ -125,7 +126,7 @@ pub type ConstantIdentifier<'ast> = &'ast str; #[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord, Serialize, Deserialize)] pub struct CanonicalConstantIdentifier<'ast> { - pub module: OwnedTypedModuleId, + pub module: OwnedModuleId, #[serde(borrow)] pub id: ConstantIdentifier<'ast>, } @@ -137,7 +138,7 @@ impl<'ast> fmt::Display for CanonicalConstantIdentifier<'ast> { } impl<'ast> CanonicalConstantIdentifier<'ast> { - pub fn new(id: ConstantIdentifier<'ast>, module: OwnedTypedModuleId) -> Self { + pub fn new(id: ConstantIdentifier<'ast>, module: OwnedModuleId) -> Self { CanonicalConstantIdentifier { module, id } } } @@ -188,7 +189,7 @@ impl<'ast, T: PartialEq> PartialEq> for DeclarationConstant inner: UExpressionInner::Value(v), .. }, - ) => *c == *v as u32, + ) => *c == v.value as u32, (DeclarationConstant::Expression(TypedExpression::Uint(e0)), e1) => e0 == e1, (DeclarationConstant::Expression(..), _) => false, // type error _ => true, @@ -233,7 +234,7 @@ impl<'ast, T: fmt::Display> fmt::Display for DeclarationConstant<'ast, T> { impl<'ast, T> From for UExpression<'ast, T> { fn from(i: u32) -> Self { - UExpressionInner::Value(i as u128).annotate(UBitwidth::B32) + UExpressionInner::Value(ValueExpression::new(i as u128)).annotate(UBitwidth::B32) } } @@ -245,7 +246,7 @@ impl<'ast, T: Field> From> for UExpression<'ast, T> .annotate(UBitwidth::B32) } DeclarationConstant::Concrete(v) => { - UExpressionInner::Value(v as u128).annotate(UBitwidth::B32) + UExpression::value(v as u128).annotate(UBitwidth::B32) } DeclarationConstant::Constant(v) => { UExpression::identifier(FrameIdentifier::from(v).into()).annotate(UBitwidth::B32) @@ -262,7 +263,7 @@ impl<'ast, T> TryInto for UExpression<'ast, T> { assert_eq!(self.bitwidth, UBitwidth::B32); match self.into_inner() { - UExpressionInner::Value(v) => Ok(v as u32), + UExpressionInner::Value(v) => Ok(v.value as u32), _ => Err(SpecializationError), } } @@ -920,7 +921,7 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> { match (&l.size.as_inner(), &*r.size) { // compare the sizes for concrete ones (UExpressionInner::Value(v), DeclarationConstant::Concrete(c)) => { - (*v as u32) == *c + (v.value as u32) == *c } _ => true, } @@ -968,7 +969,7 @@ pub type FunctionIdentifier<'ast> = &'ast str; #[derive(PartialEq, Eq, Hash, Debug, Clone, PartialOrd, Ord)] pub struct GFunctionKey<'ast, S> { - pub module: OwnedTypedModuleId, + pub module: OwnedModuleId, pub id: FunctionIdentifier<'ast>, pub signature: GSignature, } @@ -1046,7 +1047,7 @@ impl<'ast, T> From> for DeclarationFunctionKey<'ast, T } impl<'ast, S> GFunctionKey<'ast, S> { - pub fn with_location, U: Into>>( + pub fn with_location, U: Into>>( module: T, id: U, ) -> Self { @@ -1067,7 +1068,7 @@ impl<'ast, S> GFunctionKey<'ast, S> { self } - pub fn module>(mut self, module: T) -> Self { + pub fn module>(mut self, module: T) -> Self { self.module = module.into(); self } @@ -1100,7 +1101,7 @@ pub fn check_generic<'ast, T, S: Clone + PartialEq + PartialEq>( DeclarationConstant::Constant(..) => true, DeclarationConstant::Expression(e) => match e { TypedExpression::Uint(e) => match e.as_inner() { - UExpressionInner::Value(v) => *value == *v as u32, + UExpressionInner::Value(v) => *value == v.value as u32, _ => true, }, _ => unreachable!(), @@ -1231,7 +1232,7 @@ pub use self::signature::{ }; use super::identifier::FrameIdentifier; -use super::{Id, ShadowedIdentifier}; +use super::{Expr, Id, ShadowedIdentifier}; pub mod signature { use super::*; diff --git a/zokrates_ast/src/typed/uint.rs b/zokrates_ast/src/typed/uint.rs index c2120fc6f..a04550a4c 100644 --- a/zokrates_ast/src/typed/uint.rs +++ b/zokrates_ast/src/typed/uint.rs @@ -1,5 +1,5 @@ -use crate::typed::types::UBitwidth; use crate::typed::*; +use crate::{common::expressions::UValueExpression, typed::types::UBitwidth}; use std::ops::{Add, Div, Mul, Neg, Not, Rem, Sub}; use zokrates_field::Field; @@ -12,10 +12,12 @@ impl<'ast, T> Add for UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); + // we apply a basic simplification here which enables more precise comparison of array sizes during semantic checking + // this could be done by the caller by calling propagation, but it is deemed simple enough to be done here match (self.as_inner(), other.as_inner()) { - (UExpressionInner::Value(0), _) => other, - (_, UExpressionInner::Value(0)) => self, - _ => UExpressionInner::Add(box self, box other).annotate(bitwidth), + (UExpressionInner::Value(v), _) if v.value == 0 => other, + (_, UExpressionInner::Value(v)) if v.value == 0 => self, + _ => UExpressionInner::Add(BinaryExpression::new(self, other)).annotate(bitwidth), } } } @@ -26,7 +28,7 @@ impl<'ast, T> Sub for UExpression<'ast, T> { fn sub(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Sub(box self, box other).annotate(bitwidth) + UExpressionInner::Sub(BinaryExpression::new(self, other)).annotate(bitwidth) } } @@ -36,7 +38,7 @@ impl<'ast, T> Mul for UExpression<'ast, T> { fn mul(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Mult(box self, box other).annotate(bitwidth) + UExpressionInner::Mult(BinaryExpression::new(self, other)).annotate(bitwidth) } } @@ -46,7 +48,7 @@ impl<'ast, T> Div for UExpression<'ast, T> { fn div(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Div(box self, box other).annotate(bitwidth) + UExpressionInner::Div(BinaryExpression::new(self, other)).annotate(bitwidth) } } @@ -56,7 +58,7 @@ impl<'ast, T> Rem for UExpression<'ast, T> { fn rem(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Rem(box self, box other).annotate(bitwidth) + UExpressionInner::Rem(BinaryExpression::new(self, other)).annotate(bitwidth) } } @@ -65,7 +67,7 @@ impl<'ast, T> Not for UExpression<'ast, T> { fn not(self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::Not(box self).annotate(bitwidth) + UExpressionInner::Not(UnaryExpression::new(self)).annotate(bitwidth) } } @@ -74,7 +76,7 @@ impl<'ast, T> Neg for UExpression<'ast, T> { fn neg(self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::Neg(box self).annotate(bitwidth) + UExpressionInner::Neg(UnaryExpression::new(self)).annotate(bitwidth) } } @@ -82,48 +84,48 @@ impl<'ast, T: Field> UExpression<'ast, T> { pub fn xor(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Xor(box self, box other).annotate(bitwidth) + UExpressionInner::Xor(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn or(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Or(box self, box other).annotate(bitwidth) + UExpressionInner::Or(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn and(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::And(box self, box other).annotate(bitwidth) + UExpressionInner::And(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn pos(self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::Pos(box self).annotate(bitwidth) + UExpressionInner::Pos(UnaryExpression::new(self)).annotate(bitwidth) } pub fn left_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(by.bitwidth, UBitwidth::B32); - UExpressionInner::LeftShift(box self, box by).annotate(bitwidth) + UExpressionInner::LeftShift(BinaryExpression::new(self, by)).annotate(bitwidth) } pub fn right_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(by.bitwidth, UBitwidth::B32); - UExpressionInner::RightShift(box self, box by).annotate(bitwidth) + UExpressionInner::RightShift(BinaryExpression::new(self, by)).annotate(bitwidth) } pub fn floor_sub(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::FloorSub(box self, box other).annotate(bitwidth) + UExpressionInner::FloorSub(BinaryExpression::new(self, other)).annotate(bitwidth) } } -impl<'ast, T: Field> From for UExpressionInner<'ast, T> { +impl<'ast, T> From for UExpressionInner<'ast, T> { fn from(e: u128) -> Self { - UExpressionInner::Value(e) + UExpressionInner::Value(ValueExpression::new(e)) } } @@ -142,20 +144,20 @@ pub struct UExpression<'ast, T> { impl<'ast, T> From for UExpression<'ast, T> { fn from(u: u16) -> Self { - UExpressionInner::Value(u as u128).annotate(UBitwidth::B16) + UExpressionInner::Value(ValueExpression::new(u as u128)).annotate(UBitwidth::B16) } } impl<'ast, T> From for UExpression<'ast, T> { fn from(u: u8) -> Self { - UExpressionInner::Value(u as u128).annotate(UBitwidth::B8) + UExpressionInner::Value(ValueExpression::new(u as u128)).annotate(UBitwidth::B8) } } impl<'ast, T> PartialEq for UExpression<'ast, T> { fn eq(&self, other: &u32) -> bool { match self.as_inner() { - UExpressionInner::Value(v) => *v == *other as u128, + UExpressionInner::Value(v) => v.value == *other as u128, _ => true, } } @@ -165,22 +167,33 @@ impl<'ast, T> PartialEq for UExpression<'ast, T> { pub enum UExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, UExpression<'ast, T>>), Identifier(IdentifierExpression<'ast, UExpression<'ast, T>>), - Value(u128), - Add(Box>, Box>), - Sub(Box>, Box>), - FloorSub(Box>, Box>), - Mult(Box>, Box>), - Div(Box>, Box>), - Rem(Box>, Box>), - Xor(Box>, Box>), - And(Box>, Box>), - Or(Box>, Box>), - Not(Box>), - Neg(Box>), - Pos(Box>), + Value(UValueExpression), + Add(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Sub(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + FloorSub( + BinaryExpression< + OpFloorSub, + UExpression<'ast, T>, + UExpression<'ast, T>, + UExpression<'ast, T>, + >, + ), + Mult(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Div(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Rem(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Xor(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + And(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Or(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Not(UnaryExpression, UExpression<'ast, T>>), + Neg(UnaryExpression, UExpression<'ast, T>>), + Pos(UnaryExpression, UExpression<'ast, T>>), FunctionCall(FunctionCallExpression<'ast, T, UExpression<'ast, T>>), - LeftShift(Box>, Box>), - RightShift(Box>, Box>), + LeftShift( + BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>, + ), + RightShift( + BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>, + ), Conditional(ConditionalExpression<'ast, T, UExpression<'ast, T>>), Member(MemberExpression<'ast, T, UExpression<'ast, T>>), Select(SelectExpression<'ast, T, UExpression<'ast, T>>), diff --git a/zokrates_ast/src/typed/utils/mod.rs b/zokrates_ast/src/typed/utils/mod.rs index 6ec6885c2..36b6d6aea 100644 --- a/zokrates_ast/src/typed/utils/mod.rs +++ b/zokrates_ast/src/typed/utils/mod.rs @@ -1,13 +1,13 @@ use super::{ - ArrayExpression, ArrayExpressionInner, ArrayValue, BooleanExpression, Conditional, - ConditionalKind, Expr, FieldElementExpression, Id, Identifier, Select, Typed, TypedExpression, - TypedExpressionOrSpread, UBitwidth, UExpression, UExpressionInner, + ArrayExpression, ArrayExpressionInner, BooleanExpression, Conditional, ConditionalKind, Expr, + FieldElementExpression, Id, Identifier, Select, Typed, TypedExpression, + TypedExpressionOrSpread, UBitwidth, UExpression, ValueExpression, }; use zokrates_field::Field; pub fn f<'ast, T, U: TryInto>(v: U) -> FieldElementExpression<'ast, T> { - FieldElementExpression::Number(v.try_into().map_err(|_| ()).unwrap()) + FieldElementExpression::Value(ValueExpression::new(v.try_into().map_err(|_| ()).unwrap())) } pub fn a_id<'ast, T: Field, I: TryInto>>(v: I) -> ArrayExpressionInner<'ast, T> { @@ -16,24 +16,24 @@ pub fn a_id<'ast, T: Field, I: TryInto>>(v: I) -> ArrayExpressi pub fn a< 'ast, - T, + T: Field, E: Typed<'ast, T> + Expr<'ast, T> + Into>, const N: usize, >( values: [E; N], ) -> ArrayExpression<'ast, T> { let ty = values[0].get_type(); - ArrayExpressionInner::Value(ArrayValue( + ArrayExpression::value( values .into_iter() .map(|e| TypedExpressionOrSpread::Expression(e.into())) .collect(), - )) + ) .annotate(ty, N as u32) } -pub fn u_32<'ast, T, U: TryInto>(v: U) -> UExpression<'ast, T> { - UExpressionInner::Value(v.try_into().map_err(|_| ()).unwrap() as u128).annotate(UBitwidth::B32) +pub fn u_32<'ast, T: Field, U: TryInto>(v: U) -> UExpression<'ast, T> { + UExpression::value(v.try_into().map_err(|_| ()).unwrap() as u128).annotate(UBitwidth::B32) } pub fn conditional<'ast, T, E: Conditional<'ast, T>>( diff --git a/zokrates_ast/src/typed/variable.rs b/zokrates_ast/src/typed/variable.rs index d47d48728..0ad50061c 100644 --- a/zokrates_ast/src/typed/variable.rs +++ b/zokrates_ast/src/typed/variable.rs @@ -1,17 +1,11 @@ +use crate::common::WithSpan; use crate::typed::types::{DeclarationConstant, GStructType, UBitwidth}; use crate::typed::types::{GType, SpecializationError}; use crate::typed::Identifier; use crate::typed::UExpression; use crate::typed::{TryFrom, TryInto}; -use std::fmt; - -#[derive(Clone, PartialEq, Hash, Eq, PartialOrd, Ord, Debug)] -pub struct GVariable<'ast, S> { - pub id: Identifier<'ast>, - pub _type: GType, - pub is_mutable: bool, -} +pub type GVariable<'ast, S> = crate::common::Variable, GType>; pub type DeclarationVariable<'ast, T> = GVariable<'ast, DeclarationConstant<'ast, T>>; pub type ConcreteVariable<'ast> = GVariable<'ast, u32>; pub type Variable<'ast, T> = GVariable<'ast, UExpression<'ast, T>>; @@ -20,84 +14,56 @@ impl<'ast, T> TryFrom> for ConcreteVariable<'ast> { type Error = SpecializationError; fn try_from(v: Variable<'ast, T>) -> Result { - let _type = v._type.try_into()?; + let span = v.get_span(); - Ok(Self { - _type, - id: v.id, - is_mutable: v.is_mutable, - }) + let ty = v.ty.try_into()?; + + Ok(Self::new(v.id, ty).span(span)) } } impl<'ast, T> From> for Variable<'ast, T> { fn from(v: ConcreteVariable<'ast>) -> Self { - let _type = v._type.into(); + let span = v.get_span(); + + let ty = v.ty.into(); - Self { - _type, - id: v.id, - is_mutable: v.is_mutable, - } + Self::new(v.id, ty).span(span) } } pub fn try_from_g_variable, U>( v: GVariable, ) -> Result, SpecializationError> { - let _type = crate::typed::types::try_from_g_type(v._type)?; + let span = v.get_span(); - Ok(GVariable { - _type, - id: v.id, - is_mutable: v.is_mutable, - }) + let ty = crate::typed::types::try_from_g_type(v.ty)?; + + Ok(GVariable::new(v.id, ty).span(span)) } impl<'ast, S: Clone> GVariable<'ast, S> { pub fn field_element>>(id: I) -> Self { - Self::immutable(id, GType::FieldElement) + Self::new(id, GType::FieldElement) } pub fn boolean>>(id: I) -> Self { - Self::immutable(id, GType::Boolean) + Self::new(id, GType::Boolean) } pub fn uint>, W: Into>(id: I, bitwidth: W) -> Self { - Self::immutable(id, GType::uint(bitwidth)) + Self::new(id, GType::uint(bitwidth)) } pub fn array>, U: Into>(id: I, ty: GType, size: U) -> Self { - Self::immutable(id, GType::array((ty, size.into()))) + Self::new(id, GType::array((ty, size.into()))) } pub fn struc>>(id: I, ty: GStructType) -> Self { - Self::immutable(id, GType::Struct(ty)) - } - - pub fn immutable>>(id: I, _type: GType) -> Self { - Self::new(id, _type, false) - } - - pub fn mutable>>(id: I, _type: GType) -> Self { - Self::new(id, _type, true) - } - - pub fn new>>(id: I, _type: GType, is_mutable: bool) -> Self { - GVariable { - id: id.into(), - _type, - is_mutable, - } + Self::new(id, GType::Struct(ty)) } pub fn get_type(&self) -> GType { - self._type.clone() - } -} - -impl<'ast, S: fmt::Display> fmt::Display for GVariable<'ast, S> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {}", self._type, self.id,) + self.ty.clone() } } diff --git a/zokrates_ast/src/untyped/mod.rs b/zokrates_ast/src/untyped/mod.rs index 07541edec..c0f77ea64 100644 --- a/zokrates_ast/src/untyped/mod.rs +++ b/zokrates_ast/src/untyped/mod.rs @@ -8,17 +8,17 @@ mod from_ast; mod node; pub mod parameter; -mod position; pub mod types; pub mod variable; pub use self::node::{Node, NodeValue}; pub use self::parameter::{Parameter, ParameterNode}; -pub use self::position::Position; use self::types::{UnresolvedSignature, UnresolvedType, UserTypeId}; pub use self::variable::{Variable, VariableNode}; use crate::common::FlatEmbed; -use std::path::{Path, PathBuf}; +pub use crate::common::Position; +pub use crate::common::{ModuleId, OwnedModuleId}; +use std::path::Path; use std::fmt; @@ -28,10 +28,6 @@ use std::collections::HashMap; /// An identifier of a function or a variable pub type Identifier<'ast> = &'ast str; -/// The identifier of a `Module`, typically a path or uri -pub type OwnedModuleId = PathBuf; -pub type ModuleId = Path; - /// A collection of `Module`s pub type Modules<'ast> = HashMap>; diff --git a/zokrates_ast/src/untyped/node.rs b/zokrates_ast/src/untyped/node.rs index 62ef299df..fbff15dd8 100644 --- a/zokrates_ast/src/untyped/node.rs +++ b/zokrates_ast/src/untyped/node.rs @@ -1,26 +1,25 @@ +use crate::common::LocalSourceSpan as Span; use std::fmt; -use zokrates_pest_ast::Span; +use zokrates_pest_ast::Span as PestSpan; #[derive(Clone)] pub struct Node { - pub start: Position, - pub end: Position, + pub span: Span, pub value: T, } impl Node { pub fn mock(e: T) -> Self { Self { - start: Position::mock(), - end: Position::mock(), + span: Span::mock(), value: e, } } } impl Node { - pub fn pos(&self) -> (Position, Position) { - (self.start, self.end) + pub fn span(&self) -> Span { + self.span } } @@ -37,8 +36,11 @@ impl fmt::Debug for Node { } impl Node { - pub fn new(start: Position, end: Position, value: T) -> Node { - Node { start, end, value } + pub fn new(from: Position, to: Position, value: T) -> Node { + Node { + span: Span { from, to }, + value, + } } } @@ -56,7 +58,7 @@ pub trait NodeValue: fmt::Display + fmt::Debug + Sized + PartialEq { Node::new(Position::mock(), Position::mock(), self) } - fn span(self, span: Span) -> Node { + fn span(self, span: PestSpan) -> Node { let from = span.start_pos().line_col(); let to = span.end_pos().line_col(); diff --git a/zokrates_ast/src/untyped/variable.rs b/zokrates_ast/src/untyped/variable.rs index edbc7f4cf..61f86f5dc 100644 --- a/zokrates_ast/src/untyped/variable.rs +++ b/zokrates_ast/src/untyped/variable.rs @@ -8,7 +8,7 @@ use super::Identifier; pub struct Variable<'ast> { pub is_mutable: bool, pub id: Identifier<'ast>, - pub _type: UnresolvedTypeNode<'ast>, + pub ty: UnresolvedTypeNode<'ast>, } pub type VariableNode<'ast> = Node>; @@ -22,7 +22,7 @@ impl<'ast> Variable<'ast> { Variable { is_mutable, id: id.into(), - _type: t, + ty: t, } } @@ -35,13 +35,13 @@ impl<'ast> Variable<'ast> { } pub fn get_type(&self) -> &UnresolvedType<'ast> { - &self._type.value + &self.ty.value } } impl<'ast> fmt::Display for Variable<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {}", self._type, self.id,) + write!(f, "{} {}", self.ty, self.id,) } } @@ -50,7 +50,7 @@ impl<'ast> fmt::Debug for Variable<'ast> { write!( f, "Variable(type: {:?}, id: {:?}, is_mutable: {:?})", - self._type, self.id, self.is_mutable + self.ty, self.id, self.is_mutable ) } } diff --git a/zokrates_ast/src/zir/canonicalizer.rs b/zokrates_ast/src/zir/canonicalizer.rs index ca67d8b51..61bf1970a 100644 --- a/zokrates_ast/src/zir/canonicalizer.rs +++ b/zokrates_ast/src/zir/canonicalizer.rs @@ -13,14 +13,14 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> { self.identifier_map.insert(p.id.id.clone(), new_id); Parameter { - id: Variable::with_id_and_type(Identifier::internal(new_id), p.id._type), + id: Variable::with_id_and_type(Identifier::internal(new_id), p.id.ty), ..p } } fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { let new_id = self.identifier_map.len(); self.identifier_map.insert(a.id.clone(), new_id); - ZirAssignee::with_id_and_type(Identifier::internal(new_id), a._type) + ZirAssignee::with_id_and_type(Identifier::internal(new_id), a.ty) } fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { match self.identifier_map.get(&n) { @@ -43,17 +43,14 @@ mod tests { #[test] fn canonicalize() { let func = ZirFunction:: { - arguments: vec![Parameter { - id: Variable::field_element("a"), - private: true, - }], + arguments: vec![Parameter::new(Variable::field_element("a"), true)], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( ZirAssignee::field_element("b"), FieldElementExpression::Identifier(IdentifierExpression::new("a".into())) .into(), ), - ZirStatement::Return(vec![FieldElementExpression::Identifier( + ZirStatement::ret(vec![FieldElementExpression::Identifier( IdentifierExpression::new("b".into()), ) .into()]), @@ -67,19 +64,19 @@ mod tests { let result = canonicalizer.fold_function(func); let expected = ZirFunction:: { - arguments: vec![Parameter { - id: Variable::field_element(Identifier::internal(0usize)), - private: true, - }], + arguments: vec![Parameter::new( + Variable::field_element(Identifier::internal(0usize)), + true, + )], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( ZirAssignee::field_element(Identifier::internal(1usize)), FieldElementExpression::Identifier(IdentifierExpression::new( Identifier::internal(0usize), )) .into(), ), - ZirStatement::Return(vec![FieldElementExpression::Identifier( + ZirStatement::ret(vec![FieldElementExpression::Identifier( IdentifierExpression::new(Identifier::internal(1usize)), ) .into()]), diff --git a/zokrates_ast/src/zir/folder.rs b/zokrates_ast/src/zir/folder.rs index 770eb0f03..b536ccc78 100644 --- a/zokrates_ast/src/zir/folder.rs +++ b/zokrates_ast/src/zir/folder.rs @@ -1,27 +1,25 @@ // Generic walk through ZIR. Not mutating in place +use crate::common::expressions::{BinaryOrExpression, IdentifierOrExpression, UnaryOrExpression}; +use crate::common::{Fold, WithSpan}; use crate::zir::types::UBitwidth; use crate::zir::*; use zokrates_field::Field; -pub trait Fold<'ast, T: Field>: Sized { - fn fold>(self, f: &mut F) -> Self; -} - -impl<'ast, T: Field> Fold<'ast, T> for FieldElementExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for FieldElementExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_field_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for BooleanExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for BooleanExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_boolean_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for UExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for UExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_uint_expression(self) } } @@ -56,6 +54,20 @@ pub trait Folder<'ast, T: Field>: Sized { self.fold_variable(a) } + fn fold_assembly_constraint( + &mut self, + s: AssemblyConstraint<'ast, T>, + ) -> Vec> { + fold_assembly_constraint(self, s) + } + + fn fold_assembly_assignment( + &mut self, + s: AssemblyAssignment<'ast, T>, + ) -> Vec> { + fold_assembly_assignment(self, s) + } + fn fold_assembly_statement( &mut self, s: ZirAssemblyStatement<'ast, T>, @@ -63,19 +75,73 @@ pub trait Folder<'ast, T: Field>: Sized { fold_assembly_statement(self, s) } + fn fold_assembly_statement_cases( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Vec> { + fold_assembly_statement_cases(self, s) + } + fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { fold_statement(self, s) } + fn fold_statement_cases(&mut self, s: ZirStatement<'ast, T>) -> Vec> { + fold_statement_cases(self, s) + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Vec> { + fold_definition_statement(self, s) + } + + fn fold_if_else_statement( + &mut self, + s: IfElseStatement<'ast, T>, + ) -> Vec> { + fold_if_else_statement(self, s) + } + + fn fold_multiple_definition_statement( + &mut self, + s: MultipleDefinitionStatement<'ast, T>, + ) -> Vec> { + fold_multiple_definition_statement(self, s) + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, + ) -> Vec> { + fold_assertion_statement(self, s) + } + + fn fold_return_statement(&mut self, s: ReturnStatement<'ast, T>) -> Vec> { + fold_return_statement(self, s) + } + + fn fold_log_statement(&mut self, s: LogStatement<'ast, T>) -> Vec> { + fold_log_statement(self, s) + } + + fn fold_assembly_block( + &mut self, + s: AssemblyBlockStatement<'ast, T>, + ) -> Vec> { + fold_assembly_block(self, s) + } + fn fold_identifier_expression + Id<'ast, T>>( &mut self, ty: &E::Ty, e: IdentifierExpression<'ast, E>, - ) -> IdentifierOrExpression<'ast, T, E> { + ) -> IdentifierOrExpression, E, E::Inner> { fold_identifier_expression(self, ty, e) } - fn fold_conditional_expression + Fold<'ast, T> + Conditional<'ast, T>>( + fn fold_conditional_expression + Fold + Conditional<'ast, T>>( &mut self, ty: &E::Ty, e: ConditionalExpression<'ast, T, E>, @@ -83,7 +149,7 @@ pub trait Folder<'ast, T: Field>: Sized { fold_conditional_expression(self, ty, e) } - fn fold_select_expression + Fold<'ast, T> + Select<'ast, T>>( + fn fold_select_expression + Fold + Select<'ast, T>>( &mut self, ty: &E::Ty, e: SelectExpression<'ast, T, E>, @@ -91,6 +157,27 @@ pub trait Folder<'ast, T: Field>: Sized { fold_select_expression(self, ty, e) } + fn fold_binary_expression< + L: Expr<'ast, T> + Fold, + R: Expr<'ast, T> + Fold, + E: Expr<'ast, T> + Fold, + Op, + >( + &mut self, + ty: &E::Ty, + e: BinaryExpression, + ) -> BinaryOrExpression { + fold_binary_expression(self, ty, e) + } + + fn fold_unary_expression + Fold, E: Expr<'ast, T> + Fold, Op>( + &mut self, + ty: &E::Ty, + e: UnaryExpression, + ) -> UnaryOrExpression { + fold_unary_expression(self, ty, e) + } + fn fold_expression(&mut self, e: ZirExpression<'ast, T>) -> ZirExpression<'ast, T> { match e { ZirExpression::FieldElement(e) => self.fold_field_expression(e).into(), @@ -140,23 +227,66 @@ pub trait Folder<'ast, T: Field>: Sized { ) -> UExpressionInner<'ast, T> { fold_uint_expression_inner(self, bitwidth, e) } + + fn fold_field_expression_cases( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> FieldElementExpression<'ast, T> { + fold_field_expression_cases(self, e) + } + + fn fold_boolean_expression_cases( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> BooleanExpression<'ast, T> { + fold_boolean_expression_cases(self, e) + } + + fn fold_uint_expression_cases( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> UExpressionInner<'ast, T> { + fold_uint_expression_cases(self, bitwidth, e) + } } -pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( +pub fn fold_assembly_assignment<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyAssignment<'ast, T>, +) -> Vec> { + let assignees = s.assignee.into_iter().map(|a| f.fold_assignee(a)).collect(); + let expression = f.fold_function(s.expression); + vec![ZirAssemblyStatement::assignment(assignees, expression)] +} + +pub fn fold_assembly_constraint<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyConstraint<'ast, T>, +) -> Vec> { + let left = f.fold_field_expression(s.left); + let right = f.fold_field_expression(s.right); + vec![ZirAssemblyStatement::constraint(left, right, s.metadata)] +} + +fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ZirAssemblyStatement<'ast, T>, +) -> Vec> { + let span = s.get_span(); + f.fold_assembly_statement_cases(s) + .into_iter() + .map(|s| s.span(span)) + .collect() +} + +fn fold_assembly_statement_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: ZirAssemblyStatement<'ast, T>, ) -> Vec> { match s { - ZirAssemblyStatement::Assignment(assignees, function) => { - let assignees = assignees.into_iter().map(|a| f.fold_assignee(a)).collect(); - let function = f.fold_function(function); - vec![ZirAssemblyStatement::Assignment(assignees, function)] - } - ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { - let lhs = f.fold_field_expression(lhs); - let rhs = f.fold_field_expression(rhs); - vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)] - } + ZirAssemblyStatement::Assignment(s) => f.fold_assembly_assignment(s), + ZirAssemblyStatement::Constraint(s) => f.fold_assembly_constraint(s), } } @@ -164,48 +294,119 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: ZirStatement<'ast, T>, ) -> Vec> { - let res = match s { - ZirStatement::Return(expressions) => ZirStatement::Return( - expressions - .into_iter() - .map(|e| f.fold_expression(e)) - .collect(), - ), - ZirStatement::Definition(a, e) => { - ZirStatement::Definition(f.fold_assignee(a), f.fold_expression(e)) - } - ZirStatement::IfElse(condition, consequence, alternative) => ZirStatement::IfElse( - f.fold_boolean_expression(condition), - consequence - .into_iter() - .flat_map(|e| f.fold_statement(e)) - .collect(), - alternative - .into_iter() - .flat_map(|e| f.fold_statement(e)) - .collect(), - ), - ZirStatement::Assertion(e, error) => { - ZirStatement::Assertion(f.fold_boolean_expression(e), error) - } - ZirStatement::MultipleDefinition(variables, elist) => ZirStatement::MultipleDefinition( - variables.into_iter().map(|v| f.fold_variable(v)).collect(), - f.fold_expression_list(elist), - ), - ZirStatement::Log(l, e) => ZirStatement::Log( - l, - e.into_iter() - .map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect())) - .collect(), - ), - ZirStatement::Assembly(statements) => ZirStatement::Assembly( - statements + let span = s.get_span(); + f.fold_statement_cases(s) + .into_iter() + .map(|s| s.span(span)) + .collect() +} + +pub fn fold_statement_cases<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ZirStatement<'ast, T>, +) -> Vec> { + let span = s.get_span(); + + match s { + ZirStatement::Return(s) => f.fold_return_statement(s), + ZirStatement::Definition(s) => f.fold_definition_statement(s), + ZirStatement::IfElse(s) => f.fold_if_else_statement(s), + ZirStatement::Assertion(s) => f.fold_assertion_statement(s), + ZirStatement::MultipleDefinition(s) => f.fold_multiple_definition_statement(s), + ZirStatement::Log(s) => f.fold_log_statement(s), + ZirStatement::Assembly(s) => f.fold_assembly_block(s), + } + .into_iter() + .map(|s| s.span(span)) + .collect() +} + +pub fn fold_multiple_definition_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: MultipleDefinitionStatement<'ast, T>, +) -> Vec> { + let expression_list = f.fold_expression_list(s.rhs); + vec![ZirStatement::MultipleDefinition( + MultipleDefinitionStatement::new( + s.assignees .into_iter() - .flat_map(|s| f.fold_assembly_statement(s)) + .map(|v| f.fold_variable(v)) .collect(), + expression_list, ), - }; - vec![res] + )] +} + +pub fn fold_if_else_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: IfElseStatement<'ast, T>, +) -> Vec> { + vec![ZirStatement::IfElse(IfElseStatement::new( + f.fold_boolean_expression(s.condition), + s.consequence + .into_iter() + .flat_map(|e| f.fold_statement(e)) + .collect(), + s.alternative + .into_iter() + .flat_map(|e| f.fold_statement(e)) + .collect(), + ))] +} + +pub fn fold_definition_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: DefinitionStatement<'ast, T>, +) -> Vec> { + let rhs = f.fold_expression(s.rhs); + vec![ZirStatement::Definition(DefinitionStatement::new( + f.fold_assignee(s.assignee), + rhs, + ))] +} + +pub fn fold_assertion_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssertionStatement<'ast, T>, +) -> Vec> { + vec![ZirStatement::Assertion(AssertionStatement::new( + f.fold_boolean_expression(s.expression), + s.error, + ))] +} + +pub fn fold_return_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ReturnStatement<'ast, T>, +) -> Vec> { + vec![ZirStatement::Return(ReturnStatement::new( + s.inner.into_iter().map(|e| f.fold_expression(e)).collect(), + ))] +} + +pub fn fold_log_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: LogStatement<'ast, T>, +) -> Vec> { + vec![ZirStatement::Log(LogStatement::new( + s.format_string, + s.expressions + .into_iter() + .map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect())) + .collect(), + ))] +} + +pub fn fold_assembly_block<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyBlockStatement<'ast, T>, +) -> Vec> { + vec![ZirStatement::Assembly(AssemblyBlockStatement::new( + s.inner + .into_iter() + .flat_map(|s| f.fold_assembly_statement(s)) + .collect(), + ))] } pub fn fold_identifier_expression< @@ -217,16 +418,24 @@ pub fn fold_identifier_expression< f: &mut F, _: &E::Ty, e: IdentifierExpression<'ast, E>, -) -> IdentifierOrExpression<'ast, T, E> { +) -> IdentifierOrExpression, E, E::Inner> { IdentifierOrExpression::Identifier(IdentifierExpression::new(f.fold_name(e.id))) } -pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> FieldElementExpression<'ast, T> { + let span = e.get_span(); + f.fold_field_expression_cases(e).span(span) +} + +pub fn fold_field_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: FieldElementExpression<'ast, T>, ) -> FieldElementExpression<'ast, T> { match e { - FieldElementExpression::Number(n) => FieldElementExpression::Number(n), + FieldElementExpression::Value(n) => FieldElementExpression::Value(n), FieldElementExpression::Identifier(id) => { match f.fold_identifier_expression(&Type::FieldElement, id) { IdentifierOrExpression::Identifier(i) => FieldElementExpression::Identifier(i), @@ -239,60 +448,49 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( SelectOrExpression::Expression(u) => u, } } - FieldElementExpression::Add(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldElementExpression::Add(box e1, box e2) - } - FieldElementExpression::Sub(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldElementExpression::Sub(box e1, box e2) - } - FieldElementExpression::Mult(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldElementExpression::Mult(box e1, box e2) - } - FieldElementExpression::Div(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldElementExpression::Div(box e1, box e2) - } - FieldElementExpression::Pow(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_uint_expression(e2); - FieldElementExpression::Pow(box e1, box e2) - } - FieldElementExpression::And(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - FieldElementExpression::And(box left, box right) - } - FieldElementExpression::Or(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - FieldElementExpression::Or(box left, box right) - } - FieldElementExpression::Xor(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - FieldElementExpression::Xor(box left, box right) - } - FieldElementExpression::LeftShift(box e, box by) => { - let e = f.fold_field_expression(e); - let by = f.fold_uint_expression(by); - - FieldElementExpression::LeftShift(box e, box by) + FieldElementExpression::Add(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Add(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Sub(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Sub(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Mult(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Mult(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Div(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Div(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Pow(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Pow(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::And(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::And(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Or(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Or(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Xor(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Xor(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::LeftShift(e) => { + match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::LeftShift(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::RightShift(box e, box by) => { - let e = f.fold_field_expression(e); - let by = f.fold_uint_expression(by); - - FieldElementExpression::RightShift(box e, box by) + FieldElementExpression::RightShift(e) => { + match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::RightShift(e), + BinaryOrExpression::Expression(e) => e, + } } FieldElementExpression::Conditional(c) => { match f.fold_conditional_expression(&Type::FieldElement, c) { @@ -303,10 +501,20 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { + let span = e.get_span(); + f.fold_boolean_expression_cases(e).span(span) +} + +pub fn fold_boolean_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> BooleanExpression<'ast, T> { + use BooleanExpression::*; + match e { BooleanExpression::Value(v) => BooleanExpression::Value(v), BooleanExpression::Identifier(id) => match f.fold_identifier_expression(&Type::Boolean, id) @@ -318,55 +526,46 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( SelectOrExpression::Select(s) => BooleanExpression::Select(s), SelectOrExpression::Expression(u) => u, }, - BooleanExpression::FieldEq(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - BooleanExpression::FieldEq(box e1, box e2) - } - BooleanExpression::BoolEq(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1); - let e2 = f.fold_boolean_expression(e2); - BooleanExpression::BoolEq(box e1, box e2) - } - BooleanExpression::UintEq(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - BooleanExpression::UintEq(box e1, box e2) - } - BooleanExpression::FieldLt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - BooleanExpression::FieldLt(box e1, box e2) - } - BooleanExpression::UintLt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - BooleanExpression::UintLt(box e1, box e2) - } - BooleanExpression::FieldLe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - BooleanExpression::FieldLe(box e1, box e2) - } - BooleanExpression::UintLe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - BooleanExpression::UintLe(box e1, box e2) - } - BooleanExpression::Or(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1); - let e2 = f.fold_boolean_expression(e2); - BooleanExpression::Or(box e1, box e2) - } - BooleanExpression::And(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1); - let e2 = f.fold_boolean_expression(e2); - BooleanExpression::And(box e1, box e2) - } - BooleanExpression::Not(box e) => { - let e = f.fold_boolean_expression(e); - BooleanExpression::Not(box e) - } + FieldEq(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => FieldEq(e), + BinaryOrExpression::Expression(u) => u, + }, + BoolEq(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => BoolEq(e), + BinaryOrExpression::Expression(u) => u, + }, + UintEq(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => UintEq(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLt(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => FieldLt(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLe(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => FieldLe(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLt(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => UintLt(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLe(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => UintLe(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&Type::Boolean, e) { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, BooleanExpression::Conditional(c) => match f.fold_conditional_expression(&Type::Boolean, c) { ConditionalOrExpression::Conditional(s) => BooleanExpression::Conditional(s), @@ -385,86 +584,78 @@ pub fn fold_uint_expression<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ty: UBitwidth, e: UExpressionInner<'ast, T>, ) -> UExpressionInner<'ast, T> { + let span = e.get_span(); + f.fold_uint_expression_cases(ty, e).span(span) +} + +pub fn fold_uint_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ty: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> UExpressionInner<'ast, T> { + use UExpressionInner::*; + match e { - UExpressionInner::Value(v) => UExpressionInner::Value(v), - UExpressionInner::Identifier(id) => match f.fold_identifier_expression(&ty, id) { + Value(v) => UExpressionInner::Value(v), + Identifier(id) => match f.fold_identifier_expression(&ty, id) { IdentifierOrExpression::Identifier(i) => UExpressionInner::Identifier(i), IdentifierOrExpression::Expression(e) => e, }, - UExpressionInner::Select(e) => match f.fold_select_expression(&ty, e) { + Select(e) => match f.fold_select_expression(&ty, e) { SelectOrExpression::Select(s) => UExpressionInner::Select(s), SelectOrExpression::Expression(u) => u, }, - UExpressionInner::Add(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Add(box left, box right) - } - UExpressionInner::Sub(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Sub(box left, box right) - } - UExpressionInner::Mult(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Mult(box left, box right) - } - UExpressionInner::Div(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Div(box left, box right) - } - UExpressionInner::Rem(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Rem(box left, box right) - } - UExpressionInner::Xor(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Xor(box left, box right) - } - UExpressionInner::And(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::And(box left, box right) - } - UExpressionInner::Or(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Or(box left, box right) - } - UExpressionInner::LeftShift(box e, by) => { - let e = f.fold_uint_expression(e); - - UExpressionInner::LeftShift(box e, by) - } - UExpressionInner::RightShift(box e, by) => { - let e = f.fold_uint_expression(e); - - UExpressionInner::RightShift(box e, by) - } - UExpressionInner::Not(box e) => { - let e = f.fold_uint_expression(e); - - UExpressionInner::Not(box e) - } + Add(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Rem(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Rem(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&ty, e) { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c) { - ConditionalOrExpression::Conditional(s) => UExpressionInner::Conditional(s), + ConditionalOrExpression::Conditional(s) => Conditional(s), ConditionalOrExpression::Expression(u) => u, }, } @@ -495,13 +686,14 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( ) -> ZirProgram<'ast, T> { ZirProgram { main: f.fold_function(p.main), + ..p } } pub fn fold_conditional_expression< 'ast, T: Field, - E: Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>, + E: Expr<'ast, T> + Fold + Conditional<'ast, T>, F: Folder<'ast, T>, >( f: &mut F, @@ -518,7 +710,7 @@ pub fn fold_conditional_expression< pub fn fold_select_expression< 'ast, T: Field, - E: Expr<'ast, T> + Fold<'ast, T> + Select<'ast, T>, + E: Expr<'ast, T> + Fold + Select<'ast, T>, F: Folder<'ast, T>, >( f: &mut F, @@ -530,3 +722,34 @@ pub fn fold_select_expression< e.index.fold(f), )) } + +pub fn fold_binary_expression< + 'ast, + T: Field, + L: Expr<'ast, T> + Fold + From>, + R: Expr<'ast, T> + Fold + From>, + E: Expr<'ast, T> + Fold + From>, + F: Folder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: BinaryExpression, +) -> BinaryOrExpression { + BinaryOrExpression::Binary(BinaryExpression::new(e.left.fold(f), e.right.fold(f)).span(e.span)) +} + +pub fn fold_unary_expression< + 'ast, + T: Field, + In: Expr<'ast, T> + Fold + From>, + E: Expr<'ast, T> + Fold + From>, + F: Folder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: UnaryExpression, +) -> UnaryOrExpression { + UnaryOrExpression::Unary(UnaryExpression::new(e.inner.fold(f)).span(e.span)) +} diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index f0c387df3..a7ae6b959 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -20,17 +20,6 @@ pub enum SourceIdentifier<'ast> { Element(Box>, u32), } -impl<'ast> fmt::Display for SourceIdentifier<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - SourceIdentifier::Basic(i) => write!(f, "{}", i), - SourceIdentifier::Select(box i, index) => write!(f, "{}~{}", i, index), - SourceIdentifier::Member(box i, m) => write!(f, "{}.{}", i, m), - SourceIdentifier::Element(box i, index) => write!(f, "{}.{}", i, index), - } - } -} - impl<'ast> Identifier<'ast> { pub fn internal>(id: T) -> Self { Identifier::Internal(id.into()) @@ -46,6 +35,17 @@ impl<'ast> fmt::Display for Identifier<'ast> { } } +impl<'ast> fmt::Display for SourceIdentifier<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + SourceIdentifier::Basic(i) => write!(f, "{}", i), + SourceIdentifier::Select(box i, index) => write!(f, "{}~{}", i, index), + SourceIdentifier::Member(box i, m) => write!(f, "{}.{}", i, m), + SourceIdentifier::Element(box i, index) => write!(f, "{}.{}", i, index), + } + } +} + // this is only used in tests but somehow cfg(test) does not work impl<'ast> From<&'ast str> for Identifier<'ast> { fn from(id: &'ast str) -> Identifier<'ast> { diff --git a/zokrates_ast/src/zir/lqc.rs b/zokrates_ast/src/zir/lqc.rs index 834f601f4..e34877c5a 100644 --- a/zokrates_ast/src/zir/lqc.rs +++ b/zokrates_ast/src/zir/lqc.rs @@ -109,23 +109,23 @@ impl<'ast, T: Field> TryFrom> for LinQuadComb<'a fn try_from(e: FieldElementExpression<'ast, T>) -> Result { match e { - FieldElementExpression::Number(v) => Ok(Self { - constant: v, + FieldElementExpression::Value(v) => Ok(Self { + constant: v.value, ..Self::default() }), FieldElementExpression::Identifier(id) => Ok(Self { linear: vec![(T::one(), id.id)], ..Self::default() }), - FieldElementExpression::Add(box left, box right) => { - Ok(Self::try_from(left)? + Self::try_from(right)?) + FieldElementExpression::Add(e) => { + Ok(Self::try_from(*e.left)? + Self::try_from(*e.right)?) } - FieldElementExpression::Sub(box left, box right) => { - Ok(Self::try_from(left)? - Self::try_from(right)?) + FieldElementExpression::Sub(e) => { + Ok(Self::try_from(*e.left)? - Self::try_from(*e.right)?) } - FieldElementExpression::Mult(box left, box right) => { - let left = Self::try_from(left)?; - let right = Self::try_from(right)?; + FieldElementExpression::Mult(e) => { + let left = Self::try_from(*e.left)?; + let right = Self::try_from(*e.right)?; left.try_mul(right) } @@ -137,30 +137,31 @@ impl<'ast, T: Field> TryFrom> for LinQuadComb<'a #[cfg(test)] mod tests { use super::*; - use crate::zir::Id; + use crate::zir::{Expr, Id}; + use std::ops::*; use zokrates_field::Bn128Field; #[test] fn add() { // (2 + 2*a) - let a = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + let a = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), )) .unwrap(); // (2 + 2*a*b) - let b = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + let b = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ), )) .unwrap(); @@ -186,24 +187,24 @@ mod tests { #[test] fn sub() { // (2 + 2*a) - let a = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + let a = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), )) .unwrap(); // (2 + 2*a*b) - let b = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + let b = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ), )) .unwrap(); @@ -227,23 +228,23 @@ mod tests { } #[test] - fn mult() { + fn mul() { // (2 + 2*a) - let a = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + let a = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), )) .unwrap(); // (2 + 2*b) - let b = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("b".into()), + let b = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("b".into()), ), )) .unwrap(); @@ -266,13 +267,13 @@ mod tests { } #[test] - fn mult_degree_error() { + fn mul_degree_error() { // 2*a*b - let a = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let a = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), )) .unwrap(); diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index 081109666..dcb4cefaf 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -12,13 +12,20 @@ mod variable; pub use self::parameter::Parameter; pub use self::types::{Type, UBitwidth}; pub use self::variable::Variable; -use crate::common::{FlatEmbed, FormatString, SourceMetadata}; +use crate::common::expressions::{BooleanValueExpression, UnaryExpression}; +use crate::common::SourceMetadata; +use crate::common::{self, FlatEmbed, ModuleMap, Span, Value, WithSpan}; +use crate::common::{ + expressions::{self, BinaryExpression, ValueExpression}, + operators::*, +}; use crate::typed::ConcreteType; pub use crate::zir::uint::{ShouldReduce, UExpression, UExpressionInner, UMetadata}; use crate::zir::types::Signature; use std::fmt; -use std::marker::PhantomData; + +use derivative::Derivative; use zokrates_field::Field; pub use self::folder::Folder; @@ -26,9 +33,12 @@ pub use self::identifier::{Identifier, SourceIdentifier}; use serde::{Deserialize, Serialize}; /// A typed program as a collection of modules, one of them being the main -#[derive(PartialEq, Eq, Debug, Clone)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug)] pub struct ZirProgram<'ast, T> { pub main: ZirFunction<'ast, T>, + pub module_map: ModuleMap, } impl<'ast, T: fmt::Display> fmt::Display for ZirProgram<'ast, T> { @@ -37,10 +47,11 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirProgram<'ast, T> { } } /// A typed function -#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash, Eq)] +#[derive(Clone, Serialize, Deserialize)] pub struct ZirFunction<'ast, T> { /// Arguments of the function - #[serde(borrow)] pub arguments: Vec>, /// Vector of statements that are executed when running the function #[serde(borrow)] @@ -49,6 +60,9 @@ pub struct ZirFunction<'ast, T> { pub signature: Signature, } +pub type IdentifierOrExpression<'ast, T, E> = + expressions::IdentifierOrExpression, E, >::Inner>; + impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!( @@ -118,58 +132,263 @@ impl RuntimeError { } } +pub type AssemblyConstraint<'ast, T> = + crate::common::statements::AssemblyConstraint>; +pub type AssemblyAssignment<'ast, T> = + crate::common::statements::AssemblyAssignment>, ZirFunction<'ast, T>>; + #[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] pub enum ZirAssemblyStatement<'ast, T> { - Assignment( - #[serde(borrow)] Vec>, - ZirFunction<'ast, T>, - ), - Constraint( - FieldElementExpression<'ast, T>, - FieldElementExpression<'ast, T>, - SourceMetadata, - ), + #[serde(borrow)] + Assignment(AssemblyAssignment<'ast, T>), + Constraint(AssemblyConstraint<'ast, T>), +} + +impl<'ast, T> ZirAssemblyStatement<'ast, T> { + pub fn assignment(assignee: Vec>, expression: ZirFunction<'ast, T>) -> Self { + ZirAssemblyStatement::Assignment(AssemblyAssignment::new(assignee, expression)) + } + + pub fn constraint( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + metadata: SourceMetadata, + ) -> Self { + ZirAssemblyStatement::Constraint(AssemblyConstraint::new(left, right, metadata)) + } +} + +impl<'ast, T> WithSpan for ZirAssemblyStatement<'ast, T> { + fn span(self, span: Option) -> Self { + match self { + ZirAssemblyStatement::Assignment(s) => ZirAssemblyStatement::Assignment(s.span(span)), + ZirAssemblyStatement::Constraint(s) => ZirAssemblyStatement::Constraint(s.span(span)), + } + } + + fn get_span(&self) -> Option { + match self { + ZirAssemblyStatement::Assignment(s) => s.get_span(), + ZirAssemblyStatement::Constraint(s) => s.get_span(), + } + } } impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - ZirAssemblyStatement::Assignment(ref lhs, ref rhs) => { + ZirAssemblyStatement::Assignment(ref s) => { write!( f, "{} <-- {};", - lhs.iter() + s.assignee + .iter() .map(|a| a.to_string()) .collect::>() .join(", "), - rhs + s.expression ) } - ZirAssemblyStatement::Constraint(ref lhs, ref rhs, _) => { - write!(f, "{} === {};", lhs, rhs) + ZirAssemblyStatement::Constraint(ref s) => { + write!(f, "{}", s) } } } } +pub type DefinitionStatement<'ast, T> = + common::expressions::DefinitionStatement, ZirExpression<'ast, T>>; +pub type AssertionStatement<'ast, T> = + common::expressions::AssertionStatement, RuntimeError>; +pub type ReturnStatement<'ast, T> = + common::expressions::ReturnStatement>>; +pub type LogStatement<'ast, T> = + common::statements::LogStatement<(ConcreteType, Vec>)>; + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct IfElseStatement<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + #[serde(borrow)] + pub condition: BooleanExpression<'ast, T>, + pub consequence: Vec>, + pub alternative: Vec>, +} + +impl<'ast, T> IfElseStatement<'ast, T> { + pub fn new( + condition: BooleanExpression<'ast, T>, + consequence: Vec>, + alternative: Vec>, + ) -> Self { + Self { + span: None, + condition, + consequence, + alternative, + } + } +} + +impl<'ast, T> WithSpan for IfElseStatement<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct MultipleDefinitionStatement<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + #[serde(borrow)] + pub assignees: Vec>, + pub rhs: ZirExpressionList<'ast, T>, +} + +impl<'ast, T> MultipleDefinitionStatement<'ast, T> { + pub fn new(assignees: Vec>, rhs: ZirExpressionList<'ast, T>) -> Self { + Self { + span: None, + assignees, + rhs, + } + } +} + +impl<'ast, T> WithSpan for MultipleDefinitionStatement<'ast, T> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T: fmt::Display> fmt::Display for MultipleDefinitionStatement<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (i, id) in self.assignees.iter().enumerate() { + write!(f, "{}", id)?; + if i < self.assignees.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, " = {};", self.rhs) + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AssemblyBlockStatement<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + #[serde(borrow)] + pub inner: Vec>, +} + +impl<'ast, T> AssemblyBlockStatement<'ast, T> { + pub fn new(inner: Vec>) -> Self { + Self { span: None, inner } + } +} + +impl<'ast, T> WithSpan for AssemblyBlockStatement<'ast, T> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + /// A statement in a `ZirFunction` -#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum ZirStatement<'ast, T> { - Return(Vec>), - Definition(ZirAssignee<'ast>, ZirExpression<'ast, T>), - IfElse( - BooleanExpression<'ast, T>, - Vec>, - Vec>, - ), - Assertion(BooleanExpression<'ast, T>, RuntimeError), - MultipleDefinition(Vec>, ZirExpressionList<'ast, T>), - Log( - FormatString, - Vec<(ConcreteType, Vec>)>, - ), + Return(ReturnStatement<'ast, T>), + Definition(DefinitionStatement<'ast, T>), + IfElse(IfElseStatement<'ast, T>), + Assertion(AssertionStatement<'ast, T>), + MultipleDefinition(MultipleDefinitionStatement<'ast, T>), + Log(LogStatement<'ast, T>), #[serde(borrow)] - Assembly(Vec>), + Assembly(AssemblyBlockStatement<'ast, T>), +} + +impl<'ast, T> ZirStatement<'ast, T> { + pub fn definition(a: ZirAssignee<'ast>, e: ZirExpression<'ast, T>) -> Self { + Self::Definition(DefinitionStatement::new(a, e)) + } + + pub fn multiple_definition( + assignees: Vec>, + e: ZirExpressionList<'ast, T>, + ) -> Self { + Self::MultipleDefinition(MultipleDefinitionStatement::new(assignees, e)) + } + + pub fn assertion(e: BooleanExpression<'ast, T>, error: RuntimeError) -> Self { + Self::Assertion(AssertionStatement::new(e, error)) + } + + pub fn ret(e: Vec>) -> Self { + Self::Return(ReturnStatement::new(e)) + } + + pub fn assembly(s: Vec>) -> Self { + Self::Assembly(AssemblyBlockStatement::new(s)) + } + + pub fn if_else( + condition: BooleanExpression<'ast, T>, + consequence: Vec>, + alternative: Vec>, + ) -> Self { + Self::IfElse(IfElseStatement::new(condition, consequence, alternative)) + } +} + +impl<'ast, T> WithSpan for ZirStatement<'ast, T> { + fn span(self, span: Option) -> Self { + use ZirStatement::*; + + match self { + Return(e) => Return(e.span(span)), + Definition(e) => Definition(e.span(span)), + Assertion(e) => Assertion(e.span(span)), + IfElse(e) => IfElse(e.span(span)), + MultipleDefinition(e) => MultipleDefinition(e.span(span)), + Log(e) => Log(e.span(span)), + Assembly(e) => Assembly(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use ZirStatement::*; + + match self { + Return(e) => e.get_span(), + Definition(e) => e.get_span(), + Assertion(e) => e.get_span(), + IfElse(e) => e.get_span(), + MultipleDefinition(e) => e.get_span(), + Log(e) => e.get_span(), + Assembly(e) => e.get_span(), + } + } } impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> { @@ -181,14 +400,15 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> { impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result { write!(f, "{}", "\t".repeat(depth))?; + match self { - ZirStatement::Return(ref exprs) => { + ZirStatement::Return(ref s) => { write!(f, "return")?; - if !exprs.is_empty() { + if !s.inner.is_empty() { write!( f, " {}", - exprs + s.inner .iter() .map(|e| e.to_string()) .collect::>() @@ -197,43 +417,37 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { } write!(f, ";") } - ZirStatement::Definition(ref lhs, ref rhs) => { - write!(f, "{} = {};", lhs, rhs) + ZirStatement::Definition(ref s) => { + write!(f, "{}", s) } - ZirStatement::IfElse(ref condition, ref consequence, ref alternative) => { - writeln!(f, "if {} {{", condition)?; - for s in consequence { + ZirStatement::IfElse(ref s) => { + writeln!(f, "if {} {{", s.condition)?; + for s in &s.consequence { s.fmt_indented(f, depth + 1)?; writeln!(f)?; } writeln!(f, "{}}} else {{", "\t".repeat(depth))?; - for s in alternative { + for s in &s.alternative { s.fmt_indented(f, depth + 1)?; writeln!(f)?; } write!(f, "{}}}", "\t".repeat(depth)) } - ZirStatement::Assertion(ref e, ref error) => { - write!(f, "assert({}", e)?; - match error { + ZirStatement::Assertion(ref s) => { + write!(f, "assert({}", s.expression)?; + match &s.error { RuntimeError::SourceAssertion(message) => write!(f, ", \"{}\");", message), error => write!(f, "); // {}", error), } } - ZirStatement::MultipleDefinition(ref ids, ref rhs) => { - for (i, id) in ids.iter().enumerate() { - write!(f, "{}", id)?; - if i < ids.len() - 1 { - write!(f, ", ")?; - } - } - write!(f, " = {};", rhs) + ZirStatement::MultipleDefinition(ref s) => { + write!(f, "{}", s) } - ZirStatement::Log(ref l, ref expressions) => write!( + ZirStatement::Log(ref e) => write!( f, "log(\"{}\"), {});", - l, - expressions + e.format_string, + e.expressions .iter() .map(|(_, e)| format!( "[{}]", @@ -245,9 +459,9 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { .collect::>() .join(", ") ), - ZirStatement::Assembly(statements) => { + ZirStatement::Assembly(s) => { writeln!(f, "asm {{")?; - for s in statements { + for s in &s.inner { writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?; } write!(f, "{}}}", "\t".repeat(depth)) @@ -260,30 +474,12 @@ pub trait Typed { fn get_type(&self) -> Type; } -#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] -pub struct IdentifierExpression<'ast, E> { - #[serde(borrow)] - pub id: Identifier<'ast>, - ty: PhantomData, -} - -impl<'ast, E> fmt::Display for IdentifierExpression<'ast, E> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.id) - } -} - -impl<'ast, E> IdentifierExpression<'ast, E> { - pub fn new(id: Identifier<'ast>) -> Self { - IdentifierExpression { - id, - ty: PhantomData, - } - } -} - -#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct ConditionalExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, #[serde(borrow)] pub condition: Box>, pub consequence: Box, @@ -293,6 +489,7 @@ pub struct ConditionalExpression<'ast, T, E> { impl<'ast, T, E> ConditionalExpression<'ast, T, E> { pub fn new(condition: BooleanExpression<'ast, T>, consequence: E, alternative: E) -> Self { ConditionalExpression { + span: None, condition: box condition, consequence: box consequence, alternative: box alternative, @@ -305,13 +502,28 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpress write!( f, "{} ? {} : {}", - self.condition, self.consequence, self.alternative + self.condition, self.consequence, self.alternative, ) } } -#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] +impl<'ast, T, E> WithSpan for ConditionalExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct SelectExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub array: Vec, #[serde(borrow)] pub index: Box>, @@ -320,6 +532,7 @@ pub struct SelectExpression<'ast, T, E> { impl<'ast, T, E> SelectExpression<'ast, T, E> { pub fn new(array: Vec, index: UExpression<'ast, T>) -> Self { SelectExpression { + span: None, array, index: box index, } @@ -341,12 +554,46 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for SelectExpression<' } } +impl<'ast, T, E> WithSpan for SelectExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + /// A typed expression -#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Serialize, Deserialize)] pub enum ZirExpression<'ast, T> { + #[serde(borrow)] Boolean(BooleanExpression<'ast, T>), FieldElement(FieldElementExpression<'ast, T>), - Uint(#[serde(borrow)] UExpression<'ast, T>), + Uint(UExpression<'ast, T>), +} + +impl<'ast, T> WithSpan for ZirExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use ZirExpression::*; + match self { + Boolean(e) => Boolean(e.span(span)), + FieldElement(e) => FieldElement(e.span(span)), + Uint(e) => Uint(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use ZirExpression::*; + match self { + Boolean(e) => e.get_span(), + FieldElement(e) => e.get_span(), + Uint(e) => e.get_span(), + } + } } impl<'ast, T: Field> From> for ZirExpression<'ast, T> { @@ -419,7 +666,9 @@ pub trait MultiTyped { fn get_types(&self) -> &Vec; } -#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Serialize, Deserialize)] pub enum ZirExpressionList<'ast, T> { EmbedCall( FlatEmbed, @@ -428,111 +677,126 @@ pub enum ZirExpressionList<'ast, T> { ), } +pub type IdentifierExpression<'ast, E> = expressions::IdentifierExpression, E>; + /// An expression of type `field` -#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Eq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum FieldElementExpression<'ast, T> { - Number(T), + Value(ValueExpression), #[serde(borrow)] Identifier(IdentifierExpression<'ast, Self>), Select(SelectExpression<'ast, T, Self>), - Add( - Box>, - Box>, - ), - Sub( - Box>, - Box>, - ), - Mult( - Box>, - Box>, - ), - Div( - Box>, - Box>, - ), - Pow( - Box>, - #[serde(borrow)] Box>, - ), - And( - Box>, - Box>, - ), - Or( - Box>, - Box>, - ), - Xor( - Box>, - Box>, - ), - LeftShift( - Box>, - Box>, - ), - RightShift( - Box>, - Box>, - ), + Add(BinaryExpression), + Sub(BinaryExpression), + Mult(BinaryExpression), + Div(BinaryExpression), + Pow(BinaryExpression, Self>), + And(BinaryExpression), + Or(BinaryExpression), + Xor(BinaryExpression), + LeftShift(BinaryExpression, Self>), + RightShift(BinaryExpression, Self>), Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>), } impl<'ast, T> FieldElementExpression<'ast, T> { + pub fn number(n: T) -> Self { + Self::Value(ValueExpression::new(n)) + } + + pub fn pow(self, right: UExpression<'ast, T>) -> Self { + Self::Pow(BinaryExpression::new(self, right)) + } + pub fn is_linear(&self) -> bool { match self { - FieldElementExpression::Number(_) => true, + FieldElementExpression::Value(_) => true, FieldElementExpression::Identifier(_) => true, - FieldElementExpression::Add(box left, box right) => { - left.is_linear() && right.is_linear() - } - FieldElementExpression::Sub(box left, box right) => { - left.is_linear() && right.is_linear() - } - FieldElementExpression::Mult(box left, box right) => matches!( - (left, right), - (FieldElementExpression::Number(_), _) | (_, FieldElementExpression::Number(_)) + FieldElementExpression::Add(e) => e.left.is_linear() && e.right.is_linear(), + FieldElementExpression::Sub(e) => e.left.is_linear() && e.right.is_linear(), + FieldElementExpression::Mult(e) => matches!( + (&*e.left, &*e.right), + (FieldElementExpression::Value(_), _) | (_, FieldElementExpression::Value(_)) ), _ => false, } } + + pub fn left_shift(self, by: UExpression<'ast, T>) -> Self { + FieldElementExpression::LeftShift(BinaryExpression::new(self, by)) + } + + pub fn right_shift(self, by: UExpression<'ast, T>) -> Self { + FieldElementExpression::RightShift(BinaryExpression::new(self, by)) + } +} + +impl<'ast, T: Field> std::ops::BitAnd for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + FieldElementExpression::And(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::BitOr for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + FieldElementExpression::Or(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::BitXor for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitxor(self, other: Self) -> Self { + FieldElementExpression::Xor(BinaryExpression::new(self, other)) + } } /// An expression of type `bool` -#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum BooleanExpression<'ast, T> { - Value(bool), + Value(BooleanValueExpression), #[serde(borrow)] Identifier(IdentifierExpression<'ast, Self>), Select(SelectExpression<'ast, T, Self>), FieldLt( - Box>, - Box>, + BinaryExpression< + OpLt, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), FieldLe( - Box>, - Box>, + BinaryExpression< + OpLe, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), FieldEq( - Box>, - Box>, - ), - UintLt(Box>, Box>), - UintLe(Box>, Box>), - UintEq(Box>, Box>), - BoolEq( - Box>, - Box>, - ), - Or( - Box>, - Box>, - ), - And( - Box>, - Box>, + BinaryExpression< + OpEq, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), - Not(Box>), + UintLt(BinaryExpression, UExpression<'ast, T>, Self>), + UintLe(BinaryExpression, UExpression<'ast, T>, Self>), + UintEq(BinaryExpression, UExpression<'ast, T>, Self>), + BoolEq(BinaryExpression), + Or(BinaryExpression), + And(BinaryExpression), + Not(UnaryExpression), Conditional(ConditionalExpression<'ast, T, BooleanExpression<'ast, T>>), } @@ -545,9 +809,9 @@ impl<'ast, T> Iterator for ConjunctionIterator> { fn next(&mut self) -> Option { self.current.pop().and_then(|n| match n { - BooleanExpression::And(box left, box right) => { - self.current.push(left); - self.current.push(right); + BooleanExpression::And(e) => { + self.current.push(*e.left); + self.current.push(*e.right); self.next() } n => Some(n), @@ -594,23 +858,19 @@ impl<'ast, T> From> for UExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FieldElementExpression::Number(ref i) => write!(f, "{}", i), + FieldElementExpression::Value(ref i) => write!(f, "{}", i), FieldElementExpression::Identifier(ref var) => write!(f, "{}", var), FieldElementExpression::Select(ref e) => write!(f, "{}", e), - FieldElementExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - FieldElementExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), - FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs), - FieldElementExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), - FieldElementExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), - FieldElementExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), - FieldElementExpression::LeftShift(ref lhs, ref rhs) => { - write!(f, "({} << {})", lhs, rhs) - } - FieldElementExpression::RightShift(ref lhs, ref rhs) => { - write!(f, "({} >> {})", lhs, rhs) - } + FieldElementExpression::Add(ref e) => write!(f, "{}", e), + FieldElementExpression::Sub(ref e) => write!(f, "{}", e), + FieldElementExpression::Mult(ref e) => write!(f, "{}", e), + FieldElementExpression::Div(ref e) => write!(f, "{}", e), + FieldElementExpression::Pow(ref e) => write!(f, "{}", e), + FieldElementExpression::And(ref e) => write!(f, "{}", e), + FieldElementExpression::Or(ref e) => write!(f, "{}", e), + FieldElementExpression::Xor(ref e) => write!(f, "{}", e), + FieldElementExpression::LeftShift(ref e) => write!(f, "{}", e), + FieldElementExpression::RightShift(ref e) => write!(f, "{}", e), FieldElementExpression::Conditional(ref c) => { write!(f, "{}", c) } @@ -624,17 +884,17 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { UExpressionInner::Value(ref v) => write!(f, "{}", v), UExpressionInner::Identifier(ref var) => write!(f, "{}", var), UExpressionInner::Select(ref e) => write!(f, "{}", e), - UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - UExpressionInner::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - UExpressionInner::Div(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - UExpressionInner::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs), - UExpressionInner::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), - UExpressionInner::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), - UExpressionInner::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), - UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), - UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), - UExpressionInner::Not(ref e) => write!(f, "!{}", e), + UExpressionInner::Add(ref e) => write!(f, "{}", e), + UExpressionInner::Sub(ref e) => write!(f, "{}", e), + UExpressionInner::Mult(ref e) => write!(f, "{}", e), + UExpressionInner::Div(ref e) => write!(f, "{}", e), + UExpressionInner::Rem(ref e) => write!(f, "{}", e), + UExpressionInner::Xor(ref e) => write!(f, "{}", e), + UExpressionInner::And(ref e) => write!(f, "{}", e), + UExpressionInner::Or(ref e) => write!(f, "{}", e), + UExpressionInner::LeftShift(ref e) => write!(f, "{}", e), + UExpressionInner::RightShift(ref e) => write!(f, "{}", e), + UExpressionInner::Not(ref e) => write!(f, "{}", e), UExpressionInner::Conditional(ref c) => { write!(f, "{}", c) } @@ -646,18 +906,18 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { BooleanExpression::Identifier(ref var) => write!(f, "{}", var), - BooleanExpression::Value(b) => write!(f, "{}", b), + BooleanExpression::Value(ref b) => write!(f, "{}", b), BooleanExpression::Select(ref e) => write!(f, "{}", e), - BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs), - BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs), - BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs), - BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs), - BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs), - BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs), - BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs), - BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "({} || {})", lhs, rhs), - BooleanExpression::And(ref lhs, ref rhs) => write!(f, "({} && {})", lhs, rhs), - BooleanExpression::Not(ref exp) => write!(f, "!{}", exp), + BooleanExpression::FieldLt(ref e) => write!(f, "{}", e), + BooleanExpression::UintLt(ref e) => write!(f, "{}", e), + BooleanExpression::FieldLe(ref e) => write!(f, "{}", e), + BooleanExpression::UintLe(ref e) => write!(f, "{}", e), + BooleanExpression::FieldEq(ref e) => write!(f, "{}", e), + BooleanExpression::BoolEq(ref e) => write!(f, "{}", e), + BooleanExpression::UintEq(ref e) => write!(f, "{}", e), + BooleanExpression::Or(ref e) => write!(f, "{}", e), + BooleanExpression::And(ref e) => write!(f, "{}", e), + BooleanExpression::Not(ref exp) => write!(f, "{}", exp), BooleanExpression::Conditional(ref c) => { write!(f, "{}", c) } @@ -665,6 +925,69 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { } } +impl<'ast, T> std::ops::Not for BooleanExpression<'ast, T> { + type Output = Self; + + fn not(self) -> Self { + Self::Not(UnaryExpression::new(self)) + } +} + +impl<'ast, T> std::ops::BitAnd for BooleanExpression<'ast, T> { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + Self::And(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T> std::ops::BitOr for BooleanExpression<'ast, T> { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + Self::Or(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T> BooleanExpression<'ast, T> { + pub fn uint_eq(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintEq(BinaryExpression::new(left, right)) + } + + pub fn bool_eq(left: BooleanExpression<'ast, T>, right: BooleanExpression<'ast, T>) -> Self { + Self::BoolEq(BinaryExpression::new(left, right)) + } + + pub fn field_eq( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldEq(BinaryExpression::new(left, right)) + } + + pub fn uint_lt(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLt(BinaryExpression::new(left, right)) + } + + pub fn uint_le(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLe(BinaryExpression::new(left, right)) + } + + pub fn field_lt( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLt(BinaryExpression::new(left, right)) + } + + pub fn field_le( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLe(BinaryExpression::new(left, right)) + } +} + impl<'ast, T: fmt::Display> fmt::Display for ZirExpressionList<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -710,8 +1033,52 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ZirExpressionList<'ast, T> { } } +impl<'ast, T: Field> std::ops::Add for FieldElementExpression<'ast, T> { + type Output = Self; + + fn add(self, other: Self) -> Self { + FieldElementExpression::Add(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::Sub for FieldElementExpression<'ast, T> { + type Output = Self; + + fn sub(self, other: Self) -> Self { + FieldElementExpression::Sub(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::Mul for FieldElementExpression<'ast, T> { + type Output = Self; + + fn mul(self, other: Self) -> Self { + FieldElementExpression::Mult(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::Div for FieldElementExpression<'ast, T> { + type Output = Self; + + fn div(self, other: Self) -> Self { + FieldElementExpression::Div(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Clone> Value for FieldElementExpression<'ast, T> { + type Value = T; +} + +impl<'ast, T> Value for BooleanExpression<'ast, T> { + type Value = bool; +} + +impl<'ast, T> Value for UExpression<'ast, T> { + type Value = u128; +} + // Common behaviour across expressions -pub trait Expr<'ast, T>: fmt::Display + PartialEq + From> { +pub trait Expr<'ast, T>: Value + fmt::Display + PartialEq + From> { type Inner; type Ty: Clone + IntoType; @@ -722,6 +1089,8 @@ pub trait Expr<'ast, T>: fmt::Display + PartialEq + From> fn as_inner(&self) -> &Self::Inner; fn as_inner_mut(&mut self) -> &mut Self::Inner; + + fn value(_: Self::Value) -> Self::Inner; } impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> { @@ -743,6 +1112,10 @@ impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { self } + + fn value(v: ::Value) -> Self::Inner { + Self::Value(ValueExpression::new(v)) + } } impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> { @@ -764,6 +1137,10 @@ impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { self } + + fn value(v: as Value>::Value) -> Self::Inner { + Self::Value(ValueExpression::new(v)) + } } impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> { @@ -785,6 +1162,10 @@ impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { &mut self.inner } + + fn value(v: Self::Value) -> Self::Inner { + UExpressionInner::Value(ValueExpression::new(v)) + } } pub trait Id<'ast, T>: Expr<'ast, T> { @@ -809,11 +1190,6 @@ impl<'ast, T: Field> Id<'ast, T> for UExpression<'ast, T> { } } -pub enum IdentifierOrExpression<'ast, T, E: Expr<'ast, T>> { - Identifier(IdentifierExpression<'ast, E>), - Expression(E::Inner), -} - pub trait Conditional<'ast, T> { fn conditional( condition: BooleanExpression<'ast, T>, @@ -930,9 +1306,150 @@ impl<'ast, T: Field> Constant for ZirExpression<'ast, T> { } } +impl<'ast, T> WithSpan for FieldElementExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use FieldElementExpression::*; + match self { + Select(e) => Select(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + Add(e) => Add(e.span(span)), + Value(e) => Value(e.span(span)), + Sub(e) => Sub(e.span(span)), + Mult(e) => Mult(e.span(span)), + Div(e) => Div(e.span(span)), + Pow(e) => Pow(e.span(span)), + And(e) => And(e.span(span)), + Or(e) => Or(e.span(span)), + Xor(e) => Xor(e.span(span)), + LeftShift(e) => LeftShift(e.span(span)), + RightShift(e) => RightShift(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use FieldElementExpression::*; + match self { + Select(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + Add(e) => e.get_span(), + Value(e) => e.get_span(), + Sub(e) => e.get_span(), + Mult(e) => e.get_span(), + Div(e) => e.get_span(), + Pow(e) => e.get_span(), + And(e) => e.get_span(), + Or(e) => e.get_span(), + Xor(e) => e.get_span(), + LeftShift(e) => e.get_span(), + RightShift(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for BooleanExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use BooleanExpression::*; + match self { + Select(e) => Select(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + Value(e) => Value(e.span(span)), + FieldLt(e) => FieldLt(e.span(span)), + FieldLe(e) => FieldLe(e.span(span)), + FieldEq(e) => FieldEq(e.span(span)), + UintLt(e) => UintLt(e.span(span)), + UintLe(e) => UintLe(e.span(span)), + UintEq(e) => UintEq(e.span(span)), + BoolEq(e) => BoolEq(e.span(span)), + Or(e) => Or(e.span(span)), + And(e) => And(e.span(span)), + Not(e) => Not(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use BooleanExpression::*; + match self { + Select(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + Value(e) => e.get_span(), + FieldLt(e) => e.get_span(), + FieldLe(e) => e.get_span(), + FieldEq(e) => e.get_span(), + UintLt(e) => e.get_span(), + UintLe(e) => e.get_span(), + UintEq(e) => e.get_span(), + BoolEq(e) => e.get_span(), + Or(e) => e.get_span(), + And(e) => e.get_span(), + Not(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for UExpression<'ast, T> { + fn span(self, span: Option) -> Self { + Self { + inner: self.inner.span(span), + ..self + } + } + + fn get_span(&self) -> Option { + self.inner.get_span() + } +} + +impl<'ast, T> WithSpan for UExpressionInner<'ast, T> { + fn span(self, span: Option) -> Self { + use UExpressionInner::*; + match self { + Select(e) => Select(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + Value(e) => Value(e.span(span)), + Add(e) => Add(e.span(span)), + Sub(e) => Sub(e.span(span)), + Mult(e) => Mult(e.span(span)), + Div(e) => Div(e.span(span)), + Rem(e) => Rem(e.span(span)), + Xor(e) => Xor(e.span(span)), + And(e) => And(e.span(span)), + Or(e) => Or(e.span(span)), + LeftShift(e) => LeftShift(e.span(span)), + RightShift(e) => RightShift(e.span(span)), + Not(e) => Not(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use UExpressionInner::*; + match self { + Select(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + Value(e) => e.get_span(), + Add(e) => e.get_span(), + Sub(e) => e.get_span(), + Mult(e) => e.get_span(), + Div(e) => e.get_span(), + Rem(e) => e.get_span(), + Xor(e) => e.get_span(), + And(e) => e.get_span(), + Or(e) => e.get_span(), + LeftShift(e) => e.get_span(), + RightShift(e) => e.get_span(), + Not(e) => e.get_span(), + } + } +} + impl<'ast, T: Field> Constant for FieldElementExpression<'ast, T> { fn is_constant(&self) -> bool { - matches!(self, FieldElementExpression::Number(..)) + matches!(self, FieldElementExpression::Value(..)) } } diff --git a/zokrates_ast/src/zir/parameter.rs b/zokrates_ast/src/zir/parameter.rs index 203a291a7..ac413e54a 100644 --- a/zokrates_ast/src/zir/parameter.rs +++ b/zokrates_ast/src/zir/parameter.rs @@ -1,33 +1,3 @@ -use crate::zir::Variable; -use serde::{Deserialize, Serialize}; -use std::fmt; +use super::Variable; -#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] -pub struct Parameter<'ast> { - #[serde(borrow)] - pub id: Variable<'ast>, - pub private: bool, -} - -impl<'ast> Parameter<'ast> { - #[cfg(test)] - pub fn private(v: Variable<'ast>) -> Self { - Parameter { - id: v, - private: true, - } - } -} - -impl<'ast> fmt::Display for Parameter<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let visibility = if self.private { "private " } else { "" }; - write!(f, "{}{} {}", visibility, self.id.get_type(), self.id.id) - } -} - -impl<'ast> fmt::Debug for Parameter<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Parameter(variable: {:?})", self.id) - } -} +pub type Parameter<'ast> = crate::common::Parameter>; diff --git a/zokrates_ast/src/zir/result_folder.rs b/zokrates_ast/src/zir/result_folder.rs index 6ea741cf6..d368d972d 100644 --- a/zokrates_ast/src/zir/result_folder.rs +++ b/zokrates_ast/src/zir/result_folder.rs @@ -1,27 +1,30 @@ // Generic walk through ZIR. Not mutating in place +use crate::common::expressions::{BinaryOrExpression, IdentifierOrExpression, UnaryOrExpression}; +use crate::common::ResultFold; +use crate::common::WithSpan; use crate::zir::types::UBitwidth; use crate::zir::*; use zokrates_field::Field; -pub trait ResultFold<'ast, T: Field>: Sized { - fn fold>(self, f: &mut F) -> Result; -} - -impl<'ast, T: Field> ResultFold<'ast, T> for FieldElementExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for FieldElementExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { f.fold_field_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for BooleanExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for BooleanExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { f.fold_boolean_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for UExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold for UExpression<'ast, T> { + fn fold(self, f: &mut F) -> Result { f.fold_uint_expression(self) } } @@ -61,6 +64,20 @@ pub trait ResultFolder<'ast, T: Field>: Sized { self.fold_variable(a) } + fn fold_assembly_constraint( + &mut self, + s: AssemblyConstraint<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_constraint(self, s) + } + + fn fold_assembly_assignment( + &mut self, + s: AssemblyAssignment<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_assignment(self, s) + } + fn fold_assembly_statement( &mut self, s: ZirAssemblyStatement<'ast, T>, @@ -68,6 +85,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_assembly_statement(self, s) } + fn fold_assembly_statement_cases( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_statement_cases(self, s) + } + fn fold_statement( &mut self, s: ZirStatement<'ast, T>, @@ -75,6 +99,62 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_statement(self, s) } + fn fold_statement_cases( + &mut self, + s: ZirStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_statement_cases(self, s) + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_definition_statement(self, s) + } + + fn fold_multiple_definition_statement( + &mut self, + s: MultipleDefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_multiple_definition_statement(self, s) + } + + fn fold_return_statement( + &mut self, + s: ReturnStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_return_statement(self, s) + } + + fn fold_log_statement( + &mut self, + s: LogStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_log_statement(self, s) + } + + fn fold_assembly_block( + &mut self, + s: AssemblyBlockStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_block(self, s) + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assertion_statement(self, s) + } + + fn fold_if_else_statement( + &mut self, + s: IfElseStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_if_else_statement(self, s) + } + fn fold_expression( &mut self, e: ZirExpression<'ast, T>, @@ -104,16 +184,18 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } } - fn fold_identifier_expression + Id<'ast, T> + ResultFold<'ast, T>>( + fn fold_identifier_expression< + E: Expr<'ast, T> + Id<'ast, T> + ResultFold, + >( &mut self, ty: &E::Ty, id: IdentifierExpression<'ast, E>, - ) -> Result, Self::Error> { + ) -> Result, E, E::Inner>, Self::Error> { fold_identifier_expression(self, ty, id) } fn fold_conditional_expression< - E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>, + E: Expr<'ast, T> + ResultFold + Conditional<'ast, T>, >( &mut self, ty: &E::Ty, @@ -122,7 +204,35 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_conditional_expression(self, ty, e) } - fn fold_select_expression + ResultFold<'ast, T> + Select<'ast, T>>( + #[allow(clippy::type_complexity)] + fn fold_binary_expression< + L: Expr<'ast, T> + PartialEq + ResultFold, + R: Expr<'ast, T> + PartialEq + ResultFold, + E: Expr<'ast, T> + PartialEq + ResultFold, + Op, + >( + &mut self, + ty: &E::Ty, + e: BinaryExpression, + ) -> Result, Self::Error> { + fold_binary_expression(self, ty, e) + } + + fn fold_unary_expression< + In: Expr<'ast, T> + PartialEq + ResultFold, + E: Expr<'ast, T> + PartialEq + ResultFold, + Op, + >( + &mut self, + ty: &E::Ty, + e: UnaryExpression, + ) -> Result, Self::Error> { + fold_unary_expression(self, ty, e) + } + + fn fold_select_expression< + E: Clone + Expr<'ast, T> + ResultFold + Select<'ast, T>, + >( &mut self, ty: &E::Ty, e: SelectExpression<'ast, T, E>, @@ -137,6 +247,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_field_expression(self, e) } + fn fold_field_expression_cases( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_field_expression_cases(self, e) + } + fn fold_boolean_expression( &mut self, e: BooleanExpression<'ast, T>, @@ -144,6 +261,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_boolean_expression(self, e) } + fn fold_boolean_expression_cases( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_boolean_expression_cases(self, e) + } + fn fold_uint_expression( &mut self, e: UExpression<'ast, T>, @@ -158,102 +282,213 @@ pub trait ResultFolder<'ast, T: Field>: Sized { ) -> Result, Self::Error> { fold_uint_expression_inner(self, bitwidth, e) } + + fn fold_uint_expression_cases( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_uint_expression_cases(self, bitwidth, e) + } +} + +pub fn fold_assembly_assignment<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyAssignment<'ast, T>, +) -> Result>, F::Error> { + let assignees = s + .assignee + .into_iter() + .map(|a| f.fold_assignee(a)) + .collect::>()?; + let expression = f.fold_function(s.expression)?; + Ok(vec![ZirAssemblyStatement::assignment( + assignees, expression, + )]) +} + +pub fn fold_assembly_constraint<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyConstraint<'ast, T>, +) -> Result>, F::Error> { + let left = f.fold_field_expression(s.left)?; + let right = f.fold_field_expression(s.right)?; + Ok(vec![ZirAssemblyStatement::constraint( + left, right, s.metadata, + )]) } -pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + +fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: ZirAssemblyStatement<'ast, T>, ) -> Result>, F::Error> { - Ok(match s { - ZirAssemblyStatement::Assignment(assignees, function) => { - let assignees = assignees - .into_iter() - .map(|a| f.fold_assignee(a)) - .collect::>()?; - let function = f.fold_function(function)?; - vec![ZirAssemblyStatement::Assignment(assignees, function)] - } - ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { - let lhs = f.fold_field_expression(lhs)?; - let rhs = f.fold_field_expression(rhs)?; - vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)] - } - }) + let span = s.get_span(); + f.fold_assembly_statement_cases(s) + .map(|s| s.into_iter().map(|s| s.span(span)).collect()) +} + +pub fn fold_assembly_statement_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ZirAssemblyStatement<'ast, T>, +) -> Result>, F::Error> { + match s { + ZirAssemblyStatement::Assignment(s) => f.fold_assembly_assignment(s), + ZirAssemblyStatement::Constraint(s) => f.fold_assembly_constraint(s), + } } pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: ZirStatement<'ast, T>, ) -> Result>, F::Error> { - let res = match s { - ZirStatement::Return(expressions) => ZirStatement::Return( - expressions + let span = s.get_span(); + f.fold_statement_cases(s) + .map(|s| s.into_iter().map(|s| s.span(span)).collect()) +} + +pub fn fold_statement_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ZirStatement<'ast, T>, +) -> Result>, F::Error> { + let span = s.get_span(); + + match s { + ZirStatement::Return(s) => f.fold_return_statement(s), + ZirStatement::Definition(s) => f.fold_definition_statement(s), + ZirStatement::IfElse(s) => f.fold_if_else_statement(s), + ZirStatement::Assertion(s) => f.fold_assertion_statement(s), + ZirStatement::MultipleDefinition(s) => f.fold_multiple_definition_statement(s), + ZirStatement::Log(s) => f.fold_log_statement(s), + ZirStatement::Assembly(s) => f.fold_assembly_block(s), + } + .map(|s| s.into_iter().map(|s| s.span(span)).collect()) +} + +pub fn fold_return_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ReturnStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![ZirStatement::Return( + ReturnStatement::new( + s.inner .into_iter() .map(|e| f.fold_expression(e)) .collect::>()?, - ), - ZirStatement::Definition(a, e) => { - ZirStatement::Definition(f.fold_assignee(a)?, f.fold_expression(e)?) - } - ZirStatement::IfElse(condition, consequence, alternative) => ZirStatement::IfElse( - f.fold_boolean_expression(condition)?, - consequence + ) + .span(s.span), + )]) +} + +pub fn fold_definition_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: DefinitionStatement<'ast, T>, +) -> Result>, F::Error> { + let rhs = f.fold_expression(s.rhs)?; + Ok(vec![ZirStatement::Definition( + DefinitionStatement::new(f.fold_assignee(s.assignee)?, rhs).span(s.span), + )]) +} + +pub fn fold_if_else_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: IfElseStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![ZirStatement::IfElse( + IfElseStatement::new( + f.fold_boolean_expression(s.condition)?, + s.consequence .into_iter() .map(|s| f.fold_statement(s)) .collect::, _>>()? .into_iter() .flatten() .collect(), - alternative + s.alternative .into_iter() .map(|s| f.fold_statement(s)) .collect::, _>>()? .into_iter() .flatten() .collect(), - ), - ZirStatement::Assertion(e, error) => { - ZirStatement::Assertion(f.fold_boolean_expression(e)?, error) - } - ZirStatement::MultipleDefinition(variables, elist) => ZirStatement::MultipleDefinition( - variables + ) + .span(s.span), + )]) +} + +pub fn fold_assertion_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssertionStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![ZirStatement::Assertion( + AssertionStatement::new(f.fold_boolean_expression(s.expression)?, s.error).span(s.span), + )]) +} + +pub fn fold_log_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: LogStatement<'ast, T>, +) -> Result>, F::Error> { + let expressions = s + .expressions + .into_iter() + .map(|(t, e)| { + e.into_iter() + .map(|e| f.fold_expression(e)) + .collect::, _>>() + .map(|e| (t, e)) + }) + .collect::, _>>()?; + Ok(vec![ZirStatement::Log(LogStatement::new( + s.format_string, + expressions, + ))]) +} + +pub fn fold_assembly_block<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyBlockStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![ZirStatement::Assembly(AssemblyBlockStatement::new( + s.inner + .into_iter() + .map(|s| f.fold_assembly_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ))]) +} + +pub fn fold_multiple_definition_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: MultipleDefinitionStatement<'ast, T>, +) -> Result>, F::Error> { + let expression_list = f.fold_expression_list(s.rhs)?; + Ok(vec![ZirStatement::MultipleDefinition( + MultipleDefinitionStatement::new( + s.assignees .into_iter() .map(|v| f.fold_assignee(v)) .collect::>()?, - f.fold_expression_list(elist)?, + expression_list, ), - ZirStatement::Log(l, e) => { - let e = e - .into_iter() - .map(|(t, e)| { - e.into_iter() - .map(|e| f.fold_expression(e)) - .collect::, _>>() - .map(|e| (t, e)) - }) - .collect::, _>>()?; - - ZirStatement::Log(l, e) - } - ZirStatement::Assembly(statements) => { - let statements = statements - .into_iter() - .map(|s| f.fold_assembly_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(); - ZirStatement::Assembly(statements) - } - }; - Ok(vec![res]) + )]) +} + +fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_field_expression_cases(e).map(|e| e.span(span)) } -pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( +pub fn fold_field_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: FieldElementExpression<'ast, T>, ) -> Result, F::Error> { Ok(match e { - FieldElementExpression::Number(n) => FieldElementExpression::Number(n), + FieldElementExpression::Value(n) => FieldElementExpression::Value(n), FieldElementExpression::Identifier(id) => { match f.fold_identifier_expression(&Type::FieldElement, id)? { IdentifierOrExpression::Identifier(i) => FieldElementExpression::Identifier(i), @@ -266,60 +501,63 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( SelectOrExpression::Expression(u) => u, } } - FieldElementExpression::Add(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldElementExpression::Add(box e1, box e2) - } - FieldElementExpression::Sub(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldElementExpression::Sub(box e1, box e2) + FieldElementExpression::Add(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Add(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::Mult(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldElementExpression::Mult(box e1, box e2) + FieldElementExpression::Sub(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Sub(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::Div(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldElementExpression::Div(box e1, box e2) + FieldElementExpression::Mult(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Mult(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::Pow(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - FieldElementExpression::Pow(box e1, box e2) + FieldElementExpression::Div(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Div(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::Xor(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - FieldElementExpression::Xor(box left, box right) + FieldElementExpression::Pow(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Pow(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::And(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - FieldElementExpression::And(box left, box right) + FieldElementExpression::And(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::And(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::Or(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - FieldElementExpression::Or(box left, box right) + FieldElementExpression::Or(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Or(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Xor(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Xor(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::LeftShift(box e, box by) => { - let e = f.fold_field_expression(e)?; - let by = f.fold_uint_expression(by)?; - - FieldElementExpression::LeftShift(box e, box by) + FieldElementExpression::LeftShift(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::LeftShift(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::RightShift(box e, box by) => { - let e = f.fold_field_expression(e)?; - let by = f.fold_uint_expression(by)?; - - FieldElementExpression::RightShift(box e, box by) + FieldElementExpression::RightShift(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::RightShift(e), + BinaryOrExpression::Expression(e) => e, + } } FieldElementExpression::Conditional(c) => { match f.fold_conditional_expression(&Type::FieldElement, c)? { @@ -330,10 +568,20 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } -pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_boolean_expression_cases(e).map(|e| e.span(span)) +} + +pub fn fold_boolean_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: BooleanExpression<'ast, T>, ) -> Result, F::Error> { + use BooleanExpression::*; + Ok(match e { BooleanExpression::Value(v) => BooleanExpression::Value(v), BooleanExpression::Identifier(id) => { @@ -346,55 +594,46 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( SelectOrExpression::Select(s) => BooleanExpression::Select(s), SelectOrExpression::Expression(u) => u, }, - BooleanExpression::FieldEq(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - BooleanExpression::FieldEq(box e1, box e2) - } - BooleanExpression::BoolEq(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1)?; - let e2 = f.fold_boolean_expression(e2)?; - BooleanExpression::BoolEq(box e1, box e2) - } - BooleanExpression::UintEq(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - BooleanExpression::UintEq(box e1, box e2) - } - BooleanExpression::FieldLt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - BooleanExpression::FieldLt(box e1, box e2) - } - BooleanExpression::UintLt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - BooleanExpression::UintLt(box e1, box e2) - } - BooleanExpression::FieldLe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - BooleanExpression::FieldLe(box e1, box e2) - } - BooleanExpression::UintLe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - BooleanExpression::UintLe(box e1, box e2) - } - BooleanExpression::Or(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1)?; - let e2 = f.fold_boolean_expression(e2)?; - BooleanExpression::Or(box e1, box e2) - } - BooleanExpression::And(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1)?; - let e2 = f.fold_boolean_expression(e2)?; - BooleanExpression::And(box e1, box e2) - } - BooleanExpression::Not(box e) => { - let e = f.fold_boolean_expression(e)?; - BooleanExpression::Not(box e) - } + FieldEq(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => FieldEq(e), + BinaryOrExpression::Expression(u) => u, + }, + BoolEq(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => BoolEq(e), + BinaryOrExpression::Expression(u) => u, + }, + UintEq(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => UintEq(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLt(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => FieldLt(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLe(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => FieldLe(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLt(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => UintLt(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLe(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => UintLe(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&Type::Boolean, e)? { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, BooleanExpression::Conditional(c) => { match f.fold_conditional_expression(&Type::Boolean, c)? { ConditionalOrExpression::Conditional(s) => BooleanExpression::Conditional(s), @@ -414,86 +653,78 @@ pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } -pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + ty: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_uint_expression_cases(ty, e).map(|e| e.span(span)) +} + +pub fn fold_uint_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: UBitwidth, e: UExpressionInner<'ast, T>, ) -> Result, F::Error> { + use UExpressionInner::*; + Ok(match e { - UExpressionInner::Value(v) => UExpressionInner::Value(v), - UExpressionInner::Identifier(id) => match f.fold_identifier_expression(&ty, id)? { + Value(v) => UExpressionInner::Value(v), + Identifier(id) => match f.fold_identifier_expression(&ty, id)? { IdentifierOrExpression::Identifier(i) => UExpressionInner::Identifier(i), IdentifierOrExpression::Expression(e) => e, }, - UExpressionInner::Select(e) => match f.fold_select_expression(&ty, e)? { + Select(e) => match f.fold_select_expression(&ty, e)? { SelectOrExpression::Select(s) => UExpressionInner::Select(s), SelectOrExpression::Expression(u) => u, }, - UExpressionInner::Add(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Add(box left, box right) - } - UExpressionInner::Sub(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Sub(box left, box right) - } - UExpressionInner::Mult(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Mult(box left, box right) - } - UExpressionInner::Div(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Div(box left, box right) - } - UExpressionInner::Rem(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Rem(box left, box right) - } - UExpressionInner::Xor(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Xor(box left, box right) - } - UExpressionInner::And(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::And(box left, box right) - } - UExpressionInner::Or(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Or(box left, box right) - } - UExpressionInner::LeftShift(box e, by) => { - let e = f.fold_uint_expression(e)?; - - UExpressionInner::LeftShift(box e, by) - } - UExpressionInner::RightShift(box e, by) => { - let e = f.fold_uint_expression(e)?; - - UExpressionInner::RightShift(box e, by) - } - UExpressionInner::Not(box e) => { - let e = f.fold_uint_expression(e)?; - - UExpressionInner::Not(box e) - } - UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c)? { - ConditionalOrExpression::Conditional(s) => UExpressionInner::Conditional(s), + Add(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Rem(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Rem(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&ty, e)? { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, + Conditional(c) => match f.fold_conditional_expression(&ty, c)? { + ConditionalOrExpression::Conditional(s) => Conditional(s), ConditionalOrExpression::Expression(u) => u, }, }) @@ -527,6 +758,7 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>( ) -> Result, F::Error> { Ok(ZirProgram { main: f.fold_function(p.main)?, + ..p }) } @@ -539,7 +771,7 @@ pub fn fold_identifier_expression< f: &mut F, _: &E::Ty, e: IdentifierExpression<'ast, E>, -) -> Result, F::Error> { +) -> Result, E, E::Inner>, F::Error> { Ok(IdentifierOrExpression::Identifier( IdentifierExpression::new(f.fold_name(e.id)?), )) @@ -548,7 +780,7 @@ pub fn fold_identifier_expression< pub fn fold_conditional_expression< 'ast, T: Field, - E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>, + E: Expr<'ast, T> + ResultFold + Conditional<'ast, T>, F: ResultFolder<'ast, T>, >( f: &mut F, @@ -567,7 +799,7 @@ pub fn fold_conditional_expression< pub fn fold_select_expression< 'ast, T: Field, - E: Expr<'ast, T> + ResultFold<'ast, T> + Select<'ast, T>, + E: Expr<'ast, T> + ResultFold + Select<'ast, T>, F: ResultFolder<'ast, T>, >( f: &mut F, @@ -582,3 +814,39 @@ pub fn fold_select_expression< e.index.fold(f)?, ))) } + +#[allow(clippy::type_complexity)] +pub fn fold_binary_expression< + 'ast, + T: Field, + L: Expr<'ast, T> + PartialEq + ResultFold + From>, + R: Expr<'ast, T> + PartialEq + ResultFold + From>, + E: Expr<'ast, T> + PartialEq + ResultFold + From>, + F: ResultFolder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: BinaryExpression, +) -> Result, F::Error> { + Ok(BinaryOrExpression::Binary( + BinaryExpression::new(e.left.fold(f)?, e.right.fold(f)?).span(e.span), + )) +} + +pub fn fold_unary_expression< + 'ast, + T: Field, + In: Expr<'ast, T> + PartialEq + ResultFold + From>, + E: Expr<'ast, T> + PartialEq + ResultFold + From>, + F: ResultFolder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: UnaryExpression, +) -> Result, F::Error> { + Ok(UnaryOrExpression::Unary( + UnaryExpression::new(e.inner.fold(f)?).span(e.span), + )) +} diff --git a/zokrates_ast/src/zir/uint.rs b/zokrates_ast/src/zir/uint.rs index 9ae30d40e..dfb0b3ec9 100644 --- a/zokrates_ast/src/zir/uint.rs +++ b/zokrates_ast/src/zir/uint.rs @@ -1,5 +1,7 @@ -use crate::zir::types::UBitwidth; +use crate::common::expressions::{UValueExpression, UnaryExpression, ValueExpression}; use crate::zir::IdentifierExpression; +use crate::{common::expressions::BinaryExpression, common::operators::*, zir::types::UBitwidth}; +use derivative::Derivative; use serde::{Deserialize, Serialize}; use zokrates_field::Field; @@ -10,14 +12,14 @@ impl<'ast, T: Field> UExpression<'ast, T> { pub fn add(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Add(box self, box other).annotate(bitwidth) + UExpressionInner::Add(BinaryExpression::new(self, other)).annotate(bitwidth) } #[allow(clippy::should_implement_trait)] pub fn sub(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Sub(box self, box other).annotate(bitwidth) + UExpressionInner::Sub(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn select(values: Vec, index: Self) -> UExpression<'ast, T> { @@ -28,67 +30,67 @@ impl<'ast, T: Field> UExpression<'ast, T> { pub fn mult(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Mult(box self, box other).annotate(bitwidth) + UExpressionInner::Mult(BinaryExpression::new(self, other)).annotate(bitwidth) } #[allow(clippy::should_implement_trait)] pub fn div(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Div(box self, box other).annotate(bitwidth) + UExpressionInner::Div(BinaryExpression::new(self, other)).annotate(bitwidth) } #[allow(clippy::should_implement_trait)] pub fn rem(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Rem(box self, box other).annotate(bitwidth) + UExpressionInner::Rem(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn xor(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Xor(box self, box other).annotate(bitwidth) + UExpressionInner::Xor(BinaryExpression::new(self, other)).annotate(bitwidth) } #[allow(clippy::should_implement_trait)] pub fn not(self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::Not(box self).annotate(bitwidth) + UExpressionInner::Not(UnaryExpression::new(self)).annotate(bitwidth) } pub fn or(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Or(box self, box other).annotate(bitwidth) + UExpressionInner::Or(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn and(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::And(box self, box other).annotate(bitwidth) + UExpressionInner::And(BinaryExpression::new(self, other)).annotate(bitwidth) } - pub fn left_shift(self, by: u32) -> UExpression<'ast, T> { + pub fn left_shift(self, by: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::LeftShift(box self, by).annotate(bitwidth) + UExpressionInner::LeftShift(BinaryExpression::new(self, by)).annotate(bitwidth) } - pub fn right_shift(self, by: u32) -> UExpression<'ast, T> { + pub fn right_shift(self, by: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::RightShift(box self, by).annotate(bitwidth) + UExpressionInner::RightShift(BinaryExpression::new(self, by)).annotate(bitwidth) } } -impl<'ast, T: Field> From for UExpressionInner<'ast, T> { +impl<'ast, T> From for UExpressionInner<'ast, T> { fn from(e: u128) -> Self { - UExpressionInner::Value(e) + UExpressionInner::Value(ValueExpression::new(e)) } } impl<'ast, T> From for UExpression<'ast, T> { fn from(u: u32) -> Self { - UExpressionInner::Value(u as u128).annotate(UBitwidth::B32) + UExpressionInner::from(u as u128).annotate(UBitwidth::B32) } } @@ -163,7 +165,9 @@ impl UMetadata { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct UExpression<'ast, T> { pub bitwidth: UBitwidth, pub metadata: Option>, @@ -171,23 +175,29 @@ pub struct UExpression<'ast, T> { pub inner: UExpressionInner<'ast, T>, } -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum UExpressionInner<'ast, T> { - Value(u128), + Value(UValueExpression), #[serde(borrow)] Identifier(IdentifierExpression<'ast, UExpression<'ast, T>>), Select(SelectExpression<'ast, T, UExpression<'ast, T>>), - Add(Box>, Box>), - Sub(Box>, Box>), - Mult(Box>, Box>), - Div(Box>, Box>), - Rem(Box>, Box>), - Xor(Box>, Box>), - And(Box>, Box>), - Or(Box>, Box>), - LeftShift(Box>, u32), - RightShift(Box>, u32), - Not(Box>), + Add(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Sub(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Mult(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Div(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Rem(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Xor(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + And(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Or(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + LeftShift( + BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>, + ), + RightShift( + BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>, + ), + Not(UnaryExpression, UExpression<'ast, T>>), Conditional(ConditionalExpression<'ast, T, UExpression<'ast, T>>), } diff --git a/zokrates_ast/src/zir/variable.rs b/zokrates_ast/src/zir/variable.rs index 14d329727..62ee8bb12 100644 --- a/zokrates_ast/src/zir/variable.rs +++ b/zokrates_ast/src/zir/variable.rs @@ -1,14 +1,7 @@ use crate::zir::types::{Type, UBitwidth}; use crate::zir::Identifier; -use serde::{Deserialize, Serialize}; -use std::fmt; -#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] -pub struct Variable<'ast> { - #[serde(borrow)] - pub id: Identifier<'ast>, - pub _type: Type, -} +pub type Variable<'ast> = crate::common::Variable, Type>; impl<'ast> Variable<'ast> { pub fn field_element>>(id: I) -> Variable<'ast> { @@ -23,26 +16,11 @@ impl<'ast> Variable<'ast> { Self::with_id_and_type(id, Type::uint(bitwidth)) } - pub fn with_id_and_type>>(id: I, _type: Type) -> Variable<'ast> { - Variable { - id: id.into(), - _type, - } + pub fn with_id_and_type>>(id: I, ty: Type) -> Variable<'ast> { + Variable::new(id.into(), ty) } pub fn get_type(&self) -> Type { - self._type.clone() - } -} - -impl<'ast> fmt::Display for Variable<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {}", self._type, self.id,) - } -} - -impl<'ast> fmt::Debug for Variable<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Variable(type: {:?}, id: {:?})", self._type, self.id,) + self.ty.clone() } } diff --git a/zokrates_bellman/src/groth16.rs b/zokrates_bellman/src/groth16.rs index 619d95e70..2efabd991 100644 --- a/zokrates_bellman/src/groth16.rs +++ b/zokrates_bellman/src/groth16.rs @@ -210,15 +210,20 @@ mod tests { use zokrates_interpreter::Interpreter; use super::*; - use zokrates_ast::common::{Parameter, Variable}; - use zokrates_ast::ir::{Prog, Statement}; + use zokrates_ast::common::flat::Parameter; + use zokrates_ast::ir::{Prog, Statement, Variable}; #[test] fn verify() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], solvers: vec![], }; diff --git a/zokrates_bellman/src/lib.rs b/zokrates_bellman/src/lib.rs index ce1ce717b..0f8a7d511 100644 --- a/zokrates_bellman/src/lib.rs +++ b/zokrates_bellman/src/lib.rs @@ -10,7 +10,7 @@ use bellman::{ Circuit, ConstraintSystem, LinearCombination, SynthesisError, Variable as BellmanVariable, }; use std::collections::BTreeMap; -use zokrates_ast::common::Variable; +use zokrates_ast::common::flat::Variable; use zokrates_ast::ir::{LinComb, ProgIterator, Statement, Witness}; use zokrates_field::BellmanFieldExtensions; use zokrates_field::Field; @@ -50,7 +50,8 @@ fn bellman_combination, witness: &mut Witness, ) -> LinearCombination { - l.0.into_iter() + l.value + .into_iter() .map(|(k, v)| { ( v.into_bellman(), @@ -126,10 +127,10 @@ impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator = Prog { + module_map: Default::default(), arguments: vec![Parameter::private(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], solvers: vec![], }; @@ -285,9 +291,14 @@ mod tests { #[test] fn public_identity() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], solvers: vec![], }; @@ -307,9 +318,14 @@ mod tests { #[test] fn no_arguments() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![], return_count: 1, - statements: vec![Statement::constraint(Variable::one(), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::one(), + Variable::public(0), + None, + )], solvers: vec![], }; @@ -328,6 +344,7 @@ mod tests { // public variables must be ordered from 0 // private variables can be unordered let program: Prog = Prog { + module_map: Default::default(), arguments: vec![ Parameter::private(Variable::new(42)), Parameter::public(Variable::new(51)), @@ -337,10 +354,12 @@ mod tests { Statement::constraint( LinComb::from(Variable::new(42)) + LinComb::from(Variable::new(51)), Variable::public(0), + None, ), Statement::constraint( LinComb::from(Variable::one()) + LinComb::from(Variable::new(42)), Variable::public(1), + None, ), ], solvers: vec![], @@ -361,11 +380,13 @@ mod tests { #[test] fn one() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(42))], return_count: 1, statements: vec![Statement::constraint( LinComb::from(Variable::new(42)) + LinComb::one(), Variable::public(0), + None, )], solvers: vec![], }; @@ -386,6 +407,7 @@ mod tests { #[test] fn with_directives() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![ Parameter::private(Variable::new(42)), Parameter::public(Variable::new(51)), @@ -394,6 +416,7 @@ mod tests { statements: vec![Statement::constraint( LinComb::from(Variable::new(42)) + LinComb::from(Variable::new(51)), Variable::public(0), + None, )], solvers: vec![], }; diff --git a/zokrates_circom/src/lib.rs b/zokrates_circom/src/lib.rs index e87c0af38..9960fc19c 100644 --- a/zokrates_circom/src/lib.rs +++ b/zokrates_circom/src/lib.rs @@ -24,23 +24,24 @@ mod tests { #[test] fn setup_and_prove() { let prog: Prog = Prog { + module_map: Default::default(), arguments: vec![ Parameter::private(Variable::new(0)), Parameter::public(Variable::new(1)), ], return_count: 1, statements: vec![ - Statement::Constraint( - QuadComb::from_linear_combinations( + Statement::constraint( + QuadComb::new( LinComb::from(Variable::new(0)), LinComb::from(Variable::new(0)), ), LinComb::from(Variable::new(0)), None, ), - Statement::Constraint( - (LinComb::from(Variable::new(0)) + LinComb::from(Variable::new(1))).into(), - Variable::public(0).into(), + Statement::constraint( + LinComb::from(Variable::new(0)) + LinComb::from(Variable::new(1)), + Variable::public(0), None, ), ], diff --git a/zokrates_circom/src/r1cs.rs b/zokrates_circom/src/r1cs.rs index b2fc8db26..00c0d8dea 100644 --- a/zokrates_circom/src/r1cs.rs +++ b/zokrates_circom/src/r1cs.rs @@ -69,19 +69,19 @@ pub fn r1cs_program(prog: Prog) -> (Vec, usize, Vec Some((quad, lin)), + for s in prog.statements.iter().filter_map(|s| match s { + Statement::Constraint(s) => Some(s), Statement::Directive(..) => None, Statement::Block(..) => unreachable!(), Statement::Log(..) => None, }) { - for (k, _) in &quad.left.0 { + for (k, _) in &s.quad.left.value { ordered_variables_set.insert(k); } - for (k, _) in &quad.right.0 { + for (k, _) in &s.quad.right.value { ordered_variables_set.insert(k); } - for (k, _) in &lin.0 { + for (k, _) in &s.lin.value { ordered_variables_set.insert(k); } } @@ -95,23 +95,23 @@ pub fn r1cs_program(prog: Prog) -> (Vec, usize, Vec Some((quad, lin)), Statement::Block(..) => unreachable!(), Statement::Directive(..) => None, Statement::Log(..) => None, + Statement::Constraint(s) => Some((s.quad, s.lin)), }) { constraints.push(( quad.left - .0 + .value .into_iter() .map(|(k, v)| (*variables.get(&k).unwrap(), v)) .collect(), quad.right - .0 + .value .into_iter() .map(|(k, v)| (*variables.get(&k).unwrap(), v)) .collect(), - lin.0 + lin.value .into_iter() .map(|(k, v)| (*variables.get(&k).unwrap(), v)) .collect(), @@ -289,11 +289,12 @@ mod tests { #[test] fn return_one() { let prog: Prog = Prog { + module_map: Default::default(), arguments: vec![], return_count: 1, - statements: vec![Statement::Constraint( - LinComb::one().into(), - Variable::public(0).into(), + statements: vec![Statement::constraint( + LinComb::one(), + Variable::public(0), None, )], solvers: vec![], @@ -346,23 +347,24 @@ mod tests { #[test] fn with_inputs() { let prog: Prog = Prog { + module_map: Default::default(), arguments: vec![ Parameter::private(Variable::new(0)), Parameter::public(Variable::new(1)), ], return_count: 1, statements: vec![ - Statement::Constraint( - QuadComb::from_linear_combinations( + Statement::constraint( + QuadComb::new( LinComb::from(Variable::new(0)), LinComb::from(Variable::new(0)), ), LinComb::from(Variable::new(0)), None, ), - Statement::Constraint( - (LinComb::from(Variable::new(0)) + LinComb::from(Variable::new(1))).into(), - Variable::public(0).into(), + Statement::constraint( + LinComb::from(Variable::new(0)) + LinComb::from(Variable::new(1)), + Variable::public(0), None, ), ], diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index ccf98794d..90066f8fc 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -21,6 +21,7 @@ zokrates_field = { version = "0.5", path = "../zokrates_field", features = ["mul zokrates_abi = { version = "0.1", path = "../zokrates_abi" } zokrates_core = { version = "0.7", path = "../zokrates_core", default-features = false } zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } +zokrates_profiler = { version = "0.1", path = "../zokrates_profiler", default-features = false } zokrates_interpreter = { version = "0.1", path = "../zokrates_interpreter", default-features = false } zokrates_circom = { version = "0.1", path = "../zokrates_circom", default-features = false } zokrates_embed = { version = "0.1", path = "../zokrates_embed", features = ["multicore"] } diff --git a/zokrates_cli/src/bin.rs b/zokrates_cli/src/bin.rs index a5f8a31f8..a0aa12882 100644 --- a/zokrates_cli/src/bin.rs +++ b/zokrates_cli/src/bin.rs @@ -57,7 +57,9 @@ fn cli() -> Result<(), String> { generate_smtlib2::subcommand(), print_proof::subcommand(), #[cfg(any(feature = "bellman", feature = "ark"))] - verify::subcommand()]) + verify::subcommand(), + profile::subcommand() + ]) .get_matches(); match matches.subcommand() { @@ -78,6 +80,7 @@ fn cli() -> Result<(), String> { ("print-proof", Some(sub_matches)) => print_proof::exec(sub_matches), #[cfg(any(feature = "bellman", feature = "ark"))] ("verify", Some(sub_matches)) => verify::exec(sub_matches), + ("profile", Some(sub_matches)) => profile::exec(sub_matches), _ => unreachable!(), } } diff --git a/zokrates_cli/src/ops/compile.rs b/zokrates_cli/src/ops/compile.rs index 74d26b6f2..35d47b3f8 100644 --- a/zokrates_cli/src/ops/compile.rs +++ b/zokrates_cli/src/ops/compile.rs @@ -130,8 +130,8 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { let arena = Arena::new(); - let artifacts = - compile::(source, path, Some(&resolver), config, &arena).map_err(|e| { + let artifacts = compile::(source, path.clone(), Some(&resolver), config, &arena) + .map_err(|e| { format!( "Compilation failed:\n\n{}", e.0.iter() @@ -154,7 +154,15 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { let mut bin_writer = BufWriter::new(bin_output_file); let mut r1cs_writer = BufWriter::new(r1cs_output_file); - let program_flattened = program_flattened.collect(); + let mut program_flattened = program_flattened.collect(); + + // hide user path + program_flattened.module_map = program_flattened + .module_map + .remap_prefix(path.parent().unwrap(), Path::new("")); + program_flattened.module_map = program_flattened + .module_map + .remap_prefix(Path::new(stdlib_path), Path::new("STDLIB")); write_r1cs(&mut r1cs_writer, program_flattened.clone()).unwrap(); diff --git a/zokrates_cli/src/ops/mod.rs b/zokrates_cli/src/ops/mod.rs index e82dd506a..cb21da334 100644 --- a/zokrates_cli/src/ops/mod.rs +++ b/zokrates_cli/src/ops/mod.rs @@ -9,6 +9,7 @@ pub mod inspect; #[cfg(feature = "bellman")] pub mod mpc; pub mod print_proof; +pub mod profile; #[cfg(any(feature = "bellman", feature = "ark"))] pub mod setup; #[cfg(feature = "ark")] diff --git a/zokrates_cli/src/ops/profile.rs b/zokrates_cli/src/ops/profile.rs new file mode 100644 index 000000000..79e0b1f71 --- /dev/null +++ b/zokrates_cli/src/ops/profile.rs @@ -0,0 +1,52 @@ +use crate::cli_constants::FLATTENED_CODE_DEFAULT_PATH; +use clap::{App, Arg, ArgMatches, SubCommand}; +use std::fs::File; +use std::io::BufReader; +use std::path::Path; +use zokrates_ast::ir::{self, ProgEnum}; +use zokrates_field::Field; +use zokrates_profiler::profile; + +pub fn subcommand() -> App<'static, 'static> { + SubCommand::with_name("profile") + .about("Profiles a compiled program, indicating which parts of the source yield the most constraints") + .arg( + Arg::with_name("input") + .short("i") + .long("input") + .help("Path of the binary") + .value_name("FILE") + .takes_value(true) + .required(false) + .default_value(FLATTENED_CODE_DEFAULT_PATH), + ) +} + +pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { + // read compiled program + let path = Path::new(sub_matches.value_of("input").unwrap()); + let file = + File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + + let mut reader = BufReader::new(file); + + match ProgEnum::deserialize(&mut reader)? { + ProgEnum::Bn128Program(p) => cli_profile(p, sub_matches), + ProgEnum::Bls12_377Program(p) => cli_profile(p, sub_matches), + ProgEnum::Bls12_381Program(p) => cli_profile(p, sub_matches), + ProgEnum::Bw6_761Program(p) => cli_profile(p, sub_matches), + } +} + +fn cli_profile<'ast, T: Field, I: Iterator>>( + ir_prog: ir::ProgIterator<'ast, T, I>, + _: &ArgMatches, +) -> Result<(), String> { + let module_map = ir_prog.module_map.clone(); + + let heat_map = profile(ir_prog); + + println!("{}", heat_map.display(&module_map)); + + Ok(()) +} diff --git a/zokrates_cli/tests/integration.rs b/zokrates_cli/tests/integration.rs index cd988b104..628dfaf18 100644 --- a/zokrates_cli/tests/integration.rs +++ b/zokrates_cli/tests/integration.rs @@ -29,7 +29,7 @@ mod integration { mod helpers { use super::*; - use zokrates_ast::common::Variable; + use zokrates_ast::common::flat::Variable; use zokrates_field::Field; pub fn parse_variable(s: &str) -> Result { diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index a09eb61e9..8484d8b62 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -10,9 +10,15 @@ mod utils; use self::utils::flat_expression_from_bits; -use zokrates_ast::zir::{ - canonicalizer::ZirCanonicalizer, ConditionalExpression, Folder, SelectExpression, ShouldReduce, - UMetadata, ZirAssemblyStatement, ZirExpressionList, +use zokrates_ast::{ + common::{ + expressions::{BinaryExpression, UnaryExpression, ValueExpression}, + Span, + }, + zir::{ + canonicalizer::ZirCanonicalizer, ConditionalExpression, Expr, Folder, SelectExpression, + ShouldReduce, UMetadata, ZirAssemblyStatement, ZirExpressionList, ZirProgram, + }, }; use zokrates_interpreter::Interpreter; @@ -20,32 +26,73 @@ use std::collections::{ hash_map::{Entry, HashMap}, VecDeque, }; +use std::ops::*; use zokrates_ast::common::embed::*; use zokrates_ast::common::FlatEmbed; -use zokrates_ast::common::{RuntimeError, Variable}; +use zokrates_ast::common::WithSpan; +use zokrates_ast::common::{flat::Variable, RuntimeError}; use zokrates_ast::flat::*; use zokrates_ast::ir::Solver; use zokrates_ast::zir::types::{Type, UBitwidth}; use zokrates_ast::zir::{ BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter, - UExpression, UExpressionInner, Variable as ZirVariable, ZirExpression, ZirFunction, - ZirStatement, + UExpression, UExpressionInner, Variable as ZirVariable, ZirExpression, ZirStatement, }; use zokrates_common::CompileConfig; use zokrates_field::Field; -type FlatStatements<'ast, T> = VecDeque>; +/// A container for statements produced during code generation +/// New statements are registered with the span set in the container +#[derive(Default)] +pub struct FlatStatements<'ast, T> { + span: Option, + buffer: VecDeque>, +} + +impl<'ast, T> FlatStatements<'ast, T> { + fn push_back(&mut self, s: FlatStatement<'ast, T>) { + self.buffer.push_back(s.span(self.span)) + } + + fn pop_front(&mut self) -> Option> { + self.buffer.pop_front() + } + + fn extend(&mut self, i: impl IntoIterator>) { + self.buffer.extend(i.into_iter().map(|s| s.span(self.span))) + } + + fn set_span(&mut self, span: Option) { + self.span = span; + } + + fn is_empty(&self) -> bool { + self.buffer.is_empty() + } +} + +impl<'ast, T> IntoIterator for FlatStatements<'ast, T> { + type Item = FlatStatement<'ast, T>; + + type IntoIter = std::collections::vec_deque::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.buffer.into_iter() + } +} /// Flattens a function /// /// # Arguments /// * `funct` - `ZirFunction` that will be flattened -pub fn from_function_and_config( - funct: ZirFunction, +pub fn from_program_and_config( + prog: ZirProgram, config: CompileConfig, ) -> FlattenerIterator { + let funct = prog.main; + let mut flattener = Flattener::new(config); - let mut statements_flattened = FlatStatements::new(); + let mut statements_flattened = FlatStatements::default(); // push parameters let arguments_flattened = funct .arguments @@ -61,6 +108,7 @@ pub fn from_function_and_config( flattener, }, return_count: funct.signature.outputs.len(), + module_map: prog.module_map, } } @@ -168,12 +216,28 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> { } #[derive(Clone, Debug)] -struct FlatUExpression { +struct FlatUExpression { field: Option>, bits: Option>>, } -impl FlatUExpression { +impl WithSpan for FlatUExpression { + fn span(self, span: Option) -> Self { + Self { + field: self.field.map(|e| e.span(span)), + bits: self + .bits + .map(|bits| bits.into_iter().map(|b| b.span(span)).collect()), + } + } + + fn get_span(&self) -> Option { + let field_span = self.field.as_ref().map(|f| f.get_span()); + field_span.unwrap_or_else(|| unimplemented!()) + } +} + +impl FlatUExpression { fn default() -> Self { FlatUExpression { field: None, @@ -182,7 +246,7 @@ impl FlatUExpression { } } -impl FlatUExpression { +impl FlatUExpression { fn field>>>(mut self, e: U) -> Self { self.field = e.into(); self @@ -200,7 +264,9 @@ impl FlatUExpression { fn with_bits>>>>(e: U) -> Self { Self::default().bits(e) } +} +impl FlatUExpression { fn get_field_unchecked(self) -> FlatExpression { match self.field { Some(f) => f, @@ -230,10 +296,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, ) -> Variable { match e { - FlatExpression::Identifier(id) => id, + FlatExpression::Identifier(id) => id.id, e => { let res = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(res, e)); + statements_flattened.push_back(FlatStatement::definition(res, e)); res } } @@ -290,9 +356,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { .iter() .map(|e| { let e_id = self.define(e.clone(), statements_flattened); - FlatStatement::Condition( + FlatStatement::condition( e_id.into(), - FlatExpression::Mult(box e_id.into(), box e_id.into()), + FlatExpression::mul(e_id.into(), e_id.into()), RuntimeError::Bitness, ) }) @@ -312,28 +378,25 @@ impl<'ast, T: Field> Flattener<'ast, T> { } // init size_unknown = true - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( size_unknown[0], - FlatExpression::Number(T::from(1)), + FlatExpression::value(T::from(1)), )); let mut res = vec![]; for (i, b) in b.iter().enumerate() { if *b { - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( is_not_smaller_run[i], a[i].clone(), )); // don't need to update size_unknown in the last round if i < len - 1 { - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( size_unknown[i + 1], - FlatExpression::Mult( - box size_unknown[i].into(), - box is_not_smaller_run[i].into(), - ), + FlatExpression::mul(size_unknown[i].into(), is_not_smaller_run[i].into()), )); } } else { @@ -343,24 +406,20 @@ impl<'ast, T: Field> Flattener<'ast, T> { // sizeUnknown is not changing in this case // We sill have to assign the old value to the variable of the current run // This trivial definition will later be removed by the optimiser - FlatStatement::Definition(size_unknown[i + 1], size_unknown[i].into()), + FlatStatement::definition(size_unknown[i + 1], size_unknown[i].into()), ); } - let or_left = FlatExpression::Sub( - box FlatExpression::Number(T::from(1)), - box size_unknown[i].into(), - ); + let or_left = + FlatExpression::sub(FlatExpression::value(T::from(1)), size_unknown[i].into()); let or_right: FlatExpression<_> = - FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box a[i].clone()); + FlatExpression::sub(FlatExpression::value(T::from(1)), a[i].clone()); let and_name = self.use_sym(); - let and = FlatExpression::Mult(box or_left.clone(), box or_right.clone()); - statements_flattened.push_back(FlatStatement::Definition(and_name, and)); - let or = FlatExpression::Sub( - box FlatExpression::Add(box or_left, box or_right), - box and_name.into(), - ); + let and = FlatExpression::mul(or_left.clone(), or_right.clone()); + statements_flattened.push_back(FlatStatement::definition(and_name, and)); + let or = + FlatExpression::sub(FlatExpression::add(or_left, or_right), and_name.into()); res.push(or); } @@ -395,30 +454,30 @@ impl<'ast, T: Field> Flattener<'ast, T> { // Y == X * M // 0 == (1-Y) * X - let x = FlatExpression::Sub(box left.into(), box right.into()); + let x = FlatExpression::sub(left.into(), right.into()); let name_y = self.use_sym(); let name_m = self.use_sym(); - statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::directive( vec![name_y, name_m], Solver::ConditionEq, vec![x.clone()], - ))); - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Identifier(name_y), - FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)), + )); + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::identifier(name_y), + FlatExpression::mul(x.clone(), FlatExpression::identifier(name_m)), RuntimeError::Equal, )); - let res = FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box FlatExpression::Identifier(name_y), + let res = FlatExpression::sub( + FlatExpression::value(T::one()), + FlatExpression::identifier(name_y), ); - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Number(T::zero()), - FlatExpression::Mult(box res.clone(), box x), + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::zero()), + FlatExpression::mul(res.clone(), x), RuntimeError::Equal, )); @@ -446,11 +505,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { let conditions_sum = conditions .into_iter() .fold(FlatExpression::from(T::zero()), |acc, e| { - FlatExpression::Add(box acc, box e) + FlatExpression::add(acc, e) }); - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Number(T::from(0)), - FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()), + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::from(0)), + FlatExpression::sub(conditions_sum, T::from(conditions_count).into()), error, )); } @@ -507,7 +566,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) { // `e < 0` will always result in false value, so we constrain `0 == 1` if c == T::zero() { - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( T::zero().into(), T::one().into(), error, @@ -525,49 +584,58 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements .into_iter() .flat_map(|s| match s { - FlatStatement::Condition(left, right, message) => { - let mut output = VecDeque::new(); + FlatStatement::Condition(s) => { + let span = s.get_span(); + + let mut output = FlatStatements::default(); + + output.set_span(span); // we transform (a == b) into (c => (a == b)) which is (!c || (a == b)) // let's introduce new variables to make sure everything is linear - let name_left = self.define(left, &mut output); - let name_right = self.define(right, &mut output); + let name_lin = self.define(s.lin, &mut output); + let name_quad = self.define(s.quad, &mut output); // let's introduce an expression which is 1 iff `a == b` - let y = FlatExpression::Add( - box FlatExpression::Sub(box name_left.into(), box name_right.into()), - box T::one().into(), - ); - // let's introduce !c - let x = FlatExpression::Sub(box T::one().into(), box condition.clone()); - + let y = FlatExpression::add( + FlatExpression::sub(name_lin.into(), name_quad.into()), + T::one().into(), + ); // let's introduce !c + let x = FlatExpression::sub(T::one().into(), condition.clone()); assert!(x.is_linear() && y.is_linear()); let name_x_or_y = self.use_sym(); - output.push_back(FlatStatement::Directive(FlatDirective { - solver: Solver::Or, - outputs: vec![name_x_or_y], - inputs: vec![x.clone(), y.clone()], - })); - output.push_back(FlatStatement::Condition( - FlatExpression::Add( - box x.clone(), - box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()), + output.push_back(FlatStatement::directive( + vec![name_x_or_y], + Solver::Or, + vec![x.clone(), y.clone()], + )); + output.push_back(FlatStatement::condition( + FlatExpression::add( + x.clone(), + FlatExpression::sub(y.clone(), name_x_or_y.into()), ), - FlatExpression::Mult(box x.clone(), box y.clone()), + FlatExpression::mul(x, y), RuntimeError::BranchIsolation, )); - output.push_back(FlatStatement::Condition( + output.push_back(FlatStatement::condition( name_x_or_y.into(), T::one().into(), - message, + s.error, )); output } - s => VecDeque::from([s]), + s => { + let mut v = FlatStatements::default(); + v.push_back(s); + v + } + }) + .fold(FlatStatements::default(), |mut acc, s| { + acc.push_back(s); + acc }) - .collect() } /// Flatten an if/else expression @@ -593,14 +661,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { self.flatten_boolean_expression(statements_flattened, condition.clone()); let condition_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(condition_id, condition_flat)); + statements_flattened.push_back(FlatStatement::definition(condition_id, condition_flat)); let (consequence, alternative) = if self.config.isolate_branches { - let mut consequence_statements = VecDeque::new(); + let mut consequence_statements = FlatStatements::default(); let consequence = consequence.flatten(self, &mut consequence_statements); - let mut alternative_statements = VecDeque::new(); + let mut alternative_statements = FlatStatements::default(); let alternative = alternative.flatten(self, &mut alternative_statements); @@ -608,10 +676,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { self.make_conditional(consequence_statements, condition_id.into()); let alternative_statements = self.make_conditional( alternative_statements, - FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box condition_id.into(), - ), + FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()), ); statements_flattened.extend(consequence_statements); @@ -629,43 +694,37 @@ impl<'ast, T: Field> Flattener<'ast, T> { let alternative = alternative.flat(); let consequence_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(consequence_id, consequence)); + statements_flattened.push_back(FlatStatement::definition(consequence_id, consequence)); let alternative_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(alternative_id, alternative)); + statements_flattened.push_back(FlatStatement::definition(alternative_id, alternative)); let term0_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( term0_id, - FlatExpression::Mult( - box condition_id.into(), - box FlatExpression::from(consequence_id), - ), + FlatExpression::mul(condition_id.into(), FlatExpression::from(consequence_id)), )); let term1_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( term1_id, - FlatExpression::Mult( - box FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box condition_id.into(), - ), - box FlatExpression::from(alternative_id), + FlatExpression::mul( + FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()), + FlatExpression::from(alternative_id), ), )); let res = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( res, - FlatExpression::Add( - box FlatExpression::from(term0_id), - box FlatExpression::from(term1_id), + FlatExpression::add( + FlatExpression::from(term0_id), + FlatExpression::from(term1_id), ), )); FlatUExpression { - field: Some(FlatExpression::Identifier(res)), + field: Some(FlatExpression::identifier(res)), bits: None, } } @@ -736,8 +795,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { T::from(conditions.len()).into(), conditions .into_iter() - .fold(FlatExpression::Number(T::zero()), |acc, e| { - FlatExpression::Add(box acc, box e) + .fold(FlatExpression::value(T::zero()), |acc, e| { + FlatExpression::add(acc, e) }), ) } @@ -750,13 +809,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { rhs_flattened: FlatExpression, bit_width: usize, ) -> FlatExpression { - FlatExpression::Add( - box self.eq_check( + FlatExpression::add( + self.eq_check( statements_flattened, lhs_flattened.clone(), rhs_flattened.clone(), ), - box self.lt_check( + self.lt_check( statements_flattened, lhs_flattened, rhs_flattened, @@ -774,25 +833,25 @@ impl<'ast, T: Field> Flattener<'ast, T> { bit_width: usize, ) -> FlatExpression { match (lhs_flattened, rhs_flattened) { - (x, FlatExpression::Number(constant)) => { - self.constant_lt_check(statements_flattened, x, constant) + (x, FlatExpression::Value(constant)) => { + self.constant_lt_check(statements_flattened, x, constant.value) } // (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c) - (FlatExpression::Number(constant), x) => self.constant_lt_check( + (FlatExpression::Value(constant), x) => self.constant_lt_check( statements_flattened, - FlatExpression::Sub(box T::max_value().into(), box x), - T::max_value() - constant, + FlatExpression::sub(T::max_value().into(), x), + T::max_value() - constant.value, ), (lhs_flattened, rhs_flattened) => { let lhs_id = self.define(lhs_flattened, statements_flattened); let rhs_id = self.define(rhs_flattened, statements_flattened); // shifted_sub := 2**safe_width + lhs - rhs - let shifted_sub = FlatExpression::Add( - box FlatExpression::Number(T::from(2).pow(bit_width)), - box FlatExpression::Sub( - box FlatExpression::Identifier(lhs_id), - box FlatExpression::Identifier(rhs_id), + let shifted_sub = FlatExpression::add( + FlatExpression::value(T::from(2).pow(bit_width)), + FlatExpression::sub( + FlatExpression::identifier(lhs_id), + FlatExpression::identifier(rhs_id), ), ); @@ -806,9 +865,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { RuntimeError::IncompleteDynamicRange, ); - FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box shifted_sub_bits_be[0].clone(), + FlatExpression::sub( + FlatExpression::value(T::one()), + shifted_sub_bits_be[0].clone(), ) } } @@ -830,21 +889,27 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, expression: BooleanExpression<'ast, T>, ) -> FlatExpression { - match expression { + let span = expression.get_span(); + + let span_backup = statements_flattened.span; + + statements_flattened.set_span(span); + + let res = match expression { BooleanExpression::Identifier(x) => { - FlatExpression::Identifier(*self.layout.get(&x.id).unwrap()) + FlatExpression::identifier(*self.layout.get(&x.id).unwrap()) } BooleanExpression::Select(e) => self .flatten_select_expression(statements_flattened, e) .get_field_unchecked(), - BooleanExpression::FieldLt(box lhs, box rhs) => { + BooleanExpression::FieldLt(e) => { // Get the bit width to know the size of the binary decompositions for this Field let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to field elements whose difference is strictly smaller than 2**(bitwidth - 2) - let lhs_flattened = self.flatten_field_expression(statements_flattened, lhs); - let rhs_flattened = self.flatten_field_expression(statements_flattened, rhs); + let lhs_flattened = self.flatten_field_expression(statements_flattened, *e.left); + let rhs_flattened = self.flatten_field_expression(statements_flattened, *e.right); self.lt_check( statements_flattened, lhs_flattened, @@ -852,10 +917,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { safe_width, ) } - BooleanExpression::BoolEq(box lhs, box rhs) => { + BooleanExpression::BoolEq(e) => { // lhs and rhs are booleans, they flatten to 0 or 1 - let x = self.flatten_boolean_expression(statements_flattened, lhs); - let y = self.flatten_boolean_expression(statements_flattened, rhs); + let x = self.flatten_boolean_expression(statements_flattened, *e.left); + let y = self.flatten_boolean_expression(statements_flattened, *e.right); // Wanted: Not(X - Y)**2 which is an XNOR // We know that X and Y are [0, 1] // (X - Y) can become a negative values, which is why squaring the result is needed @@ -871,27 +936,27 @@ impl<'ast, T: Field> Flattener<'ast, T> { // | 0 | 0 | 0 | 1 | // +---+---+-------+---------------+ - let x_sub_y = FlatExpression::Sub(box x, box y); + let x_sub_y = FlatExpression::sub(x, y); let name_x_mult_x = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( name_x_mult_x, - FlatExpression::Mult(box x_sub_y.clone(), box x_sub_y), + FlatExpression::mul(x_sub_y.clone(), x_sub_y), )); - FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box FlatExpression::Identifier(name_x_mult_x), + FlatExpression::sub( + FlatExpression::value(T::one()), + FlatExpression::identifier(name_x_mult_x), ) } - BooleanExpression::FieldEq(box lhs, box rhs) => { - let lhs = self.flatten_field_expression(statements_flattened, lhs); + BooleanExpression::FieldEq(e) => { + let lhs = self.flatten_field_expression(statements_flattened, *e.left); - let rhs = self.flatten_field_expression(statements_flattened, rhs); + let rhs = self.flatten_field_expression(statements_flattened, *e.right); self.eq_check(statements_flattened, lhs, rhs) } - BooleanExpression::UintEq(box lhs, box rhs) => { + BooleanExpression::UintEq(e) => { // We reduce each side into range and apply the same approach as for field elements // Wanted: (Y = (X != 0) ? 1 : 0) @@ -901,39 +966,41 @@ impl<'ast, T: Field> Flattener<'ast, T> { // Y == X * M // 0 == (1-Y) * X - assert!(lhs.metadata.as_ref().unwrap().should_reduce.to_bool()); - assert!(rhs.metadata.as_ref().unwrap().should_reduce.to_bool()); + assert!(e.left.metadata.as_ref().unwrap().should_reduce.to_bool()); + assert!(e.right.metadata.as_ref().unwrap().should_reduce.to_bool()); let lhs = self - .flatten_uint_expression(statements_flattened, lhs) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs = self - .flatten_uint_expression(statements_flattened, rhs) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); self.eq_check(statements_flattened, lhs, rhs) } - BooleanExpression::FieldLe(box lhs, box rhs) => { + BooleanExpression::FieldLe(e) => { let lt = self.flatten_boolean_expression( statements_flattened, - BooleanExpression::FieldLt(box lhs.clone(), box rhs.clone()), + BooleanExpression::field_lt(*e.left.clone(), *e.right.clone()).span(span), ); + let eq = self.flatten_boolean_expression( statements_flattened, - BooleanExpression::FieldEq(box lhs, box rhs), + BooleanExpression::field_eq(*e.left, *e.right).span(span), ); - FlatExpression::Add(box eq, box lt) + + FlatExpression::add(eq, lt) } - BooleanExpression::UintLt(box lhs, box rhs) => { - let bit_width = lhs.bitwidth.to_usize(); - assert!(lhs.metadata.as_ref().unwrap().should_reduce.to_bool()); - assert!(rhs.metadata.as_ref().unwrap().should_reduce.to_bool()); + BooleanExpression::UintLt(e) => { + let bit_width = e.left.bitwidth.to_usize(); + assert!(e.left.metadata.as_ref().unwrap().should_reduce.to_bool()); + assert!(e.right.metadata.as_ref().unwrap().should_reduce.to_bool()); let lhs_flattened = self - .flatten_uint_expression(statements_flattened, lhs) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs_flattened = self - .flatten_uint_expression(statements_flattened, rhs) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); self.lt_check( @@ -943,55 +1010,55 @@ impl<'ast, T: Field> Flattener<'ast, T> { bit_width, ) } - BooleanExpression::UintLe(box lhs, box rhs) => { + BooleanExpression::UintLe(e) => { let lt = self.flatten_boolean_expression( statements_flattened, - BooleanExpression::UintLt(box lhs.clone(), box rhs.clone()), + BooleanExpression::uint_lt(*e.left.clone(), *e.right.clone()).span(span), ); let eq = self.flatten_boolean_expression( statements_flattened, - BooleanExpression::UintEq(box lhs, box rhs), + BooleanExpression::uint_eq(*e.left, *e.right).span(span), ); - FlatExpression::Add(box eq, box lt) + FlatExpression::add(eq, lt) } - BooleanExpression::Or(box lhs, box rhs) => { - let x = self.flatten_boolean_expression(statements_flattened, lhs); - let y = self.flatten_boolean_expression(statements_flattened, rhs); + BooleanExpression::Or(e) => { + let x = self.flatten_boolean_expression(statements_flattened, *e.left); + let y = self.flatten_boolean_expression(statements_flattened, *e.right); assert!(x.is_linear() && y.is_linear()); let name_x_or_y = self.use_sym(); - statements_flattened.push_back(FlatStatement::Directive(FlatDirective { - solver: Solver::Or, - outputs: vec![name_x_or_y], - inputs: vec![x.clone(), y.clone()], - })); - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Add( - box x.clone(), - box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()), + statements_flattened.push_back(FlatStatement::directive( + vec![name_x_or_y], + Solver::Or, + vec![x.clone(), y.clone()], + )); + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::add( + x.clone(), + FlatExpression::sub(y.clone(), name_x_or_y.into()), ), - FlatExpression::Mult(box x, box y), + FlatExpression::mul(x, y), RuntimeError::Or, )); name_x_or_y.into() } - BooleanExpression::And(box lhs, box rhs) => { - let x = self.flatten_boolean_expression(statements_flattened, lhs); - let y = self.flatten_boolean_expression(statements_flattened, rhs); + BooleanExpression::And(e) => { + let x = self.flatten_boolean_expression(statements_flattened, *e.left); + let y = self.flatten_boolean_expression(statements_flattened, *e.right); let name_x_and_y = self.use_sym(); assert!(x.is_linear() && y.is_linear()); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( name_x_and_y, - FlatExpression::Mult(box x, box y), + FlatExpression::mul(x, y), )); - FlatExpression::Identifier(name_x_and_y) + FlatExpression::identifier(name_x_and_y) } - BooleanExpression::Not(box exp) => { - let x = self.flatten_boolean_expression(statements_flattened, exp); - FlatExpression::Sub(box FlatExpression::Number(T::one()), box x) + BooleanExpression::Not(e) => { + let x = self.flatten_boolean_expression(statements_flattened, *e.inner); + FlatExpression::sub(FlatExpression::value(T::one()), x) } - BooleanExpression::Value(b) => FlatExpression::Number(match b { + BooleanExpression::Value(b) => FlatExpression::value(match b.value { true => T::from(1), false => T::from(0), }), @@ -999,6 +1066,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { .flatten_conditional_expression(statements_flattened, e) .get_field_unchecked(), } + .span(span); + + statements_flattened.set_span(span_backup); + + res } fn u_to_bits( @@ -1089,8 +1161,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { let constants: Vec<_> = constants .into_iter() .map(|e| match e.get_field_unchecked() { - FlatExpression::Number(n) if n == T::one() => true, - FlatExpression::Number(n) if n == T::zero() => false, + FlatExpression::Value(n) if n.value == T::one() => true, + FlatExpression::Value(n) if n.value == T::zero() => false, _ => unreachable!(), }) .collect(); @@ -1106,8 +1178,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { T::from(conditions.len()).into(), conditions .into_iter() - .fold(FlatExpression::Number(T::zero()), |acc, e| { - FlatExpression::Add(box acc, box e) + .fold(FlatExpression::value(T::zero()), |acc, e| { + FlatExpression::add(acc, e) }), ), )] @@ -1157,16 +1229,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { let statements = funct.statements.into_iter().map(|stat| match stat { FlatStatement::Block(..) => unreachable!(), - FlatStatement::Definition(var, rhs) => { + FlatStatement::Definition(s) => { let new_var = self.use_sym(); - replacement_map.insert(var, new_var); - let new_rhs = rhs.apply_substitution(&replacement_map); - FlatStatement::Definition(new_var, new_rhs) + replacement_map.insert(s.assignee, new_var); + let new_rhs = s.rhs.apply_substitution(&replacement_map); + FlatStatement::definition(new_var, new_rhs) } - FlatStatement::Condition(lhs, rhs, message) => { - let new_lhs = lhs.apply_substitution(&replacement_map); - let new_rhs = rhs.apply_substitution(&replacement_map); - FlatStatement::Condition(new_lhs, new_rhs, message) + FlatStatement::Condition(s) => { + let new_quad = s.quad.apply_substitution(&replacement_map); + let new_lin = s.lin.apply_substitution(&replacement_map); + FlatStatement::condition(new_lin, new_quad, s.error) } FlatStatement::Directive(d) => { let new_outputs = d @@ -1183,15 +1255,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .map(|i| i.apply_substitution(&replacement_map)) .collect(); - FlatStatement::Directive(FlatDirective { - outputs: new_outputs, - solver: d.solver, - inputs: new_inputs, - }) + FlatStatement::directive(new_outputs, d.solver, new_inputs) } - FlatStatement::Log(l, expressions) => FlatStatement::Log( - l, - expressions + FlatStatement::Log(s) => FlatStatement::Log(LogStatement::new( + s.format_string, + s.expressions .into_iter() .map(|(t, e)| { ( @@ -1202,7 +1270,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) }) .collect(), - ), + )), }); statements_flattened.extend(statements); @@ -1253,12 +1321,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .zip(right_bits.into_iter()) .map(|(x, y)| match (x, y) { - (FlatExpression::Number(n), e) | (e, FlatExpression::Number(n)) => { - if n == T::from(0) { + (FlatExpression::Value(n), e) | (e, FlatExpression::Value(n)) => { + if n.value == T::from(0) { self.define(e, statements_flattened).into() - } else if n == T::from(1) { + } else if n.value == T::from(1) { self.define( - FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box e), + FlatExpression::sub(FlatExpression::value(T::from(1)), e), statements_flattened, ) .into() @@ -1270,20 +1338,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { let name = self.use_sym(); statements_flattened.extend(vec![ - FlatStatement::Directive(FlatDirective::new( + FlatStatement::directive( vec![name], Solver::Xor, vec![x.clone(), y.clone()], - )), - FlatStatement::Condition( - FlatExpression::Add( - box x.clone(), - box FlatExpression::Sub(box y.clone(), box name.into()), - ), - FlatExpression::Mult( - box FlatExpression::Add(box x.clone(), box x), - box y, + ), + FlatStatement::condition( + FlatExpression::add( + x.clone(), + FlatExpression::sub(y.clone(), name.into()), ), + FlatExpression::mul(FlatExpression::add(x.clone(), x), y), RuntimeError::Xor, ), ]); @@ -1313,26 +1378,26 @@ impl<'ast, T: Field> Flattener<'ast, T> { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let d = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; // introduce the quotient and remainder let q = self.use_sym(); let r = self.use_sym(); - statements_flattened.push_back(FlatStatement::Directive(FlatDirective { - inputs: vec![n.clone(), d.clone()], - outputs: vec![q, r], - solver: Solver::EuclideanDiv, - })); + statements_flattened.push_back(FlatStatement::directive( + vec![q, r], + Solver::EuclideanDiv, + vec![n.clone(), d.clone()], + )); let target_bitwidth = target_bitwidth.to_usize(); @@ -1356,9 +1421,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { // r < d <=> r - d + 2**w < 2**w let _ = self.get_bits_unchecked( - &FlatUExpression::with_field(FlatExpression::Add( - box FlatExpression::Sub(box r.into(), box d.clone()), - box FlatExpression::Number(T::from(2_u128.pow(target_bitwidth as u32))), + &FlatUExpression::with_field(FlatExpression::add( + FlatExpression::sub(r.into(), d.clone()), + FlatExpression::value(T::from(2_u128.pow(target_bitwidth as u32))), )), target_bitwidth, target_bitwidth, @@ -1367,9 +1432,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); // q*d == n - r - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Sub(box n, box r.into()), - FlatExpression::Mult(box q.into(), box d), + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::sub(n, r.into()), + FlatExpression::mul(q.into(), d), RuntimeError::Euclidean, )); @@ -1387,6 +1452,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, expr: UExpression<'ast, T>, ) -> FlatUExpression { + let span = expr.as_inner().get_span(); + + let span_backup = statements_flattened.span; + + statements_flattened.set_span(span); + // the bitwidth for this type of uint (8, 16 or 32) let target_bitwidth = expr.bitwidth; @@ -1402,10 +1473,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { let res = match expr.into_inner() { UExpressionInner::Value(x) => { - FlatUExpression::with_field(FlatExpression::Number(T::from(x))) + FlatUExpression::with_field(FlatExpression::value(T::from(x.value))) } // force to be a field element UExpressionInner::Identifier(x) => { - let field = FlatExpression::Identifier(*self.layout.get(&x.id).unwrap()); + let field = FlatExpression::identifier(*self.layout.get(&x.id).unwrap()); let bits = self.bits_cache.get(&field).map(|bits| { assert_eq!(bits.len(), target_bitwidth.to_usize()); bits.clone() @@ -1413,8 +1484,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_field(field).bits(bits) } UExpressionInner::Select(e) => self.flatten_select_expression(statements_flattened, e), - UExpressionInner::Not(box e) => { - let e = self.flatten_uint_expression(statements_flattened, e); + UExpressionInner::Not(e) => { + let e = self.flatten_uint_expression(statements_flattened, *e.inner); let e_bits = e.bits.unwrap(); @@ -1424,7 +1495,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .map(|bit| { self.define( - FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box bit), + FlatExpression::sub(FlatExpression::value(T::from(1)), bit), statements_flattened, ) .into() @@ -1433,66 +1504,73 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_bits(name_not) } - UExpressionInner::Add(box left, box right) => { + UExpressionInner::Add(e) => { let left_flattened = self - .flatten_uint_expression(statements_flattened, left) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let right_flattened = self - .flatten_uint_expression(statements_flattened, right) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; - FlatUExpression::with_field(FlatExpression::Add(box new_left, box new_right)) + FlatUExpression::with_field(FlatExpression::add(new_left, new_right)) } - UExpressionInner::Sub(box left, box right) => { + UExpressionInner::Sub(e) => { // see uint optimizer for the reasoning here - let offset = FlatExpression::Number(T::from(2).pow(std::cmp::max( - right.metadata.as_ref().unwrap().bitwidth() as usize, + let offset = FlatExpression::value(T::from(2).pow(std::cmp::max( + e.right.metadata.as_ref().unwrap().bitwidth() as usize, target_bitwidth as usize, ))); let left_flattened = self - .flatten_uint_expression(statements_flattened, left) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let right_flattened = self - .flatten_uint_expression(statements_flattened, right) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; - FlatUExpression::with_field(FlatExpression::Add( - box offset, - box FlatExpression::Sub(box new_left, box new_right), + FlatUExpression::with_field(FlatExpression::add( + offset, + FlatExpression::sub(new_left, new_right), )) } - UExpressionInner::LeftShift(box e, by) => { + UExpressionInner::LeftShift(e) => { + let by = match e.right.into_inner() { + UExpressionInner::Value(v) => v.value as u32, + _ => unreachable!(), + }; + + let e = *e.left; + let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); @@ -1505,12 +1583,19 @@ impl<'ast, T: Field> Flattener<'ast, T> { .skip(by as usize) .chain( (0..std::cmp::min(by as usize, target_bitwidth.to_usize())) - .map(|_| FlatExpression::Number(T::from(0))), + .map(|_| FlatExpression::value(T::from(0))), ) .collect::>(), ) } - UExpressionInner::RightShift(box e, by) => { + UExpressionInner::RightShift(e) => { + let by = match e.right.into_inner() { + UExpressionInner::Value(v) => v.value as u32, + _ => unreachable!(), + }; + + let e = *e.left; + let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); @@ -1519,7 +1604,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_bits( (0..std::cmp::min(by as usize, target_bitwidth.to_usize())) - .map(|_| FlatExpression::Number(T::from(0))) + .map(|_| FlatExpression::value(T::from(0))) .chain(e_bits.into_iter().take( target_bitwidth.to_usize() - std::cmp::min(by as usize, target_bitwidth.to_usize()), @@ -1527,59 +1612,75 @@ impl<'ast, T: Field> Flattener<'ast, T> { .collect::>(), ) } - UExpressionInner::Mult(box left, box right) => { + UExpressionInner::Mult(e) => { let left_flattened = self - .flatten_uint_expression(statements_flattened, left) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let right_flattened = self - .flatten_uint_expression(statements_flattened, right) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; let res = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( res, - FlatExpression::Mult(box new_left, box new_right), + FlatExpression::mul(new_left, new_right), )); - FlatUExpression::with_field(FlatExpression::Identifier(res)) + FlatUExpression::with_field(FlatExpression::identifier(res)) } - UExpressionInner::Div(box left, box right) => { - let (q, _) = - self.euclidean_division(statements_flattened, target_bitwidth, left, right); + UExpressionInner::Div(e) => { + let (q, _) = self.euclidean_division( + statements_flattened, + target_bitwidth, + *e.left, + *e.right, + ); FlatUExpression::with_field(q) } - UExpressionInner::Rem(box left, box right) => { - let (_, r) = - self.euclidean_division(statements_flattened, target_bitwidth, left, right); + UExpressionInner::Rem(e) => { + let (_, r) = self.euclidean_division( + statements_flattened, + target_bitwidth, + *e.left, + *e.right, + ); FlatUExpression::with_field(r) } UExpressionInner::Conditional(e) => { self.flatten_conditional_expression(statements_flattened, e) } - UExpressionInner::Xor(box left, box right) => { - let left_metadata = left.metadata.clone().unwrap(); - let right_metadata = right.metadata.clone().unwrap(); + UExpressionInner::Xor(e) => { + let left_metadata = e.left.metadata.clone().unwrap(); + let right_metadata = e.right.metadata.clone().unwrap(); - match (left.into_inner(), right.into_inner()) { - (UExpressionInner::And(box a, box b), UExpressionInner::And(box aa, box c)) => { - if aa.clone().into_inner() == UExpressionInner::Not(box a.clone()) { + let left_span = e.left.get_span(); + let right_span = e.right.get_span(); + + match (e.left.into_inner(), e.right.into_inner()) { + (UExpressionInner::And(e), UExpressionInner::And(ee)) => { + let a = *e.left; + let b = *e.right; + let aa = *ee.left; + let c = *ee.right; + + if aa.as_inner() == UExpression::not(a.clone()).as_inner() { let a_flattened = self.flatten_uint_expression(statements_flattened, a); let b_flattened = self.flatten_uint_expression(statements_flattened, b); let c_flattened = self.flatten_uint_expression(statements_flattened, c); @@ -1598,17 +1699,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { let ch = self.use_sym(); statements_flattened.extend(vec![ - FlatStatement::Directive(FlatDirective::new( + FlatStatement::directive( vec![ch], Solver::ShaCh, vec![a.clone(), b.clone(), c.clone()], - )), - FlatStatement::Condition( - FlatExpression::Sub(box ch.into(), box c.clone()), - FlatExpression::Mult( - box a, - box FlatExpression::Sub(box b, box c), - ), + ), + FlatStatement::condition( + FlatExpression::sub(ch.into(), c.clone()), + a * (b - c), RuntimeError::ShaXor, ), ]); @@ -1620,25 +1718,32 @@ impl<'ast, T: Field> Flattener<'ast, T> { } else { self.default_xor( statements_flattened, - UExpressionInner::And(box a, box b) - .annotate(target_bitwidth) - .metadata(left_metadata), - UExpressionInner::And(box aa, box c) - .annotate(target_bitwidth) - .metadata(right_metadata), + UExpression::and(a, b).metadata(left_metadata), + UExpression::and(aa, c).metadata(right_metadata), ) } } - (UExpressionInner::Xor(box a, box b), c) => { - let a_metadata = a.metadata.clone().unwrap(); - let b_metadata = b.metadata.clone().unwrap(); + (UExpressionInner::Xor(e), c) => { + let a_metadata = e.left.metadata.clone().unwrap(); + let b_metadata = e.right.metadata.clone().unwrap(); - match (a.into_inner(), b.into_inner(), c) { + let a_span = e.left.get_span(); + let b_span = e.right.get_span(); + let c_span = right_span; + + match (e.left.into_inner(), e.right.into_inner(), c) { ( - UExpressionInner::And(box a, box b), - UExpressionInner::And(box aa, box c), - UExpressionInner::And(box bb, box cc), + UExpressionInner::And(e0), + UExpressionInner::And(e1), + UExpressionInner::And(e2), ) => { + let a = *e0.left; + let b = *e0.right; + let aa = *e1.left; + let c = *e1.right; + let bb = *e2.left; + let cc = *e2.right; + if (aa == a) && (bb == b) && (cc == c) { let a_flattened = self.flatten_uint_expression(statements_flattened, a); @@ -1663,33 +1768,27 @@ impl<'ast, T: Field> Flattener<'ast, T> { let bc = self.use_sym(); statements_flattened.extend(vec![ - FlatStatement::Directive(FlatDirective::new( + FlatStatement::directive( vec![maj], Solver::ShaAndXorAndXorAnd, vec![a.clone(), b.clone(), c.clone()], - )), - FlatStatement::Condition( + ), + FlatStatement::condition( bc.into(), - FlatExpression::Mult( - box b.clone(), - box c.clone(), - ), + FlatExpression::mul(b.clone(), c.clone()), RuntimeError::ShaXor, ), - FlatStatement::Condition( - FlatExpression::Sub( - box bc.into(), - box maj.into(), - ), - FlatExpression::Mult( - box FlatExpression::Sub( - box FlatExpression::Add( - box bc.into(), - box bc.into(), + FlatStatement::condition( + FlatExpression::sub(bc.into(), maj.into()), + FlatExpression::mul( + FlatExpression::sub( + FlatExpression::add( + bc.into(), + bc.into(), ), - box FlatExpression::Add(box b, box c), + FlatExpression::add(b, c), ), - box a, + a, ), RuntimeError::ShaXor, ), @@ -1702,45 +1801,57 @@ impl<'ast, T: Field> Flattener<'ast, T> { } else { self.default_xor( statements_flattened, - UExpressionInner::Xor( - box UExpressionInner::And(box a, box b) - .annotate(target_bitwidth) - .metadata(a_metadata), - box UExpressionInner::And(box aa, box c) - .annotate(target_bitwidth) - .metadata(b_metadata), + UExpression::xor( + UExpression::and(a, b) + .metadata(a_metadata) + .span(a_span), + UExpression::and(aa, c) + .metadata(b_metadata) + .span(b_span), ) - .annotate(target_bitwidth) - .metadata(left_metadata), - UExpressionInner::And(box bb, box cc) - .annotate(target_bitwidth) - .metadata(right_metadata), + .metadata(left_metadata) + .span(left_span), + UExpression::and(bb, cc) + .metadata(right_metadata) + .span(c_span), ) } } (a, b, c) => self.default_xor( statements_flattened, - UExpressionInner::Xor( - box a.annotate(target_bitwidth).metadata(a_metadata), - box b.annotate(target_bitwidth).metadata(b_metadata), + UExpression::xor( + a.annotate(target_bitwidth) + .metadata(a_metadata) + .span(a_span), + b.annotate(target_bitwidth) + .metadata(b_metadata) + .span(b_span), ) - .annotate(target_bitwidth) - .metadata(left_metadata), - c.annotate(target_bitwidth).metadata(right_metadata), + .metadata(left_metadata) + .span(left_span), + c.annotate(target_bitwidth) + .metadata(right_metadata) + .span(c_span), ), } } (left_i, right_i) => self.default_xor( statements_flattened, - left_i.annotate(target_bitwidth).metadata(left_metadata), - right_i.annotate(target_bitwidth).metadata(right_metadata), + left_i + .annotate(target_bitwidth) + .metadata(left_metadata) + .span(left_span), + right_i + .annotate(target_bitwidth) + .metadata(right_metadata) + .span(right_span), ), } } - UExpressionInner::And(box left, box right) => { - let left_flattened = self.flatten_uint_expression(statements_flattened, left); + UExpressionInner::And(e) => { + let left_flattened = self.flatten_uint_expression(statements_flattened, *e.left); - let right_flattened = self.flatten_uint_expression(statements_flattened, right); + let right_flattened = self.flatten_uint_expression(statements_flattened, *e.right); let left_bits = left_flattened.bits.unwrap(); let right_bits = right_flattened.bits.unwrap(); @@ -1752,26 +1863,26 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .zip(right_bits.into_iter()) .map(|(x, y)| match (x, y) { - (FlatExpression::Number(n), e) | (e, FlatExpression::Number(n)) => { - if n == T::from(0) { - FlatExpression::Number(T::from(0)) - } else if n == T::from(1) { + (FlatExpression::Value(n), e) | (e, FlatExpression::Value(n)) => { + if n.value == T::from(0) { + FlatExpression::value(T::from(0)) + } else if n.value == T::from(1) { e } else { unreachable!(); } } (x, y) => self - .define(FlatExpression::Mult(box x, box y), statements_flattened) + .define(FlatExpression::mul(x, y), statements_flattened) .into(), }) .collect(); FlatUExpression::with_bits(and) } - UExpressionInner::Or(box left, box right) => { - let left_flattened = self.flatten_uint_expression(statements_flattened, left); - let right_flattened = self.flatten_uint_expression(statements_flattened, right); + UExpressionInner::Or(e) => { + let left_flattened = self.flatten_uint_expression(statements_flattened, *e.left); + let right_flattened = self.flatten_uint_expression(statements_flattened, *e.right); let left_bits = left_flattened.bits.unwrap(); let right_bits = right_flattened.bits.unwrap(); @@ -1786,11 +1897,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .zip(right_bits.into_iter()) .map(|(x, y)| match (x, y) { - (FlatExpression::Number(n), e) | (e, FlatExpression::Number(n)) => { - if n == T::from(0) { + (FlatExpression::Value(n), e) | (e, FlatExpression::Value(n)) => { + if n.value == T::from(0) { self.define(e, statements_flattened).into() - } else if n == T::from(1) { - FlatExpression::Number(T::from(1)) + } else if n.value == T::from(1) { + FlatExpression::value(T::from(1)) } else { unreachable!() } @@ -1799,17 +1910,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { let name = self.use_sym(); statements_flattened.extend(vec![ - FlatStatement::Directive(FlatDirective::new( + FlatStatement::directive( vec![name], Solver::Or, vec![x.clone(), y.clone()], - )), - FlatStatement::Condition( - FlatExpression::Add( - box x.clone(), - box FlatExpression::Sub(box y.clone(), box name.into()), + ), + FlatStatement::condition( + FlatExpression::add( + x.clone(), + FlatExpression::sub(y.clone(), name.into()), ), - FlatExpression::Mult(box x, box y), + FlatExpression::mul(x, y), RuntimeError::Or, ), ]); @@ -1820,7 +1931,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_bits(or) } - }; + } + .span(span); let res = match should_reduce { true => { @@ -1834,15 +1946,15 @@ impl<'ast, T: Field> Flattener<'ast, T> { let field = if actual_bitwidth > target_bitwidth.to_usize() { bits.iter().enumerate().fold( - FlatExpression::Number(T::from(0)), + FlatExpression::value(T::from(0)), |acc, (index, bit)| { - FlatExpression::Add( - box acc, - box FlatExpression::Mult( - box FlatExpression::Number( + FlatExpression::add( + acc, + FlatExpression::mul( + FlatExpression::value( T::from(2).pow(target_bitwidth.to_usize() - index - 1), ), - box bit.clone(), + bit.clone(), ), ) }, @@ -1856,6 +1968,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { false => res, }; + statements_flattened.set_span(span_backup); + res } @@ -1884,11 +1998,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { assert!(to <= T::get_required_bits()); // constants do not require directives - if let Some(FlatExpression::Number(ref x)) = e.field { - let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[*x], &[]) + if let Some(FlatExpression::Value(ref x)) = e.field { + let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.value], &[]) .unwrap() .into_iter() - .map(FlatExpression::Number) + .map(FlatExpression::value) .collect(); assert_eq!(bits.len(), to); @@ -1915,28 +2029,28 @@ impl<'ast, T: Field> Flattener<'ast, T> { res } else { (0..to - res.len()) - .map(|_| FlatExpression::Number(T::zero())) + .map(|_| FlatExpression::value(T::zero())) .chain(res) .collect() } } Entry::Vacant(_) => { let bits = (0..from).map(|_| self.use_sym()).collect::>(); - statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::directive( bits.clone(), Solver::Bits(from), vec![e.field.clone().unwrap()], - ))); + )); - let bits: Vec<_> = bits.into_iter().map(FlatExpression::Identifier).collect(); + let bits: Vec<_> = bits.into_iter().map(FlatExpression::identifier).collect(); // decompose to the actual bitwidth // bit checks statements_flattened.extend(bits.iter().take(from).map(|bit| { - FlatStatement::Condition( + FlatStatement::condition( bit.clone(), - FlatExpression::Mult(box bit.clone(), box bit.clone()), + FlatExpression::mul(bit.clone(), bit.clone()), RuntimeError::Bitness, ) })); @@ -1944,7 +2058,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let sum = flat_expression_from_bits(bits.clone()); // sum check - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( e.field.clone().unwrap(), sum.clone(), error, @@ -1974,6 +2088,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, e: SelectExpression<'ast, T, U>, ) -> FlatUExpression { + let span = e.get_span(); + let array = e.array; let index = *e.index; @@ -1983,15 +2099,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { .map(|(i, e)| { let condition = self.flatten_boolean_expression( statements_flattened, - BooleanExpression::UintEq( - box UExpressionInner::Value(i as u128) + BooleanExpression::uint_eq( + UExpression::value(i as u128) .annotate(UBitwidth::B32) .metadata(UMetadata { should_reduce: ShouldReduce::True, max: T::from(i), - }), - box index.clone(), - ), + }) + .span(span), + index.clone(), + ) + .span(span), ); let element = e.flatten(self, statements_flattened); @@ -2002,26 +2120,29 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .fold( ( - FlatExpression::Number(T::zero()), - FlatExpression::Number(T::zero()), + FlatExpression::value(T::zero()), + FlatExpression::value(T::zero()), ), |(mut range_check, mut result), (condition, element)| { - range_check = FlatExpression::Add(box range_check, box condition.clone()); + range_check = FlatExpression::add(range_check, condition.clone()).span(span); let conditional_element_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( - conditional_element_id, - FlatExpression::Mult(box condition, box element.flat()), - )); + statements_flattened.push_back( + FlatStatement::definition( + conditional_element_id, + FlatExpression::mul(condition, element.flat()).span(span), + ) + .span(span), + ); - result = FlatExpression::Add(box result, box conditional_element_id.into()); + result = FlatExpression::add(result, conditional_element_id.into()).span(span); (range_check, result) }, ); - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( range_check, - FlatExpression::Number(T::one()), + FlatExpression::value(T::one()), RuntimeError::SelectRangeCheck, )); FlatUExpression::with_field(result) @@ -2038,84 +2159,90 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, expr: FieldElementExpression<'ast, T>, ) -> FlatExpression { - match expr { - FieldElementExpression::Number(x) => FlatExpression::Number(x), // force to be a field element - FieldElementExpression::Identifier(x) => FlatExpression::Identifier( + let span = expr.get_span(); + + let span_backup = statements_flattened.span; + + statements_flattened.set_span(span); + + let res = match expr { + FieldElementExpression::Value(x) => FlatExpression::Value(x), // force to be a field element + FieldElementExpression::Identifier(x) => FlatExpression::identifier( *self.layout.get(&x.id).unwrap_or_else(|| panic!("{}", x)), ), FieldElementExpression::Select(e) => self .flatten_select_expression(statements_flattened, e) .get_field_unchecked(), - FieldElementExpression::Add(box left, box right) => { - let left_flattened = self.flatten_field_expression(statements_flattened, left); - let right_flattened = self.flatten_field_expression(statements_flattened, right); + FieldElementExpression::Add(e) => { + let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); + let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; - FlatExpression::Add(box new_left, box new_right) + FlatExpression::add(new_left, new_right) } - FieldElementExpression::Sub(box left, box right) => { - let left_flattened = self.flatten_field_expression(statements_flattened, left); - let right_flattened = self.flatten_field_expression(statements_flattened, right); + FieldElementExpression::Sub(e) => { + let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); + let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; - FlatExpression::Sub(box new_left, box new_right) + FlatExpression::sub(new_left, new_right) } - FieldElementExpression::Mult(box left, box right) => { - let left_flattened = self.flatten_field_expression(statements_flattened, left); - let right_flattened = self.flatten_field_expression(statements_flattened, right); + FieldElementExpression::Mult(e) => { + let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); + let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; - FlatExpression::Mult(box new_left, box new_right) + FlatExpression::mul(new_left, new_right) } - FieldElementExpression::Div(box left, box right) => { - let left_flattened = self.flatten_field_expression(statements_flattened, left); - let right_flattened = self.flatten_field_expression(statements_flattened, right); + FieldElementExpression::Div(e) => { + let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); + let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left: FlatExpression = { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); id.into() }; let new_right: FlatExpression = { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); id.into() }; @@ -2125,34 +2252,34 @@ impl<'ast, T: Field> Flattener<'ast, T> { let inverse = self.use_sym(); // # c = a/b - statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::directive( vec![inverse], Solver::Div, vec![new_left.clone(), new_right.clone()], - ))); + )); // assert(c * b == a) - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( new_left, - FlatExpression::Mult(box new_right, box inverse.into()), + FlatExpression::mul(new_right, inverse.into()), RuntimeError::Division, )); inverse.into() } - FieldElementExpression::Pow(box base, box exponent) => { - match exponent.into_inner() { - UExpressionInner::Value(ref e) => { + FieldElementExpression::Pow(e) => { + match e.right.into_inner() { + UExpressionInner::Value(ref exp) => { // flatten the base expression let base_flattened = - self.flatten_field_expression(statements_flattened, base.clone()); + self.flatten_field_expression(statements_flattened, *e.left.clone()); // we require from the base to be linear // TODO change that assert!(base_flattened.is_linear()); // convert the exponent to bytes, big endian - let ebytes_be = e.to_be_bytes(); + let ebytes_be = exp.value.to_be_bytes(); // convert the bytes to bits, remove leading zeroes (we only need powers up to the highest non-zero bit) #[allow(clippy::needless_collect)] @@ -2181,17 +2308,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { // introduce a new variable let id = self.use_sym(); // set it to the square of the previous one, stored in state - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( id, - FlatExpression::Mult( - box previous.clone(), - box previous.clone(), - ), + FlatExpression::mul(previous.clone(), previous.clone()), )); // store it in the state for later squaring - *state = Some(FlatExpression::Identifier(id)); + *state = Some(FlatExpression::identifier(id)); // return it for later use constructing the result - Some(FlatExpression::Identifier(id)) + Some(FlatExpression::identifier(id)) } } }) @@ -2199,16 +2323,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { // construct the result iterating through the bits, multiplying by the associated power iff the bit is true ebits_le.into_iter().zip(powers).fold( - FlatExpression::Number(T::from(1)), // initialise the result at 1. If we have no bits to itegrate through, we're computing x**0 == 1 + FlatExpression::value(T::from(1)), // initialise the result at 1. If we have no bits to iterate through, we're computing x**0 == 1 |acc, (bit, power)| match bit { true => { // update the result by introducing a new variable let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( id, - FlatExpression::Mult(box acc.clone(), box power), // set the new result to the current result times the current power + FlatExpression::mul(acc, power), // set the new result to the current result times the current power )); - FlatExpression::Identifier(id) + FlatExpression::identifier(id) } false => acc, // this bit is false, keep the previous result }, @@ -2221,7 +2345,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { .flatten_conditional_expression(statements_flattened, e) .get_field_unchecked(), _ => unreachable!(), - } + }; + + statements_flattened.set_span(span_backup); + + res } fn flatten_assembly_statement( @@ -2229,36 +2357,46 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, stat: ZirAssemblyStatement<'ast, T>, ) { + let span = stat.get_span(); + + let span_backup = statements_flattened.span; + + statements_flattened.set_span(span); + match stat { - ZirAssemblyStatement::Assignment(assignees, function) => { - let inputs: Vec> = function + ZirAssemblyStatement::Assignment(s) => { + let inputs: Vec> = s + .expression .arguments .iter() .cloned() .map(|p| self.layout.get(&p.id.id).cloned().unwrap().into()) .collect(); - let outputs: Vec = assignees + let outputs: Vec = s + .assignee .into_iter() .map(|assignee| self.use_variable(&assignee)) .collect(); let mut canonicalizer = ZirCanonicalizer::default(); - let function = canonicalizer.fold_function(function); + let function = canonicalizer.fold_function(s.expression); let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); } - ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { - let lhs = self.flatten_field_expression(statements_flattened, lhs); - let rhs = self.flatten_field_expression(statements_flattened, rhs); - statements_flattened.push_back(FlatStatement::Condition( + ZirAssemblyStatement::Constraint(s) => { + let lhs = self.flatten_field_expression(statements_flattened, s.left); + let rhs = self.flatten_field_expression(statements_flattened, s.right); + statements_flattened.push_back(FlatStatement::condition( lhs, rhs, - RuntimeError::SourceAssemblyConstraint(metadata), + RuntimeError::SourceAssemblyConstraint(s.metadata), )); } - } + }; + + statements_flattened.set_span(span_backup); } /// Flattens a statement @@ -2270,21 +2408,21 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn flatten_statement( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, - stat: ZirStatement<'ast, T>, + s: ZirStatement<'ast, T>, ) { - match stat { - ZirStatement::Assembly(statements) => { - let mut block_statements = VecDeque::new(); - for s in statements { - self.flatten_assembly_statement(&mut block_statements, s); - } - statements_flattened.push_back(FlatStatement::Block(block_statements.into())); - } - ZirStatement::Return(exprs) => { + let span = s.get_span(); + + let span_backup = statements_flattened.span; + + statements_flattened.set_span(span); + + match s { + ZirStatement::Return(s) => { #[allow(clippy::needless_collect)] // clippy suggests to not collect here, but `statements_flattened` is borrowed in the iterator, // so we cannot borrow again when extending - let flat_expressions: Vec<_> = exprs + let flat_expressions: Vec<_> = s + .inner .into_iter() .map(|expr| self.flatten_expression(statements_flattened, expr)) .map(|x| x.get_field_unchecked()) @@ -2294,25 +2432,35 @@ impl<'ast, T: Field> Flattener<'ast, T> { flat_expressions .into_iter() .enumerate() - .map(|(index, e)| FlatStatement::Definition(Variable::public(index), e)), + .map(|(index, e)| FlatStatement::definition(Variable::public(index), e)), ); } - ZirStatement::IfElse(condition, consequence, alternative) => { + ZirStatement::Assembly(s) => { + let mut block_statements = FlatStatements::default(); + block_statements.set_span(s.get_span()); + + for s in s.inner { + self.flatten_assembly_statement(&mut block_statements, s); + } + statements_flattened + .push_back(FlatStatement::block(block_statements.buffer.into())); + } + ZirStatement::IfElse(s) => { let condition_flat = - self.flatten_boolean_expression(statements_flattened, condition.clone()); + self.flatten_boolean_expression(statements_flattened, s.condition.clone()); let condition_id = self.use_sym(); statements_flattened - .push_back(FlatStatement::Definition(condition_id, condition_flat)); + .push_back(FlatStatement::definition(condition_id, condition_flat)); if self.config.isolate_branches { - let mut consequence_statements = VecDeque::new(); - let mut alternative_statements = VecDeque::new(); + let mut consequence_statements = FlatStatements::default(); + let mut alternative_statements = FlatStatements::default(); - consequence + s.consequence .into_iter() .for_each(|s| self.flatten_statement(&mut consequence_statements, s)); - alternative + s.alternative .into_iter() .for_each(|s| self.flatten_statement(&mut alternative_statements, s)); @@ -2320,41 +2468,41 @@ impl<'ast, T: Field> Flattener<'ast, T> { self.make_conditional(consequence_statements, condition_id.into()); let alternative_statements = self.make_conditional( alternative_statements, - FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box condition_id.into(), - ), + FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()), ); statements_flattened.extend(consequence_statements); statements_flattened.extend(alternative_statements); } else { - consequence + s.consequence .into_iter() .for_each(|s| self.flatten_statement(statements_flattened, s)); - alternative + s.alternative .into_iter() .for_each(|s| self.flatten_statement(statements_flattened, s)); } } - ZirStatement::Definition(assignee, expr) => { + ZirStatement::Definition(s) => { // define n variables with n the number of primitive types for v_type // assign them to the n primitive types for expr + let assignee = s.assignee; + let expr = s.rhs; + let rhs = self.flatten_expression(statements_flattened, expr); let bits = rhs.bits.clone(); let var = match rhs.get_field_unchecked() { FlatExpression::Identifier(id) => { - self.use_variable_with_existing(&assignee, id); - id + self.use_variable_with_existing(&assignee, id.id); + id.id } e => { let var = self.use_variable(&assignee); // handle return of function call - statements_flattened.push_back(FlatStatement::Definition(var, e)); + statements_flattened.push_back(FlatStatement::definition(var, e)); var } @@ -2363,22 +2511,25 @@ impl<'ast, T: Field> Flattener<'ast, T> { // register bits if let Some(bits) = bits { self.bits_cache - .insert(FlatExpression::Identifier(var), bits); + .insert(FlatExpression::identifier(var), bits); } } - ZirStatement::Assertion(e, error) => { + ZirStatement::Assertion(s) => { + let e = s.expression; + let error = s.error; + match e { BooleanExpression::And(..) => { for boolean in e.into_conjunction_iterator() { self.flatten_statement( statements_flattened, - ZirStatement::Assertion(boolean, error.clone()), + ZirStatement::assertion(boolean, error.clone()).span(span), ) } } - BooleanExpression::FieldEq(box lhs, box rhs) => { - let lhs = self.flatten_field_expression(statements_flattened, lhs); - let rhs = self.flatten_field_expression(statements_flattened, rhs); + BooleanExpression::FieldEq(e) => { + let lhs = self.flatten_field_expression(statements_flattened, *e.left); + let rhs = self.flatten_field_expression(statements_flattened, *e.right); self.flatten_equality_assertion( statements_flattened, @@ -2387,106 +2538,106 @@ impl<'ast, T: Field> Flattener<'ast, T> { error.into(), ) } - BooleanExpression::FieldLt(box lhs, box rhs) => { - let lhs = self.flatten_field_expression(statements_flattened, lhs); - let rhs = self.flatten_field_expression(statements_flattened, rhs); + BooleanExpression::FieldLt(e) => { + let lhs = self.flatten_field_expression(statements_flattened, *e.left); + let rhs = self.flatten_field_expression(statements_flattened, *e.right); match (lhs, rhs) { - (e, FlatExpression::Number(c)) => self.enforce_constant_lt_check( + (e, FlatExpression::Value(c)) => self.enforce_constant_lt_check( statements_flattened, e, - c, + c.value, error.into(), ), // c < e <=> p - 1 - e < p - 1 - c - (FlatExpression::Number(c), e) => self.enforce_constant_lt_check( + (FlatExpression::Value(c), e) => self.enforce_constant_lt_check( statements_flattened, - FlatExpression::Sub(box T::max_value().into(), box e), - T::max_value() - c, + FlatExpression::sub(T::max_value().into(), e), + T::max_value() - c.value, error.into(), ), (lhs, rhs) => { let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete let e = self.lt_check(statements_flattened, lhs, rhs, safe_width); - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( e, - FlatExpression::Number(T::one()), + FlatExpression::value(T::one()), error.into(), )); } } } - BooleanExpression::FieldLe(box lhs, box rhs) => { - let lhs = self.flatten_field_expression(statements_flattened, lhs); - let rhs = self.flatten_field_expression(statements_flattened, rhs); + BooleanExpression::FieldLe(e) => { + let lhs = self.flatten_field_expression(statements_flattened, *e.left); + let rhs = self.flatten_field_expression(statements_flattened, *e.right); match (lhs, rhs) { - (e, FlatExpression::Number(c)) => self.enforce_constant_le_check( + (e, FlatExpression::Value(c)) => self.enforce_constant_le_check( statements_flattened, e, - c, + c.value, error.into(), ), // c <= e <=> p - 1 - e <= p - 1 - c - (FlatExpression::Number(c), e) => self.enforce_constant_le_check( + (FlatExpression::Value(c), e) => self.enforce_constant_le_check( statements_flattened, - FlatExpression::Sub(box T::max_value().into(), box e), - T::max_value() - c, + FlatExpression::sub(T::max_value().into(), e), + T::max_value() - c.value, error.into(), ), (lhs, rhs) => { let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete let e = self.le_check(statements_flattened, lhs, rhs, safe_width); - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( e, - FlatExpression::Number(T::one()), + FlatExpression::value(T::one()), error.into(), )); } } } - BooleanExpression::UintLe(box lhs, box rhs) => { + BooleanExpression::UintLe(e) => { let lhs = self - .flatten_uint_expression(statements_flattened, lhs) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs = self - .flatten_uint_expression(statements_flattened, rhs) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); match (lhs, rhs) { - (e, FlatExpression::Number(c)) => self.enforce_constant_le_check( + (e, FlatExpression::Value(c)) => self.enforce_constant_le_check( statements_flattened, e, - c, + c.value, error.into(), ), // c <= e <=> p - 1 - e <= p - 1 - c - (FlatExpression::Number(c), e) => self.enforce_constant_le_check( + (FlatExpression::Value(c), e) => self.enforce_constant_le_check( statements_flattened, - FlatExpression::Sub(box T::max_value().into(), box e), - T::max_value() - c, + FlatExpression::sub(T::max_value().into(), e), + T::max_value() - c.value, error.into(), ), (lhs, rhs) => { let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete let e = self.le_check(statements_flattened, lhs, rhs, safe_width); - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( e, - FlatExpression::Number(T::one()), + FlatExpression::value(T::one()), error.into(), )); } } } - BooleanExpression::UintEq(box lhs, box rhs) => { + BooleanExpression::UintEq(e) => { let lhs = self - .flatten_uint_expression(statements_flattened, lhs) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs = self - .flatten_uint_expression(statements_flattened, rhs) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); self.flatten_equality_assertion( @@ -2496,9 +2647,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { error.into(), ) } - BooleanExpression::BoolEq(box lhs, box rhs) => { - let lhs = self.flatten_boolean_expression(statements_flattened, lhs); - let rhs = self.flatten_boolean_expression(statements_flattened, rhs); + BooleanExpression::BoolEq(e) => { + let lhs = self.flatten_boolean_expression(statements_flattened, *e.left); + let rhs = self.flatten_boolean_expression(statements_flattened, *e.right); self.flatten_equality_assertion( statements_flattened, @@ -2508,20 +2659,38 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) } // `!(x == 0)` can be asserted by giving the inverse of `x` - BooleanExpression::Not(box BooleanExpression::UintEq( - box UExpression { - inner: UExpressionInner::Value(0), - .. - }, - box x, - )) - | BooleanExpression::Not(box BooleanExpression::UintEq( - box x, - box UExpression { - inner: UExpressionInner::Value(0), - .. - }, - )) => { + BooleanExpression::Not(UnaryExpression { + inner: + box BooleanExpression::UintEq(BinaryExpression { + left: + box UExpression { + inner: + UExpressionInner::Value(ValueExpression { + value: 0, .. + }), + .. + }, + right: box x, + .. + }), + .. + }) + | BooleanExpression::Not(UnaryExpression { + inner: + box BooleanExpression::UintEq(BinaryExpression { + right: + box UExpression { + inner: + UExpressionInner::Value(ValueExpression { + value: 0, .. + }), + .. + }, + left: box x, + .. + }), + .. + }) => { let x = self .flatten_uint_expression(statements_flattened, x) .get_field_unchecked(); @@ -2537,26 +2706,42 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatDirective::new( vec![invx], Solver::Div, - vec![FlatExpression::Number(T::one()), x_id.into()], + vec![FlatExpression::value(T::one()), x_id.into()], ), )); // assert(invx * x == 1) - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Number(T::one()), - FlatExpression::Mult(box invx.into(), box x_id.into()), + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::one()), + FlatExpression::mul(invx.into(), x_id.into()), RuntimeError::Inverse, )); } // `!(x == 0)` can be asserted by giving the inverse of `x` - BooleanExpression::Not(box BooleanExpression::FieldEq( - box FieldElementExpression::Number(zero), - box x, - )) - | BooleanExpression::Not(box BooleanExpression::FieldEq( - box x, - box FieldElementExpression::Number(zero), - )) if zero == T::from(0) => { + BooleanExpression::Not(UnaryExpression { + inner: + box BooleanExpression::FieldEq( + BinaryExpression { + left: box FieldElementExpression::Value(zero), + right: box x, + .. + }, + .., + ), + .. + }) + | BooleanExpression::Not(UnaryExpression { + inner: + box BooleanExpression::FieldEq( + BinaryExpression { + left: box x, + right: box FieldElementExpression::Value(zero), + .. + }, + .., + ), + .. + }) if zero.value == T::from(0) => { let x = self.flatten_field_expression(statements_flattened, x); // introduce intermediate variable @@ -2570,14 +2755,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatDirective::new( vec![invx], Solver::Div, - vec![FlatExpression::Number(T::one()), x_id.into()], + vec![FlatExpression::value(T::one()), x_id.into()], ), )); // assert(invx * x == 1) - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Number(T::one()), - FlatExpression::Mult(box invx.into(), box x_id.into()), + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::one()), + FlatExpression::mul(invx.into(), x_id.into()), RuntimeError::Inverse, )); } @@ -2586,15 +2771,15 @@ impl<'ast, T: Field> Flattener<'ast, T> { let e = self.flatten_boolean_expression(statements_flattened, e); if e.is_linear() { - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( e, - FlatExpression::Number(T::from(1)), + FlatExpression::value(T::from(1)), error.into(), )); } else { // swap so that left side is linear - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Number(T::from(1)), + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::from(1)), e, error.into(), )); @@ -2602,11 +2787,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { } } } - ZirStatement::MultipleDefinition(vars, rhs) => { + ZirStatement::MultipleDefinition(s) => { // flatten the right side to p = sum(var_i.type.primitive_count) expressions // define p new variables to the right side expressions - match rhs { + match s.rhs { ZirExpressionList::EmbedCall(embed, generics, exprs) => { let rhs_flattened = self.flatten_embed_call( statements_flattened, @@ -2617,20 +2802,21 @@ impl<'ast, T: Field> Flattener<'ast, T> { let rhs = rhs_flattened.into_iter(); - assert_eq!(vars.len(), rhs.len()); + assert_eq!(s.assignees.len(), rhs.len()); - let vars: Vec<_> = vars + let assignees: Vec<_> = s + .assignees .into_iter() .zip(rhs) .map(|(v, r)| match r.get_field_unchecked() { FlatExpression::Identifier(id) => { - self.use_variable_with_existing(&v, id); - id + self.use_variable_with_existing(&v, id.id); + id.id } e => { let id = self.use_variable(&v); statements_flattened - .push_back(FlatStatement::Definition(id, e)); + .push_back(FlatStatement::definition(id, e)); id } }) @@ -2648,15 +2834,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { .get_field_unchecked() }) .collect(); - self.bits_cache.insert(vars[0].into(), bits); + self.bits_cache.insert(assignees[0].into(), bits); } _ => {} } } } } - ZirStatement::Log(l, expressions) => { - let expressions = expressions + ZirStatement::Log(s) => { + let expressions = s + .expressions .into_iter() .map(|(t, e)| { ( @@ -2671,9 +2858,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { }) .collect(); - statements_flattened.push_back(FlatStatement::Log(l, expressions)); + statements_flattened.push_back(FlatStatement::Log(LogStatement::new( + s.format_string, + expressions, + ))); } - } + }; + + statements_flattened.set_span(span_backup); } /// Flattens an equality assertion, enforcing it in the circuit. @@ -2691,22 +2883,22 @@ impl<'ast, T: Field> Flattener<'ast, T> { error: RuntimeError, ) { let (lhs, rhs) = match (lhs, rhs) { - (FlatExpression::Mult(box x, box y), z) | (z, FlatExpression::Mult(box x, box y)) => ( + (FlatExpression::Mult(e), z) | (z, FlatExpression::Mult(e)) => ( self.identify_expression(z, statements_flattened), - FlatExpression::Mult( - box self.identify_expression(x, statements_flattened), - box self.identify_expression(y, statements_flattened), + FlatExpression::mul( + self.identify_expression(*e.left, statements_flattened), + self.identify_expression(*e.right, statements_flattened), ), ), (x, z) => ( self.identify_expression(z, statements_flattened), - FlatExpression::Mult( - box self.identify_expression(x, statements_flattened), - box FlatExpression::Number(T::from(1)), + FlatExpression::mul( + self.identify_expression(x, statements_flattened), + FlatExpression::value(T::from(1)), ), ), }; - statements_flattened.push_back(FlatStatement::Condition(lhs, rhs, error)); + statements_flattened.push_back(FlatStatement::condition(lhs, rhs, error)); } /// Identifies a non-linear expression by assigning it to a new identifier. @@ -2724,8 +2916,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { true => e, false => { let sym = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(sym, e)); - FlatExpression::Identifier(sym) + statements_flattened.push_back(FlatStatement::definition(sym, e)); + FlatExpression::identifier(sym) } } } @@ -2759,6 +2951,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { parameter: &ZirParameter<'ast>, statements_flattened: &mut FlatStatements<'ast, T>, ) -> Parameter { + let span = parameter.get_span(); + + let backup_span = statements_flattened.span; + + statements_flattened.set_span(span); + let variable = self.use_variable(¶meter.id); match parameter.id.get_type() { @@ -2766,7 +2964,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // to constrain unsigned integer inputs to be in range, we get their bit decomposition. // it will be cached self.get_bits_unchecked( - &FlatUExpression::with_field(FlatExpression::Identifier(variable)), + &FlatUExpression::with_field(FlatExpression::identifier(variable)), bitwidth.to_usize(), bitwidth.to_usize(), statements_flattened, @@ -2774,19 +2972,18 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); } Type::Boolean => { - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( variable.into(), - FlatExpression::Mult(box variable.into(), box variable.into()), + FlatExpression::mul(variable.into(), variable.into()), RuntimeError::ArgumentBitness, )); } Type::FieldElement => {} } - Parameter { - id: variable, - private: parameter.private, - } + statements_flattened.set_span(backup_span); + + Parameter::new(variable, parameter.private).span(span) } fn issue_new_variable(&mut self) -> Variable { @@ -2808,10 +3005,18 @@ mod tests { use zokrates_ast::zir::types::Signature; use zokrates_ast::zir::types::Type; use zokrates_ast::zir::Id; + use zokrates_ast::zir::ZirFunction; use zokrates_field::Bn128Field; fn flatten_function(f: ZirFunction) -> FlatProg { - from_function_and_config(f, CompileConfig::default()).collect() + from_program_and_config( + ZirProgram { + main: f, + module_map: Default::default(), + }, + CompileConfig::default(), + ) + .collect() } #[test] @@ -2831,22 +3036,22 @@ mod tests { let function = ZirFunction:: { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::boolean("x".into()), - BooleanExpression::Value(true).into(), + BooleanExpression::value(true).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::boolean("y".into()), - BooleanExpression::Value(true).into(), + BooleanExpression::value(true).into(), ), - ZirStatement::Assertion( - BooleanExpression::BoolEq( - box BooleanExpression::identifier("x".into()), - box BooleanExpression::identifier("y".into()), + ZirStatement::assertion( + BooleanExpression::bool_eq( + BooleanExpression::identifier("x".into()), + BooleanExpression::identifier("y".into()), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -2856,22 +3061,23 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(1)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Number(Bn128Field::from(1)), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(1)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), @@ -2898,25 +3104,25 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("x"), - FieldElementExpression::Number(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("y"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::Add( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::Number(Bn128Field::from(1)), + ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::add( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::value(Bn128Field::from(1)), ), - box FieldElementExpression::identifier("y".into()), + FieldElementExpression::identifier("y".into()), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -2927,25 +3133,26 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(2)), - ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(1)), - FlatExpression::Mult( - box FlatExpression::Add( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(2)), + ), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(1)), + FlatExpression::mul( + FlatExpression::add( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(Bn128Field::from(1)), ), - box FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), @@ -2974,24 +3181,24 @@ mod tests { let function = ZirFunction:: { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::uint("x".into(), 32), ZirExpression::Uint( - UExpressionInner::Value(42) + UExpression::value(42) .annotate(32) .metadata(metadata.clone()), ), ), - ZirStatement::Assertion( - BooleanExpression::UintEq( - box UExpression::identifier("x".into()) + ZirStatement::assertion( + BooleanExpression::uint_eq( + UExpression::identifier("x".into()) .annotate(32) .metadata(metadata.clone()), - box UExpressionInner::Value(42).annotate(32).metadata(metadata), + UExpression::value(42).annotate(32).metadata(metadata), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -3002,18 +3209,19 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(42)), + FlatExpression::value(Bn128Field::from(42)), ), - FlatStatement::Condition( - FlatExpression::Number(Bn128Field::from(42)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Number(Bn128Field::from(1)), + FlatStatement::condition( + FlatExpression::value(Bn128Field::from(42)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), @@ -3040,22 +3248,22 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("x"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("y"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -3066,22 +3274,23 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(1)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Number(Bn128Field::from(1)), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(1)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), @@ -3110,29 +3319,29 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("x"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("y"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("z"), - FieldElementExpression::Number(Bn128Field::from(4)).into(), + FieldElementExpression::value(Bn128Field::from(4)).into(), ), - ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::mul( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), - box FieldElementExpression::identifier("z".into()), + FieldElementExpression::identifier("z".into()), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -3143,26 +3352,27 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(2), - FlatExpression::Number(Bn128Field::from(4)), + FlatExpression::value(Bn128Field::from(4)), ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(2)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Identifier(Variable::new(1)), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(2)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::identifier(Variable::new(1)), ), zir::RuntimeError::mock().into(), ), @@ -3191,29 +3401,29 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("x"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("y"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("z"), - FieldElementExpression::Number(Bn128Field::from(4)).into(), - ), - ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::identifier("z".into()), - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + FieldElementExpression::value(Bn128Field::from(4)).into(), + ), + ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::identifier("z".into()), + FieldElementExpression::mul( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -3224,26 +3434,27 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(2), - FlatExpression::Number(Bn128Field::from(4)), + FlatExpression::value(Bn128Field::from(4)), ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(2)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Identifier(Variable::new(1)), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(2)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::identifier(Variable::new(1)), ), zir::RuntimeError::mock().into(), ), @@ -3275,31 +3486,31 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("x"), - FieldElementExpression::Number(Bn128Field::from(4)).into(), + FieldElementExpression::value(Bn128Field::from(4)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("y"), - FieldElementExpression::Number(Bn128Field::from(4)).into(), + FieldElementExpression::value(Bn128Field::from(4)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("z"), - FieldElementExpression::Number(Bn128Field::from(8)).into(), + FieldElementExpression::value(Bn128Field::from(8)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("t"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::mul( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("z".into()), - box FieldElementExpression::identifier("t".into()), + FieldElementExpression::mul( + FieldElementExpression::identifier("z".into()), + FieldElementExpression::identifier("t".into()), ), ), zir::RuntimeError::mock(), @@ -3314,37 +3525,38 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(4)), + FlatExpression::value(Bn128Field::from(4)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(4)), + FlatExpression::value(Bn128Field::from(4)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(2), - FlatExpression::Number(Bn128Field::from(8)), + FlatExpression::value(Bn128Field::from(8)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(3), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(4), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(2)), - box FlatExpression::Identifier(Variable::new(3)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(2)), + FlatExpression::identifier(Variable::new(3)), ), ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(4)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Identifier(Variable::new(1)), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(4)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::identifier(Variable::new(1)), ), zir::RuntimeError::mock().into(), ), @@ -3370,19 +3582,19 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("a"), - FieldElementExpression::Number(Bn128Field::from(7)).into(), + FieldElementExpression::value(Bn128Field::from(7)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("b"), - FieldElementExpression::Pow( - box FieldElementExpression::identifier("a".into()), - box 0u32.into(), + FieldElementExpression::pow( + FieldElementExpression::identifier("a".into()), + 0u32.into(), ) .into(), ), - ZirStatement::Return(vec![FieldElementExpression::identifier("b".into()).into()]), + ZirStatement::ret(vec![FieldElementExpression::identifier("b".into()).into()]), ], signature: Signature { inputs: vec![], @@ -3391,20 +3603,21 @@ mod tests { }; let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 1, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(7)), + FlatExpression::value(Bn128Field::from(7)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::public(0), - FlatExpression::Identifier(Variable::new(1)), + FlatExpression::identifier(Variable::new(1)), ), ], }; @@ -3431,19 +3644,19 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("a"), - FieldElementExpression::Number(Bn128Field::from(7)).into(), + FieldElementExpression::value(Bn128Field::from(7)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("b"), - FieldElementExpression::Pow( - box FieldElementExpression::identifier("a".into()), - box 1u32.into(), + FieldElementExpression::pow( + FieldElementExpression::identifier("a".into()), + 1u32.into(), ) .into(), ), - ZirStatement::Return(vec![FieldElementExpression::identifier("b".into()).into()]), + ZirStatement::ret(vec![FieldElementExpression::identifier("b".into()).into()]), ], signature: Signature { inputs: vec![], @@ -3452,23 +3665,24 @@ mod tests { }; let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 1, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(7)), + FlatExpression::value(Bn128Field::from(7)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Mult( - box FlatExpression::Number(Bn128Field::from(1)), - box FlatExpression::Identifier(Variable::new(0)), + FlatExpression::mul( + FlatExpression::value(Bn128Field::from(1)), + FlatExpression::identifier(Variable::new(0)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::public(0), - FlatExpression::Identifier(Variable::new(1)), + FlatExpression::identifier(Variable::new(1)), ), ], }; @@ -3512,19 +3726,19 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("a"), - FieldElementExpression::Number(Bn128Field::from(7)).into(), + FieldElementExpression::value(Bn128Field::from(7)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("b"), - FieldElementExpression::Pow( - box FieldElementExpression::identifier("a".into()), - box 13u32.into(), + FieldElementExpression::pow( + FieldElementExpression::identifier("a".into()), + 13u32.into(), ) .into(), ), - ZirStatement::Return(vec![FieldElementExpression::identifier("b".into()).into()]), + ZirStatement::ret(vec![FieldElementExpression::identifier("b".into()).into()]), ], signature: Signature { inputs: vec![], @@ -3533,58 +3747,59 @@ mod tests { }; let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 1, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(7)), + FlatExpression::value(Bn128Field::from(7)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Identifier(Variable::new(0)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::identifier(Variable::new(0)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(2), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(1)), - box FlatExpression::Identifier(Variable::new(1)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(1)), + FlatExpression::identifier(Variable::new(1)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(3), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(2)), - box FlatExpression::Identifier(Variable::new(2)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(2)), + FlatExpression::identifier(Variable::new(2)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(4), - FlatExpression::Mult( - box FlatExpression::Number(Bn128Field::from(1)), - box FlatExpression::Identifier(Variable::new(0)), + FlatExpression::mul( + FlatExpression::value(Bn128Field::from(1)), + FlatExpression::identifier(Variable::new(0)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(5), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(4)), - box FlatExpression::Identifier(Variable::new(2)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(4)), + FlatExpression::identifier(Variable::new(2)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(6), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(5)), - box FlatExpression::Identifier(Variable::new(3)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(5)), + FlatExpression::identifier(Variable::new(3)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::public(0), - FlatExpression::Identifier(Variable::new(6)), + FlatExpression::identifier(Variable::new(6)), ), ], }; @@ -3598,28 +3813,28 @@ mod tests { fn if_else() { let config = CompileConfig::default(); let expression = FieldElementExpression::conditional( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(Bn128Field::from(32)), - box FieldElementExpression::Number(Bn128Field::from(4)), + BooleanExpression::field_eq( + FieldElementExpression::value(Bn128Field::from(32)), + FieldElementExpression::value(Bn128Field::from(4)), ), - FieldElementExpression::Number(Bn128Field::from(12)), - FieldElementExpression::Number(Bn128Field::from(51)), + FieldElementExpression::value(Bn128Field::from(12)), + FieldElementExpression::value(Bn128Field::from(51)), ); let mut flattener = Flattener::new(config); - flattener.flatten_field_expression(&mut FlatStatements::new(), expression); + flattener.flatten_field_expression(&mut FlatStatements::default(), expression); } #[test] fn geq_leq() { let config = CompileConfig::default(); let mut flattener = Flattener::new(config); - let expression_le = BooleanExpression::FieldLe( - box FieldElementExpression::Number(Bn128Field::from(32)), - box FieldElementExpression::Number(Bn128Field::from(4)), + let expression_le = BooleanExpression::field_le( + FieldElementExpression::value(Bn128Field::from(32)), + FieldElementExpression::value(Bn128Field::from(4)), ); - flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_le); + flattener.flatten_boolean_expression(&mut FlatStatements::default(), expression_le); } #[test] @@ -3628,21 +3843,21 @@ mod tests { let mut flattener = Flattener::new(config); let expression = FieldElementExpression::conditional( - BooleanExpression::And( - box BooleanExpression::FieldEq( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(4)), + BooleanExpression::bitand( + BooleanExpression::field_eq( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(4)), ), - box BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(20)), + BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(20)), ), ), - FieldElementExpression::Number(Bn128Field::from(12)), - FieldElementExpression::Number(Bn128Field::from(51)), + FieldElementExpression::value(Bn128Field::from(12)), + FieldElementExpression::value(Bn128Field::from(51)), ); - flattener.flatten_field_expression(&mut FlatStatements::new(), expression); + flattener.flatten_field_expression(&mut FlatStatements::default(), expression); } #[test] @@ -3650,21 +3865,21 @@ mod tests { // a = 5 / b / b let config = CompileConfig::default(); let mut flattener = Flattener::new(config); - let mut statements_flattened = FlatStatements::new(); + let mut statements_flattened = FlatStatements::default(); - let definition = ZirStatement::Definition( + let definition = ZirStatement::definition( zir::Variable::field_element("b"), - FieldElementExpression::Number(Bn128Field::from(42)).into(), + FieldElementExpression::value(Bn128Field::from(42)).into(), ); - let statement = ZirStatement::Definition( + let statement = ZirStatement::definition( zir::Variable::field_element("a"), - FieldElementExpression::Div( - box FieldElementExpression::Div( - box FieldElementExpression::Number(Bn128Field::from(5)), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::div( + FieldElementExpression::div( + FieldElementExpression::value(Bn128Field::from(5)), + FieldElementExpression::identifier("b".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ) .into(), ); @@ -3688,35 +3903,35 @@ mod tests { let sym_2 = Variable::new(6); assert_eq!( - statements_flattened, + statements_flattened.buffer, vec![ - FlatStatement::Definition(b, FlatExpression::Number(Bn128Field::from(42))), + FlatStatement::definition(b, FlatExpression::value(Bn128Field::from(42))), // inputs to first div (5/b) - FlatStatement::Definition(five, FlatExpression::Number(Bn128Field::from(5))), - FlatStatement::Definition(b0, b.into()), + FlatStatement::definition(five, FlatExpression::value(Bn128Field::from(5))), + FlatStatement::definition(b0, b.into()), // execute div FlatStatement::Directive(FlatDirective::new( vec![sym_0], Solver::Div, - vec![five, b0] + vec![five.into(), b0.into()] )), - FlatStatement::Condition( + FlatStatement::condition( five.into(), - FlatExpression::Mult(box b0.into(), box sym_0.into()), + FlatExpression::mul(b0.into(), sym_0.into()), RuntimeError::Division ), // inputs to second div (res/b) - FlatStatement::Definition(sym_1, sym_0.into()), - FlatStatement::Definition(b1, b.into()), + FlatStatement::definition(sym_1, sym_0.into()), + FlatStatement::definition(b1, b.into()), // execute div FlatStatement::Directive(FlatDirective::new( vec![sym_2], Solver::Div, - vec![sym_1, b1] + vec![sym_1.into(), b1.into()] )), - FlatStatement::Condition( + FlatStatement::condition( sym_1.into(), - FlatExpression::Mult(box b1.into(), box sym_2.into()), + FlatExpression::mul(b1.into(), sym_2.into()), RuntimeError::Division ), ] diff --git a/zokrates_codegen/src/utils.rs b/zokrates_codegen/src/utils.rs index 6fc257921..ab53c4464 100644 --- a/zokrates_codegen/src/utils.rs +++ b/zokrates_codegen/src/utils.rs @@ -1,3 +1,4 @@ +use std::ops::*; use zokrates_ast::flat::*; use zokrates_field::Field; @@ -6,16 +7,16 @@ pub fn flat_expression_from_bits(v: Vec>) -> FlatExp v: Vec<(T, FlatExpression)>, ) -> FlatExpression { match v.len() { - 0 => FlatExpression::Number(T::zero()), + 0 => FlatExpression::value(T::zero()), 1 => { let (coeff, var) = v[0].clone(); - FlatExpression::Mult(box FlatExpression::Number(coeff), box var) + FlatExpression::mul(FlatExpression::value(coeff), var) } n => { let (u, v) = v.split_at(n / 2); - FlatExpression::Add( - box flat_expression_from_bits_aux(u.to_vec()), - box flat_expression_from_bits_aux(v.to_vec()), + FlatExpression::add( + flat_expression_from_bits_aux(u.to_vec()), + flat_expression_from_bits_aux(v.to_vec()), ) } } diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 0c7583fa2..d1a57705b 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -18,7 +18,7 @@ use zokrates_ast::ir::{self, from_flat::from_flat}; use zokrates_ast::typed::abi::Abi; use zokrates_ast::untyped::{Module, OwnedModuleId, Program}; use zokrates_ast::zir::ZirProgram; -use zokrates_codegen::from_function_and_config; +use zokrates_codegen::from_program_and_config; use zokrates_common::{CompileConfig, Resolver}; use zokrates_field::Field; use zokrates_pest_ast as pest; @@ -155,15 +155,15 @@ impl fmt::Display for CompileErrorInner { CompileErrorInner::SemanticError(ref e) => { let location = e .pos() - .map(|p| format!("{}", p.0)) + .map(|p| format!("{}", p.from)) .unwrap_or_else(|| "".to_string()); write!(f, "{}\n\t{}", location, e.message()) } CompileErrorInner::ReadError(ref e) => write!(f, "\n\t{}", e), CompileErrorInner::ImportError(ref e) => { let location = e - .pos() - .map(|p| format!("{}", p.0)) + .span() + .map(|p| format!("{}", p.from)) .unwrap_or_else(|| "".to_string()); write!(f, "{}\n\t{}", location, e.message()) } @@ -189,7 +189,7 @@ pub fn compile<'ast, T: Field, E: Into>( // flatten input program log::debug!("Flatten"); - let program_flattened = from_function_and_config(typed_ast.main, config); + let program_flattened = from_program_and_config(typed_ast, config); // convert to ir log::debug!("Convert to IR"); diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index f7fa95f13..3bc08c5b8 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -13,43 +13,43 @@ use std::path::{Path, PathBuf}; use zokrates_ast::untyped::*; use typed_arena::Arena; -use zokrates_ast::common::FlatEmbed; +use zokrates_ast::common::{FlatEmbed, SourceSpan}; use zokrates_ast::untyped::types::UnresolvedType; use zokrates_common::Resolver; use zokrates_field::Field; #[derive(PartialEq, Eq, Debug)] pub struct Error { - pos: Option<(Position, Position)>, + span: Option, message: String, } impl Error { pub fn new>(message: T) -> Error { Error { - pos: None, + span: None, message: message.into(), } } - pub fn pos(&self) -> &Option<(Position, Position)> { - &self.pos + pub fn span(&self) -> &Option { + &self.span } pub fn message(&self) -> &str { &self.message } - fn with_pos(self, pos: Option<(Position, Position)>) -> Error { - Error { pos, ..self } + fn with_span(self, span: Option) -> Error { + Error { span, ..self } } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let location = self - .pos - .map(|p| format!("{}", p.0)) + .span + .map(|p| format!("{}", p.from)) .unwrap_or_else(|| "?".to_string()); write!(f, "{}\n\t{}", location, self.message) } @@ -58,7 +58,7 @@ impl fmt::Display for Error { impl From for Error { fn from(error: io::Error) -> Self { Error { - pos: None, + span: None, message: format!("I/O Error: {}", error), } } @@ -95,7 +95,7 @@ impl Importer { modules: &mut HashMap>, arena: &'ast Arena, ) -> Result, CompileErrors> { - let pos = import.pos(); + let span = import.span().in_module(location); let module_id = import.value.source; let symbol = import.value.id; @@ -111,7 +111,7 @@ impl Importer { Bn128Field::name(), T::name() )) - .with_pos(Some(pos)), + .with_span(Some(span)), ) .in_file(location) .into()); @@ -132,7 +132,7 @@ impl Importer { Bw6_761Field::name(), T::name() )) - .with_pos(Some(pos)), + .with_span(Some(span)), ) .in_file(location) .into()); @@ -195,12 +195,12 @@ impl Importer { expression: Expression::U32Constant(T::get_required_bits() as u32) .into(), } - .start_end(pos.0, pos.1), + .start_end(span.from, span.to), )), }, s => { return Err(CompileErrorInner::ImportError( - Error::new(format!("Embed {} not found", s)).with_pos(Some(pos)), + Error::new(format!("Embed {} not found", s)).with_span(Some(span)), ) .in_file(location) .into()); @@ -245,16 +245,16 @@ impl Importer { id: alias, symbol: Symbol::There( SymbolImport::with_id_in_module(symbol.id, new_location) - .start_end(pos.0, pos.1), + .start_end(span.from, span.to), ), } } Err(err) => { - return Err( - CompileErrorInner::ImportError(err.into().with_pos(Some(pos))) - .in_file(location) - .into(), - ); + return Err(CompileErrorInner::ImportError( + err.into().with_span(Some(span)), + ) + .in_file(location) + .into()); } }, None => { @@ -267,6 +267,6 @@ impl Importer { }, }; - Ok(symbol_declaration.start_end(pos.0, pos.1)) + Ok(symbol_declaration.start_end(span.from, span.to)) } } diff --git a/zokrates_core/src/optimizer/directive.rs b/zokrates_core/src/optimizer/directive.rs index 4d140637a..dd8159013 100644 --- a/zokrates_core/src/optimizer/directive.rs +++ b/zokrates_core/src/optimizer/directive.rs @@ -28,24 +28,34 @@ impl<'ast, T: Field> Folder<'ast, T> for DirectiveOptimizer<'ast, T> { *self.substitution.get(&v).unwrap_or(&v) } - fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { - match s { - Statement::Directive(d) => { - let d = self.fold_directive(d); + fn fold_directive_statement( + &mut self, + d: DirectiveStatement<'ast, T>, + ) -> Vec> { + let d = DirectiveStatement { + inputs: d + .inputs + .into_iter() + .map(|e| self.fold_quadratic_combination(e)) + .collect(), + outputs: d + .outputs + .into_iter() + .map(|o| self.fold_variable(o)) + .collect(), + ..d + }; - match self.calls.entry((d.solver.clone(), d.inputs.clone())) { - Entry::Vacant(e) => { - e.insert(d.outputs.clone()); - vec![Statement::Directive(d)] - } - Entry::Occupied(e) => { - self.substitution - .extend(d.outputs.into_iter().zip(e.get().iter().cloned())); - vec![] - } - } + match self.calls.entry((d.solver.clone(), d.inputs.clone())) { + Entry::Vacant(e) => { + e.insert(d.outputs.clone()); + vec![Statement::Directive(d)] + } + Entry::Occupied(e) => { + self.substitution + .extend(d.outputs.into_iter().zip(e.get().iter().cloned())); + vec![] } - s => fold_statement(self, s), } } } diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index 346695a27..9f12e1e04 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -40,7 +40,11 @@ impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer { fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { match s { - Statement::Block(s) => s.into_iter().flat_map(|s| self.fold_statement(s)).collect(), + Statement::Block(s) => s + .inner + .into_iter() + .flat_map(|s| self.fold_statement(s)) + .collect(), s => { let hashed = hash(&s); let result = match self.seen.get(&hashed) { @@ -63,20 +67,23 @@ mod tests { #[test] fn identity() { let p: Prog = Prog { + module_map: Default::default(), statements: vec![ Statement::constraint( - QuadComb::from_linear_combinations( + QuadComb::new( LinComb::summand(3, Variable::new(3)), LinComb::summand(3, Variable::new(3)), ), LinComb::one(), + None, ), Statement::constraint( - QuadComb::from_linear_combinations( + QuadComb::new( LinComb::summand(3, Variable::new(42)), LinComb::summand(3, Variable::new(3)), ), LinComb::zero(), + None, ), ], return_count: 0, @@ -95,23 +102,26 @@ mod tests { #[test] fn remove_duplicates() { let constraint = Statement::constraint( - QuadComb::from_linear_combinations( + QuadComb::new( LinComb::summand(3, Variable::new(3)), LinComb::summand(3, Variable::new(3)), ), LinComb::one(), + None, ); let p: Prog = Prog { + module_map: Default::default(), statements: vec![ constraint.clone(), constraint.clone(), Statement::constraint( - QuadComb::from_linear_combinations( + QuadComb::new( LinComb::summand(3, Variable::new(42)), LinComb::summand(3, Variable::new(3)), ), LinComb::zero(), + None, ), constraint.clone(), constraint.clone(), @@ -122,14 +132,16 @@ mod tests { }; let expected = Prog { + module_map: Default::default(), statements: vec![ constraint, Statement::constraint( - QuadComb::from_linear_combinations( + QuadComb::new( LinComb::summand(3, Variable::new(42)), LinComb::summand(3, Variable::new(3)), ), LinComb::zero(), + None, ), ], return_count: 0, diff --git a/zokrates_core/src/optimizer/mod.rs b/zokrates_core/src/optimizer/mod.rs index 95e9f25cb..08a6c1b7d 100644 --- a/zokrates_core/src/optimizer/mod.rs +++ b/zokrates_core/src/optimizer/mod.rs @@ -54,6 +54,7 @@ pub fn optimize<'ast, T: Field, I: IntoIterator>>( .flat_map(move |s| directive_optimizer.fold_statement(s)) .flat_map(move |s| duplicate_optimizer.fold_statement(s)), return_count: p.return_count, + module_map: p.module_map, solvers: p.solvers, }; diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index c1f90aa5d..770e9c025 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -37,8 +37,9 @@ // - otherwise return `c_0` use std::collections::{HashMap, HashSet}; +use zokrates_ast::common::WithSpan; use zokrates_ast::flat::Variable; -use zokrates_ast::ir::folder::{fold_statement, Folder}; +use zokrates_ast::ir::folder::{fold_statement_cases, Folder}; use zokrates_ast::ir::LinComb; use zokrates_ast::ir::*; use zokrates_field::Field; @@ -67,24 +68,24 @@ impl RedefinitionOptimizer { } } - fn fold_statement<'ast>( + fn fold_statement_cases<'ast>( &mut self, s: Statement<'ast, T>, aggressive: bool, ) -> Vec> { match s { - Statement::Constraint(quad, lin, message) => { - let quad = self.fold_quadratic_combination(quad); - let lin = self.fold_linear_combination(lin); + Statement::Constraint(s) => { + let quad = self.fold_quadratic_combination(s.quad); + let lin = self.fold_linear_combination(s.lin); if lin.is_zero() { - return vec![Statement::Constraint(quad, lin, message)]; + return vec![Statement::constraint(quad, lin, s.error)]; } - let (constraint, to_insert, to_ignore) = match self.ignore.contains(&lin.0[0].0) - || self.substitution.contains_key(&lin.0[0].0) + let (constraint, to_insert, to_ignore) = match self.ignore.contains(&lin.value[0].0) + || self.substitution.contains_key(&lin.value[0].0) { - true => (Some(Statement::Constraint(quad, lin, message)), None, None), + true => (Some(Statement::constraint(quad, lin, s.error)), None, None), false => match lin.try_summand() { // if the right side is a single variable Ok((variable, coefficient)) => match quad.try_linear() { @@ -92,16 +93,16 @@ impl RedefinitionOptimizer { Ok(l) => (None, Some((variable, l / &coefficient)), None), // if the left side isn't linear Err(quad) => ( - Some(Statement::Constraint( + Some(Statement::constraint( quad, LinComb::summand(coefficient, variable), - message, + s.error, )), None, Some(variable), ), }, - Err(l) => (Some(Statement::Constraint(quad, l, message)), None, None), + Err(l) => (Some(Statement::constraint(quad, l, s.error)), None, None), }, }; @@ -122,7 +123,19 @@ impl RedefinitionOptimizer { } } Statement::Directive(d) => { - let d = self.fold_directive(d); + let d = DirectiveStatement { + inputs: d + .inputs + .into_iter() + .map(|e| self.fold_quadratic_combination(e)) + .collect(), + outputs: d + .outputs + .into_iter() + .map(|o| self.fold_variable(o)) + .collect(), + ..d + }; // check if the inputs are constants, ie reduce to the form `coeff * ~one` let inputs: Vec<_> = d @@ -171,24 +184,30 @@ impl RedefinitionOptimizer { self.ignore.insert(o); } } - vec![Statement::Directive(Directive { inputs, ..d })] + vec![Statement::directive(d.outputs, d.solver, inputs)] } } } - s => fold_statement(self, s), + s => fold_statement_cases(self, s), } } } impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer { - fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { + fn fold_statement_cases(&mut self, s: Statement<'ast, T>) -> Vec> { match s { Statement::Block(statements) => { #[allow(clippy::needless_collect)] // optimize aggressively and clean up in a second pass (we need to collect here) let statements: Vec<_> = statements + .inner .into_iter() - .flat_map(|s| self.fold_statement(s, true)) + .flat_map(|s| { + let span = s.get_span(); + self.fold_statement_cases(s, true) + .into_iter() + .map(move |s| s.span(span)) + }) .collect(); // clean up @@ -203,22 +222,23 @@ impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer { }) .collect(); - vec![Statement::Block(statements)] + vec![Statement::block(statements)] } - s => self.fold_statement(s, false), + s => self.fold_statement_cases(s, false), } } fn fold_linear_combination(&mut self, lc: LinComb) -> LinComb { match lc - .0 + .value .iter() .any(|(variable, _)| self.substitution.get(variable).is_some()) { true => // for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is { - lc.0.into_iter() + lc.value + .into_iter() .map(|(variable, coefficient)| { self.substitution .get(&variable) @@ -250,6 +270,7 @@ mod tests { let out = Variable::public(0); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ Statement::definition(y, x.id), @@ -260,6 +281,7 @@ mod tests { }; let optimized: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![Statement::definition(out, x.id)], return_count: 1, @@ -279,6 +301,7 @@ mod tests { let x = Parameter::public(Variable::new(0)); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![Statement::definition(one, x.id)], return_count: 1, @@ -311,11 +334,12 @@ mod tests { let out = Variable::public(0); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ Statement::definition(y, x.id), Statement::definition(z, y), - Statement::constraint(z, y), + Statement::constraint(z, y, None), Statement::definition(out, z), ], return_count: 1, @@ -323,9 +347,10 @@ mod tests { }; let optimized: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ - Statement::constraint(x.id, x.id), + Statement::constraint(x.id, x.id, None), Statement::definition(out, x.id), ], return_count: 1, @@ -360,6 +385,7 @@ mod tests { let out_0 = Variable::public(1); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ Statement::definition(y, x.id), @@ -374,6 +400,7 @@ mod tests { }; let optimized: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ Statement::definition(out_0, x.id), @@ -411,6 +438,7 @@ mod tests { let r = Variable::public(0); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x, y], statements: vec![ Statement::definition(a, LinComb::from(x.id) + LinComb::from(y.id)), @@ -425,6 +453,7 @@ mod tests { Statement::constraint( LinComb::summand(2, c), LinComb::summand(6, x.id) + LinComb::summand(6, y.id), + None, ), Statement::definition(r, LinComb::from(a) + LinComb::from(b) + LinComb::from(c)), ], @@ -433,11 +462,13 @@ mod tests { }; let expected: Prog = Prog { + module_map: Default::default(), arguments: vec![x, y], statements: vec![ Statement::constraint( LinComb::summand(6, x.id) + LinComb::summand(6, y.id), LinComb::summand(6, x.id) + LinComb::summand(6, y.id), + None, ), Statement::definition( r, @@ -479,12 +510,10 @@ mod tests { let z = Variable::new(2); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x, y], statements: vec![ - Statement::definition( - z, - QuadComb::from_linear_combinations(LinComb::from(x.id), LinComb::from(y.id)), - ), + Statement::definition(z, QuadComb::new(LinComb::from(x.id), LinComb::from(y.id))), Statement::definition(z, LinComb::from(x.id)), ], return_count: 0, @@ -511,10 +540,11 @@ mod tests { let x = Parameter::public(Variable::new(0)); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ - Statement::constraint(x.id, Bn128Field::from(1)), - Statement::constraint(x.id, Bn128Field::from(2)), + Statement::constraint(x.id, Bn128Field::from(1), None), + Statement::constraint(x.id, Bn128Field::from(2), None), ], return_count: 1, solvers: vec![], diff --git a/zokrates_core/src/optimizer/tautology.rs b/zokrates_core/src/optimizer/tautology.rs index 855efa11d..d33ddeeda 100644 --- a/zokrates_core/src/optimizer/tautology.rs +++ b/zokrates_core/src/optimizer/tautology.rs @@ -5,7 +5,6 @@ // // This makes the assumption that ~one has value 1, as should be guaranteed by the verifier -use zokrates_ast::ir::folder::fold_statement; use zokrates_ast::ir::folder::Folder; use zokrates_ast::ir::*; use zokrates_field::Field; @@ -14,19 +13,16 @@ use zokrates_field::Field; pub struct TautologyOptimizer; impl<'ast, T: Field> Folder<'ast, T> for TautologyOptimizer { - fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { - match s { - Statement::Constraint(quad, lin, message) => match quad.try_linear() { - Ok(l) => { - if l == lin { - vec![] - } else { - vec![Statement::Constraint(l.into(), lin, message)] - } + fn fold_constraint_statement(&mut self, s: ConstraintStatement) -> Vec> { + match s.quad.try_linear() { + Ok(l) => { + if l == s.lin { + vec![] + } else { + vec![Statement::constraint(l, s.lin, s.error)] } - Err(quad) => vec![Statement::Constraint(quad, lin, message)], - }, - _ => fold_statement(self, s), + } + Err(quad) => vec![Statement::constraint(quad, s.lin, s.error)], } } } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 4e006eed3..a0d7c5081 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -8,7 +8,8 @@ use num_bigint::BigUint; use std::collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt; use std::path::PathBuf; -use zokrates_ast::common::{FormatString, SourceMetadata}; +use zokrates_ast::common::expressions::ValueExpression; +use zokrates_ast::common::{FormatString, ModuleMap, SourceMetadata, SourceSpan, WithSpan}; use zokrates_ast::typed::types::{GGenericsAssignment, GTupleType, GenericsAssignment}; use zokrates_ast::typed::SourceIdentifier; use zokrates_ast::typed::*; @@ -17,6 +18,8 @@ use zokrates_ast::untyped::Identifier; use zokrates_ast::untyped::*; use zokrates_field::Field; +use std::ops::*; + use zokrates_ast::untyped::types::{UnresolvedSignature, UnresolvedType, UserTypeId}; use std::hash::Hash; @@ -29,7 +32,7 @@ use zokrates_ast::typed::types::{ #[derive(PartialEq, Eq, Debug)] pub struct ErrorInner { - pos: Option<(Position, Position)>, + span: Option, message: String, } @@ -40,8 +43,8 @@ pub struct Error { } impl ErrorInner { - pub fn pos(&self) -> &Option<(Position, Position)> { - &self.pos + pub fn pos(&self) -> &Option { + &self.span } pub fn message(&self) -> &str { @@ -369,6 +372,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(TypedProgram { main: program.main, + module_map: ModuleMap::new(state.typed_modules.iter().map(|(id, _)| id).cloned()), modules: state.typed_modules, }) } @@ -379,7 +383,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, state: &State<'ast, T>, ) -> Result, Vec> { - let pos = ty.pos(); + let span = ty.span().in_module(module_id); let ty = ty.value; let mut errors = vec![]; @@ -395,7 +399,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .is_some() { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!( "Generic parameter {p} conflicts with constant symbol {p}", p = g.value @@ -410,7 +414,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } false => { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!("Generic parameter {} is already declared", g.value), }); } @@ -432,7 +436,7 @@ impl<'ast, T: Field> Checker<'ast, T> { for declared_generic in generics_map.keys() { if !used_generics.contains(declared_generic) { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Generic parameter {} must be used", declared_generic), }); } @@ -458,7 +462,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, state: &State<'ast, T>, ) -> Result, ErrorInner> { - let pos = c.pos(); + let span = c.span().in_module(module_id); let ty = self.check_declaration_type( c.value.ty.clone(), @@ -492,7 +496,7 @@ impl<'ast, T: Field> Checker<'ast, T> { DeclarationType::Int => Err(checked_expr), // Integers cannot be assigned } .map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expression `{}` of type `{}` cannot be assigned to constant `{}` of type `{}`", e, @@ -511,7 +515,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, state: &State<'ast, T>, ) -> Result, Vec> { - let pos = s.pos(); + let span = s.span().in_module(module_id); let s = s.value; let mut errors = vec![]; @@ -529,7 +533,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .is_some() { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!( "Generic parameter {p} conflicts with constant symbol {p}", p = g.value @@ -544,7 +548,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } false => { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!("Generic parameter {} is already declared", g.value), }); } @@ -569,7 +573,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(f) => match fields_set.insert(f.0.clone()) { true => fields.push(f), false => errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Duplicate key {} in struct definition", f.0,), }), }, @@ -583,7 +587,7 @@ impl<'ast, T: Field> Checker<'ast, T> { for declared_generic in generics_map.keys() { if !used_generics.contains(declared_generic) { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Generic parameter {} must be used", declared_generic), }); } @@ -614,7 +618,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ) -> Result<(), Vec> { let mut errors: Vec = vec![]; - let pos = declaration.pos(); + let span = declaration.span().in_module(module_id); let declaration = declaration.value; match declaration.symbol.clone() { @@ -629,7 +633,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match symbol_unifier.insert_type(declaration.id) { false => errors.push( ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id @@ -671,7 +675,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match symbol_unifier.insert_constant(declaration.id) { false => errors.push( ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id @@ -722,7 +726,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match symbol_unifier.insert_type(declaration.id) { false => errors.push( ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id, @@ -753,7 +757,7 @@ impl<'ast, T: Field> Checker<'ast, T> { { false => errors.push( ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id @@ -789,7 +793,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Symbol::There(import) => { - let pos = import.pos(); + let span = import.span().in_module(module_id); let import = import.value; match Checker::default().check_module(&import.module_id, state) { @@ -850,7 +854,7 @@ impl<'ast, T: Field> Checker<'ast, T> { errors.push(Error { module_id: module_id.to_path_buf(), inner: ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id, @@ -871,7 +875,7 @@ impl<'ast, T: Field> Checker<'ast, T> { errors.push(Error { module_id: module_id.to_path_buf(), inner: ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id, @@ -909,7 +913,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } (0, None, None) => { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Could not find symbol {} in module {}", import.symbol_id, import.module_id.display(), @@ -923,7 +927,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match symbol_unifier.insert_function(declaration.id, candidate.signature.clone()) { false => { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id, @@ -957,7 +961,7 @@ impl<'ast, T: Field> Checker<'ast, T> { false => { errors.push( ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id @@ -1054,11 +1058,11 @@ impl<'ast, T: Field> Checker<'ast, T> { { 1 => Ok(()), 0 => Err(ErrorInner { - pos: None, + span: None, message: "No main function found".into(), }), n => Err(ErrorInner { - pos: None, + span: None, message: format!("Only one main function allowed, found {}", n), }), } @@ -1070,14 +1074,14 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, Vec> { - let pos = var.pos(); + let span = var.span().in_module(module_id); let var = self.check_variable(var, module_id, types)?; match var.get_type() { Type::Uint(UBitwidth::B32) => Ok(()), t => Err(vec![ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Variable in for loop cannot have type {}", t), }]), }?; @@ -1106,7 +1110,7 @@ impl<'ast, T: Field> Checker<'ast, T> { self.enter_scope(); - let pos = funct_node.pos(); + let span = funct_node.span().in_module(module_id); let mut errors = vec![]; let funct = funct_node.value; @@ -1142,14 +1146,14 @@ impl<'ast, T: Field> Checker<'ast, T> { } for (arg, decl_ty) in funct.arguments.into_iter().zip(s.inputs.iter()) { - let pos = arg.pos(); + let span = arg.span().in_module(module_id); let arg = arg.value; // parameters defined on a non-entrypoint function should not have visibility modifiers if (state.main_id != module_id || id != "main") && arg.is_private.is_some() { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Visibility modifiers on arguments are only allowed on the entrypoint function" .into(), @@ -1159,12 +1163,11 @@ impl<'ast, T: Field> Checker<'ast, T> { let decl_v = DeclarationVariable::new( self.id_in_this_scope(arg.id.value.id), decl_ty.clone(), - arg.id.value.is_mutable, ); let is_mutable = arg.id.value.is_mutable; - let ty = specialize_declaration_type(decl_v.clone()._type, &generics).unwrap(); + let ty = specialize_declaration_type(decl_v.clone().ty, &generics).unwrap(); assert_eq!(self.scope.level, 1); @@ -1178,27 +1181,27 @@ impl<'ast, T: Field> Checker<'ast, T> { false => {} true => { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Duplicate name in function definition: `{}` was previously declared as an argument, a generic parameter or a constant", arg.id.value.id) }); } }; - arguments_checked.push(DeclarationParameter { - id: decl_v, - private: arg.is_private.unwrap_or(false), - }); + arguments_checked.push( + DeclarationParameter::new(decl_v, arg.is_private.unwrap_or(false)) + .with_span(span), + ); } let mut found_return = false; for stat in funct.statements.into_iter() { - let pos = Some(stat.pos()); + let span = Some(stat.span().in_module(module_id)); if let Statement::Return(..) = stat.value { if found_return { errors.push(ErrorInner { - pos, + span, message: "Expected a single return statement".to_string(), }); } @@ -1218,11 +1221,12 @@ impl<'ast, T: Field> Checker<'ast, T> { if !found_return { match (&*s.output).is_empty_tuple() { - true => statements_checked - .push(TypedStatement::Return(TypedExpression::empty_tuple())), + true => statements_checked.push( + TypedStatement::ret(TypedExpression::empty_tuple()).with_span(span), + ), false => { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Expected a return statement".to_string(), }); } @@ -1271,7 +1275,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .is_some() { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!( "Generic parameter {p} conflicts with constant symbol {p}", p = g.value @@ -1286,7 +1290,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } false => { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!("Generic parameter {} is already declared", g.value), }); } @@ -1348,7 +1352,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { - let pos = ty.pos(); + let span = ty.span().in_module(module_id); let ty = ty.value; match ty { @@ -1364,7 +1368,7 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedExpression::Uint(e) => match e.bitwidth() { UBitwidth::B32 => Ok(e), _ => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant, found {} of type {}", e, ty @@ -1374,7 +1378,7 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedExpression::Int(v) => { UExpression::try_from_int(v.clone(), &UBitwidth::B32).map_err(|_| { ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant, found {} of type {}", v, ty @@ -1383,7 +1387,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }) } _ => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant, found {} of type {}", size, ty @@ -1411,7 +1415,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .get(&id) .cloned() .ok_or_else(|| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Undefined type {}", id), })?; @@ -1438,13 +1442,13 @@ impl<'ast, T: Field> Checker<'ast, T> { UExpression::try_from_typed(e, &UBitwidth::B32) .map(|e| (GenericIdentifier::with_name(g).with_index(i), e)) .map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Expected u32 expression, but got expression of type {}", e.get_type()), }) }) }, None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet." .into(), @@ -1456,7 +1460,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(specialize_declaration_type(declaration_type, &assignment).unwrap()) } false => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected {} generic argument{} on type {}, but got {}", generic_identifiers.len(), @@ -1481,7 +1485,7 @@ impl<'ast, T: Field> Checker<'ast, T> { generics_map: &BTreeMap, usize>, used_generics: &mut HashSet>, ) -> Result, ErrorInner> { - let pos = expr.pos(); + let span = expr.span().in_module(module_id); match expr.value { Expression::U32Constant(c) => Ok(DeclarationConstant::from(c)), @@ -1492,7 +1496,7 @@ impl<'ast, T: Field> Checker<'ast, T> { )) } else { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant or an identifier, found {}", Expression::IntConstant(c) @@ -1508,7 +1512,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match ty { DeclarationType::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(CanonicalConstantIdentifier::new(name, module_id.into()))), _ => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant or an identifier, found {} of type {}", name, ty @@ -1518,13 +1522,13 @@ impl<'ast, T: Field> Checker<'ast, T> { } (None, Some(index)) => Ok(DeclarationConstant::Generic(GenericIdentifier::with_name(name).with_index(*index))), _ => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Undeclared symbol `{}`", name) }) } } e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant or an identifier, found {}", e @@ -1541,7 +1545,7 @@ impl<'ast, T: Field> Checker<'ast, T> { generics_map: &BTreeMap, usize>, used_generics: &mut HashSet>, ) -> Result, ErrorInner> { - let pos = ty.pos(); + let span = ty.span().in_module(module_id); let ty = ty.value; match ty { @@ -1587,7 +1591,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .get(&id) .cloned() .ok_or_else(|| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Undefined type {}", id), })?; @@ -1605,7 +1609,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ) .map(Some), None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Expected u32 constant or identifier, but found `_`".into(), }), }) @@ -1639,7 +1643,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(res) } false => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected {} generic argument{} on type {}, but got {}", ty.generics.len(), @@ -1660,17 +1664,13 @@ impl<'ast, T: Field> Checker<'ast, T> { types: &TypeMap<'ast, T>, ) -> Result, Vec> { let ty = self - .check_type(v.value._type, module_id, types) + .check_type(v.value.ty, module_id, types) .map_err(|e| vec![e])?; // insert into the scope and ignore whether shadowing happened self.insert_into_scope(v.value.id, ty.clone(), v.value.is_mutable); - Ok(Variable::new( - self.id_in_this_scope(v.value.id), - ty, - v.value.is_mutable, - )) + Ok(Variable::new(self.id_in_this_scope(v.value.id), ty)) } fn check_for_loop( @@ -1678,7 +1678,7 @@ impl<'ast, T: Field> Checker<'ast, T> { var: zokrates_ast::untyped::VariableNode<'ast>, range: (ExpressionNode<'ast>, ExpressionNode<'ast>), statements: Vec>, - pos: (Position, Position), + span: SourceSpan, module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, Vec> { @@ -1693,7 +1693,7 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedExpression::Uint(from) => match from.bitwidth() { UBitwidth::B32 => Ok(from), bitwidth => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected lower loop bound to be of type u32, found {}", Type::::Uint(bitwidth) @@ -1702,7 +1702,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }, TypedExpression::Int(v) => { UExpression::try_from_int(v, &UBitwidth::B32).map_err(|_| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected lower loop bound to be of type u32, found {}", Type::::Int @@ -1710,7 +1710,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }) } from => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected lower loop bound to be of type u32, found {}", from.get_type() @@ -1723,7 +1723,7 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedExpression::Uint(to) => match to.bitwidth() { UBitwidth::B32 => Ok(to), bitwidth => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected upper loop bound to be of type u32, found {}", Type::::Uint(bitwidth) @@ -1732,7 +1732,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }, TypedExpression::Int(v) => { UExpression::try_from_int(v, &UBitwidth::B32).map_err(|_| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected upper loop bound to be of type u32, found {}", Type::::Int @@ -1740,7 +1740,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }) } to => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected upper loop bound to be of type u32, found {}", to.get_type() @@ -1756,7 +1756,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|s| self.check_statement(s, module_id, types)) .collect::, _>>()?; - Ok(TypedStatement::For(var, from, to, checked_statements)) + Ok(TypedStatement::for_(var, from, to, checked_statements).with_span(span)) } // the assignee is already checked to be defined and mutable @@ -1789,7 +1789,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result>, ErrorInner> { - let pos = stat.pos(); + let span = stat.span().in_module(module_id); match stat.value { AssemblyStatement::Assignment(assignee, expression, constrained) => { @@ -1797,7 +1797,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let e = self.check_expression(expression, module_id, types)?; let e = FieldElementExpression::try_from_typed(e).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected right hand side of an assembly assignment to be of type field, found {}", e.get_type(), @@ -1809,25 +1809,29 @@ impl<'ast, T: Field> Checker<'ast, T> { let e = FieldElementExpression::block(vec![], e); match assignee.get_type() { Type::FieldElement => Ok(vec![ - TypedAssemblyStatement::Assignment( + TypedAssemblyStatement::assignment( assignee.clone(), e.clone().into(), - ), - TypedAssemblyStatement::Constraint( + ) + .with_span(span), + TypedAssemblyStatement::constraint( assignee.into(), e, - SourceMetadata::new(module_id.display().to_string(), pos.0), - ), + SourceMetadata::new(module_id.display().to_string(), span.from), + ) + .with_span(span), ]), ty => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Assignee must be of type field, found {}", ty), }), } } false => { let e = FieldElementExpression::block(vec![], e); - Ok(vec![TypedAssemblyStatement::Assignment(assignee, e.into())]) + Ok(vec![ + TypedAssemblyStatement::assignment(assignee, e.into()).with_span(span) + ]) } } } @@ -1836,7 +1840,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let rhs = self.check_expression(rhs, module_id, types)?; let lhs = FieldElementExpression::try_from_typed(lhs).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected left hand side of a constraint to be of type field, found {}", e.get_type(), @@ -1844,18 +1848,19 @@ impl<'ast, T: Field> Checker<'ast, T> { })?; let rhs = FieldElementExpression::try_from_typed(rhs).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected right hand side of a constraint to be of type field, found {}", e.get_type(), ), })?; - Ok(vec![TypedAssemblyStatement::Constraint( + Ok(vec![TypedAssemblyStatement::constraint( lhs, rhs, - SourceMetadata::new(module_id.display().to_string(), pos.0), - )]) + SourceMetadata::new(module_id.display().to_string(), span.from), + ) + .with_span(span)]) } } } @@ -1866,7 +1871,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, Vec> { - let pos = stat.pos(); + let span = stat.span().in_module(module_id); match stat.value { Statement::Assembly(statements) => { @@ -1877,7 +1882,9 @@ impl<'ast, T: Field> Checker<'ast, T> { .map_err(|e| vec![e])?, ); } - Ok(TypedStatement::Assembly(checked_statements)) + Ok(TypedStatement::Assembly( + AssemblyBlockStatement::new(checked_statements).with_span(span), + )) } Statement::Log(l, expressions) => { let l = FormatString::from(l); @@ -1893,7 +1900,7 @@ impl<'ast, T: Field> Checker<'ast, T> { for e in &expressions { if let TypedExpression::Int(e) = e { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot determine type for expression `{}`", e), }); } @@ -1901,7 +1908,7 @@ impl<'ast, T: Field> Checker<'ast, T> { if expressions.len() != l.len() { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Wrong argument count in log call: expected {}, got {}", l.len(), @@ -1914,7 +1921,7 @@ impl<'ast, T: Field> Checker<'ast, T> { return Err(errors); } - Ok(TypedStatement::Log(l, expressions)) + Ok(TypedStatement::log(l, expressions).with_span(span)) } Statement::Return(e) => { let mut errors = vec![]; @@ -1954,7 +1961,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let res = match TypedExpression::align_to_type(e_checked.clone(), &return_type) .map_err(|e| { vec![ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected return value to be of type `{}`, found `{}` of type `{}`", e.1, @@ -1967,7 +1974,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match e.get_type() == return_type { true => {} false => errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected `{}` in return statement, found `{}`", return_type, @@ -1975,11 +1982,11 @@ impl<'ast, T: Field> Checker<'ast, T> { ), }), } - TypedStatement::Return(e) + TypedStatement::ret(e).with_span(span) } Err(err) => { errors.extend(err); - TypedStatement::Return(e_checked) + TypedStatement::ret(e_checked).with_span(span) } }; @@ -1992,7 +1999,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Statement::Definition(var, expr) => { // get the lhs type let var_ty = self - .check_type(var.value._type, module_id, types) + .check_type(var.value.ty, module_id, types) .map_err(|e| vec![e])?; // check the rhs based on the lhs type @@ -2003,11 +2010,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // insert the lhs into the scope and ignore whether shadowing happened self.insert_into_scope(var.value.id, var_ty.clone(), var.value.is_mutable); - let var = Variable::new( - self.id_in_this_scope(var.value.id), - var_ty.clone(), - var.value.is_mutable, - ); + let var = Variable::new(self.id_in_this_scope(var.value.id), var_ty.clone()); match var_ty { Type::FieldElement => FieldElementExpression::try_from_typed(checked_expr) @@ -2032,7 +2035,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Int => Err(checked_expr), // Integers cannot be assigned } .map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expression `{}` of type `{}` cannot be assigned to `{}` of type `{}`", e, @@ -2041,7 +2044,7 @@ impl<'ast, T: Field> Checker<'ast, T> { var_ty ), }) - .map(|e| TypedStatement::Definition(var.into(), e.into())) + .map(|e| TypedStatement::definition(var.into(), e).with_span(span)) .map_err(|e| vec![e]) } Statement::Assignment(assignee, expr) => { @@ -2079,7 +2082,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Int => Err(checked_expr), // Integers cannot be assigned } .map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expression `{}` of type `{}` cannot be assigned to `{}` of type `{}`", e, @@ -2088,7 +2091,7 @@ impl<'ast, T: Field> Checker<'ast, T> { assignee_ty ), }) - .map(|e| TypedStatement::Definition(assignee, e.into())) + .map(|e| TypedStatement::definition(assignee, e).with_span(span)) .map_err(|e| vec![e]) } Statement::Assertion(e, message) => { @@ -2097,15 +2100,16 @@ impl<'ast, T: Field> Checker<'ast, T> { .map_err(|e| vec![e])?; match e { - TypedExpression::Boolean(e) => Ok(TypedStatement::Assertion( + TypedExpression::Boolean(e) => Ok(TypedStatement::assertion( e, RuntimeError::SourceAssertion( - SourceMetadata::new(module_id.display().to_string(), pos.0) + SourceMetadata::new(module_id.display().to_string(), span.from) .message(message), ), - )), + ) + .with_span(span)), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected {} to be of type bool, found {}", e, @@ -2118,7 +2122,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Statement::For(var, from, to, statements) => { self.enter_scope(); - let res = self.check_for_loop(var, (from, to), statements, pos, module_id, types); + let res = self.check_for_loop(var, (from, to), statements, span, module_id, types); self.exit_scope(); @@ -2133,23 +2137,22 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { - let pos = assignee.pos(); + let span = assignee.span().in_module(module_id); // check that the assignee is declared match assignee.value { Assignee::Identifier(variable_name) => match self.scope.get(variable_name) { Some(info) => match info.is_mutable { false => Err(ErrorInner { - pos: Some(assignee.pos()), + span: Some(assignee.span().in_module(module_id)), message: format!("Assignment to an immutable variable `{}`", variable_name), }), _ => Ok(TypedAssignee::Identifier(Variable::new( info.id, info.ty.clone(), - info.is_mutable, ))), }, None => Err(ErrorInner { - pos: Some(assignee.pos()), + span: Some(assignee.span().in_module(module_id)), message: format!("Variable `{}` is undeclared", variable_name), }), }, @@ -2172,7 +2175,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let checked_typed_index = UExpression::try_from_typed(checked_index, &UBitwidth::B32).map_err( |e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array {} index to have type u32, found {}", checked_assignee, @@ -2187,7 +2190,7 @@ impl<'ast, T: Field> Checker<'ast, T> { )) } ty => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access element at index {} on {} of type {}", index, checked_assignee, ty, @@ -2203,7 +2206,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Struct(members) => match members.iter().find(|m| m.id == member) { Some(_) => Ok(TypedAssignee::Member(box checked_assignee, member.into())), None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} {{{}}} doesn't have member {}", ty, @@ -2217,7 +2220,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), }, ty => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access field {} on {} of type {}", @@ -2234,7 +2237,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Tuple(tuple_ty) => match tuple_ty.elements.get(index as usize) { Some(_) => Ok(TypedAssignee::Element(box checked_assignee, index)), None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Tuple of size {} cannot be accessed at index {}", tuple_ty.elements.len(), @@ -2243,7 +2246,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), }, ty => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access element {} on {} of type {}", @@ -2263,7 +2266,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ) -> Result, ErrorInner> { match spread_or_expression { SpreadOrExpression::Spread(s) => { - let pos = s.pos(); + let span = s.span().in_module(module_id); let checked_expression = self.check_expression(s.value.expression, module_id, types)?; @@ -2271,7 +2274,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match checked_expression { TypedExpression::Array(a) => Ok(TypedExpressionOrSpread::Spread(a.into())), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected spread operator to apply on array, found {}", e.get_type() @@ -2294,11 +2297,11 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { - let pos = function_id.pos(); + let span = function_id.span().in_module(module_id); let fun_id = match function_id.value { Expression::Identifier(id) => Ok(id), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected function in function call to be an identifier, found `{}`", e @@ -2313,11 +2316,11 @@ impl<'ast, T: Field> Checker<'ast, T> { .into_iter() .map(|g| { g.map(|g| { - let pos = g.pos(); + let span = g.span().in_module(module_id); self.check_expression(g, module_id, types).and_then(|g| { UExpression::try_from_typed(g, &UBitwidth::B32).map_err(|e| { ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected {} to be of type u32, found {}", e, @@ -2361,7 +2364,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let signature = f.signature; let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::, _>>().map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type()) })?; @@ -2371,7 +2374,7 @@ impl<'ast, T: Field> Checker<'ast, T> { generics_checked.clone(), arguments_checked.iter().map(|a| a.get_type()).collect() ).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Failed to infer value for generic parameter `{}`, try providing an explicit value", e, @@ -2384,7 +2387,7 @@ impl<'ast, T: Field> Checker<'ast, T> { signature: signature.clone(), }; - match output_type { + let res: Result, _> = match output_type { Type::Int => unreachable!(), Type::FieldElement => Ok(FieldElementExpression::function_call( function_key, @@ -2416,17 +2419,19 @@ impl<'ast, T: Field> Checker<'ast, T> { generics_checked, arguments_checked, ).annotate(tuple_ty).into()), - } + }; + + res.map(|e| e.with_span(span)) } 0 => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Function definition for function {} with signature {} not found.", fun_id, query ), }), n => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n) }), } @@ -2438,11 +2443,11 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { - let pos = expr.pos(); + let span: SourceSpan = expr.span().in_module(module_id); match expr.value { - Expression::IntConstant(v) => Ok(IntExpression::Value(v).into()), - Expression::BooleanConstant(b) => Ok(BooleanExpression::Value(b).into()), + Expression::IntConstant(v) => Ok(IntExpression::value(v).with_span(span).into()), + Expression::BooleanConstant(b) => Ok(BooleanExpression::value(b).with_span(span).into()), Expression::Identifier(name) => { // check that `id` is defined in the scope match self.scope.get(name) { @@ -2469,7 +2474,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Identifier \"{}\" is undefined", name), }), } @@ -2484,14 +2489,15 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `+` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { - (Int(e1), Int(e2)) => Ok(IntExpression::Add(box e1, box e2).into()), + (Int(e1), Int(e2)) => Ok(IntExpression::add(e1, e2).into()), (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(FieldElementExpression::Add(box e1, box e2).into()) + Ok(FieldElementExpression::add(e1, e2) + .into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => @@ -2499,7 +2505,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok((e1 + e2).into()) } (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `+` to {}, {}", @@ -2519,18 +2525,16 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `-` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { - (Int(e1), Int(e2)) => Ok(IntExpression::Sub(box e1, box e2).into()), - (FieldElement(e1), FieldElement(e2)) => { - Ok(FieldElementExpression::Sub(box e1, box e2).into()) - } + (Int(e1), Int(e2)) => Ok(IntExpression::sub(e1, e2).into()), + (FieldElement(e1), FieldElement(e2)) => Ok((e1 - e2).into()), (Uint(e1), Uint(e2)) if e1.get_type() == e2.get_type() => Ok((e1 - e2).into()), (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected only field elements, found {}, {}", @@ -2550,14 +2554,14 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `*` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { - (Int(e1), Int(e2)) => Ok(IntExpression::Mult(box e1, box e2).into()), + (Int(e1), Int(e2)) => Ok(IntExpression::mul(e1, e2).into()), (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(FieldElementExpression::Mult(box e1, box e2).into()) + Ok((e1 * e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => @@ -2565,7 +2569,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok((e1 * e2).into()) } (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `*` to {}, {}", @@ -2585,22 +2589,20 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `/` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { - (Int(e1), Int(e2)) => Ok(IntExpression::Div(box e1, box e2).into()), - (FieldElement(e1), FieldElement(e2)) => { - Ok(FieldElementExpression::Div(box e1, box e2).into()) - } + (Int(e1), Int(e2)) => Ok(IntExpression::div(e1, e2).into()), + (FieldElement(e1), FieldElement(e2)) => Ok((e1 / e2).into()), (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => { Ok((e1 / e2).into()) } (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `/` to {}, {}", @@ -2618,7 +2620,7 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `%` to {}, {}", e1.get_type(), e2.get_type()), })?; @@ -2629,7 +2631,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok((e1 % e2).into()) } (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `%` to {}, {}", @@ -2654,10 +2656,10 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::Uint(e2)) => Ok( - TypedExpression::FieldElement(FieldElementExpression::Pow(box e1, box e2)), + TypedExpression::FieldElement(FieldElementExpression::pow(e1, e2)), ), (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected `field` and `u32`, found {}, {}", @@ -2671,13 +2673,13 @@ impl<'ast, T: Field> Checker<'ast, T> { let e = self.check_expression(e, module_id, types)?; match e { - TypedExpression::Int(e) => Ok(IntExpression::Neg(box e).into()), + TypedExpression::Int(e) => Ok(IntExpression::neg(e).into()), TypedExpression::FieldElement(e) => { - Ok(FieldElementExpression::Neg(box e).into()) + Ok(FieldElementExpression::neg(e).into()) } TypedExpression::Uint(e) => Ok((-e).into()), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Unary operator `-` cannot be applied to {} of type {}", e, @@ -2690,13 +2692,13 @@ impl<'ast, T: Field> Checker<'ast, T> { let e = self.check_expression(e, module_id, types)?; match e { - TypedExpression::Int(e) => Ok(IntExpression::Pos(box e).into()), + TypedExpression::Int(e) => Ok(IntExpression::pos(e).into()), TypedExpression::FieldElement(e) => { - Ok(FieldElementExpression::Pos(box e).into()) + Ok(FieldElementExpression::pos(e).into()) } TypedExpression::Uint(e) => Ok(UExpression::pos(e).into()), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Unary operator `+` cannot be applied to {} of type {}", e, @@ -2713,7 +2715,7 @@ impl<'ast, T: Field> Checker<'ast, T> { || !conditional.alternative_statements.is_empty() { return Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Statements are not supported in conditional branches".to_string(), }); } @@ -2729,7 +2731,7 @@ impl<'ast, T: Field> Checker<'ast, T> { alternative_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("{{consequence}} and {{alternative}} in conditional expression should have the same type, found {}, {}", e1.get_type(), e2.get_type()), })?; @@ -2767,13 +2769,13 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(IntExpression::conditional(condition, consequence, alternative, kind).into()) }, (c, a) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("{{consequence}} and {{alternative}} in conditional expression should have the same type, found {}, {}", c.get_type(), a.get_type()) }) } } c => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{{condition}} should be a boolean, found {}", c.get_type() @@ -2781,21 +2783,23 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::FieldConstant(n) => Ok(FieldElementExpression::Number( - T::try_from(n).map_err(|_| ErrorInner { - pos: Some(pos), - message: format!( - "Field constant not in the representable range [{}, {}]", - T::min_value(), - T::max_value() - ), - })?, - ) + Expression::FieldConstant(n) => Ok(FieldElementExpression::Value( + T::try_from(n) + .map(ValueExpression::new) + .map_err(|_| ErrorInner { + span: Some(span), + message: format!( + "Field constant not in the representable range [{}, {}]", + T::min_value(), + T::max_value() + ), + })?, + ).with_span(span) .into()), - Expression::U8Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(8).into()), - Expression::U16Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(16).into()), - Expression::U32Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(32).into()), - Expression::U64Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(64).into()), + Expression::U8Constant(n) => Ok(UExpression::value(n.into()).annotate(8).with_span(span).into()), + Expression::U16Constant(n) => Ok(UExpression::value(n.into()).annotate(16).with_span(span).into()), + Expression::U32Constant(n) => Ok(UExpression::value(n.into()).annotate(32).with_span(span).into()), + Expression::U64Constant(n) => Ok(UExpression::value(n.into()).annotate(64).with_span(span).into()), Expression::FunctionCall(box fun_id_expression, generics, arguments) => self .check_function_call_expression( fun_id_expression, @@ -2813,7 +2817,7 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2825,14 +2829,14 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::FieldLt(box e1, box e2).into()) + Ok(BooleanExpression::field_lt(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { if e1.get_type() == e2.get_type() { - Ok(BooleanExpression::UintLt(box e1, box e2).into()) + Ok(BooleanExpression::uint_lt(e1, e2).into()) } else { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2844,7 +2848,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2863,7 +2867,7 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2875,14 +2879,14 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::FieldLe(box e1, box e2).into()) + Ok(BooleanExpression::field_le(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { if e1.get_type() == e2.get_type() { - Ok(BooleanExpression::UintLe(box e1, box e2).into()) + Ok(BooleanExpression::uint_le(e1, e2).into()) } else { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2894,7 +2898,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2913,7 +2917,7 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2925,27 +2929,27 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::FieldEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::field_eq(e1, e2).into()) } (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { - Ok(BooleanExpression::BoolEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::bool_eq(e1, e2).into()) } (TypedExpression::Array(e1), TypedExpression::Array(e2)) => { - Ok(BooleanExpression::ArrayEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::array_eq(e1, e2).into()) } (TypedExpression::Struct(e1), TypedExpression::Struct(e2)) => { - Ok(BooleanExpression::StructEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::struct_eq(e1, e2).into()) } (TypedExpression::Tuple(e1), TypedExpression::Tuple(e2)) => { - Ok(BooleanExpression::TupleEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::tuple_eq(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => { - Ok(BooleanExpression::UintEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::uint_eq(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2964,7 +2968,7 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2976,14 +2980,14 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::FieldGe(box e1, box e2).into()) + Ok(BooleanExpression::field_ge(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { if e1.get_type() == e2.get_type() { - Ok(BooleanExpression::UintGe(box e1, box e2).into()) + Ok(BooleanExpression::uint_ge(e1, e2).into()) } else { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2995,7 +2999,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -3014,7 +3018,7 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -3026,14 +3030,14 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::FieldGt(box e1, box e2).into()) + Ok(BooleanExpression::field_gt(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { if e1.get_type() == e2.get_type() { - Ok(BooleanExpression::UintGt(box e1, box e2).into()) + Ok(BooleanExpression::uint_gt(e1, e2).into()) } else { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -3045,7 +3049,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -3081,7 +3085,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .unwrap_or_else(|| Ok(array_size.clone().into()))?; let from = UExpression::try_from_typed(from, &UBitwidth::B32).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected the lower bound of the range to be a u32, found {} of type {}", e, @@ -3090,7 +3094,7 @@ impl<'ast, T: Field> Checker<'ast, T> { })?; let to = UExpression::try_from_typed(to, &UBitwidth::B32).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected the upper bound of the range to be a u32, found {} of type {}", e, @@ -3099,15 +3103,16 @@ impl<'ast, T: Field> Checker<'ast, T> { })?; Ok(ArrayExpressionInner::Slice( - box array, - box from.clone(), - box to.clone(), - ) + SliceExpression::new( + array, + from.clone(), + to.clone(), + )) .annotate(inner_type, UExpression::floor_sub(to, from)) .into()) } e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access slice of expression {} of type {}", e, @@ -3122,7 +3127,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let index = UExpression::try_from_typed(index, &UBitwidth::B32).map_err(|e| { ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected index to be of type u32, found {}", e @@ -3145,7 +3150,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } a => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access element at index {} of type {} on expression {} of type {}", index, @@ -3177,7 +3182,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Tuple(..) => Ok(TupleExpression::element(t, index).into()), }, None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Tuple of size {} cannot be accessed at index {}", t.ty().elements.len(), @@ -3187,7 +3192,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access tuple element {} on expression of type {}", index, @@ -3225,7 +3230,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } }, None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} {{{}}} doesn't have member {}", s.get_type(), @@ -3241,7 +3246,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access member {} on expression of type {}", id, @@ -3260,7 +3265,7 @@ impl<'ast, T: Field> Checker<'ast, T> { if expressions_or_spreads_checked.is_empty() { return Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Empty arrays are not allowed".to_string(), }); } @@ -3280,14 +3285,14 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Int => expressions_or_spreads_checked, t => { let target_array_ty = - ArrayType::new(t, UExpressionInner::Value(0).annotate(UBitwidth::B32)); + ArrayType::new(t, UExpression::value(0).annotate(UBitwidth::B32)); expressions_or_spreads_checked .into_iter() .map(|e| { TypedExpressionOrSpread::align_to_type(e, &target_array_ty).map_err( |(e, ty)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Expected {} to have type {}", e, ty,), }, ) @@ -3308,11 +3313,11 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|e| e.size()) .fold(None, |acc, e| match acc { Some((c_acc, e_acc)) => match e.as_inner() { - UExpressionInner::Value(e) => Some(((c_acc + *e as u32), e_acc)), + UExpressionInner::Value(e) => Some(((c_acc + e.value as u32), e_acc)), _ => Some((c_acc, e_acc + e)), }, None => match e.as_inner() { - UExpressionInner::Value(e) => Some((*e as u32, 0u32.into())), + UExpressionInner::Value(e) => Some((e.value as u32, 0u32.into())), _ => Some((0u32, e)), }, }) @@ -3320,7 +3325,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .unwrap_or_else(|| 0u32.into()); Ok( - ArrayExpressionInner::Value(unwrapped_expressions_or_spreads.into()) + ArrayExpression::value(unwrapped_expressions_or_spreads) .annotate(inferred_type, size) .into(), ) @@ -3331,7 +3336,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|e| self.check_expression(e, module_id, types)) .collect::>()?; let ty = TupleType::new(elements.iter().map(|e| e.get_type()).collect()); - Ok(TupleExpressionInner::Value(elements).annotate(ty).into()) + Ok(TupleExpression::value(elements).annotate(ty).into()) } Expression::ArrayInitializer(box e, box count) => { let e = self.check_expression(e, module_id, types)?; @@ -3341,7 +3346,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let count = UExpression::try_from_typed(count, &UBitwidth::B32).map_err(|e| { ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array initializer count to be a u32, found {} of type {}", e, @@ -3350,14 +3355,14 @@ impl<'ast, T: Field> Checker<'ast, T> { } })?; - Ok(ArrayExpressionInner::Repeat(box e, box count.clone()) + Ok(ArrayExpressionInner::Repeat(RepeatExpression::new(e, count.clone())) .annotate(ty, count) .into()) } Expression::InlineStruct(id, inline_members) => { let ty = match types.get(module_id).unwrap().get(&id).cloned() { None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Undefined type `{}`", id), }), Some(ty) => Ok(ty), @@ -3379,7 +3384,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // check that we provided the required number of values if declared_struct_type.members_count() != inline_members.len() { return Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Inline struct {} does not match {} {{{}}}", Expression::InlineStruct(id, inline_members), @@ -3414,7 +3419,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let expression_checked = TypedExpression::align_to_type(expression_checked, &*member.ty) .map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Member {} of struct {} has type {}, found {} of type {}", member.id, @@ -3428,7 +3433,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(expression_checked) } None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Member {} of struct {} {{{}}} not found in value {}", member.id, @@ -3454,7 +3459,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|(m, v)| { if !check_type(&m.ty, &v.get_type(), &mut generics_map) { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Value `{}` doesn't match the expected type `{}` because of conflict in generic values", Expression::InlineStruct(id.clone(), inline_members.clone()), @@ -3479,7 +3484,7 @@ impl<'ast, T: Field> Checker<'ast, T> { members, }; - Ok(StructExpressionInner::Value(inferred_values) + Ok(StructExpression::value(inferred_values) .annotate(inferred_struct_type) .into()) } @@ -3491,7 +3496,7 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply boolean operators to {} and {}", e1.get_type(), @@ -3501,13 +3506,13 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { - Ok(IntExpression::And(box e1, box e2).into()) + Ok(IntExpression::and(e1, e2).into()) } (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { - Ok(BooleanExpression::And(box e1, box e2).into()) + Ok(BooleanExpression::bitand(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply boolean operators to {} and {}", @@ -3522,10 +3527,10 @@ impl<'ast, T: Field> Checker<'ast, T> { let e2_checked = self.check_expression(e2, module_id, types)?; match (e1_checked, e2_checked) { (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { - Ok(BooleanExpression::Or(box e1, box e2).into()) + Ok(BooleanExpression::bitor(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `||` to {}, {}", e1.get_type(), @@ -3540,7 +3545,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let e2 = UExpression::try_from_typed(e2, &UBitwidth::B32).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected the left shift right operand to have type `u32`, found {}", e @@ -3548,13 +3553,13 @@ impl<'ast, T: Field> Checker<'ast, T> { })?; match e1 { - TypedExpression::Int(e1) => Ok(IntExpression::LeftShift(box e1, box e2).into()), + TypedExpression::Int(e1) => Ok(IntExpression::left_shift(e1, e2).into()), TypedExpression::Uint(e1) => Ok(UExpression::left_shift(e1, e2).into()), TypedExpression::FieldElement(e1) => { - Ok(FieldElementExpression::LeftShift(box e1, box e2).into()) + Ok(FieldElementExpression::left_shift(e1, e2).into()) } e1 => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot left-shift {} by {}", e1.get_type(), @@ -3569,7 +3574,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let e2 = UExpression::try_from_typed(e2, &UBitwidth::B32).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected the right shift right operand to be of type `u32`, found {}", e @@ -3578,14 +3583,14 @@ impl<'ast, T: Field> Checker<'ast, T> { match e1 { TypedExpression::Int(e1) => { - Ok(IntExpression::RightShift(box e1, box e2).into()) + Ok(IntExpression::right_shift(e1, e2).into()) } TypedExpression::Uint(e1) => Ok(UExpression::right_shift(e1, e2).into()), TypedExpression::FieldElement(e1) => { - Ok(FieldElementExpression::RightShift(box e1, box e2).into()) + Ok(FieldElementExpression::right_shift(e1, e2).into()) } e1 => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot right-shift {} by {}", e1.get_type(), @@ -3602,16 +3607,16 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `|` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { - Ok(IntExpression::Or(box e1, box e2).into()) + Ok(IntExpression::or(e1, e2).into()) } (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(FieldElementExpression::Or(box e1, box e2).into()) + Ok(FieldElementExpression::bitor(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => @@ -3619,7 +3624,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(UExpression::or(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `|` to {}, {}", e1.get_type(), @@ -3636,16 +3641,16 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `&` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { - Ok(IntExpression::And(box e1, box e2).into()) + Ok(IntExpression::and(e1, e2).into()) } (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(FieldElementExpression::And(box e1, box e2).into()) + Ok(FieldElementExpression::bitand(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => @@ -3653,7 +3658,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(UExpression::and(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `&` to {}, {}", e1.get_type(), @@ -3670,16 +3675,16 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `^` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { - Ok(IntExpression::Xor(box e1, box e2).into()) + Ok(IntExpression::xor(e1, e2).into()) } (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(FieldElementExpression::Xor(box e1, box e2).into()) + Ok(FieldElementExpression::bitxor(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => @@ -3687,7 +3692,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(UExpression::xor(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `^` to {}, {}", e1.get_type(), @@ -3699,16 +3704,16 @@ impl<'ast, T: Field> Checker<'ast, T> { Expression::Not(box e) => { let e_checked = self.check_expression(e, module_id, types)?; match e_checked { - TypedExpression::Int(e) => Ok(IntExpression::Not(box e).into()), - TypedExpression::Boolean(e) => Ok(BooleanExpression::Not(box e).into()), + TypedExpression::Int(e) => Ok(IntExpression::not(e).into()), + TypedExpression::Boolean(e) => Ok(BooleanExpression::not(e).into()), TypedExpression::Uint(e) => Ok((!e).into()), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot negate {}", e.get_type()), }), } } - } + }.map(|e| e.with_span(span)) } fn insert_into_scope>>( @@ -3751,13 +3756,11 @@ mod tests { use zokrates_field::Bn128Field; lazy_static! { - static ref MODULE_ID: OwnedModuleId = OwnedModuleId::from(""); + static ref MODULE_ID: OwnedModuleId = OwnedModuleId::default(); } mod constants { use super::*; - use std::ops::Add; - #[test] fn field_in_range() { // The value of `P - 1` is a valid field literal @@ -3994,7 +3997,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&OwnedTypedModuleId::from("bar"), &mut state), + checker.check_module(&OwnedModuleId::from("bar"), &mut state), Ok(()) ); assert_eq!( @@ -4502,7 +4505,7 @@ mod tests { assert_eq!( Checker::::default().check_signature(signature, &*MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Undeclared symbol `K`".to_string() }]) ); @@ -4574,7 +4577,7 @@ mod tests { assert_eq!( checker.check_statement(statement, &*MODULE_ID, &TypeMap::new()), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Identifier \"b\" is undefined".into() }]) ); @@ -4675,7 +4678,7 @@ mod tests { checker.check_module(&*MODULE_ID, &mut state), Err(vec![Error { inner: ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Identifier \"a\" is undefined".into() }, module_id: (*MODULE_ID).clone() @@ -4804,7 +4807,7 @@ mod tests { assert_eq!( checker.check_function("foo", foo, &*MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Identifier \"i\" is undefined".into() }]) ); @@ -4851,7 +4854,7 @@ mod tests { )]; let foo_statements_checked = vec![ - TypedStatement::For( + TypedStatement::for_( typed::Variable::uint( CoreIdentifier::Source(ShadowedIdentifier::shadow("i".into(), 1)), UBitwidth::B32, @@ -4860,7 +4863,7 @@ mod tests { 10u32.into(), for_statements_checked, ), - TypedStatement::Return(TypedExpression::empty_tuple()), + TypedStatement::ret(TypedExpression::empty_tuple()), ]; let foo = Function { @@ -4930,7 +4933,7 @@ mod tests { assert_eq!( checker.check_function("bar", bar, &*MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Function definition for function foo with signature () -> field not found." .into() @@ -4969,7 +4972,7 @@ mod tests { assert_eq!( checker.check_function("bar", bar, &*MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Function definition for function foo with signature () -> field not found." @@ -5045,7 +5048,7 @@ mod tests { checker.check_module(&*MODULE_ID, &mut state), Err(vec![Error { inner: ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Variable `a` is undeclared".into() }, module_id: (*MODULE_ID).clone() @@ -5182,7 +5185,7 @@ mod tests { assert_eq!( checker.check_function("bar", bar, &*MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Function definition for function foo with signature () -> _ not found." .into() @@ -5215,7 +5218,7 @@ mod tests { assert_eq!( checker.check_function("bar", bar, &*MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Identifier \"a\" is undefined".into() }]) ); @@ -5326,7 +5329,7 @@ mod tests { checker.check_program(program), Err(vec![Error { inner: ErrorInner { - pos: None, + span: None, message: "Only one main function allowed, found 2".into() }, module_id: (*MODULE_ID).clone() @@ -5459,16 +5462,14 @@ mod tests { typed::Variable::new( CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 0)), Type::FieldElement, - true, ) .into(), - FieldElementExpression::Number(2u32.into()).into(), + FieldElementExpression::value(2u32.into()).into(), ), - TypedStatement::For( + TypedStatement::for_( typed::Variable::new( CoreIdentifier::from(ShadowedIdentifier::shadow("i".into(), 1)), Type::Uint(UBitwidth::B32), - false, ), 0u32.into(), 0u32.into(), @@ -5477,19 +5478,17 @@ mod tests { typed::Variable::new( CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 0)), Type::FieldElement, - true, ) .into(), - FieldElementExpression::Number(3u32.into()).into(), + FieldElementExpression::value(3u32.into()).into(), ), TypedStatement::definition( typed::Variable::new( CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 1)), Type::FieldElement, - false, ) .into(), - FieldElementExpression::Number(4u32.into()).into(), + FieldElementExpression::value(4u32.into()).into(), ), ], ), @@ -5497,10 +5496,9 @@ mod tests { typed::Variable::new( CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 0)), Type::FieldElement, - true, ) .into(), - FieldElementExpression::Number(5u32.into()).into(), + FieldElementExpression::value(5u32.into()).into(), ), ]; @@ -5941,7 +5939,7 @@ mod tests { &state.types ), Ok(FieldElementExpression::member( - StructExpressionInner::Value(vec![FieldElementExpression::Number( + StructExpression::value(vec![FieldElementExpression::value( Bn128Field::from(42u32) ) .into()]) @@ -6065,9 +6063,9 @@ mod tests { &*MODULE_ID, &state.types ), - Ok(StructExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(42u32)).into(), - BooleanExpression::Value(true).into() + Ok(StructExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(42u32)).into(), + BooleanExpression::value(true).into() ]) .annotate(StructType::new( "".into(), @@ -6118,9 +6116,9 @@ mod tests { &*MODULE_ID, &state.types ), - Ok(StructExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(42u32)).into(), - BooleanExpression::Value(true).into() + Ok(StructExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(42u32)).into(), + BooleanExpression::value(true).into() ]) .annotate(StructType::new( "".into(), @@ -6366,7 +6364,6 @@ mod tests { Ok(TypedAssignee::Identifier(typed::Variable::new( "a", Type::FieldElement, - true ))) ); } @@ -6419,7 +6416,6 @@ mod tests { box TypedAssignee::Identifier(typed::Variable::new( "a", Type::array((Type::FieldElement, 3u32)), - true, )), box 2u32.into() )) @@ -6481,7 +6477,6 @@ mod tests { box TypedAssignee::Identifier(typed::Variable::new( "a", Type::array((Type::array((Type::FieldElement, 1u32)), 1u32)), - true, )), box 0u32.into() ), diff --git a/zokrates_core_test/tests/tests/uint/ch.json b/zokrates_core_test/tests/tests/uint/ch.json index 0039cf2a6..435893ece 100644 --- a/zokrates_core_test/tests/tests/uint/ch.json +++ b/zokrates_core_test/tests/tests/uint/ch.json @@ -1,14 +1,14 @@ { "entry_point": "./tests/tests/uint/ch.zok", - "max_constraint_count": 200, + "max_constraint_count": 132, "tests": [ { "input": { - "values": ["0x00000000", "0x00000000", "0x00000000"] + "values": ["0x0000000f", "0x0000000f", "0x0000000f"] }, "output": { "Ok": { - "value": "0x00000000" + "value": "0x0000000f" } } } diff --git a/zokrates_field/src/bn128.rs b/zokrates_field/src/bn128.rs index 4b6c9cbb3..a30b3a45f 100644 --- a/zokrates_field/src/bn128.rs +++ b/zokrates_field/src/bn128.rs @@ -48,7 +48,7 @@ mod tests { ); assert_eq!( FieldPrime::from("65484493"), - FieldPrime::from("65416358") + &FieldPrime::from("68135") + FieldPrime::from("65416358") + FieldPrime::from("68135") ); } @@ -60,7 +60,7 @@ mod tests { ); assert_eq!( FieldPrime::from("3"), - FieldPrime::from("5") + &FieldPrime::from(-2) + FieldPrime::from("5") + FieldPrime::from(-2) ); } @@ -72,7 +72,7 @@ mod tests { ); assert_eq!( FieldPrime::from("65348223"), - FieldPrime::from("65416358") + &FieldPrime::from(-68135) + FieldPrime::from("65416358") + FieldPrime::from(-68135) ); } @@ -84,7 +84,7 @@ mod tests { ); assert_eq!( FieldPrime::from("65348223"), - FieldPrime::from("65416358") - &FieldPrime::from("68135") + FieldPrime::from("65416358") - FieldPrime::from("68135") ); } @@ -96,7 +96,7 @@ mod tests { ); assert_eq!( FieldPrime::from("65484493"), - FieldPrime::from("65416358") - &FieldPrime::from(-68135) + FieldPrime::from("65416358") - FieldPrime::from(-68135) ); } @@ -112,7 +112,7 @@ mod tests { FieldPrime::from( "21888242871839275222246405745257275088548364400416034343698204186575743147394" ), - FieldPrime::from("68135") - &FieldPrime::from("65416358") + FieldPrime::from("68135") - FieldPrime::from("65416358") ); } @@ -124,7 +124,7 @@ mod tests { ); assert_eq!( FieldPrime::from("13472"), - FieldPrime::from("32") * &FieldPrime::from("421") + FieldPrime::from("32") * FieldPrime::from("421") ); } @@ -140,7 +140,7 @@ mod tests { FieldPrime::from( "21888242871839275222246405745257275088548364400416034343698204186575808014369" ), - FieldPrime::from("54") * &FieldPrime::from(-8912) + FieldPrime::from("54") * FieldPrime::from(-8912) ); } @@ -152,7 +152,7 @@ mod tests { ); assert_eq!( FieldPrime::from("648"), - FieldPrime::from(-54) * &FieldPrime::from(-12) + FieldPrime::from(-54) * FieldPrime::from(-12) ); } @@ -172,7 +172,7 @@ mod tests { ), FieldPrime::from( "21888242871839225222246405785257275088694311157297823662689037894645225727" - ) * &FieldPrime::from("218882428715392752222464057432572755886923") + ) * FieldPrime::from("218882428715392752222464057432572755886923") ); } @@ -184,7 +184,7 @@ mod tests { ); assert_eq!( FieldPrime::from(4), - FieldPrime::from(48) / &FieldPrime::from(12) + FieldPrime::from(48) / FieldPrime::from(12) ); } @@ -318,7 +318,7 @@ mod tests { let a: Fr = rng.gen(); // now test idempotence let a = FieldPrime::from_bellman(a); - assert_eq!(FieldPrime::from_bellman(a.clone().into_bellman()), a); + assert_eq!(FieldPrime::from_bellman(a.into_bellman()), a); } } diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index db6ed4a59..44fabfd73 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -5,7 +5,7 @@ use zokrates_abi::{Decode, Value}; use zokrates_ast::ir::{ LinComb, ProgIterator, QuadComb, RuntimeError, Solver, Statement, Variable, Witness, }; -use zokrates_ast::zir; +use zokrates_ast::zir::{self, Expr}; use zokrates_field::Field; pub type ExecutionResult = Result, Error>; @@ -56,16 +56,16 @@ impl Interpreter { for statement in program.statements.into_iter() { match statement { Statement::Block(..) => unreachable!(), - Statement::Constraint(quad, lin, error) => match lin.is_assignee(&witness) { + Statement::Constraint(s) => match s.lin.is_assignee(&witness) { true => { - let val = evaluate_quad(&witness, &quad).unwrap(); - witness.insert(lin.0.get(0).unwrap().0, val); + let val = evaluate_quad(&witness, &s.quad).unwrap(); + witness.insert(s.lin.value.get(0).unwrap().0, val); } false => { - let lhs_value = evaluate_quad(&witness, &quad).unwrap(); - let rhs_value = evaluate_lin(&witness, &lin).unwrap(); + let lhs_value = evaluate_quad(&witness, &s.quad).unwrap(); + let rhs_value = evaluate_lin(&witness, &s.lin).unwrap(); if lhs_value != rhs_value { - return Err(Error::UnsatisfiedConstraint { error }); + return Err(Error::UnsatisfiedConstraint { error: s.error }); } } }, @@ -91,13 +91,13 @@ impl Interpreter { witness.insert(*o, res[i]); } } - Statement::Log(l, expressions) => { - let mut parts = l.parts.into_iter(); + Statement::Log(s) => { + let mut parts = s.format_string.parts.into_iter(); write!(log_stream, "{}", parts.next().unwrap()) .map_err(|_| Error::LogStream)?; - for ((t, e), part) in expressions.into_iter().zip(parts) { + for ((t, e), part) in s.expressions.into_iter().zip(parts) { let values: Vec<_> = e .iter() .map(|e| evaluate_lin(&witness, e).unwrap()) @@ -185,26 +185,26 @@ impl Interpreter { .arguments .iter() .zip(inputs) - .map(|(a, v)| match &a.id._type { + .map(|(a, v)| match &a.id.ty { zir::Type::FieldElement => Ok(( a.id.id.clone(), - zokrates_ast::zir::FieldElementExpression::Number(*v).into(), + zokrates_ast::zir::FieldElementExpression::value(*v).into(), )), zir::Type::Boolean => match v { v if *v == T::from(0) => Ok(( a.id.id.clone(), - zokrates_ast::zir::BooleanExpression::Value(false).into(), + zokrates_ast::zir::BooleanExpression::value(false).into(), )), v if *v == T::from(1) => Ok(( a.id.id.clone(), - zokrates_ast::zir::BooleanExpression::Value(true).into(), + zokrates_ast::zir::BooleanExpression::value(true).into(), )), v => Err(format!("`{}` has unexpected value `{}`", a.id, v)), }, zir::Type::Uint(bitwidth) => match v.bits() <= bitwidth.to_usize() as u32 { true => Ok(( a.id.id.clone(), - zokrates_ast::zir::UExpressionInner::Value( + zokrates_ast::zir::UExpression::value( v.to_dec_string().parse::().unwrap(), ) .annotate(*bitwidth) @@ -230,11 +230,12 @@ impl Interpreter { if let zokrates_ast::zir::ZirStatement::Return(v) = folded_function.statements[0].clone() { - v.into_iter() + v.inner + .into_iter() .map(|v| match v { zokrates_ast::zir::ZirExpression::FieldElement( - zokrates_ast::zir::FieldElementExpression::Number(n), - ) => n, + zokrates_ast::zir::FieldElementExpression::Value(n), + ) => n.value, _ => unreachable!(), }) .collect() @@ -360,7 +361,7 @@ pub enum Error { } fn evaluate_lin(w: &Witness, l: &LinComb) -> Result { - l.0.iter().try_fold(T::from(0), |acc, (var, mult)| { + l.value.iter().try_fold(T::from(0), |acc, (var, mult)| { w.0.get(var) .map(|v| acc + (*v * mult)) .ok_or(EvaluationError) // fail if any term isn't found @@ -508,6 +509,7 @@ mod tests { #[test] fn solver_ref() { + use std::ops::Mul; use zir::{ types::{Signature, Type}, FieldElementExpression, Identifier, IdentifierExpression, Parameter, Variable, @@ -519,13 +521,10 @@ mod tests { // (field i0) -> i0 * i0 let solvers = vec![Solver::Zir(ZirFunction { - arguments: vec![Parameter { - id: Variable::field_element(id.id.clone()), - private: true, - }], - statements: vec![ZirStatement::Return(vec![FieldElementExpression::Mult( - Box::new(FieldElementExpression::Identifier(id.clone())), - Box::new(FieldElementExpression::Identifier(id.clone())), + arguments: vec![Parameter::new(Variable::field_element(id.id.clone()), true)], + statements: vec![ZirStatement::ret(vec![FieldElementExpression::mul( + FieldElementExpression::Identifier(id.clone()), + FieldElementExpression::Identifier(id.clone()), ) .into()])], signature: Signature::new() diff --git a/zokrates_js/Cargo.toml b/zokrates_js/Cargo.toml index 02374289b..243f12a80 100644 --- a/zokrates_js/Cargo.toml +++ b/zokrates_js/Cargo.toml @@ -11,7 +11,7 @@ crate-type = ["cdylib"] js-sys = "0.3.33" serde = { version = "^1.0.59", features = ["derive"] } serde_json = { version = "1.0", features = ["preserve_order"] } -wasm-bindgen = { version = "0.2.46", features = ["serde-serialize"] } +wasm-bindgen = { version = "=0.2.81", features = ["serde-serialize"] } typed-arena = "1.4.1" lazy_static = "1.4.0" zokrates_field = { path = "../zokrates_field" } diff --git a/zokrates_parser/Cargo.toml b/zokrates_parser/Cargo.toml index 65f55f24f..2928dd0fe 100644 --- a/zokrates_parser/Cargo.toml +++ b/zokrates_parser/Cargo.toml @@ -5,7 +5,7 @@ authors = ["JacobEberhardt "] edition = "2018" [dependencies] -pest = "2.0" +pest = "=2.1" pest_derive = "2.0" [dev-dependencies] diff --git a/zokrates_pest_ast/Cargo.toml b/zokrates_pest_ast/Cargo.toml index 372eb132d..6e94a080d 100644 --- a/zokrates_pest_ast/Cargo.toml +++ b/zokrates_pest_ast/Cargo.toml @@ -6,7 +6,7 @@ edition = "2018" [dependencies] zokrates_parser = { version = "0.3.0", path = "../zokrates_parser" } -pest = "2.0" +pest = "=2.1" pest-ast = "0.3.3" from-pest = "0.3.1" lazy_static = "1.3.0" diff --git a/zokrates_profiler/Cargo.toml b/zokrates_profiler/Cargo.toml new file mode 100644 index 000000000..00c0295f3 --- /dev/null +++ b/zokrates_profiler/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "zokrates_profiler" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } \ No newline at end of file diff --git a/zokrates_profiler/src/lib.rs b/zokrates_profiler/src/lib.rs new file mode 100644 index 000000000..126178572 --- /dev/null +++ b/zokrates_profiler/src/lib.rs @@ -0,0 +1,52 @@ +use std::collections::HashMap; + +use zokrates_ast::{ + common::{ModuleMap, Span}, + ir::{ProgIterator, Statement}, +}; + +#[derive(Default, Debug)] +pub struct HeatMap { + /// the total number of constraints + count: usize, + /// for each span, how many constraints are linked to it + map: HashMap, usize>, +} + +impl HeatMap { + pub fn display(&self, module_map: &ModuleMap) -> String { + let count = self.count; + + let mut stats: Vec<_> = self.map.iter().collect(); + + stats.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + + stats + .iter() + .map(|(span, c)| { + format!( + "{:>4.2}% : {}", + (**c as f64) / (count as f64) * 100.0, + span.map(|s| s.resolve(module_map).to_string()) + .unwrap_or_else(|| String::from("???")), + ) + }) + .collect::>() + .join("\n") + } +} + +pub fn profile<'ast, T, I: IntoIterator>>( + prog: ProgIterator<'ast, T, I>, +) -> HeatMap { + prog.statements + .into_iter() + .fold(HeatMap::default(), |mut heat_map, s| match s { + Statement::Constraint(s) => { + heat_map.count += 1; + *heat_map.map.entry(s.span).or_default() += 1; + heat_map + } + _ => heat_map, + }) +} diff --git a/zokrates_test/tests/wasm.rs b/zokrates_test/tests/wasm.rs index cd9fc57a0..6253b39f5 100644 --- a/zokrates_test/tests/wasm.rs +++ b/zokrates_test/tests/wasm.rs @@ -17,9 +17,14 @@ use zokrates_proof_systems::groth16::G16; #[wasm_bindgen_test] fn generate_proof() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::new(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::new(0), + None, + )], solvers: vec![], }; From 8c08164038f45848c769f69beea1918067d123fe Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Thu, 13 Apr 2023 17:26:01 +0200 Subject: [PATCH 31/32] Build with stable rust (#1288) make zokrates build on stable rust --- Cargo.lock | 5 +- changelogs/unreleased/1288-schaeff | 1 + clippy.toml | 1 - dev.Dockerfile | 6 +- rust-toolchain.toml | 2 +- zokrates_abi/src/lib.rs | 2 - .../src/boolean_array_comparator.rs | 18 +- zokrates_analysis/src/constant_resolver.rs | 18 +- .../src/flatten_complex_types.rs | 30 +- zokrates_analysis/src/lib.rs | 2 - zokrates_analysis/src/out_of_bounds.rs | 12 +- zokrates_analysis/src/panic_extractor.rs | 16 +- zokrates_analysis/src/propagation.rs | 104 ++-- zokrates_analysis/src/reducer/mod.rs | 66 +-- zokrates_analysis/src/reducer/shallow_ssa.rs | 52 +- zokrates_analysis/src/struct_concretizer.rs | 10 +- .../src/variable_write_remover.rs | 33 +- zokrates_ast/Cargo.toml | 1 + zokrates_ast/src/common/embed.rs | 10 +- zokrates_ast/src/common/expressions.rs | 6 +- zokrates_ast/src/flat/mod.rs | 14 +- zokrates_ast/src/ir/expression.rs | 1 + zokrates_ast/src/ir/mod.rs | 1 + zokrates_ast/src/lib.rs | 2 - zokrates_ast/src/typed/abi.rs | 2 +- zokrates_ast/src/typed/folder.rs | 29 +- zokrates_ast/src/typed/integer.rs | 31 +- zokrates_ast/src/typed/mod.rs | 189 ++++--- zokrates_ast/src/typed/result_folder.rs | 38 +- zokrates_ast/src/typed/types.rs | 70 ++- zokrates_ast/src/typed/utils/mod.rs | 6 +- zokrates_ast/src/untyped/from_ast.rs | 233 +++++---- zokrates_ast/src/untyped/types.rs | 2 +- zokrates_ast/src/zir/identifier.rs | 20 +- zokrates_ast/src/zir/mod.rs | 8 +- zokrates_ast/src/zir/uint.rs | 2 +- zokrates_book/src/gettingstarted.md | 2 +- zokrates_cli/src/bin.rs | 19 +- zokrates_cli/src/ops/compile.rs | 6 +- zokrates_cli/src/ops/compute_witness.rs | 8 +- zokrates_cli/src/ops/export_verifier.rs | 4 +- zokrates_cli/src/ops/generate_proof.rs | 6 +- zokrates_cli/src/ops/generate_smtlib2.rs | 2 +- zokrates_cli/src/ops/inspect.rs | 2 +- zokrates_cli/src/ops/mpc/beacon.rs | 4 +- zokrates_cli/src/ops/mpc/contribute.rs | 4 +- zokrates_cli/src/ops/mpc/export.rs | 2 +- zokrates_cli/src/ops/mpc/init.rs | 4 +- zokrates_cli/src/ops/mpc/verify.rs | 4 +- zokrates_cli/src/ops/print_proof.rs | 2 +- zokrates_cli/src/ops/profile.rs | 2 +- zokrates_cli/src/ops/setup.rs | 4 +- zokrates_cli/src/ops/verify.rs | 4 +- zokrates_cli/tests/integration.rs | 4 +- zokrates_codegen/src/lib.rs | 246 +++++---- zokrates_core/src/compile.rs | 24 +- zokrates_core/src/lib.rs | 2 - zokrates_core/src/optimizer/redefinition.rs | 1 - zokrates_core/src/semantics.rs | 475 +++++++++--------- zokrates_embed/src/ark.rs | 12 +- zokrates_js/src/lib.rs | 2 +- zokrates_parser/src/lib.rs | 1 + zokrates_pest_ast/src/lib.rs | 1 + zokrates_proof_systems/src/rng.rs | 2 +- zokrates_proof_systems/src/to_token.rs | 14 +- zokrates_test_derive/src/lib.rs | 2 +- 66 files changed, 984 insertions(+), 924 deletions(-) create mode 100644 changelogs/unreleased/1288-schaeff delete mode 100644 clippy.toml diff --git a/Cargo.lock b/Cargo.lock index 5073a553b..93cdda61a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2325,9 +2325,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.95" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" +checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" dependencies = [ "indexmap", "itoa", @@ -3157,6 +3157,7 @@ dependencies = [ "byteorder", "cfg-if 0.1.10", "derivative", + "log", "num-bigint 0.2.6", "pairing_ce", "serde", diff --git a/changelogs/unreleased/1288-schaeff b/changelogs/unreleased/1288-schaeff new file mode 100644 index 000000000..a3758ebea --- /dev/null +++ b/changelogs/unreleased/1288-schaeff @@ -0,0 +1 @@ +Make ZoKrates build on stable rust \ No newline at end of file diff --git a/clippy.toml b/clippy.toml deleted file mode 100644 index 3f707a567..000000000 --- a/clippy.toml +++ /dev/null @@ -1 +0,0 @@ -blacklisted-names = [] \ No newline at end of file diff --git a/dev.Dockerfile b/dev.Dockerfile index 2bf8fed0e..bcb73b0d8 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -1,9 +1,9 @@ -FROM rustlang/rust:nightly +FROM rust:latest RUN useradd -u 1000 -m zokrates -COPY ./scripts/install_foundry_deb.sh /tmp/ -RUN /tmp/install_foundry_deb.sh +COPY ./scripts/install_foundry.sh /tmp/ +RUN /tmp/install_foundry.sh USER zokrates diff --git a/rust-toolchain.toml b/rust-toolchain.toml index cbfe2704f..292fe499e 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2022-07-01" +channel = "stable" diff --git a/zokrates_abi/src/lib.rs b/zokrates_abi/src/lib.rs index a85adb219..d12f65d8f 100644 --- a/zokrates_abi/src/lib.rs +++ b/zokrates_abi/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(box_patterns, box_syntax)] - pub enum Inputs { Raw(Vec), Abi(Values), diff --git a/zokrates_analysis/src/boolean_array_comparator.rs b/zokrates_analysis/src/boolean_array_comparator.rs index f9696d276..72c158193 100644 --- a/zokrates_analysis/src/boolean_array_comparator.rs +++ b/zokrates_analysis/src/boolean_array_comparator.rs @@ -1,8 +1,8 @@ use zokrates_ast::{ common::WithSpan, typed::{ - folder::*, ArrayExpression, BooleanExpression, Conditional, ConditionalKind, Expr, - FieldElementExpression, Select, Type, TypedExpression, TypedProgram, UExpression, + folder::*, ArrayExpression, ArrayType, BooleanExpression, Conditional, ConditionalKind, + Expr, FieldElementExpression, Select, Type, TypedExpression, TypedProgram, UExpression, UExpressionInner, }, }; @@ -42,7 +42,7 @@ impl<'ast, T: Field> Folder<'ast, T> for BooleanArrayComparator { _ => unreachable!("array size should be known"), }; - let chunk_size = T::get_required_bits() as usize - 1; + let chunk_size = T::get_required_bits() - 1; let left_elements: Vec<_> = (0..len) .map(|i| BooleanExpression::select(*e.left.clone(), i as u32).span(span)) @@ -91,10 +91,10 @@ impl<'ast, T: Field> Folder<'ast, T> for BooleanArrayComparator { BooleanExpression::array_eq( ArrayExpression::value(left) - .annotate(Type::FieldElement, chunk_count as u32) + .annotate(ArrayType::new(Type::FieldElement, chunk_count as u32)) .span(span), ArrayExpression::value(right) - .annotate(Type::FieldElement, chunk_count as u32) + .annotate(ArrayType::new(Type::FieldElement, chunk_count as u32)) .span(span), ) } @@ -124,8 +124,8 @@ mod tests { // [x[0] ? 2**1 : 0 + x[1] ? 2**0 : 0] == [y[0] ? 2**1 : 0 + y[1] ? 2**0 : 0] // a single field is sufficient, as the prime we're working with is 3 bits long, so we can pack up to 2 bits - let x = a_id("x").annotate(Type::Boolean, 2u32); - let y = a_id("y").annotate(Type::Boolean, 2u32); + let x = a_id("x").annotate(ArrayType::new(Type::Boolean, 2u32)); + let y = a_id("y").annotate(ArrayType::new(Type::Boolean, 2u32)); let e: BooleanExpression = BooleanExpression::ArrayEq(BinaryExpression::new(x.clone(), y.clone())); @@ -152,8 +152,8 @@ mod tests { // should become // [x[0] ? 2**2 : 0 + x[1] ? 2**1 : 0, x[2] ? 2**0 : 0] == [y[0] ? 2**2 : 0 + y[1] ? 2**1 : 0 y[2] ? 2**0 : 0] - let x = a_id("x").annotate(Type::Boolean, 3u32); - let y = a_id("y").annotate(Type::Boolean, 3u32); + let x = a_id("x").annotate(ArrayType::new(Type::Boolean, 3u32)); + let y = a_id("y").annotate(ArrayType::new(Type::Boolean, 3u32)); let e: BooleanExpression = BooleanExpression::ArrayEq(BinaryExpression::new(x.clone(), y.clone())); diff --git a/zokrates_analysis/src/constant_resolver.rs b/zokrates_analysis/src/constant_resolver.rs index f39167d39..5ff42f4c2 100644 --- a/zokrates_analysis/src/constant_resolver.rs +++ b/zokrates_analysis/src/constant_resolver.rs @@ -113,7 +113,7 @@ mod tests { use zokrates_ast::typed::types::{DeclarationSignature, GTupleType}; use zokrates_ast::typed::{ DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression, - GType, Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, + Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedStatement, }; use zokrates_field::Bn128Field; @@ -316,12 +316,12 @@ mod tests { FieldElementExpression::add( FieldElementExpression::select( ArrayExpression::identifier(Identifier::from(const_id.clone())) - .annotate(GType::FieldElement, 2u32), + .annotate(GArrayType::new(Type::FieldElement, 2u32)), UExpression::value(0u128).annotate(UBitwidth::B32), ), FieldElementExpression::select( ArrayExpression::identifier(Identifier::from(const_id.clone())) - .annotate(GType::FieldElement, 2u32), + .annotate(GArrayType::new(Type::FieldElement, 2u32)), UExpression::value(1u128).annotate(UBitwidth::B32), ), ) @@ -347,7 +347,7 @@ mod tests { FieldElementExpression::value(Bn128Field::from(2)).into(), FieldElementExpression::value(Bn128Field::from(2)).into(), ]) - .annotate(GType::FieldElement, 2u32), + .annotate(GArrayType::new(Type::FieldElement, 2u32)), ), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, @@ -741,8 +741,9 @@ mod tests { main_baz_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::Array( - ArrayExpression::identifier(main_bar_const_id.clone().into()) - .annotate(Type::FieldElement, main_foo_const_id.clone()), + ArrayExpression::identifier(main_bar_const_id.clone().into()).annotate( + ArrayType::new(Type::FieldElement, main_foo_const_id.clone()), + ), ), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, @@ -814,8 +815,9 @@ mod tests { main_baz_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::Array( - ArrayExpression::identifier(main_bar_const_id.into()) - .annotate(Type::FieldElement, main_foo_const_id.clone()), + ArrayExpression::identifier(main_bar_const_id.into()).annotate( + ArrayType::new(Type::FieldElement, main_foo_const_id.clone()), + ), ), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, diff --git a/zokrates_analysis/src/flatten_complex_types.rs b/zokrates_analysis/src/flatten_complex_types.rs index 20cfa9436..64705d1ad 100644 --- a/zokrates_analysis/src/flatten_complex_types.rs +++ b/zokrates_analysis/src/flatten_complex_types.rs @@ -39,17 +39,14 @@ fn flatten_identifier_rec<'ast>( } typed::types::ConcreteType::Array(array_type) => (0..*array_type.size) .flat_map(|i| { - flatten_identifier_rec( - SourceIdentifier::Select(Box::new(id.clone()), i), - &array_type.ty, - ) + flatten_identifier_rec(SourceIdentifier::select(id.clone(), i), &array_type.ty) }) .collect(), typed::types::ConcreteType::Struct(members) => members .iter() .flat_map(|struct_member| { flatten_identifier_rec( - SourceIdentifier::Member(Box::new(id.clone()), struct_member.id.clone()), + SourceIdentifier::member(id.clone(), struct_member.id.clone()), &struct_member.ty, ) }) @@ -59,10 +56,7 @@ fn flatten_identifier_rec<'ast>( .iter() .enumerate() .flat_map(|(i, ty)| { - flatten_identifier_rec( - SourceIdentifier::Element(Box::new(id.clone()), i as u32), - ty, - ) + flatten_identifier_rec(SourceIdentifier::element(id.clone(), i as u32), ty) }) .collect(), } @@ -88,7 +82,7 @@ fn flatten_identifier_to_expression_rec<'ast, T: Field>( typed::ConcreteType::Array(array_type) => (0..*array_type.size) .flat_map(|i| { flatten_identifier_to_expression_rec( - SourceIdentifier::Select(box id.clone(), i), + SourceIdentifier::select(id.clone(), i), &array_type.ty, ) }) @@ -97,7 +91,7 @@ fn flatten_identifier_to_expression_rec<'ast, T: Field>( .iter() .flat_map(|struct_member| { flatten_identifier_to_expression_rec( - SourceIdentifier::Member(box id.clone(), struct_member.id.clone()), + SourceIdentifier::member(id.clone(), struct_member.id.clone()), &struct_member.ty, ) }) @@ -108,7 +102,7 @@ fn flatten_identifier_to_expression_rec<'ast, T: Field>( .enumerate() .flat_map(|(i, ty)| { flatten_identifier_to_expression_rec( - SourceIdentifier::Element(box id.clone(), i as u32), + SourceIdentifier::element(id.clone(), i as u32), ty, ) }) @@ -231,12 +225,12 @@ impl<'ast, T: Field> Flattener { fn fold_assignee(&mut self, a: typed::TypedAssignee<'ast, T>) -> Vec> { match a { typed::TypedAssignee::Identifier(v) => self.fold_variable(v), - typed::TypedAssignee::Select(box a, box i) => { + typed::TypedAssignee::Select(a, i) => { let count = match typed::ConcreteType::try_from(a.get_type()).unwrap() { typed::ConcreteType::Array(array_ty) => array_ty.ty.get_primitive_count(), _ => unreachable!(), }; - let a = self.fold_assignee(a); + let a = self.fold_assignee(*a); match i.as_inner() { typed::UExpressionInner::Value(index) => { @@ -245,7 +239,7 @@ impl<'ast, T: Field> Flattener { i => unreachable!("index {:?} not allowed, should be a constant", i), } } - typed::TypedAssignee::Member(box a, m) => { + typed::TypedAssignee::Member(a, m) => { let (offset, size) = match typed::ConcreteType::try_from(a.get_type()).unwrap() { typed::ConcreteType::Struct(struct_type) => struct_type @@ -263,11 +257,11 @@ impl<'ast, T: Field> Flattener { let size = size.unwrap(); - let a = self.fold_assignee(a); + let a = self.fold_assignee(*a); a[offset..offset + size].to_vec() } - typed::TypedAssignee::Element(box a, index) => { + typed::TypedAssignee::Element(a, index) => { let tuple_ty = typed::ConcreteTupleType::try_from( typed::ConcreteType::try_from(a.get_type()).unwrap(), ) @@ -282,7 +276,7 @@ impl<'ast, T: Field> Flattener { let size = &tuple_ty.elements[index as usize].get_primitive_count(); - let a = self.fold_assignee(a); + let a = self.fold_assignee(*a); a[offset..offset + size].to_vec() } diff --git a/zokrates_analysis/src/lib.rs b/zokrates_analysis/src/lib.rs index 539fe86cb..2cf0c24fc 100644 --- a/zokrates_analysis/src/lib.rs +++ b/zokrates_analysis/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(box_patterns, box_syntax)] - //! Module containing static analysis //! //! @file mod.rs diff --git a/zokrates_analysis/src/out_of_bounds.rs b/zokrates_analysis/src/out_of_bounds.rs index 5be99af50..7ed8613a7 100644 --- a/zokrates_analysis/src/out_of_bounds.rs +++ b/zokrates_analysis/src/out_of_bounds.rs @@ -46,10 +46,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for OutOfBoundsChecker { a: TypedAssignee<'ast, T>, ) -> Result, Error> { match a { - TypedAssignee::Select(box array, box index) => { + TypedAssignee::Select(array, index) => { use zokrates_ast::typed::Typed; - let array = self.fold_assignee(array)?; + let array = self.fold_assignee(*array)?; let size = match array.get_type() { Type::Array(array_ty) => match array_ty.size.as_inner() { @@ -62,13 +62,13 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for OutOfBoundsChecker { match index.as_inner() { UExpressionInner::Value(i) if i.value >= size => Err(Error(format!( "Out of bounds write to `{}` because `{}` has size {}", - TypedAssignee::Select(box array.clone(), box index), + TypedAssignee::select(array.clone(), *index), array, size ))), - _ => Ok(TypedAssignee::Select( - box self.fold_assignee(array)?, - box self.fold_uint_expression(index)?, + _ => Ok(TypedAssignee::select( + array, + self.fold_uint_expression(*index)?, )), } } diff --git a/zokrates_analysis/src/panic_extractor.rs b/zokrates_analysis/src/panic_extractor.rs index 312abab05..dd2628cd5 100644 --- a/zokrates_analysis/src/panic_extractor.rs +++ b/zokrates_analysis/src/panic_extractor.rs @@ -1,6 +1,6 @@ use std::ops::*; use zokrates_ast::{ - common::{expressions::BinaryExpression, Fold, WithSpan}, + common::{Fold, WithSpan}, zir::{ folder::*, BooleanExpression, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, FieldElementExpression, IfElseStatement, RuntimeError, UBitwidth, UExpression, @@ -160,14 +160,12 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> { ) -> BooleanExpression<'ast, T> { match e { // constant range checks are complete, so no panic needs to be extracted - e @ BooleanExpression::FieldLt(BinaryExpression { - left: box FieldElementExpression::Value(_), - .. - }) - | e @ BooleanExpression::FieldLt(BinaryExpression { - right: box FieldElementExpression::Value(_), - .. - }) => fold_boolean_expression_cases(self, e), + BooleanExpression::FieldLt(b) + if matches!(b.left.as_ref(), FieldElementExpression::Value(_)) + || matches!(b.right.as_ref(), FieldElementExpression::Value(_)) => + { + fold_boolean_expression_cases(self, BooleanExpression::FieldLt(b)) + } BooleanExpression::FieldLt(e) => { let span = e.get_span(); diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index 06e8c5b43..4fc9e6651 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -73,30 +73,28 @@ impl<'ast, T: Field> Propagator<'ast, T> { .get_mut(&var.id) .map(|c| Ok((var, c))) .unwrap_or(Err(var)), - TypedAssignee::Select(box assignee, box index) => { - match self.try_get_constant_mut(assignee) { - Ok((variable, constant)) => match index.as_inner() { - UExpressionInner::Value(n) => match constant { - TypedExpression::Array(a) => match a.as_inner_mut() { - ArrayExpressionInner::Value(value) => { - match value.value.get_mut(n.value as usize) { - Some(TypedExpressionOrSpread::Expression(ref mut e)) => { - Ok((variable, e)) - } - None => Err(variable), - _ => unreachable!(), + TypedAssignee::Select(assignee, index) => match self.try_get_constant_mut(assignee) { + Ok((variable, constant)) => match index.as_inner() { + UExpressionInner::Value(n) => match constant { + TypedExpression::Array(a) => match a.as_inner_mut() { + ArrayExpressionInner::Value(value) => { + match value.value.get_mut(n.value as usize) { + Some(TypedExpressionOrSpread::Expression(ref mut e)) => { + Ok((variable, e)) } + None => Err(variable), + _ => unreachable!(), } - _ => unreachable!("should be an array value"), - }, - _ => unreachable!("should be an array expression"), + } + _ => unreachable!("should be an array value"), }, - _ => Err(variable), + _ => unreachable!("should be an array expression"), }, - e => e, - } - } - TypedAssignee::Member(box assignee, m) => match self.try_get_constant_mut(assignee) { + _ => Err(variable), + }, + e => e, + }, + TypedAssignee::Member(assignee, m) => match self.try_get_constant_mut(assignee) { Ok((v, c)) => { let ty = assignee.get_type(); @@ -119,20 +117,18 @@ impl<'ast, T: Field> Propagator<'ast, T> { } e => e, }, - TypedAssignee::Element(box assignee, index) => { - match self.try_get_constant_mut(assignee) { - Ok((v, c)) => match c { - TypedExpression::Tuple(a) => match a.as_inner_mut() { - TupleExpressionInner::Value(value) => { - Ok((v, &mut value.value[*index as usize])) - } - _ => unreachable!("should be a tuple value"), - }, - _ => unreachable!("should be a tuple expression"), + TypedAssignee::Element(assignee, index) => match self.try_get_constant_mut(assignee) { + Ok((v, c)) => match c { + TypedExpression::Tuple(a) => match a.as_inner_mut() { + TupleExpressionInner::Value(value) => { + Ok((v, &mut value.value[*index as usize])) + } + _ => unreachable!("should be a tuple value"), }, - e => e, - } - } + _ => unreachable!("should be a tuple expression"), + }, + e => e, + }, } } } @@ -418,7 +414,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { .map(|v| BooleanExpression::value(v).into()) .collect::>(), ) - .annotate(Type::Boolean, bitwidth.to_usize() as u32) + .annotate(ArrayType::new(Type::Boolean, bitwidth.to_usize() as u32)) .into() } _ => unreachable!("should be a uint value"), @@ -526,7 +522,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { }) .collect::>(), ) - .annotate(Type::Boolean, bit_width) + .annotate(ArrayType::new(Type::Boolean, bit_width)) .span(span) .into(), )) @@ -1192,7 +1188,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { None => Ok(SelectOrExpression::Expression( E::select( ArrayExpressionInner::Identifier(id) - .annotate(inner_type, size.value as u32), + .annotate(ArrayType::new(inner_type, size.value as u32)), UExpressionInner::Value(n).annotate(UBitwidth::B32), ) .into_inner(), @@ -1200,7 +1196,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } } (a, i) => Ok(SelectOrExpression::Select(SelectExpression::new( - a.annotate(inner_type, size.value as u32), + a.annotate(ArrayType::new(inner_type, size.value as u32)), i.annotate(UBitwidth::B32), ))), }, @@ -1706,7 +1702,7 @@ mod tests { FieldElementExpression::value(Bn128Field::from(2)).into(), FieldElementExpression::value(Bn128Field::from(3)).into(), ]) - .annotate(Type::FieldElement, 3u32), + .annotate(ArrayType::new(Type::FieldElement, 3u32)), UExpression::add(1u32.into(), 1u32.into()), ); @@ -1837,34 +1833,38 @@ mod tests { ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( FieldElementExpression::value(Bn128Field::from(2usize)).into(), )]) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( FieldElementExpression::value(Bn128Field::from(2usize)).into(), )]) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); let e_constant_false = BooleanExpression::ArrayEq(BinaryExpression::new( ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( FieldElementExpression::value(Bn128Field::from(2usize)).into(), )]) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( FieldElementExpression::value(Bn128Field::from(4usize)).into(), )]) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); let e_identifier_true: BooleanExpression = BooleanExpression::ArrayEq(BinaryExpression::new( - ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32), - ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32), + ArrayExpression::identifier("a".into()) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), + ArrayExpression::identifier("a".into()) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); let e_identifier_unchanged: BooleanExpression = BooleanExpression::ArrayEq(BinaryExpression::new( - ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32), - ArrayExpression::identifier("b".into()).annotate(Type::FieldElement, 1u32), + ArrayExpression::identifier("a".into()) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), + ArrayExpression::identifier("b".into()) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); let e_non_canonical_true = BooleanExpression::ArrayEq(BinaryExpression::new( @@ -1872,14 +1872,14 @@ mod tests { ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( FieldElementExpression::value(Bn128Field::from(2usize)).into(), )]) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), )]) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( FieldElementExpression::value(Bn128Field::from(2usize)).into(), )]) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); let e_non_canonical_false = BooleanExpression::ArrayEq(BinaryExpression::new( @@ -1887,14 +1887,14 @@ mod tests { ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( FieldElementExpression::value(Bn128Field::from(2usize)).into(), )]) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), )]) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( FieldElementExpression::value(Bn128Field::from(4usize)).into(), )]) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); assert_eq!( diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index 3328bf5cb..c248cccf2 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -437,7 +437,7 @@ mod tests { use zokrates_ast::typed::types::DeclarationSignature; use zokrates_ast::typed::types::{DeclarationConstant, GTupleType}; use zokrates_ast::typed::{ - ArrayExpression, DeclarationFunctionKey, DeclarationType, DeclarationVariable, + ArrayExpression, ArrayType, DeclarationFunctionKey, DeclarationType, DeclarationVariable, FieldElementExpression, GenericIdentifier, Identifier, Select, TupleExpression, TupleType, Type, TypedExpression, TypedExpressionOrSpread, UBitwidth, Variable, }; @@ -650,7 +650,7 @@ mod tests { .into()], statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), )], signature: foo_signature.clone(), @@ -674,7 +674,7 @@ mod tests { ArrayExpression::value(vec![ FieldElementExpression::identifier("a".into()).into() ]) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( @@ -684,10 +684,10 @@ mod tests { .signature(foo_signature.clone()), vec![None], vec![ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into()], ) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( @@ -700,7 +700,7 @@ mod tests { (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), 0u32, )) .into(), @@ -750,28 +750,28 @@ mod tests { ArrayExpression::value(vec![ FieldElementExpression::identifier("a".into()).into() ]) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier(Identifier::from("a").in_frame(1)) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::ret( (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier(Identifier::from("b").version(1)) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), 0u32, )) .into(), @@ -850,7 +850,7 @@ mod tests { .into()], statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), )], signature: foo_signature.clone(), @@ -884,7 +884,7 @@ mod tests { ArrayExpression::value(vec![ FieldElementExpression::identifier("a".into()).into() ]) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( @@ -894,10 +894,10 @@ mod tests { .signature(foo_signature.clone()), vec![None], vec![ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into()], ) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( @@ -910,7 +910,7 @@ mod tests { (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), 0u32, )) .into(), @@ -960,28 +960,28 @@ mod tests { ArrayExpression::value(vec![ FieldElementExpression::identifier("a".into()).into() ]) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier(Identifier::from("a").in_frame(1)) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::ret( (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier(Identifier::from("b").version(1)) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), 0u32, )) .into(), @@ -1076,23 +1076,23 @@ mod tests { vec![ArrayExpression::value(vec![ TypedExpressionOrSpread::Spread( ArrayExpression::identifier("a".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), FieldElementExpression::value(Bn128Field::from(0)).into(), ]) - .annotate( + .annotate(ArrayType::new( Type::FieldElement, UExpression::identifier("K".into()).annotate(UBitwidth::B32) + 1u32.into(), - ) + )) .into()], ) - .annotate( + .annotate(ArrayType::new( Type::FieldElement, UExpression::identifier("K".into()).annotate(UBitwidth::B32) + 1u32.into(), - ), + )), 0u32.into(), UExpression::identifier("K".into()).annotate(UBitwidth::B32), ) @@ -1100,10 +1100,10 @@ mod tests { ), TypedStatement::ret( ArrayExpression::identifier("ret".into()) - .annotate( + .annotate(ArrayType::new( Type::FieldElement, UExpression::identifier("K".into()).annotate(UBitwidth::B32), - ) + )) .into(), ), ], @@ -1121,10 +1121,10 @@ mod tests { .into()], statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) - .annotate( + .annotate(ArrayType::new( Type::FieldElement, UExpression::identifier("K".into()).annotate(UBitwidth::B32), - ) + )) .into(), )], signature: bar_signature.clone(), @@ -1143,10 +1143,10 @@ mod tests { Bn128Field::from(1), ) .into()]) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into()], ) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::ret( @@ -1257,7 +1257,7 @@ mod tests { .into()], statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), )], signature: foo_signature.clone(), @@ -1273,10 +1273,10 @@ mod tests { .signature(foo_signature.clone()), vec![None], vec![ArrayExpression::value(vec![]) - .annotate(Type::FieldElement, 0u32) + .annotate(ArrayType::new(Type::FieldElement, 0u32)) .into()], ) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::ret( diff --git a/zokrates_analysis/src/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs index 8bb244a73..46f9cd962 100644 --- a/zokrates_analysis/src/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -103,15 +103,15 @@ impl<'ast> ShallowTransformer<'ast> { ) -> TypedAssignee<'ast, T> { match a { TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)), - TypedAssignee::Select(box a, box index) => TypedAssignee::Select( - box self.fold_assignee_no_ssa_increase(a), - box self.fold_uint_expression(index), + TypedAssignee::Select(a, index) => TypedAssignee::select( + self.fold_assignee_no_ssa_increase(*a), + self.fold_uint_expression(*index), ), - TypedAssignee::Member(box s, m) => { - TypedAssignee::Member(box self.fold_assignee_no_ssa_increase(s), m) + TypedAssignee::Member(s, m) => { + TypedAssignee::member(self.fold_assignee_no_ssa_increase(*s), m) } - TypedAssignee::Element(box s, index) => { - TypedAssignee::Element(box self.fold_assignee_no_ssa_increase(s), index) + TypedAssignee::Element(s, index) => { + TypedAssignee::element(self.fold_assignee_no_ssa_increase(*s), index) } } } @@ -409,7 +409,7 @@ mod tests { FieldElementExpression::value(Bn128Field::from(1)).into(), FieldElementExpression::value(Bn128Field::from(1)).into(), ]) - .annotate(Type::FieldElement, 2u32) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into(), ); @@ -425,15 +425,19 @@ mod tests { FieldElementExpression::value(Bn128Field::from(1)).into(), FieldElementExpression::value(Bn128Field::from(1)).into() ]) - .annotate(Type::FieldElement, 2u32) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into() )] ); let s: TypedStatement = TypedStatement::definition( TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)), - box UExpression::from(1u32), + Box::new(TypedAssignee::Identifier(Variable::array( + "a", + Type::FieldElement, + 2u32, + ))), + Box::new(UExpression::from(1u32)), ), FieldElementExpression::value(Bn128Field::from(2)).into(), ); @@ -461,16 +465,19 @@ mod tests { FieldElementExpression::value(Bn128Field::from(0)).into(), FieldElementExpression::value(Bn128Field::from(1)).into(), ]) - .annotate(Type::FieldElement, 2u32) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into(), ArrayExpression::value(vec![ FieldElementExpression::value(Bn128Field::from(2)).into(), FieldElementExpression::value(Bn128Field::from(3)).into(), ]) - .annotate(Type::FieldElement, 2u32) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into(), ]) - .annotate(Type::array((Type::FieldElement, 2u32)), 2u32) + .annotate(ArrayType::new( + Type::array((Type::FieldElement, 2u32)), + 2u32, + )) .into(), ); @@ -486,30 +493,33 @@ mod tests { FieldElementExpression::value(Bn128Field::from(0)).into(), FieldElementExpression::value(Bn128Field::from(1)).into(), ]) - .annotate(Type::FieldElement, 2u32) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into(), ArrayExpression::value(vec![ FieldElementExpression::value(Bn128Field::from(2)).into(), FieldElementExpression::value(Bn128Field::from(3)).into(), ]) - .annotate(Type::FieldElement, 2u32) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into(), ]) - .annotate(Type::array((Type::FieldElement, 2u32)), 2u32) + .annotate(ArrayType::new( + Type::array((Type::FieldElement, 2u32)), + 2u32 + )) .into(), )] ); let s: TypedStatement = TypedStatement::definition( - TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::new("a", array_of_array_ty.clone())), - box UExpression::from(1u32), + TypedAssignee::select( + TypedAssignee::Identifier(Variable::new("a", array_of_array_ty.clone())), + UExpression::from(1u32), ), ArrayExpression::value(vec![ FieldElementExpression::value(Bn128Field::from(4)).into(), FieldElementExpression::value(Bn128Field::from(5)).into(), ]) - .annotate(Type::FieldElement, 2u32) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into(), ); diff --git a/zokrates_analysis/src/struct_concretizer.rs b/zokrates_analysis/src/struct_concretizer.rs index 932a92670..6e57c0484 100644 --- a/zokrates_analysis/src/struct_concretizer.rs +++ b/zokrates_analysis/src/struct_concretizer.rs @@ -70,7 +70,7 @@ impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast, T> { .collect(), generics: concrete_generics .into_iter() - .map(|g| Some(DeclarationConstant::Concrete(g as u32))) + .map(|g| Some(DeclarationConstant::Concrete(g))) .collect(), ..ty } @@ -82,9 +82,9 @@ impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast, T> { ) -> DeclarationArrayType<'ast, T> { let size = ty.size.map_concrete(&self.generics).unwrap(); - DeclarationArrayType { - size: box DeclarationConstant::Concrete(size), - ty: box self.fold_declaration_type(*ty.ty), - } + DeclarationArrayType::new( + self.fold_declaration_type(*ty.ty), + DeclarationConstant::Concrete(size), + ) } } diff --git a/zokrates_analysis/src/variable_write_remover.rs b/zokrates_analysis/src/variable_write_remover.rs index 9defeeaf5..2d9ce8685 100644 --- a/zokrates_analysis/src/variable_write_remover.rs +++ b/zokrates_analysis/src/variable_write_remover.rs @@ -56,7 +56,7 @@ impl<'ast> VariableWriteRemover { let tail = indices; match head { - Access::Select(box head) => { + Access::Select(head) => { statements.insert(TypedStatement::assertion( BooleanExpression::uint_lt( head.clone(), @@ -266,7 +266,7 @@ impl<'ast> VariableWriteRemover { }) .collect::>(), ) - .annotate(inner_ty.clone(), size) + .annotate(ArrayType::new(inner_ty.clone(), size)) .into() } _ => unreachable!(), @@ -525,27 +525,28 @@ impl<'ast> VariableWriteRemover { #[derive(Clone, Debug)] enum Access<'ast, T: Field> { - Select(Box>), + Select(UExpression<'ast, T>), Member(MemberId), Element(u32), } + /// Turn an assignee into its representation as a base variable and a list accesses /// a[2][3][4] -> (a, [2, 3, 4]) fn linear(a: TypedAssignee) -> (Variable, Vec>) { match a { TypedAssignee::Identifier(v) => (v, vec![]), - TypedAssignee::Select(box array, box index) => { - let (v, mut indices) = linear(array); - indices.push(Access::Select(box index)); + TypedAssignee::Select(array, index) => { + let (v, mut indices) = linear(*array); + indices.push(Access::Select(*index)); (v, indices) } - TypedAssignee::Member(box s, m) => { - let (v, mut indices) = linear(s); + TypedAssignee::Member(s, m) => { + let (v, mut indices) = linear(*s); indices.push(Access::Member(m)); (v, indices) } - TypedAssignee::Element(box s, i) => { - let (v, mut indices) = linear(s); + TypedAssignee::Element(s, i) => { + let (v, mut indices) = linear(*s); indices.push(Access::Element(i)); (v, indices) } @@ -555,12 +556,12 @@ fn linear(a: TypedAssignee) -> (Variable, Vec>) { fn is_constant(assignee: &TypedAssignee) -> bool { match assignee { TypedAssignee::Identifier(_) => true, - TypedAssignee::Select(box assignee, box index) => match index.as_inner() { + TypedAssignee::Select(assignee, index) => match index.as_inner() { UExpressionInner::Value(_) => is_constant(assignee), _ => false, }, - TypedAssignee::Member(box assignee, _) => is_constant(assignee), - TypedAssignee::Element(box assignee, _) => is_constant(assignee), + TypedAssignee::Member(ref assignee, _) => is_constant(assignee), + TypedAssignee::Element(ref assignee, _) => is_constant(assignee), } } @@ -616,7 +617,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for VariableWriteRemover { .annotate(bitwidth) .into(), Type::Array(array_type) => ArrayExpression::identifier(variable.id.clone()) - .annotate(*array_type.ty, *array_type.size) + .annotate(array_type) .into(), Type::Struct(members) => StructExpression::identifier(variable.id.clone()) .annotate(members) @@ -633,9 +634,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for VariableWriteRemover { let indices = indices .into_iter() .map(|a| match a { - Access::Select(box i) => { - Ok(Access::Select(box self.fold_uint_expression(i)?)) - } + Access::Select(i) => Ok(Access::Select(self.fold_uint_expression(i)?)), a => Ok(a), }) .collect::>()?; diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index 39ba46946..38d79d775 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -9,6 +9,7 @@ bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman"] ark = ["ark-bls12-377", "zokrates_embed/ark"] [dependencies] +log = "0.4" byteorder = "1.4.3" zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" } cfg-if = "0.1" diff --git a/zokrates_ast/src/common/embed.rs b/zokrates_ast/src/common/embed.rs index 18e902716..de5d3ad07 100644 --- a/zokrates_ast/src/common/embed.rs +++ b/zokrates_ast/src/common/embed.rs @@ -301,7 +301,7 @@ impl FlatEmbed { ); assert_eq!(gen.len(), assignment.0.len()); - gen.map(|g| *assignment.0.get(&g).unwrap() as u32).collect() + gen.map(|g| *assignment.0.get(&g).unwrap()).collect() } pub fn id(&self) -> &'static str { @@ -355,15 +355,11 @@ pub fn sha256_round<'ast, T: Field>( let cs_indices = 0..variable_count; // indices of the arguments to the function // apply an offset of `variable_count` to get the indice of our dummy `input` argument - let input_argument_indices: Vec<_> = input_indices - .clone() - .into_iter() - .map(|i| i + variable_count) - .collect(); + let input_argument_indices: Vec<_> = + input_indices.clone().map(|i| i + variable_count).collect(); // apply an offset of `variable_count` to get the indice of our dummy `current_hash` argument let current_hash_argument_indices: Vec<_> = current_hash_indices .clone() - .into_iter() .map(|i| i + variable_count) .collect(); // define parameters to the function based on the variables diff --git a/zokrates_ast/src/common/expressions.rs b/zokrates_ast/src/common/expressions.rs index 4b61b1f82..01481ce6a 100644 --- a/zokrates_ast/src/common/expressions.rs +++ b/zokrates_ast/src/common/expressions.rs @@ -23,8 +23,8 @@ impl BinaryExpression { pub fn new(left: L, right: R) -> Self { Self { span: None, - left: box left, - right: box right, + left: Box::new(left), + right: Box::new(right), operator: PhantomData, output: PhantomData, } @@ -72,7 +72,7 @@ impl UnaryExpression { pub fn new(inner: In) -> Self { Self { span: None, - inner: box inner, + inner: Box::new(inner), operator: PhantomData, output: PhantomData, } diff --git a/zokrates_ast/src/flat/mod.rs b/zokrates_ast/src/flat/mod.rs index e1a86e7a9..44bd42573 100644 --- a/zokrates_ast/src/flat/mod.rs +++ b/zokrates_ast/src/flat/mod.rs @@ -386,16 +386,10 @@ impl FlatExpression { FlatExpression::Add(ref e) => e.left.is_linear() && e.right.is_linear(), FlatExpression::Sub(ref e) => e.left.is_linear() && e.right.is_linear(), FlatExpression::Mult(ref e) => matches!( - (&e.left, &e.right), - (box FlatExpression::Value(_), box FlatExpression::Value(_)) - | ( - box FlatExpression::Value(_), - box FlatExpression::Identifier(_) - ) - | ( - box FlatExpression::Identifier(_), - box FlatExpression::Value(_) - ) + (&*e.left, &*e.right), + (FlatExpression::Value(_), FlatExpression::Value(_)) + | (FlatExpression::Value(_), FlatExpression::Identifier(_)) + | (FlatExpression::Identifier(_), FlatExpression::Value(_)) ), } } diff --git a/zokrates_ast/src/ir/expression.rs b/zokrates_ast/src/ir/expression.rs index 5394b6eac..b3dd59478 100644 --- a/zokrates_ast/src/ir/expression.rs +++ b/zokrates_ast/src/ir/expression.rs @@ -29,6 +29,7 @@ impl WithSpan for QuadComb { } impl QuadComb { + #[allow(clippy::result_large_err)] pub fn try_linear(self) -> Result, Self> { // identify `(k * ~ONE) * (lincomb)` and `(lincomb) * (k * ~ONE)` and return (k * lincomb) // if not, error out with the input diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index 2adbd802c..e3aecebf1 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -115,6 +115,7 @@ impl<'ast, T: Field> fmt::Display for BlockStatement<'ast, T> { } } +#[allow(clippy::large_enum_variant)] #[derive(Debug, Serialize, Deserialize, Clone, Derivative)] #[derivative(Hash, PartialEq, Eq)] pub enum Statement<'ast, T> { diff --git a/zokrates_ast/src/lib.rs b/zokrates_ast/src/lib.rs index 797b85619..084173f84 100644 --- a/zokrates_ast/src/lib.rs +++ b/zokrates_ast/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(box_patterns, box_syntax)] - pub mod common; pub mod flat; pub mod ir; diff --git a/zokrates_ast/src/typed/abi.rs b/zokrates_ast/src/typed/abi.rs index 7c9d995fa..88a0da9e4 100644 --- a/zokrates_ast/src/typed/abi.rs +++ b/zokrates_ast/src/typed/abi.rs @@ -22,7 +22,7 @@ impl Abi { ConcreteSignature { generics: vec![], inputs: self.inputs.iter().map(|i| i.ty.clone()).collect(), - output: box self.output.clone(), + output: Box::new(self.output.clone()), } } } diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index bc05d026f..424109bf3 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -167,8 +167,8 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_array_type(&mut self, t: ArrayType<'ast, T>) -> ArrayType<'ast, T> { ArrayType { - ty: box self.fold_type(*t.ty), - size: box self.fold_uint_expression(*t.size), + ty: Box::new(self.fold_type(*t.ty)), + size: Box::new(self.fold_uint_expression(*t.size)), } } @@ -189,7 +189,7 @@ pub trait Folder<'ast, T: Field>: Sized { .members .into_iter() .map(|m| StructMember { - ty: box self.fold_type(*m.ty), + ty: Box::new(self.fold_type(*m.ty)), ..m }) .collect(), @@ -213,8 +213,8 @@ pub trait Folder<'ast, T: Field>: Sized { t: DeclarationArrayType<'ast, T>, ) -> DeclarationArrayType<'ast, T> { DeclarationArrayType { - ty: box self.fold_declaration_type(*t.ty), - size: box self.fold_declaration_constant(*t.size), + ty: Box::new(self.fold_declaration_type(*t.ty)), + size: Box::new(self.fold_declaration_constant(*t.size)), } } @@ -245,7 +245,7 @@ pub trait Folder<'ast, T: Field>: Sized { .members .into_iter() .map(|m| DeclarationStructMember { - ty: box self.fold_declaration_type(*m.ty), + ty: Box::new(self.fold_declaration_type(*m.ty)), ..m }) .collect(), @@ -1559,7 +1559,7 @@ fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>( .into_iter() .map(|o| f.fold_declaration_type(o)) .collect(), - output: box f.fold_declaration_type(*s.output), + output: Box::new(f.fold_declaration_type(*s.output)), } } @@ -1584,7 +1584,7 @@ pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>( ArrayExpression { inner: f.fold_array_expression_inner(&ty, e.inner), - ty: box ty, + ty: Box::new(ty), } } @@ -1709,12 +1709,13 @@ pub fn fold_assignee<'ast, T: Field, F: Folder<'ast, T>>( ) -> TypedAssignee<'ast, T> { match a { TypedAssignee::Identifier(v) => TypedAssignee::Identifier(f.fold_variable(v)), - TypedAssignee::Select(box a, box index) => { - TypedAssignee::Select(box f.fold_assignee(a), box f.fold_uint_expression(index)) - } - TypedAssignee::Member(box s, m) => TypedAssignee::Member(box f.fold_assignee(s), m), - TypedAssignee::Element(box s, index) => { - TypedAssignee::Element(box f.fold_assignee(s), index) + TypedAssignee::Select(a, index) => TypedAssignee::Select( + Box::new(f.fold_assignee(*a)), + Box::new(f.fold_uint_expression(*index)), + ), + TypedAssignee::Member(s, m) => TypedAssignee::Member(Box::new(f.fold_assignee(*s)), m), + TypedAssignee::Element(s, index) => { + TypedAssignee::Element(Box::new(f.fold_assignee(*s)), index) } } } diff --git a/zokrates_ast/src/typed/integer.rs b/zokrates_ast/src/typed/integer.rs index 8f8b11915..d5ba9ad58 100644 --- a/zokrates_ast/src/typed/integer.rs +++ b/zokrates_ast/src/typed/integer.rs @@ -110,7 +110,7 @@ impl<'ast, T: Clone> IntegerInference for StructType<'ast, T> { .zip(other.members.into_iter()) .map(|(m_t, m_u)| match m_t.ty.get_common_pattern(*m_u.ty) { Ok(ty) => DeclarationStructMember { - ty: box ty, + ty: Box::new(ty), id: m_t.id, }, Err(..) => unreachable!( @@ -590,7 +590,8 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { }) .collect::, _>>()?; Ok(FieldElementExpression::select( - ArrayExpression::value(values).annotate(Type::FieldElement, size), + ArrayExpression::value(values) + .annotate(ArrayType::new(Type::FieldElement, size)), index, )) } @@ -712,7 +713,8 @@ impl<'ast, T: Field> UExpression<'ast, T> { }) .collect::, _>>()?; Ok(UExpression::select( - ArrayExpression::value(values).annotate(Type::Uint(*bitwidth), size), + ArrayExpression::value(values) + .annotate(ArrayType::new(Type::Uint(*bitwidth), size)), index, )) } @@ -772,12 +774,16 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { let inner_ty = res.value[0].get_type().0; - Ok(ArrayExpressionInner::Value(res).annotate(inner_ty, *array_ty.size)) + let array_ty = ArrayType::new(inner_ty, *array_ty.size); + + Ok(ArrayExpressionInner::Value(res).annotate(array_ty)) } ArrayExpressionInner::Repeat(r) => { match &*target_array_ty.ty { GType::Int => { - Ok(ArrayExpressionInner::Repeat(r).annotate(Type::Int, *array_ty.size)) + let array_ty = ArrayType::new(Type::Int, *array_ty.size); + + Ok(ArrayExpressionInner::Repeat(r).annotate(array_ty)) } // try to align the repeated element to the target type t => TypedExpression::align_to_type(*r.e, t) @@ -785,16 +791,16 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { let ty = e.get_type().clone(); ArrayExpressionInner::Repeat(RepeatExpression::new(e, *r.count)) - .annotate(ty, *array_ty.size) + .annotate(ArrayType::new(ty, *array_ty.size)) }) .map_err(|(e, _)| e), } } a => { if *target_array_ty.ty == *array_ty.ty { - Ok(a.annotate(*array_ty.ty, *array_ty.size)) + Ok(a.annotate(array_ty)) } else { - Err(a.annotate(*array_ty.ty, *array_ty.size).into()) + Err(a.annotate(array_ty).into()) } } } @@ -920,11 +926,12 @@ mod tests { let n: IntExpression = BigUint::from(42usize).into(); let n_a: ArrayExpression = ArrayExpressionInner::Value(ValueExpression::new(vec![n.clone().into()])) - .annotate(Type::Int, 1u32); + .annotate(ArrayType::new(Type::Int, 1u32)); let t: FieldElementExpression = Bn128Field::from(42).into(); let t_a: ArrayExpression = ArrayExpressionInner::Value(ValueExpression::new(vec![t.clone().into()])) - .annotate(Type::FieldElement, 1u32); + .annotate(ArrayType::new(Type::FieldElement, 1u32)); + let i: UExpression = 42u32.into(); let c: BooleanExpression = true.into(); @@ -981,11 +988,11 @@ mod tests { let n: IntExpression = BigUint::from(42usize).into(); let n_a: ArrayExpression = ArrayExpressionInner::Value(ValueExpression::new(vec![n.clone().into()])) - .annotate(Type::Int, 1u32); + .annotate(ArrayType::new(Type::Int, 1u32)); let t: UExpression = 42u32.into(); let t_a: ArrayExpression = ArrayExpressionInner::Value(ValueExpression::new(vec![t.clone().into()])) - .annotate(Type::Uint(UBitwidth::B32), 1u32); + .annotate(ArrayType::new(Type::Uint(UBitwidth::B32), 1u32)); let i: UExpression = 0u32.into(); let c: BooleanExpression = true.into(); diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index 53e2096c5..77789fe06 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -405,6 +405,20 @@ pub enum TypedAssignee<'ast, T> { Element(Box>, u32), } +impl<'ast, T> TypedAssignee<'ast, T> { + pub fn select(self, index: UExpression<'ast, T>) -> Self { + Self::Select(Box::new(self), Box::new(index)) + } + + pub fn member(self, member: MemberId) -> Self { + Self::Member(Box::new(self), member) + } + + pub fn element(self, index: u32) -> Self { + Self::Element(Box::new(self), index) + } +} + #[derive(Clone, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)] pub struct TypedSpread<'ast, T> { pub array: ArrayExpression<'ast, T>, @@ -1072,7 +1086,7 @@ impl<'ast, T, E> BlockExpression<'ast, T, E> { BlockExpression { span: None, statements, - value: box value, + value: Box::new(value), } } } @@ -1172,7 +1186,7 @@ impl<'ast, T, E> MemberExpression<'ast, T, E> { pub fn new(struc: StructExpression<'ast, T>, id: MemberId) -> Self { MemberExpression { span: None, - struc: box struc, + struc: Box::new(struc), id, ty: PhantomData, } @@ -1200,8 +1214,8 @@ impl<'ast, T, E> SelectExpression<'ast, T, E> { pub fn new(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { SelectExpression { span: None, - array: box array, - index: box index, + array: Box::new(array), + index: Box::new(index), ty: PhantomData, } } @@ -1228,7 +1242,7 @@ impl<'ast, T, E> ElementExpression<'ast, T, E> { pub fn new(tuple: TupleExpression<'ast, T>, index: u32) -> Self { ElementExpression { span: None, - tuple: box tuple, + tuple: Box::new(tuple), index, ty: PhantomData, } @@ -1268,9 +1282,9 @@ impl<'ast, T, E> ConditionalExpression<'ast, T, E> { ) -> Self { ConditionalExpression { span: None, - condition: box condition, - consequence: box consequence, - alternative: box alternative, + condition: Box::new(condition), + consequence: Box::new(consequence), + alternative: Box::new(alternative), kind, } } @@ -1387,9 +1401,9 @@ impl<'ast, T: Field> From> for TupleExpression<'ast, T> { _ => unreachable!(), } } - TypedAssignee::Select(box a, box index) => TupleExpression::select(a.into(), index), - TypedAssignee::Member(box a, id) => TupleExpression::member(a.into(), id), - TypedAssignee::Element(box a, index) => TupleExpression::element(a.into(), index), + TypedAssignee::Select(a, index) => TupleExpression::select((*a).into(), *index), + TypedAssignee::Member(a, id) => TupleExpression::member((*a).into(), id), + TypedAssignee::Element(a, index) => TupleExpression::element((*a).into(), index), } } } @@ -1404,9 +1418,9 @@ impl<'ast, T: Field> From> for StructExpression<'ast, T> _ => unreachable!(), } } - TypedAssignee::Select(box a, box index) => StructExpression::select(a.into(), index), - TypedAssignee::Member(box a, id) => StructExpression::member(a.into(), id), - TypedAssignee::Element(box a, index) => StructExpression::element(a.into(), index), + TypedAssignee::Select(a, index) => StructExpression::select((*a).into(), *index), + TypedAssignee::Member(a, id) => StructExpression::member((*a).into(), id), + TypedAssignee::Element(a, index) => StructExpression::element((*a).into(), index), } } } @@ -1417,13 +1431,13 @@ impl<'ast, T: Field> From> for ArrayExpression<'ast, T> { TypedAssignee::Identifier(v) => { let inner = ArrayExpression::identifier(v.id); match v.ty { - GType::Array(array_ty) => inner.annotate(*array_ty.ty, *array_ty.size), + GType::Array(array_ty) => inner.annotate(array_ty), _ => unreachable!(), } } - TypedAssignee::Select(box a, box index) => ArrayExpression::select(a.into(), index), - TypedAssignee::Member(box a, id) => ArrayExpression::member(a.into(), id), - TypedAssignee::Element(box a, index) => ArrayExpression::element(a.into(), index), + TypedAssignee::Select(a, index) => ArrayExpression::select((*a).into(), *index), + TypedAssignee::Member(a, id) => ArrayExpression::member((*a).into(), id), + TypedAssignee::Element(a, index) => ArrayExpression::element((*a).into(), index), } } } @@ -1432,13 +1446,9 @@ impl<'ast, T: Field> From> for FieldElementExpression<'as fn from(assignee: TypedAssignee<'ast, T>) -> Self { match assignee { TypedAssignee::Identifier(v) => FieldElementExpression::identifier(v.id), - TypedAssignee::Element(box a, index) => { - FieldElementExpression::element(a.into(), index) - } - TypedAssignee::Member(box a, id) => FieldElementExpression::member(a.into(), id), - TypedAssignee::Select(box a, box index) => { - FieldElementExpression::select(a.into(), index) - } + TypedAssignee::Element(a, index) => FieldElementExpression::element((*a).into(), index), + TypedAssignee::Member(a, id) => FieldElementExpression::member((*a).into(), id), + TypedAssignee::Select(a, index) => FieldElementExpression::select((*a).into(), *index), } } } @@ -1737,11 +1747,7 @@ impl<'ast, T: Field> ArrayValueExpression<'ast, T> { } a => (0..size.value) .map(|i| { - Some(U::select( - a.clone() - .annotate(*array_ty.ty.clone(), *array_ty.size.clone()), - i as u32, - )) + Some(U::select(a.clone().annotate(array_ty.clone()), i as u32)) }) .collect(), } @@ -1780,13 +1786,9 @@ pub enum ArrayExpressionInner<'ast, T> { } impl<'ast, T> ArrayExpressionInner<'ast, T> { - pub fn annotate>>( - self, - ty: Type<'ast, T>, - size: S, - ) -> ArrayExpression<'ast, T> { + pub fn annotate(self, ty: ArrayType<'ast, T>) -> ArrayExpression<'ast, T> { ArrayExpression { - ty: box (ty, size.into()).into(), + ty: Box::new(ty), inner: self, } } @@ -1808,13 +1810,15 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> { ) -> Self { let inner = array.inner_type().clone(); let size = to.clone() - from.clone(); - ArrayExpressionInner::Slice(SliceExpression::new(array, from, to)).annotate(inner, size) + let array_ty = ArrayType::new(inner, size); + ArrayExpressionInner::Slice(SliceExpression::new(array, from, to)).annotate(array_ty) } pub fn repeat(e: TypedExpression<'ast, T>, count: UExpression<'ast, T>) -> Self { let inner = e.get_type().clone(); let size = count.clone(); - ArrayExpressionInner::Repeat(RepeatExpression::new(e, count)).annotate(inner, size) + let array_ty = ArrayType::new(inner, size); + ArrayExpressionInner::Repeat(RepeatExpression::new(e, count)).annotate(array_ty) } } @@ -2205,9 +2209,7 @@ impl<'ast, T: Field> From> for TypedExpression<'ast, T> { match v.get_type() { Type::FieldElement => FieldElementExpression::identifier(v.id).into(), Type::Boolean => BooleanExpression::identifier(v.id).into(), - Type::Array(ty) => ArrayExpression::identifier(v.id) - .annotate(*ty.ty, *ty.size) - .into(), + Type::Array(ty) => ArrayExpression::identifier(v.id).annotate(ty).into(), Type::Struct(ty) => StructExpression::identifier(v.id).annotate(ty).into(), Type::Tuple(ty) => TupleExpression::identifier(v.id).annotate(ty).into(), Type::Uint(w) => UExpression::identifier(v.id).annotate(w).into(), @@ -2835,7 +2837,7 @@ impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for StructExpression<'ast, T> } } -impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for ArrayExpression<'ast, T> { +impl<'ast, T: Clone + fmt::Display> Expr<'ast, T> for ArrayExpression<'ast, T> { type Inner = ArrayExpressionInner<'ast, T>; type Ty = ArrayType<'ast, T>; type ConcreteTy = ConcreteArrayType; @@ -3040,22 +3042,21 @@ impl<'ast, T> Conditional<'ast, T> for UExpression<'ast, T> { } } -impl<'ast, T: Clone> Conditional<'ast, T> for ArrayExpression<'ast, T> { +impl<'ast, T: Clone + fmt::Display> Conditional<'ast, T> for ArrayExpression<'ast, T> { fn conditional( condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self, kind: ConditionalKind, ) -> Self { - let ty = consequence.inner_type().clone(); - let size = consequence.size(); + let ty = consequence.ty().clone(); ArrayExpressionInner::Conditional(ConditionalExpression::new( condition, consequence, alternative, kind, )) - .annotate(ty, size) + .annotate(ty) } } @@ -3144,13 +3145,9 @@ impl<'ast, T: Clone> Select<'ast, T> for UExpression<'ast, T> { impl<'ast, T: Clone> Select<'ast, T> for ArrayExpression<'ast, T> { fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self { - let (ty, size) = match array.inner_type() { - Type::Array(array_type) => (array_type.ty.clone(), array_type.size.clone()), - _ => unreachable!(), - }; + let array_ty = array.inner_type().clone().try_into().unwrap(); - ArrayExpressionInner::Select(SelectExpression::new(array, index.into())) - .annotate(*ty, *size) + ArrayExpressionInner::Select(SelectExpression::new(array, index.into())).annotate(array_ty) } } @@ -3194,56 +3191,60 @@ impl<'ast, T> Member<'ast, T> for BooleanExpression<'ast, T> { impl<'ast, T: Clone> Member<'ast, T> for UExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - let ty = s.ty().members.iter().find(|member| id == member.id); - let bitwidth = match ty { - Some(crate::typed::types::StructMember { - ty: box Type::Uint(bitwidth), - .. - }) => *bitwidth, - _ => unreachable!(), - }; + let ty = *s + .ty() + .members + .iter() + .find(|member| id == member.id) + .unwrap() + .ty + .clone(); + let bitwidth: UBitwidth = ty.try_into().unwrap(); UExpressionInner::Member(MemberExpression::new(s, id)).annotate(bitwidth) } } impl<'ast, T: Clone> Member<'ast, T> for ArrayExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - let ty = s.ty().members.iter().find(|member| id == member.id); - let (ty, size) = match ty { - Some(crate::typed::types::StructMember { - ty: box Type::Array(array_ty), - .. - }) => (*array_ty.ty.clone(), array_ty.size.clone()), - _ => unreachable!(), - }; - ArrayExpressionInner::Member(MemberExpression::new(s, id)).annotate(ty, *size) + let ty = *s + .ty() + .members + .iter() + .find(|member| id == member.id) + .unwrap() + .ty + .clone(); + let array_ty: ArrayType<'ast, T> = ty.try_into().unwrap(); + ArrayExpressionInner::Member(MemberExpression::new(s, id)).annotate(array_ty) } } impl<'ast, T: Clone> Member<'ast, T> for StructExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - let ty = s.ty().members.iter().find(|member| id == member.id); - let struct_ty = match ty { - Some(crate::typed::types::StructMember { - ty: box Type::Struct(struct_ty), - .. - }) => struct_ty.clone(), - _ => unreachable!(), - }; + let ty = *s + .ty() + .members + .iter() + .find(|member| id == member.id) + .unwrap() + .ty + .clone(); + let struct_ty = ty.try_into().unwrap(); StructExpressionInner::Member(MemberExpression::new(s, id)).annotate(struct_ty) } } impl<'ast, T: Clone> Member<'ast, T> for TupleExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - let ty = s.ty().members.iter().find(|member| id == member.id); - let tuple_ty = match ty { - Some(crate::typed::types::StructMember { - ty: box Type::Tuple(tuple_ty), - .. - }) => tuple_ty.clone(), - _ => unreachable!(), - }; + let ty = *s + .ty() + .members + .iter() + .find(|member| id == member.id) + .unwrap() + .ty + .clone(); + let tuple_ty = ty.try_into().unwrap(); TupleExpressionInner::Member(MemberExpression::new(s, id)).annotate(tuple_ty) } } @@ -3277,12 +3278,9 @@ impl<'ast, T: Clone> Element<'ast, T> for UExpression<'ast, T> { impl<'ast, T: Clone> Element<'ast, T> for ArrayExpression<'ast, T> { fn element(s: TupleExpression<'ast, T>, id: u32) -> Self { - let ty = &s.ty().elements[id as usize]; - let (ty, size) = match ty { - Type::Array(array_ty) => (*array_ty.ty.clone(), array_ty.size.clone()), - _ => unreachable!(), - }; - ArrayExpressionInner::Element(ElementExpression::new(s, id)).annotate(ty, *size) + let ty = s.ty().elements[id as usize].clone(); + let array_ty = ty.try_into().unwrap(); + ArrayExpressionInner::Element(ElementExpression::new(s, id)).annotate(array_ty) } } @@ -3442,8 +3440,7 @@ impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> { impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> { fn block(statements: Vec>, value: Self) -> Self { let array_ty = value.ty().clone(); - ArrayExpressionInner::Block(BlockExpression::new(statements, value)) - .annotate(*array_ty.ty, *array_ty.size) + ArrayExpressionInner::Block(BlockExpression::new(statements, value)).annotate(array_ty) } } @@ -3561,7 +3558,7 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { .map(|e| e.into()) .collect::>(), ) - .annotate(*array_ty.ty, *array_ty.size), + .annotate(array_ty), ArrayExpressionInner::Slice(e) => { let from = match e.from.into_inner() { UExpressionInner::Value(from) => from.value as usize, @@ -3586,7 +3583,7 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { .take(to - from) .collect::>(), ) - .annotate(*array_ty.ty, *array_ty.size) + .annotate(array_ty) } ArrayExpressionInner::Repeat(e) => { let count = match e.count.into_inner() { @@ -3597,7 +3594,7 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { let e = e.e.into_canonical_constant(); ArrayExpression::value(vec![TypedExpressionOrSpread::Expression(e); count]) - .annotate(*array_ty.ty, *array_ty.size) + .annotate(array_ty) } _ => unreachable!(), } diff --git a/zokrates_ast/src/typed/result_folder.rs b/zokrates_ast/src/typed/result_folder.rs index 531d45275..d8541add0 100644 --- a/zokrates_ast/src/typed/result_folder.rs +++ b/zokrates_ast/src/typed/result_folder.rs @@ -334,8 +334,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized { t: ArrayType<'ast, T>, ) -> Result, Self::Error> { Ok(ArrayType { - ty: box self.fold_type(*t.ty)?, - size: box self.fold_uint_expression(*t.size)?, + ty: Box::new(self.fold_type(*t.ty)?), + size: Box::new(self.fold_uint_expression(*t.size)?), }) } @@ -367,8 +367,10 @@ pub trait ResultFolder<'ast, T: Field>: Sized { .into_iter() .map(|m| { let id = m.id; - self.fold_type(*m.ty) - .map(|ty| StructMember { ty: box ty, id }) + self.fold_type(*m.ty).map(|ty| StructMember { + ty: Box::new(ty), + id, + }) }) .collect::>()?, ..t @@ -394,8 +396,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized { t: DeclarationArrayType<'ast, T>, ) -> Result, Self::Error> { Ok(DeclarationArrayType { - ty: box self.fold_declaration_type(*t.ty)?, - size: box self.fold_declaration_constant(*t.size)?, + ty: Box::new(self.fold_declaration_type(*t.ty)?), + size: Box::new(self.fold_declaration_constant(*t.size)?), }) } @@ -428,7 +430,10 @@ pub trait ResultFolder<'ast, T: Field>: Sized { .map(|m| { let id = m.id; self.fold_declaration_type(*m.ty) - .map(|ty| DeclarationStructMember { ty: box ty, id }) + .map(|ty| DeclarationStructMember { + ty: Box::new(ty), + id, + }) }) .collect::>()?, ..t @@ -958,14 +963,15 @@ pub fn fold_assignee<'ast, T: Field, F: ResultFolder<'ast, T>>( ) -> Result, F::Error> { match a { TypedAssignee::Identifier(v) => Ok(TypedAssignee::Identifier(f.fold_variable(v)?)), - TypedAssignee::Select(box a, box index) => Ok(TypedAssignee::Select( - box f.fold_assignee(a)?, - box f.fold_uint_expression(index)?, + TypedAssignee::Select(a, index) => Ok(TypedAssignee::Select( + Box::new(f.fold_assignee(*a)?), + Box::new(f.fold_uint_expression(*index)?), + )), + TypedAssignee::Member(s, m) => Ok(TypedAssignee::Member(Box::new(f.fold_assignee(*s)?), m)), + TypedAssignee::Element(s, index) => Ok(TypedAssignee::Element( + Box::new(f.fold_assignee(*s)?), + index, )), - TypedAssignee::Member(box s, m) => Ok(TypedAssignee::Member(box f.fold_assignee(s)?, m)), - TypedAssignee::Element(box s, index) => { - Ok(TypedAssignee::Element(box f.fold_assignee(s)?, index)) - } } } @@ -1683,7 +1689,7 @@ pub fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>( .into_iter() .map(|o| f.fold_declaration_type(o)) .collect::>()?, - output: box f.fold_declaration_type(*s.output)?, + output: Box::new(f.fold_declaration_type(*s.output)?), }) } @@ -1725,7 +1731,7 @@ pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( Ok(ArrayExpression { inner: f.fold_array_expression_inner(&ty, e.inner)?, - ty: box ty, + ty: Box::new(ty), }) } diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index 8ae15de8b..ece3547e7 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -282,7 +282,7 @@ impl<'ast, T> TryInto for DeclarationConstant<'ast, T> { pub type MemberId = String; -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Debug, Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub struct GStructMember { #[serde(rename = "name")] @@ -306,7 +306,7 @@ fn try_from_g_struct_member, U>( ) -> Result, SpecializationError> { Ok(GStructMember { id: t.id, - ty: box try_from_g_type(*t.ty)?, + ty: Box::new(try_from_g_type(*t.ty)?), }) } @@ -324,7 +324,7 @@ impl<'ast, T> From for StructMember<'ast, T> { } } -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)] pub struct GArrayType { pub size: Box, @@ -372,8 +372,8 @@ fn try_from_g_array_type, U>( t: GArrayType, ) -> Result, SpecializationError> { Ok(GArrayType { - size: box (*t.size).try_into().map_err(|_| SpecializationError)?, - ty: box try_from_g_type(*t.ty)?, + size: Box::new((*t.size).try_into().map_err(|_| SpecializationError)?), + ty: Box::new(try_from_g_type(*t.ty)?), }) } @@ -391,7 +391,7 @@ impl<'ast, T> From for ArrayType<'ast, T> { } } -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)] pub struct GTupleType { pub elements: Vec>, @@ -470,6 +470,42 @@ impl TryFrom> for GTupleType { } } +impl TryFrom> for GStructType { + type Error = (); + + fn try_from(t: GType) -> Result { + if let GType::Struct(t) = t { + Ok(t) + } else { + Err(()) + } + } +} + +impl TryFrom> for GArrayType { + type Error = (); + + fn try_from(t: GType) -> Result { + if let GType::Array(t) = t { + Ok(t) + } else { + Err(()) + } + } +} + +impl TryFrom> for UBitwidth { + type Error = (); + + fn try_from(t: GType) -> Result { + if let GType::Uint(t) = t { + Ok(t) + } else { + Err(()) + } + } +} + #[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialOrd, Ord, Eq, PartialEq)] pub struct StructLocation { #[serde(skip)] @@ -643,7 +679,7 @@ impl fmt::Display for UBitwidth { } } -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Clone, Eq, Hash, PartialOrd, Ord, Debug)] pub enum GType { FieldElement, @@ -821,8 +857,8 @@ impl<'ast, T> From for DeclarationType<'ast, T> { impl> From<(GType, U)> for GArrayType { fn from(tup: (GType, U)) -> Self { GArrayType { - ty: box tup.0, - size: box tup.1.into(), + ty: Box::new(tup.0), + size: Box::new(tup.1.into()), } } } @@ -830,8 +866,8 @@ impl> From<(GType, U)> for GArrayType { impl GArrayType { pub fn new>(ty: GType, size: U) -> Self { GArrayType { - ty: box ty, - size: box size.into(), + ty: Box::new(ty), + size: Box::new(size.into()), } } } @@ -1208,8 +1244,12 @@ pub fn specialize_declaration_type< .into_iter() .map(|m| { let id = m.id; - specialize_declaration_type(*m.ty, &inside_generics) - .map(|ty| GStructMember { ty: box ty, id }) + specialize_declaration_type(*m.ty, &inside_generics).map(|ty| { + GStructMember { + ty: Box::new(ty), + id, + } + }) }) .collect::>()?, generics: s0 @@ -1250,7 +1290,7 @@ pub mod signature { Self { generics: vec![], inputs: vec![], - output: box GType::Tuple(GTupleType::new(vec![])), + output: Box::new(GType::Tuple(GTupleType::new(vec![]))), } } } @@ -1399,7 +1439,7 @@ pub mod signature { .into_iter() .map(try_from_g_type) .collect::>()?, - output: box try_from_g_type(*t.output)?, + output: Box::new(try_from_g_type(*t.output)?), }) } diff --git a/zokrates_ast/src/typed/utils/mod.rs b/zokrates_ast/src/typed/utils/mod.rs index 36b6d6aea..ad2c406f0 100644 --- a/zokrates_ast/src/typed/utils/mod.rs +++ b/zokrates_ast/src/typed/utils/mod.rs @@ -1,6 +1,6 @@ use super::{ ArrayExpression, ArrayExpressionInner, BooleanExpression, Conditional, ConditionalKind, Expr, - FieldElementExpression, Id, Identifier, Select, Typed, TypedExpression, + FieldElementExpression, GArrayType, Id, Identifier, Select, Typed, TypedExpression, TypedExpressionOrSpread, UBitwidth, UExpression, ValueExpression, }; @@ -23,13 +23,15 @@ pub fn a< values: [E; N], ) -> ArrayExpression<'ast, T> { let ty = values[0].get_type(); + + let array_ty = GArrayType::new(ty, N as u32); ArrayExpression::value( values .into_iter() .map(|e| TypedExpressionOrSpread::Expression(e.into())) .collect(), ) - .annotate(ty, N as u32) + .annotate(array_ty) } pub fn u_32<'ast, T: Field, U: TryInto>(v: U) -> UExpression<'ast, T> { diff --git a/zokrates_ast/src/untyped/from_ast.rs b/zokrates_ast/src/untyped/from_ast.rs index 88c12d6ce..78e02b7d0 100644 --- a/zokrates_ast/src/untyped/from_ast.rs +++ b/zokrates_ast/src/untyped/from_ast.rs @@ -390,85 +390,85 @@ impl<'ast> From> for untyped::ExpressionNode<'ast> use crate::untyped::NodeValue; match expression.op { pest::BinaryOperator::Add => untyped::Expression::Add( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Sub => untyped::Expression::Sub( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Mul => untyped::Expression::Mult( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Div => untyped::Expression::Div( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Rem => untyped::Expression::Rem( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Eq => untyped::Expression::Eq( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Lt => untyped::Expression::Lt( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Lte => untyped::Expression::Le( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Gt => untyped::Expression::Gt( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Gte => untyped::Expression::Ge( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::And => untyped::Expression::And( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Or => untyped::Expression::Or( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Pow => untyped::Expression::Pow( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::BitXor => untyped::Expression::BitXor( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::LeftShift => untyped::Expression::LeftShift( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::RightShift => untyped::Expression::RightShift( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::BitAnd => untyped::Expression::BitAnd( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::BitOr => untyped::Expression::BitOr( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), // rewrite (a != b)` as `!(a == b)` - pest::BinaryOperator::NotEq => untyped::Expression::Not( - box untyped::Expression::Eq( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + pest::BinaryOperator::NotEq => untyped::Expression::Not(Box::new( + untyped::Expression::Eq( + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ) .span(expression.span.clone()), - ), + )), } .span(expression.span) } @@ -477,22 +477,22 @@ impl<'ast> From> for untyped::ExpressionNode<'ast> impl<'ast> From> for untyped::ExpressionNode<'ast> { fn from(expression: pest::IfElseExpression<'ast>) -> untyped::ExpressionNode<'ast> { use crate::untyped::NodeValue; - untyped::Expression::Conditional(box ConditionalExpression { - condition: box untyped::ExpressionNode::from(*expression.condition), + untyped::Expression::Conditional(Box::new(ConditionalExpression { + condition: Box::new(untyped::ExpressionNode::from(*expression.condition)), consequence_statements: expression .consequence_statements .into_iter() .map(untyped::StatementNode::from) .collect(), - consequence: box untyped::ExpressionNode::from(*expression.consequence), + consequence: Box::new(untyped::ExpressionNode::from(*expression.consequence)), alternative_statements: expression .alternative_statements .into_iter() .map(untyped::StatementNode::from) .collect(), - alternative: box untyped::ExpressionNode::from(*expression.alternative), + alternative: Box::new(untyped::ExpressionNode::from(*expression.alternative)), kind: untyped::ConditionalKind::IfElse, - }) + })) .span(expression.span) } } @@ -500,14 +500,14 @@ impl<'ast> From> for untyped::ExpressionNode<'ast> impl<'ast> From> for untyped::ExpressionNode<'ast> { fn from(expression: pest::TernaryExpression<'ast>) -> untyped::ExpressionNode<'ast> { use crate::untyped::NodeValue; - untyped::Expression::Conditional(box ConditionalExpression { - condition: box untyped::ExpressionNode::from(*expression.condition), + untyped::Expression::Conditional(Box::new(ConditionalExpression { + condition: Box::new(untyped::ExpressionNode::from(*expression.condition)), consequence_statements: vec![], - consequence: box untyped::ExpressionNode::from(*expression.consequence), + consequence: Box::new(untyped::ExpressionNode::from(*expression.consequence)), alternative_statements: vec![], - alternative: box untyped::ExpressionNode::from(*expression.alternative), + alternative: Box::new(untyped::ExpressionNode::from(*expression.alternative)), kind: untyped::ConditionalKind::Ternary, - }) + })) .span(expression.span) } } @@ -617,7 +617,8 @@ impl<'ast> From> for untyped::ExpressionN let value = untyped::ExpressionNode::from(*initializer.value); let count = untyped::ExpressionNode::from(*initializer.count); - untyped::Expression::ArrayInitializer(box value, box count).span(initializer.span) + untyped::Expression::ArrayInitializer(Box::new(value), Box::new(count)) + .span(initializer.span) } } @@ -674,18 +675,20 @@ impl<'ast> From> for untyped::ExpressionNode<'ast> ) .span(a.span), pest::Access::Select(a) => untyped::Expression::Select( - box acc, - box untyped::RangeOrExpression::from(a.expression), + Box::new(acc), + Box::new(untyped::RangeOrExpression::from(a.expression)), ) .span(a.span), pest::Access::Dot(m) => match m.inner { pest::IdentifierOrDecimal::Identifier(id) => { - untyped::Expression::Member(box acc, box id.span.as_str()).span(m.span) - } - pest::IdentifierOrDecimal::Decimal(id) => { - untyped::Expression::Element(box acc, id.span.as_str().parse().unwrap()) + untyped::Expression::Member(Box::new(acc), Box::new(id.span.as_str())) .span(m.span) } + pest::IdentifierOrDecimal::Decimal(id) => untyped::Expression::Element( + Box::new(acc), + id.span.as_str().parse().unwrap(), + ) + .span(m.span), }, }) } @@ -783,15 +786,15 @@ impl<'ast> From> for untyped::AssigneeNode<'ast> { assignee.accesses.into_iter().fold(a, |acc, s| { match s { pest::AssigneeAccess::Select(s) => untyped::Assignee::Select( - box acc, - box untyped::RangeOrExpression::from(s.expression), + Box::new(acc), + Box::new(untyped::RangeOrExpression::from(s.expression)), ), pest::AssigneeAccess::Dot(a) => match a.inner { pest::IdentifierOrDecimal::Identifier(id) => { - untyped::Assignee::Member(box acc, box id.span.as_str()) + untyped::Assignee::Member(Box::new(acc), Box::new(id.span.as_str())) } pest::IdentifierOrDecimal::Decimal(id) => { - untyped::Assignee::Element(box acc, id.span.as_str().parse().unwrap()) + untyped::Assignee::Element(Box::new(acc), id.span.as_str().parse().unwrap()) } }, } @@ -1032,29 +1035,33 @@ mod tests { ( "field[2]", untyped::UnresolvedType::Array( - box untyped::UnresolvedType::FieldElement.mock(), + Box::new(untyped::UnresolvedType::FieldElement.mock()), untyped::Expression::IntConstant(2usize.into()).mock(), ), ), ( "field[2][3]", untyped::UnresolvedType::Array( - box untyped::UnresolvedType::Array( - box untyped::UnresolvedType::FieldElement.mock(), - untyped::Expression::IntConstant(3usize.into()).mock(), - ) - .mock(), + Box::new( + untyped::UnresolvedType::Array( + Box::new(untyped::UnresolvedType::FieldElement.mock()), + untyped::Expression::IntConstant(3usize.into()).mock(), + ) + .mock(), + ), untyped::Expression::IntConstant(2usize.into()).mock(), ), ), ( "bool[2][3u32]", untyped::UnresolvedType::Array( - box untyped::UnresolvedType::Array( - box untyped::UnresolvedType::Boolean.mock(), - untyped::Expression::U32Constant(3u32).mock(), - ) - .mock(), + Box::new( + untyped::UnresolvedType::Array( + Box::new(untyped::UnresolvedType::Boolean.mock()), + untyped::Expression::U32Constant(3u32).mock(), + ) + .mock(), + ), untyped::Expression::IntConstant(2usize.into()).mock(), ), ), @@ -1099,59 +1106,59 @@ mod tests { ( "a[3]", untyped::Expression::Select( - box untyped::Expression::Identifier("a").into(), - box untyped::RangeOrExpression::Expression( + Box::new(untyped::Expression::Identifier("a").into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(3usize.into()).into(), ), - ), + )), ), ( "a[3][4]", untyped::Expression::Select( - box untyped::Expression::Select( - box untyped::Expression::Identifier("a").into(), - box untyped::RangeOrExpression::Expression( + Box::new(untyped::Expression::Select( + Box::new(untyped::Expression::Identifier("a").into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(3usize.into()).into(), - ), + )), ) - .into(), - box untyped::RangeOrExpression::Expression( + .into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(4usize.into()).into(), - ), + )), ), ), ( "a(3)[4]", untyped::Expression::Select( - box untyped::Expression::FunctionCall( - box untyped::Expression::Identifier("a").mock(), + Box::new(untyped::Expression::FunctionCall( + Box::new(untyped::Expression::Identifier("a").mock()), None, vec![untyped::Expression::IntConstant(3usize.into()).into()], ) - .into(), - box untyped::RangeOrExpression::Expression( + .into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(4usize.into()).into(), - ), + )), ), ), ( "a(3)[4][5]", untyped::Expression::Select( - box untyped::Expression::Select( - box untyped::Expression::FunctionCall( - box untyped::Expression::Identifier("a").mock(), + Box::new(untyped::Expression::Select( + Box::new(untyped::Expression::FunctionCall( + Box::new(untyped::Expression::Identifier("a").mock()), None, vec![untyped::Expression::IntConstant(3usize.into()).into()], ) - .into(), - box untyped::RangeOrExpression::Expression( + .into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(4usize.into()).into(), - ), + )), ) - .into(), - box untyped::RangeOrExpression::Expression( + .into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(5usize.into()).into(), - ), + )), ), ), ]; @@ -1172,13 +1179,15 @@ mod tests { assert_eq!( untyped::Module::from(ast), wrap(untyped::Expression::FunctionCall( - box untyped::Expression::Select( - box untyped::Expression::Identifier("a").mock(), - box untyped::RangeOrExpression::Expression( - untyped::Expression::IntConstant(2u32.into()).mock() + Box::new( + untyped::Expression::Select( + Box::new(untyped::Expression::Identifier("a").mock()), + Box::new(untyped::RangeOrExpression::Expression( + untyped::Expression::IntConstant(2u32.into()).mock() + )) ) - ) - .mock(), + .mock() + ), None, vec![untyped::Expression::IntConstant(3u32.into()).mock()], )) @@ -1194,12 +1203,14 @@ mod tests { assert_eq!( untyped::Module::from(ast), wrap(untyped::Expression::FunctionCall( - box untyped::Expression::FunctionCall( - box untyped::Expression::Identifier("a").mock(), - None, - vec![untyped::Expression::IntConstant(2u32.into()).mock()] - ) - .mock(), + Box::new( + untyped::Expression::FunctionCall( + Box::new(untyped::Expression::Identifier("a").mock()), + None, + vec![untyped::Expression::IntConstant(2u32.into()).mock()] + ) + .mock() + ), None, vec![untyped::Expression::IntConstant(3u32.into()).mock()], )) @@ -1261,10 +1272,10 @@ mod tests { span: span.clone(), }), expression: pest::Expression::Postfix(pest::PostfixExpression { - base: box pest::Expression::Identifier(pest::IdentifierExpression { + base: Box::new(pest::Expression::Identifier(pest::IdentifierExpression { value: String::from("foo"), span: span.clone(), - }), + })), accesses: vec![pest::Access::Call(pest::CallAccess { explicit_generics: None, arguments: pest::Arguments { diff --git a/zokrates_ast/src/untyped/types.rs b/zokrates_ast/src/untyped/types.rs index 0f6c4ba45..f2a87ad86 100644 --- a/zokrates_ast/src/untyped/types.rs +++ b/zokrates_ast/src/untyped/types.rs @@ -69,7 +69,7 @@ impl<'ast> fmt::Display for UnresolvedType<'ast> { impl<'ast> UnresolvedType<'ast> { pub fn array(ty: UnresolvedTypeNode<'ast>, size: ExpressionNode<'ast>) -> Self { - UnresolvedType::Array(box ty, size) + UnresolvedType::Array(Box::new(ty), size) } } diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index a7ae6b959..3d56b07ef 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -26,6 +26,20 @@ impl<'ast> Identifier<'ast> { } } +impl<'ast> SourceIdentifier<'ast> { + pub fn select(self, index: u32) -> Self { + Self::Select(Box::new(self), index) + } + + pub fn member(self, member: MemberId) -> Self { + Self::Member(Box::new(self), member) + } + + pub fn element(self, index: u32) -> Self { + Self::Element(Box::new(self), index) + } +} + impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -39,9 +53,9 @@ impl<'ast> fmt::Display for SourceIdentifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { SourceIdentifier::Basic(i) => write!(f, "{}", i), - SourceIdentifier::Select(box i, index) => write!(f, "{}~{}", i, index), - SourceIdentifier::Member(box i, m) => write!(f, "{}.{}", i, m), - SourceIdentifier::Element(box i, index) => write!(f, "{}.{}", i, index), + SourceIdentifier::Select(i, index) => write!(f, "{}~{}", i, index), + SourceIdentifier::Member(i, m) => write!(f, "{}.{}", i, m), + SourceIdentifier::Element(i, index) => write!(f, "{}.{}", i, index), } } } diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index dcb4cefaf..a7cc08019 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -490,9 +490,9 @@ impl<'ast, T, E> ConditionalExpression<'ast, T, E> { pub fn new(condition: BooleanExpression<'ast, T>, consequence: E, alternative: E) -> Self { ConditionalExpression { span: None, - condition: box condition, - consequence: box consequence, - alternative: box alternative, + condition: Box::new(condition), + consequence: Box::new(consequence), + alternative: Box::new(alternative), } } } @@ -534,7 +534,7 @@ impl<'ast, T, E> SelectExpression<'ast, T, E> { SelectExpression { span: None, array, - index: box index, + index: Box::new(index), } } } diff --git a/zokrates_ast/src/zir/uint.rs b/zokrates_ast/src/zir/uint.rs index dfb0b3ec9..3ace2a152 100644 --- a/zokrates_ast/src/zir/uint.rs +++ b/zokrates_ast/src/zir/uint.rs @@ -153,7 +153,7 @@ impl UMetadata { } pub fn bitwidth(&self) -> u32 { - self.max.bits() as u32 + self.max.bits() } // issue the metadata for a parameter of a given bitwidth diff --git a/zokrates_book/src/gettingstarted.md b/zokrates_book/src/gettingstarted.md index 93caf96c2..d0aa512f3 100644 --- a/zokrates_book/src/gettingstarted.md +++ b/zokrates_book/src/gettingstarted.md @@ -25,7 +25,7 @@ You can build ZoKrates from [source](https://github.com/ZoKrates/ZoKrates/) with git clone https://github.com/ZoKrates/ZoKrates cd ZoKrates export ZOKRATES_STDLIB=$PWD/zokrates_stdlib/stdlib -cargo +nightly build -p zokrates_cli --release +cargo build -p zokrates_cli --release cd target/release ``` diff --git a/zokrates_cli/src/bin.rs b/zokrates_cli/src/bin.rs index a0aa12882..68e7d5ec4 100644 --- a/zokrates_cli/src/bin.rs +++ b/zokrates_cli/src/bin.rs @@ -1,5 +1,3 @@ -#![feature(panic_info_message)] -#![feature(backtrace)] // // @file bin.rs // @author Jacob Eberhardt @@ -86,21 +84,8 @@ fn cli() -> Result<(), String> { } fn panic_hook(pi: &std::panic::PanicInfo) { - let location = pi - .location() - .map(|l| format!("({})", l)) - .unwrap_or_default(); - - let message = pi - .message() - .map(|m| format!("{}", m)) - .or_else(|| pi.payload().downcast_ref::<&str>().map(|p| p.to_string())); - - if let Some(s) = message { - println!("{} {}", s, location); - } else { - println!("The compiler unexpectedly panicked {}", location); - } + println!("The compiler unexpectedly panicked"); + println!("{}", pi); #[cfg(debug_assertions)] { diff --git a/zokrates_cli/src/ops/compile.rs b/zokrates_cli/src/ops/compile.rs index 35d47b3f8..c7bc25a15 100644 --- a/zokrates_cli/src/ops/compile.rs +++ b/zokrates_cli/src/ops/compile.rs @@ -145,10 +145,10 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { // serialize flattened program and write to binary file log::debug!("Serialize program"); - let bin_output_file = File::create(&bin_output_path) + let bin_output_file = File::create(bin_output_path) .map_err(|why| format!("Could not create {}: {}", bin_output_path.display(), why))?; - let r1cs_output_file = File::create(&r1cs_output_path) + let r1cs_output_file = File::create(r1cs_output_path) .map_err(|why| format!("Could not create {}: {}", r1cs_output_path.display(), why))?; let mut bin_writer = BufWriter::new(bin_output_file); @@ -170,7 +170,7 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { Ok(constraint_count) => { // serialize ABI spec and write to JSON file log::debug!("Serialize ABI"); - let abi_spec_file = File::create(&abi_spec_path) + let abi_spec_file = File::create(abi_spec_path) .map_err(|why| format!("Could not create {}: {}", abi_spec_path.display(), why))?; let mut writer = BufWriter::new(abi_spec_file); diff --git a/zokrates_cli/src/ops/compute_witness.rs b/zokrates_cli/src/ops/compute_witness.rs index 0a886523b..466c88dfc 100644 --- a/zokrates_cli/src/ops/compute_witness.rs +++ b/zokrates_cli/src/ops/compute_witness.rs @@ -77,7 +77,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open {}: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open {}: {}", path.display(), why))?; let mut reader = BufReader::new(file); @@ -106,7 +106,7 @@ fn cli_compute<'a, T: Field, I: Iterator>>( let signature = match is_abi { true => { let path = Path::new(sub_matches.value_of("abi-spec").unwrap()); - let file = File::open(&path) + let file = File::open(path) .map_err(|why| format!("Could not open {}: {}", path.display(), why))?; let mut reader = BufReader::new(file); @@ -190,7 +190,7 @@ fn cli_compute<'a, T: Field, I: Iterator>>( // write witness to file let output_path = Path::new(sub_matches.value_of("output").unwrap()); - let output_file = File::create(&output_path) + let output_file = File::create(output_path) .map_err(|why| format!("Could not create {}: {}", output_path.display(), why))?; let writer = BufWriter::new(output_file); @@ -214,7 +214,7 @@ fn cli_compute<'a, T: Field, I: Iterator>>( // write circom witness to file let wtns_path = Path::new(sub_matches.value_of("circom-witness").unwrap()); - let wtns_file = File::create(&wtns_path) + let wtns_file = File::create(wtns_path) .map_err(|why| format!("Could not create {}: {}", output_path.display(), why))?; let mut writer = BufWriter::new(wtns_file); diff --git a/zokrates_cli/src/ops/export_verifier.rs b/zokrates_cli/src/ops/export_verifier.rs index 08b35a44a..f343a7374 100644 --- a/zokrates_cli/src/ops/export_verifier.rs +++ b/zokrates_cli/src/ops/export_verifier.rs @@ -35,7 +35,7 @@ pub fn subcommand() -> App<'static, 'static> { pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { let vk_path = Path::new(sub_matches.value_of("input").unwrap()); - let vk_file = File::open(&vk_path) + let vk_file = File::open(vk_path) .map_err(|why| format!("Could not open {}: {}", vk_path.display(), why))?; // deserialize vk to JSON @@ -84,7 +84,7 @@ fn cli_export_verifier App<'static, 'static> { pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { let program_path = Path::new(sub_matches.value_of("input").unwrap()); - let program_file = File::open(&program_path) + let program_file = File::open(program_path) .map_err(|why| format!("Could not open {}: {}", program_path.display(), why))?; let mut reader = BufReader::new(program_file); @@ -160,7 +160,7 @@ fn cli_generate_proof< // deserialize witness let witness_path = Path::new(sub_matches.value_of("witness").unwrap()); - let witness_file = File::open(&witness_path) + let witness_file = File::open(witness_path) .map_err(|why| format!("Could not open {}: {}", witness_path.display(), why))?; let witness_reader = BufReader::new(witness_file); @@ -171,7 +171,7 @@ fn cli_generate_proof< let pk_path = Path::new(sub_matches.value_of("proving-key-path").unwrap()); let proof_path = Path::new(sub_matches.value_of("proof-path").unwrap()); - let pk_file = File::open(&pk_path) + let pk_file = File::open(pk_path) .map_err(|why| format!("Could not open {}: {}", pk_path.display(), why))?; let pk_reader = BufReader::new(pk_file); diff --git a/zokrates_cli/src/ops/generate_smtlib2.rs b/zokrates_cli/src/ops/generate_smtlib2.rs index ac58f8e56..389372df9 100644 --- a/zokrates_cli/src/ops/generate_smtlib2.rs +++ b/zokrates_cli/src/ops/generate_smtlib2.rs @@ -35,7 +35,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open {}: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open {}: {}", path.display(), why))?; let mut reader = BufReader::new(file); diff --git a/zokrates_cli/src/ops/inspect.rs b/zokrates_cli/src/ops/inspect.rs index b8ca37545..1db259b6a 100644 --- a/zokrates_cli/src/ops/inspect.rs +++ b/zokrates_cli/src/ops/inspect.rs @@ -31,7 +31,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); diff --git a/zokrates_cli/src/ops/mpc/beacon.rs b/zokrates_cli/src/ops/mpc/beacon.rs index 619a95ed9..71412868b 100644 --- a/zokrates_cli/src/ops/mpc/beacon.rs +++ b/zokrates_cli/src/ops/mpc/beacon.rs @@ -73,7 +73,7 @@ fn cli_mpc_beacon, B: MpcBack ) -> Result<(), String> { let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); @@ -134,7 +134,7 @@ fn cli_mpc_beacon, B: MpcBack }; let output_path = Path::new(sub_matches.value_of("output").unwrap()); - let output_file = File::create(&output_path) + let output_file = File::create(output_path) .map_err(|why| format!("Could not create `{}`: {}", output_path.display(), why))?; let mut writer = BufWriter::new(output_file); diff --git a/zokrates_cli/src/ops/mpc/contribute.rs b/zokrates_cli/src/ops/mpc/contribute.rs index fcb7e5461..96128768c 100644 --- a/zokrates_cli/src/ops/mpc/contribute.rs +++ b/zokrates_cli/src/ops/mpc/contribute.rs @@ -70,12 +70,12 @@ pub fn cli_mpc_contribute< ) -> Result<(), String> { let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); let output_path = Path::new(sub_matches.value_of("output").unwrap()); - let output_file = File::create(&output_path) + let output_file = File::create(output_path) .map_err(|why| format!("Could not create `{}`: {}", output_path.display(), why))?; let mut writer = BufWriter::new(output_file); diff --git a/zokrates_cli/src/ops/mpc/export.rs b/zokrates_cli/src/ops/mpc/export.rs index cf640fa1c..87c780052 100644 --- a/zokrates_cli/src/ops/mpc/export.rs +++ b/zokrates_cli/src/ops/mpc/export.rs @@ -66,7 +66,7 @@ pub fn cli_mpc_export, B: Mpc ) -> Result<(), String> { let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); diff --git a/zokrates_cli/src/ops/mpc/init.rs b/zokrates_cli/src/ops/mpc/init.rs index 92a972482..72e2283c6 100644 --- a/zokrates_cli/src/ops/mpc/init.rs +++ b/zokrates_cli/src/ops/mpc/init.rs @@ -46,7 +46,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); @@ -76,7 +76,7 @@ fn cli_mpc_init< let mut radix_reader = BufReader::new(radix_file); let output_path = Path::new(sub_matches.value_of("output").unwrap()); - let output_file = File::create(&output_path) + let output_file = File::create(output_path) .map_err(|why| format!("Could not create `{}`: {}", output_path.display(), why))?; let mut writer = BufWriter::new(output_file); diff --git a/zokrates_cli/src/ops/mpc/verify.rs b/zokrates_cli/src/ops/mpc/verify.rs index d0520e64c..ef0371127 100644 --- a/zokrates_cli/src/ops/mpc/verify.rs +++ b/zokrates_cli/src/ops/mpc/verify.rs @@ -46,7 +46,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("circuit").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); @@ -71,7 +71,7 @@ fn cli_mpc_verify< let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let reader = BufReader::new(file); diff --git a/zokrates_cli/src/ops/print_proof.rs b/zokrates_cli/src/ops/print_proof.rs index a6df505af..cf0d25a7c 100644 --- a/zokrates_cli/src/ops/print_proof.rs +++ b/zokrates_cli/src/ops/print_proof.rs @@ -38,7 +38,7 @@ pub fn subcommand() -> App<'static, 'static> { pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { let proof_path = Path::new(sub_matches.value_of("proof-path").unwrap()); - let proof_file = File::open(&proof_path) + let proof_file = File::open(proof_path) .map_err(|why| format!("Could not open {}: {}", proof_path.display(), why))?; // deserialize proof to JSON diff --git a/zokrates_cli/src/ops/profile.rs b/zokrates_cli/src/ops/profile.rs index 79e0b1f71..a866beee8 100644 --- a/zokrates_cli/src/ops/profile.rs +++ b/zokrates_cli/src/ops/profile.rs @@ -26,7 +26,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); diff --git a/zokrates_cli/src/ops/setup.rs b/zokrates_cli/src/ops/setup.rs index 27ddc94c1..25bca25c8 100644 --- a/zokrates_cli/src/ops/setup.rs +++ b/zokrates_cli/src/ops/setup.rs @@ -95,7 +95,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Couldn't open {}: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Couldn't open {}: {}", path.display(), why))?; let mut reader = BufReader::new(file); let prog = ProgEnum::deserialize(&mut reader)?; @@ -146,7 +146,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { #[cfg(feature = "ark")] Parameters(BackendParameter::Ark, _, SchemeParameter::MARLIN) => { let setup_path = Path::new(sub_matches.value_of("universal-setup-path").unwrap()); - let setup_file = File::open(&setup_path) + let setup_file = File::open(setup_path) .map_err(|why| format!("Couldn't open {}: {}\nExpected an universal setup, make sure `zokrates universal-setup` was run`", setup_path.display(), why))?; let mut reader = BufReader::new(setup_file); diff --git a/zokrates_cli/src/ops/verify.rs b/zokrates_cli/src/ops/verify.rs index 24a96c387..568b13383 100644 --- a/zokrates_cli/src/ops/verify.rs +++ b/zokrates_cli/src/ops/verify.rs @@ -51,7 +51,7 @@ pub fn subcommand() -> App<'static, 'static> { pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { let vk_path = Path::new(sub_matches.value_of("verification-key-path").unwrap()); - let vk_file = File::open(&vk_path) + let vk_file = File::open(vk_path) .map_err(|why| format!("Could not open {}: {}", vk_path.display(), why))?; // deserialize vk to JSON @@ -60,7 +60,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { .map_err(|why| format!("Could not deserialize verification key: {}", why))?; let proof_path = Path::new(sub_matches.value_of("proof-path").unwrap()); - let proof_file = File::open(&proof_path) + let proof_file = File::open(proof_path) .map_err(|why| format!("Could not open {}: {}", proof_path.display(), why))?; // deserialize proof to JSON diff --git a/zokrates_cli/tests/integration.rs b/zokrates_cli/tests/integration.rs index 628dfaf18..68bc0abf6 100644 --- a/zokrates_cli/tests/integration.rs +++ b/zokrates_cli/tests/integration.rs @@ -313,7 +313,7 @@ mod integration { .unwrap(); // load the expected witness - let expected_witness_file = File::open(&expected_witness_path).unwrap(); + let expected_witness_file = File::open(expected_witness_path).unwrap(); let expected_witness: Witness = helpers::parse_witness_json(expected_witness_file).unwrap(); @@ -638,7 +638,7 @@ mod integration { .unwrap(); // load the expected smtlib2 - let mut expected_smtlib2_file = File::open(&expected_smtlib2_path).unwrap(); + let mut expected_smtlib2_file = File::open(expected_smtlib2_path).unwrap(); let mut expected_smtlib2 = String::new(); expected_smtlib2_file .read_to_string(&mut expected_smtlib2) diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index 8484d8b62..46faf587e 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(box_patterns, box_syntax)] - //! Module containing the `Flattener` to process a program that is R1CS-able. //! //! @file flatten.rs @@ -11,10 +9,7 @@ mod utils; use self::utils::flat_expression_from_bits; use zokrates_ast::{ - common::{ - expressions::{BinaryExpression, UnaryExpression, ValueExpression}, - Span, - }, + common::{expressions::ValueExpression, Span}, zir::{ canonicalizer::ZirCanonicalizer, ConditionalExpression, Expr, Folder, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, ZirExpressionList, ZirProgram, @@ -2658,133 +2653,71 @@ impl<'ast, T: Field> Flattener<'ast, T> { error.into(), ) } - // `!(x == 0)` can be asserted by giving the inverse of `x` - BooleanExpression::Not(UnaryExpression { - inner: - box BooleanExpression::UintEq(BinaryExpression { - left: - box UExpression { - inner: - UExpressionInner::Value(ValueExpression { - value: 0, .. - }), + BooleanExpression::Not(u) => { + let inner_span = u.get_span(); + + match *u.inner { + BooleanExpression::UintEq(b) => { + if let UExpressionInner::Value(ValueExpression { + value: 0, .. + }) = b.left.inner + { + let x = self + .flatten_uint_expression(statements_flattened, *b.right) + .get_field_unchecked(); + self.enforce_not_zero_assertion(statements_flattened, x) + } else if let UExpressionInner::Value(ValueExpression { + value: 0, + .. + }) = b.right.inner + { + let x = self + .flatten_uint_expression(statements_flattened, *b.left) + .get_field_unchecked(); + self.enforce_not_zero_assertion(statements_flattened, x) + } else { + self.enforce_naive_assertion( + statements_flattened, + BooleanExpression::not(BooleanExpression::UintEq(b)), + error, + ); + } + } + BooleanExpression::FieldEq(b) => match (*b.left, *b.right) { + ( + FieldElementExpression::Value(ValueExpression { + value: zero, .. - }, - right: box x, - .. - }), - .. - }) - | BooleanExpression::Not(UnaryExpression { - inner: - box BooleanExpression::UintEq(BinaryExpression { - right: - box UExpression { - inner: - UExpressionInner::Value(ValueExpression { - value: 0, .. - }), + }), + x, + ) + | ( + x, + FieldElementExpression::Value(ValueExpression { + value: zero, .. - }, - left: box x, - .. - }), - .. - }) => { - let x = self - .flatten_uint_expression(statements_flattened, x) - .get_field_unchecked(); - - // introduce intermediate variable - let x_id = self.define(x, statements_flattened); - - // check that `x` is not 0 by giving its inverse - let invx = self.use_sym(); - - // # invx = 1/x - statements_flattened.push_back(FlatStatement::Directive( - FlatDirective::new( - vec![invx], - Solver::Div, - vec![FlatExpression::value(T::one()), x_id.into()], - ), - )); - - // assert(invx * x == 1) - statements_flattened.push_back(FlatStatement::condition( - FlatExpression::value(T::one()), - FlatExpression::mul(invx.into(), x_id.into()), - RuntimeError::Inverse, - )); - } - // `!(x == 0)` can be asserted by giving the inverse of `x` - BooleanExpression::Not(UnaryExpression { - inner: - box BooleanExpression::FieldEq( - BinaryExpression { - left: box FieldElementExpression::Value(zero), - right: box x, - .. - }, - .., - ), - .. - }) - | BooleanExpression::Not(UnaryExpression { - inner: - box BooleanExpression::FieldEq( - BinaryExpression { - left: box x, - right: box FieldElementExpression::Value(zero), - .. - }, - .., - ), - .. - }) if zero.value == T::from(0) => { - let x = self.flatten_field_expression(statements_flattened, x); - - // introduce intermediate variable - let x_id = self.define(x, statements_flattened); - - // check that `x` is not 0 by giving its inverse - let invx = self.use_sym(); - - // # invx = 1/x - statements_flattened.push_back(FlatStatement::Directive( - FlatDirective::new( - vec![invx], - Solver::Div, - vec![FlatExpression::value(T::one()), x_id.into()], + }), + ) if zero == T::from(0) => { + let x = self.flatten_field_expression(statements_flattened, x); + self.enforce_not_zero_assertion(statements_flattened, x) + } + (left, right) => self.enforce_naive_assertion( + statements_flattened, + BooleanExpression::not( + BooleanExpression::field_eq(left, right).span(inner_span), + ) + .span(span), + error, + ), + }, + e => self.enforce_naive_assertion( + statements_flattened, + BooleanExpression::not(e.span(inner_span)).span(span), + error, ), - )); - - // assert(invx * x == 1) - statements_flattened.push_back(FlatStatement::condition( - FlatExpression::value(T::one()), - FlatExpression::mul(invx.into(), x_id.into()), - RuntimeError::Inverse, - )); - } - e => { - // naive approach: flatten the boolean to a single field element and constrain it to 1 - let e = self.flatten_boolean_expression(statements_flattened, e); - - if e.is_linear() { - statements_flattened.push_back(FlatStatement::condition( - e, - FlatExpression::value(T::from(1)), - error.into(), - )); - } else { - // swap so that left side is linear - statements_flattened.push_back(FlatStatement::condition( - FlatExpression::value(T::from(1)), - e, - error.into(), - )); } } + e => self.enforce_naive_assertion(statements_flattened, e, error), } } ZirStatement::MultipleDefinition(s) => { @@ -2868,6 +2801,63 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened.set_span(span_backup); } + fn enforce_naive_assertion( + &mut self, + statements_flattened: &mut FlatStatements<'ast, T>, + e: BooleanExpression<'ast, T>, + error: zokrates_ast::zir::RuntimeError, + ) { + // naive approach: flatten the boolean to a single field element and constrain it to 1 + let e = self.flatten_boolean_expression(statements_flattened, e); + + if e.is_linear() { + statements_flattened.push_back(FlatStatement::condition( + e, + FlatExpression::value(T::from(1)), + error.into(), + )); + } else { + // swap so that left side is linear + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::from(1)), + e, + error.into(), + )); + } + } + + /// Enforce that x is not zero + /// + /// # Arguments + /// + /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added. + /// * `x` - `FlatExpression` The expression to be constrained to not be zero. + fn enforce_not_zero_assertion( + &mut self, + statements_flattened: &mut FlatStatements<'ast, T>, + x: FlatExpression, + ) { + // introduce intermediate variable + let x_id = self.define(x, statements_flattened); + + // check that `x` is not 0 by giving its inverse + let invx = self.use_sym(); + + // # invx = 1/x + statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( + vec![invx], + Solver::Div, + vec![FlatExpression::value(T::one()), x_id.into()], + ))); + + // assert(invx * x == 1) + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::one()), + FlatExpression::mul(invx.into(), x_id.into()), + RuntimeError::Inverse, + )); + } + /// Flattens an equality assertion, enforcing it in the circuit. /// /// # Arguments diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index d1a57705b..4a0505b7d 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -153,18 +153,12 @@ impl fmt::Display for CompileErrorInner { CompileErrorInner::ParserError(ref e) => write!(f, "\n\t{}", e), CompileErrorInner::MacroError(ref e) => write!(f, "\n\t{}", e), CompileErrorInner::SemanticError(ref e) => { - let location = e - .pos() - .map(|p| format!("{}", p.from)) - .unwrap_or_else(|| "".to_string()); + let location = e.pos().map(|p| format!("{}", p.from)).unwrap_or_default(); write!(f, "{}\n\t{}", location, e.message()) } CompileErrorInner::ReadError(ref e) => write!(f, "\n\t{}", e), CompileErrorInner::ImportError(ref e) => { - let location = e - .span() - .map(|p| format!("{}", p.from)) - .unwrap_or_else(|| "".to_string()); + let location = e.span().map(|p| format!("{}", p.from)).unwrap_or_default(); write!(f, "{}\n\t{}", location, e.message()) } CompileErrorInner::AnalysisError(ref e) => write!(f, "\n\t{}", e), @@ -327,7 +321,7 @@ mod test { assert!(e.0[0] .value() .to_string() - .contains(&"Cannot resolve import without a resolver")); + .contains("Cannot resolve import without a resolver")); } #[test] @@ -448,15 +442,15 @@ struct Bar { field a; } vec![], vec![ConcreteStructMember { id: "b".into(), - ty: box ConcreteType::Struct(ConcreteStructType::new( + ty: Box::new(ConcreteType::Struct(ConcreteStructType::new( "bar".into(), "Bar".into(), vec![], - vec![ConcreteStructMember { - id: "a".into(), - ty: box ConcreteType::FieldElement - }] - )) + vec![ConcreteStructMember::new( + "a".into(), + ConcreteType::FieldElement + )] + ))) }] )) }], diff --git a/zokrates_core/src/lib.rs b/zokrates_core/src/lib.rs index 51de21db8..874b996ac 100644 --- a/zokrates_core/src/lib.rs +++ b/zokrates_core/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(box_patterns, box_syntax)] - pub mod compile; pub mod imports; mod macros; diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index 770e9c025..55880465a 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -63,7 +63,6 @@ impl RedefinitionOptimizer { .into_iter() .chain(p.arguments.iter().map(|p| p.id)) .chain(p.returns()) - .into_iter() .collect(), } } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index a0d7c5081..5be636cbb 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -372,7 +372,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(TypedProgram { main: program.main, - module_map: ModuleMap::new(state.typed_modules.iter().map(|(id, _)| id).cloned()), + module_map: ModuleMap::new(state.typed_modules.keys().cloned()), modules: state.typed_modules, }) } @@ -804,7 +804,6 @@ impl<'ast, T: Field> Checker<'ast, T> { .get(&import.module_id) .unwrap() .functions_iter() - .into_iter() .filter(|d| d.key.id == import.symbol_id) .map(|d| DeclarationFunctionKey { module: import.module_id.to_path_buf(), @@ -1052,7 +1051,6 @@ impl<'ast, T: Field> Checker<'ast, T> { fn check_single_main(module: &TypedModule) -> Result<(), ErrorInner> { match module .functions_iter() - .into_iter() .filter(|d| d.key.id == "main") .count() { @@ -1220,7 +1218,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } if !found_return { - match (&*s.output).is_empty_tuple() { + match (*s.output).is_empty_tuple() { true => statements_checked.push( TypedStatement::ret(TypedExpression::empty_tuple()).with_span(span), ), @@ -1769,9 +1767,9 @@ impl<'ast, T: Field> Checker<'ast, T> { ) -> Result, ErrorInner> { match expr.value { // for function calls, check the rhs with the expected type - Expression::FunctionCall(box fun_id_expression, generics, arguments) => self + Expression::FunctionCall(fun_id_expression, generics, arguments) => self .check_function_call_expression( - fun_id_expression, + *fun_id_expression, generics, arguments, Some(return_type), @@ -1932,11 +1930,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let e_checked = e .map(|e| { match e.value { - Expression::FunctionCall( - box fun_id_expression, - generics, - arguments, - ) => { + Expression::FunctionCall(fun_id_expression, generics, arguments) => { let ty = zokrates_ast::typed::types::try_from_g_type( return_type.clone(), ) @@ -1944,7 +1938,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .unwrap(); self.check_function_call_expression( - fun_id_expression, + *fun_id_expression, generics, arguments, ty, @@ -2156,13 +2150,13 @@ impl<'ast, T: Field> Checker<'ast, T> { message: format!("Variable `{}` is undeclared", variable_name), }), }, - Assignee::Select(box assignee, box index) => { - let checked_assignee = self.check_assignee(assignee, module_id, types)?; + Assignee::Select(assignee, index) => { + let checked_assignee = self.check_assignee(*assignee, module_id, types)?; let ty = checked_assignee.get_type(); match ty { Type::Array(..) => { - let checked_index = match index { + let checked_index = match *index { RangeOrExpression::Expression(e) => { self.check_expression(e, module_id, types)? } @@ -2184,10 +2178,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }, )?; - Ok(TypedAssignee::Select( - box checked_assignee, - box checked_typed_index, - )) + Ok(TypedAssignee::select(checked_assignee, checked_typed_index)) } ty => Err(ErrorInner { span: Some(span), @@ -2198,13 +2189,13 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Assignee::Member(box assignee, box member) => { - let checked_assignee = self.check_assignee(assignee, module_id, types)?; + Assignee::Member(assignee, member) => { + let checked_assignee = self.check_assignee(*assignee, module_id, types)?; let ty = checked_assignee.get_type(); match &ty { - Type::Struct(members) => match members.iter().find(|m| m.id == member) { - Some(_) => Ok(TypedAssignee::Member(box checked_assignee, member.into())), + Type::Struct(members) => match members.iter().find(|m| m.id == *member) { + Some(_) => Ok(TypedAssignee::member(checked_assignee, (*member).into())), None => Err(ErrorInner { span: Some(span), message: format!( @@ -2229,13 +2220,13 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Assignee::Element(box assignee, index) => { - let checked_assignee = self.check_assignee(assignee, module_id, types)?; + Assignee::Element(assignee, index) => { + let checked_assignee = self.check_assignee(*assignee, module_id, types)?; let ty = checked_assignee.get_type(); match &ty { Type::Tuple(tuple_ty) => match tuple_ty.elements.get(index as usize) { - Some(_) => Ok(TypedAssignee::Element(box checked_assignee, index)), + Some(_) => Ok(TypedAssignee::element(checked_assignee, index)), None => Err(ErrorInner { span: Some(span), message: format!( @@ -2413,7 +2404,7 @@ impl<'ast, T: Field> Checker<'ast, T> { function_key, generics_checked, arguments_checked, - ).annotate(*array_ty.ty, *array_ty.size).into()), + ).annotate(array_ty).into()), Type::Tuple(tuple_ty) => Ok(TupleExpression::function_call( function_key, generics_checked, @@ -2462,7 +2453,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(FieldElementExpression::identifier(id.into()).into()) } Type::Array(array_type) => Ok(ArrayExpression::identifier(id.into()) - .annotate(*array_type.ty, *array_type.size) + .annotate(array_type) .into()), Type::Struct(members) => Ok(StructExpression::identifier(id.into()) .annotate(members) @@ -2479,9 +2470,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Add(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Add(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; use self::TypedExpression::*; @@ -2515,9 +2506,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Sub(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Sub(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; use self::TypedExpression::*; @@ -2544,9 +2535,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Mult(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Mult(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; use self::TypedExpression::*; @@ -2579,9 +2570,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Div(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Div(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; use self::TypedExpression::*; @@ -2612,9 +2603,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Rem(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Rem(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2641,9 +2632,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Pow(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Pow(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let e1_checked = match FieldElementExpression::try_from_typed(e1_checked) { Ok(e) => e.into(), @@ -2669,8 +2660,8 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Neg(box e) => { - let e = self.check_expression(e, module_id, types)?; + Expression::Neg(e) => { + let e = self.check_expression(*e, module_id, types)?; match e { TypedExpression::Int(e) => Ok(IntExpression::neg(e).into()), @@ -2688,8 +2679,8 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Pos(box e) => { - let e = self.check_expression(e, module_id, types)?; + Expression::Pos(e) => { + let e = self.check_expression(*e, module_id, types)?; match e { TypedExpression::Int(e) => Ok(IntExpression::pos(e).into()), @@ -2707,7 +2698,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Conditional(box conditional) => { + Expression::Conditional(conditional) => { let condition_checked = self.check_expression(*conditional.condition, module_id, types)?; @@ -2800,18 +2791,18 @@ impl<'ast, T: Field> Checker<'ast, T> { Expression::U16Constant(n) => Ok(UExpression::value(n.into()).annotate(16).with_span(span).into()), Expression::U32Constant(n) => Ok(UExpression::value(n.into()).annotate(32).with_span(span).into()), Expression::U64Constant(n) => Ok(UExpression::value(n.into()).annotate(64).with_span(span).into()), - Expression::FunctionCall(box fun_id_expression, generics, arguments) => self + Expression::FunctionCall(fun_id_expression, generics, arguments) => self .check_function_call_expression( - fun_id_expression, + *fun_id_expression, generics, arguments, None, module_id, types, ), - Expression::Lt(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Lt(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2859,9 +2850,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Le(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Le(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2909,9 +2900,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Eq(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Eq(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -2960,9 +2951,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Ge(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Ge(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -3010,9 +3001,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Gt(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Gt(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -3060,10 +3051,10 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Select(box array, box index) => { - let array = self.check_expression(array, module_id, types)?; + Expression::Select(array, index) => { + let array = self.check_expression(*array, module_id, types)?; - match index { + match *index { RangeOrExpression::Range(r) => { match array { TypedExpression::Array(array) => { @@ -3108,7 +3099,7 @@ impl<'ast, T: Field> Checker<'ast, T> { from.clone(), to.clone(), )) - .annotate(inner_type, UExpression::floor_sub(to, from)) + .annotate(ArrayType::new(inner_type, UExpression::floor_sub(to, from))) .into()) } e => Err(ErrorInner { @@ -3163,8 +3154,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } } - Expression::Element(box e, index) => { - let e = self.check_expression(e, module_id, types)?; + Expression::Element(e, index) => { + let e = self.check_expression(*e, module_id, types)?; match e { TypedExpression::Tuple(t) => { let ty = t.ty().elements.get(index as usize); @@ -3201,13 +3192,13 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Member(box e, box id) => { - let e = self.check_expression(e, module_id, types)?; + Expression::Member(e, id) => { + let e = self.check_expression(*e, module_id, types)?; match e { TypedExpression::Struct(s) => { // check that the struct has that field and return the type if it does - let ty = s.ty().iter().find(|m| m.id == id).map(|m| *m.ty.clone()); + let ty = s.ty().iter().find(|m| m.id == *id).map(|m| *m.ty.clone()); match ty { Some(ty) => match ty { @@ -3326,7 +3317,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok( ArrayExpression::value(unwrapped_expressions_or_spreads) - .annotate(inferred_type, size) + .annotate(ArrayType::new(inferred_type, size)) .into(), ) } @@ -3338,11 +3329,11 @@ impl<'ast, T: Field> Checker<'ast, T> { let ty = TupleType::new(elements.iter().map(|e| e.get_type()).collect()); Ok(TupleExpression::value(elements).annotate(ty).into()) } - Expression::ArrayInitializer(box e, box count) => { - let e = self.check_expression(e, module_id, types)?; + Expression::ArrayInitializer(e, count) => { + let e = self.check_expression(*e, module_id, types)?; let ty = e.get_type(); - let count = self.check_expression(count, module_id, types)?; + let count = self.check_expression(*count, module_id, types)?; let count = UExpression::try_from_typed(count, &UBitwidth::B32).map_err(|e| { ErrorInner { @@ -3356,7 +3347,7 @@ impl<'ast, T: Field> Checker<'ast, T> { })?; Ok(ArrayExpressionInner::Repeat(RepeatExpression::new(e, count.clone())) - .annotate(ty, count) + .annotate(ArrayType::new(ty, count)) .into()) } Expression::InlineStruct(id, inline_members) => { @@ -3467,10 +3458,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ), }) } else { - Ok(StructMember { - id: m.id.clone(), - ty: box v.get_type().clone(), - }) + Ok(StructMember::new(m.id.clone(), v.get_type().clone())) } }) .collect::, _>>()?; @@ -3488,9 +3476,9 @@ impl<'ast, T: Field> Checker<'ast, T> { .annotate(inferred_struct_type) .into()) } - Expression::And(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::And(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -3522,9 +3510,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Or(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Or(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; match (e1_checked, e2_checked) { (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { Ok(BooleanExpression::bitor(e1, e2).into()) @@ -3539,9 +3527,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::LeftShift(box e1, box e2) => { - let e1 = self.check_expression(e1, module_id, types)?; - let e2 = self.check_expression(e2, module_id, types)?; + Expression::LeftShift(e1, e2) => { + let e1 = self.check_expression(*e1, module_id, types)?; + let e2 = self.check_expression(*e2, module_id, types)?; let e2 = UExpression::try_from_typed(e2, &UBitwidth::B32).map_err(|e| ErrorInner { @@ -3568,9 +3556,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::RightShift(box e1, box e2) => { - let e1 = self.check_expression(e1, module_id, types)?; - let e2 = self.check_expression(e2, module_id, types)?; + Expression::RightShift(e1, e2) => { + let e1 = self.check_expression(*e1, module_id, types)?; + let e2 = self.check_expression(*e2, module_id, types)?; let e2 = UExpression::try_from_typed(e2, &UBitwidth::B32).map_err(|e| ErrorInner { @@ -3599,9 +3587,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::BitOr(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::BitOr(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -3633,9 +3621,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::BitAnd(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::BitAnd(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -3667,9 +3655,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::BitXor(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::BitXor(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, @@ -3701,8 +3689,8 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Not(box e) => { - let e_checked = self.check_expression(e, module_id, types)?; + Expression::Not(e) => { + let e_checked = self.check_expression(*e, module_id, types)?; match e_checked { TypedExpression::Int(e) => Ok(IntExpression::not(e).into()), TypedExpression::Boolean(e) => Ok(BooleanExpression::not(e).into()), @@ -3766,7 +3754,7 @@ mod tests { // The value of `P - 1` is a valid field literal let expr = Expression::FieldConstant(Bn128Field::max_value().to_biguint()).mock(); assert!(Checker::::default() - .check_expression(expr, &*MODULE_ID, &TypeMap::new()) + .check_expression(expr, &MODULE_ID, &TypeMap::new()) .is_ok()); } @@ -3777,7 +3765,7 @@ mod tests { let expr = Expression::FieldConstant(value).mock(); assert!(Checker::::default() - .check_expression(expr, &*MODULE_ID, &TypeMap::new()) + .check_expression(expr, &MODULE_ID, &TypeMap::new()) .is_err()); } } @@ -3800,7 +3788,7 @@ mod tests { ]) .mock(); assert!(Checker::::default() - .check_expression(a, &*MODULE_ID, &types) + .check_expression(a, &MODULE_ID, &types) .is_err()); // [[0f], [0f, 0f]] @@ -3820,7 +3808,7 @@ mod tests { ]) .mock(); assert!(Checker::::default() - .check_expression(a, &*MODULE_ID, &types) + .check_expression(a, &MODULE_ID, &types) .is_ok()); // [[0f], true] @@ -3836,7 +3824,7 @@ mod tests { ]) .mock(); assert!(Checker::::default() - .check_expression(a, &*MODULE_ID, &types) + .check_expression(a, &MODULE_ID, &types) .is_err()); } } @@ -4051,7 +4039,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "foo conflicts with another symbol" @@ -4134,7 +4122,7 @@ mod tests { ); let mut checker: Checker = Checker::default(); - assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); + assert!(checker.check_module(&MODULE_ID, &mut state).is_ok()); } mod generics { @@ -4176,7 +4164,7 @@ mod tests { ); let mut checker: Checker = Checker::default(); - assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); + assert!(checker.check_module(&MODULE_ID, &mut state).is_ok()); } #[test] @@ -4233,7 +4221,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "Undeclared symbol `P`" @@ -4273,7 +4261,7 @@ mod tests { ); let mut checker: Checker = Checker::default(); - assert_eq!(checker.check_module(&*MODULE_ID, &mut state), Ok(())); + assert_eq!(checker.check_module(&MODULE_ID, &mut state), Ok(())); assert!(state .typed_modules .get(&*MODULE_ID) @@ -4324,7 +4312,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "foo conflicts with another symbol" @@ -4368,7 +4356,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "foo conflicts with another symbol" @@ -4420,7 +4408,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "foo conflicts with another symbol" @@ -4469,7 +4457,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "foo conflicts with another symbol" @@ -4498,12 +4486,12 @@ mod tests { let state = State::new(modules, (*MODULE_ID).clone()); let signature = UnresolvedSignature::new().inputs(vec![UnresolvedType::Array( - box UnresolvedType::FieldElement.mock(), + Box::new(UnresolvedType::FieldElement.mock()), Expression::Identifier("K").mock(), ) .mock()]); assert_eq!( - Checker::::default().check_signature(signature, &*MODULE_ID, &state), + Checker::::default().check_signature(signature, &MODULE_ID, &state), Err(vec![ErrorInner { span: Some(SourceSpan::mock()), message: "Undeclared symbol `K`".to_string() @@ -4520,27 +4508,31 @@ mod tests { let signature = UnresolvedSignature::new() .generics(vec!["K".mock(), "L".mock(), "M".mock()]) .inputs(vec![UnresolvedType::Array( - box UnresolvedType::Array( - box UnresolvedType::FieldElement.mock(), - Expression::Identifier("K").mock(), - ) - .mock(), + Box::new( + UnresolvedType::Array( + Box::new(UnresolvedType::FieldElement.mock()), + Expression::Identifier("K").mock(), + ) + .mock(), + ), Expression::Identifier("L").mock(), ) .mock()]) .output( UnresolvedType::Array( - box UnresolvedType::Array( - box UnresolvedType::FieldElement.mock(), - Expression::Identifier("L").mock(), - ) - .mock(), + Box::new( + UnresolvedType::Array( + Box::new(UnresolvedType::FieldElement.mock()), + Expression::Identifier("L").mock(), + ) + .mock(), + ), Expression::Identifier("K").mock(), ) .mock(), ); assert_eq!( - Checker::::default().check_signature(signature, &*MODULE_ID, &state), + Checker::::default().check_signature(signature, &MODULE_ID, &state), Ok(DeclarationSignature::new() .inputs(vec![DeclarationType::array(( DeclarationType::array(( @@ -4575,7 +4567,7 @@ mod tests { checker.enter_scope(); assert_eq!( - checker.check_statement(statement, &*MODULE_ID, &TypeMap::new()), + checker.check_statement(statement, &MODULE_ID, &TypeMap::new()), Err(vec![ErrorInner { span: Some(SourceSpan::mock()), message: "Identifier \"b\" is undefined".into() @@ -4606,7 +4598,7 @@ mod tests { let mut checker: Checker = new_with_args(scope, HashSet::new()); checker.enter_scope(); assert_eq!( - checker.check_statement(statement, &*MODULE_ID, &TypeMap::new()), + checker.check_statement(statement, &MODULE_ID, &TypeMap::new()), Ok(TypedStatement::definition( typed::Variable::field_element("a").into(), FieldElementExpression::identifier("b".into()).into() @@ -4675,7 +4667,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state), + checker.check_module(&MODULE_ID, &mut state), Err(vec![Error { inner: ErrorInner { span: Some(SourceSpan::mock()), @@ -4771,7 +4763,7 @@ mod tests { ); let mut checker: Checker = Checker::default(); - assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); + assert!(checker.check_module(&MODULE_ID, &mut state).is_ok()); } #[test] @@ -4805,7 +4797,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_function("foo", foo, &*MODULE_ID, &state), + checker.check_function("foo", foo, &MODULE_ID, &state), Err(vec![ErrorInner { span: Some(SourceSpan::mock()), message: "Identifier \"i\" is undefined".into() @@ -4884,7 +4876,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_function("foo", foo, &*MODULE_ID, &state), + checker.check_function("foo", foo, &MODULE_ID, &state), Ok(foo_checked) ); } @@ -4902,8 +4894,12 @@ mod tests { let bar_statements: Vec = vec![ Statement::Definition( untyped::Variable::immutable("a", UnresolvedType::FieldElement.mock()).mock(), - Expression::FunctionCall(box Expression::Identifier("foo").mock(), None, vec![]) - .mock(), + Expression::FunctionCall( + Box::new(Expression::Identifier("foo").mock()), + None, + vec![], + ) + .mock(), ) .mock(), Statement::Return(None).mock(), @@ -4931,7 +4927,7 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), functions); assert_eq!( - checker.check_function("bar", bar, &*MODULE_ID, &state), + checker.check_function("bar", bar, &MODULE_ID, &state), Err(vec![ErrorInner { span: Some(SourceSpan::mock()), message: @@ -4951,8 +4947,12 @@ mod tests { let bar_statements: Vec = vec![ Statement::Definition( untyped::Variable::immutable("a", UnresolvedType::FieldElement.mock()).mock(), - Expression::FunctionCall(box Expression::Identifier("foo").mock(), None, vec![]) - .mock(), + Expression::FunctionCall( + Box::new(Expression::Identifier("foo").mock()), + None, + vec![], + ) + .mock(), ) .mock(), Statement::Return(None).mock(), @@ -4970,7 +4970,7 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); assert_eq!( - checker.check_function("bar", bar, &*MODULE_ID, &state), + checker.check_function("bar", bar, &MODULE_ID, &state), Err(vec![ErrorInner { span: Some(SourceSpan::mock()), @@ -5007,8 +5007,12 @@ mod tests { let main_statements: Vec = vec![ Statement::Assignment( Assignee::Identifier("a").mock(), - Expression::FunctionCall(box Expression::Identifier("foo").mock(), None, vec![]) - .mock(), + Expression::FunctionCall( + Box::new(Expression::Identifier("foo").mock()), + None, + vec![], + ) + .mock(), ) .mock(), Statement::Return(None).mock(), @@ -5045,7 +5049,7 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state), + checker.check_module(&MODULE_ID, &mut state), Err(vec![Error { inner: ErrorInner { span: Some(SourceSpan::mock()), @@ -5099,14 +5103,18 @@ mod tests { .mock(), Statement::Assignment( Assignee::Select( - box Assignee::Identifier("a").mock(), - box RangeOrExpression::Expression( + Box::new(Assignee::Identifier("a").mock()), + Box::new(RangeOrExpression::Expression( untyped::Expression::IntConstant(0usize.into()).mock(), - ), + )), + ) + .mock(), + Expression::FunctionCall( + Box::new(Expression::Identifier("foo").mock()), + None, + vec![], ) .mock(), - Expression::FunctionCall(box Expression::Identifier("foo").mock(), None, vec![]) - .mock(), ) .mock(), Statement::Return(None).mock(), @@ -5142,7 +5150,7 @@ mod tests { ); let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); - assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); + assert!(checker.check_module(&MODULE_ID, &mut state).is_ok()); } #[test] @@ -5156,13 +5164,15 @@ mod tests { let bar_statements: Vec = vec![ Statement::Assertion( Expression::Eq( - box Expression::IntConstant(1usize.into()).mock(), - box Expression::FunctionCall( - box Expression::Identifier("foo").mock(), - None, - vec![], - ) - .mock(), + Box::new(Expression::IntConstant(1usize.into()).mock()), + Box::new( + Expression::FunctionCall( + Box::new(Expression::Identifier("foo").mock()), + None, + vec![], + ) + .mock(), + ), ) .mock(), None, @@ -5183,7 +5193,7 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); assert_eq!( - checker.check_function("bar", bar, &*MODULE_ID, &state), + checker.check_function("bar", bar, &MODULE_ID, &state), Err(vec![ErrorInner { span: Some(SourceSpan::mock()), @@ -5216,7 +5226,7 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); assert_eq!( - checker.check_function("bar", bar, &*MODULE_ID, &state), + checker.check_function("bar", bar, &MODULE_ID, &state), Err(vec![ErrorInner { span: Some(SourceSpan::mock()), message: "Identifier \"a\" is undefined".into() @@ -5254,7 +5264,7 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); assert_eq!( checker - .check_function("main", f, &*MODULE_ID, &state) + .check_function("main", f, &MODULE_ID, &state) .unwrap_err()[0] .message, "Duplicate name in function definition: `a` was previously declared as an argument, a generic parameter or a constant" @@ -5356,7 +5366,7 @@ mod tests { untyped::Expression::IntConstant(2usize.into()).mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ); let s2_checked: Result, Vec> = checker @@ -5367,7 +5377,7 @@ mod tests { untyped::Expression::IntConstant(2usize.into()).mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ); assert!(s2_checked.is_ok()); @@ -5388,7 +5398,7 @@ mod tests { untyped::Expression::IntConstant(2usize.into()).mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ); let s2_checked: Result, Vec> = checker @@ -5398,7 +5408,7 @@ mod tests { untyped::Expression::BooleanConstant(true).mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ); assert!(s2_checked.is_ok()); @@ -5507,7 +5517,7 @@ mod tests { .into_iter() .map(|s| { checker - .check_statement(s, &*MODULE_ID, &TypeMap::default()) + .check_statement(s, &MODULE_ID, &TypeMap::default()) .unwrap() }) .collect(); @@ -5539,7 +5549,7 @@ mod tests { let mut checker: Checker = Checker::default(); - checker.check_module(&*MODULE_ID, &mut state).unwrap(); + checker.check_module(&MODULE_ID, &mut state).unwrap(); (checker, state) } @@ -5567,7 +5577,7 @@ mod tests { Checker::::default().check_struct_type_declaration( "Foo".into(), declaration, - &*MODULE_ID, + &MODULE_ID, &state ), Ok(expected_type) @@ -5611,7 +5621,7 @@ mod tests { Checker::::default().check_struct_type_declaration( "Foo".into(), declaration, - &*MODULE_ID, + &MODULE_ID, &state ), Ok(expected_type) @@ -5646,7 +5656,7 @@ mod tests { .check_struct_type_declaration( "Foo".into(), declaration, - &*MODULE_ID, + &MODULE_ID, &state ) .unwrap_err()[0] @@ -5703,7 +5713,7 @@ mod tests { ); assert!(Checker::default() - .check_module(&*MODULE_ID, &mut state) + .check_module(&MODULE_ID, &mut state) .is_ok()); assert_eq!( state @@ -5763,7 +5773,7 @@ mod tests { ); assert!(Checker::default() - .check_module(&*MODULE_ID, &mut state) + .check_module(&MODULE_ID, &mut state) .is_err()); } @@ -5797,7 +5807,7 @@ mod tests { ); assert!(Checker::default() - .check_module(&*MODULE_ID, &mut state) + .check_module(&MODULE_ID, &mut state) .is_err()); } @@ -5849,7 +5859,7 @@ mod tests { ); assert!(Checker::default() - .check_module(&*MODULE_ID, &mut state) + .check_module(&MODULE_ID, &mut state) .is_err()); } } @@ -5879,7 +5889,7 @@ mod tests { assert_eq!( checker.check_type( UnresolvedType::User("Foo".into(), None).mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ), Ok(Type::Struct(StructType::new( @@ -5894,7 +5904,7 @@ mod tests { checker .check_type( UnresolvedType::User("Bar".into(), None).mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ) .unwrap_err() @@ -5927,15 +5937,17 @@ mod tests { assert_eq!( checker.check_expression( Expression::Member( - box Expression::InlineStruct( - "Foo".into(), - vec![("foo", Expression::IntConstant(42usize.into()).mock())] - ) - .mock(), + Box::new( + Expression::InlineStruct( + "Foo".into(), + vec![("foo", Expression::IntConstant(42usize.into()).mock())] + ) + .mock() + ), "foo".into() ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ), Ok(FieldElementExpression::member( @@ -5975,15 +5987,20 @@ mod tests { checker .check_expression( Expression::Member( - box Expression::InlineStruct( - "Foo".into(), - vec![("foo", Expression::IntConstant(42usize.into()).mock())] - ) - .mock(), + Box::new( + Expression::InlineStruct( + "Foo".into(), + vec![( + "foo", + Expression::IntConstant(42usize.into()).mock() + )] + ) + .mock() + ), "bar".into() ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ) .unwrap_err() @@ -6018,7 +6035,7 @@ mod tests { vec![("foo", Expression::IntConstant(42usize.into()).mock())] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ) .unwrap_err() @@ -6060,7 +6077,7 @@ mod tests { ] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ), Ok(StructExpression::value(vec![ @@ -6113,7 +6130,7 @@ mod tests { ] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ), Ok(StructExpression::value(vec![ @@ -6164,7 +6181,7 @@ mod tests { vec![("foo", Expression::IntConstant(42usize.into()).mock())] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ) .unwrap_err() @@ -6212,7 +6229,7 @@ mod tests { )] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ).unwrap_err() .message, @@ -6230,7 +6247,7 @@ mod tests { ] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ) .unwrap_err() @@ -6290,7 +6307,7 @@ mod tests { main.value.statements = vec![Statement::Return(Some( Expression::FunctionCall( - box Expression::Identifier("foo").mock(), + Box::new(Expression::Identifier("foo").mock()), None, vec![Expression::IntConstant(0usize.into()).mock()], ) @@ -6354,13 +6371,13 @@ mod tests { Expression::FieldConstant(42u32.into()).mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &TypeMap::new()), + checker.check_assignee(a, &MODULE_ID, &TypeMap::new()), Ok(TypedAssignee::Identifier(typed::Variable::new( "a", Type::FieldElement, @@ -6373,8 +6390,10 @@ mod tests { // field[3] a = [1, 2, 3] // a[2] = 42 let a = Assignee::Select( - box Assignee::Identifier("a").mock(), - box RangeOrExpression::Expression(Expression::IntConstant(2usize.into()).mock()), + Box::new(Assignee::Identifier("a").mock()), + Box::new(RangeOrExpression::Expression( + Expression::IntConstant(2usize.into()).mock(), + )), ) .mock(); @@ -6405,19 +6424,19 @@ mod tests { .mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &TypeMap::new()), + checker.check_assignee(a, &MODULE_ID, &TypeMap::new()), Ok(TypedAssignee::Select( - box TypedAssignee::Identifier(typed::Variable::new( + Box::new(TypedAssignee::Identifier(typed::Variable::new( "a", Type::array((Type::FieldElement, 3u32)), - )), - box 2u32.into() + ))), + Box::new(2u32.into()) )) ); } @@ -6427,14 +6446,18 @@ mod tests { // field[1][1] a = [[1]] // a[0][0] let a: AssigneeNode = Assignee::Select( - box Assignee::Select( - box Assignee::Identifier("a").mock(), - box RangeOrExpression::Expression( - Expression::IntConstant(0usize.into()).mock(), - ), - ) - .mock(), - box RangeOrExpression::Expression(Expression::IntConstant(0usize.into()).mock()), + Box::new( + Assignee::Select( + Box::new(Assignee::Identifier("a").mock()), + Box::new(RangeOrExpression::Expression( + Expression::IntConstant(0usize.into()).mock(), + )), + ) + .mock(), + ), + Box::new(RangeOrExpression::Expression( + Expression::IntConstant(0usize.into()).mock(), + )), ) .mock(); @@ -6465,22 +6488,22 @@ mod tests { .mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &TypeMap::new()), + checker.check_assignee(a, &MODULE_ID, &TypeMap::new()), Ok(TypedAssignee::Select( - box TypedAssignee::Select( - box TypedAssignee::Identifier(typed::Variable::new( + Box::new(TypedAssignee::Select( + Box::new(TypedAssignee::Identifier(typed::Variable::new( "a", Type::array((Type::array((Type::FieldElement, 1u32)), 1u32)), - )), - box 0u32.into() - ), - box 0u32.into() + ))), + Box::new(0u32.into()) + )), + Box::new(0u32.into()) )) ); } diff --git a/zokrates_embed/src/ark.rs b/zokrates_embed/src/ark.rs index 40ce41ddd..fa8544962 100644 --- a/zokrates_embed/src/ark.rs +++ b/zokrates_embed/src/ark.rs @@ -279,8 +279,8 @@ fn var_to_index(var: &FpVar, offset: usize) -> usize { fn new_g1(flat: &[T]) -> G1 { assert_eq!(flat.len(), 2); G1::new( - BLS12Fq::from_str(&*flat[0].to_dec_string()).unwrap(), - BLS12Fq::from_str(&*flat[1].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[0].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[1].to_dec_string()).unwrap(), false, ) } @@ -290,12 +290,12 @@ fn new_g2(flat: &[T]) -> G2 { assert_eq!(flat.len(), 4); G2::new( BLS12Fq2::new( - BLS12Fq::from_str(&*flat[0].to_dec_string()).unwrap(), - BLS12Fq::from_str(&*flat[1].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[0].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[1].to_dec_string()).unwrap(), ), BLS12Fq2::new( - BLS12Fq::from_str(&*flat[2].to_dec_string()).unwrap(), - BLS12Fq::from_str(&*flat[3].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[2].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[3].to_dec_string()).unwrap(), ), false, ) diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index ed3f26502..bc7a666c6 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -153,7 +153,7 @@ impl<'a> Resolver for JsResolver<'a> { Some(Component::Normal(_)) => { let path_normalized = normalize_path(path); let source = STDLIB - .get(&path_normalized.to_str().unwrap()) + .get(path_normalized.to_str().unwrap()) .ok_or_else(|| { Error::new(format!( "module `{}` not found in stdlib", diff --git a/zokrates_parser/src/lib.rs b/zokrates_parser/src/lib.rs index b4f9cb85c..eb61ce2d8 100644 --- a/zokrates_parser/src/lib.rs +++ b/zokrates_parser/src/lib.rs @@ -12,6 +12,7 @@ use pest::Parser; #[grammar = "zokrates.pest"] struct ZoKratesParser; +#[allow(clippy::result_large_err)] pub fn parse(input: &str) -> Result, Error> { ZoKratesParser::parse(Rule::file, input) } diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 65268b304..8772e82f4 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -1173,6 +1173,7 @@ impl fmt::Display for Error { } } +#[allow(clippy::result_large_err)] pub fn generate_ast(input: &str) -> Result { let parse_tree = parse(input).map_err(Error)?; Ok(Prog::from(parse_tree).0) diff --git a/zokrates_proof_systems/src/rng.rs b/zokrates_proof_systems/src/rng.rs index 9e2d51df5..814217f27 100644 --- a/zokrates_proof_systems/src/rng.rs +++ b/zokrates_proof_systems/src/rng.rs @@ -5,7 +5,7 @@ use rand_0_8::{rngs::StdRng, SeedableRng}; pub fn get_rng_from_entropy(entropy: &str) -> StdRng { let h = { let mut h = Blake2b::default(); - h.input(&entropy.as_bytes()); + h.input(entropy.as_bytes()); h.result() }; diff --git a/zokrates_proof_systems/src/to_token.rs b/zokrates_proof_systems/src/to_token.rs index fbefdc400..a2be8a963 100644 --- a/zokrates_proof_systems/src/to_token.rs +++ b/zokrates_proof_systems/src/to_token.rs @@ -8,8 +8,8 @@ use super::{ /// Helper methods for parsing group structure pub fn encode_g1_element(g: &G1Affine) -> (U256, U256) { ( - U256::from(&hex::decode(&g.0.trim_start_matches("0x")).unwrap()[..]), - U256::from(&hex::decode(&g.1.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.0.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.1.trim_start_matches("0x")).unwrap()[..]), ) } @@ -17,12 +17,12 @@ pub fn encode_g2_element(g: &G2Affine) -> ((U256, U256), (U256, U256)) { match g { G2Affine::Fq2(g) => ( ( - U256::from(&hex::decode(&g.0 .0.trim_start_matches("0x")).unwrap()[..]), - U256::from(&hex::decode(&g.0 .1.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.0 .0.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.0 .1.trim_start_matches("0x")).unwrap()[..]), ), ( - U256::from(&hex::decode(&g.1 .0.trim_start_matches("0x")).unwrap()[..]), - U256::from(&hex::decode(&g.1 .1.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.1 .0.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.1 .1.trim_start_matches("0x")).unwrap()[..]), ), ), _ => unreachable!(), @@ -30,7 +30,7 @@ pub fn encode_g2_element(g: &G2Affine) -> ((U256, U256), (U256, U256)) { } pub fn encode_fr_element(f: &Fr) -> U256 { - U256::from(&hex::decode(&f.trim_start_matches("0x")).unwrap()[..]) + U256::from(&hex::decode(f.trim_start_matches("0x")).unwrap()[..]) } pub trait ToToken: SolidityCompatibleScheme { diff --git a/zokrates_test_derive/src/lib.rs b/zokrates_test_derive/src/lib.rs index 5f7f10f9a..110a02753 100644 --- a/zokrates_test_derive/src/lib.rs +++ b/zokrates_test_derive/src/lib.rs @@ -9,7 +9,7 @@ pub fn write_tests(base: &str) { let base = Path::new(&base); let out_dir = env::var("OUT_DIR").unwrap(); let destination = Path::new(&out_dir).join("tests.rs"); - let test_file = File::create(&destination).unwrap(); + let test_file = File::create(destination).unwrap(); let mut writer = BufWriter::new(test_file); for p in glob(base.join("**/*.json").to_str().unwrap()).unwrap() { From 210285a07d9bef182bfadc17aa0b42ebb3130b0a Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 13 Apr 2023 17:37:10 +0200 Subject: [PATCH 32/32] bump versions, generate changelog --- CHANGELOG.md | 12 +++++++++ Cargo.lock | 38 ++++++++++++++--------------- changelogs/unreleased/1268-dark64 | 1 - changelogs/unreleased/1285-schaeff | 1 - changelogs/unreleased/1287-Turupawn | 1 - changelogs/unreleased/1288-schaeff | 1 - changelogs/unreleased/1289-dark64 | 1 - zokrates_abi/Cargo.toml | 2 +- zokrates_analysis/Cargo.toml | 2 +- zokrates_ark/Cargo.toml | 2 +- zokrates_ast/Cargo.toml | 2 +- zokrates_bellman/Cargo.toml | 2 +- zokrates_circom/Cargo.toml | 2 +- zokrates_cli/Cargo.toml | 2 +- zokrates_codegen/Cargo.toml | 2 +- zokrates_core/Cargo.toml | 2 +- zokrates_core_test/Cargo.toml | 2 +- zokrates_embed/Cargo.toml | 2 +- zokrates_field/Cargo.toml | 2 +- zokrates_interpreter/Cargo.toml | 2 +- zokrates_js/Cargo.toml | 2 +- zokrates_js/package.json | 2 +- zokrates_parser/Cargo.toml | 2 +- zokrates_pest_ast/Cargo.toml | 2 +- zokrates_proof_systems/Cargo.toml | 2 +- zokrates_test/Cargo.toml | 2 +- zokrates_test_derive/Cargo.toml | 2 +- 27 files changed, 51 insertions(+), 44 deletions(-) delete mode 100644 changelogs/unreleased/1268-dark64 delete mode 100644 changelogs/unreleased/1285-schaeff delete mode 100644 changelogs/unreleased/1287-Turupawn delete mode 100644 changelogs/unreleased/1288-schaeff delete mode 100644 changelogs/unreleased/1289-dark64 diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ea84c10a..0c24b138c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. ## [Unreleased] https://github.com/Zokrates/ZoKrates/compare/latest...develop +## [0.8.6] - 2023-04-13 + +### Release +- https://github.com/Zokrates/ZoKrates/releases/tag/0.8.6 + +### Changes +- Make ZoKrates build on stable rust (#1288, @schaeff) +- Introduce sourcemaps, introduce `inspect` command to identify costly parts of the source (#1285, @schaeff) +- Change witness format to binary, optimize backend integration code to improve proving time (#1289, @dark64) +- Fixed precedence issue on Sudoku example. (#1287, @Turupawn) +- Reduce compiled program size by deduplicating assembly solvers (#1268, @dark64) + ## [0.8.5] - 2023-03-28 ### Release diff --git a/Cargo.lock b/Cargo.lock index 93cdda61a..e74fd1a9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3090,7 +3090,7 @@ dependencies = [ [[package]] name = "zokrates_abi" -version = "0.1.7" +version = "0.1.8" dependencies = [ "serde", "serde_derive", @@ -3101,7 +3101,7 @@ dependencies = [ [[package]] name = "zokrates_analysis" -version = "0.1.1" +version = "0.1.2" dependencies = [ "cfg-if 0.1.10", "csv", @@ -3124,7 +3124,7 @@ dependencies = [ [[package]] name = "zokrates_ark" -version = "0.1.2" +version = "0.1.3" dependencies = [ "ark-bls12-377", "ark-bn254", @@ -3151,7 +3151,7 @@ dependencies = [ [[package]] name = "zokrates_ast" -version = "0.1.5" +version = "0.1.6" dependencies = [ "ark-bls12-377", "byteorder", @@ -3170,7 +3170,7 @@ dependencies = [ [[package]] name = "zokrates_bellman" -version = "0.1.1" +version = "0.1.2" dependencies = [ "bellman_ce", "getrandom", @@ -3187,7 +3187,7 @@ dependencies = [ [[package]] name = "zokrates_circom" -version = "0.1.2" +version = "0.1.3" dependencies = [ "bellman_ce", "byteorder", @@ -3200,7 +3200,7 @@ dependencies = [ [[package]] name = "zokrates_cli" -version = "0.8.5" +version = "0.8.6" dependencies = [ "assert_cli", "blake2 0.8.1", @@ -3243,7 +3243,7 @@ dependencies = [ [[package]] name = "zokrates_codegen" -version = "0.1.1" +version = "0.1.2" dependencies = [ "zokrates_ast", "zokrates_common", @@ -3261,7 +3261,7 @@ dependencies = [ [[package]] name = "zokrates_core" -version = "0.7.4" +version = "0.7.5" dependencies = [ "cfg-if 0.1.10", "csv", @@ -3286,7 +3286,7 @@ dependencies = [ [[package]] name = "zokrates_core_test" -version = "0.2.10" +version = "0.2.11" dependencies = [ "zokrates_test", "zokrates_test_derive", @@ -3294,7 +3294,7 @@ dependencies = [ [[package]] name = "zokrates_embed" -version = "0.1.9" +version = "0.1.10" dependencies = [ "ark-bls12-377", "ark-bw6-761", @@ -3312,7 +3312,7 @@ dependencies = [ [[package]] name = "zokrates_field" -version = "0.5.3" +version = "0.5.4" dependencies = [ "ark-bls12-377", "ark-bls12-381", @@ -3344,7 +3344,7 @@ dependencies = [ [[package]] name = "zokrates_interpreter" -version = "0.1.3" +version = "0.1.4" dependencies = [ "ark-bls12-377", "num", @@ -3360,7 +3360,7 @@ dependencies = [ [[package]] name = "zokrates_js" -version = "1.1.6" +version = "1.1.7" dependencies = [ "console_error_panic_hook", "getrandom", @@ -3390,7 +3390,7 @@ dependencies = [ [[package]] name = "zokrates_parser" -version = "0.3.3" +version = "0.3.4" dependencies = [ "glob 0.2.11", "pest", @@ -3399,7 +3399,7 @@ dependencies = [ [[package]] name = "zokrates_pest_ast" -version = "0.3.1" +version = "0.3.2" dependencies = [ "from-pest", "glob 0.2.11", @@ -3418,7 +3418,7 @@ dependencies = [ [[package]] name = "zokrates_proof_systems" -version = "0.1.1" +version = "0.1.2" dependencies = [ "blake2 0.8.1", "byteorder", @@ -3444,7 +3444,7 @@ dependencies = [ [[package]] name = "zokrates_test" -version = "0.2.1" +version = "0.2.2" dependencies = [ "getrandom", "rand 0.8.5", @@ -3466,7 +3466,7 @@ dependencies = [ [[package]] name = "zokrates_test_derive" -version = "0.0.1" +version = "0.0.2" dependencies = [ "glob 0.3.1", ] diff --git a/changelogs/unreleased/1268-dark64 b/changelogs/unreleased/1268-dark64 deleted file mode 100644 index 7d2f33daa..000000000 --- a/changelogs/unreleased/1268-dark64 +++ /dev/null @@ -1 +0,0 @@ -Reduce compiled program size by deduplicating assembly solvers \ No newline at end of file diff --git a/changelogs/unreleased/1285-schaeff b/changelogs/unreleased/1285-schaeff deleted file mode 100644 index aa8478a07..000000000 --- a/changelogs/unreleased/1285-schaeff +++ /dev/null @@ -1 +0,0 @@ -Introduce sourcemaps, introduce `inspect` command to identify costly parts of the source \ No newline at end of file diff --git a/changelogs/unreleased/1287-Turupawn b/changelogs/unreleased/1287-Turupawn deleted file mode 100644 index 7cb706c0d..000000000 --- a/changelogs/unreleased/1287-Turupawn +++ /dev/null @@ -1 +0,0 @@ -Fixed precedence issue on Sudoku example. diff --git a/changelogs/unreleased/1288-schaeff b/changelogs/unreleased/1288-schaeff deleted file mode 100644 index a3758ebea..000000000 --- a/changelogs/unreleased/1288-schaeff +++ /dev/null @@ -1 +0,0 @@ -Make ZoKrates build on stable rust \ No newline at end of file diff --git a/changelogs/unreleased/1289-dark64 b/changelogs/unreleased/1289-dark64 deleted file mode 100644 index 950a2f588..000000000 --- a/changelogs/unreleased/1289-dark64 +++ /dev/null @@ -1 +0,0 @@ -Change witness format to binary, optimize backend integration code to improve proving time \ No newline at end of file diff --git a/zokrates_abi/Cargo.toml b/zokrates_abi/Cargo.toml index 953aa53f9..7ee7e15ab 100644 --- a/zokrates_abi/Cargo.toml +++ b/zokrates_abi/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_abi" -version = "0.1.7" +version = "0.1.8" authors = ["Thibaut Schaeffer "] edition = "2018" diff --git a/zokrates_analysis/Cargo.toml b/zokrates_analysis/Cargo.toml index f93dc7d8d..911790400 100644 --- a/zokrates_analysis/Cargo.toml +++ b/zokrates_analysis/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_analysis" -version = "0.1.1" +version = "0.1.2" edition = "2021" [features] diff --git a/zokrates_ark/Cargo.toml b/zokrates_ark/Cargo.toml index 97bd86920..8aec37846 100644 --- a/zokrates_ark/Cargo.toml +++ b/zokrates_ark/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_ark" -version = "0.1.2" +version = "0.1.3" edition = "2021" [features] diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index 38d79d775..1e8e47c9f 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_ast" -version = "0.1.5" +version = "0.1.6" edition = "2021" [features] diff --git a/zokrates_bellman/Cargo.toml b/zokrates_bellman/Cargo.toml index fb4c70fee..0c832a678 100644 --- a/zokrates_bellman/Cargo.toml +++ b/zokrates_bellman/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_bellman" -version = "0.1.1" +version = "0.1.2" edition = "2021" [features] diff --git a/zokrates_circom/Cargo.toml b/zokrates_circom/Cargo.toml index 6a0118e31..04de4d523 100644 --- a/zokrates_circom/Cargo.toml +++ b/zokrates_circom/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_circom" -version = "0.1.2" +version = "0.1.3" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index 90066f8fc..8eda11008 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_cli" -version = "0.8.5" +version = "0.8.6" authors = ["Jacob Eberhardt ", "Dennis Kuhnert ", "Thibaut Schaeffer "] repository = "https://github.com/Zokrates/ZoKrates.git" edition = "2018" diff --git a/zokrates_codegen/Cargo.toml b/zokrates_codegen/Cargo.toml index 1d9ec3bc2..30fe981c0 100644 --- a/zokrates_codegen/Cargo.toml +++ b/zokrates_codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_codegen" -version = "0.1.1" +version = "0.1.2" edition = "2021" [features] diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index 2b2530047..e443f1768 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_core" -version = "0.7.4" +version = "0.7.5" edition = "2021" authors = ["Jacob Eberhardt ", "Dennis Kuhnert "] repository = "https://github.com/Zokrates/ZoKrates" diff --git a/zokrates_core_test/Cargo.toml b/zokrates_core_test/Cargo.toml index 19c802668..3730edd67 100644 --- a/zokrates_core_test/Cargo.toml +++ b/zokrates_core_test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_core_test" -version = "0.2.10" +version = "0.2.11" authors = ["schaeff "] edition = "2018" diff --git a/zokrates_embed/Cargo.toml b/zokrates_embed/Cargo.toml index 147a2f0cc..fcff9510e 100644 --- a/zokrates_embed/Cargo.toml +++ b/zokrates_embed/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_embed" -version = "0.1.9" +version = "0.1.10" authors = ["schaeff "] edition = "2018" diff --git a/zokrates_field/Cargo.toml b/zokrates_field/Cargo.toml index 8bcde430d..c73d38591 100644 --- a/zokrates_field/Cargo.toml +++ b/zokrates_field/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_field" -version = "0.5.3" +version = "0.5.4" authors = ["Thibaut Schaeffer ", "Guillaume Ballet "] edition = "2018" diff --git a/zokrates_interpreter/Cargo.toml b/zokrates_interpreter/Cargo.toml index 270436685..30c9f272d 100644 --- a/zokrates_interpreter/Cargo.toml +++ b/zokrates_interpreter/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_interpreter" -version = "0.1.3" +version = "0.1.4" edition = "2021" [features] diff --git a/zokrates_js/Cargo.toml b/zokrates_js/Cargo.toml index 243f12a80..4840d1c48 100644 --- a/zokrates_js/Cargo.toml +++ b/zokrates_js/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_js" -version = "1.1.6" +version = "1.1.7" authors = ["Darko Macesic"] edition = "2018" diff --git a/zokrates_js/package.json b/zokrates_js/package.json index 02b6e919e..14ca7ef7e 100644 --- a/zokrates_js/package.json +++ b/zokrates_js/package.json @@ -1,6 +1,6 @@ { "name": "zokrates-js", - "version": "1.1.6", + "version": "1.1.7", "module": "index.js", "main": "index-node.js", "description": "JavaScript bindings for ZoKrates", diff --git a/zokrates_parser/Cargo.toml b/zokrates_parser/Cargo.toml index 2928dd0fe..4c143fb5f 100644 --- a/zokrates_parser/Cargo.toml +++ b/zokrates_parser/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_parser" -version = "0.3.3" +version = "0.3.4" authors = ["JacobEberhardt "] edition = "2018" diff --git a/zokrates_pest_ast/Cargo.toml b/zokrates_pest_ast/Cargo.toml index 6e94a080d..a71ef66cd 100644 --- a/zokrates_pest_ast/Cargo.toml +++ b/zokrates_pest_ast/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_pest_ast" -version = "0.3.1" +version = "0.3.2" authors = ["schaeff "] edition = "2018" diff --git a/zokrates_proof_systems/Cargo.toml b/zokrates_proof_systems/Cargo.toml index 55fc91c82..b0b245101 100644 --- a/zokrates_proof_systems/Cargo.toml +++ b/zokrates_proof_systems/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_proof_systems" -version = "0.1.1" +version = "0.1.2" edition = "2021" [dependencies] diff --git a/zokrates_test/Cargo.toml b/zokrates_test/Cargo.toml index 0a601c703..45e01ce8a 100644 --- a/zokrates_test/Cargo.toml +++ b/zokrates_test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_test" -version = "0.2.1" +version = "0.2.2" authors = ["schaeff "] edition = "2018" diff --git a/zokrates_test_derive/Cargo.toml b/zokrates_test_derive/Cargo.toml index 753297675..5d67294b4 100644 --- a/zokrates_test_derive/Cargo.toml +++ b/zokrates_test_derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_test_derive" -version = "0.0.1" +version = "0.0.2" authors = ["schaeff "] edition = "2018"