From cbeb1d56255c6ca7ec39d98edaeff1c97f832c65 Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Sun, 3 Mar 2024 00:35:32 +0800 Subject: [PATCH] add mask to atomic model output when an atomic type exclusion presents. (#3389) This PR also - introduce atomic_output_def used wraps the fitting_output_def. - atomic_output_def will be used by make_model - add missing ut for pair and atom exclusions in dpmodel See also #3357 --------- Co-authored-by: Han Wang --- .../dpmodel/atomic_model/base_atomic_model.py | 24 +++++++ .../atomic_model/make_base_atomic_model.py | 11 +++- deepmd/dpmodel/descriptor/se_e2_a.py | 12 +++- deepmd/dpmodel/fitting/general_fitting.py | 12 +++- deepmd/dpmodel/model/make_model.py | 4 +- .../model/atomic_model/base_atomic_model.py | 24 +++++++ deepmd/pt/model/model/make_hessian_model.py | 10 +-- deepmd/pt/model/model/make_model.py | 4 +- .../common/dpmodel/test_dp_atomic_model.py | 65 +++++++++++++++++++ source/tests/pt/model/test_dp_atomic_model.py | 27 ++++++++ .../tests/pt/model/test_make_hessian_model.py | 4 +- 11 files changed, 180 insertions(+), 17 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 09d33203a1..b8c4902d68 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -8,6 +8,10 @@ import numpy as np +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) from deepmd.dpmodel.utils import ( AtomExcludeMask, PairExcludeMask, @@ -50,6 +54,25 @@ def reinit_pair_exclude( else: self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types) + def atomic_output_def(self) -> FittingOutputDef: + old_def = self.fitting_output_def() + if self.atom_excl is None: + return old_def + else: + old_list = list(old_def.get_data().values()) + return FittingOutputDef( + old_list # noqa:RUF005 + + [ + OutputVariableDef( + name="mask", + shape=[1], + reduciable=False, + r_differentiable=False, + c_differentiable=False, + ) + ] + ) + def forward_common_atomic( self, extended_coord: np.ndarray, @@ -79,6 +102,7 @@ def forward_common_atomic( atom_mask = self.atom_excl.build_type_exclude_mask(atype) for kk in ret_dict.keys(): ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None] + ret_dict["mask"] = atom_mask return ret_dict diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index df6e39dd2e..5548147d54 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -36,9 +36,18 @@ class BAM(ABC): @abstractmethod def fitting_output_def(self) -> FittingOutputDef: - """Get the fitting output def.""" + """Get the output def of developer implemented atomic models.""" pass + def atomic_output_def(self) -> FittingOutputDef: + """Get the output def of the atomic model. + + By default it is the same as FittingOutputDef, but it + allows model level wrapper of the output defined by the developer. + + """ + return self.fitting_output_def() + @abstractmethod def get_rcut(self) -> float: """Get the cut-off radius.""" diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 891f308edc..a068a2e366 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -26,6 +26,7 @@ Any, List, Optional, + Tuple, ) from deepmd.dpmodel import ( @@ -168,12 +169,12 @@ def __init__( self.resnet_dt = resnet_dt self.trainable = trainable self.type_one_side = type_one_side - self.exclude_types = exclude_types self.set_davg_zero = set_davg_zero self.activation_function = activation_function self.precision = precision self.spin = spin - self.emask = PairExcludeMask(self.ntypes, self.exclude_types) + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) in_dim = 1 # not considiering type embedding self.embeddings = NetworkCollection( @@ -271,6 +272,13 @@ def cal_g( gg = self.embeddings[embedding_idx].call(ss) return gg + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def call( self, coord_ext, diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 5b4ca195b5..c004814b60 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -120,13 +120,12 @@ def __init__( self.use_aparam_as_mask = use_aparam_as_mask self.spin = spin self.mixed_types = mixed_types - self.exclude_types = exclude_types + # order matters, should be place after the assignment of ntypes + self.reinit_exclude(exclude_types) if self.spin is not None: raise NotImplementedError("spin is not supported") self.remove_vaccum_contribution = remove_vaccum_contribution - self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) - net_dim_out = self._net_out_dim() # init constants self.bias_atom_e = np.zeros([self.ntypes, net_dim_out]) @@ -214,6 +213,13 @@ def __getitem__(self, key): else: raise KeyError(key) + def reinit_exclude( + self, + exclude_types: List[int] = [], + ): + self.exclude_types = exclude_types + self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) + def serialize(self) -> dict: """Serialize the fitting to dict.""" return { diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index e8b1ecc390..f30f6a4021 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -74,7 +74,7 @@ def __init__( def model_output_def(self): """Get the output def for the model.""" - return ModelOutputDef(self.fitting_output_def()) + return ModelOutputDef(self.atomic_output_def()) def model_output_type(self) -> str: """Get the output type for the model.""" @@ -223,7 +223,7 @@ def call_lower( ) model_predict = fit_output_to_model_output( atomic_ret, - self.fitting_output_def(), + self.atomic_output_def(), cc_ext, do_atomic_virial=do_atomic_virial, ) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index f8b737b58e..8827e3f18b 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -13,6 +13,10 @@ from deepmd.dpmodel.atomic_model import ( make_base_atomic_model, ) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) from deepmd.pt.utils import ( AtomExcludeMask, PairExcludeMask, @@ -60,6 +64,25 @@ def reinit_pair_exclude( def get_model_def_script(self) -> str: return self.model_def_script + def atomic_output_def(self) -> FittingOutputDef: + old_def = self.fitting_output_def() + if self.atom_excl is None: + return old_def + else: + old_list = list(old_def.get_data().values()) + return FittingOutputDef( + old_list # noqa:RUF005 + + [ + OutputVariableDef( + name="mask", + shape=[1], + reduciable=False, + r_differentiable=False, + c_differentiable=False, + ) + ] + ) + def forward_common_atomic( self, extended_coord: torch.Tensor, @@ -90,6 +113,7 @@ def forward_common_atomic( atom_mask = self.atom_excl(atype) for kk in ret_dict.keys(): ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None] + ret_dict["mask"] = atom_mask return ret_dict diff --git a/deepmd/pt/model/model/make_hessian_model.py b/deepmd/pt/model/model/make_hessian_model.py index 0ed14b1931..9588348f53 100644 --- a/deepmd/pt/model/model/make_hessian_model.py +++ b/deepmd/pt/model/model/make_hessian_model.py @@ -25,7 +25,7 @@ def make_hessian_model(T_Model): Parameters ---------- T_Model - The model. Should provide the `forward_common` and `fitting_output_def` methods + The model. Should provide the `forward_common` and `atomic_output_def` methods Returns ------- @@ -43,7 +43,7 @@ def __init__( *args, **kwargs, ) - self.hess_fitting_def = copy.deepcopy(super().fitting_output_def()) + self.hess_fitting_def = copy.deepcopy(super().atomic_output_def()) def requires_hessian( self, @@ -56,7 +56,7 @@ def requires_hessian( if kk in keys: self.hess_fitting_def[kk].r_hessian = True - def fitting_output_def(self): + def atomic_output_def(self): """Get the fitting output def.""" return self.hess_fitting_def @@ -102,7 +102,7 @@ def forward_common( aparam=aparam, do_atomic_virial=do_atomic_virial, ) - vdef = self.fitting_output_def() + vdef = self.atomic_output_def() hess_yes = [vdef[kk].r_hessian for kk in vdef.keys()] if any(hess_yes): hess = self._cal_hessian_all( @@ -128,7 +128,7 @@ def _cal_hessian_all( box = box.view([nf, 9]) if box is not None else None fparam = fparam.view([nf, -1]) if fparam is not None else None aparam = aparam.view([nf, nloc, -1]) if aparam is not None else None - fdef = self.fitting_output_def() + fdef = self.atomic_output_def() # keys of values that require hessian hess_keys: List[str] = [] for kk in fdef.keys(): diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 4f35acb60e..60b71400fb 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -72,7 +72,7 @@ def __init__( def model_output_def(self): """Get the output def for the model.""" - return ModelOutputDef(self.fitting_output_def()) + return ModelOutputDef(self.atomic_output_def()) @torch.jit.export def model_output_type(self) -> str: @@ -218,7 +218,7 @@ def forward_common_lower( ) model_predict = fit_output_to_model_output( atomic_ret, - self.fitting_output_def(), + self.atomic_output_def(), cc_ext, do_atomic_virial=do_atomic_virial, ) diff --git a/source/tests/common/dpmodel/test_dp_atomic_model.py b/source/tests/common/dpmodel/test_dp_atomic_model.py index f97299cf72..ac49280b82 100644 --- a/source/tests/common/dpmodel/test_dp_atomic_model.py +++ b/source/tests/common/dpmodel/test_dp_atomic_model.py @@ -50,3 +50,68 @@ def test_self_consistency( ret1 = md1.forward_common_atomic(self.coord_ext, self.atype_ext, self.nlist) np.testing.assert_allclose(ret0["energy"], ret1["energy"]) + + def test_excl_consistency(self): + type_map = ["foo", "bar"] + + # test the case of exclusion + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ) + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ) + md1 = DPAtomicModel.deserialize(md0.serialize()) + + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + # hacking! + md1.descriptor.reinit_exclude(pair_excl) + md1.fitting.reinit_exclude(atom_excl) + + # check energy consistency + args = [self.coord_ext, self.atype_ext, self.nlist] + ret0 = md0.forward_common_atomic(*args) + ret1 = md1.forward_common_atomic(*args) + np.testing.assert_allclose( + ret0["energy"], + ret1["energy"], + ) + + # check output def + out_names = [vv.name for vv in md0.atomic_output_def().get_data().values()] + if atom_excl == []: + self.assertEqual(out_names, ["energy"]) + else: + self.assertEqual(out_names, ["energy", "mask"]) + for ii in md0.atomic_output_def().get_data().values(): + if ii.name == "mask": + self.assertEqual(ii.shape, [1]) + self.assertFalse(ii.reduciable) + self.assertFalse(ii.r_differentiable) + self.assertFalse(ii.c_differentiable) + + # check mask + if atom_excl == []: + pass + elif atom_excl == [1]: + self.assertIn("mask", ret0.keys()) + expected = np.array([1, 1, 0], dtype=int) + expected = np.concatenate( + [expected, expected[self.perm[: self.nloc]]] + ).reshape(2, 3) + np.testing.assert_array_equal(ret0["mask"], expected) + else: + raise ValueError(f"not expected atom_excl {atom_excl}") diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index 88bb3ab763..6daaeef2ef 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -152,6 +152,7 @@ def test_excl_consistency(self): md1.descriptor.reinit_exclude(pair_excl) md1.fitting_net.reinit_exclude(atom_excl) + # check energy consistency args = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] @@ -162,3 +163,29 @@ def test_excl_consistency(self): to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]), ) + + # check output def + out_names = [vv.name for vv in md0.atomic_output_def().get_data().values()] + if atom_excl == []: + self.assertEqual(out_names, ["energy"]) + else: + self.assertEqual(out_names, ["energy", "mask"]) + for ii in md0.atomic_output_def().get_data().values(): + if ii.name == "mask": + self.assertEqual(ii.shape, [1]) + self.assertFalse(ii.reduciable) + self.assertFalse(ii.r_differentiable) + self.assertFalse(ii.c_differentiable) + + # check mask + if atom_excl == []: + pass + elif atom_excl == [1]: + self.assertIn("mask", ret0.keys()) + expected = np.array([1, 1, 0], dtype=int) + expected = np.concatenate( + [expected, expected[self.perm[: self.nloc]]] + ).reshape(2, 3) + np.testing.assert_array_equal(to_numpy_array(ret0["mask"]), expected) + else: + raise ValueError(f"not expected atom_excl {atom_excl}") diff --git a/source/tests/pt/model/test_make_hessian_model.py b/source/tests/pt/model/test_make_hessian_model.py index 1fb7e6f53a..7d9ae2b810 100644 --- a/source/tests/pt/model/test_make_hessian_model.py +++ b/source/tests/pt/model/test_make_hessian_model.py @@ -166,8 +166,8 @@ def setUp(self): self.model_hess.requires_hessian("energy") def test_output_def(self): - self.assertTrue(self.model_hess.fitting_output_def()["energy"].r_hessian) - self.assertFalse(self.model_valu.fitting_output_def()["energy"].r_hessian) + self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian) + self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian) self.assertTrue(self.model_hess.model_output_def()["energy"].r_hessian) self.assertEqual( self.model_hess.model_output_def()["energy_derv_r_derv_r"].category,