Skip to content

Commit

Permalink
Merge branch 'main' into rank2-head
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque authored Aug 8, 2024
2 parents 91d3830 + 21d936f commit c3b6d29
Show file tree
Hide file tree
Showing 27 changed files with 597 additions and 280 deletions.
59 changes: 52 additions & 7 deletions configs/ocp_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dataset:
# Can use 'single_point_lmdb' or 'trajectory_lmdb' for backward compatibility.
# 'single_point_lmdb' was for training IS2RE models, and 'trajectory_lmdb' was
# for training S2EF models.
format: lmdb # 'lmdb' or 'oc22_lmdb'
format: lmdb # 'lmdb', 'oc22_lmdb', or 'ase_d'
# Directory containing training set LMDBs
src: data/s2ef/all/train/
# If we want to rename a target value stored in the data object, specify the mapping here.
Expand All @@ -34,9 +34,11 @@ dataset:
irrep_dim: 0
anisotropic_stress:
irrep_dim: 2
# If we want to normalize targets, i.e. subtract the mean and
# divide by standard deviation, then specify the 'mean' and 'stdev' here.
# If we want to normalize targets, there are a couple of ways to specify normalization values.
# normalization values are applied as: (target - mean) / rmsd
# Note root mean squared difference (rmsd) is equal to stdev if mean != 0, and equal to rms if mean == 0.
# Statistics will by default be applied to the validation and test set.
# 1) specify the 'mean' and 'stdev' explicitly here.
normalizer:
energy:
mean: -0.7554450631141663
Expand All @@ -49,17 +51,60 @@ dataset:
stdev: 674.1657344451734
anisotropic_stress:
stdev: 143.72764771869745
# 2) Estimate the values on-the-fly (OTF) from training data
normalizer:
fit:
targets:
forces: { mean: 0.0 } # values can be explicitly set, ie if you need RMS forces instead of stdev force
stress_isotropic: { } # to estimate both mean and rmsd set to {} or None
stress_anisotropic: { }
batch_size: 64
num_batches: 5000 # if num_batches is not given, the whole dataset will be used
# 3) Specify a single .pt file with dict of target names and Normalizer modules
# (this is the format that OTF vales are saved in)
# see Normalizer module in fairchem.core.modules.normalization.normalizer
normalizer:
file: normalizers.pt
# 4) specify an individual file either .pt or .npz with keys 'mean' and 'rmsd' or 'stdev'
normalizer:
energy:
file: energy_norm.pt
forces:
file: forces_norm.npz
isotropic_stress:
file: isostress_norm.npz
anisotropic_stress:
file: anisostress_norm.npz
# If we want to train on total energies and use a per-element linear reference
# normalization scheme, we can estimate those from the data or specify the path to the per-element
# 1) Fit element references from data
element_references:
fit:
targets:
- energy
batch_size: 64
num_batches: 5000 # if num_batches is not given, the whole dataset will be used
# 2) Specify a file with with key energy and LinearReference object. This is the format OTF references are saved in.
# see fairchem.core.modules.normalization.element_references for references.
element_references:
file: element_references.pt
# 3) Legacy files in npz format can be specified as well. They must have the elemenet references
# under the key coeff
element_references:
energy:
file: element_ref.npz
# 4) backwards compatibility only, linear references can be set as follows. Setting the references
# file as follows is a legacy setting and only works with oc22_lmdb and ase_lmdb datasets
lin_ref: element_ref.npz

# If we want to train OC20 on total energy, a path to OC20 reference
# energies `oc20_ref` must be specified to unreference existing OC20 data.
# download at https://dl.fbaipublicfiles.com/opencatalystproject/data/oc22/oc20_ref.pkl
# Also, train_on_oc20_total_energies must be set to True
# OC22 defaults to total energy, so these flags are not necessary.
train_on_oc20_total_energies: False # True or False
oc20_ref: None # path to oc20_ref
# If we want to train on total energies and use a linear reference
# normalization scheme, we must specify the path to the per-element
# coefficients in a `.npz` format.
lin_ref: False # True or False

