Skip to content

Commit

Permalink
[HigherOrderOp] makes control flow operators respect global decomp ta…
Browse files Browse the repository at this point in the history
…ble (pytorch#120412)

A follow up of @zou3519 's comment on pytorch#120366. We create a helper method for this purpose.

Pull Request resolved: pytorch#120412
Approved by: https://github.com/zou3519
  • Loading branch information
ydwu4 authored and pytorchmergebot committed Feb 23, 2024
1 parent 156954d commit 8f4ffd3
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 26 deletions.
18 changes: 3 additions & 15 deletions torch/_higher_order_ops/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_run_with_interpreter,
_set_compilation_env,
autograd_not_implemented,
reenter_make_fx,
UnsupportedAliasMutationException,
)

from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
)
Expand Down Expand Up @@ -157,19 +156,8 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)

with disable_proxy_modes_tracing():
# We'll use the current decomposition table to make sure operatos in subgraphs are
# decomposed properly.
decomp_table = torch.fx.experimental.proxy_tensor.CURRENT_DECOMPOSITION_TABLE
true_graph = make_fx(
_maybe_run_with_interpreter(true_fn),
decomposition_table=decomp_table,
pre_dispatch=pre_dispatch,
)(*operands)
false_graph = make_fx(
_maybe_run_with_interpreter(false_fn),
decomposition_table=decomp_table,
pre_dispatch=pre_dispatch,
)(*operands)
true_graph = reenter_make_fx(true_fn, pre_dispatch)(*operands)
false_graph = reenter_make_fx(false_fn, pre_dispatch)(*operands)

true_outs = []
false_outs = []
Expand Down
6 changes: 4 additions & 2 deletions torch/_higher_order_ops/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
reenter_make_fx,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
Expand Down Expand Up @@ -228,8 +229,9 @@ def trace_map(proxy_mode, func_overload, f, xs, pos_args):

example_input = _unstack_pytree(xs)[0]
body_graph = f
if not isinstance(body_graph, torch.fx.GraphModule):
body_graph = make_fx(body_graph)(*example_input, *pos_args)

pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
body_graph = reenter_make_fx(body_graph, pre_dispatch)(*example_input, *pos_args)

next_name = None
i = 0
Expand Down
12 changes: 12 additions & 0 deletions torch/_higher_order_ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ def graph_with_interpreter(*args):
return maybe_interpreted_fn


# We'll use the current decomposition table to make sure operators in subgraphs are
# decomposed properly.
# We also need to maybe run with interpreter for propagating stack_trace
def reenter_make_fx(fn, pre_dispatch=False):
decomp_table = torch.fx.experimental.proxy_tensor.CURRENT_DECOMPOSITION_TABLE
return make_fx(
_maybe_run_with_interpreter(fn),
decomposition_table=decomp_table,
pre_dispatch=pre_dispatch,
)


@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
Expand Down
12 changes: 3 additions & 9 deletions torch/_higher_order_ops/while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_run_with_interpreter,
_set_compilation_env,
autograd_not_implemented,
reenter_make_fx,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
)
Expand Down Expand Up @@ -159,14 +158,9 @@ def _is_boolean_scalar_tensor(pred):
def while_loop_tracing(mode, cond_fn, body_fn, operands):
def _trace_while_loop(proxy_mode, while_loop_op, cond_fn, body_fn, operands):
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)

with disable_proxy_modes_tracing():
cond_graph = make_fx(
_maybe_run_with_interpreter(cond_fn), pre_dispatch=pre_dispatch
)(*operands)
body_graph = make_fx(
_maybe_run_with_interpreter(body_fn), pre_dispatch=pre_dispatch
)(*operands)
cond_graph = reenter_make_fx(cond_fn, pre_dispatch)(*operands)
body_graph = reenter_make_fx(body_fn, pre_dispatch)(*operands)

next_name = None
i = 0
Expand Down

0 comments on commit 8f4ffd3

Please sign in to comment.