From 06b5d66823278cfc92ced4704933a997055dbcb1 Mon Sep 17 00:00:00 2001 From: anuroopsriram Date: Tue, 9 Jul 2024 16:06:48 -0700 Subject: [PATCH 1/3] Updated ODAC checkpoints & configs (#755) * Updated ODAC configs * Updated ODAC checkpoints --- configs/odac/s2ef/eqv2_153M.yml | 12 ++++-------- configs/odac/s2ef/eqv2_31M.yml | 8 ++------ src/fairchem/core/models/pretrained_models.yml | 4 ++-- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/configs/odac/s2ef/eqv2_153M.yml b/configs/odac/s2ef/eqv2_153M.yml index 75f51fd89..ef5039aac 100755 --- a/configs/odac/s2ef/eqv2_153M.yml +++ b/configs/odac/s2ef/eqv2_153M.yml @@ -45,8 +45,8 @@ model: weight_init: 'uniform' # ['uniform', 'normal'] - norm_scale_nodes: 192.561 - norm_scale_degree: 21.024127419363214 + avg_num_nodes: 192.561 + avg_degree: 21.024127419363214 optim: batch_size: 1 @@ -59,12 +59,8 @@ optim: optimizer: AdamW optimizer_params: weight_decay: 0.3 - scheduler: LambdaLR - scheduler_params: - lambda_type: cosine - warmup_factor: 0.2 - warmup_epochs: 0.01 - lr_min_factor: 0.01 + scheduler: CosineAnnealingLR + T_max: 1600000 max_epochs: 1 clip_grad_norm: 100 diff --git a/configs/odac/s2ef/eqv2_31M.yml b/configs/odac/s2ef/eqv2_31M.yml index 94d00d336..a557368e8 100644 --- a/configs/odac/s2ef/eqv2_31M.yml +++ b/configs/odac/s2ef/eqv2_31M.yml @@ -67,12 +67,8 @@ optim: optimizer: AdamW optimizer_params: weight_decay: 0.3 - scheduler: LambdaLR - scheduler_params: - lambda_type: cosine - warmup_factor: 0.2 - warmup_epochs: 0.01 - lr_min_factor: 0.01 + scheduler: CosineAnnealingLR + T_max: 1600000 max_epochs: 3 clip_grad_norm: 100 diff --git a/src/fairchem/core/models/pretrained_models.yml b/src/fairchem/core/models/pretrained_models.yml index e3f31500f..9a5cda1bd 100644 --- a/src/fairchem/core/models/pretrained_models.yml +++ b/src/fairchem/core/models/pretrained_models.yml @@ -60,8 +60,8 @@ "PaiNN-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/PaiNN.pt" "GemNet-OC-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Gemnet-OC.pt" "eSCN-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/eSCN.pt" -"EquiformerV2-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231116/eqv2_31M.pt" -"EquiformerV2-Large-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Equiformer_V2_Large.pt" +"EquiformerV2-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20240709/eqv2_31M.pt" +"EquiformerV2-Large-S2EF-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20240709/Equiformer_V2_Large.pt" "Gemnet-OC-IS2RE-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Gemnet-OC_Direct.pt" "eSCN-IS2RE-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/eSCN_Direct.pt" "EquiformerV2-IS2RE-ODAC": "https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Equiformer_V2_Direct.pt" From 712e723076d1976a9643b2de44f5adf459bdb978 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:59:41 -0700 Subject: [PATCH 2/3] Add utils to help run torch profiling (#754) * add profiler to ocp trainer * add profiler utils * type check lint * fix typo * remove it from ocp_trainer * add comment --- src/fairchem/core/common/logger.py | 10 +++++ src/fairchem/core/common/profiler_utils.py | 50 ++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 src/fairchem/core/common/profiler_utils.py diff --git a/src/fairchem/core/common/logger.py b/src/fairchem/core/common/logger.py index fd52756e2..c1d501920 100644 --- a/src/fairchem/core/common/logger.py +++ b/src/fairchem/core/common/logger.py @@ -56,6 +56,9 @@ def mark_preempting(self) -> None: def log_summary(self, summary_dict: dict[str, Any]) -> None: pass + @abstractmethod + def log_artifact(self, name: str, type: str, file_location: str) -> None: + pass @registry.register_logger("wandb") class WandBLogger(Logger): @@ -101,6 +104,10 @@ def log_summary(self, summary_dict: dict[str, Any]): def mark_preempting(self) -> None: wandb.mark_preempting() + def log_artifact(self, name: str, type: str, file_location: str) -> None: + art = wandb.Artifact(name=name, type=type) + art.add_file(file_location) + art.save() @registry.register_logger("tensorboard") class TensorboardLogger(Logger): @@ -130,3 +137,6 @@ def log_plots(self, plots) -> None: def log_summary(self, summary_dict: dict[str, Any]) -> None: logging.warning("log_summary for Tensorboard not supported") + + def log_artifact(self, name: str, type: str, file_location: str) -> None: + logging.warning("log_artifact for Tensorboard not supported") diff --git a/src/fairchem/core/common/profiler_utils.py b/src/fairchem/core/common/profiler_utils.py new file mode 100644 index 000000000..0828cb673 --- /dev/null +++ b/src/fairchem/core/common/profiler_utils.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import torch + +from fairchem.core.common import distutils + +if TYPE_CHECKING: + from fairchem.core.common.logger import Logger + +def get_default_profiler_handler(run_id: str, output_dir: str, logger: Logger): + """Get a standard callback handle for the pytorch profiler""" + + def trace_handler(p): + if distutils.is_master(): + trace_name = f"{run_id}_rank_{distutils.get_rank()}.pt.trace.json" + output_path = os.path.join(output_dir, trace_name) + print(f"Saving trace in {output_path}") + p.export_chrome_trace(output_path) + if logger: + logger.log_artifact(name=trace_name, type="profile", file_location=output_path) + return trace_handler + +def get_profile_schedule(wait: int = 5, warmup: int = 5, active: int = 2): + """Get a profile schedule and total number of steps to run + check pytorch docs on the meaning of these paramters: + https://pytorch.org/docs/stable/profiler.html#torch.profiler.schedule + Example usage: + ``` + trace_handler = get_default_profiler_handler(run_id = self.config["cmd"]["timestamp_id"], + output_dir = self.config["cmd"]["results_dir"], + logger = self.logger) + profile_schedule, total_profile_steps = get_profile_schedule() + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=profile_schedule, + on_trace_ready=trace_handler + ) as p: + for i in steps: + + if i < total_profile_steps: + p.step() + """ + total_profile_steps = wait + warmup + active + profile_schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active) + + return profile_schedule, total_profile_steps From 51a439ea227c2efc72dcac03d6c26a543eddd75b Mon Sep 17 00:00:00 2001 From: anuroopsriram Date: Wed, 10 Jul 2024 16:16:04 -0700 Subject: [PATCH 3/3] Make relaxation data more general (#714) * Changes to relaxation dataset * Changes to relaxation dataset --------- Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> Co-authored-by: Luis Barroso-Luque --- src/fairchem/core/trainers/base_trainer.py | 38 +++++++++++++--------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 0da40320a..92952b805 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -162,6 +162,7 @@ def __init__( self.config["dataset"] = dataset.get("train", None) self.config["val_dataset"] = dataset.get("val", None) self.config["test_dataset"] = dataset.get("test", None) + self.config["relax_dataset"] = dataset.get("relax", None) else: self.config["dataset"] = dataset @@ -339,22 +340,27 @@ def load_datasets(self) -> None: self.test_sampler, ) - # load relaxation dataset - if "relax_dataset" in self.config["task"]: - self.relax_dataset = registry.get_dataset_class("lmdb")( - self.config["task"]["relax_dataset"] - ) - self.relax_sampler = self.get_sampler( - self.relax_dataset, - self.config["optim"].get( - "eval_batch_size", self.config["optim"]["batch_size"] - ), - shuffle=False, - ) - self.relax_loader = self.get_dataloader( - self.relax_dataset, - self.relax_sampler, - ) + if self.config.get("relax_dataset", None): + if self.config["relax_dataset"].get("use_train_settings", True): + relax_config = self.config["dataset"].copy() + relax_config.update(self.config["relax_dataset"]) + else: + relax_config = self.config["relax_dataset"] + + self.relax_dataset = registry.get_dataset_class( + relax_config.get("format", "lmdb") + )(relax_config) + self.relax_sampler = self.get_sampler( + self.relax_dataset, + self.config["optim"].get( + "eval_batch_size", self.config["optim"]["batch_size"] + ), + shuffle=False, + ) + self.relax_loader = self.get_dataloader( + self.relax_dataset, + self.relax_sampler, + ) def load_task(self): # Normalizer for the dataset.