Skip to content

Commit

Permalink
Revert "[Dynamo] Trace enter/exit of TorchFunctionModes (pytorch#135422
Browse files Browse the repository at this point in the history
…)" (pytorch#136590)

This reverts commit 7743149.

Reverts
* pytorch#135503
* pytorch#135502
* pytorch#135422

This passes this test. Earlier, the getitem would stay like a getitem in the Fx graph. But now the fake tensor propagations fails saying that .item is called. It seems that torch function is not getting triggered while fake tensor propagation.

```
import torch
from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature, _score_mod_signature, flex_attention
from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.virtualized import ops
from torch.nn.attention.flex_attention import create_block_mask

torch.set_default_device('cuda')

flex_attention = torch.compile(flex_attention, dynamic=False)

prefix_lengths = torch.arange(8)
def prefix_lm(b, h, q, kv):
    return prefix_lengths[b] >= kv

mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
```

Pull Request resolved: pytorch#136590
Approved by: https://github.com/Chillee
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Sep 25, 2024
1 parent 529b6ab commit 289df45
Show file tree
Hide file tree
Showing 18 changed files with 203 additions and 338 deletions.
153 changes: 65 additions & 88 deletions test/dynamo/test_modes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Owner(s): ["module: dynamo"]
from unittest.mock import patch

import torch
import torch._dynamo.test_case
Expand Down Expand Up @@ -106,6 +107,70 @@ def fn(x):
fn(inp)
self.assertEqual(cnt.frame_count, 4)

def _run_ignored_mode_types_test(self):
class IgnoredMode(BaseTorchFunctionMode):
pass

cnt = torch._dynamo.testing.CompileCounter()

@torch.compile(backend=cnt.__call__, fullgraph=True)
def fn(x):
return x + 1

inp = torch.ones(2, 2)

with patch(
"torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
):
# initial compile
fn(inp)

# no recompile, mode ignored
# note: the ref stack is length 0, and the stack we are checking against has length 2
# we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
with IgnoredMode(), IgnoredMode():
fn(inp)

self.assertEqual(cnt.frame_count, 1)

# recompile due to new mode on the stack
with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
fn(inp)

self.assertEqual(cnt.frame_count, 2)

# recompile
# tests both ref stack len > runtime stack len for the above guard check
# and ref stack len < runtime stack len for the initial zero mode case
with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
fn(inp)

self.assertEqual(cnt.frame_count, 3)

# no recompile
with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
fn(inp)

self.assertEqual(cnt.frame_count, 3)

# This is tricky, basically the ignored modes are baked into the guard
# IgnoredMode will be ignored forever by that guard.
# This is okay since we don't expect to be modifying IGNORED_MODES
# in the middle of execution except for the purposes of testing.
torch._dynamo.reset()

with IgnoredMode():
fn(inp)

self.assertEqual(cnt.frame_count, 4)

@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
def test_torch_function_mode_guards_ignored_types_py(self):
self._run_ignored_mode_types_test()

def test_torch_function_mode_guards_ignored_types_cpp(self):
self._run_ignored_mode_types_test()

@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
def test_torch_function_mode_guards_py(self):
self._run_torch_function_mode_guard_test()
Expand Down Expand Up @@ -396,94 +461,6 @@ def fn(x):

self.assertEqual(expected, actual)

def test_torch_function_mode_enter_exit(self):
def fn(x, y):
with TestMode():
o = torch.add(x, 3)

return torch.add(o, y)

inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn, fullgraph=True)

expected = fn(*inp)
actual = fn_opt(*inp)

self.assertEqual(expected, actual)

def test_torch_function_mode_graph_break(self):
def fn(x, y):
with TestMode():
torch._dynamo.graph_break()
o = torch.add(x, 3)

return torch.add(o, y)

inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)

expected = fn(*inp)
actual = fn_opt(*inp)

self.assertEqual(expected, actual)

def test_torch_function_mode_and_pop_graph_break(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)

return torch.add(o, y)

inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)

expected = fn(*inp)
actual = fn_opt(*inp)

self.assertEqual(expected, actual)

def test_torch_function_mode_restore_on_exc(self):
@torch._dynamo.disable()
def err():
raise RuntimeError("test")

@torch.compile()
def fn(x):
with TestMode():
x += 1
err()
x += 2
return x

try:
fn(torch.ones(2, 2))
except RuntimeError:
pass
self.assertEqual(_len_torch_function_stack(), 0)

def test_torch_function_mode_and_pop_graph_break_mutation(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
z.y = 5
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
o = torch.mul(o, z.y)

return torch.add(o, y)

inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)

expected = fn(*inp)
actual = fn_opt(*inp)

self.assertEqual(expected, actual)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/_dynamo/guards.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class GuardManager:
) -> None: ...
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
def add_torch_function_mode_stack_guard(
self, initial_stack, verbose_code_parts: list[str]
self, initial_stack, ignored_types, verbose_code_parts: list[str]
) -> None: ...

class RootGuardManager(GuardManager):
Expand Down
6 changes: 1 addition & 5 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@
troubleshooting_url,
write_record_to_file,
)
from .variables.torch_function import torch_function_mode_stack_state_mgr


np: Optional[ModuleType]
Expand Down Expand Up @@ -211,18 +210,15 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
cleanup = setup_compile_debug()

