From 289df45cee05edb0872b0df0f8a93c8d1dd5b5ca Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 25 Sep 2024 09:36:53 -0700 Subject: [PATCH] Revert "[Dynamo] Trace enter/exit of TorchFunctionModes (#135422)" (#136590) This reverts commit 7743149b2be4a9eba7e0997ccdc6abe552bec266. Reverts * https://github.com/pytorch/pytorch/pull/135503 * https://github.com/pytorch/pytorch/pull/135502 * https://github.com/pytorch/pytorch/pull/135422 This passes this test. Earlier, the getitem would stay like a getitem in the Fx graph. But now the fake tensor propagations fails saying that .item is called. It seems that torch function is not getting triggered while fake tensor propagation. ``` import torch from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature, _score_mod_signature, flex_attention from torch._inductor.lowering import make_pointwise, register_lowering from torch._inductor.virtualized import ops from torch.nn.attention.flex_attention import create_block_mask torch.set_default_device('cuda') flex_attention = torch.compile(flex_attention, dynamic=False) prefix_lengths = torch.arange(8) def prefix_lm(b, h, q, kv): return prefix_lengths[b] >= kv mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136590 Approved by: https://github.com/Chillee --- test/dynamo/test_modes.py | 153 +++++++++------------- torch/_C/_dynamo/guards.pyi | 2 +- torch/_dynamo/convert_frame.py | 6 +- torch/_dynamo/guards.py | 11 +- torch/_dynamo/output_graph.py | 6 +- torch/_dynamo/polyfills/__init__.py | 20 --- torch/_dynamo/resume_execution.py | 41 +++--- torch/_dynamo/side_effects.py | 11 -- torch/_dynamo/symbolic_convert.py | 30 ++--- torch/_dynamo/testing.py | 1 - torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/utils.py | 10 +- torch/_dynamo/variables/builder.py | 20 ++- torch/_dynamo/variables/ctx_manager.py | 12 -- torch/_dynamo/variables/torch.py | 21 +-- torch/_dynamo/variables/torch_function.py | 112 ++-------------- torch/_dynamo/variables/user_defined.py | 14 +- torch/csrc/dynamo/guards.cpp | 69 ++++++++-- 18 files changed, 203 insertions(+), 338 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 4d1f2bbea389e..fa4c23fd320dd 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -1,4 +1,5 @@ # Owner(s): ["module: dynamo"] +from unittest.mock import patch import torch import torch._dynamo.test_case @@ -106,6 +107,70 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 4) + def _run_ignored_mode_types_test(self): + class IgnoredMode(BaseTorchFunctionMode): + pass + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt.__call__, fullgraph=True) + def fn(x): + return x + 1 + + inp = torch.ones(2, 2) + + with patch( + "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode} + ): + # initial compile + fn(inp) + + # no recompile, mode ignored + # note: the ref stack is length 0, and the stack we are checking against has length 2 + # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack + with IgnoredMode(), IgnoredMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 1) + + # recompile due to new mode on the stack + with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 2) + + # recompile + # tests both ref stack len > runtime stack len for the above guard check + # and ref stack len < runtime stack len for the initial zero mode case + with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 3) + + # no recompile + with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 3) + + # This is tricky, basically the ignored modes are baked into the guard + # IgnoredMode will be ignored forever by that guard. + # This is okay since we don't expect to be modifying IGNORED_MODES + # in the middle of execution except for the purposes of testing. + torch._dynamo.reset() + + with IgnoredMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 4) + + @torch._dynamo.config.patch("enable_cpp_guard_manager", False) + def test_torch_function_mode_guards_ignored_types_py(self): + self._run_ignored_mode_types_test() + + def test_torch_function_mode_guards_ignored_types_cpp(self): + self._run_ignored_mode_types_test() + @torch._dynamo.config.patch("enable_cpp_guard_manager", False) def test_torch_function_mode_guards_py(self): self._run_torch_function_mode_guard_test() @@ -396,94 +461,6 @@ def fn(x): self.assertEqual(expected, actual) - def test_torch_function_mode_enter_exit(self): - def fn(x, y): - with TestMode(): - o = torch.add(x, 3) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn, fullgraph=True) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_graph_break(self): - def fn(x, y): - with TestMode(): - torch._dynamo.graph_break() - o = torch.add(x, 3) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_and_pop_graph_break(self): - def fn(x, y): - with TestMode(): - z = _pop_torch_function_stack() - torch._dynamo.graph_break() - _push_on_torch_function_stack(z) - o = torch.add(x, 3) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_restore_on_exc(self): - @torch._dynamo.disable() - def err(): - raise RuntimeError("test") - - @torch.compile() - def fn(x): - with TestMode(): - x += 1 - err() - x += 2 - return x - - try: - fn(torch.ones(2, 2)) - except RuntimeError: - pass - self.assertEqual(_len_torch_function_stack(), 0) - - def test_torch_function_mode_and_pop_graph_break_mutation(self): - def fn(x, y): - with TestMode(): - z = _pop_torch_function_stack() - z.y = 5 - torch._dynamo.graph_break() - _push_on_torch_function_stack(z) - o = torch.add(x, 3) - o = torch.mul(o, z.y) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 918d913068e6d..d6aab0d547f72 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -67,7 +67,7 @@ class GuardManager: ) -> None: ... def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... def add_torch_function_mode_stack_guard( - self, initial_stack, verbose_code_parts: list[str] + self, initial_stack, ignored_types, verbose_code_parts: list[str] ) -> None: ... class RootGuardManager(GuardManager): diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 1ae1f0c6fab31..c685ca5bdafda 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -112,7 +112,6 @@ troubleshooting_url, write_record_to_file, ) -from .variables.torch_function import torch_function_mode_stack_state_mgr np: Optional[ModuleType] @@ -211,18 +210,15 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() + exit_stack = contextlib.ExitStack() exit_stack.enter_context( torch.fx._symbolic_trace._maybe_revert_all_patches() ) - exit_stack.enter_context(torch_function_mode_stack_state_mgr) try: return fn(*args, **kwargs) finally: cleanup.close() - assert ( - torch._C._len_torch_function_stack() == 0 - ), "Torch function mode stack state changed while dynamo tracing, please report a bug" exit_stack.close() torch._C._set_grad_enabled(prior_grad_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 3dcc9a032f208..1544bdd6dc697 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2344,12 +2344,15 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): ) if config.enable_cpp_guard_manager: + from .variables.torch_function import IGNORED_MODES + # Insert the global_state guard assert self.guard_manager # to make mypy happy self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, + list(IGNORED_MODES), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list @@ -2656,14 +2659,18 @@ def is_recompiles_verbose_enabled(): # this will only be used if cpp guards are disabled def make_torch_function_mode_stack_guard(intial_stack): types = [type(x) for x in intial_stack] + from .variables.torch_function import IGNORED_MODES def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - if len(cur_stack) != len(types): + types_ = [ty for ty in types if ty not in IGNORED_MODES] + cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] + + if len(cur_stack_) != len(types_): return False - for ty, mode in zip(types, cur_stack): + for ty, mode in zip(types_, cur_stack_): if ty != type(mode): return False diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 76be81a088c3c..ed429f7cab002 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -78,6 +78,7 @@ get_instruction_source_311, get_locals_to_steal, get_static_address_type, + get_torch_function_mode_stack, graph_break_reasons, increment_op_count, lazy_format_graph_code, @@ -249,7 +250,6 @@ def __init__( local_scope: Scope, global_scope: Scope, f_code, - torch_function_mode_stack, ): super().__init__() self.tracers = [SubgraphTracer(self, export_root=export)] @@ -368,7 +368,7 @@ def __init__( # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() # This records the initial torch function mode stack for guarding - self.torch_function_mode_stack = torch_function_mode_stack + self.torch_function_mode_stack = get_torch_function_mode_stack() # Tracks if the output graph has a user defined allowed function in the # graph. This is used later to determine if we should fallback to eager @@ -1020,7 +1020,7 @@ def append_prefix_insts(): prefix_insts.clear() for block in reversed(tx.block_stack): - block.exit(tx, is_graph_break=reason.graph_break) + block.exit(tx) self.cleanup_graph() tx.prune_dead_locals() diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 5b2812bc08c9e..2b3f38920a0c2 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -25,26 +25,6 @@ sys as sys, ) -from torch.overrides import BaseTorchFunctionMode - - -# These classes handle support for TorchFunctionModes across -# graph breaks -# Today the TorchFunctionMode enter (for the classes we support) -# simply pushes the mode onto the stack. Since after this occurs -# the stack is mutated, and we replay these mutations, we don't need -# any cleanup logic to be run once the graph break occurs, we simply replay -# these mutations to ensure at the graph break the torch function mode stack is correct -# and reconstruct the torch function mode stack normally -# when we compile the resume function on the other side of the break. -# However, to ensure we exit properly -# in the resume function, we need to re-enter the contexts as we do other contexts. -# These contexts do nothing on enter, but provide the correct exit logic to ensure -# the stack state is correct. -class NoEnterTorchFunctionMode(BaseTorchFunctionMode): - def __enter__(self): - pass - def index(iterator, item, start=0, end=None): from itertools import islice diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 00d971dffb17e..e9dedfb84a99c 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -90,26 +90,27 @@ class ReenterWith: stack_index: int target_values: Optional[Tuple[Any, ...]] = None - def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): - """ - Codegen based off of: - try: - (rest) - except: - (restore previous tf mode stack) - raise - - """ - from .variables.torch_function import get_prev_stack_var_name - - setup_try_except, epilogue = _bytecode_from_template_with_split( - _try_except_tf_mode_template, - self.stack_index, - varname_map={"stack_var_name": get_prev_stack_var_name()}, - ) - cleanup[:] = epilogue + cleanup - - return setup_try_except + # TODO(mlazos) - Uncomment with the reland of torch function mode support + # def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): + # """ + # Codegen based off of: + # try: + # (rest) + # except: + # (restore previous tf mode stack) + # raise + + # """ + # from .variables.torch_function import get_prev_stack_var_name + + # setup_try_except, epilogue = _bytecode_from_template_with_split( + # _try_except_tf_mode_template, + # self.stack_index, + # varname_map={"stack_var_name": get_prev_stack_var_name()}, + # ) + # cleanup[:] = epilogue + cleanup + + # return setup_try_except # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 3b83eb17d407f..e26e50b09687b 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -623,22 +623,11 @@ def codegen_update_mutated(self, cg: PyCodegen): elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): - # Needed in the finally block for stack restoration - cg.add_push_null( - lambda: cg.load_import_from( - utils.__name__, "get_torch_function_mode_stack" - ) - ) - cg.call_function(0, False) - name = variables.torch_function.get_prev_stack_var_name() - cg.code_options["co_varnames"] += (name,) - cg.append_output(create_instruction("STORE_FAST", argval=name)) cg.add_push_null( lambda: cg.load_import_from( utils.__name__, "set_torch_function_mode_stack" ) ) - cg.foreach(var.symbolic_stack) cg.append_output( create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 829a831d66562..5181152fc5049 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -267,12 +267,13 @@ def resume_fn(self): else: return ReenterWith(self.stack_index) - def exit(self, tx, is_graph_break): + def exit(self, tx): + if hasattr(self, "graph_break") and isinstance( + self.with_context, TorchFunctionModeVariable + ): + return assert self.with_context is not None - if ( - is_graph_break and self.with_context.exit_on_graph_break() - ) or not is_graph_break: - return self.with_context.exit(tx) + return self.with_context.exit(tx) class ReturnValueOp(Exception): @@ -638,17 +639,10 @@ def handle_graph_break( cleanup: List[Instruction] = [] # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: - # Don't exit any modes we have entered, - # output bytecode will mutate the tf mode stack accordingly - if isinstance(b.with_context, TorchFunctionModeVariable): - cg.extend_output( - b.resume_fn().try_except_torch_function_mode( - cg.code_options, cleanup - ) - ) - continue assert b.with_context is not None - assert isinstance(b.with_context, (ContextWrappingVariable)) + assert isinstance( + b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) + ) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -2301,10 +2295,7 @@ def setup_or_before_with(self, inst): ): unimplemented(f"{inst.opname} {ctx}") - if ( - isinstance(ctx, GenericContextWrappingVariable) - and not ctx.supports_graph_breaks() - ): + if isinstance(ctx, GenericContextWrappingVariable): self.generic_context_manager_depth += 1 # Need this redundant check for mypy @@ -2677,7 +2668,6 @@ def __init__( local_scope=f_locals, global_scope=f_globals, f_code=f_code, - torch_function_mode_stack=torch_function_mode_stack, ), instructions=instructions, f_locals=f_locals, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 2e998873d0c41..3a5abaebd4573 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -187,7 +187,6 @@ def insert_nops(instructions: List[Any], code_options: Any) -> None: local_scope=locals(), global_scope=globals(), f_code=frame.f_code, - torch_function_mode_stack=[], ) return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index ccb4f357d9ed9..32c2964bb1abf 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -304,7 +304,6 @@ "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, - "torch.set_default_device": UserFunctionVariable, "torch.sparse_bsc_tensor": SkipFunctionVariable, "torch.sparse_bsr_tensor": SkipFunctionVariable, "torch.sparse_csc_tensor": SkipFunctionVariable, @@ -2802,6 +2801,7 @@ "torch.random.initial_seed", "torch.random.seed", "torch.return_types.pytree_register_structseq", + "torch.set_default_device", "torch.set_default_dtype", "torch.set_default_tensor_type", "torch.set_deterministic_debug_mode", diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 72c2e1a6343ac..5257f0cce9887 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3097,10 +3097,16 @@ def is_parameter_freezing(): return torch._inductor.config.freezing and not torch.is_grad_enabled() -def get_torch_function_mode_stack(): - return [ +def get_torch_function_mode_stack(filter_ignored=True): + from .variables.torch_function import IGNORED_MODES + + stack = [ get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) ] + if filter_ignored: + stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] + + return stack def get_torch_function_mode_stack_at(ind): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index b8444cba4bbb3..9e02093d873aa 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -204,7 +204,6 @@ from .torch_function import ( build_torch_function_fn, TensorWithTFOverrideVariable, - torch_function_mode_stack_state_mgr, TorchFunctionModeVariable, ) from .user_defined import ( @@ -1671,16 +1670,15 @@ def wrap_numpy_ndarray(self, value): # but warning is not the end of the world assert isinstance(value.base, np.nditer) - with torch_function_mode_stack_state_mgr.temp_restore_stack(): - try: - tensor_value = _util._try_convert_to_tensor(value) - if readonly: - from torch._prims_common import clone_preserve_strides - - tensor_value = clone_preserve_strides(tensor_value) - except NotImplementedError as e: - # failed to convert to tensor, graph break - unimplemented(str(e)) + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides + + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented(str(e)) # We do this because we want the full behavior of guarding the numpy ndarray as if it were # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index e19c4e254c647..8c4eb3dc4e715 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -125,12 +125,6 @@ def call_function( if isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) - def supports_graph_breaks(self): - return True - - def exit_on_graph_break(self): - return True - class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are @@ -189,12 +183,6 @@ def exit(self, tx: "InstructionTranslator", *args): tx.generic_context_manager_depth -= 1 return x - def supports_graph_breaks(self): - return False - - def exit_on_graph_break(self): - return True - class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requries grad""" diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index e27611e96b3ed..478ab82de4334 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -160,17 +160,7 @@ def get_overridable_functions(): from torch.overrides import get_overridable_functions as get_overridable_functions_ - funcs = set(chain(*get_overridable_functions_().values())) - more = { - torch.ones, - torch.ones_like, - torch.zeros, - torch.zeros_like, - torch.empty, - torch.full, - } - funcs.update(more) - return funcs + return set(chain(*get_overridable_functions_().values())) class BaseTorchVariable(VariableTracker): @@ -846,13 +836,6 @@ def handle_len_torch_function( len(tx.symbolic_torch_function_state.mode_stack) ) - @register(torch._C._get_function_stack_at) - def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): - assert len(args) == 1 and not kwargs - ind = args[0].as_python_constant() - assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) - return tx.symbolic_torch_function_state.mode_stack[ind] - @register(torch.set_default_device) def handle_set_default_device( self, tx: "InstructionTranslator", *args, **kwargs @@ -870,7 +853,7 @@ def handle_set_default_device( else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) - return ConstantVariable.create(None) + return None return handlers diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index ffb3d27d4d703..6e52cc688e030 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -2,35 +2,22 @@ import collections import contextlib -import functools import inspect from typing import Deque, Dict, List, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree from torch._guards import Source -from torch.overrides import ( - _get_overloaded_args, - get_default_nowrap_functions, - TorchFunctionMode, -) +from torch.overrides import _get_overloaded_args, get_default_nowrap_functions from torch.utils._device import DeviceContext from ..exc import unimplemented from ..guards import GuardBuilder, install_guard -from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource -from ..utils import ( - class_has_getattribute, - clear_torch_function_mode_stack, - get_safe_global_name, - has_torch_function, - is_tensor_base_attr_getter, - set_torch_function_mode_stack, -) +from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter from .base import VariableTracker from .constant import ConstantVariable -from .ctx_manager import GenericContextWrappingVariable +from .ctx_manager import ContextWrappingVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -69,38 +56,11 @@ if is_tensor_base_attr_getter(fn) ] - -@functools.lru_cache(None) -def get_prev_stack_var_name(): - from ..bytecode_transformation import unique_id - - return unique_id("___prev_torch_function_mode_stack") - - -# Used to clear/restore the python torch function mode stack and temporarily restore it as needed -class TorchFunctionModeStackStateManager: - def __init__(self): - self.stack = [] - - def __enter__(self): - self.stack = torch.overrides._get_current_function_mode_stack() - clear_torch_function_mode_stack() - - def __exit__(self, exc_type, exc_value, traceback): - set_torch_function_mode_stack(self.stack) - self.stack = [] - - @contextlib.contextmanager - def temp_restore_stack(self): - prev = torch.overrides._get_current_function_mode_stack() - set_torch_function_mode_stack(self.stack) - try: - yield - finally: - set_torch_function_mode_stack(prev) - - -torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() +# Today set default device is placed in the graph and guarded on separately +# so we should not trace through it. In the future we can trace it once +# mode tracing is implemented and not put in the graph, but this is more +# of a BE project and can be evaluated later +IGNORED_MODES = {DeviceContext} class SymbolicTorchFunctionState: @@ -229,26 +189,9 @@ def get_mode_index(cls, ind): return ind + cls.offset -class TorchFunctionModeVariable(GenericContextWrappingVariable): - @staticmethod - def is_supported_torch_function_mode(ty): - # Supported in this sense means we can support graph breaks under the - # context. - # We are able to trace custom modes but if there are graph breaks under them - # and they have a custom __enter__/__exit__ we don't handle this for the - # same reason we don't handle generic context managers: there may be side effects - # that are now affected by executing the funtion across two frames instead of one - # Today we support the enter/exit of the default TorchFunctionMode as well as - # DeviceContext (which is used for set_default_device) - return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( - not class_has_getattribute(ty) - and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ - and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ - ) - +class TorchFunctionModeVariable(ContextWrappingVariable): def __init__(self, value, source=None, **kwargs): - if value is not None: - super().__init__(value, **kwargs) + super().__init__(value, **kwargs) self.value = value self.cm_obj = value # needed for BC with calling enter from CM code self.source = source @@ -278,39 +221,8 @@ def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwar kwargs, ) - def enter(self, tx): - from .torch import TorchInGraphFunctionVariable - - if isinstance(self.value, NoEnterTorchFunctionMode): - return ConstantVariable.create(None) - - TorchInGraphFunctionVariable( - torch._C._push_on_torch_function_stack - ).call_function(tx, [self], {}) - return ConstantVariable.create(None) - - def exit(self, tx: "InstructionTranslator", *args): - from .torch import TorchInGraphFunctionVariable - - TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( - tx, [], {} - ) - return ConstantVariable.create(None) - - def reconstruct_type(self, codegen): - ty = NoEnterTorchFunctionMode - codegen( - AttrSource( - codegen.tx.import_source(ty.__module__), - ty.__name__, - ) - ) - - def supports_graph_breaks(self): - return True - - def exit_on_graph_break(self): - return False + def _call_func(self, tx: "InstructionTranslator", values): + unimplemented("enter/exit for torch function mode NYI") def _get_all_args(args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index c2d0ea0d1c345..62e03c6da0ac9 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -417,22 +417,10 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): - from torch.overrides import TorchFunctionMode - from .ctx_manager import GenericContextWrappingVariable - from .torch_function import TorchFunctionModeVariable - - if issubclass( - self.value, TorchFunctionMode - ) and TorchFunctionModeVariable.is_supported_torch_function_mode( - self.value - ): - var_cls = TorchFunctionModeVariable - else: - var_cls = GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, var_cls, {} + self.source, self.value, GenericContextWrappingVariable, {} ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 4a4b72a0267b4..7c4ed25367e06 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2537,40 +2537,90 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { public: TORCH_FUNCTION_MODE_STACK( const py::list& initial_stack, + const py::list& ignored_types, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), _ref_stack() { + : LeafGuard(std::move(verbose_code_parts)), + _ref_stack(), + _ignored_types() { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref auto type = Py_TYPE(mode); this->_ref_stack.push_back(type); } + + len = PyList_Size(ignored_types.ptr()); + for (Py_ssize_t idx = 0; idx < len; idx++) { + PyObject* type_obj = + PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref + if (PyType_Check(type_obj) == 0) { + PyErr_SetString( + PyExc_TypeError, "ignored_types should contain a list of types"); + return; + } + PyTypeObject* type = (PyTypeObject*)type_obj; + this->_ignored_types.insert(type); + } } bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface - const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len(); + size_t ref_ind = 0; + const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - if (len != ref_stack_size) { - return false; + int64_t idx = 0; + while ((idx < len) && (ref_ind < ref_stack_size)) { + std::shared_ptr mode = + at::impl::PythonTorchFunctionTLS::get_stack_at(idx); + + PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + bool act_ignored = this->_ignored_types.count(mode_type) > 0; + bool ref_ignored = + this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; + // skip ignored types + if (act_ignored && ref_ignored) { + idx++; + ref_ind++; + continue; + } else if (ref_ignored) { + ref_ind++; + continue; + } else if (act_ignored) { + idx++; + continue; + } + // if we already have more non-ignored modes than the ref stack + // or if the mode doesn't match at the current index, return false + else if (mode_type != _ref_stack.at(ref_ind)) { + return false; + } + ref_ind++; + idx++; } - for (int64_t idx = 0; (size_t)idx < len; idx++) { + for (; ref_ind < ref_stack_size; ref_ind++) { + if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { + return false; + } + } + + for (; idx < len; idx++) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - if (mode_type != _ref_stack.at(idx)) { + if (!(this->_ignored_types.count(mode_type) > 0)) { return false; } } - return true; + return ref_ind == ref_stack_size && idx == len; } private: std::vector _ref_stack; + std::set _ignored_types; }; class TENSOR_MATCH : public LeafGuard { @@ -3735,7 +3785,7 @@ PyObject* torch_c_dynamo_guards_init() { LeafGuard, std::shared_ptr>( py_m, "TORCH_FUNCTION_MODE_STACK") - .def(py::init()) + .def(py::init()) .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); py::class_>( py_m, "DATA_PTR_MATCH") @@ -3972,9 +4022,10 @@ PyObject* torch_c_dynamo_guards_init() { "add_torch_function_mode_stack_guard", [](GuardManager& self, const py::list& initial_stack, + const py::list& ignored_types, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - initial_stack, std::move(verbose_code_parts))); + initial_stack, ignored_types, std::move(verbose_code_parts))); }) .def( "add_data_ptr_guard",