diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index 6ca8b7d6f..9279c8daf 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING import hydra +from omegaconf import OmegaConf if TYPE_CHECKING: import argparse @@ -34,21 +35,40 @@ class Submitit(Checkpointable): - def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None: + def __call__(self, dict_config: DictConfig) -> None: self.config = dict_config - self.cli_args = cli_args # TODO: setup_imports is not needed if we stop instantiating models with Registry. setup_imports() setup_env_vars() - try: - distutils.setup(map_cli_args_to_dist_config(cli_args)) - self.runner: Runner = hydra.utils.instantiate(dict_config.runner) - self.runner.load_state() - self.runner.run() - finally: - distutils.cleanup() - - def checkpoint(self, *args, **kwargs): + distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args)) + self._init_logger() + runner: Runner = hydra.utils.instantiate(dict_config.runner) + runner.load_state() + runner.run() + distutils.cleanup() + + def _init_logger(self) -> None: + # optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger + # don't start logger if in debug mode + if ( + "logger" in self.config + and distutils.is_master() + and not self.config.cli_args.debug + ): + # get a partial function from the config and instantiate wandb with it + logger_initializer = hydra.utils.instantiate(self.config.logger) + simple_config = OmegaConf.to_container( + self.config, resolve=True, throw_on_missing=True + ) + logger_initializer( + config=simple_config, + run_id=self.config.cli_args.timestamp_id, + run_name=self.config.cli_args.identifier, + log_dir=self.config.cli_args.logdir, + ) + + def checkpoint(self, *args, **kwargs) -> DelayedSubmission: + # TODO: this is yet to be tested properly logging.info("Submitit checkpointing callback is triggered") new_runner = Submitit() self.runner.save_state() @@ -56,7 +76,7 @@ def checkpoint(self, *args, **kwargs): return DelayedSubmission(new_runner, self.config, self.cli_args) -def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict: +def map_cli_args_to_dist_config(cli_args: DictConfig) -> dict: return { "world_size": cli_args.num_nodes * cli_args.num_gpus, "distributed_backend": "gloo" if cli_args.cpu else "nccl", @@ -78,8 +98,8 @@ def get_hydra_config_from_yaml( return hydra.compose(config_name=config_name, overrides=overrides_args) -def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace): - Submitit()(config, cli_args) +def runner_wrapper(config: DictConfig): + Submitit()(config) # this is meant as a future replacement for the main entrypoint @@ -93,6 +113,11 @@ def main( cfg = get_hydra_config_from_yaml(args.config_yml, override_args) timestamp_id = get_timestamp_uid() log_dir = os.path.join(args.run_dir, timestamp_id, "logs") + # override timestamp id and logdir + args.timestamp_id = timestamp_id + args.logdir = log_dir + os.makedirs(log_dir) + OmegaConf.update(cfg, "cli_args", vars(args), force_add=True) if args.submit: # Run on cluster executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3) executor.update_parameters( @@ -107,7 +132,7 @@ def main( slurm_qos=args.slurm_qos, slurm_account=args.slurm_account, ) - job = executor.submit(runner_wrapper, cfg, args) + job = executor.submit(runner_wrapper, cfg) logger.info( f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}" ) @@ -131,8 +156,8 @@ def main( rdzv_backend="c10d", max_restarts=0, ) - elastic_launch(launch_config, runner_wrapper)(cfg, args) + elastic_launch(launch_config, runner_wrapper)(cfg) else: logger.info("Running in local mode without elastic launch") distutils.setup_env_local() - runner_wrapper(cfg, args) + runner_wrapper(cfg)