diff --git a/tests/test_pretrain.py b/tests/test_pretrain.py index 9a3e11624c..61c67608af 100644 --- a/tests/test_pretrain.py +++ b/tests/test_pretrain.py @@ -69,9 +69,7 @@ def test_pretrain(_, tmp_path): # Set CUDA_VISIBLE_DEVICES for FSDP hybrid-shard, if fewer GPUs are used than are available @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @mock.patch("litgpt.pretrain.L.Fabric.load_raw") -# If we were to use `save_hyperparameters()`, we would have to patch `sys.argv` or otherwise -# the CLI would capture pytest args, but unfortunately patching would mess with subprocess -# launching, so we need to mock `save_hyperparameters()` +# See comment in `test_pretrain` why we need to mock `save_hyperparameters()` @mock.patch("litgpt.pretrain.save_hyperparameters") def test_initial_checkpoint_dir(_, load_mock, tmp_path): model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)