Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 5, 2024
1 parent 2c4d7b5 commit 396a7b0
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_pretrain(_, tmp_path):
if torch.distributed.get_rank() == 0:
# tmp_path is not the same across all ranks, run assert only on rank 0
out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"step-00000001", "step-00000002", "step-00000003", "step-00000004"}
checkpoint_dirs = {"step-00000001", "step-00000002", "step-00000003", "step-00000004", "final"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
Expand All @@ -69,7 +69,11 @@ def test_pretrain(_, tmp_path):
# Set CUDA_VISIBLE_DEVICES for FSDP hybrid-shard, if fewer GPUs are used than are available
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
@mock.patch("litgpt.pretrain.L.Fabric.load_raw")
def test_initial_checkpoint_dir(load_mock, tmp_path):
# If we were to use `save_hyperparameters()`, we would have to patch `sys.argv` or otherwise
# the CLI would capture pytest args, but unfortunately patching would mess with subprocess
# launching, so we need to mock `save_hyperparameters()`
@mock.patch("litgpt.pretrain.save_hyperparameters")
def test_initial_checkpoint_dir(_, load_mock, tmp_path):
model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)

dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])
Expand Down

0 comments on commit 396a7b0

Please sign in to comment.