Skip to content

Commit

Permalink
Add hydra entrypoint (#867)
Browse files Browse the repository at this point in the history
* format

* working relaxations on local, add submitit

* small fixes

* set logging

* add reqs

* add tests

* add tests

* move hydra package

* revert some changes to predict

* cpu tests

* add cleanup

* typo

* add check for dist initialized

* update test

* wrap feature flag for local mode as well
  • Loading branch information
rayg1234 authored Oct 18, 2024
1 parent 4164cc0 commit 8a9adbb
Show file tree
Hide file tree
Showing 10 changed files with 287 additions and 15 deletions.
3 changes: 2 additions & 1 deletion packages/fairchem-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ dependencies = [
"requests",
"orjson",
"tqdm",
"submitit"
"submitit",
"hydra-core"
]

[project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev]
Expand Down
14 changes: 8 additions & 6 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

import copy
import logging
import os
from typing import TYPE_CHECKING

from submitit import AutoExecutor
from submitit.helpers import Checkpointable, DelayedSubmission
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from fairchem.core.common import distutils
from fairchem.core.common.flags import flags
from fairchem.core.common.utils import (
build_config,
Expand Down Expand Up @@ -69,6 +68,12 @@ def main(
parser: argparse.ArgumentParser = flags.get_parser()
args, override_args = parser.parse_known_args()

if args.hydra:
from fairchem.core._cli_hydra import main

main(args, override_args)
return

# TODO: rename num_gpus -> num_ranks everywhere
assert (
args.num_gpus > 0
Expand Down Expand Up @@ -126,10 +131,7 @@ def main(
logging.info(
"Running in local mode without elastic launch (single gpu only)"
)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["MASTER_PORT"] = str(get_free_port())
distutils.setup_env_local()
runner_wrapper(config)


Expand Down
126 changes: 126 additions & 0 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

import logging
import os
from typing import TYPE_CHECKING

import hydra

if TYPE_CHECKING:
import argparse

from omegaconf import DictConfig

from submitit import AutoExecutor
from submitit.helpers import Checkpointable, DelayedSubmission
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from fairchem.core.common import distutils
from fairchem.core.common.flags import flags
from fairchem.core.common.utils import get_timestamp_uid, setup_env_vars, setup_imports
from fairchem.core.components.runner import Runner

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Submitit(Checkpointable):
def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None:
self.config = dict_config
self.cli_args = cli_args
# TODO: setup_imports is not needed if we stop instantiating models with Registry.
setup_imports()
setup_env_vars()
try:
distutils.setup(map_cli_args_to_dist_config(cli_args))
runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.load_state()
runner.run()
finally:
distutils.cleanup()

def checkpoint(self, *args, **kwargs):
logging.info("Submitit checkpointing callback is triggered")
new_runner = Runner()
new_runner.save_state()
logging.info("Submitit checkpointing callback is completed")
return DelayedSubmission(new_runner, self.config)


def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict:
return {
"world_size": cli_args.num_nodes * cli_args.num_gpus,
"distributed_backend": "gloo" if cli_args.cpu else "nccl",
"submit": cli_args.submit,
"summit": None,
"cpu": cli_args.cpu,
"use_cuda_visibile_devices": True,
}


def get_hydra_config_from_yaml(
config_yml: str, overrides_args: list[str]
) -> DictConfig:
# Load the configuration from the file
os.environ["HYDRA_FULL_ERROR"] = "1"
config_directory = os.path.dirname(os.path.abspath(config_yml))
config_name = os.path.basename(config_yml)
hydra.initialize_config_dir(config_directory)
return hydra.compose(config_name=config_name, overrides=overrides_args)


def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace):
Submitit()(config, cli_args)


# this is meant as a future replacement for the main entrypoint
def main(
args: argparse.Namespace | None = None, override_args: list[str] | None = None
):
if args is None:
parser: argparse.ArgumentParser = flags.get_parser()
args, override_args = parser.parse_known_args()

cfg = get_hydra_config_from_yaml(args.config_yml, override_args)
timestamp_id = get_timestamp_uid()
log_dir = os.path.join(args.run_dir, timestamp_id, "logs")
if args.submit: # Run on cluster
executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3)
executor.update_parameters(
name=args.identifier,
mem_gb=args.slurm_mem,
timeout_min=args.slurm_timeout * 60,
slurm_partition=args.slurm_partition,
gpus_per_node=args.num_gpus,
cpus_per_task=8,
tasks_per_node=args.num_gpus,
nodes=args.num_nodes,
slurm_qos=args.slurm_qos,
slurm_account=args.slurm_account,
)
job = executor.submit(runner_wrapper, cfg, args)
logger.info(
f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}"
)
else:
if args.num_gpus > 1:
logger.info(f"Running in local mode with {args.num_gpus} ranks")
launch_config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=args.num_gpus,
rdzv_backend="c10d",
max_restarts=0,
)
elastic_launch(launch_config, runner_wrapper)(cfg, args)
else:
logger.info("Running in local mode without elastic launch")
distutils.setup_env_local()
runner_wrapper(cfg, args)
49 changes: 43 additions & 6 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

T = TypeVar("T")
DISTRIBUTED_PORT = 13356
CURRENT_DEVICE_STR = "CURRRENT_DEVICE"


def os_environ_get_or_throw(x: str) -> str:
Expand Down Expand Up @@ -72,7 +73,15 @@ def setup(config) -> None:
logging.info(
f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}"
)
torch.cuda.set_device(config["local_rank"])

