Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dzejkop/migrate-codes-streaming-impl #78

Merged
merged 10 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 0 additions & 78 deletions bin/migrate-codes/iris_db.rs

This file was deleted.

205 changes: 157 additions & 48 deletions bin/migrate-codes/migrate.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
#![allow(clippy::type_complexity)]

use clap::Parser;
use iris_db::IrisCodeEntry;
use eyre::ContextCompat;
use futures::{pin_mut, Stream, StreamExt};
use indicatif::{ProgressBar, ProgressStyle};
use mpc::bits::Bits;
use mpc::db::Db;
use mpc::distance::EncodedBits;
use mpc::iris_db::{
FinalResult, IrisCodeEntry, IrisDb, SideResult, FINAL_RESULT_STATUS,
};
use mpc::template::Template;
use rand::thread_rng;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use telemetry_batteries::tracing::stdout::StdoutBattery;

use crate::iris_db::IrisDb;
use crate::mpc_db::MPCDb;

mod iris_db;
mod mpc_db;

#[derive(Parser)]
Expand All @@ -32,12 +37,19 @@ pub struct Args {
#[clap(alias = "rp", long, env)]
pub right_participant_db: Vec<String>,

#[clap(long, env, default_value = "10000")]
pub batch_size: usize,
/// Batch size for encoding shares
#[clap(short, long, default_value = "100")]
batch_size: usize,

/// If set to true, no migration or creation of the database will occur on the Postgres side
#[clap(long)]
no_migrate_or_create: bool,
}

