diff --git a/test/test_pytree.py b/test/test_pytree.py index 76a16c827c7c5..d546e33d55175 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -3,7 +3,7 @@ import inspect import re import unittest -from collections import namedtuple, OrderedDict, UserDict +from collections import defaultdict, namedtuple, OrderedDict, UserDict import torch import torch.utils._cxx_pytree as cxx_pytree @@ -187,33 +187,34 @@ def run_test_with_leaf(leaf): subtest( ( py_pytree, - lambda lst: py_pytree.TreeSpec( - list, None, [py_pytree.LeafSpec() for _ in lst] + lambda tup: py_pytree.TreeSpec( + tuple, None, [py_pytree.LeafSpec() for _ in tup] ), ), name="py", ), subtest( - (cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))), + (cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))), name="cxx", ), ], ) - def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn): - def run_test(lst): - expected_spec = gen_expected_fn(lst) - values, treespec = pytree_impl.tree_flatten(lst) + def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn): + def run_test(tup): + expected_spec = gen_expected_fn(tup) + values, treespec = pytree_impl.tree_flatten(tup) self.assertTrue(isinstance(values, list)) - self.assertEqual(values, lst) + self.assertEqual(values, list(tup)) self.assertEqual(treespec, expected_spec) unflattened = pytree_impl.tree_unflatten(values, treespec) - self.assertEqual(unflattened, lst) - self.assertTrue(isinstance(unflattened, list)) + self.assertEqual(unflattened, tup) + self.assertTrue(isinstance(unflattened, tuple)) - run_test([]) - run_test([1.0, 2]) - run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11]) + run_test(()) + run_test((1.0,)) + run_test((1.0, 2)) + run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11)) @parametrize( "pytree_impl,gen_expected_fn", @@ -221,34 +222,33 @@ def run_test(lst): subtest( ( py_pytree, - lambda tup: py_pytree.TreeSpec( - tuple, None, [py_pytree.LeafSpec() for _ in tup] + lambda lst: py_pytree.TreeSpec( + list, None, [py_pytree.LeafSpec() for _ in lst] ), ), name="py", ), subtest( - (cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))), + (cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))), name="cxx", ), ], ) - def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn): - def run_test(tup): - expected_spec = gen_expected_fn(tup) - values, treespec = pytree_impl.tree_flatten(tup) + def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn): + def run_test(lst): + expected_spec = gen_expected_fn(lst) + values, treespec = pytree_impl.tree_flatten(lst) self.assertTrue(isinstance(values, list)) - self.assertEqual(values, list(tup)) + self.assertEqual(values, lst) self.assertEqual(treespec, expected_spec) unflattened = pytree_impl.tree_unflatten(values, treespec) - self.assertEqual(unflattened, tup) - self.assertTrue(isinstance(unflattened, tuple)) + self.assertEqual(unflattened, lst) + self.assertTrue(isinstance(unflattened, list)) - run_test(()) - run_test((1.0,)) - run_test((1.0, 2)) - run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11)) + run_test([]) + run_test([1.0, 2]) + run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11]) @parametrize( "pytree_impl,gen_expected_fn", @@ -316,7 +316,7 @@ def run_test(dct): ), ], ) - def test_flatten_unflatten_odict(self, pytree_impl, gen_expected_fn): + def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn): def run_test(odict): expected_spec = gen_expected_fn(odict) values, treespec = pytree_impl.tree_flatten(odict) @@ -335,6 +335,50 @@ def run_test(odict): od["a"] = torch.tensor(3.14) run_test(od) + @parametrize( + "pytree_impl,gen_expected_fn", + [ + subtest( + ( + py_pytree, + lambda ddct: py_pytree.TreeSpec( + defaultdict, + [ddct.default_factory, list(ddct.keys())], + [py_pytree.LeafSpec() for _ in ddct.values()], + ), + ), + name="py", + ), + subtest( + ( + cxx_pytree, + lambda ddct: cxx_pytree.tree_structure( + defaultdict(ddct.default_factory, dict.fromkeys(ddct, 0)) + ), + ), + name="cxx", + ), + ], + ) + def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn): + def run_test(ddct): + expected_spec = gen_expected_fn(ddct) + values, treespec = pytree_impl.tree_flatten(ddct) + self.assertTrue(isinstance(values, list)) + self.assertEqual(values, list(ddct.values())) + self.assertEqual(treespec, expected_spec) + + unflattened = pytree_impl.tree_unflatten(values, treespec) + self.assertEqual(unflattened, ddct) + self.assertEqual(unflattened.default_factory, ddct.default_factory) + self.assertTrue(isinstance(unflattened, defaultdict)) + + run_test(defaultdict(list, {})) + run_test(defaultdict(int, {"a": 1})) + run_test(defaultdict(int, {"abcdefg": torch.randn(2, 3)})) + run_test(defaultdict(int, {1: torch.randn(2, 3)})) + run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)})) + @parametrize( "pytree_impl", [ @@ -612,29 +656,55 @@ def test_treespec_repr_dynamo(self): @parametrize( "spec", [ + # py_pytree.tree_structure([]) py_pytree.TreeSpec(list, None, []), + # py_pytree.tree_structure(()) py_pytree.TreeSpec(tuple, None, []), + # py_pytree.tree_structure({}) py_pytree.TreeSpec(dict, [], []), + # py_pytree.tree_structure([0]) py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), + # py_pytree.tree_structure([0, 1]) py_pytree.TreeSpec( - list, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] + list, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], ), + # py_pytree.tree_structure((0, 1, 2)) py_pytree.TreeSpec( tuple, None, - [py_pytree.LeafSpec(), py_pytree.LeafSpec(), py_pytree.LeafSpec()], + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], ), + # py_pytree.tree_structure({"a": 0, "b": 1, "c": 2}) py_pytree.TreeSpec( dict, ["a", "b", "c"], - [py_pytree.LeafSpec(), py_pytree.LeafSpec(), py_pytree.LeafSpec()], + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], ), + # py_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) py_pytree.TreeSpec( OrderedDict, ["a", "b", "c"], [ py_pytree.TreeSpec( - tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] + tuple, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], ), py_pytree.LeafSpec(), py_pytree.TreeSpec( @@ -648,6 +718,7 @@ def test_treespec_repr_dynamo(self): ), ], ), + # py_pytree.tree_structure([(0, 1, [2, 3])]) py_pytree.TreeSpec( list, None, @@ -670,12 +741,44 @@ def test_treespec_repr_dynamo(self): ), ], ), + # py_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})) + py_pytree.TreeSpec( + defaultdict, + [list, ["a", "b", "c"]], + [ + py_pytree.TreeSpec( + list, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], + ), + py_pytree.TreeSpec( + list, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], + ), + py_pytree.TreeSpec(dict, [], []), + ], + ), ], ) def test_pytree_serialize(self, spec): + # Ensure that the spec is valid + self.assertEqual( + spec, + py_pytree.tree_structure( + py_pytree.tree_unflatten([0] * spec.num_leaves, spec) + ), + ) + serialized_spec = py_pytree.treespec_dumps(spec) - self.assertTrue(isinstance(serialized_spec, str)) - self.assertTrue(spec == py_pytree.treespec_loads(serialized_spec)) + self.assertIsInstance(serialized_spec, str) + self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec)) def test_pytree_serialize_namedtuple(self): Point = namedtuple("Point", ["x", "y"]) @@ -791,6 +894,7 @@ def test_pytree_serialize_bad_protocol(self): py_pytree.treespec_loads(bad_protocol_serialized_spec) def test_saved_serialized(self): + # py_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})])) complicated_spec = py_pytree.TreeSpec( OrderedDict, [1, 2, 3], @@ -810,6 +914,15 @@ def test_saved_serialized(self): ), ], ) + # Ensure that the spec is valid + self.assertEqual( + complicated_spec, + py_pytree.tree_structure( + py_pytree.tree_unflatten( + [0] * complicated_spec.num_leaves, complicated_spec + ) + ), + ) serialized_spec = py_pytree.treespec_dumps(complicated_spec) saved_spec = ( @@ -865,12 +978,22 @@ def test_treespec_repr_dynamo(self): OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) ), cxx_pytree.tree_structure([(0, 1, [2, 3])]), + cxx_pytree.tree_structure( + defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}}) + ), ], ) def test_pytree_serialize(self, spec): + self.assertEqual( + spec, + cxx_pytree.tree_structure( + cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec) + ), + ) + serialized_spec = cxx_pytree.treespec_dumps(spec) - self.assertTrue(isinstance(serialized_spec, str)) - self.assertTrue(spec == cxx_pytree.treespec_loads(serialized_spec)) + self.assertIsInstance(serialized_spec, str) + self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec)) def test_pytree_serialize_namedtuple(self): spec = cxx_pytree.tree_structure(GlobalPoint(0, 1)) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index d884f2e0668d4..4d3e51b990354 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -16,14 +16,16 @@ """ import dataclasses +import importlib import json import threading import warnings -from collections import deque, namedtuple, OrderedDict +from collections import defaultdict, deque, namedtuple, OrderedDict from typing import ( Any, Callable, cast, + DefaultDict, Dict, Iterable, List, @@ -326,22 +328,68 @@ def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: return context -def _odict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: +def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: return list(d.values()), list(d.keys()) -def _odict_unflatten( +def _ordereddict_unflatten( values: Iterable[Any], context: Context, ) -> GenericOrderedDict[Any, Any]: return OrderedDict((key, value) for key, value in zip(context, values)) +_odict_flatten = _ordereddict_flatten +_odict_unflatten = _ordereddict_unflatten + + +def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]: + values, dict_context = _dict_flatten(d) + return values, [d.default_factory, dict_context] + + +def _defaultdict_unflatten( + values: Iterable[Any], + context: Context, +) -> DefaultDict[Any, Any]: + default_factory, dict_context = context + return defaultdict(default_factory, _dict_unflatten(values, dict_context)) + + +def _defaultdict_serialize(context: Context) -> DumpableContext: + default_factory, dict_context = context + json_defaultdict = { + "default_factory_module": default_factory.__module__, + "default_factory_name": default_factory.__qualname__, + "dict_context": dict_context, + } + return json_defaultdict + + +def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context: + assert isinstance(dumpable_context, dict) + assert set(dumpable_context) == { + "default_factory_module", + "default_factory_name", + "dict_context", + } + + default_factory_module = dumpable_context["default_factory_module"] + default_factory_name = dumpable_context["default_factory_name"] + assert isinstance(default_factory_module, str) + assert isinstance(default_factory_name, str) + module = importlib.import_module(default_factory_module) + default_factory = getattr(module, default_factory_name) + + dict_context = dumpable_context["dict_context"] + return [default_factory, dict_context] + + _private_register_pytree_node( - dict, - _dict_flatten, - _dict_unflatten, - serialized_type_name="builtins.dict", + tuple, + _tuple_flatten, + _tuple_unflatten, + serialized_type_name="builtins.tuple", ) _private_register_pytree_node( list, @@ -350,10 +398,10 @@ def _odict_unflatten( serialized_type_name="builtins.list", ) _private_register_pytree_node( - tuple, - _tuple_flatten, - _tuple_unflatten, - serialized_type_name="builtins.tuple", + dict, + _dict_flatten, + _dict_unflatten, + serialized_type_name="builtins.dict", ) _private_register_pytree_node( namedtuple, # type: ignore[arg-type] @@ -365,10 +413,18 @@ def _odict_unflatten( ) _private_register_pytree_node( OrderedDict, - _odict_flatten, - _odict_unflatten, + _ordereddict_flatten, + _ordereddict_unflatten, serialized_type_name="collections.OrderedDict", ) +_private_register_pytree_node( + defaultdict, + _defaultdict_flatten, + _defaultdict_unflatten, + serialized_type_name="collections.defaultdict", + to_dumpable_context=_defaultdict_serialize, + from_dumpable_context=_defaultdict_deserialize, +) # h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple