diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 47dba909b3..f79916b36e 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -7,6 +7,7 @@ from typing import ( List, Optional, + Union, ) import numpy as np @@ -239,8 +240,10 @@ class GeneralFitting(Fitting): Random seed. exclude_types: List[int] Atomic contributions of the excluded atom types are set zero. - trainable : bool - If the parameters in the fitting net are trainable. + trainable : Union[List[bool], bool] + If the parameters in the fitting net are trainable. + Now this only supports setting all the parameters in the fitting net at one state. + When in List[bool], the trainable will be True only if all the boolean parameters are True. remove_vaccum_contribution: List[bool], optional Remove vaccum contribution before the bias is added. The list assigned each type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same @@ -263,7 +266,7 @@ def __init__( rcond: Optional[float] = None, seed: Optional[int] = None, exclude_types: List[int] = [], - trainable: bool = True, + trainable: Union[bool, List[bool]] = True, remove_vaccum_contribution: Optional[List[bool]] = None, **kwargs, ): diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 89b341491e..1f0064c460 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -885,9 +885,9 @@ def fitting_ener(): doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' - doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\ + doc_trainable = f"Whether the parameters in the fitting net are trainable. This option can be\n\n\ - bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\ -- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1." +- list of bool{doc_only_tf_supported}: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1." doc_rcond = "The condition number used to determine the inital energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details." doc_seed = "Random seed for parameter initialization of the fitting net" doc_atom_ener = "Specify the atomic energy in vacuum for each type" diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 4e73fc4f8a..f2a081610a 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -10,6 +10,8 @@ Path, ) +import torch + from deepmd.pt.entrypoints.main import ( get_trainer, ) @@ -28,6 +30,20 @@ def test_dp_train(self): trainer.run() self.tearDown() + def test_trainable(self): + fix_params = deepcopy(self.config) + fix_params["model"]["descriptor"]["trainable"] = False + fix_params["model"]["fitting_net"]["trainable"] = False + trainer_fix = get_trainer(fix_params) + model_dict_before_training = deepcopy(trainer_fix.model.state_dict()) + trainer_fix.run() + model_dict_after_training = deepcopy(trainer_fix.model.state_dict()) + for key in model_dict_before_training: + torch.testing.assert_allclose( + model_dict_before_training[key], model_dict_after_training[key] + ) + self.tearDown() + def tearDown(self): for f in os.listdir("."): if f.startswith("model") and f.endswith(".pt"):