diff --git a/docs/tutorials/qml/ml_tools.md b/docs/tutorials/qml/ml_tools.md index 019224ae..5efb6cb7 100644 --- a/docs/tutorials/qml/ml_tools.md +++ b/docs/tutorials/qml/ml_tools.md @@ -59,7 +59,7 @@ use `train_with_grad` as example but the code can be used directly with the grad As every other training routine commonly used in Machine Learning, it requires `model`, `data` and an `optimizer` as input arguments. However, in addition, it requires a `loss_fn` and a `TrainConfig`. -A `loss_fn` is required to be a function which expects both a model and data and returns a tuple of (loss, metrics: ``), where `metrics` is a dict of scalars which can be customized too. +A `loss_fn` is required to be a function which expects both a model and data and returns a tuple of (loss, metrics: ``, ...), where `metrics` is a dict of scalars which can be customized too. It can optionally also return additional values which are utilised by the corresponding user-provided `optimize_step` function inside `train_with_grad`. ```python exec="on" source="material-block" import torch diff --git a/qadence/ml_tools/train_grad.py b/qadence/ml_tools/train_grad.py index 8cd32e8d..86c9ed8e 100644 --- a/qadence/ml_tools/train_grad.py +++ b/qadence/ml_tools/train_grad.py @@ -48,7 +48,7 @@ def train( the model optimizer: The optimizer to use. config: `TrainConfig` with additional training options. - loss_fn: Loss function returning (loss: float, metrics: dict[str, float]) + loss_fn: Loss function returning (loss: float, metrics: dict[str, float], ...) device: String defining device to train on, pass 'cuda' for GPU. optimize_step: Customizable optimization callback which is called at every iteration.= The function must have the signature `optimize_step(model, @@ -224,7 +224,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d if iteration % config.val_every == 0: xs = next(dl_iter_val) xs_to_device = data_to_device(xs, device=device, dtype=data_dtype) - val_loss, _ = loss_fn(model, xs_to_device) + val_loss, *_ = loss_fn(model, xs_to_device) if config.validation_criterion(val_loss, best_val_loss, config.val_epsilon): # type: ignore[misc] best_val_loss = val_loss if config.folder and config.checkpoint_best_only: @@ -245,7 +245,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d try: xs = next(dl_iter) if dataloader is not None else None # type: ignore[arg-type] xs_to_device = data_to_device(xs, device=device, dtype=data_dtype) - loss, metrics = loss_fn(model, xs_to_device) + loss, metrics, *_ = loss_fn(model, xs_to_device) if iteration % config.print_every == 0 and config.verbose: print_metrics(loss, metrics, iteration)