Skip to content

Commit

Permalink
pt: Add support for dipole and polar training (#3380)
Browse files Browse the repository at this point in the history
Signed-off-by: Duo <[email protected]>
  • Loading branch information
iProzd authored Mar 3, 2024
1 parent 8d0e3ba commit d4ac864
Show file tree
Hide file tree
Showing 90 changed files with 771 additions and 59 deletions.
6 changes: 5 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def forward_common_atomic(
if self.atom_excl is not None:
atom_mask = self.atom_excl.build_type_exclude_mask(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask

return ret_dict
Expand Down
12 changes: 6 additions & 6 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,16 @@ def get_dim_aparam(self) -> int:
@property
def model_type(self) -> Type["DeepEvalWrapper"]:
"""The the evaluator of the model type."""
model_type = self.dp.model_output_type()
if model_type == "energy":
model_output_type = self.dp.model_output_type()
if "energy" in model_output_type:
return DeepPot
elif model_type == "dos":
elif "dos" in model_output_type:
return DeepDOS
elif model_type == "dipole":
elif "dipole" in model_output_type:
return DeepDipole
elif model_type == "polar":
elif "polar" in model_output_type:
return DeepPolar
elif model_type == "wfc":
elif "wfc" in model_output_type:
return DeepWFC
else:
raise RuntimeError("Unknown model type")
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def is_aparam_nall(self) -> bool:
"""

@abstractmethod
def model_output_type(self) -> str:
def model_output_type(self) -> List[str]:
"""Get the output type for the model."""

@abstractmethod
Expand Down
9 changes: 2 additions & 7 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def model_output_def(self):
"""Get the output def for the model."""
return ModelOutputDef(self.atomic_output_def())

def model_output_type(self) -> str:
def model_output_type(self) -> List[str]:
"""Get the output type for the model."""
output_def = self.model_output_def()
var_defs = output_def.var_defs
Expand All @@ -85,12 +85,7 @@ def model_output_type(self) -> str:
for kk, vv in var_defs.items()
if vv.category == OutputVariableCategory.OUT
]
if len(vars) == 1:
return vars[0]
elif len(vars) == 0:
raise ValueError("No valid output type found")
else:
raise ValueError(f"Multiple valid output types found: {vars}")
return vars

def call(
self,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ def __init__(
):
self.name = name
self.shape = list(shape)
# jit doesn't support math.prod(self.shape)
self.output_size = 1
len_shape = len(self.shape)
for i in range(len_shape):
self.output_size *= self.shape[i]
self.atomic = atomic
self.reduciable = reduciable
self.r_differentiable = r_differentiable
Expand Down
14 changes: 7 additions & 7 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,18 @@ def get_dim_aparam(self) -> int:
@property
def model_type(self) -> "DeepEvalWrapper":
"""The the evaluator of the model type."""
model_type = self.dp.model["Default"].model_output_type()
if model_type == "energy":
model_output_type = self.dp.model["Default"].model_output_type()
if "energy" in model_output_type:
return DeepPot
elif model_type == "dos":
elif "dos" in model_output_type:
return DeepDOS
elif model_type == "dipole":
elif "dipole" in model_output_type:
return DeepDipole
elif model_type == "polar":
elif "polar" in model_output_type:
return DeepPolar
elif model_type == "global_polar":
elif "global_polar" in model_output_type:
return DeepGlobalPolar
elif model_type == "wfc":
elif "wfc" in model_output_type:
return DeepWFC
else:
raise RuntimeError("Unknown model type")
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from .loss import (
TaskLoss,
)
from .tensor import (
TensorLoss,
)

__all__ = [
"DenoiseLoss",
"EnergyStdLoss",
"TensorLoss",
"TaskLoss",
]
162 changes: 162 additions & 0 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

import torch

from deepmd.pt.loss.loss import (
TaskLoss,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class TensorLoss(TaskLoss):
def __init__(
self,
tensor_name: str,
tensor_size: int,
label_name: str,
pref_atomic: float = 0.0,
pref: float = 0.0,
inference=False,
**kwargs,
):
r"""Construct a loss for local and global tensors.
Parameters
----------
tensor_name : str
The name of the tensor in the model predictions to compute the loss.
tensor_size : int
The size (dimension) of the tensor.
label_name : str
The name of the tensor in the labels to compute the loss.
pref_atomic : float
The prefactor of the weight of atomic loss. It should be larger than or equal to 0.
pref : float
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.tensor_name = tensor_name
self.tensor_size = tensor_size
self.label_name = label_name
self.local_weight = pref_atomic
self.global_weight = pref
self.inference = inference

assert (
self.local_weight >= 0.0 and self.global_weight >= 0.0
), "Can not assign negative weight to `pref` and `pref_atomic`"
self.has_local_weight = self.local_weight > 0.0 or inference
self.has_global_weight = self.global_weight > 0.0 or inference
assert self.has_local_weight or self.has_global_weight, AssertionError(
"Can not assian zero weight both to `pref` and `pref_atomic`"
)

def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
"""Return loss on local and global tensors.
Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.
Returns
-------
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
del learning_rate, mae
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
more_loss = {}
if (
self.has_local_weight
and self.tensor_name in model_pred
and "atomic_" + self.label_name in label
):
local_tensor_pred = model_pred[self.tensor_name].reshape(
[-1, natoms, self.tensor_size]
)
local_tensor_label = label["atomic_" + self.label_name].reshape(
[-1, natoms, self.tensor_size]
)
diff = (local_tensor_pred - local_tensor_label).reshape(
[-1, self.tensor_size]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss = torch.mean(torch.square(diff))
if not self.inference:
more_loss[f"l2_local_{self.tensor_name}_loss"] = l2_local_loss.detach()
loss += self.local_weight * l2_local_loss
rmse_local = l2_local_loss.sqrt()
more_loss[f"rmse_local_{self.tensor_name}"] = rmse_local.detach()
if (
self.has_global_weight
and "global_" + self.tensor_name in model_pred
and self.label_name in label
):
global_tensor_pred = model_pred["global_" + self.tensor_name].reshape(
[-1, self.tensor_size]
)
global_tensor_label = label[self.label_name].reshape([-1, self.tensor_size])
diff = global_tensor_pred - global_tensor_label
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss = torch.mean(
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())
else:
atom_num = natoms
l2_global_loss = torch.mean(torch.square(diff))
if not self.inference:
more_loss[
f"l2_global_{self.tensor_name}_loss"
] = l2_global_loss.detach()
loss += self.global_weight * l2_global_loss
rmse_global = l2_global_loss.sqrt() / atom_num
more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach()
return loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_local_weight:
label_requirement.append(
DataRequirementItem(
"atomic_" + self.label_name,
ndof=self.tensor_size,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_global_weight:
label_requirement.append(
DataRequirementItem(
self.label_name,
ndof=self.tensor_size,
atomic=False,
must=False,
high_prec=False,
)
)
return label_requirement
6 changes: 5 additions & 1 deletion deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ def forward_common_atomic(
if self.atom_excl is not None:
atom_mask = self.atom_excl(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask

return ret_dict
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_model(model_params):
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["mixed_types"] = descriptor.mixed_types()
fitting_net["embedding_width"] = descriptor.get_dim_out()
fitting_net["embedding_width"] = descriptor.get_dim_emb()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def forward(
model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze(
-3
)
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def forward(
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
else:
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

@torch.jit.export
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def forward(
)
else:
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
9 changes: 2 additions & 7 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def model_output_def(self):
return ModelOutputDef(self.atomic_output_def())

@torch.jit.export
def model_output_type(self) -> str:
def model_output_type(self) -> List[str]:
"""Get the output type for the model."""
output_def = self.model_output_def()
var_defs = output_def.var_defs
Expand All @@ -86,12 +86,7 @@ def model_output_type(self) -> str:
# .value is critical for JIT
if vv.category == OutputVariableCategory.OUT.value:
vars.append(kk)
if len(vars) == 1:
return vars[0]
elif len(vars) == 0:
raise ValueError("No valid output type found")
else:
raise ValueError(f"Multiple valid output types found: {vars}")
return vars

# cannot use the name forward. torch script does not work
def forward_common(
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def forward(
model_predict = {}
model_predict["polar"] = model_ret["polar"]
model_predict["global_polar"] = model_ret["polar_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def compute_output_stats(
The path to the stat file.
"""
raise NotImplementedError
pass

def forward(
self,
Expand Down
Loading

0 comments on commit d4ac864

Please sign in to comment.