Skip to content

Commit

Permalink
LessThan circuits and padding (#14)
Browse files Browse the repository at this point in the history
* Replace verbose fold. Add rayon `into_par_iter`.

Fix tabulation

Add missing import

* Add benchmarks

Fix random `Fp` slice creation

Fix formatting

* Fix random `Fp` vector creation

* Remove duplicate poly definitions

* Co-authored-by: Cesar Descalzo Blanco <[email protected]>

Remove unnecessary `PhantomData`
Co-authored-by: Antonio Mejías Gil <[email protected]>

Progress until 08/03
Co-authored-by: Antonio Mejías Gil <[email protected]>

constructing range check, wip

working on padding

println nightmare, issue apparently fixed

Remove redundant padding. Fix import.

code cleanup

further cleanup

found better way to fix padding

the 'better way' to fix padding actually broke other tests - reverted

This reverts commit d4727ba.

moved example to examples folder, refactored, tested many cases

* fix minor clippy warnings

* Add LessThanCircuitSafe example. Pass arguments as field elements instead of unsigned integers.

* added some documentation

* added message

* Remove unnecesary u64 hints. Cleanup get_msb_index.

* fixed clippy warning

* tweaked documentation

* Co-authored-by: Cesar Descalzo Blanco <[email protected]>

Remove unnecessary `PhantomData`
Co-authored-by: Antonio Mejías Gil <[email protected]>

Progress until 08/03
Co-authored-by: Antonio Mejías Gil <[email protected]>

constructing range check, wip

working on padding

println nightmare, issue apparently fixed

Remove redundant padding. Fix import.

code cleanup

further cleanup

found better way to fix padding

the 'better way' to fix padding actually broke other tests - reverted

This reverts commit d4727ba.

moved example to examples folder, refactored, tested many cases

* Add LessThanCircuitSafe example. Pass arguments as field elements instead of unsigned integers.

* added some documentation

* added message

* Remove unnecesary u64 hints. Cleanup get_msb_index.

* fixed clippy warning

* tweaked documentation

* removed unused

* fixed clippy warning

---------

Co-authored-by: Cesar Descalzo Blanco <[email protected]>
Co-authored-by: Cesar Descalzo Blanco <[email protected]>
Co-authored-by: mmagician <[email protected]>
  • Loading branch information
4 people authored Nov 15, 2024
1 parent afbf72b commit a6d6041
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 26 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ trademarks or logos is subject to and must follow
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.

## Examples

Run `cargo run --example EXAMPLE_NAME` to run the corresponding example. Leave `EXAMPLE_NAME` empty for a list of available examples.

## Benchmarks

Run `cargo bench` to run all benchmarks. Run `cargo bench --benches BENCH_NAME` to run a specific benchmark.
Run `cargo bench` to run all benchmarks. Run `cargo bench --benches BENCH_NAME` to run a specific benchmark.
236 changes: 236 additions & 0 deletions examples/less_than.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
use bellpepper_core::{
boolean::AllocatedBit, num::AllocatedNum, Circuit, ConstraintSystem, LinearCombination,
SynthesisError,
};
use ff::{PrimeField, PrimeFieldBits};
use pasta_curves::Fq;
use spartan2::{
errors::SpartanError,
traits::{snark::RelaxedR1CSSNARKTrait, Group},
SNARK,
};

fn num_to_bits_le_bounded<F: PrimeField + PrimeFieldBits, CS: ConstraintSystem<F>>(
cs: &mut CS,
n: AllocatedNum<F>,
num_bits: u8,
) -> Result<Vec<AllocatedBit>, SynthesisError> {
let opt_bits = match n.get_value() {
Some(v) => v
.to_le_bits()
.into_iter()
.take(num_bits as usize)
.map(Some)
.collect::<Vec<Option<bool>>>(),
None => vec![None; num_bits as usize],
};

// Add one witness per input bit in little-endian bit order
let bits_circuit = opt_bits.into_iter()
.enumerate()
// AllocateBit enforces the value to be 0 or 1 at the constraint level
.map(|(i, b)| AllocatedBit::alloc(cs.namespace(|| format!("b_{}", i)), b).unwrap())
.collect::<Vec<AllocatedBit>>();

let mut weighted_sum_lc = LinearCombination::zero();
let mut pow2 = F::ONE;

for bit in bits_circuit.iter() {
weighted_sum_lc = weighted_sum_lc + (pow2, bit.get_variable());
pow2 = pow2.double();
}

cs.enforce(
|| "bit decomposition check",
|lc| lc + &weighted_sum_lc,
|lc| lc + CS::one(),
|lc| lc + n.get_variable(),
);

Ok(bits_circuit)
}

fn get_msb_index<F: PrimeField + PrimeFieldBits>(n: F) -> u8 {
n.to_le_bits()
.into_iter()
.enumerate()
.rev()
.find(|(_, b)| *b)
.unwrap()
.0 as u8
}

// Constrains `input` < `bound`, where the LHS is a witness and the RHS is a
// constant. The bound must fit into `num_bits` bits (this is asserted in the
// circuit constructor).
// Important: it must be checked elsewhere (in an overarching circuit) that the
// input fits into `num_bits` bits - this is NOT constrained by this circuit
// in order to avoid duplications (hence "unsafe"). Cf. LessThanCircuitSafe for
// a safe version.
#[derive(Clone, Debug)]
struct LessThanCircuitUnsafe<F: PrimeField> {
bound: F, // Will be a constant in the constraits, not a variable
input: F, // Will be an input/output variable
num_bits: u8,
}

impl<F: PrimeField + PrimeFieldBits> LessThanCircuitUnsafe<F> {
fn new(bound: F, input: F, num_bits: u8) -> Self {
assert!(get_msb_index(bound) < num_bits);
Self {
bound,
input,
num_bits,
}
}
}

impl<F: PrimeField + PrimeFieldBits> Circuit<F> for LessThanCircuitUnsafe<F> {
fn synthesize<CS: ConstraintSystem<F>>(self, cs: &mut CS) -> Result<(), SynthesisError> {
assert!(F::NUM_BITS > self.num_bits as u32 + 1);

let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?;

let shifted_diff = AllocatedNum::alloc(cs.namespace(|| "shifted_diff"), || {
Ok(self.input + F::from(1 << self.num_bits) - self.bound)
})?;

cs.enforce(
|| "shifted_diff_computation",
|lc| lc + input.get_variable() + (F::from(1 << self.num_bits) - self.bound, CS::one()),
|lc: LinearCombination<F>| lc + CS::one(),
|lc| lc + shifted_diff.get_variable(),
);

let shifted_diff_bits = num_to_bits_le_bounded::<F, CS>(cs, shifted_diff, self.num_bits + 1)?;

// Check that the last (i.e. most sifnificant) bit is 0
cs.enforce(
|| "bound_check",
|lc| lc + shifted_diff_bits[self.num_bits as usize].get_variable(),
|lc| lc + CS::one(),
|lc| lc + (F::ZERO, CS::one()),
);

Ok(())
}
}

// Constrains `input` < `bound`, where the LHS is a witness and the RHS is a
// constant. The bound must fit into `num_bits` bits (this is asserted in the
// circuit constructor).
// Furthermore, the input must fit into `num_bits`, which is enforced at the
// constraint level.
#[derive(Clone, Debug)]
struct LessThanCircuitSafe<F: PrimeField + PrimeFieldBits> {
bound: F,
input: F,
num_bits: u8,
}

impl<F: PrimeField + PrimeFieldBits> LessThanCircuitSafe<F> {
fn new(bound: F, input: F, num_bits: u8) -> Self {
assert!(get_msb_index(bound) < num_bits);
Self {
bound,
input,
num_bits,
}
}
}

impl<F: PrimeField + PrimeFieldBits> Circuit<F> for LessThanCircuitSafe<F> {
fn synthesize<CS: ConstraintSystem<F>>(self, cs: &mut CS) -> Result<(), SynthesisError> {
let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?;

// Perform the input bit decomposition check
num_to_bits_le_bounded::<F, CS>(cs, input, self.num_bits)?;

// Entering a new namespace to prefix variables in the
// LessThanCircuitUnsafe, thus avoiding name clashes
cs.push_namespace(|| "less_than_safe");

LessThanCircuitUnsafe {
bound: self.bound,
input: self.input,
num_bits: self.num_bits,
}
.synthesize(cs)
}
}

fn verify_circuit_unsafe<G: Group, S: RelaxedR1CSSNARKTrait<G>>(
bound: G::Scalar,
input: G::Scalar,
num_bits: u8,
) -> Result<(), SpartanError> {
let circuit = LessThanCircuitUnsafe::new(bound, input, num_bits);

// produce keys
let (pk, vk) = SNARK::<G, S, LessThanCircuitUnsafe<_>>::setup(circuit.clone()).unwrap();

// produce a SNARK
let snark = SNARK::prove(&pk, circuit).unwrap();

// verify the SNARK
snark.verify(&vk, &[])
}

fn verify_circuit_safe<G: Group, S: RelaxedR1CSSNARKTrait<G>>(
bound: G::Scalar,
input: G::Scalar,
num_bits: u8,
) -> Result<(), SpartanError> {
let circuit = LessThanCircuitSafe::new(bound, input, num_bits);

// produce keys
let (pk, vk) = SNARK::<G, S, LessThanCircuitSafe<_>>::setup(circuit.clone()).unwrap();

// produce a SNARK
let snark = SNARK::prove(&pk, circuit).unwrap();

// verify the SNARK
snark.verify(&vk, &[])
}

fn main() {
type G = pasta_curves::pallas::Point;
type EE = spartan2::provider::ipa_pc::EvaluationEngine<G>;
type S = spartan2::spartan::snark::RelaxedR1CSSNARK<G, EE>;

println!("Executing unsafe circuit...");
//Typical example, ok
assert!(verify_circuit_unsafe::<G, S>(Fq::from(17), Fq::from(9), 10).is_ok());
// Typical example, err
assert!(verify_circuit_unsafe::<G, S>(Fq::from(17), Fq::from(20), 10).is_err());
// Edge case, err
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), Fq::from(4), 10).is_err());
// Edge case, ok
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), Fq::from(3), 10).is_ok());
// Minimum number of bits for the bound, ok
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), Fq::from(3), 3).is_ok());
// Insufficient number of bits for the input, but this is not detected by the
// unsafety of the circuit (compare with the last example below)
// Note that -Fq::one() is corresponds to q - 1 > bound
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), -Fq::one(), 3).is_ok());

