Skip to content

Commit

Permalink
pt: add fparam/aparam data requirements (#3386)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Mar 2, 2024
1 parent d2b18b2 commit 7aee42c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
21 changes: 21 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.data import (
DataRequirementItem,
)

if torch.__version__.startswith("2"):
import torch._dynamo
Expand Down Expand Up @@ -200,6 +203,24 @@ def get_single_model(
_training_data.add_data_requirement(_data_requirement)
if _validation_data is not None:
_validation_data.add_data_requirement(_data_requirement)
if model.get_dim_fparam() > 0:
fparam_requirement_items = [
DataRequirementItem(
"fparam", model.get_dim_fparam(), atomic=False, must=True
)
]
_training_data.add_data_requirement(fparam_requirement_items)
if _validation_data is not None:
_validation_data.add_data_requirement(fparam_requirement_items)
if model.get_dim_aparam() > 0:
aparam_requirement_items = [
DataRequirementItem(
"aparam", model.get_dim_aparam(), atomic=True, must=True
)
]
_training_data.add_data_requirement(aparam_requirement_items)
if _validation_data is not None:
_validation_data.add_data_requirement(aparam_requirement_items)
if not resuming and self.rank == 0:

@functools.lru_cache
Expand Down
3 changes: 3 additions & 0 deletions examples/fparam/train/input.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
{
"_comment1": " model parameters",
"model": {
"type_map": [
"Be"
],
"data_stat_nbatch": 1,
"descriptor": {
"type": "se_a",
Expand Down
22 changes: 22 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,28 @@ def tearDown(self) -> None:
DPTrainTest.tearDown(self)


class TestFparam(unittest.TestCase, DPTrainTest):
"""Test if `fparam` can be loaded correctly."""

def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_se_e2_a)
self.config["model"]["fitting_net"]["numb_fparam"] = 1
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")

def tearDown(self) -> None:
(self.set_path / "fparam.npy").unlink(missing_ok=True)
DPTrainTest.tearDown(self)


class TestEnergyModelDPA1(unittest.TestCase, DPTrainTest):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
Expand Down

0 comments on commit 7aee42c

Please sign in to comment.