Skip to content

Commit

Permalink
Fix FactorGraph serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaëtan Cassiers committed Feb 20, 2023
1 parent d43cb29 commit 146bd46
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/scalib_ext/scalib-py/src/factor_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ impl FactorGraph {
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &serialize(&self.inner.as_deref()).unwrap()).to_object(py))
let to_ser: Option<&sasca::FactorGraph> = self.inner.as_deref();
Ok(PyBytes::new(py, &serialize(&to_ser).unwrap()).to_object(py))
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.inner = Some(Arc::new(deserialize(s.as_bytes()).unwrap()));
let deser: Option<sasca::FactorGraph> = deserialize(s.as_bytes()).unwrap();
self.inner = deser.map(Arc::new);
Ok(())
}
Err(e) => Err(e),
Expand Down
31 changes: 31 additions & 0 deletions tests/test_factorgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,37 @@ def make_distri(nc, n):
return normalize_distr(np.random.randint(1, 10000000, (n, nc)).astype(np.float64))


def test_copy_fg():
graph = """
NC 2
PROPERTY s1: x = a^b
VAR MULTI x
VAR MULTI a
VAR MULTI b
"""
graph = FactorGraph(graph)
graph2 = copy.deepcopy(graph)


def test_copy_bp():
graph = """
NC 2
PROPERTY s1: x = a^b
VAR MULTI x
VAR MULTI a
VAR MULTI b
"""
graph = FactorGraph(graph)
n = 5
bp_state = BPState(graph, n)
distri_a = make_distri(2, 5)
distri_b = make_distri(2, 5)
bp_state.set_evidence("a", distri_a)
bp_state.set_evidence("b", distri_b)
bp_state.bp_loopy(1, initialize_states=True)
bp_state2 = copy.deepcopy(bp_state)


def test_table():
"""
Test Table lookup
Expand Down

0 comments on commit 146bd46

Please sign in to comment.