Skip to content

Commit

Permalink
Add an option to run PBC in single system mode (#795)
Browse files Browse the repository at this point in the history
* do pbc per system

* add option to use single system pbc

* remove comments

* integrate use_pbc_single to all the models in repo; add test
  • Loading branch information
misko authored Aug 6, 2024
1 parent 214522d commit b2eebb6
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 36 deletions.
49 changes: 43 additions & 6 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ def generate_graph(
use_pbc=None,
otf_graph=None,
enforce_max_neighbors_strictly=None,
use_pbc_single=False,
):
cutoff = cutoff or self.cutoff
max_neighbors = max_neighbors or self.max_neighbors
use_pbc = use_pbc or self.use_pbc
use_pbc_single = use_pbc_single or self.use_pbc_single
otf_graph = otf_graph or self.otf_graph

if enforce_max_neighbors_strictly is not None:
Expand Down Expand Up @@ -84,12 +86,47 @@ def generate_graph(

if use_pbc:
if otf_graph:
edge_index, cell_offsets, neighbors = radius_graph_pbc(
data,
cutoff,
max_neighbors,
enforce_max_neighbors_strictly,
)
if use_pbc_single:
(
edge_index_per_system,
cell_offsets_per_system,
neighbors_per_system,
) = list(
zip(
*[
radius_graph_pbc(
data[idx],
cutoff,
max_neighbors,
enforce_max_neighbors_strictly,
)
for idx in range(len(data))
]
)
)

# atom indexs in the edge_index need to be offset
atom_index_offset = data.natoms.cumsum(dim=0).roll(1)
atom_index_offset[0] = 0
edge_index = torch.hstack(
[
edge_index_per_system[idx] + atom_index_offset[idx]
for idx in range(len(data))
]
)
cell_offsets = torch.vstack(cell_offsets_per_system)
neighbors = torch.hstack(neighbors_per_system)
else:
## TODO this is the original call, but blows up with memory
## using two different samples
## sid='mp-675045-mp-675045-0-7' (MPTRAJ)
## sid='75396' (OC22)
edge_index, cell_offsets, neighbors = radius_graph_pbc(
data,
cutoff,
max_neighbors,
enforce_max_neighbors_strictly,
)

out = get_pbc_distances(
data.pos,
Expand Down
36 changes: 16 additions & 20 deletions src/fairchem/core/models/dimenet_plus_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,16 +352,13 @@ def forward(
)
}
if self.regress_forces:
outputs["forces"] = (
-1
* (
torch.autograd.grad(
outputs["energy"],
data.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
)
outputs["forces"] = -1 * (
torch.autograd.grad(
outputs["energy"],
data.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
)
return outputs

Expand All @@ -371,6 +368,7 @@ class DimeNetPlusPlusWrap(DimeNetPlusPlus, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
hidden_channels: int = 128,
num_blocks: int = 4,
Expand All @@ -388,6 +386,7 @@ def __init__(
) -> None:
self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
self.max_neighbors = 50
Expand Down Expand Up @@ -466,16 +465,13 @@ def forward(self, data):
outputs = {"energy": energy}

if self.regress_forces:
forces = (
-1
* (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
forces = -1 * (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
outputs["forces"] = forces

Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class EquiformerV2(nn.Module, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
otf_graph: bool = True,
max_neighbors: int = 500,
Expand Down Expand Up @@ -169,6 +170,7 @@ def __init__(
raise ImportError

self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.regress_forces = regress_forces
self.otf_graph = otf_graph
self.max_neighbors = max_neighbors
Expand Down
3 changes: 3 additions & 0 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class eSCN(nn.Module, GraphModelMixin):
Args:
use_pbc (bool): Use periodic boundary conditions
use_pbc_single (bool): Process batch PBC graphs one at a time
regress_forces (bool): Compute forces
otf_graph (bool): Compute graph On The Fly (OTF)
max_neighbors (int): Maximum number of neighbors per atom
Expand All @@ -69,6 +70,7 @@ class eSCN(nn.Module, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
otf_graph: bool = False,
max_neighbors: int = 40,
Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(

self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
self.show_timing_info = show_timing_info
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/models/gemnet/gemnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
extensive: bool = True,
otf_graph: bool = False,
use_pbc: bool = True,
use_pbc_single: bool = False,
output_init: str = "HeOrthogonal",
activation: str = "swish",
num_elements: int = 83,
Expand All @@ -143,6 +144,7 @@ def __init__(
self.regress_forces = regress_forces
self.otf_graph = otf_graph
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single

# GemNet variants
self.direct_forces = direct_forces
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/models/gemnet_gp/gemnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
extensive: bool = True,
otf_graph: bool = False,
use_pbc: bool = True,
use_pbc_single: bool = False,
output_init: str = "HeOrthogonal",
activation: str = "swish",
scale_num_blocks: bool = False,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
self.regress_forces = regress_forces
self.otf_graph = otf_graph
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single

# GemNet variants
self.direct_forces = direct_forces
Expand Down
4 changes: 4 additions & 0 deletions src/fairchem/core/models/gemnet_oc/gemnet_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class GemNetOC(nn.Module, GraphModelMixin):
If False predict forces based on negative gradient of energy potential.
use_pbc: bool
Whether to use periodic boundary conditions.
use_pbc_single:
Process batch PBC graphs one at a time
scale_backprop_forces: bool
Whether to scale up the energy and then scales down the forces
to prevent NaNs and infs in backpropagated forces.
Expand Down Expand Up @@ -203,6 +205,7 @@ def __init__(
regress_forces: bool = True,
direct_forces: bool = False,
use_pbc: bool = True,
use_pbc_single: bool = False,
scale_backprop_forces: bool = False,
cutoff: float = 6.0,
cutoff_qint: float | None = None,
Expand Down Expand Up @@ -269,6 +272,7 @@ def __init__(
)
self.enforce_max_neighbors_strictly = enforce_max_neighbors_strictly
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single

self.direct_forces = direct_forces
self.forces_coupled = forces_coupled
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/models/painn/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
regress_forces: bool = True,
direct_forces: bool = True,
use_pbc: bool = True,
use_pbc_single: bool = False,
otf_graph: bool = True,
num_elements: int = 83,
scale_file: str | None = None,
Expand All @@ -92,6 +93,7 @@ def __init__(
self.direct_forces = direct_forces
self.otf_graph = otf_graph
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single

# Borrowed from GemNet.
self.symmetric_edge_symmetrization = False
Expand Down
20 changes: 10 additions & 10 deletions src/fairchem/core/models/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SchNetWrap(SchNet, GraphModelMixin):
Args:
use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions.
(default: :obj:`True`)
use_pbc_single (bool,optional): Process batch PBC graphs one at a time
regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating
energy with respect to positions.
(default: :obj:`True`)
Expand All @@ -52,6 +53,7 @@ class SchNetWrap(SchNet, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
otf_graph: bool = False,
hidden_channels: int = 128,
Expand All @@ -64,6 +66,7 @@ def __init__(
self.num_targets = 1
self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
self.max_neighbors = 50
Expand Down Expand Up @@ -111,16 +114,13 @@ def forward(self, data):
outputs = {"energy": energy}

if self.regress_forces:
forces = (
-1
* (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
forces = -1 * (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
outputs["forces"] = forces

Expand Down
3 changes: 3 additions & 0 deletions src/fairchem/core/models/scn/scn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SphericalChannelNetwork(nn.Module, GraphModelMixin):
Args:
use_pbc (bool): Use periodic boundary conditions
use_pbc_single (bool): Process batch PBC graphs one at a time
regress_forces (bool): Compute forces
otf_graph (bool): Compute graph On The Fly (OTF)
max_num_neighbors (int): Maximum number of neighbors per atom
Expand Down Expand Up @@ -76,6 +77,7 @@ class SphericalChannelNetwork(nn.Module, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = True,
regress_forces: bool = True,
otf_graph: bool = False,
max_num_neighbors: int = 20,
Expand Down Expand Up @@ -107,6 +109,7 @@ def __init__(

self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
self.show_timing_info = show_timing_info
Expand Down
19 changes: 19 additions & 0 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,25 @@ def test_train_and_predict(
otf_norms=otf_norms,
)

def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic):
with tempfile.TemporaryDirectory() as tempdirname:
tempdir = Path(tempdirname)
extra_args = {"seed": 0}
_ = _run_main(
rundir=str(tempdir),
update_dict_with={
"optim": {"max_epochs": 1},
"model": {"use_pbc_single": True},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
update_run_args_with=extra_args,
input_yaml=configs["equiformer_v2"],
)

@pytest.mark.parametrize(
("world_size", "ddp"),
[
Expand Down

0 comments on commit b2eebb6

Please sign in to comment.