From 504c11c79cf116a0a701083af8f23fd50e0816e0 Mon Sep 17 00:00:00 2001 From: TrAyZeN Date: Wed, 13 Nov 2024 11:17:36 +0100 Subject: [PATCH] Add save and load helpers --- Cargo.toml | 4 +-- src/distinguishers/cpa.rs | 31 +++++++++++++++++-- src/distinguishers/cpa_normal.rs | 28 +++++++++++++++-- src/distinguishers/dpa.rs | 31 +++++++++++++++++-- src/error.rs | 10 ++++++ src/leakage_detection.rs | 52 ++++++++++++++++++++++++++++++-- src/lib.rs | 3 ++ 7 files changed, 149 insertions(+), 10 deletions(-) create mode 100644 src/error.rs diff --git a/Cargo.toml b/Cargo.toml index b715908..f302cce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" [features] progress_bar = ["dep:indicatif"] -quicklog = ["dep:thiserror"] +quicklog = [] [dependencies] serde_json = "1.0.132" @@ -22,7 +22,7 @@ rayon = "1.10.0" indicatif = { version = "0.17.8", optional = true } ndarray-npy ="0.9.1" itertools = "0.13.0" -thiserror = { version = "1.0.58", optional = true } +thiserror = { version = "1.0.58" } dtw = { git = "https://github.com/Ledger-Donjon/dtw.git", rev = "0f8d7ec3bbdf2ca4ec8ea35feddb8d1db73e7d54" } num-traits = "0.2.19" serde = { version = "1.0.214", features = ["derive"] } diff --git a/src/distinguishers/cpa.rs b/src/distinguishers/cpa.rs index 8c4eba4..c97bcfc 100644 --- a/src/distinguishers/cpa.rs +++ b/src/distinguishers/cpa.rs @@ -1,11 +1,14 @@ -use crate::util::{argmax_by, argsort_by, max_per_row}; +use crate::{ + util::{argmax_by, argsort_by, max_per_row}, + Error, +}; use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis}; use rayon::{ iter::ParallelBridge, prelude::{IntoParallelIterator, ParallelIterator}, }; use serde::{Deserialize, Serialize}; -use std::{iter::zip, ops::Add}; +use std::{fs::File, iter::zip, ops::Add, path::Path}; /// Compute the [`Cpa`] of the given traces using [`CpaProcessor`]. /// @@ -244,6 +247,30 @@ where a.dot(&b) } + /// Save the [`CpaProcessor`] to a file. + /// + /// # Warning + /// The file format is not stable as muscat is active development. Thus, the format might + /// change between versions. + pub fn save>(&self, path: P) -> Result<(), Error> { + let file = File::create(path)?; + serde_json::to_writer(file, &CpaProcessorSerdeAdapter::from(self))?; + + Ok(()) + } + + /// Load a [`CpaProcessor`] from a file. + /// + /// # Warning + /// The file format is not stable as muscat is active development. Thus, the format might + /// change between versions. + pub fn load>(path: P, leakage_func: F) -> Result { + let file = File::open(path)?; + let p: CpaProcessorSerdeAdapter = serde_json::from_reader(file)?; + + Ok(p.with(leakage_func)) + } + /// Determine if two [`CpaProcessor`] are compatible for addition. /// /// If they were created with the same parameters, they are compatible. diff --git a/src/distinguishers/cpa_normal.rs b/src/distinguishers/cpa_normal.rs index c91b4d6..87224cf 100644 --- a/src/distinguishers/cpa_normal.rs +++ b/src/distinguishers/cpa_normal.rs @@ -1,9 +1,9 @@ use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis}; use rayon::iter::{ParallelBridge, ParallelIterator}; use serde::{Deserialize, Serialize}; -use std::{iter::zip, ops::Add}; +use std::{fs::File, iter::zip, ops::Add, path::Path}; -use crate::distinguishers::cpa::Cpa; +use crate::{distinguishers::cpa::Cpa, Error}; /// Compute the [`Cpa`] of the given traces using [`CpaProcessor`]. /// @@ -203,6 +203,30 @@ where Cpa { corr } } + /// Save the [`CpaProcessor`] to a file. + /// + /// # Warning + /// The file format is not stable as muscat is active development. Thus, the format might + /// change between versions. + pub fn save>(&self, path: P) -> Result<(), Error> { + let file = File::create(path)?; + serde_json::to_writer(file, &CpaProcessorSerdeAdapter::from(self))?; + + Ok(()) + } + + /// Load a [`CpaProcessor`] from a file. + /// + /// # Warning + /// The file format is not stable as muscat is active development. Thus, the format might + /// change between versions. + pub fn load>(path: P, leakage_func: F) -> Result { + let file = File::open(path)?; + let p: CpaProcessorSerdeAdapter = serde_json::from_reader(file)?; + + Ok(p.with(leakage_func)) + } + /// Determine if two [`CpaProcessor`] are compatible for addition. /// /// If they were created with the same parameters, they are compatible. diff --git a/src/distinguishers/dpa.rs b/src/distinguishers/dpa.rs index 4f834f5..2a9d700 100644 --- a/src/distinguishers/dpa.rs +++ b/src/distinguishers/dpa.rs @@ -1,9 +1,12 @@ use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis}; use rayon::iter::{ParallelBridge, ParallelIterator}; use serde::{Deserialize, Serialize}; -use std::{iter::zip, marker::PhantomData, ops::Add}; +use std::{fs::File, iter::zip, marker::PhantomData, ops::Add, path::Path}; -use crate::util::{argmax_by, argsort_by, max_per_row}; +use crate::{ + util::{argmax_by, argsort_by, max_per_row}, + Error, +}; /// Compute the [`Dpa`] of the given traces using [`DpaProcessor`]. /// @@ -206,6 +209,30 @@ where } } + /// Save the [`DpaProcessor`] to a file. + /// + /// # Warning + /// The file format is not stable as muscat is active development. Thus, the format might + /// change between versions. + pub fn save>(&self, path: P) -> Result<(), Error> { + let file = File::create(path)?; + serde_json::to_writer(file, &DpaProcessorSerdeAdapter::from(self))?; + + Ok(()) + } + + /// Load a [`DpaProcessor`] from a file. + /// + /// # Warning + /// The file format is not stable as muscat is active development. Thus, the format might + /// change between versions. + pub fn load>(path: P, selection_function: F) -> Result { + let file = File::open(path)?; + let p: DpaProcessorSerdeAdapter = serde_json::from_reader(file)?; + + Ok(p.with(selection_function)) + } + /// Determine if two [`DpaProcessor`] are compatible for addition. /// /// If they were created with the same parameters, they are compatible. diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..e7db06b --- /dev/null +++ b/src/error.rs @@ -0,0 +1,10 @@ +use std::io; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum Error { + #[error("Failed to save/load muscat data")] + SaveLoadError(#[from] serde_json::Error), + #[error(transparent)] + IoError(#[from] io::Error), +} diff --git a/src/leakage_detection.rs b/src/leakage_detection.rs index dd06f19..eec26fa 100644 --- a/src/leakage_detection.rs +++ b/src/leakage_detection.rs @@ -1,9 +1,9 @@ //! Leakage detection methods -use crate::processors::MeanVar; +use crate::{processors::MeanVar, Error}; use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis}; use rayon::iter::{ParallelBridge, ParallelIterator}; use serde::{Deserialize, Serialize}; -use std::{iter::zip, ops::Add}; +use std::{fs::File, iter::zip, ops::Add, path::Path}; /// Compute the SNR of the given traces using [`SnrProcessor`]. /// @@ -150,6 +150,30 @@ impl SnrProcessor { self.classes_count.len() } + /// Save the [`SnrProcessor`] to a file. + /// + /// # Warning + /// The file format is not stable as muscat is active development. Thus, the format might + /// change between versions. + pub fn save>(&self, path: P) -> Result<(), Error> { + let file = File::create(path)?; + serde_json::to_writer(file, self)?; + + Ok(()) + } + + /// Load a [`SnrProcessor`] from a file. + /// + /// # Warning + /// The file format is not stable as muscat is active development. Thus, the format might + /// change between versions. + pub fn load>(path: P) -> Result { + let file = File::open(path)?; + let p = serde_json::from_reader(file)?; + + Ok(p) + } + /// Determine if two [`SnrProcessor`] are compatible for addition. /// /// If they were created with the same parameters, they are compatible. @@ -289,6 +313,30 @@ impl TTestProcessor { self.mean_var_1.size() } + /// Save the [`TTestProcessor`] to a file. + /// + /// # Warning + /// The file format is not stable as muscat is active development. Thus, the format might + /// change between versions. + pub fn save>(&self, path: P) -> Result<(), Error> { + let file = File::create(path)?; + serde_json::to_writer(file, self)?; + + Ok(()) + } + + /// Load a [`TTestProcessor`] from a file. + /// + /// # Warning + /// The file format is not stable as muscat is active development. Thus, the format might + /// change between versions. + pub fn load>(path: P) -> Result { + let file = File::open(path)?; + let p = serde_json::from_reader(file)?; + + Ok(p) + } + /// Determine if two [`TTestProcessor`] are compatible for addition. /// /// If they were created with the same parameters, they are compatible. diff --git a/src/lib.rs b/src/lib.rs index fb8ba05..ec27cba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod distinguishers; +pub mod error; pub mod leakage_detection; pub mod leakage_model; pub mod preprocessors; @@ -6,5 +7,7 @@ pub mod processors; pub mod trace; pub mod util; +pub use crate::error::Error; + #[cfg(feature = "quicklog")] pub mod quicklog;