diff --git a/src/scalib_ext/scalib-py/src/factor_graph.rs b/src/scalib_ext/scalib-py/src/factor_graph.rs index ee8cf907..b467b390 100644 --- a/src/scalib_ext/scalib-py/src/factor_graph.rs +++ b/src/scalib_ext/scalib-py/src/factor_graph.rs @@ -58,13 +58,15 @@ impl FactorGraph { } pub fn __getstate__(&self, py: Python) -> PyResult { - Ok(PyBytes::new(py, &serialize(&self.inner.as_deref()).unwrap()).to_object(py)) + let to_ser: Option<&sasca::FactorGraph> = self.inner.as_deref(); + Ok(PyBytes::new(py, &serialize(&to_ser).unwrap()).to_object(py)) } pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { match state.extract::<&PyBytes>(py) { Ok(s) => { - self.inner = Some(Arc::new(deserialize(s.as_bytes()).unwrap())); + let deser: Option = deserialize(s.as_bytes()).unwrap(); + self.inner = deser.map(Arc::new); Ok(()) } Err(e) => Err(e), diff --git a/tests/test_factorgraph.py b/tests/test_factorgraph.py index 1d47e7ef..5fe8f948 100644 --- a/tests/test_factorgraph.py +++ b/tests/test_factorgraph.py @@ -13,6 +13,37 @@ def make_distri(nc, n): return normalize_distr(np.random.randint(1, 10000000, (n, nc)).astype(np.float64)) +def test_copy_fg(): + graph = """ + NC 2 + PROPERTY s1: x = a^b + VAR MULTI x + VAR MULTI a + VAR MULTI b + """ + graph = FactorGraph(graph) + graph2 = copy.deepcopy(graph) + + +def test_copy_bp(): + graph = """ + NC 2 + PROPERTY s1: x = a^b + VAR MULTI x + VAR MULTI a + VAR MULTI b + """ + graph = FactorGraph(graph) + n = 5 + bp_state = BPState(graph, n) + distri_a = make_distri(2, 5) + distri_b = make_distri(2, 5) + bp_state.set_evidence("a", distri_a) + bp_state.set_evidence("b", distri_b) + bp_state.bp_loopy(1, initialize_states=True) + bp_state2 = copy.deepcopy(bp_state) + + def test_table(): """ Test Table lookup