Skip to content

Commit

Permalink
Fix uts
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 28, 2024
1 parent 75da5b1 commit 6c171c5
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 28 deletions.
4 changes: 3 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@
# use "class" to resolve "Variable not allowed in type expression"
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
pass
def data_requirement(self) -> dict:
"""Get the data requirement for the model."""
raise NotImplementedError

Check warning on line 19 in deepmd/dpmodel/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/dp_model.py#L19

Added line #L19 was not covered by tests
10 changes: 8 additions & 2 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
Callable,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -91,7 +93,7 @@ def __init__(
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
super().__init__(
var_name="dipole",
var_name="dipole" if "var_name" not in kwargs else kwargs.pop("var_name"),
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
Expand Down Expand Up @@ -134,7 +136,11 @@ def output_def(self) -> FittingOutputDef:
]
)

def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None):
def compute_output_stats(

Check warning on line 139 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L139

Added line #L139 was not covered by tests
self,
merged: Union[Callable, List[dict]],
stat_file_path: Optional[DPPath] = None,
):
raise NotImplementedError

Check warning on line 144 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L144

Added line #L144 was not covered by tests

def forward(
Expand Down
14 changes: 12 additions & 2 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import copy
import logging
from typing import (
Callable,
List,
Optional,
Tuple,
Union,
)

import numpy as np
Expand Down Expand Up @@ -138,13 +140,21 @@ def serialize(self) -> dict:
data["atom_ener"] = self.atom_ener
return data

def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None):
def compute_output_stats(

Check warning on line 143 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L143

Added line #L143 was not covered by tests
self,
merged: Union[Callable, List[dict]],
stat_file_path: Optional[DPPath] = None,
):
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_atom_e"
if stat_file_path is not None and stat_file_path.is_file():
bias_atom_e = stat_file_path.load_numpy()
else:
sampled = merged()
if callable(merged):

Check warning on line 153 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L153

Added line #L153 was not covered by tests
# only get data for once
sampled = merged()

Check warning on line 155 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L155

Added line #L155 was not covered by tests
else:
sampled = merged
energy = [item["energy"] for item in sampled]
data_mixed_type = "real_natoms_vec" in sampled[0]
if data_mixed_type:
Expand Down
14 changes: 12 additions & 2 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
Callable,
List,
Optional,
Union,
Expand All @@ -24,6 +25,9 @@
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.path import (

Check warning on line 28 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L28

Added line #L28 was not covered by tests
DPPath,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,7 +76,6 @@ class PolarFittingNet(GeneralFitting):

def __init__(
self,
var_name: str,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
Expand Down Expand Up @@ -112,7 +115,7 @@ def __init__(
).view(ntypes, 1)
self.shift_diag = shift_diag
super().__init__(
var_name=var_name,
var_name="polar" if "var_name" not in kwargs else kwargs.pop("var_name"),
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
Expand Down Expand Up @@ -160,6 +163,13 @@ def output_def(self) -> FittingOutputDef:
]
)

def compute_output_stats(

Check warning on line 166 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L166

Added line #L166 was not covered by tests
self,
merged: Union[Callable, List[dict]],
stat_file_path: Optional[DPPath] = None,
):
raise NotImplementedError

Check warning on line 171 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L171

Added line #L171 was not covered by tests

def forward(
self,
descriptor: torch.Tensor,
Expand Down
4 changes: 4 additions & 0 deletions source/tests/pt/model/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
op_module,
)

from ..test_stat import (
energy_data_requirement,
)
from .test_embedding_net import (
get_single_batch,
)
Expand Down Expand Up @@ -114,6 +117,7 @@ def setUp(self):
self.systems[0],
model_config["type_map"],
)
ds.add_data_requirement(energy_data_requirement)
self.np_batch, self.pt_batch = get_single_batch(ds)
self.sec = np.cumsum(self.sel)
self.ntypes = len(self.sel)
Expand Down
14 changes: 7 additions & 7 deletions source/tests/pt/model/test_dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ def test_consistency(
)
ret2 = ft2(rd0, atype, gr, fparam=ifp, aparam=iap)
np.testing.assert_allclose(
to_numpy_array(ret0["foo"]),
ret1["foo"],
to_numpy_array(ret0["dipole"]),
ret1["dipole"],
)
np.testing.assert_allclose(
to_numpy_array(ret0["foo"]),
to_numpy_array(ret2["foo"]),
to_numpy_array(ret0["dipole"]),
to_numpy_array(ret2["dipole"]),
)

def test_jit(
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_rot(self):
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap)
res.append(ret0["foo"])
res.append(ret0["dipole"])

np.testing.assert_allclose(
to_numpy_array(res[1]), to_numpy_array(torch.matmul(res[0], rmat))
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_permu(self):
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0)
res.append(ret0["foo"])
res.append(ret0["dipole"])

np.testing.assert_allclose(
to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1])
Expand Down Expand Up @@ -281,7 +281,7 @@ def test_trans(self):
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0)
res.append(ret0["foo"])
res.append(ret0["dipole"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))

Expand Down
5 changes: 5 additions & 0 deletions source/tests/pt/model/test_embedding_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
)
from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf

from ..test_stat import (
energy_data_requirement,
)

