diff --git a/tests/test_thunder_fsdp.py b/tests/test_thunder_fsdp.py index 321cdac7a6..94d874520d 100644 --- a/tests/test_thunder_fsdp.py +++ b/tests/test_thunder_fsdp.py @@ -263,8 +263,6 @@ def set_up_planner(self, state_dict, metadata, is_coordinator): @RunIf(min_cuda_gpus=2, thunder=True, standalone=True) def test_save_load_sharded_checkpoint(tmp_path): - pytest.skip("Temporarily disabled, often exceeds 5 min timeout") - strategy = ThunderFSDPStrategy(state_dict_type="sharded", broadcast_from=0) fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) fabric.launch() @@ -274,8 +272,10 @@ def test_save_load_sharded_checkpoint(tmp_path): # save a sharded model model = fabric.setup(model) + print("Past fabric.setup()") state = {"model": model, "stateful": StatefulThing(), "primitive": 123} fabric.save(tmp_path, state) + print("Past fabric.save()") # assert the file contents if fabric.global_rank == 0: @@ -295,17 +295,21 @@ def test_save_load_sharded_checkpoint(tmp_path): } } torch.testing.assert_close(checkpoint["model"], expected) + print("Past rank0 save checks()") # load its weights into a different sharded model model = MyModel(4) model = fabric.setup(model) + print("Past fabric.setup() 2") state = {"model": model, "stateful": StatefulThing(), "primitive": 321} fabric.load(tmp_path, state) + print("Past fabric.load()") from thunder.distributed import _unshard_params # unshard this model's parameters to compare with the original state dict before sharding _unshard_params(model, model.process_group_for_ddp, True) + print("Past unshard_params") # we loaded rank 0's weights, so this would fail in the other ranks if fabric.global_rank == 0: actual = model.state_dict() @@ -313,6 +317,7 @@ def test_save_load_sharded_checkpoint(tmp_path): assert actual["buf"].device.type == "cuda" actual["buf"] = actual["buf"].to(device="cpu") torch.testing.assert_close(actual, expected) + print("Past rank0 shard checks()") assert state["primitive"] == 123