#[tokio::main]
async fn main() -> eyre::Result<()> {
dotenv::dotenv().ok();

let _shutdown_tracing_provider = StdoutBattery::init();

let args = Args::parse();
Expand All @@ -56,34 +68,63 @@ async fn main() -> eyre::Result<()> {
args.left_participant_db,
args.right_coordinator_db,
args.right_participant_db,
args.no_migrate_or_create,
)
.await?;

let iris_db = IrisDb::new(args.iris_code_db).await?;

let iris_code_entries = iris_db.get_iris_code_snapshot().await?;
let (left_templates, right_templates) =
extract_templates(iris_code_entries);
let latest_serial_id = mpc_db.fetch_latest_serial_id().await?;
tracing::info!("Latest serial id {latest_serial_id}");

let mut next_serial_id = mpc_db.fetch_latest_serial_id().await? + 1;
// Cleanup items with larger ids
// as they might be assigned new values in the future
mpc_db.prune_items(latest_serial_id).await?;
iris_db.prune_final_results(latest_serial_id).await?;

let left_data = encode_shares(left_templates, num_participants as usize)?;
let right_data = encode_shares(right_templates, num_participants as usize)?;
let first_unsynced_iris_serial_id = if let Some(final_result) = iris_db
.get_final_result_by_serial_id(latest_serial_id)
.await?
{
iris_db
.get_entry_by_signup_id(&final_result.signup_id)
.await?
.context("Could not find iris code entry")?
.serial_id
} else {
0
};

insert_masks_and_shares(
&left_data,
&mpc_db.left_coordinator_db,
&mpc_db.left_participant_dbs,
)
.await?;
let num_iris_codes = iris_db
.count_whitelisted_iris_codes(first_unsynced_iris_serial_id)
.await?;
tracing::info!("Processing {} iris codes", num_iris_codes);

insert_masks_and_shares(
&right_data,
&mpc_db.right_coordinator_db,
&mpc_db.right_participant_dbs,
let pb =
ProgressBar::new(num_iris_codes).with_message("Migrating iris codes");
let pb_style = ProgressStyle::default_bar()
.template("{spinner:.green} {msg} [{elapsed_precise}] [{wide_bar:.green}] {pos:>7}/{len:7} ({eta})")
.expect("Could not create progress bar");
pb.set_style(pb_style);

let iris_code_entries = iris_db
.stream_whitelisted_iris_codes(first_unsynced_iris_serial_id)
.await?
.chunks(args.batch_size)
.map(|chunk| chunk.into_iter().collect::<Result<Vec<_>, _>>());

handle_templates_stream(
iris_code_entries,
&iris_db,
&mpc_db,
num_participants as usize,
latest_serial_id,
&pb,
)
.await?;

pb.finish();

Ok(())
}

Expand All @@ -95,7 +136,7 @@ pub struct MPCIrisData {
pub fn encode_shares(
template_data: Vec<(usize, Template)>,
num_participants: usize,
) -> eyre::Result<(Vec<(usize, Bits, Box<[EncodedBits]>)>)> {
) -> eyre::Result<Vec<(usize, Bits, Box<[EncodedBits]>)>> {
let iris_data = template_data
.into_par_iter()
.map(|(serial_id, template)| {
Expand All @@ -111,32 +152,100 @@ pub fn encode_shares(
Ok(iris_data)
}

pub fn extract_templates(
iris_code_snapshot: Vec<IrisCodeEntry>,
) -> (Vec<(usize, Template)>, Vec<(usize, Template)>) {
let (left_templates, right_templates) = iris_code_snapshot
.into_iter()
.map(|entry| {
(
(
entry.mpc_serial_id as usize,
Template {
code: entry.iris_code_left,
mask: entry.mask_code_left,
},
),
(
entry.mpc_serial_id as usize,
Template {
code: entry.iris_code_right,
mask: entry.mask_code_right,
},
),
)
})
.unzip();
async fn handle_templates_stream(
iris_code_entries: impl Stream<
Item = mongodb::error::Result<Vec<IrisCodeEntry>>,
>,
iris_db: &IrisDb,
mpc_db: &MPCDb,
num_participants: usize,
mut latest_serial_id: u64,
pb: &ProgressBar,
) -> eyre::Result<()> {
pin_mut!(iris_code_entries);

// Consume the stream
while let Some(entries) = iris_code_entries.next().await {
let entries = entries?;

let count = entries.len() as u64;

let left_data: Vec<_> = entries
.iter()
.enumerate()
.map(|(idx, entry)| {
let template = Template {
code: entry.iris_code_left,
mask: entry.mask_code_left,
};

(latest_serial_id as usize + 1 + idx, template)
})
.collect();

let right_data: Vec<_> = entries
.iter()
.enumerate()
.map(|(idx, entry)| {
let template = Template {
code: entry.iris_code_right,
mask: entry.mask_code_right,
};

(latest_serial_id as usize + 1 + idx, template)
})
.collect();

let left = handle_side_data_chunk(
left_data,
num_participants,
&mpc_db.left_coordinator_db,
&mpc_db.left_participant_dbs,
);

let right = handle_side_data_chunk(
right_data,
num_participants,
&mpc_db.right_coordinator_db,
&mpc_db.right_participant_dbs,
);

let final_results: Vec<_> = entries
.iter()
.enumerate()
.map(|(idx, entry)| FinalResult {
status: FINAL_RESULT_STATUS.to_string(),
serial_id: latest_serial_id + 1 + idx as u64,
signup_id: entry.signup_id.clone(),
unique: true,
right_result: SideResult {},
left_result: SideResult {},
})
.collect();

let results = iris_db.save_final_results(&final_results);

futures::try_join!(left, right, results)?;

latest_serial_id += count;
pb.inc(count);
}

(left_templates, right_templates)
Ok(())
}

async fn handle_side_data_chunk(
templates: Vec<(usize, Template)>,
num_participants: usize,
coordinator_db: &Db,
participant_dbs: &[Db],
) -> eyre::Result<()> {
let left_data = encode_shares(templates, num_participants)?;

insert_masks_and_shares(&left_data, coordinator_db, participant_dbs)
.await?;

Ok(())
}

async fn insert_masks_and_shares(
Expand All @@ -147,7 +256,7 @@ async fn insert_masks_and_shares(
// Insert masks
let left_masks: Vec<_> = data
.iter()
.map(|(serial_id, mask, _)| (*serial_id as u64, mask.clone()))
.map(|(serial_id, mask, _)| (*serial_id as u64, *mask))
.collect();

coordinator_db.insert_masks(&left_masks).await?;
Expand Down
Loading
Loading