Skip to content

Commit

Permalink
unittest for the compression of smooth se_atten descriptor (#2916)
Browse files Browse the repository at this point in the history
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 13, 2023
1 parent 8bc4e3f commit fc09e77
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
44 changes: 37 additions & 7 deletions source/tests/test_model_compression_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,31 @@ def _subprocess_run(command):
# - type embedding FP32, se_atten FP64
# - type embedding FP32, se_atten FP32
tests = [
{"se_atten precision": "float64", "type embedding precision": "float64"},
{"se_atten precision": "float64", "type embedding precision": "float32"},
{"se_atten precision": "float32", "type embedding precision": "float64"},
{"se_atten precision": "float32", "type embedding precision": "float32"},
{
"se_atten precision": "float64",
"type embedding precision": "float64",
"smooth_type_embdding": True,
},
{
"se_atten precision": "float64",
"type embedding precision": "float64",
"smooth_type_embdding": False,
},
{
"se_atten precision": "float64",
"type embedding precision": "float32",
"smooth_type_embdding": True,
},
{
"se_atten precision": "float32",
"type embedding precision": "float64",
"smooth_type_embdding": True,
},
{
"se_atten precision": "float32",
"type embedding precision": "float32",
"smooth_type_embdding": True,
},
]


Expand All @@ -73,6 +94,9 @@ def _init_models():
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
jdata["model"]["descriptor"]["sel"] = 120
jdata["model"]["descriptor"]["attn_layer"] = 0
jdata["model"]["descriptor"]["smooth_type_embdding"] = tests[i][
"smooth_type_embdding"
]
jdata["model"]["type_embedding"] = {}
jdata["model"]["type_embedding"]["precision"] = tests[i][
"type embedding precision"
Expand Down Expand Up @@ -479,9 +503,15 @@ def test_1frame(self):
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)
np.testing.assert_almost_equal(
ff0, ff1, default_places, err_msg=str(tests[i])
)
np.testing.assert_almost_equal(
ee0, ee1, default_places, err_msg=str(tests[i])
)
np.testing.assert_almost_equal(
vv0, vv1, default_places, err_msg=str(tests[i])
)

def test_1frame_atm(self):
for i in range(len(tests)):
Expand Down
2 changes: 2 additions & 0 deletions source/tests/test_model_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,8 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self):
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
jdata["model"]["descriptor"]["smooth_type_embdding"] = True
jdata["model"]["descriptor"]["attn_layer"] = 1
jdata["model"]["descriptor"]["rcut"] = 6.0
jdata["model"]["descriptor"]["rcut_smth"] = 4.0
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
jdata["model"]["fitting_net"]["descrpt"] = descrpt
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
Expand Down

0 comments on commit fc09e77

Please sign in to comment.