Skip to content

Commit

Permalink
Add ut and doc
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 1, 2024
1 parent 7940242 commit f50afb5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
9 changes: 6 additions & 3 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import (
List,
Optional,
Union,
)

import numpy as np
Expand Down Expand Up @@ -239,8 +240,10 @@ class GeneralFitting(Fitting):
Random seed.
exclude_types: List[int]
Atomic contributions of the excluded atom types are set zero.
trainable : bool
If the parameters in the fitting net are trainable.
trainable : Union[List[bool], bool]
If the parameters in the fitting net are trainable.
Now this only supports setting all the parameters in the fitting net at one state.
When in List[bool], the trainable will be True only if all the boolean parameters are True.
remove_vaccum_contribution: List[bool], optional
Remove vaccum contribution before the bias is added. The list assigned each
type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same
Expand All @@ -263,7 +266,7 @@ def __init__(
rcond: Optional[float] = None,
seed: Optional[int] = None,
exclude_types: List[int] = [],
trainable: bool = True,
trainable: Union[bool, List[bool]] = True,
remove_vaccum_contribution: Optional[List[bool]] = None,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,9 +885,9 @@ def fitting_ener():
doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.'
doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision."
doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection'
doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\
doc_trainable = f"Whether the parameters in the fitting net are trainable. This option can be\n\n\

Check warning on line 888 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L888

Added line #L888 was not covered by tests
- bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\
- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1."
- list of bool{doc_only_tf_supported}: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1."
doc_rcond = "The condition number used to determine the inital energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details."
doc_seed = "Random seed for parameter initialization of the fitting net"
doc_atom_ener = "Specify the atomic energy in vacuum for each type"
Expand Down
16 changes: 16 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Path,
)

import torch

from deepmd.pt.entrypoints.main import (
get_trainer,
)
Expand All @@ -28,6 +30,20 @@ def test_dp_train(self):
trainer.run()
self.tearDown()

def test_trainable(self):
fix_params = deepcopy(self.config)
fix_params["model"]["descriptor"]["trainable"] = False
fix_params["model"]["fitting_net"]["trainable"] = False
trainer_fix = get_trainer(fix_params)
model_dict_before_training = deepcopy(trainer_fix.model.state_dict())
trainer_fix.run()
model_dict_after_training = deepcopy(trainer_fix.model.state_dict())
for key in model_dict_before_training:
torch.testing.assert_allclose(
model_dict_before_training[key], model_dict_after_training[key]
)
self.tearDown()

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith(".pt"):
Expand Down

0 comments on commit f50afb5

Please sign in to comment.