From b48c5eec26cc85d96759a5cae30faf5c1718ab5b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 28 Feb 2024 23:36:50 +0800 Subject: [PATCH] reduced variables are converted to float64 --- deepmd/dpmodel/common.py | 2 +- deepmd/dpmodel/model/make_model.py | 27 ++++++++++++++++---- deepmd/pt/model/model/make_model.py | 27 ++++++++++++++++---- deepmd/pt/utils/env.py | 3 +++ source/tests/common/dpmodel/test_dp_model.py | 18 ++++++++----- source/tests/pt/model/test_dp_model.py | 12 +++++++-- 6 files changed, 69 insertions(+), 20 deletions(-) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index e33143adde..b9af55940c 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -6,7 +6,7 @@ import numpy as np -from deepmd.common import ( +from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, ) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 50bd36888d..1261906148 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -9,6 +9,7 @@ import numpy as np from deepmd.dpmodel.common import ( + GLOBAL_ENER_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION, PRECISION_DICT, RESERVED_PRECISON_DICT, @@ -17,6 +18,8 @@ from deepmd.dpmodel.output_def import ( ModelOutputDef, OutputVariableCategory, + OutputVariableOperation, + check_operation_applied, ) from deepmd.dpmodel.utils import ( build_neighbor_list, @@ -67,6 +70,7 @@ def __init__( self.precision_dict = PRECISION_DICT self.reverse_precision_dict = RESERVED_PRECISON_DICT self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION + self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION def model_output_def(self): """Get the output def for the model.""" @@ -272,13 +276,26 @@ def output_type_cast( input_prec: str, ) -> Dict[str, np.ndarray]: """Convert the model output to the input prec.""" - if ( + do_cast = ( input_prec != self.reverse_precision_dict[self.global_np_float_precision] - ): - pp = self.precision_dict[input_prec] - for kk, vv in model_ret.items(): - model_ret[kk] = vv.astype(pp) if vv is not None else None + ) + pp = self.precision_dict[input_prec] + odef = self.model_output_def() + for kk in odef.keys(): + if kk not in model_ret.keys(): + # do not return energy_derv_c if not do_atomic_virial + continue + if check_operation_applied(odef[kk], OutputVariableOperation.REDU): + model_ret[kk] = ( + model_ret[kk].astype(self.global_ener_float_precision) + if model_ret[kk] is not None + else None + ) + elif do_cast: + model_ret[kk] = ( + model_ret[kk].astype(pp) if model_ret[kk] is not None else None + ) return model_ret def format_nlist( diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 05ccac08f4..3efd3fb046 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -13,12 +13,15 @@ ) from deepmd.dpmodel.output_def import ( OutputVariableCategory, + OutputVariableOperation, + check_operation_applied, ) from deepmd.pt.model.model.transform_output import ( communicate_extended_output, fit_output_to_model_output, ) from deepmd.pt.utils.env import ( + GLOBAL_PT_ENER_FLOAT_PRECISION, GLOBAL_PT_FLOAT_PRECISION, PRECISION_DICT, RESERVED_PRECISON_DICT, @@ -65,6 +68,7 @@ def __init__( self.precision_dict = PRECISION_DICT self.reverse_precision_dict = RESERVED_PRECISON_DICT self.global_pt_float_precision = GLOBAL_PT_FLOAT_PRECISION + self.global_pt_ener_float_precision = GLOBAL_PT_ENER_FLOAT_PRECISION def model_output_def(self): """Get the output def for the model.""" @@ -272,13 +276,26 @@ def output_type_cast( input_prec: str, ) -> Dict[str, torch.Tensor]: """Convert the model output to the input prec.""" - if ( + do_cast = ( input_prec != self.reverse_precision_dict[self.global_pt_float_precision] - ): - pp = self.precision_dict[input_prec] - for kk, vv in model_ret.items(): - model_ret[kk] = vv.to(pp) if vv is not None else None + ) + pp = self.precision_dict[input_prec] + odef = self.model_output_def() + for kk in odef.keys(): + if kk not in model_ret.keys(): + # do not return energy_derv_c if not do_atomic_virial + continue + if check_operation_applied(odef[kk], OutputVariableOperation.REDU): + model_ret[kk] = ( + model_ret[kk].to(self.global_pt_ener_float_precision) + if model_ret[kk] is not None + else None + ) + elif do_cast: + model_ret[kk] = ( + model_ret[kk].to(pp) if model_ret[kk] is not None else None + ) return model_ret @torch.jit.export diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 7383cf5c49..5616f0a9a5 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -42,6 +42,9 @@ "int64": torch.int64, } GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] +GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[ + np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name +] PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION # cannot automatically generated RESERVED_PRECISON_DICT = { diff --git a/source/tests/common/dpmodel/test_dp_model.py b/source/tests/common/dpmodel/test_dp_model.py index acbf77f6e9..c3de1f4cdf 100644 --- a/source/tests/common/dpmodel/test_dp_model.py +++ b/source/tests/common/dpmodel/test_dp_model.py @@ -3,9 +3,6 @@ import numpy as np -from deepmd.dpmodel.common import ( - RESERVED_PRECISON_DICT, -) from deepmd.dpmodel.descriptor import ( DescrptSeA, ) @@ -86,9 +83,11 @@ def test_prec_consistency(self): for ii in model_l_ret_32.keys(): if model_l_ret_32[ii] is None: continue - self.assertEqual( - model_l_ret_32[ii].dtype.name, RESERVED_PRECISON_DICT[np.float32] - ) + if ii[-4:] == "redu": + self.assertEqual(model_l_ret_32[ii].dtype, np.float64) + else: + self.assertEqual(model_l_ret_32[ii].dtype, np.float32) + self.assertEqual(model_l_ret_64[ii].dtype, np.float64) np.testing.assert_allclose( model_l_ret_32[ii], model_l_ret_64[ii], @@ -135,7 +134,12 @@ def test_prec_consistency(self): for ii in model_l_ret_32.keys(): if model_l_ret_32[ii] is None: continue - self.assertEqual(model_l_ret_32[ii].dtype, np.float32) + if ii[-4:] == "redu": + self.assertEqual(model_l_ret_32[ii].dtype, np.float64) + else: + self.assertEqual(model_l_ret_32[ii].dtype, np.float32) + self.assertEqual(model_l_ret_64[ii].dtype, np.float64) + self.assertEqual(model_l_ret_64[ii].dtype, np.float64) np.testing.assert_allclose( model_l_ret_32[ii], model_l_ret_64[ii], diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index 13accf8bb8..840ba284e2 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -223,7 +223,11 @@ def test_prec_consistency(self): model_l_ret_32 = md1.forward_common(*args32, fparam=fparam, aparam=aparam) for ii in model_l_ret_32.keys(): - self.assertEqual(model_l_ret_32[ii].dtype, torch.float32) + if ii[-4:] == "redu": + self.assertEqual(model_l_ret_32[ii].dtype, torch.float64) + else: + self.assertEqual(model_l_ret_32[ii].dtype, torch.float32) + self.assertEqual(model_l_ret_64[ii].dtype, torch.float64) np.testing.assert_allclose( to_numpy_array(model_l_ret_32[ii]), to_numpy_array(model_l_ret_64[ii]), @@ -351,7 +355,11 @@ def test_prec_consistency(self): model_l_ret_32 = md1.forward_common_lower(*args32, fparam=fparam, aparam=aparam) for ii in model_l_ret_32.keys(): - self.assertEqual(model_l_ret_32[ii].dtype, torch.float32) + if ii[-4:] == "redu": + self.assertEqual(model_l_ret_32[ii].dtype, torch.float64) + else: + self.assertEqual(model_l_ret_32[ii].dtype, torch.float32) + self.assertEqual(model_l_ret_64[ii].dtype, torch.float64) np.testing.assert_allclose( to_numpy_array(model_l_ret_32[ii]), to_numpy_array(model_l_ret_64[ii]),