diff --git a/crates/seec/examples/bristol.rs b/crates/seec/examples/bristol.rs index f2ced0b..5c58e24 100644 --- a/crates/seec/examples/bristol.rs +++ b/crates/seec/examples/bristol.rs @@ -7,13 +7,19 @@ use anyhow::{Context, Result}; use clap::{Args, Parser}; -use seec::bench::{BenchParty, ServerTlsConfig}; +use rand::distributions::Standard; +use rand::prelude::Distribution; +use seec::bench::{BenchParty, BenchProtocol, ServerTlsConfig}; use seec::circuit::base_circuit::Load; use seec::circuit::{BaseCircuit, ExecutableCircuit}; -use seec::protocols::boolean_gmw::BooleanGmw; +use seec::gate::base::BaseGate; +use seec::protocols::{Gate, Share}; +use seec::protocols::{aby2, aby2::BooleanAby2, boolean_gmw::BooleanGmw, Protocol}; use seec::secret::inputs; -use seec::SubCircuitOutput; -use seec::{BooleanGate, CircuitBuilder}; +use seec::{bristol, SubCircuitOutput}; +use seec::{BooleanGate, CircuitBuilder,}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; use std::fs::File; use std::io; use std::io::{stdout, BufReader, BufWriter, Write}; @@ -48,6 +54,9 @@ struct CompileArgs { /// Circuit in bristol format circuit: PathBuf, + + #[clap(long)] + aby2: bool } #[derive(Args, Debug)] @@ -94,6 +103,9 @@ struct ExecuteArgs { /// Circuit to execute. Must be compiled beforehand circuit: PathBuf, + + #[clap(long)] + aby2: bool } #[tokio::main] @@ -101,8 +113,19 @@ async fn main() -> Result<()> { let prog_args = ProgArgs::parse(); init_tracing(&prog_args).context("failed to init logging")?; match prog_args { - ProgArgs::Compile(args) => compile(args).context("failed to compile circuit"), - ProgArgs::Execute(args) => execute(args).await.context("failed to execute circuit"), + ProgArgs::Compile(args) => { + if args.aby2 { + compile_aby2(args).context("failed to compile aby2 circuit") + } else { + compile(args).context("failed to compile circuit") + } + }, + ProgArgs::Execute(args) => if args.aby2 { + execute::(args).await.context("failed to execute circuit") + + } else { + execute::(args).await.context("failed to execute circuit") + }, } } @@ -145,6 +168,22 @@ fn compile(compile_args: CompileArgs) -> Result<()> { Ok(()) } +fn compile_aby2(compile_args: CompileArgs) -> Result<()> { + assert!(compile_args.simd.is_none()); + let mut bc: BaseCircuit = + BaseCircuit::load_bristol(&compile_args.circuit, Load::Circuit) + .expect("failed to load bristol circuit"); + + let mut circ = ExecutableCircuit::DynLayers(bc.into()); + if !compile_args.dyn_layers { + circ = circ.precompute_layers(); + } + let out = + BufWriter::new(File::create(&compile_args.output).context("failed to create output file")?); + bincode::serialize_into(out, &circ).context("failed to serialize circuit")?; + Ok(()) +} + impl ProgArgs { fn log(&self) -> Option<&PathBuf> { match self { @@ -154,17 +193,24 @@ impl ProgArgs { } } -async fn execute(execute_args: ExecuteArgs) -> Result<()> { +async fn execute

(execute_args: ExecuteArgs) -> Result<()> +where + P: BenchProtocol, +

::Gate: Gate + DeserializeOwned + Serialize + DeserializeOwned + From> + for<'a> From<&'a bristol::Gate> + From>, + Standard: Distribution, + P::Share: Share, + +{ let circ_name = execute_args .circuit .file_stem() .unwrap() .to_string_lossy() .to_string(); - let circuit = load_circ(&execute_args).context("failed to load circuit")?; + let circuit = load_circ::(&execute_args).context("failed to load circuit")?; let create_party = |id, circ| { - let mut party = BenchParty::::new(id) + let mut party = BenchParty::::new(id) .explicit_circuit(circ) .repeat(execute_args.repeat) .insecure_setup(execute_args.insecure_setup) @@ -222,7 +268,7 @@ async fn execute(execute_args: ExecuteArgs) -> Result<()> { Ok(()) } -fn load_circ(args: &ExecuteArgs) -> Result> { +fn load_circ(args: &ExecuteArgs) -> Result> where G: Gate + DeserializeOwned + Serialize + for<'a> From<&'a bristol::Gate> + From> { let res = bincode::deserialize_from(BufReader::new( File::open(&args.circuit).context("Failed to open circuit file")?, )); diff --git a/crates/seec/src/bench.rs b/crates/seec/src/bench.rs index 92a881b..3104d93 100644 --- a/crates/seec/src/bench.rs +++ b/crates/seec/src/bench.rs @@ -5,13 +5,17 @@ //! `crates/seec/examples/bristol.rs` binary. use crate::circuit::{ExecutableCircuit, GateIdx}; -use crate::executor::{Executor, Input, Message}; +use crate::executor::{DynFDSetup, Executor, Input, Message}; +use crate::mul_triple; use crate::mul_triple::storage::MTStorage; use crate::mul_triple::{boolean, MTProvider}; +use crate::protocols; +#[cfg(feature = "aby2")] +use crate::protocols::aby2::{AbySetupMsg, AbySetupProvider, BooleanAby2, DeltaSharing}; use crate::protocols::boolean_gmw::BooleanGmw; use crate::protocols::mixed_gmw::{Mixed, MixedGmw}; -use crate::protocols::{mixed_gmw, Protocol, Ring, Share, ShareStorage}; -use crate::utils::{BoxError, ErasedError}; +use crate::protocols::{mixed_gmw, FunctionDependentSetup, Protocol, Ring, Share, ShareStorage}; +use crate::utils::BoxError; use crate::CircuitBuilder; use anyhow::{anyhow, Context}; use bitvec::view::BitViewSized; @@ -19,6 +23,7 @@ use rand::distributions::{Distribution, Standard}; use rand::rngs::OsRng; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; +use remoc::RemoteSend; use seec_channel::util::{Phase, RunResult, Statistics}; use seec_channel::{sub_channels_for, Channel, Sender}; use serde::{Deserialize, Serialize}; @@ -29,27 +34,41 @@ use std::io::BufReader; use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::time::Duration; -use zappot::ot_ext::ExtOTMsg; type DynMTP

= Box::SetupStorage, Error = BoxError> + Send + 'static>; -pub trait BenchProtocol: Protocol + Default + Debug { +pub trait BenchProtocol: Protocol + Default + Debug + 'static { + const FUNCTION_DEPENDENT_SETUP: bool; + type SetupMsg: RemoteSend; + fn insecure_setup() -> DynMTP; - fn ot_setup(ch: Channel>) -> DynMTP; + fn ot_setup(ch: Channel) -> DynMTP; fn stored(path: &Path) -> DynMTP; + + fn fd_setup( + party_id: usize, + ch: Channel, + ) -> DynFDSetup<'static, Self, Idx> { + panic!("Needs to be implemented for Protocols with FUNCTION_DEPENDENT_SETUP = true") + } } impl BenchProtocol for BooleanGmw { + const FUNCTION_DEPENDENT_SETUP: bool = false; + type SetupMsg = mul_triple::boolean::ot_ext::DefaultMsg; + fn insecure_setup() -> DynMTP { - Box::new(ErasedError(boolean::InsecureMTProvider::default())) + Box::new(mul_triple::ErasedError( + boolean::InsecureMTProvider::default(), + )) } - fn ot_setup(ch: Channel>) -> DynMTP { + fn ot_setup(ch: Channel) -> DynMTP { let ot_sender = zappot::ot_ext::Sender::default(); let ot_recv = zappot::ot_ext::Receiver::default(); let mtp = boolean::OtMTProvider::new(OsRng, ot_sender, ot_recv, ch.0, ch.1); - Box::new(ErasedError(mtp)) + Box::new(mul_triple::ErasedError(mtp)) } fn stored(path: &Path) -> DynMTP { @@ -64,17 +83,54 @@ where Standard: Distribution, [R; 1]: BitViewSized, { + const FUNCTION_DEPENDENT_SETUP: bool = false; + type SetupMsg = (); + fn insecure_setup() -> DynMTP { mixed_gmw::InsecureMixedSetup::default().into_dyn() } - fn ot_setup(_ch: Channel>) -> DynMTP { + fn ot_setup(_ch: Channel) -> DynMTP { + todo!() + } + + fn stored(_path: &Path) -> DynMTP { + todo!() + } +} + +#[cfg(feature = "aby2")] +impl Default for BooleanAby2 { + fn default() -> Self { + BooleanAby2::new(DeltaSharing::insecure_default()) + } +} + +#[cfg(feature = "aby2")] +impl BenchProtocol for BooleanAby2 { + const FUNCTION_DEPENDENT_SETUP: bool = true; + type SetupMsg = AbySetupMsg; + + fn insecure_setup() -> DynMTP { + todo!() + } + + fn ot_setup(_ch: Channel) -> DynMTP { todo!() } fn stored(_path: &Path) -> DynMTP { todo!() } + + fn fd_setup( + party_id: usize, + ch: Channel, + ) -> DynFDSetup<'static, Self, Idx> { + let setup = + AbySetupProvider::new(party_id, boolean::InsecureMTProvider::default(), ch.0, ch.1); + Box::new(protocols::ErasedError(setup)) + } } // TODO this is wrong to just always generate arith shares, so it lives here in the bench API @@ -234,15 +290,10 @@ where .with_sleep(self.sleep_after_phase) .without_unaccounted(true); - let (ot_ch, mut exec_ch) = sub_channels_for!( - &mut sender, - &mut receiver, - 128, - Sender, - Message

- ) - .await - .context("Establishing sub channels")?; + let (setup_ch, mut exec_ch) = + sub_channels_for!(&mut sender, &mut receiver, 128, P::SetupMsg, Message

) + .await + .context("Establishing sub channels")?; let circ = match &self.circ { Some(circ) => circ, @@ -257,26 +308,30 @@ where } } }; - - let mut mtp = match (self.insecure_setup, &self.stored_mts) { - (false, None) => P::ot_setup(ot_ch), - (true, None) => P::insecure_setup(), - (false, Some(path)) => P::stored(path), - (true, Some(_)) => unreachable!("ensure via setters"), + let setup = if !P::FUNCTION_DEPENDENT_SETUP { + let mut mtp = match (self.insecure_setup, &self.stored_mts) { + (false, None) => P::ot_setup(setup_ch), + (true, None) => P::insecure_setup(), + (false, Some(path)) => P::stored(path), + (true, Some(_)) => unreachable!("ensure via setters"), + }; + let mts_needed = circ.interactive_count_times_simd(); + if !self.interleave_setup { + statistics + .record(Phase::Mts, mtp.precompute_mts(mts_needed)) + .await + .map_err(|err| anyhow!(err)) + .context("MT precomputation failed")?; + } + Box::new(mtp) + } else { + P::fd_setup(self.id, setup_ch) }; - let mts_needed = circ.interactive_count_times_simd(); - if !self.interleave_setup { - statistics - .record(Phase::Mts, mtp.precompute_mts(mts_needed)) - .await - .map_err(|err| anyhow!(err)) - .context("MT precomputation failed")?; - } let mut executor = statistics .record( Phase::FunctionDependentSetup, - Executor::::new(circ, self.id, mtp), + Executor::::new(circ, self.id, setup), ) .await .context("Failed to create executor")?; diff --git a/crates/seec/src/mul_triple/mod.rs b/crates/seec/src/mul_triple/mod.rs index 76ce1e8..3d3f087 100644 --- a/crates/seec/src/mul_triple/mod.rs +++ b/crates/seec/src/mul_triple/mod.rs @@ -3,7 +3,7 @@ use crate::circuit::ExecutableCircuit; use crate::executor::GateOutputs; use crate::protocols::{FunctionDependentSetup, Protocol}; -use crate::utils::{BoxError, ErasedError}; +use crate::utils::BoxError; use async_trait::async_trait; use std::error::Error; @@ -32,6 +32,8 @@ pub trait MTProvider { } } +pub struct ErasedError(pub I); + #[async_trait] impl MTProvider for &mut Mtp { type Output = Mtp::Output; diff --git a/crates/seec/src/private_test_utils.rs b/crates/seec/src/private_test_utils.rs index aecce56..2bd3357 100644 --- a/crates/seec/src/private_test_utils.rs +++ b/crates/seec/src/private_test_utils.rs @@ -40,7 +40,7 @@ use crate::protocols::boolean_gmw::{BooleanGmw, XorSharing}; use crate::protocols::mixed_gmw::{self, MixedGmw, MixedShareStorage, MixedSharing}; use crate::protocols::{FunctionDependentSetup, Protocol, Ring, ScalarDim, Share, Sharing}; -pub trait ProtocolTestExt: Protocol + Default { +pub trait ProtocolTestExt: Protocol + Default + 'static { type InsecureSetup: FunctionDependentSetup + Default + Clone diff --git a/crates/seec/src/protocols/aby2.rs b/crates/seec/src/protocols/aby2.rs index 7af7a02..78140e9 100644 --- a/crates/seec/src/protocols/aby2.rs +++ b/crates/seec/src/protocols/aby2.rs @@ -14,6 +14,7 @@ use crate::{bristol, executor, CircuitBuilder}; use ahash::AHashMap; use async_trait::async_trait; use itertools::Itertools; +use rand::distributions::{Distribution, Standard}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaChaRng; use seec_channel::multi::{MultiReceiver, MultiSender}; @@ -62,7 +63,7 @@ pub enum Msg { Delta { delta: Vec }, } -#[derive(Clone, PartialOrd, Ord, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialOrd, Ord, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)] pub enum BooleanGate { Base(BaseGate), And { n: u8 }, @@ -441,6 +442,15 @@ impl super::Share for Share { type SimdShare = ShareVec; } +impl Distribution for Standard { + fn sample(&self, rng: &mut RNG) -> Share { + Share { + public: rng.gen(), + private: rng.gen(), + } + } +} + impl From> for BooleanGate { fn from(base_gate: BaseGate) -> Self { Self::Base(base_gate) diff --git a/crates/seec/src/protocols/mod.rs b/crates/seec/src/protocols/mod.rs index 3d0f4c7..5ea6a88 100644 --- a/crates/seec/src/protocols/mod.rs +++ b/crates/seec/src/protocols/mod.rs @@ -333,6 +333,28 @@ where } } +#[async_trait] +impl<'c, P, Idx> FunctionDependentSetup + for Box + Send + 'c> +where + P: Protocol, + Idx: Sync, +{ + type Error = BoxError; + + async fn setup( + &mut self, + shares: &GateOutputs, + circuit: &ExecutableCircuit, + ) -> Result<(), Self::Error> { + self.setup(shares, circuit).await + } + + async fn request_setup_output(&mut self, count: usize) -> Result { + self.request_setup_output(count).await + } +} + #[derive(Copy, Clone, PartialOrd, Ord, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)] pub struct ScalarDim; diff --git a/crates/seec/src/utils.rs b/crates/seec/src/utils.rs index ca2b627..72ba7f1 100644 --- a/crates/seec/src/utils.rs +++ b/crates/seec/src/utils.rs @@ -131,8 +131,6 @@ impl BitVecExt for BitVec { } } -pub struct ErasedError(pub I); - #[derive(Debug)] pub struct BoxError(pub Box);