From fd3d0135af9763c3368b622698bd9104b7e78647 Mon Sep 17 00:00:00 2001 From: Ignacio Date: Mon, 1 Jul 2024 18:48:04 +0200 Subject: [PATCH] test get_latest_checkpoint_name() for legacy ckpts --- tests/ml_tools/test_checkpointing.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/tests/ml_tools/test_checkpointing.py b/tests/ml_tools/test_checkpointing.py index 5dfe97306..e8f89a92e 100644 --- a/tests/ml_tools/test_checkpointing.py +++ b/tests/ml_tools/test_checkpointing.py @@ -180,38 +180,24 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict assert torch.allclose(loaded_model.expectation(inputs), model.expectation(inputs)) -def test_random_basicqQM_save_load_legacy_ckpts( - BasicQuantumModel: QuantumModel, tmp_path: Path -) -> None: +def test_basicqQM_save_load_legacy_ckpts(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None: model = BasicQuantumModel optimizer = torch.optim.Adam(model.parameters(), lr=0.1) ps0 = get_parameters(model) write_legacy_checkpoint(tmp_path, model, optimizer, 1) - loaded_model, optimizer, _ = load_checkpoint( - tmp_path, - model, - optimizer, - "model_QuantumModel_ckpt_001.pt", - "opt_Adam_ckpt_001.pt", - ) + loaded_model, optimizer, _ = load_checkpoint(tmp_path, model, optimizer) ps1 = get_parameters(loaded_model) assert not torch.all(torch.isnan(loaded_model.expectation({}))) assert torch.allclose(ps0, ps1) -def test_random_basicqQNN_save_load_legacy_ckpts(BasicQNN: QNN, tmp_path: Path) -> None: +def test_basicqQNN_save_load_legacy_ckpts(BasicQNN: QNN, tmp_path: Path) -> None: model = BasicQNN optimizer = torch.optim.Adam(model.parameters(), lr=0.1) inputs = rand_featureparameters(model, 1) ps0 = get_parameters(model) write_legacy_checkpoint(tmp_path, model, optimizer, 1) - loaded_model, optimizer, _ = load_checkpoint( - tmp_path, - model, - optimizer, - "model_QNN_ckpt_001.pt", - "opt_Adam_ckpt_001.pt", - ) + loaded_model, optimizer, _ = load_checkpoint(tmp_path, model, optimizer) ps1 = get_parameters(loaded_model) assert not torch.all(torch.isnan(model.expectation(inputs))) assert torch.allclose(ps0, ps1)