Skip to content

Commit

Permalink
Debug FSDP test
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed May 6, 2024
1 parent f334378 commit 3a04c3c
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/test_thunder_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,6 @@ def set_up_planner(self, state_dict, metadata, is_coordinator):

@RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
def test_save_load_sharded_checkpoint(tmp_path):
pytest.skip("Temporarily disabled, often exceeds 5 min timeout")

strategy = ThunderFSDPStrategy(state_dict_type="sharded", broadcast_from=0)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
Expand All @@ -274,8 +272,10 @@ def test_save_load_sharded_checkpoint(tmp_path):

# save a sharded model
model = fabric.setup(model)
print("Past fabric.setup()")
state = {"model": model, "stateful": StatefulThing(), "primitive": 123}
fabric.save(tmp_path, state)
print("Past fabric.save()")

# assert the file contents
if fabric.global_rank == 0:
Expand All @@ -295,24 +295,29 @@ def test_save_load_sharded_checkpoint(tmp_path):
}
}
torch.testing.assert_close(checkpoint["model"], expected)
print("Past rank0 save checks()")

# load its weights into a different sharded model
model = MyModel(4)
model = fabric.setup(model)
print("Past fabric.setup() 2")
state = {"model": model, "stateful": StatefulThing(), "primitive": 321}
fabric.load(tmp_path, state)
print("Past fabric.load()")

from thunder.distributed import _unshard_params

# unshard this model's parameters to compare with the original state dict before sharding
_unshard_params(model, model.process_group_for_ddp, True)
print("Past unshard_params")
# we loaded rank 0's weights, so this would fail in the other ranks
if fabric.global_rank == 0:
actual = model.state_dict()
# `_unshard_params` doesnt offload buffers at the moment
assert actual["buf"].device.type == "cuda"
actual["buf"] = actual["buf"].to(device="cpu")
torch.testing.assert_close(actual, expected)
print("Past rank0 shard checks()")
assert state["primitive"] == 123


Expand Down

0 comments on commit 3a04c3c

Please sign in to comment.