Skip to content

Commit

Permalink
Handle kernel serialization (#232)
Browse files Browse the repository at this point in the history
* Handle kernel serialization
* Do not use typetag in WASM
* enable tests for serialization
* Update serde feature deps

Co-authored-by: Luis Moreno <[email protected]>
Co-authored-by: Lorenzo <[email protected]>
  • Loading branch information
3 people committed Nov 8, 2022
1 parent 7d87451 commit 62de25b
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 50 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
rand_distr = { version = "0.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
typetag = { version = "0.2", optional = true }

[features]
default = []
serde = ["dep:serde"]
serde = ["dep:serde", "dep:typetag"]
ndarray-bindings = ["dep:ndarray"]
datasets = ["dep:rand_distr", "std_rand", "serde"]
std_rand = ["rand/std_rng", "rand/std"]
Expand Down
46 changes: 10 additions & 36 deletions src/svm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ pub mod svr;

use core::fmt::Debug;

#[cfg(feature = "serde")]
use serde::ser::{SerializeStruct, Serializer};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand All @@ -40,36 +38,20 @@ use crate::linalg::basic::arrays::{Array1, ArrayView1};

/// Defines a kernel function.
/// This is a object-safe trait.
pub trait Kernel {
#[cfg_attr(
all(feature = "serde", not(target_arch = "wasm32")),
typetag::serde(tag = "type")
)]
pub trait Kernel: Debug {
#[allow(clippy::ptr_arg)]
/// Apply kernel function to x_i and x_j
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
/// Return a serializable name
fn name(&self) -> &'static str;
}

impl Debug for dyn Kernel {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Kernel<f64>")
}
}

#[cfg(feature = "serde")]
impl Serialize for dyn Kernel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut s = serializer.serialize_struct("Kernel", 1)?;
s.serialize_field("type", &self.name())?;
s.end()
}
}

/// Pre-defined kernel functions
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Kernels {}
pub struct Kernels;

impl Kernels {
/// Return a default linear
Expand Down Expand Up @@ -211,15 +193,14 @@ impl SigmoidKernel {
}
}

#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for LinearKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
Ok(x_i.dot(x_j))
}
fn name(&self) -> &'static str {
"Linear"
}
}

#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for RBFKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() {
Expand All @@ -231,11 +212,9 @@ impl Kernel for RBFKernel {
let v_diff = x_i.sub(x_j);
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
}
fn name(&self) -> &'static str {
"RBF"
}
}

#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for PolynomialKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
Expand All @@ -247,11 +226,9 @@ impl Kernel for PolynomialKernel {
let dot = x_i.dot(x_j);
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
}
fn name(&self) -> &'static str {
"Polynomial"
}
}

#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for SigmoidKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() {
Expand All @@ -263,9 +240,6 @@ impl Kernel for SigmoidKernel {
let dot = x_i.dot(x_j);
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
}
fn name(&self) -> &'static str {
"Sigmoid"
}
}

#[cfg(test)]
Expand Down
12 changes: 8 additions & 4 deletions src/svm/svc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,11 @@ pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX
pub c: TX,
/// Tolerance for stopping criterion.
pub tol: TX,
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
/// The kernel function.
#[cfg_attr(
all(feature = "serde", target_arch = "wasm32"),
serde(skip_serializing, skip_deserializing)
)]
pub kernel: Option<Box<dyn Kernel>>,
/// Unused parameter.
m: PhantomData<(X, Y, TY)>,
Expand Down Expand Up @@ -1085,7 +1088,7 @@ mod tests {
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "serde")]
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
fn svc_serde() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
Expand Down Expand Up @@ -1119,8 +1122,9 @@ mod tests {
let svc = SVC::fit(&x, &y, &params).unwrap();

// serialization
let serialized_svc = &serde_json::to_string(&svc).unwrap();
let deserialized_svc: SVC<f64, i32, _, _> =
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();

println!("{:?}", serialized_svc);
assert_eq!(svc, deserialized_svc);
}
}
17 changes: 8 additions & 9 deletions src/svm/svr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,11 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
pub c: T,
/// Tolerance for stopping criterion.
pub tol: T,
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
/// The kernel function.
#[cfg_attr(
all(feature = "serde", target_arch = "wasm32"),
serde(skip_serializing, skip_deserializing)
)]
pub kernel: Option<Box<dyn Kernel>>,
}

Expand Down Expand Up @@ -668,7 +671,7 @@ mod tests {
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "serde")]
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
fn svr_serde() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
Expand Down Expand Up @@ -699,13 +702,9 @@ mod tests {

let svr = SVR::fit(&x, &y, &params).unwrap();

let serialized = &serde_json::to_string(&svr).unwrap();

println!("{}", &serialized);

// let deserialized_svr: SVR<f64, DenseMatrix<f64>, LinearKernel> =
// serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> =
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();

// assert_eq!(svr, deserialized_svr);
assert_eq!(svr, deserialized_svr);
}
}

0 comments on commit 62de25b

Please sign in to comment.