diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml new file mode 100644 index 0000000000..fb9c18faa1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -0,0 +1,80 @@ +name: Report a Bug +description: FAIR-Chem bug report +labels: bug +body: + - type: input + id: python-version + attributes: + label: Python version + description: Use `python --version` to get Python version + placeholder: ex. Python 3.11.5 + validations: + required: true + + - type: input + id: fairchem-version + attributes: + label: fairchem-core version + description: Use `pip show fairchem-core | grep Version` to get fairchem-core version + placeholder: ex. 1.2.1 + validations: + required: true + + - type: input + id: torch-version + attributes: + label: pytorch version + description: Use `pip show torch | grep Version` to get pytorch version + placeholder: ex. 2.4.0 + validations: + required: true + + - type: input + id: cuda-version + attributes: + label: cuda version + description: Use `python -c 'import torch; cuda=torch.cuda.is_available(); print(cuda,torch.version.cuda if cuda else None);'` to get cuda version + placeholder: ex. 12.1 + validations: + required: true + + - type: input + id: os + attributes: + label: Operating system version + placeholder: ex. Ubuntu 22.04 LTS + validations: + required: false + + - type: textarea + id: code-snippet + attributes: + label: Minimal example + description: Please provide a minimal code snippet to reproduce this bug. + render: Python + validations: + required: false + + - type: textarea + id: current-behavior + attributes: + label: Current behavior + description: What behavior do you see? + validations: + required: true + + - type: textarea + id: expected-behavior + attributes: + label: Expected Behavior + description: What did you expect to see? + validations: + required: true + + - type: textarea + id: files + attributes: + label: Relevant files to reproduce this bug + description: Please upload relevant files to help reproduce this bug, or logs if helpful. + validations: + required: false \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/misc.yaml b/.github/ISSUE_TEMPLATE/misc.yaml new file mode 100644 index 0000000000..19849ac65d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/misc.yaml @@ -0,0 +1,10 @@ +name: Other +description: A report is not a bug exactly + +body: + - type: textarea + attributes: + label: What would you like to report? + description: A clear and concise description of what you would like to report. + validations: + required: true diff --git a/configs/omat24/all/eqV2_153M.yml b/configs/omat24/all/eqV2_153M.yml index dffd4ec346..bf87a9e3ea 100644 --- a/configs/omat24/all/eqV2_153M.yml +++ b/configs/omat24/all/eqV2_153M.yml @@ -43,7 +43,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/all/eqV2_31M.yml b/configs/omat24/all/eqV2_31M.yml index 95ff1f89f8..902a58b31c 100644 --- a/configs/omat24/all/eqV2_31M.yml +++ b/configs/omat24/all/eqV2_31M.yml @@ -44,7 +44,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/all/eqV2_86M.yml b/configs/omat24/all/eqV2_86M.yml index 30167e81a5..154ed57cb4 100644 --- a/configs/omat24/all/eqV2_86M.yml +++ b/configs/omat24/all/eqV2_86M.yml @@ -43,7 +43,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/finetune/eqV2_153M_ft_salexmptrj.yml b/configs/omat24/finetune/eqV2_153M_ft_salexmptrj.yml index d02b83760d..bd04e683d8 100644 --- a/configs/omat24/finetune/eqV2_153M_ft_salexmptrj.yml +++ b/configs/omat24/finetune/eqV2_153M_ft_salexmptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 10 diff --git a/configs/omat24/finetune/eqV2_31M_ft_salexmptrj.yml b/configs/omat24/finetune/eqV2_31M_ft_salexmptrj.yml index 146a153125..36e89c66bd 100644 --- a/configs/omat24/finetune/eqV2_31M_ft_salexmptrj.yml +++ b/configs/omat24/finetune/eqV2_31M_ft_salexmptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 10 diff --git a/configs/omat24/finetune/eqV2_86M_ft_salexmptrj.yml b/configs/omat24/finetune/eqV2_86M_ft_salexmptrj.yml index 8976ffa9a7..8e230aa1a6 100644 --- a/configs/omat24/finetune/eqV2_86M_ft_salexmptrj.yml +++ b/configs/omat24/finetune/eqV2_86M_ft_salexmptrj.yml @@ -43,7 +43,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 10 diff --git a/configs/omat24/mptrj/eqV2_153M_dens_mptrj.yml b/configs/omat24/mptrj/eqV2_153M_dens_mptrj.yml index 050d5921dd..435309a224 100644 --- a/configs/omat24/mptrj/eqV2_153M_dens_mptrj.yml +++ b/configs/omat24/mptrj/eqV2_153M_dens_mptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml b/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml index 818eaeb093..78ccb4b92e 100644 --- a/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml +++ b/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 5 - forces: fn: l2mae coefficient: 20 @@ -148,17 +148,16 @@ model: use_force_encoding: True use_noise_schedule_sigma_encoding: False - use_denoising_energy: True - use_denoising_stress: False - heads: energy: - module: fairchem.core.models.equiformer_v2.equiformer_v2_dens.DeNSEnergyHead + module: fairchem.core.models.equiformer_v2.equiformer_v2_dens.DeNSScalarHead + use_denoising: True forces: - module: fairchem.core.models.equiformer_v2.equiformer_v2_dens.DeNSForceHead + module: fairchem.core.models.equiformer_v2.equiformer_v2_dens.DeNSVectorHead stress: module: fairchem.core.models.equiformer_v2.equiformer_v2_dens.DeNSRank2Head output_name: stress use_source_target_embedding: True decompose: True + use_denoising: False diff --git a/configs/omat24/mptrj/eqV2_31M_mptrj.yml b/configs/omat24/mptrj/eqV2_31M_mptrj.yml index c9ae7c84d9..7f4c83cf68 100644 --- a/configs/omat24/mptrj/eqV2_31M_mptrj.yml +++ b/configs/omat24/mptrj/eqV2_31M_mptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/mptrj/eqV2_86M_dens_mptrj.yml b/configs/omat24/mptrj/eqV2_86M_dens_mptrj.yml index f931ee78aa..47f0958856 100644 --- a/configs/omat24/mptrj/eqV2_86M_dens_mptrj.yml +++ b/configs/omat24/mptrj/eqV2_86M_dens_mptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 20 diff --git a/docs/core/model_checkpoints.md b/docs/core/model_checkpoints.md index 6cfba70555..342498e14e 100644 --- a/docs/core/model_checkpoints.md +++ b/docs/core/model_checkpoints.md @@ -149,7 +149,7 @@ OC22 dataset or pretrained models, as well as the original paper for each model: | GemNet-OC-S2EF-ODAC | GemNet-OC | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Gemnet-OC.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/s2ef/gemnet-oc.yml) | | eSCN-S2EF-ODAC | eSCN | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/eSCN.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/s2ef/eSCN.yml) | | EquiformerV2-S2EF-ODAC | EquiformerV2 | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231116/eqv2_31M.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/s2ef/eqv2_31M.yml) | -| EquiformerV2-Large-S2EF-ODAC | EquiformerV2 (Large) | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Equiformer_V2_Large.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/s2ef/eqv2_153M.yml) | +| EquiformerV2-Large-S2EF-ODAC | EquiformerV2 (Large) | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231116/Equiformer_V2_Large.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/s2ef/eqv2_153M.yml) | ## IS2RE Direct models @@ -157,7 +157,7 @@ OC22 dataset or pretrained models, as well as the original paper for each model: |-------------------------|--------------|--- | --- | | Gemnet-OC-IS2RE-ODAC | Gemnet-OC | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Gemnet-OC_Direct.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/is2re/gemnet-oc.yml) | | eSCN-IS2RE-ODAC | eSCN | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/eSCN_Direct.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/is2re/eSCN.yml) | -| EquiformerV2-IS2RE-ODAC | EquiformerV2 | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Equiformer_V2_Direct.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/is2re/eqv2_31M.yml) | +| EquiformerV2-IS2RE-ODAC | EquiformerV2 | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231116/Equiformer_V2_Direct.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/is2re/eqv2_31M.yml) | The models in the table above were trained to predict relaxed energy directly. Relaxed energies can also be predicted by running structural relaxations using the S2EF models from the previous section. diff --git a/packages/fairchem-data-om/pyproject.toml b/packages/fairchem-data-om/pyproject.toml index 94ae5454b8..06f4a47d9e 100644 --- a/packages/fairchem-data-om/pyproject.toml +++ b/packages/fairchem-data-om/pyproject.toml @@ -30,9 +30,6 @@ git_describe_command = 'git describe --tags --match fairchem_data_om-*' [tool.hatch.build] directory = "../../dist-data-om" -[tool.hatch.build] -directory = "../../dist" - [tool.hatch.build.targets.sdist] only-include = ["src/fairchem/data/om"] diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index 79992616df..9279c8daf3 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -12,12 +12,16 @@ from typing import TYPE_CHECKING import hydra +from omegaconf import OmegaConf if TYPE_CHECKING: import argparse from omegaconf import DictConfig + from fairchem.core.components.runner import Runner + + from submitit import AutoExecutor from submitit.helpers import Checkpointable, DelayedSubmission from torch.distributed.launcher.api import LaunchConfig, elastic_launch @@ -25,36 +29,54 @@ from fairchem.core.common import distutils from fairchem.core.common.flags import flags from fairchem.core.common.utils import get_timestamp_uid, setup_env_vars, setup_imports -from fairchem.core.components.runner import Runner logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class Submitit(Checkpointable): - def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None: + def __call__(self, dict_config: DictConfig) -> None: self.config = dict_config - self.cli_args = cli_args # TODO: setup_imports is not needed if we stop instantiating models with Registry. setup_imports() setup_env_vars() - try: - distutils.setup(map_cli_args_to_dist_config(cli_args)) - runner: Runner = hydra.utils.instantiate(dict_config.runner) - runner.load_state() - runner.run() - finally: - distutils.cleanup() - - def checkpoint(self, *args, **kwargs): + distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args)) + self._init_logger() + runner: Runner = hydra.utils.instantiate(dict_config.runner) + runner.load_state() + runner.run() + distutils.cleanup() + + def _init_logger(self) -> None: + # optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger + # don't start logger if in debug mode + if ( + "logger" in self.config + and distutils.is_master() + and not self.config.cli_args.debug + ): + # get a partial function from the config and instantiate wandb with it + logger_initializer = hydra.utils.instantiate(self.config.logger) + simple_config = OmegaConf.to_container( + self.config, resolve=True, throw_on_missing=True + ) + logger_initializer( + config=simple_config, + run_id=self.config.cli_args.timestamp_id, + run_name=self.config.cli_args.identifier, + log_dir=self.config.cli_args.logdir, + ) + + def checkpoint(self, *args, **kwargs) -> DelayedSubmission: + # TODO: this is yet to be tested properly logging.info("Submitit checkpointing callback is triggered") - new_runner = Runner() - new_runner.save_state() + new_runner = Submitit() + self.runner.save_state() logging.info("Submitit checkpointing callback is completed") - return DelayedSubmission(new_runner, self.config) + return DelayedSubmission(new_runner, self.config, self.cli_args) -def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict: +def map_cli_args_to_dist_config(cli_args: DictConfig) -> dict: return { "world_size": cli_args.num_nodes * cli_args.num_gpus, "distributed_backend": "gloo" if cli_args.cpu else "nccl", @@ -76,8 +98,8 @@ def get_hydra_config_from_yaml( return hydra.compose(config_name=config_name, overrides=overrides_args) -def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace): - Submitit()(config, cli_args) +def runner_wrapper(config: DictConfig): + Submitit()(config) # this is meant as a future replacement for the main entrypoint @@ -91,6 +113,11 @@ def main( cfg = get_hydra_config_from_yaml(args.config_yml, override_args) timestamp_id = get_timestamp_uid() log_dir = os.path.join(args.run_dir, timestamp_id, "logs") + # override timestamp id and logdir + args.timestamp_id = timestamp_id + args.logdir = log_dir + os.makedirs(log_dir) + OmegaConf.update(cfg, "cli_args", vars(args), force_add=True) if args.submit: # Run on cluster executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3) executor.update_parameters( @@ -105,13 +132,23 @@ def main( slurm_qos=args.slurm_qos, slurm_account=args.slurm_account, ) - job = executor.submit(runner_wrapper, cfg, args) + job = executor.submit(runner_wrapper, cfg) logger.info( f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}" ) else: if args.num_gpus > 1: - logger.info(f"Running in local mode with {args.num_gpus} ranks") + logging.info(f"Running in local mode with {args.num_gpus} ranks") + # HACK to disable multiprocess dataloading in local mode + # there is an open issue where LMDB's environment cannot be pickled and used + # during torch multiprocessing https://github.com/pytorch/examples/issues/526 + # this HACK only works for a training submission where the config is passed in here + if "optim" in cfg and "num_workers" in cfg["optim"]: + cfg["optim"]["num_workers"] = 0 + logging.info( + "WARNING: running in local mode, setting dataloading num_workers to 0, see https://github.com/pytorch/examples/issues/526" + ) + launch_config = LaunchConfig( min_nodes=1, max_nodes=1, @@ -119,8 +156,8 @@ def main( rdzv_backend="c10d", max_restarts=0, ) - elastic_launch(launch_config, runner_wrapper)(cfg, args) + elastic_launch(launch_config, runner_wrapper)(cfg) else: logger.info("Running in local mode without elastic launch") distutils.setup_env_local() - runner_wrapper(cfg, args) + runner_wrapper(cfg) diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index 20b6e69227..8e9e3ceab6 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -716,7 +716,8 @@ def radius_graph_pbc( # Tensor of unit cells cells_per_dim = [ - torch.arange(-rep, rep + 1, device=device, dtype=torch.float) for rep in max_rep + torch.arange(-rep.item(), rep.item() + 1, device=device, dtype=torch.float) + for rep in max_rep ] unit_cell = torch.cartesian_prod(*cells_per_dim) num_cells = len(unit_cell) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_dens.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2_dens.py new file mode 100644 index 0000000000..3ad881d639 --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2_dens.py @@ -0,0 +1,586 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import math +from functools import partial +from typing import TYPE_CHECKING, Literal + +import torch + +from fairchem.core.common import gp_utils +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import conditional_grad + +try: + from e3nn import o3 +except ImportError: + import contextlib + + contextlib.suppress(ImportError) + +from fairchem.core.models.base import GraphData, HeadInterface +from fairchem.core.models.equiformer_v2.equiformer_v2 import ( + EquiformerV2Backbone, + eqv2_init_weights, +) +from fairchem.core.models.equiformer_v2.heads.rank2 import ( + Rank2SymmetricTensorHead, +) +from fairchem.core.models.equiformer_v2.so3 import SO3_Embedding, SO3_LinearV2 +from fairchem.core.models.equiformer_v2.transformer_block import ( + FeedForwardNetwork, + SO2EquivariantGraphAttention, +) + +if TYPE_CHECKING: + from torch_geometric.data.batch import Batch + + from fairchem.core.models.base import BackboneInterface + + +@registry.register_model("equiformer_v2_dens_backbone") +class EqV2DeNSBackbone(EquiformerV2Backbone): + """ + DeNS extra Args: + use_force_encoding (bool): For ablation study, whether to encode forces during denoising positions. Default: True. + use_noise_schedule_sigma_encoding (bool): For ablation study, whether to encode the sigma (sampled std of Gaussian noises) during + denoising positions when `fixed_noise_std` = False in config files. Default: False. + """ + + def __init__( + self, + use_pbc: bool = True, + use_pbc_single: bool = False, + regress_forces: bool = True, + otf_graph: bool = True, + max_neighbors: int = 500, + max_radius: float = 5.0, + max_num_elements: int = 90, + num_layers: int = 12, + sphere_channels: int = 128, + attn_hidden_channels: int = 128, + num_heads: int = 8, + attn_alpha_channels: int = 32, + attn_value_channels: int = 16, + ffn_hidden_channels: int = 512, + norm_type: str = "rms_norm_sh", + lmax_list: list[int] | None = None, + mmax_list: list[int] | None = None, + grid_resolution: int | None = None, + num_sphere_samples: int = 128, + edge_channels: int = 128, + use_atom_edge_embedding: bool = True, + share_atom_edge_embedding: bool = False, + use_m_share_rad: bool = False, + distance_function: str = "gaussian", + num_distance_basis: int = 512, + attn_activation: str = "scaled_silu", + use_s2_act_attn: bool = False, + use_attn_renorm: bool = True, + ffn_activation: str = "scaled_silu", + use_gate_act: bool = False, + use_grid_mlp: bool = False, + use_sep_s2_act: bool = True, + alpha_drop: float = 0.1, + drop_path_rate: float = 0.05, + proj_drop: float = 0.0, + weight_init: str = "normal", + enforce_max_neighbors_strictly: bool = True, + avg_num_nodes: float | None = None, + avg_degree: float | None = None, + use_energy_lin_ref: bool | None = False, + load_energy_lin_ref: bool | None = False, + activation_checkpoint: bool | None = False, + use_force_encoding=True, + use_noise_schedule_sigma_encoding: bool = False, + ): + if mmax_list is None: + mmax_list = [2] + if lmax_list is None: + lmax_list = [6] + super().__init__( + use_pbc, + use_pbc_single, + regress_forces, + otf_graph, + max_neighbors, + max_radius, + max_num_elements, + num_layers, + sphere_channels, + attn_hidden_channels, + num_heads, + attn_alpha_channels, + attn_value_channels, + ffn_hidden_channels, + norm_type, + lmax_list, + mmax_list, + grid_resolution, + num_sphere_samples, + edge_channels, + use_atom_edge_embedding, + share_atom_edge_embedding, + use_m_share_rad, + distance_function, + num_distance_basis, + attn_activation, + use_s2_act_attn, + use_attn_renorm, + ffn_activation, + use_gate_act, + use_grid_mlp, + use_sep_s2_act, + alpha_drop, + drop_path_rate, + proj_drop, + weight_init, + enforce_max_neighbors_strictly, + avg_num_nodes, + avg_degree, + use_energy_lin_ref, + load_energy_lin_ref, + activation_checkpoint, + ) + + # for denoising position + self.use_force_encoding = use_force_encoding + self.use_noise_schedule_sigma_encoding = use_noise_schedule_sigma_encoding + + # for denoising position, encode node-wise forces as node features + self.irreps_sh = o3.Irreps.spherical_harmonics(lmax=max(self.lmax_list), p=1) + self.force_embedding = SO3_LinearV2( + in_features=1, out_features=self.sphere_channels, lmax=max(self.lmax_list) + ) + + if self.use_noise_schedule_sigma_encoding: + self.noise_schedule_sigma_embedding = torch.nn.Linear( + in_features=1, out_features=self.sphere_channels + ) + + self.apply(partial(eqv2_init_weights, weight_init=self.weight_init)) + + @conditional_grad(torch.enable_grad()) + def forward(self, data) -> dict[str, torch.Tensor]: + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + self.device = data.pos.device + num_atoms = len(data.atomic_numbers) + atomic_numbers = data.atomic_numbers.long() + graph = self.generate_graph( + data, + enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, + ) + + data_batch = data.batch + if gp_utils.initialized(): + ( + atomic_numbers, + data_batch, + node_offset, + edge_index, + edge_distance, + edge_distance_vec, + ) = self._init_gp_partitions( + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, + ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + + ############################################################### + # Entering Graph Parallel Region + # after this point, if using gp, then node, edge tensors are split + # across the graph parallel ranks, some full tensors such as + # atomic_numbers_full are required because we need to index into the + # full graph when computing edge embeddings or reducing nodes from neighbors + # + # all tensors that do not have the suffix "_full" refer to the partial tensors. + # if not using gp, the full values are equal to the partial values + # ie: atomic_numbers_full == atomic_numbers + ############################################################### + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + for i in range(self.num_resolutions): + self.SO3_rotation[i].set_wigner(edge_rot_mat) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + x = SO3_Embedding( + len(atomic_numbers), + self.lmax_list, + self.sphere_channels, + self.device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l = 0, m = 0 coefficients for each resolution + for i in range(self.num_resolutions): + if self.num_resolutions == 1: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers) + else: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + ################## + ### DeNS Start ### + ################## + + # Node-wise force encoding during denoising positions + force_embedding = SO3_Embedding( + num_atoms, self.lmax_list, 1, self.device, self.dtype + ) + if hasattr(data, "denoising_pos_forward") and data.denoising_pos_forward: + assert hasattr(data, "forces") + force_data = data.forces + force_sh = o3.spherical_harmonics( + l=self.irreps_sh, + x=force_data, + normalize=True, + normalization="component", + ) + force_sh = force_sh.view(num_atoms, (max(self.lmax_list) + 1) ** 2, 1) + force_norm = force_data.norm(dim=-1, keepdim=True) + if hasattr(data, "noise_mask"): + noise_mask_tensor = data.noise_mask.view(-1, 1, 1) + force_sh = force_sh * noise_mask_tensor + else: + force_sh = torch.zeros( + (num_atoms, (max(self.lmax_list) + 1) ** 2, 1), + dtype=data.pos.dtype, + device=data.pos.device, + ) + force_norm = torch.zeros( + (num_atoms, 1), dtype=data.pos.dtype, device=data.pos.device + ) + + if not self.use_force_encoding: + # for ablation study, we enforce the force encoding to be zero. + force_sh = torch.zeros( + (num_atoms, (max(self.lmax_list) + 1) ** 2, 1), + dtype=data.pos.dtype, + device=data.pos.device, + ) + force_norm = torch.zeros( + (num_atoms, 1), dtype=data.pos.dtype, device=data.pos.device + ) + + force_norm = force_norm.view(-1, 1, 1) + force_norm = force_norm / math.sqrt( + 3.0 + ) # since we use `component` normalization + force_embedding.embedding = force_sh * force_norm + + force_embedding = self.force_embedding(force_embedding) + x.embedding = x.embedding + force_embedding.embedding + + # noise schedule sigma encoding + if self.use_noise_schedule_sigma_encoding: + if hasattr(data, "denoising_pos_forward") and data.denoising_pos_forward: + assert hasattr(data, "sigmas") + sigmas = data.sigmas + else: + sigmas = torch.zeros( + (num_atoms, 1), dtype=data.pos.dtype, device=data.pos.device + ) + noise_schedule_sigma_enbedding = self.noise_schedule_sigma_embedding(sigmas) + x.embedding[:, 0, :] = x.embedding[:, 0, :] + noise_schedule_sigma_enbedding + + ################## + ### DeNS End ### + ################## + + # Edge encoding (distance and atom edge) + graph.edge_distance = self.distance_expansion(graph.edge_distance) + if self.share_atom_edge_embedding and self.use_atom_edge_embedding: + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] + ] # Source atom atomic number + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] + ] # Target atom atomic number + source_embedding = self.source_embedding(source_element) + target_embedding = self.target_embedding(target_element) + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 + ) + + # Edge-degree embedding + edge_degree = self.edge_degree_embedding( + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + len(atomic_numbers), + graph.node_offset, + ) + x.embedding = x.embedding + edge_degree.embedding + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + if self.activation_checkpoint: + x = torch.utils.checkpoint.checkpoint( + self.blocks[i], + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + data_batch, # for GraphDropPath + graph.node_offset, + use_reentrant=not self.training, + ) + else: + x = self.blocks[i]( + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + batch=data_batch, # for GraphDropPath + node_offset=graph.node_offset, + ) + + # Final layer norm + x.embedding = self.norm(x.embedding) + + return {"node_embedding": x, "graph": graph} + + +@registry.register_model("eqV2_DeNS_scalar_head") +class DeNSScalarHead(torch.nn.Module, HeadInterface): + def __init__( + self, + backbone: BackboneInterface, + output_name: str = "energy", + reduce: Literal["sum", "mean"] = "sum", + use_denoising: bool = True, + ): + """ + Args: + backbone: Model backbone + output_name: property output name + reduce: reduction, mean or sum. Use mean for intensive properties and sum for extensive ones. + use_denoising: For ablation study, whether to predict the energy of the original structure given + a corrupted structure. If `False`, we zero out the energy prediction. Default: True. + """ + super().__init__() + self.reduce = reduce + self.avg_num_nodes = backbone.avg_num_nodes + self.scalar_block = FeedForwardNetwork( + backbone.sphere_channels, + backbone.ffn_hidden_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_grid, + backbone.ffn_activation, + backbone.use_gate_act, + backbone.use_grid_mlp, + backbone.use_sep_s2_act, + ) + self.output_name = output_name + self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) + self.use_denoising = use_denoising + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor | GraphData] + ) -> dict[str, torch.Tensor]: + node_out = self.scalar_block(emb["node_embedding"]) + node_out = node_out.embedding.narrow(1, 0, 1) + if gp_utils.initialized(): + node_out = gp_utils.gather_from_model_parallel_region(node_out, dim=0) + output_scalar = torch.zeros( + len(data.natoms), + device=node_out.device, + dtype=node_out.dtype, + ) + + output_scalar.index_add_(0, data.batch, node_out.view(-1)) + + if ( + hasattr(data, "denoising_pos_forward") + and data.denoising_pos_forward + and not self.use_denoising + ): + output_scalar = output_scalar * 0.0 + + if self.reduce == "sum": + return {self.output_name: output_scalar / self.avg_num_nodes} + elif self.reduce == "mean": + return {self.output_name: output_scalar / data.natoms} + else: + raise ValueError( + f"reduce can only be sum or mean, user provided: {self.reduce}" + ) + + +@registry.register_model("eqV2_DeNS_vector_head") +class DeNSVectorHead(torch.nn.Module, HeadInterface): + def __init__(self, backbone: BackboneInterface, output_name: str = "forces"): + super().__init__() + + self.output_name = output_name + self.activation_checkpoint = backbone.activation_checkpoint + + self.vector_block = SO2EquivariantGraphAttention( + backbone.sphere_channels, + backbone.attn_hidden_channels, + backbone.num_heads, + backbone.attn_alpha_channels, + backbone.attn_value_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_rotation, + backbone.mappingReduced, + backbone.SO3_grid, + backbone.max_num_elements, + backbone.edge_channels_list, + backbone.block_use_atom_edge_embedding, + backbone.use_m_share_rad, + backbone.attn_activation, + backbone.use_s2_act_attn, + backbone.use_attn_renorm, + backbone.use_gate_act, + backbone.use_sep_s2_act, + alpha_drop=0.0, + ) + + self.denoising_pos_block = SO2EquivariantGraphAttention( + backbone.sphere_channels, + backbone.attn_hidden_channels, + backbone.num_heads, + backbone.attn_alpha_channels, + backbone.attn_value_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_rotation, + backbone.mappingReduced, + backbone.SO3_grid, + backbone.max_num_elements, + backbone.edge_channels_list, + backbone.block_use_atom_edge_embedding, + backbone.use_m_share_rad, + backbone.attn_activation, + backbone.use_s2_act_attn, + backbone.use_attn_renorm, + backbone.use_gate_act, + backbone.use_sep_s2_act, + alpha_drop=0.0, + ) + self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + if self.activation_checkpoint: + output_vector = torch.utils.checkpoint.checkpoint( + self.vector_block, + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + emb["graph"].node_offset, + use_reentrant=not self.training, + ) + denoising_pos_vec = torch.utils.checkpoint.checkpoint( + self.denoising_pos_block, + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + emb["graph"].node_offset, + use_reentrant=not self.training, + ) + else: + output_vector = self.vector_block( + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + node_offset=emb["graph"].node_offset, + ) + denoising_pos_vec = self.denoising_pos_block( + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + node_offset=emb["graph"].node_offset, + ) + output_vector = output_vector.embedding.narrow(1, 1, 3) + output_vector = output_vector.view(-1, 3).contiguous() + denoising_pos_vec = denoising_pos_vec.embedding.narrow(1, 1, 3) + denoising_pos_vec = denoising_pos_vec.view(-1, 3) + if gp_utils.initialized(): + output_vector = gp_utils.gather_from_model_parallel_region( + output_vector, dim=0 + ) + denoising_pos_vec = gp_utils.gather_from_model_parallel_region( + denoising_pos_vec, dim=0 + ) + + if hasattr(data, "denoising_pos_forward") and data.denoising_pos_forward: + if hasattr(data, "noise_mask"): + noise_mask_tensor = data.noise_mask.view(-1, 1) + output_vector = ( + denoising_pos_vec * noise_mask_tensor + + output_vector * (~noise_mask_tensor) + ) + else: + output_vector = denoising_pos_vec + 0 * output_vector + else: + output_vector = 0 * denoising_pos_vec + output_vector + + return {self.output_name: output_vector} + + +@registry.register_model("dens_rank2_symmetric_head") +class DeNSRank2Head(Rank2SymmetricTensorHead): + def __init__( + self, backbone: BackboneInterface, *args, use_denoising: bool = True, **kwargs + ): + super().__init__(backbone, *args, **kwargs) + self.use_denoising = use_denoising + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + output = super().forward(data, emb) + if ( + hasattr(data, "denoising_pos_forward") + and data.denoising_pos_forward + and not self.use_denoising + ): + for k in output: + output[k] = output[k] * 0.0 + return output diff --git a/src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py b/src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py new file mode 100644 index 0000000000..11735d7bb9 --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py @@ -0,0 +1,854 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from dataclasses import dataclass +from functools import cached_property +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +import torch_geometric +from tqdm import tqdm + +from fairchem.core.common import distutils +from fairchem.core.common.registry import registry +from fairchem.core.modules.evaluator import mae +from fairchem.core.modules.normalization.normalizer import Normalizer +from fairchem.core.modules.scaling.util import ensure_fitted + +from .forces_trainer import EquiformerV2ForcesTrainer + +if TYPE_CHECKING: + from fairchem.core.modules.evaluator import Evaluator + + +@dataclass +class DenoisingPosParams: + prob: float = 0.0 + fixed_noise_std: bool = False + std: float = None + num_steps: int = None + std_low: float = None + std_high: float = None + corrupt_ratio: float = None + all_atoms: bool = False + denoising_pos_coefficient: float = None + + +def add_gaussian_noise_to_position(batch, std, corrupt_ratio=None, all_atoms=False): + """ + 1. Update `pos` in `batch`. + 2. Add `noise_vec` to `batch`, which will serve as the target for denoising positions. + 3. Add `denoising_pos_forward` to switch to denoising mode during training. + 4. Add `noise_mask` for partially corrupted structures when `corrupt_ratio` is not None. + 5. If `all_atoms` == True, we add noise to all atoms including fixed ones. + 6. Check whether `batch` has `md`. We do not add noise to structures from MD split. + """ + noise_vec = torch.zeros_like(batch.pos) + noise_vec = noise_vec.normal_(mean=0.0, std=std) + + if corrupt_ratio is not None: + noise_mask = torch.rand( + (batch.pos.shape[0]), + dtype=batch.pos.dtype, + device=batch.pos.device, + ) + noise_mask = noise_mask < corrupt_ratio + noise_vec[(~noise_mask)] *= 0 + batch.noise_mask = noise_mask + + # Not add noise to structures from MD split + if hasattr(batch, "md"): + batch_index = batch.batch + md_index = batch.md.bool() + md_index = md_index[batch_index] + noise_mask = ~md_index + noise_vec[(~noise_mask)] *= 0 + if hasattr(batch, "noise_mask"): + batch.noise_mask = batch.noise_mask * noise_mask + else: + batch.noise_mask = noise_mask + + pos = batch.pos + new_pos = pos + noise_vec + if all_atoms: + batch.pos = new_pos + else: + free_mask = batch.fixed == 0.0 + batch.pos[free_mask] = new_pos[free_mask] + + batch.noise_vec = noise_vec + batch.denoising_pos_forward = True + + return batch + + +def add_gaussian_noise_schedule_to_position( + batch, std_low, std_high, num_steps, corrupt_ratio=None, all_atoms=False +): + """ + 1. Similar to above, update positions in batch with gaussian noise, but + additionally, also save the sigmas the noise vectors are sampled from. + 2. Add `noise_mask` for partially corrupted structures when `corrupt_ratio` + is not None. + 3. If `all_atoms` == True, we add noise to all atoms including fixed ones. + 4. Check whether `batch` has `md`. We do not add noise to structures from MD split. + """ + sigmas = torch.tensor( + np.exp(np.linspace(np.log(std_low), np.log(std_high), num_steps)), + dtype=torch.float32, + ) + # select a sigma for each structure, and project it all atoms in the structure. + ts = torch.randint(0, num_steps, size=(batch.natoms.size(0),)) + batch.sigmas = sigmas[ts][batch.batch][:, None] # (natoms, 1) + noise_vec = torch.zeros_like(batch.pos) + noise_vec = noise_vec.normal_() * batch.sigmas + + if corrupt_ratio is not None: + noise_mask = torch.rand( + (batch.pos.shape[0]), + dtype=batch.pos.dtype, + device=batch.pos.device, + ) + noise_mask = noise_mask < corrupt_ratio + # noise_vec[(~noise_mask)] *= 0 + batch.noise_mask = noise_mask + + # Not add noise to structures from MD split + if hasattr(batch, "md"): + batch_index = batch.batch + md_index = batch.md.bool() + md_index = md_index[batch_index] + noise_mask = ~md_index + # noise_vec[(~noise_mask)] *= 0 + if hasattr(batch, "noise_mask"): + batch.noise_mask = batch.noise_mask * noise_mask + else: + batch.noise_mask = noise_mask + + if hasattr(batch, "noise_mask"): + noise_vec[(~batch.noise_mask)] *= 0 + + # only add noise to free atoms + pos = batch.pos + new_pos = pos + noise_vec + if all_atoms: + batch.pos = new_pos + else: + free_mask = batch.fixed == 0.0 + batch.pos[free_mask] = new_pos[free_mask] + + batch.noise_vec = noise_vec + batch.denoising_pos_forward = True + + return batch + + +def denoising_pos_eval( + evaluator: Evaluator, + prediction: dict[str, torch.Tensor], + target: dict[str, torch.Tensor], + denoising_targets: tuple[str], + prev_metrics: dict[str, torch.Tensor] | None = None, + denoising_pos_forward: bool = False, +): + """ + 1. Overwrite the original Evaluator.eval() here: https://github.com/Open-Catalyst-Project/ocp/blob/5a7738f9aa80b1a9a7e0ca15e33938b4d2557edd/ocpmodels/modules/evaluator.py#L69-L81 + 2. This is to make sure we separate forces MAE and denoising positions MAE. + """ + + if not denoising_pos_forward: + return evaluator.eval(prediction, target, prev_metrics) + + metrics = prev_metrics + for target_name in denoising_targets: + res = mae(prediction, target, target_name) + metrics = evaluator.update(f"denoising_{target_name}_mae", res, metrics) + + if target.get("noise_mask") is None: + # Only update`denoising_pos_mae` during denoising positions if not using partially corrupted structures + res = mae(prediction, target, "forces") + metrics = evaluator.update("denoising_pos_mae", res, metrics) + else: # Update `denoising_pos_mae` and `denoising_force_mae` if using partially corrupted structures + # separate S2EF and denoising positions results based on `noise_mask` + target_tensor = target["forces"] + prediction_tensor = prediction["forces"] + noise_mask = target["noise_mask"] + s2ef_index = torch.where(noise_mask == 0) + s2ef_prediction = {"forces": prediction_tensor[s2ef_index]} + s2ef_target = {"forces": target_tensor[s2ef_index]} + res = mae(s2ef_prediction, s2ef_target, "forces") + if res["numel"] != 0: + metrics = evaluator.update("denoising_force_mae", res, metrics) + denoising_pos_index = torch.where(noise_mask == 1) + denoising_pos_prediction = {"forces": prediction_tensor[denoising_pos_index]} + denoising_pos_target = {"forces": target_tensor[denoising_pos_index]} + res = mae(denoising_pos_prediction, denoising_pos_target, "forces") + if res["numel"] != 0: + metrics = evaluator.update("denoising_pos_mae", res, metrics) + return metrics + + +def compute_atomwise_denoising_pos_and_force_hybrid_loss( + pred, target, noise_mask, force_mult, denoising_pos_mult, mask=None +): + loss = torch.norm(pred - target, p=2, dim=-1, keepdim=True) + force_index = torch.where(noise_mask == 0) + denoising_pos_index = torch.where(noise_mask == 1) + mult_tensor = torch.ones_like(loss) + mult_tensor[force_index] *= force_mult + mult_tensor[denoising_pos_index] *= denoising_pos_mult + loss = loss * mult_tensor + if mask is not None: + loss = loss[mask] + return torch.mean(loss) + + +@registry.register_trainer("equiformerv2_dens") +class DenoisingForcesTrainer(EquiformerV2ForcesTrainer): + """ + 1. We add a denoising objective to the original S2EF task. + 2. The denoising objective is that we take as input + atom types, node-wise forces and 3D coordinates perturbed with Gaussian noises and then + output energy of the original structure (3D coordinates without any perturbation) and + the node-wise noises added to the original structure. + 3. This should make models leverage more from training data and enable data augmentation for + the S2EF task. + 4. We should only modify the training part. + 5. For normalizing the outputs of noise prediction, if we use `fixed_noise_std = True`, we use + `std` for the normalization factor. Otherwise, we use `std_high` when `fixed_noise_std = False`. + + Args: + task (dict): Task configuration. + model (dict): Model configuration. + outputs (dict): Dictionary of model output configuration. + dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset. + optimizer (dict): Optimizer configuration. + loss_functions (dict): Loss function configuration. + evaluation_metrics (dict): Evaluation metrics configuration. + identifier (str): Experiment identifier that is appended to log directory. + run_dir (str, optional): Path to the run directory where logs are to be saved. + (default: :obj:`None`) + timestamp_id (str, optional): timestamp identifier. + run_dir (str, optional): Run directory used to save checkpoints and results. + is_debug (bool, optional): Run in debug mode. + (default: :obj:`False`) + print_every (int, optional): Frequency of printing logs. + (default: :obj:`100`) + seed (int, optional): Random number seed. + (default: :obj:`None`) + logger (str, optional): Type of logger to be used. + (default: :obj:`wandb`) + local_rank (int, optional): Local rank of the process, only applicable for distributed training. + (default: :obj:`0`) + amp (bool, optional): Run using automatic mixed precision. + (default: :obj:`False`) + cpu (bool): If True will run on CPU. Default is False, will attempt to use cuda. + name (str): Trainer name. + slurm (dict): Slurm configuration. Currently just for keeping track. + (default: :obj:`{}`) + gp_gpus (int, optional): Number of graph parallel GPUs. + inference_only (bool): If true trainer will be loaded for inference only. + (ie datasets, optimizer, schedular, etc, will not be instantiated) + """ + + def __init__( + self, + task: dict[str, str | Any], + model: dict[str, Any], + outputs: dict[str, str | int], + dataset: dict[str, str | float], + optimizer: dict[str, str | float], + loss_functions: dict[str, str | float], + evaluation_metrics: dict[str, str], + identifier: str, + # TODO: dealing with local rank is dangerous + # T201111838 remove this and use CUDA_VISIBILE_DEVICES instead so trainers don't need to know about which devie to use + local_rank: int, + timestamp_id: str | None = None, + run_dir: str | None = None, + is_debug: bool = False, + print_every: int = 100, + seed: int | None = None, + logger: str = "wandb", + amp: bool = False, + cpu: bool = False, + name: str = "ocp", + slurm: dict | None = None, + gp_gpus: int | None = None, + inference_only: bool = False, + ): + if slurm is None: + slurm = {} + super().__init__( + task=task, + model=model, + outputs=outputs, + dataset=dataset, + optimizer=optimizer, + loss_functions=loss_functions, + evaluation_metrics=evaluation_metrics, + identifier=identifier, + timestamp_id=timestamp_id, + run_dir=run_dir, + is_debug=is_debug, + print_every=print_every, + seed=seed, + logger=logger, + local_rank=local_rank, + amp=amp, + cpu=cpu, + slurm=slurm, + name=name, + gp_gpus=gp_gpus, + inference_only=inference_only, + ) + + # for denoising positions + self.use_denoising_pos = self.config["optim"]["use_denoising_pos"] + self.denoising_pos_params = DenoisingPosParams( + **self.config["optim"]["denoising_pos_params"] + ) + self.denoising_pos_params.denoising_pos_coefficient = self.config["optim"][ + "denoising_pos_coefficient" + ] + self.normalizers["denoising_pos_target"] = Normalizer( + mean=0.0, + rmsd=( + self.denoising_pos_params.std + if self.denoising_pos_params.fixed_noise_std + else self.denoising_pos_params.std_high + ), + ) + self.normalizers["denoising_pos_target"].to(self.device) + + @cached_property + def denoising_targets(self): + return tuple( + head.output_name + for head in self._unwrapped_model.output_heads.values() + if getattr(head, "use_denoising", False) + ) + + def train(self, disable_eval_tqdm=False): + ensure_fitted(self._unwrapped_model, warn=True) + + eval_every = self.config["optim"].get("eval_every", len(self.train_loader)) + checkpoint_every = self.config["optim"].get("checkpoint_every", eval_every) + primary_metric = self.evaluation_metrics.get( + "primary_metric", self.evaluator.task_primary_metric[self.name] + ) + if not hasattr(self, "primary_metric") or self.primary_metric != primary_metric: + self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0 + else: + primary_metric = self.primary_metric + self.metrics = {} + + # Calculate start_epoch from step instead of loading the epoch number + # to prevent inconsistencies due to different batch size in checkpoint. + start_epoch = self.step // len(self.train_loader) + + for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): + skip_steps = self.step % len(self.train_loader) + self.train_sampler.set_epoch_and_start_iteration(epoch_int, skip_steps) + train_loader_iter = iter(self.train_loader) + + for i in range(skip_steps, len(self.train_loader)): + self.epoch = epoch_int + (i + 1) / len(self.train_loader) + self.step = epoch_int * len(self.train_loader) + i + 1 + self.model.train() + + # Get a batch. + batch = next(train_loader_iter) + + # for denoising positions + if ( + self.use_denoising_pos + and np.random.rand() < self.denoising_pos_params.prob + ): + if self.denoising_pos_params.fixed_noise_std: + batch = add_gaussian_noise_to_position( + batch, + std=self.denoising_pos_params.std, + corrupt_ratio=self.denoising_pos_params.corrupt_ratio, + all_atoms=self.denoising_pos_params.all_atoms, + ) + else: + batch = add_gaussian_noise_schedule_to_position( + batch, + std_low=self.denoising_pos_params.std_low, + std_high=self.denoising_pos_params.std_high, + num_steps=self.denoising_pos_params.num_steps, + corrupt_ratio=self.denoising_pos_params.corrupt_ratio, + all_atoms=self.denoising_pos_params.all_atoms, + ) + + # Forward, loss, backward. #TODO update this with new signatures + with torch.cuda.amp.autocast(enabled=self.scaler is not None): + out = self._forward(batch) + loss = self._compute_loss(out, batch) + + # Compute metrics. + self.metrics = self._compute_metrics( + out, + batch, + self.evaluator, + self.metrics, + ) + self.metrics = self.evaluator.update("loss", loss.item(), self.metrics) + + loss = self.scaler.scale(loss) if self.scaler else loss + self._backward(loss) + + # Log metrics. + log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} + log_dict.update( + { + "lr": self.scheduler.get_lr(), + "epoch": self.epoch, + "step": self.step, + } + ) + if ( + self.step % self.config["cmd"]["print_every"] == 0 + or i == 0 + or i == (len(self.train_loader) - 1) + ) and distutils.is_master(): + log_str = [f"{k}: {v:.2e}" for k, v in log_dict.items()] + logging.info(", ".join(log_str)) + self.metrics = {} + + if self.logger is not None: + self.logger.log( + log_dict, + step=self.step, + split="train", + ) + + if checkpoint_every != -1 and self.step % checkpoint_every == 0: + self.save(checkpoint_file="checkpoint.pt", training_state=True) + + # Evaluate on val set every `eval_every` iterations. + if self.step % eval_every == 0 or i == (len(self.train_loader) - 1): + if self.val_loader is not None: + if i == (len(self.train_loader) - 1): + self.save( + checkpoint_file="checkpoint.pt", + training_state=True, + ) + + val_metrics = self.validate( + split="val", disable_tqdm=disable_eval_tqdm + ) + self.update_best( + primary_metric, + val_metrics, + disable_eval_tqdm=disable_eval_tqdm, + ) + + if self.config["task"].get("eval_relaxations", False): + if "relax_dataset" not in self.config["task"]: + logging.warning( + "Cannot evaluate relaxations, relax_dataset not specified" + ) + else: + self.run_relaxations() + + if self.scheduler.scheduler_type == "ReduceLROnPlateau": + if self.step % eval_every == 0: + self.scheduler.step( + metrics=val_metrics[primary_metric]["metric"], + ) + else: + self.scheduler.step() + + torch.cuda.empty_cache() + + if checkpoint_every == -1: + self.save(checkpoint_file="checkpoint.pt", training_state=True) + + def _compute_loss(self, out, batch): + batch_size = batch.natoms.numel() + fixed = batch.fixed + mask = fixed == 0 + + loss = [] + for loss_fn in self.loss_functions: + target_name, loss_info = loss_fn + + if target_name == "forces" and batch.get("denoising_pos_forward", False): + denoising_pos_target = batch.noise_vec + if self.normalizers.get("denoising_pos_target", False): + denoising_pos_target = self.normalizers[ + "denoising_pos_target" + ].norm(denoising_pos_target) + + if hasattr(batch, "noise_mask"): + # for partially corrupted structures + target = batch.forces + if self.normalizers.get("forces", False): + target = self.normalizers["forces"].norm(target) + noise_mask = batch.noise_mask.view(-1, 1) + target = denoising_pos_target * noise_mask + target * (~noise_mask) + else: + target = denoising_pos_target + + pred = out["forces"] + natoms = batch.natoms + natoms = torch.repeat_interleave(natoms, natoms) + + force_mult = loss_info["coefficient"] + denoising_pos_mult = self.denoising_pos_params.denoising_pos_coefficient + + if ( + self.output_targets[target_name]["level"] == "atom" + and self.output_targets[target_name]["train_on_free_atoms"] + ): + # If `all_atoms` == True when training on only free atoms, + # we also add noise to and denoise fixed atoms. + if self.denoising_pos_params.all_atoms: + if hasattr(batch, "noise_mask"): + mask = mask.view(-1, 1) | noise_mask + else: + mask = torch.ones_like( + mask, dtype=torch.bool, device=mask.device + ).view(-1, 1) + + if hasattr(batch, "noise_mask"): + # for partially corrupted structures + loss.append( + compute_atomwise_denoising_pos_and_force_hybrid_loss( + pred=pred, + target=target, + noise_mask=noise_mask, + force_mult=force_mult, + denoising_pos_mult=denoising_pos_mult, + mask=mask, + ) + ) + else: + target = target[mask] + pred = pred[mask] + natoms = natoms[mask] + + loss.append( + denoising_pos_mult + * loss_info["fn"]( + pred, + target, + natoms=natoms, + ) + ) + else: + if hasattr(batch, "noise_mask"): + # for partially corrupted structures + loss.append( + compute_atomwise_denoising_pos_and_force_hybrid_loss( + pred=pred, + target=target, + noise_mask=noise_mask, + force_mult=force_mult, + denoising_pos_mult=denoising_pos_mult, + mask=None, + ) + ) + else: + loss.append( + denoising_pos_mult + * loss_info["fn"]( + pred, + target, + natoms=natoms, + ) + ) + + else: + target = batch[target_name] + pred = out[target_name] + natoms = batch.natoms + natoms = torch.repeat_interleave(natoms, natoms) + + if ( + self.output_targets[target_name]["level"] == "atom" + and self.output_targets[target_name]["train_on_free_atoms"] + ): + target = target[mask] + pred = pred[mask] + natoms = natoms[mask] + + num_atoms_in_batch = natoms.numel() + + ### reshape accordingly: num_atoms_in_batch, -1 or num_systems_in_batch, -1 + if self.output_targets[target_name]["level"] == "atom": + target = target.view(num_atoms_in_batch, -1) + else: + target = target.view(batch_size, -1) + + # to keep the loss coefficient weights balanced we remove linear references + # subtract element references from target data + if target_name in self.elementrefs: + target = self.elementrefs[target_name].dereference(target, batch) + # normalize the targets data + if target_name in self.normalizers: + target = self.normalizers[target_name].norm(target) + + mult = loss_info["coefficient"] + + loss.append( + mult + * loss_info["fn"]( + pred, + target, + natoms=batch.natoms, + ) + ) + + # Sanity check to make sure the compute graph is correct. + for lc in loss: + assert hasattr(lc, "grad_fn") + + return sum(loss) + + def _compute_metrics(self, out, batch, evaluator, metrics=None): + if metrics is None: + metrics = {} + # this function changes the values in the out dictionary, + # make a copy instead of changing them in the callers version + out = {k: v.clone() for k, v in out.items()} + + natoms = batch.natoms + batch_size = natoms.numel() + + ### Retrieve free atoms + fixed = batch.fixed + mask = fixed == 0 + + s_idx = 0 + natoms_free = [] + for _natoms in natoms: + natoms_free.append(torch.sum(mask[s_idx : s_idx + _natoms]).item()) + s_idx += _natoms + natoms = torch.LongTensor(natoms_free).to(self.device) + + denoising_pos_forward = bool(batch.get("denoising_pos_forward", False)) + + targets = {} + for target_name in self.output_targets: + num_atoms_in_batch = batch.natoms.sum() + + if denoising_pos_forward and target_name == "forces": + if hasattr(batch, "noise_mask"): + force_target = batch.forces + denoising_pos_target = batch.noise_vec + noise_mask = batch.noise_mask + s2ef_index = torch.where(noise_mask == 0) + denoising_pos_index = torch.where(noise_mask == 1) + noise_mask_tensor = noise_mask.view(-1, 1) + targets["forces"] = ( + denoising_pos_target * noise_mask_tensor + + force_target * (~noise_mask_tensor) + ) + targets["noise_mask"] = noise_mask + else: + targets["forces"] = batch.noise_vec + + if "denoising_pos_target" in self.normalizers: + if hasattr(batch, "noise_mask"): + out["forces"][denoising_pos_index] = self.normalizers[ + "denoising_pos_target" + ].denorm(out["forces"][denoising_pos_index]) + else: + out["forces"] = self.normalizers["denoising_pos_target"].denorm( + out["forces"] + ) + + if hasattr(batch, "noise_mask"): + out["forces"][s2ef_index] = self.normalizers["forces"].denorm( + out["forces"][s2ef_index] + ) + + if ( + self.output_targets[target_name]["level"] == "atom" + and self.output_targets[target_name]["eval_on_free_atoms"] + ): + if self.denoising_pos_params.all_atoms: + if hasattr(batch, "noise_mask"): + mask = mask | noise_mask + else: + mask = torch.ones_like( + mask, dtype=torch.bool, device=mask.device + ) + + targets["forces"] = targets["forces"][mask] + out["forces"] = out["forces"][mask] + num_atoms_in_batch = natoms.sum() + if "noise_mask" in targets: + targets["noise_mask"] = targets["noise_mask"][mask] + else: + target = batch[target_name] + + if ( + self.output_targets[target_name]["level"] == "atom" + and self.output_targets[target_name]["eval_on_free_atoms"] + ): + target = target[mask] + out[target_name] = out[target_name][mask] + num_atoms_in_batch = natoms.sum() + + ### reshape accordingly: num_atoms_in_batch, -1 or num_systems_in_batch, -1 + if self.output_targets[target_name]["level"] == "atom": + target = target.view(num_atoms_in_batch, -1) + else: + target = target.view(batch_size, -1) + + out[target_name] = self._denorm_preds( + target_name, out[target_name], batch + ) + targets[target_name] = target + + targets["natoms"] = natoms + out["natoms"] = natoms + + return denoising_pos_eval( + evaluator, + out, + targets, + denoising_targets=self.denoising_targets, + prev_metrics=metrics, + denoising_pos_forward=denoising_pos_forward, + ) + + @torch.no_grad() + def predict( + self, + data_loader, + per_image: bool = True, + results_file: str | None = None, + disable_tqdm: bool = False, + ): + if self.is_debug and per_image: + raise FileNotFoundError("Predictions require debug mode to be turned off.") + + ensure_fitted(self._unwrapped_model, warn=True) + + if distutils.is_master() and not disable_tqdm: + logging.info("Predicting on test.") + assert isinstance( + data_loader, + ( + torch.utils.data.dataloader.DataLoader, + torch_geometric.data.Batch, + ), + ) + rank = distutils.get_rank() + + if isinstance(data_loader, torch_geometric.data.Batch): + data_loader = [data_loader] + + self.model.eval() + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + predictions = defaultdict(list) + + for _, batch in tqdm( + enumerate(data_loader), + total=len(data_loader), + position=rank, + desc=f"device {rank}", + disable=disable_tqdm, + ): + with torch.cuda.amp.autocast(enabled=self.scaler is not None): + out = self._forward(batch) + + for key in out: + out[key] = out[key].float() + + for target_key in self.config["outputs"]: + pred = self._denorm_preds(target_key, out[target_key], batch) + + if per_image: + ### Save outputs in desired precision, default float16 + if ( + self.config["outputs"][target_key].get( + "prediction_dtype", "float16" + ) + == "float32" + or self.config["task"].get("prediction_dtype", "float16") + == "float32" + or self.config["task"].get("dataset", "lmdb") == "oc22_lmdb" + ): + dtype = torch.float32 + else: + dtype = torch.float16 + + pred = pred.detach().cpu().to(dtype) + + ### Split predictions into per-image predictions + if self.config["outputs"][target_key]["level"] == "atom": + batch_natoms = batch.natoms + batch_fixed = batch.fixed + per_image_pred = torch.split(pred, batch_natoms.tolist()) + + ### Save out only free atom, EvalAI does not need fixed atoms + _per_image_fixed = torch.split( + batch_fixed, batch_natoms.tolist() + ) + _per_image_free_preds = [ + _pred[(fixed == 0).tolist()].numpy() + for _pred, fixed in zip(per_image_pred, _per_image_fixed) + ] + _chunk_idx = np.array( + [free_pred.shape[0] for free_pred in _per_image_free_preds] + ) + per_image_pred = _per_image_free_preds + ### Assumes system level properties are of the same dimension + else: + per_image_pred = pred.numpy() + _chunk_idx = None + + predictions[f"{target_key}"].extend(per_image_pred) + ### Backwards compatibility, retain 'chunk_idx' for forces. + if _chunk_idx is not None: + if target_key == "forces": + predictions["chunk_idx"].extend(_chunk_idx) + else: + predictions[f"{target_key}_chunk_idx"].extend(_chunk_idx) + else: + predictions[f"{target_key}"] = pred.detach() + + if not per_image: + return predictions + + ### Get unique system identifiers + sids = ( + batch.sid.tolist() if isinstance(batch.sid, torch.Tensor) else batch.sid + ) + ## Support naming structure for OC20 S2EF + if "fid" in batch: + fids = ( + batch.fid.tolist() + if isinstance(batch.fid, torch.Tensor) + else batch.fid + ) + systemids = [f"{sid}_{fid}" for sid, fid in zip(sids, fids)] + else: + systemids = [f"{sid}" for sid in sids] + + predictions["ids"].extend(systemids) + + self.save_results(predictions, results_file) + + if self.ema: + self.ema.restore() + + return predictions diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index acd18d6e1c..90cdce0e58 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -101,6 +101,7 @@ def __init__( self.cpu = cpu self.epoch = 0 self.step = 0 + self.ema = None if torch.cuda.is_available() and not self.cpu: logging.info(f"local rank base: {local_rank}") @@ -617,7 +618,7 @@ def load_checkpoint( "Loading checkpoint in inference-only mode, not loading keys associated with trainer state!" ) - if "ema" in checkpoint and checkpoint["ema"] is not None: + if "ema" in checkpoint and checkpoint["ema"] is not None and self.ema: self.ema.load_state_dict(checkpoint["ema"]) else: self.ema = None diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 152cddd93b..49e9eec637 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -47,7 +47,7 @@ class OCPTrainer(BaseTrainer): Args: task (dict): Task configuration. model (dict): Model configuration. - outputs (dict): Output property configuration. + outputs (dict): Dictionary of model output configuration. dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset. optimizer (dict): Optimizer configuration. loss_functions (dict): Loss function configuration. @@ -55,6 +55,8 @@ class OCPTrainer(BaseTrainer): identifier (str): Experiment identifier that is appended to log directory. run_dir (str, optional): Path to the run directory where logs are to be saved. (default: :obj:`None`) + timestamp_id (str, optional): timestamp identifier. + run_dir (str, optional): Run directory used to save checkpoints and results. is_debug (bool, optional): Run in debug mode. (default: :obj:`False`) print_every (int, optional): Frequency of printing logs. @@ -63,10 +65,17 @@ class OCPTrainer(BaseTrainer): (default: :obj:`None`) logger (str, optional): Type of logger to be used. (default: :obj:`wandb`) + local_rank (int, optional): Local rank of the process, only applicable for distributed training. + (default: :obj:`0`) amp (bool, optional): Run using automatic mixed precision. (default: :obj:`False`) + cpu (bool): If True will run on CPU. Default is False, will attempt to use cuda. + name (str): Trainer name. slurm (dict): Slurm configuration. Currently just for keeping track. (default: :obj:`{}`) + gp_gpus (int, optional): Number of graph parallel GPUs. + inference_only (bool): If true trainer will be loaded for inference only. + (ie datasets, optimizer, schedular, etc, will not be instantiated) """ def __init__( @@ -91,7 +100,7 @@ def __init__( amp: bool = False, cpu: bool = False, name: str = "ocp", - slurm=None, + slurm: dict | None = None, gp_gpus: int | None = None, inference_only: bool = False, ): @@ -260,6 +269,7 @@ def _forward(self, batch): ), f"we need to know which property to match the target to, please specify the property field in the task config, current config: {self.output_targets[target_key]}" prop = self.output_targets[target_key]["property"] pred = out[target_key][prop] + # TODO clean up this logic to reconstruct a tensor from its predicted decomposition elif "decomposition" in self.output_targets[target_key]: _max_rank = 0 diff --git a/tests/core/test_hydra_cli.yml b/tests/core/test_hydra_cli.yml index 2064dd986f..59e0a158ae 100644 --- a/tests/core/test_hydra_cli.yml +++ b/tests/core/test_hydra_cli.yml @@ -1,4 +1,4 @@ runner: _target_: fairchem.core.components.runner.MockRunner x: 10 - y: 23 + y: 23 \ No newline at end of file