exit_stack = contextlib.ExitStack()
exit_stack.enter_context(
torch.fx._symbolic_trace._maybe_revert_all_patches()
)
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
try:
return fn(*args, **kwargs)
finally:
cleanup.close()
assert (
torch._C._len_torch_function_stack() == 0
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
exit_stack.close()
torch._C._set_grad_enabled(prior_grad_mode)
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
Expand Down
11 changes: 9 additions & 2 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2344,12 +2344,15 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn):
)

if config.enable_cpp_guard_manager:
from .variables.torch_function import IGNORED_MODES

# Insert the global_state guard
assert self.guard_manager # to make mypy happy
self.guard_manager.root.add_global_state_guard(["___check_global_state()"])

self.guard_manager.root.add_torch_function_mode_stack_guard(
self.torch_function_mode_stack,
list(IGNORED_MODES),
["___check_torch_function_mode_stack()"],
)
# Clear references to torch_function modes held in the list
Expand Down Expand Up @@ -2656,14 +2659,18 @@ def is_recompiles_verbose_enabled():
# this will only be used if cpp guards are disabled
def make_torch_function_mode_stack_guard(intial_stack):
types = [type(x) for x in intial_stack]
from .variables.torch_function import IGNORED_MODES

def check_torch_function_mode_stack():
cur_stack = get_torch_function_mode_stack()

if len(cur_stack) != len(types):
types_ = [ty for ty in types if ty not in IGNORED_MODES]
cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES]

if len(cur_stack_) != len(types_):
return False

for ty, mode in zip(types, cur_stack):
for ty, mode in zip(types_, cur_stack_):
if ty != type(mode):
return False

Expand Down
6 changes: 3 additions & 3 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
get_instruction_source_311,
get_locals_to_steal,
get_static_address_type,
get_torch_function_mode_stack,
graph_break_reasons,
increment_op_count,
lazy_format_graph_code,
Expand Down Expand Up @@ -249,7 +250,6 @@ def __init__(
local_scope: Scope,
global_scope: Scope,
f_code,
torch_function_mode_stack,
):
super().__init__()
self.tracers = [SubgraphTracer(self, export_root=export)]
Expand Down Expand Up @@ -368,7 +368,7 @@ def __init__(
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
# This records the initial torch function mode stack for guarding
self.torch_function_mode_stack = torch_function_mode_stack
self.torch_function_mode_stack = get_torch_function_mode_stack()

# Tracks if the output graph has a user defined allowed function in the
# graph. This is used later to determine if we should fallback to eager
Expand Down Expand Up @@ -1020,7 +1020,7 @@ def append_prefix_insts():
prefix_insts.clear()

for block in reversed(tx.block_stack):
block.exit(tx, is_graph_break=reason.graph_break)
block.exit(tx)

self.cleanup_graph()
tx.prune_dead_locals()
Expand Down
20 changes: 0 additions & 20 deletions torch/_dynamo/polyfills/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,6 @@
sys as sys,
)

from torch.overrides import BaseTorchFunctionMode


# These classes handle support for TorchFunctionModes across
# graph breaks
# Today the TorchFunctionMode enter (for the classes we support)
# simply pushes the mode onto the stack. Since after this occurs
# the stack is mutated, and we replay these mutations, we don't need
# any cleanup logic to be run once the graph break occurs, we simply replay
# these mutations to ensure at the graph break the torch function mode stack is correct
# and reconstruct the torch function mode stack normally
# when we compile the resume function on the other side of the break.
# However, to ensure we exit properly
# in the resume function, we need to re-enter the contexts as we do other contexts.
# These contexts do nothing on enter, but provide the correct exit logic to ensure
# the stack state is correct.
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
def __enter__(self):
pass


def index(iterator, item, start=0, end=None):
from itertools import islice
Expand Down
41 changes: 21 additions & 20 deletions torch/_dynamo/resume_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,26 +90,27 @@ class ReenterWith:
stack_index: int
target_values: Optional[Tuple[Any, ...]] = None

def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
"""
Codegen based off of:
try:
(rest)
except:
(restore previous tf mode stack)
raise
"""
from .variables.torch_function import get_prev_stack_var_name

setup_try_except, epilogue = _bytecode_from_template_with_split(
_try_except_tf_mode_template,
self.stack_index,
varname_map={"stack_var_name": get_prev_stack_var_name()},
)
cleanup[:] = epilogue + cleanup

return setup_try_except
# TODO(mlazos) - Uncomment with the reland of torch function mode support
# def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
# """
# Codegen based off of:
# try:
# (rest)
# except:
# (restore previous tf mode stack)
# raise

# """
# from .variables.torch_function import get_prev_stack_var_name

# setup_try_except, epilogue = _bytecode_from_template_with_split(
# _try_except_tf_mode_template,
# self.stack_index,
# varname_map={"stack_var_name": get_prev_stack_var_name()},
# )
# cleanup[:] = epilogue + cleanup

# return setup_try_except

# If we do not want to destroy the stack, we can do the same thing as a
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
Expand Down
11 changes: 0 additions & 11 deletions torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,22 +623,11 @@ def codegen_update_mutated(self, cg: PyCodegen):
elif isinstance(
var, variables.torch_function.TorchFunctionModeStackVariable
):
# Needed in the finally block for stack restoration
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "get_torch_function_mode_stack"
)
)
cg.call_function(0, False)
name = variables.torch_function.get_prev_stack_var_name()
cg.code_options["co_varnames"] += (name,)
cg.append_output(create_instruction("STORE_FAST", argval=name))
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "set_torch_function_mode_stack"
)
)

cg.foreach(var.symbolic_stack)
cg.append_output(
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
Expand Down
Loading

0 comments on commit 289df45

Please sign in to comment.