Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add singleton logger #873

Merged
merged 4 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/fairchem/core/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,89 @@ def log_summary(self, summary_dict: dict[str, Any]) -> None:

def log_artifact(self, name: str, type: str, file_location: str) -> None:
logging.warning("log_artifact for Tensorboard not supported")


class WandBSingletonLogger:
"""
Singleton version of wandb logger, this forces a single instance of the logger to be created and used from anywhere in the code (not just from the trainer).
This will replace the original WandBLogger.

We initialize wandb instance somewhere in the trainer/runner globally:

WandBSingletonLogger.init_wandb(...)

Then from anywhere in the code we can fetch the singleton instance and log to wandb,
note this allows you to log without knowing explicitly which step you are on
see: https://docs.wandb.ai/ref/python/log/#the-wb-step for more details

WandBSingletonLogger.get_instance().log({"some_value": value}, commit=False)
"""

_instance = None

def __init__(self):
raise RuntimeError("Call get_instance() instead")

@classmethod
def init_wandb(
cls,
config: dict,
run_id: str,
run_name: str,
log_dir: str,
project: str,
entity: str,
group: str | None = None,
) -> None:
wandb.init(
config=config,
id=run_id,
name=run_name,
dir=log_dir,
project=project,
entity=entity,
resume="allow",
group=group,
)

@classmethod
def get_instance(cls):
assert wandb.run is not None, "wandb is not initialized, call init_wandb first!"
if cls._instance is None:
cls._instance = cls.__new__(cls)
return cls._instance

def watch(self, model, log_freq: int = 1000) -> None:
wandb.watch(model, log_freq=log_freq)

def log(
self, update_dict: dict, step: int | None = None, commit=False, split: str = ""
) -> None:
# HACK: this is really ugly logic here for backward compat but we should get rid of.
# the split string shouldn't inserted here
if split != "":
new_dict = {}
for key in update_dict:
new_dict[f"{split}/{key}"] = update_dict[key]
update_dict = new_dict

# if step is not specified, wandb will use an auto-incremented step: https://docs.wandb.ai/ref/python/log/
# otherwise the user must increment it manually (not recommended)
wandb.log(update_dict, step=step, commit=commit)

def log_plots(self, plots, caption: str = "") -> None:
assert isinstance(plots, list)
plots = [wandb.Image(x, caption=caption) for x in plots]
wandb.log({"data": plots})

def log_summary(self, summary_dict: dict[str, Any]):
for k, v in summary_dict.items():
wandb.run.summary[k] = v

def mark_preempting(self) -> None:
wandb.mark_preempting()

def log_artifact(self, name: str, type: str, file_location: str) -> None:
art = wandb.Artifact(name=name, type=type)
art.add_file(file_location)
art.save()
15 changes: 14 additions & 1 deletion src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from fairchem.core import __version__
from fairchem.core.common import distutils, gp_utils
from fairchem.core.common.data_parallel import BalancedBatchSampler
from fairchem.core.common.logger import WandBSingletonLogger
from fairchem.core.common.registry import registry
from fairchem.core.common.slurm import (
add_timestamp_id_to_submission_pickle,
Expand Down Expand Up @@ -275,7 +276,19 @@ def load_logger(self) -> None:
logger_name = logger if isinstance(logger, str) else logger["name"]
assert logger_name, "Specify logger name"

self.logger = registry.get_logger_class(logger_name)(self.config)
if logger_name == "wandb_singleton":
WandBSingletonLogger.init_wandb(
config=self.config,
run_id=self.config["cmd"]["timestamp_id"],
run_name=self.config["cmd"]["identifier"],
log_dir=self.config["cmd"]["logs_dir"],
project=self.config["logger"]["project"],
entity=self.config["logger"]["entity"],
group=self.config["logger"].get("group", ""),
)
self.logger = WandBSingletonLogger.get_instance()
else:
self.logger = registry.get_logger_class(logger_name)(self.config)

def get_sampler(
self, dataset, batch_size: int, shuffle: bool
Expand Down