Skip to content

Commit

Permalink
add universal Python inference interface DeepPot (#3164)
Browse files Browse the repository at this point in the history
Need discussion for other classes.

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jan 24, 2024
1 parent 0f9c6eb commit 04c414a
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 1 deletion.
3 changes: 2 additions & 1 deletion deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from deepmd.utils.sess import (
run_sess,
)
from deepmd_utils.infer.deep_pot import DeepPot as DeepPotBase

if TYPE_CHECKING:
from pathlib import (
Expand All @@ -35,7 +36,7 @@
log = logging.getLogger(__name__)


class DeepPot(DeepEval):
class DeepPot(DeepEval, DeepPotBase):
"""Constructor.
Parameters
Expand Down
6 changes: 6 additions & 0 deletions deepmd_utils/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .deep_pot import (
DeepPot,
)

__all__ = ["DeepPot"]
33 changes: 33 additions & 0 deletions deepmd_utils/infer/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from enum import (
Enum,
)


class DPBackend(Enum):
"""DeePMD-kit backend."""

TensorFlow = 1
PyTorch = 2
Paddle = 3
Unknown = 4


def detect_backend(filename: str) -> DPBackend:
"""Detect the backend of the given model file.
Parameters
----------
filename : str
The model file name
"""
if filename.endswith(".pb"):
return DPBackend.TensorFlow
elif filename.endswith(".pth") or filename.endswith(".pt"):
return DPBackend.PyTorch
elif filename.endswith(".pdmodel"):
return DPBackend.Paddle
return DPBackend.Unknown


__all__ = ["DPBackend", "detect_backend"]
126 changes: 126 additions & 0 deletions deepmd_utils/infer/deep_pot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from typing import (
List,
Optional,
Tuple,
Union,
)

import numpy as np

from deepmd_utils.utils.batch_size import (
AutoBatchSize,
)

from .backend import (
DPBackend,
detect_backend,
)


class DeepPot(ABC):
"""Potential energy model.
Parameters
----------
model_file : Path
The name of the frozen model file.
auto_batch_size : bool or int or AutoBatchSize, default: True
If True, automatic batch size will be used. If int, it will be used
as the initial batch size.
neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional
The ASE neighbor list class to produce the neighbor list. If None, the
neighbor list will be built natively in the model.
"""

@abstractmethod
def __init__(
self,
model_file,
*args,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list=None,
**kwargs,
) -> None:
pass

def __new__(cls, model_file: str, *args, **kwargs):
if cls is DeepPot:
backend = detect_backend(model_file)
if backend == DPBackend.TensorFlow:
from deepmd.infer.deep_pot import DeepPot as DeepPotTF

return super().__new__(DeepPotTF)
elif backend == DPBackend.PyTorch:
from deepmd_pt.infer.deep_eval import DeepPot as DeepPotPT

return super().__new__(DeepPotPT)
else:
raise NotImplementedError("Unsupported backend: " + str(backend))
return super().__new__(cls)

@abstractmethod
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
atom_types: List[int],
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
efield: Optional[np.ndarray] = None,
mixed_type: bool = False,
) -> Tuple[np.ndarray, ...]:
"""Evaluate energy, force, and virial. If atomic is True,
also return atomic energy and atomic virial.
Parameters
----------
coords : np.ndarray
The coordinates of the atoms, in shape (nframes, natoms, 3).
cells : np.ndarray
The cell vectors of the system, in shape (nframes, 9). If the system
is not periodic, set it to None.
atom_types : List[int]
The types of the atoms. If mixed_type is False, the shape is (natoms,);
otherwise, the shape is (nframes, natoms).
atomic : bool, optional
Whether to return atomic energy and atomic virial, by default False.
fparam : np.ndarray, optional
The frame parameters, by default None.
aparam : np.ndarray, optional
The atomic parameters, by default None.
efield : np.ndarray, optional
The electric field, by default None.
mixed_type : bool, optional
Whether the system contains mixed atom types, by default False.
Returns
-------
energy
The energy of the system, in shape (nframes,).
force
The force of the system, in shape (nframes, natoms, 3).
virial
The virial of the system, in shape (nframes, 9).
atomic_energy
The atomic energy of the system, in shape (nframes, natoms). Only returned
when atomic is True.
atomic_virial
The atomic virial of the system, in shape (nframes, natoms, 9). Only returned
when atomic is True.
"""
# This method has been used by:
# documentation python.md
# dp model_devi: +fparam, +aparam, +mixed_type
# dp test: +atomic, +fparam, +aparam, +efield, +mixed_type
# finetune: +mixed_type
# dpdata
# ase


__all__ = ["DeepPot"]
27 changes: 27 additions & 0 deletions source/tests/test_uni_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Unit tests for the universal Python inference interface."""

import os
import unittest

from common import (
tests_path,
)

from deepmd.infer.deep_pot import DeepPot as DeepPotTF
from deepmd.utils.convert import (
convert_pbtxt_to_pb,
)
from deepmd_utils.infer.deep_pot import DeepPot as DeepPot


class TestUniversalInfer(unittest.TestCase):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(
str(tests_path / os.path.join("infer", "deeppot-r.pbtxt")), "deeppot.pb"
)

def test_deep_pot(self):
dp = DeepPot("deeppot.pb")
self.assertIsInstance(dp, DeepPotTF)

0 comments on commit 04c414a

Please sign in to comment.