From 49da5425ba7420947fdb848357ab133b1b5ddc9e Mon Sep 17 00:00:00 2001 From: Jonathan Hoyland Date: Mon, 25 Nov 2024 14:53:36 +0000 Subject: [PATCH] Make max_batch_size `Option`al --- crates/dapf/src/acceptance/mod.rs | 3 +- crates/daphne-server/tests/e2e/test_runner.rs | 6 +-- .../src/test_route_types.rs | 3 +- crates/daphne/src/lib.rs | 6 +-- crates/daphne/src/messages/taskprov.rs | 39 +++++++++++++++---- crates/daphne/src/roles/mod.rs | 12 ++++-- crates/daphne/src/taskprov.rs | 27 ++++++------- 7 files changed, 64 insertions(+), 32 deletions(-) diff --git a/crates/dapf/src/acceptance/mod.rs b/crates/dapf/src/acceptance/mod.rs index 49df4db3a..0a46706b1 100644 --- a/crates/dapf/src/acceptance/mod.rs +++ b/crates/dapf/src/acceptance/mod.rs @@ -41,6 +41,7 @@ use rand::{rngs, Rng}; use std::{ convert::TryFrom, env, + num::NonZeroU32, ops::Range, path::PathBuf, sync::atomic::{AtomicUsize, Ordering}, @@ -383,7 +384,7 @@ impl Test { lifetime: 60, min_batch_size: reports_per_batch.try_into().unwrap(), query: DapQueryConfig::FixedSize { - max_batch_size: Some(reports_per_batch.try_into().unwrap()), + max_batch_size: NonZeroU32::new(reports_per_batch.try_into().unwrap()), }, vdaf: self.vdaf_config, ..Default::default() diff --git a/crates/daphne-server/tests/e2e/test_runner.rs b/crates/daphne-server/tests/e2e/test_runner.rs index 1fabcde14..aeb6bad2c 100644 --- a/crates/daphne-server/tests/e2e/test_runner.rs +++ b/crates/daphne-server/tests/e2e/test_runner.rs @@ -23,7 +23,7 @@ use serde_json::json; use std::time::SystemTime; use std::{ any::{self, Any}, - num::NonZeroUsize, + num::{NonZeroU32, NonZeroUsize}, ops::Range, }; use tokio::time::timeout; @@ -31,7 +31,7 @@ use url::Url; const VDAF_CONFIG: &VdafConfig = &VdafConfig::Prio3(Prio3Config::Sum { bits: 10 }); pub(crate) const MIN_BATCH_SIZE: u64 = 10; -pub(crate) const MAX_BATCH_SIZE: u64 = 12; +pub(crate) const MAX_BATCH_SIZE: u32 = 12; pub(crate) const TIME_PRECISION: Duration = 3600; // seconds #[derive(Deserialize)] @@ -66,7 +66,7 @@ impl TestRunner { Self::with( version, &DapQueryConfig::FixedSize { - max_batch_size: Some(MAX_BATCH_SIZE), + max_batch_size: Some(NonZeroU32::new(MAX_BATCH_SIZE).unwrap()), }, ) .await diff --git a/crates/daphne-service-utils/src/test_route_types.rs b/crates/daphne-service-utils/src/test_route_types.rs index 659118899..904a9296f 100644 --- a/crates/daphne-service-utils/src/test_route_types.rs +++ b/crates/daphne-service-utils/src/test_route_types.rs @@ -10,6 +10,7 @@ use daphne::{ vdaf::{Prio3Config, VdafConfig}, }; use serde::{Deserialize, Serialize}; +use std::num::NonZeroU32; use url::Url; #[derive(Deserialize)] @@ -86,7 +87,7 @@ pub struct InternalTestAddTask { pub query_type: u8, pub min_batch_size: u64, #[serde(skip_serializing_if = "Option::is_none")] - pub max_batch_size: Option, + pub max_batch_size: Option, pub time_precision: Duration, pub collector_hpke_config: String, // base64url pub task_expiration: Time, diff --git a/crates/daphne/src/lib.rs b/crates/daphne/src/lib.rs index 2ccd22733..e081434f1 100644 --- a/crates/daphne/src/lib.rs +++ b/crates/daphne/src/lib.rs @@ -83,7 +83,7 @@ use std::{ cmp::{max, min}, collections::{HashMap, HashSet}, fmt::Debug, - num::NonZeroUsize, + num::{NonZeroU32, NonZeroUsize}, str::FromStr, }; use url::Url; @@ -237,7 +237,7 @@ pub enum DapQueryConfig { /// Aggregators are meant to stop aggregating reports when this limit is reached. FixedSize { #[serde(default)] - max_batch_size: Option, + max_batch_size: Option, }, } @@ -730,7 +730,7 @@ impl DapTaskConfig { DapQueryConfig::FixedSize { max_batch_size: Some(max_batch_size), } => { - if report_count > max_batch_size { + if report_count > u64::from(max_batch_size.get()) { return Err(DapAbort::InvalidBatchSize { detail: format!( "Report count ({report_count}) exceeds maximum ({max_batch_size})" diff --git a/crates/daphne/src/messages/taskprov.rs b/crates/daphne/src/messages/taskprov.rs index 8d26cd21f..d833cbe50 100644 --- a/crates/daphne/src/messages/taskprov.rs +++ b/crates/daphne/src/messages/taskprov.rs @@ -15,6 +15,7 @@ use prio::codec::{ }; use serde::{Deserialize, Serialize}; use std::io::{Cursor, Read}; +use std::num::NonZeroU32; use super::{ decode_base64url_vec, decode_u16_prefixed, encode_base64url, encode_u16_prefixed, TaskId, @@ -318,7 +319,7 @@ impl Decode for UrlBytes { #[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] pub enum QueryConfigVar { TimeInterval, - FixedSize { max_batch_size: u32 }, + FixedSize { max_batch_size: Option }, NotImplemented { typ: u8, param: Vec }, } @@ -363,7 +364,13 @@ impl ParameterizedEncode for QueryConfig { match &self.var { QueryConfigVar::TimeInterval => (), QueryConfigVar::FixedSize { max_batch_size } => { - max_batch_size.encode(bytes)?; + match version { + DapVersion::Draft09 => match max_batch_size { + Some(x) => x.get().encode(bytes)?, + None => 0u32.encode(bytes)?, + }, + DapVersion::Latest => (), + }; } QueryConfigVar::NotImplemented { typ: _, param } => { bytes.extend_from_slice(param); @@ -393,7 +400,10 @@ impl ParameterizedDecode<(DapVersion, Option)> for QueryConfig { match (bytes_left, query_type) { (.., QUERY_TYPE_TIME_INTERVAL) => QueryConfigVar::TimeInterval, (.., QUERY_TYPE_FIXED_SIZE) => QueryConfigVar::FixedSize { - max_batch_size: u32::decode(bytes)?, + max_batch_size: match version { + DapVersion::Draft09 => NonZeroU32::new(u32::decode(bytes)?), + DapVersion::Latest => None, + }, }, (Some(bytes_left), ..) => { let mut param = vec![0; bytes_left - fixed_size]; @@ -549,7 +559,12 @@ mod tests { DapVersion::Latest => 1, }, min_batch_size: 55, - var: QueryConfigVar::FixedSize { max_batch_size: 57 }, + var: QueryConfigVar::FixedSize { + max_batch_size: match version { + DapVersion::Draft09 => Some(NonZeroU32::new(57).unwrap()), + DapVersion::Latest => None, + }, + }, }, task_expiration: 23_232_232_232, vdaf_config: VdafConfig { @@ -576,8 +591,8 @@ mod tests { 101, 46, 99, 111, 109, 47, 118, 48, 50, 0, 42, 104, 116, 116, 112, 115, 58, 47, 47, 115, 111, 109, 101, 115, 101, 114, 118, 105, 99, 101, 46, 99, 108, 111, 117, 100, 102, 108, 97, 114, 101, 114, 101, 115, 101, 97, 114, 99, 104, 46, 99, 111, 109, 0, - 17, 0, 0, 0, 0, 0, 188, 79, 242, 0, 0, 0, 55, 2, 0, 0, 0, 57, 0, 0, 0, 5, 104, 191, - 187, 40, 0, 11, 0, 1, 1, 255, 255, 0, 0, 0, 1, 134, 159, + 13, 0, 0, 0, 0, 0, 188, 79, 242, 0, 0, 0, 55, 2, 0, 0, 0, 5, 104, 191, 187, 40, 0, + 11, 0, 1, 1, 255, 255, 0, 0, 0, 1, 134, 159, ] .as_slice(), }; @@ -613,7 +628,10 @@ mod tests { }, min_batch_size: 12_345_678, var: QueryConfigVar::FixedSize { - max_batch_size: 12_345_678, + max_batch_size: match version { + DapVersion::Draft09 => Some(NonZeroU32::new(12_345_678).unwrap()), + DapVersion::Latest => None, + }, }, }; let encoded = query_config.get_encoded_with_param(&version).unwrap(); @@ -783,7 +801,12 @@ mod tests { DapVersion::Latest => 1, }, min_batch_size: 55, - var: QueryConfigVar::FixedSize { max_batch_size: 57 }, + var: QueryConfigVar::FixedSize { + max_batch_size: match version { + DapVersion::Draft09 => Some(NonZeroU32::new(57).unwrap()), + DapVersion::Latest => None, + }, + }, }, task_expiration: 23_232_232_232, vdaf_config: VdafConfig { diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index 3afc2ca1b..073f95188 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -153,7 +153,13 @@ mod test { #[cfg(feature = "experimental")] use prio::{idpf::IdpfInput, vdaf::poplar1::Poplar1AggregationParam}; use rand::{thread_rng, Rng}; - use std::{collections::HashMap, num::NonZeroUsize, sync::Arc, time::SystemTime, vec}; + use std::{ + collections::HashMap, + num::{NonZeroU32, NonZeroUsize}, + sync::Arc, + time::SystemTime, + vec, + }; use url::Url; pub(super) struct TestData { @@ -235,7 +241,7 @@ mod test { not_after: now + Self::TASK_TIME_PRECISION, min_batch_size: 1, query: DapQueryConfig::FixedSize { - max_batch_size: Some(2), + max_batch_size: Some(NonZeroU32::new(2).unwrap()), }, vdaf: vdaf_config, vdaf_verify_key: vdaf_config.gen_verify_key(), @@ -1416,7 +1422,7 @@ mod test { version, min_batch_size: 1, query: DapQueryConfig::FixedSize { - max_batch_size: Some(2), + max_batch_size: Some(NonZeroU32::new(2).unwrap()), }, vdaf: vdaf_config, ..Default::default() diff --git a/crates/daphne/src/taskprov.rs b/crates/daphne/src/taskprov.rs index bb8cb4b11..5d70cf38c 100644 --- a/crates/daphne/src/taskprov.rs +++ b/crates/daphne/src/taskprov.rs @@ -122,12 +122,9 @@ fn url_from_bytes(task_id: &TaskId, url_bytes: &[u8]) -> Result { impl DapQueryConfig { fn try_from_taskprov(task_id: &TaskId, var: QueryConfigVar) -> Result { match var { - QueryConfigVar::FixedSize { max_batch_size: 0 } => Ok(DapQueryConfig::FixedSize { - max_batch_size: None, - }), - QueryConfigVar::FixedSize { max_batch_size } => Ok(DapQueryConfig::FixedSize { - max_batch_size: Some(max_batch_size.into()), - }), + QueryConfigVar::FixedSize { max_batch_size } => { + Ok(DapQueryConfig::FixedSize { max_batch_size }) + } QueryConfigVar::TimeInterval => Ok(DapQueryConfig::TimeInterval), QueryConfigVar::NotImplemented { typ, .. } => Err(DapAbort::InvalidTask { detail: format!("unimplemented query type ({typ})"), @@ -359,9 +356,7 @@ impl TryFrom<&DapQueryConfig> for messages::taskprov::QueryConfigVar { DapQueryConfig::TimeInterval => messages::taskprov::QueryConfigVar::TimeInterval, DapQueryConfig::FixedSize { max_batch_size } => { messages::taskprov::QueryConfigVar::FixedSize { - max_batch_size: max_batch_size.unwrap_or(0).try_into().map_err(|_| { - fatal_error!(err = "task max batch size is too large for taskprov") - })?, + max_batch_size: *max_batch_size, } } }) @@ -448,7 +443,7 @@ impl TryFrom<&DapTaskConfig> for messages::taskprov::TaskprovAdvertisement { #[cfg(test)] mod test { - use std::num::NonZeroUsize; + use std::num::{NonZeroU32, NonZeroUsize}; use prio::codec::ParameterizedEncode; @@ -477,7 +472,9 @@ mod test { time_precision: 3600, max_batch_query_count: 1, min_batch_size: 1, - var: messages::taskprov::QueryConfigVar::FixedSize { max_batch_size: 2 }, + var: messages::taskprov::QueryConfigVar::FixedSize { + max_batch_size: Some(NonZeroU32::new(2).unwrap()), + }, }, task_expiration: 1337, vdaf_config: messages::taskprov::VdafConfig { @@ -557,7 +554,9 @@ mod test { time_precision: 3600, max_batch_query_count: 1, min_batch_size: 1, - var: messages::taskprov::QueryConfigVar::FixedSize { max_batch_size: 2 }, + var: messages::taskprov::QueryConfigVar::FixedSize { + max_batch_size: Some(NonZeroU32::new(2).unwrap()), + }, }, task_expiration: 0, vdaf_config: messages::taskprov::VdafConfig { @@ -622,7 +621,9 @@ mod test { time_precision: 3600, max_batch_query_count: 1, min_batch_size: 1, - var: messages::taskprov::QueryConfigVar::FixedSize { max_batch_size: 2 }, + var: messages::taskprov::QueryConfigVar::FixedSize { + max_batch_size: Some(NonZeroU32::new(2).unwrap()), + }, }, task_expiration: 0, vdaf_config: messages::taskprov::VdafConfig {