diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index e40b2f4e53..194c6b1e24 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -741,6 +741,8 @@ class DataRequirementItem: default value of data dtype : np.dtype, optional the dtype of data, overwrites `high_prec` if provided + output_natoms_for_type_sel : bool, optional + if True and type_sel is True, the atomic dimension will be natoms instead of nsel """ def __init__( @@ -754,6 +756,7 @@ def __init__( repeat: int = 1, default: float = 0.0, dtype: Optional[np.dtype] = None, + output_natoms_for_type_sel: bool = False, ) -> None: self.key = key self.ndof = ndof @@ -764,6 +767,7 @@ def __init__( self.repeat = repeat self.default = default self.dtype = dtype + self.output_natoms_for_type_sel = output_natoms_for_type_sel self.dict = self.to_dict() def to_dict(self) -> dict: @@ -777,6 +781,7 @@ def to_dict(self) -> dict: "repeat": self.repeat, "default": self.default, "dtype": self.dtype, + "output_natoms_for_type_sel": self.output_natoms_for_type_sel, } def __getitem__(self, key: str): diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 369857b2d1..0c74abfed1 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -293,7 +293,10 @@ def add_dict(self, adict: dict) -> None: type_sel=adict[kk]["type_sel"], repeat=adict[kk]["repeat"], default=adict[kk]["default"], - output_natoms_for_type_sel=adict[kk]["output_natoms_for_type_sel"], + dtype=adict[kk].get("dtype"), + output_natoms_for_type_sel=adict[kk].get( + "output_natoms_for_type_sel", False + ), ) def add( @@ -306,6 +309,7 @@ def add( type_sel: Optional[List[int]] = None, repeat: int = 1, default: float = 0.0, + dtype: Optional[np.dtype] = None, output_natoms_for_type_sel: bool = False, ): """Add a data item that to be loaded. @@ -331,6 +335,8 @@ def add( The data will be repeated `repeat` times. default, default=0. Default value of data + dtype + The dtype of data, overwrites `high_prec` if provided output_natoms_for_type_sel : bool If True and type_sel is True, the atomic dimension will be natoms instead of nsel """ @@ -344,6 +350,7 @@ def add( repeat=repeat, type_sel=type_sel, default=default, + dtype=dtype, output_natoms_for_type_sel=output_natoms_for_type_sel, )