Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use number of partitions for buffer allocation instead of partition size #339

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion air/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,11 @@ impl PartitionOptions {
self.min_partition_size as usize,
);

base_elements_per_partition.div(E::EXTENSION_DEGREE)
base_elements_per_partition.div_ceil(E::EXTENSION_DEGREE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to make this a bit more sophisticated because I think it will produce rather suboptimal results in some situations. For example:

  • num_columns = 7
  • E::EXTENSION_DEGREE = 3
  • self.num_partitions = 4
  • self.min_partition_size = 8

If I did my math correctly, with this approach, we'd get partition size of 3 columns which would imply 9 base field columns per partition. This would require 2 permutations of the hash function in each partition.

The previous approach would have actually resulted in a better outcome here (i.e., partition size 2, so so the 4 partitions would have 2, 2, 2, 1 columns). But this result would have been technically incorrect because we'd have 6 base field elements per partition and this would be smaller than min_partition_size.

Maybe instead of min_partition_size we should be specifying the number of base elements that can be absorbed per permutation and then we can adjust this algorithm to output more optimal results.

}

pub fn num_partitons(&self) -> usize {
self.num_partitions as usize
}
}

Expand Down
7 changes: 2 additions & 5 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,11 +555,8 @@ pub trait Prover {
log_domain_size = domain_size.ilog2()
)
.in_scope(|| {
let commitment = composed_evaluations.commit_to_rows::<Self::HashFn, Self::VC>(
self.options()
.partition_options()
.partition_size::<E>(num_constraint_composition_columns),
);
let commitment = composed_evaluations
.commit_to_rows::<Self::HashFn, Self::VC>(self.options().partition_options());
ConstraintCommitment::new(composed_evaluations, commitment)
});

Expand Down
8 changes: 6 additions & 2 deletions prover/src/matrix/row_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

use alloc::vec::Vec;

