Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: Add support for dipole and polar training #3380

Merged
merged 11 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def forward_common_atomic(
if self.atom_excl is not None:
atom_mask = self.atom_excl.build_type_exclude_mask(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask

return ret_dict
Expand Down
12 changes: 6 additions & 6 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,16 @@
@property
def model_type(self) -> Type["DeepEvalWrapper"]:
"""The the evaluator of the model type."""
model_type = self.dp.model_output_type()
if model_type == "energy":
model_output_type = self.dp.model_output_type()
if "energy" in model_output_type:
return DeepPot
elif model_type == "dos":
elif "dos" in model_output_type:

Check warning on line 129 in deepmd/dpmodel/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/infer/deep_eval.py#L129

Added line #L129 was not covered by tests
return DeepDOS
elif model_type == "dipole":
elif "dipole" in model_output_type:

Check warning on line 131 in deepmd/dpmodel/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/infer/deep_eval.py#L131

Added line #L131 was not covered by tests
return DeepDipole
elif model_type == "polar":
elif "polar" in model_output_type:

Check warning on line 133 in deepmd/dpmodel/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/infer/deep_eval.py#L133

Added line #L133 was not covered by tests
return DeepPolar
elif model_type == "wfc":
elif "wfc" in model_output_type:

Check warning on line 135 in deepmd/dpmodel/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/infer/deep_eval.py#L135

Added line #L135 was not covered by tests
return DeepWFC
else:
raise RuntimeError("Unknown model type")
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def is_aparam_nall(self) -> bool:
"""

@abstractmethod
def model_output_type(self) -> str:
def model_output_type(self) -> List[str]:
"""Get the output type for the model."""

@abstractmethod
Expand Down
9 changes: 2 additions & 7 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def model_output_def(self):
"""Get the output def for the model."""
return ModelOutputDef(self.atomic_output_def())

def model_output_type(self) -> str:
def model_output_type(self) -> List[str]:
"""Get the output type for the model."""
output_def = self.model_output_def()
var_defs = output_def.var_defs
Expand All @@ -85,12 +85,7 @@ def model_output_type(self) -> str:
for kk, vv in var_defs.items()
if vv.category == OutputVariableCategory.OUT
]
if len(vars) == 1:
return vars[0]
elif len(vars) == 0:
raise ValueError("No valid output type found")
else:
raise ValueError(f"Multiple valid output types found: {vars}")
return vars

def call(
self,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ def __init__(
):
self.name = name
self.shape = list(shape)
# jit doesn't support math.prod(self.shape)
self.output_size = 1
len_shape = len(self.shape)
for i in range(len_shape):
self.output_size *= self.shape[i]
self.atomic = atomic
self.reduciable = reduciable
self.r_differentiable = r_differentiable
Expand Down
14 changes: 7 additions & 7 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,18 @@
@property
def model_type(self) -> "DeepEvalWrapper":
"""The the evaluator of the model type."""
model_type = self.dp.model["Default"].model_output_type()
if model_type == "energy":
model_output_type = self.dp.model["Default"].model_output_type()
if "energy" in model_output_type:
return DeepPot
elif model_type == "dos":
elif "dos" in model_output_type:

Check warning on line 154 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L154

Added line #L154 was not covered by tests
return DeepDOS
elif model_type == "dipole":
elif "dipole" in model_output_type:

Check warning on line 156 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L156

Added line #L156 was not covered by tests
return DeepDipole
elif model_type == "polar":
elif "polar" in model_output_type:

Check warning on line 158 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L158

Added line #L158 was not covered by tests
return DeepPolar
elif model_type == "global_polar":
elif "global_polar" in model_output_type:

Check warning on line 160 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L160

Added line #L160 was not covered by tests
return DeepGlobalPolar
elif model_type == "wfc":
elif "wfc" in model_output_type:

Check warning on line 162 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L162

Added line #L162 was not covered by tests
return DeepWFC
else:
raise RuntimeError("Unknown model type")
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from .loss import (
TaskLoss,
)
from .tensor import (
TensorLoss,
)

__all__ = [
"DenoiseLoss",
"EnergyStdLoss",
"TensorLoss",
"TaskLoss",
]
162 changes: 162 additions & 0 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

import torch

from deepmd.pt.loss.loss import (
TaskLoss,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class TensorLoss(TaskLoss):
def __init__(
self,
tensor_name: str,
tensor_size: int,
label_name: str,
pref_atomic: float = 0.0,
pref: float = 0.0,
inference=False,
**kwargs,
):
r"""Construct a loss for local and global tensors.

Parameters
----------
tensor_name : str
The name of the tensor in the model predictions to compute the loss.
tensor_size : int
The size (dimension) of the tensor.
label_name : str
The name of the tensor in the labels to compute the loss.
pref_atomic : float
The prefactor of the weight of atomic loss. It should be larger than or equal to 0.
pref : float
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.tensor_name = tensor_name
self.tensor_size = tensor_size
self.label_name = label_name
self.local_weight = pref_atomic
self.global_weight = pref
self.inference = inference

assert (
self.local_weight >= 0.0 and self.global_weight >= 0.0
), "Can not assign negative weight to `pref` and `pref_atomic`"
self.has_local_weight = self.local_weight > 0.0 or inference
self.has_global_weight = self.global_weight > 0.0 or inference
assert self.has_local_weight or self.has_global_weight, AssertionError(
"Can not assian zero weight both to `pref` and `pref_atomic`"
)

def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
"""Return loss on local and global tensors.

Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
del learning_rate, mae
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
more_loss = {}
if (
self.has_local_weight
and self.tensor_name in model_pred
and "atomic_" + self.label_name in label
):
local_tensor_pred = model_pred[self.tensor_name].reshape(
[-1, natoms, self.tensor_size]
)
local_tensor_label = label["atomic_" + self.label_name].reshape(
[-1, natoms, self.tensor_size]
)
diff = (local_tensor_pred - local_tensor_label).reshape(
[-1, self.tensor_size]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss = torch.mean(torch.square(diff))
if not self.inference:
more_loss[f"l2_local_{self.tensor_name}_loss"] = l2_local_loss.detach()
loss += self.local_weight * l2_local_loss
rmse_local = l2_local_loss.sqrt()
more_loss[f"rmse_local_{self.tensor_name}"] = rmse_local.detach()
if (
self.has_global_weight
and "global_" + self.tensor_name in model_pred
and self.label_name in label
):
global_tensor_pred = model_pred["global_" + self.tensor_name].reshape(
[-1, self.tensor_size]
)
global_tensor_label = label[self.label_name].reshape([-1, self.tensor_size])
diff = global_tensor_pred - global_tensor_label
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss = torch.mean(
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())
else:
atom_num = natoms
l2_global_loss = torch.mean(torch.square(diff))

Check warning on line 128 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L127-L128

Added lines #L127 - L128 were not covered by tests
if not self.inference:
more_loss[
f"l2_global_{self.tensor_name}_loss"
] = l2_global_loss.detach()
loss += self.global_weight * l2_global_loss
rmse_global = l2_global_loss.sqrt() / atom_num
more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach()
return loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_local_weight:
label_requirement.append(
DataRequirementItem(
"atomic_" + self.label_name,
ndof=self.tensor_size,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_global_weight:
label_requirement.append(
DataRequirementItem(
self.label_name,
ndof=self.tensor_size,
atomic=False,
must=False,
high_prec=False,
)
)
return label_requirement
6 changes: 5 additions & 1 deletion deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ def forward_common_atomic(
if self.atom_excl is not None:
atom_mask = self.atom_excl(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask

return ret_dict
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_model(model_params):
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["mixed_types"] = descriptor.mixed_types()
fitting_net["embedding_width"] = descriptor.get_dim_out()
fitting_net["embedding_width"] = descriptor.get_dim_emb()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def forward(
model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze(
-3
)
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def forward(
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
else:
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

@torch.jit.export
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def forward(
)
else:
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
9 changes: 2 additions & 7 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def model_output_def(self):
return ModelOutputDef(self.atomic_output_def())

@torch.jit.export
def model_output_type(self) -> str:
def model_output_type(self) -> List[str]:
"""Get the output type for the model."""
output_def = self.model_output_def()
var_defs = output_def.var_defs
Expand All @@ -86,12 +86,7 @@ def model_output_type(self) -> str:
# .value is critical for JIT
if vv.category == OutputVariableCategory.OUT.value:
vars.append(kk)
if len(vars) == 1:
return vars[0]
elif len(vars) == 0:
raise ValueError("No valid output type found")
else:
raise ValueError(f"Multiple valid output types found: {vars}")
return vars

# cannot use the name forward. torch script does not work
def forward_common(
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def forward(
model_predict = {}
model_predict["polar"] = model_ret["polar"]
model_predict["global_polar"] = model_ret["polar_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def compute_output_stats(
The path to the stat file.

"""
raise NotImplementedError
pass

def forward(
self,
Expand Down
Loading
Loading