Skip to content

Commit

Permalink
raise precision before reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Feb 28, 2024
1 parent 651c689 commit 9bc9a27
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
8 changes: 7 additions & 1 deletion deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import numpy as np

from deepmd.dpmodel.common import (
GLOBAL_ENER_FLOAT_PRECISION,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
Expand All @@ -30,7 +33,10 @@ def fit_output_to_model_output(
atom_axis = -(len(shap) + 1)
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
model_ret[kk_redu] = np.sum(vv, axis=atom_axis)
# cast to energy prec brefore reduction
model_ret[kk_redu] = np.sum(
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
)
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name-holders
Expand Down
12 changes: 9 additions & 3 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,16 @@ def fit_output_to_model_output(
the model output.
"""
## should have been GLOBAL_PT_ENER_FLOAT_PRECISION, but does not pass jit!!!
redu_prec = torch.float64
model_ret = dict(fit_ret.items())
for kk, vv in fit_ret.items():
vdef = fit_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
model_ret[kk_redu] = torch.sum(vv, dim=atom_axis)
model_ret[kk_redu] = torch.sum(vv.to(redu_prec), dim=atom_axis)
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
dr, dc = take_deriv(
Expand All @@ -171,7 +173,7 @@ def fit_output_to_model_output(
assert dc is not None
model_ret[kk_derv_c] = dc
model_ret[kk_derv_c + "_redu"] = torch.sum(
model_ret[kk_derv_c], dim=1
model_ret[kk_derv_c].to(redu_prec), dim=1
)
return model_ret

Expand All @@ -186,6 +188,8 @@ def communicate_extended_output(
local and ghost (extended) atoms to local atoms.
"""
## should have been GLOBAL_PT_ENER_FLOAT_PRECISION, but does not pass jit!!!
redu_prec = torch.float64
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
Expand Down Expand Up @@ -235,7 +239,9 @@ def communicate_extended_output(
src=model_ret[kk_derv_c],
reduce="sum",
)
new_ret[kk_derv_c + "_redu"] = torch.sum(new_ret[kk_derv_c], dim=1)
new_ret[kk_derv_c + "_redu"] = torch.sum(
new_ret[kk_derv_c].to(redu_prec), dim=1
)
if not do_atomic_virial:
# pop atomic virial, because it is not correctly calculated.
new_ret.pop(kk_derv_c)
Expand Down

0 comments on commit 9bc9a27

Please sign in to comment.