Skip to content

Commit

Permalink
[pytree] support collections.defaultdict type for Python pytree (pyto…
Browse files Browse the repository at this point in the history
…rch#113255)

Pull Request resolved: pytorch#113255
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#112485
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Nov 30, 2023
1 parent baeb070 commit 2ab2e8e
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 50 deletions.
197 changes: 160 additions & 37 deletions test/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import re
import unittest
from collections import namedtuple, OrderedDict, UserDict
from collections import defaultdict, namedtuple, OrderedDict, UserDict

import torch
import torch.utils._cxx_pytree as cxx_pytree
Expand Down Expand Up @@ -187,68 +187,68 @@ def run_test_with_leaf(leaf):
subtest(
(
py_pytree,
lambda lst: py_pytree.TreeSpec(
list, None, [py_pytree.LeafSpec() for _ in lst]
lambda tup: py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec() for _ in tup]
),
),
name="py",
),
subtest(
(cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))),
(cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))),
name="cxx",
),
],
)
def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn):
def run_test(lst):
expected_spec = gen_expected_fn(lst)
values, treespec = pytree_impl.tree_flatten(lst)
def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn):
def run_test(tup):
expected_spec = gen_expected_fn(tup)
values, treespec = pytree_impl.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, lst)
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)

unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, lst)
self.assertTrue(isinstance(unflattened, list))
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, tuple))

run_test([])
run_test([1.0, 2])
run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
run_test(())
run_test((1.0,))
run_test((1.0, 2))
run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))

@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda tup: py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec() for _ in tup]
lambda lst: py_pytree.TreeSpec(
list, None, [py_pytree.LeafSpec() for _ in lst]
),
),
name="py",
),
subtest(
(cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))),
(cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))),
name="cxx",
),
],
)
def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn):
def run_test(tup):
expected_spec = gen_expected_fn(tup)
values, treespec = pytree_impl.tree_flatten(tup)
def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn):
def run_test(lst):
expected_spec = gen_expected_fn(lst)
values, treespec = pytree_impl.tree_flatten(lst)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(tup))
self.assertEqual(values, lst)
self.assertEqual(treespec, expected_spec)

unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, tuple))
self.assertEqual(unflattened, lst)
self.assertTrue(isinstance(unflattened, list))

run_test(())
run_test((1.0,))
run_test((1.0, 2))
run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
run_test([])
run_test([1.0, 2])
run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])

@parametrize(
"pytree_impl,gen_expected_fn",
Expand Down Expand Up @@ -316,7 +316,7 @@ def run_test(dct):
),
],
)
def test_flatten_unflatten_odict(self, pytree_impl, gen_expected_fn):
def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn):
def run_test(odict):
expected_spec = gen_expected_fn(odict)
values, treespec = pytree_impl.tree_flatten(odict)
Expand All @@ -335,6 +335,50 @@ def run_test(odict):
od["a"] = torch.tensor(3.14)
run_test(od)

@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda ddct: py_pytree.TreeSpec(
defaultdict,
[ddct.default_factory, list(ddct.keys())],
[py_pytree.LeafSpec() for _ in ddct.values()],
),
),
name="py",
),
subtest(
(
cxx_pytree,
lambda ddct: cxx_pytree.tree_structure(
defaultdict(ddct.default_factory, dict.fromkeys(ddct, 0))
),
),
name="cxx",
),
],
)
def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn):
def run_test(ddct):
expected_spec = gen_expected_fn(ddct)
values, treespec = pytree_impl.tree_flatten(ddct)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(ddct.values()))
self.assertEqual(treespec, expected_spec)

unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, ddct)
self.assertEqual(unflattened.default_factory, ddct.default_factory)
self.assertTrue(isinstance(unflattened, defaultdict))

run_test(defaultdict(list, {}))
run_test(defaultdict(int, {"a": 1}))
run_test(defaultdict(int, {"abcdefg": torch.randn(2, 3)}))
run_test(defaultdict(int, {1: torch.randn(2, 3)}))
run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)}))

