From 78f1af197444614c3485d7c4ff1bbac4e593775d Mon Sep 17 00:00:00 2001 From: Miklos Homolya Date: Wed, 24 May 2017 17:23:17 +0100 Subject: [PATCH] fix ListTensor pickling and add test case --- gem/gem.py | 3 +++ tests/test_pickle_gem.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/gem/gem.py b/gem/gem.py index 1f74200f..20fc7c4b 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -655,6 +655,9 @@ def children(self): def shape(self): return self.array.shape + def __reduce__(self): + return type(self), (self.array,) + def reconstruct(self, *args): return ListTensor(asarray(args).reshape(self.array.shape)) diff --git a/tests/test_pickle_gem.py b/tests/test_pickle_gem.py index f246a0cb..4d484a98 100644 --- a/tests/test_pickle_gem.py +++ b/tests/test_pickle_gem.py @@ -18,6 +18,14 @@ def test_pickle_gem(protocol): assert repr(expr) == repr(unpickled) +@pytest.mark.parametrize('protocol', range(3)) +def test_listtensor(protocol): + expr = gem.ListTensor([gem.Variable('x', ()), gem.Zero()]) + + unpickled = pickle.loads(pickle.dumps(expr, protocol)) + assert expr == unpickled + + if __name__ == "__main__": import os import sys