Skip to content

Commit

Permalink
Use faster fold + reduce_with in helpers
Browse files Browse the repository at this point in the history
From benchmarks it increases performance by around 10% for cpa and
by 30-60% for snr.
  • Loading branch information
TrAyZeN authored and kingofpayne committed May 7, 2024
1 parent 7d63ba6 commit 655dc03
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 42 deletions.
23 changes: 11 additions & 12 deletions src/cpa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,24 @@ where
assert_eq!(leakages.shape()[0], plaintexts.shape()[0]);
assert!(chunk_size > 0);

// From benchmarks fold + reduce_with is faster than map + reduce/reduce_with and fold + reduce
zip(
leakages.axis_chunks_iter(Axis(0), chunk_size),
plaintexts.axis_chunks_iter(Axis(0), chunk_size),
)
.par_bridge()
.map(|(leakages_chunk, plaintexts_chunk)| {
let mut cpa =
CpaProcessor::new(leakages.shape()[1], guess_range, target_byte, leakage_func);

for i in 0..leakages_chunk.shape()[0] {
cpa.update(leakages_chunk.row(i), plaintexts_chunk.row(i));
}

cpa
})
.reduce(
.fold(
|| CpaProcessor::new(leakages.shape()[1], guess_range, target_byte, leakage_func),
|a, b| a + b,
|mut cpa, (leakages_chunk, plaintexts_chunk)| {
for i in 0..leakages_chunk.shape()[0] {
cpa.update(leakages_chunk.row(i), plaintexts_chunk.row(i));
}

cpa
},
)
.reduce_with(|a, b| a + b)
.unwrap()
.finalize()
}

Expand Down
15 changes: 8 additions & 7 deletions src/cpa_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ where
plaintexts.axis_chunks_iter(Axis(0), chunk_size),
)
.par_bridge()
.map(|(leakages_chunk, plaintexts_chunk)| {
let mut cpa = CpaProcessor::new(leakages.shape()[1], chunk_size, guess_range, leakage_func);
cpa.update(leakages_chunk, plaintexts_chunk);
cpa
})
.reduce(
.fold(
|| CpaProcessor::new(leakages.shape()[1], chunk_size, guess_range, leakage_func),
|x, y| x + y,
|mut cpa, (leakages_chunk, plaintexts_chunk)| {
cpa.update(leakages_chunk, plaintexts_chunk);

cpa
},
)
.reduce_with(|x, y| x + y)
.unwrap()
.finalize()
}

Expand Down
25 changes: 14 additions & 11 deletions src/dpa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::{iter::zip, ops::Add};

use crate::util::max_per_row;

/// # Panics
/// Panics if `chunk_size` is not strictly positive.
pub fn dpa<M, T>(
leakages: ArrayView2<T>,
metadata: ArrayView1<M>,
Expand All @@ -15,24 +17,25 @@ where
T: Into<f32> + Copy + Sync,
M: Clone + Sync,
{
assert!(chunk_size > 0);

zip(
leakages.axis_chunks_iter(Axis(0), chunk_size),
metadata.axis_chunks_iter(Axis(0), chunk_size),
)
.par_bridge()
.map(|(leakages_chunk, metadata_chunk)| {
let mut dpa = DpaProcessor::new(leakages.shape()[1], guess_range, leakage_func);

for i in 0..leakages_chunk.shape()[0] {
dpa.update(leakages_chunk.row(i), metadata_chunk[i].clone());
}

dpa
})
.reduce(
.fold(
|| DpaProcessor::new(leakages.shape()[1], guess_range, leakage_func),
|a, b| a + b,
|mut dpa, (leakages_chunk, metadata_chunk)| {
for i in 0..leakages_chunk.shape()[0] {
dpa.update(leakages_chunk.row(i), metadata_chunk[i].clone());
}

dpa
},
)
.reduce_with(|a, b| a + b)
.unwrap()
.finalize()
}

Expand Down
26 changes: 14 additions & 12 deletions src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,22 @@ where
{
assert!(chunk_size > 0);

// From benchmarks fold + reduce_with is faster than map + reduce/reduce_with and fold + reduce
leakages
.axis_chunks_iter(Axis(0), chunk_size)
.enumerate()
.par_bridge()
.map(|(chunk_idx, leakages_chunk)| {
let mut snr = Snr::new(leakages.shape()[1], classes);

for i in 0..leakages_chunk.shape()[0] {
snr.process(leakages_chunk.row(i), get_class(chunk_idx + i));
}

snr
})
.reduce(|| Snr::new(leakages.shape()[1], classes), |a, b| a + b)
.fold(
|| Snr::new(leakages.shape()[1], classes),
|mut snr, (chunk_idx, leakages_chunk)| {
for i in 0..leakages_chunk.shape()[0] {
snr.process(leakages_chunk.row(i), get_class(chunk_idx + i));
}
snr
},
)
.reduce_with(|a, b| a + b)
.unwrap()
.snr()
}

Expand Down Expand Up @@ -169,14 +171,14 @@ impl Snr {

let class_sum = self.classes_sum.slice(s![class, ..]);
for i in 0..size {
acc[i] += (class_sum[i] as f64).powf(2.0) / (self.classes_count[class] as f64);
acc[i] += (class_sum[i] as f64).powi(2) / (self.classes_count[class] as f64);
}
}

let var = self.mean_var.var();
let mean = self.mean_var.mean();
// V[E[L|X]]
let velx = (acc / self.mean_var.count as f64) - mean.mapv(|x| x.powf(2.0));
let velx = (acc / self.mean_var.count as f64) - mean.mapv(|x| x.powi(2));
1f64 / (var / velx - 1f64)
}
}
Expand Down

0 comments on commit 655dc03

Please sign in to comment.