Skip to content

Commit

Permalink
Merge branch 'devel' into enerhess
Browse files Browse the repository at this point in the history
Signed-off-by: Anchor Yu <[email protected]>
  • Loading branch information
1azyking authored Dec 25, 2024
2 parents 5f07d2b + beeb3d9 commit e95c140
Show file tree
Hide file tree
Showing 77 changed files with 5,664 additions and 217 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
from .polar_atomic_model import (
DPPolarAtomicModel,
)
from .property_atomic_model import (
DPPropertyAtomicModel,
)

__all__ = [
"BaseAtomicModel",
Expand All @@ -50,6 +53,7 @@
"DPDipoleAtomicModel",
"DPEnergyAtomicModel",
"DPPolarAtomicModel",
"DPPropertyAtomicModel",
"DPZBLLinearEnergyAtomicModel",
"LinearEnergyAtomicModel",
"PairTabAtomicModel",
Expand Down
24 changes: 24 additions & 0 deletions deepmd/dpmodel/atomic_model/property_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np

from deepmd.dpmodel.fitting.property_fitting import (
PropertyFittingNet,
)
Expand All @@ -15,3 +17,25 @@ def __init__(self, descriptor, fitting, type_map, **kwargs):
"fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel"
)
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
self,
ret: dict[str, np.ndarray],
atype: np.ndarray,
):
"""Apply the stat to each atomic output.
In property fitting, each output will be multiplied by label std and then plus the label average value.
Parameters
----------
ret
The returned dict by the forward_atomic method
atype
The atom types. nf x nloc. It is useless in property fitting.
"""
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
for kk in self.bias_keys:
ret[kk] = ret[kk] * out_std[kk][0] + out_bias[kk][0]
return ret
18 changes: 9 additions & 9 deletions deepmd/dpmodel/fitting/property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ class PropertyFittingNet(InvarFitting):
this list is of length :math:`N_l + 1`, specifying if the hidden layers and the output layer are trainable.
intensive
Whether the fitting property is intensive.
bias_method
The method of applying the bias to each atomic output, user can select 'normal' or 'no_bias'.
If 'normal' is used, the computed bias will be added to the atomic output.
If 'no_bias' is used, no bias will be added to the atomic output.
property_name:
The name of fitting property, which should be consistent with the property name in the dataset.
If the data file is named `humo.npy`, this parameter should be "humo".
resnet_dt
Time-step `dt` in the resnet construction:
:math:`y = x + dt * \phi (Wx + b)`
Expand Down Expand Up @@ -74,7 +73,7 @@ def __init__(
rcond: Optional[float] = None,
trainable: Union[bool, list[bool]] = True,
intensive: bool = False,
bias_method: str = "normal",
property_name: str = "property",
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
Expand All @@ -89,9 +88,8 @@ def __init__(
) -> None:
self.task_dim = task_dim
self.intensive = intensive
self.bias_method = bias_method
super().__init__(
var_name="property",
var_name=property_name,
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=task_dim,
Expand All @@ -113,9 +111,9 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "PropertyFittingNet":
data = data.copy()
check_version_compatibility(data.pop("@version"), 3, 1)
check_version_compatibility(data.pop("@version"), 4, 1)
data.pop("dim_out")
data.pop("var_name")
data["property_name"] = data.pop("var_name")
data.pop("tot_ener_zero")
data.pop("layer_name")
data.pop("use_aparam_as_mask", None)
Expand All @@ -131,6 +129,8 @@ def serialize(self) -> dict:
**InvarFitting.serialize(self),
"type": "property",
"task_dim": self.task_dim,
"intensive": self.intensive,
}
dd["@version"] = 4

return dd
6 changes: 3 additions & 3 deletions deepmd/dpmodel/model/property_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
from deepmd.dpmodel.atomic_model import (
DPPropertyAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
Expand All @@ -13,7 +13,7 @@
make_model,
)

DPPropertyModel_ = make_model(DPAtomicModel)
DPPropertyModel_ = make_model(DPPropertyAtomicModel)


@BaseModel.register("property")
Expand Down
20 changes: 14 additions & 6 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,9 +814,17 @@ def test_property(
tuple[list[np.ndarray], list[int]]
arrays with results and their shapes
"""
data.add("property", dp.task_dim, atomic=False, must=True, high_prec=True)
var_name = dp.get_var_name()
assert isinstance(var_name, str)
data.add(var_name, dp.task_dim, atomic=False, must=True, high_prec=True)
if has_atom_property:
data.add("atom_property", dp.task_dim, atomic=True, must=False, high_prec=True)
data.add(
f"atom_{var_name}",
dp.task_dim,
atomic=True,
must=False,
high_prec=True,
)

if dp.get_dim_fparam() > 0:
data.add(
Expand Down Expand Up @@ -867,12 +875,12 @@ def test_property(
aproperty = ret[1]
aproperty = aproperty.reshape([numb_test, natoms * dp.task_dim])

diff_property = property - test_data["property"][:numb_test]
diff_property = property - test_data[var_name][:numb_test]
mae_property = mae(diff_property)
rmse_property = rmse(diff_property)

if has_atom_property:
diff_aproperty = aproperty - test_data["atom_property"][:numb_test]
diff_aproperty = aproperty - test_data[f"atom_{var_name}"][:numb_test]
mae_aproperty = mae(diff_aproperty)
rmse_aproperty = rmse(diff_aproperty)

Expand All @@ -889,7 +897,7 @@ def test_property(
detail_path = Path(detail_file)

for ii in range(numb_test):
test_out = test_data["property"][ii].reshape(-1, 1)
test_out = test_data[var_name][ii].reshape(-1, 1)
pred_out = property[ii].reshape(-1, 1)

frame_output = np.hstack((test_out, pred_out))
Expand All @@ -903,7 +911,7 @@ def test_property(

if has_atom_property:
for ii in range(numb_test):
test_out = test_data["atom_property"][ii].reshape(-1, 1)
test_out = test_data[f"atom_{var_name}"][ii].reshape(-1, 1)
pred_out = aproperty[ii].reshape(-1, 1)

frame_output = np.hstack((test_out, pred_out))
Expand Down
6 changes: 4 additions & 2 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ class DeepEvalBackend(ABC):
"dipole_derv_c_redu": "virial",
"dos": "atom_dos",
"dos_redu": "dos",
"property": "atom_property",
"property_redu": "property",
"mask_mag": "mask_mag",
"mask": "mask",
# old models in v1
Expand Down Expand Up @@ -281,6 +279,10 @@ def get_has_hessian(self):
"""Check if the model has hessian."""
return False

def get_var_name(self) -> str:
"""Get the name of the fitting property."""
raise NotImplementedError

@abstractmethod
def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model. Only used in old implement."""
Expand Down
44 changes: 33 additions & 11 deletions deepmd/infer/deep_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,41 @@ class DeepProperty(DeepEval):
Keyword arguments.
"""

@property
def output_def(self) -> ModelOutputDef:
"""Get the output definition of this model."""
return ModelOutputDef(
"""
Get the output definition of this model.
But in property_fitting, the output definition is not known until the model is loaded.
So we need to rewrite the output definition after the model is loaded.
See detail in change_output_def.
"""
pass

def change_output_def(self) -> None:
"""
Change the output definition of this model.
In property_fitting, the output definition is known after the model is loaded.
We need to rewrite the output definition and related information.
"""
self.output_def = ModelOutputDef(
FittingOutputDef(
[
OutputVariableDef(
"property",
shape=[-1],
self.get_var_name(),
shape=[self.get_task_dim()],
reducible=True,
atomic=True,
intensive=self.get_intensive(),
),
]
)
)

def change_output_def(self) -> None:
self.output_def["property"].shape = self.task_dim
self.output_def["property"].intensive = self.get_intensive()
self.deep_eval.output_def = self.output_def
self.deep_eval._OUTDEF_DP2BACKEND[self.get_var_name()] = (
f"atom_{self.get_var_name()}"
)
self.deep_eval._OUTDEF_DP2BACKEND[f"{self.get_var_name()}_redu"] = (
self.get_var_name()
)

@property
def task_dim(self) -> int:
Expand Down Expand Up @@ -120,10 +136,12 @@ def eval(
aparam=aparam,
**kwargs,
)
atomic_property = results["property"].reshape(
atomic_property = results[self.get_var_name()].reshape(
nframes, natoms, self.get_task_dim()
)
property = results["property_redu"].reshape(nframes, self.get_task_dim())
property = results[f"{self.get_var_name()}_redu"].reshape(
nframes, self.get_task_dim()
)

if atomic:
return (
Expand All @@ -141,5 +159,9 @@ def get_intensive(self) -> bool:
"""Get whether the property is intensive."""
return self.deep_eval.get_intensive()

def get_var_name(self) -> str:
"""Get the name of the fitting property."""
return self.deep_eval.get_var_name()


__all__ = ["DeepProperty"]
2 changes: 1 addition & 1 deletion deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def prepare_trainer_input_single(

# validation and training data
# avoid the same batch sequence among devices
rank_seed = (seed + rank) % (2**32) if seed is not None else None
rank_seed = [rank, seed % (2**32)] if seed is not None else None
validation_data_single = (
DpLoaderSet(
validation_systems,
Expand Down
14 changes: 14 additions & 0 deletions deepmd/pd/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,34 @@
DescrptBlockSeAtten,
DescrptDPA1,
)
from .dpa2 import (
DescrptDPA2,
)
from .env_mat import (
prod_env_mat,
)
from .repformers import (
DescrptBlockRepformers,
)
from .se_a import (
DescrptBlockSeA,
DescrptSeA,
)
from .se_t_tebd import (
DescrptBlockSeTTebd,
DescrptSeTTebd,
)

__all__ = [
"BaseDescriptor",
"DescriptorBlock",
"DescrptBlockRepformers",
"DescrptBlockSeA",
"DescrptBlockSeAtten",
"DescrptBlockSeTTebd",
"DescrptDPA1",
"DescrptDPA2",
"DescrptSeA",
"DescrptSeTTebd",
"prod_env_mat",
]
Loading

0 comments on commit e95c140

Please sign in to comment.