Skip to content

Commit

Permalink
Merge branch 'devel' into rf_dpa2_consist
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored May 9, 2024
2 parents 375c03e + 3b59e2b commit 2f280e6
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 17 deletions.
11 changes: 11 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ def __init__(
):
super().__init__(type_map, **kwargs)
super().init_out_stat()

# check all sub models are of mixed type.
model_mixed_type = []
for m in models:
if not m.mixed_types():
model_mixed_type.append(m)
if len(model_mixed_type) > 0:
raise ValueError(
f"LinearAtomicModel only supports AtomicModel of mixed type, the following models are not mixed type: {model_mixed_type}."
)

self.models = models
sub_model_type_maps = [md.get_type_map() for md in models]
err_msg = []
Expand Down
11 changes: 11 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def __init__(
):
super().__init__(type_map, **kwargs)
super().init_out_stat()

# check all sub models are of mixed type.
model_mixed_type = []
for m in models:
if not m.mixed_types():
model_mixed_type.append(m)
if len(model_mixed_type) > 0:
raise ValueError(
f"LinearAtomicModel only supports AtomicModel of mixed type, the following models are not mixed type: {model_mixed_type}."
)

self.models = torch.nn.ModuleList(models)
sub_model_type_maps = [md.get_type_map() for md in models]
err_msg = []
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def forward(
protection=self.env_protection,
)
nlist_mask = nlist != -1
nlist[nlist == -1] = 0
nlist = torch.where(nlist == -1, 0, nlist)
sw = torch.squeeze(sw, -1)
# beyond the cutoff sw should be 0.0
sw = sw.masked_fill(~nlist_mask, 0.0)
Expand Down
12 changes: 7 additions & 5 deletions source/tests/common/dpmodel/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from deepmd.dpmodel.atomic_model.pairtab_atomic_model import (
PairTabAtomicModel,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
from deepmd.dpmodel.descriptor import (
DescrptDPA1,
)
from deepmd.dpmodel.fitting.invar_fitting import (
InvarFitting,
Expand All @@ -39,10 +39,11 @@ def test_pairwise(self, mock_loadtxt):
extended_atype = np.array([[0, 0]])
nlist = np.array([[[1], [-1]]])

ds = DescrptSeA(
ds = DescrptDPA1(
rcut_smth=0.3,
rcut=0.4,
sel=[3],
ntypes=2,
)
ft = InvarFitting(
"energy",
Expand Down Expand Up @@ -134,10 +135,11 @@ def setUp(self, mock_loadtxt):
[0.02, 0.25, 0.4, 0.75],
]
)
ds = DescrptSeA(
ds = DescrptDPA1(
self.rcut,
self.rcut_smth,
self.sel,
sum(self.sel),
self.nt,
)
ft = InvarFitting(
"energy",
Expand Down
12 changes: 7 additions & 5 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
DPZBLLinearEnergyAtomicModel,
PairTabAtomicModel,
)
from deepmd.pt.model.descriptor.se_a import (
DescrptSeA,
from deepmd.pt.model.descriptor import (
DescrptDPA1,
)
from deepmd.pt.model.model import (
DPZBLModel,
Expand Down Expand Up @@ -55,10 +55,11 @@ def test_pairwise(self, mock_loadtxt):
extended_atype = torch.tensor([[0, 0]], device=env.DEVICE)
nlist = torch.tensor([[[1], [-1]]], device=env.DEVICE)

ds = DescrptSeA(
ds = DescrptDPA1(
rcut_smth=0.3,
rcut=0.4,
sel=[3],
ntypes=2,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
Expand Down Expand Up @@ -128,10 +129,11 @@ def setUp(self, mock_loadtxt):
[0.02, 0.25, 0.4, 0.75],
]
)
ds = DescrptSeA(
ds = DescrptDPA1(
self.rcut,
self.rcut_smth,
self.sel,
sum(self.sel),
self.nt,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
Expand Down
20 changes: 14 additions & 6 deletions source/tests/pt/model/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,22 @@
"sw_rmin": 0.2,
"sw_rmax": 1.0,
"descriptor": {
"type": "se_e2_a",
"sel": [46, 92, 4],
"rcut_smth": 0.50,
"rcut": 6.00,
"type": "se_atten",
"sel": 40,
"rcut_smth": 0.5,
"rcut": 4.0,
"neuron": [25, 50, 100],
"resnet_dt": False,
"axis_neuron": 16,
"seed": 1,
"attn": 64,
"attn_layer": 2,
"attn_dotr": True,
"attn_mask": False,
"activation_function": "tanh",
"scaling_factor": 1.0,
"normalize": False,
"temperature": 1.0,
"set_davg_zero": True,
"type_one_side": True,
},
"fitting_net": {
"neuron": [24, 24, 24],
Expand Down

0 comments on commit 2f280e6

Please sign in to comment.