Skip to content

Commit

Permalink
Add argmax_by util
Browse files Browse the repository at this point in the history
  • Loading branch information
TrAyZeN authored and kingofpayne committed Jul 15, 2024
1 parent d3606db commit 785627f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 41 deletions.
22 changes: 3 additions & 19 deletions src/cpa.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::util::{argsort_by, max_per_row};
use crate::util::{argmax_by, argsort_by, max_per_row};
use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
use rayon::{
iter::ParallelBridge,
Expand Down Expand Up @@ -52,8 +52,6 @@ where
/// [^1]: <https://www.iacr.org/archive/ches2004/31560016/31560016.pdf>
#[derive(Debug)]
pub struct Cpa {
/// Guess range upper excluded bound
pub(crate) guess_range: usize,
/// Pearson correlation coefficients
pub(crate) corr: Array2<f32>,
}
Expand All @@ -73,18 +71,7 @@ impl Cpa {

/// Return the guess with the highest Pearson correlation coefficient.
pub fn best_guess(&self) -> usize {
let max_corr = self.max_corr();

let mut best_guess_corr = 0.0;
let mut best_guess = 0;
for guess in 0..self.guess_range {
if max_corr[guess] > best_guess_corr {
best_guess_corr = max_corr[guess];
best_guess = guess;
}
}

best_guess
argmax_by(self.max_corr().view(), f32::total_cmp)
}

/// Return the maximum Pearson correlation coefficient for each guess.
Expand Down Expand Up @@ -213,10 +200,7 @@ where
}
}

Cpa {
guess_range: self.guess_range,
corr,
}
Cpa { corr }
}

fn sum_mult(&self, a: ArrayView1<usize>, b: ArrayView1<usize>) -> usize {
Expand Down
5 changes: 1 addition & 4 deletions src/cpa_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,7 @@ where
}
}

Cpa {
guess_range: self.guess_range,
corr,
}
Cpa { corr }
}

/// Determine if two [`CpaProcessor`] are compatible for addition.
Expand Down
18 changes: 2 additions & 16 deletions src/dpa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use rayon::iter::{ParallelBridge, ParallelIterator};
use std::{iter::zip, marker::PhantomData, ops::Add};

use crate::util::{argsort_by, max_per_row};
use crate::util::{argmax_by, argsort_by, max_per_row};

/// Compute the [`Dpa`] of the given traces using [`DpaProcessor`].
///
Expand Down Expand Up @@ -47,8 +47,6 @@ where
/// [^1]: <https://paulkocher.com/doc/DifferentialPowerAnalysis.pdf>
#[derive(Debug)]
pub struct Dpa {
/// Guess range upper excluded bound
guess_range: usize,
differential_curves: Array2<f32>,
}

Expand All @@ -67,18 +65,7 @@ impl Dpa {

/// Return the guess with the highest differential peak.
pub fn best_guess(&self) -> usize {
let max_corr = self.max_differential_curves();

let mut best_guess_value = 0.0;
let mut best_guess = 0;
for guess in 0..self.guess_range {
if max_corr[guess] > best_guess_value {
best_guess_value = max_corr[guess];
best_guess = guess;
}
}

best_guess
argmax_by(self.max_differential_curves().view(), f32::total_cmp)
}

/// Return the maximum differential peak for each guess.
Expand Down Expand Up @@ -170,7 +157,6 @@ where
}

Dpa {
guess_range: self.guess_range,
differential_curves,
}
}
Expand Down
20 changes: 18 additions & 2 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::{cmp::Ordering, io::Read, path::Path};

use ndarray::{Array, Array1, Array2, ArrayView2, Axis};
use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
use ndarray_npy::{
read_npy, write_npy, ReadNpyError, ReadableElement, WritableElement, WriteNpyError,
};
Expand Down Expand Up @@ -84,7 +84,7 @@ pub fn max_per_row(arr: ArrayView2<f32>) -> Array1<f32> {
.collect()
}

/// Return the indices that would sort the given array with a comparator function.
/// Return the indices that would sort the given array with a comparison function.
pub fn argsort_by<T, F>(data: &[T], compare: F) -> Vec<usize>
where
F: Fn(&T, &T) -> Ordering,
Expand All @@ -95,3 +95,19 @@ where

indices
}

/// Return the index of the maximum value in the given array.
pub fn argmax_by<T, F>(array: ArrayView1<T>, compare: F) -> usize
where
F: Fn(&T, &T) -> Ordering,
{
let mut idx_max = 0;

for i in 0..array.len() {
if compare(&array[i], &array[idx_max]).is_gt() {
idx_max = i;
}
}

idx_max
}

0 comments on commit 785627f

Please sign in to comment.