Skip to content

Commit

Permalink
more fixes for mlflow logging support
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Oct 3, 2024
1 parent a921e35 commit 0f30259
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import pytorch_lightning as pl
import torch
import wandb

# Local
from .. import metrics, vis
Expand Down Expand Up @@ -467,9 +466,9 @@ def plot_examples(self, batch, n_examples, prediction=None):
]

example_i = self.plotted_examples
wandb.log(
self.logger.log_image(
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
f"{var_name}_example_{example_i}": fig
for var_name, fig in zip(
self._datastore.get_vars_names("state"), var_figs
)
Expand All @@ -483,13 +482,15 @@ def plot_examples(self, batch, n_examples, prediction=None):
torch.save(
pred_slice.cpu(),
os.path.join(
wandb.run.dir, f"example_pred_{self.plotted_examples}.pt"
self.logger.save_dir,
f"example_pred_{self.plotted_examples}.pt",
),
)
torch.save(
target_slice.cpu(),
os.path.join(
wandb.run.dir, f"example_target_{self.plotted_examples}.pt"
self.logger.save_dir,
f"example_target_{self.plotted_examples}.pt",
),
)

Expand All @@ -510,16 +511,16 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
datastore=self._datastore,
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
log_dict[full_log_name] = metric_fig

if prefix == "test":
# Save pdf
metric_fig.savefig(
os.path.join(wandb.run.dir, f"{full_log_name}.pdf")
os.path.join(self.logger.save_dir, f"{full_log_name}.pdf")
)
# Save errors also as csv
np.savetxt(
os.path.join(wandb.run.dir, f"{full_log_name}.csv"),
os.path.join(self.logger.save_dir, f"{full_log_name}.csv"),
metric_tensor.cpu().numpy(),
delimiter=",",
)
Expand Down Expand Up @@ -568,7 +569,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
)

if self.trainer.is_global_zero and not self.trainer.sanity_checking:
wandb.log(log_dict) # Log all
self.logger.log_image(log_dict) # Log all
plt.close("all") # Close all figs

def on_test_epoch_end(self):
Expand Down Expand Up @@ -599,9 +600,9 @@ def on_test_epoch_end(self):
)
]

# log all to same wandb key, sequentially
# log all to same key, sequentially
for fig in loss_map_figs:
wandb.log({"test_loss": wandb.Image(fig)})
self.logger.log_image({"test_loss": fig})

# also make without title and save as pdf
pdf_loss_map_figs = [
Expand All @@ -610,14 +611,16 @@ def on_test_epoch_end(self):
)
for loss_map in mean_spatial_loss
]
pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
pdf_loss_maps_dir = os.path.join(
self.logger.save_dir, "spatial_loss_maps"
)
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs):
fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf"))
# save mean spatial loss as .pt file also
torch.save(
mean_spatial_loss.cpu(),
os.path.join(wandb.run.dir, "mean_spatial_loss.pt"),
os.path.join(self.logger.save_dir, "mean_spatial_loss.pt"),
)

self.spatial_loss_maps.clear()
Expand Down

0 comments on commit 0f30259

Please sign in to comment.