Skip to content

Commit

Permalink
support deeppolar; support ase
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Dec 8, 2023
1 parent e4a9e75 commit a3660e7
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 8 deletions.
5 changes: 4 additions & 1 deletion deepmd/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class DP(Calculator):
type_dict : Dict[str, int], optional
mapping of element types and their numbers, best left None and the calculator
will infer this information from model, by default None
neighbor_list : ase.neighborlist.NeighborList, optional
The neighbor list object. If None, then build the native neighbor list.
Examples
--------
Expand Down Expand Up @@ -83,10 +85,11 @@ def __init__(
model: Union[str, "Path"],
label: str = "DP",
type_dict: Optional[Dict[str, int]] = None,
neighbor_list=None,
**kwargs,
) -> None:
Calculator.__init__(self, label=label, **kwargs)
self.dp = DeepPotential(str(Path(model).resolve()))
self.dp = DeepPotential(str(Path(model).resolve()), neighbor_list=neighbor_list)
if type_dict:
self.type_dict = type_dict
else:
Expand Down
7 changes: 7 additions & 0 deletions deepmd/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def DeepPotential(
load_prefix: str = "load",
default_tf_graph: bool = False,
input_map: Optional[dict] = None,
neighbor_list=None,
) -> Union[DeepDipole, DeepGlobalPolar, DeepPolar, DeepPot, DeepDOS, DeepWFC]:
"""Factory function that will inialize appropriate potential read from `model_file`.
Expand All @@ -71,6 +72,8 @@ def DeepPotential(
If uses the default tf graph, otherwise build a new tf graph for evaluation
input_map : dict, optional
The input map for tf.import_graph_def. Only work with default tf graph
neighbor_list : ase.neighborlist.NeighborList, optional
The neighbor list object. If None, then build the native neighbor list.
Returns
-------
Expand All @@ -97,6 +100,7 @@ def DeepPotential(
load_prefix=load_prefix,
default_tf_graph=default_tf_graph,
input_map=input_map,
neighbor_list=neighbor_list,
)
elif model_type == "dos":
dp = DeepDOS(
Expand All @@ -111,20 +115,23 @@ def DeepPotential(
load_prefix=load_prefix,
default_tf_graph=default_tf_graph,
input_map=input_map,
neighbor_list=neighbor_list,
)
elif model_type == "polar":
dp = DeepPolar(
mf,
load_prefix=load_prefix,
default_tf_graph=default_tf_graph,
input_map=input_map,
neighbor_list=neighbor_list,
)
elif model_type == "global_polar":
dp = DeepGlobalPolar(
mf,
load_prefix=load_prefix,
default_tf_graph=default_tf_graph,
input_map=input_map,
neighbor_list=neighbor_list,
)
elif model_type == "wfc":
dp = DeepWFC(
Expand Down
13 changes: 12 additions & 1 deletion deepmd/infer/deep_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class DeepPolar(DeepTensor):
If uses the default tf graph, otherwise build a new tf graph for evaluation
input_map : dict, optional
The input map for tf.import_graph_def. Only work with default tf graph
neighbor_list : ase.neighborlist.NeighborList, optional
The neighbor list object. If None, then build the native neighbor list.
Warnings
--------
Expand All @@ -44,6 +46,7 @@ def __init__(
load_prefix: str = "load",
default_tf_graph: bool = False,
input_map: Optional[dict] = None,
neighbor_list=None,
) -> None:
# use this in favor of dict update to move attribute from class to
# instance namespace
Expand All @@ -61,6 +64,7 @@ def __init__(
load_prefix=load_prefix,
default_tf_graph=default_tf_graph,
input_map=input_map,
neighbor_list=neighbor_list,
)

def get_dim_fparam(self) -> int:
Expand All @@ -83,10 +87,16 @@ class DeepGlobalPolar(DeepTensor):
The prefix in the load computational graph
default_tf_graph : bool
If uses the default tf graph, otherwise build a new tf graph for evaluation
neighbor_list : ase.neighborlist.NeighborList, optional
The neighbor list object. If None, then build the native neighbor list.
"""

def __init__(
self, model_file: str, load_prefix: str = "load", default_tf_graph: bool = False
self,
model_file: str,
load_prefix: str = "load",
default_tf_graph: bool = False,
neighbor_list=None,
) -> None:
self.tensors.update(
{
Expand All @@ -101,6 +111,7 @@ def __init__(
model_file,
load_prefix=load_prefix,
default_tf_graph=default_tf_graph,
neighbor_list=None,
)

def eval(
Expand Down
4 changes: 4 additions & 0 deletions source/tests/test_deepdipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,10 @@ def test_1frame_full_atm_shuffle(self):
)


@unittest.skipIf(
parse_version(tf.__version__) < parse_version("1.15"),
f"The current tf version {tf.__version__} is too low to run the new testing model.",
)
class TestDeepDipoleNewPBCNeighborList(TestDeepDipoleNewPBC):
@classmethod
def setUpClass(cls):
Expand Down
34 changes: 28 additions & 6 deletions source/tests/test_deeppolar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import unittest

import ase.neighborlist
import numpy as np
from common import (
tests_path,
Expand Down Expand Up @@ -980,12 +981,6 @@ def test_1frame_full_atm(self):
self.coords, self.box, self.atype, atomic=True
)

# print the values
for dd in (at, ff, av):
print("\n\n")
print(", ".join(f"{i:.18e}" for i in dd.reshape(-1)))
print("\n\n")

# check shape of the returns
nframes = 1
natoms = len(self.atype)
Expand Down Expand Up @@ -1088,3 +1083,30 @@ def test_2frame_full_atm(self):
np.testing.assert_almost_equal(
vv.reshape([-1]), expected_gv.reshape([-1]), decimal=default_places
)


@unittest.skipIf(
parse_version(tf.__version__) < parse_version("1.15"),
f"The current tf version {tf.__version__} is too low to run the new testing model.",
)
class TestDeepPolarNewPBCNeighborList(unittest.TestCase):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(
str(tests_path / os.path.join("infer", "deeppolar_new.pbtxt")),
"deeppolar_new.pb",
)
cls.dp = DeepPolar(
"deeppolar_new.pb",
neighbor_list=ase.neighborlist.NewPrimitiveNeighborList(
cutoffs=6, bothways=True
),
)

@unittest.skip("multiple frames not supported")
def test_2frame_full_atm(self):
pass

@unittest.skip("multiple frames not supported")
def test_2frame_old_atm(self):
pass

0 comments on commit a3660e7

Please sign in to comment.