From a6d6041da9b25e4acbf2a01598894bc8e66464a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Fri, 15 Nov 2024 22:57:57 +0100 Subject: [PATCH] LessThan circuits and padding (#14) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Replace verbose fold. Add rayon `into_par_iter`. Fix tabulation Add missing import * Add benchmarks Fix random `Fp` slice creation Fix formatting * Fix random `Fp` vector creation * Remove duplicate poly definitions * Co-authored-by: Cesar Descalzo Blanco Remove unnecessary `PhantomData` Co-authored-by: Antonio Mejías Gil Progress until 08/03 Co-authored-by: Antonio Mejías Gil constructing range check, wip working on padding println nightmare, issue apparently fixed Remove redundant padding. Fix import. code cleanup further cleanup found better way to fix padding the 'better way' to fix padding actually broke other tests - reverted This reverts commit d4727ba905065c3f5662699d48c2efb0a3e87feb. moved example to examples folder, refactored, tested many cases * fix minor clippy warnings * Add LessThanCircuitSafe example. Pass arguments as field elements instead of unsigned integers. * added some documentation * added message * Remove unnecesary u64 hints. Cleanup get_msb_index. * fixed clippy warning * tweaked documentation * Co-authored-by: Cesar Descalzo Blanco Remove unnecessary `PhantomData` Co-authored-by: Antonio Mejías Gil Progress until 08/03 Co-authored-by: Antonio Mejías Gil constructing range check, wip working on padding println nightmare, issue apparently fixed Remove redundant padding. Fix import. code cleanup further cleanup found better way to fix padding the 'better way' to fix padding actually broke other tests - reverted This reverts commit d4727ba905065c3f5662699d48c2efb0a3e87feb. moved example to examples folder, refactored, tested many cases * Add LessThanCircuitSafe example. Pass arguments as field elements instead of unsigned integers. * added some documentation * added message * Remove unnecesary u64 hints. Cleanup get_msb_index. * fixed clippy warning * tweaked documentation * removed unused * fixed clippy warning --------- Co-authored-by: Cesar Descalzo Blanco Co-authored-by: Cesar Descalzo Blanco Co-authored-by: mmagician --- README.md | 6 +- examples/less_than.rs | 236 ++++++++++++++++++++++++++++++++ src/bellpepper/shape_cs.rs | 9 +- src/bellpepper/test_shape_cs.rs | 9 +- src/lib.rs | 11 +- src/provider/keccak.rs | 2 +- src/spartan/math.rs | 7 - src/spartan/snark.rs | 28 ++++ 8 files changed, 282 insertions(+), 26 deletions(-) create mode 100644 examples/less_than.rs diff --git a/README.md b/README.md index 6bdfe16..36c7ed4 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,10 @@ trademarks or logos is subject to and must follow Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies. +## Examples + +Run `cargo run --example EXAMPLE_NAME` to run the corresponding example. Leave `EXAMPLE_NAME` empty for a list of available examples. + ## Benchmarks -Run `cargo bench` to run all benchmarks. Run `cargo bench --benches BENCH_NAME` to run a specific benchmark. \ No newline at end of file +Run `cargo bench` to run all benchmarks. Run `cargo bench --benches BENCH_NAME` to run a specific benchmark. diff --git a/examples/less_than.rs b/examples/less_than.rs new file mode 100644 index 0000000..a3bac59 --- /dev/null +++ b/examples/less_than.rs @@ -0,0 +1,236 @@ +use bellpepper_core::{ + boolean::AllocatedBit, num::AllocatedNum, Circuit, ConstraintSystem, LinearCombination, + SynthesisError, +}; +use ff::{PrimeField, PrimeFieldBits}; +use pasta_curves::Fq; +use spartan2::{ + errors::SpartanError, + traits::{snark::RelaxedR1CSSNARKTrait, Group}, + SNARK, +}; + +fn num_to_bits_le_bounded>( + cs: &mut CS, + n: AllocatedNum, + num_bits: u8, +) -> Result, SynthesisError> { + let opt_bits = match n.get_value() { + Some(v) => v + .to_le_bits() + .into_iter() + .take(num_bits as usize) + .map(Some) + .collect::>>(), + None => vec![None; num_bits as usize], + }; + + // Add one witness per input bit in little-endian bit order + let bits_circuit = opt_bits.into_iter() + .enumerate() + // AllocateBit enforces the value to be 0 or 1 at the constraint level + .map(|(i, b)| AllocatedBit::alloc(cs.namespace(|| format!("b_{}", i)), b).unwrap()) + .collect::>(); + + let mut weighted_sum_lc = LinearCombination::zero(); + let mut pow2 = F::ONE; + + for bit in bits_circuit.iter() { + weighted_sum_lc = weighted_sum_lc + (pow2, bit.get_variable()); + pow2 = pow2.double(); + } + + cs.enforce( + || "bit decomposition check", + |lc| lc + &weighted_sum_lc, + |lc| lc + CS::one(), + |lc| lc + n.get_variable(), + ); + + Ok(bits_circuit) +} + +fn get_msb_index(n: F) -> u8 { + n.to_le_bits() + .into_iter() + .enumerate() + .rev() + .find(|(_, b)| *b) + .unwrap() + .0 as u8 +} + +// Constrains `input` < `bound`, where the LHS is a witness and the RHS is a +// constant. The bound must fit into `num_bits` bits (this is asserted in the +// circuit constructor). +// Important: it must be checked elsewhere (in an overarching circuit) that the +// input fits into `num_bits` bits - this is NOT constrained by this circuit +// in order to avoid duplications (hence "unsafe"). Cf. LessThanCircuitSafe for +// a safe version. +#[derive(Clone, Debug)] +struct LessThanCircuitUnsafe { + bound: F, // Will be a constant in the constraits, not a variable + input: F, // Will be an input/output variable + num_bits: u8, +} + +impl LessThanCircuitUnsafe { + fn new(bound: F, input: F, num_bits: u8) -> Self { + assert!(get_msb_index(bound) < num_bits); + Self { + bound, + input, + num_bits, + } + } +} + +impl Circuit for LessThanCircuitUnsafe { + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { + assert!(F::NUM_BITS > self.num_bits as u32 + 1); + + let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?; + + let shifted_diff = AllocatedNum::alloc(cs.namespace(|| "shifted_diff"), || { + Ok(self.input + F::from(1 << self.num_bits) - self.bound) + })?; + + cs.enforce( + || "shifted_diff_computation", + |lc| lc + input.get_variable() + (F::from(1 << self.num_bits) - self.bound, CS::one()), + |lc: LinearCombination| lc + CS::one(), + |lc| lc + shifted_diff.get_variable(), + ); + + let shifted_diff_bits = num_to_bits_le_bounded::(cs, shifted_diff, self.num_bits + 1)?; + + // Check that the last (i.e. most sifnificant) bit is 0 + cs.enforce( + || "bound_check", + |lc| lc + shifted_diff_bits[self.num_bits as usize].get_variable(), + |lc| lc + CS::one(), + |lc| lc + (F::ZERO, CS::one()), + ); + + Ok(()) + } +} + +// Constrains `input` < `bound`, where the LHS is a witness and the RHS is a +// constant. The bound must fit into `num_bits` bits (this is asserted in the +// circuit constructor). +// Furthermore, the input must fit into `num_bits`, which is enforced at the +// constraint level. +#[derive(Clone, Debug)] +struct LessThanCircuitSafe { + bound: F, + input: F, + num_bits: u8, +} + +impl LessThanCircuitSafe { + fn new(bound: F, input: F, num_bits: u8) -> Self { + assert!(get_msb_index(bound) < num_bits); + Self { + bound, + input, + num_bits, + } + } +} + +impl Circuit for LessThanCircuitSafe { + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { + let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?; + + // Perform the input bit decomposition check + num_to_bits_le_bounded::(cs, input, self.num_bits)?; + + // Entering a new namespace to prefix variables in the + // LessThanCircuitUnsafe, thus avoiding name clashes + cs.push_namespace(|| "less_than_safe"); + + LessThanCircuitUnsafe { + bound: self.bound, + input: self.input, + num_bits: self.num_bits, + } + .synthesize(cs) + } +} + +fn verify_circuit_unsafe>( + bound: G::Scalar, + input: G::Scalar, + num_bits: u8, +) -> Result<(), SpartanError> { + let circuit = LessThanCircuitUnsafe::new(bound, input, num_bits); + + // produce keys + let (pk, vk) = SNARK::>::setup(circuit.clone()).unwrap(); + + // produce a SNARK + let snark = SNARK::prove(&pk, circuit).unwrap(); + + // verify the SNARK + snark.verify(&vk, &[]) +} + +fn verify_circuit_safe>( + bound: G::Scalar, + input: G::Scalar, + num_bits: u8, +) -> Result<(), SpartanError> { + let circuit = LessThanCircuitSafe::new(bound, input, num_bits); + + // produce keys + let (pk, vk) = SNARK::>::setup(circuit.clone()).unwrap(); + + // produce a SNARK + let snark = SNARK::prove(&pk, circuit).unwrap(); + + // verify the SNARK + snark.verify(&vk, &[]) +} + +fn main() { + type G = pasta_curves::pallas::Point; + type EE = spartan2::provider::ipa_pc::EvaluationEngine; + type S = spartan2::spartan::snark::RelaxedR1CSSNARK; + + println!("Executing unsafe circuit..."); + //Typical example, ok + assert!(verify_circuit_unsafe::(Fq::from(17), Fq::from(9), 10).is_ok()); + // Typical example, err + assert!(verify_circuit_unsafe::(Fq::from(17), Fq::from(20), 10).is_err()); + // Edge case, err + assert!(verify_circuit_unsafe::(Fq::from(4), Fq::from(4), 10).is_err()); + // Edge case, ok + assert!(verify_circuit_unsafe::(Fq::from(4), Fq::from(3), 10).is_ok()); + // Minimum number of bits for the bound, ok + assert!(verify_circuit_unsafe::(Fq::from(4), Fq::from(3), 3).is_ok()); + // Insufficient number of bits for the input, but this is not detected by the + // unsafety of the circuit (compare with the last example below) + // Note that -Fq::one() is corresponds to q - 1 > bound + assert!(verify_circuit_unsafe::(Fq::from(4), -Fq::one(), 3).is_ok()); + + println!("Unsafe circuit OK"); + + println!("Executing safe circuit..."); + // Typical example, ok + assert!(verify_circuit_safe::(Fq::from(17), Fq::from(9), 10).is_ok()); + // Typical example, err + assert!(verify_circuit_safe::(Fq::from(17), Fq::from(20), 10).is_err()); + // Edge case, err + assert!(verify_circuit_safe::(Fq::from(4), Fq::from(4), 10).is_err()); + // Edge case, ok + assert!(verify_circuit_safe::(Fq::from(4), Fq::from(3), 10).is_ok()); + // Minimum number of bits for the bound, ok + assert!(verify_circuit_safe::(Fq::from(4), Fq::from(3), 3).is_ok()); + // Insufficient number of bits for the input, err (compare with the last example + // above). + // Note that -Fq::one() is corresponds to q - 1 > bound + assert!(verify_circuit_safe::(Fq::from(4), -Fq::one(), 3).is_err()); + + println!("Safe circuit OK"); +} diff --git a/src/bellpepper/shape_cs.rs b/src/bellpepper/shape_cs.rs index 5a50649..51646ff 100644 --- a/src/bellpepper/shape_cs.rs +++ b/src/bellpepper/shape_cs.rs @@ -15,8 +15,8 @@ struct OrderedVariable(Variable); #[derive(Debug)] enum NamedObject { - Constraint(usize), - Var(Variable), + Constraint, + Var, Namespace, } @@ -222,7 +222,7 @@ where { fn default() -> Self { let mut map = HashMap::new(); - map.insert("ONE".into(), NamedObject::Var(ShapeCS::::one())); + map.insert("ONE".into(), NamedObject::Var); ShapeCS { named_objects: map, current_namespace: vec![], @@ -272,8 +272,7 @@ where LC: FnOnce(LinearCombination) -> LinearCombination, { let path = compute_path(&self.current_namespace, &annotation().into()); - let index = self.constraints.len(); - self.set_named_obj(path.clone(), NamedObject::Constraint(index)); + self.set_named_obj(path.clone(), NamedObject::Constraint); let a = a(LinearCombination::zero()); let b = b(LinearCombination::zero()); diff --git a/src/bellpepper/test_shape_cs.rs b/src/bellpepper/test_shape_cs.rs index cf43a7a..988835f 100644 --- a/src/bellpepper/test_shape_cs.rs +++ b/src/bellpepper/test_shape_cs.rs @@ -16,8 +16,8 @@ struct OrderedVariable(Variable); #[derive(Debug)] enum NamedObject { - Constraint(usize), - Var(Variable), + Constraint, + Var, Namespace, } @@ -224,7 +224,7 @@ where { fn default() -> Self { let mut map = HashMap::new(); - map.insert("ONE".into(), NamedObject::Var(TestShapeCS::::one())); + map.insert("ONE".into(), NamedObject::Var); TestShapeCS { named_objects: map, current_namespace: vec![], @@ -274,8 +274,7 @@ where LC: FnOnce(LinearCombination) -> LinearCombination, { let path = compute_path(&self.current_namespace, &annotation().into()); - let index = self.constraints.len(); - self.set_named_obj(path.clone(), NamedObject::Constraint(index)); + self.set_named_obj(path.clone(), NamedObject::Constraint); let a = a(LinearCombination::zero()); let b = b(LinearCombination::zero()); diff --git a/src/lib.rs b/src/lib.rs index 5a60502..b55f0f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,6 +76,7 @@ impl, C: Circuit> SNARK Result<(ProverKey, VerifierKey), SpartanError> { let (pk, vk) = S::setup(circuit)?; + Ok((ProverKey { pk }, VerifierKey { vk })) } @@ -108,15 +109,12 @@ mod tests { use super::*; use crate::provider::{bn256_grumpkin::bn256, secp_secq::secp256k1}; use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; - use core::marker::PhantomData; use ff::PrimeField; #[derive(Clone, Debug, Default)] - struct CubicCircuit { - _p: PhantomData, - } + struct CubicCircuit {} - impl Circuit for CubicCircuit + impl Circuit for CubicCircuit where F: PrimeField, { @@ -178,8 +176,7 @@ mod tests { let circuit = CubicCircuit::default(); // produce keys - let (pk, vk) = - SNARK::::Scalar>>::setup(circuit.clone()).unwrap(); + let (pk, vk) = SNARK::::setup(circuit.clone()).unwrap(); // produce a SNARK let res = SNARK::prove(&pk, circuit); diff --git a/src/provider/keccak.rs b/src/provider/keccak.rs index 849ae30..2046072 100644 --- a/src/provider/keccak.rs +++ b/src/provider/keccak.rs @@ -87,7 +87,7 @@ impl TranscriptEngineTrait for Keccak256Transcript { fn absorb>(&mut self, label: &'static [u8], o: &T) { self.transcript.update(label); - self.transcript.update(&o.to_transcript_bytes()); + self.transcript.update(o.to_transcript_bytes()); } fn dom_sep(&mut self, bytes: &'static [u8]) { diff --git a/src/spartan/math.rs b/src/spartan/math.rs index 691fec5..44c9f2c 100644 --- a/src/spartan/math.rs +++ b/src/spartan/math.rs @@ -1,16 +1,9 @@ pub trait Math { - fn pow2(self) -> usize; fn get_bits(self, num_bits: usize) -> Vec; fn log_2(self) -> usize; } impl Math for usize { - #[inline] - fn pow2(self) -> usize { - let base: usize = 2; - base.pow(self as u32) - } - /// Returns the `num_bits` from n in a canonical order fn get_bits(self, num_bits: usize) -> Vec { (0..num_bits) diff --git a/src/spartan/snark.rs b/src/spartan/snark.rs index a1ad696..8520650 100644 --- a/src/spartan/snark.rs +++ b/src/spartan/snark.rs @@ -101,6 +101,26 @@ impl> RelaxedR1CSSNARKTrait for Relaxe ) -> Result<(Self::ProverKey, Self::VerifierKey), SpartanError> { let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit.synthesize(&mut cs); + + // Padding the ShapeCS: constraints (rows) and variables (columns) + let num_constraints = cs.num_constraints(); + + (num_constraints..num_constraints.next_power_of_two()).for_each(|i| { + cs.enforce( + || format!("padding_constraint_{i}"), + |lc| lc, + |lc| lc, + |lc| lc, + ) + }); + + let num_vars = cs.num_aux(); + + (num_vars..num_vars.next_power_of_two()).for_each(|i| { + cs.alloc(|| format!("padding_var_{i}"), || Ok(G::Scalar::ZERO)) + .unwrap(); + }); + let (S, ck) = cs.r1cs_shape(); let (pk_ee, vk_ee) = EE::setup(&ck); @@ -121,6 +141,14 @@ impl> RelaxedR1CSSNARKTrait for Relaxe let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); let _ = circuit.synthesize(&mut cs); + // Padding variables + let num_vars = cs.aux_slice().len(); + + (num_vars..num_vars.next_power_of_two()).for_each(|i| { + cs.alloc(|| format!("padding_var_{i}"), || Ok(G::Scalar::ZERO)) + .unwrap(); + }); + let (u, w) = cs .r1cs_instance_and_witness(&pk.S, &pk.ck) .map_err(|_e| SpartanError::UnSat)?;