Skip to content

Commit

Permalink
Use pytree.arg_tree_leaves everywhere (pytorch#112394)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#112394
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#112391, pytorch#112392, pytorch#112393
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Oct 31, 2023
1 parent 046c0c6 commit 66c32d0
Show file tree
Hide file tree
Showing 40 changed files with 72 additions and 77 deletions.
2 changes: 1 addition & 1 deletion test/distributed/_tensor/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh):
spec = ops[op]
rules = spec.dim_map(*args, **kwargs)
outputs = op(*args, **kwargs)
flat_args = pytree.tree_leaves(args)
flat_args = pytree.arg_tree_leaves(*args)
in_shape = flat_args[0].shape

no_shard_dims = set()
Expand Down
6 changes: 2 additions & 4 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,7 @@ def int(self):

def compute_grads(args, kwrags, results, grads):
def gather_leaf_tensors(args, kwargs):
args = pytree.tree_leaves(args)
kwargs = pytree.tree_leaves(kwargs)
args = args + kwargs
args = pytree.arg_tree_leaves(*args, **kwargs)
leaf_tensors = [
arg for arg in args if isinstance(arg, torch.Tensor) and arg.requires_grad
]
Expand Down Expand Up @@ -7652,7 +7650,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
nonlocal max_live_tensors

kwargs = kwargs if kwargs else {}
for arg in pytree.tree_leaves((args, kwargs)):
for arg in pytree.arg_tree_leaves(*args, **kwargs):
if isinstance(arg, torch.Tensor):
live_tensors[arg] = True

Expand Down
5 changes: 2 additions & 3 deletions test/test_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,8 @@ def test_unsupported(t):
else:
return False

flat_args = pytree.tree_leaves(args)
flat_kwargs = pytree.tree_leaves(kwargs)
return any(test_unsupported(x) for x in itertools.chain(flat_args, flat_kwargs))
flat_args = pytree.arg_tree_leaves(*args, **kwargs)
return any(test_unsupported(x) for x in flat_args)


core_backward_failures = {
Expand Down
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,7 +2114,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):

if torch.Tag.pointwise in func.tags:
shapes = []
for inp in pytree.tree_leaves((args, kwargs)):
for inp in pytree.arg_tree_leaves(*args, **kwargs):
if isinstance(inp, torch.Tensor):
shapes.append(inp.shape)

Expand Down
2 changes: 1 addition & 1 deletion test/test_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2196,7 +2196,7 @@ def to_subclass(t: torch.Tensor):

result_test = op(*args_subclass, **kwargs_subclass)

args_ref_flat = pytree.tree_leaves((args, kwargs))
args_ref_flat = pytree.arg_tree_leaves(*args, **kwargs)
args_ref_flat_tensors = [x for x in args_ref_flat if isinstance(x, torch.Tensor)]

