diff --git a/Cargo.toml b/Cargo.toml index 63c9389a..b13a1e32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,9 +29,12 @@ rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } rand_distr = { version = "0.4", optional = true } serde = { version = "1", features = ["derive"], optional = true } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +typetag = { version = "0.2", optional = true } + [features] default = [] -serde = ["dep:serde"] +serde = ["dep:serde", "dep:typetag"] ndarray-bindings = ["dep:ndarray"] datasets = ["dep:rand_distr", "std_rand", "serde"] std_rand = ["rand/std_rng", "rand/std"] diff --git a/src/svm/mod.rs b/src/svm/mod.rs index ef0f0033..b2bd79cb 100644 --- a/src/svm/mod.rs +++ b/src/svm/mod.rs @@ -30,8 +30,6 @@ pub mod svr; use core::fmt::Debug; -#[cfg(feature = "serde")] -use serde::ser::{SerializeStruct, Serializer}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -40,36 +38,20 @@ use crate::linalg::basic::arrays::{Array1, ArrayView1}; /// Defines a kernel function. /// This is a object-safe trait. -pub trait Kernel { +#[cfg_attr( + all(feature = "serde", not(target_arch = "wasm32")), + typetag::serde(tag = "type") +)] +pub trait Kernel: Debug { #[allow(clippy::ptr_arg)] /// Apply kernel function to x_i and x_j fn apply(&self, x_i: &Vec, x_j: &Vec) -> Result; - /// Return a serializable name - fn name(&self) -> &'static str; -} - -impl Debug for dyn Kernel { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "Kernel") - } -} - -#[cfg(feature = "serde")] -impl Serialize for dyn Kernel { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut s = serializer.serialize_struct("Kernel", 1)?; - s.serialize_field("type", &self.name())?; - s.end() - } } /// Pre-defined kernel functions #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] -pub struct Kernels {} +pub struct Kernels; impl Kernels { /// Return a default linear @@ -211,15 +193,14 @@ impl SigmoidKernel { } } +#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)] impl Kernel for LinearKernel { fn apply(&self, x_i: &Vec, x_j: &Vec) -> Result { Ok(x_i.dot(x_j)) } - fn name(&self) -> &'static str { - "Linear" - } } +#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)] impl Kernel for RBFKernel { fn apply(&self, x_i: &Vec, x_j: &Vec) -> Result { if self.gamma.is_none() { @@ -231,11 +212,9 @@ impl Kernel for RBFKernel { let v_diff = x_i.sub(x_j); Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp()) } - fn name(&self) -> &'static str { - "RBF" - } } +#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)] impl Kernel for PolynomialKernel { fn apply(&self, x_i: &Vec, x_j: &Vec) -> Result { if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() { @@ -247,11 +226,9 @@ impl Kernel for PolynomialKernel { let dot = x_i.dot(x_j); Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap())) } - fn name(&self) -> &'static str { - "Polynomial" - } } +#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)] impl Kernel for SigmoidKernel { fn apply(&self, x_i: &Vec, x_j: &Vec) -> Result { if self.gamma.is_none() || self.coef0.is_none() { @@ -263,9 +240,6 @@ impl Kernel for SigmoidKernel { let dot = x_i.dot(x_j); Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh()) } - fn name(&self) -> &'static str { - "Sigmoid" - } } #[cfg(test)] diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 74998f57..8cd5d5b9 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -100,8 +100,11 @@ pub struct SVCParameters>, /// Unused parameter. m: PhantomData<(X, Y, TY)>, @@ -1085,7 +1088,7 @@ mod tests { wasm_bindgen_test::wasm_bindgen_test )] #[test] - #[cfg(feature = "serde")] + #[cfg(all(feature = "serde", not(target_arch = "wasm32")))] fn svc_serde() { let x = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], @@ -1119,8 +1122,9 @@ mod tests { let svc = SVC::fit(&x, &y, ¶ms).unwrap(); // serialization - let serialized_svc = &serde_json::to_string(&svc).unwrap(); + let deserialized_svc: SVC = + serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap(); - println!("{:?}", serialized_svc); + assert_eq!(svc, deserialized_svc); } } diff --git a/src/svm/svr.rs b/src/svm/svr.rs index 8d49525b..bf53e723 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -92,8 +92,11 @@ pub struct SVRParameters { pub c: T, /// Tolerance for stopping criterion. pub tol: T, - #[cfg_attr(feature = "serde", serde(skip_deserializing))] /// The kernel function. + #[cfg_attr( + all(feature = "serde", target_arch = "wasm32"), + serde(skip_serializing, skip_deserializing) + )] pub kernel: Option>, } @@ -668,7 +671,7 @@ mod tests { wasm_bindgen_test::wasm_bindgen_test )] #[test] - #[cfg(feature = "serde")] + #[cfg(all(feature = "serde", not(target_arch = "wasm32")))] fn svr_serde() { let x = DenseMatrix::from_2d_array(&[ &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], @@ -699,13 +702,9 @@ mod tests { let svr = SVR::fit(&x, &y, ¶ms).unwrap(); - let serialized = &serde_json::to_string(&svr).unwrap(); - - println!("{}", &serialized); - - // let deserialized_svr: SVR, LinearKernel> = - // serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); + let deserialized_svr: SVR, _> = + serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); - // assert_eq!(svr, deserialized_svr); + assert_eq!(svr, deserialized_svr); } }