From 8a9adbbc3631f73f5e40c2356a558e5a4a482449 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:48:40 -0700 Subject: [PATCH] Add hydra entrypoint (#867) * format * working relaxations on local, add submitit * small fixes * set logging * add reqs * add tests * add tests * move hydra package * revert some changes to predict * cpu tests * add cleanup * typo * add check for dist initialized * update test * wrap feature flag for local mode as well --- packages/fairchem-core/pyproject.toml | 3 +- src/fairchem/core/_cli.py | 14 +- src/fairchem/core/_cli_hydra.py | 126 ++++++++++++++++++ src/fairchem/core/common/distutils.py | 49 ++++++- src/fairchem/core/common/flags.py | 6 +- .../relaxation/optimizers/lbfgs_torch.py | 2 +- src/fairchem/core/common/utils.py | 6 + src/fairchem/core/components/runner.py | 42 ++++++ tests/core/test_hydra_cli.py | 50 +++++++ tests/core/test_hydra_cli.yml | 4 + 10 files changed, 287 insertions(+), 15 deletions(-) create mode 100644 src/fairchem/core/_cli_hydra.py create mode 100644 src/fairchem/core/components/runner.py create mode 100644 tests/core/test_hydra_cli.py create mode 100644 tests/core/test_hydra_cli.yml diff --git a/packages/fairchem-core/pyproject.toml b/packages/fairchem-core/pyproject.toml index 5c2ab55a62..ee92db45df 100644 --- a/packages/fairchem-core/pyproject.toml +++ b/packages/fairchem-core/pyproject.toml @@ -22,7 +22,8 @@ dependencies = [ "requests", "orjson", "tqdm", - "submitit" + "submitit", + "hydra-core" ] [project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev] diff --git a/src/fairchem/core/_cli.py b/src/fairchem/core/_cli.py index cd3960998e..b188739d14 100644 --- a/src/fairchem/core/_cli.py +++ b/src/fairchem/core/_cli.py @@ -9,14 +9,13 @@ import copy import logging -import os from typing import TYPE_CHECKING from submitit import AutoExecutor from submitit.helpers import Checkpointable, DelayedSubmission -from torch.distributed.elastic.utils.distributed import get_free_port from torch.distributed.launcher.api import LaunchConfig, elastic_launch +from fairchem.core.common import distutils from fairchem.core.common.flags import flags from fairchem.core.common.utils import ( build_config, @@ -69,6 +68,12 @@ def main( parser: argparse.ArgumentParser = flags.get_parser() args, override_args = parser.parse_known_args() + if args.hydra: + from fairchem.core._cli_hydra import main + + main(args, override_args) + return + # TODO: rename num_gpus -> num_ranks everywhere assert ( args.num_gpus > 0 @@ -126,10 +131,7 @@ def main( logging.info( "Running in local mode without elastic launch (single gpu only)" ) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["LOCAL_RANK"] = "0" - os.environ["RANK"] = "0" - os.environ["MASTER_PORT"] = str(get_free_port()) + distutils.setup_env_local() runner_wrapper(config) diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py new file mode 100644 index 0000000000..79992616df --- /dev/null +++ b/src/fairchem/core/_cli_hydra.py @@ -0,0 +1,126 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING + +import hydra + +if TYPE_CHECKING: + import argparse + + from omegaconf import DictConfig + +from submitit import AutoExecutor +from submitit.helpers import Checkpointable, DelayedSubmission +from torch.distributed.launcher.api import LaunchConfig, elastic_launch + +from fairchem.core.common import distutils +from fairchem.core.common.flags import flags +from fairchem.core.common.utils import get_timestamp_uid, setup_env_vars, setup_imports +from fairchem.core.components.runner import Runner + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Submitit(Checkpointable): + def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None: + self.config = dict_config + self.cli_args = cli_args + # TODO: setup_imports is not needed if we stop instantiating models with Registry. + setup_imports() + setup_env_vars() + try: + distutils.setup(map_cli_args_to_dist_config(cli_args)) + runner: Runner = hydra.utils.instantiate(dict_config.runner) + runner.load_state() + runner.run() + finally: + distutils.cleanup() + + def checkpoint(self, *args, **kwargs): + logging.info("Submitit checkpointing callback is triggered") + new_runner = Runner() + new_runner.save_state() + logging.info("Submitit checkpointing callback is completed") + return DelayedSubmission(new_runner, self.config) + + +def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict: + return { + "world_size": cli_args.num_nodes * cli_args.num_gpus, + "distributed_backend": "gloo" if cli_args.cpu else "nccl", + "submit": cli_args.submit, + "summit": None, + "cpu": cli_args.cpu, + "use_cuda_visibile_devices": True, + } + + +def get_hydra_config_from_yaml( + config_yml: str, overrides_args: list[str] +) -> DictConfig: + # Load the configuration from the file + os.environ["HYDRA_FULL_ERROR"] = "1" + config_directory = os.path.dirname(os.path.abspath(config_yml)) + config_name = os.path.basename(config_yml) + hydra.initialize_config_dir(config_directory) + return hydra.compose(config_name=config_name, overrides=overrides_args) + + +def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace): + Submitit()(config, cli_args) + + +# this is meant as a future replacement for the main entrypoint +def main( + args: argparse.Namespace | None = None, override_args: list[str] | None = None +): + if args is None: + parser: argparse.ArgumentParser = flags.get_parser() + args, override_args = parser.parse_known_args() + + cfg = get_hydra_config_from_yaml(args.config_yml, override_args) + timestamp_id = get_timestamp_uid() + log_dir = os.path.join(args.run_dir, timestamp_id, "logs") + if args.submit: # Run on cluster + executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3) + executor.update_parameters( + name=args.identifier, + mem_gb=args.slurm_mem, + timeout_min=args.slurm_timeout * 60, + slurm_partition=args.slurm_partition, + gpus_per_node=args.num_gpus, + cpus_per_task=8, + tasks_per_node=args.num_gpus, + nodes=args.num_nodes, + slurm_qos=args.slurm_qos, + slurm_account=args.slurm_account, + ) + job = executor.submit(runner_wrapper, cfg, args) + logger.info( + f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}" + ) + else: + if args.num_gpus > 1: + logger.info(f"Running in local mode with {args.num_gpus} ranks") + launch_config = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=args.num_gpus, + rdzv_backend="c10d", + max_restarts=0, + ) + elastic_launch(launch_config, runner_wrapper)(cfg, args) + else: + logger.info("Running in local mode without elastic launch") + distutils.setup_env_local() + runner_wrapper(cfg, args) diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 2e232ae7ba..604f969a86 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -21,6 +21,7 @@ T = TypeVar("T") DISTRIBUTED_PORT = 13356 +CURRENT_DEVICE_STR = "CURRRENT_DEVICE" def os_environ_get_or_throw(x: str) -> str: @@ -72,7 +73,15 @@ def setup(config) -> None: logging.info( f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}" ) - torch.cuda.set_device(config["local_rank"]) + + # In the new hydra runners, we setup the device for each rank as either cuda:0 or cpu + # after this point, the local rank should either be using "cpu" or "cuda" + if config.get("use_cuda_visibile_devices"): + assign_device_for_local_rank(config["cpu"], config["local_rank"]) + else: + # in the old code, all ranks can see all devices but need to be assigned a device equal to their local rank + # this is dangerous and should be deprecated + torch.cuda.set_device(config["local_rank"]) dist.init_process_group( backend="nccl", @@ -110,11 +119,10 @@ def setup(config) -> None: assert ( config["world_size"] == 1 ), "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(get_free_port()) - os.environ["LOCAL_RANK"] = "0" - os.environ["RANK"] = "0" + setup_env_local() config["local_rank"] = int(os.environ.get("LOCAL_RANK")) + if config.get("use_cuda_visibile_devices"): + assign_device_for_local_rank(config["cpu"], config["local_rank"]) dist.init_process_group( backend=config["distributed_backend"], rank=int(os.environ.get("RANK")), @@ -124,7 +132,8 @@ def setup(config) -> None: def cleanup() -> None: - dist.destroy_process_group() + if dist.is_initialized(): + dist.destroy_process_group() def initialized() -> bool: @@ -210,3 +219,31 @@ def gather_objects(data: T, group: dist.ProcessGroup = dist.group.WORLD) -> list output = [None for _ in range(get_world_size())] if is_master() else None dist.gather_object(data, output, group=group, dst=0) return output + + +def assign_device_for_local_rank(cpu: bool, local_rank: int): + if cpu: + os.environ[CURRENT_DEVICE_STR] = "cpu" + else: + # assert the cuda device to be the local rank + os.environ[CURRENT_DEVICE_STR] = "cuda" + os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank) + + +def get_device_for_local_rank(): + cur_dev_env = os.environ.get(CURRENT_DEVICE_STR) + if cur_dev_env is not None: + return cur_dev_env + else: + device = "cuda" if torch.cuda.available() else "cpu" + logging.warn( + f"{CURRENT_DEVICE_STR} env variable not found, defaulting to {device}" + ) + return device + + +def setup_env_local(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["LOCAL_RANK"] = "0" + os.environ["RANK"] = "0" + os.environ["MASTER_PORT"] = str(get_free_port()) diff --git a/src/fairchem/core/common/flags.py b/src/fairchem/core/common/flags.py index 65a9243f27..7753a736f3 100644 --- a/src/fairchem/core/common/flags.py +++ b/src/fairchem/core/common/flags.py @@ -27,7 +27,6 @@ def add_core_args(self) -> None: self.parser.add_argument( "--mode", choices=["train", "predict", "run-relaxations", "validate"], - required=True, help="Whether to train the model, make predictions, or to run relaxations", ) self.parser.add_argument( @@ -121,6 +120,11 @@ def add_core_args(self) -> None: self.parser.add_argument( "--cpu", action="store_true", help="Run CPU only training" ) + self.parser.add_argument( + "--hydra", + action="store_true", + help="Use hydra configs instead (in development)", + ) self.parser.add_argument( "--num-nodes", default=1, diff --git a/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py b/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py index 7ad3745ae6..a90f0dce5b 100644 --- a/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py +++ b/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py @@ -52,7 +52,7 @@ def __init__( self.traj_dir = traj_dir self.traj_names = traj_names self.early_stop_batch = early_stop_batch - self.otf_graph = model.model._unwrapped_model.otf_graph + self.otf_graph = True assert not self.traj_dir or ( traj_dir and len(traj_names) ), "Trajectory names should be specified to save trajectories" diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index e9d498153a..acd1255e07 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -10,6 +10,7 @@ import ast import collections import copy +import datetime import errno import importlib import itertools @@ -26,6 +27,7 @@ from itertools import product from pathlib import Path from typing import TYPE_CHECKING, Any +from uuid import uuid4 import numpy as np import torch @@ -1446,3 +1448,7 @@ def load_model_and_weights_from_checkpoint(checkpoint_path: str) -> nn.Module: matched_dict = match_state_dict(model.state_dict(), checkpoint["state_dict"]) load_state_dict(model, matched_dict, strict=True) return model + + +def get_timestamp_uid() -> str: + return datetime.datetime.now().strftime("%Y%m-%d%H-%M%S-") + str(uuid4())[:4] diff --git a/src/fairchem/core/components/runner.py b/src/fairchem/core/components/runner.py new file mode 100644 index 0000000000..d33377f9f2 --- /dev/null +++ b/src/fairchem/core/components/runner.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import Any + + +class Runner(metaclass=ABCMeta): + """ + Represents an abstraction over things that run in a loop and can save/load state. + ie: Trainers, Validators, Relaxation all fall in this category. + This allows us to decouple away from a monolithic trainer class + """ + + @abstractmethod + def run(self) -> Any: + raise NotImplementedError + + @abstractmethod + def save_state(self) -> None: + raise NotImplementedError + + @abstractmethod + def load_state(self) -> None: + raise NotImplementedError + + +# Used for testing +class MockRunner(Runner): + def __init__(self, x: int, y: int): + self.x = x + self.y = y + + def run(self) -> Any: + if self.x * self.y > 1000: + raise ValueError("sum is greater than 1000!") + return self.x + self.y + + def save_state(self) -> None: + pass + + def load_state(self) -> None: + pass diff --git a/tests/core/test_hydra_cli.py b/tests/core/test_hydra_cli.py new file mode 100644 index 0000000000..b2619fdb7b --- /dev/null +++ b/tests/core/test_hydra_cli.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import sys + +import hydra +import pytest + +from fairchem.core._cli import main +from fairchem.core.common import distutils + + +def test_hydra_cli(): + distutils.cleanup() + hydra.core.global_hydra.GlobalHydra.instance().clear() + sys_args = ["--hydra", "--config-yml", "tests/core/test_hydra_cli.yml", "--cpu"] + sys.argv[1:] = sys_args + main() + + +def test_hydra_cli_throws_error(): + distutils.cleanup() + hydra.core.global_hydra.GlobalHydra.instance().clear() + sys_args = [ + "--hydra", + "--cpu", + "--config-yml", + "tests/core/test_hydra_cli.yml", + "runner.x=1000", + "runner.y=5", + ] + sys.argv[1:] = sys_args + with pytest.raises(ValueError) as error_info: + main() + assert "sum is greater than 1000" in str(error_info.value) + + +def test_hydra_cli_throws_error_on_invalid_inputs(): + distutils.cleanup() + hydra.core.global_hydra.GlobalHydra.instance().clear() + sys_args = [ + "--hydra", + "--cpu", + "--config-yml", + "tests/core/test_hydra_cli.yml", + "runner.x=1000", + "runner.z=5", # z is not a valid input argument to runner + ] + sys.argv[1:] = sys_args + with pytest.raises(hydra.errors.ConfigCompositionException): + main() diff --git a/tests/core/test_hydra_cli.yml b/tests/core/test_hydra_cli.yml new file mode 100644 index 0000000000..2064dd986f --- /dev/null +++ b/tests/core/test_hydra_cli.yml @@ -0,0 +1,4 @@ +runner: + _target_: fairchem.core.components.runner.MockRunner + x: 10 + y: 23