Skip to content

Commit

Permalink
[dynamo] Guard on HAS_GRAPH_BREAKS if graph breaks are present (i.e…
Browse files Browse the repository at this point in the history
…. cache miss if compiled object requires nopython) (pytorch#114073)

Fixes pytorch#114059

Pull Request resolved: pytorch#114073
Approved by: https://github.com/ezyang
  • Loading branch information
jon-chuang authored and pytorchmergebot committed Nov 20, 2023
1 parent 2b4c489 commit 0bb29f9
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 4 deletions.
75 changes: 75 additions & 0 deletions test/dynamo/test_recompiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,81 @@ def forward(self, x):
model(x)
self.assertEqual(counter.frame_count, 2)

def test_forbid_nopython_has_graph_break_cache_hit(self):
from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn

for create_functions in [
lambda f, cnt: (
torch.compile(f, backend=cnt),
torch.compile(f, backend=cnt, fullgraph=True),
),
lambda f, cnt: (
torch._dynamo.optimize(backend=cnt)(f),
torch._dynamo.optimize(backend=cnt, nopython=True)(f),
),
]:
torch._dynamo.reset()

def fn(x):
if len(x.size()) == 1:
x = x + 2
torch._dynamo.graph_break()
return x + 1
else:
return x + 1

cnt = torch._dynamo.testing.CompileCounter()

opt_fn, nopython_fn = create_functions(fn, cnt)

with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "graph_break"):
nopython_fn(torch.zeros(1))
self.assertEqual(cnt.frame_count, 0)

opt_fn(torch.zeros(1))
self.assertEqual(cnt.frame_count, 2)

cache_entries = _debug_get_cache_entry_list(innermost_fn(opt_fn))
self.assertEqual(len(cache_entries), 1)
# guarded code with graph break has `___needs_nopython` guard
self.assertTrue(
any(
"___needs_nopython" in part
for part in cache_entries[0].check_fn.code_parts
)
)

with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "graph_break"):
nopython_fn(torch.zeros(1))
self.assertEqual(cnt.frame_count, 2)

opt_fn(torch.zeros(1))
self.assertEqual(cnt.frame_count, 2)

nopython_fn(torch.zeros((1, 2)))
self.assertEqual(cnt.frame_count, 3)

cache_entries = _debug_get_cache_entry_list(innermost_fn(opt_fn))
self.assertEqual(len(cache_entries), 2)
# nopython function with no graph break does not have `___needs_nopython` guard
self.assertFalse(
any(
"___needs_nopython" in part
for part in cache_entries[0].check_fn.code_parts
)
)
# previous guarded code with graph break still has `___needs_nopython` guard
self.assertTrue(
any(
"___needs_nopython" in part
for part in cache_entries[1].check_fn.code_parts
)
)

# nopython does not recompile - manages to hit cache entry with no graph breaks
nopython_fn(torch.zeros((1, 2)))
self.assertEqual(cnt.frame_count, 3)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
17 changes: 14 additions & 3 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,22 @@ def _maybe_init_guarded_config_cache():
if not hasattr(config_cache, "saved_config_and_hash"):
# Optional[ConfigAndHash]
config_cache.saved_config_and_hash = None
config_cache.nopython = None


