Skip to content

Commit

Permalink
Merge branch 'devel' into add_trainable
Browse files Browse the repository at this point in the history
Signed-off-by: Duo <[email protected]>
  • Loading branch information
iProzd authored Mar 1, 2024
2 parents cbb7499 + ee8b82b commit 7940242
Show file tree
Hide file tree
Showing 60 changed files with 1,628 additions and 418 deletions.
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ def mixed_types(self):
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
"""Update mean and stddev for descriptor elements."""
for descrpt in self.descrpt_list:
Expand Down
15 changes: 14 additions & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Union,
)

from deepmd.common import (
Expand Down Expand Up @@ -84,8 +86,19 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
pass

def compute_input_stats(
self, merged: List[dict], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError
Expand Down
9 changes: 9 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@


@BaseDescriptor.register("se_e2_a")
@BaseDescriptor.register("se_a")
class DescrptSeA(NativeOP, BaseDescriptor):
r"""DeepPot-SE constructed from all information (both angular and radial) of
atomic configurations. The embedding takes the distance between atoms as input.
Expand Down Expand Up @@ -243,6 +244,14 @@ def mixed_types(self):
"""
return False

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def mixed_types(self):
"""
return False

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
33 changes: 32 additions & 1 deletion deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ class GeneralFitting(NativeOP, BaseFitting):
different fitting nets for different atom types.
exclude_types: List[int]
Atomic contributions of the excluded atom types are set zero.
remove_vaccum_contribution: List[bool], optional
Remove vaccum contribution before the bias is added. The list assigned each
type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same
length as `ntypes` signaling if or not removing the vaccum contribution for the atom types in the list.
"""

def __init__(
Expand All @@ -95,6 +98,7 @@ def __init__(
spin: Any = None,
mixed_types: bool = True,
exclude_types: List[int] = [],
remove_vaccum_contribution: Optional[List[bool]] = None,
):
self.var_name = var_name
self.ntypes = ntypes
Expand All @@ -119,6 +123,7 @@ def __init__(
self.exclude_types = exclude_types
if self.spin is not None:
raise NotImplementedError("spin is not supported")
self.remove_vaccum_contribution = remove_vaccum_contribution

self.emask = AtomExcludeMask(self.ntypes, self.exclude_types)

Expand Down Expand Up @@ -298,6 +303,14 @@ def _call_common(
"which is not consistent with {self.dim_descrpt}."
)
xx = descriptor
if self.remove_vaccum_contribution is not None:
# TODO: Idealy, the input for vaccum should be computed;
# we consider it as always zero for convenience.
# Needs a compute_input_stats for vaccum passed from the
# descriptor.
xx_zeros = np.zeros_like(xx)
else:
xx_zeros = None
# check fparam dim, concate to input descriptor
if self.numb_fparam > 0:
assert fparam is not None, "fparam should not be None"
Expand All @@ -312,6 +325,11 @@ def _call_common(
[xx, fparam],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = np.concatenate(
[xx_zeros, fparam],
axis=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
assert aparam is not None, "aparam should not be None"
Expand All @@ -326,6 +344,11 @@ def _call_common(
[xx, aparam],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = np.concatenate(
[xx_zeros, aparam],
axis=-1,
)

# calcualte the prediction
if not self.mixed_types:
Expand All @@ -335,11 +358,19 @@ def _call_common(
(atype == type_i).reshape([nf, nloc, 1]), [1, 1, net_dim_out]
)
atom_property = self.nets[(type_i,)](xx)
if self.remove_vaccum_contribution is not None and not (
len(self.remove_vaccum_contribution) > type_i
and not self.remove_vaccum_contribution[type_i]
):
assert xx_zeros is not None
atom_property -= self.nets[(type_i,)](xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
outs = self.nets[()](xx) + self.bias_atom_e[atype]
if xx_zeros is not None:
outs -= self.nets[()](xx_zeros)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
# nf x nloc x nod
Expand Down
5 changes: 3 additions & 2 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def __init__(
raise NotImplementedError("use_aparam_as_mask is not implemented")
if layer_name is not None:
raise NotImplementedError("layer_name is not implemented")
if atom_ener is not None and atom_ener != []:
raise NotImplementedError("atom_ener is not implemented")

self.dim_out = dim_out
self.atom_ener = atom_ener
Expand All @@ -159,6 +157,9 @@ def __init__(
spin=spin,
mixed_types=mixed_types,
exclude_types=exclude_types,
remove_vaccum_contribution=None
if atom_ener is None or len([x for x in atom_ener if x is not None]) == 0
else [x is not None for x in atom_ener],
)

def serialize(self) -> dict:
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def main_parser() -> argparse.ArgumentParser:
"--output",
type=str,
default="out.json",
help="(Supported backend: TensorFlow) The output file of the parameters used in training.",
help="The output file of the parameters used in training.",
)
parser_train.add_argument(
"--skip-neighbor-stat",
Expand Down
82 changes: 31 additions & 51 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.utils.argcheck import (
normalize,
)
Expand All @@ -81,10 +78,6 @@ def get_trainer(
shared_links=None,
):
multi_task = "model_dict" in config.get("model", {})
# argcheck
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

# Initialize DDP
local_rank = os.environ.get("LOCAL_RANK")
Expand All @@ -104,36 +97,23 @@ def get_trainer(
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)

def prepare_trainer_input_single(
model_params_single, data_dict_single, loss_dict_single, suffix=""
model_params_single, data_dict_single, loss_dict_single, suffix="", rank=0
):
training_dataset_params = data_dict_single["training_data"]
type_split = False
if model_params_single["descriptor"]["type"] in ["se_e2_a"]:
type_split = True
validation_dataset_params = data_dict_single["validation_data"]
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
validation_systems = validation_dataset_params["systems"]

# noise params
noise_settings = None
if loss_dict_single.get("type", "ener") == "denoise":
noise_settings = {
"noise_type": loss_dict_single.pop("noise_type", "uniform"),
"noise": loss_dict_single.pop("noise", 1.0),
"noise_mode": loss_dict_single.pop("noise_mode", "fix_num"),
"mask_num": loss_dict_single.pop("mask_num", 8),
"mask_prob": loss_dict_single.pop("mask_prob", 0.15),
"same_mask": loss_dict_single.pop("same_mask", False),
"mask_coord": loss_dict_single.pop("mask_coord", False),
"mask_type": loss_dict_single.pop("mask_type", False),
"max_fail_num": loss_dict_single.pop("max_fail_num", 10),
"mask_type_idx": len(model_params_single["type_map"]) - 1,
}
# noise_settings = None

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
if stat_file_path_single is not None:
if rank != 0:
stat_file_path_single = None
elif stat_file_path_single is not None:
if Path(stat_file_path_single).is_dir():
raise ValueError(
f"stat_file should be a file, not a directory: {stat_file_path_single}"
Expand All @@ -144,71 +124,63 @@ def prepare_trainer_input_single(
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
validation_data_single = DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
validation_data_single = (
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
)
if validation_systems
else None
)
if ckpt or finetune_model:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
sampled_single = None
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
data_stat_nbatch = model_params_single.get("data_stat_nbatch", 10)
sampled_single = make_stat_input(
train_data_single.systems,
train_data_single.dataloaders,
data_stat_nbatch,
)
if noise_settings is not None:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
return (
train_data_single,
validation_data_single,
sampled_single,
stat_file_path_single,
)

rank = dist.get_rank() if dist.is_initialized() else 0
if not multi_task:
(
train_data,
validation_data,
sampled,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
config["model"],
config["training"],
config["loss"],
rank=rank,
)
else:
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}
train_data, validation_data, stat_file_path = {}, {}, {}
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
sampled[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
config["training"]["data_dict"][model_key],
config["loss_dict"][model_key],
suffix=f"_{model_key}",
rank=rank,
)

trainer = training.Trainer(
config,
train_data,
sampled=sampled,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
Expand Down Expand Up @@ -260,6 +232,11 @@ def train(FLAGS):
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

# argcheck
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

# do neighbor stat
if not FLAGS.skip_neighbor_stat:
log.info(
Expand All @@ -281,6 +258,9 @@ def train(FLAGS):
fake_global_jdata, config["model"]["model_dict"][model_item]
)

with open(FLAGS.output, "w") as fp:
json.dump(config, fp, indent=4)

trainer = get_trainer(
config,
FLAGS.init_model,
Expand Down
Loading

0 comments on commit 7940242

Please sign in to comment.