# In the new hydra runners, we setup the device for each rank as either cuda:0 or cpu
# after this point, the local rank should either be using "cpu" or "cuda"
if config.get("use_cuda_visibile_devices"):
assign_device_for_local_rank(config["cpu"], config["local_rank"])
else:
# in the old code, all ranks can see all devices but need to be assigned a device equal to their local rank
# this is dangerous and should be deprecated
torch.cuda.set_device(config["local_rank"])

dist.init_process_group(
backend="nccl",
Expand Down Expand Up @@ -110,11 +119,10 @@ def setup(config) -> None:
assert (
config["world_size"] == 1
), "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
setup_env_local()
config["local_rank"] = int(os.environ.get("LOCAL_RANK"))
if config.get("use_cuda_visibile_devices"):
assign_device_for_local_rank(config["cpu"], config["local_rank"])
dist.init_process_group(
backend=config["distributed_backend"],
rank=int(os.environ.get("RANK")),
Expand All @@ -124,7 +132,8 @@ def setup(config) -> None:


def cleanup() -> None:
dist.destroy_process_group()
if dist.is_initialized():
dist.destroy_process_group()


def initialized() -> bool:
Expand Down Expand Up @@ -210,3 +219,31 @@ def gather_objects(data: T, group: dist.ProcessGroup = dist.group.WORLD) -> list
output = [None for _ in range(get_world_size())] if is_master() else None
dist.gather_object(data, output, group=group, dst=0)
return output


def assign_device_for_local_rank(cpu: bool, local_rank: int):
if cpu:
os.environ[CURRENT_DEVICE_STR] = "cpu"
else:
# assert the cuda device to be the local rank
os.environ[CURRENT_DEVICE_STR] = "cuda"
os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank)


def get_device_for_local_rank():
cur_dev_env = os.environ.get(CURRENT_DEVICE_STR)
if cur_dev_env is not None:
return cur_dev_env
else:
device = "cuda" if torch.cuda.available() else "cpu"
logging.warn(
f"{CURRENT_DEVICE_STR} env variable not found, defaulting to {device}"
)
return device


def setup_env_local():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["MASTER_PORT"] = str(get_free_port())
6 changes: 5 additions & 1 deletion src/fairchem/core/common/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def add_core_args(self) -> None:
self.parser.add_argument(
"--mode",
choices=["train", "predict", "run-relaxations", "validate"],
required=True,
help="Whether to train the model, make predictions, or to run relaxations",
)
self.parser.add_argument(
Expand Down Expand Up @@ -121,6 +120,11 @@ def add_core_args(self) -> None:
self.parser.add_argument(
"--cpu", action="store_true", help="Run CPU only training"
)
self.parser.add_argument(
"--hydra",
action="store_true",
help="Use hydra configs instead (in development)",
)
self.parser.add_argument(
"--num-nodes",
default=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self.traj_dir = traj_dir
self.traj_names = traj_names
self.early_stop_batch = early_stop_batch
self.otf_graph = model.model._unwrapped_model.otf_graph
self.otf_graph = True
assert not self.traj_dir or (
traj_dir and len(traj_names)
), "Trajectory names should be specified to save trajectories"
Expand Down
6 changes: 6 additions & 0 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ast
import collections
import copy
import datetime
import errno
import importlib
import itertools
Expand All @@ -26,6 +27,7 @@
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, Any
from uuid import uuid4

import numpy as np
import torch
Expand Down Expand Up @@ -1446,3 +1448,7 @@ def load_model_and_weights_from_checkpoint(checkpoint_path: str) -> nn.Module:
matched_dict = match_state_dict(model.state_dict(), checkpoint["state_dict"])
load_state_dict(model, matched_dict, strict=True)
return model


def get_timestamp_uid() -> str:
return datetime.datetime.now().strftime("%Y%m-%d%H-%M%S-") + str(uuid4())[:4]
42 changes: 42 additions & 0 deletions src/fairchem/core/components/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import Any


class Runner(metaclass=ABCMeta):
"""
Represents an abstraction over things that run in a loop and can save/load state.
ie: Trainers, Validators, Relaxation all fall in this category.
This allows us to decouple away from a monolithic trainer class
"""

@abstractmethod
def run(self) -> Any:
raise NotImplementedError

@abstractmethod
def save_state(self) -> None:
raise NotImplementedError

@abstractmethod
def load_state(self) -> None:
raise NotImplementedError


# Used for testing
class MockRunner(Runner):
def __init__(self, x: int, y: int):
self.x = x
self.y = y

def run(self) -> Any:
if self.x * self.y > 1000:
raise ValueError("sum is greater than 1000!")
return self.x + self.y

def save_state(self) -> None:
pass

def load_state(self) -> None:
pass
Loading

0 comments on commit 8a9adbb

Please sign in to comment.