diff --git a/src/fairchem/core/_cli.py b/src/fairchem/core/_cli.py index f1270496a..47f5cb281 100644 --- a/src/fairchem/core/_cli.py +++ b/src/fairchem/core/_cli.py @@ -15,6 +15,7 @@ from submitit.helpers import Checkpointable, DelayedSubmission from torch.distributed.launcher.api import LaunchConfig, elastic_launch +from fairchem.core.common.distutils import init_local_distributed_process_group from fairchem.core.common.flags import flags from fairchem.core.common.utils import ( build_config, @@ -94,7 +95,7 @@ def main(): logging.info(f"Experiment log saved to: {log_file}") else: # Run locally on a single node, n-processes - if args.distributed: + if args.num_gpus > 1: logging.info( f"Running in distributed local mode with {args.num_gpus} ranks" ) @@ -116,10 +117,8 @@ def main(): ) elastic_launch(launch_config, runner_wrapper)(args.distributed, config) else: - logging.info("Running in non-distributed local mode") - assert ( - args.num_gpus == 1 - ), "Can only run with a single gpu in non distributed local mode, use --distributed flag instead if using >1 gpu" + logging.info("Running in local mode") + init_local_distributed_process_group(backend='nccl') runner_wrapper(args.distributed, config) diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index dfeb4f0b2..bccc936a4 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -268,10 +268,8 @@ def forward(self, data): graph.edge_index, wigner, ) - # Residual layer for all layers past the first - x_xessage = x_message + x_message_new - + x_message = x_message + x_message_new else: # No residual for the first layer x_message = self.layer_blocks[i]( @@ -599,11 +597,11 @@ def forward( x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) # Rotate back the irreps - wigner_inv = torch.transpose(wigner, 1, 2).contiguous() + wigner_inv = torch.transpose(wigner, 1, 2).contiguous().detach() x_target = torch.bmm(wigner_inv[:, :, self.out_mask], x_target) # Compute the sum of the incoming neighboring messages for each target node - new_embedding = torch.fill(x.clone(), 0) + new_embedding = torch.zeros(x.shape, dtype=x_target.dtype, device=x_target.device) new_embedding.index_add_(0, edge_index[1], x_target) return new_embedding diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 6918d45fc..7931785ef 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -551,11 +551,10 @@ def load_model(self) -> None: device_ids=None if self.cpu else [self.device], ) - torch._dynamo.config.optimize_ddp = False - torch._dynamo.config.assume_static_by_default = False - torch._dynamo.config.automatic_dynamic_shapes = True - if self.config["optim"].get("compiles"): + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True os.environ["TORCH_LOGS"] = "recompiles" self.model = torch.compile(self.model, dynamic=True) torch._dynamo.config.optimize_ddp = False diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 0ced35bef..f9c0ecdbe 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -12,6 +12,7 @@ from collections import defaultdict from itertools import chain from typing import TYPE_CHECKING +import time import numpy as np import torch @@ -139,6 +140,7 @@ def train(self, disable_eval_tqdm: bool = False) -> None: # Calculate start_epoch from step instead of loading the epoch number # to prevent inconsistencies due to different batch size in checkpoint. start_epoch = self.step // len(self.train_loader) + previous_wall_time = time.time() for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): skip_steps = self.step % len(self.train_loader) @@ -182,6 +184,9 @@ def train(self, disable_eval_tqdm: bool = False) -> None: self.step % self.config["cmd"]["print_every"] == 0 and distutils.is_master() ): + time_delta = time.time() - previous_wall_time + previous_wall_time = time.time() + log_dict.update({'step_per_s' : self.config["cmd"]["print_every"] / time_delta}) log_str = [f"{k}: {v:.2e}" for k, v in log_dict.items()] logging.info(", ".join(log_str)) self.metrics = {} diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index e653ef405..96d3c0e04 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -85,7 +85,7 @@ def init(backend: str): init_local_distributed_process_group(backend=backend) class TestESCNCompiles: - def test_escn_baseline_cpu(self, tol=1e-5): + def test_escn_baseline_cpu(self, tol=1e-8): init('gloo') data = load_data() data = data_list_collater([data]) @@ -98,7 +98,7 @@ def test_escn_baseline_cpu(self, tol=1e-5): assert torch.allclose(base_output["forces"].mean(0), export_output["forces"].mean(0), atol=tol) @skip_if_no_cuda - def test_escn_baseline_cuda(self, tol=1e-5): + def test_escn_baseline_cuda(self, tol=1e-8): init('nccl') data = load_data() data = data_list_collater([data]).to("cuda")