diff --git a/deepmd/dpmodel/model/dp_model.py b/deepmd/dpmodel/model/dp_model.py index 804ce51dfd..705750414b 100644 --- a/deepmd/dpmodel/model/dp_model.py +++ b/deepmd/dpmodel/model/dp_model.py @@ -14,4 +14,6 @@ # use "class" to resolve "Variable not allowed in type expression" @BaseModel.register("standard") class DPModel(make_model(DPAtomicModel), BaseModel): - pass + def data_requirement(self) -> dict: + """Get the data requirement for the model.""" + raise NotImplementedError diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 68ce09a080..08a3673a8c 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging from typing import ( + Callable, List, Optional, + Union, ) import torch @@ -91,7 +93,7 @@ def __init__( self.r_differentiable = r_differentiable self.c_differentiable = c_differentiable super().__init__( - var_name="dipole", + var_name="dipole" if "var_name" not in kwargs else kwargs.pop("var_name"), ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -134,7 +136,11 @@ def output_def(self) -> FittingOutputDef: ] ) - def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None): + def compute_output_stats( + self, + merged: Union[Callable, List[dict]], + stat_file_path: Optional[DPPath] = None, + ): raise NotImplementedError def forward( diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index ed9d517763..55ee79db25 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -2,9 +2,11 @@ import copy import logging from typing import ( + Callable, List, Optional, Tuple, + Union, ) import numpy as np @@ -138,13 +140,21 @@ def serialize(self) -> dict: data["atom_ener"] = self.atom_ener return data - def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None): + def compute_output_stats( + self, + merged: Union[Callable, List[dict]], + stat_file_path: Optional[DPPath] = None, + ): if stat_file_path is not None: stat_file_path = stat_file_path / "bias_atom_e" if stat_file_path is not None and stat_file_path.is_file(): bias_atom_e = stat_file_path.load_numpy() else: - sampled = merged() + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged energy = [item["energy"] for item in sampled] data_mixed_type = "real_natoms_vec" in sampled[0] if data_mixed_type: diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 1bc4798c48..0fe817084e 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging from typing import ( + Callable, List, Optional, Union, @@ -24,6 +25,9 @@ from deepmd.pt.utils.utils import ( to_numpy_array, ) +from deepmd.utils.path import ( + DPPath, +) log = logging.getLogger(__name__) @@ -72,7 +76,6 @@ class PolarFittingNet(GeneralFitting): def __init__( self, - var_name: str, ntypes: int, dim_descrpt: int, embedding_width: int, @@ -112,7 +115,7 @@ def __init__( ).view(ntypes, 1) self.shift_diag = shift_diag super().__init__( - var_name=var_name, + var_name="polar" if "var_name" not in kwargs else kwargs.pop("var_name"), ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -160,6 +163,13 @@ def output_def(self) -> FittingOutputDef: ] ) + def compute_output_stats( + self, + merged: Union[Callable, List[dict]], + stat_file_path: Optional[DPPath] = None, + ): + raise NotImplementedError + def forward( self, descriptor: torch.Tensor, diff --git a/source/tests/pt/model/test_descriptor.py b/source/tests/pt/model/test_descriptor.py index ffad27201a..7d21d1c13d 100644 --- a/source/tests/pt/model/test_descriptor.py +++ b/source/tests/pt/model/test_descriptor.py @@ -38,6 +38,9 @@ op_module, ) +from ..test_stat import ( + energy_data_requirement, +) from .test_embedding_net import ( get_single_batch, ) @@ -114,6 +117,7 @@ def setUp(self): self.systems[0], model_config["type_map"], ) + ds.add_data_requirement(energy_data_requirement) self.np_batch, self.pt_batch = get_single_batch(ds) self.sec = np.cumsum(self.sel) self.ntypes = len(self.sel) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 83054f1042..fa4be9171c 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -114,12 +114,12 @@ def test_consistency( ) ret2 = ft2(rd0, atype, gr, fparam=ifp, aparam=iap) np.testing.assert_allclose( - to_numpy_array(ret0["foo"]), - ret1["foo"], + to_numpy_array(ret0["dipole"]), + ret1["dipole"], ) np.testing.assert_allclose( - to_numpy_array(ret0["foo"]), - to_numpy_array(ret2["foo"]), + to_numpy_array(ret0["dipole"]), + to_numpy_array(ret2["dipole"]), ) def test_jit( @@ -206,7 +206,7 @@ def test_rot(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) - res.append(ret0["foo"]) + res.append(ret0["dipole"]) np.testing.assert_allclose( to_numpy_array(res[1]), to_numpy_array(torch.matmul(res[0], rmat)) @@ -241,7 +241,7 @@ def test_permu(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) - res.append(ret0["foo"]) + res.append(ret0["dipole"]) np.testing.assert_allclose( to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1]) @@ -281,7 +281,7 @@ def test_trans(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) - res.append(ret0["foo"]) + res.append(ret0["dipole"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) diff --git a/source/tests/pt/model/test_embedding_net.py b/source/tests/pt/model/test_embedding_net.py index 87e8a97444..a1895718dd 100644 --- a/source/tests/pt/model/test_embedding_net.py +++ b/source/tests/pt/model/test_embedding_net.py @@ -39,6 +39,10 @@ ) from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf +from ..test_stat import ( + energy_data_requirement, +) + CUR_DIR = os.path.dirname(__file__) @@ -128,6 +132,7 @@ def setUp(self): self.systems[0], model_config["type_map"], ) + ds.add_data_requirement(energy_data_requirement) self.filter_neuron = model_config["descriptor"]["neuron"] self.axis_neuron = model_config["descriptor"]["axis_neuron"] self.np_batch, self.torch_batch = get_single_batch(ds) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index f76a9e28ac..3b55f8bc05 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -67,7 +67,6 @@ def test_consistency( [None, self.scale], ): ft0 = PolarFittingNet( - "foo", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -113,16 +112,16 @@ def test_consistency( aparam=to_numpy_array(iap), ) np.testing.assert_allclose( - to_numpy_array(ret0["foo"]), - ret1["foo"], + to_numpy_array(ret0["polar"]), + ret1["polar"], ) np.testing.assert_allclose( - to_numpy_array(ret0["foo"]), - to_numpy_array(ret2["foo"]), + to_numpy_array(ret0["polar"]), + to_numpy_array(ret2["polar"]), ) np.testing.assert_allclose( - to_numpy_array(ret0["foo"]), - ret3["foo"], + to_numpy_array(ret0["polar"]), + ret3["polar"], ) def test_jit( @@ -135,7 +134,6 @@ def test_jit( [True, False], ): ft0 = PolarFittingNet( - "foo", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -177,7 +175,6 @@ def test_rot(self): [None, self.scale], ): ft0 = PolarFittingNet( - "foo", self.nt, self.dd0.dim_out, # dim_descrpt embedding_width=self.dd0.get_dim_emb(), @@ -220,7 +217,7 @@ def test_rot(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) - res.append(ret0["foo"]) + res.append(ret0["polar"]) np.testing.assert_allclose( to_numpy_array(res[1]), to_numpy_array( @@ -235,7 +232,6 @@ def test_permu(self): coord = torch.matmul(self.coord, self.cell) for fit_diag, scale in itertools.product([True, False], [None, self.scale]): ft0 = PolarFittingNet( - "foo", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -264,7 +260,7 @@ def test_permu(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=None, aparam=None) - res.append(ret0["foo"]) + res.append(ret0["polar"]) np.testing.assert_allclose( to_numpy_array(res[0][:, idx_perm]), @@ -281,7 +277,6 @@ def test_trans(self): ) for fit_diag, scale in itertools.product([True, False], [None, self.scale]): ft0 = PolarFittingNet( - "foo", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -309,7 +304,7 @@ def test_trans(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) - res.append(ret0["foo"]) + res.append(ret0["polar"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index e117c7f05a..484d62a3ad 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -28,6 +28,9 @@ from .model.test_embedding_net import ( get_single_batch, ) +from .test_stat import ( + energy_data_requirement, +) CUR_DIR = os.path.dirname(__file__) @@ -47,6 +50,7 @@ def get_batch(): if isinstance(systems, str): systems = expand_sys_str(systems) dataset = DeepmdDataSetForLoader(systems[0], model_config["type_map"]) + dataset.add_data_requirement(energy_data_requirement) np_batch, pt_batch = get_single_batch(dataset) return np_batch, pt_batch diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index 318b2e042f..54810fcc8f 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -47,6 +47,40 @@ CUR_DIR = os.path.dirname(__file__) +energy_data_requirement = { + "energy": { + "ndof": 1, + "atomic": False, + "must": False, + "high_prec": True, + }, + "force": { + "ndof": 3, + "atomic": True, + "must": False, + "high_prec": False, + }, + "virial": { + "ndof": 9, + "atomic": False, + "must": False, + "high_prec": False, + }, + "atom_ener": { + "ndof": 1, + "atomic": True, + "must": False, + "high_prec": False, + }, + "atom_pref": { + "ndof": 1, + "atomic": True, + "must": False, + "high_prec": False, + "repeat": 3, + }, +} + def compare(ut, base, given): if isinstance(base, list): @@ -111,6 +145,7 @@ def setUp(self): self.filter_neuron = model_config["descriptor"]["neuron"] self.axis_neuron = model_config["descriptor"]["axis_neuron"] self.n_neuron = model_config["fitting_net"]["neuron"] + self.my_dataset.add_data_requirement(energy_data_requirement) self.my_sampled = my_make( self.my_dataset.systems, self.my_dataset.dataloaders, self.data_stat_nbatch