From e61a9e79b9b34fb980fc015e02926510d06a6265 Mon Sep 17 00:00:00 2001 From: Kasper Hintz Date: Mon, 7 Oct 2024 10:44:51 +0000 Subject: [PATCH] support mlflow system metrics logging --- neural_lam/train_model.py | 38 +++++++------------------------------- 1 file changed, 7 insertions(+), 31 deletions(-) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index fd010f59..ccd74139 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -10,6 +10,8 @@ from lightning_fabric.utilities import seed from loguru import logger +import mlflow + # Local from . import utils from .config import load_config_and_datastore @@ -24,29 +26,22 @@ class CustomMLFlowLogger(pl.loggers.MLFlowLogger): + def __init__(self, experiment_name, tracking_uri): + super().__init__(experiment_name=experiment_name, tracking_uri=tracking_uri) + mlflow.start_run(run_id=self.run_id, log_system_metrics=True) + + def log_image(self, key, images): - import mlflow - import io from PIL import Image - # Retrieve the active run ID from the logger - run_id = self.run_id - # Ensure mlflow uses the same run - mlflow.start_run(run_id=run_id) - # Need to save the image to a temporary file, then log that file # mlflow.log_image, should do this automatically, but it doesn't work temporary_image = f"{key}.png" images[0].savefig(temporary_image) img = Image.open(temporary_image) - print(images) - print(images[0]) mlflow.log_image(img, f"{key}.png") - #mlflow.log_figure(images[0], key) - mlflow.end_run() - def _setup_training_logger(config, datastore, args, run_name): if config.training.logger == "wandb": @@ -69,33 +64,14 @@ def _setup_training_logger(config, datastore, args, run_name): experiment_name=args.wandb_project, tracking_uri=url, ) - print(logger) logger.log_hyperparams( dict(training=vars(args), datastore=datastore._config) ) print("Logged hyperparams") - print(run_name) - - print(logger.__str__) - # logger.log_image = log_image return logger -# def log_image(key, images): -# import mlflow - -# # Log the image -# # https://learn.microsoft.com/en-us/azure/machine-learning/how-to-log-view-metrics?view=azureml-api-2&tabs=interactive#log-images -# # For mlflow a matplotlib figure should use log_figure instead of log_image -# # Need to save the image to a temporary file, then log that file -# # mlflow.log_image, should do this automatically, but it doesn't work -# temporary_image = f"/tmp/key.png" -# images[0].savefig(temporary_image) - -# mlflow.log_figure(temporary_image, key) -# mlflow.log_figure(images[0], key) - @logger.catch def main(input_args=None):