Skip to content

Commit

Permalink
chore: improve type anotations in deepmd.infer (#3792)
Browse files Browse the repository at this point in the history
Fix several incorrect type anotations.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced flexibility in function and method parameters, allowing for
more versatile use cases.
  
- **Improvements**
- Streamlined type annotations for improved code maintainability and
readability.
- Updated import statements for better module organization and
efficiency.

- **Bug Fixes**
- Corrected parameter types to ensure proper handling of optional and
varied input types.

These changes aim to improve the overall usability and robustness of the
application, making it more adaptable to different scenarios.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored May 17, 2024
1 parent d62a41f commit 42724ce
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 43 deletions.
8 changes: 4 additions & 4 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def __init__(
self,
model_file: str,
output_def: ModelOutputDef,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
):
self.output_def = output_def
self.model_path = model_file
Expand Down Expand Up @@ -161,12 +161,12 @@ def get_ntypes_spin(self):
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.
Expand Down
9 changes: 1 addition & 8 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,7 @@
)

if TYPE_CHECKING:
from deepmd.tf.infer import (
DeepDipole,
DeepDOS,
DeepPolar,
DeepPot,
DeepWFC,
)
from deepmd.tf.infer.deep_tensor import (
from deepmd.infer.deep_tensor import (
DeepTensor,
)

Expand Down
3 changes: 1 addition & 2 deletions deepmd/infer/deep_dos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -70,7 +69,7 @@ def eval(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Tuple[np.ndarray, ...]:
"""Evaluate energy, force, and virial. If atomic is True,
also return atomic energy and atomic virial.
Expand Down
23 changes: 12 additions & 11 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -82,10 +83,10 @@ def __init__(
self,
model_file: str,
output_def: ModelOutputDef,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
pass

Expand All @@ -99,12 +100,12 @@ def __new__(cls, model_file: str, *args, **kwargs):
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.
Expand Down Expand Up @@ -166,13 +167,13 @@ def get_dim_aparam(self) -> int:
def eval_descriptor(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
efield: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> np.ndarray:
"""Evaluate descriptors by using this DP.
Expand Down Expand Up @@ -246,11 +247,11 @@ def _check_mixed_types(self, atom_types: np.ndarray) -> bool:
# assume mixed_types if there are virtual types, even when
# the atom types of all frames are the same
return False
return np.all(np.equal(atom_types, atom_types[0]))
return np.all(np.equal(atom_types, atom_types[0])).item()

@property
@abstractmethod
def model_type(self) -> "DeepEval":
def model_type(self) -> Type["DeepEval"]:
"""The the evaluator of the model type."""

@abstractmethod
Expand Down Expand Up @@ -316,10 +317,10 @@ def __new__(cls, model_file: str, *args, **kwargs):
def __init__(
self,
model_file: str,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
self.deep_eval = DeepEvalBackend(
model_file,
Expand Down Expand Up @@ -387,7 +388,7 @@ def eval_descriptor(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> np.ndarray:
"""Evaluate descriptors by using this DP.
Expand Down
2 changes: 1 addition & 1 deletion deepmd/infer/deep_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def eval(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: dict,
**kwargs,
) -> np.ndarray:
"""Evaluate the model.
Expand Down
47 changes: 45 additions & 2 deletions deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
overload,
)

import numpy as np
Expand Down Expand Up @@ -89,6 +90,48 @@ def output_def_mag(self) -> ModelOutputDef:
)
)

@overload
def eval(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: Union[List[int], np.ndarray],
atomic: Literal[True],
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
pass

@overload
def eval(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: Union[List[int], np.ndarray],
atomic: Literal[False],
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
pass

@overload
def eval(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: Union[List[int], np.ndarray],
atomic: bool,
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> Tuple[np.ndarray, ...]:
pass

def eval(
self,
coords: np.ndarray,
Expand All @@ -98,7 +141,7 @@ def eval(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Tuple[np.ndarray, ...]:
"""Evaluate energy, force, and virial. If atomic is True,
also return atomic energy and atomic virial.
Expand Down
14 changes: 11 additions & 3 deletions deepmd/infer/model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
atomic: Literal[False] = False,
atomic: Literal[False] = ...,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: ...


Expand All @@ -37,11 +37,19 @@ def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
*,
atomic: Literal[True],
atomic: Literal[True] = ...,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ...


@overload
def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
atomic: bool = False,
) -> Tuple[np.ndarray, ...]: ...


def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
Expand Down
11 changes: 6 additions & 5 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
List,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -87,11 +88,11 @@ def __init__(
self,
model_file: str,
output_def: ModelOutputDef,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
head: Optional[str] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
):
self.output_def = output_def
self.model_path = model_file
Expand Down Expand Up @@ -165,7 +166,7 @@ def get_dim_aparam(self) -> int:
return self.dp.model["Default"].get_dim_aparam()

@property
def model_type(self) -> "DeepEvalWrapper":
def model_type(self) -> Type["DeepEvalWrapper"]:
"""The the evaluator of the model type."""
model_output_type = self.dp.model["Default"].model_output_type()
if "energy" in model_output_type:
Expand Down Expand Up @@ -211,12 +212,12 @@ def get_has_spin(self):
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.
Expand Down
11 changes: 6 additions & 5 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
List,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -262,7 +263,7 @@ def _init_attr(self):

@property
@lru_cache(maxsize=None)
def model_type(self) -> "DeepEvalWrapper":
def model_type(self) -> Type["DeepEvalWrapper"]:
"""Get type of model.
:type:str
Expand Down Expand Up @@ -693,13 +694,13 @@ def _get_natoms_and_nframes(
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
efield: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.
Expand Down Expand Up @@ -1023,7 +1024,7 @@ def _get_output_shape(self, odef, nframes, natoms):
def eval_descriptor(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -1080,7 +1081,7 @@ def eval_descriptor(
def _eval_descriptor_inner(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/infer/deep_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_dim_aparam(self) -> int:
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: List[int],
atomic: bool = True,
fparam: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -276,7 +276,7 @@ def eval(
def eval_full(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: List[int],
atomic: bool = False,
fparam: Optional[np.array] = None,
Expand Down

0 comments on commit 42724ce

Please sign in to comment.