From c1cd946818442aca8c7f812b16d187ce1586c3bc Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Wed, 12 Jun 2024 13:56:27 -0700 Subject: [PATCH] [cond] add a set_ and data mutation expected failure test (#128457) A follow up of the discussion in https://github.com/pytorch/pytorch/pull/126936. Cond errors out early because of a graph break triggered by DelayGraphBreakVariable, which is created due to `aten.set_` [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/tensor.py#L366-L376). We might need to see what happened to this test if we allow graph break in higher order op. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128457 Approved by: https://github.com/zou3519 --- test/functorch/test_control_flow.py | 68 +++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index f538c5af78cef..63d00cf61bad8 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -2599,6 +2599,74 @@ def wrapper(x): res = torch.vmap(wrapper)(a) self.assertEqual(res, a + 1) + def test_cond_trace_set__and_mutate_input(self): + def f(a, tmp): + a_view = a.view(-1) + with torch.no_grad(): + a.set_(tmp) + a_view.mul_(2) + return a + tmp + + inp = torch.ones(3, 3, requires_grad=True) + tmp = torch.ones(3, 3, requires_grad=True) + # graph break: torch._dynamo.exc.Unsupported: call_function DelayGraphBreakVariable() [TensorVariable()] {} + # due to set_ + with self.assertRaisesRegex( + torch._dynamo.exc.UncapturedHigherOrderOpError, + "Cond doesn't work unless it is captured completely with torch.compile", + ): + torch.cond(inp.sum() > 0, f, f, (inp, tmp)) + + def test_cond_trace_set__and_mutate_intermediate(self): + def f(a, tmp): + a = a.clone() + a_view = a.view(-1) + tmp = tmp.clone() + with torch.no_grad(): + a.set_(tmp) + a_view.mul_(2) + return a + tmp + + inp = torch.ones(3, 3, requires_grad=True) + tmp = torch.ones(3, 3, requires_grad=True) + + class Mod(torch.nn.Module): + def forward(self, inp: torch.Tensor, tmp: torch.Tensor) -> torch.Tensor: + return torch.cond(inp.sum() > 0, f, f, (inp, tmp)) + + with self.assertRaisesRegex( + RuntimeError, "cannot mutate tensors with frozen storage" + ): + out = torch.compile(Mod(), backend="aot_eager")(inp, tmp) + + with self.assertRaisesRegex( + RuntimeError, "cannot mutate tensors with frozen storage" + ): + out = torch.compile(Mod(), backend="inductor")(inp, tmp) + + from torch._dynamo.testing import EagerAndRecordGraphs + + backend = EagerAndRecordGraphs() + out = torch.compile(Mod(), backend=backend)(inp, tmp) + self.assertExpectedInline( + backend.graphs[0].cond_true_0.code.strip("\n"), + """\ +def forward(self, l_inp_, l_tmp_): + l_inp__1 = l_inp_ + l_tmp__1 = l_tmp_ + clone = l_inp__1.clone(); l_inp__1 = None + view = clone.view(-1) + clone_1 = l_tmp__1.clone(); l_tmp__1 = None + _set_grad_enabled = torch._C._set_grad_enabled(False) + set_ = clone.set_(clone_1) + mul_ = view.mul_(2); view = None + _set_grad_enabled_1 = torch._C._set_grad_enabled(True) + add = clone + clone_1; clone = clone_1 = None + return (add,) + """, + ) + self.assertEqual(out, f(inp, tmp)) + instantiate_parametrized_tests(TestControlFlowTraced)