Skip to content

Commit

Permalink
[Feature] Allow loss_fn in train_grad to return arbitrary number of v…
Browse files Browse the repository at this point in the history
…alues. (#499)
  • Loading branch information
smitchaudhary authored Jul 16, 2024
1 parent 1d7e859 commit 01d67dd
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/qml/ml_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use `train_with_grad` as example but the code can be used directly with the grad
As every other training routine commonly used in Machine Learning, it requires
`model`, `data` and an `optimizer` as input arguments.
However, in addition, it requires a `loss_fn` and a `TrainConfig`.
A `loss_fn` is required to be a function which expects both a model and data and returns a tuple of (loss, metrics: `<dict>`), where `metrics` is a dict of scalars which can be customized too.
A `loss_fn` is required to be a function which expects both a model and data and returns a tuple of (loss, metrics: `<dict>`, ...), where `metrics` is a dict of scalars which can be customized too. It can optionally also return additional values which are utilised by the corresponding user-provided `optimize_step` function inside `train_with_grad`.

```python exec="on" source="material-block"
import torch
Expand Down
6 changes: 3 additions & 3 deletions qadence/ml_tools/train_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def train(
the model
optimizer: The optimizer to use.
config: `TrainConfig` with additional training options.
loss_fn: Loss function returning (loss: float, metrics: dict[str, float])
loss_fn: Loss function returning (loss: float, metrics: dict[str, float], ...)
device: String defining device to train on, pass 'cuda' for GPU.
optimize_step: Customizable optimization callback which is called at every iteration.=
The function must have the signature `optimize_step(model,
Expand Down Expand Up @@ -224,7 +224,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
if iteration % config.val_every == 0:
xs = next(dl_iter_val)
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
val_loss, _ = loss_fn(model, xs_to_device)
val_loss, *_ = loss_fn(model, xs_to_device)
if config.validation_criterion(val_loss, best_val_loss, config.val_epsilon): # type: ignore[misc]
best_val_loss = val_loss
if config.folder and config.checkpoint_best_only:
Expand All @@ -245,7 +245,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
try:
xs = next(dl_iter) if dataloader is not None else None # type: ignore[arg-type]
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
loss, metrics = loss_fn(model, xs_to_device)
loss, metrics, *_ = loss_fn(model, xs_to_device)
if iteration % config.print_every == 0 and config.verbose:
print_metrics(loss, metrics, iteration)

Expand Down

0 comments on commit 01d67dd

Please sign in to comment.