From cabef298ef8da4dd781330a6840069697654f49e Mon Sep 17 00:00:00 2001 From: Ignacio Date: Thu, 18 Jul 2024 12:05:30 +0200 Subject: [PATCH] convert loss to float --- qadence/ml_tools/train_grad.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/qadence/ml_tools/train_grad.py b/qadence/ml_tools/train_grad.py index 86c9ed8e..6ddfb903 100644 --- a/qadence/ml_tools/train_grad.py +++ b/qadence/ml_tools/train_grad.py @@ -246,6 +246,8 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d xs = next(dl_iter) if dataloader is not None else None # type: ignore[arg-type] xs_to_device = data_to_device(xs, device=device, dtype=data_dtype) loss, metrics, *_ = loss_fn(model, xs_to_device) + if dataloader is None: + loss = loss.item() if iteration % config.print_every == 0 and config.verbose: print_metrics(loss, metrics, iteration)