From 194d9aa0f2ca9ee2c996af5c2839ed1dc3f06e46 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 26 Sep 2023 18:54:36 +0000 Subject: [PATCH] Revert "[Dynamo] Match closures by code ID (#109427)" This reverts commit 3de08575031bc0ea770b5935dec13046d8ba7992. Reverted https://github.com/pytorch/pytorch/pull/109427 on behalf of https://github.com/voznesenskym due to Fails test `PYTORCH_TEST_WITH_DYNAMO=1 python test_ops.py -k test_out_warning__refs_cat_cpu ([comment](https://github.com/pytorch/pytorch/pull/109427#issuecomment-1736101561)) --- test/dynamo/test_misc.py | 70 ----------------------------- test/functorch/test_control_flow.py | 8 +--- torch/_dynamo/guards.py | 16 +------ torch/_dynamo/variables/builder.py | 4 +- 4 files changed, 4 insertions(+), 94 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 6b6b1de722595..e2894c2fa6052 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -7596,76 +7596,6 @@ def inner(a, b, res_dtype): torch.set_default_dtype(torch.double) foo() - def test_no_recompile_inner_function(self): - def forward(inp): - def g(y): - return inp + y - - print("graph break") - return g(torch.rand([1])) - - cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch._dynamo.optimize(cnts)(forward) - - input = torch.rand([2]) - _ = opt_fn(input) - _ = opt_fn(input) - _ = opt_fn(input) - # Should not have recompiled - self.assertEqual(cnts.frame_count, 1) - - def test_no_recompile_inner_lambda(self): - def forward(inp): - g = lambda y: inp + y - print("graph break") - return g(torch.rand([1])) - - cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch._dynamo.optimize(cnts)(forward) - - input = torch.rand([2]) - _ = opt_fn(input) - _ = opt_fn(input) - _ = opt_fn(input) - # Should not have recompiled - self.assertEqual(cnts.frame_count, 1) - - def test_complex_closure(self): - @torch.compile - def forward(y): - def a(): - def x(z): - return y + z - - return x - - return a() - - input1 = torch.rand([2]) - input2 = torch.rand([2]) - res = forward(input1)(input2) - self.assertTrue(same(res, input1 + input2)) - - def test_non_inlined_closure(self): - @torch.compile() - def program(x, y): - one = lambda x, y: x + y - - def inner(): - # Force no inlining - torch._dynamo.graph_break() - return one(x, y) - - res = inner() - one = lambda x, y: x - y - res += inner() - return res - - input1 = torch.randn(1) - input2 = torch.randn(1) - - self.assertTrue(same(program(input1, input2), input1 + input1)) - class TestTracer(JitTestCase): def test_jit_save(self): diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index a1ea485c3ac2f..fce2b9ce09e5b 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -1529,15 +1529,9 @@ def foo(x, true_fn, false_fn): inp = torch.ones(3, 4) exp_out = inp.sin() iter_n = torch._dynamo.config.cache_size_limit + 1 - - # Need this because Dynamo checks lambda code ID not object itself. - def make_dummy_fn(op): - exec(f"temp = lambda x: x.{op}()") - return locals()["temp"] - for _ in range(iter_n): # each lambda has a different object id thus fails the guard - self.assertEqual(foo(inp, make_dummy_fn("cos"), make_dummy_fn("sin")), exp_out) + self.assertEqual(foo(inp, lambda x: x.cos(), lambda x: x.sin()), exp_out) self.assertEqual(counters["stats"]["calls_captured"], iter_n) self.assertEqual(counters["stats"]["unique_graphs"], iter_n) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 0c4c95e998fab..971306c432321 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -182,7 +182,7 @@ class GuardCodeList: class GuardBuilder(GuardBuilderBase): def __init__( self, - id_ref: Callable[[Any], str], + id_ref: Callable[[Type[object]], str], source_ref: Callable[[Source], str], lookup_weakrefs: Callable[[Type[object]], ReferenceType[object]], user_scope: Optional[Dict[str, object]], @@ -489,20 +489,6 @@ def FUNCTION_MATCH(self, guard: Guard): if guard.is_local(): return self.ID_MATCH(guard) - def CLOSURE_MATCH(self, guard: Guard): - """matches a closure by __code__ id.""" - if guard.is_local(): - val = self.get(guard.name) - # Strictly only want user-defined functions - if type(val) == types.FunctionType and hasattr(val, "__code__"): - ref = self.arg_ref(guard) - code = [ - f"___check_obj_id(getattr({ref}, '__code__', None), {self.id_ref(val.__code__)})", - ] - self._produce_guard_code(guard, code) - else: - self.FUNCTION_MATCH(guard) - def BUILTIN_MATCH(self, guard: Guard): return self.FUNCTION_MATCH(guard) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d7ebccb769191..b4aab738ce50a 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -338,7 +338,7 @@ def _id_dispatch(cls): lambda self, value: LambdaVariable( InspectSignatureVariable.create, source=self.source, - guards=self.make_guards(GuardBuilder.CLOSURE_MATCH), + guards=self.make_guards(GuardBuilder.FUNCTION_MATCH), ), ), (comptime, lambda self, value: ComptimeVariable()), @@ -562,7 +562,7 @@ def index_source(key): return UserFunctionVariable( value, source=self.source, - guards=make_guards(GuardBuilder.CLOSURE_MATCH), + guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif istype(value, (types.ModuleType, replay_record.DummyModule)): return PythonModuleVariable(