val:
# Directory containing val set LMDBs
src: data/s2ef/all/val_id/
Expand Down
12 changes: 6 additions & 6 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@
DEFAULT_ENV_VARS = {
# Expandable segments is a new cuda feature that helps with memory fragmentation during frequent allocations (ie: in the case of variable batch sizes).
# see https://pytorch.org/docs/stable/notes/cuda.html.
"PYTORCH_CUDA_ALLOC_CONF" : "expandable_segments:True",
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}


# copied from https://stackoverflow.com/questions/33490870/parsing-yaml-in-python-detect-duplicated-keys
# prevents loading YAMLS where keys have been overwritten
class UniqueKeyLoader(yaml.SafeLoader):
Expand Down Expand Up @@ -265,8 +266,8 @@ def _import_local_file(path: Path, *, project_root: Path) -> None:
:type project_root: Path
"""

path = path.resolve()
project_root = project_root.parent.resolve()
path = path.absolute()
project_root = project_root.parent.absolute()

module_name = ".".join(
path.absolute().relative_to(project_root.absolute()).with_suffix("").parts
Expand All @@ -285,7 +286,7 @@ def setup_experimental_imports(project_root: Path) -> None:
:param project_root: The root directory of the project (i.e., the "ocp" folder)
"""
experimental_dir = (project_root / "experimental").resolve()
experimental_dir = (project_root / "experimental").absolute()
if not experimental_dir.exists() or not experimental_dir.is_dir():
return

Expand All @@ -298,8 +299,7 @@ def setup_experimental_imports(project_root: Path) -> None:

for inc_dir in include_dirs:
experimental_files.extend(
f.resolve().absolute()
for f in (experimental_dir / inc_dir).rglob("*.py")
f.absolute() for f in (experimental_dir / inc_dir).rglob("*.py")
)

for f in experimental_files:
Expand Down
49 changes: 43 additions & 6 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ def generate_graph(
use_pbc=None,
otf_graph=None,
enforce_max_neighbors_strictly=None,
use_pbc_single=False,
):
cutoff = cutoff or self.cutoff
max_neighbors = max_neighbors or self.max_neighbors
use_pbc = use_pbc or self.use_pbc
use_pbc_single = use_pbc_single or self.use_pbc_single
otf_graph = otf_graph or self.otf_graph

