From 785627f6cceae28528a127b7acbcb7ed07c069ba Mon Sep 17 00:00:00 2001 From: TrAyZeN Date: Wed, 19 Jun 2024 14:25:31 +0200 Subject: [PATCH] Add argmax_by util --- src/cpa.rs | 22 +++------------------- src/cpa_normal.rs | 5 +---- src/dpa.rs | 18 ++---------------- src/util.rs | 20 ++++++++++++++++++-- 4 files changed, 24 insertions(+), 41 deletions(-) diff --git a/src/cpa.rs b/src/cpa.rs index 01c2a64..5718d03 100644 --- a/src/cpa.rs +++ b/src/cpa.rs @@ -1,4 +1,4 @@ -use crate::util::{argsort_by, max_per_row}; +use crate::util::{argmax_by, argsort_by, max_per_row}; use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis}; use rayon::{ iter::ParallelBridge, @@ -52,8 +52,6 @@ where /// [^1]: #[derive(Debug)] pub struct Cpa { - /// Guess range upper excluded bound - pub(crate) guess_range: usize, /// Pearson correlation coefficients pub(crate) corr: Array2, } @@ -73,18 +71,7 @@ impl Cpa { /// Return the guess with the highest Pearson correlation coefficient. pub fn best_guess(&self) -> usize { - let max_corr = self.max_corr(); - - let mut best_guess_corr = 0.0; - let mut best_guess = 0; - for guess in 0..self.guess_range { - if max_corr[guess] > best_guess_corr { - best_guess_corr = max_corr[guess]; - best_guess = guess; - } - } - - best_guess + argmax_by(self.max_corr().view(), f32::total_cmp) } /// Return the maximum Pearson correlation coefficient for each guess. @@ -213,10 +200,7 @@ where } } - Cpa { - guess_range: self.guess_range, - corr, - } + Cpa { corr } } fn sum_mult(&self, a: ArrayView1, b: ArrayView1) -> usize { diff --git a/src/cpa_normal.rs b/src/cpa_normal.rs index a5984a7..3bee39a 100644 --- a/src/cpa_normal.rs +++ b/src/cpa_normal.rs @@ -166,10 +166,7 @@ where } } - Cpa { - guess_range: self.guess_range, - corr, - } + Cpa { corr } } /// Determine if two [`CpaProcessor`] are compatible for addition. diff --git a/src/dpa.rs b/src/dpa.rs index 4ffd1f2..cc06ee3 100644 --- a/src/dpa.rs +++ b/src/dpa.rs @@ -2,7 +2,7 @@ use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis}; use rayon::iter::{ParallelBridge, ParallelIterator}; use std::{iter::zip, marker::PhantomData, ops::Add}; -use crate::util::{argsort_by, max_per_row}; +use crate::util::{argmax_by, argsort_by, max_per_row}; /// Compute the [`Dpa`] of the given traces using [`DpaProcessor`]. /// @@ -47,8 +47,6 @@ where /// [^1]: #[derive(Debug)] pub struct Dpa { - /// Guess range upper excluded bound - guess_range: usize, differential_curves: Array2, } @@ -67,18 +65,7 @@ impl Dpa { /// Return the guess with the highest differential peak. pub fn best_guess(&self) -> usize { - let max_corr = self.max_differential_curves(); - - let mut best_guess_value = 0.0; - let mut best_guess = 0; - for guess in 0..self.guess_range { - if max_corr[guess] > best_guess_value { - best_guess_value = max_corr[guess]; - best_guess = guess; - } - } - - best_guess + argmax_by(self.max_differential_curves().view(), f32::total_cmp) } /// Return the maximum differential peak for each guess. @@ -170,7 +157,6 @@ where } Dpa { - guess_range: self.guess_range, differential_curves, } } diff --git a/src/util.rs b/src/util.rs index f500ef4..04c9670 100644 --- a/src/util.rs +++ b/src/util.rs @@ -2,7 +2,7 @@ use std::{cmp::Ordering, io::Read, path::Path}; -use ndarray::{Array, Array1, Array2, ArrayView2, Axis}; +use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis}; use ndarray_npy::{ read_npy, write_npy, ReadNpyError, ReadableElement, WritableElement, WriteNpyError, }; @@ -84,7 +84,7 @@ pub fn max_per_row(arr: ArrayView2) -> Array1 { .collect() } -/// Return the indices that would sort the given array with a comparator function. +/// Return the indices that would sort the given array with a comparison function. pub fn argsort_by(data: &[T], compare: F) -> Vec where F: Fn(&T, &T) -> Ordering, @@ -95,3 +95,19 @@ where indices } + +/// Return the index of the maximum value in the given array. +pub fn argmax_by(array: ArrayView1, compare: F) -> usize +where + F: Fn(&T, &T) -> Ordering, +{ + let mut idx_max = 0; + + for i in 0..array.len() { + if compare(&array[i], &array[idx_max]).is_gt() { + idx_max = i; + } + } + + idx_max +}