Skip to content

Commit

Permalink
Use pytree.tree_map_ everywhere (pytorch#112417)
Browse files Browse the repository at this point in the history
Wherever we discard the output of `tree_map` it's better to call `tree_map_`
which doesn't unflatten the mapped results and so is a lot cheaper.
Pull Request resolved: pytorch#112417
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#112391, pytorch#112392, pytorch#112393, pytorch#112394
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Oct 31, 2023
1 parent 66c32d0 commit 0402492
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 29 deletions.
4 changes: 2 additions & 2 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
TestCase,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
from torch.utils import _pytree as pytree
from torch.utils._sympy.functions import FloorDiv, Mod

aten = torch.ops.aten
Expand All @@ -48,7 +48,7 @@ def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
tree_map(add_func, op)
pytree.tree_map_(add_func, op)
return f
return decorator

Expand Down
4 changes: 2 additions & 2 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.library
from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
from torch._prims_common import CustomOutParamAnnotation
from torch.utils._pytree import tree_map
from torch.utils import _pytree as pytree

__all__ = [
"decomposition_table",
Expand Down Expand Up @@ -182,7 +182,7 @@ def register(op):
_add_op_to_registry(registry, op, fn)

# To handle allowing multiple aten_ops at once
tree_map(register, aten_op)
pytree.tree_map_(register, aten_op)
return fn

return decomposition_decorator
Expand Down
4 changes: 2 additions & 2 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
out_wrapper,
)
from torch._refs import _broadcast_shapes, _maybe_broadcast
from torch.utils._pytree import tree_map
from torch.utils import _pytree as pytree


aten = torch.ops.aten
Expand All @@ -45,7 +45,7 @@ def wrapper(fn):
def register(op):
_add_op_to_registry(meta_table, op, fn)

tree_map(register, op)
pytree.tree_map_(register, op)
return fn

return wrapper
Expand Down
4 changes: 2 additions & 2 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,8 +1253,8 @@ def merge_devices(t):
f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
)

tree_map(merge_devices, args)
tree_map(merge_devices, kwargs)
pytree.tree_map_(merge_devices, args)
pytree.tree_map_(merge_devices, kwargs)

# some functions that allow Python numbers to bind to Tensors
# if we have failed to find a device, and we're running one of these operators,
Expand Down
6 changes: 3 additions & 3 deletions torch/cuda/_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import torch
import torch.utils._cuda_trace as cuda_trace
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map


DEFAULT_STREAM_ID = 0
Expand Down Expand Up @@ -509,15 +509,15 @@ def parse_inputs(
) -> None:
for argument, value in zip_arguments(schema, args, kwargs):
is_write = argument.alias_info is not None and argument.alias_info.is_write
tree_map(
pytree.tree_map_(
functools.partial(
self._handle_argument, is_write=is_write, name=argument.name
),
value,
)

def parse_outputs(self, outputs: Any) -> None:
tree_map(
pytree.tree_map_(
functools.partial(self._handle_argument, is_write=True, is_output=True),
outputs,
)
Expand Down
10 changes: 5 additions & 5 deletions torch/distributed/_shard/common_op_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch.utils._pytree import tree_map
from torch.utils import _pytree as pytree
from typing import Optional

def _basic_validation(op, args=(), kwargs=None):
Expand All @@ -19,8 +19,8 @@ def is_distributed_tensor(e):
if isinstance(e, ShardedTensor):
has_distributed_tensor = True

tree_map(is_distributed_tensor, args)
tree_map(is_distributed_tensor, kwargs)
pytree.tree_map_(is_distributed_tensor, args)
pytree.tree_map_(is_distributed_tensor, kwargs)

if not has_distributed_tensor:
raise TypeError(
Expand All @@ -41,8 +41,8 @@ def validate_pg(e):
)
cur_pg = e._process_group

tree_map(validate_pg, args)
tree_map(validate_pg, kwargs)
pytree.tree_map_(validate_pg, args)
pytree.tree_map_(validate_pg, kwargs)

def _register_default_op(op, decorator):
@decorator(op)
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/_shard/sharded_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
build_global_metadata
)
from torch.distributed.remote_device import _remote_device
from torch.utils._pytree import tree_map
from torch.utils import _pytree as pytree

# Tracking for sharded tensor objects.
_sharded_tensor_lock = threading.Lock()
Expand Down Expand Up @@ -1137,8 +1137,8 @@ def find_sharded_tensor(e):
if st_instance is None and isinstance(e, ShardedTensor):
st_instance = e

tree_map(find_sharded_tensor, args)
tree_map(find_sharded_tensor, kwargs)
pytree.tree_map_(find_sharded_tensor, args)
pytree.tree_map_(find_sharded_tensor, kwargs)

if st_instance is not None:
return dispatch(st_instance, func)
Expand Down
5 changes: 3 additions & 2 deletions torch/distributed/_spmd/comm_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
set_proxy_slot,
track_tensor_tree,
)
from torch.utils import _pytree as pytree
from torch.utils._mode_utils import no_dispatch
from torch.utils._pytree import tree_flatten, tree_map, tree_map_only

