diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index b38d309fd7..7dff9078c5 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -54,6 +54,17 @@ def __init__( ): super().__init__(type_map, **kwargs) super().init_out_stat() + + # check all sub models are of mixed type. + model_mixed_type = [] + for m in models: + if not m.mixed_types(): + model_mixed_type.append(m) + if len(model_mixed_type) > 0: + raise ValueError( + f"LinearAtomicModel only supports AtomicModel of mixed type, the following models are not mixed type: {model_mixed_type}." + ) + self.models = models sub_model_type_maps = [md.get_type_map() for md in models] err_msg = [] diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index bf03a68f31..d21c65c257 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -61,6 +61,17 @@ def __init__( ): super().__init__(type_map, **kwargs) super().init_out_stat() + + # check all sub models are of mixed type. + model_mixed_type = [] + for m in models: + if not m.mixed_types(): + model_mixed_type.append(m) + if len(model_mixed_type) > 0: + raise ValueError( + f"LinearAtomicModel only supports AtomicModel of mixed type, the following models are not mixed type: {model_mixed_type}." + ) + self.models = torch.nn.ModuleList(models) sub_model_type_maps = [md.get_type_map() for md in models] err_msg = [] diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 958f3b4963..5de0aeffab 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -464,7 +464,7 @@ def forward( protection=self.env_protection, ) nlist_mask = nlist != -1 - nlist[nlist == -1] = 0 + nlist = torch.where(nlist == -1, 0, nlist) sw = torch.squeeze(sw, -1) # beyond the cutoff sw should be 0.0 sw = sw.masked_fill(~nlist_mask, 0.0) diff --git a/source/tests/common/dpmodel/test_linear_atomic_model.py b/source/tests/common/dpmodel/test_linear_atomic_model.py index 832d1de106..b7bf310676 100644 --- a/source/tests/common/dpmodel/test_linear_atomic_model.py +++ b/source/tests/common/dpmodel/test_linear_atomic_model.py @@ -15,8 +15,8 @@ from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( PairTabAtomicModel, ) -from deepmd.dpmodel.descriptor.se_e2_a import ( - DescrptSeA, +from deepmd.dpmodel.descriptor import ( + DescrptDPA1, ) from deepmd.dpmodel.fitting.invar_fitting import ( InvarFitting, @@ -39,10 +39,11 @@ def test_pairwise(self, mock_loadtxt): extended_atype = np.array([[0, 0]]) nlist = np.array([[[1], [-1]]]) - ds = DescrptSeA( + ds = DescrptDPA1( rcut_smth=0.3, rcut=0.4, sel=[3], + ntypes=2, ) ft = InvarFitting( "energy", @@ -134,10 +135,11 @@ def setUp(self, mock_loadtxt): [0.02, 0.25, 0.4, 0.75], ] ) - ds = DescrptSeA( + ds = DescrptDPA1( self.rcut, self.rcut_smth, - self.sel, + sum(self.sel), + self.nt, ) ft = InvarFitting( "energy", diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py index 7f24ffdc53..7104095250 100644 --- a/source/tests/pt/model/test_linear_atomic_model.py +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -15,8 +15,8 @@ DPZBLLinearEnergyAtomicModel, PairTabAtomicModel, ) -from deepmd.pt.model.descriptor.se_a import ( - DescrptSeA, +from deepmd.pt.model.descriptor import ( + DescrptDPA1, ) from deepmd.pt.model.model import ( DPZBLModel, @@ -55,10 +55,11 @@ def test_pairwise(self, mock_loadtxt): extended_atype = torch.tensor([[0, 0]], device=env.DEVICE) nlist = torch.tensor([[[1], [-1]]], device=env.DEVICE) - ds = DescrptSeA( + ds = DescrptDPA1( rcut_smth=0.3, rcut=0.4, sel=[3], + ntypes=2, ).to(env.DEVICE) ft = InvarFitting( "energy", @@ -128,10 +129,11 @@ def setUp(self, mock_loadtxt): [0.02, 0.25, 0.4, 0.75], ] ) - ds = DescrptSeA( + ds = DescrptDPA1( self.rcut, self.rcut_smth, - self.sel, + sum(self.sel), + self.nt, ).to(env.DEVICE) ft = InvarFitting( "energy", diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index a76bbe4246..b2bfcb8a4d 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -68,14 +68,22 @@ "sw_rmin": 0.2, "sw_rmax": 1.0, "descriptor": { - "type": "se_e2_a", - "sel": [46, 92, 4], - "rcut_smth": 0.50, - "rcut": 6.00, + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.5, + "rcut": 4.0, "neuron": [25, 50, 100], - "resnet_dt": False, "axis_neuron": 16, - "seed": 1, + "attn": 64, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "set_davg_zero": True, + "type_one_side": True, }, "fitting_net": { "neuron": [24, 24, 24],