From 360d7194793bb83e5d983934c8889cfb8cf1dbc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 23 May 2024 15:04:29 +0000 Subject: [PATCH] use thunder module reference for no bwd sync --- tests/test_thunder_ddp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_thunder_ddp.py b/tests/test_thunder_ddp.py index 957602e8e7..75227cb301 100644 --- a/tests/test_thunder_ddp.py +++ b/tests/test_thunder_ddp.py @@ -42,14 +42,14 @@ def test_no_backward_sync(choice): if "thunder" in choice: import thunder - model = thunder.jit(model) + cmodel = model = thunder.jit(model) model = fabric.setup(model) # 6 iters, 3 grad accumulation iters for i, enabled in enumerate((True, True, False, True, True, False), 1): x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32) - with fabric.no_backward_sync(model, enabled): + with fabric.no_backward_sync(cmodel, enabled): y = model(x) fabric.backward(y.sum()) if not enabled: