Skip to content

Commit

Permalink
stop non-differentiable values from being materialized in aotautograd (
Browse files Browse the repository at this point in the history
…pytorch#110721)

Pull Request resolved: pytorch#110721
Approved by: https://github.com/bdhirsh
ghstack dependencies: pytorch#110720
  • Loading branch information
Chillee authored and pytorchmergebot committed Oct 9, 2023
1 parent c596db7 commit 201d02e
Showing 1 changed file with 50 additions and 54 deletions.
104 changes: 50 additions & 54 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,13 +980,14 @@ def inner(*flat_args):
f_input_tangents = [
inp
for inp, info in zip(flat_f_args, input_info)
if info.mutates_data
if info.mutates_data and info.requires_grad
]
f_output_tangents = [
o
for o, info in zip(flat_f_outs, output_info)
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]
and issubclass(info.raw_type, torch.Tensor)
and info.requires_grad
]
# intermediate bases are also included in the backward graph
f_tangents = f_input_tangents + f_output_tangents + intermediate_bases
Expand Down Expand Up @@ -1256,7 +1257,7 @@ def inner_fn(*args):

# Also return a boolean mask specifying which outputs to this function will be used as tangents
mutated_inputs_grad_mask = [
meta.input_info[meta.mutated_inp_indices[i]].mutates_data
meta.input_info[meta.mutated_inp_indices[i]].mutates_data and meta.input_info[meta.mutated_inp_indices[i]].requires_grad
for (i, x) in enumerate(mutated_inputs_to_return)
]

Expand All @@ -1268,6 +1269,7 @@ def inner_fn(*args):
# Also, only tensor outputs should participate in the backward
# (in particular, Symint outputs in the forward graph shouldn't get tangents)
and issubclass(meta.output_info[i].raw_type, torch.Tensor)
and meta.output_info[i].requires_grad
for (i, x) in enumerate(outs)
]

Expand Down Expand Up @@ -2125,13 +2127,16 @@ def create_synthetic_base_metadata(

for o in m.output_info]

num_outer_mutated_data_inps = len([x for x in m.input_info if x.mutates_data])
inner_mutated_data_inps = [x for inner_idx, x in enumerate(inner_args) if input_infos[inner_idx].mutates_data]
inner_mutated_tangents = [
x
for inner_idx, x in enumerate(inner_args)
if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad
]

requires_grad_info = mutated_inp_require_grad_info + output_grad_info + input_metadata_mutation_grad_info
output_info = existing_output_infos + input_metadata_output_info
# Regenerate traced tangents to include mutated inputs including synthetic bases
traced_tangents = inner_mutated_data_inps + m.traced_tangents[num_outer_mutated_data_inps:]
traced_tangents = inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents):]

return ViewAndMutationMeta(
input_info=input_infos,
Expand Down Expand Up @@ -2498,7 +2503,7 @@ def wrapped_flat_fn(*args):
)(*flat_args_with_synthetic_bases)
assert ref_fw_metadata == fw_metadata_updated, (
f'ref_metadata={pprint.pformat(partial_asdict(ref_fw_metadata))}, '
f'actual_metadata={pprint.pformat(partial_asdict(fw_metadata_updated))}'
f'\nactual_metadata={pprint.pformat(partial_asdict(fw_metadata_updated))}'
)

compiled_fn = compiler_fn(wrapped_flat_fn, flat_args_with_synthetic_bases, aot_config, fw_metadata=fw_metadata_updated)
Expand Down Expand Up @@ -3148,6 +3153,7 @@ def forward(ctx, *deduped_flat_tensor_args):
and not raw_returns_meta[i].requires_grad
]
ctx.mark_non_differentiable(*fw_outs_not_requiring_grad)
ctx._materialize_non_diff_grads = False

