Skip to content

Commit

Permalink
Merge pull request #125 from firedrakeproject/fix-listtensor-pickle
Browse files Browse the repository at this point in the history
Fix ListTensor pickling
  • Loading branch information
miklos1 authored May 30, 2017
2 parents 361c1b0 + 78f1af1 commit 7cbae93
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,9 @@ def children(self):
def shape(self):
return self.array.shape

def __reduce__(self):
return type(self), (self.array,)

def reconstruct(self, *args):
return ListTensor(asarray(args).reshape(self.array.shape))

Expand Down
8 changes: 8 additions & 0 deletions tests/test_pickle_gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ def test_pickle_gem(protocol):
assert repr(expr) == repr(unpickled)


@pytest.mark.parametrize('protocol', range(3))
def test_listtensor(protocol):
expr = gem.ListTensor([gem.Variable('x', ()), gem.Zero()])

unpickled = pickle.loads(pickle.dumps(expr, protocol))
assert expr == unpickled


if __name__ == "__main__":
import os
import sys
Expand Down

0 comments on commit 7cbae93

Please sign in to comment.