Skip to content

Commit

Permalink
try to fix the tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Mar 3, 2024
1 parent 54f5826 commit 8636dc1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
5 changes: 5 additions & 0 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,8 @@ class DataRequirementItem:
default value of data
dtype : np.dtype, optional
the dtype of data, overwrites `high_prec` if provided
output_natoms_for_type_sel : bool, optional
if True and type_sel is True, the atomic dimension will be natoms instead of nsel
"""

def __init__(
Expand All @@ -754,6 +756,7 @@ def __init__(
repeat: int = 1,
default: float = 0.0,
dtype: Optional[np.dtype] = None,
output_natoms_for_type_sel: bool = False,
) -> None:
self.key = key
self.ndof = ndof
Expand All @@ -764,6 +767,7 @@ def __init__(
self.repeat = repeat
self.default = default
self.dtype = dtype
self.output_natoms_for_type_sel = output_natoms_for_type_sel
self.dict = self.to_dict()

def to_dict(self) -> dict:
Expand All @@ -777,6 +781,7 @@ def to_dict(self) -> dict:
"repeat": self.repeat,
"default": self.default,
"dtype": self.dtype,
"output_natoms_for_type_sel": self.output_natoms_for_type_sel,
}

def __getitem__(self, key: str):
Expand Down
9 changes: 8 additions & 1 deletion deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,10 @@ def add_dict(self, adict: dict) -> None:
type_sel=adict[kk]["type_sel"],
repeat=adict[kk]["repeat"],
default=adict[kk]["default"],
output_natoms_for_type_sel=adict[kk]["output_natoms_for_type_sel"],
dtype=adict[kk].get("dtype"),
output_natoms_for_type_sel=adict[kk].get(
"output_natoms_for_type_sel", False
),
)

def add(
Expand All @@ -306,6 +309,7 @@ def add(
type_sel: Optional[List[int]] = None,
repeat: int = 1,
default: float = 0.0,
dtype: Optional[np.dtype] = None,
output_natoms_for_type_sel: bool = False,
):
"""Add a data item that to be loaded.
Expand All @@ -331,6 +335,8 @@ def add(
The data will be repeated `repeat` times.
default, default=0.
Default value of data
dtype
The dtype of data, overwrites `high_prec` if provided
output_natoms_for_type_sel : bool
If True and type_sel is True, the atomic dimension will be natoms instead of nsel
"""
Expand All @@ -344,6 +350,7 @@ def add(
repeat=repeat,
type_sel=type_sel,
default=default,
dtype=dtype,
output_natoms_for_type_sel=output_natoms_for_type_sel,
)

Expand Down

0 comments on commit 8636dc1

Please sign in to comment.