Skip to content

Commit

Permalink
dtype
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Mar 2, 2024
1 parent 80337f7 commit 9e55925
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def compute_smooth_weight(
uu = (distance - rmin) / (rmax - rmin)
with np.errstate(invalid="ignore"):
vv = uu * uu * uu * (-6.0 * uu * uu + 15.0 * uu - 10.0) + 1.0
return np.where(mid_mask, vv, min_mask)
return np.where(mid_mask, vv, min_mask.astype(distance.dtype))

Check warning on line 27 in deepmd/dpmodel/utils/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/env_mat.py#L23-L27

Added lines #L23 - L27 were not covered by tests


def _make_env_mat(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def compute_smooth_weight(distance, rmin: float, rmax: float):
mid_mask = torch.logical_not(torch.logical_or(min_mask, max_mask))
uu = (distance - rmin) / (rmax - rmin)
vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1
return torch.where(mid_mask, vv, min_mask)
return torch.where(mid_mask, vv, min_mask.to(dtype=distance.dtype))

Check warning on line 236 in deepmd/pt/utils/preprocess.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/preprocess.py#L236

Added line #L236 was not covered by tests


def make_env_mat(
Expand Down

0 comments on commit 9e55925

Please sign in to comment.