-
Notifications
You must be signed in to change notification settings - Fork 525
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: manage testing models in a standard way (#4028)
Fix #2103. Migrate three models (se_e2_a, se_e2_r, and fparam_aparam) for the Python unit tests. Fix several bugs. Old files are kept until the C++ tests are also migrated. Note that several models (for example, the dipole model due to #3672) cannot be serialized yet. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit - **New Features** - Introduced a structured framework for managing and testing models with YAML files. - Added comprehensive configurations for energy calculations and molecular simulations in YAML format. - Implemented new test cases for the `DeepPot` and `DeepPotNeighborList` classes. - **Bug Fixes** - Improved robustness in tensor reshaping, resolving potential dimension mismatches. - **Tests** - Enhanced unit tests with a case-based approach for better adaptability and maintainability. - Consolidated tests by relocating obsolete classes to streamline the test suite. - **Chores** - Updated deserialization functions for better type safety and input handling. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
22 changed files
with
8,964 additions
and
1,723 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
deepmd_test_models*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
"""Manage testing models in a standard way. | ||
For each model, a YAML file ending with `-testcase.yaml` must be given. It should contains the following keys: | ||
- `key`: The key of the model. | ||
- `filename`: The path to the model file. | ||
- `ntypes`: The number of atomic types. | ||
- `rcut`: The cutoff radius. | ||
- `type_map`: The mapping between atomic types and atomic names. | ||
- `dim_fparam`: The number of frame parameters. | ||
- `dim_aparam`: The number of atomic parameters. | ||
- `results`: A list of results. Each result should contain the following keys: | ||
- `atype`: The atomic types. | ||
- `coord`: The atomic coordinates. | ||
- `box`: The simulation box. | ||
- `atomic_energy` or `energy` (optional): The atomic energies or the total energy. | ||
- `force` (optional): The atomic forces. | ||
- `atomic_virial` or `virial` (optional): The atomic virials or the total virial. | ||
""" | ||
|
||
import tempfile | ||
from functools import ( | ||
lru_cache, | ||
) | ||
from pathlib import ( | ||
Path, | ||
) | ||
from typing import ( | ||
Dict, | ||
Optional, | ||
) | ||
|
||
import numpy as np | ||
import yaml | ||
|
||
from deepmd.entrypoints.convert_backend import ( | ||
convert_backend, | ||
) | ||
|
||
this_directory = Path(__file__).parent.resolve() | ||
# create a temporary directory under this directory | ||
# to store the temporary model files | ||
# it will be deleted when the program exits | ||
tempdir = tempfile.TemporaryDirectory(dir=this_directory, prefix="deepmd_test_models_") | ||
|
||
|
||
class Result: | ||
"""Test results. | ||
Parameters | ||
---------- | ||
data : dict | ||
Dictionary containing the results. | ||
Attributes | ||
---------- | ||
atype : np.ndarray | ||
The atomic types. | ||
nloc : int | ||
The number of atoms. | ||
coord : np.ndarray | ||
The atomic coordinates. | ||
box : np.ndarray | ||
The simulation box. | ||
atomic_energy : np.ndarray | ||
The atomic energies. | ||
energy : np.ndarray | ||
The total energy. | ||
force : np.ndarray | ||
The atomic forces. | ||
atomic_virial : np.ndarray | ||
The atomic virials. | ||
virial : np.ndarray | ||
The total virial. | ||
""" | ||
|
||
def __init__(self, data: dict) -> None: | ||
self.atype = np.array(data["atype"], dtype=np.int64) | ||
self.nloc = self.atype.size | ||
self.coord = np.array(data["coord"], dtype=np.float64).reshape(self.nloc, 3) | ||
if data["box"] is not None: | ||
self.box = np.array(data["box"], dtype=np.float64).reshape(3, 3) | ||
else: | ||
self.box = None | ||
if "fparam" in data: | ||
self.fparam = np.array(data["fparam"], dtype=np.float64).ravel() | ||
else: | ||
self.fparam = None | ||
if "aparam" in data: | ||
self.aparam = np.array(data["aparam"], dtype=np.float64).reshape( | ||
self.nloc, -1 | ||
) | ||
else: | ||
self.aparam = None | ||
if "atomic_energy" in data: | ||
self.atomic_energy = np.array( | ||
data["atomic_energy"], dtype=np.float64 | ||
).reshape(self.nloc, 1) | ||
self.energy = np.sum(self.atomic_energy, axis=0) | ||
elif "energy" in data: | ||
self.atomic_energy = None | ||
self.energy = np.array(data["energy"], dtype=np.float64).reshape(1) | ||
else: | ||
self.atomic_energy = None | ||
self.energy = None | ||
if "force" in data: | ||
self.force = np.array(data["force"], dtype=np.float64).reshape(self.nloc, 3) | ||
else: | ||
self.force = None | ||
if "atomic_virial" in data: | ||
self.atomic_virial = np.array( | ||
data["atomic_virial"], dtype=np.float64 | ||
).reshape(self.nloc, 9) | ||
self.virial = np.sum(self.atomic_virial, axis=0) | ||
elif "virial" in data: | ||
self.atomic_virial = None | ||
self.virial = np.array(data["virial"], dtype=np.float64).reshape(9) | ||
else: | ||
self.atomic_virial = None | ||
self.virial = None | ||
if "descriptor" in data: | ||
self.descriptor = np.array(data["descriptor"], dtype=np.float64).reshape( | ||
self.nloc, -1 | ||
) | ||
else: | ||
self.descriptor = None | ||
|
||
|
||
class Case: | ||
"""Test case. | ||
Parameters | ||
---------- | ||
filename : str | ||
The path to the test case file. | ||
""" | ||
|
||
def __init__(self, filename: str): | ||
with open(filename) as file: | ||
config = yaml.safe_load(file) | ||
self.key = config["key"] | ||
self.filename = str(Path(filename).parent / config["filename"]) | ||
self.results = [Result(data) for data in config["results"]] | ||
self.ntypes = config["ntypes"] | ||
self.rcut = config["rcut"] | ||
self.type_map = config["type_map"] | ||
self.dim_fparam = config["dim_fparam"] | ||
self.dim_aparam = config["dim_aparam"] | ||
|
||
@lru_cache | ||
def get_model(self, suffix: str, out_file: Optional[str] = None) -> str: | ||
"""Get the model file with the specified suffix. | ||
Parameters | ||
---------- | ||
suffix : str | ||
The suffix of the model file. | ||
out_file : str, optional | ||
The path to the output model file. If not given, a temporary file will be created. | ||
Returns | ||
------- | ||
str | ||
The path to the model file. | ||
""" | ||
# generate a temporary model file | ||
if out_file is None: | ||
out_file = tempfile.NamedTemporaryFile( | ||
suffix=suffix, dir=tempdir.name, delete=False, prefix=self.key + "_" | ||
).name | ||
convert_backend(INPUT=self.filename, OUTPUT=out_file) | ||
return out_file | ||
|
||
|
||
@lru_cache | ||
def get_cases() -> Dict[str, Case]: | ||
"""Get all test cases. | ||
Returns | ||
------- | ||
Dict[str, Case] | ||
A dictionary containing all test cases. | ||
Examples | ||
-------- | ||
To get a specific case: | ||
>>> get_cases()["se_e2_a"] | ||
""" | ||
cases = {} | ||
for ff in this_directory.glob("*-testcase.yaml"): | ||
case = Case(ff) | ||
cases[case.key] = case | ||
return cases |
Oops, something went wrong.