Skip to content

Commit

Permalink
Make balance_gradient preserved in export (pytorch#120332)
Browse files Browse the repository at this point in the history
Summary: We can only not-decompose CompositeImplicit functional custom ops. From the looks of the implementation, this op looks functional. So the fix is just fixing the schema.

Test Plan: CI

Differential Revision: D54019265

Pull Request resolved: pytorch#120332
Approved by: https://github.com/zhxchen17
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Feb 23, 2024
1 parent 182ed1e commit 8646872
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
3 changes: 0 additions & 3 deletions torch/_subclasses/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,7 @@ def wrap(x):
return FunctionalTensor(x)
return x

any_functional_inputs = False

def unwrap(x):
any_functional_inputs = True
return x.elem

from torch._higher_order_ops.auto_functionalize import (
Expand Down
7 changes: 6 additions & 1 deletion torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,12 @@ def forward(self, *args, **kwargs):
fake_mode, _get_params_buffers(f)
)
ep_non_strict = _export_non_strict(
f, fake_args, fake_kwargs, fake_params_buffers, transform=_tuplify_outputs
f,
fake_args,
fake_kwargs,
fake_params_buffers,
pre_dispatch=pre_dispatch,
transform=_tuplify_outputs,
)
range_constraints, equality_constraints = make_constraints(
fake_mode, src_equalities, original_signature, ep_non_strict.gm
Expand Down

0 comments on commit 8646872

Please sign in to comment.