diff --git a/benches/dpa.rs b/benches/dpa.rs index f5620fe..2fa2b10 100644 --- a/benches/dpa.rs +++ b/benches/dpa.rs @@ -1,12 +1,12 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use muscat::distinguishers::dpa::{dpa, Dpa, DpaProcessor}; use muscat::leakage::sbox; -use ndarray::{Array1, Array2}; +use ndarray::{Array1, Array2, ArrayView1}; use ndarray_rand::rand::{rngs::StdRng, SeedableRng}; use ndarray_rand::rand_distr::Uniform; use ndarray_rand::RandomExt; -fn selection_function(metadata: Array1, guess: usize) -> bool { +fn selection_function(metadata: ArrayView1, guess: usize) -> bool { usize::from(sbox(metadata[1] ^ guess as u8)) & 1 == 1 } @@ -14,7 +14,7 @@ fn dpa_sequential(traces: &Array2, plaintexts: &Array2) -> Dpa { let mut dpa = DpaProcessor::new(traces.shape()[1], 256, selection_function); for i in 0..traces.shape()[0] { - dpa.update(traces.row(i), plaintexts.row(i).to_owned()); + dpa.update(traces.row(i), plaintexts.row(i)); } dpa.finalize() @@ -26,8 +26,7 @@ fn dpa_parallel(traces: &Array2, plaintexts: &Array2) -> Dpa { plaintexts .rows() .into_iter() - .map(|x| x.to_owned()) - .collect::>>() + .collect::>>() .view(), 256, selection_function, diff --git a/src/distinguishers/cpa.rs b/src/distinguishers/cpa.rs index fb57f9a..92752a9 100644 --- a/src/distinguishers/cpa.rs +++ b/src/distinguishers/cpa.rs @@ -282,3 +282,36 @@ where } } } + +#[cfg(test)] +mod tests { + use super::{cpa, CpaProcessor}; + use ndarray::array; + + #[test] + fn test_cpa_helper() { + let traces = array![ + [77usize, 137, 51, 91], + [72, 61, 91, 83], + [39, 49, 52, 23], + [26, 114, 63, 45], + [30, 8, 97, 91], + [13, 68, 7, 45], + [17, 181, 60, 34], + [43, 88, 76, 78], + [0, 36, 35, 0], + [93, 191, 49, 26], + ]; + let plaintexts = array![[1usize], [3], [1], [2], [3], [2], [2], [1], [3], [1]]; + + let leakage_model = |value, guess| value ^ guess; + let mut processor = CpaProcessor::new(traces.shape()[1], 256, 0, leakage_model); + for i in 0..traces.shape()[0] { + processor.update(traces.row(i), plaintexts.row(i)); + } + assert_eq!( + processor.finalize().corr(), + cpa(traces.view(), plaintexts.view(), 256, 0, leakage_model, 2).corr() + ); + } +} diff --git a/src/distinguishers/cpa_normal.rs b/src/distinguishers/cpa_normal.rs index b77728c..0b08588 100644 --- a/src/distinguishers/cpa_normal.rs +++ b/src/distinguishers/cpa_normal.rs @@ -245,3 +245,48 @@ where } } } + +#[cfg(test)] +mod tests { + use std::iter::zip; + + use super::{cpa, CpaProcessor}; + use ndarray::{array, ArrayView1, Axis}; + + #[test] + fn test_cpa_helper() { + let traces = array![ + [77usize, 137, 51, 91], + [72, 61, 91, 83], + [39, 49, 52, 23], + [26, 114, 63, 45], + [30, 8, 97, 91], + [13, 68, 7, 45], + [17, 181, 60, 34], + [43, 88, 76, 78], + [0, 36, 35, 0], + [93, 191, 49, 26], + ]; + let plaintexts = array![[1usize], [3], [1], [2], [3], [2], [2], [1], [3], [1]]; + + let leakage_model = |plaintext: ArrayView1, guess| plaintext[0] ^ guess; + let mut processor = CpaProcessor::new(traces.shape()[1], 1, 256, leakage_model); + for (trace, plaintext) in zip( + traces.axis_chunks_iter(Axis(0), 1), + plaintexts.axis_chunks_iter(Axis(0), 1), + ) { + processor.update(trace.map(|&x| x as f32).view(), plaintext.view()); + } + assert_eq!( + processor.finalize().corr(), + cpa( + traces.map(|&x| x as f32).view(), + plaintexts.view(), + 256, + leakage_model, + 2 + ) + .corr() + ); + } +} diff --git a/src/distinguishers/dpa.rs b/src/distinguishers/dpa.rs index c0f899e..1aa305c 100644 --- a/src/distinguishers/dpa.rs +++ b/src/distinguishers/dpa.rs @@ -245,3 +245,48 @@ where } } } + +#[cfg(test)] +mod tests { + use super::{dpa, DpaProcessor}; + use ndarray::{array, Array1, ArrayView1}; + + #[test] + fn test_dpa_helper() { + let traces = array![ + [77usize, 137, 51, 91], + [72, 61, 91, 83], + [39, 49, 52, 23], + [26, 114, 63, 45], + [30, 8, 97, 91], + [13, 68, 7, 45], + [17, 181, 60, 34], + [43, 88, 76, 78], + [0, 36, 35, 0], + [93, 191, 49, 26], + ]; + let plaintexts = array![[1], [3], [1], [2], [3], [2], [2], [1], [3], [1]]; + + let selection_function = + |plaintext: ArrayView1, guess| (plaintext[0] as usize ^ guess) & 1 == 1; + let mut processor = DpaProcessor::new(traces.shape()[1], 256, selection_function); + for i in 0..traces.shape()[0] { + processor.update(traces.row(i).map(|&x| x as f32).view(), plaintexts.row(i)); + } + assert_eq!( + processor.finalize().differential_curves(), + dpa( + traces.view().map(|&x| x as f32).view(), + plaintexts + .rows() + .into_iter() + .collect::>>() + .view(), + 256, + selection_function, + 2 + ) + .differential_curves() + ); + } +} diff --git a/src/leakage_detection.rs b/src/leakage_detection.rs index 8773dac..be54398 100644 --- a/src/leakage_detection.rs +++ b/src/leakage_detection.rs @@ -63,7 +63,7 @@ where || SnrProcessor::new(traces.shape()[1], classes), |mut snr, (batch_idx, trace_batch)| { for i in 0..trace_batch.shape()[0] { - snr.process(trace_batch.row(i), get_class(batch_idx + i)); + snr.process(trace_batch.row(i), get_class(batch_idx * batch_size + i)); } snr }, @@ -317,12 +317,34 @@ impl Add for TTestProcessor { #[cfg(test)] mod tests { - use super::{ttest, TTestProcessor}; + use super::{snr, ttest, SnrProcessor, TTestProcessor}; use ndarray::array; + #[test] + fn test_snr_helper() { + let traces = array![ + [77, 137, 51, 91], + [72, 61, 91, 83], + [39, 49, 52, 23], + [26, 114, 63, 45], + [30, 8, 97, 91], + [13, 68, 7, 45], + [17, 181, 60, 34], + [43, 88, 76, 78], + [0, 36, 35, 0], + [93, 191, 49, 26], + ]; + let classes = [1, 3, 1, 2, 3, 2, 2, 1, 3, 1]; + + let mut processor = SnrProcessor::new(traces.shape()[1], 256); + for (trace, class) in std::iter::zip(traces.rows(), classes.iter()) { + processor.process(trace, *class); + } + assert_eq!(processor.snr(), snr(traces.view(), 256, |i| classes[i], 2)); + } + #[test] fn test_ttest() { - let mut processor = TTestProcessor::new(4); let traces = [ array![77, 137, 51, 91], array![72, 61, 91, 83], @@ -335,9 +357,12 @@ mod tests { array![0, 36, 35, 0], array![93, 191, 49, 26], ]; + + let mut processor = TTestProcessor::new(4); for (i, trace) in traces.iter().enumerate() { processor.process(trace.view(), i % 3 == 0); } + assert_eq!( processor.ttest(), array![ @@ -351,7 +376,6 @@ mod tests { #[test] fn test_ttest_helper() { - let mut processor = TTestProcessor::new(4); let traces = array![ [77, 137, 51, 91], [72, 61, 91, 83], @@ -366,6 +390,8 @@ mod tests { ]; let trace_classes = array![true, false, false, true, false, false, true, false, false, true]; + + let mut processor = TTestProcessor::new(4); for (i, trace) in traces.rows().into_iter().enumerate() { processor.process(trace, trace_classes[i]); }