@contextlib.contextmanager
def restore_guarded_dynamo_config(
first_ctx: bool, saved_config_and_hash: ConfigAndHash
first_ctx: bool, saved_config_and_hash: ConfigAndHash, nopython: bool
):
_maybe_init_guarded_config_cache()
# Set exactly once from top-level compile
is_top_level = False
try:
if first_ctx and config_cache.saved_config_and_hash is None:
assert config_cache.nopython is None
is_top_level = True
config_cache.saved_config_and_hash = saved_config_and_hash
config_cache.nopython = nopython
log.debug(
"Setting top-level compile config hash: %s",
saved_config_and_hash.hash.hex(),
Expand All @@ -312,6 +315,7 @@ def restore_guarded_dynamo_config(
config_cache.saved_config_and_hash.hash.hex(),
)
config_cache.saved_config_and_hash = None
config_cache.nopython = None


def _get_config_and_hash(dynamic=None):
Expand Down Expand Up @@ -344,6 +348,7 @@ def __init__(
dynamic=None,
compiler_config=None,
save_config=True,
nopython=False,
):
super().__init__()
assert callable(callback) or callback is False or callback is None
Expand All @@ -355,6 +360,7 @@ def __init__(
self.dynamic = dynamic
self.compiler_config = compiler_config
self.save_config = save_config and first_ctx
self.nopython = nopython
if self.save_config:
self.save_and_hash_config()
patch_fn()
Expand Down Expand Up @@ -382,7 +388,7 @@ def __enter__(self):
self.backend_ctx.__enter__()
if self.save_config:
self.dynamo_config_ctx = restore_guarded_dynamo_config(
self.first_ctx, self.saved_config_and_hash
self.first_ctx, self.saved_config_and_hash, self.nopython
)
self.dynamo_config_ctx.__enter__()

Expand Down Expand Up @@ -475,7 +481,7 @@ def _fn(*args, **kwargs):
backend_ctx.__enter__()
if self.save_config:
dynamo_config_ctx = restore_guarded_dynamo_config(
self.first_ctx, self.saved_config_and_hash
self.first_ctx, self.saved_config_and_hash, self.nopython
)
dynamo_config_ctx.__enter__()
try:
Expand Down Expand Up @@ -554,6 +560,7 @@ def __init__(
dynamic=None,
save_config=True,
compiler_config=None,
nopython=False,
):
def on_enter():
install_generation_tagging_init()
Expand All @@ -567,6 +574,7 @@ def on_enter():
dynamic=dynamic,
compiler_config=compiler_config,
save_config=save_config,
nopython=nopython,
)


Expand Down Expand Up @@ -656,6 +664,7 @@ def _optimize_catch_errors(
dynamic=None,
compiler_config=None,
save_config=True,
nopython=False,
):
return OptimizeContext(
catch_errors_wrapper(compile_fn, hooks),
Expand All @@ -664,6 +673,7 @@ def _optimize_catch_errors(
dynamic=dynamic,
compiler_config=compiler_config,
save_config=save_config,
nopython=nopython,
)


Expand Down Expand Up @@ -1469,6 +1479,7 @@ def optimize_assert(
backend_ctx_ctor,
dynamic=dynamic,
save_config=save_config,
nopython=True,
)


Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def uninteresting_files():
"___compile_config_hash": (
lambda: torch._dynamo.eval_frame.get_saved_else_current_config_hash().hex()
),
"___needs_nopython": (lambda: torch._dynamo.eval_frame.config_cache.nopython),
"___odict_getitem": collections.OrderedDict.__getitem__,
"___dict_param_key_ids": dict_param_key_ids,
"___dict_const_keys": dict_const_keys,
Expand Down Expand Up @@ -611,6 +612,13 @@ def CONFIG_HASH_MATCH(self, guard: Guard):
self.config_hash = config_hash
self._produce_guard_code(guard, code)

def HAS_GRAPH_BREAK(self, guard: Guard):
# If this compiled entry has a graph break / is not a single graph, it is a cache miss
# if the compiled object needs nopython. We only need to install this guard if
# there is a graph break.
code = ["not ___needs_nopython()"]
self._produce_guard_code(guard, code)

def SHAPE_ENV(self, guard: Guard):
# Let's handle ShapeEnv guards. To do this, we will resolve
# shape variables to sources from tracked_fakes. This must happen after
Expand Down
13 changes: 12 additions & 1 deletion torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ def init_ambient_guards(self):

self.guards.add(GlobalStateSource().make_guard(GuardBuilder.CONFIG_HASH_MATCH))

def guard_has_graph_break(self):
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.HAS_GRAPH_BREAK))

def add_cleanup_hook(self, fn: Callable[[], Any]):
self.cleanup_hooks.append(fn)

Expand Down Expand Up @@ -769,7 +772,11 @@ def register_leaf_name(leaf_name):
raise AssertionError("unreachable")

def compile_subgraph(
self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
self,
tx,
partial_convert=False,
reason: Optional[GraphCompileReason] = None,
compile_return_value=False,
):
"""
Generate a subgraph to continue execution on user code.
Expand All @@ -783,6 +790,10 @@ def compile_subgraph(
self.compile_subgraph_reason = reason
self.should_exit = True

if not compile_return_value:
# invalid graph to be cache hit for nopython
self.guard_has_graph_break()

log.debug("COMPILING GRAPH due to %s", reason)

if not all(block.can_restore() for block in tx.block_stack):
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,6 +2240,7 @@ def RETURN_VALUE(self, inst):
reason=GraphCompileReason(
"return_value", [self.frame_summary()], graph_break=False
),
compile_return_value=True,
)
self.output.add_output_instructions([create_instruction("RETURN_VALUE")])

Expand Down

0 comments on commit 0bb29f9

Please sign in to comment.