From 1f0f631f612b3cf273b3b84fe3c8094367348354 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:16:05 -0700 Subject: [PATCH 1/6] Fuse all hydras (#814) * make hydra compat with multitask * remove finetune hydra * fix tests * update comment * update logic slightly * ruff * fix test * add map location * ruff * fix tests * get device from input in hydra * update ocp_hydra_example.yml * add logging * update --- configs/ocp_hydra_example.yml | 2 + src/fairchem/core/common/utils.py | 19 ++ src/fairchem/core/models/base.py | 119 +++++++----- src/fairchem/core/models/finetune_hydra.py | 177 ------------------ src/fairchem/core/trainers/base_trainer.py | 8 - src/fairchem/core/trainers/ocp_trainer.py | 15 +- tests/core/e2e/test_e2e_finetune_hydra.py | 49 +++-- tests/core/e2e/test_s2efs.py | 2 +- .../models/test_configs/test_dpp_hydra.yml | 3 + .../test_configs/test_equiformerv2_hydra.yml | 2 + .../models/test_configs/test_escn_hydra.yml | 2 + .../test_configs/test_finetune_hydra.yml | 2 + .../test_configs/test_gemnet_dt_hydra.yml | 3 + .../test_gemnet_dt_hydra_grad.yml | 3 + .../test_configs/test_gemnet_oc_hydra.yml | 2 + .../test_gemnet_oc_hydra_grad.yml | 3 + .../models/test_configs/test_painn_hydra.yml | 2 + tests/core/models/test_equiformer_v2.py | 4 +- 18 files changed, 146 insertions(+), 271 deletions(-) delete mode 100644 src/fairchem/core/models/finetune_hydra.py diff --git a/configs/ocp_hydra_example.yml b/configs/ocp_hydra_example.yml index dbcadeff3..10373ad61 100755 --- a/configs/ocp_hydra_example.yml +++ b/configs/ocp_hydra_example.yml @@ -22,9 +22,11 @@ logger: wandb outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index 669449f0b..e762dfeb5 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -10,6 +10,7 @@ import ast import collections import copy +import errno import importlib import itertools import json @@ -38,6 +39,7 @@ from torch_scatter import scatter, segment_coo, segment_csr import fairchem.core +from fairchem.core.common.registry import registry from fairchem.core.modules.loss import AtomwiseL2Loss, L2MAELoss if TYPE_CHECKING: @@ -1370,3 +1372,20 @@ def get_loss_module(loss_name): raise NotImplementedError(f"Unknown loss function name: {loss_name}") return loss_fn + + +def load_model_and_weights_from_checkpoint(checkpoint_path: str) -> nn.Module: + if not os.path.isfile(checkpoint_path): + raise FileNotFoundError( + errno.ENOENT, "Checkpoint file not found", checkpoint_path + ) + logging.info(f"Loading checkpoint from: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) + # this assumes the checkpont also contains the config with the full model in it + # TODO: need to schematize how we save and load the config from checkpoint + config = checkpoint["config"]["model"] + name = config.pop("name") + model = registry.get_model_class(name)(**config) + matched_dict = match_state_dict(model.state_dict(), checkpoint["state_dict"]) + load_state_dict(model, matched_dict, strict=True) + return model diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index c070fea4e..480ee7d02 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -9,7 +9,7 @@ import copy import logging -from abc import ABC, ABCMeta, abstractmethod +from abc import ABCMeta, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING @@ -21,6 +21,7 @@ from fairchem.core.common.utils import ( compute_neighbors, get_pbc_distances, + load_model_and_weights_from_checkpoint, radius_graph_pbc, ) @@ -232,64 +233,79 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: return -class HydraInterface(ABC): - # a hydra has a backbone and heads - @abstractmethod - def get_backbone(self) -> BackboneInterface: - raise not NotImplementedError - - @abstractmethod - def get_heads(self) -> dict[str, HeadInterface]: - raise not NotImplementedError - - @registry.register_model("hydra") -class HydraModel(nn.Module, GraphModelMixin, HydraInterface): +class HydraModel(nn.Module, GraphModelMixin): def __init__( self, - backbone: dict, - heads: dict, + backbone: dict | None = None, + heads: dict | None = None, + finetune_config: dict | None = None, otf_graph: bool = True, + pass_through_head_outputs: bool = False, ): super().__init__() + self.device = None self.otf_graph = otf_graph - self.device = "cpu" - # make a copy so we don't modify the original config - backbone = copy.deepcopy(backbone) - heads = copy.deepcopy(heads) - - backbone_model_name = backbone.pop("model") - self.backbone: BackboneInterface = registry.get_model_class( - backbone_model_name - )( - **backbone, - ) - - # Iterate through outputs_cfg and create heads - self.output_heads: dict[str, HeadInterface] = {} - - head_names_sorted = sorted(heads.keys()) - for head_name in head_names_sorted: - head_config = heads[head_name] - if "module" not in head_config: - raise ValueError( - f"{head_name} head does not specify module to use for the head" - ) - - module_name = head_config.pop("module") - self.output_heads[head_name] = registry.get_model_class(module_name)( - self.backbone, - **head_config, + # This is required for hydras with models that have multiple outputs per head, since we will deprecate + # the old config system at some point, this will prevent the need to make major modifications to the trainer + # because they all expect the name of the outputs directly instead of the head_name.property_name + self.pass_through_head_outputs = pass_through_head_outputs + + # if finetune_config is provided, then attempt to load the model from the given finetune checkpoint + starting_model = None + if finetune_config is not None: + starting_model: HydraModel = load_model_and_weights_from_checkpoint(finetune_config["starting_checkpoint"]) + logging.info(f"Found and loaded fine-tuning checkpoint: {finetune_config['starting_checkpoint']} (Note we are NOT loading the training state from this checkpoint, only parts of the model and weights)") + assert isinstance(starting_model, HydraModel), "Can only finetune starting from other hydra models!" + + if backbone is not None: + backbone = copy.deepcopy(backbone) + backbone_model_name = backbone.pop("model") + self.backbone: BackboneInterface = registry.get_model_class( + backbone_model_name + )( + **backbone, ) + elif starting_model is not None: + self.backbone = starting_model.backbone + logging.info(f"User did not specify a backbone, using the backbone from the starting checkpoint {self.backbone}") + else: + raise RuntimeError("Backbone not specified and not found in the starting checkpoint") + + if heads is not None: + heads = copy.deepcopy(heads) + # Iterate through outputs_cfg and create heads + self.output_heads: dict[str, HeadInterface] = {} + + head_names_sorted = sorted(heads.keys()) + assert len(set(head_names_sorted)) == len(head_names_sorted), "Head names must be unique!" + for head_name in head_names_sorted: + head_config = heads[head_name] + if "module" not in head_config: + raise ValueError( + f"{head_name} head does not specify module to use for the head" + ) - self.output_heads = torch.nn.ModuleDict(self.output_heads) + module_name = head_config.pop("module") + self.output_heads[head_name] = registry.get_model_class(module_name)( + self.backbone, + **head_config, + ) - def to(self, *args, **kwargs): - if "device" in kwargs: - self.device = kwargs["device"] - return super().to(*args, **kwargs) + self.output_heads = torch.nn.ModuleDict(self.output_heads) + elif starting_model is not None: + self.output_heads = starting_model.output_heads + logging.info(f"User did not specify heads, using the output heads from the starting checkpoint {self.output_heads}") + else: + raise RuntimeError("Heads not specified and not found in the starting checkpoint") def forward(self, data: Batch): + # lazily get device from input to use with amp, at least one input must be a tensor to figure out it's device + if not self.device: + device_from_tensors = {x.device.type for x in data.values() if isinstance(x, torch.Tensor)} + assert len(device_from_tensors) == 1, f"all inputs must be on the same device, found the following devices {device_from_tensors}" + self.device = device_from_tensors.pop() + emb = self.backbone(data) # Predict all output properties for all structures in the batch for now. out = {} @@ -297,12 +313,11 @@ def forward(self, data: Batch): with torch.autocast( device_type=self.device, enabled=self.output_heads[k].use_amp ): - out.update(self.output_heads[k](data, emb)) + if self.pass_through_head_outputs: + out.update(self.output_heads[k](data, emb)) + else: + out[k] = self.output_heads[k](data, emb) return out - def get_backbone(self) -> BackboneInterface: - return self.backbone - def get_heads(self) -> dict[str, HeadInterface]: - return self.output_heads diff --git a/src/fairchem/core/models/finetune_hydra.py b/src/fairchem/core/models/finetune_hydra.py deleted file mode 100644 index 6c271e24e..000000000 --- a/src/fairchem/core/models/finetune_hydra.py +++ /dev/null @@ -1,177 +0,0 @@ -from __future__ import annotations - -import copy -import errno -import logging -import os -from enum import Enum -from typing import TYPE_CHECKING - -import torch -from torch import nn - -from fairchem.core.common.registry import registry -from fairchem.core.common.utils import load_state_dict, match_state_dict -from fairchem.core.models.base import BackboneInterface, HeadInterface, HydraInterface - -if TYPE_CHECKING: - from torch_geometric.data import Batch - -FTHYDRA_NAME = "finetune_hydra" - -class FineTuneMode(Enum): - # in DATA_ONLY, we load the entire model and only finetune on new data - DATA_ONLY = 1 - # in this mode, we only load the Backbone and feed the output of the backbone - # to new heads that are specified - RETAIN_BACKBONE_ONLY = 2 - - -def get_model_config_from_checkpoint(checkpoint_path: str) -> dict: - if not os.path.isfile(checkpoint_path): - raise FileNotFoundError( - errno.ENOENT, "Checkpoint file not found", checkpoint_path - ) - checkpoint = torch.load(checkpoint_path) - return checkpoint["config"]["model"] - - -def load_hydra_model(checkpoint_path: str) -> HydraInterface: - if not os.path.isfile(checkpoint_path): - raise FileNotFoundError( - errno.ENOENT, "Checkpoint file not found", checkpoint_path - ) - logging.info(f"Loading checkpoint from: {checkpoint_path}") - checkpoint = torch.load(checkpoint_path) - config = checkpoint["config"]["model"] - name = config.pop("name") - hydra_model = registry.get_model_class(name)(**config) - assert isinstance( - hydra_model, HydraInterface - ), "Can only load models with the HydraInterface" - matched_dict = match_state_dict(hydra_model.state_dict(), checkpoint["state_dict"]) - load_state_dict(hydra_model, matched_dict, strict=True) - return hydra_model - - -class FTConfig: - FT_CONFIG_NAME = "finetune_config" - STARTING_CHECKPOINT = "starting_checkpoint" - STARTING_MODEL = "starting_model" - MODE = "mode" - HEADS = "heads" - - def __init__(self, config: dict): - self.config = config - self._mode = FineTuneMode[self.config[FTConfig.MODE]] - assert ( - (FTConfig.STARTING_CHECKPOINT in self.config) - or (FTConfig.STARTING_MODEL in self.config) - ), "Either a starting checkpoint or a starting model must be provided!" - assert FTConfig.MODE in self.config - if self._mode == FineTuneMode.RETAIN_BACKBONE_ONLY: - # in this mode, we keep the backbone but attach new output heads specified in head config - assert ( - FTConfig.HEADS in self.config - ), "heads cannot be empty when using RETAIN_BACKBONE_ONLY mode!" - - def load_model(self) -> nn.Module: - # if provided a hydra config to start, build from the starting hydra model - # this assumes the weights are loaded from the state_dict in the checkpoint.pt file instead - # so no actual weights are loaded here - if FTConfig.STARTING_MODEL in self.config: - # register model from hydra_config - config_copy = copy.deepcopy(self.config[FTConfig.STARTING_MODEL]) - name = config_copy.pop("name") - hydra_model = registry.get_model_class(name)(**config_copy) - # if provided a checkpoint to start then load the model and weights from the given checkpoint - # this happens used in the beginning of a finetuning run - elif FTConfig.STARTING_CHECKPOINT in self.config: - hydra_model: HydraInterface = load_hydra_model( - self.config[FTConfig.STARTING_CHECKPOINT] - ) - assert isinstance(hydra_model, HydraInterface) - - num_params = sum(p.numel() for p in hydra_model.parameters()) - logging.info(f"Loaded Original hydra model with {num_params} params") - return hydra_model - - def get_standalone_config(self) -> dict: - # replace a config with a checkpoint with one that has the model config only - # this is required for standalone prediction (so we don't need to ship the original checkpoint), - # multi-round finetuning, and better robustness - standalone_config = { - "name": FTHYDRA_NAME, - FTConfig.FT_CONFIG_NAME: self.config, - } - if FTConfig.STARTING_CHECKPOINT in self.config: - # modify the config to store the original model config inside model attrs - # so we dont need the checkpoint again when loading from checkpoint - new_config = copy.deepcopy(self.config) - new_config[FTConfig.STARTING_MODEL] = ( - get_model_config_from_checkpoint( - self.config[FTConfig.STARTING_CHECKPOINT] - ) - ) - standalone_config[FTConfig.FT_CONFIG_NAME] = new_config - return standalone_config - - @property - def mode(self) -> FineTuneMode: - return self._mode - - @property - def head_config(self) -> dict: - return copy.deepcopy(self.config[FTConfig.HEADS]) - - -@registry.register_model(FTHYDRA_NAME) -class FineTuneHydra(nn.Module, HydraInterface): - def __init__(self, finetune_config: dict): - super().__init__() - ft_config = FTConfig(finetune_config) - logging.info(f"Initializing FineTuneHydra model in {ft_config.mode} mode") - hydra_model: HydraInterface = ft_config.load_model() - self.backbone: BackboneInterface = hydra_model.get_backbone() - - if ft_config.mode == FineTuneMode.DATA_ONLY: - # in this mode, we just use the model as is and train on it with new data - self.output_heads: dict[str, HeadInterface] = hydra_model.get_heads() - elif ft_config.mode == FineTuneMode.RETAIN_BACKBONE_ONLY: - # in this mode, we keep the backbone but attach new output heads specified in head config - self.output_heads: dict[str, HeadInterface] = {} - heads_config = ft_config.head_config - head_names_sorted = sorted(heads_config.keys()) - for head_name in head_names_sorted: - head_config = heads_config[head_name] - if "module" not in head_config: - raise ValueError( - f"{head_name} head does not specify module to use for the head" - ) - - module_name = head_config.pop("module") - self.output_heads[head_name] = registry.get_model_class(module_name)( - self.backbone, - **head_config, - ) - num_params = sum( - p.numel() for p in self.output_heads[head_name].parameters() - ) - logging.info( - f"Attaching new output head: {module_name} with {num_params} params" - ) - self.output_heads = torch.nn.ModuleDict(self.output_heads) - - - def forward(self, data: Batch): - emb = self.backbone(data) - out = {} - for k in self.output_heads: - out.update(self.output_heads[k](data, emb)) - return out - - def get_backbone(self) -> BackboneInterface: - return self.backbone - - def get_heads(self) -> dict[str, HeadInterface]: - return self.output_heads diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index a31421a75..d84a2c12f 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -44,7 +44,6 @@ update_config, ) from fairchem.core.datasets.base_dataset import create_dataset -from fairchem.core.models.finetune_hydra import FineTuneHydra, FTConfig from fairchem.core.modules.evaluator import Evaluator from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage from fairchem.core.modules.loss import DDPLoss @@ -716,13 +715,6 @@ def save( training_state: bool = True, ) -> str | None: if not self.is_debug and distutils.is_master(): - # if we are using a FineTune-able model, then we need to modify the config to remove - # the original starting checkpoint so it can be loaded standalone, can move this to save function - if isinstance(self.model, FineTuneHydra): - self.config["model"] = FTConfig( - self.config["model"][FTConfig.FT_CONFIG_NAME] - ).get_standalone_config() - state = { "state_dict": self.model.state_dict(), "normalizers": { diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 662341bdc..0ced35bef 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -251,9 +251,18 @@ def _forward(self, batch): for target_key in self.output_targets: ### Target property is a direct output of the model if target_key in out: - pred = out[target_key] - ## Target property is a derived output of the model. Construct the - ## parent property + if isinstance(out[target_key], torch.Tensor): + pred = out[target_key] + elif isinstance(out[target_key], dict): + # if output is a nested dictionary (in the case of hydra models), we attempt to retrieve it using the property name + # ie: "output_head_name.property" + assert "property" in self.output_targets[target_key], \ + f"we need to know which property to match the target to, please specify the property field in the task config, current config: {self.output_targets[target_key]}" + property = self.output_targets[target_key]["property"] + pred = out[target_key][property] + + ## TODO: deprecate the following logic? + ## Otherwise, assume target property is a derived output of the model. Construct the parent property else: _max_rank = 0 for subtarget_key in self.output_targets[target_key]["decomposition"]: diff --git a/tests/core/e2e/test_e2e_finetune_hydra.py b/tests/core/e2e/test_e2e_finetune_hydra.py index 91f2abd49..df9def3b0 100644 --- a/tests/core/e2e/test_e2e_finetune_hydra.py +++ b/tests/core/e2e/test_e2e_finetune_hydra.py @@ -8,8 +8,6 @@ import torch from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths -from fairchem.core.models.finetune_hydra import FTHYDRA_NAME, FineTuneMode, FTConfig - @pytest.fixture() def tutorial_val_src(tutorial_dataset_path): @@ -49,7 +47,7 @@ def run_main_with_ft_hydra(tempdir: str, yaml: str, data_src: str, run_args: dict, - ft_config: str, + model_config: str, output_checkpoint: str): _run_main( tempdir, @@ -68,10 +66,7 @@ def run_main_with_ft_hydra(tempdir: str, test_src=str(data_src), otf_norms=False, ), - "model": { - "name": FTHYDRA_NAME, - FTConfig.FT_CONFIG_NAME: ft_config, - } + "model": model_config, }, update_run_args_with=run_args, save_checkpoint_to=output_checkpoint, @@ -87,9 +82,9 @@ def test_finetune_hydra_retain_backbone(tutorial_val_src): with tempfile.TemporaryDirectory() as ft_temp_dir: ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") - ft_config = { - "mode": FineTuneMode.RETAIN_BACKBONE_ONLY.name, - "starting_checkpoint": starting_ckpt, + model_config = { + "name" : "hydra", + "finetune_config": {'starting_checkpoint': starting_ckpt}, "heads": { "energy": { "module": "equiformer_v2_energy_head" @@ -103,12 +98,12 @@ def test_finetune_hydra_retain_backbone(tutorial_val_src): yaml = ft_yml, data_src = tutorial_val_src, run_args = {"seed": 1000}, - ft_config = ft_config, + model_config = model_config, output_checkpoint = ck_ft_path) assert os.path.isfile(ck_ft_path) ft_ckpt = torch.load(ck_ft_path) assert "config" in ft_ckpt - assert ft_ckpt["config"]["model"]["name"] == FTHYDRA_NAME + assert ft_ckpt["config"]["model"]["name"] == "hydra" # check that the backbone weights are the same, and other weights are not the same new_state_dict = ft_ckpt["state_dict"] for key in new_state_dict: @@ -128,28 +123,26 @@ def test_finetune_hydra_data_only(tutorial_val_src): with tempfile.TemporaryDirectory() as ft_temp_dir: ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") - ft_config = { - "mode": FineTuneMode.DATA_ONLY.name, - "starting_checkpoint": starting_ckpt, + model_config = { + "name" : "hydra", + "finetune_config": {'starting_checkpoint': starting_ckpt}, } run_main_with_ft_hydra(tempdir = ft_temp_dir, yaml = ft_yml, data_src = tutorial_val_src, run_args = {"seed": 1000}, - ft_config = ft_config, + model_config = model_config, output_checkpoint = ck_ft_path) assert os.path.isfile(ck_ft_path) ft_ckpt = torch.load(ck_ft_path) assert "config" in ft_ckpt config_model = ft_ckpt["config"]["model"] - assert config_model["name"] == FTHYDRA_NAME + assert config_model["name"] == "hydra" # check that the entire model weights are the same new_state_dict = ft_ckpt["state_dict"] assert len(new_state_dict) == len(old_state_dict) for key in new_state_dict: assert torch.allclose(new_state_dict[key], old_state_dict[key]) - # check the new checkpoint contains a hydra model - assert FTConfig.STARTING_MODEL in config_model[FTConfig.FT_CONFIG_NAME] def test_finetune_from_finetunehydra(tutorial_val_src): @@ -159,15 +152,15 @@ def test_finetune_from_finetunehydra(tutorial_val_src): with tempfile.TemporaryDirectory() as finetune_run1_dir: ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") ck_ft_path = os.path.join(finetune_run1_dir, "checkpoint_ft.pt") - ft_config_1 = { - "mode": FineTuneMode.DATA_ONLY.name, - "starting_checkpoint": starting_ckpt, + model_config_1 = { + "name" : "hydra", + "finetune_config": {'starting_checkpoint': starting_ckpt}, } run_main_with_ft_hydra(tempdir = finetune_run1_dir, yaml = ft_yml, data_src = tutorial_val_src, run_args = {"seed": 1000}, - ft_config = ft_config_1, + model_config = model_config_1, output_checkpoint = ck_ft_path) assert os.path.isfile(ck_ft_path) @@ -175,20 +168,20 @@ def test_finetune_from_finetunehydra(tutorial_val_src): ######################################################################################## with tempfile.TemporaryDirectory() as finetune_run2_dir: ck_ft2_path = os.path.join(finetune_run2_dir, "checkpoint_ft.pt") - ft_config_2 = { - "mode": FineTuneMode.DATA_ONLY.name, - "starting_checkpoint": ck_ft_path, + model_config_2 = { + "name" : "hydra", + "finetune_config": {'starting_checkpoint': ck_ft_path}, } run_main_with_ft_hydra(tempdir = finetune_run2_dir, yaml = ft_yml, data_src = tutorial_val_src, run_args = {"seed": 1000}, - ft_config = ft_config_2, + model_config = model_config_2, output_checkpoint = ck_ft2_path) ft_ckpt2 = torch.load(ck_ft2_path) assert "config" in ft_ckpt2 config_model = ft_ckpt2["config"]["model"] - assert config_model["name"] == FTHYDRA_NAME + assert config_model["name"] == "hydra" old_state_dict = torch.load(ck_ft_path)["state_dict"] new_state_dict = ft_ckpt2["state_dict"] # the state dicts should still be identical because we made the LR = 0.0 diff --git a/tests/core/e2e/test_s2efs.py b/tests/core/e2e/test_s2efs.py index 94b0862ed..037979e60 100644 --- a/tests/core/e2e/test_s2efs.py +++ b/tests/core/e2e/test_s2efs.py @@ -44,7 +44,7 @@ def test_smoke_s2efs_predict( {"forces": {"fn": "l2mae", "coefficient": 100}}, {"stress": {"fn": "mae", "coefficient": 100}}, ], - "outputs": {"stress": {"level": "system", "irrep_dim": 2}}, + "outputs": {"stress": {"level": "system", "irrep_dim": 2, "property": "stress"}}, "evaluation_metrics": {"metrics": {"stress": ["mae"]}}, "dataset": { "train": { diff --git a/tests/core/models/test_configs/test_dpp_hydra.yml b/tests/core/models/test_configs/test_dpp_hydra.yml index e41a39141..24f13e6ad 100755 --- a/tests/core/models/test_configs/test_dpp_hydra.yml +++ b/tests/core/models/test_configs/test_dpp_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True @@ -52,6 +54,7 @@ model: heads: energy: module: dimenetplusplus_energy_and_force_head + pass_through_head_outputs: True # *** Important note *** # The total number of gpus used for this run was 256. diff --git a/tests/core/models/test_configs/test_equiformerv2_hydra.yml b/tests/core/models/test_configs/test_equiformerv2_hydra.yml index 1852799f5..9747eec80 100644 --- a/tests/core/models/test_configs/test_equiformerv2_hydra.yml +++ b/tests/core/models/test_configs/test_equiformerv2_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/tests/core/models/test_configs/test_escn_hydra.yml b/tests/core/models/test_configs/test_escn_hydra.yml index c51d46fc3..8d730dad0 100644 --- a/tests/core/models/test_configs/test_escn_hydra.yml +++ b/tests/core/models/test_configs/test_escn_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/tests/core/models/test_configs/test_finetune_hydra.yml b/tests/core/models/test_configs/test_finetune_hydra.yml index a5f1dc51b..e1be6d20b 100644 --- a/tests/core/models/test_configs/test_finetune_hydra.yml +++ b/tests/core/models/test_configs/test_finetune_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml index 036ed689f..4b5c239fc 100644 --- a/tests/core/models/test_configs/test_gemnet_dt_hydra.yml +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True @@ -73,6 +75,7 @@ model: forces: module: gemnet_t_force_head + optim: batch_size: 8 eval_batch_size: 8 diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml index 358dd1c86..e8e4eca9c 100644 --- a/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True @@ -71,6 +73,7 @@ model: heads: energy_and_forces: module: gemnet_t_energy_and_grad_force_head + pass_through_head_outputs: True optim: batch_size: 8 diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml index 716718e3e..a58d328bd 100644 --- a/tests/core/models/test_configs/test_gemnet_oc_hydra.yml +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml index 90001488b..88cb493f4 100644 --- a/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml @@ -3,9 +3,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True @@ -96,6 +98,7 @@ model: energy: module: gemnet_oc_energy_and_grad_force_head num_global_out_layers: 2 + pass_through_head_outputs: True optim: batch_size: 5 diff --git a/tests/core/models/test_configs/test_painn_hydra.yml b/tests/core/models/test_configs/test_painn_hydra.yml index 2c4731742..a6b26c5d4 100644 --- a/tests/core/models/test_configs/test_painn_hydra.yml +++ b/tests/core/models/test_configs/test_painn_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 54d58db1c..2f0903608 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -264,12 +264,12 @@ def test_eqv2_hydra_activation_checkpoint(): # way to do this is save the rng state and reset it after stepping the first model start_rng_state = torch.random.get_rng_state() outputs_no_ac = no_ac_model(inputs) - torch.autograd.backward(outputs_no_ac["energy"].sum() + outputs_no_ac["forces"].sum()) + torch.autograd.backward(outputs_no_ac["energy"]["energy"].sum() + outputs_no_ac["forces"]["forces"].sum()) # reset the rng state to the beginning torch.random.set_rng_state(start_rng_state) outptuts_ac = ac_model(inputs) - torch.autograd.backward(outptuts_ac["energy"].sum() + outptuts_ac["forces"].sum()) + torch.autograd.backward(outptuts_ac["energy"]["energy"].sum() + outptuts_ac["forces"]["forces"].sum()) # assert all the gradients are identical between the model with checkpointing and no checkpointing ac_model_grad_dict = {name:p.grad for name, p in ac_model.named_parameters() if p.grad is not None} From 427fb8d0da3c807c4d7567c5b4acb737549c5b64 Mon Sep 17 00:00:00 2001 From: Misko Date: Tue, 20 Aug 2024 17:03:24 -0700 Subject: [PATCH 2/6] refactor and deprecate old equiformerv2 (#812) * refactor and deprecate old equiformerv2 * remove equiv2_backbone_and_heads * lint fixes * remove backbone and heads model * fix merge * split up tests * update * add in missing file --- .../core/models/equiformer_v2/__init__.py | 2 +- .../models/equiformer_v2/equiformer_v2.py | 350 ++------- .../equiformer_v2/equiformer_v2_deprecated.py | 681 ++++++++++++++++++ ...mbr => test_equiformer_v2_deprecated.ambr} | 8 +- tests/core/models/test_equiformer_v2.py | 167 +---- .../models/test_equiformer_v2_deprecated.py | 161 +++++ 6 files changed, 926 insertions(+), 443 deletions(-) create mode 100644 src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py rename tests/core/models/__snapshots__/{test_equiformer_v2.ambr => test_equiformer_v2_deprecated.ambr} (84%) create mode 100644 tests/core/models/test_equiformer_v2_deprecated.py diff --git a/src/fairchem/core/models/equiformer_v2/__init__.py b/src/fairchem/core/models/equiformer_v2/__init__.py index 720f890f6..918f0c617 100644 --- a/src/fairchem/core/models/equiformer_v2/__init__.py +++ b/src/fairchem/core/models/equiformer_v2/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from .equiformer_v2 import EquiformerV2 +from .equiformer_v2_deprecated import EquiformerV2 __all__ = ["EquiformerV2"] diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index b78f43597..978d4c226 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -11,7 +11,10 @@ from fairchem.core.common import gp_utils from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface +from fairchem.core.models.base import ( + GraphModelMixin, + HeadInterface, +) from fairchem.core.models.scn.smearing import GaussianSmearing with contextlib.suppress(ImportError): @@ -77,8 +80,8 @@ def eqv2_uniform_init_linear_weights(m): torch.nn.init.uniform_(m.weight, -std, std) -@registry.register_model("equiformer_v2") -class EquiformerV2(nn.Module, GraphModelMixin): +@registry.register_model("equiformer_v2_backbone") +class EquiformerV2Backbone(nn.Module, GraphModelMixin): """ Equiformer with graph attention built upon SO(2) convolution and feedforward network built upon S2 activation @@ -380,43 +383,6 @@ def __init__( lmax=max(self.lmax_list), num_channels=self.sphere_channels, ) - self.energy_block = FeedForwardNetwork( - self.sphere_channels, - self.ffn_hidden_channels, - 1, - self.lmax_list, - self.mmax_list, - self.SO3_grid, - self.ffn_activation, - self.use_gate_act, - self.use_grid_mlp, - self.use_sep_s2_act, - ) - if self.regress_forces: - self.force_block = SO2EquivariantGraphAttention( - self.sphere_channels, - self.attn_hidden_channels, - self.num_heads, - self.attn_alpha_channels, - self.attn_value_channels, - 1, - self.lmax_list, - self.mmax_list, - self.SO3_rotation, - self.mappingReduced, - self.SO3_grid, - self.max_num_elements, - self.edge_channels_list, - self.block_use_atom_edge_embedding, - self.use_m_share_rad, - self.attn_activation, - self.use_s2_act_attn, - self.use_attn_renorm, - self.use_gate_act, - self.use_sep_s2_act, - alpha_drop=0.0, - ) - if self.load_energy_lin_ref: self.energy_lin_ref = nn.Parameter( torch.zeros(self.max_num_elements), @@ -425,44 +391,8 @@ def __init__( self.apply(partial(eqv2_init_weights, weight_init=self.weight_init)) - def _init_gp_partitions( - self, - atomic_numbers_full, - data_batch_full, - edge_index, - edge_distance, - edge_distance_vec, - ): - """Graph Parallel - This creates the required partial tensors for each rank given the full tensors. - The tensors are split on the dimension along the node index using node_partition. - """ - node_partition = gp_utils.scatter_to_model_parallel_region( - torch.arange(len(atomic_numbers_full)).to(self.device) - ) - edge_partition = torch.where( - torch.logical_and( - edge_index[1] >= node_partition.min(), - edge_index[1] <= node_partition.max(), # TODO: 0 or 1? - ) - )[0] - edge_index = edge_index[:, edge_partition] - edge_distance = edge_distance[edge_partition] - edge_distance_vec = edge_distance_vec[edge_partition] - atomic_numbers = atomic_numbers_full[node_partition] - data_batch = data_batch_full[node_partition] - node_offset = node_partition.min().item() - return ( - atomic_numbers, - data_batch, - node_offset, - edge_index, - edge_distance, - edge_distance_vec, - ) - @conditional_grad(torch.enable_grad()) - def forward(self, data): + def forward(self, data: Batch) -> dict[str, torch.Tensor]: self.batch_size = len(data.natoms) self.dtype = data.pos.dtype self.device = data.pos.device @@ -574,75 +504,67 @@ def forward(self, data): ############################################################### for i in range(self.num_layers): - x = self.blocks[i]( - x, # SO3_Embedding - graph.atomic_numbers_full, - graph.edge_distance, - graph.edge_index, - batch=data_batch, # for GraphDropPath - node_offset=graph.node_offset, - ) + if self.activation_checkpoint: + x = torch.utils.checkpoint.checkpoint( + self.blocks[i], + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + data_batch, # for GraphDropPath + graph.node_offset, + use_reentrant=not self.training, + ) + else: + x = self.blocks[i]( + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + batch=data_batch, # for GraphDropPath + node_offset=graph.node_offset, + ) # Final layer norm x.embedding = self.norm(x.embedding) - ############################################################### - # Energy estimation - ############################################################### - node_energy = self.energy_block(x) - node_energy = node_energy.embedding.narrow(1, 0, 1) - if gp_utils.initialized(): - node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0) - energy = torch.zeros( - len(data.natoms), - device=node_energy.device, - dtype=node_energy.dtype, - ) - energy.index_add_(0, graph.batch_full, node_energy.view(-1)) - energy = energy / self.avg_num_nodes - - # Add the per-atom linear references to the energy. - if self.use_energy_lin_ref and self.load_energy_lin_ref: - # During training, target E = (E_DFT - E_ref - E_mean) / E_std, and - # during inference, \hat{E_DFT} = \hat{E} * E_std + E_ref + E_mean - # where - # - # E_DFT = raw DFT energy, - # E_ref = reference energy, - # E_mean = normalizer mean, - # E_std = normalizer std, - # \hat{E} = predicted energy, - # \hat{E_DFT} = predicted DFT energy. - # - # We can also write this as - # \hat{E_DFT} = E_std * (\hat{E} + E_ref / E_std) + E_mean, - # which is why we save E_ref / E_std as the linear reference. - with torch.cuda.amp.autocast(False): - energy = energy.to(self.energy_lin_ref.dtype).index_add( - 0, - graph.batch_full, - self.energy_lin_ref[graph.atomic_numbers_full], - ) + return {"node_embedding": x, "graph": graph} - outputs = {"energy": energy} - ############################################################### - # Force estimation - ############################################################### - if self.regress_forces: - forces = self.force_block( - x, - graph.atomic_numbers_full, - graph.edge_distance, - graph.edge_index, - node_offset=graph.node_offset, + def _init_gp_partitions( + self, + atomic_numbers_full, + data_batch_full, + edge_index, + edge_distance, + edge_distance_vec, + ): + """Graph Parallel + This creates the required partial tensors for each rank given the full tensors. + The tensors are split on the dimension along the node index using node_partition. + """ + node_partition = gp_utils.scatter_to_model_parallel_region( + torch.arange(len(atomic_numbers_full)).to(self.device) + ) + edge_partition = torch.where( + torch.logical_and( + edge_index[1] >= node_partition.min(), + edge_index[1] <= node_partition.max(), # TODO: 0 or 1? ) - forces = forces.embedding.narrow(1, 1, 3) - forces = forces.view(-1, 3).contiguous() - if gp_utils.initialized(): - forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) - outputs["forces"] = forces - - return outputs + )[0] + edge_index = edge_index[:, edge_partition] + edge_distance = edge_distance[edge_partition] + edge_distance_vec = edge_distance_vec[edge_partition] + atomic_numbers = atomic_numbers_full[node_partition] + data_batch = data_batch_full[node_partition] + node_offset = node_partition.min().item() + return ( + atomic_numbers, + data_batch, + node_offset, + edge_index, + edge_distance, + edge_distance_vec, + ) # Initialize the edge rotation matrics def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): @@ -683,154 +605,6 @@ def no_weight_decay(self) -> set: return set(no_wd_list) -@registry.register_model("equiformer_v2_backbone") -class EquiformerV2Backbone(EquiformerV2, BackboneInterface): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # TODO remove these once we deprecate/stop-inheriting EquiformerV2 class - self.energy_block = None - self.force_block = None - - @conditional_grad(torch.enable_grad()) - def forward(self, data: Batch) -> dict[str, torch.Tensor]: - self.batch_size = len(data.natoms) - self.dtype = data.pos.dtype - self.device = data.pos.device - atomic_numbers = data.atomic_numbers.long() - graph = self.generate_graph( - data, - enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, - ) - - data_batch = data.batch - if gp_utils.initialized(): - ( - atomic_numbers, - data_batch, - node_offset, - edge_index, - edge_distance, - edge_distance_vec, - ) = self._init_gp_partitions( - graph.atomic_numbers_full, - graph.batch_full, - graph.edge_index, - graph.edge_distance, - graph.edge_distance_vec, - ) - graph.node_offset = node_offset - graph.edge_index = edge_index - graph.edge_distance = edge_distance - graph.edge_distance_vec = edge_distance_vec - - ############################################################### - # Entering Graph Parallel Region - # after this point, if using gp, then node, edge tensors are split - # across the graph parallel ranks, some full tensors such as - # atomic_numbers_full are required because we need to index into the - # full graph when computing edge embeddings or reducing nodes from neighbors - # - # all tensors that do not have the suffix "_full" refer to the partial tensors. - # if not using gp, the full values are equal to the partial values - # ie: atomic_numbers_full == atomic_numbers - ############################################################### - - ############################################################### - # Initialize data structures - ############################################################### - - # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec - ) - - # Initialize the WignerD matrices and other values for spherical harmonic calculations - for i in range(self.num_resolutions): - self.SO3_rotation[i].set_wigner(edge_rot_mat) - - ############################################################### - # Initialize node embeddings - ############################################################### - - # Init per node representations using an atomic number based embedding - x = SO3_Embedding( - len(atomic_numbers), - self.lmax_list, - self.sphere_channels, - self.device, - self.dtype, - ) - - offset_res = 0 - offset = 0 - # Initialize the l = 0, m = 0 coefficients for each resolution - for i in range(self.num_resolutions): - if self.num_resolutions == 1: - x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers) - else: - x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ - :, offset : offset + self.sphere_channels - ] - offset = offset + self.sphere_channels - offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) - - # Edge encoding (distance and atom edge) - graph.edge_distance = self.distance_expansion(graph.edge_distance) - if self.share_atom_edge_embedding and self.use_atom_edge_embedding: - source_element = graph.atomic_numbers_full[ - graph.edge_index[0] - ] # Source atom atomic number - target_element = graph.atomic_numbers_full[ - graph.edge_index[1] - ] # Target atom atomic number - source_embedding = self.source_embedding(source_element) - target_embedding = self.target_embedding(target_element) - graph.edge_distance = torch.cat( - (graph.edge_distance, source_embedding, target_embedding), dim=1 - ) - - # Edge-degree embedding - edge_degree = self.edge_degree_embedding( - graph.atomic_numbers_full, - graph.edge_distance, - graph.edge_index, - len(atomic_numbers), - graph.node_offset, - ) - x.embedding = x.embedding + edge_degree.embedding - - ############################################################### - # Update spherical node embeddings - ############################################################### - - for i in range(self.num_layers): - if self.activation_checkpoint: - x = torch.utils.checkpoint.checkpoint( - self.blocks[i], - x, # SO3_Embedding - graph.atomic_numbers_full, - graph.edge_distance, - graph.edge_index, - data_batch, # for GraphDropPath - graph.node_offset, - use_reentrant=not self.training, - ) - else: - x = self.blocks[i]( - x, # SO3_Embedding - graph.atomic_numbers_full, - graph.edge_distance, - graph.edge_index, - batch=data_batch, # for GraphDropPath - node_offset=graph.node_offset, - ) - - # Final layer norm - x.embedding = self.norm(x.embedding) - - return {"node_embedding": x, "graph": graph} - - @registry.register_model("equiformer_v2_energy_head") class EquiformerV2EnergyHead(nn.Module, HeadInterface): def __init__(self, backbone): diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py new file mode 100644 index 000000000..0dedecd86 --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py @@ -0,0 +1,681 @@ +from __future__ import annotations + +import contextlib +import logging +import math + +import torch +import torch.nn as nn + +from fairchem.core.common import gp_utils +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import conditional_grad +from fairchem.core.models.base import GraphModelMixin +from fairchem.core.models.scn.smearing import GaussianSmearing + +with contextlib.suppress(ImportError): + pass + + + +from .edge_rot_mat import init_edge_rot_mat +from .gaussian_rbf import GaussianRadialBasisLayer +from .input_block import EdgeDegreeEmbedding +from .layer_norm import ( + EquivariantLayerNormArray, + EquivariantLayerNormArraySphericalHarmonics, + EquivariantRMSNormArraySphericalHarmonics, + EquivariantRMSNormArraySphericalHarmonicsV2, + get_normalization_layer, +) +from .module_list import ModuleListInfo +from .radial_function import RadialFunction +from .so3 import ( + CoefficientMappingModule, + SO3_Embedding, + SO3_Grid, + SO3_LinearV2, + SO3_Rotation, +) +from .transformer_block import ( + FeedForwardNetwork, + SO2EquivariantGraphAttention, + TransBlockV2, +) + +# Statistics of IS2RE 100K +_AVG_NUM_NODES = 77.81317 +_AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100 + + +@registry.register_model("equiformer_v2") +class EquiformerV2(nn.Module, GraphModelMixin): + """ + THIS CLASS HAS BEEN DEPRECATED! Please use "EquiformerV2BackboneAndHeads" + + Equiformer with graph attention built upon SO(2) convolution and feedforward network built upon S2 activation + + Args: + use_pbc (bool): Use periodic boundary conditions + use_pbc_single (bool): Process batch PBC graphs one at a time + regress_forces (bool): Compute forces + otf_graph (bool): Compute graph On The Fly (OTF) + max_neighbors (int): Maximum number of neighbors per atom + max_radius (float): Maximum distance between nieghboring atoms in Angstroms + max_num_elements (int): Maximum atomic number + + num_layers (int): Number of layers in the GNN + sphere_channels (int): Number of spherical channels (one set per resolution) + attn_hidden_channels (int): Number of hidden channels used during SO(2) graph attention + num_heads (int): Number of attention heads + attn_alpha_head (int): Number of channels for alpha vector in each attention head + attn_value_head (int): Number of channels for value vector in each attention head + ffn_hidden_channels (int): Number of hidden channels used during feedforward network + norm_type (str): Type of normalization layer (['layer_norm', 'layer_norm_sh', 'rms_norm_sh']) + + lmax_list (int): List of maximum degree of the spherical harmonics (1 to 10) + mmax_list (int): List of maximum order of the spherical harmonics (0 to lmax) + grid_resolution (int): Resolution of SO3_Grid + + num_sphere_samples (int): Number of samples used to approximate the integration of the sphere in the output blocks + + edge_channels (int): Number of channels for the edge invariant features + use_atom_edge_embedding (bool): Whether to use atomic embedding along with relative distance for edge scalar features + share_atom_edge_embedding (bool): Whether to share `atom_edge_embedding` across all blocks + use_m_share_rad (bool): Whether all m components within a type-L vector of one channel share radial function weights + distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu"): Basis function used for distances + + attn_activation (str): Type of activation function for SO(2) graph attention + use_s2_act_attn (bool): Whether to use attention after S2 activation. Otherwise, use the same attention as Equiformer + use_attn_renorm (bool): Whether to re-normalize attention weights + ffn_activation (str): Type of activation function for feedforward network + use_gate_act (bool): If `True`, use gate activation. Otherwise, use S2 activation + use_grid_mlp (bool): If `True`, use projecting to grids and performing MLPs for FFNs. + use_sep_s2_act (bool): If `True`, use separable S2 activation when `use_gate_act` is False. + + alpha_drop (float): Dropout rate for attention weights + drop_path_rate (float): Drop path rate + proj_drop (float): Dropout rate for outputs of attention and FFN in Transformer blocks + + weight_init (str): ['normal', 'uniform'] initialization of weights of linear layers except those in radial functions + enforce_max_neighbors_strictly (bool): When edges are subselected based on the `max_neighbors` arg, arbitrarily select amongst equidistant / degenerate edges to have exactly the correct number. + avg_num_nodes (float): Average number of nodes per graph + avg_degree (float): Average degree of nodes in the graph + + use_energy_lin_ref (bool): Whether to add the per-atom energy references during prediction. + During training and validation, this should be kept `False` since we use the `lin_ref` parameter in the OC22 dataloader to subtract the per-atom linear references from the energy targets. + During prediction (where we don't have energy targets), this can be set to `True` to add the per-atom linear references to the predicted energies. + load_energy_lin_ref (bool): Whether to add nn.Parameters for the per-element energy references. + This additional flag is there to ensure compatibility when strict-loading checkpoints, since the `use_energy_lin_ref` flag can be either True or False even if the model is trained with linear references. + You can't have use_energy_lin_ref = True and load_energy_lin_ref = False, since the model will not have the parameters for the linear references. All other combinations are fine. + """ + + def __init__( + self, + use_pbc: bool = True, + use_pbc_single: bool = False, + regress_forces: bool = True, + otf_graph: bool = True, + max_neighbors: int = 500, + max_radius: float = 5.0, + max_num_elements: int = 90, + num_layers: int = 12, + sphere_channels: int = 128, + attn_hidden_channels: int = 128, + num_heads: int = 8, + attn_alpha_channels: int = 32, + attn_value_channels: int = 16, + ffn_hidden_channels: int = 512, + norm_type: str = "rms_norm_sh", + lmax_list: list[int] | None = None, + mmax_list: list[int] | None = None, + grid_resolution: int | None = None, + num_sphere_samples: int = 128, + edge_channels: int = 128, + use_atom_edge_embedding: bool = True, + share_atom_edge_embedding: bool = False, + use_m_share_rad: bool = False, + distance_function: str = "gaussian", + num_distance_basis: int = 512, + attn_activation: str = "scaled_silu", + use_s2_act_attn: bool = False, + use_attn_renorm: bool = True, + ffn_activation: str = "scaled_silu", + use_gate_act: bool = False, + use_grid_mlp: bool = False, + use_sep_s2_act: bool = True, + alpha_drop: float = 0.1, + drop_path_rate: float = 0.05, + proj_drop: float = 0.0, + weight_init: str = "normal", + enforce_max_neighbors_strictly: bool = True, + avg_num_nodes: float | None = None, + avg_degree: float | None = None, + use_energy_lin_ref: bool | None = False, + load_energy_lin_ref: bool | None = False, + ): + logging.warning( + "equiformer_v2 (EquiformerV2) class is deprecaed in favor of equiformer_v2_backbone_and_heads (EquiformerV2BackboneAndHeads)" + ) + if mmax_list is None: + mmax_list = [2] + if lmax_list is None: + lmax_list = [6] + super().__init__() + + import sys + + if "e3nn" not in sys.modules: + logging.error("You need to install e3nn==0.4.4 to use EquiformerV2.") + raise ImportError + + self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single + self.regress_forces = regress_forces + self.otf_graph = otf_graph + self.max_neighbors = max_neighbors + self.max_radius = max_radius + self.cutoff = max_radius + self.max_num_elements = max_num_elements + + self.num_layers = num_layers + self.sphere_channels = sphere_channels + self.attn_hidden_channels = attn_hidden_channels + self.num_heads = num_heads + self.attn_alpha_channels = attn_alpha_channels + self.attn_value_channels = attn_value_channels + self.ffn_hidden_channels = ffn_hidden_channels + self.norm_type = norm_type + + self.lmax_list = lmax_list + self.mmax_list = mmax_list + self.grid_resolution = grid_resolution + + self.num_sphere_samples = num_sphere_samples + + self.edge_channels = edge_channels + self.use_atom_edge_embedding = use_atom_edge_embedding + self.share_atom_edge_embedding = share_atom_edge_embedding + if self.share_atom_edge_embedding: + assert self.use_atom_edge_embedding + self.block_use_atom_edge_embedding = False + else: + self.block_use_atom_edge_embedding = self.use_atom_edge_embedding + self.use_m_share_rad = use_m_share_rad + self.distance_function = distance_function + self.num_distance_basis = num_distance_basis + + self.attn_activation = attn_activation + self.use_s2_act_attn = use_s2_act_attn + self.use_attn_renorm = use_attn_renorm + self.ffn_activation = ffn_activation + self.use_gate_act = use_gate_act + self.use_grid_mlp = use_grid_mlp + self.use_sep_s2_act = use_sep_s2_act + + self.alpha_drop = alpha_drop + self.drop_path_rate = drop_path_rate + self.proj_drop = proj_drop + + self.avg_num_nodes = avg_num_nodes or _AVG_NUM_NODES + self.avg_degree = avg_degree or _AVG_DEGREE + + self.use_energy_lin_ref = use_energy_lin_ref + self.load_energy_lin_ref = load_energy_lin_ref + assert not ( + self.use_energy_lin_ref and not self.load_energy_lin_ref + ), "You can't have use_energy_lin_ref = True and load_energy_lin_ref = False, since the model will not have the parameters for the linear references. All other combinations are fine." + + self.weight_init = weight_init + assert self.weight_init in ["normal", "uniform"] + + self.enforce_max_neighbors_strictly = enforce_max_neighbors_strictly + + self.device = "cpu" # torch.cuda.current_device() + + self.grad_forces = False + self.num_resolutions: int = len(self.lmax_list) + self.sphere_channels_all: int = self.num_resolutions * self.sphere_channels + + # Weights for message initialization + self.sphere_embedding = nn.Embedding( + self.max_num_elements, self.sphere_channels_all + ) + + # Initialize the function used to measure the distances between atoms + assert self.distance_function in [ + "gaussian", + ] + if self.distance_function == "gaussian": + self.distance_expansion = GaussianSmearing( + 0.0, + self.cutoff, + 600, + 2.0, + ) + # self.distance_expansion = GaussianRadialBasisLayer(num_basis=self.num_distance_basis, cutoff=self.max_radius) + else: + raise ValueError + + # Initialize the sizes of radial functions (input channels and 2 hidden channels) + self.edge_channels_list = [int(self.distance_expansion.num_output)] + [ + self.edge_channels + ] * 2 + + # Initialize atom edge embedding + if self.share_atom_edge_embedding and self.use_atom_edge_embedding: + self.source_embedding = nn.Embedding( + self.max_num_elements, self.edge_channels_list[-1] + ) + self.target_embedding = nn.Embedding( + self.max_num_elements, self.edge_channels_list[-1] + ) + self.edge_channels_list[0] = ( + self.edge_channels_list[0] + 2 * self.edge_channels_list[-1] + ) + else: + self.source_embedding, self.target_embedding = None, None + + # Initialize the module that compute WignerD matrices and other values for spherical harmonic calculations + self.SO3_rotation = nn.ModuleList() + for i in range(self.num_resolutions): + self.SO3_rotation.append(SO3_Rotation(self.lmax_list[i])) + + # Initialize conversion between degree l and order m layouts + self.mappingReduced = CoefficientMappingModule(self.lmax_list, self.mmax_list) + + # Initialize the transformations between spherical and grid representations + self.SO3_grid = ModuleListInfo( + f"({max(self.lmax_list)}, {max(self.lmax_list)})" + ) + for lval in range(max(self.lmax_list) + 1): + SO3_m_grid = nn.ModuleList() + for m in range(max(self.lmax_list) + 1): + SO3_m_grid.append( + SO3_Grid( + lval, + m, + resolution=self.grid_resolution, + normalization="component", + ) + ) + self.SO3_grid.append(SO3_m_grid) + + # Edge-degree embedding + self.edge_degree_embedding = EdgeDegreeEmbedding( + self.sphere_channels, + self.lmax_list, + self.mmax_list, + self.SO3_rotation, + self.mappingReduced, + self.max_num_elements, + self.edge_channels_list, + self.block_use_atom_edge_embedding, + rescale_factor=self.avg_degree, + ) + + # Initialize the blocks for each layer of EquiformerV2 + self.blocks = nn.ModuleList() + for _ in range(self.num_layers): + block = TransBlockV2( + self.sphere_channels, + self.attn_hidden_channels, + self.num_heads, + self.attn_alpha_channels, + self.attn_value_channels, + self.ffn_hidden_channels, + self.sphere_channels, + self.lmax_list, + self.mmax_list, + self.SO3_rotation, + self.mappingReduced, + self.SO3_grid, + self.max_num_elements, + self.edge_channels_list, + self.block_use_atom_edge_embedding, + self.use_m_share_rad, + self.attn_activation, + self.use_s2_act_attn, + self.use_attn_renorm, + self.ffn_activation, + self.use_gate_act, + self.use_grid_mlp, + self.use_sep_s2_act, + self.norm_type, + self.alpha_drop, + self.drop_path_rate, + self.proj_drop, + ) + self.blocks.append(block) + + # Output blocks for energy and forces + self.norm = get_normalization_layer( + self.norm_type, + lmax=max(self.lmax_list), + num_channels=self.sphere_channels, + ) + self.energy_block = FeedForwardNetwork( + self.sphere_channels, + self.ffn_hidden_channels, + 1, + self.lmax_list, + self.mmax_list, + self.SO3_grid, + self.ffn_activation, + self.use_gate_act, + self.use_grid_mlp, + self.use_sep_s2_act, + ) + if self.regress_forces: + self.force_block = SO2EquivariantGraphAttention( + self.sphere_channels, + self.attn_hidden_channels, + self.num_heads, + self.attn_alpha_channels, + self.attn_value_channels, + 1, + self.lmax_list, + self.mmax_list, + self.SO3_rotation, + self.mappingReduced, + self.SO3_grid, + self.max_num_elements, + self.edge_channels_list, + self.block_use_atom_edge_embedding, + self.use_m_share_rad, + self.attn_activation, + self.use_s2_act_attn, + self.use_attn_renorm, + self.use_gate_act, + self.use_sep_s2_act, + alpha_drop=0.0, + ) + + if self.load_energy_lin_ref: + self.energy_lin_ref = nn.Parameter( + torch.zeros(self.max_num_elements), + requires_grad=False, + ) + + self.apply(self._init_weights) + self.apply(self._uniform_init_rad_func_linear_weights) + + def _init_gp_partitions( + self, + atomic_numbers_full, + data_batch_full, + edge_index, + edge_distance, + edge_distance_vec, + ): + """Graph Parallel + This creates the required partial tensors for each rank given the full tensors. + The tensors are split on the dimension along the node index using node_partition. + """ + node_partition = gp_utils.scatter_to_model_parallel_region( + torch.arange(len(atomic_numbers_full)).to(self.device) + ) + edge_partition = torch.where( + torch.logical_and( + edge_index[1] >= node_partition.min(), + edge_index[1] <= node_partition.max(), # TODO: 0 or 1? + ) + )[0] + edge_index = edge_index[:, edge_partition] + edge_distance = edge_distance[edge_partition] + edge_distance_vec = edge_distance_vec[edge_partition] + atomic_numbers = atomic_numbers_full[node_partition] + data_batch = data_batch_full[node_partition] + node_offset = node_partition.min().item() + return ( + atomic_numbers, + data_batch, + node_offset, + edge_index, + edge_distance, + edge_distance_vec, + ) + + @conditional_grad(torch.enable_grad()) + def forward(self, data): + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + self.device = data.pos.device + atomic_numbers = data.atomic_numbers.long() + graph = self.generate_graph( + data, + enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, + ) + + data_batch = data.batch + if gp_utils.initialized(): + ( + atomic_numbers, + data_batch, + node_offset, + edge_index, + edge_distance, + edge_distance_vec, + ) = self._init_gp_partitions( + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, + ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + + ############################################################### + # Entering Graph Parallel Region + # after this point, if using gp, then node, edge tensors are split + # across the graph parallel ranks, some full tensors such as + # atomic_numbers_full are required because we need to index into the + # full graph when computing edge embeddings or reducing nodes from neighbors + # + # all tensors that do not have the suffix "_full" refer to the partial tensors. + # if not using gp, the full values are equal to the partial values + # ie: atomic_numbers_full == atomic_numbers + ############################################################### + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + for i in range(self.num_resolutions): + self.SO3_rotation[i].set_wigner(edge_rot_mat) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + x = SO3_Embedding( + len(atomic_numbers), + self.lmax_list, + self.sphere_channels, + self.device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l = 0, m = 0 coefficients for each resolution + for i in range(self.num_resolutions): + if self.num_resolutions == 1: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers) + else: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + # Edge encoding (distance and atom edge) + graph.edge_distance = self.distance_expansion(graph.edge_distance) + if self.share_atom_edge_embedding and self.use_atom_edge_embedding: + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] + ] # Source atom atomic number + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] + ] # Target atom atomic number + source_embedding = self.source_embedding(source_element) + target_embedding = self.target_embedding(target_element) + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 + ) + + # Edge-degree embedding + edge_degree = self.edge_degree_embedding( + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + len(atomic_numbers), + graph.node_offset, + ) + x.embedding = x.embedding + edge_degree.embedding + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + x = self.blocks[i]( + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + batch=data_batch, # for GraphDropPath + node_offset=graph.node_offset, + ) + + # Final layer norm + x.embedding = self.norm(x.embedding) + + ############################################################### + # Energy estimation + ############################################################### + node_energy = self.energy_block(x) + node_energy = node_energy.embedding.narrow(1, 0, 1) + if gp_utils.initialized(): + node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0) + energy = torch.zeros( + len(data.natoms), + device=node_energy.device, + dtype=node_energy.dtype, + ) + energy.index_add_(0, graph.batch_full, node_energy.view(-1)) + energy = energy / self.avg_num_nodes + + # Add the per-atom linear references to the energy. + if self.use_energy_lin_ref and self.load_energy_lin_ref: + # During training, target E = (E_DFT - E_ref - E_mean) / E_std, and + # during inference, \hat{E_DFT} = \hat{E} * E_std + E_ref + E_mean + # where + # + # E_DFT = raw DFT energy, + # E_ref = reference energy, + # E_mean = normalizer mean, + # E_std = normalizer std, + # \hat{E} = predicted energy, + # \hat{E_DFT} = predicted DFT energy. + # + # We can also write this as + # \hat{E_DFT} = E_std * (\hat{E} + E_ref / E_std) + E_mean, + # which is why we save E_ref / E_std as the linear reference. + with torch.cuda.amp.autocast(False): + energy = energy.to(self.energy_lin_ref.dtype).index_add( + 0, + graph.batch_full, + self.energy_lin_ref[graph.atomic_numbers_full], + ) + + outputs = {"energy": energy} + ############################################################### + # Force estimation + ############################################################### + if self.regress_forces: + forces = self.force_block( + x, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + node_offset=graph.node_offset, + ) + forces = forces.embedding.narrow(1, 1, 3) + forces = forces.view(-1, 3).contiguous() + if gp_utils.initialized(): + forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) + outputs["forces"] = forces + + return outputs + + # Initialize the edge rotation matrics + def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): + return init_edge_rot_mat(edge_distance_vec) + + @property + def num_params(self): + return sum(p.numel() for p in self.parameters()) + + def _init_weights(self, m): + if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + if self.weight_init == "normal": + std = 1 / math.sqrt(m.in_features) + torch.nn.init.normal_(m.weight, 0, std) + + elif isinstance(m, torch.nn.LayerNorm): + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.constant_(m.weight, 1.0) + + def _uniform_init_rad_func_linear_weights(self, m): + if isinstance(m, RadialFunction): + m.apply(self._uniform_init_linear_weights) + + def _uniform_init_linear_weights(self, m): + if isinstance(m, torch.nn.Linear): + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + std = 1 / math.sqrt(m.in_features) + torch.nn.init.uniform_(m.weight, -std, std) + + @torch.jit.ignore + def no_weight_decay(self) -> set: + no_wd_list = [] + named_parameters_list = [name for name, _ in self.named_parameters()] + for module_name, module in self.named_modules(): + if isinstance( + module, + ( + torch.nn.Linear, + SO3_LinearV2, + torch.nn.LayerNorm, + EquivariantLayerNormArray, + EquivariantLayerNormArraySphericalHarmonics, + EquivariantRMSNormArraySphericalHarmonics, + EquivariantRMSNormArraySphericalHarmonicsV2, + GaussianRadialBasisLayer, + ), + ): + for parameter_name, _ in module.named_parameters(): + if ( + isinstance(module, (torch.nn.Linear, SO3_LinearV2)) + and "weight" in parameter_name + ): + continue + global_parameter_name = module_name + "." + parameter_name + assert global_parameter_name in named_parameters_list + no_wd_list.append(global_parameter_name) + + return set(no_wd_list) diff --git a/tests/core/models/__snapshots__/test_equiformer_v2.ambr b/tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr similarity index 84% rename from tests/core/models/__snapshots__/test_equiformer_v2.ambr rename to tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr index 03be8ebda..d374d616e 100644 --- a/tests/core/models/__snapshots__/test_equiformer_v2.ambr +++ b/tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr @@ -6,7 +6,7 @@ # --- # name: TestEquiformerV2.test_ddp.1 Approx( - array([0.12408741], dtype=float32), + array([0.12408739], dtype=float32), rtol=0.001, atol=0.001 ) @@ -19,7 +19,7 @@ # --- # name: TestEquiformerV2.test_ddp.3 Approx( - array([ 1.4928594e-03, -7.4167736e-05, 2.9909366e-03], dtype=float32), + array([ 1.4928584e-03, -7.4167408e-05, 2.9909366e-03], dtype=float32), rtol=0.001, atol=0.001 ) @@ -31,7 +31,7 @@ # --- # name: TestEquiformerV2.test_energy_force_shape.1 Approx( - array([0.12408741], dtype=float32), + array([0.12408739], dtype=float32), rtol=0.001, atol=0.001 ) @@ -44,7 +44,7 @@ # --- # name: TestEquiformerV2.test_energy_force_shape.3 Approx( - array([ 1.4928594e-03, -7.4167736e-05, 2.9909366e-03], dtype=float32), + array([ 1.4928584e-03, -7.4167408e-05, 2.9909366e-03], dtype=float32), rtol=0.001, atol=0.001 ) diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 2f0903608..1abe78a35 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -7,25 +7,14 @@ from __future__ import annotations -import copy -import io import os from pathlib import Path -import pytest -import requests import torch import yaml from ase.io import read -from torch.nn.parallel.distributed import DistributedDataParallel from fairchem.core.common.registry import registry -from fairchem.core.common.test_utils import ( - PGConfig, - init_pg_and_rank_and_launch_test, - spawn_multi_process, -) -from fairchem.core.common.utils import load_state_dict, setup_imports from fairchem.core.datasets import data_list_collater from fairchem.core.models.equiformer_v2.so3 import ( CoefficientMappingModule, @@ -34,139 +23,6 @@ from fairchem.core.preprocessing import AtomsToGraphs -@pytest.fixture(scope="class") -def load_data(request): - atoms = read( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), - index=0, - format="json", - ) - a2g = AtomsToGraphs( - max_neigh=200, - radius=6, - r_edges=False, - r_fixed=True, - ) - data_list = a2g.convert_all([atoms]) - request.cls.data = data_list[0] - - -def _load_model(): - torch.manual_seed(4) - setup_imports() - - # download and load weights. - checkpoint_url = "https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_31M_ec4_allmd.pt" - - # load buffer into memory as a stream - # and then load it with torch.load - r = requests.get(checkpoint_url, stream=True) - r.raise_for_status() - checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) - - model = registry.get_model_class("equiformer_v2")( - use_pbc=True, - regress_forces=True, - otf_graph=True, - max_neighbors=20, - max_radius=12.0, - max_num_elements=90, - num_layers=8, - sphere_channels=128, - attn_hidden_channels=64, - num_heads=8, - attn_alpha_channels=64, - attn_value_channels=16, - ffn_hidden_channels=128, - norm_type="layer_norm_sh", - lmax_list=[4], - mmax_list=[2], - grid_resolution=18, - num_sphere_samples=128, - edge_channels=128, - use_atom_edge_embedding=True, - distance_function="gaussian", - num_distance_basis=512, - attn_activation="silu", - use_s2_act_attn=False, - ffn_activation="silu", - use_gate_act=False, - use_grid_mlp=True, - alpha_drop=0.1, - drop_path_rate=0.1, - proj_drop=0.0, - weight_init="uniform", - ) - - new_dict = {k[len("module.") * 2 :]: v for k, v in checkpoint["state_dict"].items()} - load_state_dict(model, new_dict) - - # Precision errors between mac vs. linux compound with multiple layers, - # so we explicitly set the number of layers to 1 (instead of all 8). - # The other alternative is to have different snapshots for mac vs. linux. - model.num_layers = 1 - return model - - -@pytest.fixture(scope="class") -def load_model(request): - request.cls.model = _load_model() - - -def _runner(data): - # serializing the model through python multiprocess results in precision errors, so we get a fresh model here - model = _load_model() - ddp_model = DistributedDataParallel(model) - outputs = ddp_model(data_list_collater([data])) - return {k: v.detach() for k, v in outputs.items()} - - -@pytest.mark.usefixtures("load_data") -@pytest.mark.usefixtures("load_model") -class TestEquiformerV2: - def test_energy_force_shape(self, snapshot): - # Recreate the Data object to only keep the necessary features. - data = self.data - model = copy.deepcopy(self.model) - - # Pass it through the model. - outputs = model(data_list_collater([data])) - print(outputs) - energy, forces = outputs["energy"], outputs["forces"] - - assert snapshot == energy.shape - assert snapshot == pytest.approx(energy.detach()) - - assert snapshot == forces.shape - assert snapshot == pytest.approx(forces.detach().mean(0)) - - def test_ddp(self, snapshot): - data_dist = self.data.clone().detach() - config = PGConfig(backend="gloo", world_size=1, gp_group_size=1, use_gp=False) - output = spawn_multi_process( - config, _runner, init_pg_and_rank_and_launch_test, data_dist - ) - assert len(output) == 1 - energy, forces = output[0]["energy"], output[0]["forces"] - assert snapshot == energy.shape - assert snapshot == pytest.approx(energy.detach()) - assert snapshot == forces.shape - assert snapshot == pytest.approx(forces.detach().mean(0)) - - def test_gp(self, snapshot): - data_dist = self.data.clone().detach() - config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) - output = spawn_multi_process( - config, _runner, init_pg_and_rank_and_launch_test, data_dist - ) - assert len(output) == 2 - energy, forces = output[0]["energy"], output[0]["forces"] - assert snapshot == energy.shape - assert snapshot == pytest.approx(energy.detach()) - assert snapshot == forces.shape - assert snapshot == pytest.approx(forces.detach().mean(0)) - - class TestMPrimaryLPrimary: def test_mprimary_lprimary_mappings(self): def sign(x): @@ -236,12 +92,17 @@ def sign(x): def _load_hydra_model(): torch.manual_seed(4) - with open(Path("tests/core/models/test_configs/test_equiformerv2_hydra.yml")) as yaml_file: + with open( + Path("tests/core/models/test_configs/test_equiformerv2_hydra.yml") + ) as yaml_file: yaml_config = yaml.safe_load(yaml_file) - model = registry.get_model_class("hydra")(yaml_config["model"]["backbone"],yaml_config["model"]["heads"]) + model = registry.get_model_class("hydra")( + yaml_config["model"]["backbone"], yaml_config["model"]["heads"] + ) model.backbone.num_layers = 1 return model + def test_eqv2_hydra_activation_checkpoint(): atoms = read( os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), @@ -258,7 +119,7 @@ def test_eqv2_hydra_activation_checkpoint(): inputs = data_list_collater(data_list) no_ac_model = _load_hydra_model() ac_model = _load_hydra_model() - ac_model.backbone.activation_checkpoint=True + ac_model.backbone.activation_checkpoint = True # to do this test we need both models to have the exact same state and the only # way to do this is save the rng state and reset it after stepping the first model @@ -272,7 +133,13 @@ def test_eqv2_hydra_activation_checkpoint(): torch.autograd.backward(outptuts_ac["energy"]["energy"].sum() + outptuts_ac["forces"]["forces"].sum()) # assert all the gradients are identical between the model with checkpointing and no checkpointing - ac_model_grad_dict = {name:p.grad for name, p in ac_model.named_parameters() if p.grad is not None} - no_ac_model_grad_dict = {name:p.grad for name, p in no_ac_model.named_parameters() if p.grad is not None} + ac_model_grad_dict = { + name: p.grad for name, p in ac_model.named_parameters() if p.grad is not None + } + no_ac_model_grad_dict = { + name: p.grad for name, p in no_ac_model.named_parameters() if p.grad is not None + } for name in no_ac_model_grad_dict: - assert torch.allclose(no_ac_model_grad_dict[name], ac_model_grad_dict[name], atol=1e-4) + assert torch.allclose( + no_ac_model_grad_dict[name], ac_model_grad_dict[name], atol=1e-4 + ) diff --git a/tests/core/models/test_equiformer_v2_deprecated.py b/tests/core/models/test_equiformer_v2_deprecated.py new file mode 100644 index 000000000..a42257c65 --- /dev/null +++ b/tests/core/models/test_equiformer_v2_deprecated.py @@ -0,0 +1,161 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import copy +import io +import os + +import pytest +import requests +import torch +from ase.io import read +from torch.nn.parallel.distributed import DistributedDataParallel + +from fairchem.core.common.registry import registry +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) +from fairchem.core.common.utils import load_state_dict, setup_imports +from fairchem.core.datasets import data_list_collater +from fairchem.core.preprocessing import AtomsToGraphs + + +@pytest.fixture(scope="class") +def load_data(request): + atoms = read( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), + index=0, + format="json", + ) + a2g = AtomsToGraphs( + max_neigh=200, + radius=6, + r_edges=False, + r_fixed=True, + ) + data_list = a2g.convert_all([atoms]) + request.cls.data = data_list[0] + + +def _load_model(): + torch.manual_seed(4) + setup_imports() + + # download and load weights. + checkpoint_url = "https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_31M_ec4_allmd.pt" + + # load buffer into memory as a stream + # and then load it with torch.load + r = requests.get(checkpoint_url, stream=True) + r.raise_for_status() + checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) + + model = registry.get_model_class("equiformer_v2")( + use_pbc=True, + regress_forces=True, + otf_graph=True, + max_neighbors=20, + max_radius=12.0, + max_num_elements=90, + num_layers=8, + sphere_channels=128, + attn_hidden_channels=64, + num_heads=8, + attn_alpha_channels=64, + attn_value_channels=16, + ffn_hidden_channels=128, + norm_type="layer_norm_sh", + lmax_list=[4], + mmax_list=[2], + grid_resolution=18, + num_sphere_samples=128, + edge_channels=128, + use_atom_edge_embedding=True, + distance_function="gaussian", + num_distance_basis=512, + attn_activation="silu", + use_s2_act_attn=False, + ffn_activation="silu", + use_gate_act=False, + use_grid_mlp=True, + alpha_drop=0.1, + drop_path_rate=0.1, + proj_drop=0.0, + weight_init="uniform", + ) + + new_dict = {k[len("module.") * 2 :]: v for k, v in checkpoint["state_dict"].items()} + load_state_dict(model, new_dict) + + # Precision errors between mac vs. linux compound with multiple layers, + # so we explicitly set the number of layers to 1 (instead of all 8). + # The other alternative is to have different snapshots for mac vs. linux. + model.num_layers = 1 + return model + + +@pytest.fixture(scope="class") +def load_model(request): + request.cls.model = _load_model() + + +def _runner(data): + # serializing the model through python multiprocess results in precision errors, so we get a fresh model here + model = _load_model() + ddp_model = DistributedDataParallel(model) + outputs = ddp_model(data_list_collater([data])) + return {k: v.detach() for k, v in outputs.items()} + + +@pytest.mark.usefixtures("load_data") +@pytest.mark.usefixtures("load_model") +class TestEquiformerV2: + def test_energy_force_shape(self, snapshot): + # Recreate the Data object to only keep the necessary features. + data = self.data + model = copy.deepcopy(self.model) + + # Pass it through the model. + outputs = model(data_list_collater([data])) + print(outputs) + energy, forces = outputs["energy"], outputs["forces"] + + assert snapshot == energy.shape + assert snapshot == pytest.approx(energy.detach()) + + assert snapshot == forces.shape + assert snapshot == pytest.approx(forces.detach().mean(0)) + + def test_ddp(self, snapshot): + data_dist = self.data.clone().detach() + config = PGConfig(backend="gloo", world_size=1, gp_group_size=1, use_gp=False) + output = spawn_multi_process( + config, _runner, init_pg_and_rank_and_launch_test, data_dist + ) + assert len(output) == 1 + energy, forces = output[0]["energy"], output[0]["forces"] + assert snapshot == energy.shape + assert snapshot == pytest.approx(energy.detach()) + assert snapshot == forces.shape + assert snapshot == pytest.approx(forces.detach().mean(0)) + + def test_gp(self, snapshot): + data_dist = self.data.clone().detach() + config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) + output = spawn_multi_process( + config, _runner, init_pg_and_rank_and_launch_test, data_dist + ) + assert len(output) == 2 + energy, forces = output[0]["energy"], output[0]["forces"] + assert snapshot == energy.shape + assert snapshot == pytest.approx(energy.detach()) + assert snapshot == forces.shape + assert snapshot == pytest.approx(forces.detach().mean(0)) From 3899aac954032c3eaa433907d5411d2f3f6eea24 Mon Sep 17 00:00:00 2001 From: Misko Date: Tue, 20 Aug 2024 19:58:22 -0700 Subject: [PATCH 3/6] FM-v4 branch into main (#752) * Update BalancedBatchSampler to use datasets' `data_sizes` method Replace BalancedBatchSampler's `force_balancing` and `throw_on_error` parameters with `on_error` * Remove python 3.10 syntax * Documentation * Added set_epoch method * Format * Changed "resolved dataset" message to be a debug log to reduce log spam * Minor changes to support multitask * add in pickle data set; add in stat functions for combining mean and variance * checksums for equiformer * detach compute metrics and add checksum function for linear layer * change name to dataset_configs * add seed option * remove pickle dataset * remove pickle dataset * add experimental datatransform to ase_dataset * clean up batchsampler and tests * base dataset class * move lin_ref to base dataset * inherit basedataset for ase dataset * filter indices prop * updated import for ase dataset * added create_dataset fn * yaml load fix * create dataset function instead of filtering in base * remove filtered_indices * make create_dataset and LMDBDatabase importable from datasets * create_dataset cleanup * test create_dataset * use metadata.natoms directly and add it to subset * use self.indices to handle shard * rename _data_sizes * fix Subset of metadata * fix up to be mergeable * merge in monorepo * small fix for import and keyerror * minor change to metadata, added full path option * import updates * minor fix to base dataset * skip force_balance and seed * adding get_metadata to base_dataset * implement get_metadata for datasets; add tests for max_atoms and balanced partitioning * a[:len(a)+1] does not throw error, change to check for this * bug fix for base_dataset * max atoms branch * fix typo * do pbc per system * add option to use single system pbc * add multiple mapping * lint and github workflow fixes * track parent checkpoint for logger grouping * add generator to basedataset * check path relative to yaml file * add load and exit flag to base_trainer * add in merge mean and std code to utils * add log when passing through mean or computing; check other paths for includes * add qos flag * use slurm_qos instead of qos * fix includes * fix set init * adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764) * adding new notebook for using fairchem models with NEBs * adding md tutorials * blocking code cells that arent needed or take too long * remove files with diff whitespace * add resolution flag to escn * try to revert oxides * revert typing * remove white space * extra line never reached * move out of fmv4 into dev * move avg num nodes * optional import from experimental * fix lint * add comments, refactor common trainer args in a single dictionary * add comments, refactor common trainer args in a single dictionary * remove parent --------- Co-authored-by: Nima Shoghi Co-authored-by: Nima Shoghi Co-authored-by: Abhishek Das Co-authored-by: lbluque Co-authored-by: Brandon Co-authored-by: Muhammed Shuaibi Co-authored-by: Ray Gao Co-authored-by: Brook Wander <73855115+brookwander@users.noreply.github.com> Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> --- src/fairchem/core/common/utils.py | 149 +++++++++++++++------ src/fairchem/core/datasets/_utils.py | 21 ++- src/fairchem/core/modules/loss.py | 3 +- src/fairchem/core/modules/transforms.py | 6 + src/fairchem/core/trainers/base_trainer.py | 34 +++-- tests/core/common/test_yaml_loader.py | 37 ++++- 6 files changed, 187 insertions(+), 63 deletions(-) diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index e762dfeb5..955ea1e06 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -92,6 +92,15 @@ def save_checkpoint( return filename +multitask_required_keys = { + "tasks", + "datasets", + "combined_dataset", + "model", + "optim", +} + + class Complete: def __call__(self, data): device = data.edge_index.device @@ -393,48 +402,83 @@ def create_dict_from_args(args: list, sep: str = "."): return return_dict -def load_config(path: str, previous_includes: list | None = None): - if previous_includes is None: - previous_includes = [] +# given a filename and set of paths , return the full file path +def find_relative_file_in_paths(filename, include_paths): + if os.path.exists(filename): + return filename + for path in include_paths: + include_filename = os.path.join(path, filename) + if os.path.exists(include_filename): + return include_filename + raise ValueError(f"Cannot find include YML {filename}") + + +def load_config( + path: str, + files_previously_included: list | None = None, + include_paths: list | None = None, +): + """ + Load a given config with any defined imports + + When imports are present this is a recursive function called on imports. + To prevent any cyclic imports we keep track of already imported yml files + using files_previously_included + """ + if include_paths is None: + include_paths = [] + if files_previously_included is None: + files_previously_included = [] path = Path(path) - if path in previous_includes: + if path in files_previously_included: raise ValueError( - f"Cyclic config include detected. {path} included in sequence {previous_includes}." + f"Cyclic config include detected. {path} included in sequence {files_previously_included}." ) - previous_includes = [*previous_includes, path] + files_previously_included = [*files_previously_included, path] with open(path) as fp: - direct_config = yaml.load(fp, Loader=UniqueKeyLoader) + current_config = yaml.load(fp, Loader=UniqueKeyLoader) # Load config from included files. - includes = direct_config.pop("includes") if "includes" in direct_config else [] - if not isinstance(includes, list): - raise AttributeError(f"Includes must be a list, '{type(includes)}' provided") + includes_listed_in_config = ( + current_config.pop("includes") if "includes" in current_config else [] + ) + if not isinstance(includes_listed_in_config, list): + raise AttributeError( + f"Includes must be a list, '{type(includes_listed_in_config)}' provided" + ) - config = {} + config_from_includes = {} duplicates_warning = [] duplicates_error = [] - - for include in includes: + for include in includes_listed_in_config: + include_filename = find_relative_file_in_paths( + include, [os.path.dirname(path), *include_paths] + ) include_config, inc_dup_warning, inc_dup_error = load_config( - include, previous_includes + include_filename, files_previously_included ) duplicates_warning += inc_dup_warning duplicates_error += inc_dup_error # Duplicates between includes causes an error - config, merge_dup_error = merge_dicts(config, include_config) + config_from_includes, merge_dup_error = merge_dicts( + config_from_includes, include_config + ) duplicates_error += merge_dup_error # Duplicates between included and main file causes warnings - config, merge_dup_warning = merge_dicts(config, direct_config) + config_from_includes, merge_dup_warning = merge_dicts( + config_from_includes, current_config + ) duplicates_warning += merge_dup_warning + return config_from_includes, duplicates_warning, duplicates_error - return config, duplicates_warning, duplicates_error - -def build_config(args, args_override): - config, duplicates_warning, duplicates_error = load_config(args.config_yml) +def build_config(args, args_override, include_paths=None): + config, duplicates_warning, duplicates_error = load_config( + args.config_yml, include_paths=include_paths + ) if len(duplicates_warning) > 0: logging.warning( f"Overwritten config parameters from included configs " @@ -999,34 +1043,53 @@ class _TrainingContext: task_name = "s2ef" elif trainer_name in ["energy", "equiformerv2_energy"]: task_name = "is2re" + elif "multitask" in trainer_name: + task_name = "multitask" else: task_name = "ocp" trainer_cls = registry.get_trainer_class(trainer_name) assert trainer_cls is not None, "Trainer not found" - trainer = trainer_cls( - task=config.get("task", {}), - model=config["model"], - outputs=config.get("outputs", {}), - dataset=config["dataset"], - optimizer=config["optim"], - loss_functions=config.get("loss_functions", {}), - evaluation_metrics=config.get("evaluation_metrics", {}), - identifier=config["identifier"], - timestamp_id=config.get("timestamp_id", None), - run_dir=config.get("run_dir", "./"), - is_debug=config.get("is_debug", False), - print_every=config.get("print_every", 10), - seed=config.get("seed", 0), - logger=config.get("logger", "wandb"), - local_rank=config["local_rank"], - amp=config.get("amp", False), - cpu=config.get("cpu", False), - slurm=config.get("slurm", {}), - noddp=config.get("noddp", False), - name=task_name, - gp_gpus=config.get("gp_gpus"), - ) + + trainer_config = { + "model": config["model"], + "optimizer": config["optim"], + "identifier": config["identifier"], + "timestamp_id": config.get("timestamp_id", None), + "run_dir": config.get("run_dir", "./"), + "is_debug": config.get("is_debug", False), + "print_every": config.get("print_every", 10), + "seed": config.get("seed", 0), + "logger": config.get("logger", "wandb"), + "local_rank": config["local_rank"], + "amp": config.get("amp", False), + "cpu": config.get("cpu", False), + "slurm": config.get("slurm", {}), + "noddp": config.get("noddp", False), + "name": task_name, + "gp_gpus": config.get("gp_gpus"), + } + + if task_name == "multitask": + trainer_config.update( + { + "tasks": config.get("tasks", {}), + "dataset_configs": config["datasets"], + "combined_dataset_config": config.get("combined_dataset", {}), + "evaluations": config.get("evaluations", {}), + } + ) + else: + trainer_config.update( + { + "task": config.get("task", {}), + "outputs": config.get("outputs", {}), + "dataset": config["dataset"], + "loss_functions": config.get("loss_functions", {}), + "evaluation_metrics": config.get("evaluation_metrics", {}), + } + ) + trainer = trainer_cls(**trainer_config) task_cls = registry.get_task_class(config["mode"]) assert task_cls is not None, "Task not found" diff --git a/src/fairchem/core/datasets/_utils.py b/src/fairchem/core/datasets/_utils.py index 7572eb3ca..6d5c947e2 100644 --- a/src/fairchem/core/datasets/_utils.py +++ b/src/fairchem/core/datasets/_utils.py @@ -13,19 +13,32 @@ from torch_geometric.data import Data -def rename_data_object_keys(data_object: Data, key_mapping: dict[str, str]) -> Data: +def rename_data_object_keys( + data_object: Data, key_mapping: dict[str, str | list[str]] +) -> Data: """Rename data object keys Args: data_object: data object key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key} + + new_key can be a list of new keys, for example, + prev_key: energy + new_key: [common_energy, oc20_energy] + + This is currently required when we use a single target/label for multiple tasks """ for _property in key_mapping: # catch for test data not containing labels if _property in data_object: - new_property = key_mapping[_property] - if new_property not in data_object: + list_of_new_keys = key_mapping[_property] + if isinstance(list_of_new_keys, str): + list_of_new_keys = [list_of_new_keys] + for new_property in list_of_new_keys: + if new_property == _property: + continue + assert new_property not in data_object data_object[new_property] = data_object[_property] + if _property not in list_of_new_keys: del data_object[_property] - return data_object diff --git a/src/fairchem/core/modules/loss.py b/src/fairchem/core/modules/loss.py index e818704bf..b737e79e6 100644 --- a/src/fairchem/core/modules/loss.py +++ b/src/fairchem/core/modules/loss.py @@ -20,7 +20,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): return torch.mean(dists) elif self.reduction == "sum": return torch.sum(dists) - return None + + return dists class AtomwiseL2Loss(nn.Module): diff --git a/src/fairchem/core/modules/transforms.py b/src/fairchem/core/modules/transforms.py index 52675fd28..50f59fd8b 100644 --- a/src/fairchem/core/modules/transforms.py +++ b/src/fairchem/core/modules/transforms.py @@ -8,6 +8,12 @@ if TYPE_CHECKING: from torch_geometric.data import Data +from contextlib import suppress + +with suppress(ImportError): + # TODO remove this in favor of a better solution + # We should never be importing * from a module + from fairchem.experimental.foundation_models.multi_task_dataloader.transforms.data_object import * # noqa class DataTransforms: diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index d84a2c12f..94becb924 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -13,6 +13,7 @@ import logging import os import random +import sys from abc import ABC, abstractmethod from itertools import chain from typing import TYPE_CHECKING @@ -232,6 +233,8 @@ def load(self) -> None: self.load_loss() self.load_optimizer() self.load_extras() + if self.config["optim"].get("load_datasets_and_model_then_exit", False): + sys.exit(0) def set_seed(self, seed) -> None: # https://pytorch.org/docs/stable/notes/randomness.html @@ -792,6 +795,22 @@ def update_best( disable_tqdm=disable_eval_tqdm, ) + def _aggregate_metrics(self, metrics): + aggregated_metrics = {} + for k in metrics: + aggregated_metrics[k] = { + "total": distutils.all_reduce( + metrics[k]["total"], average=False, device=self.device + ), + "numel": distutils.all_reduce( + metrics[k]["numel"], average=False, device=self.device + ), + } + aggregated_metrics[k]["metric"] = ( + aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"] + ) + return aggregated_metrics + @torch.no_grad() def validate(self, split: str = "val", disable_tqdm: bool = False): ensure_fitted(self._unwrapped_model, warn=True) @@ -833,20 +852,7 @@ def validate(self, split: str = "val", disable_tqdm: bool = False): metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) - aggregated_metrics = {} - for k in metrics: - aggregated_metrics[k] = { - "total": distutils.all_reduce( - metrics[k]["total"], average=False, device=self.device - ), - "numel": distutils.all_reduce( - metrics[k]["numel"], average=False, device=self.device - ), - } - aggregated_metrics[k]["metric"] = ( - aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"] - ) - metrics = aggregated_metrics + metrics = self._aggregate_metrics(metrics) log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": self.epoch}) diff --git a/tests/core/common/test_yaml_loader.py b/tests/core/common/test_yaml_loader.py index 0e1ead3f9..e9cfdf356 100644 --- a/tests/core/common/test_yaml_loader.py +++ b/tests/core/common/test_yaml_loader.py @@ -1,11 +1,12 @@ from __future__ import annotations +import os import tempfile import pytest import yaml -from fairchem.core.common.utils import UniqueKeyLoader +from fairchem.core.common.utils import UniqueKeyLoader, load_config @pytest.fixture(scope="class") @@ -32,6 +33,14 @@ def valid_yaml_config(): """ +@pytest.fixture(scope="class") +def include_path_in_yaml_config(): + return """ +includes: + - other.yml +""" + + def test_invalid_config(invalid_yaml_config): with tempfile.NamedTemporaryFile(delete=False) as fp: fp.write(invalid_yaml_config.encode()) @@ -49,3 +58,29 @@ def test_valid_config(valid_yaml_config): with open(fname) as fp: yaml.load(fp, Loader=UniqueKeyLoader) + + +def test_load_config_with_include_path(include_path_in_yaml_config, valid_yaml_config): + with tempfile.TemporaryDirectory() as tempdirname: + + this_yml_path = f"{tempdirname}/this.yml" + with open(this_yml_path, "w") as fp: + fp.write(include_path_in_yaml_config) + + # the include does not exist throw an error! + with pytest.raises(ValueError): + load_config(this_yml_path) + + other_yml_path = f"{tempdirname}/subfolder" + os.mkdir(other_yml_path) + other_yml_full_filename = f"{other_yml_path}/other.yml" + with open(other_yml_full_filename, "w") as fp: + fp.write(valid_yaml_config) + + # the include does not exist throw an error! + with pytest.raises(ValueError): + load_config(this_yml_path) + + # the include does not exist throw an error! + loaded_config = load_config(this_yml_path, include_paths=[other_yml_path]) + assert set(loaded_config[0].keys()) == set(["key1", "key2"]) From df8933067743f46cc1c5aae176e3474007080527 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Wed, 21 Aug 2024 16:04:19 -0700 Subject: [PATCH 4/6] update to use abs run_dir paths by default (#820) --- src/fairchem/core/common/flags.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fairchem/core/common/flags.py b/src/fairchem/core/common/flags.py index ac4bd2f84..266e1e640 100644 --- a/src/fairchem/core/common/flags.py +++ b/src/fairchem/core/common/flags.py @@ -8,6 +8,7 @@ from __future__ import annotations import argparse +import os from pathlib import Path @@ -48,7 +49,7 @@ def add_core_args(self) -> None: ) self.parser.add_argument( "--run-dir", - default="./", + default=os.path.abspath("./"), type=str, help="Directory to store checkpoint/log/result directory", ) From 1bee0d71cc96e15293c49382c6ad80840ca5dc57 Mon Sep 17 00:00:00 2001 From: Misko Date: Wed, 21 Aug 2024 17:07:47 -0700 Subject: [PATCH 5/6] Add check to max num atoms (#817) * add assert for max_num_atoms * add test to make sure we are properly checking for max_num_elements * fix post merge --- .../models/equiformer_v2/equiformer_v2.py | 3 +++ src/fairchem/core/models/escn/escn.py | 3 +++ tests/core/e2e/test_s2ef.py | 20 +++++++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 978d4c226..61b62be16 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -397,6 +397,9 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: self.dtype = data.pos.dtype self.device = data.pos.device atomic_numbers = data.atomic_numbers.long() + assert ( + atomic_numbers.max().item() < self.max_num_elements + ), "Atomic number exceeds that given in model config" graph = self.generate_graph( data, enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 54b1992f4..6eb95947a 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -235,6 +235,9 @@ def forward(self, data): start_time = time.time() atomic_numbers = data.atomic_numbers.long() + assert ( + atomic_numbers.max().item() < self.max_num_elements + ), "Atomic number exceeds that given in model config" num_atoms = len(atomic_numbers) graph = self.generate_graph(data) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 6b83749c0..2f7dfa373 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -170,6 +170,26 @@ def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic): input_yaml=configs["equiformer_v2"], ) + def test_max_num_atoms(self, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + with pytest.raises(AssertionError): + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "model": {"backbone": {"max_num_elements": 2}}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2_hydra"], + ) + @pytest.mark.parametrize( ("world_size", "ddp"), [ From c2b0c304982d6e4c22e7df95101fdd6edca7a7f7 Mon Sep 17 00:00:00 2001 From: Misko Date: Wed, 21 Aug 2024 19:46:49 -0700 Subject: [PATCH 6/6] fix gemnet scaling factors fit.py and add a test (#819) * fix gemnet fit and add test * only save factors, not the whole model! --- src/fairchem/core/modules/scaling/fit.py | 64 ++++++++++------------ tests/core/e2e/test_e2e_commons.py | 19 ++++--- tests/core/e2e/test_s2ef.py | 67 +++++++++++++++++++++++- 3 files changed, 104 insertions(+), 46 deletions(-) diff --git a/src/fairchem/core/modules/scaling/fit.py b/src/fairchem/core/modules/scaling/fit.py index 63c36e8f5..4bfc4bb62 100644 --- a/src/fairchem/core/modules/scaling/fit.py +++ b/src/fairchem/core/modules/scaling/fit.py @@ -37,18 +37,9 @@ def _train_batch(trainer: BaseTrainer, batch) -> None: del out, loss -def main(*, num_batches: int = 16) -> None: - # region args/config setup - setup_logging() - - parser = flags.get_parser() - args, override_args = parser.parse_known_args() - _config = build_config(args, override_args) - _config["logger"] = "wandb" - # endregion +def compute_scaling_factors(config, num_batches: int = 16) -> None: - assert not args.distributed, "This doesn't work with DDP" - with new_trainer_context(args=args, config=_config) as ctx: + with new_trainer_context(config=config) as ctx: config = ctx.config trainer = ctx.trainer @@ -61,8 +52,8 @@ def main(*, num_batches: int = 16) -> None: logging.info(f"Input checkpoint path: {ckpt_file}, {ckpt_file.exists()=}") model: nn.Module = trainer.model - val_loader = trainer.val_loader - assert val_loader is not None, "Val dataset is required for making predictions" + data_loader = trainer.train_loader + assert data_loader is not None, "Train set required to load batches" if ckpt_file.exists(): trainer.load_checkpoint(checkpoint_path=str(ckpt_file)) @@ -122,15 +113,8 @@ def main(*, num_batches: int = 16) -> None: sys.exit(-1) # endregion - # region get the output path - out_path = Path( - _prefilled_input( - "Enter output path for fitted scale factors: ", - prefill=str(ckpt_file), - ) - ) - if out_path.exists(): - logging.warning(f"Already found existing file: {out_path}") + if ckpt_file.exists(): + logging.warning(f"Already found existing file: {ckpt_file}") flag = input( "Do you want to continue and overwrite existing file (1), " "or exit (2)? " @@ -142,7 +126,7 @@ def main(*, num_batches: int = 16) -> None: sys.exit() logging.info( - f"Output path for fitted scale factors: {out_path}, {out_path.exists()=}" + f"Output path for fitted scale factors: {ckpt_file}, {ckpt_file.exists()=}" ) # endregion @@ -175,7 +159,7 @@ def index_fn(name: str = name) -> None: module.initialize_(index_fn=index_fn) # single pass through network - _train_batch(trainer, next(iter(val_loader))) + _train_batch(trainer, next(iter(data_loader))) # sort the scale factors by their computation order sorted_factors = sorted( @@ -200,7 +184,7 @@ def index_fn(name: str = name) -> None: logging.info(f"Fitting {name}...") with module.fit_context_(): - for batch in islice(val_loader, num_batches): + for batch in islice(data_loader, num_batches): _train_batch(trainer, batch) stats, ratio, value = module.fit_() @@ -216,19 +200,27 @@ def index_fn(name: str = name) -> None: assert module.fitted, f"{name} is not fitted" # region save the scale factors to the checkpoint file - trainer.config["cmd"]["checkpoint_dir"] = out_path.parent + trainer.config["cmd"]["checkpoint_dir"] = ckpt_file.parent trainer.is_debug = False - out_file = trainer.save( - metrics=None, - checkpoint_file=out_path.name, - training_state=False, + + torch.save( + { + x[0].replace(".scale_factor", ""): x[1] + for x in trainer.model.to("cpu").named_parameters() + if ".scale_" in x[0] + }, + str(ckpt_file), ) - assert out_file is not None, "Failed to save checkpoint" - out_file = Path(out_file) - assert out_file.exists(), f"Failed to save checkpoint to {out_file}" - # endregion - logging.info(f"Saved results to: {out_file}") + logging.info(f"Saved results to: {ckpt_file}") if __name__ == "__main__": - main() + # region args/config setup + setup_logging() + + parser = flags.get_parser() + args, override_args = parser.parse_known_args() + assert not args.distributed, "This doesn't work with DDP" + config = build_config(args, override_args) + + compute_scaling_factors(config) diff --git a/tests/core/e2e/test_e2e_commons.py b/tests/core/e2e/test_e2e_commons.py index ff3ea3634..ef2b860bf 100644 --- a/tests/core/e2e/test_e2e_commons.py +++ b/tests/core/e2e/test_e2e_commons.py @@ -93,6 +93,16 @@ def merge_dictionary(d, u): return d +def update_yaml_with_dict(input_yaml, output_yaml, update_dict_with): + with open(input_yaml) as yaml_file: + yaml_config = yaml.safe_load(yaml_file) + if update_dict_with is not None: + yaml_config = merge_dictionary(yaml_config, update_dict_with) + yaml_config["backend"] = "gloo" + with open(str(output_yaml), "w") as yaml_file: + yaml.dump(yaml_config, yaml_file) + + def _run_main( rundir, input_yaml, @@ -103,14 +113,7 @@ def _run_main( world_size=0, ): config_yaml = Path(rundir) / "train_and_val_on_val.yml" - - with open(input_yaml) as yaml_file: - yaml_config = yaml.safe_load(yaml_file) - if update_dict_with is not None: - yaml_config = merge_dictionary(yaml_config, update_dict_with) - yaml_config["backend"] = "gloo" - with open(str(config_yaml), "w") as yaml_file: - yaml.dump(yaml_config, yaml_file) + update_yaml_with_dict(input_yaml, config_yaml, update_dict_with) run_args = { "run_dir": rundir, "logdir": f"{rundir}/logs", diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 2f7dfa373..10e3203c9 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -8,11 +8,19 @@ import numpy as np import numpy.testing as npt import pytest -from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths +from fairchem.core._cli import Runner +from fairchem.core.modules.scaling.fit import compute_scaling_factors +from test_e2e_commons import ( + _run_main, + oc20_lmdb_train_and_val_from_paths, + update_yaml_with_dict, +) -from fairchem.core.common.utils import setup_logging +from fairchem.core.common.utils import build_config, setup_logging from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes +from fairchem.core.common.flags import flags + setup_logging() @@ -98,6 +106,61 @@ def smoke_test_train( energy_from_train, energy_from_checkpoint, rtol=1e-6, atol=1e-6 ) + def test_gemnet_fit_scaling(self, configs, tutorial_val_src): + + with tempfile.TemporaryDirectory() as tempdirname: + # (1) generate scaling factors for gemnet config + config_yaml = f"{tempdirname}/train_and_val_on_val.yml" + scaling_pt = f"{tempdirname}/scaling.pt" + # run + parser = flags.get_parser() + args, override_args = parser.parse_known_args( + [ + "--mode", + "train", + "--seed", + "100", + "--config-yml", + config_yaml, + "--cpu", + "--checkpoint", + scaling_pt, + ] + ) + update_yaml_with_dict( + configs["gemnet_oc"], + config_yaml, + update_dict_with={ + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + ) + config = build_config(args, override_args) + + # (2) if existing scaling factors are present remove them + if "scale_file" in config["model"]: + config["model"].pop("scale_file") + + compute_scaling_factors(config) + + # (3) try to run the config with the newly generated scaling factors + _ = _run_main( + rundir=tempdirname, + update_dict_with={ + "optim": {"max_epochs": 1}, + "model": {"use_pbc_single": True, "scale_file": scaling_pt}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + input_yaml=configs["gemnet_oc"], + ) + # not all models are tested with otf normalization estimation # only gemnet_oc, escn, equiformer, and their hydra versions @pytest.mark.parametrize(