Skip to content

Commit

Permalink
Update test_descriptor_dpa2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed May 9, 2024
1 parent 2f280e6 commit d85eef0
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion source/tests/common/dpmodel/test_descriptor_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,13 @@ def test_self_consistency(
em1 = DescrptDPA2.deserialize(em0.serialize())
mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, self.mapping)
mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist, self.mapping)
for ii in [0, 1, 4]:
desired_shape = [
(nf, nloc, em0.get_dim_out()), # descriptor
(nf, nloc, em0.get_dim_emb(), 3), # rot_mat
(nf, nloc, nnei // 2, em0.repformers.g2_dim), # g2
(nf, nloc, nnei // 2, 3), # h2
(nf, nloc, nnei // 2), # sw
]
for ii in [0, 1, 2, 3, 4]:
np.testing.assert_equal(mm0[ii].shape, desired_shape[ii])
np.testing.assert_allclose(mm0[ii], mm1[ii])

0 comments on commit d85eef0

Please sign in to comment.