Skip to content

Commit

Permalink
convert loss to float
Browse files Browse the repository at this point in the history
  • Loading branch information
inafergra committed Jul 18, 2024
1 parent 01d67dd commit cabef29
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions qadence/ml_tools/train_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
xs = next(dl_iter) if dataloader is not None else None # type: ignore[arg-type]
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
loss, metrics, *_ = loss_fn(model, xs_to_device)
if dataloader is None:
loss = loss.item()
if iteration % config.print_every == 0 and config.verbose:
print_metrics(loss, metrics, iteration)

Expand Down

0 comments on commit cabef29

Please sign in to comment.