Skip to content

Commit

Permalink
reduced variables are converted to float64
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Feb 28, 2024
1 parent 9bc9a27 commit b48c5ee
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 20 deletions.
2 changes: 1 addition & 1 deletion deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from deepmd.common import (
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)

Expand Down
27 changes: 22 additions & 5 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from deepmd.dpmodel.common import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
PRECISION_DICT,
RESERVED_PRECISON_DICT,
Expand All @@ -17,6 +18,8 @@
from deepmd.dpmodel.output_def import (
ModelOutputDef,
OutputVariableCategory,
OutputVariableOperation,
check_operation_applied,
)
from deepmd.dpmodel.utils import (
build_neighbor_list,
Expand Down Expand Up @@ -67,6 +70,7 @@ def __init__(
self.precision_dict = PRECISION_DICT
self.reverse_precision_dict = RESERVED_PRECISON_DICT
self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION
self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION

def model_output_def(self):
"""Get the output def for the model."""
Expand Down Expand Up @@ -272,13 +276,26 @@ def output_type_cast(
input_prec: str,
) -> Dict[str, np.ndarray]:
"""Convert the model output to the input prec."""
if (
do_cast = (
input_prec
!= self.reverse_precision_dict[self.global_np_float_precision]
):
pp = self.precision_dict[input_prec]
for kk, vv in model_ret.items():
model_ret[kk] = vv.astype(pp) if vv is not None else None
)
pp = self.precision_dict[input_prec]
odef = self.model_output_def()
for kk in odef.keys():
if kk not in model_ret.keys():
# do not return energy_derv_c if not do_atomic_virial
continue
if check_operation_applied(odef[kk], OutputVariableOperation.REDU):
model_ret[kk] = (
model_ret[kk].astype(self.global_ener_float_precision)
if model_ret[kk] is not None
else None
)
elif do_cast:
model_ret[kk] = (
model_ret[kk].astype(pp) if model_ret[kk] is not None else None
)
return model_ret

def format_nlist(
Expand Down
27 changes: 22 additions & 5 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
)
from deepmd.dpmodel.output_def import (
OutputVariableCategory,
OutputVariableOperation,
check_operation_applied,
)
from deepmd.pt.model.model.transform_output import (
communicate_extended_output,
fit_output_to_model_output,
)
from deepmd.pt.utils.env import (
GLOBAL_PT_ENER_FLOAT_PRECISION,
GLOBAL_PT_FLOAT_PRECISION,
PRECISION_DICT,
RESERVED_PRECISON_DICT,
Expand Down Expand Up @@ -65,6 +68,7 @@ def __init__(
self.precision_dict = PRECISION_DICT
self.reverse_precision_dict = RESERVED_PRECISON_DICT
self.global_pt_float_precision = GLOBAL_PT_FLOAT_PRECISION
self.global_pt_ener_float_precision = GLOBAL_PT_ENER_FLOAT_PRECISION

def model_output_def(self):
"""Get the output def for the model."""
Expand Down Expand Up @@ -272,13 +276,26 @@ def output_type_cast(
input_prec: str,
) -> Dict[str, torch.Tensor]:
"""Convert the model output to the input prec."""
if (
do_cast = (
input_prec
!= self.reverse_precision_dict[self.global_pt_float_precision]
):
pp = self.precision_dict[input_prec]
for kk, vv in model_ret.items():
model_ret[kk] = vv.to(pp) if vv is not None else None
)
pp = self.precision_dict[input_prec]
odef = self.model_output_def()
for kk in odef.keys():
if kk not in model_ret.keys():
# do not return energy_derv_c if not do_atomic_virial
continue
if check_operation_applied(odef[kk], OutputVariableOperation.REDU):
model_ret[kk] = (
model_ret[kk].to(self.global_pt_ener_float_precision)
if model_ret[kk] is not None
else None
)
elif do_cast:
model_ret[kk] = (
model_ret[kk].to(pp) if model_ret[kk] is not None else None
)
return model_ret

@torch.jit.export
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
"int64": torch.int64,
}
GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name]
GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[
np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name
]
PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION
# cannot automatically generated
RESERVED_PRECISON_DICT = {
Expand Down
18 changes: 11 additions & 7 deletions source/tests/common/dpmodel/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

import numpy as np

from deepmd.dpmodel.common import (
RESERVED_PRECISON_DICT,
)
from deepmd.dpmodel.descriptor import (
DescrptSeA,
)
Expand Down Expand Up @@ -86,9 +83,11 @@ def test_prec_consistency(self):
for ii in model_l_ret_32.keys():
if model_l_ret_32[ii] is None:
continue
self.assertEqual(
model_l_ret_32[ii].dtype.name, RESERVED_PRECISON_DICT[np.float32]
)
if ii[-4:] == "redu":
self.assertEqual(model_l_ret_32[ii].dtype, np.float64)
else:
self.assertEqual(model_l_ret_32[ii].dtype, np.float32)
self.assertEqual(model_l_ret_64[ii].dtype, np.float64)
np.testing.assert_allclose(
model_l_ret_32[ii],
model_l_ret_64[ii],
Expand Down Expand Up @@ -135,7 +134,12 @@ def test_prec_consistency(self):
for ii in model_l_ret_32.keys():
if model_l_ret_32[ii] is None:
continue
self.assertEqual(model_l_ret_32[ii].dtype, np.float32)
if ii[-4:] == "redu":
self.assertEqual(model_l_ret_32[ii].dtype, np.float64)
else:
self.assertEqual(model_l_ret_32[ii].dtype, np.float32)
self.assertEqual(model_l_ret_64[ii].dtype, np.float64)
self.assertEqual(model_l_ret_64[ii].dtype, np.float64)
np.testing.assert_allclose(
model_l_ret_32[ii],
model_l_ret_64[ii],
Expand Down
12 changes: 10 additions & 2 deletions source/tests/pt/model/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ def test_prec_consistency(self):
model_l_ret_32 = md1.forward_common(*args32, fparam=fparam, aparam=aparam)

for ii in model_l_ret_32.keys():
self.assertEqual(model_l_ret_32[ii].dtype, torch.float32)
if ii[-4:] == "redu":
self.assertEqual(model_l_ret_32[ii].dtype, torch.float64)
else:
self.assertEqual(model_l_ret_32[ii].dtype, torch.float32)
self.assertEqual(model_l_ret_64[ii].dtype, torch.float64)
np.testing.assert_allclose(
to_numpy_array(model_l_ret_32[ii]),
to_numpy_array(model_l_ret_64[ii]),
Expand Down Expand Up @@ -351,7 +355,11 @@ def test_prec_consistency(self):
model_l_ret_32 = md1.forward_common_lower(*args32, fparam=fparam, aparam=aparam)

for ii in model_l_ret_32.keys():
self.assertEqual(model_l_ret_32[ii].dtype, torch.float32)
if ii[-4:] == "redu":
self.assertEqual(model_l_ret_32[ii].dtype, torch.float64)
else:
self.assertEqual(model_l_ret_32[ii].dtype, torch.float32)
self.assertEqual(model_l_ret_64[ii].dtype, torch.float64)
np.testing.assert_allclose(
to_numpy_array(model_l_ret_32[ii]),
to_numpy_array(model_l_ret_64[ii]),
Expand Down

0 comments on commit b48c5ee

Please sign in to comment.