@parametrize(
"pytree_impl",
[
Expand Down Expand Up @@ -612,29 +656,55 @@ def test_treespec_repr_dynamo(self):
@parametrize(
"spec",
[
# py_pytree.tree_structure([])
py_pytree.TreeSpec(list, None, []),
# py_pytree.tree_structure(())
py_pytree.TreeSpec(tuple, None, []),
# py_pytree.tree_structure({})
py_pytree.TreeSpec(dict, [], []),
# py_pytree.tree_structure([0])
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
# py_pytree.tree_structure([0, 1])
py_pytree.TreeSpec(
list, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
list,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
# py_pytree.tree_structure((0, 1, 2))
py_pytree.TreeSpec(
tuple,
None,
[py_pytree.LeafSpec(), py_pytree.LeafSpec(), py_pytree.LeafSpec()],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
# py_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
py_pytree.TreeSpec(
dict,
["a", "b", "c"],
[py_pytree.LeafSpec(), py_pytree.LeafSpec(), py_pytree.LeafSpec()],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
# py_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
py_pytree.TreeSpec(
OrderedDict,
["a", "b", "c"],
[
py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
tuple,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
Expand All @@ -648,6 +718,7 @@ def test_treespec_repr_dynamo(self):
),
],
),
# py_pytree.tree_structure([(0, 1, [2, 3])])
py_pytree.TreeSpec(
list,
None,
Expand All @@ -670,12 +741,44 @@ def test_treespec_repr_dynamo(self):
),
],
),
# py_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}}))
py_pytree.TreeSpec(
defaultdict,
[list, ["a", "b", "c"]],
[
py_pytree.TreeSpec(
list,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
py_pytree.TreeSpec(
list,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
py_pytree.TreeSpec(dict, [], []),
],
),
],
)
def test_pytree_serialize(self, spec):
# Ensure that the spec is valid
self.assertEqual(
spec,
py_pytree.tree_structure(
py_pytree.tree_unflatten([0] * spec.num_leaves, spec)
),
)

serialized_spec = py_pytree.treespec_dumps(spec)
self.assertTrue(isinstance(serialized_spec, str))
self.assertTrue(spec == py_pytree.treespec_loads(serialized_spec))
self.assertIsInstance(serialized_spec, str)
self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))

def test_pytree_serialize_namedtuple(self):
Point = namedtuple("Point", ["x", "y"])
Expand Down Expand Up @@ -791,6 +894,7 @@ def test_pytree_serialize_bad_protocol(self):
py_pytree.treespec_loads(bad_protocol_serialized_spec)

def test_saved_serialized(self):
# py_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})]))
complicated_spec = py_pytree.TreeSpec(
OrderedDict,
[1, 2, 3],
Expand All @@ -810,6 +914,15 @@ def test_saved_serialized(self):
),
],
)
# Ensure that the spec is valid
self.assertEqual(
complicated_spec,
py_pytree.tree_structure(
py_pytree.tree_unflatten(
[0] * complicated_spec.num_leaves, complicated_spec
)
),
)

serialized_spec = py_pytree.treespec_dumps(complicated_spec)
saved_spec = (
Expand Down Expand Up @@ -865,12 +978,22 @@ def test_treespec_repr_dynamo(self):
OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
),
cxx_pytree.tree_structure([(0, 1, [2, 3])]),
cxx_pytree.tree_structure(
defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})
),
],
)
def test_pytree_serialize(self, spec):
self.assertEqual(
spec,
cxx_pytree.tree_structure(
cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec)
),
)

serialized_spec = cxx_pytree.treespec_dumps(spec)
self.assertTrue(isinstance(serialized_spec, str))
self.assertTrue(spec == cxx_pytree.treespec_loads(serialized_spec))
self.assertIsInstance(serialized_spec, str)
self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))

def test_pytree_serialize_namedtuple(self):
spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
Expand Down
Loading

0 comments on commit 2ab2e8e

Please sign in to comment.