functionalized_rng_runtime_epilogue(
CompiledFunction.metadata,
Expand Down Expand Up @@ -3175,56 +3181,46 @@ def backward(ctx, *flat_args):

assert len(flat_args) == expected_grad_outs
out_info = CompiledFunction.metadata.output_info
if (
CompiledFunction.metadata.num_mutated_metadata_only_inputs > 0
or CompiledFunction.metadata.num_outputs_aliased > 0
):
inp_tangents, out_tangents, intermediate_base_tangents = (
flat_args[0:num_mutated_inps],
flat_args[num_mutated_inps:num_mutated_inps + CompiledFunction.metadata.num_outputs],
flat_args[num_mutated_inps + CompiledFunction.metadata.num_outputs:],
)
# input_info contains info on *every* input,
# But in the backward(), we are only given grad outputs for every mutated input.
# We then need to filter out the grad outputs that correspond to metadata-only mutations.
mutated_inp_indices = CompiledFunction.metadata.mutated_inp_indices
input_info = CompiledFunction.metadata.input_info
assert len(inp_tangents) == len(mutated_inp_indices)
inp_tangents_filtered = [
x
for x, info_idx in zip(inp_tangents, mutated_inp_indices)
if input_info[info_idx].mutates_data
]
# We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates
out_tangents_filtered = [
x
for x, info in zip(out_tangents, out_info)
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]
and issubclass(info.raw_type, torch.Tensor)
]
# intermediate bases always require gradients, and always participate in the backward graph.
flat_bw_args = itertools.chain(inp_tangents_filtered, out_tangents_filtered, intermediate_base_tangents)

# sanity asserts
# metadata_only_inps = [
# x for x, info_idx in zip(inp_tangents, mutated_inp_indices)
# if not input_info[info_idx].mutates_data
# ]
# aliased_outputs = [
# x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]
# assert all(x is None for x in metadata_only_inps)
# assert all(x is None for x in aliased_outputs)
else:
# filter out non-tensor grad_outputs (aka due to ints being returned as outputs in the forward)
num_mutated_inps = CompiledFunction.metadata.num_mutated_inputs
mutated_inp_args = flat_args[:num_mutated_inps] if num_mutated_inps > 0 else []
user_tangents = flat_args[num_mutated_inps:]
assert len(user_tangents) == len(out_info)
filtered_user_tangents = [x for x, info in zip(user_tangents, out_info) if issubclass(info.raw_type, torch.Tensor)]
flat_bw_args = tuple(mutated_inp_args) + tuple(filtered_user_tangents)

inp_tangents, out_tangents, intermediate_base_tangents = (
flat_args[0:num_mutated_inps],
flat_args[num_mutated_inps:num_mutated_inps + CompiledFunction.metadata.num_outputs],
flat_args[num_mutated_inps + CompiledFunction.metadata.num_outputs:],
)
# input_info contains info on *every* input,
# But in the backward(), we are only given grad outputs for every mutated input
# We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad
mutated_inp_indices = CompiledFunction.metadata.mutated_inp_indices
input_info = CompiledFunction.metadata.input_info
assert len(inp_tangents) == len(mutated_inp_indices)
inp_tangents_filtered = [
x
for x, info_idx in zip(inp_tangents, mutated_inp_indices)
if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad
]
# We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates
out_tangents_filtered = [
x
for x, info in zip(out_tangents, out_info)
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]
and issubclass(info.raw_type, torch.Tensor)
and info.requires_grad
]
# intermediate bases always require gradients, and always participate in the backward graph.
flat_bw_args_with_grads = itertools.chain(inp_tangents_filtered, out_tangents_filtered, intermediate_base_tangents)

# sanity asserts
# metadata_only_inps = [
# x for x, info_idx in zip(inp_tangents, mutated_inp_indices)
# if not input_info[info_idx].mutates_data
# ]
# aliased_outputs = [
# x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]
# assert all(x is None for x in metadata_only_inps)
# assert all(x is None for x in aliased_outputs)

contiguous_args = [
t.contiguous() if torch.is_tensor(t) else t for t in flat_bw_args
t.contiguous() if torch.is_tensor(t) else t for t in flat_bw_args_with_grads
]

rng_args = []
Expand Down

0 comments on commit 201d02e

Please sign in to comment.