diff --git a/tests/test_thunder_ddp.py b/tests/test_thunder_ddp.py index 957602e8e7..75227cb301 100644 --- a/tests/test_thunder_ddp.py +++ b/tests/test_thunder_ddp.py @@ -42,14 +42,14 @@ def test_no_backward_sync(choice): if "thunder" in choice: import thunder - model = thunder.jit(model) + cmodel = model = thunder.jit(model) model = fabric.setup(model) # 6 iters, 3 grad accumulation iters for i, enabled in enumerate((True, True, False, True, True, False), 1): x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32) - with fabric.no_backward_sync(model, enabled): + with fabric.no_backward_sync(cmodel, enabled): y = model(x) fabric.backward(y.sum()) if not enabled: