Skip to content

Commit

Permalink
add trainer update
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Oct 9, 2024
1 parent 7a883e7 commit 5a9eb39
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from fairchem.core import __version__
from fairchem.core.common import distutils, gp_utils
from fairchem.core.common.data_parallel import BalancedBatchSampler
from fairchem.core.common.logger import WandBSingletonLogger
from fairchem.core.common.registry import registry
from fairchem.core.common.slurm import (
add_timestamp_id_to_submission_pickle,
Expand Down Expand Up @@ -275,7 +276,19 @@ def load_logger(self) -> None:
logger_name = logger if isinstance(logger, str) else logger["name"]
assert logger_name, "Specify logger name"

self.logger = registry.get_logger_class(logger_name)(self.config)
if logger_name == "wandb_singleton":
WandBSingletonLogger.init_wandb(
config=self.config,
run_id=self.config["cmd"]["timestamp_id"],
run_name=self.config["cmd"]["identifier"],
log_dir=self.config["cmd"]["logs_dir"],
project=self.config["logger"]["project"],
entity=self.config["logger"]["entity"],
group=self.config["logger"].get("group", ""),
)
self.logger = WandBSingletonLogger.get_instance()
else:
self.logger = registry.get_logger_class(logger_name)(self.config)

def get_sampler(
self, dataset, batch_size: int, shuffle: bool
Expand Down

0 comments on commit 5a9eb39

Please sign in to comment.