Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: fix spin nlist in spin_model #3718

Merged
merged 8 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,20 @@ def extend_nlist(extended_atype, nlist):
nlist_shift = nlist + nall
nlist[~nlist_mask] = -1
nlist_shift[~nlist_mask] = -1
self_spin = torch.arange(0, nloc, dtype=nlist.dtype, device=nlist.device) + nall
self_spin = self_spin.view(1, -1, 1).expand(nframes, -1, -1)
# self spin + real neighbor + virtual neighbor
self_real = (
torch.arange(0, nloc, dtype=nlist.dtype, device=nlist.device)
.view(1, -1, 1)
.expand(nframes, -1, -1)
)
self_spin = self_real + nall
# real atom's neighbors: self spin + real neighbor + virtual neighbor
# nf x nloc x (1 + nnei + nnei)
real_nlist = torch.cat([self_spin, nlist, nlist_shift], dim=-1)
# spin atom's neighbors: real + real neighbor + virtual neighbor
# nf x nloc x (1 + nnei + nnei)
extended_nlist = torch.cat([self_spin, nlist, nlist_shift], dim=-1)
spin_nlist = torch.cat([self_real, nlist, nlist_shift], dim=-1)
# nf x (nloc + nloc) x (1 + nnei + nnei)
extended_nlist = torch.cat(
[extended_nlist, -1 * torch.ones_like(extended_nlist)], dim=-2
)
extended_nlist = torch.cat([real_nlist, spin_nlist], dim=-2)
# update the index for switch
first_part_index = (nloc <= extended_nlist) & (extended_nlist < nall)
second_part_index = (nall <= extended_nlist) & (extended_nlist < (nall + nloc))
Expand Down
25 changes: 20 additions & 5 deletions source/tests/pt/model/test_forward_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,13 @@ class TestEnergyModelSeA(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_se_e2_a)
self.type_split = False
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelDPA1(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_dpa1)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


Expand All @@ -151,15 +149,13 @@ def setUp(self):
"repinit_nsel"
]
model_params = copy.deepcopy(model_dpa2)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelZBL(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_zbl)
self.type_split = False
self.model = get_model(model_params).to(env.DEVICE)


Expand All @@ -168,7 +164,26 @@ def setUp(self):
# still need to figure out why only 1e-5 rtol and atol
self.prec = 1e-5
model_params = copy.deepcopy(model_spin)
self.type_split = False
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelSpinDPA1(unittest.TestCase, ForwardLowerTest):
def setUp(self):
# still need to figure out why only 1e-4 rtol and atol
self.prec = 1e-4
model_params = copy.deepcopy(model_spin)
model_params["descriptor"] = copy.deepcopy(model_dpa1)["descriptor"]
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelSpinDPA2(unittest.TestCase, ForwardLowerTest):
def setUp(self):
# still need to figure out why only 1e-4 rtol and atol
self.prec = 1e-4
iProzd marked this conversation as resolved.
Show resolved Hide resolved
model_params = copy.deepcopy(model_spin)
model_params["descriptor"] = copy.deepcopy(model_dpa2)["descriptor"]
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)

Expand Down
Loading