diff --git a/qadence/ml_tools/printing.py b/qadence/ml_tools/printing.py index f9fcabad..7d301d04 100644 --- a/qadence/ml_tools/printing.py +++ b/qadence/ml_tools/printing.py @@ -16,7 +16,7 @@ logger = getLogger(__name__) PlottingFunction = Callable[[Module, int], tuple[str, Figure]] -InputData = Union[Tensor, dict[str, Tensor]] +InputData = Tensor | dict[str, Tensor] def print_metrics(loss: float | None, metrics: dict, iteration: int) -> None: