Skip to content

Commit

Permalink
Allow configuring RNG (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dzejkop authored Feb 29, 2024
1 parent 36b3edc commit cbc96ef
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 14 deletions.
8 changes: 4 additions & 4 deletions bin/e2e/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,11 @@ pub async fn seed_db_sync(
participant_db_sync_queues: &[&str],
template: Template,
serial_id: u64,
rng: &mut impl Rng,
) -> eyre::Result<()> {
tracing::info!("Encoding shares");
let shares: Box<[EncodedBits]> = mpc::distance::encode(&template)
.share(participant_db_sync_queues.len());
.share(participant_db_sync_queues.len(), rng);

let coordinator_payload =
serde_json::to_string(&vec![coordinator::DbSyncPayload {
Expand Down Expand Up @@ -209,9 +210,8 @@ async fn wait_for_empty_queue(
}
}

pub fn generate_random_string(len: usize) -> String {
rand::thread_rng()
.sample_iter(&Alphanumeric)
pub fn generate_random_string(len: usize, rng: &mut impl Rng) -> String {
rng.sample_iter(&Alphanumeric)
.take(len)
.map(char::from)
.collect()
Expand Down
9 changes: 8 additions & 1 deletion bin/e2e/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use eyre::ContextCompat;
use mpc::config::{load_config, AwsConfig, DbConfig};
use mpc::coordinator::UniquenessCheckResult;
use mpc::db::Db;
use mpc::rng_source::RngSource;
use mpc::template::{Bits, Template};
use mpc::utils::aws::{self, sqs_client_from_config};
use serde::Deserialize;
Expand Down Expand Up @@ -44,6 +45,9 @@ struct Args {
/// The path to the signup sequence file to use
#[clap(short, long, default_value = "bin/e2e/signup_sequence.json")]
signup_sequence: String,

#[clap(short, long, env, default_value = "thread")]
rng: RngSource,
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -74,6 +78,8 @@ async fn main() -> eyre::Result<()> {
tracing::warn!("AWS_DEFAULT_REGION not set");
}

let mut rng = args.rng.to_rng();

let _shutdown_tracing_provider = StdoutBattery::init();

tracing::info!("Loading config");
Expand Down Expand Up @@ -132,7 +138,7 @@ async fn main() -> eyre::Result<()> {
&sqs_client,
&config.coordinator_queue.query_queue,
&element.signup_id,
&common::generate_random_string(4),
&common::generate_random_string(4, &mut rng),
)
.await?;

Expand Down Expand Up @@ -177,6 +183,7 @@ async fn main() -> eyre::Result<()> {
&participant_db_sync_queues,
template,
next_serial_id,
&mut rng,
)
.await?;

Expand Down
8 changes: 6 additions & 2 deletions bin/utils/seed_iris_db.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use clap::Args;
use mpc::bits::Bits;
use mpc::rng_source::RngSource;
use mpc::template::Template;
use rand::{thread_rng, Rng};
use rand::Rng;
use serde::{Deserialize, Serialize};

use crate::generate_random_string;
Expand All @@ -19,6 +20,9 @@ pub struct SeedIrisDb {

#[clap(short, long, default_value = "10000")]
pub batch_size: usize,

#[clap(short, long, env, default_value = "thread")]
pub rng: RngSource,
}

pub async fn seed_iris_db(args: &SeedIrisDb) -> eyre::Result<()> {
Expand All @@ -30,7 +34,7 @@ pub async fn seed_iris_db(args: &SeedIrisDb) -> eyre::Result<()> {

let iris_db = client.database(DATABASE_NAME);

let mut rng = thread_rng();
let mut rng = args.rng.to_rng();

tracing::info!("Generating codes");
let left_templates = (0..args.num_templates)
Expand Down
12 changes: 8 additions & 4 deletions bin/utils/seed_mpc_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use clap::Args;
use indicatif::ProgressBar;
use mpc::config::DbConfig;
use mpc::db::Db;
use mpc::rng_source::RngSource;
use mpc::template::Template;
use rand::{thread_rng, Rng};
use rand::Rng;

#[derive(Debug, Clone, Args)]
pub struct SeedMPCDb {
Expand All @@ -18,6 +19,9 @@ pub struct SeedMPCDb {

#[clap(short, long, default_value = "10000")]
pub batch_size: usize,

#[clap(short, long, env, default_value = "thread")]
pub rng: RngSource,
}

pub async fn seed_mpc_db(args: &SeedMPCDb) -> eyre::Result<()> {
Expand All @@ -31,7 +35,7 @@ pub async fn seed_mpc_db(args: &SeedMPCDb) -> eyre::Result<()> {
let pb = ProgressBar::new(args.num_templates as u64)
.with_message("Generating templates");

let mut rng = thread_rng();
let mut rng = args.rng.to_rng();

for _ in 0..args.num_templates {
templates.push(rng.gen());
Expand Down Expand Up @@ -81,8 +85,8 @@ pub async fn seed_mpc_db(args: &SeedMPCDb) -> eyre::Result<()> {
let pb = ProgressBar::new(chunk.len() as u64)
.with_message("Encoding shares");
for (offset, template) in chunk.iter().enumerate() {
let shares =
mpc::distance::encode(template).share(participant_dbs.len());
let shares = mpc::distance::encode(template)
.share(participant_dbs.len(), &mut rng);

let id = offset + (idx * args.batch_size);

Expand Down
7 changes: 4 additions & 3 deletions src/encoded_bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytemuck::{cast_slice_mut, Pod, Zeroable};
use rand::distributions::{Distribution, Standard};
use rand::{thread_rng, Rng};
use rand::Rng;
use serde::de::Error as _;
use serde::{Deserialize, Deserializer, Serialize};

Expand All @@ -23,11 +23,10 @@ unsafe impl Pod for EncodedBits {}

impl EncodedBits {
/// Generate secret shares from this bitvector.
pub fn share(&self, n: usize) -> Box<[EncodedBits]> {
pub fn share(&self, n: usize, rng: &mut impl Rng) -> Box<[EncodedBits]> {
assert!(n > 0);

// Create `n - 1` random shares.
let mut rng = thread_rng();
let mut result: Box<[EncodedBits]> =
iter::repeat_with(|| rng.gen::<EncodedBits>())
.take(n - 1)
Expand Down Expand Up @@ -210,6 +209,8 @@ impl Serialize for EncodedBits {

#[cfg(test)]
mod tests {
use rand::thread_rng;

use super::*;

#[test]
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ pub mod distance;
pub mod encoded_bits;
pub mod health_check;
pub mod participant;
pub mod rng_source;
pub mod template;
pub mod utils;
73 changes: 73 additions & 0 deletions src/rng_source.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use std::fmt;
use std::str::FromStr;

use rand::{thread_rng, RngCore, SeedableRng};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "kind")]
pub enum RngSource {
Thread,
Small(u64),
Std(u64),
}

impl RngSource {
pub fn to_rng(&self) -> Box<dyn RngCore> {
match self {
RngSource::Thread => Box::new(thread_rng()),
RngSource::Small(seed) => {
let rng: rand::rngs::SmallRng =
SeedableRng::seed_from_u64(*seed);
Box::new(rng)
}
RngSource::Std(seed) => {
let rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(*seed);
Box::new(rng)
}
}
}
}

impl fmt::Display for RngSource {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
RngSource::Thread => write!(f, "thread"),
RngSource::Small(seed) => write!(f, "small:{}", seed),
RngSource::Std(seed) => write!(f, "std:{}", seed),
}
}
}

impl FromStr for RngSource {
type Err = eyre::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s == "thread" {
Ok(RngSource::Thread)
} else if s.starts_with("small:") {
let seed = s.trim_start_matches("small:").parse()?;
Ok(RngSource::Small(seed))
} else if s.starts_with("std:") {
let seed = s.trim_start_matches("std:").parse()?;
Ok(RngSource::Std(seed))
} else {
Err(eyre::eyre!("Invalid RngSource: {}", s))
}
}
}

#[cfg(test)]
mod tests {
use test_case::test_case;

use super::*;

#[test_case("thread" => RngSource::Thread)]
#[test_case("std:42" => RngSource::Std(42))]
#[test_case("small:42" => RngSource::Small(42))]
fn serialization_round_trip(s: &str) -> RngSource {
s.parse().unwrap()
}
}

0 comments on commit cbc96ef

Please sign in to comment.