From 201d02ef77f5ec2cfb15386fceeb956dfa669b43 Mon Sep 17 00:00:00 2001 From: chilli Date: Fri, 6 Oct 2023 21:14:33 -0700 Subject: [PATCH] stop non-differentiable values from being materialized in aotautograd (#110721) Pull Request resolved: https://github.com/pytorch/pytorch/pull/110721 Approved by: https://github.com/bdhirsh ghstack dependencies: #110720 --- torch/_functorch/aot_autograd.py | 104 +++++++++++++++---------------- 1 file changed, 50 insertions(+), 54 deletions(-) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index fecc86d3a5e2c3..880f5371315f5e 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -980,13 +980,14 @@ def inner(*flat_args): f_input_tangents = [ inp for inp, info in zip(flat_f_args, input_info) - if info.mutates_data + if info.mutates_data and info.requires_grad ] f_output_tangents = [ o for o, info in zip(flat_f_outs, output_info) if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view] and issubclass(info.raw_type, torch.Tensor) + and info.requires_grad ] # intermediate bases are also included in the backward graph f_tangents = f_input_tangents + f_output_tangents + intermediate_bases @@ -1256,7 +1257,7 @@ def inner_fn(*args): # Also return a boolean mask specifying which outputs to this function will be used as tangents mutated_inputs_grad_mask = [ - meta.input_info[meta.mutated_inp_indices[i]].mutates_data + meta.input_info[meta.mutated_inp_indices[i]].mutates_data and meta.input_info[meta.mutated_inp_indices[i]].requires_grad for (i, x) in enumerate(mutated_inputs_to_return) ] @@ -1268,6 +1269,7 @@ def inner_fn(*args): # Also, only tensor outputs should participate in the backward # (in particular, Symint outputs in the forward graph shouldn't get tangents) and issubclass(meta.output_info[i].raw_type, torch.Tensor) + and meta.output_info[i].requires_grad for (i, x) in enumerate(outs) ] @@ -2125,13 +2127,16 @@ def create_synthetic_base_metadata( for o in m.output_info] - num_outer_mutated_data_inps = len([x for x in m.input_info if x.mutates_data]) - inner_mutated_data_inps = [x for inner_idx, x in enumerate(inner_args) if input_infos[inner_idx].mutates_data] + inner_mutated_tangents = [ + x + for inner_idx, x in enumerate(inner_args) + if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad + ] requires_grad_info = mutated_inp_require_grad_info + output_grad_info + input_metadata_mutation_grad_info output_info = existing_output_infos + input_metadata_output_info # Regenerate traced tangents to include mutated inputs including synthetic bases - traced_tangents = inner_mutated_data_inps + m.traced_tangents[num_outer_mutated_data_inps:] + traced_tangents = inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents):] return ViewAndMutationMeta( input_info=input_infos, @@ -2498,7 +2503,7 @@ def wrapped_flat_fn(*args): )(*flat_args_with_synthetic_bases) assert ref_fw_metadata == fw_metadata_updated, ( f'ref_metadata={pprint.pformat(partial_asdict(ref_fw_metadata))}, ' - f'actual_metadata={pprint.pformat(partial_asdict(fw_metadata_updated))}' + f'\nactual_metadata={pprint.pformat(partial_asdict(fw_metadata_updated))}' ) compiled_fn = compiler_fn(wrapped_flat_fn, flat_args_with_synthetic_bases, aot_config, fw_metadata=fw_metadata_updated) @@ -3148,6 +3153,7 @@ def forward(ctx, *deduped_flat_tensor_args): and not raw_returns_meta[i].requires_grad ] ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) + ctx._materialize_non_diff_grads = False functionalized_rng_runtime_epilogue( CompiledFunction.metadata, @@ -3175,56 +3181,46 @@ def backward(ctx, *flat_args): assert len(flat_args) == expected_grad_outs out_info = CompiledFunction.metadata.output_info - if ( - CompiledFunction.metadata.num_mutated_metadata_only_inputs > 0 - or CompiledFunction.metadata.num_outputs_aliased > 0 - ): - inp_tangents, out_tangents, intermediate_base_tangents = ( - flat_args[0:num_mutated_inps], - flat_args[num_mutated_inps:num_mutated_inps + CompiledFunction.metadata.num_outputs], - flat_args[num_mutated_inps + CompiledFunction.metadata.num_outputs:], - ) - # input_info contains info on *every* input, - # But in the backward(), we are only given grad outputs for every mutated input. - # We then need to filter out the grad outputs that correspond to metadata-only mutations. - mutated_inp_indices = CompiledFunction.metadata.mutated_inp_indices - input_info = CompiledFunction.metadata.input_info - assert len(inp_tangents) == len(mutated_inp_indices) - inp_tangents_filtered = [ - x - for x, info_idx in zip(inp_tangents, mutated_inp_indices) - if input_info[info_idx].mutates_data - ] - # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates - out_tangents_filtered = [ - x - for x, info in zip(out_tangents, out_info) - if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view] - and issubclass(info.raw_type, torch.Tensor) - ] - # intermediate bases always require gradients, and always participate in the backward graph. - flat_bw_args = itertools.chain(inp_tangents_filtered, out_tangents_filtered, intermediate_base_tangents) - - # sanity asserts - # metadata_only_inps = [ - # x for x, info_idx in zip(inp_tangents, mutated_inp_indices) - # if not input_info[info_idx].mutates_data - # ] - # aliased_outputs = [ - # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias] - # assert all(x is None for x in metadata_only_inps) - # assert all(x is None for x in aliased_outputs) - else: - # filter out non-tensor grad_outputs (aka due to ints being returned as outputs in the forward) - num_mutated_inps = CompiledFunction.metadata.num_mutated_inputs - mutated_inp_args = flat_args[:num_mutated_inps] if num_mutated_inps > 0 else [] - user_tangents = flat_args[num_mutated_inps:] - assert len(user_tangents) == len(out_info) - filtered_user_tangents = [x for x, info in zip(user_tangents, out_info) if issubclass(info.raw_type, torch.Tensor)] - flat_bw_args = tuple(mutated_inp_args) + tuple(filtered_user_tangents) + + inp_tangents, out_tangents, intermediate_base_tangents = ( + flat_args[0:num_mutated_inps], + flat_args[num_mutated_inps:num_mutated_inps + CompiledFunction.metadata.num_outputs], + flat_args[num_mutated_inps + CompiledFunction.metadata.num_outputs:], + ) + # input_info contains info on *every* input, + # But in the backward(), we are only given grad outputs for every mutated input + # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad + mutated_inp_indices = CompiledFunction.metadata.mutated_inp_indices + input_info = CompiledFunction.metadata.input_info + assert len(inp_tangents) == len(mutated_inp_indices) + inp_tangents_filtered = [ + x + for x, info_idx in zip(inp_tangents, mutated_inp_indices) + if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad + ] + # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates + out_tangents_filtered = [ + x + for x, info in zip(out_tangents, out_info) + if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view] + and issubclass(info.raw_type, torch.Tensor) + and info.requires_grad + ] + # intermediate bases always require gradients, and always participate in the backward graph. + flat_bw_args_with_grads = itertools.chain(inp_tangents_filtered, out_tangents_filtered, intermediate_base_tangents) + + # sanity asserts + # metadata_only_inps = [ + # x for x, info_idx in zip(inp_tangents, mutated_inp_indices) + # if not input_info[info_idx].mutates_data + # ] + # aliased_outputs = [ + # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias] + # assert all(x is None for x in metadata_only_inps) + # assert all(x is None for x in aliased_outputs) contiguous_args = [ - t.contiguous() if torch.is_tensor(t) else t for t in flat_bw_args + t.contiguous() if torch.is_tensor(t) else t for t in flat_bw_args_with_grads ] rng_args = []