From b2eebb6a0f8ab06fef504fc2989daf9216816eab Mon Sep 17 00:00:00 2001 From: Misko Date: Mon, 5 Aug 2024 19:02:52 -0700 Subject: [PATCH] Add an option to run PBC in single system mode (#795) * do pbc per system * add option to use single system pbc * remove comments * integrate use_pbc_single to all the models in repo; add test --- src/fairchem/core/models/base.py | 49 ++++++++++++++++--- src/fairchem/core/models/dimenet_plus_plus.py | 36 ++++++-------- .../models/equiformer_v2/equiformer_v2.py | 2 + src/fairchem/core/models/escn/escn.py | 3 ++ src/fairchem/core/models/gemnet/gemnet.py | 2 + src/fairchem/core/models/gemnet_gp/gemnet.py | 2 + .../core/models/gemnet_oc/gemnet_oc.py | 4 ++ src/fairchem/core/models/painn/painn.py | 2 + src/fairchem/core/models/schnet.py | 20 ++++---- src/fairchem/core/models/scn/scn.py | 3 ++ tests/core/e2e/test_s2ef.py | 19 +++++++ 11 files changed, 106 insertions(+), 36 deletions(-) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index eb8c9d543..8ce8f3fcb 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -53,10 +53,12 @@ def generate_graph( use_pbc=None, otf_graph=None, enforce_max_neighbors_strictly=None, + use_pbc_single=False, ): cutoff = cutoff or self.cutoff max_neighbors = max_neighbors or self.max_neighbors use_pbc = use_pbc or self.use_pbc + use_pbc_single = use_pbc_single or self.use_pbc_single otf_graph = otf_graph or self.otf_graph if enforce_max_neighbors_strictly is not None: @@ -84,12 +86,47 @@ def generate_graph( if use_pbc: if otf_graph: - edge_index, cell_offsets, neighbors = radius_graph_pbc( - data, - cutoff, - max_neighbors, - enforce_max_neighbors_strictly, - ) + if use_pbc_single: + ( + edge_index_per_system, + cell_offsets_per_system, + neighbors_per_system, + ) = list( + zip( + *[ + radius_graph_pbc( + data[idx], + cutoff, + max_neighbors, + enforce_max_neighbors_strictly, + ) + for idx in range(len(data)) + ] + ) + ) + + # atom indexs in the edge_index need to be offset + atom_index_offset = data.natoms.cumsum(dim=0).roll(1) + atom_index_offset[0] = 0 + edge_index = torch.hstack( + [ + edge_index_per_system[idx] + atom_index_offset[idx] + for idx in range(len(data)) + ] + ) + cell_offsets = torch.vstack(cell_offsets_per_system) + neighbors = torch.hstack(neighbors_per_system) + else: + ## TODO this is the original call, but blows up with memory + ## using two different samples + ## sid='mp-675045-mp-675045-0-7' (MPTRAJ) + ## sid='75396' (OC22) + edge_index, cell_offsets, neighbors = radius_graph_pbc( + data, + cutoff, + max_neighbors, + enforce_max_neighbors_strictly, + ) out = get_pbc_distances( data.pos, diff --git a/src/fairchem/core/models/dimenet_plus_plus.py b/src/fairchem/core/models/dimenet_plus_plus.py index aa08ea067..f55544826 100644 --- a/src/fairchem/core/models/dimenet_plus_plus.py +++ b/src/fairchem/core/models/dimenet_plus_plus.py @@ -352,16 +352,13 @@ def forward( ) } if self.regress_forces: - outputs["forces"] = ( - -1 - * ( - torch.autograd.grad( - outputs["energy"], - data.pos, - grad_outputs=torch.ones_like(outputs["energy"]), - create_graph=True, - )[0] - ) + outputs["forces"] = -1 * ( + torch.autograd.grad( + outputs["energy"], + data.pos, + grad_outputs=torch.ones_like(outputs["energy"]), + create_graph=True, + )[0] ) return outputs @@ -371,6 +368,7 @@ class DimeNetPlusPlusWrap(DimeNetPlusPlus, GraphModelMixin): def __init__( self, use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, hidden_channels: int = 128, num_blocks: int = 4, @@ -388,6 +386,7 @@ def __init__( ) -> None: self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.max_neighbors = 50 @@ -466,16 +465,13 @@ def forward(self, data): outputs = {"energy": energy} if self.regress_forces: - forces = ( - -1 - * ( - torch.autograd.grad( - energy, - data.pos, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - ) + forces = -1 * ( + torch.autograd.grad( + energy, + data.pos, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] ) outputs["forces"] = forces diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index e2625eada..06a0280e9 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -116,6 +116,7 @@ class EquiformerV2(nn.Module, GraphModelMixin): def __init__( self, use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = True, max_neighbors: int = 500, @@ -169,6 +170,7 @@ def __init__( raise ImportError self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.regress_forces = regress_forces self.otf_graph = otf_graph self.max_neighbors = max_neighbors diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index dfa872c39..d6367fa9a 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -47,6 +47,7 @@ class eSCN(nn.Module, GraphModelMixin): Args: use_pbc (bool): Use periodic boundary conditions + use_pbc_single (bool): Process batch PBC graphs one at a time regress_forces (bool): Compute forces otf_graph (bool): Compute graph On The Fly (OTF) max_neighbors (int): Maximum number of neighbors per atom @@ -69,6 +70,7 @@ class eSCN(nn.Module, GraphModelMixin): def __init__( self, use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = False, max_neighbors: int = 40, @@ -100,6 +102,7 @@ def __init__( self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.show_timing_info = show_timing_info diff --git a/src/fairchem/core/models/gemnet/gemnet.py b/src/fairchem/core/models/gemnet/gemnet.py index 59b3eda08..f5537b953 100644 --- a/src/fairchem/core/models/gemnet/gemnet.py +++ b/src/fairchem/core/models/gemnet/gemnet.py @@ -118,6 +118,7 @@ def __init__( extensive: bool = True, otf_graph: bool = False, use_pbc: bool = True, + use_pbc_single: bool = False, output_init: str = "HeOrthogonal", activation: str = "swish", num_elements: int = 83, @@ -143,6 +144,7 @@ def __init__( self.regress_forces = regress_forces self.otf_graph = otf_graph self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single # GemNet variants self.direct_forces = direct_forces diff --git a/src/fairchem/core/models/gemnet_gp/gemnet.py b/src/fairchem/core/models/gemnet_gp/gemnet.py index a75756dcc..97af540de 100644 --- a/src/fairchem/core/models/gemnet_gp/gemnet.py +++ b/src/fairchem/core/models/gemnet_gp/gemnet.py @@ -114,6 +114,7 @@ def __init__( extensive: bool = True, otf_graph: bool = False, use_pbc: bool = True, + use_pbc_single: bool = False, output_init: str = "HeOrthogonal", activation: str = "swish", scale_num_blocks: bool = False, @@ -142,6 +143,7 @@ def __init__( self.regress_forces = regress_forces self.otf_graph = otf_graph self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single # GemNet variants self.direct_forces = direct_forces diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index 0aea3d81b..c9dd9e13e 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -108,6 +108,8 @@ class GemNetOC(nn.Module, GraphModelMixin): If False predict forces based on negative gradient of energy potential. use_pbc: bool Whether to use periodic boundary conditions. + use_pbc_single: + Process batch PBC graphs one at a time scale_backprop_forces: bool Whether to scale up the energy and then scales down the forces to prevent NaNs and infs in backpropagated forces. @@ -203,6 +205,7 @@ def __init__( regress_forces: bool = True, direct_forces: bool = False, use_pbc: bool = True, + use_pbc_single: bool = False, scale_backprop_forces: bool = False, cutoff: float = 6.0, cutoff_qint: float | None = None, @@ -269,6 +272,7 @@ def __init__( ) self.enforce_max_neighbors_strictly = enforce_max_neighbors_strictly self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.direct_forces = direct_forces self.forces_coupled = forces_coupled diff --git a/src/fairchem/core/models/painn/painn.py b/src/fairchem/core/models/painn/painn.py index ec9e9f465..33425e8d8 100644 --- a/src/fairchem/core/models/painn/painn.py +++ b/src/fairchem/core/models/painn/painn.py @@ -73,6 +73,7 @@ def __init__( regress_forces: bool = True, direct_forces: bool = True, use_pbc: bool = True, + use_pbc_single: bool = False, otf_graph: bool = True, num_elements: int = 83, scale_file: str | None = None, @@ -92,6 +93,7 @@ def __init__( self.direct_forces = direct_forces self.otf_graph = otf_graph self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single # Borrowed from GemNet. self.symmetric_edge_symmetrization = False diff --git a/src/fairchem/core/models/schnet.py b/src/fairchem/core/models/schnet.py index 5ca70a354..878aee746 100644 --- a/src/fairchem/core/models/schnet.py +++ b/src/fairchem/core/models/schnet.py @@ -30,6 +30,7 @@ class SchNetWrap(SchNet, GraphModelMixin): Args: use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions. (default: :obj:`True`) + use_pbc_single (bool,optional): Process batch PBC graphs one at a time regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating energy with respect to positions. (default: :obj:`True`) @@ -52,6 +53,7 @@ class SchNetWrap(SchNet, GraphModelMixin): def __init__( self, use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = False, hidden_channels: int = 128, @@ -64,6 +66,7 @@ def __init__( self.num_targets = 1 self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.max_neighbors = 50 @@ -111,16 +114,13 @@ def forward(self, data): outputs = {"energy": energy} if self.regress_forces: - forces = ( - -1 - * ( - torch.autograd.grad( - energy, - data.pos, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - ) + forces = -1 * ( + torch.autograd.grad( + energy, + data.pos, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] ) outputs["forces"] = forces diff --git a/src/fairchem/core/models/scn/scn.py b/src/fairchem/core/models/scn/scn.py index 84806e19e..299fa4858 100644 --- a/src/fairchem/core/models/scn/scn.py +++ b/src/fairchem/core/models/scn/scn.py @@ -39,6 +39,7 @@ class SphericalChannelNetwork(nn.Module, GraphModelMixin): Args: use_pbc (bool): Use periodic boundary conditions + use_pbc_single (bool): Process batch PBC graphs one at a time regress_forces (bool): Compute forces otf_graph (bool): Compute graph On The Fly (OTF) max_num_neighbors (int): Maximum number of neighbors per atom @@ -76,6 +77,7 @@ class SphericalChannelNetwork(nn.Module, GraphModelMixin): def __init__( self, use_pbc: bool = True, + use_pbc_single: bool = True, regress_forces: bool = True, otf_graph: bool = False, max_num_neighbors: int = 20, @@ -107,6 +109,7 @@ def __init__( self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.show_timing_info = show_timing_info diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 54055d0c3..9a68c4771 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -319,6 +319,25 @@ def test_train_and_predict( otf_norms=otf_norms, ) + def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "model": {"use_pbc_single": True}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + ) + @pytest.mark.parametrize( ("world_size", "ddp"), [