From 3fbe2d095c58c085f635b1c0de9201c391c4082e Mon Sep 17 00:00:00 2001 From: Kasper Hintz Date: Thu, 3 Oct 2024 12:07:26 +0000 Subject: [PATCH] Make wandb work again with pytorch_lightning.logger --- neural_lam/models/ar_model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 54d01e20..45cbd247 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -568,8 +568,16 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): ) ) + # Ensure that log_dict has structure for logging as dict(str, plt.Figure) + assert all( + isinstance(key, str) and isinstance(value, plt.Figure) + for key, value in log_dict.items() + ) + if self.trainer.is_global_zero and not self.trainer.sanity_checking: - self.logger.log_image(log_dict) # Log all + for key, figure in log_dict.items(): + self.logger.log_image(key=key, images=[figure]) + plt.close("all") # Close all figs def on_test_epoch_end(self):