Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: refact training code #3359

Merged
merged 47 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3812866
Fix single-task training&data stat
iProzd Feb 28, 2024
08e18fe
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
ae27607
Fix EnergyFittingNetDirect
iProzd Feb 28, 2024
7f573ab
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
f9265d5
Add data_requirement for dataloader
iProzd Feb 28, 2024
f8d2980
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
c9eb767
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2024
00105c7
Update make_base_descriptor.py
iProzd Feb 28, 2024
5a9df83
Update typing
iProzd Feb 28, 2024
75da5b1
Update training.py
iProzd Feb 28, 2024
6c171c5
Fix uts
iProzd Feb 28, 2024
2e87e1d
Fix uts
iProzd Feb 28, 2024
eb8094d
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
2618d98
Support multi-task training
iProzd Feb 28, 2024
f1585b2
Take advice from QL scan
iProzd Feb 28, 2024
463f9fb
Support no validation
iProzd Feb 28, 2024
e8575af
Update se_r.py
iProzd Feb 28, 2024
66d03b8
omit data prob log
iProzd Feb 28, 2024
e9e0d95
omit seed log
iProzd Feb 28, 2024
90be50e
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
ab35653
Add fparam and aparam
iProzd Feb 29, 2024
64d6079
Add type hint for `Callable`
iProzd Feb 29, 2024
6020a2b
Fix nopbc
iProzd Feb 29, 2024
5db7883
Add DataRequirementItem
iProzd Feb 29, 2024
c03a5ba
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
cce52da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
18cbf9e
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
cdcfcb2
Fix neighbor-stat for multitask (#31)
iProzd Feb 29, 2024
a7d44d1
Revert "Fix neighbor-stat for multitask (#31)"
iProzd Feb 29, 2024
fdca653
Move label requirement to loss func
iProzd Feb 29, 2024
525ce93
resolve conversations
iProzd Feb 29, 2024
46ee16c
set label_requirement abstractmethod
iProzd Feb 29, 2024
9d18dc4
make label_requirement dynamic
iProzd Feb 29, 2024
ad7227d
update docs
iProzd Feb 29, 2024
35598d2
replace lazy with functools.lru_cache
iProzd Feb 29, 2024
c0a0cfc
Update training.py
iProzd Feb 29, 2024
d50e2a2
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
66edca5
Update deepmd/pt/train/training.py
wanghan-iapcm Feb 29, 2024
d5a1549
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
c51f865
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
e17546a
Update test_multitask.py
iProzd Feb 29, 2024
1debf4f
Fix h5py files in multitask DDP
iProzd Feb 29, 2024
db31edc
FIx h5py file read block
iProzd Feb 29, 2024
60dda49
Merge branch 'devel' into train_rf
iProzd Mar 1, 2024
3dfc31e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2024
615446f
Update hybrid.py
iProzd Mar 1, 2024
e26c118
Update hybrid.py
iProzd Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Union,
)

