Skip to content

Commit

Permalink
Merge branch 'main' into single-ruff-config
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque authored Jul 11, 2024
2 parents 5b6b42e + 51a439e commit ba248cb
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 32 deletions.
12 changes: 4 additions & 8 deletions configs/odac/s2ef/eqv2_153M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ model:

weight_init: 'uniform' # ['uniform', 'normal']

norm_scale_nodes: 192.561
norm_scale_degree: 21.024127419363214
avg_num_nodes: 192.561
avg_degree: 21.024127419363214

optim:
batch_size: 1
Expand All @@ -59,12 +59,8 @@ optim:
optimizer: AdamW
optimizer_params:
weight_decay: 0.3
scheduler: LambdaLR
scheduler_params:
lambda_type: cosine
warmup_factor: 0.2
warmup_epochs: 0.01
lr_min_factor: 0.01
scheduler: CosineAnnealingLR
T_max: 1600000

max_epochs: 1
clip_grad_norm: 100
Expand Down
8 changes: 2 additions & 6 deletions configs/odac/s2ef/eqv2_31M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,8 @@ optim:
optimizer: AdamW
optimizer_params:
weight_decay: 0.3
scheduler: LambdaLR
scheduler_params:
lambda_type: cosine
warmup_factor: 0.2
warmup_epochs: 0.01
lr_min_factor: 0.01
scheduler: CosineAnnealingLR
T_max: 1600000

max_epochs: 3
clip_grad_norm: 100
Expand Down
10 changes: 10 additions & 0 deletions src/fairchem/core/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def mark_preempting(self) -> None:
def log_summary(self, summary_dict: dict[str, Any]) -> None:
pass

@abstractmethod
def log_artifact(self, name: str, type: str, file_location: str) -> None:
pass

@registry.register_logger("wandb")
class WandBLogger(Logger):
Expand Down Expand Up @@ -101,6 +104,10 @@ def log_summary(self, summary_dict: dict[str, Any]):
def mark_preempting(self) -> None:
wandb.mark_preempting()

def log_artifact(self, name: str, type: str, file_location: str) -> None:
art = wandb.Artifact(name=name, type=type)
art.add_file(file_location)
art.save()

@registry.register_logger("tensorboard")
class TensorboardLogger(Logger):
Expand Down Expand Up @@ -130,3 +137,6 @@ def log_plots(self, plots) -> None:

def log_summary(self, summary_dict: dict[str, Any]) -> None:
logging.warning("log_summary for Tensorboard not supported")

def log_artifact(self, name: str, type: str, file_location: str) -> None:
logging.warning("log_artifact for Tensorboard not supported")
50 changes: 50 additions & 0 deletions src/fairchem/core/common/profiler_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING

import torch

from fairchem.core.common import distutils

if TYPE_CHECKING:
from fairchem.core.common.logger import Logger

def get_default_profiler_handler(run_id: str, output_dir: str, logger: Logger):
"""Get a standard callback handle for the pytorch profiler"""

def trace_handler(p):
if distutils.is_master():
trace_name = f"{run_id}_rank_{distutils.get_rank()}.pt.trace.json"
output_path = os.path.join(output_dir, trace_name)
print(f"Saving trace in {output_path}")
p.export_chrome_trace(output_path)
if logger:
logger.log_artifact(name=trace_name, type="profile", file_location=output_path)
return trace_handler

def get_profile_schedule(wait: int = 5, warmup: int = 5, active: int = 2):
"""Get a profile schedule and total number of steps to run
check pytorch docs on the meaning of these paramters:
https://pytorch.org/docs/stable/profiler.html#torch.profiler.schedule
Example usage:
```
trace_handler = get_default_profiler_handler(run_id = self.config["cmd"]["timestamp_id"],
output_dir = self.config["cmd"]["results_dir"],
logger = self.logger)
profile_schedule, total_profile_steps = get_profile_schedule()
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=profile_schedule,
on_trace_ready=trace_handler
) as p:
for i in steps:
<code block to profile>
if i < total_profile_steps:
p.step()
"""
total_profile_steps = wait + warmup + active
profile_schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active)

return profile_schedule, total_profile_steps
4 changes: 2 additions & 2 deletions src/fairchem/core/models/pretrained_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
"PaiNN-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/PaiNN.pt"
"GemNet-OC-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Gemnet-OC.pt"
"eSCN-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/eSCN.pt"
"EquiformerV2-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231116/eqv2_31M.pt"
"EquiformerV2-Large-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Equiformer_V2_Large.pt"
"EquiformerV2-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20240709/eqv2_31M.pt"
"EquiformerV2-Large-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20240709/Equiformer_V2_Large.pt"
"Gemnet-OC-IS2RE-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Gemnet-OC_Direct.pt"
"eSCN-IS2RE-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/eSCN_Direct.pt"
"EquiformerV2-IS2RE-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Equiformer_V2_Direct.pt"
38 changes: 22 additions & 16 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
self.config["dataset"] = dataset.get("train", None)
self.config["val_dataset"] = dataset.get("val", None)
self.config["test_dataset"] = dataset.get("test", None)
self.config["relax_dataset"] = dataset.get("relax", None)
else:
self.config["dataset"] = dataset

Expand Down Expand Up @@ -339,22 +340,27 @@ def load_datasets(self) -> None:
self.test_sampler,
)

# load relaxation dataset
if "relax_dataset" in self.config["task"]:
self.relax_dataset = registry.get_dataset_class("lmdb")(
self.config["task"]["relax_dataset"]
)
self.relax_sampler = self.get_sampler(
self.relax_dataset,
self.config["optim"].get(
"eval_batch_size", self.config["optim"]["batch_size"]
),
shuffle=False,
)
self.relax_loader = self.get_dataloader(
self.relax_dataset,
self.relax_sampler,
)
if self.config.get("relax_dataset", None):
if self.config["relax_dataset"].get("use_train_settings", True):
relax_config = self.config["dataset"].copy()
relax_config.update(self.config["relax_dataset"])
else:
relax_config = self.config["relax_dataset"]

self.relax_dataset = registry.get_dataset_class(
relax_config.get("format", "lmdb")
)(relax_config)
self.relax_sampler = self.get_sampler(
self.relax_dataset,
self.config["optim"].get(
"eval_batch_size", self.config["optim"]["batch_size"]
),
shuffle=False,
)
self.relax_loader = self.get_dataloader(
self.relax_dataset,
self.relax_sampler,
)

def load_task(self):
# Normalizer for the dataset.
Expand Down

0 comments on commit ba248cb

Please sign in to comment.