Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: Add support for dipole and polar training #3380

Merged
merged 11 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@
if self.atom_excl is not None:
atom_mask = self.atom_excl.build_type_exclude_mask(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]
out_shape = ret_dict[kk].shape
ret_dict[kk] = (

Check warning on line 82 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L81-L82

Added lines #L81 - L82 were not covered by tests
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask

Check warning on line 86 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L86

Added line #L86 was not covered by tests

return ret_dict

Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
nf, nloc = nlist.shape[:2]
if "mask" in model_predict_lower:
model_predict_lower["mask"] = model_predict_lower["mask"][:, :nloc]

Check warning on line 164 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L162-L164

Added lines #L162 - L164 were not covered by tests
model_predict = communicate_extended_output(
model_predict_lower,
self.model_output_def(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
"""
model_ret = dict(fit_ret.items())
for kk, vv in fit_ret.items():
if kk in ["mask"]:
continue

Check warning on line 32 in deepmd/dpmodel/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/transform_output.py#L31-L32

Added lines #L31 - L32 were not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
vdef = fit_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
Expand Down Expand Up @@ -59,6 +61,8 @@

"""
new_ret = {}
if "mask" in model_ret:
new_ret["mask"] = model_ret["mask"]

Check warning on line 65 in deepmd/dpmodel/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/transform_output.py#L64-L65

Added lines #L64 - L65 were not covered by tests
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
vdef = model_output_def[kk]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from .loss import (
TaskLoss,
)
from .tensor import (

Check warning on line 11 in deepmd/pt/loss/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/__init__.py#L11

Added line #L11 was not covered by tests
TensorLoss,
)

__all__ = [
"DenoiseLoss",
"EnergyStdLoss",
"TensorLoss",
"TaskLoss",
]
162 changes: 162 additions & 0 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

Check warning on line 2 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L2

Added line #L2 was not covered by tests
List,
)

import torch

Check warning on line 6 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L6

Added line #L6 was not covered by tests

from deepmd.pt.loss.loss import (

Check warning on line 8 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L8

Added line #L8 was not covered by tests
TaskLoss,
)
from deepmd.pt.utils import (

Check warning on line 11 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L11

Added line #L11 was not covered by tests
env,
)
from deepmd.utils.data import (

Check warning on line 14 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L14

Added line #L14 was not covered by tests
DataRequirementItem,
)


class TensorLoss(TaskLoss):
def __init__(

Check warning on line 20 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L19-L20

Added lines #L19 - L20 were not covered by tests
self,
tensor_name: str,
tensor_size: int,
label_name: str,
pref_atomic: float = 0.0,
pref: float = 0.0,
inference=False,
**kwargs,
):
r"""Construct a loss for local and global tensors.

Parameters
----------
tensor_name : str
The name of the tensor in the model predictions to compute the loss.
tensor_size : int
The size (dimension) of the tensor.
label_name : str
The name of the tensor in the labels to compute the loss.
pref_atomic : float
The prefactor of the weight of atomic loss. It should be larger than or equal to 0.
pref : float
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.tensor_name = tensor_name
self.tensor_size = tensor_size
self.label_name = label_name
self.local_weight = pref_atomic
self.global_weight = pref
self.inference = inference

Check warning on line 55 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L49-L55

Added lines #L49 - L55 were not covered by tests

assert (

Check warning on line 57 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L57

Added line #L57 was not covered by tests
self.local_weight >= 0.0 and self.global_weight >= 0.0
), "Can not assign negative weight to `pref` and `pref_atomic`"
self.has_local_weight = self.local_weight > 0.0 or inference
self.has_global_weight = self.global_weight > 0.0 or inference
assert self.has_local_weight or self.has_global_weight, AssertionError(

Check warning on line 62 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L60-L62

Added lines #L60 - L62 were not covered by tests
"Can not assian zero weight both to `pref` and `pref_atomic`"
)

def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):

Check warning on line 66 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L66

Added line #L66 was not covered by tests
"""Return loss on local and global tensors.

Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
del learning_rate, mae
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
more_loss = {}
if (

Check warning on line 88 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L85-L88

Added lines #L85 - L88 were not covered by tests
self.has_local_weight
and self.tensor_name in model_pred
and "atomic_" + self.label_name in label
):
local_tensor_pred = model_pred[self.tensor_name].reshape(

Check warning on line 93 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L93

Added line #L93 was not covered by tests
[-1, natoms, self.tensor_size]
)
local_tensor_label = label["atomic_" + self.label_name].reshape(

Check warning on line 96 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L96

Added line #L96 was not covered by tests
[-1, natoms, self.tensor_size]
)
diff = (local_tensor_pred - local_tensor_label).reshape(

Check warning on line 99 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L99

Added line #L99 was not covered by tests
[-1, self.tensor_size]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss = torch.mean(torch.square(diff))
if not self.inference:
more_loss[f"l2_local_{self.tensor_name}_loss"] = l2_local_loss.detach()
loss += self.local_weight * l2_local_loss
rmse_local = l2_local_loss.sqrt()
more_loss[f"rmse_local_{self.tensor_name}"] = rmse_local.detach()
if (

Check warning on line 110 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L102-L110

Added lines #L102 - L110 were not covered by tests
self.has_global_weight
and "global_" + self.tensor_name in model_pred
and self.label_name in label
):
global_tensor_pred = model_pred["global_" + self.tensor_name].reshape(

Check warning on line 115 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L115

Added line #L115 was not covered by tests
[-1, self.tensor_size]
)
global_tensor_label = label[self.label_name].reshape([-1, self.tensor_size])
diff = global_tensor_pred - global_tensor_label
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss = torch.mean(

Check warning on line 122 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L118-L122

Added lines #L118 - L122 were not covered by tests
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())

Check warning on line 125 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L125

Added line #L125 was not covered by tests
else:
atom_num = natoms
l2_global_loss = torch.mean(torch.square(diff))
if not self.inference:
more_loss[

Check warning on line 130 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L127-L130

Added lines #L127 - L130 were not covered by tests
f"l2_global_{self.tensor_name}_loss"
] = l2_global_loss.detach()
loss += self.global_weight * l2_global_loss
rmse_global = l2_global_loss.sqrt() / atom_num
more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach()
return loss, more_loss

Check warning on line 136 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L133-L136

Added lines #L133 - L136 were not covered by tests

@property
def label_requirement(self) -> List[DataRequirementItem]:

Check warning on line 139 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L138-L139

Added lines #L138 - L139 were not covered by tests
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_local_weight:
label_requirement.append(

Check warning on line 143 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L141-L143

Added lines #L141 - L143 were not covered by tests
DataRequirementItem(
"atomic_" + self.label_name,
ndof=self.tensor_size,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_global_weight:
label_requirement.append(

Check warning on line 153 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L152-L153

Added lines #L152 - L153 were not covered by tests
DataRequirementItem(
self.label_name,
ndof=self.tensor_size,
atomic=False,
must=False,
high_prec=False,
)
)
return label_requirement

Check warning on line 162 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L162

Added line #L162 was not covered by tests
7 changes: 6 additions & 1 deletion deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@
if self.atom_excl is not None:
atom_mask = self.atom_excl(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]
out_shape = ret_dict[kk].shape
ret_dict[kk] = (

Check warning on line 93 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L92-L93

Added lines #L92 - L93 were not covered by tests
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask

Check warning on line 97 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L97

Added line #L97 was not covered by tests

return ret_dict

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["mixed_types"] = descriptor.mixed_types()
fitting_net["embedding_width"] = descriptor.get_dim_out()
fitting_net["embedding_width"] = descriptor.get_dim_emb()

Check warning on line 97 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L97

Added line #L97 was not covered by tests
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze(
-3
)
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]

Check warning on line 54 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L53-L54

Added lines #L53 - L54 were not covered by tests
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
else:
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]

Check warning on line 67 in deepmd/pt/model/model/dp_zbl_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_zbl_model.py#L66-L67

Added lines #L66 - L67 were not covered by tests
return model_predict

@torch.jit.export
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
)
else:
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]

Check warning on line 56 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L55-L56

Added lines #L55 - L56 were not covered by tests
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@
fparam=fp,
aparam=ap,
)
nf, nloc = nlist.shape[:2]
if "mask" in model_predict_lower:
model_predict_lower["mask"] = model_predict_lower["mask"][:, :nloc]

Check warning on line 159 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L157-L159

Added lines #L157 - L159 were not covered by tests
model_predict = communicate_extended_output(
model_predict_lower,
self.model_output_def(),
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
model_predict = {}
model_predict["polar"] = model_ret["polar"]
model_predict["global_polar"] = model_ret["polar_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]

Check warning on line 46 in deepmd/pt/model/model/polar_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/polar_model.py#L45-L46

Added lines #L45 - L46 were not covered by tests
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@
redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION
model_ret = dict(fit_ret.items())
for kk, vv in fit_ret.items():
if kk in ["mask"]:
continue

Check warning on line 158 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L157-L158

Added lines #L157 - L158 were not covered by tests
vdef = fit_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
Expand Down Expand Up @@ -192,6 +194,8 @@
"""
redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION
new_ret = {}
if "mask" in model_ret:
new_ret["mask"] = model_ret["mask"]

Check warning on line 198 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L197-L198

Added lines #L197 - L198 were not covered by tests
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
vdef = model_output_def[kk]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
The path to the stat file.

"""
raise NotImplementedError
pass

Check warning on line 160 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L160

Added line #L160 was not covered by tests

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
The path to the stat file.

"""
raise NotImplementedError
pass

Check warning on line 187 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L187

Added line #L187 was not covered by tests

def forward(
self,
Expand Down
Loading
Loading