Skip to content

Commit

Permalink
Add trainable settings for pt (#3371)
Browse files Browse the repository at this point in the history
Signed-off-by: Duo <[email protected]>
  • Loading branch information
iProzd authored Mar 2, 2024
1 parent ee8b82b commit bdea3ce
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 30 deletions.
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def __init__(
raise NotImplementedError("type_one_side is not supported.")
if precision != "default" and precision != "float64":
raise NotImplementedError("precison is not supported.")
if not trainable:
raise NotImplementedError("trainable == False is not supported.")
if exclude_types is not None and exclude_types != []:
raise NotImplementedError("exclude_types is not supported.")
if stripped_type_embedding:
Expand Down Expand Up @@ -108,6 +106,9 @@ def __init__(
self.type_embedding = TypeEmbedNet(ntypes, tebd_dim)
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
# set trainable
for param in self.parameters():
param.requires_grad = trainable

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
repformer_update_style: str = "res_avg",
repformer_set_davg_zero: bool = True, # TODO
repformer_add_type_ebd_to_seq: bool = False,
trainable: bool = True,
type: Optional[
str
] = None, # work around the bad design in get_trainer and DpLoaderSet!
Expand Down Expand Up @@ -172,6 +173,8 @@ def __init__(
repformers block: set the avg to zero in statistics
repformer_add_type_ebd_to_seq : bool
repformers block: concatenate the type embedding at the output.
trainable : bool
If the parameters in the descriptor are trainable.
Returns
-------
Expand Down Expand Up @@ -251,6 +254,9 @@ def __init__(
self.rcut = self.repinit.get_rcut()
self.ntypes = ntypes
self.sel = self.repinit.sel
# set trainable
for param in self.parameters():
param.requires_grad = trainable

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def __init__(
exclude_types: List[Tuple[int, int]] = [],
old_impl: bool = False,
type_one_side: bool = True,
trainable: bool = True,
**kwargs,
):
"""Construct an embedding net of type `se_a`.
Expand Down Expand Up @@ -384,6 +385,9 @@ def __init__(
)
self.filter_layers = filter_layers
self.stats = None
# set trainable
for param in self.parameters():
param.requires_grad = trainable

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
old_impl: bool = False,
trainable: bool = True,
**kwargs,
):
super().__init__()
Expand Down Expand Up @@ -112,6 +113,9 @@ def __init__(
)
self.filter_layers = filter_layers
self.stats = None
# set trainable
for param in self.parameters():
param.requires_grad = trainable

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down
16 changes: 15 additions & 1 deletion deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import (
List,
Optional,
Union,
)

import numpy as np
Expand Down Expand Up @@ -239,6 +240,10 @@ class GeneralFitting(Fitting):
Random seed.
exclude_types: List[int]
Atomic contributions of the excluded atom types are set zero.
trainable : Union[List[bool], bool]
If the parameters in the fitting net are trainable.
Now this only supports setting all the parameters in the fitting net at one state.
When in List[bool], the trainable will be True only if all the boolean parameters are True.
remove_vaccum_contribution: List[bool], optional
Remove vaccum contribution before the bias is added. The list assigned each
type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same
Expand All @@ -261,6 +266,7 @@ def __init__(
rcond: Optional[float] = None,
seed: Optional[int] = None,
exclude_types: List[int] = [],
trainable: Union[bool, List[bool]] = True,
remove_vaccum_contribution: Optional[List[bool]] = None,
**kwargs,
):
Expand All @@ -279,6 +285,11 @@ def __init__(
self.rcond = rcond
# order matters, should be place after the assignment of ntypes
self.reinit_exclude(exclude_types)
self.trainable = trainable
# need support for each layer settings
self.trainable = (
all(self.trainable) if isinstance(self.trainable, list) else self.trainable
)
self.remove_vaccum_contribution = remove_vaccum_contribution

net_dim_out = self._net_out_dim()
Expand Down Expand Up @@ -353,6 +364,9 @@ def __init__(

if seed is not None:
torch.manual_seed(seed)
# set trainable
for param in self.parameters():
param.requires_grad = self.trainable

def reinit_exclude(
self,
Expand Down Expand Up @@ -394,7 +408,7 @@ def serialize(self) -> dict:
# "spin": self.spin ,
## NOTICE: not supported by far
"tot_ener_zero": False,
"trainable": [True] * (len(self.neuron) + 1),
"trainable": [self.trainable] * (len(self.neuron) + 1),
"layer_name": None,
"use_aparam_as_mask": False,
"spin": None,
Expand Down
3 changes: 0 additions & 3 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,6 @@ def get_loss(loss_params, start_lr, _ntypes):
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
self.model.load_state_dict(frz_model.state_dict())

# Set trainable params
self.wrapper.set_trainable_params()

# Multi-task share params
if shared_links is not None:
self.wrapper.share_params(shared_links, resume=resuming or self.rank != 0)
Expand Down
22 changes: 0 additions & 22 deletions deepmd/pt/train/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,6 @@ def __init__(
self.loss[task_key] = loss[task_key]
self.inference_only = self.loss is None

def set_trainable_params(self):
supported_types = ["descriptor", "fitting_net"]
for model_item in self.model:
for net_type in supported_types:
trainable = True
if not self.multi_task:
if net_type in self.model_params:
trainable = self.model_params[net_type].get("trainable", True)
else:
if net_type in self.model_params["model_dict"][model_item]:
trainable = self.model_params["model_dict"][model_item][
net_type
].get("trainable", True)
if (
hasattr(self.model[model_item], net_type)
and getattr(self.model[model_item], net_type) is not None
):
for param in (
self.model[model_item].__getattr__(net_type).parameters()
):
param.requires_grad = trainable

def share_params(self, shared_links, resume=False):
"""
Share the parameters of classes following rules defined in shared_links during multitask training.
Expand Down
4 changes: 2 additions & 2 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,9 +885,9 @@ def fitting_ener():
doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.'
doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision."
doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection'
doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\
doc_trainable = f"Whether the parameters in the fitting net are trainable. This option can be\n\n\
- bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\
- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1."
- list of bool{doc_only_tf_supported}: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1."
doc_rcond = "The condition number used to determine the inital energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details."
doc_seed = "Random seed for parameter initialization of the fitting net"
doc_atom_ener = "Specify the atomic energy in vacuum for each type"
Expand Down
16 changes: 16 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Path,
)

import torch

from deepmd.pt.entrypoints.main import (
get_trainer,
)
Expand All @@ -28,6 +30,20 @@ def test_dp_train(self):
trainer.run()
self.tearDown()

def test_trainable(self):
fix_params = deepcopy(self.config)
fix_params["model"]["descriptor"]["trainable"] = False
fix_params["model"]["fitting_net"]["trainable"] = False
trainer_fix = get_trainer(fix_params)
model_dict_before_training = deepcopy(trainer_fix.model.state_dict())
trainer_fix.run()
model_dict_after_training = deepcopy(trainer_fix.model.state_dict())
for key in model_dict_before_training:
torch.testing.assert_allclose(
model_dict_before_training[key], model_dict_after_training[key]
)
self.tearDown()

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith(".pt"):
Expand Down

0 comments on commit bdea3ce

Please sign in to comment.