args_test_flat = pytree.tree_leaves((args_subclass, kwargs_subclass))
Expand Down
2 changes: 1 addition & 1 deletion torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def type_casts(
@functools.wraps(f)
def inner(*args, **kwargs):
flat_args = [
x for x in pytree.tree_leaves((args, kwargs)) if isinstance(x, Tensor)
x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor)
]
computation_dtype, result_dtype = utils.elementwise_dtypes(
*flat_args, type_promotion_kind=type_promotion
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@ def create_node(
self, op, target, args=None, kwargs=None, name=None, type_expr=None
):
if self.parent is not None:
flat_args = pytree.tree_leaves((args, kwargs))
flat_args = pytree.arg_tree_leaves(*args, **kwargs)
for arg in flat_args:
if not isinstance(arg, torch.fx.Node):
continue
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/variables/torch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def build_torch_function_fn(tx, value, source):

def can_dispatch_torch_function(tx, args, kwargs):
if tx.output.torch_function_enabled:
all_args = pytree.tree_leaves(args) + pytree.tree_leaves(kwargs)
all_args = pytree.arg_tree_leaves(*args, **kwargs)
return any(isinstance(arg, TensorWithTFOverrideVariable) for arg in all_args)
else:
return False
Expand All @@ -92,7 +92,7 @@ def can_dispatch_torch_function(tx, args, kwargs):
def dispatch_torch_function(tx, fn, args, kwargs):
"""Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args"""

all_args = pytree.tree_leaves(args) + pytree.tree_leaves(kwargs)
all_args = pytree.arg_tree_leaves(*args, **kwargs)
overloaded_args = _get_overloaded_args(
[arg for arg in all_args if isinstance(arg, TensorWithTFOverrideVariable)],
lambda x: x.class_type,
Expand Down
2 changes: 1 addition & 1 deletion torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ def aot_compile(
# We want to export to Torch IR here to utilize the pre_grad passes in
# inductor, which run on Torch IR.
gm = _export_to_torch_ir(f, args, kwargs, constraints)
flat_example_inputs = pytree.tree_leaves(combine_args_kwargs(args, kwargs))
flat_example_inputs = pytree.arg_tree_leaves(*args, **kwargs or {})

with torch.no_grad():
so_path = torch._inductor.aot_compile(gm, flat_example_inputs, options) # type: ignore[arg-type]
Expand Down
6 changes: 3 additions & 3 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ def _get_hints(exprs):
return exprs

def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool:
args_flattened = pytree.tree_leaves(args)
args_flattened = pytree.arg_tree_leaves(*args)
any_subclass_args = any(is_traceable_wrapper_subclass(x) for x in args_flattened if isinstance(x, Tensor))
any_subclass_outputs = any(is_traceable_wrapper_subclass(x) for x in fw_metadata.traced_tangents if isinstance(x, Tensor))
# This tells us whether or not we need to perform any unwrapping/wrapping of tensor subclasses at runtime.
Expand Down Expand Up @@ -4623,7 +4623,7 @@ def aot_function(
def returned_function(*args, **kwargs):
nonlocal cached_res
# Now flatten the tensor args
flat_args = pytree.tree_leaves((args, kwargs))
flat_args = pytree.arg_tree_leaves(*args, **kwargs)

# Compile the function and save it in the cache
if cached_res is None:
Expand Down Expand Up @@ -4973,7 +4973,7 @@ def flattened_joint(*args):
return *fw_outs, *output_gradients
fx_g = make_fx(flattened_joint)(*full_args)

user_args_flat = pytree.tree_leaves(args)
user_args_flat = pytree.arg_tree_leaves(*args)
return fx_g, create_graph_signature(
fx_g,
metadata,
Expand Down
6 changes: 3 additions & 3 deletions torch/_functorch/autograd_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def jvp(ctx, *tangents):
# Mode-only functorch will greatly simplify this logic.
def wrap_outputs_maintaining_identity(
outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS):
flat_unwrapped_inputs = pytree.tree_leaves(unwrapped_inputs)
flat_orig_inputs = pytree.tree_leaves(orig_inputs)
flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs)
flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs)

unwrapped_input_to_orig_input = {
id(unwrapped): orig
Expand Down Expand Up @@ -451,7 +451,7 @@ def get_out_dims():
# the corresponding in_dims with None.
def get_tangents_in_dims(input_dims, tangents):
flat_in_dims, spec = pytree.tree_flatten(input_dims)
flat_tangents = pytree.tree_leaves(tangents)
flat_tangents = pytree.arg_tree_leaves(*tangents)
result = [None if tangent is None else in_dim
for in_dim, tangent in zip(flat_in_dims, flat_tangents)]
return pytree.tree_unflatten(result, spec)
Expand Down
8 changes: 4 additions & 4 deletions torch/_functorch/eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,10 +1513,10 @@ def wrapped(*args, **kwargs):
func_args = _wrap_all_tensors_to_functional(args, func_level)
func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level)

flattened_unwrapped_args = pytree.tree_leaves(args)
flattened_wrapped_args = pytree.tree_leaves(func_args)
flattened_unwrapped_kwargs = pytree.tree_leaves(kwargs)
flattened_wrapped_kwargs = pytree.tree_leaves(func_kwargs)
flattened_unwrapped_args = pytree.arg_tree_leaves(*args)
flattened_wrapped_args = pytree.arg_tree_leaves(*func_args)
flattened_unwrapped_kwargs = pytree.arg_tree_leaves(**kwargs)
flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs)

func_outputs = func(*func_args, **func_kwargs)
outputs = _unwrap_all_tensors_from_functional(func_outputs, reapply_views=reapply_views)
Expand Down
4 changes: 2 additions & 2 deletions torch/_functorch/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs):
elif node.op == 'placeholder':
env[node] = InvalidNode
elif node.op == 'call_function':
all_args = pytree.tree_leaves((node.args, node.kwargs))
all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
all_args = [isinstance(env[x], InvalidNodeBase) for x in all_args if isinstance(x, fx.Node)]
if any(all_args):
env[node] = InvalidNode
Expand Down Expand Up @@ -124,7 +124,7 @@ def _is_fwd_seed_offset(node):


def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs):
outputs = pytree.tree_leaves([node.args for node in joint_module.graph.nodes if node.op == 'output'])
outputs = pytree.arg_tree_leaves(*(node.args for node in joint_module.graph.nodes if node.op == 'output'))
fwd_outputs = outputs[:num_fwd_outputs]
bwd_outputs = outputs[num_fwd_outputs:]
return fwd_outputs, bwd_outputs
Expand Down
4 changes: 2 additions & 2 deletions torch/_higher_order_ops/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
if node.op == "output":
false_outs.extend(node.args)

