diff --git a/qadence/ml_tools/config.py b/qadence/ml_tools/config.py index 36877c91..89053cad 100644 --- a/qadence/ml_tools/config.py +++ b/qadence/ml_tools/config.py @@ -2,15 +2,12 @@ import datetime import os -from dataclasses import dataclass +from dataclasses import dataclass, field from logging import getLogger from pathlib import Path from typing import Callable, Optional from uuid import uuid4 -from matplotlib.figure import Figure -from torch.nn import Module - from qadence.types import ExperimentTrackingTool logger = getLogger(__name__) @@ -38,7 +35,7 @@ class TrainConfig: """Write tensorboard logs.""" checkpoint_every: int = 5000 """Write model/optimizer checkpoint.""" - plot_every: int | None = None + plot_every: int = 5000 """Write figures. NOTE: currently only works with mlflow. @@ -64,7 +61,7 @@ class TrainConfig: """The tracking tool of choice.""" hyperparams: dict | None = None """Hyperparameters to track.""" - plotting_functions: tuple[Callable[[Module, int], tuple[str, Figure]]] | None = None + plotting_functions: tuple[Callable] = field(default_factory=tuple) # type: ignore """Functions for in-train plotting.""" # mlflow_callbacks: list[Callable] = [write_mlflow_figure(), write_x()] @@ -82,10 +79,10 @@ def __post_init__(self) -> None: self.trainstop_criterion = lambda x: x <= self.max_iter if self.validation_criterion is None: self.validation_criterion = lambda x: False - if self.plot_every and self.tracking_tool != ExperimentTrackingTool.MLFLOW: - raise NotImplementedError("In-training plots are only available with mlflow tracking.") - if self.plot_every is not None and self.plotting_functions is None: - logger.warning("Plots tracking is required, but no plotting functions are provided.") + if self.plotting_functions and self.tracking_tool != ExperimentTrackingTool.MLFLOW: + logger.warning("In-training plots are only available with mlflow tracking.") + if not self.plotting_functions and self.tracking_tool == ExperimentTrackingTool.MLFLOW: + logger.warning("Tracking with mlflow, but no plotting functions provided.") @dataclass diff --git a/qadence/ml_tools/printing.py b/qadence/ml_tools/printing.py index 09ee39d8..ed833307 100644 --- a/qadence/ml_tools/printing.py +++ b/qadence/ml_tools/printing.py @@ -2,7 +2,6 @@ from typing import Any, Callable -from matplotlib.figure import Figure from torch.nn import Module from torch.utils.tensorboard import SummaryWriter @@ -30,9 +29,13 @@ def log_hyperparams_tensorboard(writer: SummaryWriter, hyperparams: dict, metric def plot_tensorboard( - writer: SummaryWriter, iteration: int, plotting_functions: tuple[Callable] + writer: SummaryWriter, + model: Module, + iteration: int, + plotting_functions: tuple[Callable], ) -> None: - raise NotImplementedError("Plot logging with tensorboard is not implemented") + # TODO: implement me + pass def write_mlflow(writer: Any, loss: float | None, metrics: dict, iteration: int) -> None: @@ -45,10 +48,10 @@ def log_hyperparams_mlflow(writer: Any, hyperparams: dict, metrics: dict) -> Non def plot_mlflow( - writer: SummaryWriter, + writer: Any, model: Module, iteration: int, - plotting_functions: tuple[Callable[[Module, int], tuple[str, Figure]]], + plotting_functions: tuple[Callable], ) -> None: for pf in plotting_functions: descr, fig = pf(model, iteration)