Skip to content

Commit

Permalink
Merge branch 'main' into ifg/fix_ckpt_match
Browse files Browse the repository at this point in the history
  • Loading branch information
inafergra authored Jul 1, 2024
2 parents fd3d013 + 70e12d4 commit ffaa662
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 14 deletions.
7 changes: 5 additions & 2 deletions qadence/ml_tools/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ def print_metrics(loss: float | None, metrics: dict, iteration: int) -> None:
print(msg)


def write_tensorboard(writer: SummaryWriter, loss: float, metrics: dict, iteration: int) -> None:
writer.add_scalar("loss", loss, iteration)
def write_tensorboard(
writer: SummaryWriter, loss: float = None, metrics: dict = {}, iteration: int = 0
) -> None:
if loss is not None:
writer.add_scalar("loss", loss, iteration)
for key, arg in metrics.items():
writer.add_scalar(key, arg, iteration)

Expand Down
47 changes: 41 additions & 6 deletions qadence/ml_tools/train_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,32 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
data_dtype = float64 if dtype == complex128 else float32

best_val_loss = math.inf

with progress:
dl_iter = iter(dataloader) if dataloader is not None else None
if perform_val:
dl_iter_val = iter(val_dataloader) if val_dataloader is not None else None

# Initial validation evaluation
try:
if perform_val:
dl_iter_val = iter(val_dataloader) if val_dataloader is not None else None
xs = next(dl_iter_val)
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
best_val_loss, metrics = loss_fn(model, xs_to_device)

metrics["val_loss"] = best_val_loss
write_tensorboard(writer, None, metrics, init_iter)

if config.folder:
if config.checkpoint_best_only:
write_checkpoint(config.folder, model, optimizer, iteration="best")
else:
write_checkpoint(config.folder, model, optimizer, init_iter)

except KeyboardInterrupt:
logger.info("Terminating training gracefully after the current iteration.")

# outer epoch loop
init_iter += 1
for iteration in progress.track(range(init_iter, init_iter + config.max_iter)):
try:
# in case there is not data needed by the model
Expand Down Expand Up @@ -192,10 +212,13 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
)

if iteration % config.print_every == 0 and config.verbose:
print_metrics(loss, metrics, iteration)
# Note that the loss returned by optimize_step
# is the value before doing the training step
# which is printed accordingly by the previous iteration number
print_metrics(loss, metrics, iteration - 1)

if iteration % config.write_every == 0:
write_tensorboard(writer, loss, metrics, iteration)
write_tensorboard(writer, loss, metrics, iteration - 1)

if perform_val:
if iteration % config.val_every == 0:
Expand All @@ -207,7 +230,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
if config.folder and config.checkpoint_best_only:
write_checkpoint(config.folder, model, optimizer, iteration="best")
metrics["val_loss"] = val_loss
write_tensorboard(writer, math.nan, metrics, iteration)
write_tensorboard(writer, None, metrics, iteration)

if config.folder:
if iteration % config.checkpoint_every == 0 and not config.checkpoint_best_only:
Expand All @@ -217,7 +240,19 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
logger.info("Terminating training gracefully after the current iteration.")
break

# Final writing and checkpointing
# Handling printing the last training loss
# as optimize_step does not give the loss value at the last iteration
try:
xs = next(dl_iter) if dataloader is not None else None # type: ignore[arg-type]
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
loss, metrics = loss_fn(model, xs_to_device)
if iteration % config.print_every == 0 and config.verbose:
print_metrics(loss, metrics, iteration)

except KeyboardInterrupt:
logger.info("Terminating training gracefully after the current iteration.")

# Final printing, writing and checkpointing
if config.folder and not config.checkpoint_best_only:
write_checkpoint(config.folder, model, optimizer, iteration)
write_tensorboard(writer, loss, metrics, iteration)
Expand Down
10 changes: 5 additions & 5 deletions tests/ml_tools/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
n_epochs = 100
config = TrainConfig(folder=tmp_path, max_iter=n_epochs, checkpoint_every=100, write_every=100)
train_with_grad(model, data, optimizer, config, loss_fn=loss_fn)
assert next(cnt) == n_epochs
assert next(cnt) == (n_epochs + 1)

x = torch.rand(5, 1)
assert torch.allclose(torch.sin(x), model(x), rtol=1e-1, atol=1e-1)
Expand Down Expand Up @@ -110,7 +110,7 @@ def loss_fn(model: torch.nn.Module, xs: Any = None) -> tuple[torch.Tensor, dict]
write_every=100,
)
train_with_grad(model, data, optimizer, config, loss_fn=loss_fn)
assert next(cnt) == n_epochs
assert next(cnt) == (n_epochs + 1)

out = model()
assert torch.allclose(out, torch.zeros(1), atol=1e-2, rtol=1e-2)
Expand Down Expand Up @@ -139,7 +139,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
folder=tmp_path, max_iter=n_epochs, print_every=10, checkpoint_every=100, write_every=100
)
train_with_grad(model, data, optimizer, config, loss_fn=loss_fn)
assert next(cnt) == n_epochs
assert next(cnt) == (n_epochs + 1)

x = torch.rand(5, 1)
assert torch.allclose(torch.sin(x), model(x), rtol=1e-1, atol=1e-1)
Expand Down Expand Up @@ -173,7 +173,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
)
data = to_dataloader(x, y, batch_size=batch_size, infinite=True)
model, _ = train_with_grad(model, data, optimizer, config, loss_fn=loss_fn, dtype=dtype)
assert next(cnt) == n_epochs
assert next(cnt) == (n_epochs + 1)

x = torch.rand(5, 1, dtype=torch.float32)
assert torch.allclose(torch.sin(x), model(x), rtol=1e-1, atol=1e-1)
Expand Down Expand Up @@ -279,7 +279,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d

config = get_train_config_validation(tmp_path, n_epochs, checkpoint_every, val_every)
train_with_grad(model, data, optimizer, config, loss_fn=loss_fn)
assert next(cnt) == n_epochs + n_epochs // val_every
assert next(cnt) == 2 + n_epochs + n_epochs // val_every

files = [f for f in os.listdir(tmp_path) if f.endswith(".pt") and "model" in f]
# Ideally it can be ensured if the (only) saved checkpoint is indeed the best,
Expand Down
2 changes: 1 addition & 1 deletion tests/qadence/test_error_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_bitstring_corruption_mixed_bitflips(
corrupted_counters = [bs_corruption(err_idx=err_idx, sample=sample)]
for noiseless, noisy in zip(counters, corrupted_counters):
assert sum(noisy.values()) == n_shots
assert js_divergence(noiseless, noisy) > 0.0
assert js_divergence(noiseless, noisy) >= 0.0


@pytest.mark.flaky(max_runs=5)
Expand Down

0 comments on commit ffaa662

Please sign in to comment.