Expand Down Expand Up @@ -222,7 +223,7 @@ def set_work(work: torch.distributed._Work, e: Any):
# for it later to make sure the execution during tracing is
# correct. Also, remember comm is already launched
# args[0] is always the collection of output tensors
tree_map(partial(set_work, out[1]), args[0])
pytree.tree_map_(partial(set_work, out[1]), args[0])

# HACK: update the proxy on the input argument as this is an
# inplace collective communication.
Expand All @@ -235,7 +236,7 @@ def set_work(work: torch.distributed._Work, e: Any):
else:
# in eager mode, simply remember work handle as an attribute
out = func(*unwrapped_args, **unwrapped_kwargs)
tree_map(partial(set_work, out[1]), args[0])
pytree.tree_map_(partial(set_work, out[1]), args[0])
return out
else:
if work is not None:
Expand Down
6 changes: 3 additions & 3 deletions torch/fx/passes/backends/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.utils._pytree import tree_map
from torch.utils import _pytree as pytree

import operator

Expand All @@ -30,9 +30,9 @@ def find_not_cuda(t):
found_not_cuda = True

for n in node.all_input_nodes:
tree_map(find_not_cuda, meta_fk(n.meta))
pytree.tree_map_(find_not_cuda, meta_fk(n.meta))

tree_map(find_not_cuda, meta_fk(node.meta))
pytree.tree_map_(find_not_cuda, meta_fk(node.meta))

# NB: factory function is accounted for because the result would be
# cpu or cuda
Expand Down
4 changes: 2 additions & 2 deletions torch/fx/passes/reinplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.fx import Node
from torch.fx._compatibility import compatibility
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.utils._pytree import tree_map, tree_map_only
from torch.utils._pytree import tree_map_only
from torch.utils import _pytree as pytree
from torch.multiprocessing.reductions import StorageWeakRef

Expand Down Expand Up @@ -483,7 +483,7 @@ def f(x):
def _add_to_map(x):
if isinstance(x, FakeTensor):
storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
tree_map(_add_to_map, n.meta['fake_result'])
pytree.tree_map_(_add_to_map, n.meta['fake_result'])

# inplace-ify functional ops, subject to the constraints written below.
all_later_view_inverse_nodes_to_delete = set()
Expand Down
6 changes: 3 additions & 3 deletions torch/testing/_internal/composite_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ def wrap(e):
# have consistent metadata. If they don't have consistent metadata,
# that means the operator did something fishy.
check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor)
tree_map(check, args)
tree_map(check, kwargs)
tree_map(check, rs)
pytree.tree_map_(check, args)
pytree.tree_map_(check, kwargs)
pytree.tree_map_(check, rs)
return rs

return CompositeCompliantTensor, CompositeCompliantTensorMode()
Expand Down

0 comments on commit 0402492

Please sign in to comment.