diff --git a/tests/ml_tools/test_checkpointing.py b/tests/ml_tools/test_checkpointing.py index 61e0a96a..5ba78560 100644 --- a/tests/ml_tools/test_checkpointing.py +++ b/tests/ml_tools/test_checkpointing.py @@ -79,7 +79,12 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict assert torch.allclose(loaded_model.expectation({}), model.expectation({})) -def test_check_ckpts_exist(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None: +@pytest.mark.parametrize( + "tool", [ExperimentTrackingTool.TENSORBOARD, ExperimentTrackingTool.MLFLOW] +) +def test_check_ckpts_exist( + tool: ExperimentTrackingTool, BasicQuantumModel: QuantumModel, tmp_path: Path +) -> None: data = dataloader() model = BasicQuantumModel cnt = count() @@ -92,7 +97,9 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict loss = criterion(out, torch.rand(1)) return loss, {} - config = TrainConfig(folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1) + config = TrainConfig( + folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1, tracking_tool=tool + ) train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) ckpts = [tmp_path / Path(f"model_QuantumModel_ckpt_00{i}_device_cpu.pt") for i in range(1, 9)] assert all(os.path.isfile(ckpt) for ckpt in ckpts) @@ -189,8 +196,11 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict assert torch.allclose(loaded_model.expectation(inputs), model.expectation(inputs)) +@pytest.mark.parametrize( + "tool", [ExperimentTrackingTool.TENSORBOARD, ExperimentTrackingTool.MLFLOW] +) def test_check_transformedmodule_ckpts_exist( - BasicTransformedModule: TransformedModule, tmp_path: Path + tool: ExperimentTrackingTool, BasicTransformedModule: TransformedModule, tmp_path: Path ) -> None: data = dataloader() model = BasicTransformedModule @@ -205,7 +215,9 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict loss = criterion(out, torch.rand(1)) return loss, {} - config = TrainConfig(folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1) + config = TrainConfig( + folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1, tracking_tool=tool + ) train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) ckpts = [ tmp_path / Path(f"model_TransformedModule_ckpt_00{i}_device_cpu.pt") for i in range(1, 9)