println!("Unsafe circuit OK");

println!("Executing safe circuit...");
// Typical example, ok
assert!(verify_circuit_safe::<G, S>(Fq::from(17), Fq::from(9), 10).is_ok());
// Typical example, err
assert!(verify_circuit_safe::<G, S>(Fq::from(17), Fq::from(20), 10).is_err());
// Edge case, err
assert!(verify_circuit_safe::<G, S>(Fq::from(4), Fq::from(4), 10).is_err());
// Edge case, ok
assert!(verify_circuit_safe::<G, S>(Fq::from(4), Fq::from(3), 10).is_ok());
// Minimum number of bits for the bound, ok
assert!(verify_circuit_safe::<G, S>(Fq::from(4), Fq::from(3), 3).is_ok());
// Insufficient number of bits for the input, err (compare with the last example
// above).
// Note that -Fq::one() is corresponds to q - 1 > bound
assert!(verify_circuit_safe::<G, S>(Fq::from(4), -Fq::one(), 3).is_err());

println!("Safe circuit OK");
}
9 changes: 4 additions & 5 deletions src/bellpepper/shape_cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ struct OrderedVariable(Variable);

#[derive(Debug)]
enum NamedObject {
Constraint(usize),
Var(Variable),
Constraint,
Var,
Namespace,
}

