Skip to content

Commit

Permalink
Allow to pass different trace and plaintext types
Browse files Browse the repository at this point in the history
  • Loading branch information
TrAyZeN committed Dec 5, 2024
1 parent 8057a65 commit e0ab261
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
10 changes: 6 additions & 4 deletions src/distinguishers/cpa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::{iter::zip, ops::Add};
/// use ndarray::array;
///
/// let traces = array![
/// [77, 137, 51, 91],
/// [77u8, 137, 51, 91],
/// [72, 61, 91, 83],
/// [39, 49, 52, 23],
/// [26, 114, 63, 45],
Expand Down Expand Up @@ -44,16 +44,17 @@ use std::{iter::zip, ops::Add};
/// # Panics
/// - Panic if `traces.shape()[0] != plaintexts.shape()[0]`
/// - Panic if `batch_size` is 0.
pub fn cpa<T, F>(
pub fn cpa<T, P, F>(
traces: ArrayView2<T>,
plaintexts: ArrayView2<T>,
plaintexts: ArrayView2<P>,
guess_range: usize,
target_byte: usize,
leakage_func: F,
batch_size: usize,
) -> Cpa
where
T: Into<usize> + Copy + Sync,
P: Into<usize> + Copy + Sync,
F: Fn(usize, usize) -> usize + Send + Sync + Copy,
{
assert_eq!(traces.shape()[0], plaintexts.shape()[0]);
Expand Down Expand Up @@ -171,9 +172,10 @@ where

/// # Panics
/// Panic in debug if `trace.shape()[0] != self.num_samples`.
pub fn update<T>(&mut self, trace: ArrayView1<T>, plaintext: ArrayView1<T>)
pub fn update<T, P>(&mut self, trace: ArrayView1<T>, plaintext: ArrayView1<P>)
where
T: Into<usize> + Copy,
P: Into<usize> + Copy,
{
debug_assert_eq!(trace.shape()[0], self.num_samples);

Expand Down
10 changes: 5 additions & 5 deletions src/distinguishers/cpa_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ use crate::distinguishers::cpa::Cpa;
/// # Panics
/// - Panic if `traces.shape()[0] != plaintexts.shape()[0]`
/// - Panic if `batch_size` is 0.
pub fn cpa<T, U, F>(
pub fn cpa<T, P, F>(
traces: ArrayView2<T>,
plaintexts: ArrayView2<U>,
plaintexts: ArrayView2<P>,
guess_range: usize,
leakage_func: F,
batch_size: usize,
) -> Cpa
where
T: Into<f32> + Copy + Sync,
U: Into<usize> + Copy + Sync,
P: Into<usize> + Copy + Sync,
F: Fn(ArrayView1<usize>, usize) -> usize + Send + Sync + Copy,
{
assert_eq!(traces.shape()[0], plaintexts.shape()[0]);
Expand Down Expand Up @@ -127,10 +127,10 @@ where
/// # Panics
/// - Panic in debug if `trace_batch.shape()[0] != plaintext_batch.shape()[0]`.
/// - Panic in debug if `trace_batch.shape()[1] != self.num_samples`.
pub fn update<T, U>(&mut self, trace_batch: ArrayView2<T>, plaintext_batch: ArrayView2<U>)
pub fn update<T, P>(&mut self, trace_batch: ArrayView2<T>, plaintext_batch: ArrayView2<P>)
where
T: Into<f32> + Copy,
U: Into<usize> + Copy,
P: Into<usize> + Copy,
{
debug_assert_eq!(trace_batch.shape()[0], plaintext_batch.shape()[0]);
debug_assert_eq!(trace_batch.shape()[1], self.num_samples);
Expand Down
2 changes: 1 addition & 1 deletion src/distinguishers/dpa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use crate::util::{argmax_by, argsort_by, max_per_row};
///
/// # Panics
/// Panic if `batch_size` is not strictly positive.
pub fn dpa<M, T, F>(
pub fn dpa<T, M, F>(
traces: ArrayView2<T>,
metadata: ArrayView1<M>,
guess_range: usize,
Expand Down

0 comments on commit e0ab261

Please sign in to comment.