Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: refact training code #3359

Merged
merged 47 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3812866
Fix single-task training&data stat
iProzd Feb 28, 2024
08e18fe
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
ae27607
Fix EnergyFittingNetDirect
iProzd Feb 28, 2024
7f573ab
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
f9265d5
Add data_requirement for dataloader
iProzd Feb 28, 2024
f8d2980
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
c9eb767
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2024
00105c7
Update make_base_descriptor.py
iProzd Feb 28, 2024
5a9df83
Update typing
iProzd Feb 28, 2024
75da5b1
Update training.py
iProzd Feb 28, 2024
6c171c5
Fix uts
iProzd Feb 28, 2024
2e87e1d
Fix uts
iProzd Feb 28, 2024
eb8094d
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
2618d98
Support multi-task training
iProzd Feb 28, 2024
f1585b2
Take advice from QL scan
iProzd Feb 28, 2024
463f9fb
Support no validation
iProzd Feb 28, 2024
e8575af
Update se_r.py
iProzd Feb 28, 2024
66d03b8
omit data prob log
iProzd Feb 28, 2024
e9e0d95
omit seed log
iProzd Feb 28, 2024
90be50e
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
ab35653
Add fparam and aparam
iProzd Feb 29, 2024
64d6079
Add type hint for `Callable`
iProzd Feb 29, 2024
6020a2b
Fix nopbc
iProzd Feb 29, 2024
5db7883
Add DataRequirementItem
iProzd Feb 29, 2024
c03a5ba
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
cce52da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
18cbf9e
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
cdcfcb2
Fix neighbor-stat for multitask (#31)
iProzd Feb 29, 2024
a7d44d1
Revert "Fix neighbor-stat for multitask (#31)"
iProzd Feb 29, 2024
fdca653
Move label requirement to loss func
iProzd Feb 29, 2024
525ce93
resolve conversations
iProzd Feb 29, 2024
46ee16c
set label_requirement abstractmethod
iProzd Feb 29, 2024
9d18dc4
make label_requirement dynamic
iProzd Feb 29, 2024
ad7227d
update docs
iProzd Feb 29, 2024
35598d2
replace lazy with functools.lru_cache
iProzd Feb 29, 2024
c0a0cfc
Update training.py
iProzd Feb 29, 2024
d50e2a2
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
66edca5
Update deepmd/pt/train/training.py
wanghan-iapcm Feb 29, 2024
d5a1549
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
c51f865
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
e17546a
Update test_multitask.py
iProzd Feb 29, 2024
1debf4f
Fix h5py files in multitask DDP
iProzd Feb 29, 2024
db31edc
FIx h5py file read block
iProzd Feb 29, 2024
60dda49
Merge branch 'devel' into train_rf
iProzd Mar 1, 2024
3dfc31e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2024
615446f
Update hybrid.py
iProzd Mar 1, 2024
e26c118
Update hybrid.py
iProzd Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Union,
)