CUR_DIR = os.path.dirname(__file__)


Expand Down Expand Up @@ -128,6 +132,7 @@ def setUp(self):
self.systems[0],
model_config["type_map"],
)
ds.add_data_requirement(energy_data_requirement)
self.filter_neuron = model_config["descriptor"]["neuron"]
self.axis_neuron = model_config["descriptor"]["axis_neuron"]
self.np_batch, self.torch_batch = get_single_batch(ds)
Expand Down
23 changes: 9 additions & 14 deletions source/tests/pt/model/test_polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def test_consistency(
[None, self.scale],
):
ft0 = PolarFittingNet(
"foo",
self.nt,
self.dd0.dim_out,
embedding_width=self.dd0.get_dim_emb(),
Expand Down Expand Up @@ -113,16 +112,16 @@ def test_consistency(
aparam=to_numpy_array(iap),
)
np.testing.assert_allclose(
to_numpy_array(ret0["foo"]),
ret1["foo"],
to_numpy_array(ret0["polar"]),
ret1["polar"],
)
np.testing.assert_allclose(
to_numpy_array(ret0["foo"]),
to_numpy_array(ret2["foo"]),
to_numpy_array(ret0["polar"]),
to_numpy_array(ret2["polar"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["foo"]),
ret3["foo"],
to_numpy_array(ret0["polar"]),
ret3["polar"],
)

def test_jit(
Expand All @@ -135,7 +134,6 @@ def test_jit(
[True, False],
):
ft0 = PolarFittingNet(
"foo",
self.nt,
self.dd0.dim_out,
embedding_width=self.dd0.get_dim_emb(),
Expand Down Expand Up @@ -177,7 +175,6 @@ def test_rot(self):
[None, self.scale],
):
ft0 = PolarFittingNet(
"foo",
self.nt,
self.dd0.dim_out, # dim_descrpt
embedding_width=self.dd0.get_dim_emb(),
Expand Down Expand Up @@ -220,7 +217,7 @@ def test_rot(self):
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap)
res.append(ret0["foo"])
res.append(ret0["polar"])
np.testing.assert_allclose(
to_numpy_array(res[1]),
to_numpy_array(
Expand All @@ -235,7 +232,6 @@ def test_permu(self):
coord = torch.matmul(self.coord, self.cell)
for fit_diag, scale in itertools.product([True, False], [None, self.scale]):
ft0 = PolarFittingNet(
"foo",
self.nt,
self.dd0.dim_out,
embedding_width=self.dd0.get_dim_emb(),
Expand Down Expand Up @@ -264,7 +260,7 @@ def test_permu(self):
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=None, aparam=None)
res.append(ret0["foo"])
res.append(ret0["polar"])

np.testing.assert_allclose(
to_numpy_array(res[0][:, idx_perm]),
Expand All @@ -281,7 +277,6 @@ def test_trans(self):
)
for fit_diag, scale in itertools.product([True, False], [None, self.scale]):
ft0 = PolarFittingNet(
"foo",
self.nt,
self.dd0.dim_out,
embedding_width=self.dd0.get_dim_emb(),
Expand Down Expand Up @@ -309,7 +304,7 @@ def test_trans(self):
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0)
res.append(ret0["foo"])
res.append(ret0["polar"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))

Expand Down
4 changes: 4 additions & 0 deletions source/tests/pt/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from .model.test_embedding_net import (
get_single_batch,
)
from .test_stat import (
energy_data_requirement,
)

CUR_DIR = os.path.dirname(__file__)

Expand All @@ -47,6 +50,7 @@ def get_batch():
if isinstance(systems, str):
systems = expand_sys_str(systems)
dataset = DeepmdDataSetForLoader(systems[0], model_config["type_map"])
dataset.add_data_requirement(energy_data_requirement)
np_batch, pt_batch = get_single_batch(dataset)
return np_batch, pt_batch

Expand Down
35 changes: 35 additions & 0 deletions source/tests/pt/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,40 @@

CUR_DIR = os.path.dirname(__file__)

energy_data_requirement = {
"energy": {
"ndof": 1,
"atomic": False,
"must": False,
"high_prec": True,
},
"force": {
"ndof": 3,
"atomic": True,
"must": False,
"high_prec": False,
},
"virial": {
"ndof": 9,
"atomic": False,
"must": False,
"high_prec": False,
},
"atom_ener": {
"ndof": 1,
"atomic": True,
"must": False,
"high_prec": False,
},
"atom_pref": {
"ndof": 1,
"atomic": True,
"must": False,
"high_prec": False,
"repeat": 3,
},
}


def compare(ut, base, given):
if isinstance(base, list):
Expand Down Expand Up @@ -111,6 +145,7 @@ def setUp(self):
self.filter_neuron = model_config["descriptor"]["neuron"]
self.axis_neuron = model_config["descriptor"]["axis_neuron"]
self.n_neuron = model_config["fitting_net"]["neuron"]
self.my_dataset.add_data_requirement(energy_data_requirement)

self.my_sampled = my_make(
self.my_dataset.systems, self.my_dataset.dataloaders, self.data_stat_nbatch
Expand Down

0 comments on commit 6c171c5

Please sign in to comment.