Skip to content

Commit

Permalink
test get_latest_checkpoint_name() for legacy ckpts
Browse files Browse the repository at this point in the history
  • Loading branch information
inafergra committed Jul 1, 2024
1 parent 6635131 commit fd3d013
Showing 1 changed file with 4 additions and 18 deletions.
22 changes: 4 additions & 18 deletions tests/ml_tools/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,38 +180,24 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict
assert torch.allclose(loaded_model.expectation(inputs), model.expectation(inputs))


def test_random_basicqQM_save_load_legacy_ckpts(
BasicQuantumModel: QuantumModel, tmp_path: Path
) -> None:
def test_basicqQM_save_load_legacy_ckpts(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None:
model = BasicQuantumModel
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
ps0 = get_parameters(model)
write_legacy_checkpoint(tmp_path, model, optimizer, 1)
loaded_model, optimizer, _ = load_checkpoint(
tmp_path,
model,
optimizer,
"model_QuantumModel_ckpt_001.pt",
"opt_Adam_ckpt_001.pt",
)
loaded_model, optimizer, _ = load_checkpoint(tmp_path, model, optimizer)
ps1 = get_parameters(loaded_model)
assert not torch.all(torch.isnan(loaded_model.expectation({})))
assert torch.allclose(ps0, ps1)


def test_random_basicqQNN_save_load_legacy_ckpts(BasicQNN: QNN, tmp_path: Path) -> None:
def test_basicqQNN_save_load_legacy_ckpts(BasicQNN: QNN, tmp_path: Path) -> None:
model = BasicQNN
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
inputs = rand_featureparameters(model, 1)
ps0 = get_parameters(model)
write_legacy_checkpoint(tmp_path, model, optimizer, 1)
loaded_model, optimizer, _ = load_checkpoint(
tmp_path,
model,
optimizer,
"model_QNN_ckpt_001.pt",
"opt_Adam_ckpt_001.pt",
)
loaded_model, optimizer, _ = load_checkpoint(tmp_path, model, optimizer)
ps1 = get_parameters(loaded_model)
assert not torch.all(torch.isnan(model.expectation(inputs)))
assert torch.allclose(ps0, ps1)

0 comments on commit fd3d013

Please sign in to comment.