from deepmd.common import (
Expand Down Expand Up @@ -85,7 +87,7 @@ def mixed_types(self) -> bool:
pass

def compute_input_stats(
self, merged: List[dict], path: Optional[DPPath] = None
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
njzjz marked this conversation as resolved.
Show resolved Hide resolved
):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def is_aparam_nall(self) -> bool:
def model_output_type(self) -> str:
"""Get the output type for the model."""

@abstractmethod
def data_requirement(self) -> dict:
"""Get the data requirement for the model."""

@abstractmethod
def serialize(self) -> dict:
"""Serialize the model.
Expand Down
41 changes: 1 addition & 40 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -83,7 +80,6 @@
multi_task=multi_task,
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])
Expand All @@ -98,24 +94,6 @@
validation_dataset_params = data_dict_single["validation_data"]
training_systems = training_dataset_params["systems"]
validation_systems = validation_dataset_params["systems"]

# noise params
noise_settings = None
if loss_dict_single.get("type", "ener") == "denoise":
noise_settings = {
"noise_type": loss_dict_single.pop("noise_type", "uniform"),
"noise": loss_dict_single.pop("noise", 1.0),
"noise_mode": loss_dict_single.pop("noise_mode", "fix_num"),
"mask_num": loss_dict_single.pop("mask_num", 8),
"mask_prob": loss_dict_single.pop("mask_prob", 0.15),
"same_mask": loss_dict_single.pop("same_mask", False),
"mask_coord": loss_dict_single.pop("mask_coord", False),
"mask_type": loss_dict_single.pop("mask_type", False),
"max_fail_num": loss_dict_single.pop("max_fail_num", 10),
"mask_type_idx": len(model_params_single["type_map"]) - 1,
}
# noise_settings = None

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
if stat_file_path_single is not None:
Expand All @@ -140,48 +118,32 @@
training_dataset_params["batch_size"],
model_params_single,
)
sampled_single = None
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
data_stat_nbatch = model_params_single.get("data_stat_nbatch", 10)
sampled_single = make_stat_input(
train_data_single.systems,
train_data_single.dataloaders,
data_stat_nbatch,
)
if noise_settings is not None:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
return (
train_data_single,
validation_data_single,
sampled_single,
stat_file_path_single,
)

if not multi_task:
(
train_data,
validation_data,
sampled,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
)
else:
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}
train_data, validation_data, stat_file_path = {}, {}, {}

Check warning on line 142 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L142

Added line #L142 was not covered by tests
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
sampled[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
Expand All @@ -193,7 +155,6 @@
trainer = training.Trainer(
config,
train_data,
sampled=sampled,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
Expand Down
17 changes: 5 additions & 12 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
)
from deepmd.pt.utils.utils import (
dict_to_device,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -170,7 +167,7 @@

def compute_or_load_stat(
self,
sampled,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
"""
Expand All @@ -183,22 +180,18 @@

Parameters
----------
sampled
The sampled data frames from different data systems.
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
"""
if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)
for data_sys in sampled:
dict_to_device(data_sys)
if sampled is None:
sampled = []
self.descriptor.compute_input_stats(sampled, stat_file_path)
self.descriptor.compute_input_stats(sampled_func, stat_file_path)

Check warning on line 192 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L192

Added line #L192 was not covered by tests
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(sampled, stat_file_path)
self.fitting_net.compute_output_stats(sampled_func, stat_file_path)

Check warning on line 194 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L194

Added line #L194 was not covered by tests

@torch.jit.export
def get_dim_fparam(self) -> int:
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
abstractmethod,
)
from typing import (
Callable,
Dict,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -86,7 +88,9 @@
"""Returns the embedding dimension."""
pass

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(

Check warning on line 91 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L91

Added line #L91 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
):
"""Update mean and stddev for DescriptorBlock elements."""
raise NotImplementedError

Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -128,7 +130,9 @@
def dim_emb(self):
return self.get_dim_emb()

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(

Check warning on line 133 in deepmd/pt/model/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa1.py#L133

Added line #L133 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
):
return self.se_atten.compute_input_stats(merged, path)

def serialize(self) -> dict:
Expand Down
15 changes: 6 additions & 9 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -295,16 +297,11 @@
"""Returns the embedding dimension g2."""
return self.get_dim_emb()

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(

Check warning on line 300 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L300

Added line #L300 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
):
for ii, descrpt in enumerate([self.repinit, self.repformers]):
merged_tmp = [
{
key: item[key] if not isinstance(item[key], list) else item[key][ii]
for key in item
}
for item in merged
]
descrpt.compute_input_stats(merged_tmp, path)
descrpt.compute_input_stats(merged, path)

Check warning on line 304 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L304

Added line #L304 was not covered by tests

def serialize(self) -> dict:
"""Serialize the obj to dict."""
Expand Down
16 changes: 7 additions & 9 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -157,17 +159,13 @@
else:
raise NotImplementedError

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(

Check warning on line 162 in deepmd/pt/model/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/hybrid.py#L162

Added line #L162 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
):
"""Update mean and stddev for descriptor elements."""
for ii, descrpt in enumerate(self.descriptor_list):
merged_tmp = [
{
key: item[key] if not isinstance(item[key], list) else item[key][ii]
for key in item
}
for item in merged
]
descrpt.compute_input_stats(merged_tmp, path)
# need support for hybrid descriptors
descrpt.compute_input_stats(merged, path)

Check warning on line 168 in deepmd/pt/model/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/hybrid.py#L168

Added line #L168 was not covered by tests

def forward(
self,
Expand Down
16 changes: 14 additions & 2 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -278,12 +280,22 @@

return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(

Check warning on line 283 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L283

Added line #L283 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
env_mat_stat.load_or_compute_stats(merged, path)
if path is None or not path.is_dir():
if callable(merged):

Check warning on line 291 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L290-L291

Added lines #L290 - L291 were not covered by tests
# only get data for once
sampled = merged()

Check warning on line 293 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L293

Added line #L293 was not covered by tests
else:
sampled = merged

Check warning on line 295 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L295

Added line #L295 was not covered by tests
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)

Check warning on line 298 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L297-L298

Added lines #L297 - L298 were not covered by tests
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
Expand Down
20 changes: 17 additions & 3 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools
from typing import (
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Union,
)

import numpy as np
Expand Down Expand Up @@ -129,7 +131,9 @@
"""Returns the output dimension of this descriptor."""
return self.sea.dim_out

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(

Check warning on line 134 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L134

Added line #L134 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
):
"""Update mean and stddev for descriptor elements."""
return self.sea.compute_input_stats(merged, path)

Expand Down Expand Up @@ -387,12 +391,22 @@
else:
raise KeyError(key)

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(

Check warning on line 394 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L394

Added line #L394 was not covered by tests
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
env_mat_stat.load_or_compute_stats(merged, path)
if path is None or not path.is_dir():
if callable(merged):

Check warning on line 402 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L401-L402

Added lines #L401 - L402 were not covered by tests
# only get data for once
sampled = merged()

Check warning on line 404 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L404

Added line #L404 was not covered by tests
else:
sampled = merged

Check warning on line 406 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L406

Added line #L406 was not covered by tests
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)

Check warning on line 409 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L408-L409

Added lines #L408 - L409 were not covered by tests
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
Expand Down
Loading
Loading