diff --git a/deepmd/dpmodel/descriptor/__init__.py b/deepmd/dpmodel/descriptor/__init__.py index bbc332588c..563fb9d149 100644 --- a/deepmd/dpmodel/descriptor/__init__.py +++ b/deepmd/dpmodel/descriptor/__init__.py @@ -2,6 +2,9 @@ from .dpa1 import ( DescrptDPA1, ) +from .dpa2 import ( + DescrptDPA2, +) from .hybrid import ( DescrptHybrid, ) @@ -19,6 +22,7 @@ "DescrptSeA", "DescrptSeR", "DescrptDPA1", + "DescrptDPA2", "DescrptHybrid", "make_base_descriptor", ] diff --git a/deepmd/dpmodel/descriptor/descriptor.py b/deepmd/dpmodel/descriptor/descriptor.py new file mode 100644 index 0000000000..444df1abf8 --- /dev/null +++ b/deepmd/dpmodel/descriptor/descriptor.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + Callable, + Dict, + List, + Optional, + Union, +) + +import numpy as np + +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.plugin import ( + make_plugin_registry, +) + +log = logging.getLogger(__name__) + + +class DescriptorBlock(ABC, make_plugin_registry("DescriptorBlock")): + """The building block of descriptor. + Given the input descriptor, provide with the atomic coordinates, + atomic types and neighbor list, calculate the new descriptor. + """ + + local_cluster = False + + def __new__(cls, *args, **kwargs): + if cls is DescriptorBlock: + try: + descrpt_type = kwargs["type"] + except KeyError: + raise KeyError("the type of DescriptorBlock should be set by `type`") + cls = cls.get_class_by_type(descrpt_type) + return super().__new__(cls) + + @abstractmethod + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + pass + + @abstractmethod + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + pass + + @abstractmethod + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + pass + + @abstractmethod + def get_ntypes(self) -> int: + """Returns the number of element types.""" + pass + + @abstractmethod + def get_dim_out(self) -> int: + """Returns the output dimension.""" + pass + + @abstractmethod + def get_dim_in(self) -> int: + """Returns the input dimension.""" + pass + + @abstractmethod + def get_dim_emb(self) -> int: + """Returns the embedding dimension.""" + pass + + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + raise NotImplementedError + + def get_stats(self) -> Dict[str, StatItem]: + """Get the statistics of the descriptor.""" + raise NotImplementedError + + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + + @abstractmethod + def call( + self, + nlist: np.ndarray, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + extended_atype_embd: Optional[np.ndarray] = None, + mapping: Optional[np.ndarray] = None, + ): + """Calculate DescriptorBlock.""" + pass diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 39d773e3c6..e5e79a8984 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -28,6 +28,7 @@ from typing import ( Any, + Callable, List, Optional, Tuple, @@ -49,6 +50,9 @@ from .base_descriptor import ( BaseDescriptor, ) +from .descriptor import ( + DescriptorBlock, +) def np_softmax(x, axis=-1): @@ -224,7 +228,7 @@ def __init__( attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - exclude_types: List[List[int]] = [], + exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, set_davg_zero: bool = False, activation_function: str = "tanh", @@ -256,6 +260,292 @@ def __init__( if ln_eps is None: ln_eps = 1e-5 + self.se_atten = DescrptBlockSeAtten( + rcut, + rcut_smth, + sel, + ntypes, + neuron=neuron, + axis_neuron=axis_neuron, + tebd_dim=tebd_dim, + tebd_input_mode=tebd_input_mode, + set_davg_zero=set_davg_zero, + attn=attn, + attn_layer=attn_layer, + attn_dotr=attn_dotr, + attn_mask=False, + activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + smooth=smooth_type_embedding, + type_one_side=type_one_side, + exclude_types=exclude_types, + env_protection=env_protection, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + ) + self.type_embedding = TypeEmbedNet( + ntypes=ntypes, + neuron=[tebd_dim], + padding=True, + activation_function="Linear", + precision=precision, + ) + self.tebd_dim = tebd_dim + self.concat_output_tebd = concat_output_tebd + self.trainable = trainable + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.se_atten.get_rcut() + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return self.se_atten.get_nsel() + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.se_atten.get_sel() + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.se_atten.get_ntypes() + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + ret = self.se_atten.get_dim_out() + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + return self.se_atten.dim_emb + + def mixed_types(self) -> bool: + """If true, the discriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the discriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return self.se_atten.mixed_types() + + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + return self.get_dim_emb() + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + raise NotImplementedError + + def set_stat_mean_and_stddev( + self, + mean: np.ndarray, + stddev: np.ndarray, + ) -> None: + self.se_atten.mean = mean + self.se_atten.stddev = stddev + + def call( + self, + coord_ext, + atype_ext, + nlist, + mapping: Optional[np.ndarray] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping from extended to lcoal region. not used by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + """ + del mapping + nf, nloc, nnei = nlist.shape + nall = coord_ext.reshape(nf, -1).shape[1] // 3 + # nf x nall x tebd_dim + atype_embd_ext = self.type_embedding.call()[atype_ext] + # nfnl x tebd_dim + atype_embd = atype_embd_ext[:, :nloc, :] + grrg, g2, h2, rot_mat, sw = self.se_atten( + nlist, + coord_ext, + atype_ext, + atype_embd_ext, + mapping=None, + ) + # nf x nloc x (ng x ng1 + tebd_dim) + if self.concat_output_tebd: + grrg = np.concatenate([grrg, atype_embd.reshape(nf, nloc, -1)], axis=-1) + return grrg, rot_mat, None, None, sw + + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + obj = self.se_atten + data = { + "@class": "Descriptor", + "type": "dpa1", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "attn": obj.attn, + "attn_layer": obj.attn_layer, + "attn_dotr": obj.attn_dotr, + "attn_mask": False, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "scaling_factor": obj.scaling_factor, + "normalize": obj.normalize, + "temperature": obj.temperature, + "trainable_ln": obj.trainable_ln, + "ln_eps": obj.ln_eps, + "smooth_type_embedding": obj.smooth, + "type_one_side": obj.type_one_side, + "concat_output_tebd": self.concat_output_tebd, + # make deterministic + "precision": np.dtype(PRECISION_DICT[obj.precision]).name, + "embeddings": obj.embeddings.serialize(), + "attention_layers": obj.dpa1_attention.serialize(), + "env_mat": obj.env_mat.serialize(), + "type_embedding": self.type_embedding.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "@variables": { + "davg": obj["davg"], + "dstd": obj["dstd"], + }, + ## to be updated when the options are supported. + "trainable": self.trainable, + "spin": None, + } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.embeddings_strip.serialize()}) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA1": + """Deserialize from dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None + obj = cls(**data) + + obj.se_atten["davg"] = variables["davg"] + obj.se_atten["dstd"] = variables["dstd"] + obj.se_atten.embeddings = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.se_atten.embeddings_strip = NetworkCollection.deserialize( + embeddings_strip + ) + obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) + obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( + attention_layers + ) + return obj + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True) + + +@DescriptorBlock.register("se_atten") +class DescrptBlockSeAtten(NativeOP, DescriptorBlock): + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], + ntypes: int, + neuron: List[int] = [25, 50, 100], + axis_neuron: int = 8, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + resnet_dt: bool = False, + type_one_side: bool = False, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + scaling_factor=1.0, + normalize: bool = True, + temperature: Optional[float] = None, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + smooth: bool = True, + ) -> None: self.rcut = rcut self.rcut_smth = rcut_smth if isinstance(sel, int): @@ -269,13 +559,13 @@ def __init__( self.tebd_dim = tebd_dim self.tebd_input_mode = tebd_input_mode self.resnet_dt = resnet_dt - self.trainable = trainable self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.type_one_side = type_one_side self.attn = attn self.attn_layer = attn_layer self.attn_dotr = attn_dotr + self.attn_mask = attn_mask self.exclude_types = exclude_types self.env_protection = env_protection self.set_davg_zero = set_davg_zero @@ -284,18 +574,10 @@ def __init__( self.scaling_factor = scaling_factor self.normalize = normalize self.temperature = temperature - self.smooth = smooth_type_embedding - self.concat_output_tebd = concat_output_tebd + self.smooth = smooth # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) - self.type_embedding = TypeEmbedNet( - ntypes=self.ntypes, - neuron=[self.tebd_dim], - padding=True, - activation_function="Linear", - precision=precision, - ) self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 if self.tebd_input_mode in ["concat"]: self.embd_input_dim = 1 + self.tebd_dim_input @@ -345,52 +627,55 @@ def __init__( wanted_shape = (self.ntypes, self.nnei, 4) self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) - self.davg = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) - self.dstd = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) self.orig_sel = self.sel + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_emb(self) -> int: + """Returns the output dimension of embedding.""" + return self.filter_neuron[-1] + def __setitem__(self, key, value): if key in ("avg", "data_avg", "davg"): - self.davg = value + self.mean = value elif key in ("std", "data_std", "dstd"): - self.dstd = value + self.stddev = value else: raise KeyError(key) def __getitem__(self, key): if key in ("avg", "data_avg", "davg"): - return self.davg + return self.mean elif key in ("std", "data_std", "dstd"): - return self.dstd + return self.stddev else: raise KeyError(key) - @property - def dim_out(self): - """Returns the output dimension of this descriptor.""" - return self.get_dim_out() - - def get_dim_out(self): - """Returns the output dimension of this descriptor.""" - return ( - self.neuron[-1] * self.axis_neuron + self.tebd_dim - if self.concat_output_tebd - else self.neuron[-1] * self.axis_neuron - ) - - def get_dim_emb(self): - """Returns the embedding (g2) dimension of this descriptor.""" - return self.neuron[-1] - - def get_rcut(self): - """Returns cutoff radius.""" - return self.rcut - - def get_sel(self): - """Returns cutoff radius.""" - return self.sel - - def mixed_types(self): + def mixed_types(self) -> bool: """If true, the discriptor 1. assumes total number of atoms aligned across frames; 2. requires a neighbor list that does not distinguish different atomic types. @@ -402,22 +687,40 @@ def mixed_types(self): """ return True - def share_params(self, base_class, shared_level, resume=False): - """ - Share the parameters of self to the base_class with shared_level during multitask training. - If not start from checkpoint (resume is False), - some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. - """ - raise NotImplementedError + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.filter_neuron[-1] * self.axis_neuron - def get_ntypes(self) -> int: - """Returns the number of element types.""" - return self.ntypes + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.tebd_dim - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): - """Update mean and stddev for descriptor elements.""" + @property + def dim_emb(self): + """Returns the output dimension of embedding.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.""" raise NotImplementedError + def get_stats(self): + """Get the statistics of the descriptor.""" + raise NotImplementedError + + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def cal_g( self, ss, @@ -441,53 +744,17 @@ def cal_g_strip( gg = self.embeddings_strip[embedding_idx].call(ss) return gg - def reinit_exclude( - self, - exclude_types: List[Tuple[int, int]] = [], - ): - self.exclude_types = exclude_types - self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) - def call( self, - coord_ext, - atype_ext, - nlist, + nlist: np.ndarray, + coord_ext: np.ndarray, + atype_ext: np.ndarray, + atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, ): - """Compute the descriptor. - - Parameters - ---------- - coord_ext - The extended coordinates of atoms. shape: nf x (nallx3) - atype_ext - The extended aotm types. shape: nf x nall - nlist - The neighbor list. shape: nf x nloc x nnei - mapping - The index mapping from extended to lcoal region. not used by this descriptor. - - Returns - ------- - descriptor - The descriptor. shape: nf x nloc x (ng x axis_neuron) - gr - The rotationally equivariant and permutationally invariant single particle - representation. shape: nf x nloc x ng x 3 - g2 - The rotationally invariant pair-partical representation. - this descriptor returns None - h2 - The rotationally equivariant pair-partical representation. - this descriptor returns None - sw - The smooth switch function. - """ - del mapping # nf x nloc x nnei x 4 - dmatrix, sw = self.env_mat.call( - coord_ext, atype_ext, nlist, self.davg, self.dstd + dmatrix, diff, sw = self.env_mat.call( + coord_ext, atype_ext, nlist, self.mean, self.stddev ) nf, nloc, nnei, _ = dmatrix.shape exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) @@ -497,10 +764,6 @@ def call( dmatrix = dmatrix.reshape(nf * nloc, nnei, 4) # nfnl x nnei x 1 sw = sw.reshape(nf * nloc, nnei, 1) - - # add type embedding into input - # nf x nall x tebd_dim - atype_embd_ext = self.type_embedding.call()[atype_ext] # nfnl x tebd_dim atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, -1) # nfnl x nnei x tebd_dim @@ -515,7 +778,6 @@ def call( atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( nf * nloc, nnei, self.tebd_dim ) - ng = self.neuron[-1] # nfnl x nnei exclude_mask = exclude_mask.reshape(nf * nloc, nnei) @@ -572,102 +834,13 @@ def call( grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( GLOBAL_NP_FLOAT_PRECISION ) - # nf x nloc x (ng x ng1 + tebd_dim) - if self.concat_output_tebd: - grrg = np.concatenate([grrg, atype_embd.reshape(nf, nloc, -1)], axis=-1) - gr = gr.reshape(nf, nloc, *gr.shape[1:]) - return grrg, gr[..., 1:], None, None, sw - - def serialize(self) -> dict: - """Serialize the descriptor to dict.""" - data = { - "@class": "Descriptor", - "type": "dpa1", - "@version": 1, - "rcut": self.rcut, - "rcut_smth": self.rcut_smth, - "sel": self.sel, - "ntypes": self.ntypes, - "neuron": self.neuron, - "axis_neuron": self.axis_neuron, - "tebd_dim": self.tebd_dim, - "tebd_input_mode": self.tebd_input_mode, - "set_davg_zero": self.set_davg_zero, - "attn": self.attn, - "attn_layer": self.attn_layer, - "attn_dotr": self.attn_dotr, - "attn_mask": False, - "activation_function": self.activation_function, - "resnet_dt": self.resnet_dt, - "scaling_factor": self.scaling_factor, - "normalize": self.normalize, - "temperature": self.temperature, - "trainable_ln": self.trainable_ln, - "ln_eps": self.ln_eps, - "smooth_type_embedding": self.smooth, - "type_one_side": self.type_one_side, - "concat_output_tebd": self.concat_output_tebd, - # make deterministic - "precision": np.dtype(PRECISION_DICT[self.precision]).name, - "embeddings": self.embeddings.serialize(), - "attention_layers": self.dpa1_attention.serialize(), - "env_mat": self.env_mat.serialize(), - "type_embedding": self.type_embedding.serialize(), - "exclude_types": self.exclude_types, - "env_protection": self.env_protection, - "@variables": { - "davg": self.davg, - "dstd": self.dstd, - }, - ## to be updated when the options are supported. - "trainable": True, - "spin": None, - } - if self.tebd_input_mode in ["strip"]: - data.update({"embeddings_strip": self.embeddings_strip.serialize()}) - return data - - @classmethod - def deserialize(cls, data: dict) -> "DescrptDPA1": - """Deserialize from dict.""" - data = data.copy() - check_version_compatibility(data.pop("@version"), 1, 1) - data.pop("@class") - data.pop("type") - variables = data.pop("@variables") - embeddings = data.pop("embeddings") - type_embedding = data.pop("type_embedding") - attention_layers = data.pop("attention_layers") - env_mat = data.pop("env_mat") - tebd_input_mode = data["tebd_input_mode"] - if tebd_input_mode in ["strip"]: - embeddings_strip = data.pop("embeddings_strip") - else: - embeddings_strip = None - obj = cls(**data) - - obj["davg"] = variables["davg"] - obj["dstd"] = variables["dstd"] - obj.embeddings = NetworkCollection.deserialize(embeddings) - if tebd_input_mode in ["strip"]: - obj.embeddings_strip = NetworkCollection.deserialize(embeddings_strip) - obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) - obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) - return obj - - @classmethod - def update_sel(cls, global_jdata: dict, local_jdata: dict): - """Update the selection and perform neighbor statistics. - - Parameters - ---------- - global_jdata : dict - The global data, containing the training section - local_jdata : dict - The local data refer to the current class - """ - local_jdata_cpy = local_jdata.copy() - return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True) + return ( + grrg.reshape(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), + gg.reshape(-1, nloc, self.nnei, self.filter_neuron[-1]), + dmatrix.reshape(-1, nloc, self.nnei, 4)[..., 1:], + gr[..., 1:].reshape(-1, nloc, self.filter_neuron[-1], 3), + sw, + ) class NeighborGatedAttention(NativeOP): @@ -850,7 +1023,7 @@ def call( sw: Optional[np.ndarray] = None, ): residual = x - x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x, _ = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) x = residual + x x = self.attn_layer_norm(x) return x @@ -903,6 +1076,7 @@ def __init__( nnei: int, embed_dim: int, hidden_dim: int, + num_heads: int = 1, dotr: bool = False, do_mask: bool = False, scaling_factor: float = 1.0, @@ -912,11 +1086,14 @@ def __init__( smooth: bool = True, precision: str = DEFAULT_PRECISION, ): - """Construct a neighbor-wise attention net.""" + """Construct a multi-head neighbor-wise attention net.""" super().__init__() + assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" self.nnei = nnei self.embed_dim = embed_dim self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads self.dotr = dotr self.do_mask = do_mask self.bias = bias @@ -924,10 +1101,11 @@ def __init__( self.scaling_factor = scaling_factor self.temperature = temperature self.precision = precision - if temperature is None: - self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 - else: - self.scaling = temperature + self.scaling = ( + (self.head_dim * scaling_factor) ** -0.5 + if temperature is None + else temperature + ) self.normalize = normalize self.in_proj = NativeLayer( embed_dim, @@ -948,41 +1126,55 @@ def call(self, query, nei_mask, input_r=None, sw=None, attnw_shift=20.0): # Linear projection q, k, v = np.split(self.in_proj(query), 3, axis=-1) # Reshape and normalize - q = q.reshape(-1, self.nnei, self.hidden_dim) - k = k.reshape(-1, self.nnei, self.hidden_dim) - v = v.reshape(-1, self.nnei, self.hidden_dim) + # (nf x nloc) x num_heads x nnei x head_dim + q = q.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) + k = k.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) + v = v.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) if self.normalize: q = np_normalize(q, axis=-1) k = np_normalize(k, axis=-1) v = np_normalize(v, axis=-1) q = q * self.scaling # Attention weights - attn_weights = q @ k.transpose(0, 2, 1) + # (nf x nloc) x num_heads x nnei x nnei + attn_weights = q @ k.transpose(0, 1, 3, 2) nei_mask = nei_mask.reshape(-1, self.nnei) if self.smooth: - sw = sw.reshape(-1, self.nnei) - attn_weights = (attn_weights + attnw_shift) * sw[:, None, :] * sw[ - :, :, None + sw = sw.reshape(-1, 1, self.nnei) + attn_weights = (attn_weights + attnw_shift) * sw[:, :, :, None] * sw[ + :, :, None, : ] - attnw_shift else: - attn_weights = np.where(nei_mask[:, None, :], attn_weights, -np.inf) + attn_weights = np.where(nei_mask[:, None, None, :], attn_weights, -np.inf) attn_weights = np_softmax(attn_weights, axis=-1) - attn_weights = np.where(nei_mask[:, :, None], attn_weights, 0.0) + attn_weights = np.where(nei_mask[:, None, :, None], attn_weights, 0.0) if self.smooth: - attn_weights = attn_weights * sw[:, None, :] * sw[:, :, None] + attn_weights = attn_weights * sw[:, :, :, None] * sw[:, :, None, :] if self.dotr: - angular_weight = input_r @ input_r.transpose(0, 2, 1) + angular_weight = (input_r @ input_r.transpose(0, 2, 1)).reshape( + -1, 1, self.nnei, self.nnei + ) attn_weights = attn_weights * angular_weight # Output projection + # (nf x nloc) x num_heads x nnei x head_dim o = attn_weights @ v + # (nf x nloc) x nnei x (num_heads x head_dim) + o = o.transpose(0, 2, 1, 3).reshape(-1, self.nnei, self.hidden_dim) output = self.out_proj(o) - return output + return output, attn_weights def serialize(self): return { "nnei": self.nnei, "embed_dim": self.embed_dim, "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, "dotr": self.dotr, "do_mask": self.do_mask, "scaling_factor": self.scaling_factor, diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py new file mode 100644 index 0000000000..b76530cf2f --- /dev/null +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -0,0 +1,675 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from deepmd.dpmodel.utils.network import ( + Identity, + NativeLayer, +) +from deepmd.dpmodel.utils.nlist import ( + build_multiple_neighbor_list, + get_multiple_nlist_key, +) +from deepmd.dpmodel.utils.type_embed import ( + TypeEmbedNet, +) +from deepmd.dpmodel.utils.update_sel import ( + UpdateSel, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + +from typing import ( + List, + Optional, + Tuple, +) + +from deepmd.dpmodel import ( + NativeOP, +) +from deepmd.dpmodel.utils import ( + EnvMat, + NetworkCollection, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .dpa1 import ( + DescrptBlockSeAtten, +) +from .repformers import ( + DescrptBlockRepformers, + RepformerLayer, +) + + +@BaseDescriptor.register("dpa2") +class DescrptDPA2(NativeOP, BaseDescriptor): + def __init__( + self, + # args for repinit + ntypes: int, + repinit_rcut: float, + repinit_rcut_smth: float, + repinit_nsel: int, + repformer_rcut: float, + repformer_rcut_smth: float, + repformer_nsel: int, + # kwargs for repinit + repinit_neuron: List[int] = [25, 50, 100], + repinit_axis_neuron: int = 16, + repinit_tebd_dim: int = 8, + repinit_tebd_input_mode: str = "concat", + repinit_set_davg_zero: bool = True, + repinit_activation_function="tanh", + repinit_resnet_dt: bool = False, + repinit_type_one_side: bool = False, + # kwargs for repformer + repformer_nlayers: int = 3, + repformer_g1_dim: int = 128, + repformer_g2_dim: int = 16, + repformer_axis_neuron: int = 4, + repformer_direct_dist: bool = False, + repformer_update_g1_has_conv: bool = True, + repformer_update_g1_has_drrd: bool = True, + repformer_update_g1_has_grrg: bool = True, + repformer_update_g1_has_attn: bool = True, + repformer_update_g2_has_g1g1: bool = True, + repformer_update_g2_has_attn: bool = True, + repformer_update_h2: bool = False, + repformer_attn1_hidden: int = 64, + repformer_attn1_nhead: int = 4, + repformer_attn2_hidden: int = 16, + repformer_attn2_nhead: int = 4, + repformer_attn2_has_gate: bool = False, + repformer_activation_function: str = "tanh", + repformer_update_style: str = "res_avg", + repformer_update_residual: float = 0.001, + repformer_update_residual_init: str = "norm", + repformer_set_davg_zero: bool = True, + repformer_trainable_ln: bool = True, + repformer_ln_eps: Optional[float] = 1e-5, + # kwargs for descriptor + concat_output_tebd: bool = True, + precision: str = "float64", + smooth: bool = True, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, + trainable: bool = True, + seed: Optional[int] = None, + add_tebd_to_repinit_out: bool = False, + ): + r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492. + + Parameters + ---------- + repinit_rcut : float + (Used in the repinit block.) + The cut-off radius. + repinit_rcut_smth : float + (Used in the repinit block.) + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. + repinit_nsel : int + (Used in the repinit block.) + Maximally possible number of selected neighbors. + repinit_neuron : list, optional + (Used in the repinit block.) + Number of neurons in each hidden layers of the embedding net. + When two layers are of the same size or one layer is twice as large as the previous layer, + a skip connection is built. + repinit_axis_neuron : int, optional + (Used in the repinit block.) + Size of the submatrix of G (embedding matrix). + repinit_tebd_dim : int, optional + (Used in the repinit block.) + The dimension of atom type embedding. + repinit_tebd_input_mode : str, optional + (Used in the repinit block.) + The input mode of the type embedding. Supported modes are ['concat', 'strip']. + repinit_set_davg_zero : bool, optional + (Used in the repinit block.) + Set the normalization average to zero. + repinit_activation_function : str, optional + (Used in the repinit block.) + The activation function in the embedding net. + repinit_resnet_dt : bool, optional + (Used in the repinit block.) + Whether to use a "Timestep" in the skip connection. + repinit_type_one_side : bool, optional + (Used in the repinit block.) + Whether to use one-side type embedding. + repformer_rcut : float + (Used in the repformer block.) + The cut-off radius. + repformer_rcut_smth : float + (Used in the repformer block.) + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. + repformer_nsel : int + (Used in the repformer block.) + Maximally possible number of selected neighbors. + repformer_nlayers : int, optional + (Used in the repformer block.) + Number of repformer layers. + repformer_g1_dim : int, optional + (Used in the repformer block.) + Dimension of the first graph convolution layer. + repformer_g2_dim : int, optional + (Used in the repformer block.) + Dimension of the second graph convolution layer. + repformer_axis_neuron : int, optional + (Used in the repformer block.) + Size of the submatrix of G (embedding matrix). + repformer_direct_dist : bool, optional + (Used in the repformer block.) + Whether to use direct distance information (1/r term) in the repformer block. + repformer_update_g1_has_conv : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with convolution term. + repformer_update_g1_has_drrd : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the drrd term. + repformer_update_g1_has_grrg : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the grrg term. + repformer_update_g1_has_attn : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the localized self-attention. + repformer_update_g2_has_g1g1 : bool, optional + (Used in the repformer block.) + Whether to update the g2 rep with the g1xg1 term. + repformer_update_g2_has_attn : bool, optional + (Used in the repformer block.) + Whether to update the g2 rep with the gated self-attention. + repformer_update_h2 : bool, optional + (Used in the repformer block.) + Whether to update the h2 rep. + repformer_attn1_hidden : int, optional + (Used in the repformer block.) + The hidden dimension of localized self-attention to update the g1 rep. + repformer_attn1_nhead : int, optional + (Used in the repformer block.) + The number of heads in localized self-attention to update the g1 rep. + repformer_attn2_hidden : int, optional + (Used in the repformer block.) + The hidden dimension of gated self-attention to update the g2 rep. + repformer_attn2_nhead : int, optional + (Used in the repformer block.) + The number of heads in gated self-attention to update the g2 rep. + repformer_attn2_has_gate : bool, optional + (Used in the repformer block.) + Whether to use gate in the gated self-attention to update the g2 rep. + repformer_activation_function : str, optional + (Used in the repformer block.) + The activation function in the embedding net. + repformer_update_style : str, optional + (Used in the repformer block.) + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `repformer_update_residual` + and `repformer_update_residual_init`. + repformer_update_residual : float, optional + (Used in the repformer block.) + When update using residual mode, the initial std of residual vector weights. + repformer_update_residual_init : str, optional + (Used in the repformer block.) + When update using residual mode, the initialization mode of residual vector weights. + repformer_set_davg_zero : bool, optional + (Used in the repformer block.) + Set the normalization average to zero. + repformer_trainable_ln : bool, optional + (Used in the repformer block.) + Whether to use trainable shift and scale weights in layer normalization. + repformer_ln_eps : float, optional + (Used in the repformer block.) + The epsilon value for layer normalization. + concat_output_tebd : bool, optional + Whether to concat type embedding at the output of the descriptor. + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : List[List[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable : bool, optional + If the parameters are trainable. + seed : int, optional + (Unused yet) Random seed for parameter initialization. + add_tebd_to_repinit_out : bool, optional + Whether to add type embedding to the output representation from repinit before inputting it into repformer. + + Returns + ------- + descriptor: torch.Tensor + the descriptor of shape nf x nloc x g1_dim. + invariant single-atom representation. + g2: torch.Tensor + invariant pair-atom representation. + h2: torch.Tensor + equivariant pair-atom representation. + rot_mat: torch.Tensor + rotation matrix for equivariant fittings + sw: torch.Tensor + The switch function for decaying inverse distance. + + """ + # to keep consistent with default value in this backends + if repformer_ln_eps is None: + repformer_ln_eps = 1e-5 + self.repinit = DescrptBlockSeAtten( + repinit_rcut, + repinit_rcut_smth, + repinit_nsel, + ntypes, + attn_layer=0, + neuron=repinit_neuron, + axis_neuron=repinit_axis_neuron, + tebd_dim=repinit_tebd_dim, + tebd_input_mode=repinit_tebd_input_mode, + set_davg_zero=repinit_set_davg_zero, + exclude_types=exclude_types, + env_protection=env_protection, + activation_function=repinit_activation_function, + precision=precision, + resnet_dt=repinit_resnet_dt, + smooth=smooth, + type_one_side=repinit_type_one_side, + ) + self.repformers = DescrptBlockRepformers( + repformer_rcut, + repformer_rcut_smth, + repformer_nsel, + ntypes, + nlayers=repformer_nlayers, + g1_dim=repformer_g1_dim, + g2_dim=repformer_g2_dim, + axis_neuron=repformer_axis_neuron, + direct_dist=repformer_direct_dist, + update_g1_has_conv=repformer_update_g1_has_conv, + update_g1_has_drrd=repformer_update_g1_has_drrd, + update_g1_has_grrg=repformer_update_g1_has_grrg, + update_g1_has_attn=repformer_update_g1_has_attn, + update_g2_has_g1g1=repformer_update_g2_has_g1g1, + update_g2_has_attn=repformer_update_g2_has_attn, + update_h2=repformer_update_h2, + attn1_hidden=repformer_attn1_hidden, + attn1_nhead=repformer_attn1_nhead, + attn2_hidden=repformer_attn2_hidden, + attn2_nhead=repformer_attn2_nhead, + attn2_has_gate=repformer_attn2_has_gate, + activation_function=repformer_activation_function, + update_style=repformer_update_style, + update_residual=repformer_update_residual, + update_residual_init=repformer_update_residual_init, + set_davg_zero=repformer_set_davg_zero, + smooth=smooth, + exclude_types=exclude_types, + env_protection=env_protection, + precision=precision, + trainable_ln=repformer_trainable_ln, + ln_eps=repformer_ln_eps, + ) + self.type_embedding = TypeEmbedNet( + ntypes=ntypes, + neuron=[repinit_tebd_dim], + padding=True, + activation_function="Linear", + precision=precision, + ) + self.concat_output_tebd = concat_output_tebd + self.precision = precision + self.smooth = smooth + self.exclude_types = exclude_types + self.env_protection = env_protection + self.trainable = trainable + self.add_tebd_to_repinit_out = add_tebd_to_repinit_out + + if self.repinit.dim_out == self.repformers.dim_in: + self.g1_shape_tranform = Identity() + else: + self.g1_shape_tranform = NativeLayer( + self.repinit.dim_out, + self.repformers.dim_in, + bias=False, + precision=precision, + ) + self.tebd_transform = None + if self.add_tebd_to_repinit_out: + self.tebd_transform = NativeLayer( + repinit_tebd_dim, + self.repformers.dim_in, + bias=False, + precision=precision, + ) + assert self.repinit.rcut > self.repformers.rcut + assert self.repinit.sel[0] > self.repformers.sel[0] + + self.tebd_dim = repinit_tebd_dim + self.rcut = self.repinit.get_rcut() + self.ntypes = ntypes + self.sel = self.repinit.sel + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + ret = self.repformers.dim_out + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + """Returns the embedding dimension of this descriptor.""" + return self.repformers.dim_emb + + def mixed_types(self) -> bool: + """If true, the discriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the discriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + raise NotImplementedError + + def call( + self, + coord_ext: np.ndarray, + atype_ext: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, maps extended region index to local region. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + nframes, nloc, nnei = nlist.shape + nall = coord_ext.reshape(nframes, -1).shape[1] // 3 + # nlists + nlist_dict = build_multiple_neighbor_list( + coord_ext, + nlist, + [self.repformers.get_rcut(), self.repinit.get_rcut()], + [self.repformers.get_nsel(), self.repinit.get_nsel()], + ) + # repinit + g1_ext = self.type_embedding.call()[atype_ext] + g1_inp = g1_ext[:, :nloc, :] + g1, _, _, _, _ = self.repinit( + nlist_dict[ + get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel()) + ], + coord_ext, + atype_ext, + g1_ext, + mapping, + ) + # linear to change shape + g1 = self.g1_shape_tranform(g1) + if self.add_tebd_to_repinit_out: + assert self.tebd_transform is not None + g1 = g1 + self.tebd_transform(g1_inp) + # mapping g1 + assert mapping is not None + mapping_ext = np.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1])) + g1_ext = np.take_along_axis(g1, mapping_ext, axis=1) + # repformer + g1, g2, h2, rot_mat, sw = self.repformers( + nlist_dict[ + get_multiple_nlist_key( + self.repformers.get_rcut(), self.repformers.get_nsel() + ) + ], + coord_ext, + atype_ext, + g1_ext, + mapping, + ) + if self.concat_output_tebd: + g1 = np.concatenate([g1, g1_inp], axis=-1) + return g1, rot_mat, g2, h2, sw + + def serialize(self) -> dict: + repinit = self.repinit + repformers = self.repformers + data = { + "@class": "Descriptor", + "type": "dpa2", + "@version": 1, + "ntypes": self.ntypes, + "repinit_rcut": repinit.rcut, + "repinit_rcut_smth": repinit.rcut_smth, + "repinit_nsel": repinit.sel, + "repformer_rcut": repformers.rcut, + "repformer_rcut_smth": repformers.rcut_smth, + "repformer_nsel": repformers.sel, + "repinit_neuron": repinit.neuron, + "repinit_axis_neuron": repinit.axis_neuron, + "repinit_tebd_dim": repinit.tebd_dim, + "repinit_tebd_input_mode": repinit.tebd_input_mode, + "repinit_set_davg_zero": repinit.set_davg_zero, + "repinit_activation_function": repinit.activation_function, + "repinit_resnet_dt": repinit.resnet_dt, + "repinit_type_one_side": repinit.type_one_side, + "repformer_nlayers": repformers.nlayers, + "repformer_g1_dim": repformers.g1_dim, + "repformer_g2_dim": repformers.g2_dim, + "repformer_axis_neuron": repformers.axis_neuron, + "repformer_direct_dist": repformers.direct_dist, + "repformer_update_g1_has_conv": repformers.update_g1_has_conv, + "repformer_update_g1_has_drrd": repformers.update_g1_has_drrd, + "repformer_update_g1_has_grrg": repformers.update_g1_has_grrg, + "repformer_update_g1_has_attn": repformers.update_g1_has_attn, + "repformer_update_g2_has_g1g1": repformers.update_g2_has_g1g1, + "repformer_update_g2_has_attn": repformers.update_g2_has_attn, + "repformer_update_h2": repformers.update_h2, + "repformer_attn1_hidden": repformers.attn1_hidden, + "repformer_attn1_nhead": repformers.attn1_nhead, + "repformer_attn2_hidden": repformers.attn2_hidden, + "repformer_attn2_nhead": repformers.attn2_nhead, + "repformer_attn2_has_gate": repformers.attn2_has_gate, + "repformer_activation_function": repformers.activation_function, + "repformer_update_style": repformers.update_style, + "repformer_set_davg_zero": repformers.set_davg_zero, + "repformer_trainable_ln": repformers.trainable_ln, + "repformer_ln_eps": repformers.ln_eps, + "concat_output_tebd": self.concat_output_tebd, + "precision": self.precision, + "smooth": self.smooth, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "trainable": self.trainable, + "add_tebd_to_repinit_out": self.add_tebd_to_repinit_out, + "type_embedding": self.type_embedding.serialize(), + "g1_shape_tranform": self.g1_shape_tranform.serialize(), + } + if self.add_tebd_to_repinit_out: + data.update( + { + "tebd_transform": self.tebd_transform.serialize(), + } + ) + repinit_variable = { + "embeddings": repinit.embeddings.serialize(), + "env_mat": EnvMat(repinit.rcut, repinit.rcut_smth).serialize(), + "@variables": { + "davg": repinit["davg"], + "dstd": repinit["dstd"], + }, + } + if repinit.tebd_input_mode in ["strip"]: + repinit_variable.update( + {"embeddings_strip": repinit.embeddings_strip.serialize()} + ) + repformers_variable = { + "g2_embd": repformers.g2_embd.serialize(), + "repformer_layers": [layer.serialize() for layer in repformers.layers], + "env_mat": EnvMat(repformers.rcut, repformers.rcut_smth).serialize(), + "@variables": { + "davg": repformers["davg"], + "dstd": repformers["dstd"], + }, + } + data.update( + { + "repinit": repinit_variable, + "repformers": repformers_variable, + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA2": + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + repinit_variable = data.pop("repinit").copy() + repformers_variable = data.pop("repformers").copy() + type_embedding = data.pop("type_embedding") + g1_shape_tranform = data.pop("g1_shape_tranform") + tebd_transform = data.pop("tebd_transform", None) + add_tebd_to_repinit_out = data["add_tebd_to_repinit_out"] + obj = cls(**data) + obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) + if add_tebd_to_repinit_out: + assert isinstance(tebd_transform, dict) + obj.tebd_transform = NativeLayer.deserialize(tebd_transform) + if obj.repinit.dim_out != obj.repformers.dim_in: + obj.g1_shape_tranform = NativeLayer.deserialize(g1_shape_tranform) + + # deserialize repinit + statistic_repinit = repinit_variable.pop("@variables") + env_mat = repinit_variable.pop("env_mat") + tebd_input_mode = data["repinit_tebd_input_mode"] + obj.repinit.embeddings = NetworkCollection.deserialize( + repinit_variable.pop("embeddings") + ) + if tebd_input_mode in ["strip"]: + obj.repinit.embeddings_strip = NetworkCollection.deserialize( + repinit_variable.pop("embeddings_strip") + ) + obj.repinit["davg"] = statistic_repinit["davg"] + obj.repinit["dstd"] = statistic_repinit["dstd"] + + # deserialize repformers + statistic_repformers = repformers_variable.pop("@variables") + env_mat = repformers_variable.pop("env_mat") + repformer_layers = repformers_variable.pop("repformer_layers") + obj.repformers.g2_embd = NativeLayer.deserialize( + repformers_variable.pop("g2_embd") + ) + obj.repformers["davg"] = statistic_repformers["davg"] + obj.repformers["dstd"] = statistic_repformers["dstd"] + obj.repformers.layers = [ + RepformerLayer.deserialize(layer) for layer in repformer_layers + ] + return obj + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + update_sel = UpdateSel() + local_jdata_cpy = update_sel.update_one_sel( + global_jdata, + local_jdata_cpy, + True, + rcut_key="repinit_rcut", + sel_key="repinit_nsel", + ) + local_jdata_cpy = update_sel.update_one_sel( + global_jdata, + local_jdata_cpy, + True, + rcut_key="repformer_rcut", + sel_key="repformer_nsel", + ) + return local_jdata_cpy diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py new file mode 100644 index 0000000000..f0275b44b1 --- /dev/null +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -0,0 +1,1637 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from deepmd.dpmodel.utils.network import ( + LayerNorm, + NativeLayer, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + +from typing import ( + Callable, + List, + Optional, + Tuple, + Union, +) + +from deepmd.dpmodel import ( + PRECISION_DICT, + NativeOP, +) +from deepmd.dpmodel.utils import ( + EnvMat, + PairExcludeMask, +) +from deepmd.dpmodel.utils.network import ( + get_activation_fn, +) + +from .descriptor import ( + DescriptorBlock, +) +from .dpa1 import ( + np_softmax, +) + + +@DescriptorBlock.register("se_repformer") +@DescriptorBlock.register("se_uni") +class DescrptBlockRepformers(NativeOP, DescriptorBlock): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + nlayers: int = 3, + g1_dim=128, + g2_dim=16, + axis_neuron: int = 4, + direct_dist: bool = False, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation_function: str = "tanh", + update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", + set_davg_zero: bool = True, + smooth: bool = True, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, + precision: str = "float64", + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + ): + r""" + The repformer descriptor block. + + Parameters + ---------- + rcut : float + The cut-off radius. + rcut_smth : float + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. + sel : int + Maximally possible number of selected neighbors. + ntypes : int + Number of element types + nlayers : int, optional + Number of repformer layers. + g1_dim : int, optional + Dimension of the first graph convolution layer. + g2_dim : int, optional + Dimension of the second graph convolution layer. + axis_neuron : int, optional + Size of the submatrix of G (embedding matrix). + direct_dist : bool, optional + Whether to use direct distance information (1/r term) in the repformer block. + update_g1_has_conv : bool, optional + Whether to update the g1 rep with convolution term. + update_g1_has_drrd : bool, optional + Whether to update the g1 rep with the drrd term. + update_g1_has_grrg : bool, optional + Whether to update the g1 rep with the grrg term. + update_g1_has_attn : bool, optional + Whether to update the g1 rep with the localized self-attention. + update_g2_has_g1g1 : bool, optional + Whether to update the g2 rep with the g1xg1 term. + update_g2_has_attn : bool, optional + Whether to update the g2 rep with the gated self-attention. + update_h2 : bool, optional + Whether to update the h2 rep. + attn1_hidden : int, optional + The hidden dimension of localized self-attention to update the g1 rep. + attn1_nhead : int, optional + The number of heads in localized self-attention to update the g1 rep. + attn2_hidden : int, optional + The hidden dimension of gated self-attention to update the g2 rep. + attn2_nhead : int, optional + The number of heads in gated self-attention to update the g2 rep. + attn2_has_gate : bool, optional + Whether to use gate in the gated self-attention to update the g2 rep. + activation_function : str, optional + The activation function in the embedding net. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + set_davg_zero : bool, optional + Set the normalization average to zero. + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : List[List[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable_ln : bool, optional + Whether to use trainable shift and scale weights in layer normalization. + ln_eps : float, optional + The epsilon value for layer normalization. + """ + super().__init__() + self.rcut = rcut + self.rcut_smth = rcut_smth + self.ntypes = ntypes + self.nlayers = nlayers + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 # use full descriptor. + assert len(sel) == 1 + self.sel = sel + self.sec = self.sel + self.split_sel = self.sel + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + self.g1_dim = g1_dim + self.g2_dim = g2_dim + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_attn = update_g1_has_attn + self.update_g2_has_g1g1 = update_g2_has_g1g1 + self.update_g2_has_attn = update_g2_has_attn + self.update_h2 = update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_has_gate = attn2_has_gate + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.activation_function = activation_function + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.direct_dist = direct_dist + self.act = get_activation_fn(self.activation_function) + self.smooth = smooth + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + self.env_protection = env_protection + self.precision = precision + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.epsilon = 1e-4 + + self.g2_embd = NativeLayer(1, self.g2_dim, precision=precision) + layers = [] + for ii in range(nlayers): + layers.append( + RepformerLayer( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.g1_dim, + self.g2_dim, + axis_neuron=self.axis_neuron, + update_chnnl_2=(ii != nlayers - 1), + update_g1_has_conv=self.update_g1_has_conv, + update_g1_has_drrd=self.update_g1_has_drrd, + update_g1_has_grrg=self.update_g1_has_grrg, + update_g1_has_attn=self.update_g1_has_attn, + update_g2_has_g1g1=self.update_g2_has_g1g1, + update_g2_has_attn=self.update_g2_has_attn, + update_h2=self.update_h2, + attn1_hidden=self.attn1_hidden, + attn1_nhead=self.attn1_nhead, + attn2_has_gate=self.attn2_has_gate, + attn2_hidden=self.attn2_hidden, + attn2_nhead=self.attn2_nhead, + activation_function=self.activation_function, + update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, + smooth=self.smooth, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + precision=precision, + ) + ) + self.layers = layers + + wanted_shape = (self.ntypes, self.nnei, 4) + self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) + self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.orig_sel = self.sel + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_emb(self) -> int: + """Returns the embedding dimension g2.""" + return self.g2_dim + + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the discriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the discriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.g1_dim + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.g1_dim + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.""" + raise NotImplementedError + + def get_stats(self): + """Get the statistics of the descriptor.""" + raise NotImplementedError + + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def call( + self, + nlist: np.ndarray, + coord_ext: np.ndarray, + atype_ext: np.ndarray, + atype_embd_ext: Optional[np.ndarray] = None, + mapping: Optional[np.ndarray] = None, + ): + exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + nlist = nlist * exclude_mask + # nf x nloc x nnei x 4 + dmatrix, diff, sw = self.env_mat.call( + coord_ext, atype_ext, nlist, self.mean, self.stddev + ) + nf, nloc, nnei, _ = dmatrix.shape + # nf x nloc x nnei + nlist_mask = nlist != -1 + # nf x nloc x nnei + sw = sw.reshape(nf, nloc, nnei) + sw = np.where(nlist_mask, sw, 0.0) + # nf x nloc x tebd_dim + atype_embd = atype_embd_ext[:, :nloc, :] + assert list(atype_embd.shape) == [nf, nloc, self.g1_dim] + + g1 = self.act(atype_embd) + # nf x nloc x nnei x 1, nf x nloc x nnei x 3 + if not self.direct_dist: + g2, h2 = np.split(dmatrix, [1], axis=-1) + else: + g2, h2 = np.linalg.norm(diff, axis=-1, keepdims=True), diff + g2 = g2 / self.rcut + h2 = h2 / self.rcut + # nf x nloc x nnei x ng2 + g2 = self.act(self.g2_embd(g2)) + # set all padding positions to index of 0 + # if a neighbor is real or not is indicated by nlist_mask + nlist[nlist == -1] = 0 + # nf x nall x ng1 + mapping = np.tile(mapping.reshape(nf, -1, 1), (1, 1, self.g1_dim)) + for idx, ll in enumerate(self.layers): + # g1: nf x nloc x ng1 + # g1_ext: nf x nall x ng1 + g1_ext = np.take_along_axis(g1, mapping, axis=1) + g1, g2, h2 = ll.call( + g1_ext, + g2, + h2, + nlist, + nlist_mask, + sw, + ) + + # nf x nloc x 3 x ng2 + h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon) + # (nf x nloc) x ng2 x 3 + rot_mat = np.transpose(h2g2, (0, 1, 3, 2)) + return g1, g2, h2, rot_mat.reshape(-1, nloc, self.dim_emb, 3), sw + + +# translated by GPT and modified +def get_residual( + _dim: int, + _scale: float, + _mode: str = "norm", + trainable: bool = True, + precision: str = "float64", +) -> np.ndarray: + """ + Get residual tensor for one update vector. + + Parameters + ---------- + _dim : int + The dimension of the update vector. + _scale + The initial scale of the residual tensor. See `_mode` for details. + _mode + The mode of residual initialization for the residual tensor. + - "norm" (default): init residual using normal with `_scale` std. + - "const": init residual using element-wise constants of `_scale`. + trainable + Whether the residual tensor is trainable. + precision + The precision of the residual tensor. + """ + residual = np.zeros(_dim, dtype=PRECISION_DICT[precision]) + rng = np.random.default_rng() + if trainable: + if _mode == "norm": + residual = rng.normal(scale=_scale, size=_dim).astype( + PRECISION_DICT[precision] + ) + elif _mode == "const": + residual.fill(_scale) + else: + raise RuntimeError(f"Unsupported initialization mode '{_mode}'!") + return residual + + +def _make_nei_g1( + g1_ext: np.ndarray, + nlist: np.ndarray, +) -> np.ndarray: + """ + Make neighbor-wise atomic invariant rep. + + Parameters + ---------- + g1_ext + Extended atomic invariant rep, with shape [nf, nall, ng1]. + nlist + Neighbor list, with shape [nf, nloc, nnei]. + + Returns + ------- + gg1: np.ndarray + Neighbor-wise atomic invariant rep, with shape [nf, nloc, nnei, ng1]. + """ + # nlist: nf x nloc x nnei + nf, nloc, nnei = nlist.shape + # g1_ext: nf x nall x ng1 + ng1 = g1_ext.shape[-1] + # index: nf x (nloc x nnei) x ng1 + index = np.tile(nlist.reshape(nf, nloc * nnei, 1), (1, 1, ng1)) + # gg1 : nf x (nloc x nnei) x ng1 + gg1 = np.take_along_axis(g1_ext, index, axis=1) + # gg1 : nf x nloc x nnei x ng1 + gg1 = gg1.reshape(nf, nloc, nnei, ng1) + return gg1 + + +def _apply_nlist_mask( + gg: np.ndarray, + nlist_mask: np.ndarray, +) -> np.ndarray: + """ + Apply nlist mask to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape [nf, nloc, nnei, d]. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape [nf, nloc, nnei]. + """ + masked_gg = np.where(nlist_mask[:, :, :, None], gg, 0.0) + return masked_gg + + +def _apply_switch(gg: np.ndarray, sw: np.ndarray) -> np.ndarray: + """ + Apply switch function to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape [nf, nloc, nnei, d]. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape [nf, nloc, nnei]. + """ + # gg: nf x nloc x nnei x d + # sw: nf x nloc x nnei + return gg * sw[:, :, :, None] + + +def _cal_hg( + g: np.ndarray, + h: np.ndarray, + nlist_mask: np.ndarray, + sw: np.ndarray, + smooth: bool = True, + epsilon: float = 1e-4, +) -> np.ndarray: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + g + Neighbor-wise/Pair-wise invariant rep tensors, with shape [nf, nloc, nnei, ng]. + h + Neighbor-wise/Pair-wise equivariant rep tensors, with shape [nf, nloc, nnei, 3]. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape [nf, nloc, nnei]. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape [nf, nloc, nnei]. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape [nf, nloc, 3, ng]. + """ + # g: nf x nloc x nnei x ng + # h: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nf, nloc, nnei, _ = g.shape + ng = g.shape[-1] + # nf x nloc x nnei x ng + g = _apply_nlist_mask(g, nlist_mask) + if not smooth: + # nf x nloc + invnnei = 1.0 / (epsilon + np.sum(nlist_mask, axis=-1)) + # nf x nloc x 1 x 1 + invnnei = invnnei[:, :, np.newaxis, np.newaxis] + else: + g = _apply_switch(g, sw) + invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1, 1), dtype=g.dtype) + # nf x nloc x 3 x ng + hg = np.matmul(np.transpose(h, axes=(0, 1, 3, 2)), g) * invnnei + return hg + + +def _cal_grrg(hg: np.ndarray, axis_neuron: int) -> np.ndarray: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + hg + The transposed rotation matrix, with shape [nf, nloc, 3, ng]. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape [nf, nloc, (axis_neuron * ng)]. + """ + # nf x nloc x 3 x ng + nf, nloc, _, ng = hg.shape + # nf x nloc x 3 x axis + hgm = np.split(hg, [axis_neuron], axis=-1)[0] + # nf x nloc x axis_neuron x ng + grrg = np.matmul(np.transpose(hgm, axes=(0, 1, 3, 2)), hg) / (3.0**1) + # nf x nloc x (axis_neuron * ng) + grrg = grrg.reshape(nf, nloc, axis_neuron * ng) + return grrg + + +def symmetrization_op( + g: np.ndarray, + h: np.ndarray, + nlist_mask: np.ndarray, + sw: np.ndarray, + axis_neuron: int, + smooth: bool = True, + epsilon: float = 1e-4, +) -> np.ndarray: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + g + Neighbor-wise/Pair-wise invariant rep tensors, with shape [nf, nloc, nnei, ng]. + h + Neighbor-wise/Pair-wise equivariant rep tensors, with shape [nf, nloc, nnei, 3]. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape [nf, nloc, nnei]. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape [nf, nloc, nnei]. + axis_neuron + Size of the submatrix. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + grrg + Atomic invariant rep, with shape [nf, nloc, (axis_neuron * ng)]. + """ + # g: nf x nloc x nnei x ng + # h: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nf, nloc, nnei, _ = g.shape + # nf x nloc x 3 x ng + hg = _cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) + # nf x nloc x (axis_neuron x ng) + grrg = _cal_grrg(hg, axis_neuron) + return grrg + + +class Atten2Map(NativeOP): + def __init__( + self, + input_dim: int, + hidden_dim: int, + head_num: int, + has_gate: bool = False, # apply gate to attn map + smooth: bool = True, + attnw_shift: float = 20.0, + precision: str = "float64", + ): + """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapqk = NativeLayer( + input_dim, hidden_dim * 2 * head_num, bias=False, precision=precision + ) + self.has_gate = has_gate + self.smooth = smooth + self.attnw_shift = attnw_shift + self.precision = precision + + def call( + self, + g2: np.ndarray, # nf x nloc x nnei x ng2 + h2: np.ndarray, # nf x nloc x nnei x 3 + nlist_mask: np.ndarray, # nf x nloc x nnei + sw: np.ndarray, # nf x nloc x nnei + ) -> np.ndarray: + ( + nf, + nloc, + nnei, + _, + ) = g2.shape + nd, nh = self.hidden_dim, self.head_num + # nf x nloc x nnei x nd x (nh x 2) + g2qk = self.mapqk(g2).reshape(nf, nloc, nnei, nd, nh * 2) + # nf x nloc x (nh x 2) x nnei x nd + g2qk = np.transpose(g2qk, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x nd + g2q, g2k = np.split(g2qk, [nh], axis=2) + # g2q = np.linalg.norm(g2q, axis=-1) + # g2k = np.linalg.norm(g2k, axis=-1) + # nf x nloc x nh x nnei x nnei + attnw = np.matmul(g2q, np.transpose(g2k, axes=(0, 1, 2, 4, 3))) / nd**0.5 + if self.has_gate: + gate = np.matmul(h2, np.transpose(h2, axes=(0, 1, 3, 2))).reshape( + nf, nloc, 1, nnei, nnei + ) + attnw = attnw * gate + # mask the attenmap, nf x nloc x 1 x 1 x nnei + attnw_mask = ~np.expand_dims(np.expand_dims(nlist_mask, axis=2), axis=2) + # mask the attenmap, nf x nloc x 1 x nnei x 1 + attnw_mask_c = ~np.expand_dims(np.expand_dims(nlist_mask, axis=2), axis=-1) + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ + :, :, None, None, : + ] - self.attnw_shift + else: + attnw = np.where(attnw_mask, -np.inf, attnw) + attnw = np_softmax(attnw, axis=-1) + attnw = np.where(attnw_mask, 0.0, attnw) + # nf x nloc x nh x nnei x nnei + attnw = np.where(attnw_mask_c, 0.0, attnw) + if self.smooth: + attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] + # nf x nloc x nnei x nnei + h2h2t = np.matmul(h2, np.transpose(h2, axes=(0, 1, 3, 2))) / 3.0**0.5 + # nf x nloc x nh x nnei x nnei + ret = attnw * h2h2t[:, :, None, :, :] + # ret = np.exp(g2qk - np.max(g2qk, axis=-1, keepdims=True)) + # nf x nloc x nnei x nnei x nh + ret = np.transpose(ret, (0, 1, 3, 4, 2)) + return ret + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2Map", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "has_gate": self.has_gate, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapqk": self.mapqk.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2Map": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapqk = data.pop("mapqk") + obj = cls(**data) + obj.mapqk = NativeLayer.deserialize(mapqk) + return obj + + +class Atten2MultiHeadApply(NativeOP): + def __init__( + self, + input_dim: int, + head_num: int, + precision: str = "float64", + ): + super().__init__() + self.input_dim = input_dim + self.head_num = head_num + self.mapv = NativeLayer( + input_dim, input_dim * head_num, bias=False, precision=precision + ) + self.head_map = NativeLayer( + input_dim * head_num, input_dim, precision=precision + ) + self.precision = precision + + def call( + self, + AA: np.ndarray, # nf x nloc x nnei x nnei x nh + g2: np.ndarray, # nf x nloc x nnei x ng2 + ) -> np.ndarray: + nf, nloc, nnei, ng2 = g2.shape + nh = self.head_num + # nf x nloc x nnei x ng2 x nh + g2v = self.mapv(g2).reshape(nf, nloc, nnei, ng2, nh) + # nf x nloc x nh x nnei x ng2 + g2v = np.transpose(g2v, (0, 1, 4, 2, 3)) + # g2v = np.linalg.norm(g2v, axis=-1) + # nf x nloc x nh x nnei x nnei + AA = np.transpose(AA, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x ng2 + ret = np.matmul(AA, g2v) + # nf x nloc x nnei x ng2 x nh + ret = np.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) + # nf x nloc x nnei x ng2 + return self.head_map(ret) + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2MultiHeadApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "mapv": self.mapv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2MultiHeadApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapv = data.pop("mapv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapv = NativeLayer.deserialize(mapv) + obj.head_map = NativeLayer.deserialize(head_map) + return obj + + +class Atten2EquiVarApply(NativeOP): + def __init__( + self, + input_dim: int, + head_num: int, + precision: str = "float64", + ): + super().__init__() + self.input_dim = input_dim + self.head_num = head_num + self.head_map = NativeLayer(head_num, 1, bias=False, precision=precision) + self.precision = precision + + def call( + self, + AA: np.ndarray, # nf x nloc x nnei x nnei x nh + h2: np.ndarray, # nf x nloc x nnei x 3 + ) -> np.ndarray: + nf, nloc, nnei, _ = h2.shape + nh = self.head_num + # nf x nloc x nh x nnei x nnei + AA = np.transpose(AA, (0, 1, 4, 2, 3)) + h2m = np.expand_dims(h2, axis=2) + # nf x nloc x nh x nnei x 3 + h2m = np.tile(h2m, (1, 1, nh, 1, 1)) + # nf x nloc x nh x nnei x 3 + ret = np.matmul(AA, h2m) + # nf x nloc x nnei x 3 x nh + ret = np.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, 3, nh) + # nf x nloc x nnei x 3 + return np.squeeze(self.head_map(ret), axis=-1) + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2EquiVarApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2EquiVarApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + head_map = data.pop("head_map") + obj = cls(**data) + obj.head_map = NativeLayer.deserialize(head_map) + return obj + + +class LocalAtten(NativeOP): + def __init__( + self, + input_dim: int, + hidden_dim: int, + head_num: int, + smooth: bool = True, + attnw_shift: float = 20.0, + precision: str = "float64", + ): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapq = NativeLayer( + input_dim, hidden_dim * 1 * head_num, bias=False, precision=precision + ) + self.mapkv = NativeLayer( + input_dim, + (hidden_dim + input_dim) * head_num, + bias=False, + precision=precision, + ) + self.head_map = NativeLayer( + input_dim * head_num, input_dim, precision=precision + ) + self.smooth = smooth + self.attnw_shift = attnw_shift + self.precision = precision + + def call( + self, + g1: np.ndarray, # nf x nloc x ng1 + gg1: np.ndarray, # nf x nloc x nnei x ng1 + nlist_mask: np.ndarray, # nf x nloc x nnei + sw: np.ndarray, # nf x nloc x nnei + ) -> np.ndarray: + nf, nloc, nnei = nlist_mask.shape + ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num + assert ni == g1.shape[-1] + assert ni == gg1.shape[-1] + # nf x nloc x nd x nh + g1q = self.mapq(g1).reshape(nf, nloc, nd, nh) + # nf x nloc x nh x nd + g1q = np.transpose(g1q, (0, 1, 3, 2)) + # nf x nloc x nnei x (nd+ni) x nh + gg1kv = self.mapkv(gg1).reshape(nf, nloc, nnei, nd + ni, nh) + gg1kv = np.transpose(gg1kv, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x nd, nf x nloc x nh x nnei x ng1 + gg1k, gg1v = np.split(gg1kv, [nd], axis=-1) + + # nf x nloc x nh x 1 x nnei + attnw = ( + np.matmul( + np.expand_dims(g1q, axis=-2), np.transpose(gg1k, axes=(0, 1, 2, 4, 3)) + ) + / nd**0.5 + ) + # nf x nloc x nh x nnei + attnw = np.squeeze(attnw, axis=-2) + # mask the attenmap, nf x nloc x 1 x nnei + attnw_mask = ~np.expand_dims(nlist_mask, axis=-2) + # nf x nloc x nh x nnei + if self.smooth: + attnw = (attnw + self.attnw_shift) * np.expand_dims( + sw, axis=-2 + ) - self.attnw_shift + else: + attnw = np.where(attnw_mask, -np.inf, attnw) + attnw = np_softmax(attnw, axis=-1) + attnw = np.where(attnw_mask, 0.0, attnw) + if self.smooth: + attnw = attnw * np.expand_dims(sw, axis=-2) + + # nf x nloc x nh x ng1 + ret = ( + np.matmul(np.expand_dims(attnw, axis=-2), gg1v) + .squeeze(-2) + .reshape(nf, nloc, nh * ni) + ) + # nf x nloc x ng1 + ret = self.head_map(ret) + return ret + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "LocalAtten", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapq": self.mapq.serialize(), + "mapkv": self.mapkv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "LocalAtten": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapq = data.pop("mapq") + mapkv = data.pop("mapkv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapq = NativeLayer.deserialize(mapq) + obj.mapkv = NativeLayer.deserialize(mapkv) + obj.head_map = NativeLayer.deserialize(head_map) + return obj + + +class RepformerLayer(NativeOP): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + g1_dim=128, + g2_dim=16, + axis_neuron: int = 4, + update_chnnl_2: bool = True, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation_function: str = "tanh", + update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", + smooth: bool = True, + precision: str = "float64", + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + ): + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.rcut = rcut + self.rcut_smth = rcut_smth + self.ntypes = ntypes + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + assert len(sel) == 1 + self.sel = sel + self.sec = self.sel + self.axis_neuron = axis_neuron + self.activation_function = activation_function + self.act = get_activation_fn(self.activation_function) + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_attn = update_g1_has_attn + self.update_chnnl_2 = update_chnnl_2 + self.update_g2_has_g1g1 = update_g2_has_g1g1 if self.update_chnnl_2 else False + self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False + self.update_h2 = update_h2 if self.update_chnnl_2 else False + del update_g2_has_g1g1, update_g2_has_attn, update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.attn2_has_gate = attn2_has_gate + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.smooth = smooth + self.g1_dim = g1_dim + self.g2_dim = g2_dim + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.precision = precision + + assert update_residual_init in [ + "norm", + "const", + ], "'update_residual_init' only support 'norm' or 'const'!" + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.g1_residual = [] + self.g2_residual = [] + self.h2_residual = [] + + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + + g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) + self.linear1 = NativeLayer(g1_in_dim, g1_dim, precision=precision) + self.linear2 = None + self.proj_g1g2 = None + self.proj_g1g1g2 = None + self.attn2g_map = None + self.attn2_mh_apply = None + self.attn2_lm = None + self.attn2_ev_apply = None + self.loc_attn = None + + if self.update_chnnl_2: + self.linear2 = NativeLayer(g2_dim, g2_dim, precision=precision) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + if self.update_g1_has_conv: + self.proj_g1g2 = NativeLayer( + g1_dim, g2_dim, bias=False, precision=precision + ) + if self.update_g2_has_g1g1: + self.proj_g1g1g2 = NativeLayer( + g1_dim, g2_dim, bias=False, precision=precision + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + if self.update_g2_has_attn or self.update_h2: + self.attn2g_map = Atten2Map( + g2_dim, + attn2_hidden, + attn2_nhead, + attn2_has_gate, + self.smooth, + precision=precision, + ) + if self.update_g2_has_attn: + self.attn2_mh_apply = Atten2MultiHeadApply( + g2_dim, attn2_nhead, precision=precision + ) + self.attn2_lm = LayerNorm( + g2_dim, eps=ln_eps, trainable=trainable_ln, precision=precision + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + + if self.update_h2: + self.attn2_ev_apply = Atten2EquiVarApply( + g2_dim, attn2_nhead, precision=precision + ) + if self.update_style == "res_residual": + self.h2_residual.append( + get_residual( + 1, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + if self.update_g1_has_attn: + self.loc_attn = LocalAtten( + g1_dim, attn1_hidden, attn1_nhead, self.smooth, precision=precision + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: + ret = g1d + if self.update_g1_has_grrg: + ret += g2d * ax + if self.update_g1_has_drrd: + ret += g1d * ax + if self.update_g1_has_conv: + ret += g2d + return ret + + def _update_h2( + self, + h2: np.ndarray, + attn: np.ndarray, + ) -> np.ndarray: + """ + Calculate the attention weights update for pair-wise equivariant rep. + + Parameters + ---------- + h2 + Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + attn + Attention weights from g2 attention, with shape nf x nloc x nnei x nnei x nh2. + """ + assert self.attn2_ev_apply is not None + # nf x nloc x nnei x nh2 + h2_1 = self.attn2_ev_apply(attn, h2) + return h2_1 + + def _update_g1_conv( + self, + gg1: np.ndarray, + g2: np.ndarray, + nlist_mask: np.ndarray, + sw: np.ndarray, + ) -> np.ndarray: + """ + Calculate the convolution update for atomic invariant rep. + + Parameters + ---------- + gg1 + Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + g2 + Pair invariant rep, with shape nf x nloc x nnei x ng2. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + """ + assert self.proj_g1g2 is not None + nf, nloc, nnei, _ = g2.shape + ng1 = gg1.shape[-1] + ng2 = g2.shape[-1] + # gg1 : nf x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).reshape(nf, nloc, nnei, ng2) + # nf x nloc x nnei x ng2 + gg1 = _apply_nlist_mask(gg1, nlist_mask) + if not self.smooth: + # normalized by number of neighbors, not smooth + # nf x nloc + invnnei = 1.0 / (self.epsilon + np.sum(nlist_mask, axis=-1)) + # nf x nloc x 1 + invnnei = invnnei[:, :, np.newaxis] + else: + gg1 = _apply_switch(gg1, sw) + invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1), dtype=gg1.dtype) + # nf x nloc x ng2 + g1_11 = np.sum(g2 * gg1, axis=2) * invnnei + return g1_11 + + def _update_g2_g1g1( + self, + g1: np.ndarray, # nf x nloc x ng1 + gg1: np.ndarray, # nf x nloc x nnei x ng1 + nlist_mask: np.ndarray, # nf x nloc x nnei + sw: np.ndarray, # nf x nloc x nnei + ) -> np.ndarray: + """ + Update the g2 using element-wise dot g1_i * g1_j. + + Parameters + ---------- + g1 + Atomic invariant rep, with shape nf x nloc x ng1. + gg1 + Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + """ + ret = np.expand_dims(g1, axis=-2) * gg1 + # nf x nloc x nnei x ng1 + ret = _apply_nlist_mask(ret, nlist_mask) + if self.smooth: + ret = _apply_switch(ret, sw) + return ret + + def call( + self, + g1_ext: np.ndarray, # nf x nall x ng1 + g2: np.ndarray, # nf x nloc x nnei x ng2 + h2: np.ndarray, # nf x nloc x nnei x 3 + nlist: np.ndarray, # nf x nloc x nnei + nlist_mask: np.ndarray, # nf x nloc x nnei + sw: np.ndarray, # switch func, nf x nloc x nnei + ): + """ + Parameters + ---------- + g1_ext : nf x nall x ng1 extended single-atom chanel + g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant + h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant + nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) + nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei switch function + + Returns + ------- + g1: nf x nloc x ng1 updated single-atom chanel + g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant + h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant + """ + cal_gg1 = ( + self.update_g1_has_drrd + or self.update_g1_has_conv + or self.update_g1_has_attn + or self.update_g2_has_g1g1 + ) + + nf, nloc, nnei, _ = g2.shape + nall = g1_ext.shape[1] + g1, _ = np.split(g1_ext, [nloc], axis=1) + assert (nf, nloc) == g1.shape[:2] + assert (nf, nloc, nnei) == h2.shape[:3] + + g2_update: List[np.ndarray] = [g2] + h2_update: List[np.ndarray] = [h2] + g1_update: List[np.ndarray] = [g1] + g1_mlp: List[np.ndarray] = [g1] + + if cal_gg1: + gg1 = _make_nei_g1(g1_ext, nlist) + else: + gg1 = None + + if self.update_chnnl_2: + # mlp(g2) + assert self.linear2 is not None + # nf x nloc x nnei x ng2 + g2_1 = self.act(self.linear2(g2)) + g2_update.append(g2_1) + + if self.update_g2_has_g1g1: + # linear(g1_i * g1_j) + assert gg1 is not None + assert self.proj_g1g1g2 is not None + g2_update.append( + self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) + ) + + if self.update_g2_has_attn or self.update_h2: + # gated_attention(g2, h2) + assert self.attn2g_map is not None + # nf x nloc x nnei x nnei x nh + AAg = self.attn2g_map(g2, h2, nlist_mask, sw) + + if self.update_g2_has_attn: + assert self.attn2_mh_apply is not None + assert self.attn2_lm is not None + # nf x nloc x nnei x ng2 + g2_2 = self.attn2_mh_apply(AAg, g2) + g2_2 = self.attn2_lm(g2_2) + g2_update.append(g2_2) + + if self.update_h2: + # linear_head(attention_weights * h2) + h2_update.append(self._update_h2(h2, AAg)) + + if self.update_g1_has_conv: + assert gg1 is not None + g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) + + if self.update_g1_has_grrg: + g1_mlp.append( + symmetrization_op( + g2, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) + + if self.update_g1_has_drrd: + assert gg1 is not None + g1_mlp.append( + symmetrization_op( + gg1, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) + + # nf x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] + # conv grrg drrd + g1_1 = self.act(self.linear1(np.concatenate(g1_mlp, axis=-1))) + g1_update.append(g1_1) + + if self.update_g1_has_attn: + assert gg1 is not None + assert self.loc_attn is not None + g1_update.append(self.loc_attn(g1, gg1, nlist_mask, sw)) + + # update + if self.update_chnnl_2: + g2_new = self.list_update(g2_update, "g2") + h2_new = self.list_update(h2_update, "h2") + else: + g2_new, h2_new = g2, h2 + g1_new = self.list_update(g1_update, "g1") + return g1_new, g2_new, h2_new + + def list_update_res_avg( + self, + update_list: List[np.ndarray], + ) -> np.ndarray: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + def list_update_res_incr(self, update_list: List[np.ndarray]) -> np.ndarray: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + def list_update_res_residual( + self, update_list: List[np.ndarray], update_name: str = "g1" + ) -> np.ndarray: + nitem = len(update_list) + uu = update_list[0] + if update_name == "g1": + for ii, vv in enumerate(self.g1_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "g2": + for ii, vv in enumerate(self.g2_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "h2": + for ii, vv in enumerate(self.h2_residual): + uu = uu + vv * update_list[ii + 1] + else: + raise NotImplementedError + return uu + + def list_update( + self, update_list: List[np.ndarray], update_name: str = "g1" + ) -> np.ndarray: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + elif self.update_style == "res_residual": + return self.list_update_res_residual(update_list, update_name=update_name) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + data = { + "@class": "RepformerLayer", + "@version": 1, + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "g1_dim": self.g1_dim, + "g2_dim": self.g2_dim, + "axis_neuron": self.axis_neuron, + "update_chnnl_2": self.update_chnnl_2, + "update_g1_has_conv": self.update_g1_has_conv, + "update_g1_has_drrd": self.update_g1_has_drrd, + "update_g1_has_grrg": self.update_g1_has_grrg, + "update_g1_has_attn": self.update_g1_has_attn, + "update_g2_has_g1g1": self.update_g2_has_g1g1, + "update_g2_has_attn": self.update_g2_has_attn, + "update_h2": self.update_h2, + "attn1_hidden": self.attn1_hidden, + "attn1_nhead": self.attn1_nhead, + "attn2_hidden": self.attn2_hidden, + "attn2_nhead": self.attn2_nhead, + "attn2_has_gate": self.attn2_has_gate, + "activation_function": self.activation_function, + "update_style": self.update_style, + "smooth": self.smooth, + "precision": self.precision, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "linear1": self.linear1.serialize(), + } + if self.update_chnnl_2: + data.update( + { + "linear2": self.linear2.serialize(), + } + ) + if self.update_g1_has_conv: + data.update( + { + "proj_g1g2": self.proj_g1g2.serialize(), + } + ) + if self.update_g2_has_g1g1: + data.update( + { + "proj_g1g1g2": self.proj_g1g1g2.serialize(), + } + ) + if self.update_g2_has_attn or self.update_h2: + data.update( + { + "attn2g_map": self.attn2g_map.serialize(), + } + ) + if self.update_g2_has_attn: + data.update( + { + "attn2_mh_apply": self.attn2_mh_apply.serialize(), + "attn2_lm": self.attn2_lm.serialize(), + } + ) + + if self.update_h2: + data.update( + { + "attn2_ev_apply": self.attn2_ev_apply.serialize(), + } + ) + if self.update_g1_has_attn: + data.update( + { + "loc_attn": self.loc_attn.serialize(), + } + ) + if self.update_style == "res_residual": + data.update( + { + "g1_residual": self.g1_residual, + "g2_residual": self.g2_residual, + "h2_residual": self.h2_residual, + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "RepformerLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + linear1 = data.pop("linear1") + update_chnnl_2 = data["update_chnnl_2"] + update_g1_has_conv = data["update_g1_has_conv"] + update_g2_has_g1g1 = data["update_g2_has_g1g1"] + update_g2_has_attn = data["update_g2_has_attn"] + update_h2 = data["update_h2"] + update_g1_has_attn = data["update_g1_has_attn"] + update_style = data["update_style"] + + linear2 = data.pop("linear2", None) + proj_g1g2 = data.pop("proj_g1g2", None) + proj_g1g1g2 = data.pop("proj_g1g1g2", None) + attn2g_map = data.pop("attn2g_map", None) + attn2_mh_apply = data.pop("attn2_mh_apply", None) + attn2_lm = data.pop("attn2_lm", None) + attn2_ev_apply = data.pop("attn2_ev_apply", None) + loc_attn = data.pop("loc_attn", None) + g1_residual = data.pop("g1_residual", []) + g2_residual = data.pop("g2_residual", []) + h2_residual = data.pop("h2_residual", []) + + obj = cls(**data) + obj.linear1 = NativeLayer.deserialize(linear1) + if update_chnnl_2: + assert isinstance(linear2, dict) + obj.linear2 = NativeLayer.deserialize(linear2) + if update_g1_has_conv: + assert isinstance(proj_g1g2, dict) + obj.proj_g1g2 = NativeLayer.deserialize(proj_g1g2) + if update_g2_has_g1g1: + assert isinstance(proj_g1g1g2, dict) + obj.proj_g1g1g2 = NativeLayer.deserialize(proj_g1g1g2) + if update_g2_has_attn or update_h2: + assert isinstance(attn2g_map, dict) + obj.attn2g_map = Atten2Map.deserialize(attn2g_map) + if update_g2_has_attn: + assert isinstance(attn2_mh_apply, dict) + assert isinstance(attn2_lm, dict) + obj.attn2_mh_apply = Atten2MultiHeadApply.deserialize(attn2_mh_apply) + obj.attn2_lm = LayerNorm.deserialize(attn2_lm) + if update_h2: + assert isinstance(attn2_ev_apply, dict) + obj.attn2_ev_apply = Atten2EquiVarApply.deserialize(attn2_ev_apply) + if update_g1_has_attn: + assert isinstance(loc_attn, dict) + obj.loc_attn = LocalAtten.deserialize(loc_attn) + if update_style == "res_residual": + obj.g1_residual = g1_residual + obj.g2_residual = g2_residual + obj.h2_residual = h2_residual + return obj diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 8d926034dd..c50fdb4cb3 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -321,7 +321,9 @@ def call( """ del mapping # nf x nloc x nnei x 4 - rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd) + rr, diff, ww = self.env_mat.call( + coord_ext, atype_ext, nlist, self.davg, self.dstd + ) nf, nloc, nnei, _ = rr.shape sec = np.append([0], np.cumsum(self.sel)) diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 9c9b4e096e..6b50c3ba68 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -270,7 +270,7 @@ def call( """ del mapping # nf x nloc x nnei x 1 - rr, ww = self.env_mat.call( + rr, diff, ww = self.env_mat.call( coord_ext, atype_ext, nlist, self.davg, self.dstd, True ) nf, nloc, nnei, _ = rr.shape diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 0c2ca43c40..94cf3a7c21 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -109,17 +109,19 @@ def call( ------- env_mat The environment matrix. shape: nf x nloc x nnei x (4 or 1) + diff + The relative coordinate of neighbors. shape: nf x nloc x nnei x 3 switch The value of switch function. shape: nf x nloc x nnei """ - em, sw = self._call(nlist, coord_ext, radial_only) + em, diff, sw = self._call(nlist, coord_ext, radial_only) nf, nloc, nnei = nlist.shape atype = atype_ext[:, :nloc] if davg is not None: em -= davg[atype] if dstd is not None: em /= dstd[atype] - return em, sw + return em, diff, sw def _call(self, nlist, coord_ext, radial_only): em, diff, ww = _make_env_mat( @@ -130,7 +132,7 @@ def _call(self, nlist, coord_ext, radial_only): radial_only=radial_only, protection=self.protection, ) - return em, ww + return em, diff, ww def serialize( self, diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 319f8a0dbd..ae557326ff 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -33,6 +33,25 @@ ) +class Identity(NativeOP): + def __init__(self): + super().__init__() + + def call(self, x: np.ndarray) -> np.ndarray: + """The Identity operation layer.""" + return x + + def serialize(self) -> dict: + return { + "@class": "Identity", + "@version": 1, + } + + @classmethod + def deserialize(cls, data: dict) -> "Identity": + return Identity() + + class NativeLayer(NativeOP): """Native representation of a layer. diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 5aae848aa4..d586fc988f 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -80,7 +80,7 @@ def get_dim_out(self) -> int: @abstractmethod def get_dim_in(self) -> int: - """Returns the output dimension.""" + """Returns the input dimension.""" pass @abstractmethod diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 0b28e388d5..e955fab048 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -273,6 +273,7 @@ def __init__( self.type_embedding = TypeEmbedNet(ntypes, tebd_dim, precision=precision) self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd + self.trainable = trainable # set trainable for param in self.parameters(): param.requires_grad = trainable @@ -418,8 +419,7 @@ def serialize(self) -> dict: "davg": obj["davg"].detach().cpu().numpy(), "dstd": obj["dstd"].detach().cpu().numpy(), }, - ## to be updated when the options are supported. - "trainable": True, + "trainable": self.trainable, "spin": None, } if obj.tebd_input_mode in ["strip"]: diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 2bf4d193f3..28ff1b6848 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -10,10 +10,18 @@ import torch -from deepmd.pt.model.network.network import ( +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.pt.model.network.mlp import ( Identity, - Linear, + MLPLayer, + NetworkCollection, +) +from deepmd.pt.model.network.network import ( TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pt.utils import ( + env, ) from deepmd.pt.utils.nlist import ( build_multiple_neighbor_list, @@ -22,13 +30,22 @@ from deepmd.pt.utils.update_sel import ( UpdateSel, ) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .base_descriptor import ( BaseDescriptor, ) +from .repformer_layer import ( + RepformerLayer, +) from .repformers import ( DescrptBlockRepformers, ) @@ -38,9 +55,10 @@ @BaseDescriptor.register("dpa2") -class DescrptDPA2(torch.nn.Module, BaseDescriptor): +class DescrptDPA2(BaseDescriptor, torch.nn.Module): def __init__( self, + # args for repinit ntypes: int, repinit_rcut: float, repinit_rcut_smth: float, @@ -48,21 +66,21 @@ def __init__( repformer_rcut: float, repformer_rcut_smth: float, repformer_nsel: int, - # kwargs - tebd_dim: int = 8, - concat_output_tebd: bool = True, + # kwargs for repinit repinit_neuron: List[int] = [25, 50, 100], repinit_axis_neuron: int = 16, - repinit_set_davg_zero: bool = True, # TODO - repinit_activation="tanh", - # repinit still unclear: - # ffn, ffn_embed_dim, scaling_factor, normalize, + repinit_tebd_dim: int = 8, + repinit_tebd_input_mode: str = "concat", + repinit_set_davg_zero: bool = True, + repinit_activation_function="tanh", + repinit_resnet_dt: bool = False, + repinit_type_one_side: bool = False, + # kwargs for repformer repformer_nlayers: int = 3, repformer_g1_dim: int = 128, repformer_g2_dim: int = 16, - repformer_axis_dim: int = 4, - repformer_do_bn_mode: str = "no", - repformer_bn_momentum: float = 0.1, + repformer_axis_neuron: int = 4, + repformer_direct_dist: bool = False, repformer_update_g1_has_conv: bool = True, repformer_update_g1_has_drrd: bool = True, repformer_update_g1_has_grrg: bool = True, @@ -75,113 +93,168 @@ def __init__( repformer_attn2_hidden: int = 16, repformer_attn2_nhead: int = 4, repformer_attn2_has_gate: bool = False, - repformer_activation: str = "tanh", + repformer_activation_function: str = "tanh", repformer_update_style: str = "res_avg", - repformer_set_davg_zero: bool = True, # TODO - repformer_add_type_ebd_to_seq: bool = False, + repformer_update_residual: float = 0.001, + repformer_update_residual_init: str = "norm", + repformer_set_davg_zero: bool = True, + repformer_trainable_ln: bool = True, + repformer_ln_eps: Optional[float] = 1e-5, + # kwargs for descriptor + concat_output_tebd: bool = True, + precision: str = "float64", + smooth: bool = True, + exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, trainable: bool = True, - exclude_types: List[Tuple[int, int]] = [], - type: Optional[ - str - ] = None, # work around the bad design in get_trainer and DpLoaderSet! - rcut: Optional[ - float - ] = None, # work around the bad design in get_trainer and DpLoaderSet! - rcut_smth: Optional[ - float - ] = None, # work around the bad design in get_trainer and DpLoaderSet! - sel: Optional[ - int - ] = None, # work around the bad design in get_trainer and DpLoaderSet! + seed: Optional[int] = None, + add_tebd_to_repinit_out: bool = False, + old_impl: bool = False, ): r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492. Parameters ---------- - ntypes : int - Number of atom types repinit_rcut : float - The cut-off radius of the repinit block + (Used in the repinit block.) + The cut-off radius. repinit_rcut_smth : float - From this position the inverse distance smoothly decays - to 0 at the cut-off. Use in the repinit block. + (Used in the repinit block.) + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. repinit_nsel : int - Maximally possible number of neighbors for repinit block. + (Used in the repinit block.) + Maximally possible number of selected neighbors. + repinit_neuron : list, optional + (Used in the repinit block.) + Number of neurons in each hidden layers of the embedding net. + When two layers are of the same size or one layer is twice as large as the previous layer, + a skip connection is built. + repinit_axis_neuron : int, optional + (Used in the repinit block.) + Size of the submatrix of G (embedding matrix). + repinit_tebd_dim : int, optional + (Used in the repinit block.) + The dimension of atom type embedding. + repinit_tebd_input_mode : str, optional + (Used in the repinit block.) + The input mode of the type embedding. Supported modes are ['concat', 'strip']. + repinit_set_davg_zero : bool, optional + (Used in the repinit block.) + Set the normalization average to zero. + repinit_activation_function : str, optional + (Used in the repinit block.) + The activation function in the embedding net. + repinit_resnet_dt : bool, optional + (Used in the repinit block.) + Whether to use a "Timestep" in the skip connection. + repinit_type_one_side : bool, optional + (Used in the repinit block.) + Whether to use one-side type embedding. repformer_rcut : float - The cut-off radius of the repformer block + (Used in the repformer block.) + The cut-off radius. repformer_rcut_smth : float - From this position the inverse distance smoothly decays - to 0 at the cut-off. Use in the repformer block. + (Used in the repformer block.) + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. repformer_nsel : int - Maximally possible number of neighbors for repformer block. - tebd_dim : int - The dimension of atom type embedding - concat_output_tebd : bool + (Used in the repformer block.) + Maximally possible number of selected neighbors. + repformer_nlayers : int, optional + (Used in the repformer block.) + Number of repformer layers. + repformer_g1_dim : int, optional + (Used in the repformer block.) + Dimension of the first graph convolution layer. + repformer_g2_dim : int, optional + (Used in the repformer block.) + Dimension of the second graph convolution layer. + repformer_axis_neuron : int, optional + (Used in the repformer block.) + Size of the submatrix of G (embedding matrix). + repformer_direct_dist : bool, optional + (Used in the repformer block.) + Whether to use direct distance information (1/r term) in the repformer block. + repformer_update_g1_has_conv : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with convolution term. + repformer_update_g1_has_drrd : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the drrd term. + repformer_update_g1_has_grrg : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the grrg term. + repformer_update_g1_has_attn : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the localized self-attention. + repformer_update_g2_has_g1g1 : bool, optional + (Used in the repformer block.) + Whether to update the g2 rep with the g1xg1 term. + repformer_update_g2_has_attn : bool, optional + (Used in the repformer block.) + Whether to update the g2 rep with the gated self-attention. + repformer_update_h2 : bool, optional + (Used in the repformer block.) + Whether to update the h2 rep. + repformer_attn1_hidden : int, optional + (Used in the repformer block.) + The hidden dimension of localized self-attention to update the g1 rep. + repformer_attn1_nhead : int, optional + (Used in the repformer block.) + The number of heads in localized self-attention to update the g1 rep. + repformer_attn2_hidden : int, optional + (Used in the repformer block.) + The hidden dimension of gated self-attention to update the g2 rep. + repformer_attn2_nhead : int, optional + (Used in the repformer block.) + The number of heads in gated self-attention to update the g2 rep. + repformer_attn2_has_gate : bool, optional + (Used in the repformer block.) + Whether to use gate in the gated self-attention to update the g2 rep. + repformer_activation_function : str, optional + (Used in the repformer block.) + The activation function in the embedding net. + repformer_update_style : str, optional + (Used in the repformer block.) + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `repformer_update_residual` + and `repformer_update_residual_init`. + repformer_update_residual : float, optional + (Used in the repformer block.) + When update using residual mode, the initial std of residual vector weights. + repformer_update_residual_init : str, optional + (Used in the repformer block.) + When update using residual mode, the initialization mode of residual vector weights. + repformer_set_davg_zero : bool, optional + (Used in the repformer block.) + Set the normalization average to zero. + repformer_trainable_ln : bool, optional + (Used in the repformer block.) + Whether to use trainable shift and scale weights in layer normalization. + repformer_ln_eps : float, optional + (Used in the repformer block.) + The epsilon value for layer normalization. + concat_output_tebd : bool, optional Whether to concat type embedding at the output of the descriptor. - repinit_neuron : List[int] - repinit block: the number of neurons in the embedding net. - repinit_axis_neuron : int - repinit block: the number of dimension of split in the - symmetrization op. - repinit_activation : str - repinit block: the activation function in the embedding net - repformer_nlayers : int - repformers block: the number of repformer layers - repformer_g1_dim : int - repformers block: the dimension of single-atom rep - repformer_g2_dim : int - repformers block: the dimension of invariant pair-atom rep - repformer_axis_dim : int - repformers block: the number of dimension of split in the - symmetrization ops. - repformer_do_bn_mode : bool - repformers block: do batch norm in the repformer layers - repformer_bn_momentum : float - repformers block: moment in the batch normalization - repformer_update_g1_has_conv : bool - repformers block: update the g1 rep with convolution term - repformer_update_g1_has_drrd : bool - repformers block: update the g1 rep with the drrd term - repformer_update_g1_has_grrg : bool - repformers block: update the g1 rep with the grrg term - repformer_update_g1_has_attn : bool - repformers block: update the g1 rep with the localized - self-attention - repformer_update_g2_has_g1g1 : bool - repformers block: update the g2 rep with the g1xg1 term - repformer_update_g2_has_attn : bool - repformers block: update the g2 rep with the gated self-attention - repformer_update_h2 : bool - repformers block: update the h2 rep - repformer_attn1_hidden : int - repformers block: the hidden dimension of localized self-attention - repformer_attn1_nhead : int - repformers block: the number of heads in localized self-attention - repformer_attn2_hidden : int - repformers block: the hidden dimension of gated self-attention - repformer_attn2_nhead : int - repformers block: the number of heads in gated self-attention - repformer_attn2_has_gate : bool - repformers block: has gate in the gated self-attention - repformer_activation : str - repformers block: the activation function in the MLPs. - repformer_update_style : str - repformers block: style of update a rep. - can be res_avg or res_incr. - res_avg updates a rep `u` with: - u = 1/\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) - res_incr updates a rep `u` with: - u = u + 1/\sqrt{n} (u_1 + u_2 + ... + u_n) - repformer_set_davg_zero : bool - repformers block: set the avg to zero in statistics - repformer_add_type_ebd_to_seq : bool - repformers block: concatenate the type embedding at the output. - trainable : bool - If the parameters in the descriptor are trainable. - exclude_types : List[Tuple[int, int]] = [], + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : List[List[int]], optional The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable : bool, optional + If the parameters are trainable. + seed : int, optional + (Unused yet) Random seed for parameter initialization. + add_tebd_to_repinit_out : bool, optional + Whether to add type embedding to the output representation from repinit before inputting it into repformer. Returns ------- @@ -199,7 +272,9 @@ def __init__( """ super().__init__() - del type, rcut, rcut_smth, sel + # to keep consistent with default value in this backends + if repformer_ln_eps is None: + repformer_ln_eps = 1e-5 self.repinit = DescrptBlockSeAtten( repinit_rcut, repinit_rcut_smth, @@ -208,13 +283,16 @@ def __init__( attn_layer=0, neuron=repinit_neuron, axis_neuron=repinit_axis_neuron, - tebd_dim=tebd_dim, - tebd_input_mode="concat", - # tebd_input_mode='dot_residual_s', + tebd_dim=repinit_tebd_dim, + tebd_input_mode=repinit_tebd_input_mode, set_davg_zero=repinit_set_davg_zero, exclude_types=exclude_types, env_protection=env_protection, - activation_function=repinit_activation, + activation_function=repinit_activation_function, + precision=precision, + resnet_dt=repinit_resnet_dt, + smooth=smooth, + type_one_side=repinit_type_one_side, ) self.repformers = DescrptBlockRepformers( repformer_rcut, @@ -224,10 +302,8 @@ def __init__( nlayers=repformer_nlayers, g1_dim=repformer_g1_dim, g2_dim=repformer_g2_dim, - axis_dim=repformer_axis_dim, - direct_dist=False, - do_bn_mode=repformer_do_bn_mode, - bn_momentum=repformer_bn_momentum, + axis_neuron=repformer_axis_neuron, + direct_dist=repformer_direct_dist, update_g1_has_conv=repformer_update_g1_has_conv, update_g1_has_drrd=repformer_update_g1_has_drrd, update_g1_has_grrg=repformer_update_g1_has_grrg, @@ -240,28 +316,52 @@ def __init__( attn2_hidden=repformer_attn2_hidden, attn2_nhead=repformer_attn2_nhead, attn2_has_gate=repformer_attn2_has_gate, - activation_function=repformer_activation, + activation_function=repformer_activation_function, update_style=repformer_update_style, + update_residual=repformer_update_residual, + update_residual_init=repformer_update_residual_init, set_davg_zero=repformer_set_davg_zero, - smooth=True, - add_type_ebd_to_seq=repformer_add_type_ebd_to_seq, + smooth=smooth, exclude_types=exclude_types, env_protection=env_protection, + precision=precision, + trainable_ln=repformer_trainable_ln, + ln_eps=repformer_ln_eps, + old_impl=old_impl, + ) + self.type_embedding = TypeEmbedNet( + ntypes, repinit_tebd_dim, precision=precision ) - self.type_embedding = TypeEmbedNet(ntypes, tebd_dim) + self.concat_output_tebd = concat_output_tebd + self.precision = precision + self.smooth = smooth + self.exclude_types = exclude_types + self.env_protection = env_protection + self.trainable = trainable + self.add_tebd_to_repinit_out = add_tebd_to_repinit_out + if self.repinit.dim_out == self.repformers.dim_in: self.g1_shape_tranform = Identity() else: - self.g1_shape_tranform = Linear( + self.g1_shape_tranform = MLPLayer( self.repinit.dim_out, self.repformers.dim_in, bias=False, + precision=precision, init="glorot", ) + self.tebd_transform = None + if self.add_tebd_to_repinit_out: + self.tebd_transform = MLPLayer( + repinit_tebd_dim, + self.repformers.dim_in, + bias=False, + precision=precision, + ) assert self.repinit.rcut > self.repformers.rcut assert self.repinit.sel[0] > self.repformers.sel[0] - self.concat_output_tebd = concat_output_tebd - self.tebd_dim = tebd_dim + + self.tebd_dim = repinit_tebd_dim self.rcut = self.repinit.get_rcut() self.ntypes = ntypes self.sel = self.repinit.sel @@ -382,13 +482,146 @@ def compute_input_stats( descrpt.compute_input_stats(merged, path) def serialize(self) -> dict: - """Serialize the obj to dict.""" - raise NotImplementedError + repinit = self.repinit + repformers = self.repformers + data = { + "@class": "Descriptor", + "type": "dpa2", + "@version": 1, + "ntypes": self.ntypes, + "repinit_rcut": repinit.rcut, + "repinit_rcut_smth": repinit.rcut_smth, + "repinit_nsel": repinit.sel, + "repformer_rcut": repformers.rcut, + "repformer_rcut_smth": repformers.rcut_smth, + "repformer_nsel": repformers.sel, + "repinit_neuron": repinit.neuron, + "repinit_axis_neuron": repinit.axis_neuron, + "repinit_tebd_dim": repinit.tebd_dim, + "repinit_tebd_input_mode": repinit.tebd_input_mode, + "repinit_set_davg_zero": repinit.set_davg_zero, + "repinit_activation_function": repinit.activation_function, + "repinit_resnet_dt": repinit.resnet_dt, + "repinit_type_one_side": repinit.type_one_side, + "repformer_nlayers": repformers.nlayers, + "repformer_g1_dim": repformers.g1_dim, + "repformer_g2_dim": repformers.g2_dim, + "repformer_axis_neuron": repformers.axis_neuron, + "repformer_direct_dist": repformers.direct_dist, + "repformer_update_g1_has_conv": repformers.update_g1_has_conv, + "repformer_update_g1_has_drrd": repformers.update_g1_has_drrd, + "repformer_update_g1_has_grrg": repformers.update_g1_has_grrg, + "repformer_update_g1_has_attn": repformers.update_g1_has_attn, + "repformer_update_g2_has_g1g1": repformers.update_g2_has_g1g1, + "repformer_update_g2_has_attn": repformers.update_g2_has_attn, + "repformer_update_h2": repformers.update_h2, + "repformer_attn1_hidden": repformers.attn1_hidden, + "repformer_attn1_nhead": repformers.attn1_nhead, + "repformer_attn2_hidden": repformers.attn2_hidden, + "repformer_attn2_nhead": repformers.attn2_nhead, + "repformer_attn2_has_gate": repformers.attn2_has_gate, + "repformer_activation_function": repformers.activation_function, + "repformer_update_style": repformers.update_style, + "repformer_set_davg_zero": repformers.set_davg_zero, + "repformer_trainable_ln": repformers.trainable_ln, + "repformer_ln_eps": repformers.ln_eps, + "concat_output_tebd": self.concat_output_tebd, + "precision": self.precision, + "smooth": self.smooth, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "trainable": self.trainable, + "add_tebd_to_repinit_out": self.add_tebd_to_repinit_out, + "type_embedding": self.type_embedding.embedding.serialize(), + "g1_shape_tranform": self.g1_shape_tranform.serialize(), + } + if self.add_tebd_to_repinit_out: + data.update( + { + "tebd_transform": self.tebd_transform.serialize(), + } + ) + repinit_variable = { + "embeddings": repinit.filter_layers.serialize(), + "env_mat": DPEnvMat(repinit.rcut, repinit.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(repinit["davg"]), + "dstd": to_numpy_array(repinit["dstd"]), + }, + } + if repinit.tebd_input_mode in ["strip"]: + repinit_variable.update( + {"embeddings_strip": repinit.filter_layers_strip.serialize()} + ) + repformers_variable = { + "g2_embd": repformers.g2_embd.serialize(), + "repformer_layers": [layer.serialize() for layer in repformers.layers], + "env_mat": DPEnvMat(repformers.rcut, repformers.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(repformers["davg"]), + "dstd": to_numpy_array(repformers["dstd"]), + }, + } + data.update( + { + "repinit": repinit_variable, + "repformers": repformers_variable, + } + ) + return data @classmethod - def deserialize(cls) -> "DescrptDPA2": - """Deserialize from a dict.""" - raise NotImplementedError + def deserialize(cls, data: dict) -> "DescrptDPA2": + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + repinit_variable = data.pop("repinit").copy() + repformers_variable = data.pop("repformers").copy() + type_embedding = data.pop("type_embedding") + g1_shape_tranform = data.pop("g1_shape_tranform") + tebd_transform = data.pop("tebd_transform", None) + add_tebd_to_repinit_out = data["add_tebd_to_repinit_out"] + obj = cls(**data) + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + if add_tebd_to_repinit_out: + assert isinstance(tebd_transform, dict) + obj.tebd_transform = MLPLayer.deserialize(tebd_transform) + if obj.repinit.dim_out != obj.repformers.dim_in: + obj.g1_shape_tranform = MLPLayer.deserialize(g1_shape_tranform) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.repinit.prec, device=env.DEVICE) + + # deserialize repinit + statistic_repinit = repinit_variable.pop("@variables") + env_mat = repinit_variable.pop("env_mat") + tebd_input_mode = data["repinit_tebd_input_mode"] + obj.repinit.filter_layers = NetworkCollection.deserialize( + repinit_variable.pop("embeddings") + ) + if tebd_input_mode in ["strip"]: + obj.repinit.filter_layers_strip = NetworkCollection.deserialize( + repinit_variable.pop("embeddings_strip") + ) + obj.repinit["davg"] = t_cvt(statistic_repinit["davg"]) + obj.repinit["dstd"] = t_cvt(statistic_repinit["dstd"]) + + # deserialize repformers + statistic_repformers = repformers_variable.pop("@variables") + env_mat = repformers_variable.pop("env_mat") + repformer_layers = repformers_variable.pop("repformer_layers") + obj.repformers.g2_embd = MLPLayer.deserialize( + repformers_variable.pop("g2_embd") + ) + obj.repformers["davg"] = t_cvt(statistic_repformers["davg"]) + obj.repformers["dstd"] = t_cvt(statistic_repformers["dstd"]) + obj.repformers.layers = torch.nn.ModuleList( + [RepformerLayer.deserialize(layer) for layer in repformer_layers] + ) + return obj def forward( self, @@ -402,9 +635,9 @@ def forward( Parameters ---------- - coord_ext + extended_coord The extended coordinates of atoms. shape: nf x (nallx3) - atype_ext + extended_atype The extended aotm types. shape: nf x nall nlist The neighbor list. shape: nf x nloc x nnei @@ -453,6 +686,9 @@ def forward( ) # linear to change shape g1 = self.g1_shape_tranform(g1) + if self.add_tebd_to_repinit_out: + assert self.tebd_transform is not None + g1 = g1 + self.tebd_transform(g1_inp) # mapping g1 if comm_dict is None: assert mapping is not None diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index a58d6b0e2c..8397b4b421 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -1,32 +1,93 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Callable, List, + Optional, ) import torch +import torch.nn as nn -from deepmd.pt.model.network.network import ( - SimpleLinear, +from deepmd.pt.model.network.layernorm import ( + LayerNorm, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, ) from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) from deepmd.pt.utils.utils import ( ActivationFn, + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.version import ( + check_version_compatibility, ) -def torch_linear(*args, **kwargs): - return torch.nn.Linear( - *args, **kwargs, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE +def get_residual( + _dim: int, + _scale: float, + _mode: str = "norm", + trainable: bool = True, + precision: str = "float64", +) -> torch.Tensor: + r""" + Get residual tensor for one update vector. + + Parameters + ---------- + _dim : int + The dimension of the update vector. + _scale + The initial scale of the residual tensor. See `_mode` for details. + _mode + The mode of residual initialization for the residual tensor. + - "norm" (default): init residual using normal with `_scale` std. + - "const": init residual using element-wise constants of `_scale`. + trainable + Whether the residual tensor is trainable. + precision + The precision of the residual tensor. + """ + residual = nn.Parameter( + data=torch.zeros(_dim, dtype=PRECISION_DICT[precision], device=env.DEVICE), + requires_grad=trainable, ) + if _mode == "norm": + nn.init.normal_(residual.data, std=_scale) + elif _mode == "const": + nn.init.constant_(residual.data, val=_scale) + else: + raise RuntimeError(f"Unsupported initialization mode '{_mode}'!") + return residual +# common ops def _make_nei_g1( g1_ext: torch.Tensor, nlist: torch.Tensor, ) -> torch.Tensor: + """ + Make neighbor-wise atomic invariant rep. + + Parameters + ---------- + g1_ext + Extended atomic invariant rep, with shape nb x nall x ng1. + nlist + Neighbor list, with shape nb x nloc x nnei. + + Returns + ------- + gg1: torch.Tensor + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + + """ # nlist: nb x nloc x nnei nb, nloc, nnei = nlist.shape # g1_ext: nb x nall x ng1 @@ -44,51 +105,61 @@ def _apply_nlist_mask( gg: torch.Tensor, nlist_mask: torch.Tensor, ) -> torch.Tensor: - # gg: nf x nloc x nnei x ng + """ + Apply nlist mask to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + """ + # gg: nf x nloc x nnei x d # msk: nf x nloc x nnei return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0) def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: - # gg: nf x nloc x nnei x ng - # sw: nf x nloc x nnei - return gg * sw.unsqueeze(-1) - + """ + Apply switch function to neighbor-wise rep tensors. -def _apply_h_norm( - hh: torch.Tensor, # nf x nloc x nnei x 3 -) -> torch.Tensor: - """Normalize h by the std of vector length. - do not have an idea if this is a good way. + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. """ - nf, nl, nnei, _ = hh.shape - # nf x nloc x nnei - normh = torch.linalg.norm(hh, dim=-1) - # nf x nloc - std = torch.std(normh, dim=-1) - # nf x nloc x nnei x 3 - hh = hh[:, :, :, :] / (1.0 + std[:, :, None, None]) - return hh + # gg: nf x nloc x nnei x d + # sw: nf x nloc x nnei + return gg * sw.unsqueeze(-1) class Atten2Map(torch.nn.Module): def __init__( self, - ni: int, - nd: int, - nh: int, + input_dim: int, + hidden_dim: int, + head_num: int, has_gate: bool = False, # apply gate to attn map smooth: bool = True, attnw_shift: float = 20.0, + precision: str = "float64", ): + """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" super().__init__() - self.ni = ni - self.nd = nd - self.nh = nh - self.mapqk = SimpleLinear(ni, nd * 2 * nh, bias=False) + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapqk = MLPLayer( + input_dim, hidden_dim * 2 * head_num, bias=False, precision=precision + ) self.has_gate = has_gate self.smooth = smooth self.attnw_shift = attnw_shift + self.precision = precision def forward( self, @@ -103,7 +174,7 @@ def forward( nnei, _, ) = g2.shape - nd, nh = self.nd, self.nh + nd, nh = self.hidden_dim, self.head_num # nb x nloc x nnei x nd x (nh x 2) g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2) # nb x nloc x (nh x 2) x nnei x nd @@ -151,18 +222,60 @@ def forward( ret = torch.permute(ret, (0, 1, 3, 4, 2)) return ret + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2Map", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "has_gate": self.has_gate, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapqk": self.mapqk.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2Map": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapqk = data.pop("mapqk") + obj = cls(**data) + obj.mapqk = MLPLayer.deserialize(mapqk) + return obj + class Atten2MultiHeadApply(torch.nn.Module): def __init__( self, - ni: int, - nh: int, + input_dim: int, + head_num: int, + precision: str = "float64", ): super().__init__() - self.ni = ni - self.nh = nh - self.mapv = SimpleLinear(ni, ni * nh, bias=False) - self.head_map = SimpleLinear(ni * nh, ni) + self.input_dim = input_dim + self.head_num = head_num + self.mapv = MLPLayer( + input_dim, input_dim * head_num, bias=False, precision=precision + ) + self.head_map = MLPLayer(input_dim * head_num, input_dim, precision=precision) + self.precision = precision def forward( self, @@ -170,7 +283,7 @@ def forward( g2: torch.Tensor, # nf x nloc x nnei x ng2 ) -> torch.Tensor: nf, nloc, nnei, ng2 = g2.shape - nh = self.nh + nh = self.head_num # nf x nloc x nnei x ng2 x nh g2v = self.mapv(g2).view(nf, nloc, nnei, ng2, nh) # nf x nloc x nh x nnei x ng2 @@ -185,17 +298,56 @@ def forward( # nf x nloc x nnei x ng2 return self.head_map(ret) + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2MultiHeadApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "mapv": self.mapv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2MultiHeadApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapv = data.pop("mapv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapv = MLPLayer.deserialize(mapv) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + class Atten2EquiVarApply(torch.nn.Module): def __init__( self, - ni: int, - nh: int, + input_dim: int, + head_num: int, + precision: str = "float64", ): super().__init__() - self.ni = ni - self.nh = nh - self.head_map = SimpleLinear(nh, 1, bias=False) + self.input_dim = input_dim + self.head_num = head_num + self.head_map = MLPLayer(head_num, 1, bias=False, precision=precision) + self.precision = precision def forward( self, @@ -203,7 +355,7 @@ def forward( h2: torch.Tensor, # nf x nloc x nnei x 3 ) -> torch.Tensor: nf, nloc, nnei, _ = h2.shape - nh = self.nh + nh = self.head_num # nf x nloc x nh x nnei x nnei AA = torch.permute(AA, (0, 1, 4, 2, 3)) h2m = torch.unsqueeze(h2, dim=2) @@ -216,25 +368,68 @@ def forward( # nf x nloc x nnei x 3 return torch.squeeze(self.head_map(ret), dim=-1) + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2EquiVarApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2EquiVarApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + head_map = data.pop("head_map") + obj = cls(**data) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + class LocalAtten(torch.nn.Module): def __init__( self, - ni: int, - nd: int, - nh: int, + input_dim: int, + hidden_dim: int, + head_num: int, smooth: bool = True, attnw_shift: float = 20.0, + precision: str = "float64", ): super().__init__() - self.ni = ni - self.nd = nd - self.nh = nh - self.mapq = SimpleLinear(ni, nd * 1 * nh, bias=False) - self.mapkv = SimpleLinear(ni, (nd + ni) * nh, bias=False) - self.head_map = SimpleLinear(ni * nh, ni) + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapq = MLPLayer( + input_dim, hidden_dim * 1 * head_num, bias=False, precision=precision + ) + self.mapkv = MLPLayer( + input_dim, + (hidden_dim + input_dim) * head_num, + bias=False, + precision=precision, + ) + self.head_map = MLPLayer(input_dim * head_num, input_dim, precision=precision) self.smooth = smooth self.attnw_shift = attnw_shift + self.precision = precision def forward( self, @@ -244,7 +439,7 @@ def forward( sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: nb, nloc, nnei = nlist_mask.shape - ni, nd, nh = self.ni, self.nd, self.nh + ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num assert ni == g1.shape[-1] assert ni == gg1.shape[-1] # nb x nloc x nd x nh @@ -287,6 +482,49 @@ def forward( ret = self.head_map(ret) return ret + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "LocalAtten", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapq": self.mapq.serialize(), + "mapkv": self.mapkv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "LocalAtten": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapq = data.pop("mapq") + mapkv = data.pop("mapkv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapq = MLPLayer.deserialize(mapq) + obj.mapkv = MLPLayer.deserialize(mapkv) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + class RepformerLayer(torch.nn.Module): def __init__( @@ -297,10 +535,8 @@ def __init__( ntypes: int, g1_dim=128, g2_dim=16, - axis_dim: int = 4, + axis_neuron: int = 4, update_chnnl_2: bool = True, - do_bn_mode: str = "no", - bn_momentum: float = 0.1, update_g1_has_conv: bool = True, update_g1_has_drrd: bool = True, update_g1_has_grrg: bool = True, @@ -315,8 +551,12 @@ def __init__( attn2_has_gate: bool = False, activation_function: str = "tanh", update_style: str = "res_avg", - set_davg_zero: bool = True, # TODO + update_residual: float = 0.001, + update_residual_init: str = "norm", smooth: bool = True, + precision: str = "float64", + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, ): super().__init__() self.epsilon = 1e-4 # protection of 1./nnei @@ -326,12 +566,10 @@ def __init__( sel = [sel] if isinstance(sel, int) else sel self.nnei = sum(sel) assert len(sel) == 1 - self.sel = torch.tensor(sel, device=env.DEVICE) + self.sel = sel self.sec = self.sel - self.axis_dim = axis_dim - self.set_davg_zero = set_davg_zero - self.do_bn_mode = do_bn_mode - self.bn_momentum = bn_momentum + self.axis_neuron = axis_neuron + self.activation_function = activation_function self.act = ActivationFn(activation_function) self.update_g1_has_grrg = update_g1_has_grrg self.update_g1_has_drrd = update_g1_has_drrd @@ -342,58 +580,132 @@ def __init__( self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False self.update_h2 = update_h2 if self.update_chnnl_2 else False del update_g2_has_g1g1, update_g2_has_attn, update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.attn2_has_gate = attn2_has_gate self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init self.smooth = smooth self.g1_dim = g1_dim self.g2_dim = g2_dim + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.precision = precision + + assert update_residual_init in [ + "norm", + "const", + ], "'update_residual_init' only support 'norm' or 'const'!" + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.g1_residual = [] + self.g2_residual = [] + self.h2_residual = [] + + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) - g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_dim) - self.linear1 = SimpleLinear(g1_in_dim, g1_dim) + g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) + self.linear1 = MLPLayer(g1_in_dim, g1_dim, precision=precision) self.linear2 = None self.proj_g1g2 = None self.proj_g1g1g2 = None self.attn2g_map = None self.attn2_mh_apply = None self.attn2_lm = None - self.attn2h_map = None self.attn2_ev_apply = None self.loc_attn = None if self.update_chnnl_2: - self.linear2 = SimpleLinear(g2_dim, g2_dim) + self.linear2 = MLPLayer(g2_dim, g2_dim, precision=precision) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) if self.update_g1_has_conv: - self.proj_g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) + self.proj_g1g2 = MLPLayer(g1_dim, g2_dim, bias=False, precision=precision) if self.update_g2_has_g1g1: - self.proj_g1g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) - if self.update_g2_has_attn: + self.proj_g1g1g2 = MLPLayer(g1_dim, g2_dim, bias=False, precision=precision) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + if self.update_g2_has_attn or self.update_h2: self.attn2g_map = Atten2Map( - g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth - ) - self.attn2_mh_apply = Atten2MultiHeadApply(g2_dim, attn2_nhead) - self.attn2_lm = torch.nn.LayerNorm( g2_dim, - elementwise_affine=True, - device=env.DEVICE, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - ) - if self.update_h2: - self.attn2h_map = Atten2Map( - g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth + attn2_hidden, + attn2_nhead, + attn2_has_gate, + self.smooth, + precision=precision, ) - self.attn2_ev_apply = Atten2EquiVarApply(g2_dim, attn2_nhead) + if self.update_g2_has_attn: + self.attn2_mh_apply = Atten2MultiHeadApply( + g2_dim, attn2_nhead, precision=precision + ) + self.attn2_lm = LayerNorm( + g2_dim, eps=ln_eps, trainable=trainable_ln, precision=precision + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + + if self.update_h2: + self.attn2_ev_apply = Atten2EquiVarApply( + g2_dim, attn2_nhead, precision=precision + ) + if self.update_style == "res_residual": + self.h2_residual.append( + get_residual( + 1, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) if self.update_g1_has_attn: - self.loc_attn = LocalAtten(g1_dim, attn1_hidden, attn1_nhead, self.smooth) - - if self.do_bn_mode == "uniform": - self.bn1 = self._bn_layer() - self.bn2 = self._bn_layer() - elif self.do_bn_mode == "component": - self.bn1 = self._bn_layer(nf=g1_dim) - self.bn2 = self._bn_layer(nf=g2_dim) - elif self.do_bn_mode == "no": - self.bn1, self.bn2 = None, None - else: - raise RuntimeError(f"unknown bn_mode {self.do_bn_mode}") + self.loc_attn = LocalAtten( + g1_dim, attn1_hidden, attn1_nhead, self.smooth, precision=precision + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + + self.g1_residual = nn.ParameterList(self.g1_residual) + self.g2_residual = nn.ParameterList(self.g2_residual) + self.h2_residual = nn.ParameterList(self.h2_residual) def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: ret = g1d @@ -407,21 +719,22 @@ def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: def _update_h2( self, - g2: torch.Tensor, h2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, + attn: torch.Tensor, ) -> torch.Tensor: - assert self.attn2h_map is not None + """ + Calculate the attention weights update for pair-wise equivariant rep. + + Parameters + ---------- + h2 + Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + attn + Attention weights from g2 attention, with shape nf x nloc x nnei x nnei x nh2. + """ assert self.attn2_ev_apply is not None - nb, nloc, nnei, _ = g2.shape - # # nb x nloc x nnei x nh2 - # h2_1 = self.attn2_ev_apply(AA, h2) - # h2_update.append(h2_1) - # nb x nloc x nnei x nnei x nh - AAh = self.attn2h_map(g2, h2, nlist_mask, sw) - # nb x nloc x nnei x nh2 - h2_1 = self.attn2_ev_apply(AAh, h2) + # nf x nloc x nnei x nh2 + h2_1 = self.attn2_ev_apply(attn, h2) return h2_1 def _update_g1_conv( @@ -431,6 +744,21 @@ def _update_g1_conv( nlist_mask: torch.Tensor, sw: torch.Tensor, ) -> torch.Tensor: + """ + Calculate the convolution update for atomic invariant rep. + + Parameters + ---------- + gg1 + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + g2 + Pair invariant rep, with shape nb x nloc x nnei x ng2. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + """ assert self.proj_g1g2 is not None nb, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] @@ -442,71 +770,145 @@ def _update_g1_conv( if not self.smooth: # normalized by number of neighbors, not smooth # nb x nloc x 1 - invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)).unsqueeze(-1) + # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy + invnnei = 1.0 / ( + self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1) + ).unsqueeze(-1) else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=gg1.device + (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device ) # nb x nloc x ng2 g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei return g1_11 - def _cal_h2g2( - self, + @staticmethod + def _cal_hg( g2: torch.Tensor, h2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, + smooth: bool = True, + epsilon: float = 1e-4, ) -> torch.Tensor: - # g2: nf x nloc x nnei x ng2 - # h2: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + g2 + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape nb x nloc x 3 x ng2. + """ + # g2: nb x nloc x nnei x ng2 + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei nb, nloc, nnei, _ = g2.shape ng2 = g2.shape[-1] # nb x nloc x nnei x ng2 g2 = _apply_nlist_mask(g2, nlist_mask) - if not self.smooth: + if not smooth: # nb x nloc - invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)) + # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy + invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) # nb x nloc x 1 x 1 invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) else: g2 = _apply_switch(g2, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=g2.device + (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device ) # nb x nloc x 3 x ng2 h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei return h2g2 - def _cal_grrg(self, h2g2: torch.Tensor) -> torch.Tensor: + @staticmethod + def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + h2g2 + The transposed rotation matrix, with shape nb x nloc x 3 x ng2. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) + """ # nb x nloc x 3 x ng2 nb, nloc, _, ng2 = h2g2.shape # nb x nloc x 3 x axis - h2g2m = torch.split(h2g2, self.axis_dim, dim=-1)[0] + h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0] # nb x nloc x axis x ng2 g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) # nb x nloc x (axisxng2) - g1_13 = g1_13.view(nb, nloc, self.axis_dim * ng2) + g1_13 = g1_13.view(nb, nloc, axis_neuron * ng2) return g1_13 - def _update_g1_grrg( + def symmetrization_op( self, g2: torch.Tensor, h2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, + axis_neuron: int, + smooth: bool = True, + epsilon: float = 1e-4, ) -> torch.Tensor: - # g2: nf x nloc x nnei x ng2 - # h2: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + g2 + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + axis_neuron + Size of the submatrix. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) + """ + # g2: nb x nloc x nnei x ng2 + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei nb, nloc, nnei, _ = g2.shape - ng2 = g2.shape[-1] # nb x nloc x 3 x ng2 - h2g2 = self._cal_h2g2(g2, h2, nlist_mask, sw) + h2g2 = self._cal_hg(g2, h2, nlist_mask, sw, smooth=smooth, epsilon=epsilon) # nb x nloc x (axisxng2) - g1_13 = self._cal_grrg(h2g2) + g1_13 = self._cal_grrg(h2g2, axis_neuron) return g1_13 def _update_g2_g1g1( @@ -516,6 +918,21 @@ def _update_g2_g1g1( nlist_mask: torch.Tensor, # nb x nloc x nnei sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: + """ + Update the g2 using element-wise dot g1_i * g1_j. + + Parameters + ---------- + g1 + Atomic invariant rep, with shape nb x nloc x ng1. + gg1 + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + """ ret = g1.unsqueeze(-2) * gg1 # nb x nloc x nnei x ng1 ret = _apply_nlist_mask(ret, nlist_mask) @@ -523,73 +940,6 @@ def _update_g2_g1g1( ret = _apply_switch(ret, sw) return ret - def _apply_bn( - self, - bn_number: int, - gg: torch.Tensor, - ): - if self.do_bn_mode == "uniform": - return self._apply_bn_uni(bn_number, gg) - elif self.do_bn_mode == "component": - return self._apply_bn_comp(bn_number, gg) - else: - return gg - - def _apply_nb_1(self, bn_number: int, gg: torch.Tensor) -> torch.Tensor: - nb, nl, nf = gg.shape - gg = gg.view([nb, 1, nl * nf]) - if bn_number == 1: - assert self.bn1 is not None - gg = self.bn1(gg) - else: - assert self.bn2 is not None - gg = self.bn2(gg) - return gg.view([nb, nl, nf]) - - def _apply_nb_2( - self, - bn_number: int, - gg: torch.Tensor, - ) -> torch.Tensor: - nb, nl, nnei, nf = gg.shape - gg = gg.view([nb, 1, nl * nnei * nf]) - if bn_number == 1: - assert self.bn1 is not None - gg = self.bn1(gg) - else: - assert self.bn2 is not None - gg = self.bn2(gg) - return gg.view([nb, nl, nnei, nf]) - - def _apply_bn_uni( - self, - bn_number: int, - gg: torch.Tensor, - mode: str = "1", - ) -> torch.Tensor: - if len(gg.shape) == 3: - return self._apply_nb_1(bn_number, gg) - elif len(gg.shape) == 4: - return self._apply_nb_2(bn_number, gg) - else: - raise RuntimeError(f"unsupported input shape {gg.shape}") - - def _apply_bn_comp( - self, - bn_number: int, - gg: torch.Tensor, - ) -> torch.Tensor: - ss = gg.shape - nf = ss[-1] - gg = gg.view([-1, nf]) - if bn_number == 1: - assert self.bn1 is not None - gg = self.bn1(gg).view(ss) - else: - assert self.bn2 is not None - gg = self.bn2(gg).view(ss) - return gg - def forward( self, g1_ext: torch.Tensor, # nf x nall x ng1 @@ -627,16 +977,6 @@ def forward( g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) assert (nb, nloc) == g1.shape[:2] assert (nb, nloc, nnei) == h2.shape[:3] - ng1 = g1.shape[-1] - ng2 = g2.shape[-1] - nh2 = h2.shape[-1] - - if self.bn1 is not None: - g1 = self._apply_bn(1, g1) - if self.bn2 is not None: - g2 = self._apply_bn(2, g2) - if self.update_h2: - h2 = _apply_h_norm(h2) g2_update: List[torch.Tensor] = [g2] h2_update: List[torch.Tensor] = [h2] @@ -649,42 +989,68 @@ def forward( gg1 = None if self.update_chnnl_2: - # nb x nloc x nnei x ng2 + # mlp(g2) assert self.linear2 is not None + # nb x nloc x nnei x ng2 g2_1 = self.act(self.linear2(g2)) g2_update.append(g2_1) if self.update_g2_has_g1g1: + # linear(g1_i * g1_j) assert gg1 is not None assert self.proj_g1g1g2 is not None g2_update.append( self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) ) - if self.update_g2_has_attn: + if self.update_g2_has_attn or self.update_h2: + # gated_attention(g2, h2) assert self.attn2g_map is not None - assert self.attn2_mh_apply is not None - assert self.attn2_lm is not None # nb x nloc x nnei x nnei x nh AAg = self.attn2g_map(g2, h2, nlist_mask, sw) - # nb x nloc x nnei x ng2 - g2_2 = self.attn2_mh_apply(AAg, g2) - g2_2 = self.attn2_lm(g2_2) - g2_update.append(g2_2) - if self.update_h2: - h2_update.append(self._update_h2(g2, h2, nlist_mask, sw)) + if self.update_g2_has_attn: + assert self.attn2_mh_apply is not None + assert self.attn2_lm is not None + # nb x nloc x nnei x ng2 + g2_2 = self.attn2_mh_apply(AAg, g2) + g2_2 = self.attn2_lm(g2_2) + g2_update.append(g2_2) + + if self.update_h2: + # linear_head(attention_weights * h2) + h2_update.append(self._update_h2(h2, AAg)) if self.update_g1_has_conv: assert gg1 is not None g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) if self.update_g1_has_grrg: - g1_mlp.append(self._update_g1_grrg(g2, h2, nlist_mask, sw)) + g1_mlp.append( + self.symmetrization_op( + g2, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) if self.update_g1_has_drrd: assert gg1 is not None - g1_mlp.append(self._update_g1_grrg(gg1, h2, nlist_mask, sw)) + g1_mlp.append( + self.symmetrization_op( + gg1, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] # conv grrg drrd @@ -698,11 +1064,11 @@ def forward( # update if self.update_chnnl_2: - g2_new = self.list_update(g2_update) - h2_new = self.list_update(h2_update) + g2_new = self.list_update(g2_update, "g2") + h2_new = self.list_update(h2_update, "h2") else: g2_new, h2_new = g2, h2 - g1_new = self.list_update(g1_update) + g1_new = self.list_update(g1_update, "g1") return g1_new, g2_new, h2_new @torch.jit.export @@ -726,24 +1092,194 @@ def list_update_res_incr(self, update_list: List[torch.Tensor]) -> torch.Tensor: return uu @torch.jit.export - def list_update(self, update_list: List[torch.Tensor]) -> torch.Tensor: + def list_update_res_residual( + self, update_list: List[torch.Tensor], update_name: str = "g1" + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + # make jit happy + if update_name == "g1": + for ii, vv in enumerate(self.g1_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "g2": + for ii, vv in enumerate(self.g2_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "h2": + for ii, vv in enumerate(self.h2_residual): + uu = uu + vv * update_list[ii + 1] + else: + raise NotImplementedError + return uu + + @torch.jit.export + def list_update( + self, update_list: List[torch.Tensor], update_name: str = "g1" + ) -> torch.Tensor: if self.update_style == "res_avg": return self.list_update_res_avg(update_list) elif self.update_style == "res_incr": return self.list_update_res_incr(update_list) + elif self.update_style == "res_residual": + return self.list_update_res_residual(update_list, update_name=update_name) else: raise RuntimeError(f"unknown update style {self.update_style}") - def _bn_layer( - self, - nf: int = 1, - ) -> Callable: - return torch.nn.BatchNorm1d( - nf, - eps=1e-5, - momentum=self.bn_momentum, - affine=False, - track_running_stats=True, - device=env.DEVICE, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - ) + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + data = { + "@class": "RepformerLayer", + "@version": 1, + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "g1_dim": self.g1_dim, + "g2_dim": self.g2_dim, + "axis_neuron": self.axis_neuron, + "update_chnnl_2": self.update_chnnl_2, + "update_g1_has_conv": self.update_g1_has_conv, + "update_g1_has_drrd": self.update_g1_has_drrd, + "update_g1_has_grrg": self.update_g1_has_grrg, + "update_g1_has_attn": self.update_g1_has_attn, + "update_g2_has_g1g1": self.update_g2_has_g1g1, + "update_g2_has_attn": self.update_g2_has_attn, + "update_h2": self.update_h2, + "attn1_hidden": self.attn1_hidden, + "attn1_nhead": self.attn1_nhead, + "attn2_hidden": self.attn2_hidden, + "attn2_nhead": self.attn2_nhead, + "attn2_has_gate": self.attn2_has_gate, + "activation_function": self.activation_function, + "update_style": self.update_style, + "smooth": self.smooth, + "precision": self.precision, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "linear1": self.linear1.serialize(), + } + if self.update_chnnl_2: + data.update( + { + "linear2": self.linear2.serialize(), + } + ) + if self.update_g1_has_conv: + data.update( + { + "proj_g1g2": self.proj_g1g2.serialize(), + } + ) + if self.update_g2_has_g1g1: + data.update( + { + "proj_g1g1g2": self.proj_g1g1g2.serialize(), + } + ) + if self.update_g2_has_attn or self.update_h2: + data.update( + { + "attn2g_map": self.attn2g_map.serialize(), + } + ) + if self.update_g2_has_attn: + data.update( + { + "attn2_mh_apply": self.attn2_mh_apply.serialize(), + "attn2_lm": self.attn2_lm.serialize(), + } + ) + + if self.update_h2: + data.update( + { + "attn2_ev_apply": self.attn2_ev_apply.serialize(), + } + ) + if self.update_g1_has_attn: + data.update( + { + "loc_attn": self.loc_attn.serialize(), + } + ) + if self.update_style == "res_residual": + data.update( + { + "g1_residual": [to_numpy_array(t) for t in self.g1_residual], + "g2_residual": [to_numpy_array(t) for t in self.g2_residual], + "h2_residual": [to_numpy_array(t) for t in self.h2_residual], + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "RepformerLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + linear1 = data.pop("linear1") + update_chnnl_2 = data["update_chnnl_2"] + update_g1_has_conv = data["update_g1_has_conv"] + update_g2_has_g1g1 = data["update_g2_has_g1g1"] + update_g2_has_attn = data["update_g2_has_attn"] + update_h2 = data["update_h2"] + update_g1_has_attn = data["update_g1_has_attn"] + update_style = data["update_style"] + + linear2 = data.pop("linear2", None) + proj_g1g2 = data.pop("proj_g1g2", None) + proj_g1g1g2 = data.pop("proj_g1g1g2", None) + attn2g_map = data.pop("attn2g_map", None) + attn2_mh_apply = data.pop("attn2_mh_apply", None) + attn2_lm = data.pop("attn2_lm", None) + attn2_ev_apply = data.pop("attn2_ev_apply", None) + loc_attn = data.pop("loc_attn", None) + g1_residual = data.pop("g1_residual", []) + g2_residual = data.pop("g2_residual", []) + h2_residual = data.pop("h2_residual", []) + + obj = cls(**data) + obj.linear1 = MLPLayer.deserialize(linear1) + if update_chnnl_2: + assert isinstance(linear2, dict) + obj.linear2 = MLPLayer.deserialize(linear2) + if update_g1_has_conv: + assert isinstance(proj_g1g2, dict) + obj.proj_g1g2 = MLPLayer.deserialize(proj_g1g2) + if update_g2_has_g1g1: + assert isinstance(proj_g1g1g2, dict) + obj.proj_g1g1g2 = MLPLayer.deserialize(proj_g1g1g2) + if update_g2_has_attn or update_h2: + assert isinstance(attn2g_map, dict) + obj.attn2g_map = Atten2Map.deserialize(attn2g_map) + if update_g2_has_attn: + assert isinstance(attn2_mh_apply, dict) + assert isinstance(attn2_lm, dict) + obj.attn2_mh_apply = Atten2MultiHeadApply.deserialize(attn2_mh_apply) + obj.attn2_lm = LayerNorm.deserialize(attn2_lm) + if update_h2: + assert isinstance(attn2_ev_apply, dict) + obj.attn2_ev_apply = Atten2EquiVarApply.deserialize(attn2_ev_apply) + if update_g1_has_attn: + assert isinstance(loc_attn, dict) + obj.loc_attn = LocalAtten.deserialize(loc_attn) + if update_style == "res_residual": + for ii, t in enumerate(obj.g1_residual): + t.data = to_torch_tensor(g1_residual[ii]) + for ii, t in enumerate(obj.g2_residual): + t.data = to_torch_tensor(g2_residual[ii]) + for ii, t in enumerate(obj.h2_residual): + t.data = to_torch_tensor(h2_residual[ii]) + return obj diff --git a/deepmd/pt/model/descriptor/repformer_layer_old_impl.py b/deepmd/pt/model/descriptor/repformer_layer_old_impl.py new file mode 100644 index 0000000000..af9c2e0981 --- /dev/null +++ b/deepmd/pt/model/descriptor/repformer_layer_old_impl.py @@ -0,0 +1,745 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + List, +) + +import torch + +from deepmd.pt.model.network.network import ( + SimpleLinear, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) + + +def _make_nei_g1( + g1_ext: torch.Tensor, + nlist: torch.Tensor, +) -> torch.Tensor: + # nlist: nb x nloc x nnei + nb, nloc, nnei = nlist.shape + # g1_ext: nb x nall x ng1 + ng1 = g1_ext.shape[-1] + # index: nb x (nloc x nnei) x ng1 + index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) + # gg1 : nb x (nloc x nnei) x ng1 + gg1 = torch.gather(g1_ext, dim=1, index=index) + # gg1 : nb x nloc x nnei x ng1 + gg1 = gg1.view(nb, nloc, nnei, ng1) + return gg1 + + +def _apply_nlist_mask( + gg: torch.Tensor, + nlist_mask: torch.Tensor, +) -> torch.Tensor: + # gg: nf x nloc x nnei x ng + # msk: nf x nloc x nnei + return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0) + + +def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: + # gg: nf x nloc x nnei x ng + # sw: nf x nloc x nnei + return gg * sw.unsqueeze(-1) + + +def _apply_h_norm( + hh: torch.Tensor, # nf x nloc x nnei x 3 +) -> torch.Tensor: + """Normalize h by the std of vector length. + do not have an idea if this is a good way. + """ + nf, nl, nnei, _ = hh.shape + # nf x nloc x nnei + normh = torch.linalg.norm(hh, dim=-1) + # nf x nloc + std = torch.std(normh, dim=-1) + # nf x nloc x nnei x 3 + hh = hh[:, :, :, :] / (1.0 + std[:, :, None, None]) + return hh + + +class Atten2Map(torch.nn.Module): + def __init__( + self, + ni: int, + nd: int, + nh: int, + has_gate: bool = False, # apply gate to attn map + smooth: bool = True, + attnw_shift: float = 20.0, + ): + super().__init__() + self.ni = ni + self.nd = nd + self.nh = nh + self.mapqk = SimpleLinear(ni, nd * 2 * nh, bias=False) # todo + self.has_gate = has_gate + self.smooth = smooth + self.attnw_shift = attnw_shift + + def forward( + self, + g2: torch.Tensor, # nb x nloc x nnei x ng2 + h2: torch.Tensor, # nb x nloc x nnei x 3 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei + ) -> torch.Tensor: + ( + nb, + nloc, + nnei, + _, + ) = g2.shape + nd, nh = self.nd, self.nh + # nb x nloc x nnei x nd x (nh x 2) + g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2) + # nb x nloc x (nh x 2) x nnei x nd + g2qk = torch.permute(g2qk, (0, 1, 4, 2, 3)) + # nb x nloc x nh x nnei x nd + g2q, g2k = torch.split(g2qk, nh, dim=2) + # g2q = torch.nn.functional.normalize(g2q, dim=-1) + # g2k = torch.nn.functional.normalize(g2k, dim=-1) + # nb x nloc x nh x nnei x nnei + attnw = torch.matmul(g2q, torch.transpose(g2k, -1, -2)) / nd**0.5 + if self.has_gate: + gate = torch.matmul(h2, torch.transpose(h2, -1, -2)).unsqueeze(-3) + attnw = attnw * gate + # mask the attenmap, nb x nloc x 1 x 1 x nnei + attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2) + # mask the attenmap, nb x nloc x 1 x nnei x 1 + attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1) + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ + :, :, None, None, : + ] - self.attnw_shift + else: + attnw = attnw.masked_fill( + attnw_mask, + float("-inf"), + ) + attnw = torch.softmax(attnw, dim=-1) + attnw = attnw.masked_fill( + attnw_mask, + 0.0, + ) + # nb x nloc x nh x nnei x nnei + attnw = attnw.masked_fill( + attnw_mask_c, + 0.0, + ) + if self.smooth: + attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] + # nb x nloc x nnei x nnei + h2h2t = torch.matmul(h2, torch.transpose(h2, -1, -2)) / 3.0**0.5 + # nb x nloc x nh x nnei x nnei + ret = attnw * h2h2t[:, :, None, :, :] + # ret = torch.softmax(g2qk, dim=-1) + # nb x nloc x nnei x nnei x nh + ret = torch.permute(ret, (0, 1, 3, 4, 2)) + return ret + + +class Atten2MultiHeadApply(torch.nn.Module): + def __init__( + self, + ni: int, + nh: int, + ): + super().__init__() + self.ni = ni + self.nh = nh + self.mapv = SimpleLinear(ni, ni * nh, bias=False) + self.head_map = SimpleLinear(ni * nh, ni) + + def forward( + self, + AA: torch.Tensor, # nf x nloc x nnei x nnei x nh + g2: torch.Tensor, # nf x nloc x nnei x ng2 + ) -> torch.Tensor: + nf, nloc, nnei, ng2 = g2.shape + nh = self.nh + # nf x nloc x nnei x ng2 x nh + g2v = self.mapv(g2).view(nf, nloc, nnei, ng2, nh) + # nf x nloc x nh x nnei x ng2 + g2v = torch.permute(g2v, (0, 1, 4, 2, 3)) + # g2v = torch.nn.functional.normalize(g2v, dim=-1) + # nf x nloc x nh x nnei x nnei + AA = torch.permute(AA, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x ng2 + ret = torch.matmul(AA, g2v) + # nf x nloc x nnei x ng2 x nh + ret = torch.permute(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) + # nf x nloc x nnei x ng2 + return self.head_map(ret) + + +class Atten2EquiVarApply(torch.nn.Module): + def __init__( + self, + ni: int, + nh: int, + ): + super().__init__() + self.ni = ni + self.nh = nh + self.head_map = SimpleLinear(nh, 1, bias=False) + + def forward( + self, + AA: torch.Tensor, # nf x nloc x nnei x nnei x nh + h2: torch.Tensor, # nf x nloc x nnei x 3 + ) -> torch.Tensor: + nf, nloc, nnei, _ = h2.shape + nh = self.nh + # nf x nloc x nh x nnei x nnei + AA = torch.permute(AA, (0, 1, 4, 2, 3)) + h2m = torch.unsqueeze(h2, dim=2) + # nf x nloc x nh x nnei x 3 + h2m = torch.tile(h2m, [1, 1, nh, 1, 1]) + # nf x nloc x nh x nnei x 3 + ret = torch.matmul(AA, h2m) + # nf x nloc x nnei x 3 x nh + ret = torch.permute(ret, (0, 1, 3, 4, 2)).view(nf, nloc, nnei, 3, nh) + # nf x nloc x nnei x 3 + return torch.squeeze(self.head_map(ret), dim=-1) + + +class LocalAtten(torch.nn.Module): + def __init__( + self, + ni: int, + nd: int, + nh: int, + smooth: bool = True, + attnw_shift: float = 20.0, + ): + super().__init__() + self.ni = ni + self.nd = nd + self.nh = nh + self.mapq = SimpleLinear(ni, nd * 1 * nh, bias=False) + self.mapkv = SimpleLinear(ni, (nd + ni) * nh, bias=False) + self.head_map = SimpleLinear(ni * nh, ni) + self.smooth = smooth + self.attnw_shift = attnw_shift + + def forward( + self, + g1: torch.Tensor, # nb x nloc x ng1 + gg1: torch.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei + ) -> torch.Tensor: + nb, nloc, nnei = nlist_mask.shape + ni, nd, nh = self.ni, self.nd, self.nh + assert ni == g1.shape[-1] + assert ni == gg1.shape[-1] + # nb x nloc x nd x nh + g1q = self.mapq(g1).view(nb, nloc, nd, nh) + # nb x nloc x nh x nd + g1q = torch.permute(g1q, (0, 1, 3, 2)) + # nb x nloc x nnei x (nd+ni) x nh + gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh) + gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3)) + # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 + gg1k, gg1v = torch.split(gg1kv, [nd, ni], dim=-1) + + # nb x nloc x nh x 1 x nnei + attnw = torch.matmul(g1q.unsqueeze(-2), torch.transpose(gg1k, -1, -2)) / nd**0.5 + # nb x nloc x nh x nnei + attnw = attnw.squeeze(-2) + # mask the attenmap, nb x nloc x 1 x nnei + attnw_mask = ~nlist_mask.unsqueeze(-2) + # nb x nloc x nh x nnei + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift + else: + attnw = attnw.masked_fill( + attnw_mask, + float("-inf"), + ) + attnw = torch.softmax(attnw, dim=-1) + attnw = attnw.masked_fill( + attnw_mask, + 0.0, + ) + if self.smooth: + attnw = attnw * sw.unsqueeze(-2) + + # nb x nloc x nh x ng1 + ret = ( + torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) + ) + # nb x nloc x ng1 + ret = self.head_map(ret) + return ret + + +class RepformerLayer(torch.nn.Module): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + g1_dim=128, + g2_dim=16, + axis_neuron: int = 4, + update_chnnl_2: bool = True, + do_bn_mode: str = "no", + bn_momentum: float = 0.1, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation_function: str = "tanh", + update_style: str = "res_avg", + set_davg_zero: bool = True, # TODO + smooth: bool = True, + ): + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.rcut = rcut + self.rcut_smth = rcut_smth + self.ntypes = ntypes + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + assert len(sel) == 1 + self.sel = torch.tensor(sel, device=env.DEVICE) + self.sec = self.sel + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + self.do_bn_mode = do_bn_mode + self.bn_momentum = bn_momentum + self.act = ActivationFn(activation_function) + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_attn = update_g1_has_attn + self.update_chnnl_2 = update_chnnl_2 + self.update_g2_has_g1g1 = update_g2_has_g1g1 if self.update_chnnl_2 else False + self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False + self.update_h2 = update_h2 if self.update_chnnl_2 else False + del update_g2_has_g1g1, update_g2_has_attn, update_h2 + self.update_style = update_style + self.smooth = smooth + self.g1_dim = g1_dim + self.g2_dim = g2_dim + + g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) + self.linear1 = SimpleLinear(g1_in_dim, g1_dim) + self.linear2 = None + self.proj_g1g2 = None + self.proj_g1g1g2 = None + self.attn2g_map = None + self.attn2_mh_apply = None + self.attn2_lm = None + self.attn2h_map = None + self.attn2_ev_apply = None + self.loc_attn = None + + if self.update_chnnl_2: + self.linear2 = SimpleLinear(g2_dim, g2_dim) + if self.update_g1_has_conv: + self.proj_g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) + if self.update_g2_has_g1g1: + self.proj_g1g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) + if self.update_g2_has_attn: + self.attn2g_map = Atten2Map( + g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth + ) + self.attn2_mh_apply = Atten2MultiHeadApply(g2_dim, attn2_nhead) + self.attn2_lm = torch.nn.LayerNorm( + g2_dim, + elementwise_affine=True, + device=env.DEVICE, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + ) + if self.update_h2: + self.attn2h_map = Atten2Map( + g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth + ) + self.attn2_ev_apply = Atten2EquiVarApply(g2_dim, attn2_nhead) + if self.update_g1_has_attn: + self.loc_attn = LocalAtten(g1_dim, attn1_hidden, attn1_nhead, self.smooth) + + if self.do_bn_mode == "uniform": + self.bn1 = self._bn_layer() + self.bn2 = self._bn_layer() + elif self.do_bn_mode == "component": + self.bn1 = self._bn_layer(nf=g1_dim) + self.bn2 = self._bn_layer(nf=g2_dim) + elif self.do_bn_mode == "no": + self.bn1, self.bn2 = None, None + else: + raise RuntimeError(f"unknown bn_mode {self.do_bn_mode}") + + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: + ret = g1d + if self.update_g1_has_grrg: + ret += g2d * ax + if self.update_g1_has_drrd: + ret += g1d * ax + if self.update_g1_has_conv: + ret += g2d + return ret + + def _update_h2( + self, + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + assert self.attn2h_map is not None + assert self.attn2_ev_apply is not None + nb, nloc, nnei, _ = g2.shape + # # nb x nloc x nnei x nh2 + # h2_1 = self.attn2_ev_apply(AA, h2) + # h2_update.append(h2_1) + # nb x nloc x nnei x nnei x nh + AAh = self.attn2h_map(g2, h2, nlist_mask, sw) + # nb x nloc x nnei x nh2 + h2_1 = self.attn2_ev_apply(AAh, h2) + return h2_1 + + def _update_g1_conv( + self, + gg1: torch.Tensor, + g2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + assert self.proj_g1g2 is not None + nb, nloc, nnei, _ = g2.shape + ng1 = gg1.shape[-1] + ng2 = g2.shape[-1] + # gg1 : nb x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) + # nb x nloc x nnei x ng2 + gg1 = _apply_nlist_mask(gg1, nlist_mask) + if not self.smooth: + # normalized by number of neighbors, not smooth + # nb x nloc x 1 + invnnei = 1.0 / ( + self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1) + ).unsqueeze(-1) + else: + gg1 = _apply_switch(gg1, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=gg1.device + ) + # nb x nloc x ng2 + g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei + return g1_11 + + def _cal_h2g2( + self, + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + # g2: nf x nloc x nnei x ng2 + # h2: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nb, nloc, nnei, _ = g2.shape + ng2 = g2.shape[-1] + # nb x nloc x nnei x ng2 + g2 = _apply_nlist_mask(g2, nlist_mask) + if not self.smooth: + # nb x nloc + invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) + # nb x nloc x 1 x 1 + invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) + else: + g2 = _apply_switch(g2, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=g2.device + ) + # nb x nloc x 3 x ng2 + h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei + return h2g2 + + def _cal_grrg(self, h2g2: torch.Tensor) -> torch.Tensor: + # nb x nloc x 3 x ng2 + nb, nloc, _, ng2 = h2g2.shape + # nb x nloc x 3 x axis + h2g2m = torch.split(h2g2, self.axis_neuron, dim=-1)[0] + # nb x nloc x axis x ng2 + g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) + # nb x nloc x (axisxng2) + g1_13 = g1_13.view(nb, nloc, self.axis_neuron * ng2) + return g1_13 + + def _update_g1_grrg( + self, + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + # g2: nf x nloc x nnei x ng2 + # h2: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nb, nloc, nnei, _ = g2.shape + ng2 = g2.shape[-1] + # nb x nloc x 3 x ng2 + h2g2 = self._cal_h2g2(g2, h2, nlist_mask, sw) + # nb x nloc x (axisxng2) + g1_13 = self._cal_grrg(h2g2) + return g1_13 + + def _update_g2_g1g1( + self, + g1: torch.Tensor, # nb x nloc x ng1 + gg1: torch.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei + ) -> torch.Tensor: + ret = g1.unsqueeze(-2) * gg1 + # nb x nloc x nnei x ng1 + ret = _apply_nlist_mask(ret, nlist_mask) + if self.smooth: + ret = _apply_switch(ret, sw) + return ret + + def _apply_bn( + self, + bn_number: int, + gg: torch.Tensor, + ): + if self.do_bn_mode == "uniform": + return self._apply_bn_uni(bn_number, gg) + elif self.do_bn_mode == "component": + return self._apply_bn_comp(bn_number, gg) + else: + return gg + + def _apply_nb_1(self, bn_number: int, gg: torch.Tensor) -> torch.Tensor: + nb, nl, nf = gg.shape + gg = gg.view([nb, 1, nl * nf]) + if bn_number == 1: + assert self.bn1 is not None + gg = self.bn1(gg) + else: + assert self.bn2 is not None + gg = self.bn2(gg) + return gg.view([nb, nl, nf]) + + def _apply_nb_2( + self, + bn_number: int, + gg: torch.Tensor, + ) -> torch.Tensor: + nb, nl, nnei, nf = gg.shape + gg = gg.view([nb, 1, nl * nnei * nf]) + if bn_number == 1: + assert self.bn1 is not None + gg = self.bn1(gg) + else: + assert self.bn2 is not None + gg = self.bn2(gg) + return gg.view([nb, nl, nnei, nf]) + + def _apply_bn_uni( + self, + bn_number: int, + gg: torch.Tensor, + mode: str = "1", + ) -> torch.Tensor: + if len(gg.shape) == 3: + return self._apply_nb_1(bn_number, gg) + elif len(gg.shape) == 4: + return self._apply_nb_2(bn_number, gg) + else: + raise RuntimeError(f"unsupported input shape {gg.shape}") + + def _apply_bn_comp( + self, + bn_number: int, + gg: torch.Tensor, + ) -> torch.Tensor: + ss = gg.shape + nf = ss[-1] + gg = gg.view([-1, nf]) + if bn_number == 1: + assert self.bn1 is not None + gg = self.bn1(gg).view(ss) + else: + assert self.bn2 is not None + gg = self.bn2(gg).view(ss) + return gg + + def forward( + self, + g1_ext: torch.Tensor, # nf x nall x ng1 + g2: torch.Tensor, # nf x nloc x nnei x ng2 + h2: torch.Tensor, # nf x nloc x nnei x 3 + nlist: torch.Tensor, # nf x nloc x nnei + nlist_mask: torch.Tensor, # nf x nloc x nnei + sw: torch.Tensor, # switch func, nf x nloc x nnei + ): + """ + Parameters + ---------- + g1_ext : nf x nall x ng1 extended single-atom chanel + g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant + h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant + nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) + nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei switch function + + Returns + ------- + g1: nf x nloc x ng1 updated single-atom chanel + g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant + h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant + """ + cal_gg1 = ( + self.update_g1_has_drrd + or self.update_g1_has_conv + or self.update_g1_has_attn + or self.update_g2_has_g1g1 + ) + + nb, nloc, nnei, _ = g2.shape + nall = g1_ext.shape[1] + g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) + assert (nb, nloc) == g1.shape[:2] + assert (nb, nloc, nnei) == h2.shape[:3] + ng1 = g1.shape[-1] + ng2 = g2.shape[-1] + nh2 = h2.shape[-1] + + if self.bn1 is not None: + g1 = self._apply_bn(1, g1) + if self.bn2 is not None: + g2 = self._apply_bn(2, g2) + if self.update_h2: + h2 = _apply_h_norm(h2) + + g2_update: List[torch.Tensor] = [g2] + h2_update: List[torch.Tensor] = [h2] + g1_update: List[torch.Tensor] = [g1] + g1_mlp: List[torch.Tensor] = [g1] + + if cal_gg1: + gg1 = _make_nei_g1(g1_ext, nlist) + else: + gg1 = None + + if self.update_chnnl_2: + # nb x nloc x nnei x ng2 + assert self.linear2 is not None + g2_1 = self.act(self.linear2(g2)) + g2_update.append(g2_1) + + if self.update_g2_has_g1g1: + assert gg1 is not None + assert self.proj_g1g1g2 is not None + g2_update.append( + self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) + ) + + if self.update_g2_has_attn: + assert self.attn2g_map is not None + assert self.attn2_mh_apply is not None + assert self.attn2_lm is not None + # nb x nloc x nnei x nnei x nh + AAg = self.attn2g_map(g2, h2, nlist_mask, sw) + # nb x nloc x nnei x ng2 + g2_2 = self.attn2_mh_apply(AAg, g2) + g2_2 = self.attn2_lm(g2_2) + g2_update.append(g2_2) + + if self.update_h2: + h2_update.append(self._update_h2(g2, h2, nlist_mask, sw)) + + if self.update_g1_has_conv: + assert gg1 is not None + g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) + + if self.update_g1_has_grrg: + g1_mlp.append(self._update_g1_grrg(g2, h2, nlist_mask, sw)) + + if self.update_g1_has_drrd: + assert gg1 is not None + g1_mlp.append(self._update_g1_grrg(gg1, h2, nlist_mask, sw)) + + # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] + # conv grrg drrd + g1_1 = self.act(self.linear1(torch.cat(g1_mlp, dim=-1))) + g1_update.append(g1_1) + + if self.update_g1_has_attn: + assert gg1 is not None + assert self.loc_attn is not None + g1_update.append(self.loc_attn(g1, gg1, nlist_mask, sw)) + + # update + if self.update_chnnl_2: + g2_new = self.list_update(g2_update) + h2_new = self.list_update(h2_update) + else: + g2_new, h2_new = g2, h2 + g1_new = self.list_update(g1_update) + return g1_new, g2_new, h2_new + + @torch.jit.export + def list_update_res_avg( + self, + update_list: List[torch.Tensor], + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + @torch.jit.export + def list_update_res_incr(self, update_list: List[torch.Tensor]) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + @torch.jit.export + def list_update(self, update_list: List[torch.Tensor]) -> torch.Tensor: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def _bn_layer( + self, + nf: int = 1, + ) -> Callable: + return torch.nn.BatchNorm1d( + nf, + eps=1e-5, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + device=env.DEVICE, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + ) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index c91ca8056b..2d6a61d264 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -16,8 +16,8 @@ from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat, ) -from deepmd.pt.model.network.network import ( - SimpleLinear, +from deepmd.pt.model.network.mlp import ( + MLPLayer, ) from deepmd.pt.utils import ( env, @@ -41,18 +41,7 @@ from .repformer_layer import ( RepformerLayer, ) - -mydtype = env.GLOBAL_PT_FLOAT_PRECISION -mydev = env.DEVICE - - -def torch_linear(*args, **kwargs): - return torch.nn.Linear(*args, **kwargs, dtype=mydtype, device=mydev) - - -simple_linear = SimpleLinear -mylinear = simple_linear - +from .repformer_layer_old_impl import RepformerLayer as RepformerLayerOld if not hasattr(torch.ops.deepmd, "border_op"): @@ -87,10 +76,8 @@ def __init__( nlayers: int = 3, g1_dim=128, g2_dim=16, - axis_dim: int = 4, + axis_neuron: int = 4, direct_dist: bool = False, - do_bn_mode: str = "no", - bn_momentum: float = 0.1, update_g1_has_conv: bool = True, update_g1_has_drrd: bool = True, update_g1_has_grrg: bool = True, @@ -105,24 +92,96 @@ def __init__( attn2_has_gate: bool = False, activation_function: str = "tanh", update_style: str = "res_avg", - set_davg_zero: bool = True, # TODO + update_residual: float = 0.001, + update_residual_init: str = "norm", + set_davg_zero: bool = True, smooth: bool = True, - add_type_ebd_to_seq: bool = False, exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, - type: Optional[str] = None, + precision: str = "float64", + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + old_impl: bool = False, ): - """ - smooth: - If strictly smooth, cannot be used with update_g1_has_attn - add_type_ebd_to_seq: - At the presence of seq_input (optional input to forward), - whether or not add an type embedding to seq_input. - If no seq_input is given, it has no effect. + r""" + The repformer descriptor block. + + Parameters + ---------- + rcut : float + The cut-off radius. + rcut_smth : float + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. + sel : int + Maximally possible number of selected neighbors. + ntypes : int + Number of element types + nlayers : int, optional + Number of repformer layers. + g1_dim : int, optional + Dimension of the first graph convolution layer. + g2_dim : int, optional + Dimension of the second graph convolution layer. + axis_neuron : int, optional + Size of the submatrix of G (embedding matrix). + direct_dist : bool, optional + Whether to use direct distance information (1/r term) in the repformer block. + update_g1_has_conv : bool, optional + Whether to update the g1 rep with convolution term. + update_g1_has_drrd : bool, optional + Whether to update the g1 rep with the drrd term. + update_g1_has_grrg : bool, optional + Whether to update the g1 rep with the grrg term. + update_g1_has_attn : bool, optional + Whether to update the g1 rep with the localized self-attention. + update_g2_has_g1g1 : bool, optional + Whether to update the g2 rep with the g1xg1 term. + update_g2_has_attn : bool, optional + Whether to update the g2 rep with the gated self-attention. + update_h2 : bool, optional + Whether to update the h2 rep. + attn1_hidden : int, optional + The hidden dimension of localized self-attention to update the g1 rep. + attn1_nhead : int, optional + The number of heads in localized self-attention to update the g1 rep. + attn2_hidden : int, optional + The hidden dimension of gated self-attention to update the g2 rep. + attn2_nhead : int, optional + The number of heads in gated self-attention to update the g2 rep. + attn2_has_gate : bool, optional + Whether to use gate in the gated self-attention to update the g2 rep. + activation_function : str, optional + The activation function in the embedding net. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + set_davg_zero : bool, optional + Set the normalization average to zero. + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : List[List[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable_ln : bool, optional + Whether to use trainable shift and scale weights in layer normalization. + ln_eps : float, optional + The epsilon value for layer normalization. """ super().__init__() - del type - self.epsilon = 1e-4 # protection of 1./nnei self.rcut = rcut self.rcut_smth = rcut_smth self.ntypes = ntypes @@ -134,54 +193,111 @@ def __init__( self.sel = sel self.sec = self.sel self.split_sel = self.sel - self.axis_dim = axis_dim + self.axis_neuron = axis_neuron self.set_davg_zero = set_davg_zero self.g1_dim = g1_dim self.g2_dim = g2_dim - self.act = ActivationFn(activation_function) + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_attn = update_g1_has_attn + self.update_g2_has_g1g1 = update_g2_has_g1g1 + self.update_g2_has_attn = update_g2_has_attn + self.update_h2 = update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_has_gate = attn2_has_gate + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.activation_function = activation_function + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init self.direct_dist = direct_dist - self.add_type_ebd_to_seq = add_type_ebd_to_seq + self.act = ActivationFn(activation_function) + self.smooth = smooth # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.env_protection = env_protection + self.precision = precision + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.epsilon = 1e-4 + self.old_impl = old_impl - self.g2_embd = mylinear(1, self.g2_dim) + self.g2_embd = MLPLayer(1, self.g2_dim, precision=precision) layers = [] for ii in range(nlayers): - layers.append( - RepformerLayer( - rcut, - rcut_smth, - sel, - ntypes, - self.g1_dim, - self.g2_dim, - axis_dim=self.axis_dim, - update_chnnl_2=(ii != nlayers - 1), - do_bn_mode=do_bn_mode, - bn_momentum=bn_momentum, - update_g1_has_conv=update_g1_has_conv, - update_g1_has_drrd=update_g1_has_drrd, - update_g1_has_grrg=update_g1_has_grrg, - update_g1_has_attn=update_g1_has_attn, - update_g2_has_g1g1=update_g2_has_g1g1, - update_g2_has_attn=update_g2_has_attn, - update_h2=update_h2, - attn1_hidden=attn1_hidden, - attn1_nhead=attn1_nhead, - attn2_has_gate=attn2_has_gate, - attn2_hidden=attn2_hidden, - attn2_nhead=attn2_nhead, - activation_function=activation_function, - update_style=update_style, - smooth=smooth, + if self.old_impl: + layers.append( + RepformerLayerOld( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.g1_dim, + self.g2_dim, + axis_neuron=self.axis_neuron, + update_chnnl_2=(ii != nlayers - 1), + update_g1_has_conv=self.update_g1_has_conv, + update_g1_has_drrd=self.update_g1_has_drrd, + update_g1_has_grrg=self.update_g1_has_grrg, + update_g1_has_attn=self.update_g1_has_attn, + update_g2_has_g1g1=self.update_g2_has_g1g1, + update_g2_has_attn=self.update_g2_has_attn, + update_h2=self.update_h2, + attn1_hidden=self.attn1_hidden, + attn1_nhead=self.attn1_nhead, + attn2_has_gate=self.attn2_has_gate, + attn2_hidden=self.attn2_hidden, + attn2_nhead=self.attn2_nhead, + activation_function=self.activation_function, + update_style=self.update_style, + smooth=self.smooth, + ) + ) + else: + layers.append( + RepformerLayer( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.g1_dim, + self.g2_dim, + axis_neuron=self.axis_neuron, + update_chnnl_2=(ii != nlayers - 1), + update_g1_has_conv=self.update_g1_has_conv, + update_g1_has_drrd=self.update_g1_has_drrd, + update_g1_has_grrg=self.update_g1_has_grrg, + update_g1_has_attn=self.update_g1_has_attn, + update_g2_has_g1g1=self.update_g2_has_g1g1, + update_g2_has_attn=self.update_g2_has_attn, + update_h2=self.update_h2, + attn1_hidden=self.attn1_hidden, + attn1_nhead=self.attn1_nhead, + attn2_has_gate=self.attn2_has_gate, + attn2_hidden=self.attn2_hidden, + attn2_nhead=self.attn2_nhead, + activation_function=self.activation_function, + update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, + smooth=self.smooth, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + precision=precision, + ) ) - ) self.layers = torch.nn.ModuleList(layers) - sshape = (self.ntypes, self.nnei, 4) - mean = torch.zeros(sshape, dtype=mydtype, device=mydev) - stddev = torch.ones(sshape, dtype=mydtype, device=mydev) + wanted_shape = (self.ntypes, self.nnei, 4) + mean = torch.zeros( + wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + stddev = torch.ones( + wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.stats = None @@ -214,6 +330,22 @@ def get_dim_emb(self) -> int: """Returns the embedding dimension g2.""" return self.g2_dim + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + def mixed_types(self) -> bool: """If true, the discriptor 1. assumes total number of atoms aligned across frames; @@ -263,6 +395,9 @@ def forward( nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] + # nb x nloc x nnei + exclude_mask = self.emask(nlist, extended_atype) + nlist = nlist * exclude_mask # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 dmatrix, diff, sw = prod_env_mat( extended_coord, @@ -345,9 +480,10 @@ def forward( sw, ) - # uses the last layer. # nb x nloc x 3 x ng2 - h2g2 = ll._cal_h2g2(g2, h2, nlist_mask, sw) + h2g2 = RepformerLayer._cal_hg( + g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon + ) # (nb x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index d573ba9b7f..5de0aeffab 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -299,7 +299,7 @@ def get_ntypes(self) -> int: return self.ntypes def get_dim_in(self) -> int: - """Returns the output dimension.""" + """Returns the input dimension.""" return self.dim_in def get_dim_out(self) -> int: @@ -761,7 +761,7 @@ def forward( sw: Optional[torch.Tensor] = None, ): residual = x - x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x, _ = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) x = residual + x x = self.attn_layer_norm(x) return x @@ -814,6 +814,7 @@ def __init__( nnei: int, embed_dim: int, hidden_dim: int, + num_heads: int = 1, dotr: bool = False, do_mask: bool = False, scaling_factor: float = 1.0, @@ -823,11 +824,14 @@ def __init__( smooth: bool = True, precision: str = DEFAULT_PRECISION, ): - """Construct a neighbor-wise attention net.""" + """Construct a multi-head neighbor-wise attention net.""" super().__init__() + assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" self.nnei = nnei self.embed_dim = embed_dim self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads self.dotr = dotr self.do_mask = do_mask self.bias = bias @@ -835,10 +839,11 @@ def __init__( self.scaling_factor = scaling_factor self.temperature = temperature self.precision = precision - if temperature is None: - self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 - else: - self.scaling = temperature + self.scaling = ( + (self.head_dim * scaling_factor) ** -0.5 + if temperature is None + else temperature + ) self.normalize = normalize self.in_proj = MLPLayer( embed_dim, @@ -867,7 +872,7 @@ def forward( sw: Optional[torch.Tensor] = None, attnw_shift: float = 20.0, ): - """Compute the gated self-attention. + """Compute the multi-head gated self-attention. Parameters ---------- @@ -883,43 +888,66 @@ def forward( The attention weight shift to preserve smoothness when doing padding before softmax. """ q, k, v = self.in_proj(query).chunk(3, dim=-1) - # [nframes * nloc, nnei, hidden_dim] - q = q.view(-1, self.nnei, self.hidden_dim) - k = k.view(-1, self.nnei, self.hidden_dim) - v = v.view(-1, self.nnei, self.hidden_dim) + + # Reshape for multi-head attention: (nf x nloc) x num_heads x nnei x head_dim + q = q.view(-1, self.nnei, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(-1, self.nnei, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(-1, self.nnei, self.num_heads, self.head_dim).transpose(1, 2) + if self.normalize: q = torch_func.normalize(q, dim=-1) k = torch_func.normalize(k, dim=-1) v = torch_func.normalize(v, dim=-1) + q = q * self.scaling - k = k.transpose(1, 2) - # [nframes * nloc, nnei, nnei] - attn_weights = torch.bmm(q, k) - # [nframes * nloc, nnei] + # (nf x nloc) x num_heads x head_dim x nnei + k = k.transpose(-2, -1) + + # Compute attention scores + # (nf x nloc) x num_heads x nnei x nnei + attn_weights = torch.matmul(q, k) + # (nf x nloc) x nnei nei_mask = nei_mask.view(-1, self.nnei) + if self.smooth: - # [nframes * nloc, nnei] assert sw is not None - sw = sw.view([-1, self.nnei]) - attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[ - :, None, : + # (nf x nloc) x 1 x nnei + sw = sw.view(-1, 1, self.nnei) + attn_weights = (attn_weights + attnw_shift) * sw[:, :, :, None] * sw[ + :, :, None, : ] - attnw_shift else: + # (nf x nloc) x 1 x 1 x nnei attn_weights = attn_weights.masked_fill( - ~nei_mask.unsqueeze(1), float("-inf") + ~nei_mask.unsqueeze(1).unsqueeze(1), float("-inf") ) + attn_weights = torch_func.softmax(attn_weights, dim=-1) - attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), 0.0) + attn_weights = attn_weights.masked_fill( + ~nei_mask.unsqueeze(1).unsqueeze(-1), 0.0 + ) if self.smooth: assert sw is not None - attn_weights = attn_weights * sw[:, :, None] * sw[:, None, :] + attn_weights = attn_weights * sw[:, :, :, None] * sw[:, :, None, :] + if self.dotr: + # (nf x nloc) x nnei x 3 assert input_r is not None, "input_r must be provided when dotr is True!" - angular_weight = torch.bmm(input_r, input_r.transpose(1, 2)) + # (nf x nloc) x 1 x nnei x nnei + angular_weight = torch.matmul(input_r, input_r.transpose(-2, -1)).view( + -1, 1, self.nnei, self.nnei + ) attn_weights = attn_weights * angular_weight - o = torch.bmm(attn_weights, v) + + # Apply attention to values + # (nf x nloc) x nnei x (num_heads x head_dim) + o = ( + torch.matmul(attn_weights, v) + .transpose(1, 2) + .reshape(-1, self.nnei, self.hidden_dim) + ) output = self.out_proj(o) - return output + return output, attn_weights def serialize(self) -> dict: """Serialize the networks to a dict. @@ -929,12 +957,11 @@ def serialize(self) -> dict: dict The serialized networks. """ - # network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()} - # network_type_name = network_type_map_inv[self.network_type] return { "nnei": self.nnei, "embed_dim": self.embed_dim, "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, "dotr": self.dotr, "do_mask": self.do_mask, "scaling_factor": self.scaling_factor, diff --git a/deepmd/pt/model/network/layernorm.py b/deepmd/pt/model/network/layernorm.py index 27b9808010..7c58e248ba 100644 --- a/deepmd/pt/model/network/layernorm.py +++ b/deepmd/pt/model/network/layernorm.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np import torch import torch.nn as nn @@ -15,14 +16,14 @@ to_torch_tensor, ) -from .mlp import ( - MLPLayer, -) - device = env.DEVICE -class LayerNorm(MLPLayer): +def empty_t(shape, precision): + return torch.empty(shape, dtype=precision, device=device) + + +class LayerNorm(nn.Module): def __init__( self, num_in, @@ -33,24 +34,22 @@ def __init__( precision: str = DEFAULT_PRECISION, trainable: bool = True, ): + super().__init__() self.eps = eps self.uni_init = uni_init self.num_in = num_in - super().__init__( - num_in=1, - num_out=num_in, - bias=True, - use_timestep=False, - activation_function=None, - resnet=False, - bavg=bavg, - stddev=stddev, - precision=precision, + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.matrix = nn.Parameter(data=empty_t((num_in,), self.prec)) + self.bias = nn.Parameter( + data=empty_t([num_in], self.prec), ) - self.matrix = torch.nn.Parameter(self.matrix.squeeze(0)) if self.uni_init: nn.init.ones_(self.matrix.data) nn.init.zeros_(self.bias.data) + else: + nn.init.normal_(self.bias.data, mean=bavg, std=stddev) + nn.init.normal_(self.matrix.data, std=stddev / np.sqrt(self.num_in)) self.trainable = trainable if not self.trainable: self.matrix.requires_grad = False @@ -75,8 +74,11 @@ def forward( yy: torch.Tensor The output. """ - mean = xx.mean(dim=-1, keepdim=True) - variance = xx.var(dim=-1, unbiased=False, keepdim=True) + # mean = xx.mean(dim=-1, keepdim=True) + # variance = xx.var(dim=-1, unbiased=False, keepdim=True) + # The following operation is the same as above, but will not raise error when using jit model to inference. + # See https://github.com/pytorch/pytorch/issues/85792 + variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True) yy = (xx - mean) / torch.sqrt(variance + self.eps) if self.matrix is not None and self.bias is not None: yy = yy * self.matrix + self.bias diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 762461111e..5bd1fb0484 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -44,6 +44,28 @@ def empty_t(shape, precision): return torch.empty(shape, dtype=precision, device=device) +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + xx: torch.Tensor, + ) -> torch.Tensor: + """The Identity operation layer.""" + return xx + + def serialize(self) -> dict: + return { + "@class": "Identity", + "@version": 1, + } + + @classmethod + def deserialize(cls, data: dict) -> "Identity": + return Identity() + + class MLPLayer(nn.Module): def __init__( self, @@ -56,31 +78,47 @@ def __init__( bavg: float = 0.0, stddev: float = 1.0, precision: str = DEFAULT_PRECISION, + init: str = "default", ): super().__init__() # only use_timestep when skip connection is established. self.use_timestep = use_timestep and ( num_out == num_in or num_out == num_in * 2 ) + self.num_in = num_in + self.num_out = num_out self.activate_name = activation_function self.activate = ActivationFn(self.activate_name) self.precision = precision self.prec = PRECISION_DICT[self.precision] self.matrix = nn.Parameter(data=empty_t((num_in, num_out), self.prec)) - nn.init.normal_(self.matrix.data, std=stddev / np.sqrt(num_out + num_in)) if bias: self.bias = nn.Parameter( data=empty_t([num_out], self.prec), ) - nn.init.normal_(self.bias.data, mean=bavg, std=stddev) else: self.bias = None if self.use_timestep: self.idt = nn.Parameter(data=empty_t([num_out], self.prec)) - nn.init.normal_(self.idt.data, mean=0.1, std=0.001) else: self.idt = None self.resnet = resnet + if init == "default": + self._default_normal_init(bavg=bavg, stddev=stddev) + elif init == "trunc_normal": + self._trunc_normal_init(1.0) + elif init == "relu": + self._trunc_normal_init(2.0) + elif init == "glorot": + self._glorot_uniform_init() + elif init == "gating": + self._zero_init(self.use_bias) + elif init == "kaiming_normal": + self._normal_init() + elif init == "final": + self._zero_init(False) + else: + raise ValueError(f"Unknown initialization method: {init}") def check_type_consistency(self): precision = self.precision @@ -90,8 +128,8 @@ def check_var(var): # assertion "float64" == "double" would fail assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision] - check_var(self.w) - check_var(self.b) + check_var(self.matrix) + check_var(self.bias) check_var(self.idt) def dim_in(self) -> int: @@ -100,6 +138,36 @@ def dim_in(self) -> int: def dim_out(self) -> int: return self.matrix.shape[1] + def _default_normal_init(self, bavg: float = 0.0, stddev: float = 1.0): + nn.init.normal_( + self.matrix.data, std=stddev / np.sqrt(self.num_out + self.num_in) + ) + if self.bias is not None: + nn.init.normal_(self.bias.data, mean=bavg, std=stddev) + if self.idt is not None: + nn.init.normal_(self.idt.data, mean=0.1, std=0.001) + + def _trunc_normal_init(self, scale=1.0): + # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 + _, fan_in = self.matrix.shape + scale = scale / max(1, fan_in) + std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR + nn.init.trunc_normal_(self.matrix, mean=0.0, std=std) + + def _glorot_uniform_init(self): + nn.init.xavier_uniform_(self.matrix, gain=1) + + def _zero_init(self, use_bias=True): + with torch.no_grad(): + self.matrix.fill_(0.0) + if use_bias and self.bias is not None: + with torch.no_grad(): + self.bias.fill_(1.0) + + def _normal_init(self): + nn.init.kaiming_normal_(self.matrix, nonlinearity="linear") + def forward( self, xx: torch.Tensor, diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index bc7315e66a..dcf3f3c24a 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -2232,6 +2232,11 @@ def update_attention_layers_serialize(self, data: dict): new_dict["attention_layers"][layer_idx]["attention_layer"].update( update_info ) + new_dict["attention_layers"][layer_idx]["attention_layer"].update( + { + "num_heads": 1, + } + ) return new_dict @classmethod diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index f6dbada56c..fb0e0855b8 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -281,7 +281,7 @@ def descrpt_se_a_args(): float, optional=True, default=0.0, - doc=doc_only_tf_supported + doc_env_protection, + doc=doc_only_pt_supported + doc_env_protection, ), Argument( "set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero @@ -360,6 +360,7 @@ def descrpt_se_r_args(): doc_seed = "Random seed for parameter initialization" doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used" + doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." return [ Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel), @@ -392,6 +393,13 @@ def descrpt_se_r_args(): Argument( "set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero ), + Argument( + "env_protection", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_env_protection, + ), ] @@ -423,25 +431,15 @@ def descrpt_se_atten_common_args(): doc_neuron = "Number of neurons in each hidden layers of the embedding net. When two layers are of the same size or one layer is twice as large as the previous layer, a skip connection is built." doc_axis_neuron = "Size of the submatrix of G (embedding matrix)." doc_activation_function = f'The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' - doc_resnet_dt = ( - doc_only_tf_supported + 'Whether to use a "Timestep" in the skip connection' - ) - doc_type_one_side = ( - doc_only_tf_supported - + r"If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. Default is 'False'." - ) - doc_precision = ( - doc_only_tf_supported - + f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." - ) + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' + doc_type_one_side = r"If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. Default is 'False'." + doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." doc_trainable = ( doc_only_tf_supported + "If the parameters in the embedding net is trainable" ) doc_seed = "Random seed for parameter initialization" - doc_exclude_types = ( - doc_only_tf_supported - + "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." - ) + doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." + doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." doc_attn = "The length of hidden vectors in attention layers" doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` is only enabled when attn_layer==0 and tebd_input_mode=='strip'" doc_attn_dotr = "Whether to do dot product with the normalized relative coordinates" @@ -485,6 +483,13 @@ def descrpt_se_atten_common_args(): default=[], doc=doc_exclude_types, ), + Argument( + "env_protection", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_env_protection, + ), Argument("attn", int, optional=True, default=128, doc=doc_attn), Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer), Argument("attn_dotr", bool, optional=True, default=True, doc=doc_attn_dotr), @@ -612,87 +617,128 @@ def descrpt_se_atten_v2_args(): @descrpt_args_plugin.register("dpa2", doc=doc_only_pt_supported) def descrpt_dpa2_args(): - # Generate by GitHub Copilot - doc_repinit_rcut = "The cut-off radius of the repinit block" - doc_repinit_rcut_smth = "From this position the inverse distance smoothly decays to 0 at the cut-off. Use in the repinit block." - doc_repinit_nsel = "Maximally possible number of neighbors for repinit block." - doc_repformer_rcut = "The cut-off radius of the repformer block" - doc_repformer_rcut_smth = "From this position the inverse distance smoothly decays to 0 at the cut-off. Use in the repformer block." - doc_repformer_nsel = "Maximally possible number of neighbors for repformer block." - doc_tebd_dim = "The dimension of atom type embedding" - doc_concat_output_tebd = ( - "Whether to concat type embedding at the output of the descriptor." + # repinit args + doc_repinit = "(Used in the repinit block.) " + doc_repinit_rcut = f"{doc_repinit}The cut-off radius." + doc_repinit_rcut_smth = f"{doc_repinit}Where to start smoothing. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." + doc_repinit_nsel = f"{doc_repinit}Maximally possible number of selected neighbors." + doc_repinit_neuron = ( + f"{doc_repinit}Number of neurons in each hidden layers of the embedding net." + f"When two layers are of the same size or one layer is twice as large as the previous layer, " + f"a skip connection is built." ) - doc_repinit_neuron = "repinit block: the number of neurons in the embedding net." doc_repinit_axis_neuron = ( - "repinit block: the number of dimension of split in the symmetrization op." + f"{doc_repinit}Size of the submatrix of G (embedding matrix)." + ) + doc_repinit_tebd_dim = f"{doc_repinit}The dimension of atom type embedding." + doc_repinit_tebd_input_mode = ( + f"{doc_repinit}The input mode of the type embedding. Supported modes are ['concat', 'strip']." + "- 'concat': Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. " + "When `type_one_side` is False, the input is `input_ij = concat([r_ij, tebd_j, tebd_i])`. When `type_one_side` is True, the input is `input_ij = concat([r_ij, tebd_j])`. " + "The output is `out_ij = embeding(input_ij)` for the pair-wise representation of atom i with neighbor j." + "- 'strip': Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. " + f"When `type_one_side` is False, the input is `input_t = concat([tebd_j, tebd_i])`. {doc_only_pt_supported} When `type_one_side` is True, the input is `input_t = tebd_j`. " + "The output is `out_ij = embeding_t(input_t) * embeding_s(r_ij) + embeding_s(r_ij)` for the pair-wise representation of atom i with neighbor j." + ) + doc_repinit_set_davg_zero = ( + f"{doc_repinit}Set the normalization average to zero. " + f"This option should be set when `atom_ener` in the energy fitting is used." + ) + doc_repinit_activation_function = f"{doc_repinit}The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." + doc_repinit_type_one_side = ( + f"{doc_repinit}" + + r"If true, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters." + ) + doc_repinit_resnet_dt = ( + f'{doc_repinit}Whether to use a "Timestep" in the skip connection.' ) - doc_repinit_activation = ( - "repinit block: the activation function in the embedding net" + + # repformer args + doc_repformer = "(Used in the repformer block.) " + doc_repformer_rcut = f"{doc_repformer}The cut-off radius." + doc_repformer_rcut_smth = f"{doc_repformer}Where to start smoothing. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." + doc_repformer_nsel = ( + f"{doc_repformer}Maximally possible number of selected neighbors." + ) + doc_repformer_nlayers = f"{doc_repformer}The number of repformer layers." + doc_repformer_g1_dim = ( + f"{doc_repformer}The dimension of invariant single-atom representation." ) - doc_repformer_nlayers = "repformers block: the number of repformer layers" - doc_repformer_g1_dim = "repformers block: the dimension of single-atom rep" - doc_repformer_g2_dim = "repformers block: the dimension of invariant pair-atom rep" - doc_repformer_axis_dim = ( - "repformers block: the number of dimension of split in the symmetrization ops." + doc_repformer_g2_dim = ( + f"{doc_repformer}The dimension of invariant pair-atom representation." ) - doc_repformer_do_bn_mode = "repformers block: do batch norm in the repformer layers" - doc_repformer_bn_momentum = "repformers block: moment in the batch normalization" + doc_repformer_axis_neuron = f"{doc_repformer}The number of dimension of submatrix in the symmetrization ops." + doc_repformer_direct_dist = f"{doc_repformer}Whether or not use direct distance as input for the embedding net to get g2 instead of smoothed 1/r." doc_repformer_update_g1_has_conv = ( - "repformers block: update the g1 rep with convolution term" + f"{doc_repformer}Update the g1 rep with convolution term." ) doc_repformer_update_g1_has_drrd = ( - "repformers block: update the g1 rep with the drrd term" + f"{doc_repformer}Update the g1 rep with the drrd term." ) doc_repformer_update_g1_has_grrg = ( - "repformers block: update the g1 rep with the grrg term" + f"{doc_repformer}Update the g1 rep with the grrg term." ) doc_repformer_update_g1_has_attn = ( - "repformers block: update the g1 rep with the localized self-attention" + f"{doc_repformer}Update the g1 rep with the localized self-attention." ) doc_repformer_update_g2_has_g1g1 = ( - "repformers block: update the g2 rep with the g1xg1 term" + f"{doc_repformer}Update the g2 rep with the g1xg1 term." ) doc_repformer_update_g2_has_attn = ( - "repformers block: update the g2 rep with the gated self-attention" - ) - doc_repformer_update_h2 = "repformers block: update the h2 rep" - doc_repformer_attn1_hidden = ( - "repformers block: the hidden dimension of localized self-attention" - ) - doc_repformer_attn1_nhead = ( - "repformers block: the number of heads in localized self-attention" - ) - doc_repformer_attn2_hidden = ( - "repformers block: the hidden dimension of gated self-attention" - ) - doc_repformer_attn2_nhead = ( - "repformers block: the number of heads in gated self-attention" + f"{doc_repformer}Update the g2 rep with the gated self-attention." + ) + doc_repformer_update_h2 = f"{doc_repformer}Update the h2 rep." + doc_repformer_attn1_hidden = f"{doc_repformer}The hidden dimension of localized self-attention to update the g1 rep." + doc_repformer_attn1_nhead = f"{doc_repformer}The number of heads in localized self-attention to update the g1 rep." + doc_repformer_attn2_hidden = f"{doc_repformer}The hidden dimension of gated self-attention to update the g2 rep." + doc_repformer_attn2_nhead = f"{doc_repformer}The number of heads in gated self-attention to update the g2 rep." + doc_repformer_attn2_has_gate = f"{doc_repformer}Whether to use gate in the gated self-attention to update the g2 rep." + doc_repformer_activation_function = f"{doc_repformer}The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." + doc_repformer_update_style = ( + f"{doc_repformer}Style to update a representation. " + f"Supported options are: " + "-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) " + "-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)" + "-'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) " + "where `r1`, `r2` ... `r3` are residual weights defined by `repformer_update_residual` " + "and `repformer_update_residual_init`." + ) + doc_repformer_update_residual = ( + f"{doc_repformer}When update using residual mode, " + "the initial std of residual vector weights." + ) + doc_repformer_update_residual_init = ( + f"{doc_repformer}When update using residual mode, " + "the initialization mode of residual vector weights." + "Supported modes are: ['norm', 'const']." + ) + doc_repformer_set_davg_zero = ( + f"{doc_repformer}Set the normalization average to zero. " + f"This option should be set when `atom_ener` in the energy fitting is used." + ) + doc_repformer_trainable_ln = ( + "Whether to use trainable shift and scale weights in layer normalization." ) - doc_repformer_attn2_has_gate = ( - "repformers block: has gate in the gated self-attention" + doc_repformer_ln_eps = "The epsilon value for layer normalization. The default value for TensorFlow is set to 1e-3 to keep consistent with keras while set to 1e-5 in PyTorch and DP implementation." + + # descriptor args + doc_concat_output_tebd = ( + "Whether to concat type embedding at the output of the descriptor." ) - doc_repformer_activation = "repformers block: the activation function in the MLPs." - doc_repformer_update_style = "repformers block: style of update a rep. can be res_avg or res_incr. res_avg updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) res_incr updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)" - doc_repformer_set_davg_zero = "repformers block: set the avg to zero in statistics" - doc_repformer_add_type_ebd_to_seq = ( - "repformers block: concatenate the type embedding at the output" + doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_smooth = ( + "Whether to use smoothness in processes such as attention weights calculation." ) + doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." + doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." + doc_trainable = "If the parameters in the embedding net is trainable." + doc_seed = "Random seed for parameter initialization." + doc_add_tebd_to_repinit_out = "Add type embedding to the output representation from repinit before inputting it into repformer." return [ + # repinit args Argument("repinit_rcut", float, doc=doc_repinit_rcut), Argument("repinit_rcut_smth", float, doc=doc_repinit_rcut_smth), Argument("repinit_nsel", int, doc=doc_repinit_nsel), - Argument("repformer_rcut", float, doc=doc_repformer_rcut), - Argument("repformer_rcut_smth", float, doc=doc_repformer_rcut_smth), - Argument("repformer_nsel", int, doc=doc_repformer_nsel), - Argument("tebd_dim", int, optional=True, default=8, doc=doc_tebd_dim), - Argument( - "concat_output_tebd", - bool, - optional=True, - default=True, - doc=doc_concat_output_tebd, - ), Argument( "repinit_neuron", list, @@ -707,14 +753,54 @@ def descrpt_dpa2_args(): default=16, doc=doc_repinit_axis_neuron, ), - Argument("repinit_set_davg_zero", bool, optional=True, default=True), Argument( - "repinit_activation", + "repinit_tebd_dim", + int, + optional=True, + default=8, + alias=["tebd_dim"], + doc=doc_repinit_tebd_dim, + ), + Argument( + "repinit_tebd_input_mode", + str, + optional=True, + default="concat", + doc=doc_repinit_tebd_input_mode, + ), + Argument( + "repinit_set_davg_zero", + bool, + optional=True, + default=True, + doc=doc_repinit_set_davg_zero, + ), + Argument( + "repinit_activation_function", str, optional=True, default="tanh", - doc=doc_repinit_activation, + alias=["repinit_activation"], + doc=doc_repinit_activation_function, ), + Argument( + "repinit_type_one_side", + bool, + optional=True, + default=False, + doc=doc_repinit_type_one_side, + ), + Argument( + "repinit_resnet_dt", + bool, + optional=True, + default=False, + doc=doc_repinit_resnet_dt, + ), + # repformer args + Argument("repformer_rcut", float, doc=doc_repformer_rcut), + Argument("repformer_rcut_smth", float, doc=doc_repformer_rcut_smth), + Argument("repformer_nsel", int, doc=doc_repformer_nsel), Argument( "repformer_nlayers", int, @@ -733,25 +819,19 @@ def descrpt_dpa2_args(): "repformer_g2_dim", int, optional=True, default=16, doc=doc_repformer_g2_dim ), Argument( - "repformer_axis_dim", + "repformer_axis_neuron", int, optional=True, default=4, - doc=doc_repformer_axis_dim, - ), - Argument( - "repformer_do_bn_mode", - str, - optional=True, - default="no", - doc=doc_repformer_do_bn_mode, + alias=["repformer_axis_dim"], + doc=doc_repformer_axis_neuron, ), Argument( - "repformer_bn_momentum", - float, + "repformer_direct_dist", + bool, optional=True, - default=0.1, - doc=doc_repformer_bn_momentum, + default=False, + doc=doc_repformer_direct_dist, ), Argument( "repformer_update_g1_has_conv", @@ -838,11 +918,12 @@ def descrpt_dpa2_args(): doc=doc_repformer_attn2_has_gate, ), Argument( - "repformer_activation", + "repformer_activation_function", str, optional=True, default="tanh", - doc=doc_repformer_activation, + alias=["repformer_activation"], + doc=doc_repformer_activation_function, ), Argument( "repformer_update_style", @@ -851,6 +932,20 @@ def descrpt_dpa2_args(): default="res_avg", doc=doc_repformer_update_style, ), + Argument( + "repformer_update_residual", + float, + optional=True, + default=0.001, + doc=doc_repformer_update_residual, + ), + Argument( + "repformer_update_residual_init", + str, + optional=True, + default="norm", + doc=doc_repformer_update_residual_init, + ), Argument( "repformer_set_davg_zero", bool, @@ -859,11 +954,52 @@ def descrpt_dpa2_args(): doc=doc_repformer_set_davg_zero, ), Argument( - "repformer_add_type_ebd_to_seq", + "repformer_trainable_ln", + bool, + optional=True, + default=True, + doc=doc_repformer_trainable_ln, + ), + Argument( + "repformer_ln_eps", + float, + optional=True, + default=None, + doc=doc_repformer_ln_eps, + ), + # descriptor args + Argument( + "concat_output_tebd", + bool, + optional=True, + default=True, + doc=doc_concat_output_tebd, + ), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument("smooth", bool, optional=True, default=False, doc=doc_smooth), + Argument( + "exclude_types", + List[List[int]], + optional=True, + default=[], + doc=doc_exclude_types, + ), + Argument( + "env_protection", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_env_protection, + ), + Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), + Argument("seed", [int, None], optional=True, doc=doc_seed), + Argument( + "add_tebd_to_repinit_out", bool, optional=True, default=False, - doc=doc_repformer_add_type_ebd_to_seq, + alias=["repformer_add_type_ebd_to_seq"], + doc=doc_add_tebd_to_repinit_out, ), ] diff --git a/doc/model/dpa2.md b/doc/model/dpa2.md index e295f6b6bb..0d17dfd475 100644 --- a/doc/model/dpa2.md +++ b/doc/model/dpa2.md @@ -1,5 +1,9 @@ -# Descriptor DPA-2 {{ pytorch_icon }} +# Descriptor DPA-2 {{ pytorch_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: PyTorch {{ pytorch_icon }} +**Supported backends**: PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} ::: + +The DPA-2 model implementation. See https://arxiv.org/abs/2312.15492 for more details. + +Training example: `examples/water/dpa2/input_torch.json`. diff --git a/examples/water/dpa2/input_torch.json b/examples/water/dpa2/input_torch.json index e94086b898..7096e1d40b 100644 --- a/examples/water/dpa2/input_torch.json +++ b/examples/water/dpa2/input_torch.json @@ -7,7 +7,7 @@ ], "descriptor": { "type": "dpa2", - "tebd_dim": 8, + "repinit_tebd_dim": 8, "repinit_rcut": 9.0, "repinit_rcut_smth": 8.0, "repinit_nsel": 120, @@ -20,7 +20,7 @@ 100 ], "repinit_axis_neuron": 12, - "repinit_activation": "tanh", + "repinit_activation_function": "tanh", "repformer_nlayers": 12, "repformer_g1_dim": 128, "repformer_g2_dim": 32, @@ -28,7 +28,7 @@ "repformer_attn2_nhead": 4, "repformer_attn1_hidden": 128, "repformer_attn1_nhead": 4, - "repformer_axis_dim": 4, + "repformer_axis_neuron": 4, "repformer_update_h2": false, "repformer_update_g1_has_conv": true, "repformer_update_g1_has_grrg": true, @@ -37,7 +37,7 @@ "repformer_update_g2_has_g1g1": true, "repformer_update_g2_has_attn": true, "repformer_attn2_has_gate": true, - "repformer_add_type_ebd_to_seq": false + "add_tebd_to_repinit_out": false }, "fitting_net": { "neuron": [ diff --git a/source/tests/common/dpmodel/case_single_frame_with_nlist.py b/source/tests/common/dpmodel/case_single_frame_with_nlist.py index 828e090cad..f674c7a3a9 100644 --- a/source/tests/common/dpmodel/case_single_frame_with_nlist.py +++ b/source/tests/common/dpmodel/case_single_frame_with_nlist.py @@ -41,9 +41,11 @@ def setUp(self): ).reshape([1, self.nall, 3]) self.coord = self.coord_ext[:, : self.nloc, :] self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) + self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall]) self.atype = self.atype_ext[:, : self.nloc] # sel = [5, 2] self.sel = [5, 2] + self.sel_mix = [7] self.nlist = np.array( [ [1, 3, -1, -1, -1, 2, -1], @@ -66,6 +68,9 @@ def setUp(self): self.atype_ext = np.concatenate( [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 ) + self.mapping = np.concatenate( + [self.mapping, self.mapping[:, self.perm]], axis=0 + ) # permute the nlist nlist1 = self.nlist[:, self.perm[: self.nloc], :] mask = nlist1 == -1 diff --git a/source/tests/common/dpmodel/test_descriptor_dpa2.py b/source/tests/common/dpmodel/test_descriptor_dpa2.py new file mode 100644 index 0000000000..3ae0689dad --- /dev/null +++ b/source/tests/common/dpmodel/test_descriptor_dpa2.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor import ( + DescrptDPA2, +) + +from .case_single_frame_with_nlist import ( + TestCaseSingleFrameWithNlist, +) + + +class TestDescrptDPA2(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + dstd_2 = 0.1 + np.abs(dstd_2) + + em0 = DescrptDPA2( + ntypes=self.nt, + repinit_rcut=self.rcut, + repinit_rcut_smth=self.rcut_smth, + repinit_nsel=self.sel_mix, + repformer_rcut=self.rcut / 2, + repformer_rcut_smth=self.rcut_smth, + repformer_nsel=nnei // 2, + ) + + em0.repinit.mean = davg + em0.repinit.stddev = dstd + em0.repformers.mean = davg_2 + em0.repformers.stddev = dstd_2 + em1 = DescrptDPA2.deserialize(em0.serialize()) + mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, self.mapping) + mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist, self.mapping) + desired_shape = [ + (nf, nloc, em0.get_dim_out()), # descriptor + (nf, nloc, em0.get_dim_emb(), 3), # rot_mat + (nf, nloc, nnei // 2, em0.repformers.g2_dim), # g2 + (nf, nloc, nnei // 2, 3), # h2 + (nf, nloc, nnei // 2), # sw + ] + for ii in [0, 1, 2, 3, 4]: + np.testing.assert_equal(mm0[ii].shape, desired_shape[ii]) + np.testing.assert_allclose(mm0[ii], mm1[ii]) diff --git a/source/tests/common/dpmodel/test_env_mat.py b/source/tests/common/dpmodel/test_env_mat.py index 7e1ce7cddd..bfeb0d69d9 100644 --- a/source/tests/common/dpmodel/test_env_mat.py +++ b/source/tests/common/dpmodel/test_env_mat.py @@ -26,7 +26,12 @@ def test_self_consistency( dstd = 0.1 + np.abs(dstd) em0 = EnvMat(self.rcut, self.rcut_smth) em1 = EnvMat.deserialize(em0.serialize()) - mm0, ww0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd) - mm1, ww1 = em1.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd) + mm0, diff0, ww0 = em0.call( + self.coord_ext, self.atype_ext, self.nlist, davg, dstd + ) + mm1, diff1, ww1 = em1.call( + self.coord_ext, self.atype_ext, self.nlist, davg, dstd + ) np.testing.assert_allclose(mm0, mm1) + np.testing.assert_allclose(diff0, diff1) np.testing.assert_allclose(ww0, ww1) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 9c8c1cea7f..13ceef84ab 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -74,7 +74,7 @@ def eval_dp_descriptor( dp_obj.get_sel(), distinguish_types=(not mixed_types), ) - return dp_obj(ext_coords, ext_atype, nlist=nlist) + return dp_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) def eval_pt_descriptor( self, pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False @@ -95,5 +95,5 @@ def eval_pt_descriptor( ) return [ x.detach().cpu().numpy() if torch.is_tensor(x) else x - for x in pt_obj(ext_coords, ext_atype, nlist=nlist) + for x in pt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) ] diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py new file mode 100644 index 0000000000..dd868dbd59 --- /dev/null +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -0,0 +1,380 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Tuple, +) + +import numpy as np +from dargs import ( + Argument, +) + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + CommonTest, + parameterized, +) +from .common import ( + DescriptorTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2PT +else: + DescrptDPA2PT = None + +# not implemented +DescrptDPA2TF = None + +from deepmd.utils.argcheck import ( + descrpt_se_atten_args, +) + + +@parameterized( + ("concat", "strip"), # repinit_tebd_input_mode + (True,), # repinit_set_davg_zero + (False,), # repinit_type_one_side + (True, False), # repformer_direct_dist + (True,), # repformer_update_g1_has_conv + (True,), # repformer_update_g1_has_drrd + (True,), # repformer_update_g1_has_grrg + (True,), # repformer_update_g1_has_attn + (True,), # repformer_update_g2_has_g1g1 + (True,), # repformer_update_g2_has_attn + (False,), # repformer_update_h2 + (True, False), # repformer_attn2_has_gate + ("res_avg", "res_residual"), # repformer_update_style + ("norm", "const"), # repformer_update_residual_init + (True,), # repformer_set_davg_zero + (True,), # repformer_trainable_ln + (1e-5,), # repformer_ln_eps + (True, False), # smooth + ([], [[0, 1]]), # exclude_types + ("float64",), # precision + (True, False), # add_tebd_to_repinit_out +) +class TestDPA2(CommonTest, DescriptorTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + ) = self.param + return { + "ntypes": self.ntypes, + "repinit_rcut": 6.00, + "repinit_rcut_smth": 5.80, + "repinit_nsel": 10, + "repformer_rcut": 4.00, + "repformer_rcut_smth": 3.50, + "repformer_nsel": 8, + # kwargs for repinit + "repinit_neuron": [6, 12, 24], + "repinit_axis_neuron": 3, + "repinit_tebd_dim": 4, + "repinit_tebd_input_mode": repinit_tebd_input_mode, + "repinit_set_davg_zero": repinit_set_davg_zero, + "repinit_activation_function": "tanh", + "repinit_type_one_side": repinit_type_one_side, + # kwargs for repformer + "repformer_nlayers": 3, + "repformer_g1_dim": 20, + "repformer_g2_dim": 10, + "repformer_axis_neuron": 3, + "repformer_direct_dist": repformer_direct_dist, + "repformer_update_g1_has_conv": repformer_update_g1_has_conv, + "repformer_update_g1_has_drrd": repformer_update_g1_has_drrd, + "repformer_update_g1_has_grrg": repformer_update_g1_has_grrg, + "repformer_update_g1_has_attn": repformer_update_g1_has_attn, + "repformer_update_g2_has_g1g1": repformer_update_g2_has_g1g1, + "repformer_update_g2_has_attn": repformer_update_g2_has_attn, + "repformer_update_h2": repformer_update_h2, + "repformer_attn1_hidden": 12, + "repformer_attn1_nhead": 2, + "repformer_attn2_hidden": 10, + "repformer_attn2_nhead": 2, + "repformer_attn2_has_gate": repformer_attn2_has_gate, + "repformer_activation_function": "tanh", + "repformer_update_style": repformer_update_style, + "repformer_update_residual": 0.001, + "repformer_update_residual_init": repformer_update_residual_init, + "repformer_set_davg_zero": True, + "repformer_trainable_ln": repformer_trainable_ln, + "repformer_ln_eps": repformer_ln_eps, + # kwargs for descriptor + "concat_output_tebd": True, + "precision": precision, + "smooth": smooth, + "exclude_types": exclude_types, + "env_protection": 0.0, + "trainable": True, + "add_tebd_to_repinit_out": add_tebd_to_repinit_out, + } + + @property + def skip_pt(self) -> bool: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + ) = self.param + return CommonTest.skip_pt + + @property + def skip_dp(self) -> bool: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + ) = self.param + return CommonTest.skip_pt + + @property + def skip_tf(self) -> bool: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + ) = self.param + return True + + tf_class = DescrptDPA2TF + dp_class = DescrptDPA2DP + pt_class = DescrptDPA2PT + args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + ) = self.param + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + return self.build_tf_descriptor( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_descriptor( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_descriptor( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + return (ret[0],) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt/model/models/dpa2.json b/source/tests/pt/model/models/dpa2.json index 8b9c735851..966e298dfa 100644 --- a/source/tests/pt/model/models/dpa2.json +++ b/source/tests/pt/model/models/dpa2.json @@ -17,7 +17,7 @@ 8 ], "repinit_axis_neuron": 4, - "repinit_activation": "tanh", + "repinit_activation_function": "tanh", "repformer_nlayers": 12, "repformer_g1_dim": 8, "repformer_g2_dim": 5, @@ -25,7 +25,7 @@ "repformer_attn2_nhead": 1, "repformer_attn1_hidden": 5, "repformer_attn1_nhead": 1, - "repformer_axis_dim": 4, + "repformer_axis_neuron": 4, "repformer_update_h2": false, "repformer_update_g1_has_conv": true, "repformer_update_g1_has_grrg": true, @@ -34,7 +34,7 @@ "repformer_update_g2_has_g1g1": true, "repformer_update_g2_has_attn": true, "repformer_attn2_has_gate": true, - "repformer_add_type_ebd_to_seq": false + "add_tebd_to_repinit_out": false }, "fitting_net": { "neuron": [ diff --git a/source/tests/pt/model/models/dpa2.pth b/source/tests/pt/model/models/dpa2.pth index 26d6155272..d31f69742a 100644 Binary files a/source/tests/pt/model/models/dpa2.pth and b/source/tests/pt/model/models/dpa2.pth differ diff --git a/source/tests/pt/model/test_descriptor_dpa2.py b/source/tests/pt/model/test_descriptor_dpa2.py index 5c9e58b1f1..240871f2d7 100644 --- a/source/tests/pt/model/test_descriptor_dpa2.py +++ b/source/tests/pt/model/test_descriptor_dpa2.py @@ -124,9 +124,9 @@ def test_descriptor(self): target_dict = des.state_dict() source_dict = torch.load(self.file_model_param) # type_embd of repformer is removed - source_dict.pop("descriptor_list.1.type_embd.embedding.weight") + source_dict.pop("type_embedding.embedding.embedding_net.layers.0.bias") type_embd_dict = torch.load(self.file_type_embed) - target_dict = translate_hybrid_and_type_embd_dicts_to_dpa2( + target_dict = translate_type_embd_dicts_to_dpa2( target_dict, source_dict, type_embd_dict, @@ -176,7 +176,7 @@ def test_descriptor(self): self.assertEqual(descriptor.shape[-1], des.get_dim_out()) -def translate_hybrid_and_type_embd_dicts_to_dpa2( +def translate_type_embd_dicts_to_dpa2( target_dict, source_dict, type_embd_dict, @@ -184,11 +184,8 @@ def translate_hybrid_and_type_embd_dicts_to_dpa2( all_keys = list(target_dict.keys()) record = [False for ii in all_keys] for kk, vv in source_dict.items(): - tk = kk.replace("descriptor_list.1", "repformers") - tk = tk.replace("descriptor_list.0", "repinit") - tk = tk.replace("sequential_transform.0", "g1_shape_tranform") - record[all_keys.index(tk)] = True - target_dict[tk] = vv + record[all_keys.index(kk)] = True + target_dict[kk] = vv assert len(type_embd_dict.keys()) == 2 it = iter(type_embd_dict.keys()) for _ in range(2): diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py index c1b6f97b26..01cbf259a7 100644 --- a/source/tests/pt/model/test_dpa1.py +++ b/source/tests/pt/model/test_dpa1.py @@ -39,12 +39,12 @@ def test_consistency( dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec, sm, to, tm in itertools.product( + for idt, sm, to, tm, prec in itertools.product( [False, True], # resnet_dt - ["float64", "float32"], # precision [False, True], # smooth_type_embedding [False, True], # type_one_side ["concat", "strip"], # tebd_input_mode + ["float64", "float32"], # precision ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -65,7 +65,7 @@ def test_consistency( old_impl=False, ).to(env.DEVICE) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) - dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) rd0, _, _, _, _ = dd0( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), diff --git a/source/tests/pt/model/test_dpa2.py b/source/tests/pt/model/test_dpa2.py new file mode 100644 index 0000000000..09331ed852 --- /dev/null +++ b/source/tests/pt/model/test_dpa2.py @@ -0,0 +1,334 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DPDescrptDPA2 +from deepmd.pt.model.descriptor.dpa2 import ( + DescrptDPA2, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptDPA2(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + dstd_2 = 0.1 + np.abs(dstd_2) + + for ( + riti, + riz, + rp1c, + rp1d, + rp1g, + rp1a, + rp2g, + rp2a, + rph, + rp2gate, + rus, + rpz, + sm, + prec, + ) in itertools.product( + ["concat", "strip"], # repinit_tebd_input_mode + [ + True, + ], # repinit_set_davg_zero + [True, False], # repformer_update_g1_has_conv + [True, False], # repformer_update_g1_has_drrd + [True, False], # repformer_update_g1_has_grrg + [True, False], # repformer_update_g1_has_attn + [True, False], # repformer_update_g2_has_g1g1 + [True, False], # repformer_update_g2_has_attn + [ + False, + ], # repformer_update_h2 + [True, False], # repformer_attn2_has_gate + ["res_avg", "res_residual"], # repformer_update_style + [ + True, + ], # repformer_set_davg_zero + [True, False], # smooth + ["float64"], # precision + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + # dpa2 new impl + dd0 = DescrptDPA2( + self.nt, + repinit_rcut=self.rcut, + repinit_rcut_smth=self.rcut_smth, + repinit_nsel=self.sel_mix, + repformer_rcut=self.rcut / 2, + repformer_rcut_smth=self.rcut_smth, + repformer_nsel=nnei // 2, + # kwargs for repinit + repinit_tebd_input_mode=riti, + repinit_set_davg_zero=riz, + # kwargs for repformer + repformer_nlayers=3, + repformer_g1_dim=20, + repformer_g2_dim=10, + repformer_axis_neuron=4, + repformer_update_g1_has_conv=rp1c, + repformer_update_g1_has_drrd=rp1d, + repformer_update_g1_has_grrg=rp1g, + repformer_update_g1_has_attn=rp1a, + repformer_update_g2_has_g1g1=rp2g, + repformer_update_g2_has_attn=rp2a, + repformer_update_h2=rph, + repformer_attn1_hidden=20, + repformer_attn1_nhead=2, + repformer_attn2_hidden=10, + repformer_attn2_nhead=2, + repformer_attn2_has_gate=rp2gate, + repformer_update_style=rus, + repformer_set_davg_zero=rpz, + # kwargs for descriptor + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + old_impl=False, + ).to(env.DEVICE) + + dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repinit.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd0.repformers.mean = torch.tensor(davg_2, dtype=dtype, device=env.DEVICE) + dd0.repformers.stddev = torch.tensor(dstd_2, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA2.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + # dp impl + dd2 = DPDescrptDPA2.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, self.atype_ext, self.nlist, self.mapping + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + ) + # old impl + if prec == "float64" and rus == "res_avg": + dd3 = DescrptDPA2( + self.nt, + repinit_rcut=self.rcut, + repinit_rcut_smth=self.rcut_smth, + repinit_nsel=self.sel_mix, + repformer_rcut=self.rcut / 2, + repformer_rcut_smth=self.rcut_smth, + repformer_nsel=nnei // 2, + # kwargs for repinit + repinit_tebd_input_mode=riti, + repinit_set_davg_zero=riz, + # kwargs for repformer + repformer_nlayers=3, + repformer_g1_dim=20, + repformer_g2_dim=10, + repformer_axis_neuron=4, + repformer_update_g1_has_conv=rp1c, + repformer_update_g1_has_drrd=rp1d, + repformer_update_g1_has_grrg=rp1g, + repformer_update_g1_has_attn=rp1a, + repformer_update_g2_has_g1g1=rp2g, + repformer_update_g2_has_attn=rp2a, + repformer_update_h2=rph, + repformer_attn1_hidden=20, + repformer_attn1_nhead=2, + repformer_attn2_hidden=10, + repformer_attn2_nhead=2, + repformer_attn2_has_gate=rp2gate, + repformer_update_style="res_avg", + repformer_set_davg_zero=rpz, + # kwargs for descriptor + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + old_impl=True, + ).to(env.DEVICE) + dd0_state_dict = dd0.state_dict() + dd3_state_dict = dd3.state_dict() + for i in list(dd0_state_dict.keys()): + if ".bias" in i and ( + ".linear1." in i or ".linear2." in i or ".head_map." in i + ): + dd0_state_dict[i] = dd0_state_dict[i].unsqueeze(0) + if ".attn2_lm.matrix" in i: + dd0_state_dict[ + i.replace(".attn2_lm.matrix", ".attn2_lm.weight") + ] = dd0_state_dict.pop(i) + + dd3.load_state_dict(dd0_state_dict) + rd3, _, _, _, _ = dd3( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd3.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + + def test_jit( + self, + ): + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + riti, + riz, + rp1c, + rp1d, + rp1g, + rp1a, + rp2g, + rp2a, + rph, + rp2gate, + rus, + rpz, + sm, + prec, + ) in itertools.product( + ["concat", "strip"], # repinit_tebd_input_mode + [ + True, + ], # repinit_set_davg_zero + [ + True, + ], # repformer_update_g1_has_conv + [ + True, + ], # repformer_update_g1_has_drrd + [ + True, + ], # repformer_update_g1_has_grrg + [ + True, + ], # repformer_update_g1_has_attn + [ + True, + ], # repformer_update_g2_has_g1g1 + [ + True, + ], # repformer_update_g2_has_attn + [ + False, + ], # repformer_update_h2 + [ + True, + ], # repformer_attn2_has_gate + ["res_avg", "res_residual"], # repformer_update_style + [ + True, + ], # repformer_set_davg_zero + [ + True, + ], # smooth + ["float64"], # precision + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + # dpa2 new impl + dd0 = DescrptDPA2( + self.nt, + repinit_rcut=self.rcut, + repinit_rcut_smth=self.rcut_smth, + repinit_nsel=self.sel_mix, + repformer_rcut=self.rcut / 2, + repformer_rcut_smth=self.rcut_smth, + repformer_nsel=nnei // 2, + # kwargs for repinit + repinit_tebd_input_mode=riti, + repinit_set_davg_zero=riz, + # kwargs for repformer + repformer_nlayers=3, + repformer_g1_dim=20, + repformer_g2_dim=10, + repformer_axis_neuron=4, + repformer_update_g1_has_conv=rp1c, + repformer_update_g1_has_drrd=rp1d, + repformer_update_g1_has_grrg=rp1g, + repformer_update_g1_has_attn=rp1a, + repformer_update_g2_has_g1g1=rp2g, + repformer_update_g2_has_attn=rp2a, + repformer_update_h2=rph, + repformer_attn1_hidden=20, + repformer_attn1_nhead=2, + repformer_attn2_hidden=10, + repformer_attn2_nhead=2, + repformer_attn2_has_gate=rp2gate, + repformer_update_style=rus, + repformer_set_davg_zero=rpz, + # kwargs for descriptor + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + old_impl=False, + ).to(env.DEVICE) + + dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repinit.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd0.repformers.mean = torch.tensor(davg_2, dtype=dtype, device=env.DEVICE) + dd0.repformers.stddev = torch.tensor(dstd_2, dtype=dtype, device=env.DEVICE) + model = torch.jit.script(dd0) diff --git a/source/tests/pt/model/test_env_mat.py b/source/tests/pt/model/test_env_mat.py index 24ed886b86..84099cddaf 100644 --- a/source/tests/pt/model/test_env_mat.py +++ b/source/tests/pt/model/test_env_mat.py @@ -33,6 +33,7 @@ def setUp(self): dtype=np.float64, ).reshape([1, self.nall, 3]) self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) + self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall]) # sel = [5, 2] self.sel = [5, 2] self.sel_mix = [7] @@ -57,6 +58,10 @@ def setUp(self): self.atype_ext = np.concatenate( [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 ) + self.mapping = np.concatenate( + [self.mapping, self.mapping[:, self.perm]], axis=0 + ) + # permute the nlist nlist1 = self.nlist[:, self.perm[: self.nloc], :] mask = nlist1 == -1 @@ -156,8 +161,10 @@ def test_consistency( dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) em0 = EnvMat(self.rcut, self.rcut_smth) - mm0, ww0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd) - mm1, _, ww1 = prod_env_mat( + mm0, diff0, ww0 = em0.call( + self.coord_ext, self.atype_ext, self.nlist, davg, dstd + ) + mm1, diff1, ww1 = prod_env_mat( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.nlist, dtype=int, device=env.DEVICE), torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE), @@ -167,5 +174,6 @@ def test_consistency( self.rcut_smth, ) np.testing.assert_allclose(mm0, mm1.detach().cpu().numpy()) + np.testing.assert_allclose(diff0, diff1.detach().cpu().numpy()) np.testing.assert_allclose(ww0, ww1.detach().cpu().numpy()) np.testing.assert_allclose(mm0[0][self.perm[: self.nloc]], mm0[1]) diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index a54d3159e2..b2bfcb8a4d 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -130,7 +130,7 @@ "repformer_nsel": 40, "repinit_neuron": [2, 4, 8], "repinit_axis_neuron": 4, - "repinit_activation": "tanh", + "repinit_activation_function": "tanh", "repformer_nlayers": 12, "repformer_g1_dim": 8, "repformer_g2_dim": 5, @@ -138,7 +138,7 @@ "repformer_attn2_nhead": 1, "repformer_attn1_hidden": 5, "repformer_attn1_nhead": 1, - "repformer_axis_dim": 4, + "repformer_axis_neuron": 4, "repformer_update_h2": False, "repformer_update_g1_has_conv": True, "repformer_update_g1_has_grrg": True, @@ -147,7 +147,7 @@ "repformer_update_g2_has_g1g1": True, "repformer_update_g2_has_attn": True, "repformer_attn2_has_gate": True, - "repformer_add_type_ebd_to_seq": False, + "add_tebd_to_repinit_out": False, }, "fitting_net": { "neuron": [24, 24], @@ -215,7 +215,7 @@ "repformer_nsel": 10, "repinit_neuron": [2, 4, 8], "repinit_axis_neuron": 4, - "repinit_activation": "tanh", + "repinit_activation_function": "tanh", "repformer_nlayers": 12, "repformer_g1_dim": 8, "repformer_g2_dim": 5, @@ -223,7 +223,7 @@ "repformer_attn2_nhead": 1, "repformer_attn1_hidden": 5, "repformer_attn1_nhead": 1, - "repformer_axis_dim": 4, + "repformer_axis_neuron": 4, "repformer_update_h2": False, "repformer_update_g1_has_conv": True, "repformer_update_g1_has_grrg": True, @@ -232,7 +232,7 @@ "repformer_update_g2_has_g1g1": True, "repformer_update_g2_has_attn": True, "repformer_attn2_has_gate": True, - "repformer_add_type_ebd_to_seq": False, + "add_tebd_to_repinit_out": False, }, ], }, diff --git a/source/tests/pt/model/test_unused_params.py b/source/tests/pt/model/test_unused_params.py index a3c93cbe68..730eb00a0b 100644 --- a/source/tests/pt/model/test_unused_params.py +++ b/source/tests/pt/model/test_unused_params.py @@ -41,8 +41,6 @@ def test_unused(self): # skip the case g2 is not envolved continue model = copy.deepcopy(model_dpa2) - model["descriptor"]["rcut"] = model["descriptor"]["repinit_rcut"] - model["descriptor"]["sel"] = model["descriptor"]["repinit_nsel"] model["descriptor"]["repformer_nlayers"] = 2 # model["descriptor"]["combine_grrg"] = cmbg2 model["descriptor"]["repformer_update_g1_has_conv"] = conv