diff --git a/tests/test_thunder_pretrain.py b/tests/test_thunder_pretrain.py index 30f9d71afb..2668c7e270 100644 --- a/tests/test_thunder_pretrain.py +++ b/tests/test_thunder_pretrain.py @@ -37,6 +37,7 @@ def test_pretrain(tmp_path, monkeypatch): out_dir=out_dir, train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0), eval=EvalArgs(interval=1, max_iters=1), + optimizer: Union[str, Dict] = "AdamW", ) out_dir_contents = set(os.listdir(out_dir))