Skip to content

Commit

Permalink
Merge branch 'main' into uniform_initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
misko authored Aug 15, 2024
2 parents 2ec121c + ef2a4bc commit 1467a5c
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 21 deletions.
56 changes: 41 additions & 15 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
avg_degree: float | None = None,
use_energy_lin_ref: bool | None = False,
load_energy_lin_ref: bool | None = False,
activation_checkpoint: bool | None = False,
):
if mmax_list is None:
mmax_list = [2]
Expand All @@ -170,6 +171,7 @@ def __init__(
logging.error("You need to install e3nn==0.4.4 to use EquiformerV2.")
raise ImportError

self.activation_checkpoint = activation_checkpoint
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.regress_forces = regress_forces
Expand Down Expand Up @@ -805,14 +807,26 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
###############################################################

for i in range(self.num_layers):
x = self.blocks[i](
x, # SO3_Embedding
graph.atomic_numbers_full,
graph.edge_distance,
graph.edge_index,
batch=data_batch, # for GraphDropPath
node_offset=graph.node_offset,
)
if self.activation_checkpoint:
x = torch.utils.checkpoint.checkpoint(
self.blocks[i],
x, # SO3_Embedding
graph.atomic_numbers_full,
graph.edge_distance,
graph.edge_index,
data_batch, # for GraphDropPath
graph.node_offset,
use_reentrant=not self.training,
)
else:
x = self.blocks[i](
x, # SO3_Embedding
graph.atomic_numbers_full,
graph.edge_distance,
graph.edge_index,
batch=data_batch, # for GraphDropPath
node_offset=graph.node_offset,
)

# Final layer norm
x.embedding = self.norm(x.embedding)
Expand Down Expand Up @@ -860,6 +874,7 @@ class EquiformerV2ForceHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

self.activation_checkpoint = backbone.activation_checkpoint
self.force_block = SO2EquivariantGraphAttention(
backbone.sphere_channels,
backbone.attn_hidden_channels,
Expand Down Expand Up @@ -887,13 +902,24 @@ def __init__(self, backbone):
self.apply(backbone._uniform_init_rad_func_linear_weights)

def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
forces = self.force_block(
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
emb["graph"].edge_index,
node_offset=emb["graph"].node_offset,
)
if self.activation_checkpoint:
forces = torch.utils.checkpoint.checkpoint(
self.force_block,
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
emb["graph"].edge_index,
emb["graph"].node_offset,
use_reentrant=not self.training,
)
else:
forces = self.force_block(
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
emb["graph"].edge_index,
node_offset=emb["graph"].node_offset,
)
forces = forces.embedding.narrow(1, 1, 3)
forces = forces.view(-1, 3).contiguous()
if gp_utils.initialized():
Expand Down
3 changes: 2 additions & 1 deletion src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
basis_width_scalar: float = 1.0,
distance_resolution: float = 0.02,
show_timing_info: bool = False,
resolution: int | None = None,
) -> None:
if mmax_list is None:
mmax_list = [2]
Expand Down Expand Up @@ -176,7 +177,7 @@ def __init__(
for lval in range(max(self.lmax_list) + 1):
SO3_m_grid = nn.ModuleList()
for m in range(max(self.lmax_list) + 1):
SO3_m_grid.append(SO3_Grid(lval, m))
SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution))

self.SO3_grid.append(SO3_m_grid)

Expand Down
9 changes: 4 additions & 5 deletions src/fairchem/core/models/escn/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,7 @@ class SO3_Grid(torch.nn.Module):
mmax (int): Maximum order of the spherical harmonics
"""

def __init__(
self,
lmax: int,
mmax: int,
) -> None:
def __init__(self, lmax: int, mmax: int, resolution: int | None = None) -> None:
super().__init__()
self.lmax = lmax
self.mmax = mmax
Expand All @@ -465,6 +461,9 @@ def __init__(
self.long_resolution = 2 * (self.mmax + 1) + 1
else:
self.long_resolution = 2 * (self.mmax) + 1
if resolution:
self.long_resolution=resolution
self.lat_resolution=resolution

self.initialized = False

Expand Down
46 changes: 46 additions & 0 deletions tests/core/models/test_equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import copy
import io
import os
from pathlib import Path

import pytest
import requests
import torch
import yaml
from ase.io import read
from torch.nn.parallel.distributed import DistributedDataParallel

Expand Down Expand Up @@ -230,3 +232,47 @@ def sign(x):
embedding._l_primary(c)
lp = embedding.embedding.clone()
(test_matrix_lp == lp).all()


def _load_hydra_model():
torch.manual_seed(4)
with open(Path("tests/core/models/test_configs/test_equiformerv2_hydra.yml")) as yaml_file:
yaml_config = yaml.safe_load(yaml_file)
model = registry.get_model_class("hydra")(yaml_config["model"]["backbone"],yaml_config["model"]["heads"])
model.backbone.num_layers = 1
return model

def test_eqv2_hydra_activation_checkpoint():
atoms = read(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"),
index=0,
format="json",
)
a2g = AtomsToGraphs(
max_neigh=200,
radius=6,
r_edges=False,
r_fixed=True,
)
data_list = a2g.convert_all([atoms])
inputs = data_list_collater(data_list)
no_ac_model = _load_hydra_model()
ac_model = _load_hydra_model()
ac_model.backbone.activation_checkpoint=True

# to do this test we need both models to have the exact same state and the only
# way to do this is save the rng state and reset it after stepping the first model
start_rng_state = torch.random.get_rng_state()
outputs_no_ac = no_ac_model(inputs)
torch.autograd.backward(outputs_no_ac["energy"].sum() + outputs_no_ac["forces"].sum())

# reset the rng state to the beginning
torch.random.set_rng_state(start_rng_state)
outptuts_ac = ac_model(inputs)
torch.autograd.backward(outptuts_ac["energy"].sum() + outptuts_ac["forces"].sum())

# assert all the gradients are identical between the model with checkpointing and no checkpointing
ac_model_grad_dict = {name:p.grad for name, p in ac_model.named_parameters() if p.grad is not None}
no_ac_model_grad_dict = {name:p.grad for name, p in no_ac_model.named_parameters() if p.grad is not None}
for name in no_ac_model_grad_dict:
assert torch.allclose(no_ac_model_grad_dict[name], ac_model_grad_dict[name], atol=1e-4)

0 comments on commit 1467a5c

Please sign in to comment.