Skip to content

Commit

Permalink
[cond] add a set_ and data mutation expected failure test (pytorch#12…
Browse files Browse the repository at this point in the history
…8457)

A follow up of the discussion in pytorch#126936.

Cond errors out early because of a graph break triggered by DelayGraphBreakVariable, which is created due to `aten.set_` [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/tensor.py#L366-L376).

We might need to see what happened to this test if we allow graph break in higher order op.

Pull Request resolved: pytorch#128457
Approved by: https://github.com/zou3519
  • Loading branch information
ydwu4 authored and pytorchmergebot committed Jun 13, 2024
1 parent c472cec commit c1cd946
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2599,6 +2599,74 @@ def wrapper(x):
res = torch.vmap(wrapper)(a)
self.assertEqual(res, a + 1)

def test_cond_trace_set__and_mutate_input(self):
def f(a, tmp):
a_view = a.view(-1)
with torch.no_grad():
a.set_(tmp)
a_view.mul_(2)
return a + tmp

inp = torch.ones(3, 3, requires_grad=True)
tmp = torch.ones(3, 3, requires_grad=True)
# graph break: torch._dynamo.exc.Unsupported: call_function DelayGraphBreakVariable() [TensorVariable()] {}
# due to set_
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
torch.cond(inp.sum() > 0, f, f, (inp, tmp))

def test_cond_trace_set__and_mutate_intermediate(self):
def f(a, tmp):
a = a.clone()
a_view = a.view(-1)
tmp = tmp.clone()
with torch.no_grad():
a.set_(tmp)
a_view.mul_(2)
return a + tmp

inp = torch.ones(3, 3, requires_grad=True)
tmp = torch.ones(3, 3, requires_grad=True)

class Mod(torch.nn.Module):
def forward(self, inp: torch.Tensor, tmp: torch.Tensor) -> torch.Tensor:
return torch.cond(inp.sum() > 0, f, f, (inp, tmp))

with self.assertRaisesRegex(
RuntimeError, "cannot mutate tensors with frozen storage"
):
out = torch.compile(Mod(), backend="aot_eager")(inp, tmp)

with self.assertRaisesRegex(
RuntimeError, "cannot mutate tensors with frozen storage"
):
out = torch.compile(Mod(), backend="inductor")(inp, tmp)

from torch._dynamo.testing import EagerAndRecordGraphs

backend = EagerAndRecordGraphs()
out = torch.compile(Mod(), backend=backend)(inp, tmp)
self.assertExpectedInline(
backend.graphs[0].cond_true_0.code.strip("\n"),
"""\
def forward(self, l_inp_, l_tmp_):
l_inp__1 = l_inp_
l_tmp__1 = l_tmp_
clone = l_inp__1.clone(); l_inp__1 = None
view = clone.view(-1)
clone_1 = l_tmp__1.clone(); l_tmp__1 = None
_set_grad_enabled = torch._C._set_grad_enabled(False)
set_ = clone.set_(clone_1)
mul_ = view.mul_(2); view = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
add = clone + clone_1; clone = clone_1 = None
return (add,)
""",
)
self.assertEqual(out, f(inp, tmp))


instantiate_parametrized_tests(TestControlFlowTraced)

Expand Down

0 comments on commit c1cd946

Please sign in to comment.