Skip to content

Commit

Permalink
Revert "[Dynamo] Match closures by code ID (pytorch#109427)"
Browse files Browse the repository at this point in the history
This reverts commit 3de0857.

Reverted pytorch#109427 on behalf of https://github.com/voznesenskym due to Fails test `PYTORCH_TEST_WITH_DYNAMO=1 python test_ops.py -k test_out_warning__refs_cat_cpu ([comment](pytorch#109427 (comment)))
  • Loading branch information
pytorchmergebot committed Sep 26, 2023
1 parent a740969 commit 194d9aa
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 94 deletions.
70 changes: 0 additions & 70 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7596,76 +7596,6 @@ def inner(a, b, res_dtype):
torch.set_default_dtype(torch.double)
foo()

def test_no_recompile_inner_function(self):
def forward(inp):
def g(y):
return inp + y

print("graph break")
return g(torch.rand([1]))

cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(forward)

input = torch.rand([2])
_ = opt_fn(input)
_ = opt_fn(input)
_ = opt_fn(input)
# Should not have recompiled
self.assertEqual(cnts.frame_count, 1)

def test_no_recompile_inner_lambda(self):
def forward(inp):
g = lambda y: inp + y
print("graph break")
return g(torch.rand([1]))

cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(forward)

input = torch.rand([2])
_ = opt_fn(input)
_ = opt_fn(input)
_ = opt_fn(input)
# Should not have recompiled
self.assertEqual(cnts.frame_count, 1)

def test_complex_closure(self):
@torch.compile
def forward(y):
def a():
def x(z):
return y + z

return x

return a()

input1 = torch.rand([2])
input2 = torch.rand([2])
res = forward(input1)(input2)
self.assertTrue(same(res, input1 + input2))

def test_non_inlined_closure(self):
@torch.compile()
def program(x, y):
one = lambda x, y: x + y

def inner():
# Force no inlining
torch._dynamo.graph_break()
return one(x, y)

res = inner()
one = lambda x, y: x - y
res += inner()
return res

input1 = torch.randn(1)
input2 = torch.randn(1)

self.assertTrue(same(program(input1, input2), input1 + input1))


class TestTracer(JitTestCase):
def test_jit_save(self):
Expand Down
8 changes: 1 addition & 7 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,15 +1529,9 @@ def foo(x, true_fn, false_fn):
inp = torch.ones(3, 4)
exp_out = inp.sin()
iter_n = torch._dynamo.config.cache_size_limit + 1

# Need this because Dynamo checks lambda code ID not object itself.
def make_dummy_fn(op):
exec(f"temp = lambda x: x.{op}()")
return locals()["temp"]

for _ in range(iter_n):
# each lambda has a different object id thus fails the guard
self.assertEqual(foo(inp, make_dummy_fn("cos"), make_dummy_fn("sin")), exp_out)
self.assertEqual(foo(inp, lambda x: x.cos(), lambda x: x.sin()), exp_out)
self.assertEqual(counters["stats"]["calls_captured"], iter_n)
self.assertEqual(counters["stats"]["unique_graphs"], iter_n)

Expand Down
16 changes: 1 addition & 15 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class GuardCodeList:
class GuardBuilder(GuardBuilderBase):
def __init__(
self,
id_ref: Callable[[Any], str],
id_ref: Callable[[Type[object]], str],
source_ref: Callable[[Source], str],
lookup_weakrefs: Callable[[Type[object]], ReferenceType[object]],
user_scope: Optional[Dict[str, object]],
Expand Down Expand Up @@ -489,20 +489,6 @@ def FUNCTION_MATCH(self, guard: Guard):
if guard.is_local():
return self.ID_MATCH(guard)

def CLOSURE_MATCH(self, guard: Guard):
"""matches a closure by __code__ id."""
if guard.is_local():
val = self.get(guard.name)
# Strictly only want user-defined functions
if type(val) == types.FunctionType and hasattr(val, "__code__"):
ref = self.arg_ref(guard)
code = [
f"___check_obj_id(getattr({ref}, '__code__', None), {self.id_ref(val.__code__)})",
]
self._produce_guard_code(guard, code)
else:
self.FUNCTION_MATCH(guard)

def BUILTIN_MATCH(self, guard: Guard):
return self.FUNCTION_MATCH(guard)

Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _id_dispatch(cls):
lambda self, value: LambdaVariable(
InspectSignatureVariable.create,
source=self.source,
guards=self.make_guards(GuardBuilder.CLOSURE_MATCH),
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
),
),
(comptime, lambda self, value: ComptimeVariable()),
Expand Down Expand Up @@ -562,7 +562,7 @@ def index_source(key):
return UserFunctionVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.CLOSURE_MATCH),
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif istype(value, (types.ModuleType, replay_record.DummyModule)):
return PythonModuleVariable(
Expand Down

0 comments on commit 194d9aa

Please sign in to comment.