diff --git a/README.md b/README.md index 6bdfe16..36c7ed4 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,10 @@ trademarks or logos is subject to and must follow Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies. +## Examples + +Run `cargo run --example EXAMPLE_NAME` to run the corresponding example. Leave `EXAMPLE_NAME` empty for a list of available examples. + ## Benchmarks -Run `cargo bench` to run all benchmarks. Run `cargo bench --benches BENCH_NAME` to run a specific benchmark. \ No newline at end of file +Run `cargo bench` to run all benchmarks. Run `cargo bench --benches BENCH_NAME` to run a specific benchmark. diff --git a/examples/less_than.rs b/examples/less_than.rs new file mode 100644 index 0000000..a3bac59 --- /dev/null +++ b/examples/less_than.rs @@ -0,0 +1,236 @@ +use bellpepper_core::{ + boolean::AllocatedBit, num::AllocatedNum, Circuit, ConstraintSystem, LinearCombination, + SynthesisError, +}; +use ff::{PrimeField, PrimeFieldBits}; +use pasta_curves::Fq; +use spartan2::{ + errors::SpartanError, + traits::{snark::RelaxedR1CSSNARKTrait, Group}, + SNARK, +}; + +fn num_to_bits_le_bounded>( + cs: &mut CS, + n: AllocatedNum, + num_bits: u8, +) -> Result, SynthesisError> { + let opt_bits = match n.get_value() { + Some(v) => v + .to_le_bits() + .into_iter() + .take(num_bits as usize) + .map(Some) + .collect::>>(), + None => vec![None; num_bits as usize], + }; + + // Add one witness per input bit in little-endian bit order + let bits_circuit = opt_bits.into_iter() + .enumerate() + // AllocateBit enforces the value to be 0 or 1 at the constraint level + .map(|(i, b)| AllocatedBit::alloc(cs.namespace(|| format!("b_{}", i)), b).unwrap()) + .collect::>(); + + let mut weighted_sum_lc = LinearCombination::zero(); + let mut pow2 = F::ONE; + + for bit in bits_circuit.iter() { + weighted_sum_lc = weighted_sum_lc + (pow2, bit.get_variable()); + pow2 = pow2.double(); + } + + cs.enforce( + || "bit decomposition check", + |lc| lc + &weighted_sum_lc, + |lc| lc + CS::one(), + |lc| lc + n.get_variable(), + ); + + Ok(bits_circuit) +} + +fn get_msb_index(n: F) -> u8 { + n.to_le_bits() + .into_iter() + .enumerate() + .rev() + .find(|(_, b)| *b) + .unwrap() + .0 as u8 +} + +// Constrains `input` < `bound`, where the LHS is a witness and the RHS is a +// constant. The bound must fit into `num_bits` bits (this is asserted in the +// circuit constructor). +// Important: it must be checked elsewhere (in an overarching circuit) that the +// input fits into `num_bits` bits - this is NOT constrained by this circuit +// in order to avoid duplications (hence "unsafe"). Cf. LessThanCircuitSafe for +// a safe version. +#[derive(Clone, Debug)] +struct LessThanCircuitUnsafe { + bound: F, // Will be a constant in the constraits, not a variable + input: F, // Will be an input/output variable + num_bits: u8, +} + +impl LessThanCircuitUnsafe { + fn new(bound: F, input: F, num_bits: u8) -> Self { + assert!(get_msb_index(bound) < num_bits); + Self { + bound, + input, + num_bits, + } + } +} + +impl Circuit for LessThanCircuitUnsafe { + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { + assert!(F::NUM_BITS > self.num_bits as u32 + 1); + + let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?; + + let shifted_diff = AllocatedNum::alloc(cs.namespace(|| "shifted_diff"), || { + Ok(self.input + F::from(1 << self.num_bits) - self.bound) + })?; + + cs.enforce( + || "shifted_diff_computation", + |lc| lc + input.get_variable() + (F::from(1 << self.num_bits) - self.bound, CS::one()), + |lc: LinearCombination| lc + CS::one(), + |lc| lc + shifted_diff.get_variable(), + ); + + let shifted_diff_bits = num_to_bits_le_bounded::(cs, shifted_diff, self.num_bits + 1)?; + + // Check that the last (i.e. most sifnificant) bit is 0 + cs.enforce( + || "bound_check", + |lc| lc + shifted_diff_bits[self.num_bits as usize].get_variable(), + |lc| lc + CS::one(), + |lc| lc + (F::ZERO, CS::one()), + ); + + Ok(()) + } +} + +// Constrains `input` < `bound`, where the LHS is a witness and the RHS is a +// constant. The bound must fit into `num_bits` bits (this is asserted in the +// circuit constructor). +// Furthermore, the input must fit into `num_bits`, which is enforced at the +// constraint level. +#[derive(Clone, Debug)] +struct LessThanCircuitSafe { + bound: F, + input: F, + num_bits: u8, +} + +impl LessThanCircuitSafe { + fn new(bound: F, input: F, num_bits: u8) -> Self { + assert!(get_msb_index(bound) < num_bits); + Self { + bound, + input, + num_bits, + } + } +} + +impl Circuit for LessThanCircuitSafe { + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { + let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?; + + // Perform the input bit decomposition check + num_to_bits_le_bounded::(cs, input, self.num_bits)?; + + // Entering a new namespace to prefix variables in the + // LessThanCircuitUnsafe, thus avoiding name clashes + cs.push_namespace(|| "less_than_safe"); + + LessThanCircuitUnsafe { + bound: self.bound, + input: self.input, + num_bits: self.num_bits, + } + .synthesize(cs) + } +} + +fn verify_circuit_unsafe>( + bound: G::Scalar, + input: G::Scalar, + num_bits: u8, +) -> Result<(), SpartanError> { + let circuit = LessThanCircuitUnsafe::new(bound, input, num_bits); + + // produce keys + let (pk, vk) = SNARK::>::setup(circuit.clone()).unwrap(); + + // produce a SNARK + let snark = SNARK::prove(&pk, circuit).unwrap(); + + // verify the SNARK + snark.verify(&vk, &[]) +} + +fn verify_circuit_safe>( + bound: G::Scalar, + input: G::Scalar, + num_bits: u8, +) -> Result<(), SpartanError> { + let circuit = LessThanCircuitSafe::new(bound, input, num_bits); + + // produce keys + let (pk, vk) = SNARK::>::setup(circuit.clone()).unwrap(); + + // produce a SNARK + let snark = SNARK::prove(&pk, circuit).unwrap(); + + // verify the SNARK + snark.verify(&vk, &[]) +} + +fn main() { + type G = pasta_curves::pallas::Point; + type EE = spartan2::provider::ipa_pc::EvaluationEngine; + type S = spartan2::spartan::snark::RelaxedR1CSSNARK; + + println!("Executing unsafe circuit..."); + //Typical example, ok + assert!(verify_circuit_unsafe::(Fq::from(17), Fq::from(9), 10).is_ok()); + // Typical example, err + assert!(verify_circuit_unsafe::(Fq::from(17), Fq::from(20), 10).is_err()); + // Edge case, err + assert!(verify_circuit_unsafe::(Fq::from(4), Fq::from(4), 10).is_err()); + // Edge case, ok + assert!(verify_circuit_unsafe::(Fq::from(4), Fq::from(3), 10).is_ok()); + // Minimum number of bits for the bound, ok + assert!(verify_circuit_unsafe::(Fq::from(4), Fq::from(3), 3).is_ok()); + // Insufficient number of bits for the input, but this is not detected by the + // unsafety of the circuit (compare with the last example below) + // Note that -Fq::one() is corresponds to q - 1 > bound + assert!(verify_circuit_unsafe::(Fq::from(4), -Fq::one(), 3).is_ok()); + + println!("Unsafe circuit OK"); + + println!("Executing safe circuit..."); + // Typical example, ok + assert!(verify_circuit_safe::(Fq::from(17), Fq::from(9), 10).is_ok()); + // Typical example, err + assert!(verify_circuit_safe::(Fq::from(17), Fq::from(20), 10).is_err()); + // Edge case, err + assert!(verify_circuit_safe::(Fq::from(4), Fq::from(4), 10).is_err()); + // Edge case, ok + assert!(verify_circuit_safe::(Fq::from(4), Fq::from(3), 10).is_ok()); + // Minimum number of bits for the bound, ok + assert!(verify_circuit_safe::(Fq::from(4), Fq::from(3), 3).is_ok()); + // Insufficient number of bits for the input, err (compare with the last example + // above). + // Note that -Fq::one() is corresponds to q - 1 > bound + assert!(verify_circuit_safe::(Fq::from(4), -Fq::one(), 3).is_err()); + + println!("Safe circuit OK"); +} diff --git a/src/bellpepper/shape_cs.rs b/src/bellpepper/shape_cs.rs index 5a50649..51646ff 100644 --- a/src/bellpepper/shape_cs.rs +++ b/src/bellpepper/shape_cs.rs @@ -15,8 +15,8 @@ struct OrderedVariable(Variable); #[derive(Debug)] enum NamedObject { - Constraint(usize), - Var(Variable), + Constraint, + Var, Namespace, } @@ -222,7 +222,7 @@ where { fn default() -> Self { let mut map = HashMap::new(); - map.insert("ONE".into(), NamedObject::Var(ShapeCS::::one())); + map.insert("ONE".into(), NamedObject::Var); ShapeCS { named_objects: map, current_namespace: vec![], @@ -272,8 +272,7 @@ where LC: FnOnce(LinearCombination) -> LinearCombination, { let path = compute_path(&self.current_namespace, &annotation().into()); - let index = self.constraints.len(); - self.set_named_obj(path.clone(), NamedObject::Constraint(index)); + self.set_named_obj(path.clone(), NamedObject::Constraint); let a = a(LinearCombination::zero()); let b = b(LinearCombination::zero()); diff --git a/src/bellpepper/test_shape_cs.rs b/src/bellpepper/test_shape_cs.rs index cf43a7a..988835f 100644 --- a/src/bellpepper/test_shape_cs.rs +++ b/src/bellpepper/test_shape_cs.rs @@ -16,8 +16,8 @@ struct OrderedVariable(Variable); #[derive(Debug)] enum NamedObject { - Constraint(usize), - Var(Variable), + Constraint, + Var, Namespace, } @@ -224,7 +224,7 @@ where { fn default() -> Self { let mut map = HashMap::new(); - map.insert("ONE".into(), NamedObject::Var(TestShapeCS::::one())); + map.insert("ONE".into(), NamedObject::Var); TestShapeCS { named_objects: map, current_namespace: vec![], @@ -274,8 +274,7 @@ where LC: FnOnce(LinearCombination) -> LinearCombination, { let path = compute_path(&self.current_namespace, &annotation().into()); - let index = self.constraints.len(); - self.set_named_obj(path.clone(), NamedObject::Constraint(index)); + self.set_named_obj(path.clone(), NamedObject::Constraint); let a = a(LinearCombination::zero()); let b = b(LinearCombination::zero()); diff --git a/src/lib.rs b/src/lib.rs index 5a60502..b55f0f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,6 +76,7 @@ impl, C: Circuit> SNARK Result<(ProverKey, VerifierKey), SpartanError> { let (pk, vk) = S::setup(circuit)?; + Ok((ProverKey { pk }, VerifierKey { vk })) } @@ -108,15 +109,12 @@ mod tests { use super::*; use crate::provider::{bn256_grumpkin::bn256, secp_secq::secp256k1}; use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; - use core::marker::PhantomData; use ff::PrimeField; #[derive(Clone, Debug, Default)] - struct CubicCircuit { - _p: PhantomData, - } + struct CubicCircuit {} - impl Circuit for CubicCircuit + impl Circuit for CubicCircuit where F: PrimeField, { @@ -178,8 +176,7 @@ mod tests { let circuit = CubicCircuit::default(); // produce keys - let (pk, vk) = - SNARK::::Scalar>>::setup(circuit.clone()).unwrap(); + let (pk, vk) = SNARK::::setup(circuit.clone()).unwrap(); // produce a SNARK let res = SNARK::prove(&pk, circuit); diff --git a/src/provider/keccak.rs b/src/provider/keccak.rs index 849ae30..2046072 100644 --- a/src/provider/keccak.rs +++ b/src/provider/keccak.rs @@ -87,7 +87,7 @@ impl TranscriptEngineTrait for Keccak256Transcript { fn absorb>(&mut self, label: &'static [u8], o: &T) { self.transcript.update(label); - self.transcript.update(&o.to_transcript_bytes()); + self.transcript.update(o.to_transcript_bytes()); } fn dom_sep(&mut self, bytes: &'static [u8]) { diff --git a/src/spartan/math.rs b/src/spartan/math.rs index 691fec5..44c9f2c 100644 --- a/src/spartan/math.rs +++ b/src/spartan/math.rs @@ -1,16 +1,9 @@ pub trait Math { - fn pow2(self) -> usize; fn get_bits(self, num_bits: usize) -> Vec; fn log_2(self) -> usize; } impl Math for usize { - #[inline] - fn pow2(self) -> usize { - let base: usize = 2; - base.pow(self as u32) - } - /// Returns the `num_bits` from n in a canonical order fn get_bits(self, num_bits: usize) -> Vec { (0..num_bits) diff --git a/src/spartan/snark.rs b/src/spartan/snark.rs index a1ad696..8520650 100644 --- a/src/spartan/snark.rs +++ b/src/spartan/snark.rs @@ -101,6 +101,26 @@ impl> RelaxedR1CSSNARKTrait for Relaxe ) -> Result<(Self::ProverKey, Self::VerifierKey), SpartanError> { let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit.synthesize(&mut cs); + + // Padding the ShapeCS: constraints (rows) and variables (columns) + let num_constraints = cs.num_constraints(); + + (num_constraints..num_constraints.next_power_of_two()).for_each(|i| { + cs.enforce( + || format!("padding_constraint_{i}"), + |lc| lc, + |lc| lc, + |lc| lc, + ) + }); + + let num_vars = cs.num_aux(); + + (num_vars..num_vars.next_power_of_two()).for_each(|i| { + cs.alloc(|| format!("padding_var_{i}"), || Ok(G::Scalar::ZERO)) + .unwrap(); + }); + let (S, ck) = cs.r1cs_shape(); let (pk_ee, vk_ee) = EE::setup(&ck); @@ -121,6 +141,14 @@ impl> RelaxedR1CSSNARKTrait for Relaxe let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); let _ = circuit.synthesize(&mut cs); + // Padding variables + let num_vars = cs.aux_slice().len(); + + (num_vars..num_vars.next_power_of_two()).for_each(|i| { + cs.alloc(|| format!("padding_var_{i}"), || Ok(G::Scalar::ZERO)) + .unwrap(); + }); + let (u, w) = cs .r1cs_instance_and_witness(&pk.S, &pk.ck) .map_err(|_e| SpartanError::UnSat)?;