Expand Down Expand Up @@ -222,7 +222,7 @@ where
{
fn default() -> Self {
let mut map = HashMap::new();
map.insert("ONE".into(), NamedObject::Var(ShapeCS::<G>::one()));
map.insert("ONE".into(), NamedObject::Var);
ShapeCS {
named_objects: map,
current_namespace: vec![],
Expand Down Expand Up @@ -272,8 +272,7 @@ where
LC: FnOnce(LinearCombination<G::Scalar>) -> LinearCombination<G::Scalar>,
{
let path = compute_path(&self.current_namespace, &annotation().into());
let index = self.constraints.len();
self.set_named_obj(path.clone(), NamedObject::Constraint(index));
self.set_named_obj(path.clone(), NamedObject::Constraint);

let a = a(LinearCombination::zero());
let b = b(LinearCombination::zero());
Expand Down
9 changes: 4 additions & 5 deletions src/bellpepper/test_shape_cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ struct OrderedVariable(Variable);

#[derive(Debug)]
enum NamedObject {
Constraint(usize),
Var(Variable),
Constraint,
Var,
Namespace,
}

Expand Down Expand Up @@ -224,7 +224,7 @@ where
{
fn default() -> Self {
let mut map = HashMap::new();
map.insert("ONE".into(), NamedObject::Var(TestShapeCS::<G>::one()));
map.insert("ONE".into(), NamedObject::Var);
TestShapeCS {
named_objects: map,
current_namespace: vec![],
Expand Down Expand Up @@ -274,8 +274,7 @@ where
LC: FnOnce(LinearCombination<G::Scalar>) -> LinearCombination<G::Scalar>,
{
let path = compute_path(&self.current_namespace, &annotation().into());
let index = self.constraints.len();
self.set_named_obj(path.clone(), NamedObject::Constraint(index));
self.set_named_obj(path.clone(), NamedObject::Constraint);

let a = a(LinearCombination::zero());
let b = b(LinearCombination::zero());
Expand Down
11 changes: 4 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl<G: Group, S: RelaxedR1CSSNARKTrait<G>, C: Circuit<G::Scalar>> SNARK<G, S, C
/// Produces prover and verifier keys for the direct SNARK
pub fn setup(circuit: C) -> Result<(ProverKey<G, S>, VerifierKey<G, S>), SpartanError> {
let (pk, vk) = S::setup(circuit)?;

Ok((ProverKey { pk }, VerifierKey { vk }))
}

Expand Down Expand Up @@ -108,15 +109,12 @@ mod tests {
use super::*;
use crate::provider::{bn256_grumpkin::bn256, secp_secq::secp256k1};
use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError};
use core::marker::PhantomData;
use ff::PrimeField;

#[derive(Clone, Debug, Default)]
struct CubicCircuit<F: PrimeField> {
_p: PhantomData<F>,
}
struct CubicCircuit {}

impl<F> Circuit<F> for CubicCircuit<F>
impl<F> Circuit<F> for CubicCircuit
where
F: PrimeField,
{
Expand Down Expand Up @@ -178,8 +176,7 @@ mod tests {
let circuit = CubicCircuit::default();

// produce keys
let (pk, vk) =
SNARK::<G, S, CubicCircuit<<G as Group>::Scalar>>::setup(circuit.clone()).unwrap();
let (pk, vk) = SNARK::<G, S, CubicCircuit>::setup(circuit.clone()).unwrap();

// produce a SNARK
let res = SNARK::prove(&pk, circuit);
Expand Down
2 changes: 1 addition & 1 deletion src/provider/keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl<G: Group> TranscriptEngineTrait<G> for Keccak256Transcript<G> {

fn absorb<T: TranscriptReprTrait<G>>(&mut self, label: &'static [u8], o: &T) {
self.transcript.update(label);
self.transcript.update(&o.to_transcript_bytes());
self.transcript.update(o.to_transcript_bytes());
}

fn dom_sep(&mut self, bytes: &'static [u8]) {
Expand Down
7 changes: 0 additions & 7 deletions src/spartan/math.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
pub trait Math {
fn pow2(self) -> usize;
fn get_bits(self, num_bits: usize) -> Vec<bool>;
fn log_2(self) -> usize;
}

impl Math for usize {
#[inline]
fn pow2(self) -> usize {
let base: usize = 2;
base.pow(self as u32)
}

/// Returns the `num_bits` from n in a canonical order
fn get_bits(self, num_bits: usize) -> Vec<bool> {
(0..num_bits)
Expand Down
Loading

0 comments on commit a6d6041

Please sign in to comment.