Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option for partitioned trace committment #336

Merged
merged 14 commits into from
Oct 24, 2024
72 changes: 69 additions & 3 deletions air/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
// LICENSE file in the root directory of this source tree.

use alloc::vec::Vec;
use core::{cmp, ops::Div};

use fri::FriOptions;
use math::{StarkField, ToElements};
use math::{FieldElement, StarkField, ToElements};
use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

// CONSTANTS
Expand Down Expand Up @@ -74,6 +75,16 @@ pub enum FieldExtension {
/// is the hash function used in the protocol. The soundness of a STARK proof is limited by the
/// collision resistance of the hash function used by the protocol. For example, if a hash function
/// with 128-bit collision resistance is used, soundness of a STARK proof cannot exceed 128 bits.
///
/// In addition to the above, the parameter `num_partitions` is used in order to specify the number
/// of partitions each of the traces committed to during proof generation is split into.
/// More precisely, and taking the main segment trace as an example, the prover will split the main
/// segment trace into `num_partitions` parts. The prover will then proceed to hash each part row-wise
/// resulting in `num_partitions` digests per row of the trace. The prover finally combines
/// the `num_partitions` digest (per row) into one digest (per row) and at this point the vector
/// commitment scheme can be called.
/// In the case when `num_partitions` is equal to `1` the prover will just hash each row in one go
/// producing one digest per row of the trace.
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct ProofOptions {
num_queries: u8,
Expand All @@ -82,6 +93,8 @@ pub struct ProofOptions {
field_extension: FieldExtension,
fri_folding_factor: u8,
fri_remainder_max_degree: u8,
num_partitions: u8,
min_partition_size: u8,
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
}

// PROOF OPTIONS IMPLEMENTATION
Expand All @@ -108,14 +121,47 @@ impl ProofOptions {
/// - `grinding_factor` is greater than 32.
/// - `fri_folding_factor` is not 2, 4, 8, or 16.
/// - `fri_remainder_max_degree` is greater than 255 or is not a power of two minus 1.
#[rustfmt::skip]
pub const fn new(
num_queries: usize,
blowup_factor: usize,
grinding_factor: u32,
field_extension: FieldExtension,
fri_folding_factor: usize,
fri_remainder_max_degree: usize,
) -> ProofOptions {
Self::with_num_partitions(
num_queries,
blowup_factor,
grinding_factor,
field_extension,
fri_folding_factor,
fri_remainder_max_degree,
1,
1,
)
}

/// Returns a new instance of [ProofOptions] struct constructed from the specified parameters.
///
/// # Panics
/// Panics if:
/// - `num_queries` is zero or greater than 255.
/// - `blowup_factor` is smaller than 2, greater than 128, or is not a power of two.
/// - `grinding_factor` is greater than 32.
/// - `fri_folding_factor` is not 2, 4, 8, or 16.
/// - `fri_remainder_max_degree` is greater than 255 or is not a power of two minus 1.
/// - `num_partitions` is zero or greater than 16.
#[rustfmt::skip]
#[allow(clippy::too_many_arguments)]
pub const fn with_num_partitions(
num_queries: usize,
blowup_factor: usize,
grinding_factor: u32,
field_extension: FieldExtension,
fri_folding_factor: usize,
fri_remainder_max_degree: usize,
num_partitions: usize,
min_partition_size: usize,
) -> ProofOptions {
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
// TODO: return errors instead of panicking
assert!(num_queries > 0, "number of queries must be greater than 0");
Expand All @@ -140,13 +186,18 @@ impl ProofOptions {
"FRI polynomial remainder degree cannot be greater than 255"
);

assert!(num_partitions > 0, "number of partitions must be greater than 0");
assert!(num_partitions <= 16, "number of partitions must be less than 16");
irakliyk marked this conversation as resolved.
Show resolved Hide resolved

ProofOptions {
num_queries: num_queries as u8,
blowup_factor: blowup_factor as u8,
grinding_factor: grinding_factor as u8,
field_extension,
fri_folding_factor: fri_folding_factor as u8,
fri_remainder_max_degree: fri_remainder_max_degree as u8,
num_partitions: num_partitions as u8,
min_partition_size: min_partition_size as u8,
}
}

Expand Down Expand Up @@ -206,6 +257,17 @@ impl ProofOptions {
let remainder_max_degree = self.fri_remainder_max_degree as usize;
FriOptions::new(self.blowup_factor(), folding_factor, remainder_max_degree)
}

/// Returns the size of each partition used when committing to the main and auxiliary traces as
/// well as the constraint evaluation trace.
pub fn partition_size<E: FieldElement>(&self, num_columns: usize) -> usize {
let base_elements_per_partition = cmp::max(
(num_columns * E::EXTENSION_DEGREE).div_ceil(self.num_partitions as usize),
self.min_partition_size as usize,
);

base_elements_per_partition.div(E::EXTENSION_DEGREE)
}
}

impl<E: StarkField> ToElements<E> for ProofOptions {
Expand Down Expand Up @@ -233,6 +295,8 @@ impl Serializable for ProofOptions {
target.write(self.field_extension);
target.write_u8(self.fri_folding_factor);
target.write_u8(self.fri_remainder_max_degree);
target.write_u8(self.num_partitions);
target.write_u8(self.min_partition_size);
}
}

Expand All @@ -242,13 +306,15 @@ impl Deserializable for ProofOptions {
/// # Errors
/// Returns an error of a valid proof options could not be read from the specified `source`.
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
Ok(ProofOptions::new(
Ok(ProofOptions::with_num_partitions(
source.read_u8()? as usize,
source.read_u8()? as usize,
source.read_u8()? as u32,
FieldExtension::read_from(source)?,
source.read_u8()? as usize,
source.read_u8()? as usize,
source.read_u8()? as usize,
source.read_u8()? as usize,
))
}
}
Expand Down
9 changes: 9 additions & 0 deletions crypto/src/hash/blake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ impl<B: StarkField> Hasher for Blake3_256<B> {
ByteDigest(blake3::hash(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
ByteDigest(blake3::hash(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut data = [0; 40];
data[..32].copy_from_slice(&seed.0);
Expand Down Expand Up @@ -84,6 +88,11 @@ impl<B: StarkField> Hasher for Blake3_192<B> {
ByteDigest(result.as_bytes()[..24].try_into().unwrap())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
let result = blake3::hash(ByteDigest::digests_as_bytes(values));
ByteDigest(result.as_bytes()[..24].try_into().unwrap())
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut data = [0; 32];
data[..24].copy_from_slice(&seed.0);
Expand Down
24 changes: 24 additions & 0 deletions crypto/src/hash/blake/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

use math::{fields::f62::BaseElement, FieldElement};
use rand_utils::rand_array;
use utils::Deserializable;

use super::{Blake3_256, ElementHasher, Hasher};
use crate::hash::{Blake3_192, ByteDigest};

#[test]
fn hash_padding() {
Expand All @@ -29,3 +31,25 @@ fn hash_elements_padding() {
let r2 = Blake3_256::hash_elements(&e2);
assert_ne!(r1, r2);
}

#[test]
fn merge_vs_merge_many_256() {
let digest_0 = ByteDigest::read_from_bytes(&[1_u8; 32]).unwrap();
let digest_1 = ByteDigest::read_from_bytes(&[2_u8; 32]).unwrap();

let r1 = Blake3_256::<BaseElement>::merge(&[digest_0, digest_1]);
let r2 = Blake3_256::<BaseElement>::merge_many(&[digest_0, digest_1]);

assert_eq!(r1, r2)
}

#[test]
fn merge_vs_merge_many_192() {
let digest_0 = ByteDigest::read_from_bytes(&[1_u8; 24]).unwrap();
let digest_1 = ByteDigest::read_from_bytes(&[2_u8; 24]).unwrap();

let r1 = Blake3_192::<BaseElement>::merge(&[digest_0, digest_1]);
let r2 = Blake3_192::<BaseElement>::merge_many(&[digest_0, digest_1]);

assert_eq!(r1, r2)
}
3 changes: 3 additions & 0 deletions crypto/src/hash/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ pub trait Hasher {
/// Merkle trees.
fn merge(values: &[Self::Digest; 2]) -> Self::Digest;

/// Returns a hash of many digests.
fn merge_many(values: &[Self::Digest]) -> Self::Digest;

/// Returns hash(`seed` || `value`). This method is intended for use in PRNG and PoW contexts.
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest;
}
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/rescue/rp62_248/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ impl Hasher for Rp62_248 {
ElementDigest::new(state[..DIGEST_SIZE].try_into().unwrap())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(ElementDigest::digests_as_elements(values))
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the state.
Expand Down
14 changes: 14 additions & 0 deletions crypto/src/hash/rescue/rp62_248/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ fn hash_elements_vs_merge() {
assert_eq!(m_result, h_result);
}

#[test]
fn merge_vs_merge_many() {
let elements: [BaseElement; 8] = rand_array();

let digests: [ElementDigest; 2] = [
ElementDigest::new(elements[..4].try_into().unwrap()),
ElementDigest::new(elements[4..].try_into().unwrap()),
];

let m_result = Rp62_248::merge(&digests);
let h_result = Rp62_248::merge_many(&digests);
assert_eq!(m_result, h_result);
}

#[test]
fn hash_elements_vs_merge_with_int() {
let seed = ElementDigest::new(rand_array());
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/rescue/rp64_256/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ impl Hasher for Rp64_256 {
ElementDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(ElementDigest::digests_as_elements(values))
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state.
Expand Down
14 changes: 14 additions & 0 deletions crypto/src/hash/rescue/rp64_256/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ fn hash_elements_vs_merge() {
assert_eq!(m_result, h_result);
}

#[test]
fn merge_vs_merge_many() {
let elements: [BaseElement; 8] = rand_array();

let digests: [ElementDigest; 2] = [
ElementDigest::new(elements[..4].try_into().unwrap()),
ElementDigest::new(elements[4..].try_into().unwrap()),
];

let m_result = Rp64_256::merge(&digests);
let h_result = Rp64_256::merge_many(&digests);
assert_eq!(m_result, h_result);
}

#[test]
fn hash_elements_vs_merge_with_int() {
let seed = ElementDigest::new(rand_array());
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/rescue/rp64_256_jive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ impl Hasher for RpJive64_256 {
Self::apply_jive_summation(&initial_state, &state)
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(ElementDigest::digests_as_elements(values))
}

// We do not rely on the sponge construction to build our compression function. Instead, we use
// the Jive compression mode designed in https://eprint.iacr.org/2022/840.pdf.
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/sha/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ impl<B: StarkField> Hasher for Sha3_256<B> {
ByteDigest(sha3::Sha3_256::digest(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
ByteDigest(sha3::Sha3_256::digest(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut data = [0; 40];
data[..32].copy_from_slice(&seed.0);
Expand Down
3 changes: 2 additions & 1 deletion examples/src/fibonacci/fib2/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
num_partitions: usize,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain)
DefaultTraceLde::new(trace_info, main_trace, domain, num_partitions)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
Expand Down
3 changes: 2 additions & 1 deletion examples/src/fibonacci/fib8/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
num_partitions: usize,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain)
DefaultTraceLde::new(trace_info, main_trace, domain, num_partitions)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
Expand Down
3 changes: 2 additions & 1 deletion examples/src/fibonacci/fib_small/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
num_partitions: usize,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain)
DefaultTraceLde::new(trace_info, main_trace, domain, num_partitions)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
Expand Down
3 changes: 2 additions & 1 deletion examples/src/fibonacci/mulfib2/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
num_partitions: usize,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain)
DefaultTraceLde::new(trace_info, main_trace, domain, num_partitions)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
Expand Down
3 changes: 2 additions & 1 deletion examples/src/fibonacci/mulfib8/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
num_partitions: usize,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain)
DefaultTraceLde::new(trace_info, main_trace, domain, num_partitions)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
Expand Down
3 changes: 2 additions & 1 deletion examples/src/lamport/aggregate/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
num_partitions: usize,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain)
DefaultTraceLde::new(trace_info, main_trace, domain, num_partitions)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
Expand Down
3 changes: 2 additions & 1 deletion examples/src/lamport/threshold/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,9 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
num_partitions: usize,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain)
DefaultTraceLde::new(trace_info, main_trace, domain, num_partitions)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
Expand Down
3 changes: 2 additions & 1 deletion examples/src/merkle/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
num_partitions: usize,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain)
DefaultTraceLde::new(trace_info, main_trace, domain, num_partitions)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
Expand Down
Loading