diff --git a/qadence/ml_tools/printing.py b/qadence/ml_tools/printing.py index 9ae63677..2983312d 100644 --- a/qadence/ml_tools/printing.py +++ b/qadence/ml_tools/printing.py @@ -82,7 +82,7 @@ def plot_mlflow( def log_model_mlflow( writer: Any, model: Module, - dataloader: Union[None, DataLoader, DictDataLoader], + dataloader: DataLoader | DictDataLoader | None ) -> None: if dataloader is not None: xs: InputData