Skip to content

Commit

Permalink
compile works
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Sep 2, 2024
1 parent 2250fa6 commit c0d8e41
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 16 deletions.
9 changes: 4 additions & 5 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from submitit.helpers import Checkpointable, DelayedSubmission
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from fairchem.core.common.distutils import init_local_distributed_process_group
from fairchem.core.common.flags import flags
from fairchem.core.common.utils import (
build_config,
Expand Down Expand Up @@ -94,7 +95,7 @@ def main():
logging.info(f"Experiment log saved to: {log_file}")

else: # Run locally on a single node, n-processes
if args.distributed:
if args.num_gpus > 1:
logging.info(
f"Running in distributed local mode with {args.num_gpus} ranks"
)
Expand All @@ -116,10 +117,8 @@ def main():
)
elastic_launch(launch_config, runner_wrapper)(args.distributed, config)
else:
logging.info("Running in non-distributed local mode")
assert (
args.num_gpus == 1
), "Can only run with a single gpu in non distributed local mode, use --distributed flag instead if using >1 gpu"
logging.info("Running in local mode")
init_local_distributed_process_group(backend='nccl')
runner_wrapper(args.distributed, config)


Expand Down
8 changes: 3 additions & 5 deletions src/fairchem/core/models/escn/escn_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,8 @@ def forward(self, data):
graph.edge_index,
wigner,
)

# Residual layer for all layers past the first
x_xessage = x_message + x_message_new

x_message = x_message + x_message_new
else:
# No residual for the first layer
x_message = self.layer_blocks[i](
Expand Down Expand Up @@ -599,11 +597,11 @@ def forward(
x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid)

# Rotate back the irreps
wigner_inv = torch.transpose(wigner, 1, 2).contiguous()
wigner_inv = torch.transpose(wigner, 1, 2).contiguous().detach()
x_target = torch.bmm(wigner_inv[:, :, self.out_mask], x_target)

# Compute the sum of the incoming neighboring messages for each target node
new_embedding = torch.fill(x.clone(), 0)
new_embedding = torch.zeros(x.shape, dtype=x_target.dtype, device=x_target.device)
new_embedding.index_add_(0, edge_index[1], x_target)

return new_embedding
Expand Down
7 changes: 3 additions & 4 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,11 +551,10 @@ def load_model(self) -> None:
device_ids=None if self.cpu else [self.device],
)

torch._dynamo.config.optimize_ddp = False
torch._dynamo.config.assume_static_by_default = False
torch._dynamo.config.automatic_dynamic_shapes = True

if self.config["optim"].get("compiles"):
torch._dynamo.config.optimize_ddp = False
torch._dynamo.config.assume_static_by_default = False
torch._dynamo.config.automatic_dynamic_shapes = True
os.environ["TORCH_LOGS"] = "recompiles"
self.model = torch.compile(self.model, dynamic=True)
torch._dynamo.config.optimize_ddp = False
Expand Down
5 changes: 5 additions & 0 deletions src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import defaultdict
from itertools import chain
from typing import TYPE_CHECKING
import time

import numpy as np
import torch
Expand Down Expand Up @@ -139,6 +140,7 @@ def train(self, disable_eval_tqdm: bool = False) -> None:
# Calculate start_epoch from step instead of loading the epoch number
# to prevent inconsistencies due to different batch size in checkpoint.
start_epoch = self.step // len(self.train_loader)
previous_wall_time = time.time()

for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]):
skip_steps = self.step % len(self.train_loader)
Expand Down Expand Up @@ -182,6 +184,9 @@ def train(self, disable_eval_tqdm: bool = False) -> None:
self.step % self.config["cmd"]["print_every"] == 0
and distutils.is_master()
):
time_delta = time.time() - previous_wall_time
previous_wall_time = time.time()
log_dict.update({'step_per_s' : self.config["cmd"]["print_every"] / time_delta})
log_str = [f"{k}: {v:.2e}" for k, v in log_dict.items()]
logging.info(", ".join(log_str))
self.metrics = {}
Expand Down
4 changes: 2 additions & 2 deletions tests/core/models/test_escn_compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def init(backend: str):
init_local_distributed_process_group(backend=backend)

class TestESCNCompiles:
def test_escn_baseline_cpu(self, tol=1e-5):
def test_escn_baseline_cpu(self, tol=1e-8):
init('gloo')
data = load_data()
data = data_list_collater([data])
Expand All @@ -98,7 +98,7 @@ def test_escn_baseline_cpu(self, tol=1e-5):
assert torch.allclose(base_output["forces"].mean(0), export_output["forces"].mean(0), atol=tol)

@skip_if_no_cuda
def test_escn_baseline_cuda(self, tol=1e-5):
def test_escn_baseline_cuda(self, tol=1e-8):
init('nccl')
data = load_data()
data = data_list_collater([data]).to("cuda")
Expand Down

0 comments on commit c0d8e41

Please sign in to comment.