Skip to content

Commit

Permalink
Adds optional per-element linear reference coefficients to Equiformer…
Browse files Browse the repository at this point in the history
…V2 (#584)

* OC22 dloader bugfix when using with linref: pyg objects have `x` and `y` but set to None

* Updates EquiformerV2 model to optionally ship with linref coefficients

* Some more details of how the linear ref energies are computed

* Don't autocast to float16 for total energy predictions

* debugging circleci
  • Loading branch information
abhshkdz authored Oct 4, 2023
1 parent 6d19ba1 commit 0b44322
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ocpmodels/datasets/oc22_lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ def __getitem__(self, idx):
fid = fid.item()
data_object.fid = fid

if hasattr(data_object, "y_relaxed"):
if getattr(data_object, "y_relaxed", None) is not None:
attr = "y_relaxed"
elif hasattr(data_object, "y"):
elif getattr(data_object, "y", None) is not None:
attr = "y"
# if targets are not available, test data is being used
else:
Expand Down
46 changes: 46 additions & 0 deletions ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ class EquiformerV2_OC20(BaseModel):
weight_init (str): ['normal', 'uniform'] initialization of weights of linear layers except those in radial functions
enforce_max_neighbors_strictly (bool): When edges are subselected based on the `max_neighbors` arg, arbitrarily select amongst equidistant / degenerate edges to have exactly the correct number.
avg_num_nodes (float): Average number of nodes per graph
avg_degree (float): Average degree of nodes in the graph
use_energy_lin_ref (bool): Whether to add the per-atom energy references during prediction.
During training and validation, this should be kept `False` since we use the `lin_ref` parameter in the OC22 dataloader to subtract the per-atom linear references from the energy targets.
During prediction (where we don't have energy targets), this can be set to `True` to add the per-atom linear references to the predicted energies.
load_energy_lin_ref (bool): Whether to add nn.Parameters for the per-element energy references.
This additional flag is there to ensure compatibility when strict-loading checkpoints, since the `use_energy_lin_ref` flag can be either True or False even if the model is trained with linear references.
You can't have use_energy_lin_ref = True and load_energy_lin_ref = False, since the model will not have the parameters for the linear references. All other combinations are fine.
"""

def __init__(
Expand Down Expand Up @@ -141,6 +150,8 @@ def __init__(
enforce_max_neighbors_strictly: bool = True,
avg_num_nodes: Optional[float] = None,
avg_degree: Optional[float] = None,
use_energy_lin_ref: Optional[bool] = False,
load_energy_lin_ref: Optional[bool] = False,
):
super().__init__()

Expand Down Expand Up @@ -202,6 +213,12 @@ def __init__(
self.avg_num_nodes = avg_num_nodes or _AVG_NUM_NODES
self.avg_degree = avg_degree or _AVG_DEGREE

self.use_energy_lin_ref = use_energy_lin_ref
self.load_energy_lin_ref = load_energy_lin_ref
assert not (
self.use_energy_lin_ref and not self.load_energy_lin_ref
), "You can't have use_energy_lin_ref = True and load_energy_lin_ref = False, since the model will not have the parameters for the linear references. All other combinations are fine."

self.weight_init = weight_init
assert self.weight_init in ["normal", "uniform"]

Expand Down Expand Up @@ -371,6 +388,12 @@ def __init__(
alpha_drop=0.0,
)

if self.load_energy_lin_ref:
self.energy_lin_ref = nn.Parameter(
torch.zeros(self.max_num_elements),
requires_grad=False,
)

self.apply(self._init_weights)
self.apply(self._uniform_init_rad_func_linear_weights)

Expand Down Expand Up @@ -487,6 +510,29 @@ def forward(self, data):
energy.index_add_(0, data.batch, node_energy.view(-1))
energy = energy / self.avg_num_nodes

# Add the per-atom linear references to the energy.
if self.use_energy_lin_ref and self.load_energy_lin_ref:
# During training, target E = (E_DFT - E_ref - E_mean) / E_std, and
# during inference, \hat{E_DFT} = \hat{E} * E_std + E_ref + E_mean
# where
#
# E_DFT = raw DFT energy,
# E_ref = reference energy,
# E_mean = normalizer mean,
# E_std = normalizer std,
# \hat{E} = predicted energy,
# \hat{E_DFT} = predicted DFT energy.
#
# We can also write this as
# \hat{E_DFT} = E_std * (\hat{E} + E_ref / E_std) + E_mean,
# which is why we save E_ref / E_std as the linear reference.
with torch.cuda.amp.autocast(False):
energy = energy.to(self.energy_lin_ref.dtype).index_add(
0,
data.batch,
self.energy_lin_ref[atomic_numbers],
)

###############################################################
# Force estimation
###############################################################
Expand Down

0 comments on commit 0b44322

Please sign in to comment.