-
Notifications
You must be signed in to change notification settings - Fork 526
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add universal Python inference interface DeepPot (#3164)
Need discussion for other classes. --------- Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
5 changed files
with
194 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from .deep_pot import ( | ||
DeepPot, | ||
) | ||
|
||
__all__ = ["DeepPot"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from enum import ( | ||
Enum, | ||
) | ||
|
||
|
||
class DPBackend(Enum): | ||
"""DeePMD-kit backend.""" | ||
|
||
TensorFlow = 1 | ||
PyTorch = 2 | ||
Paddle = 3 | ||
Unknown = 4 | ||
|
||
|
||
def detect_backend(filename: str) -> DPBackend: | ||
"""Detect the backend of the given model file. | ||
Parameters | ||
---------- | ||
filename : str | ||
The model file name | ||
""" | ||
if filename.endswith(".pb"): | ||
return DPBackend.TensorFlow | ||
elif filename.endswith(".pth") or filename.endswith(".pt"): | ||
return DPBackend.PyTorch | ||
elif filename.endswith(".pdmodel"): | ||
return DPBackend.Paddle | ||
return DPBackend.Unknown | ||
|
||
|
||
__all__ = ["DPBackend", "detect_backend"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from abc import ( | ||
ABC, | ||
abstractmethod, | ||
) | ||
from typing import ( | ||
List, | ||
Optional, | ||
Tuple, | ||
Union, | ||
) | ||
|
||
import numpy as np | ||
|
||
from deepmd_utils.utils.batch_size import ( | ||
AutoBatchSize, | ||
) | ||
|
||
from .backend import ( | ||
DPBackend, | ||
detect_backend, | ||
) | ||
|
||
|
||
class DeepPot(ABC): | ||
"""Potential energy model. | ||
Parameters | ||
---------- | ||
model_file : Path | ||
The name of the frozen model file. | ||
auto_batch_size : bool or int or AutoBatchSize, default: True | ||
If True, automatic batch size will be used. If int, it will be used | ||
as the initial batch size. | ||
neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional | ||
The ASE neighbor list class to produce the neighbor list. If None, the | ||
neighbor list will be built natively in the model. | ||
""" | ||
|
||
@abstractmethod | ||
def __init__( | ||
self, | ||
model_file, | ||
*args, | ||
auto_batch_size: Union[bool, int, AutoBatchSize] = True, | ||
neighbor_list=None, | ||
**kwargs, | ||
) -> None: | ||
pass | ||
|
||
def __new__(cls, model_file: str, *args, **kwargs): | ||
if cls is DeepPot: | ||
backend = detect_backend(model_file) | ||
if backend == DPBackend.TensorFlow: | ||
from deepmd.infer.deep_pot import DeepPot as DeepPotTF | ||
|
||
return super().__new__(DeepPotTF) | ||
elif backend == DPBackend.PyTorch: | ||
from deepmd_pt.infer.deep_eval import DeepPot as DeepPotPT | ||
|
||
return super().__new__(DeepPotPT) | ||
else: | ||
raise NotImplementedError("Unsupported backend: " + str(backend)) | ||
return super().__new__(cls) | ||
|
||
@abstractmethod | ||
def eval( | ||
self, | ||
coords: np.ndarray, | ||
cells: np.ndarray, | ||
atom_types: List[int], | ||
atomic: bool = False, | ||
fparam: Optional[np.ndarray] = None, | ||
aparam: Optional[np.ndarray] = None, | ||
efield: Optional[np.ndarray] = None, | ||
mixed_type: bool = False, | ||
) -> Tuple[np.ndarray, ...]: | ||
"""Evaluate energy, force, and virial. If atomic is True, | ||
also return atomic energy and atomic virial. | ||
Parameters | ||
---------- | ||
coords : np.ndarray | ||
The coordinates of the atoms, in shape (nframes, natoms, 3). | ||
cells : np.ndarray | ||
The cell vectors of the system, in shape (nframes, 9). If the system | ||
is not periodic, set it to None. | ||
atom_types : List[int] | ||
The types of the atoms. If mixed_type is False, the shape is (natoms,); | ||
otherwise, the shape is (nframes, natoms). | ||
atomic : bool, optional | ||
Whether to return atomic energy and atomic virial, by default False. | ||
fparam : np.ndarray, optional | ||
The frame parameters, by default None. | ||
aparam : np.ndarray, optional | ||
The atomic parameters, by default None. | ||
efield : np.ndarray, optional | ||
The electric field, by default None. | ||
mixed_type : bool, optional | ||
Whether the system contains mixed atom types, by default False. | ||
Returns | ||
------- | ||
energy | ||
The energy of the system, in shape (nframes,). | ||
force | ||
The force of the system, in shape (nframes, natoms, 3). | ||
virial | ||
The virial of the system, in shape (nframes, 9). | ||
atomic_energy | ||
The atomic energy of the system, in shape (nframes, natoms). Only returned | ||
when atomic is True. | ||
atomic_virial | ||
The atomic virial of the system, in shape (nframes, natoms, 9). Only returned | ||
when atomic is True. | ||
""" | ||
# This method has been used by: | ||
# documentation python.md | ||
# dp model_devi: +fparam, +aparam, +mixed_type | ||
# dp test: +atomic, +fparam, +aparam, +efield, +mixed_type | ||
# finetune: +mixed_type | ||
# dpdata | ||
# ase | ||
|
||
|
||
__all__ = ["DeepPot"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
"""Unit tests for the universal Python inference interface.""" | ||
|
||
import os | ||
import unittest | ||
|
||
from common import ( | ||
tests_path, | ||
) | ||
|
||
from deepmd.infer.deep_pot import DeepPot as DeepPotTF | ||
from deepmd.utils.convert import ( | ||
convert_pbtxt_to_pb, | ||
) | ||
from deepmd_utils.infer.deep_pot import DeepPot as DeepPot | ||
|
||
|
||
class TestUniversalInfer(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
convert_pbtxt_to_pb( | ||
str(tests_path / os.path.join("infer", "deeppot-r.pbtxt")), "deeppot.pb" | ||
) | ||
|
||
def test_deep_pot(self): | ||
dp = DeepPot("deeppot.pb") | ||
self.assertIsInstance(dp, DeepPotTF) |