flat_true_outs = pytree.tree_leaves(true_outs)
flat_false_outs = pytree.tree_leaves(false_outs)
flat_true_outs = pytree.arg_tree_leaves(*true_outs)
flat_false_outs = pytree.arg_tree_leaves(*false_outs)
if len(flat_true_outs) != len(flat_false_outs):
raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected to return same number of outputs but got:"
Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/out_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def is_int_mm(op, output_dtype, args):


def out_dtype_fallback(op, output_dtype, *args):
flat_inputs = pytree.tree_leaves(args) + [torch.ones(1, dtype=output_dtype)]
flat_inputs = pytree.arg_tree_leaves(*args) + [torch.ones(1, dtype=output_dtype)]
promote_dtype: torch.dtype = elementwise_dtypes(
*flat_inputs,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def autograd_not_implemented_inner(
"""
with torch._C._AutoDispatchBelowAutograd():
result = operator(*args, **kwargs)
flat_operands = pytree.tree_leaves(args)
flat_operands = pytree.arg_tree_leaves(*args)
if torch.is_grad_enabled() and any(
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
):
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def fw_compiler_base(
if config.keep_output_stride:
*_, model_outputs_node = model.graph.nodes
assert model_outputs_node.op == "output"
model_outputs = pytree.tree_leaves(model_outputs_node.args)
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
num_model_outputs = len(model_outputs)

context = torch._guards.TracingContext.get()
Expand Down Expand Up @@ -1343,7 +1343,7 @@ def forward(self, *args):
@functools.wraps(compiled_fn)
def wrapper(*args):
# note this doesn't check the spec, assuming it is the same
return compiled_fn(*pytree.tree_leaves(args))
return compiled_fn(*pytree.arg_tree_leaves(*args))

return wrapper

Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def set_env(arg):
return super().run_node(node)

args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.tree_leaves((args, kwargs))
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)

if self.unknown_value in flattened_inputs:
return self.unknown_value
Expand Down Expand Up @@ -137,7 +137,7 @@ def set_env(arg):

self.add_node_replacement(node, out)

flattened_node_inps = pytree.tree_leaves((node.args, node.kwargs))
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)

for n in flattened_node_inps:
if not isinstance(n, torch.fx.Node):
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def __init__(self, elem, name: Optional[str], mod):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
erased_tensors = [
e for e in pytree.tree_leaves((args, kwargs)) if isinstance(e, ErasedTensor)
e
for e in pytree.arg_tree_leaves(*args, **kwargs)
if isinstance(e, ErasedTensor)
]
assert len(erased_tensors) > 0
e = erased_tensors[0]
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/fx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str,
First value returns a boolean if any of the input nodes don't have a faketensor.
"""
args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
if any(isinstance(a, torch.fx.Node) for a in pytree.tree_leaves((args, kwargs))):
if any(
isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
):
return False, args, kwargs
return True, args, kwargs
2 changes: 1 addition & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ def check_skip_condition(node, parent, is_output):
return False

# only skip codegen if there is a cpu output, not input
for arg in pytree.tree_leaves((node.args, node.kwargs)):
for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs):
if check_skip_condition(arg, node, is_output=False):
return True

Expand Down
5 changes: 1 addition & 4 deletions torch/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,7 @@ def name(self):


def _to_flat_tuple(args, kwargs):
flat_args = torch.utils._pytree.tree_leaves(args)
flat_kwargs = torch.utils._pytree.tree_leaves(kwargs)
flat_all = flat_args + flat_kwargs
return flat_all
return torch.utils._pytree.arg_tree_leaves(*args, **kwargs)


def _compute_keyset(args, kwargs, non_fallthrough_keys):
Expand Down
2 changes: 1 addition & 1 deletion torch/_prims_common/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _fn(*args, **kwargs):
if x in bound.arguments.keys()
)

