diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index 79992616d..6ca8b7d6f 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -18,6 +18,9 @@ from omegaconf import DictConfig + from fairchem.core.components.runner import Runner + + from submitit import AutoExecutor from submitit.helpers import Checkpointable, DelayedSubmission from torch.distributed.launcher.api import LaunchConfig, elastic_launch @@ -25,7 +28,6 @@ from fairchem.core.common import distutils from fairchem.core.common.flags import flags from fairchem.core.common.utils import get_timestamp_uid, setup_env_vars, setup_imports -from fairchem.core.components.runner import Runner logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -40,18 +42,18 @@ def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> Non setup_env_vars() try: distutils.setup(map_cli_args_to_dist_config(cli_args)) - runner: Runner = hydra.utils.instantiate(dict_config.runner) - runner.load_state() - runner.run() + self.runner: Runner = hydra.utils.instantiate(dict_config.runner) + self.runner.load_state() + self.runner.run() finally: distutils.cleanup() def checkpoint(self, *args, **kwargs): logging.info("Submitit checkpointing callback is triggered") - new_runner = Runner() - new_runner.save_state() + new_runner = Submitit() + self.runner.save_state() logging.info("Submitit checkpointing callback is completed") - return DelayedSubmission(new_runner, self.config) + return DelayedSubmission(new_runner, self.config, self.cli_args) def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict: @@ -111,7 +113,17 @@ def main( ) else: if args.num_gpus > 1: - logger.info(f"Running in local mode with {args.num_gpus} ranks") + logging.info(f"Running in local mode with {args.num_gpus} ranks") + # HACK to disable multiprocess dataloading in local mode + # there is an open issue where LMDB's environment cannot be pickled and used + # during torch multiprocessing https://github.com/pytorch/examples/issues/526 + # this HACK only works for a training submission where the config is passed in here + if "optim" in cfg and "num_workers" in cfg["optim"]: + cfg["optim"]["num_workers"] = 0 + logging.info( + "WARNING: running in local mode, setting dataloading num_workers to 0, see https://github.com/pytorch/examples/issues/526" + ) + launch_config = LaunchConfig( min_nodes=1, max_nodes=1, diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index 20b6e6922..8e9e3ceab 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -716,7 +716,8 @@ def radius_graph_pbc( # Tensor of unit cells cells_per_dim = [ - torch.arange(-rep, rep + 1, device=device, dtype=torch.float) for rep in max_rep + torch.arange(-rep.item(), rep.item() + 1, device=device, dtype=torch.float) + for rep in max_rep ] unit_cell = torch.cartesian_prod(*cells_per_dim) num_cells = len(unit_cell) diff --git a/tests/core/test_hydra_cli.yml b/tests/core/test_hydra_cli.yml index 2064dd986..59e0a158a 100644 --- a/tests/core/test_hydra_cli.yml +++ b/tests/core/test_hydra_cli.yml @@ -1,4 +1,4 @@ runner: _target_: fairchem.core.components.runner.MockRunner x: 10 - y: 23 + y: 23 \ No newline at end of file