Skip to content

Commit

Permalink
Add option for partitioned trace commitment (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
Al-Kindi-0 authored Oct 24, 2024
1 parent 9a9377f commit 6204d61
Show file tree
Hide file tree
Showing 32 changed files with 369 additions and 81 deletions.
2 changes: 1 addition & 1 deletion air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mod errors;
pub use errors::AssertionError;

mod options;
pub use options::{FieldExtension, ProofOptions};
pub use options::{FieldExtension, PartitionOptions, ProofOptions};

mod air;
pub use air::{
Expand Down
114 changes: 103 additions & 11 deletions air/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
// LICENSE file in the root directory of this source tree.

use alloc::vec::Vec;
use core::{cmp, ops::Div};

use fri::FriOptions;
use math::{StarkField, ToElements};
use math::{FieldElement, StarkField, ToElements};
use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

// CONSTANTS
Expand Down Expand Up @@ -74,6 +75,17 @@ pub enum FieldExtension {
/// is the hash function used in the protocol. The soundness of a STARK proof is limited by the
/// collision resistance of the hash function used by the protocol. For example, if a hash function
/// with 128-bit collision resistance is used, soundness of a STARK proof cannot exceed 128 bits.
///
/// In addition to the above, the parameter `num_partitions` is used in order to specify the number
/// of partitions each of the traces committed to during proof generation is split into, and
/// the parameter `min_partition_size` gives a lower bound on the size of each such partition.
/// More precisely, and taking the main segment trace as an example, the prover will split the main
/// segment trace into `num_partitions` parts each of size at least `min_partition_size`. The prover
/// will then proceed to hash each part row-wise resulting in `num_partitions` digests per row of
/// the trace. The prover finally combines the `num_partitions` digest (per row) into one digest
/// (per row) and at this point the vector commitment scheme can be called.
/// In the case when `num_partitions` is equal to `1` the prover will just hash each row in one go
/// producing one digest per row of the trace.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct ProofOptions {
num_queries: u8,
Expand All @@ -82,6 +94,7 @@ pub struct ProofOptions {
field_extension: FieldExtension,
fri_folding_factor: u8,
fri_remainder_max_degree: u8,
partition_options: PartitionOptions,
}

// PROOF OPTIONS IMPLEMENTATION
Expand All @@ -108,7 +121,6 @@ impl ProofOptions {
/// - `grinding_factor` is greater than 32.
/// - `fri_folding_factor` is not 2, 4, 8, or 16.
/// - `fri_remainder_max_degree` is greater than 255 or is not a power of two minus 1.
#[rustfmt::skip]
pub const fn new(
num_queries: usize,
blowup_factor: usize,
Expand All @@ -125,11 +137,20 @@ impl ProofOptions {
assert!(blowup_factor >= MIN_BLOWUP_FACTOR, "blowup factor cannot be smaller than 2");
assert!(blowup_factor <= MAX_BLOWUP_FACTOR, "blowup factor cannot be greater than 128");

assert!(grinding_factor <= MAX_GRINDING_FACTOR, "grinding factor cannot be greater than 32");
assert!(
grinding_factor <= MAX_GRINDING_FACTOR,
"grinding factor cannot be greater than 32"
);

assert!(fri_folding_factor.is_power_of_two(), "FRI folding factor must be a power of 2");
assert!(fri_folding_factor >= FRI_MIN_FOLDING_FACTOR, "FRI folding factor cannot be smaller than 2");
assert!(fri_folding_factor <= FRI_MAX_FOLDING_FACTOR, "FRI folding factor cannot be greater than 16");
assert!(
fri_folding_factor >= FRI_MIN_FOLDING_FACTOR,
"FRI folding factor cannot be smaller than 2"
);
assert!(
fri_folding_factor <= FRI_MAX_FOLDING_FACTOR,
"FRI folding factor cannot be greater than 16"
);

assert!(
(fri_remainder_max_degree + 1).is_power_of_two(),
Expand All @@ -140,16 +161,33 @@ impl ProofOptions {
"FRI polynomial remainder degree cannot be greater than 255"
);

ProofOptions {
Self {
num_queries: num_queries as u8,
blowup_factor: blowup_factor as u8,
grinding_factor: grinding_factor as u8,
field_extension,
fri_folding_factor: fri_folding_factor as u8,
fri_remainder_max_degree: fri_remainder_max_degree as u8,
partition_options: PartitionOptions::new(1, 1),
}
}

/// Updates the provided [ProofOptions] instance with the specified partition parameters.
///
/// # Panics
/// Panics if:
/// - `num_partitions` is zero or greater than 16.
/// - `min_partition_size` is zero or greater than 256.
pub const fn with_partitions(
mut self,
num_partitions: usize,
min_partition_size: usize,
) -> ProofOptions {
self.partition_options = PartitionOptions::new(num_partitions, min_partition_size);

self
}

// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -206,6 +244,11 @@ impl ProofOptions {
let remainder_max_degree = self.fri_remainder_max_degree as usize;
FriOptions::new(self.blowup_factor(), folding_factor, remainder_max_degree)
}

/// Returns the `[PartitionOptions]` used in this instance of proof options.
pub fn partition_options(&self) -> PartitionOptions {
self.partition_options
}
}

impl<E: StarkField> ToElements<E> for ProofOptions {
Expand Down Expand Up @@ -233,6 +276,8 @@ impl Serializable for ProofOptions {
target.write(self.field_extension);
target.write_u8(self.fri_folding_factor);
target.write_u8(self.fri_remainder_max_degree);
target.write_u8(self.partition_options.num_partitions);
target.write_u8(self.partition_options.min_partition_size);
}
}

Expand All @@ -242,14 +287,15 @@ impl Deserializable for ProofOptions {
/// # Errors
/// Returns an error of a valid proof options could not be read from the specified `source`.
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
Ok(ProofOptions::new(
let result = ProofOptions::new(
source.read_u8()? as usize,
source.read_u8()? as usize,
source.read_u8()? as u32,
FieldExtension::read_from(source)?,
source.read_u8()? as usize,
source.read_u8()? as usize,
))
);
Ok(result.with_partitions(source.read_u8()? as usize, source.read_u8()? as usize))
}
}

Expand All @@ -272,9 +318,6 @@ impl FieldExtension {
}
}

// SERIALIZATION
// ================================================================================================

impl Serializable for FieldExtension {
/// Serializes `self` and writes the resulting bytes into the `target`.
fn write_into<W: ByteWriter>(&self, target: &mut W) {
Expand All @@ -301,6 +344,55 @@ impl Deserializable for FieldExtension {
}
}

// PARTITION OPTION IMPLEMENTATION
// ================================================================================================

/// Defines the parameters used when committing to the traces generated during the protocol.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct PartitionOptions {
num_partitions: u8,
min_partition_size: u8,
}

impl PartitionOptions {
/// Returns a new instance of `[PartitionOptions]`.
pub const fn new(num_partitions: usize, min_partition_size: usize) -> Self {
assert!(num_partitions >= 1, "number of partitions must be greater than or eqaul to 1");
assert!(num_partitions <= 16, "number of partitions must be smaller than or equal to 16");

assert!(
min_partition_size >= 1,
"smallest partition size must be greater than or equal to 1"
);
assert!(
min_partition_size <= 256,
"smallest partition size must be smaller than or equal to 256"
);

Self {
num_partitions: num_partitions as u8,
min_partition_size: min_partition_size as u8,
}
}

/// Returns the size of each partition used when committing to the main and auxiliary traces as
/// well as the constraint evaluation trace.
pub fn partition_size<E: FieldElement>(&self, num_columns: usize) -> usize {
let base_elements_per_partition = cmp::max(
(num_columns * E::EXTENSION_DEGREE).div_ceil(self.num_partitions as usize),
self.min_partition_size as usize,
);

base_elements_per_partition.div(E::EXTENSION_DEGREE)
}
}

impl Default for PartitionOptions {
fn default() -> Self {
Self { num_partitions: 1, min_partition_size: 1 }
}
}

// TESTS
// ================================================================================================

Expand Down
9 changes: 9 additions & 0 deletions crypto/src/hash/blake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ impl<B: StarkField> Hasher for Blake3_256<B> {
ByteDigest(blake3::hash(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
ByteDigest(blake3::hash(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut data = [0; 40];
data[..32].copy_from_slice(&seed.0);
Expand Down Expand Up @@ -84,6 +88,11 @@ impl<B: StarkField> Hasher for Blake3_192<B> {
ByteDigest(result.as_bytes()[..24].try_into().unwrap())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
let result = blake3::hash(ByteDigest::digests_as_bytes(values));
ByteDigest(result.as_bytes()[..24].try_into().unwrap())
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut data = [0; 32];
data[..24].copy_from_slice(&seed.0);
Expand Down
24 changes: 24 additions & 0 deletions crypto/src/hash/blake/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

use math::{fields::f62::BaseElement, FieldElement};
use rand_utils::rand_array;
use utils::Deserializable;

use super::{Blake3_256, ElementHasher, Hasher};
use crate::hash::{Blake3_192, ByteDigest};

#[test]
fn hash_padding() {
Expand All @@ -29,3 +31,25 @@ fn hash_elements_padding() {
let r2 = Blake3_256::hash_elements(&e2);
assert_ne!(r1, r2);
}

#[test]
fn merge_vs_merge_many_256() {
let digest_0 = ByteDigest::read_from_bytes(&[1_u8; 32]).unwrap();
let digest_1 = ByteDigest::read_from_bytes(&[2_u8; 32]).unwrap();

let r1 = Blake3_256::<BaseElement>::merge(&[digest_0, digest_1]);
let r2 = Blake3_256::<BaseElement>::merge_many(&[digest_0, digest_1]);

assert_eq!(r1, r2)
}

#[test]
fn merge_vs_merge_many_192() {
let digest_0 = ByteDigest::read_from_bytes(&[1_u8; 24]).unwrap();
let digest_1 = ByteDigest::read_from_bytes(&[2_u8; 24]).unwrap();

let r1 = Blake3_192::<BaseElement>::merge(&[digest_0, digest_1]);
let r2 = Blake3_192::<BaseElement>::merge_many(&[digest_0, digest_1]);

assert_eq!(r1, r2)
}
3 changes: 3 additions & 0 deletions crypto/src/hash/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ pub trait Hasher {
/// Merkle trees.
fn merge(values: &[Self::Digest; 2]) -> Self::Digest;

/// Returns a hash of many digests.
fn merge_many(values: &[Self::Digest]) -> Self::Digest;

/// Returns hash(`seed` || `value`). This method is intended for use in PRNG and PoW contexts.
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest;
}
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/rescue/rp62_248/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ impl Hasher for Rp62_248 {
ElementDigest::new(state[..DIGEST_SIZE].try_into().unwrap())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(ElementDigest::digests_as_elements(values))
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the state.
Expand Down
14 changes: 14 additions & 0 deletions crypto/src/hash/rescue/rp62_248/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ fn hash_elements_vs_merge() {
assert_eq!(m_result, h_result);
}

#[test]
fn merge_vs_merge_many() {
let elements: [BaseElement; 8] = rand_array();

let digests: [ElementDigest; 2] = [
ElementDigest::new(elements[..4].try_into().unwrap()),
ElementDigest::new(elements[4..].try_into().unwrap()),
];

let m_result = Rp62_248::merge(&digests);
let h_result = Rp62_248::merge_many(&digests);
assert_eq!(m_result, h_result);
}

#[test]
fn hash_elements_vs_merge_with_int() {
let seed = ElementDigest::new(rand_array());
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/rescue/rp64_256/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ impl Hasher for Rp64_256 {
ElementDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(ElementDigest::digests_as_elements(values))
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state.
Expand Down
14 changes: 14 additions & 0 deletions crypto/src/hash/rescue/rp64_256/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ fn hash_elements_vs_merge() {
assert_eq!(m_result, h_result);
}

#[test]
fn merge_vs_merge_many() {
let elements: [BaseElement; 8] = rand_array();

let digests: [ElementDigest; 2] = [
ElementDigest::new(elements[..4].try_into().unwrap()),
ElementDigest::new(elements[4..].try_into().unwrap()),
];

let m_result = Rp64_256::merge(&digests);
let h_result = Rp64_256::merge_many(&digests);
assert_eq!(m_result, h_result);
}

#[test]
fn hash_elements_vs_merge_with_int() {
let seed = ElementDigest::new(rand_array());
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/rescue/rp64_256_jive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ impl Hasher for RpJive64_256 {
Self::apply_jive_summation(&initial_state, &state)
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(ElementDigest::digests_as_elements(values))
}

// We do not rely on the sponge construction to build our compression function. Instead, we use
// the Jive compression mode designed in https://eprint.iacr.org/2022/840.pdf.
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/sha/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ impl<B: StarkField> Hasher for Sha3_256<B> {
ByteDigest(sha3::Sha3_256::digest(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
ByteDigest(sha3::Sha3_256::digest(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut data = [0; 40];
data[..32].copy_from_slice(&seed.0);
Expand Down
7 changes: 4 additions & 3 deletions examples/src/fibonacci/fib2/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

use winterfell::{
crypto::MerkleTree, matrix::ColMatrix, AuxRandElements, ConstraintCompositionCoefficients,
DefaultConstraintEvaluator, DefaultTraceLde, StarkDomain, Trace, TraceInfo, TracePolyTable,
TraceTable,
DefaultConstraintEvaluator, DefaultTraceLde, PartitionOptions, StarkDomain, Trace, TraceInfo,
TracePolyTable, TraceTable,
};

use super::{
Expand Down Expand Up @@ -77,8 +77,9 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
partition_option: PartitionOptions,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain)
DefaultTraceLde::new(trace_info, main_trace, domain, partition_option)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
Expand Down
Loading

0 comments on commit 6204d61

Please sign in to comment.