Skip to content

Commit

Permalink
Fix current ml tracking tests
Browse files Browse the repository at this point in the history
  • Loading branch information
debrevitatevitae committed Jul 10, 2024
1 parent 72f1098 commit 936021e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
17 changes: 7 additions & 10 deletions qadence/ml_tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@

import datetime
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path
from typing import Callable, Optional
from uuid import uuid4

from matplotlib.figure import Figure
from torch.nn import Module

from qadence.types import ExperimentTrackingTool

logger = getLogger(__name__)
Expand Down Expand Up @@ -38,7 +35,7 @@ class TrainConfig:
"""Write tensorboard logs."""
checkpoint_every: int = 5000
"""Write model/optimizer checkpoint."""
plot_every: int | None = None
plot_every: int = 5000
"""Write figures.
NOTE: currently only works with mlflow.
Expand All @@ -64,7 +61,7 @@ class TrainConfig:
"""The tracking tool of choice."""
hyperparams: dict | None = None
"""Hyperparameters to track."""
plotting_functions: tuple[Callable[[Module, int], tuple[str, Figure]]] | None = None
plotting_functions: tuple[Callable] = field(default_factory=tuple) # type: ignore
"""Functions for in-train plotting."""

# mlflow_callbacks: list[Callable] = [write_mlflow_figure(), write_x()]
Expand All @@ -82,10 +79,10 @@ def __post_init__(self) -> None:
self.trainstop_criterion = lambda x: x <= self.max_iter
if self.validation_criterion is None:
self.validation_criterion = lambda x: False
if self.plot_every and self.tracking_tool != ExperimentTrackingTool.MLFLOW:
raise NotImplementedError("In-training plots are only available with mlflow tracking.")
if self.plot_every is not None and self.plotting_functions is None:
logger.warning("Plots tracking is required, but no plotting functions are provided.")
if self.plotting_functions and self.tracking_tool != ExperimentTrackingTool.MLFLOW:
logger.warning("In-training plots are only available with mlflow tracking.")
if not self.plotting_functions and self.tracking_tool == ExperimentTrackingTool.MLFLOW:
logger.warning("Tracking with mlflow, but no plotting functions provided.")


@dataclass
Expand Down
13 changes: 8 additions & 5 deletions qadence/ml_tools/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import Any, Callable

from matplotlib.figure import Figure
from torch.nn import Module
from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -30,9 +29,13 @@ def log_hyperparams_tensorboard(writer: SummaryWriter, hyperparams: dict, metric


def plot_tensorboard(
writer: SummaryWriter, iteration: int, plotting_functions: tuple[Callable]
writer: SummaryWriter,
model: Module,
iteration: int,
plotting_functions: tuple[Callable],
) -> None:
raise NotImplementedError("Plot logging with tensorboard is not implemented")
# TODO: implement me
pass


def write_mlflow(writer: Any, loss: float | None, metrics: dict, iteration: int) -> None:
Expand All @@ -45,10 +48,10 @@ def log_hyperparams_mlflow(writer: Any, hyperparams: dict, metrics: dict) -> Non


def plot_mlflow(
writer: SummaryWriter,
writer: Any,
model: Module,
iteration: int,
plotting_functions: tuple[Callable[[Module, int], tuple[str, Figure]]],
plotting_functions: tuple[Callable],
) -> None:
for pf in plotting_functions:
descr, fig = pf(model, iteration)
Expand Down

0 comments on commit 936021e

Please sign in to comment.