diff --git a/configs/ocp_hydra_example.yml b/configs/ocp_hydra_example.yml index dbcadeff3..10373ad61 100755 --- a/configs/ocp_hydra_example.yml +++ b/configs/ocp_hydra_example.yml @@ -22,9 +22,11 @@ logger: wandb outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index 669449f0b..e762dfeb5 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -10,6 +10,7 @@ import ast import collections import copy +import errno import importlib import itertools import json @@ -38,6 +39,7 @@ from torch_scatter import scatter, segment_coo, segment_csr import fairchem.core +from fairchem.core.common.registry import registry from fairchem.core.modules.loss import AtomwiseL2Loss, L2MAELoss if TYPE_CHECKING: @@ -1370,3 +1372,20 @@ def get_loss_module(loss_name): raise NotImplementedError(f"Unknown loss function name: {loss_name}") return loss_fn + + +def load_model_and_weights_from_checkpoint(checkpoint_path: str) -> nn.Module: + if not os.path.isfile(checkpoint_path): + raise FileNotFoundError( + errno.ENOENT, "Checkpoint file not found", checkpoint_path + ) + logging.info(f"Loading checkpoint from: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) + # this assumes the checkpont also contains the config with the full model in it + # TODO: need to schematize how we save and load the config from checkpoint + config = checkpoint["config"]["model"] + name = config.pop("name") + model = registry.get_model_class(name)(**config) + matched_dict = match_state_dict(model.state_dict(), checkpoint["state_dict"]) + load_state_dict(model, matched_dict, strict=True) + return model diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index c070fea4e..480ee7d02 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -9,7 +9,7 @@ import copy import logging -from abc import ABC, ABCMeta, abstractmethod +from abc import ABCMeta, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING @@ -21,6 +21,7 @@ from fairchem.core.common.utils import ( compute_neighbors, get_pbc_distances, + load_model_and_weights_from_checkpoint, radius_graph_pbc, ) @@ -232,64 +233,79 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: return -class HydraInterface(ABC): - # a hydra has a backbone and heads - @abstractmethod - def get_backbone(self) -> BackboneInterface: - raise not NotImplementedError - - @abstractmethod - def get_heads(self) -> dict[str, HeadInterface]: - raise not NotImplementedError - - @registry.register_model("hydra") -class HydraModel(nn.Module, GraphModelMixin, HydraInterface): +class HydraModel(nn.Module, GraphModelMixin): def __init__( self, - backbone: dict, - heads: dict, + backbone: dict | None = None, + heads: dict | None = None, + finetune_config: dict | None = None, otf_graph: bool = True, + pass_through_head_outputs: bool = False, ): super().__init__() + self.device = None self.otf_graph = otf_graph - self.device = "cpu" - # make a copy so we don't modify the original config - backbone = copy.deepcopy(backbone) - heads = copy.deepcopy(heads) - - backbone_model_name = backbone.pop("model") - self.backbone: BackboneInterface = registry.get_model_class( - backbone_model_name - )( - **backbone, - ) - - # Iterate through outputs_cfg and create heads - self.output_heads: dict[str, HeadInterface] = {} - - head_names_sorted = sorted(heads.keys()) - for head_name in head_names_sorted: - head_config = heads[head_name] - if "module" not in head_config: - raise ValueError( - f"{head_name} head does not specify module to use for the head" - ) - - module_name = head_config.pop("module") - self.output_heads[head_name] = registry.get_model_class(module_name)( - self.backbone, - **head_config, + # This is required for hydras with models that have multiple outputs per head, since we will deprecate + # the old config system at some point, this will prevent the need to make major modifications to the trainer + # because they all expect the name of the outputs directly instead of the head_name.property_name + self.pass_through_head_outputs = pass_through_head_outputs + + # if finetune_config is provided, then attempt to load the model from the given finetune checkpoint + starting_model = None + if finetune_config is not None: + starting_model: HydraModel = load_model_and_weights_from_checkpoint(finetune_config["starting_checkpoint"]) + logging.info(f"Found and loaded fine-tuning checkpoint: {finetune_config['starting_checkpoint']} (Note we are NOT loading the training state from this checkpoint, only parts of the model and weights)") + assert isinstance(starting_model, HydraModel), "Can only finetune starting from other hydra models!" + + if backbone is not None: + backbone = copy.deepcopy(backbone) + backbone_model_name = backbone.pop("model") + self.backbone: BackboneInterface = registry.get_model_class( + backbone_model_name + )( + **backbone, ) + elif starting_model is not None: + self.backbone = starting_model.backbone + logging.info(f"User did not specify a backbone, using the backbone from the starting checkpoint {self.backbone}") + else: + raise RuntimeError("Backbone not specified and not found in the starting checkpoint") + + if heads is not None: + heads = copy.deepcopy(heads) + # Iterate through outputs_cfg and create heads + self.output_heads: dict[str, HeadInterface] = {} + + head_names_sorted = sorted(heads.keys()) + assert len(set(head_names_sorted)) == len(head_names_sorted), "Head names must be unique!" + for head_name in head_names_sorted: + head_config = heads[head_name] + if "module" not in head_config: + raise ValueError( + f"{head_name} head does not specify module to use for the head" + ) - self.output_heads = torch.nn.ModuleDict(self.output_heads) + module_name = head_config.pop("module") + self.output_heads[head_name] = registry.get_model_class(module_name)( + self.backbone, + **head_config, + ) - def to(self, *args, **kwargs): - if "device" in kwargs: - self.device = kwargs["device"] - return super().to(*args, **kwargs) + self.output_heads = torch.nn.ModuleDict(self.output_heads) + elif starting_model is not None: + self.output_heads = starting_model.output_heads + logging.info(f"User did not specify heads, using the output heads from the starting checkpoint {self.output_heads}") + else: + raise RuntimeError("Heads not specified and not found in the starting checkpoint") def forward(self, data: Batch): + # lazily get device from input to use with amp, at least one input must be a tensor to figure out it's device + if not self.device: + device_from_tensors = {x.device.type for x in data.values() if isinstance(x, torch.Tensor)} + assert len(device_from_tensors) == 1, f"all inputs must be on the same device, found the following devices {device_from_tensors}" + self.device = device_from_tensors.pop() + emb = self.backbone(data) # Predict all output properties for all structures in the batch for now. out = {} @@ -297,12 +313,11 @@ def forward(self, data: Batch): with torch.autocast( device_type=self.device, enabled=self.output_heads[k].use_amp ): - out.update(self.output_heads[k](data, emb)) + if self.pass_through_head_outputs: + out.update(self.output_heads[k](data, emb)) + else: + out[k] = self.output_heads[k](data, emb) return out - def get_backbone(self) -> BackboneInterface: - return self.backbone - def get_heads(self) -> dict[str, HeadInterface]: - return self.output_heads diff --git a/src/fairchem/core/models/finetune_hydra.py b/src/fairchem/core/models/finetune_hydra.py deleted file mode 100644 index 6c271e24e..000000000 --- a/src/fairchem/core/models/finetune_hydra.py +++ /dev/null @@ -1,177 +0,0 @@ -from __future__ import annotations - -import copy -import errno -import logging -import os -from enum import Enum -from typing import TYPE_CHECKING - -import torch -from torch import nn - -from fairchem.core.common.registry import registry -from fairchem.core.common.utils import load_state_dict, match_state_dict -from fairchem.core.models.base import BackboneInterface, HeadInterface, HydraInterface - -if TYPE_CHECKING: - from torch_geometric.data import Batch - -FTHYDRA_NAME = "finetune_hydra" - -class FineTuneMode(Enum): - # in DATA_ONLY, we load the entire model and only finetune on new data - DATA_ONLY = 1 - # in this mode, we only load the Backbone and feed the output of the backbone - # to new heads that are specified - RETAIN_BACKBONE_ONLY = 2 - - -def get_model_config_from_checkpoint(checkpoint_path: str) -> dict: - if not os.path.isfile(checkpoint_path): - raise FileNotFoundError( - errno.ENOENT, "Checkpoint file not found", checkpoint_path - ) - checkpoint = torch.load(checkpoint_path) - return checkpoint["config"]["model"] - - -def load_hydra_model(checkpoint_path: str) -> HydraInterface: - if not os.path.isfile(checkpoint_path): - raise FileNotFoundError( - errno.ENOENT, "Checkpoint file not found", checkpoint_path - ) - logging.info(f"Loading checkpoint from: {checkpoint_path}") - checkpoint = torch.load(checkpoint_path) - config = checkpoint["config"]["model"] - name = config.pop("name") - hydra_model = registry.get_model_class(name)(**config) - assert isinstance( - hydra_model, HydraInterface - ), "Can only load models with the HydraInterface" - matched_dict = match_state_dict(hydra_model.state_dict(), checkpoint["state_dict"]) - load_state_dict(hydra_model, matched_dict, strict=True) - return hydra_model - - -class FTConfig: - FT_CONFIG_NAME = "finetune_config" - STARTING_CHECKPOINT = "starting_checkpoint" - STARTING_MODEL = "starting_model" - MODE = "mode" - HEADS = "heads" - - def __init__(self, config: dict): - self.config = config - self._mode = FineTuneMode[self.config[FTConfig.MODE]] - assert ( - (FTConfig.STARTING_CHECKPOINT in self.config) - or (FTConfig.STARTING_MODEL in self.config) - ), "Either a starting checkpoint or a starting model must be provided!" - assert FTConfig.MODE in self.config - if self._mode == FineTuneMode.RETAIN_BACKBONE_ONLY: - # in this mode, we keep the backbone but attach new output heads specified in head config - assert ( - FTConfig.HEADS in self.config - ), "heads cannot be empty when using RETAIN_BACKBONE_ONLY mode!" - - def load_model(self) -> nn.Module: - # if provided a hydra config to start, build from the starting hydra model - # this assumes the weights are loaded from the state_dict in the checkpoint.pt file instead - # so no actual weights are loaded here - if FTConfig.STARTING_MODEL in self.config: - # register model from hydra_config - config_copy = copy.deepcopy(self.config[FTConfig.STARTING_MODEL]) - name = config_copy.pop("name") - hydra_model = registry.get_model_class(name)(**config_copy) - # if provided a checkpoint to start then load the model and weights from the given checkpoint - # this happens used in the beginning of a finetuning run - elif FTConfig.STARTING_CHECKPOINT in self.config: - hydra_model: HydraInterface = load_hydra_model( - self.config[FTConfig.STARTING_CHECKPOINT] - ) - assert isinstance(hydra_model, HydraInterface) - - num_params = sum(p.numel() for p in hydra_model.parameters()) - logging.info(f"Loaded Original hydra model with {num_params} params") - return hydra_model - - def get_standalone_config(self) -> dict: - # replace a config with a checkpoint with one that has the model config only - # this is required for standalone prediction (so we don't need to ship the original checkpoint), - # multi-round finetuning, and better robustness - standalone_config = { - "name": FTHYDRA_NAME, - FTConfig.FT_CONFIG_NAME: self.config, - } - if FTConfig.STARTING_CHECKPOINT in self.config: - # modify the config to store the original model config inside model attrs - # so we dont need the checkpoint again when loading from checkpoint - new_config = copy.deepcopy(self.config) - new_config[FTConfig.STARTING_MODEL] = ( - get_model_config_from_checkpoint( - self.config[FTConfig.STARTING_CHECKPOINT] - ) - ) - standalone_config[FTConfig.FT_CONFIG_NAME] = new_config - return standalone_config - - @property - def mode(self) -> FineTuneMode: - return self._mode - - @property - def head_config(self) -> dict: - return copy.deepcopy(self.config[FTConfig.HEADS]) - - -@registry.register_model(FTHYDRA_NAME) -class FineTuneHydra(nn.Module, HydraInterface): - def __init__(self, finetune_config: dict): - super().__init__() - ft_config = FTConfig(finetune_config) - logging.info(f"Initializing FineTuneHydra model in {ft_config.mode} mode") - hydra_model: HydraInterface = ft_config.load_model() - self.backbone: BackboneInterface = hydra_model.get_backbone() - - if ft_config.mode == FineTuneMode.DATA_ONLY: - # in this mode, we just use the model as is and train on it with new data - self.output_heads: dict[str, HeadInterface] = hydra_model.get_heads() - elif ft_config.mode == FineTuneMode.RETAIN_BACKBONE_ONLY: - # in this mode, we keep the backbone but attach new output heads specified in head config - self.output_heads: dict[str, HeadInterface] = {} - heads_config = ft_config.head_config - head_names_sorted = sorted(heads_config.keys()) - for head_name in head_names_sorted: - head_config = heads_config[head_name] - if "module" not in head_config: - raise ValueError( - f"{head_name} head does not specify module to use for the head" - ) - - module_name = head_config.pop("module") - self.output_heads[head_name] = registry.get_model_class(module_name)( - self.backbone, - **head_config, - ) - num_params = sum( - p.numel() for p in self.output_heads[head_name].parameters() - ) - logging.info( - f"Attaching new output head: {module_name} with {num_params} params" - ) - self.output_heads = torch.nn.ModuleDict(self.output_heads) - - - def forward(self, data: Batch): - emb = self.backbone(data) - out = {} - for k in self.output_heads: - out.update(self.output_heads[k](data, emb)) - return out - - def get_backbone(self) -> BackboneInterface: - return self.backbone - - def get_heads(self) -> dict[str, HeadInterface]: - return self.output_heads diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index a31421a75..d84a2c12f 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -44,7 +44,6 @@ update_config, ) from fairchem.core.datasets.base_dataset import create_dataset -from fairchem.core.models.finetune_hydra import FineTuneHydra, FTConfig from fairchem.core.modules.evaluator import Evaluator from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage from fairchem.core.modules.loss import DDPLoss @@ -716,13 +715,6 @@ def save( training_state: bool = True, ) -> str | None: if not self.is_debug and distutils.is_master(): - # if we are using a FineTune-able model, then we need to modify the config to remove - # the original starting checkpoint so it can be loaded standalone, can move this to save function - if isinstance(self.model, FineTuneHydra): - self.config["model"] = FTConfig( - self.config["model"][FTConfig.FT_CONFIG_NAME] - ).get_standalone_config() - state = { "state_dict": self.model.state_dict(), "normalizers": { diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 662341bdc..0ced35bef 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -251,9 +251,18 @@ def _forward(self, batch): for target_key in self.output_targets: ### Target property is a direct output of the model if target_key in out: - pred = out[target_key] - ## Target property is a derived output of the model. Construct the - ## parent property + if isinstance(out[target_key], torch.Tensor): + pred = out[target_key] + elif isinstance(out[target_key], dict): + # if output is a nested dictionary (in the case of hydra models), we attempt to retrieve it using the property name + # ie: "output_head_name.property" + assert "property" in self.output_targets[target_key], \ + f"we need to know which property to match the target to, please specify the property field in the task config, current config: {self.output_targets[target_key]}" + property = self.output_targets[target_key]["property"] + pred = out[target_key][property] + + ## TODO: deprecate the following logic? + ## Otherwise, assume target property is a derived output of the model. Construct the parent property else: _max_rank = 0 for subtarget_key in self.output_targets[target_key]["decomposition"]: diff --git a/tests/core/e2e/test_e2e_finetune_hydra.py b/tests/core/e2e/test_e2e_finetune_hydra.py index 91f2abd49..df9def3b0 100644 --- a/tests/core/e2e/test_e2e_finetune_hydra.py +++ b/tests/core/e2e/test_e2e_finetune_hydra.py @@ -8,8 +8,6 @@ import torch from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths -from fairchem.core.models.finetune_hydra import FTHYDRA_NAME, FineTuneMode, FTConfig - @pytest.fixture() def tutorial_val_src(tutorial_dataset_path): @@ -49,7 +47,7 @@ def run_main_with_ft_hydra(tempdir: str, yaml: str, data_src: str, run_args: dict, - ft_config: str, + model_config: str, output_checkpoint: str): _run_main( tempdir, @@ -68,10 +66,7 @@ def run_main_with_ft_hydra(tempdir: str, test_src=str(data_src), otf_norms=False, ), - "model": { - "name": FTHYDRA_NAME, - FTConfig.FT_CONFIG_NAME: ft_config, - } + "model": model_config, }, update_run_args_with=run_args, save_checkpoint_to=output_checkpoint, @@ -87,9 +82,9 @@ def test_finetune_hydra_retain_backbone(tutorial_val_src): with tempfile.TemporaryDirectory() as ft_temp_dir: ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") - ft_config = { - "mode": FineTuneMode.RETAIN_BACKBONE_ONLY.name, - "starting_checkpoint": starting_ckpt, + model_config = { + "name" : "hydra", + "finetune_config": {'starting_checkpoint': starting_ckpt}, "heads": { "energy": { "module": "equiformer_v2_energy_head" @@ -103,12 +98,12 @@ def test_finetune_hydra_retain_backbone(tutorial_val_src): yaml = ft_yml, data_src = tutorial_val_src, run_args = {"seed": 1000}, - ft_config = ft_config, + model_config = model_config, output_checkpoint = ck_ft_path) assert os.path.isfile(ck_ft_path) ft_ckpt = torch.load(ck_ft_path) assert "config" in ft_ckpt - assert ft_ckpt["config"]["model"]["name"] == FTHYDRA_NAME + assert ft_ckpt["config"]["model"]["name"] == "hydra" # check that the backbone weights are the same, and other weights are not the same new_state_dict = ft_ckpt["state_dict"] for key in new_state_dict: @@ -128,28 +123,26 @@ def test_finetune_hydra_data_only(tutorial_val_src): with tempfile.TemporaryDirectory() as ft_temp_dir: ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") - ft_config = { - "mode": FineTuneMode.DATA_ONLY.name, - "starting_checkpoint": starting_ckpt, + model_config = { + "name" : "hydra", + "finetune_config": {'starting_checkpoint': starting_ckpt}, } run_main_with_ft_hydra(tempdir = ft_temp_dir, yaml = ft_yml, data_src = tutorial_val_src, run_args = {"seed": 1000}, - ft_config = ft_config, + model_config = model_config, output_checkpoint = ck_ft_path) assert os.path.isfile(ck_ft_path) ft_ckpt = torch.load(ck_ft_path) assert "config" in ft_ckpt config_model = ft_ckpt["config"]["model"] - assert config_model["name"] == FTHYDRA_NAME + assert config_model["name"] == "hydra" # check that the entire model weights are the same new_state_dict = ft_ckpt["state_dict"] assert len(new_state_dict) == len(old_state_dict) for key in new_state_dict: assert torch.allclose(new_state_dict[key], old_state_dict[key]) - # check the new checkpoint contains a hydra model - assert FTConfig.STARTING_MODEL in config_model[FTConfig.FT_CONFIG_NAME] def test_finetune_from_finetunehydra(tutorial_val_src): @@ -159,15 +152,15 @@ def test_finetune_from_finetunehydra(tutorial_val_src): with tempfile.TemporaryDirectory() as finetune_run1_dir: ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") ck_ft_path = os.path.join(finetune_run1_dir, "checkpoint_ft.pt") - ft_config_1 = { - "mode": FineTuneMode.DATA_ONLY.name, - "starting_checkpoint": starting_ckpt, + model_config_1 = { + "name" : "hydra", + "finetune_config": {'starting_checkpoint': starting_ckpt}, } run_main_with_ft_hydra(tempdir = finetune_run1_dir, yaml = ft_yml, data_src = tutorial_val_src, run_args = {"seed": 1000}, - ft_config = ft_config_1, + model_config = model_config_1, output_checkpoint = ck_ft_path) assert os.path.isfile(ck_ft_path) @@ -175,20 +168,20 @@ def test_finetune_from_finetunehydra(tutorial_val_src): ######################################################################################## with tempfile.TemporaryDirectory() as finetune_run2_dir: ck_ft2_path = os.path.join(finetune_run2_dir, "checkpoint_ft.pt") - ft_config_2 = { - "mode": FineTuneMode.DATA_ONLY.name, - "starting_checkpoint": ck_ft_path, + model_config_2 = { + "name" : "hydra", + "finetune_config": {'starting_checkpoint': ck_ft_path}, } run_main_with_ft_hydra(tempdir = finetune_run2_dir, yaml = ft_yml, data_src = tutorial_val_src, run_args = {"seed": 1000}, - ft_config = ft_config_2, + model_config = model_config_2, output_checkpoint = ck_ft2_path) ft_ckpt2 = torch.load(ck_ft2_path) assert "config" in ft_ckpt2 config_model = ft_ckpt2["config"]["model"] - assert config_model["name"] == FTHYDRA_NAME + assert config_model["name"] == "hydra" old_state_dict = torch.load(ck_ft_path)["state_dict"] new_state_dict = ft_ckpt2["state_dict"] # the state dicts should still be identical because we made the LR = 0.0 diff --git a/tests/core/e2e/test_s2efs.py b/tests/core/e2e/test_s2efs.py index 94b0862ed..037979e60 100644 --- a/tests/core/e2e/test_s2efs.py +++ b/tests/core/e2e/test_s2efs.py @@ -44,7 +44,7 @@ def test_smoke_s2efs_predict( {"forces": {"fn": "l2mae", "coefficient": 100}}, {"stress": {"fn": "mae", "coefficient": 100}}, ], - "outputs": {"stress": {"level": "system", "irrep_dim": 2}}, + "outputs": {"stress": {"level": "system", "irrep_dim": 2, "property": "stress"}}, "evaluation_metrics": {"metrics": {"stress": ["mae"]}}, "dataset": { "train": { diff --git a/tests/core/models/test_configs/test_dpp_hydra.yml b/tests/core/models/test_configs/test_dpp_hydra.yml index e41a39141..24f13e6ad 100755 --- a/tests/core/models/test_configs/test_dpp_hydra.yml +++ b/tests/core/models/test_configs/test_dpp_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True @@ -52,6 +54,7 @@ model: heads: energy: module: dimenetplusplus_energy_and_force_head + pass_through_head_outputs: True # *** Important note *** # The total number of gpus used for this run was 256. diff --git a/tests/core/models/test_configs/test_equiformerv2_hydra.yml b/tests/core/models/test_configs/test_equiformerv2_hydra.yml index 1852799f5..9747eec80 100644 --- a/tests/core/models/test_configs/test_equiformerv2_hydra.yml +++ b/tests/core/models/test_configs/test_equiformerv2_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/tests/core/models/test_configs/test_escn_hydra.yml b/tests/core/models/test_configs/test_escn_hydra.yml index c51d46fc3..8d730dad0 100644 --- a/tests/core/models/test_configs/test_escn_hydra.yml +++ b/tests/core/models/test_configs/test_escn_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/tests/core/models/test_configs/test_finetune_hydra.yml b/tests/core/models/test_configs/test_finetune_hydra.yml index a5f1dc51b..e1be6d20b 100644 --- a/tests/core/models/test_configs/test_finetune_hydra.yml +++ b/tests/core/models/test_configs/test_finetune_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml index 036ed689f..4b5c239fc 100644 --- a/tests/core/models/test_configs/test_gemnet_dt_hydra.yml +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True @@ -73,6 +75,7 @@ model: forces: module: gemnet_t_force_head + optim: batch_size: 8 eval_batch_size: 8 diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml index 358dd1c86..e8e4eca9c 100644 --- a/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True @@ -71,6 +73,7 @@ model: heads: energy_and_forces: module: gemnet_t_energy_and_grad_force_head + pass_through_head_outputs: True optim: batch_size: 8 diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml index 716718e3e..a58d328bd 100644 --- a/tests/core/models/test_configs/test_gemnet_oc_hydra.yml +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml index 90001488b..88cb493f4 100644 --- a/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml @@ -3,9 +3,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True @@ -96,6 +98,7 @@ model: energy: module: gemnet_oc_energy_and_grad_force_head num_global_out_layers: 2 + pass_through_head_outputs: True optim: batch_size: 5 diff --git a/tests/core/models/test_configs/test_painn_hydra.yml b/tests/core/models/test_configs/test_painn_hydra.yml index 2c4731742..a6b26c5d4 100644 --- a/tests/core/models/test_configs/test_painn_hydra.yml +++ b/tests/core/models/test_configs/test_painn_hydra.yml @@ -2,9 +2,11 @@ trainer: forces outputs: energy: + property: energy shape: 1 level: system forces: + property: forces irrep_dim: 1 level: atom train_on_free_atoms: True diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 54d58db1c..2f0903608 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -264,12 +264,12 @@ def test_eqv2_hydra_activation_checkpoint(): # way to do this is save the rng state and reset it after stepping the first model start_rng_state = torch.random.get_rng_state() outputs_no_ac = no_ac_model(inputs) - torch.autograd.backward(outputs_no_ac["energy"].sum() + outputs_no_ac["forces"].sum()) + torch.autograd.backward(outputs_no_ac["energy"]["energy"].sum() + outputs_no_ac["forces"]["forces"].sum()) # reset the rng state to the beginning torch.random.set_rng_state(start_rng_state) outptuts_ac = ac_model(inputs) - torch.autograd.backward(outptuts_ac["energy"].sum() + outptuts_ac["forces"].sum()) + torch.autograd.backward(outptuts_ac["energy"]["energy"].sum() + outptuts_ac["forces"]["forces"].sum()) # assert all the gradients are identical between the model with checkpointing and no checkpointing ac_model_grad_dict = {name:p.grad for name, p in ac_model.named_parameters() if p.grad is not None}