diff --git a/src/fairchem/core/common/logger.py b/src/fairchem/core/common/logger.py index 97199c15c..cab3e7be4 100644 --- a/src/fairchem/core/common/logger.py +++ b/src/fairchem/core/common/logger.py @@ -148,3 +148,89 @@ def log_summary(self, summary_dict: dict[str, Any]) -> None: def log_artifact(self, name: str, type: str, file_location: str) -> None: logging.warning("log_artifact for Tensorboard not supported") + + +class WandBSingletonLogger: + """ + Singleton version of wandb logger, this forces a single instance of the logger to be created and used from anywhere in the code (not just from the trainer). + This will replace the original WandBLogger. + + We initialize wandb instance somewhere in the trainer/runner globally: + + WandBSingletonLogger.init_wandb(...) + + Then from anywhere in the code we can fetch the singleton instance and log to wandb, + note this allows you to log without knowing explicitly which step you are on + see: https://docs.wandb.ai/ref/python/log/#the-wb-step for more details + + WandBSingletonLogger.get_instance().log({"some_value": value}, commit=False) + """ + + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + @classmethod + def init_wandb( + cls, + config: dict, + run_id: str, + run_name: str, + log_dir: str, + project: str, + entity: str, + group: str | None = None, + ) -> None: + wandb.init( + config=config, + id=run_id, + name=run_name, + dir=log_dir, + project=project, + entity=entity, + resume="allow", + group=group, + ) + + @classmethod + def get_instance(cls): + assert wandb.run is not None, "wandb is not initialized, call init_wandb first!" + if cls._instance is None: + cls._instance = cls.__new__(cls) + return cls._instance + + def watch(self, model, log_freq: int = 1000) -> None: + wandb.watch(model, log_freq=log_freq) + + def log( + self, update_dict: dict, step: int | None = None, commit=False, split: str = "" + ) -> None: + # HACK: this is really ugly logic here for backward compat but we should get rid of. + # the split string shouldn't inserted here + if split != "": + new_dict = {} + for key in update_dict: + new_dict[f"{split}/{key}"] = update_dict[key] + update_dict = new_dict + + # if step is not specified, wandb will use an auto-incremented step: https://docs.wandb.ai/ref/python/log/ + # otherwise the user must increment it manually (not recommended) + wandb.log(update_dict, step=step, commit=commit) + + def log_plots(self, plots, caption: str = "") -> None: + assert isinstance(plots, list) + plots = [wandb.Image(x, caption=caption) for x in plots] + wandb.log({"data": plots}) + + def log_summary(self, summary_dict: dict[str, Any]): + for k, v in summary_dict.items(): + wandb.run.summary[k] = v + + def mark_preempting(self) -> None: + wandb.mark_preempting() + + def log_artifact(self, name: str, type: str, file_location: str) -> None: + art = wandb.Artifact(name=name, type=type) + art.add_file(file_location) + art.save() diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 9da73f723..0ead76b2c 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -31,6 +31,7 @@ from fairchem.core import __version__ from fairchem.core.common import distutils, gp_utils from fairchem.core.common.data_parallel import BalancedBatchSampler +from fairchem.core.common.logger import WandBSingletonLogger from fairchem.core.common.registry import registry from fairchem.core.common.slurm import ( add_timestamp_id_to_submission_pickle, @@ -275,7 +276,19 @@ def load_logger(self) -> None: logger_name = logger if isinstance(logger, str) else logger["name"] assert logger_name, "Specify logger name" - self.logger = registry.get_logger_class(logger_name)(self.config) + if logger_name == "wandb_singleton": + WandBSingletonLogger.init_wandb( + config=self.config, + run_id=self.config["cmd"]["timestamp_id"], + run_name=self.config["cmd"]["identifier"], + log_dir=self.config["cmd"]["logs_dir"], + project=self.config["logger"]["project"], + entity=self.config["logger"]["entity"], + group=self.config["logger"].get("group", ""), + ) + self.logger = WandBSingletonLogger.get_instance() + else: + self.logger = registry.get_logger_class(logger_name)(self.config) def get_sampler( self, dataset, batch_size: int, shuffle: bool