Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
robinhundt committed Aug 14, 2024
1 parent 45f3dec commit 5b78fd4
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 49 deletions.
66 changes: 56 additions & 10 deletions crates/seec/examples/bristol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@
use anyhow::{Context, Result};
use clap::{Args, Parser};
use seec::bench::{BenchParty, ServerTlsConfig};
use rand::distributions::Standard;
use rand::prelude::Distribution;
use seec::bench::{BenchParty, BenchProtocol, ServerTlsConfig};
use seec::circuit::base_circuit::Load;
use seec::circuit::{BaseCircuit, ExecutableCircuit};
use seec::protocols::boolean_gmw::BooleanGmw;
use seec::gate::base::BaseGate;
use seec::protocols::{Gate, Share};
use seec::protocols::{aby2, aby2::BooleanAby2, boolean_gmw::BooleanGmw, Protocol};
use seec::secret::inputs;
use seec::SubCircuitOutput;
use seec::{BooleanGate, CircuitBuilder};
use seec::{bristol, SubCircuitOutput};
use seec::{BooleanGate, CircuitBuilder,};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io;
use std::io::{stdout, BufReader, BufWriter, Write};
Expand Down Expand Up @@ -48,6 +54,9 @@ struct CompileArgs {

/// Circuit in bristol format
circuit: PathBuf,

#[clap(long)]
aby2: bool
}

#[derive(Args, Debug)]
Expand Down Expand Up @@ -94,15 +103,29 @@ struct ExecuteArgs {

/// Circuit to execute. Must be compiled beforehand
circuit: PathBuf,

#[clap(long)]
aby2: bool
}

#[tokio::main]
async fn main() -> Result<()> {
let prog_args = ProgArgs::parse();
init_tracing(&prog_args).context("failed to init logging")?;
match prog_args {
ProgArgs::Compile(args) => compile(args).context("failed to compile circuit"),
ProgArgs::Execute(args) => execute(args).await.context("failed to execute circuit"),
ProgArgs::Compile(args) => {
if args.aby2 {
compile_aby2(args).context("failed to compile aby2 circuit")
} else {
compile(args).context("failed to compile circuit")
}
},
ProgArgs::Execute(args) => if args.aby2 {
execute::<BooleanAby2>(args).await.context("failed to execute circuit")

} else {
execute::<BooleanGmw>(args).await.context("failed to execute circuit")
},
}
}

Expand Down Expand Up @@ -145,6 +168,22 @@ fn compile(compile_args: CompileArgs) -> Result<()> {
Ok(())
}

fn compile_aby2(compile_args: CompileArgs) -> Result<()> {
assert!(compile_args.simd.is_none());
let mut bc: BaseCircuit<bool, aby2::BooleanGate> =
BaseCircuit::load_bristol(&compile_args.circuit, Load::Circuit)
.expect("failed to load bristol circuit");

let mut circ = ExecutableCircuit::DynLayers(bc.into());
if !compile_args.dyn_layers {
circ = circ.precompute_layers();
}
let out =
BufWriter::new(File::create(&compile_args.output).context("failed to create output file")?);
bincode::serialize_into(out, &circ).context("failed to serialize circuit")?;
Ok(())
}