from deepmd.common import (
Expand Down Expand Up @@ -84,8 +86,15 @@
"""
pass

@abstractmethod
def share_params(self, base_class, shared_level, resume=False):
"""Share the parameters of self to the base_class with shared_level."""
iProzd marked this conversation as resolved.
Show resolved Hide resolved
pass

Check warning on line 92 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L92

Added line #L92 was not covered by tests

def compute_input_stats(
self, merged: List[dict], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
iProzd marked this conversation as resolved.
Show resolved Hide resolved
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@
"""
return False

def share_params(self, base_class, shared_level, resume=False):
"""Share the parameters of self to the base_class with shared_level."""
raise NotImplementedError

Check warning on line 248 in deepmd/dpmodel/descriptor/se_e2_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_e2_a.py#L248

Added line #L248 was not covered by tests

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@
"""
return False

def share_params(self, base_class, shared_level, resume=False):
"""Share the parameters of self to the base_class with shared_level."""
raise NotImplementedError

Check warning on line 208 in deepmd/dpmodel/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_r.py#L208

Added line #L208 was not covered by tests

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
Expand Down
58 changes: 13 additions & 45 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.utils.argcheck import (
normalize,
)
Expand Down Expand Up @@ -97,7 +94,6 @@
multi_task=multi_task,
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])
Expand All @@ -109,26 +105,11 @@
type_split = False
if model_params_single["descriptor"]["type"] in ["se_e2_a"]:
type_split = True
validation_dataset_params = data_dict_single["validation_data"]
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (

Check warning on line 109 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L108-L109

Added lines #L108 - L109 were not covered by tests
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
validation_systems = validation_dataset_params["systems"]

# noise params
noise_settings = None
if loss_dict_single.get("type", "ener") == "denoise":
noise_settings = {
"noise_type": loss_dict_single.pop("noise_type", "uniform"),
"noise": loss_dict_single.pop("noise", 1.0),
"noise_mode": loss_dict_single.pop("noise_mode", "fix_num"),
"mask_num": loss_dict_single.pop("mask_num", 8),
"mask_prob": loss_dict_single.pop("mask_prob", 0.15),
"same_mask": loss_dict_single.pop("same_mask", False),
"mask_coord": loss_dict_single.pop("mask_coord", False),
"mask_type": loss_dict_single.pop("mask_type", False),
"max_fail_num": loss_dict_single.pop("max_fail_num", 10),
"mask_type_idx": len(model_params_single["type_map"]) - 1,
}
# noise_settings = None

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
Expand All @@ -143,59 +124,47 @@
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
validation_data_single = DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
validation_data_single = (

Check warning on line 127 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L127

Added line #L127 was not covered by tests
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
)
if validation_systems
else None
)
if ckpt or finetune_model:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
sampled_single = None
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
data_stat_nbatch = model_params_single.get("data_stat_nbatch", 10)
sampled_single = make_stat_input(
train_data_single.systems,
train_data_single.dataloaders,
data_stat_nbatch,
)
if noise_settings is not None:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
return (
train_data_single,
validation_data_single,
sampled_single,
stat_file_path_single,
)

if not multi_task:
(
train_data,
validation_data,
sampled,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
)
else:
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}
train_data, validation_data, stat_file_path = {}, {}, {}

Check warning on line 163 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L163

Added line #L163 was not covered by tests
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
sampled[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
Expand All @@ -207,7 +176,6 @@
trainer = training.Trainer(
config,
train_data,
sampled=sampled,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
Expand Down
50 changes: 50 additions & 0 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

Check warning on line 2 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L2

Added line #L2 was not covered by tests
List,
)

import torch
import torch.nn.functional as F

Expand All @@ -11,6 +15,9 @@
from deepmd.pt.utils.env import (
GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.utils.data import (

Check warning on line 18 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L18

Added line #L18 was not covered by tests
DataRequirementItem,
)


class EnergyStdLoss(TaskLoss):
Expand Down Expand Up @@ -153,3 +160,46 @@
if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:

Check warning on line 165 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L164-L165

Added lines #L164 - L165 were not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"""Return data label requirements needed for this loss calculation."""
data_requirement = [

Check warning on line 167 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L167

Added line #L167 was not covered by tests
DataRequirementItem(
"energy",
ndof=1,
atomic=False,
must=False,
high_prec=True,
),
DataRequirementItem(
"force",
ndof=3,
atomic=True,
must=False,
high_prec=False,
),
DataRequirementItem(
"virial",
ndof=9,
atomic=False,
must=False,
high_prec=False,
),
DataRequirementItem(
"atom_ener",
ndof=1,
atomic=True,
must=False,
high_prec=False,
),
DataRequirementItem(
"atom_pref",
ndof=1,
atomic=True,
must=False,
high_prec=False,
repeat=3,
),
]
return data_requirement

Check warning on line 205 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L205

Added line #L205 was not covered by tests
13 changes: 13 additions & 0 deletions deepmd/pt/loss/loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

Check warning on line 2 in deepmd/pt/loss/loss.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/loss.py#L2

Added line #L2 was not covered by tests
List,
)

import torch

from deepmd.utils.data import (

Check warning on line 8 in deepmd/pt/loss/loss.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/loss.py#L8

Added line #L8 was not covered by tests
DataRequirementItem,
)


class TaskLoss(torch.nn.Module):
def __init__(self, **kwargs):
Expand All @@ -10,3 +18,8 @@
def forward(self, model_pred, label, natoms, learning_rate):
"""Return loss ."""
raise NotImplementedError

@property
def label_requirement(self) -> List[DataRequirementItem]:

Check warning on line 23 in deepmd/pt/loss/loss.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/loss.py#L22-L23

Added lines #L22 - L23 were not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"""Return data label requirements needed for this loss calculation."""
raise NotImplementedError

Check warning on line 25 in deepmd/pt/loss/loss.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/loss.py#L25

Added line #L25 was not covered by tests
17 changes: 5 additions & 12 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
)
from deepmd.pt.utils.utils import (
dict_to_device,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -185,7 +182,7 @@

def compute_or_load_stat(
self,
sampled,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
"""
Expand All @@ -198,22 +195,18 @@

Parameters
----------
sampled
The sampled data frames from different data systems.
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
"""
if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)
for data_sys in sampled:
dict_to_device(data_sys)
if sampled is None:
sampled = []
self.descriptor.compute_input_stats(sampled, stat_file_path)
self.descriptor.compute_input_stats(sampled_func, stat_file_path)

Check warning on line 207 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L207

Added line #L207 was not covered by tests
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(sampled, stat_file_path)
self.fitting_net.compute_output_stats(sampled_func, stat_file_path)

Check warning on line 209 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L209

Added line #L209 was not covered by tests

@torch.jit.export
def get_dim_fparam(self) -> int:
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_descriptor import (

Check warning on line 2 in deepmd/pt/model/descriptor/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/__init__.py#L2

Added line #L2 was not covered by tests
BaseDescriptor,
)
from .descriptor import (
DescriptorBlock,
make_default_type_embedding,
Expand Down Expand Up @@ -31,6 +34,7 @@
)

__all__ = [
"BaseDescriptor",
"DescriptorBlock",
"make_default_type_embedding",
"DescrptBlockSeA",
Expand Down
8 changes: 7 additions & 1 deletion deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
abstractmethod,
)
from typing import (
Callable,
Dict,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -86,7 +88,11 @@
"""Returns the embedding dimension."""
pass

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(

Check warning on line 91 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L91

Added line #L91 was not covered by tests
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""Update mean and stddev for DescriptorBlock elements."""
raise NotImplementedError

Expand Down
Loading
Loading