diff --git a/benches/cpa.rs b/benches/cpa.rs index 04e7581..7161351 100644 --- a/benches/cpa.rs +++ b/benches/cpa.rs @@ -13,16 +13,17 @@ pub fn leakage_model(value: usize, guess: usize) -> usize { } fn cpa_sequential(traces: &Array2, plaintexts: &Array2) -> Cpa { - let mut cpa = CpaProcessor::new(traces.shape()[1], 256, 0, leakage_model); + let mut cpa = CpaProcessor::new(traces.shape()[1], 256, 0); for i in 0..traces.shape()[0] { cpa.update( traces.row(i).map(|&x| x as usize).view(), plaintexts.row(i).map(|&y| y as usize).view(), + leakage_model, ); } - cpa.finalize() + cpa.finalize(leakage_model) } pub fn leakage_model_normal(value: ArrayView1, guess: usize) -> usize { @@ -32,14 +33,17 @@ pub fn leakage_model_normal(value: ArrayView1, guess: usize) -> usize { fn cpa_normal_sequential(traces: &Array2, plaintexts: &Array2) -> Cpa { let batch_size = 500; - let mut cpa = - cpa_normal::CpaProcessor::new(traces.shape()[1], batch_size, 256, leakage_model_normal); + let mut cpa = cpa_normal::CpaProcessor::new(traces.shape()[1], batch_size, 256); for (trace_batch, plaintext_batch) in zip( traces.axis_chunks_iter(Axis(0), batch_size), plaintexts.axis_chunks_iter(Axis(0), batch_size), ) { - cpa.update(trace_batch.map(|&x| x as f32).view(), plaintext_batch); + cpa.update( + trace_batch.map(|&x| x as f32).view(), + plaintext_batch, + leakage_model_normal, + ); } cpa.finalize() diff --git a/benches/dpa.rs b/benches/dpa.rs index 2fa2b10..d5acc0b 100644 --- a/benches/dpa.rs +++ b/benches/dpa.rs @@ -11,10 +11,10 @@ fn selection_function(metadata: ArrayView1, guess: usize) -> bool { } fn dpa_sequential(traces: &Array2, plaintexts: &Array2) -> Dpa { - let mut dpa = DpaProcessor::new(traces.shape()[1], 256, selection_function); + let mut dpa = DpaProcessor::new(traces.shape()[1], 256); for i in 0..traces.shape()[0] { - dpa.update(traces.row(i), plaintexts.row(i)); + dpa.update(traces.row(i), plaintexts.row(i), selection_function); } dpa.finalize() diff --git a/examples/cpa.rs b/examples/cpa.rs index 5cf2c33..dca19cb 100644 --- a/examples/cpa.rs +++ b/examples/cpa.rs @@ -35,20 +35,17 @@ fn cpa() -> Result<()> { .progress_with(progress_bar(len_traces)) .par_bridge() .map(|row_number| { - let mut cpa = CpaProcessor::new(size, batch, guess_range, leakage_model); + let mut cpa = CpaProcessor::new(size, batch, guess_range); let range_rows = row_number..row_number + batch; let range_samples = start_sample..end_sample; let sample_traces = traces .slice(s![range_rows.clone(), range_samples]) .map(|l| *l as f32); let sample_metadata = plaintext.slice(s![range_rows, ..]).map(|p| *p as usize); - cpa.update(sample_traces.view(), sample_metadata.view()); + cpa.update(sample_traces.view(), sample_metadata.view(), leakage_model); cpa }) - .reduce( - || CpaProcessor::new(size, batch, guess_range, leakage_model), - |x, y| x + y, - ); + .reduce(|| CpaProcessor::new(size, batch, guess_range), |x, y| x + y); let cpa = cpa_parallel.finalize(); println!("Guessed key = {}", cpa.best_guess()); @@ -69,7 +66,7 @@ fn success() -> Result<()> { let nfiles = 13; // Number of files in the directory. TBD: Automating this value let rank_traces = 1000; - let mut cpa = CpaProcessor::new(size, batch, guess_range, leakage_model); + let mut cpa = CpaProcessor::new(size, batch, guess_range); let mut rank = Array1::zeros(guess_range); let mut processed_traces = 0; @@ -87,7 +84,7 @@ fn success() -> Result<()> { .map(|l| *l as f32); let sample_metadata = plaintext.slice(s![range_rows, range_metadata]); - cpa.update(sample_traces.view(), sample_metadata); + cpa.update(sample_traces.view(), sample_metadata, leakage_model); processed_traces += sample_traces.len(); if processed_traces % rank_traces == 0 { // rank can be saved to get its evolution diff --git a/examples/cpa_partioned.rs b/examples/cpa_partioned.rs index c0f14f3..e4c7489 100644 --- a/examples/cpa_partioned.rs +++ b/examples/cpa_partioned.rs @@ -34,21 +34,22 @@ fn cpa() -> Result<()> { }) .par_bridge() .map(|batch| { - let mut c = CpaProcessor::new(size, guess_range, target_byte, leakage_model); + let mut c = CpaProcessor::new(size, guess_range, target_byte); for i in 0..batch.0.shape()[0] { c.update( batch.0.row(i).map(|x| *x as usize).view(), batch.1.row(i).map(|y| *y as usize).view(), + leakage_model, ); } c }) .reduce( - || CpaProcessor::new(size, guess_range, target_byte, leakage_model), + || CpaProcessor::new(size, guess_range, target_byte), |a, b| a + b, ); - let cpa_result = cpa.finalize(); + let cpa_result = cpa.finalize(leakage_model); println!("Guessed key = {}", cpa_result.best_guess()); // save corr key curves in npy diff --git a/examples/dpa.rs b/examples/dpa.rs index 1c909fc..9fb4871 100644 --- a/examples/dpa.rs +++ b/examples/dpa.rs @@ -25,14 +25,18 @@ fn dpa() -> Result<()> { let traces = read_array2_from_npy_file::(&dir_l)?; let plaintext = read_array2_from_npy_file::(&dir_p)?; let len_traces = 20000; //traces.shape()[0]; - let mut dpa_proc = DpaProcessor::new(size, guess_range, selection_function); + let mut dpa_proc = DpaProcessor::new(size, guess_range); for i in (0..len_traces).progress() { let tmp_trace = traces .row(i) .slice(s![start_sample..end_sample]) .mapv(|t| t as f32); let tmp_metadata = plaintext.row(i); - dpa_proc.update(tmp_trace.view(), tmp_metadata.to_owned()); + dpa_proc.update( + tmp_trace.view(), + tmp_metadata.to_owned(), + selection_function, + ); } let dpa = dpa_proc.finalize(); println!("Guessed key = {:02x}", dpa.best_guess()); @@ -53,7 +57,7 @@ fn dpa_success() -> Result<()> { let traces = read_array2_from_npy_file::(&dir_l)?; let plaintext = read_array2_from_npy_file::(&dir_p)?; let len_traces = traces.shape()[0]; - let mut dpa_proc = DpaProcessor::new(size, guess_range, selection_function); + let mut dpa_proc = DpaProcessor::new(size, guess_range); let rank_traces: usize = 100; let mut rank = Array1::zeros(guess_range); @@ -63,7 +67,7 @@ fn dpa_success() -> Result<()> { .slice(s![start_sample..end_sample]) .mapv(|t| t as f32); let tmp_metadata = plaintext.row(i).to_owned(); - dpa_proc.update(tmp_trace.view(), tmp_metadata); + dpa_proc.update(tmp_trace.view(), tmp_metadata, selection_function); if i % rank_traces == 0 { // rank can be saved to get its evolution @@ -102,18 +106,15 @@ fn dpa_parallel() -> Result<()> { .slice(s![range_rows..range_rows + batch, ..]) .to_owned(); - let mut dpa_inner = DpaProcessor::new(size, guess_range, selection_function); + let mut dpa_inner = DpaProcessor::new(size, guess_range); for i in 0..batch { let trace = tmp_traces.row(i); let metadata = tmp_metadata.row(i).to_owned(); - dpa_inner.update(trace, metadata); + dpa_inner.update(trace, metadata, selection_function); } dpa_inner }) - .reduce( - || DpaProcessor::new(size, guess_range, selection_function), - |x, y| x + y, - ) + .reduce(|| DpaProcessor::new(size, guess_range), |x, y| x + y) .finalize(); println!("{:2x}", dpa.best_guess()); diff --git a/examples/rank.rs b/examples/rank.rs index 3429b6e..f67ef1d 100644 --- a/examples/rank.rs +++ b/examples/rank.rs @@ -22,7 +22,7 @@ fn rank() -> Result<()> { let folder = String::from("../../data"); let nfiles = 5; let batch_size = 3000; - let mut rank = CpaProcessor::new(size, guess_range, target_byte, leakage_model); + let mut rank = CpaProcessor::new(size, guess_range, target_byte); for file in (0..nfiles).progress_with(progress_bar(nfiles)) { let dir_l = format!("{folder}/l{file}.npy"); let dir_p = format!("{folder}/p{file}.npy"); @@ -37,24 +37,25 @@ fn rank() -> Result<()> { let x = (0..batch_size) .par_bridge() .fold( - || CpaProcessor::new(size, guess_range, target_byte, leakage_model), + || CpaProcessor::new(size, guess_range, target_byte), |mut r, n| { r.update( l_sample.row(n).map(|l| *l as usize).view(), p_sample.row(n).map(|p| *p as usize).view(), + leakage_model, ); r }, ) .reduce( - || CpaProcessor::new(size, guess_range, target_byte, leakage_model), + || CpaProcessor::new(size, guess_range, target_byte), |lhs, rhs| lhs + rhs, ); rank = rank + x; } } - let rank = rank.finalize(); + let rank = rank.finalize(leakage_model); // save rank key curves in npy save_array("../results/rank.npy", &rank.rank().map(|&x| x as u64))?; diff --git a/src/distinguishers/cpa.rs b/src/distinguishers/cpa.rs index 9ca331d..aa7dfad 100644 --- a/src/distinguishers/cpa.rs +++ b/src/distinguishers/cpa.rs @@ -53,7 +53,7 @@ pub fn cpa( plaintexts: ArrayView2, guess_range: usize, target_byte: usize, - leakage_func: F, + leakage_model: F, batch_size: usize, ) -> Cpa where @@ -70,10 +70,10 @@ where ) .par_bridge() .fold( - || CpaProcessor::new(traces.shape()[1], guess_range, target_byte, leakage_func), + || CpaProcessor::new(traces.shape()[1], guess_range, target_byte), |mut cpa, (trace_batch, plaintext_batch)| { for i in 0..trace_batch.shape()[0] { - cpa.update(trace_batch.row(i), plaintext_batch.row(i)); + cpa.update(trace_batch.row(i), plaintext_batch.row(i), leakage_model); } cpa @@ -81,7 +81,7 @@ where ) .reduce_with(|a, b| a + b) .unwrap() - .finalize() + .finalize(leakage_model) } /// Result of the CPA[^1] on some traces. @@ -122,11 +122,8 @@ impl Cpa { /// It implements algorithm 4 from [^1]. /// /// [^1]: -#[derive(Debug, PartialEq)] -pub struct CpaProcessor -where - F: Fn(usize, usize) -> usize, -{ +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct CpaProcessor { /// Number of samples per trace num_samples: usize, /// Target byte index in a block @@ -144,22 +141,12 @@ where /// Sum of traces per plaintext used /// See 4.3 in plaintext_sum_traces: Array2, - /// Leakage model - leakage_func: F, /// Number of traces processed num_traces: usize, } -impl CpaProcessor -where - F: Fn(usize, usize) -> usize + Sync, -{ - pub fn new( - num_samples: usize, - guess_range: usize, - target_byte: usize, - leakage_func: F, - ) -> Self { +impl CpaProcessor { + pub fn new(num_samples: usize, guess_range: usize, target_byte: usize) -> Self { Self { num_samples, target_byte, @@ -169,16 +156,16 @@ where guess_sum_traces: Array1::zeros(guess_range), guess_sum_squares_traces: Array1::zeros(guess_range), plaintext_sum_traces: Array2::zeros((guess_range, num_samples)), - leakage_func, num_traces: 0, } } /// # Panics /// Panic in debug if `trace.shape()[0] != self.num_samples`. - pub fn update(&mut self, trace: ArrayView1, plaintext: ArrayView1) + pub fn update(&mut self, trace: ArrayView1, plaintext: ArrayView1, leakage_model: F) where T: Into + Copy, + F: Fn(usize, usize) -> usize, { debug_assert_eq!(trace.shape()[0], self.num_samples); @@ -191,7 +178,7 @@ where } for guess in 0..self.guess_range { - let value = (self.leakage_func)(plaintext[self.target_byte].into(), guess); + let value = leakage_model(plaintext[self.target_byte].into(), guess); self.guess_sum_traces[guess] += value; self.guess_sum_squares_traces[guess] += value * value; } @@ -200,13 +187,16 @@ where } /// Finalize the calculation after feeding the overall traces. - pub fn finalize(&self) -> Cpa { + pub fn finalize(&self, leakage_model: F) -> Cpa + where + F: Fn(usize, usize) -> usize, + { let mut modeled_leakages = Array1::zeros(self.guess_range); let mut corr = Array2::zeros((self.guess_range, self.num_samples)); for guess in 0..self.guess_range { for u in 0..self.guess_range { - modeled_leakages[u] = (self.leakage_func)(u, guess); + modeled_leakages[u] = leakage_model(u, guess); } let mean_key = self.guess_sum_traces[guess] as f32 / self.num_traces as f32; @@ -252,7 +242,7 @@ where /// change between versions. pub fn save>(&self, path: P) -> Result<(), Error> { let file = File::create(path)?; - serde_json::to_writer(file, &CpaProcessorSerdeAdapter::from(self))?; + serde_json::to_writer(file, self)?; Ok(()) } @@ -262,19 +252,16 @@ where /// # Warning /// The file format is not stable as muscat is active development. Thus, the format might /// change between versions. - pub fn load>(path: P, leakage_func: F) -> Result { + pub fn load>(path: P) -> Result { let file = File::open(path)?; - let p: CpaProcessorSerdeAdapter = serde_json::from_reader(file)?; + let p: CpaProcessor = serde_json::from_reader(file)?; - Ok(p.with(leakage_func)) + Ok(p) } /// Determine if two [`CpaProcessor`] are compatible for addition. /// /// If they were created with the same parameters, they are compatible. - /// - /// Note: [`CpaProcessor::leakage_func`] cannot be checked for equality, but they must have the - /// same leakage functions in order to be compatible. fn is_compatible_with(&self, other: &Self) -> bool { self.num_samples == other.num_samples && self.target_byte == other.target_byte @@ -282,10 +269,7 @@ where } } -impl Add for CpaProcessor -where - F: Fn(usize, usize) -> usize + Sync, -{ +impl Add for CpaProcessor { type Output = Self; /// Merge computations of two [`CpaProcessor`]. Processors need to be compatible to be merged @@ -306,73 +290,14 @@ where guess_sum_traces: self.guess_sum_traces + rhs.guess_sum_traces, guess_sum_squares_traces: self.guess_sum_squares_traces + rhs.guess_sum_squares_traces, plaintext_sum_traces: self.plaintext_sum_traces + rhs.plaintext_sum_traces, - leakage_func: self.leakage_func, num_traces: self.num_traces + rhs.num_traces, } } } -/// This type implements [`Deserialize`] on the subset of fields of [`CpaProcessor`] that are -/// serializable. -/// -/// [`CpaProcessor`] cannot implement [`Deserialize`] for every type `F` as it does not -/// implement [`Default`]. One solution would be to erase the type, but that would add an -/// indirection which could hurt the performance (not benchmarked though). -#[derive(Serialize, Deserialize)] -pub struct CpaProcessorSerdeAdapter { - num_samples: usize, - target_byte: usize, - guess_range: usize, - sum_traces: Array1, - sum_square_traces: Array1, - guess_sum_traces: Array1, - guess_sum_squares_traces: Array1, - plaintext_sum_traces: Array2, - num_traces: usize, -} - -impl CpaProcessorSerdeAdapter { - pub fn with(self, leakage_func: F) -> CpaProcessor - where - F: Fn(usize, usize) -> usize, - { - CpaProcessor { - num_samples: self.num_samples, - target_byte: self.target_byte, - guess_range: self.guess_range, - sum_traces: self.sum_traces, - sum_square_traces: self.sum_square_traces, - guess_sum_traces: self.guess_sum_traces, - guess_sum_squares_traces: self.guess_sum_squares_traces, - plaintext_sum_traces: self.plaintext_sum_traces, - leakage_func, - num_traces: self.num_traces, - } - } -} - -impl From<&CpaProcessor> for CpaProcessorSerdeAdapter -where - F: Fn(usize, usize) -> usize, -{ - fn from(processor: &CpaProcessor) -> Self { - Self { - num_samples: processor.num_samples, - target_byte: processor.target_byte, - guess_range: processor.guess_range, - sum_traces: processor.sum_traces.clone(), - sum_square_traces: processor.sum_square_traces.clone(), - guess_sum_traces: processor.guess_sum_traces.clone(), - guess_sum_squares_traces: processor.guess_sum_squares_traces.clone(), - plaintext_sum_traces: processor.plaintext_sum_traces.clone(), - num_traces: processor.num_traces, - } - } -} - #[cfg(test)] mod tests { - use super::{cpa, CpaProcessor, CpaProcessorSerdeAdapter}; + use super::{cpa, CpaProcessor}; use ndarray::array; use serde::Deserialize; @@ -393,12 +318,12 @@ mod tests { let plaintexts = array![[1usize], [3], [1], [2], [3], [2], [2], [1], [3], [1]]; let leakage_model = |value, guess| value ^ guess; - let mut processor = CpaProcessor::new(traces.shape()[1], 256, 0, leakage_model); + let mut processor = CpaProcessor::new(traces.shape()[1], 256, 0); for i in 0..traces.shape()[0] { - processor.update(traces.row(i), plaintexts.row(i)); + processor.update(traces.row(i), plaintexts.row(i), leakage_model); } assert_eq!( - processor.finalize().corr(), + processor.finalize(leakage_model).corr(), cpa(traces.view(), plaintexts.view(), 256, 0, leakage_model, 2).corr() ); } @@ -420,17 +345,14 @@ mod tests { let plaintexts = array![[1usize], [3], [1], [2], [3], [2], [2], [1], [3], [1]]; let leakage_model = |value, guess| value ^ guess; - let mut processor = CpaProcessor::new(traces.shape()[1], 256, 0, leakage_model); + let mut processor = CpaProcessor::new(traces.shape()[1], 256, 0); for i in 0..traces.shape()[0] { - processor.update(traces.row(i), plaintexts.row(i)); + processor.update(traces.row(i), plaintexts.row(i), leakage_model); } - let serialized = - serde_json::to_string(&CpaProcessorSerdeAdapter::from(&processor)).unwrap(); + let serialized = serde_json::to_string(&processor).unwrap(); let mut deserializer = serde_json::Deserializer::from_str(serialized.as_str()); - let restored_processor = CpaProcessorSerdeAdapter::deserialize(&mut deserializer) - .unwrap() - .with(leakage_model); + let restored_processor = CpaProcessor::deserialize(&mut deserializer).unwrap(); assert_eq!(processor.num_samples, restored_processor.num_samples); assert_eq!(processor.target_byte, restored_processor.target_byte); diff --git a/src/distinguishers/cpa_normal.rs b/src/distinguishers/cpa_normal.rs index 379ff78..9a19ca2 100644 --- a/src/distinguishers/cpa_normal.rs +++ b/src/distinguishers/cpa_normal.rs @@ -47,7 +47,7 @@ pub fn cpa( traces: ArrayView2, plaintexts: ArrayView2, guess_range: usize, - leakage_func: F, + leakage_model: F, batch_size: usize, ) -> Cpa where @@ -64,9 +64,9 @@ where ) .par_bridge() .fold( - || CpaProcessor::new(traces.shape()[1], batch_size, guess_range, leakage_func), + || CpaProcessor::new(traces.shape()[1], batch_size, guess_range), |mut cpa, (trace_batch, plaintext_batch)| { - cpa.update(trace_batch, plaintext_batch); + cpa.update(trace_batch, plaintext_batch, leakage_model); cpa }, @@ -79,10 +79,8 @@ where /// A processor that computes the [`Cpa`] of the given traces. /// /// [^1]: -pub struct CpaProcessor -where - F: Fn(ArrayView1, usize) -> usize, -{ +#[derive(Serialize, Deserialize)] +pub struct CpaProcessor { /// Number of samples per trace num_samples: usize, /// Guess range upper excluded bound @@ -99,17 +97,12 @@ where cov: Array2, /// Batch size batch_size: usize, - /// Leakage model - leakage_func: F, /// Number of traces processed num_traces: usize, } -impl CpaProcessor -where - F: Fn(ArrayView1, usize) -> usize, -{ - pub fn new(num_samples: usize, batch_size: usize, guess_range: usize, leakage_func: F) -> Self { +impl CpaProcessor { + pub fn new(num_samples: usize, batch_size: usize, guess_range: usize) -> Self { Self { num_samples, guess_range, @@ -120,7 +113,6 @@ where values: Array2::zeros((batch_size, guess_range)), cov: Array2::zeros((guess_range, num_samples)), batch_size, - leakage_func, num_traces: 0, } } @@ -128,10 +120,15 @@ where /// # Panics /// - Panic in debug if `trace_batch.shape()[0] != plaintext_batch.shape()[0]`. /// - Panic in debug if `trace_batch.shape()[1] != self.num_samples`. - pub fn update(&mut self, trace_batch: ArrayView2, plaintext_batch: ArrayView2) - where + pub fn update( + &mut self, + trace_batch: ArrayView2, + plaintext_batch: ArrayView2, + leakage_model: F, + ) where T: Into + Copy, U: Into + Copy, + F: Fn(ArrayView1, usize) -> usize, { debug_assert_eq!(trace_batch.shape()[0], plaintext_batch.shape()[0]); debug_assert_eq!(trace_batch.shape()[1], self.num_samples); @@ -141,23 +138,31 @@ where let trace_batch = trace_batch.mapv(|t| t.into()); let plaintext_batch = plaintext_batch.mapv(|m| m.into()); - self.update_values(trace_batch.view(), plaintext_batch.view(), self.guess_range); + self.update_values( + trace_batch.view(), + plaintext_batch.view(), + self.guess_range, + leakage_model, + ); self.update_key_leakages(trace_batch.view(), self.guess_range); self.num_traces += self.batch_size; } - fn update_values( + fn update_values( /* This function generates the values and cov arrays */ &mut self, trace: ArrayView2, metadata: ArrayView2, guess_range: usize, - ) { + leakage_model: F, + ) where + F: Fn(ArrayView1, usize) -> usize, + { for row in 0..self.batch_size { for guess in 0..guess_range { let pass_to_leakage = metadata.row(row); - self.values[[row, guess]] = (self.leakage_func)(pass_to_leakage, guess) as f32; + self.values[[row, guess]] = leakage_model(pass_to_leakage, guess) as f32; } } @@ -210,7 +215,7 @@ where /// change between versions. pub fn save>(&self, path: P) -> Result<(), Error> { let file = File::create(path)?; - serde_json::to_writer(file, &CpaProcessorSerdeAdapter::from(self))?; + serde_json::to_writer(file, self)?; Ok(()) } @@ -220,19 +225,16 @@ where /// # Warning /// The file format is not stable as muscat is active development. Thus, the format might /// change between versions. - pub fn load>(path: P, leakage_func: F) -> Result { + pub fn load>(path: P) -> Result { let file = File::open(path)?; - let p: CpaProcessorSerdeAdapter = serde_json::from_reader(file)?; + let p: CpaProcessor = serde_json::from_reader(file)?; - Ok(p.with(leakage_func)) + Ok(p) } /// Determine if two [`CpaProcessor`] are compatible for addition. /// /// If they were created with the same parameters, they are compatible. - /// - /// Note: [`CpaProcessor::leakage_func`] cannot be checked for equality, but they must have the - /// same leakage functions in order to be compatible. fn is_compatible_with(&self, other: &Self) -> bool { self.num_samples == other.num_samples && self.batch_size == other.batch_size @@ -240,10 +242,7 @@ where } } -impl Add for CpaProcessor -where - F: Fn(ArrayView1, usize) -> usize, -{ +impl Add for CpaProcessor { type Output = Self; /// Merge computations of two [`CpaProcessor`]. Processors need to be compatible to be merged @@ -265,78 +264,16 @@ where values: self.values + rhs.values, cov: self.cov + rhs.cov, batch_size: self.batch_size, - leakage_func: self.leakage_func, num_traces: self.num_traces + rhs.num_traces, } } } -/// This type implements [`Deserialize`] on the subset of fields of [`CpaProcessor`] that are -/// serializable. -/// -/// [`CpaProcessor`] cannot implement [`Deserialize`] for every type `F` as it does not -/// implement [`Default`]. One solution would be to erase the type, but that would add an -/// indirection which could hurt the performance (not benchmarked though). -#[derive(Serialize, Deserialize)] -pub struct CpaProcessorSerdeAdapter { - num_samples: usize, - guess_range: usize, - sum_traces: Array1, - sum_traces2: Array1, - guess_sum_traces: Array1, - guess_sum_traces2: Array1, - values: Array2, - cov: Array2, - batch_size: usize, - num_traces: usize, -} - -impl CpaProcessorSerdeAdapter { - pub fn with(self, leakage_func: F) -> CpaProcessor - where - F: Fn(ArrayView1, usize) -> usize, - { - CpaProcessor { - num_samples: self.num_samples, - guess_range: self.guess_range, - sum_traces: self.sum_traces, - sum_traces2: self.sum_traces2, - guess_sum_traces: self.guess_sum_traces, - guess_sum_traces2: self.guess_sum_traces2, - values: self.values, - cov: self.cov, - batch_size: self.batch_size, - leakage_func, - num_traces: self.num_traces, - } - } -} - -impl From<&CpaProcessor> for CpaProcessorSerdeAdapter -where - F: Fn(ArrayView1, usize) -> usize, -{ - fn from(processor: &CpaProcessor) -> Self { - Self { - num_samples: processor.num_samples, - guess_range: processor.guess_range, - sum_traces: processor.sum_traces.clone(), - sum_traces2: processor.sum_traces2.clone(), - guess_sum_traces: processor.guess_sum_traces.clone(), - guess_sum_traces2: processor.guess_sum_traces2.clone(), - values: processor.values.clone(), - cov: processor.cov.clone(), - batch_size: processor.batch_size, - num_traces: processor.num_traces, - } - } -} - #[cfg(test)] mod tests { use std::iter::zip; - use super::{cpa, CpaProcessor, CpaProcessorSerdeAdapter}; + use super::{cpa, CpaProcessor}; use ndarray::{array, ArrayView1, Axis}; use serde::Deserialize; @@ -357,12 +294,16 @@ mod tests { let plaintexts = array![[1usize], [3], [1], [2], [3], [2], [2], [1], [3], [1]]; let leakage_model = |plaintext: ArrayView1, guess| plaintext[0] ^ guess; - let mut processor = CpaProcessor::new(traces.shape()[1], 1, 256, leakage_model); + let mut processor = CpaProcessor::new(traces.shape()[1], 1, 256); for (trace, plaintext) in zip( traces.axis_chunks_iter(Axis(0), 1), plaintexts.axis_chunks_iter(Axis(0), 1), ) { - processor.update(trace.map(|&x| x as f32).view(), plaintext.view()); + processor.update( + trace.map(|&x| x as f32).view(), + plaintext.view(), + leakage_model, + ); } assert_eq!( processor.finalize().corr(), @@ -394,20 +335,22 @@ mod tests { let plaintexts = array![[1usize], [3], [1], [2], [3], [2], [2], [1], [3], [1]]; let leakage_model = |plaintext: ArrayView1, guess| plaintext[0] ^ guess; - let mut processor = CpaProcessor::new(traces.shape()[1], 1, 256, leakage_model); + let mut processor = CpaProcessor::new(traces.shape()[1], 1, 256); for (trace, plaintext) in zip( traces.axis_chunks_iter(Axis(0), 1), plaintexts.axis_chunks_iter(Axis(0), 1), ) { - processor.update(trace.map(|&x| x as f32).view(), plaintext.view()); + processor.update( + trace.map(|&x| x as f32).view(), + plaintext.view(), + leakage_model, + ); } - let serialized = - serde_json::to_string(&CpaProcessorSerdeAdapter::from(&processor)).unwrap(); - let mut deserializer = serde_json::Deserializer::from_str(serialized.as_str()); - let restored_processor = CpaProcessorSerdeAdapter::deserialize(&mut deserializer) - .unwrap() - .with(leakage_model); + let serialized = serde_json::to_string(&processor).unwrap(); + let mut deserializer: serde_json::Deserializer> = + serde_json::Deserializer::from_str(serialized.as_str()); + let restored_processor = CpaProcessor::deserialize(&mut deserializer).unwrap(); assert_eq!(processor.num_samples, restored_processor.num_samples); assert_eq!(processor.guess_range, restored_processor.guess_range); diff --git a/src/distinguishers/dpa.rs b/src/distinguishers/dpa.rs index 32cf54a..21a115f 100644 --- a/src/distinguishers/dpa.rs +++ b/src/distinguishers/dpa.rs @@ -76,10 +76,14 @@ where ) .par_bridge() .fold( - || DpaProcessor::new(traces.shape()[1], guess_range, selection_function), + || DpaProcessor::new(traces.shape()[1], guess_range), |mut dpa, (trace_batch, metadata_batch)| { for i in 0..trace_batch.shape()[0] { - dpa.update(trace_batch.row(i), metadata_batch[i].clone()); + dpa.update( + trace_batch.row(i), + metadata_batch[i].clone(), + selection_function, + ); } dpa @@ -126,10 +130,8 @@ impl Dpa { /// /// [^1]: /// [^2]: -pub struct DpaProcessor -where - F: Fn(M, usize) -> bool, -{ +#[derive(Serialize, Deserialize)] +pub struct DpaProcessor { /// Number of samples per trace num_samples: usize, /// Guess range upper excluded bound @@ -142,18 +144,16 @@ where count_0: Array1, /// Number of traces processed for which the selection function equals true count_1: Array1, - selection_function: F, /// Number of traces processed num_traces: usize, _metadata: PhantomData, } -impl DpaProcessor +impl DpaProcessor where M: Clone, - F: Fn(M, usize) -> bool, { - pub fn new(num_samples: usize, guess_range: usize, selection_function: F) -> Self { + pub fn new(num_samples: usize, guess_range: usize) -> Self { Self { num_samples, guess_range, @@ -161,7 +161,6 @@ where sum_1: Array2::zeros((guess_range, num_samples)), count_0: Array1::zeros(guess_range), count_1: Array1::zeros(guess_range), - selection_function, num_traces: 0, _metadata: PhantomData, } @@ -169,14 +168,15 @@ where /// # Panics /// Panic in debug if `trace.shape()[0] != self.num_samples`. - pub fn update(&mut self, trace: ArrayView1, metadata: M) + pub fn update(&mut self, trace: ArrayView1, metadata: M, selection_function: F) where T: Into + Copy, + F: Fn(M, usize) -> bool, { debug_assert_eq!(trace.shape()[0], self.num_samples); for guess in 0..self.guess_range { - if (self.selection_function)(metadata.clone(), guess) { + if selection_function(metadata.clone(), guess) { for i in 0..self.num_samples { self.sum_1[[guess, i]] += trace[i].into(); } @@ -216,7 +216,7 @@ where /// change between versions. pub fn save>(&self, path: P) -> Result<(), Error> { let file = File::create(path)?; - serde_json::to_writer(file, &DpaProcessorSerdeAdapter::from(self))?; + serde_json::to_writer(file, self)?; Ok(()) } @@ -226,27 +226,23 @@ where /// # Warning /// The file format is not stable as muscat is active development. Thus, the format might /// change between versions. - pub fn load>(path: P, selection_function: F) -> Result { + pub fn load>(path: P) -> Result { let file = File::open(path)?; - let p: DpaProcessorSerdeAdapter = serde_json::from_reader(file)?; + let p: DpaProcessor = serde_json::from_reader(file)?; - Ok(p.with(selection_function)) + Ok(p) } /// Determine if two [`DpaProcessor`] are compatible for addition. /// /// If they were created with the same parameters, they are compatible. - /// - /// Note: [`DpaProcessor::selection_function`] cannot be checked for equality, but they must - /// have the same selection functions in order to be compatible. fn is_compatible_with(&self, other: &Self) -> bool { self.num_samples == other.num_samples && self.guess_range == other.guess_range } } -impl Add for DpaProcessor +impl Add for DpaProcessor where - F: Fn(M, usize) -> bool, M: Clone, { type Output = Self; @@ -267,69 +263,15 @@ where sum_1: self.sum_1 + rhs.sum_1, count_0: self.count_0 + rhs.count_0, count_1: self.count_1 + rhs.count_1, - selection_function: self.selection_function, num_traces: self.num_traces + rhs.num_traces, _metadata: PhantomData, } } } -/// This type implements [`Deserialize`] on the subset of fields of [`DpaProcessor`] that are -/// serializable. -/// -/// [`DpaProcessor`] cannot implement [`Deserialize`] for every type `F` as it does not -/// implement [`Default`]. One solution would be to erase the type, but that would add an -/// indirection which could hurt the performance (not benchmarked though). -#[derive(Serialize, Deserialize)] -pub struct DpaProcessorSerdeAdapter { - num_samples: usize, - guess_range: usize, - sum_0: Array2, - sum_1: Array2, - count_0: Array1, - count_1: Array1, - num_traces: usize, -} - -impl DpaProcessorSerdeAdapter { - pub fn with(self, selection_function: F) -> DpaProcessor - where - F: Fn(M, usize) -> bool, - { - DpaProcessor { - num_samples: self.num_samples, - guess_range: self.guess_range, - sum_0: self.sum_0, - sum_1: self.sum_1, - count_0: self.count_0, - count_1: self.count_1, - selection_function, - num_traces: self.num_traces, - _metadata: PhantomData, - } - } -} - -impl From<&DpaProcessor> for DpaProcessorSerdeAdapter -where - F: Fn(M, usize) -> bool, -{ - fn from(processor: &DpaProcessor) -> Self { - Self { - num_samples: processor.num_samples, - guess_range: processor.guess_range, - sum_0: processor.sum_0.clone(), - sum_1: processor.sum_1.clone(), - count_0: processor.count_0.clone(), - count_1: processor.count_1.clone(), - num_traces: processor.num_traces, - } - } -} - #[cfg(test)] mod tests { - use super::{dpa, DpaProcessor, DpaProcessorSerdeAdapter}; + use super::{dpa, DpaProcessor}; use ndarray::{array, Array1, ArrayView1}; use serde::Deserialize; @@ -351,9 +293,13 @@ mod tests { let selection_function = |plaintext: ArrayView1, guess| (plaintext[0] as usize ^ guess) & 1 == 1; - let mut processor = DpaProcessor::new(traces.shape()[1], 256, selection_function); + let mut processor = DpaProcessor::new(traces.shape()[1], 256); for i in 0..traces.shape()[0] { - processor.update(traces.row(i).map(|&x| x as f32).view(), plaintexts.row(i)); + processor.update( + traces.row(i).map(|&x| x as f32).view(), + plaintexts.row(i), + selection_function, + ); } assert_eq!( processor.finalize().differential_curves(), @@ -390,17 +336,19 @@ mod tests { let selection_function = |plaintext: ArrayView1, guess| (plaintext[0] as usize ^ guess) & 1 == 1; - let mut processor = DpaProcessor::new(traces.shape()[1], 256, selection_function); + let mut processor = DpaProcessor::new(traces.shape()[1], 256); for i in 0..traces.shape()[0] { - processor.update(traces.row(i).map(|&x| x as f32).view(), plaintexts.row(i)); + processor.update( + traces.row(i).map(|&x| x as f32).view(), + plaintexts.row(i), + selection_function, + ); } - let serialized = - serde_json::to_string(&DpaProcessorSerdeAdapter::from(&processor)).unwrap(); + let serialized = serde_json::to_string(&processor).unwrap(); let mut deserializer = serde_json::Deserializer::from_str(serialized.as_str()); - let restored_processor = DpaProcessorSerdeAdapter::deserialize(&mut deserializer) - .unwrap() - .with(selection_function); + let restored_processor = + DpaProcessor::>::deserialize(&mut deserializer).unwrap(); assert_eq!(processor.num_samples, restored_processor.num_samples); assert_eq!(processor.guess_range, restored_processor.guess_range); diff --git a/src/leakage_detection.rs b/src/leakage_detection.rs index eec26fa..4cbe45d 100644 --- a/src/leakage_detection.rs +++ b/src/leakage_detection.rs @@ -105,6 +105,7 @@ impl SnrProcessor { /// - Panics in debug if the length of the trace is different from the size of [`SnrProcessor`]. pub fn process + Copy>(&mut self, trace: ArrayView1, class: usize) { debug_assert!(trace.len() == self.size()); + debug_assert!(class < self.num_classes()); self.mean_var.process(trace);