if enforce_max_neighbors_strictly is not None:
Expand Down Expand Up @@ -84,12 +86,47 @@ def generate_graph(

if use_pbc:
if otf_graph:
edge_index, cell_offsets, neighbors = radius_graph_pbc(
data,
cutoff,
max_neighbors,
enforce_max_neighbors_strictly,
)
if use_pbc_single:
(
edge_index_per_system,
cell_offsets_per_system,
neighbors_per_system,
) = list(
zip(
*[
radius_graph_pbc(
data[idx],
cutoff,
max_neighbors,
enforce_max_neighbors_strictly,
)
for idx in range(len(data))
]
)
)

# atom indexs in the edge_index need to be offset
atom_index_offset = data.natoms.cumsum(dim=0).roll(1)
atom_index_offset[0] = 0
edge_index = torch.hstack(
[
edge_index_per_system[idx] + atom_index_offset[idx]
for idx in range(len(data))
]
)
cell_offsets = torch.vstack(cell_offsets_per_system)
neighbors = torch.hstack(neighbors_per_system)
else:
## TODO this is the original call, but blows up with memory
## using two different samples
## sid='mp-675045-mp-675045-0-7' (MPTRAJ)
## sid='75396' (OC22)
edge_index, cell_offsets, neighbors = radius_graph_pbc(
data,
cutoff,
max_neighbors,
enforce_max_neighbors_strictly,
)

out = get_pbc_distances(
data.pos,
Expand Down
36 changes: 16 additions & 20 deletions src/fairchem/core/models/dimenet_plus_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,16 +352,13 @@ def forward(
)
}
if self.regress_forces:
outputs["forces"] = (
-1
* (
torch.autograd.grad(
outputs["energy"],
data.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
)
outputs["forces"] = -1 * (
torch.autograd.grad(
outputs["energy"],
data.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
)
return outputs

Expand All @@ -371,6 +368,7 @@ class DimeNetPlusPlusWrap(DimeNetPlusPlus, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
hidden_channels: int = 128,
num_blocks: int = 4,
Expand All @@ -388,6 +386,7 @@ def __init__(
) -> None:
self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
self.max_neighbors = 50
Expand Down Expand Up @@ -466,16 +465,13 @@ def forward(self, data):
outputs = {"energy": energy}

if self.regress_forces:
forces = (
-1
* (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
forces = -1 * (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
outputs["forces"] = forces

Expand Down
14 changes: 14 additions & 0 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class EquiformerV2(nn.Module, GraphModelMixin):
Args:
use_pbc (bool): Use periodic boundary conditions
use_pbc_single (bool): Process batch PBC graphs one at a time
regress_forces (bool): Compute forces
otf_graph (bool): Compute graph On The Fly (OTF)
max_neighbors (int): Maximum number of neighbors per atom
Expand Down Expand Up @@ -116,6 +117,7 @@ class EquiformerV2(nn.Module, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
otf_graph: bool = True,
max_neighbors: int = 500,
Expand Down Expand Up @@ -169,6 +171,7 @@ def __init__(
raise ImportError

self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.regress_forces = regress_forces
self.otf_graph = otf_graph
self.max_neighbors = max_neighbors
Expand Down Expand Up @@ -681,6 +684,12 @@ def no_weight_decay(self) -> set:

@registry.register_model("equiformer_v2_backbone")
class EquiformerV2Backbone(EquiformerV2, BackboneInterface):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO remove these once we deprecate/stop-inheriting EquiformerV2 class
self.energy_block = None
self.force_block = None

@conditional_grad(torch.enable_grad())
def forward(self, data: Batch) -> dict[str, torch.Tensor]:
self.batch_size = len(data.natoms)
Expand Down Expand Up @@ -813,6 +822,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

self.avg_num_nodes = backbone.avg_num_nodes
self.energy_block = FeedForwardNetwork(
backbone.sphere_channels,
Expand All @@ -826,6 +836,8 @@ def __init__(self, backbone):
backbone.use_grid_mlp,
backbone.use_sep_s2_act,
)
self.apply(backbone._init_weights)
self.apply(backbone._uniform_init_rad_func_linear_weights)

def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
node_energy = self.energy_block(emb["node_embedding"])
Expand Down Expand Up @@ -869,6 +881,8 @@ def __init__(self, backbone):
backbone.use_sep_s2_act,
alpha_drop=0.0,
)
self.apply(backbone._init_weights)
self.apply(backbone._uniform_init_rad_func_linear_weights)

def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
forces = self.force_block(
Expand Down
7 changes: 5 additions & 2 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class eSCN(nn.Module, GraphModelMixin):
Args:
use_pbc (bool): Use periodic boundary conditions
use_pbc_single (bool): Process batch PBC graphs one at a time
regress_forces (bool): Compute forces
otf_graph (bool): Compute graph On The Fly (OTF)
max_neighbors (int): Maximum number of neighbors per atom
Expand All @@ -69,6 +70,7 @@ class eSCN(nn.Module, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
otf_graph: bool = False,
max_neighbors: int = 40,
Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(

self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
self.show_timing_info = show_timing_info
Expand Down Expand Up @@ -527,7 +530,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
class eSCNEnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

backbone.energy_block = None
# Output blocks for energy and forces
self.energy_block = EnergyBlock(
backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act
Expand All @@ -547,7 +550,7 @@ def forward(
class eSCNForceHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

backbone.force_block = None
self.force_block = ForceBlock(
backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act
)
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/models/gemnet/gemnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
extensive: bool = True,
otf_graph: bool = False,
use_pbc: bool = True,
use_pbc_single: bool = False,
output_init: str = "HeOrthogonal",
activation: str = "swish",
num_elements: int = 83,
Expand All @@ -143,6 +144,7 @@ def __init__(
self.regress_forces = regress_forces
self.otf_graph = otf_graph
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single

# GemNet variants
self.direct_forces = direct_forces
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/models/gemnet_gp/gemnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
extensive: bool = True,
otf_graph: bool = False,
use_pbc: bool = True,
use_pbc_single: bool = False,
output_init: str = "HeOrthogonal",
activation: str = "swish",
scale_num_blocks: bool = False,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
self.regress_forces = regress_forces
self.otf_graph = otf_graph
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single

# GemNet variants
self.direct_forces = direct_forces
Expand Down
Loading

0 comments on commit c3b6d29

Please sign in to comment.