Skip to content

Commit

Permalink
Log num params to W&B (#741)
Browse files Browse the repository at this point in the history
* log num params

* typo

* check logger exists first

* update tb loggers
  • Loading branch information
rayg1234 authored Jun 26, 2024
1 parent c001d79 commit 3d04ca9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/fairchem/core/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import logging
from abc import ABC, abstractmethod
from typing import Any

import torch
import wandb
Expand Down Expand Up @@ -51,6 +52,10 @@ def log_plots(self, plots) -> None:
def mark_preempting(self) -> None:
pass

@abstractmethod
def log_summary(self, summary_dict: dict[str, Any]) -> None:
pass


@registry.register_logger("wandb")
class WandBLogger(Logger):
Expand Down Expand Up @@ -89,6 +94,10 @@ def log_plots(self, plots, caption: str = "") -> None:
plots = [wandb.Image(x, caption=caption) for x in plots]
wandb.log({"data": plots})

def log_summary(self, summary_dict: dict[str, Any]):
for k, v in summary_dict.items():
wandb.run.summary[k] = v

def mark_preempting(self) -> None:
wandb.mark_preempting()

Expand All @@ -114,7 +123,10 @@ def log(self, update_dict, step: int, split: str = ""):
self.writer.add_scalar(key, update_dict[key], step)

def mark_preempting(self) -> None:
pass
logging.warning("mark_preempting for Tensorboard not supported")

def log_plots(self, plots) -> None:
pass
logging.warning("log_plots for Tensorboard not supported")

def log_summary(self, summary_dict: dict[str, Any]) -> None:
logging.warning("log_summary for Tensorboard not supported")
1 change: 1 addition & 0 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def load_model(self) -> None:

if self.logger is not None:
self.logger.watch(self.model)
self.logger.log_summary({"num_params": self.model.num_params})

if distutils.initialized() and not self.config["noddp"]:
self.model = DistributedDataParallel(self.model, device_ids=[self.device])
Expand Down

0 comments on commit 3d04ca9

Please sign in to comment.