Skip to content

Commit

Permalink
use thunder module reference for no bwd sync
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed May 23, 2024
1 parent bba376c commit 360d719
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_thunder_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def test_no_backward_sync(choice):
if "thunder" in choice:
import thunder

model = thunder.jit(model)
cmodel = model = thunder.jit(model)
model = fabric.setup(model)

# 6 iters, 3 grad accumulation iters
for i, enabled in enumerate((True, True, False, True, True, False), 1):
x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32)

with fabric.no_backward_sync(model, enabled):
with fabric.no_backward_sync(cmodel, enabled):
y = model(x)
fabric.backward(y.sum())
if not enabled:
Expand Down

0 comments on commit 360d719

Please sign in to comment.