Skip to content

Commit

Permalink
support mlflow system metrics logging
Browse files Browse the repository at this point in the history
  • Loading branch information
khintz committed Oct 7, 2024
1 parent 27408f2 commit e61a9e7
Showing 1 changed file with 7 additions and 31 deletions.
38 changes: 7 additions & 31 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from lightning_fabric.utilities import seed
from loguru import logger

import mlflow

# Local
from . import utils
from .config import load_config_and_datastore
Expand All @@ -24,29 +26,22 @@

class CustomMLFlowLogger(pl.loggers.MLFlowLogger):

def __init__(self, experiment_name, tracking_uri):
super().__init__(experiment_name=experiment_name, tracking_uri=tracking_uri)
mlflow.start_run(run_id=self.run_id, log_system_metrics=True)


def log_image(self, key, images):
import mlflow
import io
from PIL import Image

# Retrieve the active run ID from the logger
run_id = self.run_id
# Ensure mlflow uses the same run
mlflow.start_run(run_id=run_id)

# Need to save the image to a temporary file, then log that file
# mlflow.log_image, should do this automatically, but it doesn't work
temporary_image = f"{key}.png"
images[0].savefig(temporary_image)

img = Image.open(temporary_image)
print(images)
print(images[0])
mlflow.log_image(img, f"{key}.png")

#mlflow.log_figure(images[0], key)
mlflow.end_run()


def _setup_training_logger(config, datastore, args, run_name):
if config.training.logger == "wandb":
Expand All @@ -69,33 +64,14 @@ def _setup_training_logger(config, datastore, args, run_name):
experiment_name=args.wandb_project,
tracking_uri=url,
)
print(logger)
logger.log_hyperparams(
dict(training=vars(args), datastore=datastore._config)
)
print("Logged hyperparams")
print(run_name)

print(logger.__str__)
# logger.log_image = log_image

return logger


# def log_image(key, images):
# import mlflow

# # Log the image
# # https://learn.microsoft.com/en-us/azure/machine-learning/how-to-log-view-metrics?view=azureml-api-2&tabs=interactive#log-images
# # For mlflow a matplotlib figure should use log_figure instead of log_image
# # Need to save the image to a temporary file, then log that file
# # mlflow.log_image, should do this automatically, but it doesn't work
# temporary_image = f"/tmp/key.png"
# images[0].savefig(temporary_image)

# mlflow.log_figure(temporary_image, key)
# mlflow.log_figure(images[0], key)


@logger.catch
def main(input_args=None):
Expand Down

0 comments on commit e61a9e7

Please sign in to comment.