From 7d96d1c56f9317d8da298a2a676aec509d51c165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20Fern=C3=A1ndez=20Gra=C3=B1a?= <51716758+inafergra@users.noreply.github.com> Date: Thu, 18 Jul 2024 12:49:32 +0200 Subject: [PATCH] [Bug] Fix printing format error in train_grad (#504) --- qadence/ml_tools/train_grad.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/qadence/ml_tools/train_grad.py b/qadence/ml_tools/train_grad.py index 86c9ed8e..6ddfb903 100644 --- a/qadence/ml_tools/train_grad.py +++ b/qadence/ml_tools/train_grad.py @@ -246,6 +246,8 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d xs = next(dl_iter) if dataloader is not None else None # type: ignore[arg-type] xs_to_device = data_to_device(xs, device=device, dtype=data_dtype) loss, metrics, *_ = loss_fn(model, xs_to_device) + if dataloader is None: + loss = loss.item() if iteration % config.print_every == 0 and config.verbose: print_metrics(loss, metrics, iteration)