Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact: the DPA2 descriptor #3758

Merged
merged 42 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
7d82945
feat: Support `stripped_type_embedding` in PT/DP
iProzd Apr 25, 2024
a230198
Update train-se-atten.md
iProzd Apr 25, 2024
5157781
Update graph.py
iProzd Apr 25, 2024
f780d58
Update deepmd/utils/argcheck.py
iProzd Apr 25, 2024
3b3d25e
Update deepmd/pt/model/descriptor/se_atten.py
iProzd Apr 25, 2024
cf841f2
Update deepmd/tf/descriptor/se_a.py
iProzd Apr 25, 2024
a9e24d9
Update deepmd/tf/descriptor/se_a.py
iProzd Apr 25, 2024
764cab7
Update deepmd/tf/descriptor/se_atten.py
iProzd Apr 25, 2024
0b9cea1
Update deepmd/tf/descriptor/se_atten.py
iProzd Apr 25, 2024
30a594a
Merge branch 'devel' into add_strip_dpa1
iProzd Apr 25, 2024
1e86b75
Update docs
iProzd Apr 25, 2024
f3056ee
resolve conversations
iProzd Apr 26, 2024
4e231e4
rf dpa2 with identity implement
iProzd Apr 28, 2024
b7af498
Update test_dpa2.py
iProzd Apr 30, 2024
61d9794
Add residual support
iProzd May 7, 2024
0e4fe1c
rm bn support
iProzd May 7, 2024
7a1095c
Add numpy impl for DPA2
iProzd May 7, 2024
0924505
Merge branch 'devel' into rf_dpa2_consist
iProzd May 7, 2024
b19a0e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
c952798
update argcheck
iProzd May 7, 2024
bd1d5d9
Fix uts
iProzd May 8, 2024
0ebadc2
Update test_permutation.py
iProzd May 8, 2024
fe6ed6e
fix uts
iProzd May 8, 2024
104527d
Update test_dpa2.py
iProzd May 8, 2024
cbba7a7
Merge branch 'devel' into rf_dpa2_consist
iProzd May 9, 2024
e1270bd
Update argcheck.py
iProzd May 9, 2024
ceaaa07
Update se_atten.py
iProzd May 9, 2024
d2bcdbf
Fix typo
iProzd May 9, 2024
385e1f7
revert 'nf' to 'nb'
iProzd May 9, 2024
d1e38ad
Update repformers.py
iProzd May 9, 2024
9d0ad7f
Update repformers.py
iProzd May 9, 2024
375c03e
mv symmetrization_op into static
iProzd May 9, 2024
2f280e6
Merge branch 'devel' into rf_dpa2_consist
iProzd May 9, 2024
d85eef0
Update test_descriptor_dpa2.py
iProzd May 9, 2024
244c8e5
Update dpa2.md
iProzd May 9, 2024
e9fe376
separate args for repinit and repformers
iProzd May 9, 2024
515c534
Update repformer_layer.py
iProzd May 9, 2024
bd25aa6
Update repformer_layer.py
iProzd May 9, 2024
f17f40f
Update repformer_layer.py
iProzd May 9, 2024
a8c89dc
Update repformer_layer.py
iProzd May 9, 2024
e329e30
Update repformer_layer.py
iProzd May 9, 2024
e223336
Merge branch 'devel' into rf_dpa2_consist
iProzd May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .dpa1 import (
DescrptDPA1,
)
from .dpa2 import (
DescrptDPA2,
)
from .hybrid import (
DescrptHybrid,
)
Expand All @@ -19,6 +22,7 @@
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptDPA2",
"DescrptHybrid",
"make_base_descriptor",
]
127 changes: 127 additions & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from abc import (
ABC,
abstractmethod,
)
from typing import (
Callable,
Dict,
List,
Optional,
Union,
)

import numpy as np

from deepmd.utils.env_mat_stat import (
StatItem,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.plugin import (
make_plugin_registry,
)

log = logging.getLogger(__name__)


class DescriptorBlock(ABC, make_plugin_registry("DescriptorBlock")):
"""The building block of descriptor.
Given the input descriptor, provide with the atomic coordinates,
atomic types and neighbor list, calculate the new descriptor.
"""

local_cluster = False

def __new__(cls, *args, **kwargs):
if cls is DescriptorBlock:
try:
descrpt_type = kwargs["type"]
except KeyError:
raise KeyError("the type of DescriptorBlock should be set by `type`")
cls = cls.get_class_by_type(descrpt_type)

Check warning on line 44 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L40-L44

Added lines #L40 - L44 were not covered by tests
return super().__new__(cls)

@abstractmethod
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
pass

Check warning on line 50 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L50

Added line #L50 was not covered by tests

@abstractmethod
def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
pass

Check warning on line 55 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L55

Added line #L55 was not covered by tests

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
pass

Check warning on line 60 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L60

Added line #L60 was not covered by tests

@abstractmethod
def get_ntypes(self) -> int:
"""Returns the number of element types."""
pass

Check warning on line 65 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L65

Added line #L65 was not covered by tests

@abstractmethod
def get_dim_out(self) -> int:
"""Returns the output dimension."""
pass

Check warning on line 70 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L70

Added line #L70 was not covered by tests

@abstractmethod
def get_dim_in(self) -> int:
"""Returns the input dimension."""
pass

Check warning on line 75 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L75

Added line #L75 was not covered by tests

@abstractmethod
def get_dim_emb(self) -> int:
"""Returns the embedding dimension."""
pass

Check warning on line 80 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L80

Added line #L80 was not covered by tests

def compute_input_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
path : Optional[DPPath]
The path to the stat file.

"""
raise NotImplementedError

Check warning on line 103 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L103

Added line #L103 was not covered by tests

def get_stats(self) -> Dict[str, StatItem]:
"""Get the statistics of the descriptor."""
raise NotImplementedError

Check warning on line 107 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L107

Added line #L107 was not covered by tests

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

Check warning on line 115 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L115

Added line #L115 was not covered by tests

@abstractmethod
def call(
self,
nlist: np.ndarray,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
extended_atype_embd: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
):
"""Calculate DescriptorBlock."""
pass

Check warning on line 127 in deepmd/dpmodel/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/descriptor.py#L127

Added line #L127 was not covered by tests
Loading