flattened_type_promoting_args = pytree.tree_leaves(type_promoting_args)
flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args)
compute_dtype, result_dtype = utils.elementwise_dtypes(
*flattened_type_promoting_args,
type_promotion_kind=self.type_promotion_kind,
Expand Down
6 changes: 2 additions & 4 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,9 +1177,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return NotImplemented

fake_mode = None
for arg in itertools.chain(
pytree.tree_leaves(args), pytree.tree_leaves(kwargs)
):
for arg in pytree.arg_tree_leaves(*args, **kwargs):
if isinstance(arg, FakeTensor):
fake_mode = arg.fake_mode
break
Expand Down Expand Up @@ -1900,7 +1898,7 @@ def to_real_tensor(e):
tensor_impls = set()
storages = set()

for e in pytree.tree_leaves((args, kwargs)):
for e in pytree.arg_tree_leaves(*args, **kwargs):
if isinstance(e, torch.Tensor):
if not e.is_sparse:
storages.add(e._typed_storage()._cdata)
Expand Down
4 changes: 2 additions & 2 deletions torch/_subclasses/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,8 @@ def inner(*args, **kwargs):
func_args = pytree.tree_map_only(torch.Tensor, to_fun, args)
func_kwargs = pytree.tree_map_only(torch.Tensor, to_fun, kwargs)

flattened_wrapped_args = pytree.tree_leaves(func_args)
flattened_wrapped_kwargs = pytree.tree_leaves(func_kwargs)
flattened_wrapped_args = pytree.arg_tree_leaves(*func_args)
flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs)

disable_above = torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
Expand Down
4 changes: 2 additions & 2 deletions torch/cuda/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def make_graphed_callables(
+ ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
+ "``requires_grad=False``."
)
flatten_arg = _pytree.tree_leaves(args)
flatten_arg = _pytree.arg_tree_leaves(*args)
flatten_sample_args.append(tuple(flatten_arg))
assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
"In the beta API, sample_args "
Expand Down Expand Up @@ -443,7 +443,7 @@ def functionalized(*user_args):
# Runs the autograd function with inputs == all inputs to the graph that might require grad
# (explicit user args + module parameters)
# Assumes module params didn't change since capture.
flatten_user_args = _pytree.tree_leaves(user_args)
flatten_user_args = _pytree.arg_tree_leaves(*user_args)
out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
return _pytree.tree_unflatten(out, output_unflatten_spec)

Expand Down
8 changes: 5 additions & 3 deletions torch/distributed/_spmd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _dedup_collectives(gm: fx.GraphModule) -> fx.GraphModule:

for node in gm.graph.nodes:
# replace all args with the results from the first unique comm op
args = pytree.tree_leaves(node.args)
args = pytree.arg_tree_leaves(*node.args)

if node.target in DEDUP_TARGETS:
args_key = (node.target, *args)
Expand Down Expand Up @@ -340,7 +340,7 @@ def _compile(
# FIXME(@mrshenli): support multiple Optiimzer instances
# FIXME(@mrshenli): need to broadcast model to sync parameters
mod, opt = None, None
for arg in pytree.tree_leaves(list(args) + list(kwargs.values())):
for arg in pytree.arg_tree_leaves(*args, **kwargs):
if isinstance(arg, nn.Module):
assert mod is None, "Only support single nn.Module for now"
mod = arg
Expand Down Expand Up @@ -539,7 +539,9 @@ def wrapper(*args, **kwargs):
compiled_obj = _compile(func, module_override, mode, *args, **kwargs)
wrapper.__dict__[COMPILED_OBJECT_KEY] = compiled_obj

flat_inps = compiled_obj.flat_state + pytree.tree_leaves([args, kwargs])
flat_inps = compiled_obj.flat_state + pytree.arg_tree_leaves(
*args, **kwargs
)

with torch.no_grad():
# N.B.: we don't need autograd as backward has already been
Expand Down
Loading

0 comments on commit 66c32d0

Please sign in to comment.