diff --git a/source/tests/common/dpmodel/test_descriptor_dpa2.py b/source/tests/common/dpmodel/test_descriptor_dpa2.py index 4df01c61ad..3ae0689dad 100644 --- a/source/tests/common/dpmodel/test_descriptor_dpa2.py +++ b/source/tests/common/dpmodel/test_descriptor_dpa2.py @@ -45,5 +45,13 @@ def test_self_consistency( em1 = DescrptDPA2.deserialize(em0.serialize()) mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, self.mapping) mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist, self.mapping) - for ii in [0, 1, 4]: + desired_shape = [ + (nf, nloc, em0.get_dim_out()), # descriptor + (nf, nloc, em0.get_dim_emb(), 3), # rot_mat + (nf, nloc, nnei // 2, em0.repformers.g2_dim), # g2 + (nf, nloc, nnei // 2, 3), # h2 + (nf, nloc, nnei // 2), # sw + ] + for ii in [0, 1, 2, 3, 4]: + np.testing.assert_equal(mm0[ii].shape, desired_shape[ii]) np.testing.assert_allclose(mm0[ii], mm1[ii])