use air::PartitionOptions;
use crypto::{ElementHasher, VectorCommitment};
use math::{fft, FieldElement, StarkField};
#[cfg(feature = "concurrent")]
Expand Down Expand Up @@ -180,14 +181,17 @@ impl<E: FieldElement> RowMatrix<E> {
/// * A vector commitment is computed for the resulting vector using the specified vector
/// commitment scheme.
/// * The resulting vector commitment is returned as the commitment to the entire matrix.
pub fn commit_to_rows<H, V>(&self, partition_size: usize) -> V
pub fn commit_to_rows<H, V>(&self, partition_options: PartitionOptions) -> V
where
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
// allocate vector to store row hashes
let mut row_hashes = unsafe { uninit_vector::<H::Digest>(self.num_rows()) };

let partition_size = partition_options.partition_size::<E>(self.num_cols());
let num_partitions = partition_options.num_partitons();
Comment on lines +192 to +193
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: should this be the "specified" number or partitions or the "implied" number of partitions? For example, let's say we have 7 columns in the degree 2 extension field and the specified number of partitions is 4 with min partition size being 8.

With these parameters, the implied number of partitions is actually 2 (because partition size would be 4 columns and there are 7 columns total). So, would we want to use 2 or 4 for the number of partitions?


if partition_size == self.num_cols() * E::EXTENSION_DEGREE {
// iterate though matrix rows, hashing each row
batch_iter_mut!(
Expand All @@ -205,7 +209,7 @@ impl<E: FieldElement> RowMatrix<E> {
&mut row_hashes,
128, // min batch size
|batch: &mut [H::Digest], batch_offset: usize| {
let mut buffer = vec![H::Digest::default(); partition_size];
let mut buffer = vec![H::Digest::default(); num_partitions];
for (i, row_hash) in batch.iter_mut().enumerate() {
self.row(batch_offset + i)
.chunks(partition_size)
Expand Down
18 changes: 7 additions & 11 deletions prover/src/trace/trace_lde/default/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub struct DefaultTraceLde<
aux_segment_oracles: Option<V>,
blowup: usize,
trace_info: TraceInfo,
partition_option: PartitionOptions,
partition_options: PartitionOptions,
_h: PhantomData<H>,
}

Expand All @@ -64,14 +64,14 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<E::BaseField>,
domain: &StarkDomain<E::BaseField>,
partition_option: PartitionOptions,
partition_options: PartitionOptions,
) -> (Self, TracePolyTable<E>) {
// extend the main execution trace and build a commitment to the extended trace
let (main_segment_lde, main_segment_vector_com, main_segment_polys) =
build_trace_commitment::<E, E::BaseField, H, V>(
main_trace,
domain,
partition_option.partition_size::<E::BaseField>(main_trace.num_cols()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compilation error. Needs to be partition_options.

partition_options.partition_size::<E::BaseField>(main_trace.num_cols()),
);

let trace_poly_table = TracePolyTable::new(main_segment_polys);
Expand All @@ -82,7 +82,7 @@ where
aux_segment_oracles: None,
blowup: domain.trace_to_lde_blowup(),
trace_info: trace_info.clone(),
partition_option,
partition_options,
_h: PhantomData,
};

Expand Down Expand Up @@ -148,11 +148,7 @@ where
) -> (ColMatrix<E>, H::Digest) {
// extend the auxiliary trace segment and build a commitment to the extended trace
let (aux_segment_lde, aux_segment_oracles, aux_segment_polys) =
build_trace_commitment::<E, E, H, Self::VC>(
aux_trace,
domain,
self.partition_option.partition_size::<E>(aux_trace.num_cols()),
);
build_trace_commitment::<E, E, H, Self::VC>(aux_trace, domain, self.partition_options);

// check errors
assert!(
Expand Down Expand Up @@ -276,7 +272,7 @@ where
fn build_trace_commitment<E, F, H, V>(
trace: &ColMatrix<F>,
domain: &StarkDomain<E::BaseField>,
partition_size: usize,
partition_options: PartitionOptions,
) -> (RowMatrix<F>, V, ColMatrix<F>)
where
E: FieldElement,
Expand Down Expand Up @@ -306,7 +302,7 @@ where
// build trace commitment
let commitment_domain_size = trace_lde.num_rows();
let trace_vector_com = info_span!("compute_execution_trace_commitment", commitment_domain_size)
.in_scope(|| trace_lde.commit_to_rows::<H, V>(partition_size));
.in_scope(|| trace_lde.commit_to_rows::<H, V>(partition_options));
assert_eq!(trace_vector_com.domain_len(), commitment_domain_size);

(trace_lde, trace_vector_com, trace_polys)
Expand Down
15 changes: 10 additions & 5 deletions verifier/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct VerifierChannel<
constraint_commitment: H::Digest,
constraint_queries: Option<ConstraintQueries<E, H, V>>,
// partition sizes for the rows of main, auxiliary and constraint traces rows
num_partitions: usize,
partition_size_main: usize,
partition_size_aux: usize,
partition_size_constraint: usize,
Expand Down Expand Up @@ -120,6 +121,7 @@ where
.map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?;

// --- compute the partition size for each trace ------------------------------------------
let num_partitions = partition_options.num_partitons();
let partition_size_main = partition_options
.partition_size::<E::BaseField>(air.context().trace_info().main_trace_width());
let partition_size_aux =
Expand All @@ -135,6 +137,7 @@ where
constraint_commitment,
constraint_queries: Some(constraint_queries),
// num partitions used in commitment
num_partitions,
partition_size_main,
partition_size_aux,
partition_size_constraint,
Expand Down Expand Up @@ -211,7 +214,9 @@ where
let items: Vec<H::Digest> = queries
.main_states
.rows()
.map(|row| hash_row::<H, E::BaseField>(row, self.partition_size_main))
.map(|row| {
hash_row::<H, E::BaseField>(row, self.partition_size_main, self.num_partitions)
})
.collect();

<V as VectorCommitment<H>>::verify_many(
Expand All @@ -225,7 +230,7 @@ where
if let Some(ref aux_states) = queries.aux_states {
let items: Vec<H::Digest> = aux_states
.rows()
.map(|row| hash_row::<H, E>(row, self.partition_size_aux))
.map(|row| hash_row::<H, E>(row, self.partition_size_aux, self.num_partitions))
.collect();

<V as VectorCommitment<H>>::verify_many(
Expand All @@ -252,7 +257,7 @@ where
let items: Vec<H::Digest> = queries
.evaluations
.rows()
.map(|row| hash_row::<H, E>(row, self.partition_size_constraint))
.map(|row| hash_row::<H, E>(row, self.partition_size_constraint, self.num_partitions))
.collect();

<V as VectorCommitment<H>>::verify_many(
Expand Down Expand Up @@ -437,15 +442,15 @@ where
// ================================================================================================

/// Hashes a row of a trace in batches where each batch is of size at most `partition_size`.
fn hash_row<H, E>(row: &[E], partition_size: usize) -> H::Digest
fn hash_row<H, E>(row: &[E], partition_size: usize, num_partitions: usize) -> H::Digest
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
{
if partition_size == row.len() * E::EXTENSION_DEGREE {
H::hash_elements(row)
} else {
let mut buffer = vec![H::Digest::default(); partition_size];
let mut buffer = vec![H::Digest::default(); num_partitions];

row.chunks(partition_size)
.zip(buffer.iter_mut())
Expand Down
Loading