From 0bb29f945079ac4c83d674f7b3ff755cfb5396cf Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Mon, 20 Nov 2023 19:31:59 +0000 Subject: [PATCH] [dynamo] Guard on `HAS_GRAPH_BREAKS` if graph breaks are present (i.e. cache miss if compiled object requires nopython) (#114073) Fixes https://github.com/pytorch/pytorch/issues/114059 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114073 Approved by: https://github.com/ezyang --- test/dynamo/test_recompiles.py | 75 +++++++++++++++++++++++++++++++ torch/_dynamo/eval_frame.py | 17 +++++-- torch/_dynamo/guards.py | 8 ++++ torch/_dynamo/output_graph.py | 13 +++++- torch/_dynamo/symbolic_convert.py | 1 + 5 files changed, 110 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_recompiles.py b/test/dynamo/test_recompiles.py index 171bf9020d7287..96499fb0c98c8e 100644 --- a/test/dynamo/test_recompiles.py +++ b/test/dynamo/test_recompiles.py @@ -356,6 +356,81 @@ def forward(self, x): model(x) self.assertEqual(counter.frame_count, 2) + def test_forbid_nopython_has_graph_break_cache_hit(self): + from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn + + for create_functions in [ + lambda f, cnt: ( + torch.compile(f, backend=cnt), + torch.compile(f, backend=cnt, fullgraph=True), + ), + lambda f, cnt: ( + torch._dynamo.optimize(backend=cnt)(f), + torch._dynamo.optimize(backend=cnt, nopython=True)(f), + ), + ]: + torch._dynamo.reset() + + def fn(x): + if len(x.size()) == 1: + x = x + 2 + torch._dynamo.graph_break() + return x + 1 + else: + return x + 1 + + cnt = torch._dynamo.testing.CompileCounter() + + opt_fn, nopython_fn = create_functions(fn, cnt) + + with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "graph_break"): + nopython_fn(torch.zeros(1)) + self.assertEqual(cnt.frame_count, 0) + + opt_fn(torch.zeros(1)) + self.assertEqual(cnt.frame_count, 2) + + cache_entries = _debug_get_cache_entry_list(innermost_fn(opt_fn)) + self.assertEqual(len(cache_entries), 1) + # guarded code with graph break has `___needs_nopython` guard + self.assertTrue( + any( + "___needs_nopython" in part + for part in cache_entries[0].check_fn.code_parts + ) + ) + + with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "graph_break"): + nopython_fn(torch.zeros(1)) + self.assertEqual(cnt.frame_count, 2) + + opt_fn(torch.zeros(1)) + self.assertEqual(cnt.frame_count, 2) + + nopython_fn(torch.zeros((1, 2))) + self.assertEqual(cnt.frame_count, 3) + + cache_entries = _debug_get_cache_entry_list(innermost_fn(opt_fn)) + self.assertEqual(len(cache_entries), 2) + # nopython function with no graph break does not have `___needs_nopython` guard + self.assertFalse( + any( + "___needs_nopython" in part + for part in cache_entries[0].check_fn.code_parts + ) + ) + # previous guarded code with graph break still has `___needs_nopython` guard + self.assertTrue( + any( + "___needs_nopython" in part + for part in cache_entries[1].check_fn.code_parts + ) + ) + + # nopython does not recompile - manages to hit cache entry with no graph breaks + nopython_fn(torch.zeros((1, 2))) + self.assertEqual(cnt.frame_count, 3) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 3636fc3581616f..b8c95fb168d87a 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -285,19 +285,22 @@ def _maybe_init_guarded_config_cache(): if not hasattr(config_cache, "saved_config_and_hash"): # Optional[ConfigAndHash] config_cache.saved_config_and_hash = None + config_cache.nopython = None @contextlib.contextmanager def restore_guarded_dynamo_config( - first_ctx: bool, saved_config_and_hash: ConfigAndHash + first_ctx: bool, saved_config_and_hash: ConfigAndHash, nopython: bool ): _maybe_init_guarded_config_cache() # Set exactly once from top-level compile is_top_level = False try: if first_ctx and config_cache.saved_config_and_hash is None: + assert config_cache.nopython is None is_top_level = True config_cache.saved_config_and_hash = saved_config_and_hash + config_cache.nopython = nopython log.debug( "Setting top-level compile config hash: %s", saved_config_and_hash.hash.hex(), @@ -312,6 +315,7 @@ def restore_guarded_dynamo_config( config_cache.saved_config_and_hash.hash.hex(), ) config_cache.saved_config_and_hash = None + config_cache.nopython = None def _get_config_and_hash(dynamic=None): @@ -344,6 +348,7 @@ def __init__( dynamic=None, compiler_config=None, save_config=True, + nopython=False, ): super().__init__() assert callable(callback) or callback is False or callback is None @@ -355,6 +360,7 @@ def __init__( self.dynamic = dynamic self.compiler_config = compiler_config self.save_config = save_config and first_ctx + self.nopython = nopython if self.save_config: self.save_and_hash_config() patch_fn() @@ -382,7 +388,7 @@ def __enter__(self): self.backend_ctx.__enter__() if self.save_config: self.dynamo_config_ctx = restore_guarded_dynamo_config( - self.first_ctx, self.saved_config_and_hash + self.first_ctx, self.saved_config_and_hash, self.nopython ) self.dynamo_config_ctx.__enter__() @@ -475,7 +481,7 @@ def _fn(*args, **kwargs): backend_ctx.__enter__() if self.save_config: dynamo_config_ctx = restore_guarded_dynamo_config( - self.first_ctx, self.saved_config_and_hash + self.first_ctx, self.saved_config_and_hash, self.nopython ) dynamo_config_ctx.__enter__() try: @@ -554,6 +560,7 @@ def __init__( dynamic=None, save_config=True, compiler_config=None, + nopython=False, ): def on_enter(): install_generation_tagging_init() @@ -567,6 +574,7 @@ def on_enter(): dynamic=dynamic, compiler_config=compiler_config, save_config=save_config, + nopython=nopython, ) @@ -656,6 +664,7 @@ def _optimize_catch_errors( dynamic=None, compiler_config=None, save_config=True, + nopython=False, ): return OptimizeContext( catch_errors_wrapper(compile_fn, hooks), @@ -664,6 +673,7 @@ def _optimize_catch_errors( dynamic=dynamic, compiler_config=compiler_config, save_config=save_config, + nopython=nopython, ) @@ -1469,6 +1479,7 @@ def optimize_assert( backend_ctx_ctor, dynamic=dynamic, save_config=save_config, + nopython=True, ) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 3cbf90dfe5d9be..1b068402019b0b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -110,6 +110,7 @@ def uninteresting_files(): "___compile_config_hash": ( lambda: torch._dynamo.eval_frame.get_saved_else_current_config_hash().hex() ), + "___needs_nopython": (lambda: torch._dynamo.eval_frame.config_cache.nopython), "___odict_getitem": collections.OrderedDict.__getitem__, "___dict_param_key_ids": dict_param_key_ids, "___dict_const_keys": dict_const_keys, @@ -611,6 +612,13 @@ def CONFIG_HASH_MATCH(self, guard: Guard): self.config_hash = config_hash self._produce_guard_code(guard, code) + def HAS_GRAPH_BREAK(self, guard: Guard): + # If this compiled entry has a graph break / is not a single graph, it is a cache miss + # if the compiled object needs nopython. We only need to install this guard if + # there is a graph break. + code = ["not ___needs_nopython()"] + self._produce_guard_code(guard, code) + def SHAPE_ENV(self, guard: Guard): # Let's handle ShapeEnv guards. To do this, we will resolve # shape variables to sources from tracked_fakes. This must happen after diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 23be43203b18dd..f17e3e0f9e220e 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -372,6 +372,9 @@ def init_ambient_guards(self): self.guards.add(GlobalStateSource().make_guard(GuardBuilder.CONFIG_HASH_MATCH)) + def guard_has_graph_break(self): + self.guards.add(GlobalStateSource().make_guard(GuardBuilder.HAS_GRAPH_BREAK)) + def add_cleanup_hook(self, fn: Callable[[], Any]): self.cleanup_hooks.append(fn) @@ -769,7 +772,11 @@ def register_leaf_name(leaf_name): raise AssertionError("unreachable") def compile_subgraph( - self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None + self, + tx, + partial_convert=False, + reason: Optional[GraphCompileReason] = None, + compile_return_value=False, ): """ Generate a subgraph to continue execution on user code. @@ -783,6 +790,10 @@ def compile_subgraph( self.compile_subgraph_reason = reason self.should_exit = True + if not compile_return_value: + # invalid graph to be cache hit for nopython + self.guard_has_graph_break() + log.debug("COMPILING GRAPH due to %s", reason) if not all(block.can_restore() for block in tx.block_stack): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 32469a9bcb71aa..5189a6df7b01f9 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2240,6 +2240,7 @@ def RETURN_VALUE(self, inst): reason=GraphCompileReason( "return_value", [self.frame_summary()], graph_break=False ), + compile_return_value=True, ) self.output.add_output_instructions([create_instruction("RETURN_VALUE")])