impl ProgArgs {
fn log(&self) -> Option<&PathBuf> {
match self {
Expand All @@ -154,17 +193,24 @@ impl ProgArgs {
}
}

async fn execute(execute_args: ExecuteArgs) -> Result<()> {
async fn execute<P>(execute_args: ExecuteArgs) -> Result<()>
where
P: BenchProtocol<Plain = bool>,
<P as Protocol>::Gate: Gate<bool> + DeserializeOwned + Serialize + DeserializeOwned + From<BaseGate<bool>> + for<'a> From<&'a bristol::Gate> + From<BaseGate<bool>>,
Standard: Distribution<P::Share>,
P::Share: Share<SimdShare = P::ShareStorage>,

{
let circ_name = execute_args
.circuit
.file_stem()
.unwrap()
.to_string_lossy()
.to_string();
let circuit = load_circ(&execute_args).context("failed to load circuit")?;
let circuit = load_circ::<P::Gate>(&execute_args).context("failed to load circuit")?;

let create_party = |id, circ| {
let mut party = BenchParty::<BooleanGmw, u32>::new(id)
let mut party = BenchParty::<P, u32>::new(id)
.explicit_circuit(circ)
.repeat(execute_args.repeat)
.insecure_setup(execute_args.insecure_setup)
Expand Down Expand Up @@ -222,7 +268,7 @@ async fn execute(execute_args: ExecuteArgs) -> Result<()> {
Ok(())
}

fn load_circ(args: &ExecuteArgs) -> Result<ExecutableCircuit<bool, BooleanGate, u32>> {
fn load_circ<G>(args: &ExecuteArgs) -> Result<ExecutableCircuit<bool, G, u32>> where G: Gate<bool> + DeserializeOwned + Serialize + for<'a> From<&'a bristol::Gate> + From<BaseGate<bool>> {
let res = bincode::deserialize_from(BufReader::new(
File::open(&args.circuit).context("Failed to open circuit file")?,
));
Expand Down
123 changes: 89 additions & 34 deletions crates/seec/src/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@
//! `crates/seec/examples/bristol.rs` binary.
use crate::circuit::{ExecutableCircuit, GateIdx};
use crate::executor::{Executor, Input, Message};
use crate::executor::{DynFDSetup, Executor, Input, Message};
use crate::mul_triple;
use crate::mul_triple::storage::MTStorage;
use crate::mul_triple::{boolean, MTProvider};
use crate::protocols;
#[cfg(feature = "aby2")]
use crate::protocols::aby2::{AbySetupMsg, AbySetupProvider, BooleanAby2, DeltaSharing};
use crate::protocols::boolean_gmw::BooleanGmw;
use crate::protocols::mixed_gmw::{Mixed, MixedGmw};
use crate::protocols::{mixed_gmw, Protocol, Ring, Share, ShareStorage};
use crate::utils::{BoxError, ErasedError};
use crate::protocols::{mixed_gmw, FunctionDependentSetup, Protocol, Ring, Share, ShareStorage};
use crate::utils::BoxError;
use crate::CircuitBuilder;
use anyhow::{anyhow, Context};
use bitvec::view::BitViewSized;
use rand::distributions::{Distribution, Standard};
use rand::rngs::OsRng;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use remoc::RemoteSend;
use seec_channel::util::{Phase, RunResult, Statistics};
use seec_channel::{sub_channels_for, Channel, Sender};
use serde::{Deserialize, Serialize};
Expand All @@ -29,27 +34,41 @@ use std::io::BufReader;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::time::Duration;
use zappot::ot_ext::ExtOTMsg;

type DynMTP<P> =
Box<dyn MTProvider<Output = <P as Protocol>::SetupStorage, Error = BoxError> + Send + 'static>;

pub trait BenchProtocol: Protocol + Default + Debug {
pub trait BenchProtocol: Protocol + Default + Debug + 'static {
const FUNCTION_DEPENDENT_SETUP: bool;
type SetupMsg: RemoteSend;

fn insecure_setup() -> DynMTP<Self>;
fn ot_setup(ch: Channel<Sender<ExtOTMsg>>) -> DynMTP<Self>;
fn ot_setup(ch: Channel<Self::SetupMsg>) -> DynMTP<Self>;
fn stored(path: &Path) -> DynMTP<Self>;

fn fd_setup<Idx: GateIdx>(
party_id: usize,
ch: Channel<Self::SetupMsg>,
) -> DynFDSetup<'static, Self, Idx> {
panic!("Needs to be implemented for Protocols with FUNCTION_DEPENDENT_SETUP = true")
}
}

impl BenchProtocol for BooleanGmw {
const FUNCTION_DEPENDENT_SETUP: bool = false;
type SetupMsg = mul_triple::boolean::ot_ext::DefaultMsg;

fn insecure_setup() -> DynMTP<Self> {
Box::new(ErasedError(boolean::InsecureMTProvider::default()))
Box::new(mul_triple::ErasedError(
boolean::InsecureMTProvider::default(),
))
}

fn ot_setup(ch: Channel<Sender<ExtOTMsg>>) -> DynMTP<Self> {
fn ot_setup(ch: Channel<Self::SetupMsg>) -> DynMTP<Self> {
let ot_sender = zappot::ot_ext::Sender::default();
let ot_recv = zappot::ot_ext::Receiver::default();
let mtp = boolean::OtMTProvider::new(OsRng, ot_sender, ot_recv, ch.0, ch.1);
Box::new(ErasedError(mtp))
Box::new(mul_triple::ErasedError(mtp))
}

fn stored(path: &Path) -> DynMTP<Self> {
Expand All @@ -64,17 +83,54 @@ where
Standard: Distribution<R>,
[R; 1]: BitViewSized,
{
const FUNCTION_DEPENDENT_SETUP: bool = false;
type SetupMsg = ();

fn insecure_setup() -> DynMTP<Self> {
mixed_gmw::InsecureMixedSetup::default().into_dyn()
}

fn ot_setup(_ch: Channel<Sender<ExtOTMsg>>) -> DynMTP<Self> {
fn ot_setup(_ch: Channel<Self::SetupMsg>) -> DynMTP<Self> {
todo!()
}

fn stored(_path: &Path) -> DynMTP<Self> {
todo!()
}
}

#[cfg(feature = "aby2")]
impl Default for BooleanAby2 {
fn default() -> Self {
BooleanAby2::new(DeltaSharing::insecure_default())
}
}

#[cfg(feature = "aby2")]
impl BenchProtocol for BooleanAby2 {
const FUNCTION_DEPENDENT_SETUP: bool = true;
type SetupMsg = AbySetupMsg;

fn insecure_setup() -> DynMTP<Self> {
todo!()
}

fn ot_setup(_ch: Channel<Self::SetupMsg>) -> DynMTP<Self> {
todo!()
}

fn stored(_path: &Path) -> DynMTP<Self> {
todo!()
}

fn fd_setup<Idx: GateIdx>(
party_id: usize,
ch: Channel<Self::SetupMsg>,
) -> DynFDSetup<'static, Self, Idx> {
let setup =
AbySetupProvider::new(party_id, boolean::InsecureMTProvider::default(), ch.0, ch.1);
Box::new(protocols::ErasedError(setup))
}
}

// TODO this is wrong to just always generate arith shares, so it lives here in the bench API
Expand Down Expand Up @@ -234,15 +290,10 @@ where
.with_sleep(self.sleep_after_phase)
.without_unaccounted(true);

let (ot_ch, mut exec_ch) = sub_channels_for!(
&mut sender,
&mut receiver,
128,
Sender<ExtOTMsg>,
Message<P>
)
.await
.context("Establishing sub channels")?;
let (setup_ch, mut exec_ch) =
sub_channels_for!(&mut sender, &mut receiver, 128, P::SetupMsg, Message<P>)
.await
.context("Establishing sub channels")?;

let circ = match &self.circ {
Some(circ) => circ,
Expand All @@ -257,26 +308,30 @@ where
}
}
};

let mut mtp = match (self.insecure_setup, &self.stored_mts) {
(false, None) => P::ot_setup(ot_ch),
(true, None) => P::insecure_setup(),
(false, Some(path)) => P::stored(path),
(true, Some(_)) => unreachable!("ensure via setters"),
let setup = if !P::FUNCTION_DEPENDENT_SETUP {
let mut mtp = match (self.insecure_setup, &self.stored_mts) {
(false, None) => P::ot_setup(setup_ch),
(true, None) => P::insecure_setup(),
(false, Some(path)) => P::stored(path),
(true, Some(_)) => unreachable!("ensure via setters"),
};
let mts_needed = circ.interactive_count_times_simd();
if !self.interleave_setup {
statistics
.record(Phase::Mts, mtp.precompute_mts(mts_needed))
.await
.map_err(|err| anyhow!(err))
.context("MT precomputation failed")?;
}
Box::new(mtp)
} else {
P::fd_setup(self.id, setup_ch)
};
let mts_needed = circ.interactive_count_times_simd();
if !self.interleave_setup {
statistics
.record(Phase::Mts, mtp.precompute_mts(mts_needed))
.await
.map_err(|err| anyhow!(err))
.context("MT precomputation failed")?;
}

let mut executor = statistics
.record(
Phase::FunctionDependentSetup,
Executor::<P, Idx>::new(circ, self.id, mtp),
Executor::<P, Idx>::new(circ, self.id, setup),
)
.await
.context("Failed to create executor")?;
Expand Down
4 changes: 3 additions & 1 deletion crates/seec/src/mul_triple/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::circuit::ExecutableCircuit;
use crate::executor::GateOutputs;
use crate::protocols::{FunctionDependentSetup, Protocol};
use crate::utils::{BoxError, ErasedError};
use crate::utils::BoxError;
use async_trait::async_trait;
use std::error::Error;

Expand Down Expand Up @@ -32,6 +32,8 @@ pub trait MTProvider {
}
}

pub struct ErasedError<I>(pub I);

#[async_trait]
impl<Mtp: MTProvider + Send> MTProvider for &mut Mtp {
type Output = Mtp::Output;
Expand Down
2 changes: 1 addition & 1 deletion crates/seec/src/private_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::protocols::boolean_gmw::{BooleanGmw, XorSharing};
use crate::protocols::mixed_gmw::{self, MixedGmw, MixedShareStorage, MixedSharing};
use crate::protocols::{FunctionDependentSetup, Protocol, Ring, ScalarDim, Share, Sharing};

pub trait ProtocolTestExt: Protocol + Default {
pub trait ProtocolTestExt: Protocol + Default + 'static {
type InsecureSetup<Idx: GateIdx>: FunctionDependentSetup<Self, Idx, Error = Infallible>
+ Default
+ Clone
Expand Down
12 changes: 11 additions & 1 deletion crates/seec/src/protocols/aby2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::{bristol, executor, CircuitBuilder};
use ahash::AHashMap;
use async_trait::async_trait;
use itertools::Itertools;
use rand::distributions::{Distribution, Standard};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaChaRng;
use seec_channel::multi::{MultiReceiver, MultiSender};
Expand Down Expand Up @@ -62,7 +63,7 @@ pub enum Msg {
Delta { delta: Vec<u8> },
}

#[derive(Clone, PartialOrd, Ord, PartialEq, Eq, Hash, Debug)]
#[derive(Clone, PartialOrd, Ord, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
pub enum BooleanGate {
Base(BaseGate<bool, ScalarDim>),
And { n: u8 },
Expand Down Expand Up @@ -441,6 +442,15 @@ impl super::Share for Share {
type SimdShare = ShareVec;
}

impl Distribution<Share> for Standard {
fn sample<RNG: Rng + ?Sized>(&self, rng: &mut RNG) -> Share {
Share {
public: rng.gen(),
private: rng.gen(),
}
}
}

impl From<BaseGate<bool>> for BooleanGate {
fn from(base_gate: BaseGate<bool>) -> Self {
Self::Base(base_gate)
Expand Down
Loading

0 comments on commit 5b78fd4

Please sign in to comment.