diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 6b6b1de7225953..e2894c2fa6052a 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -7596,76 +7596,6 @@ def inner(a, b, res_dtype): torch.set_default_dtype(torch.double) foo() - def test_no_recompile_inner_function(self): - def forward(inp): - def g(y): - return inp + y - - print("graph break") - return g(torch.rand([1])) - - cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch._dynamo.optimize(cnts)(forward) - - input = torch.rand([2]) - _ = opt_fn(input) - _ = opt_fn(input) - _ = opt_fn(input) - # Should not have recompiled - self.assertEqual(cnts.frame_count, 1) - - def test_no_recompile_inner_lambda(self): - def forward(inp): - g = lambda y: inp + y - print("graph break") - return g(torch.rand([1])) - - cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch._dynamo.optimize(cnts)(forward) - - input = torch.rand([2]) - _ = opt_fn(input) - _ = opt_fn(input) - _ = opt_fn(input) - # Should not have recompiled - self.assertEqual(cnts.frame_count, 1) - - def test_complex_closure(self): - @torch.compile - def forward(y): - def a(): - def x(z): - return y + z - - return x - - return a() - - input1 = torch.rand([2]) - input2 = torch.rand([2]) - res = forward(input1)(input2) - self.assertTrue(same(res, input1 + input2)) - - def test_non_inlined_closure(self): - @torch.compile() - def program(x, y): - one = lambda x, y: x + y - - def inner(): - # Force no inlining - torch._dynamo.graph_break() - return one(x, y) - - res = inner() - one = lambda x, y: x - y - res += inner() - return res - - input1 = torch.randn(1) - input2 = torch.randn(1) - - self.assertTrue(same(program(input1, input2), input1 + input1)) - class TestTracer(JitTestCase): def test_jit_save(self): diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index a1ea485c3ac2fc..fce2b9ce09e5b6 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -1529,15 +1529,9 @@ def foo(x, true_fn, false_fn): inp = torch.ones(3, 4) exp_out = inp.sin() iter_n = torch._dynamo.config.cache_size_limit + 1 - - # Need this because Dynamo checks lambda code ID not object itself. - def make_dummy_fn(op): - exec(f"temp = lambda x: x.{op}()") - return locals()["temp"] - for _ in range(iter_n): # each lambda has a different object id thus fails the guard - self.assertEqual(foo(inp, make_dummy_fn("cos"), make_dummy_fn("sin")), exp_out) + self.assertEqual(foo(inp, lambda x: x.cos(), lambda x: x.sin()), exp_out) self.assertEqual(counters["stats"]["calls_captured"], iter_n) self.assertEqual(counters["stats"]["unique_graphs"], iter_n) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 0c4c95e998fab6..971306c432321e 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -182,7 +182,7 @@ class GuardCodeList: class GuardBuilder(GuardBuilderBase): def __init__( self, - id_ref: Callable[[Any], str], + id_ref: Callable[[Type[object]], str], source_ref: Callable[[Source], str], lookup_weakrefs: Callable[[Type[object]], ReferenceType[object]], user_scope: Optional[Dict[str, object]], @@ -489,20 +489,6 @@ def FUNCTION_MATCH(self, guard: Guard): if guard.is_local(): return self.ID_MATCH(guard) - def CLOSURE_MATCH(self, guard: Guard): - """matches a closure by __code__ id.""" - if guard.is_local(): - val = self.get(guard.name) - # Strictly only want user-defined functions - if type(val) == types.FunctionType and hasattr(val, "__code__"): - ref = self.arg_ref(guard) - code = [ - f"___check_obj_id(getattr({ref}, '__code__', None), {self.id_ref(val.__code__)})", - ] - self._produce_guard_code(guard, code) - else: - self.FUNCTION_MATCH(guard) - def BUILTIN_MATCH(self, guard: Guard): return self.FUNCTION_MATCH(guard) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d7ebccb769191a..b4aab738ce50ac 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -338,7 +338,7 @@ def _id_dispatch(cls): lambda self, value: LambdaVariable( InspectSignatureVariable.create, source=self.source, - guards=self.make_guards(GuardBuilder.CLOSURE_MATCH), + guards=self.make_guards(GuardBuilder.FUNCTION_MATCH), ), ), (comptime, lambda self, value: ComptimeVariable()), @@ -562,7 +562,7 @@ def index_source(key): return UserFunctionVariable( value, source=self.source, - guards=make_guards(GuardBuilder.CLOSURE_MATCH), + guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif istype(value, (types.ModuleType, replay_record.DummyModule)): return PythonModuleVariable(