Skip to content

Commit

Permalink
Better tests and in place operations
Browse files Browse the repository at this point in the history
  • Loading branch information
Pratyush committed Jan 3, 2024
1 parent 715d94a commit b03717a
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 89 deletions.
11 changes: 9 additions & 2 deletions src/boolean/not.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@ use super::Boolean;

impl<F: Field> Boolean<F> {
fn _not(&self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
result.not_in_place()?;
Ok(result)
}

pub fn not_in_place(&mut self) -> Result<(), SynthesisError> {
match *self {
Boolean::Constant(c) => Ok(Boolean::Constant(!c)),
Boolean::Var(ref v) => Ok(Boolean::Var(v.not().unwrap())),
Boolean::Constant(ref mut c) => *c = !*c,
Boolean::Var(ref mut v) => *v = v.not()?,
}
Ok(())
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/cmp.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use ark_ff::Field;
use ark_relations::r1cs::SynthesisError;

use crate::boolean::Boolean;
use crate::{boolean::Boolean, R1CSVar};

/// Specifies how to generate constraints for comparing two variables.
pub trait CmpGadget<F: Field> {
pub trait CmpGadget<F: Field>: R1CSVar<F> {
fn is_gt(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
other.is_lt(self)
}
Expand Down
98 changes: 76 additions & 22 deletions src/uint/and.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ use super::*;
impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
fn _and(&self, other: &Self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
for (a, b) in result.bits.iter_mut().zip(&other.bits) {
result._and_in_place(other)?;
Ok(result)
}

fn _and_in_place(&mut self, other: &Self) -> Result<(), SynthesisError> {
for (a, b) in self.bits.iter_mut().zip(&other.bits) {
*a &= b;
}
result.value = self.value.and_then(|a| Some(a & other.value?));
Ok(result)
self.value = self.value.and_then(|a| Some(a & other.value?));
Ok(())
}
}

Expand Down Expand Up @@ -70,8 +75,9 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<&'a Self> for UInt<N, T,
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: &Self) -> Self::Output {
self._and(&other).unwrap()
fn bitand(mut self, other: &Self) -> Self::Output {
self._and_in_place(other).unwrap();
self
}
}

Expand Down Expand Up @@ -102,7 +108,7 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<UInt<N, T, F>> for &'a UI
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: UInt<N, T, F>) -> Self::Output {
self._and(&other).unwrap()
other & self
}
}

Expand Down Expand Up @@ -133,7 +139,7 @@ impl<const N: usize, T: PrimUInt, F: Field> BitAnd<Self> for UInt<N, T, F> {
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: Self) -> Self::Output {
self._and(&other).unwrap()
self & &other
}
}

Expand All @@ -142,7 +148,7 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<T> for UInt<N, T, F> {

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: T) -> Self::Output {
self._and(&UInt::constant(other)).unwrap()
self & UInt::constant(other)
}
}

Expand All @@ -151,7 +157,7 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<&'a T> for UInt<N, T, F>

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: &'a T) -> Self::Output {
self._and(&Self::constant(*other)).unwrap()
self & UInt::constant(*other)
}
}

Expand All @@ -160,7 +166,7 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<&'a T> for &'a UInt<N, T,

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: &'a T) -> Self::Output {
self._and(&UInt::constant(*other)).unwrap()
self & UInt::constant(*other)
}
}

Expand All @@ -169,11 +175,10 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<T> for &'a UInt<N, T, F>

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: T) -> Self::Output {
self._and(&UInt::constant(other)).unwrap()
self & UInt::constant(other)
}
}


impl<const N: usize, T: PrimUInt, F: Field> BitAndAssign<Self> for UInt<N, T, F> {
/// Sets `self = self & other`.
///
Expand All @@ -200,8 +205,7 @@ impl<const N: usize, T: PrimUInt, F: Field> BitAndAssign<Self> for UInt<N, T, F>
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand_assign(&mut self, other: Self) {
let result = self._and(&other).unwrap();
*self = result;
self._and_in_place(&other).unwrap();
}
}

Expand Down Expand Up @@ -231,8 +235,21 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitAndAssign<&'a Self> for UInt<
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand_assign(&mut self, other: &'a Self) {
let result = self._and(other).unwrap();
*self = result;
self._and_in_place(&other).unwrap();
}
}

impl<const N: usize, T: PrimUInt, F: Field> BitAndAssign<T> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand_assign(&mut self, other: T) {
*self &= &Self::constant(other);
}
}

impl<'a, const N: usize, T: PrimUInt, F: Field> BitAndAssign<&'a T> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand_assign(&mut self, other: &'a T) {
*self &= &Self::constant(*other);
}
}

Expand All @@ -242,7 +259,7 @@ mod tests {
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
uint::test_utils::{run_binary_exhaustive_both, run_binary_random_both},
R1CSVar,
};
use ark_ff::PrimeField;
Expand Down Expand Up @@ -273,28 +290,65 @@ mod tests {
Ok(())
}

fn uint_and_native<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: T,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let computed = &a & b;
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected =
UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? & b), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}

#[test]
fn u8_and() {
run_binary_exhaustive(uint_and::<u8, 8, Fr>).unwrap()
run_binary_exhaustive_both(uint_and::<u8, 8, Fr>, uint_and_native::<u8, 8, Fr>).unwrap()
}

#[test]
fn u16_and() {
run_binary_random::<1000, 16, _, _>(uint_and::<u16, 16, Fr>).unwrap()
run_binary_random_both::<1000, 16, _, _>(
uint_and::<u16, 16, Fr>,
uint_and_native::<u16, 16, Fr>,
)
.unwrap()
}

#[test]
fn u32_and() {
run_binary_random::<1000, 32, _, _>(uint_and::<u32, 32, Fr>).unwrap()
run_binary_random_both::<1000, 32, _, _>(
uint_and::<u32, 32, Fr>,
uint_and_native::<u32, 32, Fr>,
)
.unwrap()
}

#[test]
fn u64_and() {
run_binary_random::<1000, 64, _, _>(uint_and::<u64, 64, Fr>).unwrap()
run_binary_random_both::<1000, 64, _, _>(
uint_and::<u64, 64, Fr>,
uint_and_native::<u64, 64, Fr>,
)
.unwrap()
}

#[test]
fn u128_and() {
run_binary_random::<1000, 128, _, _>(uint_and::<u128, 128, Fr>).unwrap()
run_binary_random_both::<1000, 128, _, _>(
uint_and::<u128, 128, Fr>,
uint_and_native::<u128, 128, Fr>,
)
.unwrap()
}
}
18 changes: 12 additions & 6 deletions src/uint/not.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@ use super::*;
impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
fn _not(&self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
for a in &mut result.bits {
*a = !&*a
}
result.value = self.value.map(Not::not);
result._not_in_place()?;
Ok(result)
}

fn _not_in_place(&mut self) -> Result<(), SynthesisError> {
for a in &mut self.bits {
a.not_in_place()?;
}
self.value = self.value.map(Not::not);
Ok(())
}
}

impl<'a, const N: usize, T: PrimUInt, F: Field> Not for &'a UInt<N, T, F> {
Expand Down Expand Up @@ -67,8 +72,9 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> Not for UInt<N, T, F> {
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self))]
fn not(self) -> Self::Output {
self._not().unwrap()
fn not(mut self) -> Self::Output {
self._not_in_place().unwrap();
self
}
}

Expand Down
Loading

0 comments on commit b03717a

Please sign in to comment.