Skip to content

Commit

Permalink
Test all current on mlflow too
Browse files Browse the repository at this point in the history
  • Loading branch information
debrevitatevitae committed Jul 10, 2024
1 parent 8d630cc commit eaab95a
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions tests/ml_tools/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict
assert torch.allclose(loaded_model.expectation({}), model.expectation({}))


def test_check_ckpts_exist(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None:
@pytest.mark.parametrize(
"tool", [ExperimentTrackingTool.TENSORBOARD, ExperimentTrackingTool.MLFLOW]
)
def test_check_ckpts_exist(
tool: ExperimentTrackingTool, BasicQuantumModel: QuantumModel, tmp_path: Path
) -> None:
data = dataloader()
model = BasicQuantumModel
cnt = count()
Expand All @@ -92,7 +97,9 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict
loss = criterion(out, torch.rand(1))
return loss, {}

config = TrainConfig(folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1)
config = TrainConfig(
folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1, tracking_tool=tool
)
train_with_grad(model, data, optimizer, config, loss_fn=loss_fn)
ckpts = [tmp_path / Path(f"model_QuantumModel_ckpt_00{i}_device_cpu.pt") for i in range(1, 9)]
assert all(os.path.isfile(ckpt) for ckpt in ckpts)
Expand Down Expand Up @@ -189,8 +196,11 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict
assert torch.allclose(loaded_model.expectation(inputs), model.expectation(inputs))


@pytest.mark.parametrize(
"tool", [ExperimentTrackingTool.TENSORBOARD, ExperimentTrackingTool.MLFLOW]
)
def test_check_transformedmodule_ckpts_exist(
BasicTransformedModule: TransformedModule, tmp_path: Path
tool: ExperimentTrackingTool, BasicTransformedModule: TransformedModule, tmp_path: Path
) -> None:
data = dataloader()
model = BasicTransformedModule
Expand All @@ -205,7 +215,9 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict
loss = criterion(out, torch.rand(1))
return loss, {}

config = TrainConfig(folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1)
config = TrainConfig(
folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1, tracking_tool=tool
)
train_with_grad(model, data, optimizer, config, loss_fn=loss_fn)
ckpts = [
tmp_path / Path(f"model_TransformedModule_ckpt_00{i}_device_cpu.pt") for i in range(1, 9)
Expand Down

0 comments on commit eaab95a

Please sign in to comment.