From 7d82945b5c089c5d83a8f124cda61fe66a76aad6 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Apr 2024 21:56:14 +0800 Subject: [PATCH 01/37] feat: Support `stripped_type_embedding` in PT/DP --- deepmd/dpmodel/descriptor/dpa1.py | 84 +++++-- deepmd/pt/model/descriptor/dpa1.py | 31 +-- deepmd/pt/model/descriptor/se_atten.py | 55 ++++- deepmd/tf/descriptor/se_a.py | 9 +- deepmd/tf/descriptor/se_a_ebd_v2.py | 4 +- deepmd/tf/descriptor/se_a_mask.py | 4 +- deepmd/tf/descriptor/se_atten.py | 213 ++++++++++++++++-- deepmd/tf/descriptor/se_atten_v2.py | 2 +- deepmd/tf/nvnmd/data/data.py | 2 +- deepmd/tf/utils/graph.py | 76 ++++--- deepmd/utils/argcheck.py | 40 +++- doc/model/train-se-atten.md | 2 +- .../water/se_atten_dpa1_compat/input.json | 2 +- .../tests/consistent/descriptor/test_dpa1.py | 3 +- source/tests/pt/model/test_dpa1.py | 24 +- source/tests/tf/test_data_large_batch.py | 4 +- source/tests/tf/test_descrpt_se_atten.py | 4 +- source/tests/tf/test_finetune_se_atten.py | 6 +- .../tests/tf/test_init_frz_model_se_atten.py | 6 +- .../tf/test_model_compression_se_atten.py | 4 +- source/tests/tf/test_model_se_atten.py | 12 +- 21 files changed, 447 insertions(+), 140 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index a551a57628..2f3494c2b0 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -135,8 +135,9 @@ class DescrptDPA1(NativeOP, BaseDescriptor): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - The way to mix the type embeddings. Supported options are `concat`. - (TODO need to support stripped_type_embedding option) + The input mode of the type embedding. Supported modes are [`concat`, `strip`]. + - `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) @@ -189,9 +190,6 @@ class DescrptDPA1(NativeOP, BaseDescriptor): Limitations ----------- - The currently implementation does not support the following features - 1. tebd_input_mode != 'concat' - The currently implementation will not support the following deprecated features 1. spin is not None 2. attn_mask == True @@ -243,9 +241,6 @@ def __init__( raise NotImplementedError( "old implementation of attn_mask is not supported." ) - # TODO - if tebd_input_mode != "concat": - raise NotImplementedError("tebd_input_mode != 'concat' not implemented") # to keep consistent with default value in this backends if ln_eps is None: ln_eps = 1e-5 @@ -290,25 +285,41 @@ def __init__( activation_function="Linear", precision=precision, ) + if not self.type_one_side: + self.tebd_dim_input = self.tebd_dim * 2 + else: + self.tebd_dim_input = self.tebd_dim if self.tebd_input_mode in ["concat"]: - if not self.type_one_side: - in_dim = 1 + self.tebd_dim * 2 - else: - in_dim = 1 + self.tebd_dim + self.embd_input_dim = 1 + self.tebd_dim_input else: - in_dim = 1 + self.embd_input_dim = 1 self.embeddings = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network", ) self.embeddings[0] = EmbeddingNet( - in_dim, + self.embd_input_dim, self.neuron, self.activation_function, self.resnet_dt, self.precision, ) + if self.tebd_input_mode in ["strip"]: + self.embeddings_strip = NetworkCollection( + ndim=0, + ntypes=self.ntypes, + network_type="embedding_network", + ) + self.embeddings_strip[0] = EmbeddingNet( + self.tebd_dim_input, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + ) + else: + self.embeddings_strip = None self.dpa1_attention = NeighborGatedAttention( self.attn_layer, self.nnei, @@ -410,6 +421,18 @@ def cal_g( gg = self.embeddings[embedding_idx].call(ss) return gg + def cal_g_strip( + self, + ss, + embedding_idx, + ): + assert self.embeddings_strip is not None + nfnl, nnei = ss.shape[0:2] + ss = ss.reshape(nfnl, nnei, -1) + # nfnl x nnei x ng + gg = self.embeddings_strip[embedding_idx].call(ss) + return gg + def reinit_exclude( self, exclude_types: List[Tuple[int, int]] = [], @@ -500,11 +523,28 @@ def call( else: # nfnl x nnei x (1 + tebd_dim) ss = np.concatenate([ss, atype_embd_nlist], axis=-1) + # calculate gg + # nfnl x nnei x ng + gg = self.cal_g(ss, 0) + elif self.tebd_input_mode in ["strip"]: + # nfnl x nnei x ng + gg_s = self.cal_g(ss, 0) + assert self.embeddings_strip is not None + if not self.type_one_side: + # nfnl x nnei x (tebd_dim * 2) + tt = np.concatenate([atype_embd_nlist, atype_embd_nnei], axis=-1) + else: + # nfnl x nnei x tebd_dim + tt = atype_embd_nlist + # nfnl x nnei x ng + gg_t = self.cal_g_strip(tt, 0) + if self.smooth: + gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + # nfnl x nnei x ng + gg = gg_s * gg_t + gg_s else: raise NotImplementedError - # calculate gg - gg = self.cal_g(ss, 0) input_r = dmatrix.reshape(-1, nnei, 4)[:, :, 1:4] / ( np.linalg.norm( dmatrix.reshape(-1, nnei, 4)[:, :, 1:4], axis=-1, keepdims=True @@ -531,7 +571,7 @@ def call( def serialize(self) -> dict: """Serialize the descriptor to dict.""" - return { + data = { "@class": "Descriptor", "type": "dpa1", "@version": 1, @@ -574,6 +614,9 @@ def serialize(self) -> dict: "trainable": True, "spin": None, } + if self.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": self.embeddings_strip.serialize()}) + return data @classmethod def deserialize(cls, data: dict) -> "DescrptDPA1": @@ -587,11 +630,18 @@ def deserialize(cls, data: dict) -> "DescrptDPA1": type_embedding = data.pop("type_embedding") attention_layers = data.pop("attention_layers") env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None obj = cls(**data) obj["davg"] = variables["davg"] obj["dstd"] = variables["dstd"] obj.embeddings = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.embeddings_strip = NetworkCollection.deserialize(embeddings_strip) obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) return obj diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 852e08403c..42e74cacc1 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -117,8 +117,9 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - The way to mix the type embeddings. Supported options are `concat`. - (TODO need to support stripped_type_embedding option) + The input mode of the type embedding. Supported modes are [`concat`, `strip`]. + - `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) @@ -171,9 +172,6 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): Limitations ----------- - The currently implementation does not support the following features - 1. tebd_input_mode != 'concat' - The currently implementation will not support the following deprecated features 1. spin is not None 2. attn_mask == True @@ -195,8 +193,7 @@ def __init__( axis_neuron: int = 16, tebd_dim: int = 8, tebd_input_mode: str = "concat", - # set_davg_zero: bool = False, - set_davg_zero: bool = True, # TODO + set_davg_zero: bool = True, attn: int = 128, attn_layer: int = 2, attn_dotr: bool = True, @@ -216,24 +213,18 @@ def __init__( smooth_type_embedding: bool = True, type_one_side: bool = False, # not implemented - stripped_type_embedding: bool = False, spin=None, type: Optional[str] = None, seed: Optional[int] = None, old_impl: bool = False, ): super().__init__() - if stripped_type_embedding: - raise NotImplementedError("stripped_type_embedding is not supported.") if spin is not None: raise NotImplementedError("old implementation of spin is not supported.") if attn_mask: raise NotImplementedError( "old implementation of attn_mask is not supported." ) - # TODO - if tebd_input_mode != "concat": - raise NotImplementedError("tebd_input_mode != 'concat' not implemented") # to keep consistent with default value in this backends if ln_eps is None: ln_eps = 1e-5 @@ -376,7 +367,7 @@ def set_stat_mean_and_stddev( def serialize(self) -> dict: obj = self.se_atten - return { + data = { "@class": "Descriptor", "type": "dpa1", "@version": 1, @@ -419,6 +410,9 @@ def serialize(self) -> dict: "trainable": True, "spin": None, } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.filter_layers_strip.serialize()}) + return data @classmethod def deserialize(cls, data: dict) -> "DescrptDPA1": @@ -431,6 +425,11 @@ def deserialize(cls, data: dict) -> "DescrptDPA1": type_embedding = data.pop("type_embedding") attention_layers = data.pop("attention_layers") env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None obj = cls(**data) def t_cvt(xx): @@ -442,6 +441,10 @@ def t_cvt(xx): obj.se_atten["davg"] = t_cvt(variables["davg"]) obj.se_atten["dstd"] = t_cvt(variables["dstd"]) obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.se_atten.filter_layers_strip = NetworkCollection.deserialize( + embeddings_strip + ) obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( attention_layers ) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 66da86ce29..d261312f98 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -106,8 +106,9 @@ def __init__( tebd_dim : int Dimension of the type embedding tebd_input_mode : str - The way to mix the type embeddings. Supported options are `concat`. - (TODO need to support stripped_type_embedding option) + The input mode of the type embedding. Supported modes are [`concat`, `strip`]. + - `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt : bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) @@ -191,6 +192,9 @@ def __init__( # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) if self.old_impl: + assert self.tebd_input_mode in [ + "concat" + ], "Old implementation does not support tebd_input_mode != 'concat'." self.dpa1_attention = NeighborWiseAttention( self.attn_layer, self.nnei, @@ -230,16 +234,18 @@ def __init__( ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) + if not self.type_one_side: + self.tebd_dim_input = self.tebd_dim * 2 + else: + self.tebd_dim_input = self.tebd_dim if self.tebd_input_mode in ["concat"]: - if not self.type_one_side: - self.embd_input_dim = 1 + self.tebd_dim * 2 - else: - self.embd_input_dim = 1 + self.tebd_dim + self.embd_input_dim = 1 + self.tebd_dim_input else: self.embd_input_dim = 1 self.filter_layers_old = None self.filter_layers = None + self.filter_layers_strip = None if self.old_impl: filter_layers = [] one = TypeFilter( @@ -265,6 +271,18 @@ def __init__( resnet_dt=self.resnet_dt, ) self.filter_layers = filter_layers + if self.tebd_input_mode in ["strip"]: + filter_layers_strip = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers_strip[0] = EmbeddingNet( + self.tebd_dim_input, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + ) + self.filter_layers_strip = filter_layers_strip self.stats = None def get_rcut(self) -> float: @@ -498,19 +516,36 @@ def forward( rr = dmatrix rr = rr * exclude_mask[:, :, None] ss = rr[:, :, :1] + nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) + atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) if self.tebd_input_mode in ["concat"]: - nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) - atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) if not self.type_one_side: # nfnl x nnei x (1 + tebd_dim * 2) ss = torch.concat([ss, nlist_tebd, atype_tebd], dim=2) else: # nfnl x nnei x (1 + tebd_dim) ss = torch.concat([ss, nlist_tebd], dim=2) + # nfnl x nnei x ng + gg = self.filter_layers.networks[0](ss) + elif self.tebd_input_mode in ["strip"]: + # nfnl x nnei x ng + gg_s = self.filter_layers.networks[0](ss) + assert self.filter_layers_strip is not None + if not self.type_one_side: + # nfnl x nnei x (tebd_dim * 2) + tt = torch.concat([nlist_tebd, atype_tebd], dim=2) + else: + # nfnl x nnei x tebd_dim + tt = nlist_tebd + # nfnl x nnei x ng + gg_t = self.filter_layers_strip.networks[0](tt) + if self.smooth: + gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + # nfnl x nnei x ng + gg = gg_s * gg_t + gg_s else: raise NotImplementedError - # nfnl x nnei x ng - gg = self.filter_layers._networks[0](ss) + input_r = torch.nn.functional.normalize( dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 ) diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 4f7897e76c..31e3cc92f8 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -183,7 +183,7 @@ def __init__( uniform_seed: bool = False, multi_task: bool = False, spin: Optional[Spin] = None, - stripped_type_embedding: bool = False, + tebd_input_mode: str = "concat", env_protection: float = 0.0, # not implement!! **kwargs, ) -> None: @@ -194,6 +194,8 @@ def __init__( ) if env_protection != 0.0: raise NotImplementedError("env_protection != 0.0 is not supported.") + # to be compat with old option of `stripped_type_embedding` + stripped_type_embedding = tebd_input_mode == "strip" self.sel_a = sel self.rcut_r = rcut self.rcut_r_smth = rcut_smth @@ -1056,7 +1058,7 @@ def _filter_lower( ) if self.compress: raise RuntimeError( - "compression of type embedded descriptor is not supported when stripped_type_embedding == False" + "compression of type embedded descriptor is not supported when tebd_input_mode != 'strip'" ) # natom x 4 x outputs_size if nvnmd_cfg.enable: @@ -1361,7 +1363,6 @@ def init_variables( graph_def, suffix, get_extra_embedding_net_suffix(self.type_one_side), - self.layer_size, ) ) @@ -1426,7 +1427,7 @@ def serialize(self, suffix: str = "") -> dict: ) if self.stripped_type_embedding: raise NotImplementedError( - "stripped_type_embedding is unsupported by the native model" + "tebd_input_mode=='strip' is unsupported by the native model" ) if (self.original_sel != self.sel_a).any(): raise NotImplementedError( diff --git a/deepmd/tf/descriptor/se_a_ebd_v2.py b/deepmd/tf/descriptor/se_a_ebd_v2.py index 0d2acbc9d5..9b92931b7f 100644 --- a/deepmd/tf/descriptor/se_a_ebd_v2.py +++ b/deepmd/tf/descriptor/se_a_ebd_v2.py @@ -24,7 +24,7 @@ class DescrptSeAEbdV2(DescrptSeA): r"""A compressible se_a_ebd model. - This model is a warpper for DescriptorSeA, which set stripped_type_embedding=True. + This model is a warpper for DescriptorSeA, which set tebd_input_mode='strip'. """ def __init__( @@ -65,6 +65,6 @@ def __init__( uniform_seed=uniform_seed, multi_task=multi_task, spin=spin, - stripped_type_embedding=True, + tebd_input_mode="strip", **kwargs, ) diff --git a/deepmd/tf/descriptor/se_a_mask.py b/deepmd/tf/descriptor/se_a_mask.py index d1ae5d7bad..e78dfba461 100644 --- a/deepmd/tf/descriptor/se_a_mask.py +++ b/deepmd/tf/descriptor/se_a_mask.py @@ -128,7 +128,7 @@ def __init__( activation_function: str = "tanh", precision: str = "default", uniform_seed: bool = False, - stripped_type_embedding: bool = False, + tebd_input_mode: str = "concat", **kwargs, ) -> None: """Constructor.""" @@ -160,6 +160,8 @@ def __init__( # numb of neighbors and numb of descrptors self.nnei_a = np.cumsum(self.sel_a)[-1] self.nnei = self.nnei_a + # to be compat with old option of `stripped_type_embedding` + stripped_type_embedding = tebd_input_mode == "strip" self.stripped_type_embedding = stripped_type_embedding self.ndescrpt_a = self.nnei_a * 4 diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 0ba426ee4b..8d93d9659a 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -20,8 +20,10 @@ EnvMat, ) from deepmd.dpmodel.utils.network import ( + EmbeddingNet, LayerNorm, NativeLayer, + NetworkCollection, ) from deepmd.tf.common import ( cast_precision, @@ -116,7 +118,9 @@ class DescrptSeAtten(DescrptSeA): seed: int, Optional Random seed for initializing the network parameters. type_one_side: bool - Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. @@ -140,9 +144,11 @@ class DescrptSeAtten(DescrptSeA): The epsilon value for layer normalization. multi_task: bool If the model has multi fitting nets to train. - stripped_type_embedding: bool - Whether to strip the type embedding into a separated embedding network. - Default value will be True in `se_atten_v2` descriptor. + tebd_input_mode: str + The input mode of the type embedding. Supported modes are [`concat`, `strip`]. + - `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. + Default value will be `strip` in `se_atten_v2` descriptor. smooth_type_embedding: bool Whether to use smooth process in attention weights calculation. And when using stripped type embedding, whether to dot smooth factor on the network output of type embedding @@ -177,8 +183,8 @@ def __init__( attn_dotr: bool = True, attn_mask: bool = False, multi_task: bool = False, - stripped_type_embedding: bool = False, smooth_type_embedding: bool = False, + tebd_input_mode: str = "concat", # not implemented scaling_factor=1.0, normalize=True, @@ -189,6 +195,8 @@ def __init__( env_protection: float = 0.0, # not implement!! **kwargs, ) -> None: + # to be compat with old option of `stripped_type_embedding` + stripped_type_embedding = tebd_input_mode == "strip" if not set_davg_zero and not ( stripped_type_embedding and smooth_type_embedding ): @@ -239,6 +247,7 @@ def __init__( if ntypes == 0: raise ValueError("`model/type_map` is not set or empty!") self.stripped_type_embedding = stripped_type_embedding + self.tebd_input_mode = tebd_input_mode self.smooth = smooth_type_embedding self.trainable_ln = trainable_ln self.ln_eps = ln_eps @@ -1372,7 +1381,6 @@ def compat_ln_pattern(old_key): graph_def, suffix, get_extra_embedding_net_suffix(type_one_side=False), - self.layer_size, ) ) @@ -1581,6 +1589,89 @@ def serialize_attention_layers( ) return data + def serialize_network_strip( + self, + ntypes: int, + ndim: int, + in_dim: int, + neuron: List[int], + activation_function: str, + resnet_dt: bool, + variables: dict, + suffix: str = "", + type_one_side: bool = False, + ) -> dict: + """Serialize network. + + Parameters + ---------- + ntypes : int + The number of types + ndim : int + The dimension of elements + in_dim : int + The input dimension + neuron : List[int] + The neuron list + activation_function : str + The activation function + resnet_dt : bool + Whether to use resnet + variables : dict + The input variables + suffix : str, optional + The suffix of the scope + type_one_side : bool, optional + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + + Returns + ------- + dict + The converted network data + """ + assert ndim == 0, "only supports descriptors with type embedding!" + embeddings = NetworkCollection( + ntypes=ntypes, + ndim=ndim, + network_type="embedding_network", + ) + name_suffix = get_extra_embedding_net_suffix(type_one_side=type_one_side) + embedding_net_pattern_strip = str( + rf"filter_type_(all)/(matrix)_(\d+){name_suffix}|" + rf"filter_type_(all)/(bias)_(\d+){name_suffix}|" + rf"filter_type_(all)/(idt)_(\d+){name_suffix}|" + )[:-1] + if suffix != "": + embedding_net_pattern = ( + embedding_net_pattern_strip.replace("/(idt)", suffix + "/(idt)") + .replace("/(bias)", suffix + "/(bias)") + .replace("/(matrix)", suffix + "/(matrix)") + ) + else: + embedding_net_pattern = embedding_net_pattern_strip + for key, value in variables.items(): + m = re.search(embedding_net_pattern, key) + m = [mm for mm in m.groups() if mm is not None] + layer_idx = int(m[2]) - 1 + weight_name = m[1] + network_idx = () + if embeddings[network_idx] is None: + # initialize the network if it is not initialized + embeddings[network_idx] = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function=activation_function, + resnet_dt=resnet_dt, + precision=self.precision.name, + ) + assert embeddings[network_idx] is not None + if weight_name == "idt": + value = value.ravel() + embeddings[network_idx][layer_idx][weight_name] = value + return embeddings.serialize() + @classmethod def deserialize_attention_layers(cls, data: dict, suffix: str = "") -> dict: """Deserialize attention layers. @@ -1657,6 +1748,53 @@ def deserialize_attention_layers(cls, data: dict, suffix: str = "") -> dict: ] = layer_norm["matrix"] return attention_layer_variables + @classmethod + def deserialize_network_strip( + cls, data: dict, suffix: str = "", type_one_side: bool = False + ) -> dict: + """Deserialize network. + + Parameters + ---------- + data : dict + The input network data + suffix : str, optional + The suffix of the scope + type_one_side : bool, optional + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + + Returns + ------- + variables : dict + The input variables + """ + embedding_net_variables = {} + embeddings = NetworkCollection.deserialize(data) + assert embeddings.ndim == 0, "only supports descriptors with type embedding!" + name_suffix = get_extra_embedding_net_suffix(type_one_side=type_one_side) + net_idx = () + network = embeddings[net_idx] + assert network is not None + for layer_idx, layer in enumerate(network.layers): + embedding_net_variables[ + f"filter_type_all{suffix}/matrix_{layer_idx + 1}{name_suffix}" + ] = layer.w + embedding_net_variables[ + f"filter_type_all{suffix}/bias_{layer_idx + 1}{name_suffix}" + ] = layer.b + if layer.idt is not None: + embedding_net_variables[ + f"filter_type_all{suffix}/idt_{layer_idx + 1}{name_suffix}" + ] = layer.idt.reshape(1, -1) + else: + # prevent keyError + embedding_net_variables[ + f"filter_type_all{suffix}/idt_{layer_idx + 1}{name_suffix}" + ] = 0.0 + return embedding_net_variables + @classmethod def deserialize(cls, data: dict, suffix: str = ""): """Deserialize the model. @@ -1685,6 +1823,11 @@ def deserialize(cls, data: dict, suffix: str = ""): ) data.pop("env_mat") variables = data.pop("@variables") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + raise NotImplementedError( + "deserialization is unsupported by the native model when tebd_input_mode=='strip'" + ) descriptor = cls(**data) descriptor.embedding_net_variables = embedding_net_variables descriptor.attention_layer_variables = attention_layer_variables @@ -1713,9 +1856,15 @@ def serialize(self, suffix: str = "") -> dict: raise NotImplementedError( "Not implemented in class %s" % self.__class__.__name__ ) - if self.stripped_type_embedding: + if self.stripped_type_embedding and type(self) is not DescrptDPA1Compat: + # only DescrptDPA1Compat can serialize when tebd_input_mode=='strip' + raise NotImplementedError( + "serialization is unsupported by the native model when tebd_input_mode=='strip'" + ) + # todo support serialization when tebd_input_mode=='strip' and type_one_side is True + if self.stripped_type_embedding and self.type_one_side: raise NotImplementedError( - "stripped_type_embedding is unsupported by the native model" + "serialization is unsupported when tebd_input_mode=='strip' and type_one_side is True" ) if (self.original_sel != self.sel_a).any(): raise NotImplementedError( @@ -1728,7 +1877,7 @@ def serialize(self, suffix: str = "") -> dict: assert self.davg is not None assert self.dstd is not None - return { + data = { "@class": "Descriptor", "type": "se_atten", "@version": 1, @@ -1746,6 +1895,7 @@ def serialize(self, suffix: str = "") -> dict: "activation_function": self.activation_function_name, "resnet_dt": self.filter_resnet_dt, "smooth_type_embedding": self.smooth, + "tebd_input_mode": self.tebd_input_mode, "trainable_ln": self.trainable_ln, "ln_eps": self.ln_eps, "precision": self.filter_precision.name, @@ -1785,6 +1935,27 @@ def serialize(self, suffix: str = "") -> dict: "type_one_side": self.type_one_side, "spin": self.spin, } + if self.tebd_input_mode in ["strip"]: + assert ( + type(self) is DescrptDPA1Compat + ), "only DescrptDPA1Compat can serialize when tebd_input_mode=='strip'" + data.update( + { + "embeddings_strip": self.serialize_network_strip( + ntypes=self.ntypes, + ndim=0, + in_dim=2 + * self.tebd_dim, # only DescrptDPA1Compat has this attribute + neuron=self.filter_neuron, + activation_function=self.activation_function_name, + resnet_dt=self.filter_resnet_dt, + variables=self.two_side_embeeding_net_variables, + suffix=suffix, + type_one_side=self.type_one_side, + ) + } + ) + return data class DescrptDPA1Compat(DescrptSeAtten): @@ -1810,8 +1981,9 @@ class DescrptDPA1Compat(DescrptSeAtten): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - (Only support `concat` to keep consistent with other backend references.) - The way to mix the type embeddings. + The input mode of the type embedding. Supported modes are [`concat`, `strip`]. + - `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) @@ -1902,10 +2074,6 @@ def __init__( seed: Optional[int] = None, uniform_seed: bool = False, ) -> None: - if tebd_input_mode != "concat": - raise NotImplementedError( - "Only support tebd_input_mode == `concat` in this version." - ) if not normalize: raise NotImplementedError("Only support normalize == True in this version.") if temperature != 1.0: @@ -1943,14 +2111,13 @@ def __init__( attn_dotr=attn_dotr, attn_mask=attn_mask, multi_task=True, - stripped_type_embedding=False, trainable_ln=trainable_ln, ln_eps=ln_eps, smooth_type_embedding=smooth_type_embedding, + tebd_input_mode=tebd_input_mode, env_protection=env_protection, ) self.tebd_dim = tebd_dim - self.tebd_input_mode = tebd_input_mode self.scaling_factor = scaling_factor self.normalize = normalize self.temperature = temperature @@ -2089,9 +2256,20 @@ def deserialize(cls, data: dict, suffix: str = ""): data.pop("env_mat") variables = data.pop("@variables") type_embedding = data.pop("type_embedding") + tebd_input_mode = data["tebd_input_mode"] + type_one_side = data["type_one_side"] + if tebd_input_mode in ["strip"]: + two_side_embeeding_net_variables = cls.deserialize_network_strip( + data.pop("embeddings_strip"), + suffix=suffix, + type_one_side=type_one_side, + ) + else: + two_side_embeeding_net_variables = None descriptor = cls(**data) descriptor.embedding_net_variables = embedding_net_variables descriptor.attention_layer_variables = attention_layer_variables + descriptor.two_side_embeeding_net_variables = two_side_embeeding_net_variables descriptor.davg = variables["davg"].reshape( descriptor.ntypes, descriptor.ndescrpt ) @@ -2121,7 +2299,6 @@ def serialize(self, suffix: str = "") -> dict: { "type": "dpa1", "tebd_dim": self.tebd_dim, - "tebd_input_mode": self.tebd_input_mode, "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, diff --git a/deepmd/tf/descriptor/se_atten_v2.py b/deepmd/tf/descriptor/se_atten_v2.py index 01c4d93ad8..61e672788e 100644 --- a/deepmd/tf/descriptor/se_atten_v2.py +++ b/deepmd/tf/descriptor/se_atten_v2.py @@ -109,7 +109,7 @@ def __init__( attn_dotr=attn_dotr, attn_mask=attn_mask, multi_task=multi_task, - stripped_type_embedding=True, + tebd_input_mode="strip", smooth_type_embedding=True, **kwargs, ) diff --git a/deepmd/tf/nvnmd/data/data.py b/deepmd/tf/nvnmd/data/data.py index 9e6dd4cc89..7f2c9ef5e9 100644 --- a/deepmd/tf/nvnmd/data/data.py +++ b/deepmd/tf/nvnmd/data/data.py @@ -332,7 +332,7 @@ "descriptor": { "seed": 1, "type": "se_atten", - "stripped_type_embedding": True, + "tebd_input_mode": "strip", "sel": 128, "rcut": 7.0, "rcut_smth": 0.5, diff --git a/deepmd/tf/utils/graph.py b/deepmd/tf/utils/graph.py index 53de9c9ce2..287c1575d5 100644 --- a/deepmd/tf/utils/graph.py +++ b/deepmd/tf/utils/graph.py @@ -216,60 +216,72 @@ def get_extra_embedding_net_suffix(type_one_side: bool): return extra_suffix -def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern: str): - """Get variables from the given tf.GraphDef object, with numpy array returns. +def get_extra_embedding_net_nodes_from_graph_def( + graph_def: tf.GraphDef, + suffix: str = "", + extra_suffix: str = "", +) -> Dict: + """Get the extra embedding net nodes with the given tf.GraphDef object. Parameters ---------- graph_def The input tf.GraphDef object - pattern : str - The name of variable + suffix : str, optional + The scope suffix + extra_suffix : str + The extra scope suffix Returns ------- - np.ndarray - The numpy array of the variable + Dict + The embedding net nodes within the given tf.GraphDef object """ - node = get_pattern_nodes_from_graph_def(graph_def, pattern)[pattern] - return tf.make_ndarray(node) + embedding_net_pattern_strip = str( + rf"filter_type_(all)/(matrix)_(\d+){extra_suffix}|" + rf"filter_type_(all)/(bias)_(\d+){extra_suffix}|" + rf"filter_type_(all)/(idt)_(\d+){extra_suffix}|" + )[:-1] + if suffix != "": + embedding_net_pattern_strip = ( + embedding_net_pattern_strip.replace("/(idt)", suffix + "/(idt)") + .replace("/(bias)", suffix + "/(bias)") + .replace("/(matrix)", suffix + "/(matrix)") + ) + else: + embedding_net_pattern_strip = embedding_net_pattern_strip + + embedding_net_nodes_strip = get_pattern_nodes_from_graph_def( + graph_def, embedding_net_pattern_strip + ) + return embedding_net_nodes_strip def get_extra_embedding_net_variables_from_graph_def( - graph_def: tf.GraphDef, suffix: str, extra_suffix: str, layer_size: int -): - """Get extra embedding net variables from the given tf.GraphDef object. - The "extra embedding net" means the embedding net with only type embeddings input, - which occurs in "se_atten_v2" and "se_a_ebd_v2" descriptor. + graph_def: tf.GraphDef, + suffix: str = "", + extra_suffix: str = "", +) -> Dict: + """Get the embedding net variables with the given tf.GraphDef object. Parameters ---------- graph_def The input tf.GraphDef object - suffix : str - The "common" suffix in the descriptor - extra_suffix : str - This value depends on the value of "type_one_side". - It should always be "_one_side_ebd" or "_two_side_ebd" - layer_size : int - The layer size of the embedding net + suffix : str, optional + The suffix of the scope + extra_suffix + The extra scope suffix Returns ------- Dict - The extra embedding net variables within the given tf.GraphDef object + The embedding net variables within the given tf.GraphDef object """ - extra_embedding_net_variables = {} - for i in range(1, layer_size + 1): - matrix_pattern = f"filter_type_all{suffix}/matrix_{i}{extra_suffix}" - extra_embedding_net_variables[matrix_pattern] = ( - get_variables_from_graph_def_as_numpy_array(graph_def, matrix_pattern) - ) - bias_pattern = f"filter_type_all{suffix}/bias_{i}{extra_suffix}" - extra_embedding_net_variables[bias_pattern] = ( - get_variables_from_graph_def_as_numpy_array(graph_def, bias_pattern) - ) - return extra_embedding_net_variables + extra_embedding_net_nodes = get_extra_embedding_net_nodes_from_graph_def( + graph_def, extra_suffix=extra_suffix, suffix=suffix + ) + return convert_tensor_to_ndarray_in_dict(extra_embedding_net_nodes) def get_embedding_net_variables(model_file: str, suffix: str = "") -> Dict: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 94c225010a..a2f2b65eb5 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -424,7 +424,7 @@ def descrpt_se_atten_common_args(): + "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." ) doc_attn = "The length of hidden vectors in attention layers" - doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` is only enabled when attn_layer==0 and stripped_type_embedding is True" + doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` is only enabled when attn_layer==0 and tebd_input_mode=='strip'" doc_attn_dotr = "Whether to do dot product with the normalized relative coordinates" doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix" @@ -475,7 +475,6 @@ def descrpt_se_atten_common_args(): @descrpt_args_plugin.register("se_atten", alias=["dpa1"]) def descrpt_se_atten_args(): - doc_stripped_type_embedding = "Whether to strip the type embedding into a separated embedding network. Setting it to `False` will fall back to the previous version of `se_atten` which is non-compressible." doc_smooth_type_embedding = f"Whether to use smooth process in attention weights calculation. {doc_only_tf_supported} When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used" doc_trainable_ln = ( @@ -495,17 +494,18 @@ def descrpt_se_atten_args(): doc_concat_output_tebd = ( "Whether to concat type embedding at the output of the descriptor." ) - doc_deprecated = "This feature will be removed in a future release." + doc_tebd_input_mode = ( + "The input mode of the type embedding. Supported modes are [`concat`, `strip`]." + "- `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. " + "When `type_one_side` is False, the input is `input_ij = concat([r_ij, tebd_j, tebd_i])`. When `type_one_side` is True, the input is `input_ij = concat([r_ij, tebd_j])`. " + "The output is `out_ij = embeding(input_ij)` for the pair-wise representation of atom i with neighbor j." + "- `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. " + f"When `type_one_side` is False, the input is `input_t = concat([tebd_j, tebd_i])`. {doc_only_pt_supported} When `type_one_side` is True, the input is `input_t = tebd_j`. " + "The output is `out_ij = embeding_t(input_t) * embeding_s(r_ij) + embeding_s(r_ij)` for the pair-wise representation of atom i with neighbor j." + ) return [ *descrpt_se_atten_common_args(), - Argument( - "stripped_type_embedding", - bool, - optional=True, - default=False, - doc=doc_only_tf_supported + doc_stripped_type_embedding, - ), Argument( "smooth_type_embedding", bool, @@ -534,7 +534,7 @@ def descrpt_se_atten_args(): str, optional=True, default="concat", - doc=doc_only_pt_supported + doc_deprecated, + doc=doc_tebd_input_mode, ), Argument( "scaling_factor", @@ -2311,6 +2311,23 @@ def gen_args(**kwargs) -> List[Argument]: ] +def backend_compat(data): + data = data.copy() + # stripped_type_embedding in old DescrptSeAtten + if data["model"]["descriptor"].get("type", "se_e2_a") == "se_atten" and data[ + "model" + ]["descriptor"].pop("stripped_type_embedding", False): + if "tebd_input_mode" not in data["model"]["descriptor"]: + data["model"]["descriptor"]["tebd_input_mode"] = "strip" + elif data["model"]["descriptor"]["tebd_input_mode"] != "strip": + raise ValueError( + "When setting stripped_type_embedding == True, tebd_input_mode should be 'strip'!" + ) + else: + pass + return data + + def normalize_multi_task(data): # single-task or multi-task mode if data["model"].get("type", "standard") not in ("standard", "multi"): @@ -2507,6 +2524,7 @@ def normalize_fitting_weight(fitting_keys, data_keys, fitting_weight=None): def normalize(data): data = normalize_multi_task(data) + data = backend_compat(data) base = Argument("base", dict, gen_args()) data = base.normalize_value(data, trim_pattern="_*") diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index 59333eb0da..6e70caa8ea 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -156,7 +156,7 @@ An example of the DPA-1 descriptor is provided as follows We highly recommend using the version 2.0 of the attention-based descriptor `"se_atten_v2"`, which is inherited from `"se_atten"` but with the following parameter modifications: ```json - "stripped_type_embedding": true, + "tebd_input_mode": 'strip', "smooth_type_embedding": true, "set_davg_zero": false ``` diff --git a/examples/water/se_atten_dpa1_compat/input.json b/examples/water/se_atten_dpa1_compat/input.json index 90c597e586..3018096ae5 100644 --- a/examples/water/se_atten_dpa1_compat/input.json +++ b/examples/water/se_atten_dpa1_compat/input.json @@ -7,7 +7,7 @@ ], "descriptor": { "type": "se_atten", - "stripped_type_embedding": false, + "tebd_input_mode": "concat", "sel": 120, "rcut_smth": 0.50, "rcut": 6.00, diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index c0ca46c91e..a2d4ca074f 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -40,7 +40,7 @@ @parameterized( (4,), # tebd_dim - ("concat",), # tebd_input_mode + ("concat", "strip"), # tebd_input_mode (True,), # resnet_dt (True, False), # type_one_side (20,), # attn @@ -181,6 +181,7 @@ def skip_tf(self) -> bool: or not normalize or temperature != 1.0 or (excluded_types != [] and attn_layer > 0) + or (type_one_side and tebd_input_mode == "strip") # not consistent yet ) tf_class = DescrptDPA1TF diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py index 7567f18593..c1b6f97b26 100644 --- a/source/tests/pt/model/test_dpa1.py +++ b/source/tests/pt/model/test_dpa1.py @@ -39,11 +39,12 @@ def test_consistency( dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec, sm, to in itertools.product( + for idt, prec, sm, to, tm in itertools.product( [False, True], # resnet_dt ["float64", "float32"], # precision [False, True], # smooth_type_embedding [False, True], # type_one_side + ["concat", "strip"], # tebd_input_mode ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -60,6 +61,7 @@ def test_consistency( resnet_dt=idt, smooth_type_embedding=sm, type_one_side=to, + tebd_input_mode=tm, old_impl=False, ).to(env.DEVICE) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) @@ -98,7 +100,7 @@ def test_consistency( err_msg=err_msg, ) # old impl - if idt is False and prec == "float64" and to is False: + if idt is False and prec == "float64" and to is False and tm == "concat": dd3 = DescrptDPA1( self.rcut, self.rcut_smth, @@ -163,16 +165,21 @@ def test_jit( dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec, sm, to in itertools.product( - [False, True], - ["float64", "float32"], - [False, True], - [False, True], + for idt, prec, sm, to, tm in itertools.product( + [ + False, + ], # resnet_dt + [ + "float64", + ], # precision + [False, True], # smooth_type_embedding + [False, True], # type_one_side + ["concat", "strip"], # tebd_input_mode ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) err_msg = f"idt={idt} prec={prec}" - # sea new impl + # dpa1 new impl dd0 = DescrptDPA1( self.rcut, self.rcut_smth, @@ -182,6 +189,7 @@ def test_jit( resnet_dt=idt, smooth_type_embedding=sm, type_one_side=to, + tebd_input_mode=tm, old_impl=False, ) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/tf/test_data_large_batch.py b/source/tests/tf/test_data_large_batch.py index dad6bbf252..1232f8b1db 100644 --- a/source/tests/tf/test_data_large_batch.py +++ b/source/tests/tf/test_data_large_batch.py @@ -309,7 +309,7 @@ def test_stripped_data_mixed_type(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True) jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out() @@ -507,7 +507,7 @@ def test_compressible_data_mixed_type(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 0 descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True) jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() diff --git a/source/tests/tf/test_descrpt_se_atten.py b/source/tests/tf/test_descrpt_se_atten.py index 7a1bfd18f6..da5eb05650 100644 --- a/source/tests/tf/test_descrpt_se_atten.py +++ b/source/tests/tf/test_descrpt_se_atten.py @@ -421,7 +421,7 @@ def test_stripped_type_embedding_descriptor_two_sides(self): "resnet_dt": False, "seed": 1, } - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" # init models typeebd = TypeEmbedNet( @@ -588,7 +588,7 @@ def test_compressible_descriptor_two_sides(self): jdata["model"]["descriptor"]["neuron"] = [5, 5, 5] jdata["model"]["descriptor"]["axis_neuron"] = 2 jdata["model"]["descriptor"]["attn_layer"] = 0 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" typeebd_param = { "neuron": [5], "resnet_dt": False, diff --git a/source/tests/tf/test_finetune_se_atten.py b/source/tests/tf/test_finetune_se_atten.py index 40fc5b68a3..ebb858b0bb 100644 --- a/source/tests/tf/test_finetune_se_atten.py +++ b/source/tests/tf/test_finetune_se_atten.py @@ -146,15 +146,15 @@ def setUpClass(cls) -> None: if not parse_version(tf.__version__) < parse_version("1.15"): def previous_se_atten(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = False + jdata["model"]["descriptor"]["tebd_input_mode"] = "concat" jdata["model"]["descriptor"]["attn_layer"] = 2 def stripped_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 2 def compressible_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 0 models = [previous_se_atten, stripped_model, compressible_model] diff --git a/source/tests/tf/test_init_frz_model_se_atten.py b/source/tests/tf/test_init_frz_model_se_atten.py index a114deffc8..25f629511c 100644 --- a/source/tests/tf/test_init_frz_model_se_atten.py +++ b/source/tests/tf/test_init_frz_model_se_atten.py @@ -136,15 +136,15 @@ def _init_models(model_setup, i): if not parse_version(tf.__version__) < parse_version("1.15"): def previous_se_atten(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = False + jdata["model"]["descriptor"]["tebd_input_mode"] = "concat" jdata["model"]["descriptor"]["attn_layer"] = 2 def stripped_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 2 def compressible_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 0 diff --git a/source/tests/tf/test_model_compression_se_atten.py b/source/tests/tf/test_model_compression_se_atten.py index 03ddedad39..29328056c3 100644 --- a/source/tests/tf/test_model_compression_se_atten.py +++ b/source/tests/tf/test_model_compression_se_atten.py @@ -79,7 +79,7 @@ def _init_models(): jdata["model"]["descriptor"] = {} jdata["model"]["descriptor"]["type"] = "se_atten" jdata["model"]["descriptor"]["precision"] = tests[i]["se_atten precision"] - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["sel"] = 120 jdata["model"]["descriptor"]["attn_layer"] = 0 jdata["model"]["descriptor"]["smooth_type_embedding"] = tests[i][ @@ -128,7 +128,7 @@ def _init_models_exclude_types(): jdata["model"]["descriptor"]["type"] = "se_atten" jdata["model"]["descriptor"]["exclude_types"] = [[0, 1]] jdata["model"]["descriptor"]["precision"] = tests[i]["se_atten precision"] - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["sel"] = 120 jdata["model"]["descriptor"]["attn_layer"] = 0 jdata["model"]["type_embedding"] = {} diff --git a/source/tests/tf/test_model_se_atten.py b/source/tests/tf/test_model_se_atten.py index d75dc0cfff..a4b6575ecb 100644 --- a/source/tests/tf/test_model_se_atten.py +++ b/source/tests/tf/test_model_se_atten.py @@ -293,7 +293,7 @@ def test_compressible_model(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 0 descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True) jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() @@ -461,7 +461,7 @@ def test_compressible_exclude_types(self): # successful descrpt = DescrptSeAtten(ntypes=ntypes, **jdata["model"]["descriptor"]) typeebd_param = jdata["model"]["type_embedding"] - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 0 typeebd = TypeEmbedNet( ntypes=descrpt.get_ntypes(), @@ -524,7 +524,7 @@ def test_stripped_type_embedding_model(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 2 descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True) jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() @@ -695,7 +695,7 @@ def test_stripped_type_embedding_exclude_types(self): # successful descrpt = DescrptSeAtten(ntypes=ntypes, **jdata["model"]["descriptor"]) typeebd_param = jdata["model"]["type_embedding"] - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 2 typeebd = TypeEmbedNet( ntypes=descrpt.get_ntypes(), @@ -763,7 +763,7 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["smooth_type_embedding"] = True jdata["model"]["descriptor"]["attn_layer"] = 1 jdata["model"]["descriptor"]["rcut"] = 6.0 @@ -909,7 +909,7 @@ def test_smoothness_of_stripped_type_embedding_smooth_model_excluded_types(self) jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["smooth_type_embedding"] = True jdata["model"]["descriptor"]["attn_layer"] = 1 jdata["model"]["descriptor"]["rcut"] = 6.0 From a2301985f50d711e621a6f0f2b0cf14d91bd2d66 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Apr 2024 21:58:37 +0800 Subject: [PATCH 02/37] Update train-se-atten.md --- doc/model/train-se-atten.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index 6e70caa8ea..89c2068ab2 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -156,7 +156,7 @@ An example of the DPA-1 descriptor is provided as follows We highly recommend using the version 2.0 of the attention-based descriptor `"se_atten_v2"`, which is inherited from `"se_atten"` but with the following parameter modifications: ```json - "tebd_input_mode": 'strip', + "tebd_input_mode": "strip", "smooth_type_embedding": true, "set_davg_zero": false ``` From 515778100ca2c217b1ed72eb4013c1af50c3d40c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Apr 2024 22:13:38 +0800 Subject: [PATCH 03/37] Update graph.py --- deepmd/tf/utils/graph.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/deepmd/tf/utils/graph.py b/deepmd/tf/utils/graph.py index 287c1575d5..a891506e95 100644 --- a/deepmd/tf/utils/graph.py +++ b/deepmd/tf/utils/graph.py @@ -248,8 +248,6 @@ def get_extra_embedding_net_nodes_from_graph_def( .replace("/(bias)", suffix + "/(bias)") .replace("/(matrix)", suffix + "/(matrix)") ) - else: - embedding_net_pattern_strip = embedding_net_pattern_strip embedding_net_nodes_strip = get_pattern_nodes_from_graph_def( graph_def, embedding_net_pattern_strip From f780d58c48eade7e72c968d14ee9182a2e5bcc55 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Apr 2024 22:16:26 +0800 Subject: [PATCH 04/37] Update deepmd/utils/argcheck.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> --- deepmd/utils/argcheck.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index a2f2b65eb5..aa6019b13e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2321,7 +2321,7 @@ def backend_compat(data): data["model"]["descriptor"]["tebd_input_mode"] = "strip" elif data["model"]["descriptor"]["tebd_input_mode"] != "strip": raise ValueError( - "When setting stripped_type_embedding == True, tebd_input_mode should be 'strip'!" + "Conflict detected: 'stripped_type_embedding' is set to True, but 'tebd_input_mode' is not 'strip'. Please ensure 'tebd_input_mode' is set to 'strip' when 'stripped_type_embedding' is True." ) else: pass From 3b3d25e04b1c35fd7a22baa6c5df36005ac5cd67 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Apr 2024 22:17:20 +0800 Subject: [PATCH 05/37] Update deepmd/pt/model/descriptor/se_atten.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> --- deepmd/pt/model/descriptor/se_atten.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index d261312f98..fa3c6c1e3c 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -234,10 +234,7 @@ def __init__( ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) - if not self.type_one_side: - self.tebd_dim_input = self.tebd_dim * 2 - else: - self.tebd_dim_input = self.tebd_dim + self.tebd_dim_input = self.tebd_dim * (2 if not self.type_one_side else 1) if self.tebd_input_mode in ["concat"]: self.embd_input_dim = 1 + self.tebd_dim_input else: From cf841f21c37f7090a6a1f7521198c6cd306170b8 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Apr 2024 22:18:34 +0800 Subject: [PATCH 06/37] Update deepmd/tf/descriptor/se_a.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> --- deepmd/tf/descriptor/se_a.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 31e3cc92f8..cc94e65990 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -1058,7 +1058,7 @@ def _filter_lower( ) if self.compress: raise RuntimeError( - "compression of type embedded descriptor is not supported when tebd_input_mode != 'strip'" + "compression of type embedded descriptor is not supported when tebd_input_mode is not set to 'strip'" ) # natom x 4 x outputs_size if nvnmd_cfg.enable: From a9e24d9f08075d2d92ca57e718d8b34920dd36e7 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Apr 2024 22:19:20 +0800 Subject: [PATCH 07/37] Update deepmd/tf/descriptor/se_a.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> --- deepmd/tf/descriptor/se_a.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index cc94e65990..f8acd1dba1 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -1427,7 +1427,7 @@ def serialize(self, suffix: str = "") -> dict: ) if self.stripped_type_embedding: raise NotImplementedError( - "tebd_input_mode=='strip' is unsupported by the native model" + "Serialization is unsupported when tebd_input_mode is set to 'strip'" ) if (self.original_sel != self.sel_a).any(): raise NotImplementedError( From 764cab7e544855a04cfa0a9106aac44c7df9160c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Apr 2024 22:23:51 +0800 Subject: [PATCH 08/37] Update deepmd/tf/descriptor/se_atten.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> --- deepmd/tf/descriptor/se_atten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 8d93d9659a..70602f4c85 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -195,7 +195,7 @@ def __init__( env_protection: float = 0.0, # not implement!! **kwargs, ) -> None: - # to be compat with old option of `stripped_type_embedding` + # Ensure compatibility with the deprecated `stripped_type_embedding` option. stripped_type_embedding = tebd_input_mode == "strip" if not set_davg_zero and not ( stripped_type_embedding and smooth_type_embedding From 0b9cea16cc31b876c6ab4e7f30c39c23e1af847f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Apr 2024 22:24:43 +0800 Subject: [PATCH 09/37] Update deepmd/tf/descriptor/se_atten.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> --- deepmd/tf/descriptor/se_atten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 70602f4c85..8cdf626cf3 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -1825,8 +1825,8 @@ def deserialize(cls, data: dict, suffix: str = ""): variables = data.pop("@variables") tebd_input_mode = data["tebd_input_mode"] if tebd_input_mode in ["strip"]: - raise NotImplementedError( - "deserialization is unsupported by the native model when tebd_input_mode=='strip'" + raise ValueError( + "Deserialization is unsupported for `tebd_input_mode='strip'` in the native model." ) descriptor = cls(**data) descriptor.embedding_net_variables = embedding_net_variables From 1e86b75bd8ccc4d02d71b5e9a71a5416ede2eca6 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 26 Apr 2024 00:40:24 +0800 Subject: [PATCH 10/37] Update docs --- deepmd/dpmodel/descriptor/dpa1.py | 6 +++--- deepmd/pt/model/descriptor/dpa1.py | 6 +++--- deepmd/pt/model/descriptor/se_atten.py | 6 +++--- deepmd/tf/descriptor/se_atten.py | 12 ++++++------ deepmd/utils/argcheck.py | 14 ++++++++------ 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 2f3494c2b0..4bef4a20ef 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -135,9 +135,9 @@ class DescrptDPA1(NativeOP, BaseDescriptor): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - The input mode of the type embedding. Supported modes are [`concat`, `strip`]. - - `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. - - `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 42e74cacc1..5caf583a23 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -117,9 +117,9 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - The input mode of the type embedding. Supported modes are [`concat`, `strip`]. - - `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. - - `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index fa3c6c1e3c..335f914349 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -106,9 +106,9 @@ def __init__( tebd_dim : int Dimension of the type embedding tebd_input_mode : str - The input mode of the type embedding. Supported modes are [`concat`, `strip`]. - - `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. - - `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt : bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 8cdf626cf3..d91a490855 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -145,9 +145,9 @@ class DescrptSeAtten(DescrptSeA): multi_task: bool If the model has multi fitting nets to train. tebd_input_mode: str - The input mode of the type embedding. Supported modes are [`concat`, `strip`]. - - `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. - - `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. Default value will be `strip` in `se_atten_v2` descriptor. smooth_type_embedding: bool Whether to use smooth process in attention weights calculation. @@ -1981,9 +1981,9 @@ class DescrptDPA1Compat(DescrptSeAtten): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - The input mode of the type embedding. Supported modes are [`concat`, `strip`]. - - `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. - - `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index aa6019b13e..1e24091018 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -495,11 +495,11 @@ def descrpt_se_atten_args(): "Whether to concat type embedding at the output of the descriptor." ) doc_tebd_input_mode = ( - "The input mode of the type embedding. Supported modes are [`concat`, `strip`]." - "- `concat`: Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. " + "The input mode of the type embedding. Supported modes are ['concat', 'strip']." + "- 'concat': Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. " "When `type_one_side` is False, the input is `input_ij = concat([r_ij, tebd_j, tebd_i])`. When `type_one_side` is True, the input is `input_ij = concat([r_ij, tebd_j])`. " "The output is `out_ij = embeding(input_ij)` for the pair-wise representation of atom i with neighbor j." - "- `strip`: Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. " + "- 'strip': Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. " f"When `type_one_side` is False, the input is `input_t = concat([tebd_j, tebd_i])`. {doc_only_pt_supported} When `type_one_side` is True, the input is `input_t = tebd_j`. " "The output is `out_ij = embeding_t(input_t) * embeding_s(r_ij) + embeding_s(r_ij)` for the pair-wise representation of atom i with neighbor j." ) @@ -2314,9 +2314,11 @@ def gen_args(**kwargs) -> List[Argument]: def backend_compat(data): data = data.copy() # stripped_type_embedding in old DescrptSeAtten - if data["model"]["descriptor"].get("type", "se_e2_a") == "se_atten" and data[ - "model" - ]["descriptor"].pop("stripped_type_embedding", False): + if ( + "descriptor" in data["model"] + and data["model"]["descriptor"].get("type", "se_e2_a") == "se_atten" + and data["model"]["descriptor"].pop("stripped_type_embedding", False) + ): if "tebd_input_mode" not in data["model"]["descriptor"]: data["model"]["descriptor"]["tebd_input_mode"] = "strip" elif data["model"]["descriptor"]["tebd_input_mode"] != "strip": From f3056ee5003b2d33a1394857ee380c0364f405f4 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 26 Apr 2024 15:30:51 +0800 Subject: [PATCH 11/37] resolve conversations --- deepmd/dpmodel/descriptor/dpa1.py | 5 +---- deepmd/pt/model/descriptor/se_atten.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 4bef4a20ef..bdffd74fbe 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -285,10 +285,7 @@ def __init__( activation_function="Linear", precision=precision, ) - if not self.type_one_side: - self.tebd_dim_input = self.tebd_dim * 2 - else: - self.tebd_dim_input = self.tebd_dim + self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 if self.tebd_input_mode in ["concat"]: self.embd_input_dim = 1 + self.tebd_dim_input else: diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 335f914349..9bf4788bf2 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -234,7 +234,7 @@ def __init__( ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) - self.tebd_dim_input = self.tebd_dim * (2 if not self.type_one_side else 1) + self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 if self.tebd_input_mode in ["concat"]: self.embd_input_dim = 1 + self.tebd_dim_input else: From 4e231e45cbf0c157c21eb32a78bc94a4083e23bf Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sun, 28 Apr 2024 16:16:24 +0800 Subject: [PATCH 12/37] rf dpa2 with identity implement --- deepmd/dpmodel/descriptor/descriptor.py | 127 +++ deepmd/dpmodel/descriptor/dpa1.py | 619 ++++++++++----- deepmd/pt/model/descriptor/dpa1.py | 4 +- deepmd/pt/model/descriptor/dpa2.py | 453 ++++++++--- deepmd/pt/model/descriptor/repformer_layer.py | 578 +++++++++++--- .../descriptor/repformer_layer_old_impl.py | 743 ++++++++++++++++++ deepmd/pt/model/descriptor/repformers.py | 265 +++++-- deepmd/pt/model/descriptor/se_atten.py | 241 +++++- deepmd/pt/model/network/mlp.py | 78 +- deepmd/tf/descriptor/se_atten.py | 5 + deepmd/utils/argcheck.py | 276 +++++-- source/tests/pt/model/test_dpa1.py | 6 +- source/tests/pt/model/test_dpa2.py | 258 ++++++ source/tests/pt/model/test_env_mat.py | 5 + 14 files changed, 3030 insertions(+), 628 deletions(-) create mode 100644 deepmd/dpmodel/descriptor/descriptor.py create mode 100644 deepmd/pt/model/descriptor/repformer_layer_old_impl.py create mode 100644 source/tests/pt/model/test_dpa2.py diff --git a/deepmd/dpmodel/descriptor/descriptor.py b/deepmd/dpmodel/descriptor/descriptor.py new file mode 100644 index 0000000000..f11f406399 --- /dev/null +++ b/deepmd/dpmodel/descriptor/descriptor.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + Callable, + Dict, + List, + Optional, + Union, +) + +import numpy as np + +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.plugin import ( + make_plugin_registry, +) + +log = logging.getLogger(__name__) + + +class DescriptorBlock(ABC, make_plugin_registry("DescriptorBlock")): + """The building block of descriptor. + Given the input descriptor, provide with the atomic coordinates, + atomic types and neighbor list, calculate the new descriptor. + """ + + local_cluster = False + + def __new__(cls, *args, **kwargs): + if cls is DescriptorBlock: + try: + descrpt_type = kwargs["type"] + except KeyError: + raise KeyError("the type of DescriptorBlock should be set by `type`") + cls = cls.get_class_by_type(descrpt_type) + return super().__new__(cls) + + @abstractmethod + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + pass + + @abstractmethod + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + pass + + @abstractmethod + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + pass + + @abstractmethod + def get_ntypes(self) -> int: + """Returns the number of element types.""" + pass + + @abstractmethod + def get_dim_out(self) -> int: + """Returns the output dimension.""" + pass + + @abstractmethod + def get_dim_in(self) -> int: + """Returns the output dimension.""" + pass + + @abstractmethod + def get_dim_emb(self) -> int: + """Returns the embedding dimension.""" + pass + + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + raise NotImplementedError + + def get_stats(self) -> Dict[str, StatItem]: + """Get the statistics of the descriptor.""" + raise NotImplementedError + + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + + @abstractmethod + def call( + self, + nlist: np.ndarray, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + extended_atype_embd: Optional[np.ndarray] = None, + mapping: Optional[np.ndarray] = None, + ): + """Calculate DescriptorBlock.""" + pass diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index bdffd74fbe..4c1d3c6a0a 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -28,6 +28,7 @@ from typing import ( Any, + Callable, List, Optional, Tuple, @@ -49,6 +50,9 @@ from .base_descriptor import ( BaseDescriptor, ) +from .descriptor import ( + DescriptorBlock, +) def np_softmax(x, axis=-1): @@ -245,6 +249,292 @@ def __init__( if ln_eps is None: ln_eps = 1e-5 + self.se_atten = DescrptBlockSeAtten( + rcut, + rcut_smth, + sel, + ntypes, + neuron=neuron, + axis_neuron=axis_neuron, + tebd_dim=tebd_dim, + tebd_input_mode=tebd_input_mode, + set_davg_zero=set_davg_zero, + attn=attn, + attn_layer=attn_layer, + attn_dotr=attn_dotr, + attn_mask=False, + activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + smooth=smooth_type_embedding, + type_one_side=type_one_side, + exclude_types=exclude_types, + env_protection=env_protection, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + ) + self.type_embedding = TypeEmbedNet( + ntypes=ntypes, + neuron=[tebd_dim], + padding=True, + activation_function="Linear", + precision=precision, + ) + self.tebd_dim = tebd_dim + self.concat_output_tebd = concat_output_tebd + self.trainable = trainable + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.se_atten.get_rcut() + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return self.se_atten.get_nsel() + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.se_atten.get_sel() + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.se_atten.get_ntypes() + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + ret = self.se_atten.get_dim_out() + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + return self.se_atten.dim_emb + + def mixed_types(self) -> bool: + """If true, the discriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the discriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return self.se_atten.mixed_types() + + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + return self.get_dim_emb() + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + raise NotImplementedError + + def set_stat_mean_and_stddev( + self, + mean: np.ndarray, + stddev: np.ndarray, + ) -> None: + self.se_atten.mean = mean + self.se_atten.stddev = stddev + + def call( + self, + coord_ext, + atype_ext, + nlist, + mapping: Optional[np.ndarray] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping from extended to lcoal region. not used by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + """ + del mapping + nf, nloc, nnei = nlist.shape + nall = coord_ext.reshape(nf, -1).shape[1] // 3 + # nf x nall x tebd_dim + atype_embd_ext = self.type_embedding.call()[atype_ext] + # nfnl x tebd_dim + atype_embd = atype_embd_ext[:, :nloc, :] + grrg, g2, h2, rot_mat, sw = self.se_atten( + nlist, + coord_ext, + atype_ext, + atype_embd_ext, + mapping=None, + ) + # nf x nloc x (ng x ng1 + tebd_dim) + if self.concat_output_tebd: + grrg = np.concatenate([grrg, atype_embd.reshape(nf, nloc, -1)], axis=-1) + return grrg, rot_mat, None, None, sw + + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + obj = self.se_atten + data = { + "@class": "Descriptor", + "type": "dpa1", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "attn": obj.attn, + "attn_layer": obj.attn_layer, + "attn_dotr": obj.attn_dotr, + "attn_mask": False, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "scaling_factor": obj.scaling_factor, + "normalize": obj.normalize, + "temperature": obj.temperature, + "trainable_ln": obj.trainable_ln, + "ln_eps": obj.ln_eps, + "smooth_type_embedding": obj.smooth, + "type_one_side": obj.type_one_side, + "concat_output_tebd": self.concat_output_tebd, + # make deterministic + "precision": np.dtype(PRECISION_DICT[obj.precision]).name, + "embeddings": obj.embeddings.serialize(), + "attention_layers": obj.dpa1_attention.serialize(), + "env_mat": obj.env_mat.serialize(), + "type_embedding": self.type_embedding.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "@variables": { + "davg": obj["davg"], + "dstd": obj["dstd"], + }, + ## to be updated when the options are supported. + "trainable": self.trainable, + "spin": None, + } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.embeddings_strip.serialize()}) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA1": + """Deserialize from dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None + obj = cls(**data) + + obj.se_atten["davg"] = variables["davg"] + obj.se_atten["dstd"] = variables["dstd"] + obj.se_atten.embeddings = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.se_atten.embeddings_strip = NetworkCollection.deserialize( + embeddings_strip + ) + obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) + obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( + attention_layers + ) + return obj + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True) + + +@DescriptorBlock.register("se_atten") +class DescrptBlockSeAtten(NativeOP, DescriptorBlock): + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], + ntypes: int, + neuron: List[int] = [25, 50, 100], + axis_neuron: int = 8, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + resnet_dt: bool = False, + type_one_side: bool = False, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + exclude_types: List[List[int]] = [], + env_protection: float = 0.0, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + scaling_factor=1.0, + normalize: bool = True, + temperature: Optional[float] = None, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + smooth: bool = True, + ) -> None: self.rcut = rcut self.rcut_smth = rcut_smth if isinstance(sel, int): @@ -258,13 +548,13 @@ def __init__( self.tebd_dim = tebd_dim self.tebd_input_mode = tebd_input_mode self.resnet_dt = resnet_dt - self.trainable = trainable self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.type_one_side = type_one_side self.attn = attn self.attn_layer = attn_layer self.attn_dotr = attn_dotr + self.attn_mask = attn_mask self.exclude_types = exclude_types self.env_protection = env_protection self.set_davg_zero = set_davg_zero @@ -273,18 +563,10 @@ def __init__( self.scaling_factor = scaling_factor self.normalize = normalize self.temperature = temperature - self.smooth = smooth_type_embedding - self.concat_output_tebd = concat_output_tebd + self.smooth = smooth # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) - self.type_embedding = TypeEmbedNet( - ntypes=self.ntypes, - neuron=[self.tebd_dim], - padding=True, - activation_function="Linear", - precision=precision, - ) self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 if self.tebd_input_mode in ["concat"]: self.embd_input_dim = 1 + self.tebd_dim_input @@ -334,52 +616,55 @@ def __init__( wanted_shape = (self.ntypes, self.nnei, 4) self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) - self.davg = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) - self.dstd = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) self.orig_sel = self.sel + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_in(self) -> int: + """Returns the output dimension.""" + return self.dim_in + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_emb(self) -> int: + """Returns the output dimension of embedding.""" + return self.filter_neuron[-1] + def __setitem__(self, key, value): if key in ("avg", "data_avg", "davg"): - self.davg = value + self.mean = value elif key in ("std", "data_std", "dstd"): - self.dstd = value + self.stddev = value else: raise KeyError(key) def __getitem__(self, key): if key in ("avg", "data_avg", "davg"): - return self.davg + return self.mean elif key in ("std", "data_std", "dstd"): - return self.dstd + return self.stddev else: raise KeyError(key) - @property - def dim_out(self): - """Returns the output dimension of this descriptor.""" - return self.get_dim_out() - - def get_dim_out(self): - """Returns the output dimension of this descriptor.""" - return ( - self.neuron[-1] * self.axis_neuron + self.tebd_dim - if self.concat_output_tebd - else self.neuron[-1] * self.axis_neuron - ) - - def get_dim_emb(self): - """Returns the embedding (g2) dimension of this descriptor.""" - return self.neuron[-1] - - def get_rcut(self): - """Returns cutoff radius.""" - return self.rcut - - def get_sel(self): - """Returns cutoff radius.""" - return self.sel - - def mixed_types(self): + def mixed_types(self) -> bool: """If true, the discriptor 1. assumes total number of atoms aligned across frames; 2. requires a neighbor list that does not distinguish different atomic types. @@ -391,22 +676,40 @@ def mixed_types(self): """ return True - def share_params(self, base_class, shared_level, resume=False): - """ - Share the parameters of self to the base_class with shared_level during multitask training. - If not start from checkpoint (resume is False), - some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. - """ - raise NotImplementedError + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.filter_neuron[-1] * self.axis_neuron - def get_ntypes(self) -> int: - """Returns the number of element types.""" - return self.ntypes + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.tebd_dim - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): - """Update mean and stddev for descriptor elements.""" + @property + def dim_emb(self): + """Returns the output dimension of embedding.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.""" raise NotImplementedError + def get_stats(self): + """Get the statistics of the descriptor.""" + raise NotImplementedError + + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def cal_g( self, ss, @@ -430,53 +733,17 @@ def cal_g_strip( gg = self.embeddings_strip[embedding_idx].call(ss) return gg - def reinit_exclude( - self, - exclude_types: List[Tuple[int, int]] = [], - ): - self.exclude_types = exclude_types - self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) - def call( self, - coord_ext, - atype_ext, - nlist, + nlist: np.ndarray, + coord_ext: np.ndarray, + atype_ext: np.ndarray, + atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, ): - """Compute the descriptor. - - Parameters - ---------- - coord_ext - The extended coordinates of atoms. shape: nf x (nallx3) - atype_ext - The extended aotm types. shape: nf x nall - nlist - The neighbor list. shape: nf x nloc x nnei - mapping - The index mapping from extended to lcoal region. not used by this descriptor. - - Returns - ------- - descriptor - The descriptor. shape: nf x nloc x (ng x axis_neuron) - gr - The rotationally equivariant and permutationally invariant single particle - representation. shape: nf x nloc x ng x 3 - g2 - The rotationally invariant pair-partical representation. - this descriptor returns None - h2 - The rotationally equivariant pair-partical representation. - this descriptor returns None - sw - The smooth switch function. - """ - del mapping # nf x nloc x nnei x 4 dmatrix, sw = self.env_mat.call( - coord_ext, atype_ext, nlist, self.davg, self.dstd + coord_ext, atype_ext, nlist, self.mean, self.stddev ) nf, nloc, nnei, _ = dmatrix.shape exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) @@ -486,10 +753,6 @@ def call( dmatrix = dmatrix.reshape(nf * nloc, nnei, 4) # nfnl x nnei x 1 sw = sw.reshape(nf * nloc, nnei, 1) - - # add type embedding into input - # nf x nall x tebd_dim - atype_embd_ext = self.type_embedding.call()[atype_ext] # nfnl x tebd_dim atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, -1) # nfnl x nnei x tebd_dim @@ -504,7 +767,6 @@ def call( atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( nf * nloc, nnei, self.tebd_dim ) - ng = self.neuron[-1] # nfnl x nnei exclude_mask = exclude_mask.reshape(nf * nloc, nnei) @@ -561,101 +823,13 @@ def call( grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( GLOBAL_NP_FLOAT_PRECISION ) - # nf x nloc x (ng x ng1 + tebd_dim) - if self.concat_output_tebd: - grrg = np.concatenate([grrg, atype_embd.reshape(nf, nloc, -1)], axis=-1) - return grrg, gr[..., 1:], None, None, sw - - def serialize(self) -> dict: - """Serialize the descriptor to dict.""" - data = { - "@class": "Descriptor", - "type": "dpa1", - "@version": 1, - "rcut": self.rcut, - "rcut_smth": self.rcut_smth, - "sel": self.sel, - "ntypes": self.ntypes, - "neuron": self.neuron, - "axis_neuron": self.axis_neuron, - "tebd_dim": self.tebd_dim, - "tebd_input_mode": self.tebd_input_mode, - "set_davg_zero": self.set_davg_zero, - "attn": self.attn, - "attn_layer": self.attn_layer, - "attn_dotr": self.attn_dotr, - "attn_mask": False, - "activation_function": self.activation_function, - "resnet_dt": self.resnet_dt, - "scaling_factor": self.scaling_factor, - "normalize": self.normalize, - "temperature": self.temperature, - "trainable_ln": self.trainable_ln, - "ln_eps": self.ln_eps, - "smooth_type_embedding": self.smooth, - "type_one_side": self.type_one_side, - "concat_output_tebd": self.concat_output_tebd, - # make deterministic - "precision": np.dtype(PRECISION_DICT[self.precision]).name, - "embeddings": self.embeddings.serialize(), - "attention_layers": self.dpa1_attention.serialize(), - "env_mat": self.env_mat.serialize(), - "type_embedding": self.type_embedding.serialize(), - "exclude_types": self.exclude_types, - "env_protection": self.env_protection, - "@variables": { - "davg": self.davg, - "dstd": self.dstd, - }, - ## to be updated when the options are supported. - "trainable": True, - "spin": None, - } - if self.tebd_input_mode in ["strip"]: - data.update({"embeddings_strip": self.embeddings_strip.serialize()}) - return data - - @classmethod - def deserialize(cls, data: dict) -> "DescrptDPA1": - """Deserialize from dict.""" - data = data.copy() - check_version_compatibility(data.pop("@version"), 1, 1) - data.pop("@class") - data.pop("type") - variables = data.pop("@variables") - embeddings = data.pop("embeddings") - type_embedding = data.pop("type_embedding") - attention_layers = data.pop("attention_layers") - env_mat = data.pop("env_mat") - tebd_input_mode = data["tebd_input_mode"] - if tebd_input_mode in ["strip"]: - embeddings_strip = data.pop("embeddings_strip") - else: - embeddings_strip = None - obj = cls(**data) - - obj["davg"] = variables["davg"] - obj["dstd"] = variables["dstd"] - obj.embeddings = NetworkCollection.deserialize(embeddings) - if tebd_input_mode in ["strip"]: - obj.embeddings_strip = NetworkCollection.deserialize(embeddings_strip) - obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) - obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) - return obj - - @classmethod - def update_sel(cls, global_jdata: dict, local_jdata: dict): - """Update the selection and perform neighbor statistics. - - Parameters - ---------- - global_jdata : dict - The global data, containing the training section - local_jdata : dict - The local data refer to the current class - """ - local_jdata_cpy = local_jdata.copy() - return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True) + return ( + grrg.reshape(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), + gg.reshape(-1, nloc, self.nnei, self.filter_neuron[-1]), + dmatrix.reshape(-1, nloc, self.nnei, 4)[..., 1:], + gr[..., 1:].reshape(-1, nloc, self.filter_neuron[-1], 3), + sw, + ) class NeighborGatedAttention(NativeOP): @@ -838,7 +1012,7 @@ def call( sw: Optional[np.ndarray] = None, ): residual = x - x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x, _ = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) x = residual + x x = self.attn_layer_norm(x) return x @@ -891,6 +1065,7 @@ def __init__( nnei: int, embed_dim: int, hidden_dim: int, + num_heads: int = 1, dotr: bool = False, do_mask: bool = False, scaling_factor: float = 1.0, @@ -900,11 +1075,14 @@ def __init__( smooth: bool = True, precision: str = DEFAULT_PRECISION, ): - """Construct a neighbor-wise attention net.""" + """Construct a multi-head neighbor-wise attention net.""" super().__init__() + assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" self.nnei = nnei self.embed_dim = embed_dim self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads self.dotr = dotr self.do_mask = do_mask self.bias = bias @@ -912,10 +1090,11 @@ def __init__( self.scaling_factor = scaling_factor self.temperature = temperature self.precision = precision - if temperature is None: - self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 - else: - self.scaling = temperature + self.scaling = ( + (self.head_dim * scaling_factor) ** -0.5 + if temperature is None + else temperature + ) self.normalize = normalize self.in_proj = NativeLayer( embed_dim, @@ -936,41 +1115,55 @@ def call(self, query, nei_mask, input_r=None, sw=None, attnw_shift=20.0): # Linear projection q, k, v = np.split(self.in_proj(query), 3, axis=-1) # Reshape and normalize - q = q.reshape(-1, self.nnei, self.hidden_dim) - k = k.reshape(-1, self.nnei, self.hidden_dim) - v = v.reshape(-1, self.nnei, self.hidden_dim) + # (nf x nloc) x num_heads x nnei x head_dim + q = q.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) + k = k.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) + v = v.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) if self.normalize: q = np_normalize(q, axis=-1) k = np_normalize(k, axis=-1) v = np_normalize(v, axis=-1) q = q * self.scaling # Attention weights - attn_weights = q @ k.transpose(0, 2, 1) + # (nf x nloc) x num_heads x nnei x nnei + attn_weights = q @ k.transpose(0, 1, 3, 2) nei_mask = nei_mask.reshape(-1, self.nnei) if self.smooth: - sw = sw.reshape(-1, self.nnei) - attn_weights = (attn_weights + attnw_shift) * sw[:, None, :] * sw[ - :, :, None + sw = sw.reshape(-1, 1, self.nnei) + attn_weights = (attn_weights + attnw_shift) * sw[:, :, :, None] * sw[ + :, :, None, : ] - attnw_shift else: - attn_weights = np.where(nei_mask[:, None, :], attn_weights, -np.inf) + attn_weights = np.where(nei_mask[:, None, None, :], attn_weights, -np.inf) attn_weights = np_softmax(attn_weights, axis=-1) - attn_weights = np.where(nei_mask[:, :, None], attn_weights, 0.0) + attn_weights = np.where(nei_mask[:, None, :, None], attn_weights, 0.0) if self.smooth: - attn_weights = attn_weights * sw[:, None, :] * sw[:, :, None] + attn_weights = attn_weights * sw[:, :, :, None] * sw[:, :, None, :] if self.dotr: - angular_weight = input_r @ input_r.transpose(0, 2, 1) + angular_weight = (input_r @ input_r.transpose(0, 2, 1)).reshape( + -1, 1, self.nnei, self.nnei + ) attn_weights = attn_weights * angular_weight # Output projection + # (nf x nloc) x num_heads x nnei x head_dim o = attn_weights @ v + # (nf x nloc) x nnei x (num_heads x head_dim) + o = o.transpose(0, 2, 1, 3).reshape(-1, self.nnei, self.hidden_dim) output = self.out_proj(o) - return output + return output, attn_weights def serialize(self): return { "nnei": self.nnei, "embed_dim": self.embed_dim, "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, "dotr": self.dotr, "do_mask": self.do_mask, "scaling_factor": self.scaling_factor, diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 5caf583a23..52895d3f72 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -261,6 +261,7 @@ def __init__( self.type_embedding = TypeEmbedNet(ntypes, tebd_dim, precision=precision) self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd + self.trainable = trainable # set trainable for param in self.parameters(): param.requires_grad = trainable @@ -406,8 +407,7 @@ def serialize(self) -> dict: "davg": obj["davg"].detach().cpu().numpy(), "dstd": obj["dstd"].detach().cpu().numpy(), }, - ## to be updated when the options are supported. - "trainable": True, + "trainable": self.trainable, "spin": None, } if obj.tebd_input_mode in ["strip"]: diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index fb792a51e2..842c924662 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -9,10 +9,18 @@ import torch -from deepmd.pt.model.network.network import ( +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.pt.model.network.mlp import ( Identity, - Linear, + MLPLayer, + NetworkCollection, +) +from deepmd.pt.model.network.network import ( TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pt.utils import ( + env, ) from deepmd.pt.utils.nlist import ( build_multiple_neighbor_list, @@ -24,10 +32,16 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .base_descriptor import ( BaseDescriptor, ) +from .repformer_layer import ( + RepformerLayer, +) from .repformers import ( DescrptBlockRepformers, ) @@ -37,9 +51,10 @@ @BaseDescriptor.register("dpa2") -class DescrptDPA2(torch.nn.Module, BaseDescriptor): +class DescrptDPA2(BaseDescriptor, torch.nn.Module): def __init__( self, + # args for repinit ntypes: int, repinit_rcut: float, repinit_rcut_smth: float, @@ -47,19 +62,19 @@ def __init__( repformer_rcut: float, repformer_rcut_smth: float, repformer_nsel: int, - # kwargs - tebd_dim: int = 8, - concat_output_tebd: bool = True, + # kwargs for repinit repinit_neuron: List[int] = [25, 50, 100], repinit_axis_neuron: int = 16, - repinit_set_davg_zero: bool = True, # TODO - repinit_activation="tanh", - # repinit still unclear: - # ffn, ffn_embed_dim, scaling_factor, normalize, + repinit_tebd_dim: int = 8, + repinit_tebd_input_mode: str = "concat", + repinit_set_davg_zero: bool = True, + repinit_activation_function="tanh", + # kwargs for repformer repformer_nlayers: int = 3, repformer_g1_dim: int = 128, repformer_g2_dim: int = 16, - repformer_axis_dim: int = 4, + repformer_axis_neuron: int = 4, + repformer_direct_dist: bool = False, repformer_do_bn_mode: str = "no", repformer_bn_momentum: float = 0.1, repformer_update_g1_has_conv: bool = True, @@ -74,113 +89,164 @@ def __init__( repformer_attn2_hidden: int = 16, repformer_attn2_nhead: int = 4, repformer_attn2_has_gate: bool = False, - repformer_activation: str = "tanh", + repformer_activation_function: str = "tanh", repformer_update_style: str = "res_avg", - repformer_set_davg_zero: bool = True, # TODO - repformer_add_type_ebd_to_seq: bool = False, + repformer_set_davg_zero: bool = True, + # kwargs for descriptor + concat_output_tebd: bool = True, + precision: str = "float64", + smooth: bool = True, + exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, trainable: bool = True, - exclude_types: List[Tuple[int, int]] = [], - type: Optional[ - str - ] = None, # work around the bad design in get_trainer and DpLoaderSet! - rcut: Optional[ - float - ] = None, # work around the bad design in get_trainer and DpLoaderSet! - rcut_smth: Optional[ - float - ] = None, # work around the bad design in get_trainer and DpLoaderSet! - sel: Optional[ - int - ] = None, # work around the bad design in get_trainer and DpLoaderSet! + seed: Optional[int] = None, + resnet_dt: bool = False, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + type_one_side: bool = False, + add_tebd_to_repinit_out: bool = False, + old_impl: bool = False, ): r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492. Parameters ---------- - ntypes : int - Number of atom types repinit_rcut : float - The cut-off radius of the repinit block + (Used in the repinit block.) + The cut-off radius. repinit_rcut_smth : float - From this position the inverse distance smoothly decays - to 0 at the cut-off. Use in the repinit block. + (Used in the repinit block.) + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. repinit_nsel : int - Maximally possible number of neighbors for repinit block. + (Used in the repinit block.) + Maximally possible number of selected neighbors. + repinit_neuron : list, optional + (Used in the repinit block.) + Number of neurons in each hidden layers of the embedding net. + When two layers are of the same size or one layer is twice as large as the previous layer, + a skip connection is built. + repinit_axis_neuron : int, optional + (Used in the repinit block.) + Size of the submatrix of G (embedding matrix). + repinit_tebd_dim : int, optional + (Used in the repinit block.) + The dimension of atom type embedding. + repinit_tebd_input_mode : str, optional + (Used in the repinit block.) + The input mode of the type embedding. Supported modes are ['concat', 'strip']. + repinit_set_davg_zero : bool, optional + (Used in the repinit block.) + Set the normalization average to zero. + repinit_activation_function : str, optional + (Used in the repinit block.) + The activation function in the embedding net. repformer_rcut : float - The cut-off radius of the repformer block + (Used in the repformer block.) + The cut-off radius. repformer_rcut_smth : float - From this position the inverse distance smoothly decays - to 0 at the cut-off. Use in the repformer block. + (Used in the repformer block.) + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. repformer_nsel : int - Maximally possible number of neighbors for repformer block. - tebd_dim : int - The dimension of atom type embedding - concat_output_tebd : bool + (Used in the repformer block.) + Maximally possible number of selected neighbors. + repformer_nlayers : int, optional + (Used in the repformer block.) + Number of repformer layers. + repformer_g1_dim : int, optional + (Used in the repformer block.) + Dimension of the first graph convolution layer. + repformer_g2_dim : int, optional + (Used in the repformer block.) + Dimension of the second graph convolution layer. + repformer_axis_neuron : int, optional + (Used in the repformer block.) + Size of the submatrix of G (embedding matrix). + repformer_direct_dist : bool, optional + (Used in the repformer block.) + Whether to use direct distance information (1/r term) in the repformer block. + repformer_do_bn_mode : str, optional + (Used in the repformer block.) + The mode to do batch normalization in the repformer layers. Supported modes are: + -'no': Not do batch normalization. + -'uniform': Do batch normalization using scalar running momentum and learnable gamma/beta (num_features=1). + -'component': Do batch normalization using vector running momentum and learnable gamma/beta (num_features=d). + repformer_bn_momentum : float, optional + (Used in the repformer block.) + Momentum used in the batch normalization. + repformer_update_g1_has_conv : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with convolution term. + repformer_update_g1_has_drrd : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the drrd term. + repformer_update_g1_has_grrg : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the grrg term. + repformer_update_g1_has_attn : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the localized self-attention. + repformer_update_g2_has_g1g1 : bool, optional + (Used in the repformer block.) + Whether to update the g2 rep with the g1xg1 term. + repformer_update_g2_has_attn : bool, optional + (Used in the repformer block.) + Whether to update the g2 rep with the gated self-attention. + repformer_update_h2 : bool, optional + (Used in the repformer block.) + Whether to update the h2 rep. + repformer_attn1_hidden : int, optional + (Used in the repformer block.) + The hidden dimension of localized self-attention to update the g1 rep. + repformer_attn1_nhead : int, optional + (Used in the repformer block.) + The number of heads in localized self-attention to update the g1 rep. + repformer_attn2_hidden : int, optional + (Used in the repformer block.) + The hidden dimension of gated self-attention to update the g2 rep. + repformer_attn2_nhead : int, optional + (Used in the repformer block.) + The number of heads in gated self-attention to update the g2 rep. + repformer_attn2_has_gate : bool, optional + (Used in the repformer block.) + Whether to use gate in the gated self-attention to update the g2 rep. + repformer_activation_function : str, optional + (Used in the repformer block.) + The activation function in the embedding net. + repformer_update_style : str, optional + (Used in the repformer block.) + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + repformer_set_davg_zero : bool, optional + (Used in the repformer block.) + Set the normalization average to zero. + concat_output_tebd : bool, optional Whether to concat type embedding at the output of the descriptor. - repinit_neuron : List[int] - repinit block: the number of neurons in the embedding net. - repinit_axis_neuron : int - repinit block: the number of dimension of split in the - symmetrization op. - repinit_activation : str - repinit block: the activation function in the embedding net - repformer_nlayers : int - repformers block: the number of repformer layers - repformer_g1_dim : int - repformers block: the dimension of single-atom rep - repformer_g2_dim : int - repformers block: the dimension of invariant pair-atom rep - repformer_axis_dim : int - repformers block: the number of dimension of split in the - symmetrization ops. - repformer_do_bn_mode : bool - repformers block: do batch norm in the repformer layers - repformer_bn_momentum : float - repformers block: moment in the batch normalization - repformer_update_g1_has_conv : bool - repformers block: update the g1 rep with convolution term - repformer_update_g1_has_drrd : bool - repformers block: update the g1 rep with the drrd term - repformer_update_g1_has_grrg : bool - repformers block: update the g1 rep with the grrg term - repformer_update_g1_has_attn : bool - repformers block: update the g1 rep with the localized - self-attention - repformer_update_g2_has_g1g1 : bool - repformers block: update the g2 rep with the g1xg1 term - repformer_update_g2_has_attn : bool - repformers block: update the g2 rep with the gated self-attention - repformer_update_h2 : bool - repformers block: update the h2 rep - repformer_attn1_hidden : int - repformers block: the hidden dimension of localized self-attention - repformer_attn1_nhead : int - repformers block: the number of heads in localized self-attention - repformer_attn2_hidden : int - repformers block: the hidden dimension of gated self-attention - repformer_attn2_nhead : int - repformers block: the number of heads in gated self-attention - repformer_attn2_has_gate : bool - repformers block: has gate in the gated self-attention - repformer_activation : str - repformers block: the activation function in the MLPs. - repformer_update_style : str - repformers block: style of update a rep. - can be res_avg or res_incr. - res_avg updates a rep `u` with: - u = 1/\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) - res_incr updates a rep `u` with: - u = u + 1/\sqrt{n} (u_1 + u_2 + ... + u_n) - repformer_set_davg_zero : bool - repformers block: set the avg to zero in statistics - repformer_add_type_ebd_to_seq : bool - repformers block: concatenate the type embedding at the output. - trainable : bool - If the parameters in the descriptor are trainable. - exclude_types : List[Tuple[int, int]] = [], + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : List[List[int]], optional The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable : bool, optional + If the parameters are trainable. + seed : int, optional + (Unused yet) Random seed for parameter initialization. + resnet_dt : bool, optional + Whether to use a "Timestep" in the skip connection. + trainable_ln : bool, optional + Whether to use trainable shift and scale weights in layer normalization. + ln_eps : float, optional + The epsilon value for layer normalization. + type_one_side : bool, optional + Whether to use one-side type embedding. + add_tebd_to_repinit_out : bool, optional + Whether to add type embedding to the output representation from repinit before inputting it into repformer. Returns ------- @@ -198,7 +264,6 @@ def __init__( """ super().__init__() - del type, rcut, rcut_smth, sel self.repinit = DescrptBlockSeAtten( repinit_rcut, repinit_rcut_smth, @@ -207,13 +272,18 @@ def __init__( attn_layer=0, neuron=repinit_neuron, axis_neuron=repinit_axis_neuron, - tebd_dim=tebd_dim, - tebd_input_mode="concat", - # tebd_input_mode='dot_residual_s', + tebd_dim=repinit_tebd_dim, + tebd_input_mode=repinit_tebd_input_mode, set_davg_zero=repinit_set_davg_zero, exclude_types=exclude_types, env_protection=env_protection, - activation_function=repinit_activation, + activation_function=repinit_activation_function, + precision=precision, + resnet_dt=resnet_dt, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + smooth=smooth, + type_one_side=type_one_side, ) self.repformers = DescrptBlockRepformers( repformer_rcut, @@ -223,8 +293,8 @@ def __init__( nlayers=repformer_nlayers, g1_dim=repformer_g1_dim, g2_dim=repformer_g2_dim, - axis_dim=repformer_axis_dim, - direct_dist=False, + axis_neuron=repformer_axis_neuron, + direct_dist=repformer_direct_dist, do_bn_mode=repformer_do_bn_mode, bn_momentum=repformer_bn_momentum, update_g1_has_conv=repformer_update_g1_has_conv, @@ -239,28 +309,47 @@ def __init__( attn2_hidden=repformer_attn2_hidden, attn2_nhead=repformer_attn2_nhead, attn2_has_gate=repformer_attn2_has_gate, - activation_function=repformer_activation, + activation_function=repformer_activation_function, update_style=repformer_update_style, set_davg_zero=repformer_set_davg_zero, - smooth=True, - add_type_ebd_to_seq=repformer_add_type_ebd_to_seq, + smooth=smooth, exclude_types=exclude_types, env_protection=env_protection, + precision=precision, + resnet_dt=resnet_dt, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + old_impl=old_impl, ) - self.type_embedding = TypeEmbedNet(ntypes, tebd_dim) + self.type_embedding = TypeEmbedNet( + ntypes, repinit_tebd_dim, precision=precision + ) + self.concat_output_tebd = concat_output_tebd + self.precision = precision + self.smooth = smooth + self.exclude_types = exclude_types + self.env_protection = env_protection + self.trainable = trainable + self.resnet_dt = resnet_dt + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.type_one_side = type_one_side + self.add_tebd_to_repinit_out = add_tebd_to_repinit_out + if self.repinit.dim_out == self.repformers.dim_in: self.g1_shape_tranform = Identity() else: - self.g1_shape_tranform = Linear( + self.g1_shape_tranform = MLPLayer( self.repinit.dim_out, self.repformers.dim_in, bias=False, + precision=precision, init="glorot", ) assert self.repinit.rcut > self.repformers.rcut assert self.repinit.sel[0] > self.repformers.sel[0] - self.concat_output_tebd = concat_output_tebd - self.tebd_dim = tebd_dim + + self.tebd_dim = repinit_tebd_dim self.rcut = self.repinit.get_rcut() self.ntypes = ntypes self.sel = self.repinit.sel @@ -381,13 +470,137 @@ def compute_input_stats( descrpt.compute_input_stats(merged, path) def serialize(self) -> dict: - """Serialize the obj to dict.""" - raise NotImplementedError + repinit = self.repinit + repformers = self.repformers + data = { + "@class": "Descriptor", + "type": "dpa2", + "@version": 1, + "ntypes": self.ntypes, + "repinit_rcut": repinit.rcut, + "repinit_rcut_smth": repinit.rcut_smth, + "repinit_nsel": repinit.sel, + "repformer_rcut": repformers.rcut, + "repformer_rcut_smth": repformers.rcut_smth, + "repformer_nsel": repformers.sel, + "repinit_neuron": repinit.neuron, + "repinit_axis_neuron": repinit.axis_neuron, + "repinit_tebd_dim": repinit.tebd_dim, + "repinit_tebd_input_mode": repinit.tebd_input_mode, + "repinit_set_davg_zero": repinit.set_davg_zero, + "repinit_activation_function": repinit.activation_function, + "repformer_nlayers": repformers.nlayers, + "repformer_g1_dim": repformers.g1_dim, + "repformer_g2_dim": repformers.g2_dim, + "repformer_axis_neuron": repformers.axis_neuron, + "repformer_direct_dist": repformers.direct_dist, + "repformer_do_bn_mode": repformers.do_bn_mode, + "repformer_bn_momentum": repformers.bn_momentum, + "repformer_update_g1_has_conv": repformers.update_g1_has_conv, + "repformer_update_g1_has_drrd": repformers.update_g1_has_drrd, + "repformer_update_g1_has_grrg": repformers.update_g1_has_grrg, + "repformer_update_g1_has_attn": repformers.update_g1_has_attn, + "repformer_update_g2_has_g1g1": repformers.update_g2_has_g1g1, + "repformer_update_g2_has_attn": repformers.update_g2_has_attn, + "repformer_update_h2": repformers.update_h2, + "repformer_attn1_hidden": repformers.attn1_hidden, + "repformer_attn1_nhead": repformers.attn1_nhead, + "repformer_attn2_hidden": repformers.attn2_hidden, + "repformer_attn2_nhead": repformers.attn2_nhead, + "repformer_attn2_has_gate": repformers.attn2_has_gate, + "repformer_activation_function": repformers.activation_function, + "repformer_update_style": repformers.update_style, + "repformer_set_davg_zero": repformers.set_davg_zero, + "concat_output_tebd": self.concat_output_tebd, + "precision": self.precision, + "smooth": self.smooth, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "trainable": self.trainable, + "resnet_dt": self.resnet_dt, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "type_one_side": self.type_one_side, + "add_tebd_to_repinit_out": self.add_tebd_to_repinit_out, + "type_embedding": self.type_embedding.embedding.serialize(), + "g1_shape_tranform": self.g1_shape_tranform.serialize(), + } + repinit_variable = { + "embeddings": repinit.filter_layers.serialize(), + "env_mat": DPEnvMat(repinit.rcut, repinit.rcut_smth).serialize(), + "@variables": { + "davg": repinit["davg"].detach().cpu().numpy(), + "dstd": repinit["dstd"].detach().cpu().numpy(), + }, + } + if repinit.tebd_input_mode in ["strip"]: + repinit_variable.update( + {"embeddings_strip": repinit.filter_layers_strip.serialize()} + ) + repformers_variable = { + "g2_embd": repformers.g2_embd.serialize(), + "repformer_layers": [layer.serialize() for layer in repformers.layers], + "env_mat": DPEnvMat(repformers.rcut, repformers.rcut_smth).serialize(), + "@variables": { + "davg": repformers["davg"].detach().cpu().numpy(), + "dstd": repformers["dstd"].detach().cpu().numpy(), + }, + } + data.update( + { + "repinit": repinit_variable, + "repformers": repformers_variable, + } + ) + return data @classmethod - def deserialize(cls) -> "DescrptDPA2": - """Deserialize from a dict.""" - raise NotImplementedError + def deserialize(cls, data: dict) -> "DescrptDPA2": + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + repinit_variable = data.pop("repinit").copy() + repformers_variable = data.pop("repformers").copy() + type_embedding = data.pop("type_embedding") + g1_shape_tranform = data.pop("g1_shape_tranform") + obj = cls(**data) + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + if obj.repinit.dim_out != obj.repformers.dim_in: + obj.g1_shape_tranform = MLPLayer.deserialize(g1_shape_tranform) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.repinit.prec, device=env.DEVICE) + + # deserialize repinit + statistic_repinit = repinit_variable.pop("@variables") + env_mat = repinit_variable.pop("env_mat") + tebd_input_mode = data["repinit_tebd_input_mode"] + obj.repinit.filter_layers = NetworkCollection.deserialize( + repinit_variable.pop("embeddings") + ) + if tebd_input_mode in ["strip"]: + obj.repinit.filter_layers_strip = NetworkCollection.deserialize( + repinit_variable.pop("embeddings_strip") + ) + obj.repinit["davg"] = t_cvt(statistic_repinit["davg"]) + obj.repinit["dstd"] = t_cvt(statistic_repinit["dstd"]) + + # deserialize repformers + statistic_repformers = repformers_variable.pop("@variables") + env_mat = repformers_variable.pop("env_mat") + repformer_layers = repformers_variable.pop("repformer_layers") + obj.repformers.g2_embd = MLPLayer.deserialize( + repformers_variable.pop("g2_embd") + ) + obj.repformers["davg"] = t_cvt(statistic_repformers["davg"]) + obj.repformers["dstd"] = t_cvt(statistic_repformers["dstd"]) + obj.repformers.layers = torch.nn.ModuleList( + [RepformerLayer.deserialize(layer) for layer in repformer_layers] + ) + return obj def forward( self, diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index a58d6b0e2c..9ce590f10e 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -2,12 +2,16 @@ from typing import ( Callable, List, + Optional, ) import torch -from deepmd.pt.model.network.network import ( - SimpleLinear, +from deepmd.pt.model.network.layernorm import ( + LayerNorm, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, ) from deepmd.pt.utils import ( env, @@ -15,27 +19,41 @@ from deepmd.pt.utils.utils import ( ActivationFn, ) +from deepmd.utils.version import ( + check_version_compatibility, +) -def torch_linear(*args, **kwargs): - return torch.nn.Linear( - *args, **kwargs, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - - +# common ops def _make_nei_g1( g1_ext: torch.Tensor, nlist: torch.Tensor, ) -> torch.Tensor: - # nlist: nb x nloc x nnei + """ + Make neighbor-wise atomic invariant rep. + + Parameters + ---------- + g1_ext + Extended atomic invariant rep, with shape nf x nall x ng1. + nlist + Neighbor list, with shape nf x nloc x nnei. + + Returns + ------- + gg1: torch.Tensor + Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + + """ + # nlist: nf x nloc x nnei nb, nloc, nnei = nlist.shape - # g1_ext: nb x nall x ng1 + # g1_ext: nf x nall x ng1 ng1 = g1_ext.shape[-1] - # index: nb x (nloc x nnei) x ng1 + # index: nf x (nloc x nnei) x ng1 index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) - # gg1 : nb x (nloc x nnei) x ng1 + # gg1 : nf x (nloc x nnei) x ng1 gg1 = torch.gather(g1_ext, dim=1, index=index) - # gg1 : nb x nloc x nnei x ng1 + # gg1 : nf x nloc x nnei x ng1 gg1 = gg1.view(nb, nloc, nnei, ng1) return gg1 @@ -44,82 +62,92 @@ def _apply_nlist_mask( gg: torch.Tensor, nlist_mask: torch.Tensor, ) -> torch.Tensor: - # gg: nf x nloc x nnei x ng + """ + Apply nlist mask to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + """ + # gg: nf x nloc x nnei x d # msk: nf x nloc x nnei return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0) def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: - # gg: nf x nloc x nnei x ng - # sw: nf x nloc x nnei - return gg * sw.unsqueeze(-1) - + """ + Apply switch function to neighbor-wise rep tensors. -def _apply_h_norm( - hh: torch.Tensor, # nf x nloc x nnei x 3 -) -> torch.Tensor: - """Normalize h by the std of vector length. - do not have an idea if this is a good way. + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. """ - nf, nl, nnei, _ = hh.shape - # nf x nloc x nnei - normh = torch.linalg.norm(hh, dim=-1) - # nf x nloc - std = torch.std(normh, dim=-1) - # nf x nloc x nnei x 3 - hh = hh[:, :, :, :] / (1.0 + std[:, :, None, None]) - return hh + # gg: nf x nloc x nnei x d + # sw: nf x nloc x nnei + return gg * sw.unsqueeze(-1) class Atten2Map(torch.nn.Module): def __init__( self, - ni: int, - nd: int, - nh: int, + input_dim: int, + hidden_dim: int, + head_num: int, has_gate: bool = False, # apply gate to attn map smooth: bool = True, attnw_shift: float = 20.0, + precision: str = "float64", ): + """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" super().__init__() - self.ni = ni - self.nd = nd - self.nh = nh - self.mapqk = SimpleLinear(ni, nd * 2 * nh, bias=False) + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapqk = MLPLayer( + input_dim, hidden_dim * 2 * head_num, bias=False, precision=precision + ) self.has_gate = has_gate self.smooth = smooth self.attnw_shift = attnw_shift + self.precision = precision def forward( self, - g2: torch.Tensor, # nb x nloc x nnei x ng2 - h2: torch.Tensor, # nb x nloc x nnei x 3 - nlist_mask: torch.Tensor, # nb x nloc x nnei - sw: torch.Tensor, # nb x nloc x nnei + g2: torch.Tensor, # nf x nloc x nnei x ng2 + h2: torch.Tensor, # nf x nloc x nnei x 3 + nlist_mask: torch.Tensor, # nf x nloc x nnei + sw: torch.Tensor, # nf x nloc x nnei ) -> torch.Tensor: ( - nb, + nf, nloc, nnei, _, ) = g2.shape - nd, nh = self.nd, self.nh - # nb x nloc x nnei x nd x (nh x 2) - g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2) - # nb x nloc x (nh x 2) x nnei x nd + nd, nh = self.hidden_dim, self.head_num + # nf x nloc x nnei x nd x (nh x 2) + g2qk = self.mapqk(g2).view(nf, nloc, nnei, nd, nh * 2) + # nf x nloc x (nh x 2) x nnei x nd g2qk = torch.permute(g2qk, (0, 1, 4, 2, 3)) - # nb x nloc x nh x nnei x nd + # nf x nloc x nh x nnei x nd g2q, g2k = torch.split(g2qk, nh, dim=2) # g2q = torch.nn.functional.normalize(g2q, dim=-1) # g2k = torch.nn.functional.normalize(g2k, dim=-1) - # nb x nloc x nh x nnei x nnei + # nf x nloc x nh x nnei x nnei attnw = torch.matmul(g2q, torch.transpose(g2k, -1, -2)) / nd**0.5 if self.has_gate: gate = torch.matmul(h2, torch.transpose(h2, -1, -2)).unsqueeze(-3) attnw = attnw * gate - # mask the attenmap, nb x nloc x 1 x 1 x nnei + # mask the attenmap, nf x nloc x 1 x 1 x nnei attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2) - # mask the attenmap, nb x nloc x 1 x nnei x 1 + # mask the attenmap, nf x nloc x 1 x nnei x 1 attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1) if self.smooth: attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ @@ -135,34 +163,76 @@ def forward( attnw_mask, 0.0, ) - # nb x nloc x nh x nnei x nnei + # nf x nloc x nh x nnei x nnei attnw = attnw.masked_fill( attnw_mask_c, 0.0, ) if self.smooth: attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] - # nb x nloc x nnei x nnei + # nf x nloc x nnei x nnei h2h2t = torch.matmul(h2, torch.transpose(h2, -1, -2)) / 3.0**0.5 - # nb x nloc x nh x nnei x nnei + # nf x nloc x nh x nnei x nnei ret = attnw * h2h2t[:, :, None, :, :] # ret = torch.softmax(g2qk, dim=-1) - # nb x nloc x nnei x nnei x nh + # nf x nloc x nnei x nnei x nh ret = torch.permute(ret, (0, 1, 3, 4, 2)) return ret + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2Map", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "has_gate": self.has_gate, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapqk": self.mapqk.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2Map": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapqk = data.pop("mapqk") + obj = cls(**data) + obj.mapqk = MLPLayer.deserialize(mapqk) + return obj + class Atten2MultiHeadApply(torch.nn.Module): def __init__( self, - ni: int, - nh: int, + input_dim: int, + head_num: int, + precision: str = "float64", ): super().__init__() - self.ni = ni - self.nh = nh - self.mapv = SimpleLinear(ni, ni * nh, bias=False) - self.head_map = SimpleLinear(ni * nh, ni) + self.input_dim = input_dim + self.head_num = head_num + self.mapv = MLPLayer( + input_dim, input_dim * head_num, bias=False, precision=precision + ) + self.head_map = MLPLayer(input_dim * head_num, input_dim, precision=precision) + self.precision = precision def forward( self, @@ -170,7 +240,7 @@ def forward( g2: torch.Tensor, # nf x nloc x nnei x ng2 ) -> torch.Tensor: nf, nloc, nnei, ng2 = g2.shape - nh = self.nh + nh = self.head_num # nf x nloc x nnei x ng2 x nh g2v = self.mapv(g2).view(nf, nloc, nnei, ng2, nh) # nf x nloc x nh x nnei x ng2 @@ -185,17 +255,56 @@ def forward( # nf x nloc x nnei x ng2 return self.head_map(ret) + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2MultiHeadApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "mapv": self.mapv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2MultiHeadApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapv = data.pop("mapv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapv = MLPLayer.deserialize(mapv) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + class Atten2EquiVarApply(torch.nn.Module): def __init__( self, - ni: int, - nh: int, + input_dim: int, + head_num: int, + precision: str = "float64", ): super().__init__() - self.ni = ni - self.nh = nh - self.head_map = SimpleLinear(nh, 1, bias=False) + self.input_dim = input_dim + self.head_num = head_num + self.head_map = MLPLayer(head_num, 1, bias=False, precision=precision) + self.precision = precision def forward( self, @@ -203,7 +312,7 @@ def forward( h2: torch.Tensor, # nf x nloc x nnei x 3 ) -> torch.Tensor: nf, nloc, nnei, _ = h2.shape - nh = self.nh + nh = self.head_num # nf x nloc x nh x nnei x nnei AA = torch.permute(AA, (0, 1, 4, 2, 3)) h2m = torch.unsqueeze(h2, dim=2) @@ -216,42 +325,85 @@ def forward( # nf x nloc x nnei x 3 return torch.squeeze(self.head_map(ret), dim=-1) + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2EquiVarApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2EquiVarApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + head_map = data.pop("head_map") + obj = cls(**data) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + class LocalAtten(torch.nn.Module): def __init__( self, - ni: int, - nd: int, - nh: int, + input_dim: int, + hidden_dim: int, + head_num: int, smooth: bool = True, attnw_shift: float = 20.0, + precision: str = "float64", ): super().__init__() - self.ni = ni - self.nd = nd - self.nh = nh - self.mapq = SimpleLinear(ni, nd * 1 * nh, bias=False) - self.mapkv = SimpleLinear(ni, (nd + ni) * nh, bias=False) - self.head_map = SimpleLinear(ni * nh, ni) + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapq = MLPLayer( + input_dim, hidden_dim * 1 * head_num, bias=False, precision=precision + ) + self.mapkv = MLPLayer( + input_dim, + (hidden_dim + input_dim) * head_num, + bias=False, + precision=precision, + ) + self.head_map = MLPLayer(input_dim * head_num, input_dim, precision=precision) self.smooth = smooth self.attnw_shift = attnw_shift + self.precision = precision def forward( self, - g1: torch.Tensor, # nb x nloc x ng1 - gg1: torch.Tensor, # nb x nloc x nnei x ng1 - nlist_mask: torch.Tensor, # nb x nloc x nnei - sw: torch.Tensor, # nb x nloc x nnei + g1: torch.Tensor, # nf x nloc x ng1 + gg1: torch.Tensor, # nf x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nf x nloc x nnei + sw: torch.Tensor, # nf x nloc x nnei ) -> torch.Tensor: nb, nloc, nnei = nlist_mask.shape - ni, nd, nh = self.ni, self.nd, self.nh + ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num assert ni == g1.shape[-1] assert ni == gg1.shape[-1] - # nb x nloc x nd x nh + # nf x nloc x nd x nh g1q = self.mapq(g1).view(nb, nloc, nd, nh) - # nb x nloc x nh x nd + # nf x nloc x nh x nd g1q = torch.permute(g1q, (0, 1, 3, 2)) - # nb x nloc x nnei x (nd+ni) x nh + # nf x nloc x nnei x (nd+ni) x nh gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh) gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3)) # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 @@ -287,6 +439,49 @@ def forward( ret = self.head_map(ret) return ret + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "LocalAtten", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapq": self.mapq.serialize(), + "mapkv": self.mapkv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "LocalAtten": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapq = data.pop("mapq") + mapkv = data.pop("mapkv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapq = MLPLayer.deserialize(mapq) + obj.mapkv = MLPLayer.deserialize(mapkv) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + class RepformerLayer(torch.nn.Module): def __init__( @@ -297,7 +492,7 @@ def __init__( ntypes: int, g1_dim=128, g2_dim=16, - axis_dim: int = 4, + axis_neuron: int = 4, update_chnnl_2: bool = True, do_bn_mode: str = "no", bn_momentum: float = 0.1, @@ -315,8 +510,10 @@ def __init__( attn2_has_gate: bool = False, activation_function: str = "tanh", update_style: str = "res_avg", - set_davg_zero: bool = True, # TODO smooth: bool = True, + precision: str = "float64", + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, ): super().__init__() self.epsilon = 1e-4 # protection of 1./nnei @@ -326,12 +523,12 @@ def __init__( sel = [sel] if isinstance(sel, int) else sel self.nnei = sum(sel) assert len(sel) == 1 - self.sel = torch.tensor(sel, device=env.DEVICE) + self.sel = sel self.sec = self.sel - self.axis_dim = axis_dim - self.set_davg_zero = set_davg_zero + self.axis_neuron = axis_neuron self.do_bn_mode = do_bn_mode self.bn_momentum = bn_momentum + self.activation_function = activation_function self.act = ActivationFn(activation_function) self.update_g1_has_grrg = update_g1_has_grrg self.update_g1_has_drrd = update_g1_has_drrd @@ -342,13 +539,21 @@ def __init__( self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False self.update_h2 = update_h2 if self.update_chnnl_2 else False del update_g2_has_g1g1, update_g2_has_attn, update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.attn2_has_gate = attn2_has_gate self.update_style = update_style self.smooth = smooth self.g1_dim = g1_dim self.g2_dim = g2_dim + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.precision = precision - g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_dim) - self.linear1 = SimpleLinear(g1_in_dim, g1_dim) + g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) + self.linear1 = MLPLayer(g1_in_dim, g1_dim, precision=precision) self.linear2 = None self.proj_g1g2 = None self.proj_g1g1g2 = None @@ -360,29 +565,42 @@ def __init__( self.loc_attn = None if self.update_chnnl_2: - self.linear2 = SimpleLinear(g2_dim, g2_dim) + self.linear2 = MLPLayer(g2_dim, g2_dim, precision=precision) if self.update_g1_has_conv: - self.proj_g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) + self.proj_g1g2 = MLPLayer(g1_dim, g2_dim, bias=False, precision=precision) if self.update_g2_has_g1g1: - self.proj_g1g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) + self.proj_g1g1g2 = MLPLayer(g1_dim, g2_dim, bias=False, precision=precision) if self.update_g2_has_attn: self.attn2g_map = Atten2Map( - g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth - ) - self.attn2_mh_apply = Atten2MultiHeadApply(g2_dim, attn2_nhead) - self.attn2_lm = torch.nn.LayerNorm( g2_dim, - elementwise_affine=True, - device=env.DEVICE, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, + attn2_hidden, + attn2_nhead, + attn2_has_gate, + self.smooth, + precision=precision, + ) + self.attn2_mh_apply = Atten2MultiHeadApply( + g2_dim, attn2_nhead, precision=precision + ) + self.attn2_lm = LayerNorm( + g2_dim, eps=ln_eps, trainable=trainable_ln, precision=precision ) if self.update_h2: self.attn2h_map = Atten2Map( - g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth + g2_dim, + attn2_hidden, + attn2_nhead, + attn2_has_gate, + self.smooth, + precision=precision, + ) + self.attn2_ev_apply = Atten2EquiVarApply( + g2_dim, attn2_nhead, precision=precision ) - self.attn2_ev_apply = Atten2EquiVarApply(g2_dim, attn2_nhead) if self.update_g1_has_attn: - self.loc_attn = LocalAtten(g1_dim, attn1_hidden, attn1_nhead, self.smooth) + self.loc_attn = LocalAtten( + g1_dim, attn1_hidden, attn1_nhead, self.smooth, precision=precision + ) if self.do_bn_mode == "uniform": self.bn1 = self._bn_layer() @@ -446,7 +664,7 @@ def _update_g1_conv( else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=gg1.device + (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device ) # nb x nloc x ng2 g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei @@ -474,7 +692,7 @@ def _cal_h2g2( else: g2 = _apply_switch(g2, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=g2.device + (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device ) # nb x nloc x 3 x ng2 h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei @@ -484,11 +702,11 @@ def _cal_grrg(self, h2g2: torch.Tensor) -> torch.Tensor: # nb x nloc x 3 x ng2 nb, nloc, _, ng2 = h2g2.shape # nb x nloc x 3 x axis - h2g2m = torch.split(h2g2, self.axis_dim, dim=-1)[0] + h2g2m = torch.split(h2g2, self.axis_neuron, dim=-1)[0] # nb x nloc x axis x ng2 g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) # nb x nloc x (axisxng2) - g1_13 = g1_13.view(nb, nloc, self.axis_dim * ng2) + g1_13 = g1_13.view(nb, nloc, self.axis_neuron * ng2) return g1_13 def _update_g1_grrg( @@ -635,8 +853,6 @@ def forward( g1 = self._apply_bn(1, g1) if self.bn2 is not None: g2 = self._apply_bn(2, g2) - if self.update_h2: - h2 = _apply_h_norm(h2) g2_update: List[torch.Tensor] = [g2] h2_update: List[torch.Tensor] = [h2] @@ -747,3 +963,143 @@ def _bn_layer( device=env.DEVICE, dtype=env.GLOBAL_PT_FLOAT_PRECISION, ) + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + data = { + "@class": "RepformerLayer", + "@version": 1, + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "g1_dim": self.g1_dim, + "g2_dim": self.g2_dim, + "axis_neuron": self.axis_neuron, + "update_chnnl_2": self.update_chnnl_2, + "do_bn_mode": self.do_bn_mode, + "bn_momentum": self.bn_momentum, + "update_g1_has_conv": self.update_g1_has_conv, + "update_g1_has_drrd": self.update_g1_has_drrd, + "update_g1_has_grrg": self.update_g1_has_grrg, + "update_g1_has_attn": self.update_g1_has_attn, + "update_g2_has_g1g1": self.update_g2_has_g1g1, + "update_g2_has_attn": self.update_g2_has_attn, + "update_h2": self.update_h2, + "attn1_hidden": self.attn1_hidden, + "attn1_nhead": self.attn1_nhead, + "attn2_hidden": self.attn2_hidden, + "attn2_nhead": self.attn2_nhead, + "attn2_has_gate": self.attn2_has_gate, + "activation_function": self.activation_function, + "update_style": self.update_style, + "smooth": self.smooth, + "precision": self.precision, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "linear1": self.linear1.serialize(), + } + if self.update_chnnl_2: + data.update( + { + "linear2": self.linear2.serialize(), + } + ) + if self.update_g1_has_conv: + data.update( + { + "proj_g1g2": self.proj_g1g2.serialize(), + } + ) + if self.update_g2_has_g1g1: + data.update( + { + "proj_g1g1g2": self.proj_g1g1g2.serialize(), + } + ) + if self.update_g2_has_attn: + data.update( + { + "attn2g_map": self.attn2g_map.serialize(), + "attn2_mh_apply": self.attn2_mh_apply.serialize(), + "attn2_lm": self.attn2_lm.serialize(), + } + ) + if self.update_h2: + data.update( + { + "attn2h_map": self.attn2h_map.serialize(), + "attn2_ev_apply": self.attn2_ev_apply.serialize(), + } + ) + if self.update_g1_has_attn: + data.update( + { + "loc_attn": self.loc_attn.serialize(), + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "RepformerLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + linear1 = data.pop("linear1") + update_chnnl_2 = data["update_chnnl_2"] + update_g1_has_conv = data["update_g1_has_conv"] + update_g2_has_g1g1 = data["update_g2_has_g1g1"] + update_g2_has_attn = data["update_g2_has_attn"] + update_h2 = data["update_h2"] + update_g1_has_attn = data["update_g1_has_attn"] + + linear2 = data.pop("linear2", None) + proj_g1g2 = data.pop("proj_g1g2", None) + proj_g1g1g2 = data.pop("proj_g1g1g2", None) + attn2g_map = data.pop("attn2g_map", None) + attn2_mh_apply = data.pop("attn2_mh_apply", None) + attn2_lm = data.pop("attn2_lm", None) + attn2h_map = data.pop("attn2h_map", None) + attn2_ev_apply = data.pop("attn2_ev_apply", None) + loc_attn = data.pop("loc_attn", None) + + obj = cls(**data) + obj.linear1 = MLPLayer.deserialize(linear1) + if update_chnnl_2: + assert isinstance(linear2, dict) + obj.linear2 = MLPLayer.deserialize(linear2) + if update_g1_has_conv: + assert isinstance(proj_g1g2, dict) + obj.proj_g1g2 = MLPLayer.deserialize(proj_g1g2) + if update_g2_has_g1g1: + assert isinstance(proj_g1g1g2, dict) + obj.proj_g1g1g2 = MLPLayer.deserialize(proj_g1g1g2) + if update_g2_has_attn: + assert isinstance(attn2g_map, dict) + assert isinstance(attn2_mh_apply, dict) + assert isinstance(attn2_lm, dict) + obj.attn2g_map = Atten2Map.deserialize(attn2g_map) + obj.attn2_mh_apply = Atten2MultiHeadApply.deserialize(attn2_mh_apply) + obj.attn2_lm = LayerNorm.deserialize(attn2_lm) + if update_h2: + assert isinstance(attn2h_map, dict) + assert isinstance(attn2_ev_apply, dict) + obj.attn2h_map = Atten2Map.deserialize(attn2h_map) + obj.attn2_ev_apply = Atten2EquiVarApply.deserialize(attn2_ev_apply) + if update_g1_has_attn: + assert isinstance(loc_attn, dict) + obj.loc_attn = LocalAtten.deserialize(loc_attn) + return obj diff --git a/deepmd/pt/model/descriptor/repformer_layer_old_impl.py b/deepmd/pt/model/descriptor/repformer_layer_old_impl.py new file mode 100644 index 0000000000..ab39fbb830 --- /dev/null +++ b/deepmd/pt/model/descriptor/repformer_layer_old_impl.py @@ -0,0 +1,743 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + List, +) + +import torch + +from deepmd.pt.model.network.network import ( + SimpleLinear, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) + + +def _make_nei_g1( + g1_ext: torch.Tensor, + nlist: torch.Tensor, +) -> torch.Tensor: + # nlist: nb x nloc x nnei + nb, nloc, nnei = nlist.shape + # g1_ext: nb x nall x ng1 + ng1 = g1_ext.shape[-1] + # index: nb x (nloc x nnei) x ng1 + index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) + # gg1 : nb x (nloc x nnei) x ng1 + gg1 = torch.gather(g1_ext, dim=1, index=index) + # gg1 : nb x nloc x nnei x ng1 + gg1 = gg1.view(nb, nloc, nnei, ng1) + return gg1 + + +def _apply_nlist_mask( + gg: torch.Tensor, + nlist_mask: torch.Tensor, +) -> torch.Tensor: + # gg: nf x nloc x nnei x ng + # msk: nf x nloc x nnei + return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0) + + +def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: + # gg: nf x nloc x nnei x ng + # sw: nf x nloc x nnei + return gg * sw.unsqueeze(-1) + + +def _apply_h_norm( + hh: torch.Tensor, # nf x nloc x nnei x 3 +) -> torch.Tensor: + """Normalize h by the std of vector length. + do not have an idea if this is a good way. + """ + nf, nl, nnei, _ = hh.shape + # nf x nloc x nnei + normh = torch.linalg.norm(hh, dim=-1) + # nf x nloc + std = torch.std(normh, dim=-1) + # nf x nloc x nnei x 3 + hh = hh[:, :, :, :] / (1.0 + std[:, :, None, None]) + return hh + + +class Atten2Map(torch.nn.Module): + def __init__( + self, + ni: int, + nd: int, + nh: int, + has_gate: bool = False, # apply gate to attn map + smooth: bool = True, + attnw_shift: float = 20.0, + ): + super().__init__() + self.ni = ni + self.nd = nd + self.nh = nh + self.mapqk = SimpleLinear(ni, nd * 2 * nh, bias=False) # todo + self.has_gate = has_gate + self.smooth = smooth + self.attnw_shift = attnw_shift + + def forward( + self, + g2: torch.Tensor, # nb x nloc x nnei x ng2 + h2: torch.Tensor, # nb x nloc x nnei x 3 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei + ) -> torch.Tensor: + ( + nb, + nloc, + nnei, + _, + ) = g2.shape + nd, nh = self.nd, self.nh + # nb x nloc x nnei x nd x (nh x 2) + g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2) + # nb x nloc x (nh x 2) x nnei x nd + g2qk = torch.permute(g2qk, (0, 1, 4, 2, 3)) + # nb x nloc x nh x nnei x nd + g2q, g2k = torch.split(g2qk, nh, dim=2) + # g2q = torch.nn.functional.normalize(g2q, dim=-1) + # g2k = torch.nn.functional.normalize(g2k, dim=-1) + # nb x nloc x nh x nnei x nnei + attnw = torch.matmul(g2q, torch.transpose(g2k, -1, -2)) / nd**0.5 + if self.has_gate: + gate = torch.matmul(h2, torch.transpose(h2, -1, -2)).unsqueeze(-3) + attnw = attnw * gate + # mask the attenmap, nb x nloc x 1 x 1 x nnei + attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2) + # mask the attenmap, nb x nloc x 1 x nnei x 1 + attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1) + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ + :, :, None, None, : + ] - self.attnw_shift + else: + attnw = attnw.masked_fill( + attnw_mask, + float("-inf"), + ) + attnw = torch.softmax(attnw, dim=-1) + attnw = attnw.masked_fill( + attnw_mask, + 0.0, + ) + # nb x nloc x nh x nnei x nnei + attnw = attnw.masked_fill( + attnw_mask_c, + 0.0, + ) + if self.smooth: + attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] + # nb x nloc x nnei x nnei + h2h2t = torch.matmul(h2, torch.transpose(h2, -1, -2)) / 3.0**0.5 + # nb x nloc x nh x nnei x nnei + ret = attnw * h2h2t[:, :, None, :, :] + # ret = torch.softmax(g2qk, dim=-1) + # nb x nloc x nnei x nnei x nh + ret = torch.permute(ret, (0, 1, 3, 4, 2)) + return ret + + +class Atten2MultiHeadApply(torch.nn.Module): + def __init__( + self, + ni: int, + nh: int, + ): + super().__init__() + self.ni = ni + self.nh = nh + self.mapv = SimpleLinear(ni, ni * nh, bias=False) + self.head_map = SimpleLinear(ni * nh, ni) + + def forward( + self, + AA: torch.Tensor, # nf x nloc x nnei x nnei x nh + g2: torch.Tensor, # nf x nloc x nnei x ng2 + ) -> torch.Tensor: + nf, nloc, nnei, ng2 = g2.shape + nh = self.nh + # nf x nloc x nnei x ng2 x nh + g2v = self.mapv(g2).view(nf, nloc, nnei, ng2, nh) + # nf x nloc x nh x nnei x ng2 + g2v = torch.permute(g2v, (0, 1, 4, 2, 3)) + # g2v = torch.nn.functional.normalize(g2v, dim=-1) + # nf x nloc x nh x nnei x nnei + AA = torch.permute(AA, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x ng2 + ret = torch.matmul(AA, g2v) + # nf x nloc x nnei x ng2 x nh + ret = torch.permute(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) + # nf x nloc x nnei x ng2 + return self.head_map(ret) + + +class Atten2EquiVarApply(torch.nn.Module): + def __init__( + self, + ni: int, + nh: int, + ): + super().__init__() + self.ni = ni + self.nh = nh + self.head_map = SimpleLinear(nh, 1, bias=False) + + def forward( + self, + AA: torch.Tensor, # nf x nloc x nnei x nnei x nh + h2: torch.Tensor, # nf x nloc x nnei x 3 + ) -> torch.Tensor: + nf, nloc, nnei, _ = h2.shape + nh = self.nh + # nf x nloc x nh x nnei x nnei + AA = torch.permute(AA, (0, 1, 4, 2, 3)) + h2m = torch.unsqueeze(h2, dim=2) + # nf x nloc x nh x nnei x 3 + h2m = torch.tile(h2m, [1, 1, nh, 1, 1]) + # nf x nloc x nh x nnei x 3 + ret = torch.matmul(AA, h2m) + # nf x nloc x nnei x 3 x nh + ret = torch.permute(ret, (0, 1, 3, 4, 2)).view(nf, nloc, nnei, 3, nh) + # nf x nloc x nnei x 3 + return torch.squeeze(self.head_map(ret), dim=-1) + + +class LocalAtten(torch.nn.Module): + def __init__( + self, + ni: int, + nd: int, + nh: int, + smooth: bool = True, + attnw_shift: float = 20.0, + ): + super().__init__() + self.ni = ni + self.nd = nd + self.nh = nh + self.mapq = SimpleLinear(ni, nd * 1 * nh, bias=False) + self.mapkv = SimpleLinear(ni, (nd + ni) * nh, bias=False) + self.head_map = SimpleLinear(ni * nh, ni) + self.smooth = smooth + self.attnw_shift = attnw_shift + + def forward( + self, + g1: torch.Tensor, # nb x nloc x ng1 + gg1: torch.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei + ) -> torch.Tensor: + nb, nloc, nnei = nlist_mask.shape + ni, nd, nh = self.ni, self.nd, self.nh + assert ni == g1.shape[-1] + assert ni == gg1.shape[-1] + # nb x nloc x nd x nh + g1q = self.mapq(g1).view(nb, nloc, nd, nh) + # nb x nloc x nh x nd + g1q = torch.permute(g1q, (0, 1, 3, 2)) + # nb x nloc x nnei x (nd+ni) x nh + gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh) + gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3)) + # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 + gg1k, gg1v = torch.split(gg1kv, [nd, ni], dim=-1) + + # nb x nloc x nh x 1 x nnei + attnw = torch.matmul(g1q.unsqueeze(-2), torch.transpose(gg1k, -1, -2)) / nd**0.5 + # nb x nloc x nh x nnei + attnw = attnw.squeeze(-2) + # mask the attenmap, nb x nloc x 1 x nnei + attnw_mask = ~nlist_mask.unsqueeze(-2) + # nb x nloc x nh x nnei + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift + else: + attnw = attnw.masked_fill( + attnw_mask, + float("-inf"), + ) + attnw = torch.softmax(attnw, dim=-1) + attnw = attnw.masked_fill( + attnw_mask, + 0.0, + ) + if self.smooth: + attnw = attnw * sw.unsqueeze(-2) + + # nb x nloc x nh x ng1 + ret = ( + torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) + ) + # nb x nloc x ng1 + ret = self.head_map(ret) + return ret + + +class RepformerLayer(torch.nn.Module): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + g1_dim=128, + g2_dim=16, + axis_neuron: int = 4, + update_chnnl_2: bool = True, + do_bn_mode: str = "no", + bn_momentum: float = 0.1, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation_function: str = "tanh", + update_style: str = "res_avg", + set_davg_zero: bool = True, # TODO + smooth: bool = True, + ): + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.rcut = rcut + self.rcut_smth = rcut_smth + self.ntypes = ntypes + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + assert len(sel) == 1 + self.sel = torch.tensor(sel, device=env.DEVICE) + self.sec = self.sel + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + self.do_bn_mode = do_bn_mode + self.bn_momentum = bn_momentum + self.act = ActivationFn(activation_function) + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_attn = update_g1_has_attn + self.update_chnnl_2 = update_chnnl_2 + self.update_g2_has_g1g1 = update_g2_has_g1g1 if self.update_chnnl_2 else False + self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False + self.update_h2 = update_h2 if self.update_chnnl_2 else False + del update_g2_has_g1g1, update_g2_has_attn, update_h2 + self.update_style = update_style + self.smooth = smooth + self.g1_dim = g1_dim + self.g2_dim = g2_dim + + g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) + self.linear1 = SimpleLinear(g1_in_dim, g1_dim) + self.linear2 = None + self.proj_g1g2 = None + self.proj_g1g1g2 = None + self.attn2g_map = None + self.attn2_mh_apply = None + self.attn2_lm = None + self.attn2h_map = None + self.attn2_ev_apply = None + self.loc_attn = None + + if self.update_chnnl_2: + self.linear2 = SimpleLinear(g2_dim, g2_dim) + if self.update_g1_has_conv: + self.proj_g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) + if self.update_g2_has_g1g1: + self.proj_g1g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) + if self.update_g2_has_attn: + self.attn2g_map = Atten2Map( + g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth + ) + self.attn2_mh_apply = Atten2MultiHeadApply(g2_dim, attn2_nhead) + self.attn2_lm = torch.nn.LayerNorm( + g2_dim, + elementwise_affine=True, + device=env.DEVICE, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + ) + if self.update_h2: + self.attn2h_map = Atten2Map( + g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth + ) + self.attn2_ev_apply = Atten2EquiVarApply(g2_dim, attn2_nhead) + if self.update_g1_has_attn: + self.loc_attn = LocalAtten(g1_dim, attn1_hidden, attn1_nhead, self.smooth) + + if self.do_bn_mode == "uniform": + self.bn1 = self._bn_layer() + self.bn2 = self._bn_layer() + elif self.do_bn_mode == "component": + self.bn1 = self._bn_layer(nf=g1_dim) + self.bn2 = self._bn_layer(nf=g2_dim) + elif self.do_bn_mode == "no": + self.bn1, self.bn2 = None, None + else: + raise RuntimeError(f"unknown bn_mode {self.do_bn_mode}") + + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: + ret = g1d + if self.update_g1_has_grrg: + ret += g2d * ax + if self.update_g1_has_drrd: + ret += g1d * ax + if self.update_g1_has_conv: + ret += g2d + return ret + + def _update_h2( + self, + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + assert self.attn2h_map is not None + assert self.attn2_ev_apply is not None + nb, nloc, nnei, _ = g2.shape + # # nb x nloc x nnei x nh2 + # h2_1 = self.attn2_ev_apply(AA, h2) + # h2_update.append(h2_1) + # nb x nloc x nnei x nnei x nh + AAh = self.attn2h_map(g2, h2, nlist_mask, sw) + # nb x nloc x nnei x nh2 + h2_1 = self.attn2_ev_apply(AAh, h2) + return h2_1 + + def _update_g1_conv( + self, + gg1: torch.Tensor, + g2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + assert self.proj_g1g2 is not None + nb, nloc, nnei, _ = g2.shape + ng1 = gg1.shape[-1] + ng2 = g2.shape[-1] + # gg1 : nb x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) + # nb x nloc x nnei x ng2 + gg1 = _apply_nlist_mask(gg1, nlist_mask) + if not self.smooth: + # normalized by number of neighbors, not smooth + # nb x nloc x 1 + invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)).unsqueeze(-1) + else: + gg1 = _apply_switch(gg1, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=gg1.device + ) + # nb x nloc x ng2 + g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei + return g1_11 + + def _cal_h2g2( + self, + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + # g2: nf x nloc x nnei x ng2 + # h2: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nb, nloc, nnei, _ = g2.shape + ng2 = g2.shape[-1] + # nb x nloc x nnei x ng2 + g2 = _apply_nlist_mask(g2, nlist_mask) + if not self.smooth: + # nb x nloc + invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)) + # nb x nloc x 1 x 1 + invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) + else: + g2 = _apply_switch(g2, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=g2.device + ) + # nb x nloc x 3 x ng2 + h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei + return h2g2 + + def _cal_grrg(self, h2g2: torch.Tensor) -> torch.Tensor: + # nb x nloc x 3 x ng2 + nb, nloc, _, ng2 = h2g2.shape + # nb x nloc x 3 x axis + h2g2m = torch.split(h2g2, self.axis_neuron, dim=-1)[0] + # nb x nloc x axis x ng2 + g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) + # nb x nloc x (axisxng2) + g1_13 = g1_13.view(nb, nloc, self.axis_neuron * ng2) + return g1_13 + + def _update_g1_grrg( + self, + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + # g2: nf x nloc x nnei x ng2 + # h2: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nb, nloc, nnei, _ = g2.shape + ng2 = g2.shape[-1] + # nb x nloc x 3 x ng2 + h2g2 = self._cal_h2g2(g2, h2, nlist_mask, sw) + # nb x nloc x (axisxng2) + g1_13 = self._cal_grrg(h2g2) + return g1_13 + + def _update_g2_g1g1( + self, + g1: torch.Tensor, # nb x nloc x ng1 + gg1: torch.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei + ) -> torch.Tensor: + ret = g1.unsqueeze(-2) * gg1 + # nb x nloc x nnei x ng1 + ret = _apply_nlist_mask(ret, nlist_mask) + if self.smooth: + ret = _apply_switch(ret, sw) + return ret + + def _apply_bn( + self, + bn_number: int, + gg: torch.Tensor, + ): + if self.do_bn_mode == "uniform": + return self._apply_bn_uni(bn_number, gg) + elif self.do_bn_mode == "component": + return self._apply_bn_comp(bn_number, gg) + else: + return gg + + def _apply_nb_1(self, bn_number: int, gg: torch.Tensor) -> torch.Tensor: + nb, nl, nf = gg.shape + gg = gg.view([nb, 1, nl * nf]) + if bn_number == 1: + assert self.bn1 is not None + gg = self.bn1(gg) + else: + assert self.bn2 is not None + gg = self.bn2(gg) + return gg.view([nb, nl, nf]) + + def _apply_nb_2( + self, + bn_number: int, + gg: torch.Tensor, + ) -> torch.Tensor: + nb, nl, nnei, nf = gg.shape + gg = gg.view([nb, 1, nl * nnei * nf]) + if bn_number == 1: + assert self.bn1 is not None + gg = self.bn1(gg) + else: + assert self.bn2 is not None + gg = self.bn2(gg) + return gg.view([nb, nl, nnei, nf]) + + def _apply_bn_uni( + self, + bn_number: int, + gg: torch.Tensor, + mode: str = "1", + ) -> torch.Tensor: + if len(gg.shape) == 3: + return self._apply_nb_1(bn_number, gg) + elif len(gg.shape) == 4: + return self._apply_nb_2(bn_number, gg) + else: + raise RuntimeError(f"unsupported input shape {gg.shape}") + + def _apply_bn_comp( + self, + bn_number: int, + gg: torch.Tensor, + ) -> torch.Tensor: + ss = gg.shape + nf = ss[-1] + gg = gg.view([-1, nf]) + if bn_number == 1: + assert self.bn1 is not None + gg = self.bn1(gg).view(ss) + else: + assert self.bn2 is not None + gg = self.bn2(gg).view(ss) + return gg + + def forward( + self, + g1_ext: torch.Tensor, # nf x nall x ng1 + g2: torch.Tensor, # nf x nloc x nnei x ng2 + h2: torch.Tensor, # nf x nloc x nnei x 3 + nlist: torch.Tensor, # nf x nloc x nnei + nlist_mask: torch.Tensor, # nf x nloc x nnei + sw: torch.Tensor, # switch func, nf x nloc x nnei + ): + """ + Parameters + ---------- + g1_ext : nf x nall x ng1 extended single-atom chanel + g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant + h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant + nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) + nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei switch function + + Returns + ------- + g1: nf x nloc x ng1 updated single-atom chanel + g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant + h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant + """ + cal_gg1 = ( + self.update_g1_has_drrd + or self.update_g1_has_conv + or self.update_g1_has_attn + or self.update_g2_has_g1g1 + ) + + nb, nloc, nnei, _ = g2.shape + nall = g1_ext.shape[1] + g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) + assert (nb, nloc) == g1.shape[:2] + assert (nb, nloc, nnei) == h2.shape[:3] + ng1 = g1.shape[-1] + ng2 = g2.shape[-1] + nh2 = h2.shape[-1] + + if self.bn1 is not None: + g1 = self._apply_bn(1, g1) + if self.bn2 is not None: + g2 = self._apply_bn(2, g2) + if self.update_h2: + h2 = _apply_h_norm(h2) + + g2_update: List[torch.Tensor] = [g2] + h2_update: List[torch.Tensor] = [h2] + g1_update: List[torch.Tensor] = [g1] + g1_mlp: List[torch.Tensor] = [g1] + + if cal_gg1: + gg1 = _make_nei_g1(g1_ext, nlist) + else: + gg1 = None + + if self.update_chnnl_2: + # nb x nloc x nnei x ng2 + assert self.linear2 is not None + g2_1 = self.act(self.linear2(g2)) + g2_update.append(g2_1) + + if self.update_g2_has_g1g1: + assert gg1 is not None + assert self.proj_g1g1g2 is not None + g2_update.append( + self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) + ) + + if self.update_g2_has_attn: + assert self.attn2g_map is not None + assert self.attn2_mh_apply is not None + assert self.attn2_lm is not None + # nb x nloc x nnei x nnei x nh + AAg = self.attn2g_map(g2, h2, nlist_mask, sw) + # nb x nloc x nnei x ng2 + g2_2 = self.attn2_mh_apply(AAg, g2) + g2_2 = self.attn2_lm(g2_2) + g2_update.append(g2_2) + + if self.update_h2: + h2_update.append(self._update_h2(g2, h2, nlist_mask, sw)) + + if self.update_g1_has_conv: + assert gg1 is not None + g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) + + if self.update_g1_has_grrg: + g1_mlp.append(self._update_g1_grrg(g2, h2, nlist_mask, sw)) + + if self.update_g1_has_drrd: + assert gg1 is not None + g1_mlp.append(self._update_g1_grrg(gg1, h2, nlist_mask, sw)) + + # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] + # conv grrg drrd + g1_1 = self.act(self.linear1(torch.cat(g1_mlp, dim=-1))) + g1_update.append(g1_1) + + if self.update_g1_has_attn: + assert gg1 is not None + assert self.loc_attn is not None + g1_update.append(self.loc_attn(g1, gg1, nlist_mask, sw)) + + # update + if self.update_chnnl_2: + g2_new = self.list_update(g2_update) + h2_new = self.list_update(h2_update) + else: + g2_new, h2_new = g2, h2 + g1_new = self.list_update(g1_update) + return g1_new, g2_new, h2_new + + @torch.jit.export + def list_update_res_avg( + self, + update_list: List[torch.Tensor], + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + @torch.jit.export + def list_update_res_incr(self, update_list: List[torch.Tensor]) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + @torch.jit.export + def list_update(self, update_list: List[torch.Tensor]) -> torch.Tensor: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def _bn_layer( + self, + nf: int = 1, + ) -> Callable: + return torch.nn.BatchNorm1d( + nf, + eps=1e-5, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + device=env.DEVICE, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + ) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 16a38052b1..18ed502e59 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -16,8 +16,8 @@ from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat, ) -from deepmd.pt.model.network.network import ( - SimpleLinear, +from deepmd.pt.model.network.mlp import ( + MLPLayer, ) from deepmd.pt.utils import ( env, @@ -41,17 +41,7 @@ from .repformer_layer import ( RepformerLayer, ) - -mydtype = env.GLOBAL_PT_FLOAT_PRECISION -mydev = env.DEVICE - - -def torch_linear(*args, **kwargs): - return torch.nn.Linear(*args, **kwargs, dtype=mydtype, device=mydev) - - -simple_linear = SimpleLinear -mylinear = simple_linear +from .repformer_layer_old_impl import RepformerLayer as RepformerLayerOld @DescriptorBlock.register("se_repformer") @@ -66,7 +56,7 @@ def __init__( nlayers: int = 3, g1_dim=128, g2_dim=16, - axis_dim: int = 4, + axis_neuron: int = 4, direct_dist: bool = False, do_bn_mode: str = "no", bn_momentum: float = 0.1, @@ -84,24 +74,97 @@ def __init__( attn2_has_gate: bool = False, activation_function: str = "tanh", update_style: str = "res_avg", - set_davg_zero: bool = True, # TODO + set_davg_zero: bool = True, smooth: bool = True, - add_type_ebd_to_seq: bool = False, exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, - type: Optional[str] = None, + precision: str = "float64", + resnet_dt: bool = False, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + old_impl: bool = False, ): - """ - smooth: - If strictly smooth, cannot be used with update_g1_has_attn - add_type_ebd_to_seq: - At the presence of seq_input (optional input to forward), - whether or not add an type embedding to seq_input. - If no seq_input is given, it has no effect. + r""" + The repformer descriptor block. + + Parameters + ---------- + rcut : float + The cut-off radius. + rcut_smth : float + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. + sel : int + Maximally possible number of selected neighbors. + ntypes : int + Number of element types + nlayers : int, optional + Number of repformer layers. + g1_dim : int, optional + Dimension of the first graph convolution layer. + g2_dim : int, optional + Dimension of the second graph convolution layer. + axis_neuron : int, optional + Size of the submatrix of G (embedding matrix). + direct_dist : bool, optional + Whether to use direct distance information (1/r term) in the repformer block. + do_bn_mode : str, optional + The mode to do batch normalization in the repformer layers. Supported modes are: + -'no': Not do batch normalization. + -'uniform': Do batch normalization using scalar running momentum and learnable gamma/beta (num_features=1). + -'component': Do batch normalization using vector running momentum and learnable gamma/beta (num_features=d). + bn_momentum : float, optional + Momentum used in the batch normalization. + update_g1_has_conv : bool, optional + Whether to update the g1 rep with convolution term. + update_g1_has_drrd : bool, optional + Whether to update the g1 rep with the drrd term. + update_g1_has_grrg : bool, optional + Whether to update the g1 rep with the grrg term. + update_g1_has_attn : bool, optional + Whether to update the g1 rep with the localized self-attention. + update_g2_has_g1g1 : bool, optional + Whether to update the g2 rep with the g1xg1 term. + update_g2_has_attn : bool, optional + Whether to update the g2 rep with the gated self-attention. + update_h2 : bool, optional + Whether to update the h2 rep. + attn1_hidden : int, optional + The hidden dimension of localized self-attention to update the g1 rep. + attn1_nhead : int, optional + The number of heads in localized self-attention to update the g1 rep. + attn2_hidden : int, optional + The hidden dimension of gated self-attention to update the g2 rep. + attn2_nhead : int, optional + The number of heads in gated self-attention to update the g2 rep. + attn2_has_gate : bool, optional + Whether to use gate in the gated self-attention to update the g2 rep. + activation_function : str, optional + The activation function in the embedding net. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + set_davg_zero : bool, optional + Set the normalization average to zero. + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : List[List[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + resnet_dt : bool, optional + Whether to use a "Timestep" in the skip connection. + trainable_ln : bool, optional + Whether to use trainable shift and scale weights in layer normalization. + ln_eps : float, optional + The epsilon value for layer normalization. """ super().__init__() - del type - self.epsilon = 1e-4 # protection of 1./nnei self.rcut = rcut self.rcut_smth = rcut_smth self.ntypes = ntypes @@ -113,54 +176,113 @@ def __init__( self.sel = sel self.sec = self.sel self.split_sel = self.sel - self.axis_dim = axis_dim + self.axis_neuron = axis_neuron self.set_davg_zero = set_davg_zero self.g1_dim = g1_dim self.g2_dim = g2_dim - self.act = ActivationFn(activation_function) + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_attn = update_g1_has_attn + self.update_g2_has_g1g1 = update_g2_has_g1g1 + self.update_g2_has_attn = update_g2_has_attn + self.update_h2 = update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_has_gate = attn2_has_gate + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.do_bn_mode = do_bn_mode + self.bn_momentum = bn_momentum + self.activation_function = activation_function + self.update_style = update_style self.direct_dist = direct_dist - self.add_type_ebd_to_seq = add_type_ebd_to_seq + self.act = ActivationFn(activation_function) + self.smooth = smooth # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.env_protection = env_protection + self.precision = precision + self.resnet_dt = resnet_dt + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.old_impl = old_impl - self.g2_embd = mylinear(1, self.g2_dim) + self.g2_embd = MLPLayer(1, self.g2_dim) layers = [] for ii in range(nlayers): - layers.append( - RepformerLayer( - rcut, - rcut_smth, - sel, - ntypes, - self.g1_dim, - self.g2_dim, - axis_dim=self.axis_dim, - update_chnnl_2=(ii != nlayers - 1), - do_bn_mode=do_bn_mode, - bn_momentum=bn_momentum, - update_g1_has_conv=update_g1_has_conv, - update_g1_has_drrd=update_g1_has_drrd, - update_g1_has_grrg=update_g1_has_grrg, - update_g1_has_attn=update_g1_has_attn, - update_g2_has_g1g1=update_g2_has_g1g1, - update_g2_has_attn=update_g2_has_attn, - update_h2=update_h2, - attn1_hidden=attn1_hidden, - attn1_nhead=attn1_nhead, - attn2_has_gate=attn2_has_gate, - attn2_hidden=attn2_hidden, - attn2_nhead=attn2_nhead, - activation_function=activation_function, - update_style=update_style, - smooth=smooth, + if self.old_impl: + layers.append( + RepformerLayerOld( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.g1_dim, + self.g2_dim, + axis_neuron=self.axis_neuron, + update_chnnl_2=(ii != nlayers - 1), + do_bn_mode=self.do_bn_mode, + bn_momentum=self.bn_momentum, + update_g1_has_conv=self.update_g1_has_conv, + update_g1_has_drrd=self.update_g1_has_drrd, + update_g1_has_grrg=self.update_g1_has_grrg, + update_g1_has_attn=self.update_g1_has_attn, + update_g2_has_g1g1=self.update_g2_has_g1g1, + update_g2_has_attn=self.update_g2_has_attn, + update_h2=self.update_h2, + attn1_hidden=self.attn1_hidden, + attn1_nhead=self.attn1_nhead, + attn2_has_gate=self.attn2_has_gate, + attn2_hidden=self.attn2_hidden, + attn2_nhead=self.attn2_nhead, + activation_function=self.activation_function, + update_style=self.update_style, + smooth=self.smooth, + ) + ) + else: + layers.append( + RepformerLayer( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.g1_dim, + self.g2_dim, + axis_neuron=self.axis_neuron, + update_chnnl_2=(ii != nlayers - 1), + do_bn_mode=self.do_bn_mode, + bn_momentum=self.bn_momentum, + update_g1_has_conv=self.update_g1_has_conv, + update_g1_has_drrd=self.update_g1_has_drrd, + update_g1_has_grrg=self.update_g1_has_grrg, + update_g1_has_attn=self.update_g1_has_attn, + update_g2_has_g1g1=self.update_g2_has_g1g1, + update_g2_has_attn=self.update_g2_has_attn, + update_h2=self.update_h2, + attn1_hidden=self.attn1_hidden, + attn1_nhead=self.attn1_nhead, + attn2_has_gate=self.attn2_has_gate, + attn2_hidden=self.attn2_hidden, + attn2_nhead=self.attn2_nhead, + activation_function=self.activation_function, + update_style=self.update_style, + smooth=self.smooth, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + precision=precision, + ) ) - ) self.layers = torch.nn.ModuleList(layers) - sshape = (self.ntypes, self.nnei, 4) - mean = torch.zeros(sshape, dtype=mydtype, device=mydev) - stddev = torch.ones(sshape, dtype=mydtype, device=mydev) + wanted_shape = (self.ntypes, self.nnei, 4) + mean = torch.zeros( + wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + stddev = torch.ones( + wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.stats = None @@ -193,6 +315,22 @@ def get_dim_emb(self) -> int: """Returns the embedding dimension g2.""" return self.g2_dim + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + def mixed_types(self) -> bool: """If true, the discriptor 1. assumes total number of atoms aligned across frames; @@ -240,7 +378,10 @@ def forward( nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] - # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 + # # nf x nloc x nnei + exclude_mask = self.emask(nlist, extended_atype) + nlist = nlist * exclude_mask + # nf x nloc x nnei x 4, nf x nloc x nnei x 3, nf x nloc x nnei x 1 dmatrix, diff, sw = prod_env_mat( extended_coord, nlist, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 9bf4788bf2..6fedb60d38 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -761,7 +761,7 @@ def forward( sw: Optional[torch.Tensor] = None, ): residual = x - x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x, _ = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) x = residual + x x = self.attn_layer_norm(x) return x @@ -808,12 +808,171 @@ def deserialize(cls, data: dict) -> "NeighborGatedAttentionLayer": return obj +# class GatedAttentionLayer(nn.Module): +# def __init__( +# self, +# nnei: int, +# embed_dim: int, +# hidden_dim: int, +# dotr: bool = False, +# do_mask: bool = False, +# scaling_factor: float = 1.0, +# normalize: bool = True, +# temperature: Optional[float] = None, +# bias: bool = True, +# smooth: bool = True, +# precision: str = DEFAULT_PRECISION, +# ): +# """Construct a neighbor-wise attention net.""" +# super().__init__() +# self.nnei = nnei +# self.embed_dim = embed_dim +# self.hidden_dim = hidden_dim +# self.dotr = dotr +# self.do_mask = do_mask +# self.bias = bias +# self.smooth = smooth +# self.scaling_factor = scaling_factor +# self.temperature = temperature +# self.precision = precision +# if temperature is None: +# self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 +# else: +# self.scaling = temperature +# self.normalize = normalize +# self.in_proj = MLPLayer( +# embed_dim, +# hidden_dim * 3, +# bias=bias, +# use_timestep=False, +# bavg=0.0, +# stddev=1.0, +# precision=precision, +# ) +# self.out_proj = MLPLayer( +# hidden_dim, +# embed_dim, +# bias=bias, +# use_timestep=False, +# bavg=0.0, +# stddev=1.0, +# precision=precision, +# ) +# +# def forward( +# self, +# query, +# nei_mask, +# input_r: Optional[torch.Tensor] = None, +# sw: Optional[torch.Tensor] = None, +# attnw_shift: float = 20.0, +# ): +# """Compute the gated self-attention. +# +# Parameters +# ---------- +# query +# inputs with shape: (nf x nloc) x nnei x embed_dim. +# nei_mask +# neighbor mask, with paddings being 0. shape: (nf x nloc) x nnei. +# input_r +# normalized radial. shape: (nf x nloc) x nnei x 3. +# sw +# The smooth switch function. shape: (nf x nloc) x nnei +# attnw_shift : float +# The attention weight shift to preserve smoothness when doing padding before softmax. +# """ +# q, k, v = self.in_proj(query).chunk(3, dim=-1) +# # [nframes * nloc, nnei, hidden_dim] +# q = q.view(-1, self.nnei, self.hidden_dim) +# k = k.view(-1, self.nnei, self.hidden_dim) +# v = v.view(-1, self.nnei, self.hidden_dim) +# if self.normalize: +# q = torch_func.normalize(q, dim=-1) +# k = torch_func.normalize(k, dim=-1) +# v = torch_func.normalize(v, dim=-1) +# q = q * self.scaling +# k = k.transpose(1, 2) +# # [nframes * nloc, nnei, nnei] +# attn_weights = torch.bmm(q, k) +# # [nframes * nloc, nnei] +# nei_mask = nei_mask.view(-1, self.nnei) +# if self.smooth: +# # [nframes * nloc, nnei] +# assert sw is not None +# sw = sw.view([-1, self.nnei]) +# attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[ +# :, None, : +# ] - attnw_shift +# else: +# attn_weights = attn_weights.masked_fill( +# ~nei_mask.unsqueeze(1), float("-inf") +# ) +# attn_weights = torch_func.softmax(attn_weights, dim=-1) +# attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), 0.0) +# if self.smooth: +# assert sw is not None +# attn_weights = attn_weights * sw[:, :, None] * sw[:, None, :] +# if self.dotr: +# assert input_r is not None, "input_r must be provided when dotr is True!" +# angular_weight = torch.bmm(input_r, input_r.transpose(1, 2)) +# attn_weights = attn_weights * angular_weight +# o = torch.bmm(attn_weights, v) +# output = self.out_proj(o) +# return output +# +# def serialize(self) -> dict: +# """Serialize the networks to a dict. +# +# Returns +# ------- +# dict +# The serialized networks. +# """ +# # network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()} +# # network_type_name = network_type_map_inv[self.network_type] +# return { +# "nnei": self.nnei, +# "embed_dim": self.embed_dim, +# "hidden_dim": self.hidden_dim, +# "dotr": self.dotr, +# "do_mask": self.do_mask, +# "scaling_factor": self.scaling_factor, +# "normalize": self.normalize, +# "temperature": self.temperature, +# "bias": self.bias, +# "smooth": self.smooth, +# "precision": self.precision, +# "in_proj": self.in_proj.serialize(), +# "out_proj": self.out_proj.serialize(), +# } +# +# @classmethod +# def deserialize(cls, data: dict) -> "GatedAttentionLayer": +# """Deserialize the networks from a dict. +# +# Parameters +# ---------- +# data : dict +# The dict to deserialize from. +# """ +# data = data.copy() +# in_proj = data.pop("in_proj") +# out_proj = data.pop("out_proj") +# obj = cls(**data) +# obj.in_proj = MLPLayer.deserialize(in_proj) +# obj.out_proj = MLPLayer.deserialize(out_proj) +# return obj +# + + class GatedAttentionLayer(nn.Module): def __init__( self, nnei: int, embed_dim: int, hidden_dim: int, + num_heads: int = 1, dotr: bool = False, do_mask: bool = False, scaling_factor: float = 1.0, @@ -821,13 +980,16 @@ def __init__( temperature: Optional[float] = None, bias: bool = True, smooth: bool = True, - precision: str = DEFAULT_PRECISION, + precision: str = "float", ): - """Construct a neighbor-wise attention net.""" + """Construct a multi-head neighbor-wise attention net.""" super().__init__() + assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" self.nnei = nnei self.embed_dim = embed_dim self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads self.dotr = dotr self.do_mask = do_mask self.bias = bias @@ -835,10 +997,11 @@ def __init__( self.scaling_factor = scaling_factor self.temperature = temperature self.precision = precision - if temperature is None: - self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 - else: - self.scaling = temperature + self.scaling = ( + (self.head_dim * scaling_factor) ** -0.5 + if temperature is None + else temperature + ) self.normalize = normalize self.in_proj = MLPLayer( embed_dim, @@ -867,7 +1030,7 @@ def forward( sw: Optional[torch.Tensor] = None, attnw_shift: float = 20.0, ): - """Compute the gated self-attention. + """Compute the multi-head gated self-attention. Parameters ---------- @@ -883,43 +1046,66 @@ def forward( The attention weight shift to preserve smoothness when doing padding before softmax. """ q, k, v = self.in_proj(query).chunk(3, dim=-1) - # [nframes * nloc, nnei, hidden_dim] - q = q.view(-1, self.nnei, self.hidden_dim) - k = k.view(-1, self.nnei, self.hidden_dim) - v = v.view(-1, self.nnei, self.hidden_dim) + + # Reshape for multi-head attention: (nf x nloc) x num_heads x nnei x head_dim + q = q.view(-1, self.nnei, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(-1, self.nnei, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(-1, self.nnei, self.num_heads, self.head_dim).transpose(1, 2) + if self.normalize: q = torch_func.normalize(q, dim=-1) k = torch_func.normalize(k, dim=-1) v = torch_func.normalize(v, dim=-1) + q = q * self.scaling - k = k.transpose(1, 2) - # [nframes * nloc, nnei, nnei] - attn_weights = torch.bmm(q, k) - # [nframes * nloc, nnei] + # (nf x nloc) x num_heads x head_dim x nnei + k = k.transpose(-2, -1) + + # Compute attention scores + # (nf x nloc) x num_heads x nnei x nnei + attn_weights = torch.matmul(q, k) + # (nf x nloc) x nnei nei_mask = nei_mask.view(-1, self.nnei) + if self.smooth: - # [nframes * nloc, nnei] assert sw is not None - sw = sw.view([-1, self.nnei]) - attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[ - :, None, : + # (nf x nloc) x 1 x nnei + sw = sw.view(-1, 1, self.nnei) + attn_weights = (attn_weights + attnw_shift) * sw[:, :, :, None] * sw[ + :, :, None, : ] - attnw_shift else: + # (nf x nloc) x 1 x 1 x nnei attn_weights = attn_weights.masked_fill( - ~nei_mask.unsqueeze(1), float("-inf") + ~nei_mask.unsqueeze(1).unsqueeze(1), float("-inf") ) + attn_weights = torch_func.softmax(attn_weights, dim=-1) - attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), 0.0) + attn_weights = attn_weights.masked_fill( + ~nei_mask.unsqueeze(1).unsqueeze(-1), 0.0 + ) if self.smooth: assert sw is not None - attn_weights = attn_weights * sw[:, :, None] * sw[:, None, :] + attn_weights = attn_weights * sw[:, :, :, None] * sw[:, :, None, :] + if self.dotr: + # (nf x nloc) x nnei x 3 assert input_r is not None, "input_r must be provided when dotr is True!" - angular_weight = torch.bmm(input_r, input_r.transpose(1, 2)) + # (nf x nloc) x 1 x nnei x nnei + angular_weight = torch.matmul(input_r, input_r.transpose(-2, -1)).view( + -1, 1, self.nnei, self.nnei + ) attn_weights = attn_weights * angular_weight - o = torch.bmm(attn_weights, v) + + # Apply attention to values + # (nf x nloc) x nnei x (num_heads x head_dim) + o = ( + torch.matmul(attn_weights, v) + .transpose(1, 2) + .reshape(-1, self.nnei, self.hidden_dim) + ) output = self.out_proj(o) - return output + return output, attn_weights def serialize(self) -> dict: """Serialize the networks to a dict. @@ -929,12 +1115,11 @@ def serialize(self) -> dict: dict The serialized networks. """ - # network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()} - # network_type_name = network_type_map_inv[self.network_type] return { "nnei": self.nnei, "embed_dim": self.embed_dim, "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, "dotr": self.dotr, "do_mask": self.do_mask, "scaling_factor": self.scaling_factor, diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 762461111e..5bd1fb0484 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -44,6 +44,28 @@ def empty_t(shape, precision): return torch.empty(shape, dtype=precision, device=device) +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + xx: torch.Tensor, + ) -> torch.Tensor: + """The Identity operation layer.""" + return xx + + def serialize(self) -> dict: + return { + "@class": "Identity", + "@version": 1, + } + + @classmethod + def deserialize(cls, data: dict) -> "Identity": + return Identity() + + class MLPLayer(nn.Module): def __init__( self, @@ -56,31 +78,47 @@ def __init__( bavg: float = 0.0, stddev: float = 1.0, precision: str = DEFAULT_PRECISION, + init: str = "default", ): super().__init__() # only use_timestep when skip connection is established. self.use_timestep = use_timestep and ( num_out == num_in or num_out == num_in * 2 ) + self.num_in = num_in + self.num_out = num_out self.activate_name = activation_function self.activate = ActivationFn(self.activate_name) self.precision = precision self.prec = PRECISION_DICT[self.precision] self.matrix = nn.Parameter(data=empty_t((num_in, num_out), self.prec)) - nn.init.normal_(self.matrix.data, std=stddev / np.sqrt(num_out + num_in)) if bias: self.bias = nn.Parameter( data=empty_t([num_out], self.prec), ) - nn.init.normal_(self.bias.data, mean=bavg, std=stddev) else: self.bias = None if self.use_timestep: self.idt = nn.Parameter(data=empty_t([num_out], self.prec)) - nn.init.normal_(self.idt.data, mean=0.1, std=0.001) else: self.idt = None self.resnet = resnet + if init == "default": + self._default_normal_init(bavg=bavg, stddev=stddev) + elif init == "trunc_normal": + self._trunc_normal_init(1.0) + elif init == "relu": + self._trunc_normal_init(2.0) + elif init == "glorot": + self._glorot_uniform_init() + elif init == "gating": + self._zero_init(self.use_bias) + elif init == "kaiming_normal": + self._normal_init() + elif init == "final": + self._zero_init(False) + else: + raise ValueError(f"Unknown initialization method: {init}") def check_type_consistency(self): precision = self.precision @@ -90,8 +128,8 @@ def check_var(var): # assertion "float64" == "double" would fail assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision] - check_var(self.w) - check_var(self.b) + check_var(self.matrix) + check_var(self.bias) check_var(self.idt) def dim_in(self) -> int: @@ -100,6 +138,36 @@ def dim_in(self) -> int: def dim_out(self) -> int: return self.matrix.shape[1] + def _default_normal_init(self, bavg: float = 0.0, stddev: float = 1.0): + nn.init.normal_( + self.matrix.data, std=stddev / np.sqrt(self.num_out + self.num_in) + ) + if self.bias is not None: + nn.init.normal_(self.bias.data, mean=bavg, std=stddev) + if self.idt is not None: + nn.init.normal_(self.idt.data, mean=0.1, std=0.001) + + def _trunc_normal_init(self, scale=1.0): + # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 + _, fan_in = self.matrix.shape + scale = scale / max(1, fan_in) + std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR + nn.init.trunc_normal_(self.matrix, mean=0.0, std=std) + + def _glorot_uniform_init(self): + nn.init.xavier_uniform_(self.matrix, gain=1) + + def _zero_init(self, use_bias=True): + with torch.no_grad(): + self.matrix.fill_(0.0) + if use_bias and self.bias is not None: + with torch.no_grad(): + self.bias.fill_(1.0) + + def _normal_init(self): + nn.init.kaiming_normal_(self.matrix, nonlinearity="linear") + def forward( self, xx: torch.Tensor, diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index d91a490855..5735a428e8 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -2225,6 +2225,11 @@ def update_attention_layers_serialize(self, data: dict): new_dict["attention_layers"][layer_idx]["attention_layer"].update( update_info ) + new_dict["attention_layers"][layer_idx]["attention_layer"].update( + { + "num_heads": 1, + } + ) return new_dict @classmethod diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 1e24091018..1763cc4ef8 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -262,7 +262,7 @@ def descrpt_se_a_args(): float, optional=True, default=0.0, - doc=doc_only_tf_supported + doc_env_protection, + doc=doc_only_pt_supported + doc_env_protection, ), Argument( "set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero @@ -341,6 +341,7 @@ def descrpt_se_r_args(): doc_seed = "Random seed for parameter initialization" doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used" + doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." return [ Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel), @@ -373,6 +374,13 @@ def descrpt_se_r_args(): Argument( "set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero ), + Argument( + "env_protection", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_env_protection, + ), ] @@ -404,25 +412,15 @@ def descrpt_se_atten_common_args(): doc_neuron = "Number of neurons in each hidden layers of the embedding net. When two layers are of the same size or one layer is twice as large as the previous layer, a skip connection is built." doc_axis_neuron = "Size of the submatrix of G (embedding matrix)." doc_activation_function = f'The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' - doc_resnet_dt = ( - doc_only_tf_supported + 'Whether to use a "Timestep" in the skip connection' - ) - doc_type_one_side = ( - doc_only_tf_supported - + r"If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. Default is 'False'." - ) - doc_precision = ( - doc_only_tf_supported - + f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." - ) + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' + doc_type_one_side = r"If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. Default is 'False'." + doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." doc_trainable = ( doc_only_tf_supported + "If the parameters in the embedding net is trainable" ) doc_seed = "Random seed for parameter initialization" - doc_exclude_types = ( - doc_only_tf_supported - + "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." - ) + doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." + doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." doc_attn = "The length of hidden vectors in attention layers" doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` is only enabled when attn_layer==0 and tebd_input_mode=='strip'" doc_attn_dotr = "Whether to do dot product with the normalized relative coordinates" @@ -466,6 +464,13 @@ def descrpt_se_atten_common_args(): default=[], doc=doc_exclude_types, ), + Argument( + "env_protection", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_env_protection, + ), Argument("attn", int, optional=True, default=128, doc=doc_attn), Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer), Argument("attn_dotr", bool, optional=True, default=True, doc=doc_attn_dotr), @@ -475,7 +480,7 @@ def descrpt_se_atten_common_args(): @descrpt_args_plugin.register("se_atten", alias=["dpa1"]) def descrpt_se_atten_args(): - doc_smooth_type_embedding = f"Whether to use smooth process in attention weights calculation. {doc_only_tf_supported} When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." + doc_smooth_type_embedding = "Whether to use smooth process in attention weights calculation. When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used" doc_trainable_ln = ( "Whether to use trainable shift and scale weights in layer normalization." @@ -580,87 +585,121 @@ def descrpt_se_atten_v2_args(): @descrpt_args_plugin.register("dpa2", doc=doc_only_pt_supported) def descrpt_dpa2_args(): - # Generate by GitHub Copilot - doc_repinit_rcut = "The cut-off radius of the repinit block" - doc_repinit_rcut_smth = "From this position the inverse distance smoothly decays to 0 at the cut-off. Use in the repinit block." - doc_repinit_nsel = "Maximally possible number of neighbors for repinit block." - doc_repformer_rcut = "The cut-off radius of the repformer block" - doc_repformer_rcut_smth = "From this position the inverse distance smoothly decays to 0 at the cut-off. Use in the repformer block." - doc_repformer_nsel = "Maximally possible number of neighbors for repformer block." - doc_tebd_dim = "The dimension of atom type embedding" - doc_concat_output_tebd = ( - "Whether to concat type embedding at the output of the descriptor." + # repinit args + doc_repinit = "(Used in the repinit block.) " + doc_repinit_rcut = f"{doc_repinit}The cut-off radius." + doc_repinit_rcut_smth = f"{doc_repinit}Where to start smoothing. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." + doc_repinit_nsel = f"{doc_repinit}Maximally possible number of selected neighbors." + doc_repinit_neuron = ( + f"{doc_repinit}Number of neurons in each hidden layers of the embedding net." + f"When two layers are of the same size or one layer is twice as large as the previous layer, " + f"a skip connection is built." ) - doc_repinit_neuron = "repinit block: the number of neurons in the embedding net." doc_repinit_axis_neuron = ( - "repinit block: the number of dimension of split in the symmetrization op." + f"{doc_repinit}Size of the submatrix of G (embedding matrix)." + ) + doc_repinit_tebd_dim = f"{doc_repinit}The dimension of atom type embedding." + doc_repinit_tebd_input_mode = ( + f"{doc_repinit}The input mode of the type embedding. Supported modes are ['concat', 'strip']." + "- 'concat': Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. " + "When `type_one_side` is False, the input is `input_ij = concat([r_ij, tebd_j, tebd_i])`. When `type_one_side` is True, the input is `input_ij = concat([r_ij, tebd_j])`. " + "The output is `out_ij = embeding(input_ij)` for the pair-wise representation of atom i with neighbor j." + "- 'strip': Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. " + f"When `type_one_side` is False, the input is `input_t = concat([tebd_j, tebd_i])`. {doc_only_pt_supported} When `type_one_side` is True, the input is `input_t = tebd_j`. " + "The output is `out_ij = embeding_t(input_t) * embeding_s(r_ij) + embeding_s(r_ij)` for the pair-wise representation of atom i with neighbor j." + ) + doc_repinit_set_davg_zero = ( + f"{doc_repinit}Set the normalization average to zero. " + f"This option should be set when `atom_ener` in the energy fitting is used." + ) + doc_repinit_activation_function = f"{doc_repinit}The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." + + # repformer args + doc_repformer = "(Used in the repformer block.) " + doc_repformer_rcut = f"{doc_repformer}The cut-off radius." + doc_repformer_rcut_smth = f"{doc_repformer}Where to start smoothing. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." + doc_repformer_nsel = ( + f"{doc_repformer}Maximally possible number of selected neighbors." + ) + doc_repformer_nlayers = f"{doc_repformer}The number of repformer layers." + doc_repformer_g1_dim = ( + f"{doc_repformer}The dimension of invariant single-atom representation." + ) + doc_repformer_g2_dim = ( + f"{doc_repformer}The dimension of invariant pair-atom representation." ) - doc_repinit_activation = ( - "repinit block: the activation function in the embedding net" + doc_repformer_axis_neuron = f"{doc_repformer}The number of dimension of submatrix in the symmetrization ops." + doc_repformer_direct_dist = f"{doc_repformer}Whether or not use direct distance as input for the embedding net to get g2 instead of smoothed 1/r." + doc_repformer_do_bn_mode = ( + f"{doc_repformer}The mode to do batch normalization in the repformer layers. " + f"Supported options are: " + f"-'no': Not do batch normalization." + f"-'uniform': Do batch normalization using scalar running momentum and learnable gamma/beta (num_features=1)." + f"-'component': Do batch normalization using vector running momentum and learnable gamma/beta (num_features=d)." ) - doc_repformer_nlayers = "repformers block: the number of repformer layers" - doc_repformer_g1_dim = "repformers block: the dimension of single-atom rep" - doc_repformer_g2_dim = "repformers block: the dimension of invariant pair-atom rep" - doc_repformer_axis_dim = ( - "repformers block: the number of dimension of split in the symmetrization ops." + doc_repformer_bn_momentum = ( + f"{doc_repformer}Momentum used in the batch normalization." ) - doc_repformer_do_bn_mode = "repformers block: do batch norm in the repformer layers" - doc_repformer_bn_momentum = "repformers block: moment in the batch normalization" doc_repformer_update_g1_has_conv = ( - "repformers block: update the g1 rep with convolution term" + f"{doc_repformer}Update the g1 rep with convolution term." ) doc_repformer_update_g1_has_drrd = ( - "repformers block: update the g1 rep with the drrd term" + f"{doc_repformer}Update the g1 rep with the drrd term." ) doc_repformer_update_g1_has_grrg = ( - "repformers block: update the g1 rep with the grrg term" + f"{doc_repformer}Update the g1 rep with the grrg term." ) doc_repformer_update_g1_has_attn = ( - "repformers block: update the g1 rep with the localized self-attention" + f"{doc_repformer}Update the g1 rep with the localized self-attention." ) doc_repformer_update_g2_has_g1g1 = ( - "repformers block: update the g2 rep with the g1xg1 term" + f"{doc_repformer}Update the g2 rep with the g1xg1 term." ) doc_repformer_update_g2_has_attn = ( - "repformers block: update the g2 rep with the gated self-attention" - ) - doc_repformer_update_h2 = "repformers block: update the h2 rep" - doc_repformer_attn1_hidden = ( - "repformers block: the hidden dimension of localized self-attention" - ) - doc_repformer_attn1_nhead = ( - "repformers block: the number of heads in localized self-attention" - ) - doc_repformer_attn2_hidden = ( - "repformers block: the hidden dimension of gated self-attention" - ) - doc_repformer_attn2_nhead = ( - "repformers block: the number of heads in gated self-attention" + f"{doc_repformer}Update the g2 rep with the gated self-attention." + ) + doc_repformer_update_h2 = f"{doc_repformer}Update the h2 rep." + doc_repformer_attn1_hidden = f"{doc_repformer}The hidden dimension of localized self-attention to update the g1 rep." + doc_repformer_attn1_nhead = f"{doc_repformer}The number of heads in localized self-attention to update the g1 rep." + doc_repformer_attn2_hidden = f"{doc_repformer}The hidden dimension of gated self-attention to update the g2 rep." + doc_repformer_attn2_nhead = f"{doc_repformer}The number of heads in gated self-attention to update the g2 rep." + doc_repformer_attn2_has_gate = f"{doc_repformer}Whether to use gate in the gated self-attention to update the g2 rep." + doc_repformer_activation_function = f"{doc_repformer}The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." + doc_repformer_update_style = ( + f"{doc_repformer}Style to update a representation. " + f"Supported options are: " + "-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) " + "-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)" + ) + doc_repformer_set_davg_zero = ( + f"{doc_repformer}Set the normalization average to zero. " + f"This option should be set when `atom_ener` in the energy fitting is used." + ) + + # descriptor args + doc_concat_output_tebd = ( + "Whether to concat type embedding at the output of the descriptor." ) - doc_repformer_attn2_has_gate = ( - "repformers block: has gate in the gated self-attention" + doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_smooth = ( + "Whether to use smoothness in processes such as attention weights calculation." ) - doc_repformer_activation = "repformers block: the activation function in the MLPs." - doc_repformer_update_style = "repformers block: style of update a rep. can be res_avg or res_incr. res_avg updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) res_incr updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)" - doc_repformer_set_davg_zero = "repformers block: set the avg to zero in statistics" - doc_repformer_add_type_ebd_to_seq = ( - "repformers block: concatenate the type embedding at the output" + doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." + doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." + doc_trainable = "If the parameters in the embedding net is trainable." + doc_seed = "Random seed for parameter initialization." + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection.' + doc_trainable_ln = ( + "Whether to use trainable shift and scale weights in layer normalization." ) + doc_ln_eps = "The epsilon value for layer normalization. The default value for TensorFlow is set to 1e-3 to keep consistent with keras while set to 1e-5 in PyTorch and DP implementation." + doc_type_one_side = r"If true, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters." + doc_add_tebd_to_repinit_out = "Add type embedding to the output representation from repinit before inputting it into repformer." return [ + # repinit args Argument("repinit_rcut", float, doc=doc_repinit_rcut), Argument("repinit_rcut_smth", float, doc=doc_repinit_rcut_smth), Argument("repinit_nsel", int, doc=doc_repinit_nsel), - Argument("repformer_rcut", float, doc=doc_repformer_rcut), - Argument("repformer_rcut_smth", float, doc=doc_repformer_rcut_smth), - Argument("repformer_nsel", int, doc=doc_repformer_nsel), - Argument("tebd_dim", int, optional=True, default=8, doc=doc_tebd_dim), - Argument( - "concat_output_tebd", - bool, - optional=True, - default=True, - doc=doc_concat_output_tebd, - ), Argument( "repinit_neuron", list, @@ -675,14 +714,39 @@ def descrpt_dpa2_args(): default=16, doc=doc_repinit_axis_neuron, ), - Argument("repinit_set_davg_zero", bool, optional=True, default=True), Argument( - "repinit_activation", + "repinit_tebd_dim", + int, + optional=True, + default=8, + doc=doc_repinit_tebd_dim, + ), + Argument( + "repinit_tebd_input_mode", + str, + optional=True, + default="concat", + doc=doc_repinit_tebd_input_mode, + ), + Argument( + "repinit_set_davg_zero", + bool, + optional=True, + default=True, + doc=doc_repinit_set_davg_zero, + ), + Argument( + "repinit_activation_function", str, optional=True, default="tanh", - doc=doc_repinit_activation, + alias=["repinit_activation"], + doc=doc_repinit_activation_function, ), + # repformer args + Argument("repformer_rcut", float, doc=doc_repformer_rcut), + Argument("repformer_rcut_smth", float, doc=doc_repformer_rcut_smth), + Argument("repformer_nsel", int, doc=doc_repformer_nsel), Argument( "repformer_nlayers", int, @@ -701,11 +765,19 @@ def descrpt_dpa2_args(): "repformer_g2_dim", int, optional=True, default=16, doc=doc_repformer_g2_dim ), Argument( - "repformer_axis_dim", + "repformer_axis_neuron", int, optional=True, default=4, - doc=doc_repformer_axis_dim, + alias=["repformer_axis_dim"], + doc=doc_repformer_axis_neuron, + ), + Argument( + "repformer_direct_dist", + bool, + optional=True, + default=False, + doc=doc_repformer_direct_dist, ), Argument( "repformer_do_bn_mode", @@ -806,11 +878,12 @@ def descrpt_dpa2_args(): doc=doc_repformer_attn2_has_gate, ), Argument( - "repformer_activation", + "repformer_activation_function", str, optional=True, default="tanh", - doc=doc_repformer_activation, + alias=["repformer_activation"], + doc=doc_repformer_activation_function, ), Argument( "repformer_update_style", @@ -826,12 +899,47 @@ def descrpt_dpa2_args(): default=True, doc=doc_repformer_set_davg_zero, ), + # descriptor args + Argument( + "concat_output_tebd", + bool, + optional=True, + default=True, + doc=doc_concat_output_tebd, + ), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument("smooth", bool, optional=True, default=False, doc=doc_smooth), + Argument( + "exclude_types", + List[List[int]], + optional=True, + default=[], + doc=doc_exclude_types, + ), + Argument( + "env_protection", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_env_protection, + ), + Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), + Argument("seed", [int, None], optional=True, doc=doc_seed), + Argument("resnet_dt", bool, optional=True, default=False, doc=doc_resnet_dt), + Argument( + "trainable_ln", bool, optional=True, default=True, doc=doc_trainable_ln + ), + Argument("ln_eps", float, optional=True, default=None, doc=doc_ln_eps), + Argument( + "type_one_side", bool, optional=True, default=False, doc=doc_type_one_side + ), Argument( - "repformer_add_type_ebd_to_seq", + "add_tebd_to_repinit_out", bool, optional=True, default=False, - doc=doc_repformer_add_type_ebd_to_seq, + alias=["repformer_add_type_ebd_to_seq"], + doc=doc_add_tebd_to_repinit_out, ), ] diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py index c1b6f97b26..01cbf259a7 100644 --- a/source/tests/pt/model/test_dpa1.py +++ b/source/tests/pt/model/test_dpa1.py @@ -39,12 +39,12 @@ def test_consistency( dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec, sm, to, tm in itertools.product( + for idt, sm, to, tm, prec in itertools.product( [False, True], # resnet_dt - ["float64", "float32"], # precision [False, True], # smooth_type_embedding [False, True], # type_one_side ["concat", "strip"], # tebd_input_mode + ["float64", "float32"], # precision ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -65,7 +65,7 @@ def test_consistency( old_impl=False, ).to(env.DEVICE) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) - dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) rd0, _, _, _, _ = dd0( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), diff --git a/source/tests/pt/model/test_dpa2.py b/source/tests/pt/model/test_dpa2.py new file mode 100644 index 0000000000..4446151d30 --- /dev/null +++ b/source/tests/pt/model/test_dpa2.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +# from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DPDescrptDPA2 +from deepmd.pt.model.descriptor.dpa2 import ( + DescrptDPA2, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptDPA2(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + riti, + riz, + rp1c, + rp1d, + rp1g, + rp1a, + rp2g, + rp2a, + rph, + rp2gate, + rpz, + sm, + prec, + ) in itertools.product( + ["concat", "strip"], # repinit_tebd_input_mode + [ + True, + ], # repinit_set_davg_zero + [True, False], # repformer_update_g1_has_conv + [True, False], # repformer_update_g1_has_drrd + [True, False], # repformer_update_g1_has_grrg + [True, False], # repformer_update_g1_has_attn + [True, False], # repformer_update_g2_has_g1g1 + [True, False], # repformer_update_g2_has_attn + [ + False, + ], # repformer_update_h2 + [True, False], # repformer_attn2_has_gate + [ + True, + ], # repformer_set_davg_zero + [True, False], # smooth + ["float64", "float32"], # precision + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + # dpa2 new impl + dd0 = DescrptDPA2( + self.nt, + repinit_rcut=self.rcut, + repinit_rcut_smth=self.rcut_smth, + repinit_nsel=self.sel_mix, + repformer_rcut=self.rcut / 2, + repformer_rcut_smth=self.rcut_smth, + repformer_nsel=nnei // 2, + # kwargs for repinit + repinit_tebd_input_mode=riti, + repinit_set_davg_zero=riz, + # kwargs for repformer + repformer_nlayers=3, + repformer_g1_dim=20, + repformer_g2_dim=10, + repformer_axis_neuron=4, + repformer_update_g1_has_conv=rp1c, + repformer_update_g1_has_drrd=rp1d, + repformer_update_g1_has_grrg=rp1g, + repformer_update_g1_has_attn=rp1a, + repformer_update_g2_has_g1g1=rp2g, + repformer_update_g2_has_attn=rp2a, + repformer_update_h2=rph, + repformer_attn1_hidden=20, + repformer_attn1_nhead=2, + repformer_attn2_hidden=10, + repformer_attn2_nhead=2, + repformer_attn2_has_gate=rp2gate, + repformer_update_style="res_avg", + repformer_set_davg_zero=rpz, + # kwargs for descriptor + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + old_impl=False, + ).to(env.DEVICE) + + dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repinit.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd0.repformers.mean = torch.tensor(davg_2, dtype=dtype, device=env.DEVICE) + dd0.repformers.stddev = torch.tensor(dstd_2, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA2.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + # dp impl + # dd2 = DPDescrptDPA1.deserialize(dd0.serialize()) + # rd2, _, _, _, _ = dd2.call( + # self.coord_ext, + # self.atype_ext, + # self.nlist, + # ) + # np.testing.assert_allclose( + # rd0.detach().cpu().numpy(), + # rd2, + # rtol=rtol, + # atol=atol, + # err_msg=err_msg, + # ) + # old impl + if prec == "float64" and rp1a is False and rp2a is False and rph is False: + dd3 = DescrptDPA2( + self.nt, + repinit_rcut=self.rcut, + repinit_rcut_smth=self.rcut_smth, + repinit_nsel=self.sel_mix, + repformer_rcut=self.rcut / 2, + repformer_rcut_smth=self.rcut_smth, + repformer_nsel=nnei // 2, + # kwargs for repinit + repinit_tebd_input_mode=riti, + repinit_set_davg_zero=riz, + # kwargs for repformer + repformer_nlayers=3, + repformer_g1_dim=20, + repformer_g2_dim=10, + repformer_axis_neuron=4, + repformer_update_g1_has_conv=rp1c, + repformer_update_g1_has_drrd=rp1d, + repformer_update_g1_has_grrg=rp1g, + repformer_update_g1_has_attn=rp1a, + repformer_update_g2_has_g1g1=rp2g, + repformer_update_g2_has_attn=rp2a, + repformer_update_h2=rph, + repformer_attn1_hidden=20, + repformer_attn1_nhead=2, + repformer_attn2_hidden=10, + repformer_attn2_nhead=2, + repformer_attn2_has_gate=rp2gate, + repformer_update_style="res_avg", + repformer_set_davg_zero=rpz, + # kwargs for descriptor + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + old_impl=True, + ).to(env.DEVICE) + dd0_state_dict = dd0.state_dict() + dd3_state_dict = dd3.state_dict() + for i in dd0_state_dict: + if ".bias" in i and (".linear1." in i or ".linear2." in i): + dd0_state_dict[i] = dd0_state_dict[i].unsqueeze(0) + + dd3.load_state_dict(dd0_state_dict) + rd3, _, _, _, _ = dd3( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd3.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + + # def test_jit( + # self, + # ): + # rng = np.random.default_rng() + # nf, nloc, nnei = self.nlist.shape + # davg = rng.normal(size=(self.nt, nnei, 4)) + # dstd = rng.normal(size=(self.nt, nnei, 4)) + # dstd = 0.1 + np.abs(dstd) + # + # for idt, prec, sm, to, tm in itertools.product( + # [ + # False, + # ], # resnet_dt + # [ + # "float64", + # ], # precision + # [False, True], # smooth_type_embedding + # [False, True], # type_one_side + # ["concat", "strip"], # tebd_input_mode + # ): + # dtype = PRECISION_DICT[prec] + # rtol, atol = get_tols(prec) + # err_msg = f"idt={idt} prec={prec}" + # # dpa1 new impl + # dd0 = DescrptDPA2( + # self.rcut, + # self.rcut_smth, + # self.sel, + # self.nt, + # precision=prec, + # resnet_dt=idt, + # smooth_type_embedding=sm, + # type_one_side=to, + # tebd_input_mode=tm, + # old_impl=False, + # ) + # dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + # dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + # # dd1 = DescrptDPA1.deserialize(dd0.serialize()) + # model = torch.jit.script(dd0) + # # model = torch.jit.script(dd1) diff --git a/source/tests/pt/model/test_env_mat.py b/source/tests/pt/model/test_env_mat.py index 24ed886b86..cc7b426585 100644 --- a/source/tests/pt/model/test_env_mat.py +++ b/source/tests/pt/model/test_env_mat.py @@ -33,6 +33,7 @@ def setUp(self): dtype=np.float64, ).reshape([1, self.nall, 3]) self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) + self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall]) # sel = [5, 2] self.sel = [5, 2] self.sel_mix = [7] @@ -57,6 +58,10 @@ def setUp(self): self.atype_ext = np.concatenate( [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 ) + self.mapping = np.concatenate( + [self.mapping, self.mapping[:, self.perm]], axis=0 + ) + # permute the nlist nlist1 = self.nlist[:, self.perm[: self.nloc], :] mask = nlist1 == -1 From b7af498bc7aba586c702a6effd76833999d07e50 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 30 Apr 2024 17:33:08 +0800 Subject: [PATCH 13/37] Update test_dpa2.py --- source/tests/pt/model/test_dpa2.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/source/tests/pt/model/test_dpa2.py b/source/tests/pt/model/test_dpa2.py index 4446151d30..49dc9f5e47 100644 --- a/source/tests/pt/model/test_dpa2.py +++ b/source/tests/pt/model/test_dpa2.py @@ -157,7 +157,7 @@ def test_consistency( # err_msg=err_msg, # ) # old impl - if prec == "float64" and rp1a is False and rp2a is False and rph is False: + if prec == "float64": dd3 = DescrptDPA2( self.nt, repinit_rcut=self.rcut, @@ -197,9 +197,15 @@ def test_consistency( ).to(env.DEVICE) dd0_state_dict = dd0.state_dict() dd3_state_dict = dd3.state_dict() - for i in dd0_state_dict: - if ".bias" in i and (".linear1." in i or ".linear2." in i): + for i in list(dd0_state_dict.keys()): + if ".bias" in i and ( + ".linear1." in i or ".linear2." in i or ".head_map." in i + ): dd0_state_dict[i] = dd0_state_dict[i].unsqueeze(0) + if ".attn2_lm.matrix" in i: + dd0_state_dict[ + i.replace(".attn2_lm.matrix", ".attn2_lm.weight") + ] = dd0_state_dict.pop(i) dd3.load_state_dict(dd0_state_dict) rd3, _, _, _, _ = dd3( From 61d9794d47548a013e763422bf67cf9170fc35ef Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 7 May 2024 19:54:25 +0800 Subject: [PATCH 14/37] Add residual support --- deepmd/pt/model/descriptor/dpa2.py | 40 +- deepmd/pt/model/descriptor/repformer_layer.py | 605 +++++++++++++----- deepmd/pt/model/descriptor/repformers.py | 32 +- deepmd/pt/model/network/layernorm.py | 38 +- deepmd/utils/argcheck.py | 26 + source/tests/pt/model/test_dpa2.py | 160 +++-- 6 files changed, 668 insertions(+), 233 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 842c924662..c2eeef48e9 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -91,6 +91,8 @@ def __init__( repformer_attn2_has_gate: bool = False, repformer_activation_function: str = "tanh", repformer_update_style: str = "res_avg", + repformer_update_residual: float = 0.001, + repformer_update_residual_init: str = "norm", repformer_set_davg_zero: bool = True, # kwargs for descriptor concat_output_tebd: bool = True, @@ -218,6 +220,15 @@ def __init__( Supported options are: -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `repformer_update_residual` + and `repformer_update_residual_init`. + repformer_update_residual : float, optional + (Used in the repformer block.) + When update using residual mode, the initial std of residual vector weights. + repformer_update_residual_init : str, optional + (Used in the repformer block.) + When update using residual mode, the initialization mode of residual vector weights. repformer_set_davg_zero : bool, optional (Used in the repformer block.) Set the normalization average to zero. @@ -251,7 +262,7 @@ def __init__( Returns ------- descriptor: torch.Tensor - the descriptor of shape nb x nloc x g1_dim. + the descriptor of shape nf x nloc x g1_dim. invariant single-atom representation. g2: torch.Tensor invariant pair-atom representation. @@ -264,6 +275,9 @@ def __init__( """ super().__init__() + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-5 self.repinit = DescrptBlockSeAtten( repinit_rcut, repinit_rcut_smth, @@ -311,6 +325,8 @@ def __init__( attn2_has_gate=repformer_attn2_has_gate, activation_function=repformer_activation_function, update_style=repformer_update_style, + update_residual=repformer_update_residual, + update_residual_init=repformer_update_residual_init, set_davg_zero=repformer_set_davg_zero, smooth=smooth, exclude_types=exclude_types, @@ -346,6 +362,14 @@ def __init__( precision=precision, init="glorot", ) + self.tebd_transform = None + if self.add_tebd_to_repinit_out: + self.tebd_transform = MLPLayer( + repinit_tebd_dim, + self.repformers.dim_in, + bias=False, + precision=precision, + ) assert self.repinit.rcut > self.repformers.rcut assert self.repinit.sel[0] > self.repformers.sel[0] @@ -525,6 +549,12 @@ def serialize(self) -> dict: "type_embedding": self.type_embedding.embedding.serialize(), "g1_shape_tranform": self.g1_shape_tranform.serialize(), } + if self.add_tebd_to_repinit_out: + data.update( + { + "tebd_transform": self.tebd_transform.serialize(), + } + ) repinit_variable = { "embeddings": repinit.filter_layers.serialize(), "env_mat": DPEnvMat(repinit.rcut, repinit.rcut_smth).serialize(), @@ -564,10 +594,15 @@ def deserialize(cls, data: dict) -> "DescrptDPA2": repformers_variable = data.pop("repformers").copy() type_embedding = data.pop("type_embedding") g1_shape_tranform = data.pop("g1_shape_tranform") + tebd_transform = data.pop("tebd_transform", None) + add_tebd_to_repinit_out = data["add_tebd_to_repinit_out"] obj = cls(**data) obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( type_embedding ) + if add_tebd_to_repinit_out: + assert isinstance(tebd_transform, dict) + obj.tebd_transform = MLPLayer.deserialize(tebd_transform) if obj.repinit.dim_out != obj.repformers.dim_in: obj.g1_shape_tranform = MLPLayer.deserialize(g1_shape_tranform) @@ -662,6 +697,9 @@ def forward( ) # linear to change shape g1 = self.g1_shape_tranform(g1) + if self.add_tebd_to_repinit_out: + assert self.tebd_transform is not None + g1 = g1 + self.tebd_transform(g1_inp) # mapping g1 assert mapping is not None mapping_ext = ( diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 9ce590f10e..4480ed7e29 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -6,6 +6,7 @@ ) import torch +import torch.nn as nn from deepmd.pt.model.network.layernorm import ( LayerNorm, @@ -16,14 +17,57 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) from deepmd.pt.utils.utils import ( ActivationFn, + to_numpy_array, + to_torch_tensor, ) from deepmd.utils.version import ( check_version_compatibility, ) +def get_residual( + _dim: int, + _scale: float, + _mode: str = "norm", + trainable: bool = True, + precision: str = "float64", +) -> torch.Tensor: + r""" + Get residual tensor for one update vector. + + Parameters + ---------- + _dim : int + The dimension of the update vector. + _scale + The initial scale of the residual tensor. See `_mode` for details. + _mode + The mode of residual initialization for the residual tensor. + - "norm" (default): init residual using normal with `_scale` std. + - "const": init residual using element-wise constants of `_scale`. + trainable + Whether the residual tensor is trainable. + precision + The precision of the residual tensor. + """ + residual = nn.Parameter( + data=torch.zeros(_dim, dtype=PRECISION_DICT[precision], device=env.DEVICE), + requires_grad=trainable, + ) + if _mode == "norm": + nn.init.normal_(residual.data, std=_scale) + elif _mode == "const": + nn.init.constant_(residual.data, val=_scale) + else: + raise RuntimeError(f"Unsupported initialization mode '{_mode}'!") + return residual + + # common ops def _make_nei_g1( g1_ext: torch.Tensor, @@ -46,15 +90,15 @@ def _make_nei_g1( """ # nlist: nf x nloc x nnei - nb, nloc, nnei = nlist.shape + nf, nloc, nnei = nlist.shape # g1_ext: nf x nall x ng1 ng1 = g1_ext.shape[-1] # index: nf x (nloc x nnei) x ng1 - index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) + index = nlist.reshape(nf, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) # gg1 : nf x (nloc x nnei) x ng1 gg1 = torch.gather(g1_ext, dim=1, index=index) # gg1 : nf x nloc x nnei x ng1 - gg1 = gg1.view(nb, nloc, nnei, ng1) + gg1 = gg1.view(nf, nloc, nnei, ng1) return gg1 @@ -94,6 +138,133 @@ def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: return gg * sw.unsqueeze(-1) +def _cal_hg( + g: torch.Tensor, + h: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + smooth: bool = True, + epsilon: float = 1e-4, +) -> torch.Tensor: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + g + Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. + h + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape nf x nloc x 3 x ng. + """ + # g: nf x nloc x nnei x ng + # h: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nf, nloc, nnei, _ = g.shape + ng = g.shape[-1] + # nf x nloc x nnei x ng + g = _apply_nlist_mask(g, nlist_mask) + if not smooth: + # nf x nloc + invnnei = 1.0 / (epsilon + torch.sum(nlist_mask, dim=-1)) + # nf x nloc x 1 x 1 + invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) + else: + g = _apply_switch(g, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nf, nloc, 1, 1), dtype=g.dtype, device=g.device + ) + # nf x nloc x 3 x ng + hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei + return hg + + +def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + hg + The transposed rotation matrix, with shape nf x nloc x 3 x ng. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) + """ + # nf x nloc x 3 x ng2 + nf, nloc, _, ng = hg.shape + # nf x nloc x 3 x axis + hgm = torch.split(hg, axis_neuron, dim=-1)[0] + # nf x nloc x axis_neuron x ng + grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1) + # nf x nloc x (axis_neuron x ng) + grrg = grrg.view(nf, nloc, axis_neuron * ng) + return grrg + + +def symmetrization_op( + g: torch.Tensor, + h: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + axis_neuron: int, + smooth: bool = True, + epsilon: float = 1e-4, +) -> torch.Tensor: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + g + Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. + h + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + axis_neuron + Size of the submatrix. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + grrg + Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) + """ + # g: nf x nloc x nnei x ng + # h: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nf, nloc, nnei, _ = g.shape + # nf x nloc x 3 x ng + hg = _cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) + # nf x nloc x (axis_neuron x ng) + grrg = _cal_grrg(hg, axis_neuron) + return grrg + + class Atten2Map(torch.nn.Module): def __init__( self, @@ -395,27 +566,27 @@ def forward( nlist_mask: torch.Tensor, # nf x nloc x nnei sw: torch.Tensor, # nf x nloc x nnei ) -> torch.Tensor: - nb, nloc, nnei = nlist_mask.shape + nf, nloc, nnei = nlist_mask.shape ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num assert ni == g1.shape[-1] assert ni == gg1.shape[-1] # nf x nloc x nd x nh - g1q = self.mapq(g1).view(nb, nloc, nd, nh) + g1q = self.mapq(g1).view(nf, nloc, nd, nh) # nf x nloc x nh x nd g1q = torch.permute(g1q, (0, 1, 3, 2)) # nf x nloc x nnei x (nd+ni) x nh - gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh) + gg1kv = self.mapkv(gg1).view(nf, nloc, nnei, nd + ni, nh) gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3)) - # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 + # nf x nloc x nh x nnei x nd, nf x nloc x nh x nnei x ng1 gg1k, gg1v = torch.split(gg1kv, [nd, ni], dim=-1) - # nb x nloc x nh x 1 x nnei + # nf x nloc x nh x 1 x nnei attnw = torch.matmul(g1q.unsqueeze(-2), torch.transpose(gg1k, -1, -2)) / nd**0.5 - # nb x nloc x nh x nnei + # nf x nloc x nh x nnei attnw = attnw.squeeze(-2) - # mask the attenmap, nb x nloc x 1 x nnei + # mask the attenmap, nf x nloc x 1 x nnei attnw_mask = ~nlist_mask.unsqueeze(-2) - # nb x nloc x nh x nnei + # nf x nloc x nh x nnei if self.smooth: attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift else: @@ -431,11 +602,11 @@ def forward( if self.smooth: attnw = attnw * sw.unsqueeze(-2) - # nb x nloc x nh x ng1 + # nf x nloc x nh x ng1 ret = ( - torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) + torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nf, nloc, nh * ni) ) - # nb x nloc x ng1 + # nf x nloc x ng1 ret = self.head_map(ret) return ret @@ -510,6 +681,8 @@ def __init__( attn2_has_gate: bool = False, activation_function: str = "tanh", update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", smooth: bool = True, precision: str = "float64", trainable_ln: bool = True, @@ -545,6 +718,8 @@ def __init__( self.attn2_nhead = attn2_nhead self.attn2_has_gate = attn2_has_gate self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init self.smooth = smooth self.g1_dim = g1_dim self.g2_dim = g2_dim @@ -552,6 +727,26 @@ def __init__( self.ln_eps = ln_eps self.precision = precision + assert update_residual_init in [ + "norm", + "const", + ], "'update_residual_init' only support 'norm' or 'const'!" + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.g1_residual = [] + self.g2_residual = [] + self.h2_residual = [] + + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) self.linear1 = MLPLayer(g1_in_dim, g1_dim, precision=precision) self.linear2 = None @@ -560,17 +755,34 @@ def __init__( self.attn2g_map = None self.attn2_mh_apply = None self.attn2_lm = None - self.attn2h_map = None self.attn2_ev_apply = None self.loc_attn = None if self.update_chnnl_2: self.linear2 = MLPLayer(g2_dim, g2_dim, precision=precision) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) if self.update_g1_has_conv: self.proj_g1g2 = MLPLayer(g1_dim, g2_dim, bias=False, precision=precision) if self.update_g2_has_g1g1: self.proj_g1g1g2 = MLPLayer(g1_dim, g2_dim, bias=False, precision=precision) - if self.update_g2_has_attn: + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + if self.update_g2_has_attn or self.update_h2: self.attn2g_map = Atten2Map( g2_dim, attn2_hidden, @@ -579,28 +791,49 @@ def __init__( self.smooth, precision=precision, ) - self.attn2_mh_apply = Atten2MultiHeadApply( - g2_dim, attn2_nhead, precision=precision - ) - self.attn2_lm = LayerNorm( - g2_dim, eps=ln_eps, trainable=trainable_ln, precision=precision - ) - if self.update_h2: - self.attn2h_map = Atten2Map( - g2_dim, - attn2_hidden, - attn2_nhead, - attn2_has_gate, - self.smooth, - precision=precision, - ) - self.attn2_ev_apply = Atten2EquiVarApply( - g2_dim, attn2_nhead, precision=precision - ) + if self.update_g2_has_attn: + self.attn2_mh_apply = Atten2MultiHeadApply( + g2_dim, attn2_nhead, precision=precision + ) + self.attn2_lm = LayerNorm( + g2_dim, eps=ln_eps, trainable=trainable_ln, precision=precision + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + + if self.update_h2: + self.attn2_ev_apply = Atten2EquiVarApply( + g2_dim, attn2_nhead, precision=precision + ) + if self.update_style == "res_residual": + self.h2_residual.append( + get_residual( + 1, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) if self.update_g1_has_attn: self.loc_attn = LocalAtten( g1_dim, attn1_hidden, attn1_nhead, self.smooth, precision=precision ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) if self.do_bn_mode == "uniform": self.bn1 = self._bn_layer() @@ -613,6 +846,10 @@ def __init__( else: raise RuntimeError(f"unknown bn_mode {self.do_bn_mode}") + self.g1_residual = nn.ParameterList(self.g1_residual) + self.g2_residual = nn.ParameterList(self.g2_residual) + self.h2_residual = nn.ParameterList(self.h2_residual) + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: ret = g1d if self.update_g1_has_grrg: @@ -625,21 +862,22 @@ def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: def _update_h2( self, - g2: torch.Tensor, h2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, + attn: torch.Tensor, ) -> torch.Tensor: - assert self.attn2h_map is not None + """ + Calculate the attention weights update for pair-wise equivariant rep. + + Parameters + ---------- + h2 + Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + attn + Attention weights from g2 attention, with shape nf x nloc x nnei x nnei x nh2. + """ assert self.attn2_ev_apply is not None - nb, nloc, nnei, _ = g2.shape - # # nb x nloc x nnei x nh2 - # h2_1 = self.attn2_ev_apply(AA, h2) - # h2_update.append(h2_1) - # nb x nloc x nnei x nnei x nh - AAh = self.attn2h_map(g2, h2, nlist_mask, sw) - # nb x nloc x nnei x nh2 - h2_1 = self.attn2_ev_apply(AAh, h2) + # nf x nloc x nnei x nh2 + h2_1 = self.attn2_ev_apply(attn, h2) return h2_1 def _update_g1_conv( @@ -649,93 +887,66 @@ def _update_g1_conv( nlist_mask: torch.Tensor, sw: torch.Tensor, ) -> torch.Tensor: + """ + Calculate the convolution update for atomic invariant rep. + + Parameters + ---------- + gg1 + Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + g2 + Pair invariant rep, with shape nf x nloc x nnei x ng2. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + """ assert self.proj_g1g2 is not None - nb, nloc, nnei, _ = g2.shape + nf, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] ng2 = g2.shape[-1] - # gg1 : nb x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) - # nb x nloc x nnei x ng2 + # gg1 : nf x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).view(nf, nloc, nnei, ng2) + # nf x nloc x nnei x ng2 gg1 = _apply_nlist_mask(gg1, nlist_mask) if not self.smooth: # normalized by number of neighbors, not smooth - # nb x nloc x 1 + # nf x nloc x 1 invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)).unsqueeze(-1) else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device + (nf, nloc, 1), dtype=gg1.dtype, device=gg1.device ) - # nb x nloc x ng2 + # nf x nloc x ng2 g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei return g1_11 - def _cal_h2g2( - self, - g2: torch.Tensor, - h2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - ) -> torch.Tensor: - # g2: nf x nloc x nnei x ng2 - # h2: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei - nb, nloc, nnei, _ = g2.shape - ng2 = g2.shape[-1] - # nb x nloc x nnei x ng2 - g2 = _apply_nlist_mask(g2, nlist_mask) - if not self.smooth: - # nb x nloc - invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)) - # nb x nloc x 1 x 1 - invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) - else: - g2 = _apply_switch(g2, sw) - invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device - ) - # nb x nloc x 3 x ng2 - h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei - return h2g2 - - def _cal_grrg(self, h2g2: torch.Tensor) -> torch.Tensor: - # nb x nloc x 3 x ng2 - nb, nloc, _, ng2 = h2g2.shape - # nb x nloc x 3 x axis - h2g2m = torch.split(h2g2, self.axis_neuron, dim=-1)[0] - # nb x nloc x axis x ng2 - g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) - # nb x nloc x (axisxng2) - g1_13 = g1_13.view(nb, nloc, self.axis_neuron * ng2) - return g1_13 - - def _update_g1_grrg( - self, - g2: torch.Tensor, - h2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - ) -> torch.Tensor: - # g2: nf x nloc x nnei x ng2 - # h2: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei - nb, nloc, nnei, _ = g2.shape - ng2 = g2.shape[-1] - # nb x nloc x 3 x ng2 - h2g2 = self._cal_h2g2(g2, h2, nlist_mask, sw) - # nb x nloc x (axisxng2) - g1_13 = self._cal_grrg(h2g2) - return g1_13 - def _update_g2_g1g1( self, - g1: torch.Tensor, # nb x nloc x ng1 - gg1: torch.Tensor, # nb x nloc x nnei x ng1 - nlist_mask: torch.Tensor, # nb x nloc x nnei - sw: torch.Tensor, # nb x nloc x nnei + g1: torch.Tensor, # nf x nloc x ng1 + gg1: torch.Tensor, # nf x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nf x nloc x nnei + sw: torch.Tensor, # nf x nloc x nnei ) -> torch.Tensor: + """ + Update the g2 using element-wise dot g1_i * g1_j. + + Parameters + ---------- + g1 + Atomic invariant rep, with shape nf x nloc x ng1. + gg1 + Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + """ ret = g1.unsqueeze(-2) * gg1 - # nb x nloc x nnei x ng1 + # nf x nloc x nnei x ng1 ret = _apply_nlist_mask(ret, nlist_mask) if self.smooth: ret = _apply_switch(ret, sw) @@ -754,30 +965,30 @@ def _apply_bn( return gg def _apply_nb_1(self, bn_number: int, gg: torch.Tensor) -> torch.Tensor: - nb, nl, nf = gg.shape - gg = gg.view([nb, 1, nl * nf]) + nf, nl, nf = gg.shape + gg = gg.view([nf, 1, nl * nf]) if bn_number == 1: assert self.bn1 is not None gg = self.bn1(gg) else: assert self.bn2 is not None gg = self.bn2(gg) - return gg.view([nb, nl, nf]) + return gg.view([nf, nl, nf]) def _apply_nb_2( self, bn_number: int, gg: torch.Tensor, ) -> torch.Tensor: - nb, nl, nnei, nf = gg.shape - gg = gg.view([nb, 1, nl * nnei * nf]) + nf, nl, nnei, nf = gg.shape + gg = gg.view([nf, 1, nl * nnei * nf]) if bn_number == 1: assert self.bn1 is not None gg = self.bn1(gg) else: assert self.bn2 is not None gg = self.bn2(gg) - return gg.view([nb, nl, nnei, nf]) + return gg.view([nf, nl, nnei, nf]) def _apply_bn_uni( self, @@ -840,11 +1051,11 @@ def forward( or self.update_g2_has_g1g1 ) - nb, nloc, nnei, _ = g2.shape + nf, nloc, nnei, _ = g2.shape nall = g1_ext.shape[1] g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) - assert (nb, nloc) == g1.shape[:2] - assert (nb, nloc, nnei) == h2.shape[:3] + assert (nf, nloc) == g1.shape[:2] + assert (nf, nloc, nnei) == h2.shape[:3] ng1 = g1.shape[-1] ng2 = g2.shape[-1] nh2 = h2.shape[-1] @@ -865,44 +1076,70 @@ def forward( gg1 = None if self.update_chnnl_2: - # nb x nloc x nnei x ng2 + # mlp(g2) assert self.linear2 is not None + # nf x nloc x nnei x ng2 g2_1 = self.act(self.linear2(g2)) g2_update.append(g2_1) if self.update_g2_has_g1g1: + # linear(g1_i * g1_j) assert gg1 is not None assert self.proj_g1g1g2 is not None g2_update.append( self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) ) - if self.update_g2_has_attn: + if self.update_g2_has_attn or self.update_h2: + # gated_attention(g2, h2) assert self.attn2g_map is not None - assert self.attn2_mh_apply is not None - assert self.attn2_lm is not None - # nb x nloc x nnei x nnei x nh + # nf x nloc x nnei x nnei x nh AAg = self.attn2g_map(g2, h2, nlist_mask, sw) - # nb x nloc x nnei x ng2 - g2_2 = self.attn2_mh_apply(AAg, g2) - g2_2 = self.attn2_lm(g2_2) - g2_update.append(g2_2) - if self.update_h2: - h2_update.append(self._update_h2(g2, h2, nlist_mask, sw)) + if self.update_g2_has_attn: + assert self.attn2_mh_apply is not None + assert self.attn2_lm is not None + # nf x nloc x nnei x ng2 + g2_2 = self.attn2_mh_apply(AAg, g2) + g2_2 = self.attn2_lm(g2_2) + g2_update.append(g2_2) + + if self.update_h2: + # linear_head(attention_weights * h2) + h2_update.append(self._update_h2(h2, AAg)) if self.update_g1_has_conv: assert gg1 is not None g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) if self.update_g1_has_grrg: - g1_mlp.append(self._update_g1_grrg(g2, h2, nlist_mask, sw)) + g1_mlp.append( + symmetrization_op( + g2, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) if self.update_g1_has_drrd: assert gg1 is not None - g1_mlp.append(self._update_g1_grrg(gg1, h2, nlist_mask, sw)) + g1_mlp.append( + symmetrization_op( + gg1, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) - # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] + # nf x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] # conv grrg drrd g1_1 = self.act(self.linear1(torch.cat(g1_mlp, dim=-1))) g1_update.append(g1_1) @@ -914,11 +1151,11 @@ def forward( # update if self.update_chnnl_2: - g2_new = self.list_update(g2_update) - h2_new = self.list_update(h2_update) + g2_new = self.list_update(g2_update, "g2") + h2_new = self.list_update(h2_update, "h2") else: g2_new, h2_new = g2, h2 - g1_new = self.list_update(g1_update) + g1_new = self.list_update(g1_update, "g1") return g1_new, g2_new, h2_new @torch.jit.export @@ -942,11 +1179,35 @@ def list_update_res_incr(self, update_list: List[torch.Tensor]) -> torch.Tensor: return uu @torch.jit.export - def list_update(self, update_list: List[torch.Tensor]) -> torch.Tensor: + def list_update_res_residual( + self, update_list: List[torch.Tensor], update_name: str = "g1" + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + # make jit happy + if update_name == "g1": + for ii, vv in enumerate(self.g1_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "g2": + for ii, vv in enumerate(self.g2_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "h2": + for ii, vv in enumerate(self.h2_residual): + uu = uu + vv * update_list[ii + 1] + else: + raise NotImplementedError + return uu + + @torch.jit.export + def list_update( + self, update_list: List[torch.Tensor], update_name: str = "g1" + ) -> torch.Tensor: if self.update_style == "res_avg": return self.list_update_res_avg(update_list) elif self.update_style == "res_incr": return self.list_update_res_incr(update_list) + elif self.update_style == "res_residual": + return self.list_update_res_residual(update_list, update_name=update_name) else: raise RuntimeError(f"unknown update style {self.update_style}") @@ -1023,25 +1284,38 @@ def serialize(self) -> dict: "proj_g1g1g2": self.proj_g1g1g2.serialize(), } ) - if self.update_g2_has_attn: + if self.update_g2_has_attn or self.update_h2: data.update( { "attn2g_map": self.attn2g_map.serialize(), - "attn2_mh_apply": self.attn2_mh_apply.serialize(), - "attn2_lm": self.attn2_lm.serialize(), } ) - if self.update_h2: + if self.update_g2_has_attn: + data.update( + { + "attn2_mh_apply": self.attn2_mh_apply.serialize(), + "attn2_lm": self.attn2_lm.serialize(), + } + ) + + if self.update_h2: + data.update( + { + "attn2_ev_apply": self.attn2_ev_apply.serialize(), + } + ) + if self.update_g1_has_attn: data.update( { - "attn2h_map": self.attn2h_map.serialize(), - "attn2_ev_apply": self.attn2_ev_apply.serialize(), + "loc_attn": self.loc_attn.serialize(), } ) - if self.update_g1_has_attn: + if self.update_style == "res_residual": data.update( { - "loc_attn": self.loc_attn.serialize(), + "g1_residual": [to_numpy_array(t) for t in self.g1_residual], + "g2_residual": [to_numpy_array(t) for t in self.g2_residual], + "h2_residual": [to_numpy_array(t) for t in self.h2_residual], } ) return data @@ -1065,6 +1339,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": update_g2_has_attn = data["update_g2_has_attn"] update_h2 = data["update_h2"] update_g1_has_attn = data["update_g1_has_attn"] + update_style = data["update_style"] linear2 = data.pop("linear2", None) proj_g1g2 = data.pop("proj_g1g2", None) @@ -1072,9 +1347,11 @@ def deserialize(cls, data: dict) -> "RepformerLayer": attn2g_map = data.pop("attn2g_map", None) attn2_mh_apply = data.pop("attn2_mh_apply", None) attn2_lm = data.pop("attn2_lm", None) - attn2h_map = data.pop("attn2h_map", None) attn2_ev_apply = data.pop("attn2_ev_apply", None) loc_attn = data.pop("loc_attn", None) + g1_residual = data.pop("g1_residual", []) + g2_residual = data.pop("g2_residual", []) + h2_residual = data.pop("h2_residual", []) obj = cls(**data) obj.linear1 = MLPLayer.deserialize(linear1) @@ -1087,19 +1364,25 @@ def deserialize(cls, data: dict) -> "RepformerLayer": if update_g2_has_g1g1: assert isinstance(proj_g1g1g2, dict) obj.proj_g1g1g2 = MLPLayer.deserialize(proj_g1g1g2) - if update_g2_has_attn: + if update_g2_has_attn or update_h2: assert isinstance(attn2g_map, dict) - assert isinstance(attn2_mh_apply, dict) - assert isinstance(attn2_lm, dict) obj.attn2g_map = Atten2Map.deserialize(attn2g_map) - obj.attn2_mh_apply = Atten2MultiHeadApply.deserialize(attn2_mh_apply) - obj.attn2_lm = LayerNorm.deserialize(attn2_lm) - if update_h2: - assert isinstance(attn2h_map, dict) - assert isinstance(attn2_ev_apply, dict) - obj.attn2h_map = Atten2Map.deserialize(attn2h_map) - obj.attn2_ev_apply = Atten2EquiVarApply.deserialize(attn2_ev_apply) + if update_g2_has_attn: + assert isinstance(attn2_mh_apply, dict) + assert isinstance(attn2_lm, dict) + obj.attn2_mh_apply = Atten2MultiHeadApply.deserialize(attn2_mh_apply) + obj.attn2_lm = LayerNorm.deserialize(attn2_lm) + if update_h2: + assert isinstance(attn2_ev_apply, dict) + obj.attn2_ev_apply = Atten2EquiVarApply.deserialize(attn2_ev_apply) if update_g1_has_attn: assert isinstance(loc_attn, dict) obj.loc_attn = LocalAtten.deserialize(loc_attn) + if update_style == "res_residual": + for ii, t in enumerate(obj.g1_residual): + t.data = to_torch_tensor(g1_residual[ii]) + for ii, t in enumerate(obj.g2_residual): + t.data = to_torch_tensor(g2_residual[ii]) + for ii, t in enumerate(obj.h2_residual): + t.data = to_torch_tensor(h2_residual[ii]) return obj diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 18ed502e59..b726ec3945 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -40,6 +40,7 @@ from .repformer_layer import ( RepformerLayer, + _cal_hg, ) from .repformer_layer_old_impl import RepformerLayer as RepformerLayerOld @@ -74,6 +75,8 @@ def __init__( attn2_has_gate: bool = False, activation_function: str = "tanh", update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", set_davg_zero: bool = True, smooth: bool = True, exclude_types: List[Tuple[int, int]] = [], @@ -145,6 +148,13 @@ def __init__( Supported options are: -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. set_davg_zero : bool, optional Set the normalization average to zero. precision : str, optional @@ -196,6 +206,8 @@ def __init__( self.bn_momentum = bn_momentum self.activation_function = activation_function self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init self.direct_dist = direct_dist self.act = ActivationFn(activation_function) self.smooth = smooth @@ -206,6 +218,7 @@ def __init__( self.resnet_dt = resnet_dt self.trainable_ln = trainable_ln self.ln_eps = ln_eps + self.epsilon = 1e-4 self.old_impl = old_impl self.g2_embd = MLPLayer(1, self.g2_dim) @@ -268,6 +281,8 @@ def __init__( attn2_nhead=self.attn2_nhead, activation_function=self.activation_function, update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, smooth=self.smooth, trainable_ln=self.trainable_ln, ln_eps=self.ln_eps, @@ -402,24 +417,24 @@ def forward( assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim] g1 = self.act(atype_embd) - # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + # nf x nloc x nnei x 1, nf x nloc x nnei x 3 if not self.direct_dist: g2, h2 = torch.split(dmatrix, [1, 3], dim=-1) else: g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut - # nb x nloc x nnei x ng2 + # nf x nloc x nnei x ng2 g2 = self.act(self.g2_embd(g2)) # set all padding positions to index of 0 # if the a neighbor is real or not is indicated by nlist_mask nlist[nlist == -1] = 0 - # nb x nall x ng1 + # nf x nall x ng1 mapping = mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim) for idx, ll in enumerate(self.layers): - # g1: nb x nloc x ng1 - # g1_ext: nb x nall x ng1 + # g1: nf x nloc x ng1 + # g1_ext: nf x nall x ng1 g1_ext = torch.gather(g1, 1, mapping) g1, g2, h2 = ll.forward( g1_ext, @@ -430,10 +445,9 @@ def forward( sw, ) - # uses the last layer. - # nb x nloc x 3 x ng2 - h2g2 = ll._cal_h2g2(g2, h2, nlist_mask, sw) - # (nb x nloc) x ng2 x 3 + # nf x nloc x 3 x ng2 + h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon) + # (nf x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw diff --git a/deepmd/pt/model/network/layernorm.py b/deepmd/pt/model/network/layernorm.py index 27b9808010..7c58e248ba 100644 --- a/deepmd/pt/model/network/layernorm.py +++ b/deepmd/pt/model/network/layernorm.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np import torch import torch.nn as nn @@ -15,14 +16,14 @@ to_torch_tensor, ) -from .mlp import ( - MLPLayer, -) - device = env.DEVICE -class LayerNorm(MLPLayer): +def empty_t(shape, precision): + return torch.empty(shape, dtype=precision, device=device) + + +class LayerNorm(nn.Module): def __init__( self, num_in, @@ -33,24 +34,22 @@ def __init__( precision: str = DEFAULT_PRECISION, trainable: bool = True, ): + super().__init__() self.eps = eps self.uni_init = uni_init self.num_in = num_in - super().__init__( - num_in=1, - num_out=num_in, - bias=True, - use_timestep=False, - activation_function=None, - resnet=False, - bavg=bavg, - stddev=stddev, - precision=precision, + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.matrix = nn.Parameter(data=empty_t((num_in,), self.prec)) + self.bias = nn.Parameter( + data=empty_t([num_in], self.prec), ) - self.matrix = torch.nn.Parameter(self.matrix.squeeze(0)) if self.uni_init: nn.init.ones_(self.matrix.data) nn.init.zeros_(self.bias.data) + else: + nn.init.normal_(self.bias.data, mean=bavg, std=stddev) + nn.init.normal_(self.matrix.data, std=stddev / np.sqrt(self.num_in)) self.trainable = trainable if not self.trainable: self.matrix.requires_grad = False @@ -75,8 +74,11 @@ def forward( yy: torch.Tensor The output. """ - mean = xx.mean(dim=-1, keepdim=True) - variance = xx.var(dim=-1, unbiased=False, keepdim=True) + # mean = xx.mean(dim=-1, keepdim=True) + # variance = xx.var(dim=-1, unbiased=False, keepdim=True) + # The following operation is the same as above, but will not raise error when using jit model to inference. + # See https://github.com/pytorch/pytorch/issues/85792 + variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True) yy = (xx - mean) / torch.sqrt(variance + self.eps) if self.matrix is not None and self.bias is not None: yy = yy * self.matrix + self.bias diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 1763cc4ef8..34ff91b5c3 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -670,6 +670,18 @@ def descrpt_dpa2_args(): f"Supported options are: " "-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) " "-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)" + "-'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) " + "where `r1`, `r2` ... `r3` are residual weights defined by `repformer_update_residual` " + "and `repformer_update_residual_init`." + ) + doc_repformer_update_residual = ( + f"{doc_repformer}When update using residual mode, " + "the initial std of residual vector weights." + ) + doc_repformer_update_residual_init = ( + f"{doc_repformer}When update using residual mode, " + "the initialization mode of residual vector weights." + "Supported modes are: ['norm', 'const']." ) doc_repformer_set_davg_zero = ( f"{doc_repformer}Set the normalization average to zero. " @@ -892,6 +904,20 @@ def descrpt_dpa2_args(): default="res_avg", doc=doc_repformer_update_style, ), + Argument( + "repformer_update_residual", + float, + optional=True, + default=0.001, + doc=doc_repformer_update_residual, + ), + Argument( + "repformer_update_residual_init", + str, + optional=True, + default="norm", + doc=doc_repformer_update_residual_init, + ), Argument( "repformer_set_davg_zero", bool, diff --git a/source/tests/pt/model/test_dpa2.py b/source/tests/pt/model/test_dpa2.py index 49dc9f5e47..a2ae57e549 100644 --- a/source/tests/pt/model/test_dpa2.py +++ b/source/tests/pt/model/test_dpa2.py @@ -52,6 +52,7 @@ def test_consistency( rp2a, rph, rp2gate, + rus, rpz, sm, prec, @@ -70,11 +71,12 @@ def test_consistency( False, ], # repformer_update_h2 [True, False], # repformer_attn2_has_gate + ["res_avg", "res_residual"], # repformer_update_style [ True, ], # repformer_set_davg_zero [True, False], # smooth - ["float64", "float32"], # precision + ["float64"], # precision ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -108,7 +110,7 @@ def test_consistency( repformer_attn2_hidden=10, repformer_attn2_nhead=2, repformer_attn2_has_gate=rp2gate, - repformer_update_style="res_avg", + repformer_update_style=rus, repformer_set_davg_zero=rpz, # kwargs for descriptor smooth=sm, @@ -157,7 +159,7 @@ def test_consistency( # err_msg=err_msg, # ) # old impl - if prec == "float64": + if prec == "float64" and rus == "res_avg": dd3 = DescrptDPA2( self.nt, repinit_rcut=self.rcut, @@ -221,44 +223,114 @@ def test_consistency( atol=atol, ) - # def test_jit( - # self, - # ): - # rng = np.random.default_rng() - # nf, nloc, nnei = self.nlist.shape - # davg = rng.normal(size=(self.nt, nnei, 4)) - # dstd = rng.normal(size=(self.nt, nnei, 4)) - # dstd = 0.1 + np.abs(dstd) - # - # for idt, prec, sm, to, tm in itertools.product( - # [ - # False, - # ], # resnet_dt - # [ - # "float64", - # ], # precision - # [False, True], # smooth_type_embedding - # [False, True], # type_one_side - # ["concat", "strip"], # tebd_input_mode - # ): - # dtype = PRECISION_DICT[prec] - # rtol, atol = get_tols(prec) - # err_msg = f"idt={idt} prec={prec}" - # # dpa1 new impl - # dd0 = DescrptDPA2( - # self.rcut, - # self.rcut_smth, - # self.sel, - # self.nt, - # precision=prec, - # resnet_dt=idt, - # smooth_type_embedding=sm, - # type_one_side=to, - # tebd_input_mode=tm, - # old_impl=False, - # ) - # dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) - # dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) - # # dd1 = DescrptDPA1.deserialize(dd0.serialize()) - # model = torch.jit.script(dd0) - # # model = torch.jit.script(dd1) + def test_jit( + self, + ): + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + riti, + riz, + rp1c, + rp1d, + rp1g, + rp1a, + rp2g, + rp2a, + rph, + rp2gate, + rus, + rpz, + sm, + prec, + ) in itertools.product( + ["concat", "strip"], # repinit_tebd_input_mode + [ + True, + ], # repinit_set_davg_zero + [ + True, + ], # repformer_update_g1_has_conv + [ + True, + ], # repformer_update_g1_has_drrd + [ + True, + ], # repformer_update_g1_has_grrg + [ + True, + ], # repformer_update_g1_has_attn + [ + True, + ], # repformer_update_g2_has_g1g1 + [ + True, + ], # repformer_update_g2_has_attn + [ + False, + ], # repformer_update_h2 + [ + True, + ], # repformer_attn2_has_gate + ["res_avg", "res_residual"], # repformer_update_style + [ + True, + ], # repformer_set_davg_zero + [ + True, + ], # smooth + ["float64"], # precision + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + # dpa2 new impl + dd0 = DescrptDPA2( + self.nt, + repinit_rcut=self.rcut, + repinit_rcut_smth=self.rcut_smth, + repinit_nsel=self.sel_mix, + repformer_rcut=self.rcut / 2, + repformer_rcut_smth=self.rcut_smth, + repformer_nsel=nnei // 2, + # kwargs for repinit + repinit_tebd_input_mode=riti, + repinit_set_davg_zero=riz, + # kwargs for repformer + repformer_nlayers=3, + repformer_g1_dim=20, + repformer_g2_dim=10, + repformer_axis_neuron=4, + repformer_update_g1_has_conv=rp1c, + repformer_update_g1_has_drrd=rp1d, + repformer_update_g1_has_grrg=rp1g, + repformer_update_g1_has_attn=rp1a, + repformer_update_g2_has_g1g1=rp2g, + repformer_update_g2_has_attn=rp2a, + repformer_update_h2=rph, + repformer_attn1_hidden=20, + repformer_attn1_nhead=2, + repformer_attn2_hidden=10, + repformer_attn2_nhead=2, + repformer_attn2_has_gate=rp2gate, + repformer_update_style=rus, + repformer_set_davg_zero=rpz, + # kwargs for descriptor + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + old_impl=False, + ).to(env.DEVICE) + + dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repinit.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd0.repformers.mean = torch.tensor(davg_2, dtype=dtype, device=env.DEVICE) + dd0.repformers.stddev = torch.tensor(dstd_2, dtype=dtype, device=env.DEVICE) + model = torch.jit.script(dd0) From 0e4fe1c9001d12dc78ecc67bc75e2372526473cc Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 7 May 2024 20:02:56 +0800 Subject: [PATCH 15/37] rm bn support --- deepmd/pt/model/descriptor/dpa2.py | 15 --- deepmd/pt/model/descriptor/repformer_layer.py | 107 ------------------ deepmd/pt/model/descriptor/repformers.py | 15 --- deepmd/utils/argcheck.py | 24 ---- 4 files changed, 161 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index c2eeef48e9..54bcce6969 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -75,8 +75,6 @@ def __init__( repformer_g2_dim: int = 16, repformer_axis_neuron: int = 4, repformer_direct_dist: bool = False, - repformer_do_bn_mode: str = "no", - repformer_bn_momentum: float = 0.1, repformer_update_g1_has_conv: bool = True, repformer_update_g1_has_drrd: bool = True, repformer_update_g1_has_grrg: bool = True, @@ -166,15 +164,6 @@ def __init__( repformer_direct_dist : bool, optional (Used in the repformer block.) Whether to use direct distance information (1/r term) in the repformer block. - repformer_do_bn_mode : str, optional - (Used in the repformer block.) - The mode to do batch normalization in the repformer layers. Supported modes are: - -'no': Not do batch normalization. - -'uniform': Do batch normalization using scalar running momentum and learnable gamma/beta (num_features=1). - -'component': Do batch normalization using vector running momentum and learnable gamma/beta (num_features=d). - repformer_bn_momentum : float, optional - (Used in the repformer block.) - Momentum used in the batch normalization. repformer_update_g1_has_conv : bool, optional (Used in the repformer block.) Whether to update the g1 rep with convolution term. @@ -309,8 +298,6 @@ def __init__( g2_dim=repformer_g2_dim, axis_neuron=repformer_axis_neuron, direct_dist=repformer_direct_dist, - do_bn_mode=repformer_do_bn_mode, - bn_momentum=repformer_bn_momentum, update_g1_has_conv=repformer_update_g1_has_conv, update_g1_has_drrd=repformer_update_g1_has_drrd, update_g1_has_grrg=repformer_update_g1_has_grrg, @@ -518,8 +505,6 @@ def serialize(self) -> dict: "repformer_g2_dim": repformers.g2_dim, "repformer_axis_neuron": repformers.axis_neuron, "repformer_direct_dist": repformers.direct_dist, - "repformer_do_bn_mode": repformers.do_bn_mode, - "repformer_bn_momentum": repformers.bn_momentum, "repformer_update_g1_has_conv": repformers.update_g1_has_conv, "repformer_update_g1_has_drrd": repformers.update_g1_has_drrd, "repformer_update_g1_has_grrg": repformers.update_g1_has_grrg, diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 4480ed7e29..c839179ce6 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Callable, List, Optional, ) @@ -665,8 +664,6 @@ def __init__( g2_dim=16, axis_neuron: int = 4, update_chnnl_2: bool = True, - do_bn_mode: str = "no", - bn_momentum: float = 0.1, update_g1_has_conv: bool = True, update_g1_has_drrd: bool = True, update_g1_has_grrg: bool = True, @@ -699,8 +696,6 @@ def __init__( self.sel = sel self.sec = self.sel self.axis_neuron = axis_neuron - self.do_bn_mode = do_bn_mode - self.bn_momentum = bn_momentum self.activation_function = activation_function self.act = ActivationFn(activation_function) self.update_g1_has_grrg = update_g1_has_grrg @@ -835,17 +830,6 @@ def __init__( ) ) - if self.do_bn_mode == "uniform": - self.bn1 = self._bn_layer() - self.bn2 = self._bn_layer() - elif self.do_bn_mode == "component": - self.bn1 = self._bn_layer(nf=g1_dim) - self.bn2 = self._bn_layer(nf=g2_dim) - elif self.do_bn_mode == "no": - self.bn1, self.bn2 = None, None - else: - raise RuntimeError(f"unknown bn_mode {self.do_bn_mode}") - self.g1_residual = nn.ParameterList(self.g1_residual) self.g2_residual = nn.ParameterList(self.g2_residual) self.h2_residual = nn.ParameterList(self.h2_residual) @@ -952,73 +936,6 @@ def _update_g2_g1g1( ret = _apply_switch(ret, sw) return ret - def _apply_bn( - self, - bn_number: int, - gg: torch.Tensor, - ): - if self.do_bn_mode == "uniform": - return self._apply_bn_uni(bn_number, gg) - elif self.do_bn_mode == "component": - return self._apply_bn_comp(bn_number, gg) - else: - return gg - - def _apply_nb_1(self, bn_number: int, gg: torch.Tensor) -> torch.Tensor: - nf, nl, nf = gg.shape - gg = gg.view([nf, 1, nl * nf]) - if bn_number == 1: - assert self.bn1 is not None - gg = self.bn1(gg) - else: - assert self.bn2 is not None - gg = self.bn2(gg) - return gg.view([nf, nl, nf]) - - def _apply_nb_2( - self, - bn_number: int, - gg: torch.Tensor, - ) -> torch.Tensor: - nf, nl, nnei, nf = gg.shape - gg = gg.view([nf, 1, nl * nnei * nf]) - if bn_number == 1: - assert self.bn1 is not None - gg = self.bn1(gg) - else: - assert self.bn2 is not None - gg = self.bn2(gg) - return gg.view([nf, nl, nnei, nf]) - - def _apply_bn_uni( - self, - bn_number: int, - gg: torch.Tensor, - mode: str = "1", - ) -> torch.Tensor: - if len(gg.shape) == 3: - return self._apply_nb_1(bn_number, gg) - elif len(gg.shape) == 4: - return self._apply_nb_2(bn_number, gg) - else: - raise RuntimeError(f"unsupported input shape {gg.shape}") - - def _apply_bn_comp( - self, - bn_number: int, - gg: torch.Tensor, - ) -> torch.Tensor: - ss = gg.shape - nf = ss[-1] - gg = gg.view([-1, nf]) - if bn_number == 1: - assert self.bn1 is not None - gg = self.bn1(gg).view(ss) - else: - assert self.bn2 is not None - gg = self.bn2(gg).view(ss) - return gg - def forward( self, g1_ext: torch.Tensor, # nf x nall x ng1 @@ -1056,14 +973,6 @@ def forward( g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) assert (nf, nloc) == g1.shape[:2] assert (nf, nloc, nnei) == h2.shape[:3] - ng1 = g1.shape[-1] - ng2 = g2.shape[-1] - nh2 = h2.shape[-1] - - if self.bn1 is not None: - g1 = self._apply_bn(1, g1) - if self.bn2 is not None: - g2 = self._apply_bn(2, g2) g2_update: List[torch.Tensor] = [g2] h2_update: List[torch.Tensor] = [h2] @@ -1211,20 +1120,6 @@ def list_update( else: raise RuntimeError(f"unknown update style {self.update_style}") - def _bn_layer( - self, - nf: int = 1, - ) -> Callable: - return torch.nn.BatchNorm1d( - nf, - eps=1e-5, - momentum=self.bn_momentum, - affine=False, - track_running_stats=True, - device=env.DEVICE, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - ) - def serialize(self) -> dict: """Serialize the networks to a dict. @@ -1244,8 +1139,6 @@ def serialize(self) -> dict: "g2_dim": self.g2_dim, "axis_neuron": self.axis_neuron, "update_chnnl_2": self.update_chnnl_2, - "do_bn_mode": self.do_bn_mode, - "bn_momentum": self.bn_momentum, "update_g1_has_conv": self.update_g1_has_conv, "update_g1_has_drrd": self.update_g1_has_drrd, "update_g1_has_grrg": self.update_g1_has_grrg, diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index b726ec3945..da22356814 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -59,8 +59,6 @@ def __init__( g2_dim=16, axis_neuron: int = 4, direct_dist: bool = False, - do_bn_mode: str = "no", - bn_momentum: float = 0.1, update_g1_has_conv: bool = True, update_g1_has_drrd: bool = True, update_g1_has_grrg: bool = True, @@ -110,13 +108,6 @@ def __init__( Size of the submatrix of G (embedding matrix). direct_dist : bool, optional Whether to use direct distance information (1/r term) in the repformer block. - do_bn_mode : str, optional - The mode to do batch normalization in the repformer layers. Supported modes are: - -'no': Not do batch normalization. - -'uniform': Do batch normalization using scalar running momentum and learnable gamma/beta (num_features=1). - -'component': Do batch normalization using vector running momentum and learnable gamma/beta (num_features=d). - bn_momentum : float, optional - Momentum used in the batch normalization. update_g1_has_conv : bool, optional Whether to update the g1 rep with convolution term. update_g1_has_drrd : bool, optional @@ -202,8 +193,6 @@ def __init__( self.attn2_has_gate = attn2_has_gate self.attn2_hidden = attn2_hidden self.attn2_nhead = attn2_nhead - self.do_bn_mode = do_bn_mode - self.bn_momentum = bn_momentum self.activation_function = activation_function self.update_style = update_style self.update_residual = update_residual @@ -235,8 +224,6 @@ def __init__( self.g2_dim, axis_neuron=self.axis_neuron, update_chnnl_2=(ii != nlayers - 1), - do_bn_mode=self.do_bn_mode, - bn_momentum=self.bn_momentum, update_g1_has_conv=self.update_g1_has_conv, update_g1_has_drrd=self.update_g1_has_drrd, update_g1_has_grrg=self.update_g1_has_grrg, @@ -265,8 +252,6 @@ def __init__( self.g2_dim, axis_neuron=self.axis_neuron, update_chnnl_2=(ii != nlayers - 1), - do_bn_mode=self.do_bn_mode, - bn_momentum=self.bn_momentum, update_g1_has_conv=self.update_g1_has_conv, update_g1_has_drrd=self.update_g1_has_drrd, update_g1_has_grrg=self.update_g1_has_grrg, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 34ff91b5c3..71301ed3ff 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -630,16 +630,6 @@ def descrpt_dpa2_args(): ) doc_repformer_axis_neuron = f"{doc_repformer}The number of dimension of submatrix in the symmetrization ops." doc_repformer_direct_dist = f"{doc_repformer}Whether or not use direct distance as input for the embedding net to get g2 instead of smoothed 1/r." - doc_repformer_do_bn_mode = ( - f"{doc_repformer}The mode to do batch normalization in the repformer layers. " - f"Supported options are: " - f"-'no': Not do batch normalization." - f"-'uniform': Do batch normalization using scalar running momentum and learnable gamma/beta (num_features=1)." - f"-'component': Do batch normalization using vector running momentum and learnable gamma/beta (num_features=d)." - ) - doc_repformer_bn_momentum = ( - f"{doc_repformer}Momentum used in the batch normalization." - ) doc_repformer_update_g1_has_conv = ( f"{doc_repformer}Update the g1 rep with convolution term." ) @@ -791,20 +781,6 @@ def descrpt_dpa2_args(): default=False, doc=doc_repformer_direct_dist, ), - Argument( - "repformer_do_bn_mode", - str, - optional=True, - default="no", - doc=doc_repformer_do_bn_mode, - ), - Argument( - "repformer_bn_momentum", - float, - optional=True, - default=0.1, - doc=doc_repformer_bn_momentum, - ), Argument( "repformer_update_g1_has_conv", bool, From 7a1095cf540c82a19807da644c4a34e23ebadc6f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 8 May 2024 01:26:07 +0800 Subject: [PATCH 16/37] Add numpy impl for DPA2 --- deepmd/dpmodel/descriptor/__init__.py | 4 + deepmd/dpmodel/descriptor/dpa1.py | 6 +- deepmd/dpmodel/descriptor/dpa2.py | 678 +++++++ deepmd/dpmodel/descriptor/repformers.py | 1641 +++++++++++++++++ deepmd/dpmodel/descriptor/se_e2_a.py | 4 +- deepmd/dpmodel/descriptor/se_r.py | 2 +- deepmd/dpmodel/utils/env_mat.py | 8 +- deepmd/dpmodel/utils/network.py | 19 + deepmd/pt/model/descriptor/dpa2.py | 15 +- deepmd/pt/model/descriptor/repformer_layer.py | 10 +- deepmd/pt/model/descriptor/repformers.py | 2 +- .../dpmodel/case_single_frame_with_nlist.py | 5 + .../common/dpmodel/test_descriptor_dpa2.py | 49 + source/tests/common/dpmodel/test_env_mat.py | 9 +- source/tests/consistent/descriptor/common.py | 4 +- .../tests/consistent/descriptor/test_dpa2.py | 380 ++++ source/tests/pt/model/test_dpa2.py | 1 + source/tests/pt/model/test_env_mat.py | 7 +- 18 files changed, 2820 insertions(+), 24 deletions(-) create mode 100644 deepmd/dpmodel/descriptor/dpa2.py create mode 100644 deepmd/dpmodel/descriptor/repformers.py create mode 100644 source/tests/common/dpmodel/test_descriptor_dpa2.py create mode 100644 source/tests/consistent/descriptor/test_dpa2.py diff --git a/deepmd/dpmodel/descriptor/__init__.py b/deepmd/dpmodel/descriptor/__init__.py index bbc332588c..563fb9d149 100644 --- a/deepmd/dpmodel/descriptor/__init__.py +++ b/deepmd/dpmodel/descriptor/__init__.py @@ -2,6 +2,9 @@ from .dpa1 import ( DescrptDPA1, ) +from .dpa2 import ( + DescrptDPA2, +) from .hybrid import ( DescrptHybrid, ) @@ -19,6 +22,7 @@ "DescrptSeA", "DescrptSeR", "DescrptDPA1", + "DescrptDPA2", "DescrptHybrid", "make_base_descriptor", ] diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 4c1d3c6a0a..328d8dbbb6 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -222,7 +222,7 @@ def __init__( attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - exclude_types: List[List[int]] = [], + exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, set_davg_zero: bool = False, activation_function: str = "tanh", @@ -523,7 +523,7 @@ def __init__( attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - exclude_types: List[List[int]] = [], + exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, set_davg_zero: bool = False, activation_function: str = "tanh", @@ -742,7 +742,7 @@ def call( mapping: Optional[np.ndarray] = None, ): # nf x nloc x nnei x 4 - dmatrix, sw = self.env_mat.call( + dmatrix, diff, sw = self.env_mat.call( coord_ext, atype_ext, nlist, self.mean, self.stddev ) nf, nloc, nnei, _ = dmatrix.shape diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py new file mode 100644 index 0000000000..0d019590f6 --- /dev/null +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -0,0 +1,678 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from deepmd.dpmodel.utils.network import ( + Identity, + NativeLayer, +) +from deepmd.dpmodel.utils.nlist import ( + build_multiple_neighbor_list, + get_multiple_nlist_key, +) +from deepmd.dpmodel.utils.type_embed import ( + TypeEmbedNet, +) +from deepmd.dpmodel.utils.update_sel import ( + UpdateSel, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + +from typing import ( + List, + Optional, + Tuple, +) + +from deepmd.dpmodel import ( + NativeOP, +) +from deepmd.dpmodel.utils import ( + EnvMat, + NetworkCollection, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .dpa1 import ( + DescrptBlockSeAtten, +) +from .repformers import ( + DescrptBlockRepformers, + RepformerLayer, +) + + +@BaseDescriptor.register("dpa2") +class DescrptDPA2(NativeOP, BaseDescriptor): + def __init__( + self, + # args for repinit + ntypes: int, + repinit_rcut: float, + repinit_rcut_smth: float, + repinit_nsel: int, + repformer_rcut: float, + repformer_rcut_smth: float, + repformer_nsel: int, + # kwargs for repinit + repinit_neuron: List[int] = [25, 50, 100], + repinit_axis_neuron: int = 16, + repinit_tebd_dim: int = 8, + repinit_tebd_input_mode: str = "concat", + repinit_set_davg_zero: bool = True, + repinit_activation_function="tanh", + # kwargs for repformer + repformer_nlayers: int = 3, + repformer_g1_dim: int = 128, + repformer_g2_dim: int = 16, + repformer_axis_neuron: int = 4, + repformer_direct_dist: bool = False, + repformer_update_g1_has_conv: bool = True, + repformer_update_g1_has_drrd: bool = True, + repformer_update_g1_has_grrg: bool = True, + repformer_update_g1_has_attn: bool = True, + repformer_update_g2_has_g1g1: bool = True, + repformer_update_g2_has_attn: bool = True, + repformer_update_h2: bool = False, + repformer_attn1_hidden: int = 64, + repformer_attn1_nhead: int = 4, + repformer_attn2_hidden: int = 16, + repformer_attn2_nhead: int = 4, + repformer_attn2_has_gate: bool = False, + repformer_activation_function: str = "tanh", + repformer_update_style: str = "res_avg", + repformer_update_residual: float = 0.001, + repformer_update_residual_init: str = "norm", + repformer_set_davg_zero: bool = True, + # kwargs for descriptor + concat_output_tebd: bool = True, + precision: str = "float64", + smooth: bool = True, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, + trainable: bool = True, + seed: Optional[int] = None, + resnet_dt: bool = False, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + type_one_side: bool = False, + add_tebd_to_repinit_out: bool = False, + ): + r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492. + + Parameters + ---------- + repinit_rcut : float + (Used in the repinit block.) + The cut-off radius. + repinit_rcut_smth : float + (Used in the repinit block.) + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. + repinit_nsel : int + (Used in the repinit block.) + Maximally possible number of selected neighbors. + repinit_neuron : list, optional + (Used in the repinit block.) + Number of neurons in each hidden layers of the embedding net. + When two layers are of the same size or one layer is twice as large as the previous layer, + a skip connection is built. + repinit_axis_neuron : int, optional + (Used in the repinit block.) + Size of the submatrix of G (embedding matrix). + repinit_tebd_dim : int, optional + (Used in the repinit block.) + The dimension of atom type embedding. + repinit_tebd_input_mode : str, optional + (Used in the repinit block.) + The input mode of the type embedding. Supported modes are ['concat', 'strip']. + repinit_set_davg_zero : bool, optional + (Used in the repinit block.) + Set the normalization average to zero. + repinit_activation_function : str, optional + (Used in the repinit block.) + The activation function in the embedding net. + repformer_rcut : float + (Used in the repformer block.) + The cut-off radius. + repformer_rcut_smth : float + (Used in the repformer block.) + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. + repformer_nsel : int + (Used in the repformer block.) + Maximally possible number of selected neighbors. + repformer_nlayers : int, optional + (Used in the repformer block.) + Number of repformer layers. + repformer_g1_dim : int, optional + (Used in the repformer block.) + Dimension of the first graph convolution layer. + repformer_g2_dim : int, optional + (Used in the repformer block.) + Dimension of the second graph convolution layer. + repformer_axis_neuron : int, optional + (Used in the repformer block.) + Size of the submatrix of G (embedding matrix). + repformer_direct_dist : bool, optional + (Used in the repformer block.) + Whether to use direct distance information (1/r term) in the repformer block. + repformer_update_g1_has_conv : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with convolution term. + repformer_update_g1_has_drrd : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the drrd term. + repformer_update_g1_has_grrg : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the grrg term. + repformer_update_g1_has_attn : bool, optional + (Used in the repformer block.) + Whether to update the g1 rep with the localized self-attention. + repformer_update_g2_has_g1g1 : bool, optional + (Used in the repformer block.) + Whether to update the g2 rep with the g1xg1 term. + repformer_update_g2_has_attn : bool, optional + (Used in the repformer block.) + Whether to update the g2 rep with the gated self-attention. + repformer_update_h2 : bool, optional + (Used in the repformer block.) + Whether to update the h2 rep. + repformer_attn1_hidden : int, optional + (Used in the repformer block.) + The hidden dimension of localized self-attention to update the g1 rep. + repformer_attn1_nhead : int, optional + (Used in the repformer block.) + The number of heads in localized self-attention to update the g1 rep. + repformer_attn2_hidden : int, optional + (Used in the repformer block.) + The hidden dimension of gated self-attention to update the g2 rep. + repformer_attn2_nhead : int, optional + (Used in the repformer block.) + The number of heads in gated self-attention to update the g2 rep. + repformer_attn2_has_gate : bool, optional + (Used in the repformer block.) + Whether to use gate in the gated self-attention to update the g2 rep. + repformer_activation_function : str, optional + (Used in the repformer block.) + The activation function in the embedding net. + repformer_update_style : str, optional + (Used in the repformer block.) + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `repformer_update_residual` + and `repformer_update_residual_init`. + repformer_update_residual : float, optional + (Used in the repformer block.) + When update using residual mode, the initial std of residual vector weights. + repformer_update_residual_init : str, optional + (Used in the repformer block.) + When update using residual mode, the initialization mode of residual vector weights. + repformer_set_davg_zero : bool, optional + (Used in the repformer block.) + Set the normalization average to zero. + concat_output_tebd : bool, optional + Whether to concat type embedding at the output of the descriptor. + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : List[List[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable : bool, optional + If the parameters are trainable. + seed : int, optional + (Unused yet) Random seed for parameter initialization. + resnet_dt : bool, optional + Whether to use a "Timestep" in the skip connection. + trainable_ln : bool, optional + Whether to use trainable shift and scale weights in layer normalization. + ln_eps : float, optional + The epsilon value for layer normalization. + type_one_side : bool, optional + Whether to use one-side type embedding. + add_tebd_to_repinit_out : bool, optional + Whether to add type embedding to the output representation from repinit before inputting it into repformer. + + Returns + ------- + descriptor: torch.Tensor + the descriptor of shape nf x nloc x g1_dim. + invariant single-atom representation. + g2: torch.Tensor + invariant pair-atom representation. + h2: torch.Tensor + equivariant pair-atom representation. + rot_mat: torch.Tensor + rotation matrix for equivariant fittings + sw: torch.Tensor + The switch function for decaying inverse distance. + + """ + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-5 + self.repinit = DescrptBlockSeAtten( + repinit_rcut, + repinit_rcut_smth, + repinit_nsel, + ntypes, + attn_layer=0, + neuron=repinit_neuron, + axis_neuron=repinit_axis_neuron, + tebd_dim=repinit_tebd_dim, + tebd_input_mode=repinit_tebd_input_mode, + set_davg_zero=repinit_set_davg_zero, + exclude_types=exclude_types, + env_protection=env_protection, + activation_function=repinit_activation_function, + precision=precision, + resnet_dt=resnet_dt, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + smooth=smooth, + type_one_side=type_one_side, + ) + self.repformers = DescrptBlockRepformers( + repformer_rcut, + repformer_rcut_smth, + repformer_nsel, + ntypes, + nlayers=repformer_nlayers, + g1_dim=repformer_g1_dim, + g2_dim=repformer_g2_dim, + axis_neuron=repformer_axis_neuron, + direct_dist=repformer_direct_dist, + update_g1_has_conv=repformer_update_g1_has_conv, + update_g1_has_drrd=repformer_update_g1_has_drrd, + update_g1_has_grrg=repformer_update_g1_has_grrg, + update_g1_has_attn=repformer_update_g1_has_attn, + update_g2_has_g1g1=repformer_update_g2_has_g1g1, + update_g2_has_attn=repformer_update_g2_has_attn, + update_h2=repformer_update_h2, + attn1_hidden=repformer_attn1_hidden, + attn1_nhead=repformer_attn1_nhead, + attn2_hidden=repformer_attn2_hidden, + attn2_nhead=repformer_attn2_nhead, + attn2_has_gate=repformer_attn2_has_gate, + activation_function=repformer_activation_function, + update_style=repformer_update_style, + update_residual=repformer_update_residual, + update_residual_init=repformer_update_residual_init, + set_davg_zero=repformer_set_davg_zero, + smooth=smooth, + exclude_types=exclude_types, + env_protection=env_protection, + precision=precision, + resnet_dt=resnet_dt, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + ) + self.type_embedding = TypeEmbedNet( + ntypes=ntypes, + neuron=[repinit_tebd_dim], + padding=True, + activation_function="Linear", + precision=precision, + ) + self.concat_output_tebd = concat_output_tebd + self.precision = precision + self.smooth = smooth + self.exclude_types = exclude_types + self.env_protection = env_protection + self.trainable = trainable + self.resnet_dt = resnet_dt + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.type_one_side = type_one_side + self.add_tebd_to_repinit_out = add_tebd_to_repinit_out + + if self.repinit.dim_out == self.repformers.dim_in: + self.g1_shape_tranform = Identity() + else: + self.g1_shape_tranform = NativeLayer( + self.repinit.dim_out, + self.repformers.dim_in, + bias=False, + precision=precision, + ) + self.tebd_transform = None + if self.add_tebd_to_repinit_out: + self.tebd_transform = NativeLayer( + repinit_tebd_dim, + self.repformers.dim_in, + bias=False, + precision=precision, + ) + assert self.repinit.rcut > self.repformers.rcut + assert self.repinit.sel[0] > self.repformers.sel[0] + + self.tebd_dim = repinit_tebd_dim + self.rcut = self.repinit.get_rcut() + self.ntypes = ntypes + self.sel = self.repinit.sel + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + ret = self.repformers.dim_out + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + """Returns the embedding dimension of this descriptor.""" + return self.repformers.dim_emb + + def mixed_types(self) -> bool: + """If true, the discriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the discriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + raise NotImplementedError + + def call( + self, + coord_ext: np.ndarray, + atype_ext: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, maps extended region index to local region. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + nframes, nloc, nnei = nlist.shape + nall = coord_ext.reshape(nframes, -1).shape[1] // 3 + # nlists + nlist_dict = build_multiple_neighbor_list( + coord_ext, + nlist, + [self.repformers.get_rcut(), self.repinit.get_rcut()], + [self.repformers.get_nsel(), self.repinit.get_nsel()], + ) + # repinit + g1_ext = self.type_embedding.call()[atype_ext] + g1_inp = g1_ext[:, :nloc, :] + g1, _, _, _, _ = self.repinit( + nlist_dict[ + get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel()) + ], + coord_ext, + atype_ext, + g1_ext, + mapping, + ) + # linear to change shape + g1 = self.g1_shape_tranform(g1) + if self.add_tebd_to_repinit_out: + assert self.tebd_transform is not None + g1 = g1 + self.tebd_transform(g1_inp) + # mapping g1 + assert mapping is not None + mapping_ext = np.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1])) + g1_ext = np.take_along_axis(g1, mapping_ext, axis=1) + # repformer + g1, g2, h2, rot_mat, sw = self.repformers( + nlist_dict[ + get_multiple_nlist_key( + self.repformers.get_rcut(), self.repformers.get_nsel() + ) + ], + coord_ext, + atype_ext, + g1_ext, + mapping, + ) + if self.concat_output_tebd: + g1 = np.concatenate([g1, g1_inp], axis=-1) + return g1, rot_mat, g2, h2, sw + + def serialize(self) -> dict: + repinit = self.repinit + repformers = self.repformers + data = { + "@class": "Descriptor", + "type": "dpa2", + "@version": 1, + "ntypes": self.ntypes, + "repinit_rcut": repinit.rcut, + "repinit_rcut_smth": repinit.rcut_smth, + "repinit_nsel": repinit.sel, + "repformer_rcut": repformers.rcut, + "repformer_rcut_smth": repformers.rcut_smth, + "repformer_nsel": repformers.sel, + "repinit_neuron": repinit.neuron, + "repinit_axis_neuron": repinit.axis_neuron, + "repinit_tebd_dim": repinit.tebd_dim, + "repinit_tebd_input_mode": repinit.tebd_input_mode, + "repinit_set_davg_zero": repinit.set_davg_zero, + "repinit_activation_function": repinit.activation_function, + "repformer_nlayers": repformers.nlayers, + "repformer_g1_dim": repformers.g1_dim, + "repformer_g2_dim": repformers.g2_dim, + "repformer_axis_neuron": repformers.axis_neuron, + "repformer_direct_dist": repformers.direct_dist, + "repformer_update_g1_has_conv": repformers.update_g1_has_conv, + "repformer_update_g1_has_drrd": repformers.update_g1_has_drrd, + "repformer_update_g1_has_grrg": repformers.update_g1_has_grrg, + "repformer_update_g1_has_attn": repformers.update_g1_has_attn, + "repformer_update_g2_has_g1g1": repformers.update_g2_has_g1g1, + "repformer_update_g2_has_attn": repformers.update_g2_has_attn, + "repformer_update_h2": repformers.update_h2, + "repformer_attn1_hidden": repformers.attn1_hidden, + "repformer_attn1_nhead": repformers.attn1_nhead, + "repformer_attn2_hidden": repformers.attn2_hidden, + "repformer_attn2_nhead": repformers.attn2_nhead, + "repformer_attn2_has_gate": repformers.attn2_has_gate, + "repformer_activation_function": repformers.activation_function, + "repformer_update_style": repformers.update_style, + "repformer_set_davg_zero": repformers.set_davg_zero, + "concat_output_tebd": self.concat_output_tebd, + "precision": self.precision, + "smooth": self.smooth, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "trainable": self.trainable, + "resnet_dt": self.resnet_dt, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "type_one_side": self.type_one_side, + "add_tebd_to_repinit_out": self.add_tebd_to_repinit_out, + "type_embedding": self.type_embedding.serialize(), + "g1_shape_tranform": self.g1_shape_tranform.serialize(), + } + if self.add_tebd_to_repinit_out: + data.update( + { + "tebd_transform": self.tebd_transform.serialize(), + } + ) + repinit_variable = { + "embeddings": repinit.embeddings.serialize(), + "env_mat": EnvMat(repinit.rcut, repinit.rcut_smth).serialize(), + "@variables": { + "davg": repinit["davg"], + "dstd": repinit["dstd"], + }, + } + if repinit.tebd_input_mode in ["strip"]: + repinit_variable.update( + {"embeddings_strip": repinit.embeddings_strip.serialize()} + ) + repformers_variable = { + "g2_embd": repformers.g2_embd.serialize(), + "repformer_layers": [layer.serialize() for layer in repformers.layers], + "env_mat": EnvMat(repformers.rcut, repformers.rcut_smth).serialize(), + "@variables": { + "davg": repformers["davg"], + "dstd": repformers["dstd"], + }, + } + data.update( + { + "repinit": repinit_variable, + "repformers": repformers_variable, + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA2": + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + repinit_variable = data.pop("repinit").copy() + repformers_variable = data.pop("repformers").copy() + type_embedding = data.pop("type_embedding") + g1_shape_tranform = data.pop("g1_shape_tranform") + tebd_transform = data.pop("tebd_transform", None) + add_tebd_to_repinit_out = data["add_tebd_to_repinit_out"] + obj = cls(**data) + obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) + if add_tebd_to_repinit_out: + assert isinstance(tebd_transform, dict) + obj.tebd_transform = NativeLayer.deserialize(tebd_transform) + if obj.repinit.dim_out != obj.repformers.dim_in: + obj.g1_shape_tranform = NativeLayer.deserialize(g1_shape_tranform) + + # deserialize repinit + statistic_repinit = repinit_variable.pop("@variables") + env_mat = repinit_variable.pop("env_mat") + tebd_input_mode = data["repinit_tebd_input_mode"] + obj.repinit.embeddings = NetworkCollection.deserialize( + repinit_variable.pop("embeddings") + ) + if tebd_input_mode in ["strip"]: + obj.repinit.embeddings_strip = NetworkCollection.deserialize( + repinit_variable.pop("embeddings_strip") + ) + obj.repinit["davg"] = statistic_repinit["davg"] + obj.repinit["dstd"] = statistic_repinit["dstd"] + + # deserialize repformers + statistic_repformers = repformers_variable.pop("@variables") + env_mat = repformers_variable.pop("env_mat") + repformer_layers = repformers_variable.pop("repformer_layers") + obj.repformers.g2_embd = NativeLayer.deserialize( + repformers_variable.pop("g2_embd") + ) + obj.repformers["davg"] = statistic_repformers["davg"] + obj.repformers["dstd"] = statistic_repformers["dstd"] + obj.repformers.layers = [ + RepformerLayer.deserialize(layer) for layer in repformer_layers + ] + return obj + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + update_sel = UpdateSel() + local_jdata_cpy = update_sel.update_one_sel( + global_jdata, + local_jdata_cpy, + True, + rcut_key="repinit_rcut", + sel_key="repinit_nsel", + ) + local_jdata_cpy = update_sel.update_one_sel( + global_jdata, + local_jdata_cpy, + True, + rcut_key="repformer_rcut", + sel_key="repformer_nsel", + ) + return local_jdata_cpy diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py new file mode 100644 index 0000000000..bd2f97a29a --- /dev/null +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -0,0 +1,1641 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from deepmd.dpmodel.utils.network import ( + LayerNorm, + NativeLayer, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + +from typing import ( + Callable, + List, + Optional, + Tuple, + Union, +) + +from deepmd.dpmodel import ( + PRECISION_DICT, + NativeOP, +) +from deepmd.dpmodel.utils import ( + EnvMat, + PairExcludeMask, +) +from deepmd.dpmodel.utils.network import ( + get_activation_fn, +) + +from .descriptor import ( + DescriptorBlock, +) +from .dpa1 import ( + np_softmax, +) + + +@DescriptorBlock.register("se_repformer") +@DescriptorBlock.register("se_uni") +class DescrptBlockRepformers(NativeOP, DescriptorBlock): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + nlayers: int = 3, + g1_dim=128, + g2_dim=16, + axis_neuron: int = 4, + direct_dist: bool = False, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation_function: str = "tanh", + update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", + set_davg_zero: bool = True, + smooth: bool = True, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, + precision: str = "float64", + resnet_dt: bool = False, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + ): + r""" + The repformer descriptor block. + + Parameters + ---------- + rcut : float + The cut-off radius. + rcut_smth : float + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. + sel : int + Maximally possible number of selected neighbors. + ntypes : int + Number of element types + nlayers : int, optional + Number of repformer layers. + g1_dim : int, optional + Dimension of the first graph convolution layer. + g2_dim : int, optional + Dimension of the second graph convolution layer. + axis_neuron : int, optional + Size of the submatrix of G (embedding matrix). + direct_dist : bool, optional + Whether to use direct distance information (1/r term) in the repformer block. + update_g1_has_conv : bool, optional + Whether to update the g1 rep with convolution term. + update_g1_has_drrd : bool, optional + Whether to update the g1 rep with the drrd term. + update_g1_has_grrg : bool, optional + Whether to update the g1 rep with the grrg term. + update_g1_has_attn : bool, optional + Whether to update the g1 rep with the localized self-attention. + update_g2_has_g1g1 : bool, optional + Whether to update the g2 rep with the g1xg1 term. + update_g2_has_attn : bool, optional + Whether to update the g2 rep with the gated self-attention. + update_h2 : bool, optional + Whether to update the h2 rep. + attn1_hidden : int, optional + The hidden dimension of localized self-attention to update the g1 rep. + attn1_nhead : int, optional + The number of heads in localized self-attention to update the g1 rep. + attn2_hidden : int, optional + The hidden dimension of gated self-attention to update the g2 rep. + attn2_nhead : int, optional + The number of heads in gated self-attention to update the g2 rep. + attn2_has_gate : bool, optional + Whether to use gate in the gated self-attention to update the g2 rep. + activation_function : str, optional + The activation function in the embedding net. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + set_davg_zero : bool, optional + Set the normalization average to zero. + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : List[List[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + resnet_dt : bool, optional + Whether to use a "Timestep" in the skip connection. + trainable_ln : bool, optional + Whether to use trainable shift and scale weights in layer normalization. + ln_eps : float, optional + The epsilon value for layer normalization. + """ + super().__init__() + self.rcut = rcut + self.rcut_smth = rcut_smth + self.ntypes = ntypes + self.nlayers = nlayers + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 # use full descriptor. + assert len(sel) == 1 + self.sel = sel + self.sec = self.sel + self.split_sel = self.sel + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + self.g1_dim = g1_dim + self.g2_dim = g2_dim + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_attn = update_g1_has_attn + self.update_g2_has_g1g1 = update_g2_has_g1g1 + self.update_g2_has_attn = update_g2_has_attn + self.update_h2 = update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_has_gate = attn2_has_gate + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.activation_function = activation_function + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.direct_dist = direct_dist + self.act = get_activation_fn(self.activation_function) + self.smooth = smooth + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + self.env_protection = env_protection + self.precision = precision + self.resnet_dt = resnet_dt + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.epsilon = 1e-4 + + self.g2_embd = NativeLayer(1, self.g2_dim, precision=precision) + layers = [] + for ii in range(nlayers): + layers.append( + RepformerLayer( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.g1_dim, + self.g2_dim, + axis_neuron=self.axis_neuron, + update_chnnl_2=(ii != nlayers - 1), + update_g1_has_conv=self.update_g1_has_conv, + update_g1_has_drrd=self.update_g1_has_drrd, + update_g1_has_grrg=self.update_g1_has_grrg, + update_g1_has_attn=self.update_g1_has_attn, + update_g2_has_g1g1=self.update_g2_has_g1g1, + update_g2_has_attn=self.update_g2_has_attn, + update_h2=self.update_h2, + attn1_hidden=self.attn1_hidden, + attn1_nhead=self.attn1_nhead, + attn2_has_gate=self.attn2_has_gate, + attn2_hidden=self.attn2_hidden, + attn2_nhead=self.attn2_nhead, + activation_function=self.activation_function, + update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, + smooth=self.smooth, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + precision=precision, + ) + ) + self.layers = layers + + wanted_shape = (self.ntypes, self.nnei, 4) + self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) + self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.orig_sel = self.sel + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_in(self) -> int: + """Returns the output dimension.""" + return self.dim_in + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_emb(self) -> int: + """Returns the embedding dimension g2.""" + return self.g2_dim + + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the discriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the discriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.g1_dim + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.g1_dim + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.""" + raise NotImplementedError + + def get_stats(self): + """Get the statistics of the descriptor.""" + raise NotImplementedError + + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def call( + self, + nlist: np.ndarray, + coord_ext: np.ndarray, + atype_ext: np.ndarray, + atype_embd_ext: Optional[np.ndarray] = None, + mapping: Optional[np.ndarray] = None, + ): + exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + nlist = nlist * exclude_mask + # nf x nloc x nnei x 4 + dmatrix, diff, sw = self.env_mat.call( + coord_ext, atype_ext, nlist, self.mean, self.stddev + ) + nf, nloc, nnei, _ = dmatrix.shape + # nf x nloc x nnei + nlist_mask = nlist != -1 + # nf x nloc x nnei + sw = sw.reshape(nf, nloc, nnei) + sw = np.where(nlist_mask, sw, 0.0) + # nf x nloc x tebd_dim + atype_embd = atype_embd_ext[:, :nloc, :] + assert list(atype_embd.shape) == [nf, nloc, self.g1_dim] + + g1 = self.act(atype_embd) + # nf x nloc x nnei x 1, nf x nloc x nnei x 3 + if not self.direct_dist: + g2, h2 = np.split(dmatrix, [1], axis=-1) + else: + g2, h2 = np.linalg.norm(diff, axis=-1, keepdims=True), diff + g2 = g2 / self.rcut + h2 = h2 / self.rcut + # nf x nloc x nnei x ng2 + g2 = self.act(self.g2_embd(g2)) + # set all padding positions to index of 0 + # if a neighbor is real or not is indicated by nlist_mask + nlist[nlist == -1] = 0 + # nf x nall x ng1 + mapping = np.tile(mapping.reshape(nf, -1, 1), (1, 1, self.g1_dim)) + for idx, ll in enumerate(self.layers): + # g1: nf x nloc x ng1 + # g1_ext: nf x nall x ng1 + g1_ext = np.take_along_axis(g1, mapping, axis=1) + g1, g2, h2 = ll.call( + g1_ext, + g2, + h2, + nlist, + nlist_mask, + sw, + ) + + # nf x nloc x 3 x ng2 + h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon) + # (nf x nloc) x ng2 x 3 + rot_mat = np.transpose(h2g2, (0, 1, 3, 2)) + return g1, g2, h2, rot_mat.reshape(-1, nloc, self.dim_emb, 3), sw + + +# translated by GPT and modified +def get_residual( + _dim: int, + _scale: float, + _mode: str = "norm", + trainable: bool = True, + precision: str = "float64", +) -> np.ndarray: + """ + Get residual tensor for one update vector. + + Parameters + ---------- + _dim : int + The dimension of the update vector. + _scale + The initial scale of the residual tensor. See `_mode` for details. + _mode + The mode of residual initialization for the residual tensor. + - "norm" (default): init residual using normal with `_scale` std. + - "const": init residual using element-wise constants of `_scale`. + trainable + Whether the residual tensor is trainable. + precision + The precision of the residual tensor. + """ + residual = np.zeros(_dim, dtype=PRECISION_DICT[precision]) + rng = np.random.default_rng() + if trainable: + if _mode == "norm": + residual = rng.normal(scale=_scale, size=_dim).astype( + PRECISION_DICT[precision] + ) + elif _mode == "const": + residual.fill(_scale) + else: + raise RuntimeError(f"Unsupported initialization mode '{_mode}'!") + return residual + + +def _make_nei_g1( + g1_ext: np.ndarray, + nlist: np.ndarray, +) -> np.ndarray: + """ + Make neighbor-wise atomic invariant rep. + + Parameters + ---------- + g1_ext + Extended atomic invariant rep, with shape [nf, nall, ng1]. + nlist + Neighbor list, with shape [nf, nloc, nnei]. + + Returns + ------- + gg1: np.ndarray + Neighbor-wise atomic invariant rep, with shape [nf, nloc, nnei, ng1]. + """ + # nlist: nf x nloc x nnei + nf, nloc, nnei = nlist.shape + # g1_ext: nf x nall x ng1 + ng1 = g1_ext.shape[-1] + # index: nf x (nloc x nnei) x ng1 + index = np.tile(nlist.reshape(nf, nloc * nnei, 1), (1, 1, ng1)) + # gg1 : nf x (nloc x nnei) x ng1 + gg1 = np.take_along_axis(g1_ext, index, axis=1) + # gg1 : nf x nloc x nnei x ng1 + gg1 = gg1.reshape(nf, nloc, nnei, ng1) + return gg1 + + +def _apply_nlist_mask( + gg: np.ndarray, + nlist_mask: np.ndarray, +) -> np.ndarray: + """ + Apply nlist mask to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape [nf, nloc, nnei, d]. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape [nf, nloc, nnei]. + """ + masked_gg = np.where(nlist_mask[:, :, :, None], gg, 0.0) + return masked_gg + + +def _apply_switch(gg: np.ndarray, sw: np.ndarray) -> np.ndarray: + """ + Apply switch function to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape [nf, nloc, nnei, d]. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape [nf, nloc, nnei]. + """ + # gg: nf x nloc x nnei x d + # sw: nf x nloc x nnei + return gg * sw[:, :, :, None] + + +def _cal_hg( + g: np.ndarray, + h: np.ndarray, + nlist_mask: np.ndarray, + sw: np.ndarray, + smooth: bool = True, + epsilon: float = 1e-4, +) -> np.ndarray: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + g + Neighbor-wise/Pair-wise invariant rep tensors, with shape [nf, nloc, nnei, ng]. + h + Neighbor-wise/Pair-wise equivariant rep tensors, with shape [nf, nloc, nnei, 3]. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape [nf, nloc, nnei]. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape [nf, nloc, nnei]. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape [nf, nloc, 3, ng]. + """ + # g: nf x nloc x nnei x ng + # h: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nf, nloc, nnei, _ = g.shape + ng = g.shape[-1] + # nf x nloc x nnei x ng + g = _apply_nlist_mask(g, nlist_mask) + if not smooth: + # nf x nloc + invnnei = 1.0 / (epsilon + np.sum(nlist_mask, axis=-1)) + # nf x nloc x 1 x 1 + invnnei = invnnei[:, :, np.newaxis, np.newaxis] + else: + g = _apply_switch(g, sw) + invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1, 1), dtype=g.dtype) + # nf x nloc x 3 x ng + hg = np.matmul(np.transpose(h, axes=(0, 1, 3, 2)), g) * invnnei + return hg + + +def _cal_grrg(hg: np.ndarray, axis_neuron: int) -> np.ndarray: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + hg + The transposed rotation matrix, with shape [nf, nloc, 3, ng]. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape [nf, nloc, (axis_neuron * ng)]. + """ + # nf x nloc x 3 x ng + nf, nloc, _, ng = hg.shape + # nf x nloc x 3 x axis + hgm = np.split(hg, [axis_neuron], axis=-1)[0] + # nf x nloc x axis_neuron x ng + grrg = np.matmul(np.transpose(hgm, axes=(0, 1, 3, 2)), hg) / (3.0**1) + # nf x nloc x (axis_neuron * ng) + grrg = grrg.reshape(nf, nloc, axis_neuron * ng) + return grrg + + +def symmetrization_op( + g: np.ndarray, + h: np.ndarray, + nlist_mask: np.ndarray, + sw: np.ndarray, + axis_neuron: int, + smooth: bool = True, + epsilon: float = 1e-4, +) -> np.ndarray: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + g + Neighbor-wise/Pair-wise invariant rep tensors, with shape [nf, nloc, nnei, ng]. + h + Neighbor-wise/Pair-wise equivariant rep tensors, with shape [nf, nloc, nnei, 3]. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape [nf, nloc, nnei]. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape [nf, nloc, nnei]. + axis_neuron + Size of the submatrix. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + grrg + Atomic invariant rep, with shape [nf, nloc, (axis_neuron * ng)]. + """ + # g: nf x nloc x nnei x ng + # h: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nf, nloc, nnei, _ = g.shape + # nf x nloc x 3 x ng + hg = _cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) + # nf x nloc x (axis_neuron x ng) + grrg = _cal_grrg(hg, axis_neuron) + return grrg + + +class Atten2Map(NativeOP): + def __init__( + self, + input_dim: int, + hidden_dim: int, + head_num: int, + has_gate: bool = False, # apply gate to attn map + smooth: bool = True, + attnw_shift: float = 20.0, + precision: str = "float64", + ): + """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapqk = NativeLayer( + input_dim, hidden_dim * 2 * head_num, bias=False, precision=precision + ) + self.has_gate = has_gate + self.smooth = smooth + self.attnw_shift = attnw_shift + self.precision = precision + + def call( + self, + g2: np.ndarray, # nf x nloc x nnei x ng2 + h2: np.ndarray, # nf x nloc x nnei x 3 + nlist_mask: np.ndarray, # nf x nloc x nnei + sw: np.ndarray, # nf x nloc x nnei + ) -> np.ndarray: + ( + nf, + nloc, + nnei, + _, + ) = g2.shape + nd, nh = self.hidden_dim, self.head_num + # nf x nloc x nnei x nd x (nh x 2) + g2qk = self.mapqk(g2).reshape(nf, nloc, nnei, nd, nh * 2) + # nf x nloc x (nh x 2) x nnei x nd + g2qk = np.transpose(g2qk, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x nd + g2q, g2k = np.split(g2qk, [nh], axis=2) + # g2q = np.linalg.norm(g2q, axis=-1) + # g2k = np.linalg.norm(g2k, axis=-1) + # nf x nloc x nh x nnei x nnei + attnw = np.matmul(g2q, np.transpose(g2k, axes=(0, 1, 2, 4, 3))) / nd**0.5 + if self.has_gate: + gate = np.matmul(h2, np.transpose(h2, axes=(0, 1, 3, 2))).reshape( + nf, nloc, 1, nnei, nnei + ) + attnw = attnw * gate + # mask the attenmap, nf x nloc x 1 x 1 x nnei + attnw_mask = ~np.expand_dims(np.expand_dims(nlist_mask, axis=2), axis=2) + # mask the attenmap, nf x nloc x 1 x nnei x 1 + attnw_mask_c = ~np.expand_dims(np.expand_dims(nlist_mask, axis=2), axis=-1) + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ + :, :, None, None, : + ] - self.attnw_shift + else: + attnw = np.where(attnw_mask, -np.inf, attnw) + attnw = np_softmax(attnw, axis=-1) + attnw = np.where(attnw_mask, 0.0, attnw) + # nf x nloc x nh x nnei x nnei + attnw = np.where(attnw_mask_c, 0.0, attnw) + if self.smooth: + attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] + # nf x nloc x nnei x nnei + h2h2t = np.matmul(h2, np.transpose(h2, axes=(0, 1, 3, 2))) / 3.0**0.5 + # nf x nloc x nh x nnei x nnei + ret = attnw * h2h2t[:, :, None, :, :] + # ret = np.exp(g2qk - np.max(g2qk, axis=-1, keepdims=True)) + # nf x nloc x nnei x nnei x nh + ret = np.transpose(ret, (0, 1, 3, 4, 2)) + return ret + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2Map", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "has_gate": self.has_gate, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapqk": self.mapqk.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2Map": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapqk = data.pop("mapqk") + obj = cls(**data) + obj.mapqk = NativeLayer.deserialize(mapqk) + return obj + + +class Atten2MultiHeadApply(NativeOP): + def __init__( + self, + input_dim: int, + head_num: int, + precision: str = "float64", + ): + super().__init__() + self.input_dim = input_dim + self.head_num = head_num + self.mapv = NativeLayer( + input_dim, input_dim * head_num, bias=False, precision=precision + ) + self.head_map = NativeLayer( + input_dim * head_num, input_dim, precision=precision + ) + self.precision = precision + + def call( + self, + AA: np.ndarray, # nf x nloc x nnei x nnei x nh + g2: np.ndarray, # nf x nloc x nnei x ng2 + ) -> np.ndarray: + nf, nloc, nnei, ng2 = g2.shape + nh = self.head_num + # nf x nloc x nnei x ng2 x nh + g2v = self.mapv(g2).reshape(nf, nloc, nnei, ng2, nh) + # nf x nloc x nh x nnei x ng2 + g2v = np.transpose(g2v, (0, 1, 4, 2, 3)) + # g2v = np.linalg.norm(g2v, axis=-1) + # nf x nloc x nh x nnei x nnei + AA = np.transpose(AA, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x ng2 + ret = np.matmul(AA, g2v) + # nf x nloc x nnei x ng2 x nh + ret = np.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) + # nf x nloc x nnei x ng2 + return self.head_map(ret) + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2MultiHeadApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "mapv": self.mapv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2MultiHeadApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapv = data.pop("mapv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapv = NativeLayer.deserialize(mapv) + obj.head_map = NativeLayer.deserialize(head_map) + return obj + + +class Atten2EquiVarApply(NativeOP): + def __init__( + self, + input_dim: int, + head_num: int, + precision: str = "float64", + ): + super().__init__() + self.input_dim = input_dim + self.head_num = head_num + self.head_map = NativeLayer(head_num, 1, bias=False, precision=precision) + self.precision = precision + + def call( + self, + AA: np.ndarray, # nf x nloc x nnei x nnei x nh + h2: np.ndarray, # nf x nloc x nnei x 3 + ) -> np.ndarray: + nf, nloc, nnei, _ = h2.shape + nh = self.head_num + # nf x nloc x nh x nnei x nnei + AA = np.transpose(AA, (0, 1, 4, 2, 3)) + h2m = np.expand_dims(h2, axis=2) + # nf x nloc x nh x nnei x 3 + h2m = np.tile(h2m, (1, 1, nh, 1, 1)) + # nf x nloc x nh x nnei x 3 + ret = np.matmul(AA, h2m) + # nf x nloc x nnei x 3 x nh + ret = np.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, 3, nh) + # nf x nloc x nnei x 3 + return np.squeeze(self.head_map(ret), axis=-1) + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2EquiVarApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2EquiVarApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + head_map = data.pop("head_map") + obj = cls(**data) + obj.head_map = NativeLayer.deserialize(head_map) + return obj + + +class LocalAtten(NativeOP): + def __init__( + self, + input_dim: int, + hidden_dim: int, + head_num: int, + smooth: bool = True, + attnw_shift: float = 20.0, + precision: str = "float64", + ): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapq = NativeLayer( + input_dim, hidden_dim * 1 * head_num, bias=False, precision=precision + ) + self.mapkv = NativeLayer( + input_dim, + (hidden_dim + input_dim) * head_num, + bias=False, + precision=precision, + ) + self.head_map = NativeLayer( + input_dim * head_num, input_dim, precision=precision + ) + self.smooth = smooth + self.attnw_shift = attnw_shift + self.precision = precision + + def call( + self, + g1: np.ndarray, # nf x nloc x ng1 + gg1: np.ndarray, # nf x nloc x nnei x ng1 + nlist_mask: np.ndarray, # nf x nloc x nnei + sw: np.ndarray, # nf x nloc x nnei + ) -> np.ndarray: + nf, nloc, nnei = nlist_mask.shape + ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num + assert ni == g1.shape[-1] + assert ni == gg1.shape[-1] + # nf x nloc x nd x nh + g1q = self.mapq(g1).reshape(nf, nloc, nd, nh) + # nf x nloc x nh x nd + g1q = np.transpose(g1q, (0, 1, 3, 2)) + # nf x nloc x nnei x (nd+ni) x nh + gg1kv = self.mapkv(gg1).reshape(nf, nloc, nnei, nd + ni, nh) + gg1kv = np.transpose(gg1kv, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x nd, nf x nloc x nh x nnei x ng1 + gg1k, gg1v = np.split(gg1kv, [nd], axis=-1) + + # nf x nloc x nh x 1 x nnei + attnw = ( + np.matmul( + np.expand_dims(g1q, axis=-2), np.transpose(gg1k, axes=(0, 1, 2, 4, 3)) + ) + / nd**0.5 + ) + # nf x nloc x nh x nnei + attnw = np.squeeze(attnw, axis=-2) + # mask the attenmap, nf x nloc x 1 x nnei + attnw_mask = ~np.expand_dims(nlist_mask, axis=-2) + # nf x nloc x nh x nnei + if self.smooth: + attnw = (attnw + self.attnw_shift) * np.expand_dims( + sw, axis=-2 + ) - self.attnw_shift + else: + attnw = np.where(attnw_mask, -np.inf, attnw) + attnw = np_softmax(attnw, axis=-1) + attnw = np.where(attnw_mask, 0.0, attnw) + if self.smooth: + attnw = attnw * np.expand_dims(sw, axis=-2) + + # nf x nloc x nh x ng1 + ret = ( + np.matmul(np.expand_dims(attnw, axis=-2), gg1v) + .squeeze(-2) + .reshape(nf, nloc, nh * ni) + ) + # nf x nloc x ng1 + ret = self.head_map(ret) + return ret + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "LocalAtten", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapq": self.mapq.serialize(), + "mapkv": self.mapkv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "LocalAtten": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapq = data.pop("mapq") + mapkv = data.pop("mapkv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapq = NativeLayer.deserialize(mapq) + obj.mapkv = NativeLayer.deserialize(mapkv) + obj.head_map = NativeLayer.deserialize(head_map) + return obj + + +class RepformerLayer(NativeOP): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + g1_dim=128, + g2_dim=16, + axis_neuron: int = 4, + update_chnnl_2: bool = True, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation_function: str = "tanh", + update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", + smooth: bool = True, + precision: str = "float64", + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + ): + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.rcut = rcut + self.rcut_smth = rcut_smth + self.ntypes = ntypes + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + assert len(sel) == 1 + self.sel = sel + self.sec = self.sel + self.axis_neuron = axis_neuron + self.activation_function = activation_function + self.act = get_activation_fn(self.activation_function) + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_attn = update_g1_has_attn + self.update_chnnl_2 = update_chnnl_2 + self.update_g2_has_g1g1 = update_g2_has_g1g1 if self.update_chnnl_2 else False + self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False + self.update_h2 = update_h2 if self.update_chnnl_2 else False + del update_g2_has_g1g1, update_g2_has_attn, update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.attn2_has_gate = attn2_has_gate + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.smooth = smooth + self.g1_dim = g1_dim + self.g2_dim = g2_dim + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.precision = precision + + assert update_residual_init in [ + "norm", + "const", + ], "'update_residual_init' only support 'norm' or 'const'!" + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.g1_residual = [] + self.g2_residual = [] + self.h2_residual = [] + + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + + g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) + self.linear1 = NativeLayer(g1_in_dim, g1_dim, precision=precision) + self.linear2 = None + self.proj_g1g2 = None + self.proj_g1g1g2 = None + self.attn2g_map = None + self.attn2_mh_apply = None + self.attn2_lm = None + self.attn2_ev_apply = None + self.loc_attn = None + + if self.update_chnnl_2: + self.linear2 = NativeLayer(g2_dim, g2_dim, precision=precision) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + if self.update_g1_has_conv: + self.proj_g1g2 = NativeLayer( + g1_dim, g2_dim, bias=False, precision=precision + ) + if self.update_g2_has_g1g1: + self.proj_g1g1g2 = NativeLayer( + g1_dim, g2_dim, bias=False, precision=precision + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + if self.update_g2_has_attn or self.update_h2: + self.attn2g_map = Atten2Map( + g2_dim, + attn2_hidden, + attn2_nhead, + attn2_has_gate, + self.smooth, + precision=precision, + ) + if self.update_g2_has_attn: + self.attn2_mh_apply = Atten2MultiHeadApply( + g2_dim, attn2_nhead, precision=precision + ) + self.attn2_lm = LayerNorm( + g2_dim, eps=ln_eps, trainable=trainable_ln, precision=precision + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + + if self.update_h2: + self.attn2_ev_apply = Atten2EquiVarApply( + g2_dim, attn2_nhead, precision=precision + ) + if self.update_style == "res_residual": + self.h2_residual.append( + get_residual( + 1, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + if self.update_g1_has_attn: + self.loc_attn = LocalAtten( + g1_dim, attn1_hidden, attn1_nhead, self.smooth, precision=precision + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + ) + ) + + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: + ret = g1d + if self.update_g1_has_grrg: + ret += g2d * ax + if self.update_g1_has_drrd: + ret += g1d * ax + if self.update_g1_has_conv: + ret += g2d + return ret + + def _update_h2( + self, + h2: np.ndarray, + attn: np.ndarray, + ) -> np.ndarray: + """ + Calculate the attention weights update for pair-wise equivariant rep. + + Parameters + ---------- + h2 + Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + attn + Attention weights from g2 attention, with shape nf x nloc x nnei x nnei x nh2. + """ + assert self.attn2_ev_apply is not None + # nf x nloc x nnei x nh2 + h2_1 = self.attn2_ev_apply(attn, h2) + return h2_1 + + def _update_g1_conv( + self, + gg1: np.ndarray, + g2: np.ndarray, + nlist_mask: np.ndarray, + sw: np.ndarray, + ) -> np.ndarray: + """ + Calculate the convolution update for atomic invariant rep. + + Parameters + ---------- + gg1 + Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + g2 + Pair invariant rep, with shape nf x nloc x nnei x ng2. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + """ + assert self.proj_g1g2 is not None + nf, nloc, nnei, _ = g2.shape + ng1 = gg1.shape[-1] + ng2 = g2.shape[-1] + # gg1 : nf x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).reshape(nf, nloc, nnei, ng2) + # nf x nloc x nnei x ng2 + gg1 = _apply_nlist_mask(gg1, nlist_mask) + if not self.smooth: + # normalized by number of neighbors, not smooth + # nf x nloc + invnnei = 1.0 / (self.epsilon + np.sum(nlist_mask, axis=-1)) + # nf x nloc x 1 + invnnei = invnnei[:, :, np.newaxis] + else: + gg1 = _apply_switch(gg1, sw) + invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1), dtype=gg1.dtype) + # nf x nloc x ng2 + g1_11 = np.sum(g2 * gg1, axis=2) * invnnei + return g1_11 + + def _update_g2_g1g1( + self, + g1: np.ndarray, # nf x nloc x ng1 + gg1: np.ndarray, # nf x nloc x nnei x ng1 + nlist_mask: np.ndarray, # nf x nloc x nnei + sw: np.ndarray, # nf x nloc x nnei + ) -> np.ndarray: + """ + Update the g2 using element-wise dot g1_i * g1_j. + + Parameters + ---------- + g1 + Atomic invariant rep, with shape nf x nloc x ng1. + gg1 + Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + """ + ret = np.expand_dims(g1, axis=-2) * gg1 + # nf x nloc x nnei x ng1 + ret = _apply_nlist_mask(ret, nlist_mask) + if self.smooth: + ret = _apply_switch(ret, sw) + return ret + + def call( + self, + g1_ext: np.ndarray, # nf x nall x ng1 + g2: np.ndarray, # nf x nloc x nnei x ng2 + h2: np.ndarray, # nf x nloc x nnei x 3 + nlist: np.ndarray, # nf x nloc x nnei + nlist_mask: np.ndarray, # nf x nloc x nnei + sw: np.ndarray, # switch func, nf x nloc x nnei + ): + """ + Parameters + ---------- + g1_ext : nf x nall x ng1 extended single-atom chanel + g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant + h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant + nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) + nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei switch function + + Returns + ------- + g1: nf x nloc x ng1 updated single-atom chanel + g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant + h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant + """ + cal_gg1 = ( + self.update_g1_has_drrd + or self.update_g1_has_conv + or self.update_g1_has_attn + or self.update_g2_has_g1g1 + ) + + nf, nloc, nnei, _ = g2.shape + nall = g1_ext.shape[1] + g1, _ = np.split(g1_ext, [nloc], axis=1) + assert (nf, nloc) == g1.shape[:2] + assert (nf, nloc, nnei) == h2.shape[:3] + + g2_update: List[np.ndarray] = [g2] + h2_update: List[np.ndarray] = [h2] + g1_update: List[np.ndarray] = [g1] + g1_mlp: List[np.ndarray] = [g1] + + if cal_gg1: + gg1 = _make_nei_g1(g1_ext, nlist) + else: + gg1 = None + + if self.update_chnnl_2: + # mlp(g2) + assert self.linear2 is not None + # nf x nloc x nnei x ng2 + g2_1 = self.act(self.linear2(g2)) + g2_update.append(g2_1) + + if self.update_g2_has_g1g1: + # linear(g1_i * g1_j) + assert gg1 is not None + assert self.proj_g1g1g2 is not None + g2_update.append( + self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) + ) + + if self.update_g2_has_attn or self.update_h2: + # gated_attention(g2, h2) + assert self.attn2g_map is not None + # nf x nloc x nnei x nnei x nh + AAg = self.attn2g_map(g2, h2, nlist_mask, sw) + + if self.update_g2_has_attn: + assert self.attn2_mh_apply is not None + assert self.attn2_lm is not None + # nf x nloc x nnei x ng2 + g2_2 = self.attn2_mh_apply(AAg, g2) + g2_2 = self.attn2_lm(g2_2) + g2_update.append(g2_2) + + if self.update_h2: + # linear_head(attention_weights * h2) + h2_update.append(self._update_h2(h2, AAg)) + + if self.update_g1_has_conv: + assert gg1 is not None + g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) + + if self.update_g1_has_grrg: + g1_mlp.append( + symmetrization_op( + g2, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) + + if self.update_g1_has_drrd: + assert gg1 is not None + g1_mlp.append( + symmetrization_op( + gg1, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) + + # nf x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] + # conv grrg drrd + g1_1 = self.act(self.linear1(np.concatenate(g1_mlp, axis=-1))) + g1_update.append(g1_1) + + if self.update_g1_has_attn: + assert gg1 is not None + assert self.loc_attn is not None + g1_update.append(self.loc_attn(g1, gg1, nlist_mask, sw)) + + # update + if self.update_chnnl_2: + g2_new = self.list_update(g2_update, "g2") + h2_new = self.list_update(h2_update, "h2") + else: + g2_new, h2_new = g2, h2 + g1_new = self.list_update(g1_update, "g1") + return g1_new, g2_new, h2_new + + def list_update_res_avg( + self, + update_list: List[np.ndarray], + ) -> np.ndarray: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + def list_update_res_incr(self, update_list: List[np.ndarray]) -> np.ndarray: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + def list_update_res_residual( + self, update_list: List[np.ndarray], update_name: str = "g1" + ) -> np.ndarray: + nitem = len(update_list) + uu = update_list[0] + if update_name == "g1": + for ii, vv in enumerate(self.g1_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "g2": + for ii, vv in enumerate(self.g2_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "h2": + for ii, vv in enumerate(self.h2_residual): + uu = uu + vv * update_list[ii + 1] + else: + raise NotImplementedError + return uu + + def list_update( + self, update_list: List[np.ndarray], update_name: str = "g1" + ) -> np.ndarray: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + elif self.update_style == "res_residual": + return self.list_update_res_residual(update_list, update_name=update_name) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + data = { + "@class": "RepformerLayer", + "@version": 1, + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "g1_dim": self.g1_dim, + "g2_dim": self.g2_dim, + "axis_neuron": self.axis_neuron, + "update_chnnl_2": self.update_chnnl_2, + "update_g1_has_conv": self.update_g1_has_conv, + "update_g1_has_drrd": self.update_g1_has_drrd, + "update_g1_has_grrg": self.update_g1_has_grrg, + "update_g1_has_attn": self.update_g1_has_attn, + "update_g2_has_g1g1": self.update_g2_has_g1g1, + "update_g2_has_attn": self.update_g2_has_attn, + "update_h2": self.update_h2, + "attn1_hidden": self.attn1_hidden, + "attn1_nhead": self.attn1_nhead, + "attn2_hidden": self.attn2_hidden, + "attn2_nhead": self.attn2_nhead, + "attn2_has_gate": self.attn2_has_gate, + "activation_function": self.activation_function, + "update_style": self.update_style, + "smooth": self.smooth, + "precision": self.precision, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "linear1": self.linear1.serialize(), + } + if self.update_chnnl_2: + data.update( + { + "linear2": self.linear2.serialize(), + } + ) + if self.update_g1_has_conv: + data.update( + { + "proj_g1g2": self.proj_g1g2.serialize(), + } + ) + if self.update_g2_has_g1g1: + data.update( + { + "proj_g1g1g2": self.proj_g1g1g2.serialize(), + } + ) + if self.update_g2_has_attn or self.update_h2: + data.update( + { + "attn2g_map": self.attn2g_map.serialize(), + } + ) + if self.update_g2_has_attn: + data.update( + { + "attn2_mh_apply": self.attn2_mh_apply.serialize(), + "attn2_lm": self.attn2_lm.serialize(), + } + ) + + if self.update_h2: + data.update( + { + "attn2_ev_apply": self.attn2_ev_apply.serialize(), + } + ) + if self.update_g1_has_attn: + data.update( + { + "loc_attn": self.loc_attn.serialize(), + } + ) + if self.update_style == "res_residual": + data.update( + { + "g1_residual": self.g1_residual, + "g2_residual": self.g2_residual, + "h2_residual": self.h2_residual, + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "RepformerLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + linear1 = data.pop("linear1") + update_chnnl_2 = data["update_chnnl_2"] + update_g1_has_conv = data["update_g1_has_conv"] + update_g2_has_g1g1 = data["update_g2_has_g1g1"] + update_g2_has_attn = data["update_g2_has_attn"] + update_h2 = data["update_h2"] + update_g1_has_attn = data["update_g1_has_attn"] + update_style = data["update_style"] + + linear2 = data.pop("linear2", None) + proj_g1g2 = data.pop("proj_g1g2", None) + proj_g1g1g2 = data.pop("proj_g1g1g2", None) + attn2g_map = data.pop("attn2g_map", None) + attn2_mh_apply = data.pop("attn2_mh_apply", None) + attn2_lm = data.pop("attn2_lm", None) + attn2_ev_apply = data.pop("attn2_ev_apply", None) + loc_attn = data.pop("loc_attn", None) + g1_residual = data.pop("g1_residual", []) + g2_residual = data.pop("g2_residual", []) + h2_residual = data.pop("h2_residual", []) + + obj = cls(**data) + obj.linear1 = NativeLayer.deserialize(linear1) + if update_chnnl_2: + assert isinstance(linear2, dict) + obj.linear2 = NativeLayer.deserialize(linear2) + if update_g1_has_conv: + assert isinstance(proj_g1g2, dict) + obj.proj_g1g2 = NativeLayer.deserialize(proj_g1g2) + if update_g2_has_g1g1: + assert isinstance(proj_g1g1g2, dict) + obj.proj_g1g1g2 = NativeLayer.deserialize(proj_g1g1g2) + if update_g2_has_attn or update_h2: + assert isinstance(attn2g_map, dict) + obj.attn2g_map = Atten2Map.deserialize(attn2g_map) + if update_g2_has_attn: + assert isinstance(attn2_mh_apply, dict) + assert isinstance(attn2_lm, dict) + obj.attn2_mh_apply = Atten2MultiHeadApply.deserialize(attn2_mh_apply) + obj.attn2_lm = LayerNorm.deserialize(attn2_lm) + if update_h2: + assert isinstance(attn2_ev_apply, dict) + obj.attn2_ev_apply = Atten2EquiVarApply.deserialize(attn2_ev_apply) + if update_g1_has_attn: + assert isinstance(loc_attn, dict) + obj.loc_attn = LocalAtten.deserialize(loc_attn) + if update_style == "res_residual": + obj.g1_residual = g1_residual + obj.g2_residual = g2_residual + obj.h2_residual = h2_residual + return obj diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 8d926034dd..c50fdb4cb3 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -321,7 +321,9 @@ def call( """ del mapping # nf x nloc x nnei x 4 - rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd) + rr, diff, ww = self.env_mat.call( + coord_ext, atype_ext, nlist, self.davg, self.dstd + ) nf, nloc, nnei, _ = rr.shape sec = np.append([0], np.cumsum(self.sel)) diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 9c9b4e096e..6b50c3ba68 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -270,7 +270,7 @@ def call( """ del mapping # nf x nloc x nnei x 1 - rr, ww = self.env_mat.call( + rr, diff, ww = self.env_mat.call( coord_ext, atype_ext, nlist, self.davg, self.dstd, True ) nf, nloc, nnei, _ = rr.shape diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 0c2ca43c40..94cf3a7c21 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -109,17 +109,19 @@ def call( ------- env_mat The environment matrix. shape: nf x nloc x nnei x (4 or 1) + diff + The relative coordinate of neighbors. shape: nf x nloc x nnei x 3 switch The value of switch function. shape: nf x nloc x nnei """ - em, sw = self._call(nlist, coord_ext, radial_only) + em, diff, sw = self._call(nlist, coord_ext, radial_only) nf, nloc, nnei = nlist.shape atype = atype_ext[:, :nloc] if davg is not None: em -= davg[atype] if dstd is not None: em /= dstd[atype] - return em, sw + return em, diff, sw def _call(self, nlist, coord_ext, radial_only): em, diff, ww = _make_env_mat( @@ -130,7 +132,7 @@ def _call(self, nlist, coord_ext, radial_only): radial_only=radial_only, protection=self.protection, ) - return em, ww + return em, diff, ww def serialize( self, diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 319f8a0dbd..ae557326ff 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -33,6 +33,25 @@ ) +class Identity(NativeOP): + def __init__(self): + super().__init__() + + def call(self, x: np.ndarray) -> np.ndarray: + """The Identity operation layer.""" + return x + + def serialize(self) -> dict: + return { + "@class": "Identity", + "@version": 1, + } + + @classmethod + def deserialize(cls, data: dict) -> "Identity": + return Identity() + + class NativeLayer(NativeOP): """Native representation of a layer. diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 54bcce6969..00233df29b 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -29,6 +29,9 @@ from deepmd.pt.utils.update_sel import ( UpdateSel, ) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) from deepmd.utils.path import ( DPPath, ) @@ -544,8 +547,8 @@ def serialize(self) -> dict: "embeddings": repinit.filter_layers.serialize(), "env_mat": DPEnvMat(repinit.rcut, repinit.rcut_smth).serialize(), "@variables": { - "davg": repinit["davg"].detach().cpu().numpy(), - "dstd": repinit["dstd"].detach().cpu().numpy(), + "davg": to_numpy_array(repinit["davg"]), + "dstd": to_numpy_array(repinit["dstd"]), }, } if repinit.tebd_input_mode in ["strip"]: @@ -557,8 +560,8 @@ def serialize(self) -> dict: "repformer_layers": [layer.serialize() for layer in repformers.layers], "env_mat": DPEnvMat(repformers.rcut, repformers.rcut_smth).serialize(), "@variables": { - "davg": repformers["davg"].detach().cpu().numpy(), - "dstd": repformers["dstd"].detach().cpu().numpy(), + "davg": to_numpy_array(repformers["davg"]), + "dstd": to_numpy_array(repformers["dstd"]), }, } data.update( @@ -633,9 +636,9 @@ def forward( Parameters ---------- - coord_ext + extended_coord The extended coordinates of atoms. shape: nf x (nallx3) - atype_ext + extended_atype The extended aotm types. shape: nf x nall nlist The neighbor list. shape: nf x nloc x nnei diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index c839179ce6..3f304cc7a0 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -178,7 +178,8 @@ def _cal_hg( g = _apply_nlist_mask(g, nlist_mask) if not smooth: # nf x nloc - invnnei = 1.0 / (epsilon + torch.sum(nlist_mask, dim=-1)) + # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy + invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1)) # nf x nloc x 1 x 1 invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) else: @@ -207,7 +208,7 @@ def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor: grrg Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) """ - # nf x nloc x 3 x ng2 + # nf x nloc x 3 x ng nf, nloc, _, ng = hg.shape # nf x nloc x 3 x axis hgm = torch.split(hg, axis_neuron, dim=-1)[0] @@ -897,7 +898,10 @@ def _update_g1_conv( if not self.smooth: # normalized by number of neighbors, not smooth # nf x nloc x 1 - invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)).unsqueeze(-1) + # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy + invnnei = 1.0 / ( + self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1) + ).unsqueeze(-1) else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * torch.ones( diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index da22356814..55f6100b64 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -210,7 +210,7 @@ def __init__( self.epsilon = 1e-4 self.old_impl = old_impl - self.g2_embd = MLPLayer(1, self.g2_dim) + self.g2_embd = MLPLayer(1, self.g2_dim, precision=precision) layers = [] for ii in range(nlayers): if self.old_impl: diff --git a/source/tests/common/dpmodel/case_single_frame_with_nlist.py b/source/tests/common/dpmodel/case_single_frame_with_nlist.py index 828e090cad..f674c7a3a9 100644 --- a/source/tests/common/dpmodel/case_single_frame_with_nlist.py +++ b/source/tests/common/dpmodel/case_single_frame_with_nlist.py @@ -41,9 +41,11 @@ def setUp(self): ).reshape([1, self.nall, 3]) self.coord = self.coord_ext[:, : self.nloc, :] self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) + self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall]) self.atype = self.atype_ext[:, : self.nloc] # sel = [5, 2] self.sel = [5, 2] + self.sel_mix = [7] self.nlist = np.array( [ [1, 3, -1, -1, -1, 2, -1], @@ -66,6 +68,9 @@ def setUp(self): self.atype_ext = np.concatenate( [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 ) + self.mapping = np.concatenate( + [self.mapping, self.mapping[:, self.perm]], axis=0 + ) # permute the nlist nlist1 = self.nlist[:, self.perm[: self.nloc], :] mask = nlist1 == -1 diff --git a/source/tests/common/dpmodel/test_descriptor_dpa2.py b/source/tests/common/dpmodel/test_descriptor_dpa2.py new file mode 100644 index 0000000000..4df01c61ad --- /dev/null +++ b/source/tests/common/dpmodel/test_descriptor_dpa2.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor import ( + DescrptDPA2, +) + +from .case_single_frame_with_nlist import ( + TestCaseSingleFrameWithNlist, +) + + +class TestDescrptDPA2(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + dstd_2 = 0.1 + np.abs(dstd_2) + + em0 = DescrptDPA2( + ntypes=self.nt, + repinit_rcut=self.rcut, + repinit_rcut_smth=self.rcut_smth, + repinit_nsel=self.sel_mix, + repformer_rcut=self.rcut / 2, + repformer_rcut_smth=self.rcut_smth, + repformer_nsel=nnei // 2, + ) + + em0.repinit.mean = davg + em0.repinit.stddev = dstd + em0.repformers.mean = davg_2 + em0.repformers.stddev = dstd_2 + em1 = DescrptDPA2.deserialize(em0.serialize()) + mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, self.mapping) + mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist, self.mapping) + for ii in [0, 1, 4]: + np.testing.assert_allclose(mm0[ii], mm1[ii]) diff --git a/source/tests/common/dpmodel/test_env_mat.py b/source/tests/common/dpmodel/test_env_mat.py index 7e1ce7cddd..bfeb0d69d9 100644 --- a/source/tests/common/dpmodel/test_env_mat.py +++ b/source/tests/common/dpmodel/test_env_mat.py @@ -26,7 +26,12 @@ def test_self_consistency( dstd = 0.1 + np.abs(dstd) em0 = EnvMat(self.rcut, self.rcut_smth) em1 = EnvMat.deserialize(em0.serialize()) - mm0, ww0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd) - mm1, ww1 = em1.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd) + mm0, diff0, ww0 = em0.call( + self.coord_ext, self.atype_ext, self.nlist, davg, dstd + ) + mm1, diff1, ww1 = em1.call( + self.coord_ext, self.atype_ext, self.nlist, davg, dstd + ) np.testing.assert_allclose(mm0, mm1) + np.testing.assert_allclose(diff0, diff1) np.testing.assert_allclose(ww0, ww1) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 9c8c1cea7f..13ceef84ab 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -74,7 +74,7 @@ def eval_dp_descriptor( dp_obj.get_sel(), distinguish_types=(not mixed_types), ) - return dp_obj(ext_coords, ext_atype, nlist=nlist) + return dp_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) def eval_pt_descriptor( self, pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False @@ -95,5 +95,5 @@ def eval_pt_descriptor( ) return [ x.detach().cpu().numpy() if torch.is_tensor(x) else x - for x in pt_obj(ext_coords, ext_atype, nlist=nlist) + for x in pt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) ] diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py new file mode 100644 index 0000000000..1313a6d727 --- /dev/null +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -0,0 +1,380 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Tuple, +) + +import numpy as np +from dargs import ( + Argument, +) + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + CommonTest, + parameterized, +) +from .common import ( + DescriptorTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2PT +else: + DescrptDPA2PT = None + +# not implemented +DescrptDPA2TF = None + +from deepmd.utils.argcheck import ( + descrpt_se_atten_args, +) + + +@parameterized( + ("concat", "strip"), # repinit_tebd_input_mode + (True,), # repinit_set_davg_zero + (True, False), # repformer_direct_dist + (True,), # repformer_update_g1_has_conv + (True,), # repformer_update_g1_has_drrd + (True,), # repformer_update_g1_has_grrg + (True,), # repformer_update_g1_has_attn + (True,), # repformer_update_g2_has_g1g1 + (True,), # repformer_update_g2_has_attn + (False,), # repformer_update_h2 + (True, False), # repformer_attn2_has_gate + ("res_avg", "res_residual"), # repformer_update_style + ("norm", "const"), # repformer_update_residual_init + (True,), # repformer_set_davg_zero + (True, False), # smooth + ([], [[0, 1]]), # exclude_types + ("float64",), # precision + (True,), # trainable_lns + (1e-5,), # ln_eps + (False,), # type_one_side + (True, False), # add_tebd_to_repinit_out +) +class TestDPA2(CommonTest, DescriptorTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + smooth, + exclude_types, + precision, + trainable_ln, + ln_eps, + type_one_side, + add_tebd_to_repinit_out, + ) = self.param + return { + "ntypes": self.ntypes, + "repinit_rcut": 6.00, + "repinit_rcut_smth": 5.80, + "repinit_nsel": 10, + "repformer_rcut": 4.00, + "repformer_rcut_smth": 3.50, + "repformer_nsel": 8, + # kwargs for repinit + "repinit_neuron": [6, 12, 24], + "repinit_axis_neuron": 3, + "repinit_tebd_dim": 4, + "repinit_tebd_input_mode": repinit_tebd_input_mode, + "repinit_set_davg_zero": repinit_set_davg_zero, + "repinit_activation_function": "tanh", + # kwargs for repformer + "repformer_nlayers": 3, + "repformer_g1_dim": 20, + "repformer_g2_dim": 10, + "repformer_axis_neuron": 3, + "repformer_direct_dist": repformer_direct_dist, + "repformer_update_g1_has_conv": repformer_update_g1_has_conv, + "repformer_update_g1_has_drrd": repformer_update_g1_has_drrd, + "repformer_update_g1_has_grrg": repformer_update_g1_has_grrg, + "repformer_update_g1_has_attn": repformer_update_g1_has_attn, + "repformer_update_g2_has_g1g1": repformer_update_g2_has_g1g1, + "repformer_update_g2_has_attn": repformer_update_g2_has_attn, + "repformer_update_h2": repformer_update_h2, + "repformer_attn1_hidden": 12, + "repformer_attn1_nhead": 2, + "repformer_attn2_hidden": 10, + "repformer_attn2_nhead": 2, + "repformer_attn2_has_gate": repformer_attn2_has_gate, + "repformer_activation_function": "tanh", + "repformer_update_style": repformer_update_style, + "repformer_update_residual": 0.001, + "repformer_update_residual_init": repformer_update_residual_init, + "repformer_set_davg_zero": True, + # kwargs for descriptor + "concat_output_tebd": True, + "precision": precision, + "smooth": smooth, + "exclude_types": exclude_types, + "env_protection": 0.0, + "trainable": True, + "trainable_ln": trainable_ln, + "ln_eps": ln_eps, + "type_one_side": type_one_side, + "add_tebd_to_repinit_out": add_tebd_to_repinit_out, + } + + @property + def skip_pt(self) -> bool: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + smooth, + exclude_types, + precision, + trainable_ln, + ln_eps, + type_one_side, + add_tebd_to_repinit_out, + ) = self.param + return CommonTest.skip_pt + + @property + def skip_dp(self) -> bool: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + smooth, + exclude_types, + precision, + trainable_ln, + ln_eps, + type_one_side, + add_tebd_to_repinit_out, + ) = self.param + return CommonTest.skip_pt + + @property + def skip_tf(self) -> bool: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + smooth, + exclude_types, + precision, + trainable_ln, + ln_eps, + type_one_side, + add_tebd_to_repinit_out, + ) = self.param + return True + + tf_class = DescrptDPA2TF + dp_class = DescrptDPA2DP + pt_class = DescrptDPA2PT + args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + smooth, + exclude_types, + precision, + trainable_ln, + ln_eps, + type_one_side, + add_tebd_to_repinit_out, + ) = self.param + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + return self.build_tf_descriptor( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_descriptor( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_descriptor( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + return (ret[0],) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + smooth, + exclude_types, + precision, + trainable_ln, + ln_eps, + type_one_side, + add_tebd_to_repinit_out, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + smooth, + exclude_types, + precision, + trainable_ln, + ln_eps, + type_one_side, + add_tebd_to_repinit_out, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt/model/test_dpa2.py b/source/tests/pt/model/test_dpa2.py index a2ae57e549..30907fdeaa 100644 --- a/source/tests/pt/model/test_dpa2.py +++ b/source/tests/pt/model/test_dpa2.py @@ -40,6 +40,7 @@ def test_consistency( davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) dstd = 0.1 + np.abs(dstd) + dstd_2 = 0.1 + np.abs(dstd_2) for ( riti, diff --git a/source/tests/pt/model/test_env_mat.py b/source/tests/pt/model/test_env_mat.py index cc7b426585..84099cddaf 100644 --- a/source/tests/pt/model/test_env_mat.py +++ b/source/tests/pt/model/test_env_mat.py @@ -161,8 +161,10 @@ def test_consistency( dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) em0 = EnvMat(self.rcut, self.rcut_smth) - mm0, ww0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd) - mm1, _, ww1 = prod_env_mat( + mm0, diff0, ww0 = em0.call( + self.coord_ext, self.atype_ext, self.nlist, davg, dstd + ) + mm1, diff1, ww1 = prod_env_mat( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.nlist, dtype=int, device=env.DEVICE), torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE), @@ -172,5 +174,6 @@ def test_consistency( self.rcut_smth, ) np.testing.assert_allclose(mm0, mm1.detach().cpu().numpy()) + np.testing.assert_allclose(diff0, diff1.detach().cpu().numpy()) np.testing.assert_allclose(ww0, ww1.detach().cpu().numpy()) np.testing.assert_allclose(mm0[0][self.perm[: self.nloc]], mm0[1]) From b19a0e1322422d7a88e7800f98f491fcd42dc587 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 May 2024 17:36:57 +0000 Subject: [PATCH 17/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/descriptor/repformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 8c4390ae2d..77f95f4b91 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -44,7 +44,6 @@ ) from .repformer_layer_old_impl import RepformerLayer as RepformerLayerOld - if not hasattr(torch.ops.deepmd, "border_op"): def border_op( From c9527988bad3119359e349fdd927aabe03e55ee9 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 8 May 2024 01:42:31 +0800 Subject: [PATCH 18/37] update argcheck --- deepmd/utils/argcheck.py | 1 + examples/water/dpa2/input_torch.json | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index b7fea1f39b..a5b687782c 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -753,6 +753,7 @@ def descrpt_dpa2_args(): int, optional=True, default=8, + alias=["tebd_dim"], doc=doc_repinit_tebd_dim, ), Argument( diff --git a/examples/water/dpa2/input_torch.json b/examples/water/dpa2/input_torch.json index e94086b898..7096e1d40b 100644 --- a/examples/water/dpa2/input_torch.json +++ b/examples/water/dpa2/input_torch.json @@ -7,7 +7,7 @@ ], "descriptor": { "type": "dpa2", - "tebd_dim": 8, + "repinit_tebd_dim": 8, "repinit_rcut": 9.0, "repinit_rcut_smth": 8.0, "repinit_nsel": 120, @@ -20,7 +20,7 @@ 100 ], "repinit_axis_neuron": 12, - "repinit_activation": "tanh", + "repinit_activation_function": "tanh", "repformer_nlayers": 12, "repformer_g1_dim": 128, "repformer_g2_dim": 32, @@ -28,7 +28,7 @@ "repformer_attn2_nhead": 4, "repformer_attn1_hidden": 128, "repformer_attn1_nhead": 4, - "repformer_axis_dim": 4, + "repformer_axis_neuron": 4, "repformer_update_h2": false, "repformer_update_g1_has_conv": true, "repformer_update_g1_has_grrg": true, @@ -37,7 +37,7 @@ "repformer_update_g2_has_g1g1": true, "repformer_update_g2_has_attn": true, "repformer_attn2_has_gate": true, - "repformer_add_type_ebd_to_seq": false + "add_tebd_to_repinit_out": false }, "fitting_net": { "neuron": [ From bd1d5d92a32b5b9abd0ef75594767e52614a5561 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 8 May 2024 10:19:25 +0800 Subject: [PATCH 19/37] Fix uts --- deepmd/pt/model/descriptor/repformer_layer_old_impl.py | 6 ++++-- source/tests/pt/model/test_permutation.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/model/descriptor/repformer_layer_old_impl.py b/deepmd/pt/model/descriptor/repformer_layer_old_impl.py index ab39fbb830..af9c2e0981 100644 --- a/deepmd/pt/model/descriptor/repformer_layer_old_impl.py +++ b/deepmd/pt/model/descriptor/repformer_layer_old_impl.py @@ -436,7 +436,9 @@ def _update_g1_conv( if not self.smooth: # normalized by number of neighbors, not smooth # nb x nloc x 1 - invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)).unsqueeze(-1) + invnnei = 1.0 / ( + self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1) + ).unsqueeze(-1) else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * torch.ones( @@ -462,7 +464,7 @@ def _cal_h2g2( g2 = _apply_nlist_mask(g2, nlist_mask) if not self.smooth: # nb x nloc - invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)) + invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) # nb x nloc x 1 x 1 invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) else: diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index b4cd133200..c9977e3662 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -122,7 +122,7 @@ "repformer_nsel": 20, "repinit_neuron": [2, 4, 8], "repinit_axis_neuron": 4, - "repinit_activation": "tanh", + "repinit_activation_function": "tanh", "repformer_nlayers": 12, "repformer_g1_dim": 8, "repformer_g2_dim": 5, @@ -130,7 +130,7 @@ "repformer_attn2_nhead": 1, "repformer_attn1_hidden": 5, "repformer_attn1_nhead": 1, - "repformer_axis_dim": 4, + "repformer_axis_neuron": 4, "repformer_update_h2": False, "repformer_update_g1_has_conv": True, "repformer_update_g1_has_grrg": True, @@ -139,7 +139,7 @@ "repformer_update_g2_has_g1g1": True, "repformer_update_g2_has_attn": True, "repformer_attn2_has_gate": True, - "repformer_add_type_ebd_to_seq": False, + "add_tebd_to_repinit_out": False, }, "fitting_net": { "neuron": [24, 24], From 0ebadc2f9771775fd7902506f72dad7122db203e Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 8 May 2024 10:31:06 +0800 Subject: [PATCH 20/37] Update test_permutation.py --- source/tests/pt/model/test_permutation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index c9977e3662..f876b49b7f 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -207,7 +207,7 @@ "repformer_nsel": 10, "repinit_neuron": [2, 4, 8], "repinit_axis_neuron": 4, - "repinit_activation": "tanh", + "repinit_activation_function": "tanh", "repformer_nlayers": 12, "repformer_g1_dim": 8, "repformer_g2_dim": 5, @@ -215,7 +215,7 @@ "repformer_attn2_nhead": 1, "repformer_attn1_hidden": 5, "repformer_attn1_nhead": 1, - "repformer_axis_dim": 4, + "repformer_axis_neuron": 4, "repformer_update_h2": False, "repformer_update_g1_has_conv": True, "repformer_update_g1_has_grrg": True, @@ -224,7 +224,7 @@ "repformer_update_g2_has_g1g1": True, "repformer_update_g2_has_attn": True, "repformer_attn2_has_gate": True, - "repformer_add_type_ebd_to_seq": False, + "add_tebd_to_repinit_out": False, }, ], }, From fe6ed6e1c79f45ddb874149d4a2b3e375c785a60 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 8 May 2024 11:23:07 +0800 Subject: [PATCH 21/37] fix uts --- source/tests/pt/model/models/dpa2.json | 6 +++--- source/tests/pt/model/models/dpa2.pth | Bin 158910 -> 179645 bytes source/tests/pt/model/test_descriptor_dpa2.py | 13 +++++-------- source/tests/pt/model/test_unused_params.py | 2 -- 4 files changed, 8 insertions(+), 13 deletions(-) diff --git a/source/tests/pt/model/models/dpa2.json b/source/tests/pt/model/models/dpa2.json index 8b9c735851..966e298dfa 100644 --- a/source/tests/pt/model/models/dpa2.json +++ b/source/tests/pt/model/models/dpa2.json @@ -17,7 +17,7 @@ 8 ], "repinit_axis_neuron": 4, - "repinit_activation": "tanh", + "repinit_activation_function": "tanh", "repformer_nlayers": 12, "repformer_g1_dim": 8, "repformer_g2_dim": 5, @@ -25,7 +25,7 @@ "repformer_attn2_nhead": 1, "repformer_attn1_hidden": 5, "repformer_attn1_nhead": 1, - "repformer_axis_dim": 4, + "repformer_axis_neuron": 4, "repformer_update_h2": false, "repformer_update_g1_has_conv": true, "repformer_update_g1_has_grrg": true, @@ -34,7 +34,7 @@ "repformer_update_g2_has_g1g1": true, "repformer_update_g2_has_attn": true, "repformer_attn2_has_gate": true, - "repformer_add_type_ebd_to_seq": false + "add_tebd_to_repinit_out": false }, "fitting_net": { "neuron": [ diff --git a/source/tests/pt/model/models/dpa2.pth b/source/tests/pt/model/models/dpa2.pth index 26d6155272deec08544d2c14c294e94b8652824b..d31f69742adc2eeb854d7908c6c2391d6f256207 100644 GIT binary patch delta 41780 zcma)F349bq`Xv!axWW~V03ipb93hz`lLU}(Ul9m{LgX|+1eUXlA%Kd(RlF6n)o>_y zySwf#?jnT$B8MlUtKzYqE9iQzCtirU;{SbL^-OH+i1!{u2k7B3k+yhVfQ^56RSo8Rh`(%|fn6W6<}d4mQG zRK~sk^sbyKF(5URJxr@qII$fhl zkXo*j2DSLGlg?e*;AE7>MLmC}iZP1^P}Wb&L|=*p(}j@*^LVW(fn z>0c{@9GyXyLCtptAcKZsCz~>8xWUPx3<5fX+^~}ua`LSVy16sYYGMpZ&?$gOde|vs zk&?u#oT8FM8kB)rXi(T09CFUFQtC5Uj%Tq zh*5p4Q6{FRNz2g-Q_BXJ@nUbZ$3*(N&>QEm-Z+mV?eSW>BJ&bM^Cw_jQwR*Tj+Qx~3NrQMg#3pzj-dLvQhjl>0{ zI$f*I2s<-F&aASQy2lqSS4%6@N_}q4j-AL zd0}UM$f+o6O>*5URz&o?bLJvLvB+6~*lhkLOgIe+0JrEoNnB)iX>Ksop8umRn{T41fMaLra+Ff8og>1cGgg@8f|da(%@XI zdvi(HxisXgE6X(XmXxohG0M(bvS?XFxyqs3I?^0mW@-nVDkOI>A%&B>BJ5llayFEu z|ArW6d6lV@>s+l5=4-quyq2=OPG@<2*tsF(+<1m8-9?ElBL_nnk6WIz5rQdUXA^h3 zq|sH*%_WWK6uU)>x?$(mkaOE_WE;>PaJ!jpzEiEobh9_6Taay6XRB7;7Iy9kId_(I zH^#y3)g*Of0{<5B?(8eZdZ499T=Rr)v zp(Y(W(J}H+*!gqF`Ab<}ll5bh?k20y*$LKwrkb@Y>^u^3c1Ik?JFMN(W-0E`*3H^E zd(06ia%w=n+avD_JCBB({bd2O$(omWH7zsfRxOPmk(*G%bL=XEXph9~_dNx!9)-wr$f z4mt0X^^U3RE^5)t-PSzGIFbT+&T$ChNP3q96X{5LPY>vSwCMX`=Yx>*VcFS+s9WBm zt(!8s`55OTGr@f4W9TgmJD+I17FRj{Eong~&!<}Nv#|4d$oZnIpRL!!Of2Yp1)-r~ z=Y$r*h3;#8q5DP)eH(VZ3pwAH^)ZEBYSArv!K8B~P~iLkAzbWEvCszRM;fx9w8+n4 z=YJvR7t0^q;-t>T5g#aYe#Jq%Bq_5S&J%SO78GIC5M@X;LTNdHyEmzed9)R&1i&~A zm57?oMA701$*88-az3gVYGKtJWk|I^X(bjp1amxl4R}YON&?oHh$gw&swD*z+yyN& z$YfPp?1oh`%8+V@(#pEFHXVSfJ&8!q$jwn5C`MNX)e$x7Ors4d1y$-yCv0_RI-?d= zT~M|QsZ?xiU(0Hh-p-H?s5HQk^So9WJks4*qC3iv>Ol&Yq5ZX0rm2vtdTE6ZJ%zrk z(2o`RlY(V#*{yq-3VCXPR#?(n=a9<^d90963f_7%;w{74kkPd^Rq|DVRypaZ3}ThR zta1*iSdKTWO_r$;RK;51$~HQW5v(wh6-JSQw}AHZ$aK?3Hm)L;F{%Y>48?fD#-ip; z*f>;a!p38(C#(duusRoINKHUlhXLkfSq7*k0vu^GiLau$g=!K76Won$Q_rWvZwhw9 z>O7Pobw0{E3{Z&k@i&!JjP4YvX%tIv6O&Wy7A65K8AGucY6V1+ALVFM|o8zvmFl0bA7$WwV0 zkmm5^tE(wS_N=Z!%}e=OR4L`_u+=GFk6KvWfD%@6BQ|!<`P^pIywhe2s+8PTY;|(mPz$R&P=?fXmZf< zbJbyru@63jnm3_GP^Afd7F#`F&!HAp|3n#5&!eov2LrWy@CAUqHd{7gAI#5FFH$hU zeYA5bpF6K$H>_So8B+g3S(Z`9xf3;6K6zdP)a*{adYxk2oi|YPy7MNg)Sb7m)!lg; zwXphkT!Z##@*V7026;8wYa!U8YZrIfoWxngGG?z(eMK=^6RH!a=|ev8D)lvX zcx9-*K}~o0ThzkpJCq^yJxbd=TXgMdtkCj{)DM7ZCpv{1t8Y+0qDq7L6Si9UXVk*# ze<(xh7nGI}MXWD3YJK|4xjGH3(V-^6Z1pPz6Wph}YFj-E3s$TvQDR+*(z2-9+I(a* z2GURt=BNaUv8^Ve=5?kCs??dL*y_$SLk-JKl(5wn*jT-BZ%Q+5H5gDW0YlD@rRn)> z%M!^bv8p5m%T}AGo3fTld1 zp&QM`Zdm1^45@sSwpX^!>|>lF!2(sF6-qO;!XQ={%nIj_g5{B|GgI8X|7) z4eLadSSO;a!yj{N`Qt=@WwuQl6lAMO6ijgc+fA?LreGKAM3h)3qO|D?40} zqo%O}rA?merF{Xal=gINk+!?OdvcQb&Di~8_wL4JwIHD8XenG)FZ877JFzM!F}eq& z=4v@CXy>68R`XGYR0T@Q3f=xaOfM|RRSN(^R~Piq6TO%fmaxK7Qn0*mU{6{sHa8BJ zf;_cct9;dij?vaVlas8K;lQ5e%CI0`t!4_Z3{ms^U@fXNLKkDJC-xH5urfr6l_5&o z4i5D+?Vup2ssKZmzU`^M8ehQ*SF*weQn2ixxR+@M1qJGAR^TNdYF=*FqDr}4hpo=- zdepE4M2Qt3N;|iGy{y~{)keUO+n2p)pj!1#PO_3I>TM-cq^g<1OFq=RLum`Dl+0Fa zbu!yf!;%jrmV78}Z`$44^rpgWwF59D^J(vle)M(dURJn|749bm)0-k*pXavDNbd9i zapgf@n4=z~816}&t&gRLu!}_=N-XkF+KarJvrXeF45*!2;ag8(7b`r%3cE?cTI3Ds zV;WasuBy=r>-y*%_OZgFtgxRHEaMu|Ct_TQDa15-l&20-EWvHoSF1gN-LQHRWk@}R z(q8CkwYG+b73QnMtU@#R3~JsC9zm66@L6p23_gb%)_5qf#zR?$hvn7suonQ9nU8N! z7*sD(kXK?yQS%yo2~}$JWo&h$ub_stA4;tKP}<%$t6!>dvKJPp*8#(zZ0wiHht1op z@NZUlhZIap^ADSwo43P)j>yI+7pmhFOK@BEr;{S5e{z!bU3Yqa)3yqW)W@1~Ykxg4 zpJ6wwK1Uf+U!b&Y>v;d(MlXu8)mK`fOO{snmKDBZh3`qhah_yS_d^ow%T z4;15-%_-D8xA+lN8tI>~)g%2gYFMVB#4-(KS*EedcE89n-J&R?8Fx={0SYl14}))ZTvTQk(K>OzS{7fL&~PX<`I<*Al{ zA-B%i-mN@VNM?n0q+l9brh&j=(f&lJJ>jW3UV{0sCmsyMwOZgVXK>&f*KZhD6zmpX`3LLaqr4&+TQF%u9^lcM&wYQ zcETAfGm~XzkxXnS%y;30UDRwMn~HgA4h0k3VflKD=V2EMI+R$@p|lrtH{|!QlFC;L zw8UXgVlhiBVTq+AVXf%41-rOe!RD4sP%YOor9qv^N|p(;%qo(x{Bc{*|N4rLaljO) z)u?FmS%aEqK5J2>(Z3j5J^GiRh9w?KEb&lU<`X$pg7){<%Yc;w*bI#psw#>mxKj(f z`-0dFt1D54)CQDw_>wQK?+cQMIX6Y>Y6|KbgQ$7kxfWIG&UM)8?p%)=R)uHry^R~O zWA(>PE9`2Hc22h12pqDWQm9AUWsO@=;?5vRm~QHBFYGeaNG~TxRkH-I6H)Vy^)0AU zdRwv8>1{&|>qL}TC!(}$)9qZ8ZeZD21;B2Vqq5;F(tEOxQPBP_9-B&@~ZV}s}pU~5YzU)5-t zk3E@vEb}PK>?aw^79Sf#zxyWAg~vQ1a)Rn01ryxLK{}%+u#4p(N-Pgi+K%|xpnlP% zAwF)9lT)A$qoN(^8PvR4JAx|B+Oyc|S$hsOtPN3OZHUq`LjN3z-lFqt@dbcI-Z*P> zathUp6xEAE)Vx+-LX}#58C$gK-aR-uDFbQZFSsH<=sY_()f&$t^*U?0zYO-4gV@Eb zJ(Re!hq4Y63~CcRMr7ki4rHt26ije0JBQYTcb!vfJ$U3C%k=^|>SN7mI7CmyXV}Gp z5G59bC~eo9G^DrLkw8Fwr6q3kB)(;d?^xn{lCT_a(hyz|PA8VxgIx6k1@)2;HP0%3 zM3u((Cv5es{)`%ygeb8jL}@$4oMOu<0(t6J;4nCDai;#^w;@S1M2R&aO4})p7h6sd z$X5x#>Fg3w^RjD#DrMIcTb*4q)UYT-iA5nwJG+@f?d*c8C2+`Y(@;GPZCN6jCEAgM zM1dCCh|aW);a;zNKZFT*wQZO5z&sm0P4%Q;^+j z4Qif0tVNZk_F`=H)Lw!b)@>-UZbNDN!yRKRf55NYRls5Xj*ijpbp=aY$r2k#!t#eM zW4q)T+2QByt673?&Y|XIcP*-v-F4U^J9ov{KSmc8&{?4XZ~`hSYA9ws~cg^f3?6yaH9DCCW>*#6FgIlqL3)gtg+zD2aR#O|0#K z_@(-wmOJ6eJ;8EMvfNW7XBn85Ysp8RIUq&qFiX+oJ%gGzc}GyC$$J)CJ$cWehSe8J ztiDjzVPb`~OzZ`Ky;*40i2sf~2R~)MNYMoMmvd96(tX!gu!~!8C~*r8rDbB#dG*!# z)P0Rqj1J+q?bj)o;I5va&FyXMVljpii!qe8xgD5bnHzrGKCUHBP0+LTK1+PS5+9O; zWo`#1khwMO{1MU2baK_l6vRqyBCq5oCMQ{zc4VTp48xDpUujOlB;EeE*v0RrDDk@~ zO54tkOr&L4Q`08$^VJV5LpE^=HP0q~M3tuYCv3?k9-dUoCf=N6*+hO&{mNQyi_%P- zH`%4N?%H2aYWV_wWKL*AU8nXFQS;hwf-1G&6kC00HA4-HERCUKO?bWD3 z8nRBsj^$?2Ch6WA(aeVMGjkdR@%K_8@8%l|c1MYuZz$`qwj#Ru7DF#b_0oFRdwP9Y zuOI96Cq47)rJo<}zIDzbt}zn$O?dzX6Wq2_^ii0LUEF>{iQ8`|ZQGkY#j-v8s9c~W zwojph|In1=B+KV!Pl@;(ZnzQF%pp$|Yk~gf=^RF27dPBc;)WYa+vv1jlDp%)MpI3x zd^Lupc(sO_H?iYTrHLJnt)AEt)UaAZiPaiP%jn$8&u^OjJK+0{OH@pJQXEj4nQ&SoV`g;{D9Nm&cHjnjIXqlVv`SF;4U%^K7^w^@rSjoQW7 z>QTD{H7wasV#$Wmp5OLmY39`%zdu(2hp{Rt^OkHZaV1M^APMUhL3Nomr}&ZiYL?(7 z8){y5*P=?UIUF5{i zi98LWABj-6lZtW9;1}v@3bGk)M$J1Cx1dT{Z^c$;y$v<2>>A-Q7A)$M$`&1?(QgA`0~UzkpAb7Fd}OKr*w z>rz`#sCH`3v>E!q+l5`M?NDNEhcfbnQsnMW*33@yH+31) zU9)PfHIB@(F7EgN`*qfGf0^a2;INAo97?RQ<1Me*1X1Z^i+I?UEEDWiMwei%lezMIeE_6mPAl}r6o3b z65q1KcP#NeNtjn;ck-N6`hAtn#_UCb`hlf*DTkV8A3vf>!}}AqdRBi%4R_N};%*vB z%RZtfQRGntei951f~nI0dqWplr56^eUn!d4Zobg_OSs0Q-Vh~jw4toaS$uU~r8frD znif2IK|#GXM9u3^6I7`~O|jJ-YK9uthK(mun=P{)V<(PitG&PezG9ltXDBTVK!hj;;2@GiC)pdv~B#L|>NZ#}fTX!aBCTo|{_5 zlaWO{vmbc2f`U95*{FGwk%KBtMgUtq8M&xo1&9(WK$K-coCD#tD=O%L$BOy$7cE?)#y{ zeLs|CdB2fNmCn%1#!H6BWC-Y1DtOL=f;=WeQS-)R7^*ZT!?D$Aj6e-5Qj}PcqAcr( zG#2E{S-G$r;}TZo%gC(hiJ&Sv&qf1cHH8O3D3~yzK^$%*qDpP0Q@6)qrzEZ<{VFvc zHQinbYFM74#PSs7s4Hf2rF?3|?);c|K`T}?uW9tT&YsPrzwXAM-W)oG4yWhcq>tL31Q1-N2-zsS6oXE-@l$=%lBqxr~`7G2RUV>KHvmpPsGeGpIs>yzfF0 z(Al`u&I>Q9EJr^sV)O!uwk31aLIy38AX^g8?L;Jh$LPfpZAxaQsU?hADlw*Fw_Z`b zWsF@ev1wxT9&6MwO|4*zBQf4$lso-cqfTi`F=nO2c(GSz2##LP{fo)8n_iVemhN0I= zWapSj|Mvu}ORX z*uRV6cT0HJI&$LgCky5*EnlJTVb;BpRoB@0EL@vcIE&SNV5P28_e-+1#{aQN=RVP~ zet_G0P{OTWTS}XC%v671*dHa#{9@}plHV{>{fVIuNvL%nUq8Ly$WP+WO! zW*M34uMFEMVYFHjf7O6T`~AUxnBjku@Gb^Cvg-D~9IC%FXIFhVOP0^0b-^hAT6YV`|e$*bu)<}KNTqB1WpaTf%ADVla^M8_KO%XV~z*$_W{=wLN5^Iel zu-T0D8+epS`z5J|F;K|AbEZ1Lu!9n2{oY;PqKiNLhZz2tgqvrSKm1r|s*071=C3MW zq8?}36ZN76eA<&td#Z7jxFprp+(N;7nt6w1bFp1IM?J&PBNAHMrK{)-C44ni&ocfw ziH~iY(Jz|npG)|xB)^ZZG5dANwkA2~T`B!G@l{&A!Mrym&$8{mwrcPD(OV3E zTf$A-cC%Y2XR3cQ>>UX+M-Q-0nd%tBj!T&3$l0xhBfrb=_axjp-G{Y-BmakC?@O5F z$oky@aO4jd`k{ncj;x`=kw0R}$C6_C$gnnW_oS2S0DY-+YI>ETFrMBh9laf37hs8Gx|5n1SnZdLC-QmLD zG4Oi{v_=oW{);R91N$U1evk~yg5PZi3qHlLA0^DPV7vpsx8R=`{JMx`1uzLCVow#4Y$ z1XN?DBuI*RQb!CbO7Ukgk-1e(B>Bwk1yob!G?SDww9?G`dhtW_~-#H$6GZ@2|~6 zucg~Ft%IbQ!|TpUfk$^_Sc-&M9z82XcyuR*cb0I=36^z&M|WXZs)Y42M%ty98ysy7 zkM7FQGzqmlT0@0Lr!yr(QY@EQ)(IY+$*^t`W_dK;YT$cDcZTWbAx<>!MhM@5 zQuCQvAzL&3x(=d$EM3G33xq;T`)}#`_rUxuWZWWwvwtq^-$R)A?-;&V!YvaY*b^qc zgkei1%rbGjIl-ThWei^~;noS~nRsrNH1QQobLvGCCa##avK}U$Yb>$D%v&X!i*4e0 zs*<6rCA1C`$K&fT@imNJEAg>SH7{~wSzXMOOC-fS>m!Fql;R&8mom3%og|;Ry@0w* zb0p==?FkcK&)mx;w+<7}6DIz9W?dm!u}xf%{61dE>qg1gB ze@Gt~#wLc{EMdJ&!ynQIhJOpgTnV!be@Gu;__s3rHVL;3U%yNOhJQOlt0mMjeD5t5 zzQ1f{$`(n9ZTSAnB-B>sY^x6^YWR0B=T47vCd0=`X&C-?X6}%!#WsA4=uh8WtZ=tb zsIm;-KMG{qxYKeEQ|}d2>l#oO({KGgmbhO^^k`uo7RUSJ!tem&9+Wui!hknR_!IXB zhW}B*t&`ZhFyt8*hCeawp?cB8h2hUk`%67840*+-%V)iu1*-1&FRr%YRtXQ&EJqQgSfRC8=aXqd9`NlQi6XyL_HW&LE z5LBNs^fL*q;~Eg8DN~;_{tJnZ?a9%%cBn6z@|C2-HjOC7KR8Y>x9V$2K685k^$l~r zm6S8LC$0hCG534Pt>YRH6xV>0%=$sHVqXIU$?xMSX8$PJ)-e^i2KaQk2K>alpCvE$ zHNc?x?f#GXzes-UYe3Mr!*iNxze<`pym&zf?(j57+WNNU;+&p)=0v7nI=xJ7UBEvjTxFCp_Y+rsJQ>0$do3MVww44gK+=7DZ`pcm}TU6 zF^TUP%^BW8!eg5_&alY+_ax@DtPdyR$Ep=`T6>%`nK(`~!^GP#v#o3`wuxIr|5!?9 zg?2)rt_8m-fdzkirgjii%hw}Uberk7-jO9zq=dcTpELw!--&UZCC)N?yg0<4xGoG& zm2m6C^vphJn0;5KrPYfj%s!oI8HrWmjH#slK9x>TM)ey)SD@1ON!|Zk=cw={AujN+^W8keCGB7svmRuOUjws6K0>q+yRnX zhuIehv(IK$j%39)dqML17+`j;WLuLQF?*j*W}nBre97x=7@t3`F@J0gGQU9bW1D?} zVfKYgE0Q#Gc-@)9VfF(VHb}xOv!6L!nEhaepCjRx*)Jafvme5+VhOX%UcZh7Kei5K z=r9Sj%w9u<*$-#R2uZQbe)$NP{YZw5k}%8c@fH@}Ge$FfjD*KFd;hH}YAkcc)rS-D zZ8e@bB_1c%kF9t+3jQGGT*gh1+NQMwCjKC1BEu(1xHU6)VG8|0Oeq5=OQ1D+02+S~ z6Jo{`$*>mW`t2w9rgR=d&zDfkr@a@Y_~Sm6DbpmyIbF7h51b8&63T<_Ub}4o1t?gv<|Pvt1aNw7c#zF;$wSfVbrVVGG(5mm?w9{Q=$~# z$L2G)szQ>_++ILk#GD0^a_07gS1)AlBFU}8s|$r!|BhLUB`da93zFZ*CCpwb+14=? z@oJw=UcHQY%O$UOOs_U*e!DA}??`@ZuP!uh&M2m>l(f2rSKeUZn=@hNu995S!2Lr( za^0Qdg`-xo$Z9EKnMXzmY<&&G)=HRV>lr1u5pyxaE|Dlk>w1X`mApkeDbFyltaux$OvL@d2;V%SCrvn(Ai zi}A;M6T@$oaO>3b?7Yyh^IMqa){7?W{8pyjRu4NbH0=C#=2gq)V%vF<+RV@`5?Y6y z7qOjhW&AdYk8Q7!SI6Y2JD75(q{Ke|qZHqAwllYCha{i5y@0xlId@CSncEX~eh+i+ zmE1b)yhzyjeayOFvSQo0Ao+cKfY}d9wslNJ?A)i5o&SM(f0Vr5G40%-`R)FR`435c zY&$P9?EKG6`-`O2HN4W2+0OsU+?|pe+s<*YiF@tVkFXE3%HO1ln z`T6WA@bgC)wp+q1Kc78C`1u}&*GRbKMEadN@bkS4{XYq{{M>t;kMA%4V9Gv8iS6g! z0O5b%sUBs{{`zpDetv*C2R%+qKgZi}@JEq{82FeJw+tOP`B#yTGvNtIuqJ2yR7}p3 z40}q#tYKU~Ropdyn&F2f+;a1c)8OXMFzkqgS#FMZ>iDDmEW@9ZaO=$T+`Pze^M5k! z`FhcWo4>%c7wh5XMTVOnW!_7&x#XDVIo=8VGGkwnSbDzRe{c=!CH-CjdQl0M=dUv1 zUy=~}fQ!EVCP%%-oYy7CoQ%k#%;)&7^9J>QoqAKUf1|~KdW&0pTXKG*Md9iHX7W2x zauuGqjBRwtH*%9AppG%^c$5Z~X&<7OQI_A&cbWd4q^IkdsERym>>XU9RkHQ}F!6m! z^qxCEBhPR70~YvD3K&Oq;YyLYZ?ozn=6x)A<{-N%GjNW5!m$5Jn01b&%;@6(mExxi z|4hQW8Al}E>ywhHK4;h$64u*>b<#J`zhvlF5^BBXhh92#w6*vKaDpjcON#Zbt*n`u znd%#ceJf%4b`ss@;Iz>%9r%vn-%B_a3+BnjPb*`Fd7R0uPcrF;x=DDgO}JE z%(79uheto@`4dBbmQc$|0kxj={2z0EksNDQ!112+Jk8i&CDs}|{Ys$5{-if(!auU0 zjQ+@CE{M>{E}1HxVP{F$0Am2@6+gaRH)MDt2~RiR;<%Q70@;{(3H9R%!%k#gRg-!V z<*o0gOl&53bvSm8YR=de5^FkkRkT<7Vb3Hcw3LK8{4@IKrfS8U){Gpn6s)#0Cl<-1sWrgv!4prd~bnP->#^<11N9T}G*o3Tz5 zyi`b^D4iJESwgK71yJim>B5{;$+1opaJ&9PLv#DVbha&z3RskCrWSTRh?ZgqC8RhFtM-XCE<)V zUy-ExF{-~r(J5l0vMQ<~?^IG*Oc)>uG^#csg7@FSMDIFM+04$7Z2FS&o9wCu;#Em1 zz$&>?g{JH`RisE(w#s9fd?`~{hfJ0EJ}Hq*mUvT6kOd2*AiY)uerdjeNfk1%NCG>> z0!B~S9%>*{1_{clm2*B1ql;_3nP{-pug zqlY(ZFx~&RYgk29f_reufHo#BXY}x5H*Qt&uHvOT6A~kzxhKPSMR#^>U)3=#y5}CK z-0bhnUcJWOIVN^?&0e#~C+#}9cJEnH;I2LE*2G8m+~E4X(Vbl<*Y9l=-E;Tc?0@7| z-Xe19>dKsk-0J7F>)5Tb<;vUliadgw7mB?HwzLv^2euaad+zCNzNEYTjx~N#UGMaF zc5T0Nlb`Rd;9q59CwCg1-eq)p&)%rwu9JI=P9GK@RX#IRBz54(W}h|tS=rB%&xxkp zJqHwzvA!gI`)cex$#m+IYvupNGvgMOM%TDk4Rdl*K7$3Jcs#tRpdwxf&_q3HCRwl=Zu2vst7w0Ed zGrN79XtL&Hhd9yOnmsATZe`~<(fgW@y2Od@*PKp`6aB9qk|x@zK9&~ece;9Zy4b5Z zmTu%wnHeYgT@&nSaJtOUd{HgW)9eZ6tfyW|3@7*DXdud$yX#$sNWi@zI^~pszW--SmRHgmCq-yF|aM zXWuRNj@=!%$F#E9Bt2uVqNT~9NF$%%w zjZxV1Vy(#!Mh?}3N8<*F{7$}P4AGwd$RVnH-x#6;ABb#gW`Ag8bn+u(n)ZBR92z}8 zHHKlxXT~sG_PG$Q?)rt8{^~tn#QDpy>e*i!gYu=3Lgfi_P`)z<<)qQIAB=tn6XWHj zG%;RYEr&FX7gx(;P2-RHtIC=^E#l?Xva+SI8*Cjfua*zAiI-Q)18w8w)$(+*Xs3F6 zyLfT6?Al)J)ogDcze!ADuw(q*v#i6jC^F&2UE+6&d{1_bH?Ex-vg5(bc)5l+8IT|D z$u&mhbe_nzW_!MwQGqcGCku^XxM-jphU&`(#n12Qf6-{rfWFt5{~`unoj5o?w|T?e z4Wdq5J!WuxpxL7$a3uwv6oJRUY4(%|{5m*3ui4WgFytHr4vWA#JMbt4o)L(jDR4vt z1`a{sc@bDkfuka@*ADzZf%gO=Sd74bL||1hQkQ{Ug!n)pzOe&2LxK26AXZS|6A{=+ zf&YrYmlXI^1o{s{;4=|eLV?dk;2}HkDFwa|h_iitFF zMM4Y_fuykr6pO$F3Jev28|=Vy6c{cLO~)ZHQUu0PV3Y`4LxIsEaF_yPMWFHc_`J3v z;H@O{tKSJy_ zPbmVIh`=HWTrUEDpui0x@G%8$6oH5KZlOSp2)smreIn3mDgygOVA9n1l%D>#BKsO!dO%=r0=6N2 zDXLySBECbjgW{v-@lo}{sqvkY5Bbr^cff+`*QVn9Jo=3PQFZfa@j1;ed{#upO^XjC zZ+^~?XnG)rPW&#_m(Gt*N-q3ot&anA@^`7e@WS|{=6l7*)n}E(2bvc??-QzrQRL_g zB2q<>O)rbceu}jEmx%n#EEy@ffLI?1%xa31ek>xpFF^m^m62}<)A7FolRX`gy)v?l zBBh@R%%3UJ>I)J1lp^cD6p`LD5ZU~dh+Kq7@*5}oh#t$_0x_0rzOMCgp%}|k;^XS{ znP_3nH$I_y_DuAz)wd$DjW9>Q6OlJ5vgwqFbeM%mtDi;We2N^E5tkyHn#cQ>%$F!q z+CoHH&PHT&l88*8NMTD6xt=0NTZzaKifn2lA_;TQi|n={GI9yt&~_Y~ROPDBn- zN3lTXbBNtO-O-F&*OOaNcMC2qzj;4x8UO6JA=^~;i(keql9xg{O zHf4&)7tG2Pk-l?@l_w&LC{mU$A`epJl#F~xkxc~x(`_Cir9~oAPLV?cMPvs>-WVhz z?@*-UIU>?|ete+qUJ=nps60SRhxmB!Q@+xBC;KV!arLW^Yrdw~CsenoK%~_$5hErQrj{ilrN9fN7 g@$VKO>0fC{;{@8y*su5FIM%2!#h$@`gxA>r2b*)}VgLXD delta 24724 zcma)k34B!5^*=MhBqTss5(wFVu*#BTCbI>EfS{mJh~t*Z8k0yUf$%PmfE&@j0(B+u zgh7$TBG3vHEs&uV2-X#?5NoYk6}ni_x`12UTKIj>y>H&!J9$I@@cEpa^PT&?_uS=u z-^t6Hclf82Bky)@%DZe~&C=P`iz+>Z)s=c>!J_$fmtWeo!%gzPyPS7w9fno(?$Dt_ zb6;RkbOT>X+PH zS97gEQ?!m14m2v#I#pbRpH+*NXs)uex|VcJS%)%bSzcMkRa$b{va*h4j#XOc`>Lwy z>*{K%^m+9Q7pF{P)itc9deXcqJw;2oIHaZOT9=U4wIUNtBirY;Sg)?GY0$dWBDLC! zYDLd(Ev?pN^mF74AckJLmLAeFDn^5$9f9Gvc|Z<^SgEsl%e3AKjYsPPn$?a>PLr)` zIUy~#BAAHAIOXVkxtvC^*00u8=G3+ReBdjz0cDOATK;{tl+{388x+z8Umz=YG^-&> zR$gr=vZ{27tVZbC$dESb0$H_mOqx=lF!;1WOId%m>zX&D`4Z-p ze_b>!ze4KQ0wC?LYe7z0mb_9cDN8PMtkgye;xW2*Nk|);Fw^|01vk&DT&#^#Xaia) zXp?nqJkchv&?Zo>mkP>>x;81KT~;y1c1f034KhBvC^e%b(WjGnf0LvA)taj{#ayi` zsHZDMPi4AxRY?XadeSN}l&hm9SV_B~GMYh`e!r-Y3q_ z`$gUl=-PuJ?VDrThL)!L=QFa{}N9Oi6x?Sy>s_jsD1KLi(`zwq0 z*Mj#c!MjV>c89b*6&_pKMYY++*&(hB?KdiIP}_^$adl^L4cMn^`$O7+3ZE@UW9Ok| z%EFas5wUOwt%duY$m#bYr)PETP)Iv$t72}Wv2$NzYiVa=YbJ(!M9Irjq_qg@qZajZ zg8C1F`gvV@A*8)n5lE0KQFm?BUQ%h@+RLEDIrj?D@?CID(7q~YU(>ZehP2l!hTG8= z)K!}2N2RVW^vteXP_VRS-rQPUJFa%)(N6GgjIUFA@Js%iy7uRg_EyC(J6aZIR|XeDo`{cS&piDpGt;*WM3lA5;|DaxM6IA7gcDin351uXYN2*pz=GKE5gc zE{6RN!TC>J`!J;atD?x3^MlmrFG3##~9SkV~m${%^wb#XBVLk)mcZ_5bFeMuSLe)U2_7;OqH-C zwC0JiWO&isSZ4%0J(dCw85*r!)2AeodcN)4rqV~Hhjk-rxx3;bmPVM%7~E|D&$|~w zou$KuSO#o@yytbxnWGT7SszY>bA@HX6M1)9$@(JTTZd)A<1@>$;pr>~HpFsa?KRZ+ zYq$OxN+%xH53P}HTG{|dis4R6&#)K8`h_Hl$<((x-eNXXFx;Oe#ybL`&PKw9*eKWp zayX!#W zGAE>Eo&wKW<|`10Wxf)gSmrW#I=c!s#IATK;zW{7#WA(&fP!oqfi9zK&s@G= zuSKY{>tI9dde{Vtd3DcRm8gW>D2N(OqUoGy1}B*knr(#`g}$?IYDg77<& za4{#;IpGo_v{%eA=|kqor=Pc&E#*WIO|}f4rA@aY5ZZJbJhAw4)5{vs8Y}!sx=_-4IMKbFXay13mZC;MhN?(jA6vzVDBEAav$DMpfynlLcp}>e z;OXo^*bsXNHbJ&qGV+x3!Ru$M(He5yS7%M|s1qtk*%}1md{_%lbYj5M*(0zawhq>Q zB+;V5W2rY4)oeWuAT66|)xq8X3lm873EKeAIvF2DAUb~xp6GldJe@rb8)BPa6KL+L z-m2z$gKP_0V?7S_7Ta+fCwh_-Z6_jI30~VLUtQM{wo~u~`Usv~oM$)Z*+V?`h3K3m z>Bj0ly%aKE5!)-sA|}~BPPU(u9UwA$)xD-6*XUQ6Z1m4mW!&dx5kWXQQ)KrnCp^Rn z4-=ui;$G7bElInm8G2ZYAbidwe4Z1&zzJU@LVM*EgxwlnqNC-r&sWS|CXBE4EAXt< zK88T7_N(y3YQF|gXMcnZvDaZQEWMs~(wlW02dVvg*$Kj2#_#(I>3tKS&i)J=VsF7F zklr8rs?zK8v9|@$$Sgth9w&OA6MaBL_R@P|DOIXT+3X}8uFQ>}og&O-Y|j$Q_fLd6 z`w%w7{so&rgikE(ZJf_aQ z5tyYdJh6dO;pwakY>0J*O(4iu<*JttzlWuvHL`8Yr4sRp-1J7<>qT_6q7xwQ0sPP;64BI|5;C~RR-*us*?_H1|S*KK(7RF(C6SwA#K+Vk?n z<`~GC263jr#AK_o#v8YbLYEF!()O{Tgt?3lO~MffQRu>=(1lGPw>|DiH!kU?%B|ne z3I$npKapKAC-ZVLACcM1?Swn?c!5g?E7=8DfH0SF+9Vu`_tMt^-s^R7g<4*MxCk*C2TTh;N>kmYq76DAQt;dcywy@7?6I& z1%K%;ZYQc!iCfM%P{gJY<}&6EFw0qlC}&|&&cY_p)SUxVO%1r&je_W1lW008n!$-? z5|O>8-nG1sQIy|xj*^v!&E{md0l?tcjD(_uqw z39P-)T8dzNnVnhFKorWX_}Egy=#~Ln2G5eOTM-EPx(%L~mD}N=`h`XH3!6Z`x(`z2 zE8u61XpP(_4l=7>PINCPT0um%YG(XmkSex;09(b0c=Zd<%Jx14BHR1niEJN$hw2v= z)h}#xjrc+eL%pvMpg@PQ)u&c-DFMC<3wL9)l;c z-3SkrEG#Nn*aVgA-v_H28!Td5&>GqH9wM^c#)+QfMB9nTUSn?>qL#8jH`^(Q)(jCF zY!@fm&58C9k!>j(-C$m0!u470wr$YE_6pV>Lq&S~IO~4SdVpB%MRxYE=$(Y5tY%`6 zcWtnkMF?{lkC?2_BGlO-*bqAmn?Px256j1W*5vkSdRdDg%pNAve4Z1&zzJU@LVL}f zJ#3`%t5eX&Ugku!=C8oB*8CU(vF5MB6Knn&Jk+zWsApj>T+e#jN$+tSQ>*sxXD0~c z^>iyd>c7(p_6A(sJY|2vPto_A@KEQ%qRxf2*JNvTV}0}A+x~(h%q6*@V+Dc^!2o-k zK;H8^@XVgyg-bpE6+cDK@4?gA`>-MQ0c-*hcdv*R#%QB1e~_Ib%w?Rvm@bI}hNoxP z-;QvvAbn2B`lua62|Fz)?i?-_{8NNF`wTY3K8LjzY2&%!!<4$Zq==mnMCXQ6F8xQO zXV~9(cy4%(abmQKt|`iBO5E%_&cG{Qcw);WuVgxA!fw7D5C;{qSSk$|)_M#SDf1g!R^W}Hp0Cf<>EQv6F9Z7~~ z4WcswF^Cj+d=NufDqPgTu&9G!6O3ZV$hc8>SsGen6z`4{tJaGXrE{VTBC?m}+)=7D zm-tv8PQ>e9cp_c8v5P>QrdjaBoM*#B{R@lw7dC-3KRim6<`O^aht|mU`B6eq2Xdl8 zoMQu_n9CSfKpSpOL3-W=h1qAg`(?GK39>>#bA;3Utsp(a zUWp4Hh$^vqGoYk|1q8!{La`WQ5Te|LMY#)`K!gh(AQ4W=9!Ez5Q}^zJh8}S@Niof7Poa_ zFRaDBm>!D1eX*%HrZH|f9ySfANT3`ZB>*wL1};fr2tP&V*TO@A42uF8_QGN;bg{_Fm9ggTo68)7qI?Nv9n_wG*3$YHaHLtP>-n@yn0 zc*ZSOy$T^pWmuHTun9`#nGa?fzsc`RXH^`dkIf~}WfXY?<9viVtAh=(1+WRE_^%IU z8>>BCE7kn`te&&sZ)$82Jeu=lJUzoS1VV@yJTdQ!;h~0xMGXyWUqc(z)6)Z!XlKXe z9$-re6y-BKvBl8atq4SKx4}biMtO1i6&L)&mfzA?@x%pNPSon?f~=81m$9SRtfUd@ z>|WRqTLGIu06#6xiJPYqwu&?HQW~C>`F#jP=J&%B^Yj2bl+v&$rC}53^-OPGZzW%M z5nGMs$ajr*0Ke=QoaqtHw2qi;1v|P8qf=tu$5>BXY6or>CQ#fhh9`yrA%7HsS!2Va zVGIkgjc`$9!=lKBO`znHef{Ic;bB|Q9OJmtC-&$z&h#W_+D=TiQf@rw%a3C!W;+Gb z4<^$t&a|5|?I9*xH8;lkhsH5^*9)^Cxs6k6cro4njCT83i=bE;5ScxXP-ic|hS-a+31plY3yV74@)8}b zW))yB6DaNz!?V`;7y`4dh9|b`Yw%E4!=kQ+wbydvxnQq+Q@xL)Ra}ojc7i~cad*(X zGmH=wGb}1*SbLEMYwk=n_Y(FtXQE=`9e9*GE(7nv<&`mhioD;0hsqchl`*WnMjP(Y zy^K?9I~h%>t{ip}?bOUYMeG!TxMNg8E7(+$usjCQuk*WRf7G7C%}xuB&r8G-eu@zF zF)Zq1*aV{N9z9ImrXCMFBbe$(Q#NZxCn}KLqp3hnQFY2w%)aA?^XeF$rBgp35IXfE zJTZYk!9#Tni|QCQfld{UiPI^Mmz_s*Oj*?!^LB7YV(JKs@)*`$r;J~XiPKe&k0qhG zxFRIOv-0hXK;)YOPvn~l4|Oms>R{Le`DR@bpRb>#p*ixMdI?SHqD%hA4dHl)*#242yah)?QCUv0*lH;xgaE`~*244Xi5mygd? zw^*@{-6)uzF`1@wrWu@RCNbIHIsS6|pg1N!n=P0|Oc2vt#hI!(Qw=fM-+K6RJbx0x zXqwP5jxoUI3dW~Q#`&DFjx#PGM%(+t#+T!J8Pgt1HVOxJw%K7pRxembUMlij%vp8L zx`bHmZ%ag3QQy*N;#Q=DEhSLYx9}`AyA^?1-@+4H;&ynrJq(N6!>|d|?6XVb)C^CX z@ZIP#`cD)$C+^`)_j0Bc#B{+bF=!Re{MVtCI6xf+o<3PcARopr;90}C4}n?i!lPlJ zN_+q=s$E!AyRZqy@%F^{ad_BjG{^o;og~)4;7pHjrgg+*I}T&oq`0dH9!d#wCQ{=K z@T^nuQ3PgL3s21ZMtG=dVNun>Ca7wkn-o{odcAB5nj_!u%zU?TrYAYmc4D#@;Rou&1~`Vmtassw4e%N~)Vr{#cVREA#ld(j z&SA%KkeUDxAAalGD=)JT$vu9hn*tOWqdqE%=kYM;?W3LJQ@L;K$RD5rMnQx*&or-DrGS{ zO(6bid4-_-6d@{JSX8{Q31oTHl{rTF&ulLGJ})~X7(X!?zvhhJaK>+m(O!~AT{%Q8 zg7I+4cbthA!SF2o`T>Di1j7@%;U{<~f?-hv!zPfgUtbv~Up_xOkLH+x53bC2q^vFb zA2nr(D!Z;u?WQ!qk_Z$xh2dFgcSc~=!tg}esqjz>!=e_3McR{XYRYIu5$!$p7%+~B zThbs)BhY2MQYMzP7eYJ-42$Q0VH4*aU*?m=+b}?e}LU6m`0QfrsU&yC9oQAf8Dp7i(385VwC} zar+lGf$$F4DQ^GL!75`3n@ga}7<7$boR1KgdOEoMp29OI z^@4SW$-0=c>YQ~6vD#~Il=Z~uPK^!3pd^bYQkD`Z>S1`6qTPzXtcT$V9k?AH>S0*a z!>|{=EW6{s!e%svTxy9G@UTYkV6HZX%-g@bi+gz&E2sj_HikuQ3|lc0e;#S~B}(I&r+N*tvfqq$aoPLX z76M&H!*wG2Z3uPtBy5Onhpi|`z}p93*+|OXK}V|-5@0(CbQxcpl)Dh>Y&UF(?SW0O z?=Rk!&C6vvSY-^dy#%_9CD)52_aW5Te%KH@09)a=O>)uhEc06%Nye$`JH_p}5*87x zKbox1a@Iqf^)Rv8eG%MkEZUuK%()@%qJ+m)S_IoYH;5dc=WH)%%=)$M@iR$k_ev>jf7XKjaL2+T?up4bkr!9%4Ci%J7l&*WpL1=AVM)Z=IA{k(SjPGZ8d#f@LIXY#W%oQeLHcNQK^3IAK(mk5MN zeFaZU%h&Mm;4v&7JchNeZepFrrlq!~vU>K~oRmv&RNSlt*mneSz5gDb*j)5%1_F`6 zkMNLz(c`B7C6m!_u0_A2%+3+Fx(kBrJb^Bw_9pX2w2Ow-5f(S1VH3#Yo}1!iGKlZo zp}ARn!?RNFj6kHG0#8g{Dm)b5uqeJ^6Ub!c)9C}vZ~U@uXr?BQXJOI^bQw2J7bkZw zgs8V+QE$WAZ&hRC^uem=;F*{{g6S_NQ#NPH;Y_*2WV<9_|CMMg>Y2?n#MQM?gJ%o-{)~Ezuw<8cwp_r5-5aKChSUiOcYrpkCIaGZ@p#!KR*x&XX0pr*Qd-om-j9LC`&D+hYep@SPw#*`B3D$CT_czBe;Mf6Ad2?)fvFNG&k zm(8)M@sGz6X$bI>xrD#;Tuk`TVq~z#lM%xQptzh5g7c(7JWE^>ELoAr}^0HWP zQVRCXdZj6L(7v`~2QRW3?rVO1nbV34f9pXj*w@y2u)7r*&3})5WHkIk$|>)ka!%tu zjODcNpl*Z=Imc^3F`#cux+x&S`EZ=?p zb8=^G=cPeMnlExlqYnFBj)N9wqzzGwC*0ge{;(m_F+P@A_()$zEVan2EQd7QNL#i; zN+H~ws}SWmq~S+C>gSNgA34|GAx%JdT)xyz_+-8#HtO)Kfl?GXIZ#QVVX#9UxBp^g z+|47Dj7G{siyRr{cv_lcFQ+ z#Nm@uq@jjqT_Hy&cF?kWAi_u4*VJU!{*=01CN~B=XhF5xA}lF)P{)C ztN&T0*XF~HgFT`u&tKkEi0fgb@u*{daz>916XGr*xK0`8k4ufJD%#}_9r{cqdHr8D|Je0~CR9zDOYclF(;UoRc{ z{w>=EfA~$yl(!CjxAx*cw3L<{8gO&o-%Af~-gNK1``>JF{jl<~z55Qe=<8FmFWqsv zbkmHqb^4z7Th{&cwdw!->}*TPV{L24Zh65`YD_pV9#^l9DZkwEQ}`&Zphr8?vyWjM>3r9Mekg1snhW0K2GVP zmzOC;k#mb;=j}hW^g@#rH?IYU2QBp_E^9zEhr}hGKb&{9a{>miwe^ zBYA!$qvimnC^iiZB~InqKU$fFb7Pcg7&lfT4xb$Bj4gk-VVpBohlTS>l}Ra8Qt(eu zCuNd4DU+3kzZk%n2yxczgKu&q*=hFhwXX;@yZ zOv8~Hsh{wyTcn|ePs&kb)?8)aZM7J9-1aPqF8KL6=Uyq_=6YpT{7fEr8zT9Hxkvu6 z;a+7{j;xTfjm%o9X0%F~hUWW}X&CZ=JPqM-4?3%d#D2Z#(4hwuSN#{ghJ;^v&{^E& zTl^YgAxRH8J*jmPFo^(N0`7juS)95=0$Lt&dQ%%D;2Z%=(z#F`x0=?2BiLh1gxeO zUr4}l0?tT4*GB=Im4K__fd>fqN^0>c0pCbK=f?niD*;moXp?|d1bi<6F9GN!0cLsV z3qP^mIXolyL#%1^lcMn7@KKk8(nsO{w-K~I#@dFTBjBePu>Ntx=VNiWlmMqACLqh> zfx`s!l3M(SfOH8Mvk5?k1l&eIZwdI_Cagmr3HXj$;59v#ZuvI@$d-U51msA-egbkO z;A;Ts{bGPQF@8R0BT}VL!rB%b+dtMc{B%5UhJbvjMZs18LnL6qR*Ye&1ni;~!zJL~ z1dNb?i=O~6MgnRHxI_YW5-?T*J|bY81PuNqfbkMgO~3>R*haua3HS#AlO$lkHl%); z1kB!s7MDxFW&$Qlzz6X_&XZ_?w^LawHk|;xm&yVjBj9QYc$a{33CP?I;2H_Ik$_4G zSWm#s67be`q+TTfy>_5QjRb@Um@5Hm2$(MctpqHPfV7>?zTq7^o!O}iC3Gs#A+cX^ z;xD|d4N}{Oh~Zy5ompMFF%->Z*H=XDZg@h(%-=woT7->r5-n(*g4_b z4>-G}FT1_n#~wPlbHcNqai*n8pN5^kc6z!jyCc>L?@6-S6z&4@y^Jg%qzNxgvfAt- zr2KIS`Gk<8n6u$&h?vUaKx`_z9F#r||D7m;Ps9!g_uB*Hpp4Ai zgCUjwQfl)Uwdww(guFvY>rM&j{Tm?VPf5sigmm8}A!`X~eO5wR328bcA>H-@dFQZ% zTt&#TBNB2yA>EHk$SZ`jJ|`irr!k5}e~^%gPovFIydTM$^t%ap=LHFAAtVzoMY7ub zmym-pGUgc|!Iz~rw-RztMk0iiGgODKJP-M0FkdW4oBqV=7kfx6%WDX%6{w*P! z3CYBZiY%V@38}zKi!3DTfYa0KdkGPnQreU0O?bUgw7vN$zjpJ}yIYqa-f9&6IQ%B) zx*U{|9?d|?&q{5k5z-y6II@m)f>33-hlg;yF`Z8}HjsAOjhnM_EljI1CeGev6i zJRt{V#BmUd)YMgKGwvYTl&49^9fWl6AtApf6GpzE>>?;f(Ipy2nER;Gm15z?A1Ay56D_F$fboTfJA10-bdv&54x zAytI5dL-ltLYj&tO_oJRRN Date: Wed, 8 May 2024 15:40:49 +0800 Subject: [PATCH 22/37] Update test_dpa2.py --- source/tests/pt/model/test_dpa2.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/source/tests/pt/model/test_dpa2.py b/source/tests/pt/model/test_dpa2.py index 30907fdeaa..09331ed852 100644 --- a/source/tests/pt/model/test_dpa2.py +++ b/source/tests/pt/model/test_dpa2.py @@ -5,7 +5,7 @@ import numpy as np import torch -# from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DPDescrptDPA2 +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DPDescrptDPA2 from deepmd.pt.model.descriptor.dpa2 import ( DescrptDPA2, ) @@ -146,19 +146,16 @@ def test_consistency( atol=atol, ) # dp impl - # dd2 = DPDescrptDPA1.deserialize(dd0.serialize()) - # rd2, _, _, _, _ = dd2.call( - # self.coord_ext, - # self.atype_ext, - # self.nlist, - # ) - # np.testing.assert_allclose( - # rd0.detach().cpu().numpy(), - # rd2, - # rtol=rtol, - # atol=atol, - # err_msg=err_msg, - # ) + dd2 = DPDescrptDPA2.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, self.atype_ext, self.nlist, self.mapping + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + ) # old impl if prec == "float64" and rus == "res_avg": dd3 = DescrptDPA2( From e1270bdd03b37d2b32d084fdc1174e813a6a804c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 14:23:41 +0800 Subject: [PATCH 23/37] Update argcheck.py --- deepmd/utils/argcheck.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index a5b687782c..cc18d8c6d1 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2452,25 +2452,6 @@ def gen_args(**kwargs) -> List[Argument]: ] -def backend_compat(data): - data = data.copy() - # stripped_type_embedding in old DescrptSeAtten - if ( - "descriptor" in data["model"] - and data["model"]["descriptor"].get("type", "se_e2_a") == "se_atten" - and data["model"]["descriptor"].pop("stripped_type_embedding", False) - ): - if "tebd_input_mode" not in data["model"]["descriptor"]: - data["model"]["descriptor"]["tebd_input_mode"] = "strip" - elif data["model"]["descriptor"]["tebd_input_mode"] != "strip": - raise ValueError( - "Conflict detected: 'stripped_type_embedding' is set to True, but 'tebd_input_mode' is not 'strip'. Please ensure 'tebd_input_mode' is set to 'strip' when 'stripped_type_embedding' is True." - ) - else: - pass - return data - - def normalize_multi_task(data): # single-task or multi-task mode if data["model"].get("type", "standard") not in ("standard", "multi"): @@ -2667,7 +2648,6 @@ def normalize_fitting_weight(fitting_keys, data_keys, fitting_weight=None): def normalize(data): data = normalize_multi_task(data) - data = backend_compat(data) base = Argument("base", dict, gen_args()) data = base.normalize_value(data, trim_pattern="_*") From ceaaa07e27ce7946bfcf71cfca70dac2cf34b694 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 14:26:52 +0800 Subject: [PATCH 24/37] Update se_atten.py --- deepmd/pt/model/descriptor/se_atten.py | 160 +------------------------ 1 file changed, 1 insertion(+), 159 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 6fedb60d38..465bbb8308 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -808,164 +808,6 @@ def deserialize(cls, data: dict) -> "NeighborGatedAttentionLayer": return obj -# class GatedAttentionLayer(nn.Module): -# def __init__( -# self, -# nnei: int, -# embed_dim: int, -# hidden_dim: int, -# dotr: bool = False, -# do_mask: bool = False, -# scaling_factor: float = 1.0, -# normalize: bool = True, -# temperature: Optional[float] = None, -# bias: bool = True, -# smooth: bool = True, -# precision: str = DEFAULT_PRECISION, -# ): -# """Construct a neighbor-wise attention net.""" -# super().__init__() -# self.nnei = nnei -# self.embed_dim = embed_dim -# self.hidden_dim = hidden_dim -# self.dotr = dotr -# self.do_mask = do_mask -# self.bias = bias -# self.smooth = smooth -# self.scaling_factor = scaling_factor -# self.temperature = temperature -# self.precision = precision -# if temperature is None: -# self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 -# else: -# self.scaling = temperature -# self.normalize = normalize -# self.in_proj = MLPLayer( -# embed_dim, -# hidden_dim * 3, -# bias=bias, -# use_timestep=False, -# bavg=0.0, -# stddev=1.0, -# precision=precision, -# ) -# self.out_proj = MLPLayer( -# hidden_dim, -# embed_dim, -# bias=bias, -# use_timestep=False, -# bavg=0.0, -# stddev=1.0, -# precision=precision, -# ) -# -# def forward( -# self, -# query, -# nei_mask, -# input_r: Optional[torch.Tensor] = None, -# sw: Optional[torch.Tensor] = None, -# attnw_shift: float = 20.0, -# ): -# """Compute the gated self-attention. -# -# Parameters -# ---------- -# query -# inputs with shape: (nf x nloc) x nnei x embed_dim. -# nei_mask -# neighbor mask, with paddings being 0. shape: (nf x nloc) x nnei. -# input_r -# normalized radial. shape: (nf x nloc) x nnei x 3. -# sw -# The smooth switch function. shape: (nf x nloc) x nnei -# attnw_shift : float -# The attention weight shift to preserve smoothness when doing padding before softmax. -# """ -# q, k, v = self.in_proj(query).chunk(3, dim=-1) -# # [nframes * nloc, nnei, hidden_dim] -# q = q.view(-1, self.nnei, self.hidden_dim) -# k = k.view(-1, self.nnei, self.hidden_dim) -# v = v.view(-1, self.nnei, self.hidden_dim) -# if self.normalize: -# q = torch_func.normalize(q, dim=-1) -# k = torch_func.normalize(k, dim=-1) -# v = torch_func.normalize(v, dim=-1) -# q = q * self.scaling -# k = k.transpose(1, 2) -# # [nframes * nloc, nnei, nnei] -# attn_weights = torch.bmm(q, k) -# # [nframes * nloc, nnei] -# nei_mask = nei_mask.view(-1, self.nnei) -# if self.smooth: -# # [nframes * nloc, nnei] -# assert sw is not None -# sw = sw.view([-1, self.nnei]) -# attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[ -# :, None, : -# ] - attnw_shift -# else: -# attn_weights = attn_weights.masked_fill( -# ~nei_mask.unsqueeze(1), float("-inf") -# ) -# attn_weights = torch_func.softmax(attn_weights, dim=-1) -# attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), 0.0) -# if self.smooth: -# assert sw is not None -# attn_weights = attn_weights * sw[:, :, None] * sw[:, None, :] -# if self.dotr: -# assert input_r is not None, "input_r must be provided when dotr is True!" -# angular_weight = torch.bmm(input_r, input_r.transpose(1, 2)) -# attn_weights = attn_weights * angular_weight -# o = torch.bmm(attn_weights, v) -# output = self.out_proj(o) -# return output -# -# def serialize(self) -> dict: -# """Serialize the networks to a dict. -# -# Returns -# ------- -# dict -# The serialized networks. -# """ -# # network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()} -# # network_type_name = network_type_map_inv[self.network_type] -# return { -# "nnei": self.nnei, -# "embed_dim": self.embed_dim, -# "hidden_dim": self.hidden_dim, -# "dotr": self.dotr, -# "do_mask": self.do_mask, -# "scaling_factor": self.scaling_factor, -# "normalize": self.normalize, -# "temperature": self.temperature, -# "bias": self.bias, -# "smooth": self.smooth, -# "precision": self.precision, -# "in_proj": self.in_proj.serialize(), -# "out_proj": self.out_proj.serialize(), -# } -# -# @classmethod -# def deserialize(cls, data: dict) -> "GatedAttentionLayer": -# """Deserialize the networks from a dict. -# -# Parameters -# ---------- -# data : dict -# The dict to deserialize from. -# """ -# data = data.copy() -# in_proj = data.pop("in_proj") -# out_proj = data.pop("out_proj") -# obj = cls(**data) -# obj.in_proj = MLPLayer.deserialize(in_proj) -# obj.out_proj = MLPLayer.deserialize(out_proj) -# return obj -# - - class GatedAttentionLayer(nn.Module): def __init__( self, @@ -980,7 +822,7 @@ def __init__( temperature: Optional[float] = None, bias: bool = True, smooth: bool = True, - precision: str = "float", + precision: str = DEFAULT_PRECISION, ): """Construct a multi-head neighbor-wise attention net.""" super().__init__() From d2bcdbff5e01d8191748bd67f9f7a89b33446e22 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 14:30:02 +0800 Subject: [PATCH 25/37] Fix typo --- deepmd/dpmodel/descriptor/descriptor.py | 2 +- deepmd/dpmodel/descriptor/dpa1.py | 2 +- deepmd/dpmodel/descriptor/repformers.py | 2 +- deepmd/pt/model/descriptor/descriptor.py | 2 +- deepmd/pt/model/descriptor/se_atten.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/deepmd/dpmodel/descriptor/descriptor.py b/deepmd/dpmodel/descriptor/descriptor.py index f11f406399..444df1abf8 100644 --- a/deepmd/dpmodel/descriptor/descriptor.py +++ b/deepmd/dpmodel/descriptor/descriptor.py @@ -71,7 +71,7 @@ def get_dim_out(self) -> int: @abstractmethod def get_dim_in(self) -> int: - """Returns the output dimension.""" + """Returns the input dimension.""" pass @abstractmethod diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 9d1dee2481..e5e79a8984 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -648,7 +648,7 @@ def get_ntypes(self) -> int: return self.ntypes def get_dim_in(self) -> int: - """Returns the output dimension.""" + """Returns the input dimension.""" return self.dim_in def get_dim_out(self) -> int: diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index bd2f97a29a..ca3651a979 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -268,7 +268,7 @@ def get_ntypes(self) -> int: return self.ntypes def get_dim_in(self) -> int: - """Returns the output dimension.""" + """Returns the input dimension.""" return self.dim_in def get_dim_out(self) -> int: diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 5aae848aa4..d586fc988f 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -80,7 +80,7 @@ def get_dim_out(self) -> int: @abstractmethod def get_dim_in(self) -> int: - """Returns the output dimension.""" + """Returns the input dimension.""" pass @abstractmethod diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 465bbb8308..958f3b4963 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -299,7 +299,7 @@ def get_ntypes(self) -> int: return self.ntypes def get_dim_in(self) -> int: - """Returns the output dimension.""" + """Returns the input dimension.""" return self.dim_in def get_dim_out(self) -> int: From 385e1f7a998c9883a784ad4180b5ce0335c4be05 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 15:39:32 +0800 Subject: [PATCH 26/37] revert 'nf' to 'nb' --- deepmd/pt/model/descriptor/dpa2.py | 2 +- deepmd/pt/model/descriptor/repformer_layer.py | 144 +++++++++--------- deepmd/pt/model/descriptor/repformers.py | 12 +- 3 files changed, 79 insertions(+), 79 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index f23456f8e9..2e9e228bc9 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -255,7 +255,7 @@ def __init__( Returns ------- descriptor: torch.Tensor - the descriptor of shape nf x nloc x g1_dim. + the descriptor of shape nb x nloc x g1_dim. invariant single-atom representation. g2: torch.Tensor invariant pair-atom representation. diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 3f304cc7a0..8af81520dd 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -78,26 +78,26 @@ def _make_nei_g1( Parameters ---------- g1_ext - Extended atomic invariant rep, with shape nf x nall x ng1. + Extended atomic invariant rep, with shape nb x nall x ng1. nlist - Neighbor list, with shape nf x nloc x nnei. + Neighbor list, with shape nb x nloc x nnei. Returns ------- gg1: torch.Tensor - Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. """ - # nlist: nf x nloc x nnei - nf, nloc, nnei = nlist.shape - # g1_ext: nf x nall x ng1 + # nlist: nb x nloc x nnei + nb, nloc, nnei = nlist.shape + # g1_ext: nb x nall x ng1 ng1 = g1_ext.shape[-1] - # index: nf x (nloc x nnei) x ng1 - index = nlist.reshape(nf, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) - # gg1 : nf x (nloc x nnei) x ng1 + # index: nb x (nloc x nnei) x ng1 + index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) + # gg1 : nb x (nloc x nnei) x ng1 gg1 = torch.gather(g1_ext, dim=1, index=index) - # gg1 : nf x nloc x nnei x ng1 - gg1 = gg1.view(nf, nloc, nnei, ng1) + # gg1 : nb x nloc x nnei x ng1 + gg1 = gg1.view(nb, nloc, nnei, ng1) return gg1 @@ -291,34 +291,34 @@ def __init__( def forward( self, - g2: torch.Tensor, # nf x nloc x nnei x ng2 - h2: torch.Tensor, # nf x nloc x nnei x 3 - nlist_mask: torch.Tensor, # nf x nloc x nnei - sw: torch.Tensor, # nf x nloc x nnei + g2: torch.Tensor, # nb x nloc x nnei x ng2 + h2: torch.Tensor, # nb x nloc x nnei x 3 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: ( - nf, + nb, nloc, nnei, _, ) = g2.shape nd, nh = self.hidden_dim, self.head_num - # nf x nloc x nnei x nd x (nh x 2) - g2qk = self.mapqk(g2).view(nf, nloc, nnei, nd, nh * 2) - # nf x nloc x (nh x 2) x nnei x nd + # nb x nloc x nnei x nd x (nh x 2) + g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2) + # nb x nloc x (nh x 2) x nnei x nd g2qk = torch.permute(g2qk, (0, 1, 4, 2, 3)) - # nf x nloc x nh x nnei x nd + # nb x nloc x nh x nnei x nd g2q, g2k = torch.split(g2qk, nh, dim=2) # g2q = torch.nn.functional.normalize(g2q, dim=-1) # g2k = torch.nn.functional.normalize(g2k, dim=-1) - # nf x nloc x nh x nnei x nnei + # nb x nloc x nh x nnei x nnei attnw = torch.matmul(g2q, torch.transpose(g2k, -1, -2)) / nd**0.5 if self.has_gate: gate = torch.matmul(h2, torch.transpose(h2, -1, -2)).unsqueeze(-3) attnw = attnw * gate - # mask the attenmap, nf x nloc x 1 x 1 x nnei + # mask the attenmap, nb x nloc x 1 x 1 x nnei attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2) - # mask the attenmap, nf x nloc x 1 x nnei x 1 + # mask the attenmap, nb x nloc x 1 x nnei x 1 attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1) if self.smooth: attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ @@ -334,19 +334,19 @@ def forward( attnw_mask, 0.0, ) - # nf x nloc x nh x nnei x nnei + # nb x nloc x nh x nnei x nnei attnw = attnw.masked_fill( attnw_mask_c, 0.0, ) if self.smooth: attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] - # nf x nloc x nnei x nnei + # nb x nloc x nnei x nnei h2h2t = torch.matmul(h2, torch.transpose(h2, -1, -2)) / 3.0**0.5 - # nf x nloc x nh x nnei x nnei + # nb x nloc x nh x nnei x nnei ret = attnw * h2h2t[:, :, None, :, :] # ret = torch.softmax(g2qk, dim=-1) - # nf x nloc x nnei x nnei x nh + # nb x nloc x nnei x nnei x nh ret = torch.permute(ret, (0, 1, 3, 4, 2)) return ret @@ -561,32 +561,32 @@ def __init__( def forward( self, - g1: torch.Tensor, # nf x nloc x ng1 - gg1: torch.Tensor, # nf x nloc x nnei x ng1 - nlist_mask: torch.Tensor, # nf x nloc x nnei - sw: torch.Tensor, # nf x nloc x nnei + g1: torch.Tensor, # nb x nloc x ng1 + gg1: torch.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: - nf, nloc, nnei = nlist_mask.shape + nb, nloc, nnei = nlist_mask.shape ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num assert ni == g1.shape[-1] assert ni == gg1.shape[-1] - # nf x nloc x nd x nh - g1q = self.mapq(g1).view(nf, nloc, nd, nh) - # nf x nloc x nh x nd + # nb x nloc x nd x nh + g1q = self.mapq(g1).view(nb, nloc, nd, nh) + # nb x nloc x nh x nd g1q = torch.permute(g1q, (0, 1, 3, 2)) - # nf x nloc x nnei x (nd+ni) x nh - gg1kv = self.mapkv(gg1).view(nf, nloc, nnei, nd + ni, nh) + # nb x nloc x nnei x (nd+ni) x nh + gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh) gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3)) - # nf x nloc x nh x nnei x nd, nf x nloc x nh x nnei x ng1 + # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 gg1k, gg1v = torch.split(gg1kv, [nd, ni], dim=-1) - # nf x nloc x nh x 1 x nnei + # nb x nloc x nh x 1 x nnei attnw = torch.matmul(g1q.unsqueeze(-2), torch.transpose(gg1k, -1, -2)) / nd**0.5 - # nf x nloc x nh x nnei + # nb x nloc x nh x nnei attnw = attnw.squeeze(-2) - # mask the attenmap, nf x nloc x 1 x nnei + # mask the attenmap, nb x nloc x 1 x nnei attnw_mask = ~nlist_mask.unsqueeze(-2) - # nf x nloc x nh x nnei + # nb x nloc x nh x nnei if self.smooth: attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift else: @@ -602,11 +602,11 @@ def forward( if self.smooth: attnw = attnw * sw.unsqueeze(-2) - # nf x nloc x nh x ng1 + # nb x nloc x nh x ng1 ret = ( - torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nf, nloc, nh * ni) + torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) ) - # nf x nloc x ng1 + # nb x nloc x ng1 ret = self.head_map(ret) return ret @@ -878,26 +878,26 @@ def _update_g1_conv( Parameters ---------- gg1 - Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. g2 - Pair invariant rep, with shape nf x nloc x nnei x ng2. + Pair invariant rep, with shape nb x nloc x nnei x ng2. nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. + and remains 0 beyond rcut, with shape nb x nloc x nnei. """ assert self.proj_g1g2 is not None - nf, nloc, nnei, _ = g2.shape + nb, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] ng2 = g2.shape[-1] - # gg1 : nf x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).view(nf, nloc, nnei, ng2) - # nf x nloc x nnei x ng2 + # gg1 : nb x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) + # nb x nloc x nnei x ng2 gg1 = _apply_nlist_mask(gg1, nlist_mask) if not self.smooth: # normalized by number of neighbors, not smooth - # nf x nloc x 1 + # nb x nloc x 1 # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy invnnei = 1.0 / ( self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1) @@ -905,18 +905,18 @@ def _update_g1_conv( else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nf, nloc, 1), dtype=gg1.dtype, device=gg1.device + (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device ) - # nf x nloc x ng2 + # nb x nloc x ng2 g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei return g1_11 def _update_g2_g1g1( self, - g1: torch.Tensor, # nf x nloc x ng1 - gg1: torch.Tensor, # nf x nloc x nnei x ng1 - nlist_mask: torch.Tensor, # nf x nloc x nnei - sw: torch.Tensor, # nf x nloc x nnei + g1: torch.Tensor, # nb x nloc x ng1 + gg1: torch.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: """ Update the g2 using element-wise dot g1_i * g1_j. @@ -924,17 +924,17 @@ def _update_g2_g1g1( Parameters ---------- g1 - Atomic invariant rep, with shape nf x nloc x ng1. + Atomic invariant rep, with shape nb x nloc x ng1. gg1 - Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. + and remains 0 beyond rcut, with shape nb x nloc x nnei. """ ret = g1.unsqueeze(-2) * gg1 - # nf x nloc x nnei x ng1 + # nb x nloc x nnei x ng1 ret = _apply_nlist_mask(ret, nlist_mask) if self.smooth: ret = _apply_switch(ret, sw) @@ -972,11 +972,11 @@ def forward( or self.update_g2_has_g1g1 ) - nf, nloc, nnei, _ = g2.shape + nb, nloc, nnei, _ = g2.shape nall = g1_ext.shape[1] g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) - assert (nf, nloc) == g1.shape[:2] - assert (nf, nloc, nnei) == h2.shape[:3] + assert (nb, nloc) == g1.shape[:2] + assert (nb, nloc, nnei) == h2.shape[:3] g2_update: List[torch.Tensor] = [g2] h2_update: List[torch.Tensor] = [h2] @@ -991,7 +991,7 @@ def forward( if self.update_chnnl_2: # mlp(g2) assert self.linear2 is not None - # nf x nloc x nnei x ng2 + # nb x nloc x nnei x ng2 g2_1 = self.act(self.linear2(g2)) g2_update.append(g2_1) @@ -1006,13 +1006,13 @@ def forward( if self.update_g2_has_attn or self.update_h2: # gated_attention(g2, h2) assert self.attn2g_map is not None - # nf x nloc x nnei x nnei x nh + # nb x nloc x nnei x nnei x nh AAg = self.attn2g_map(g2, h2, nlist_mask, sw) if self.update_g2_has_attn: assert self.attn2_mh_apply is not None assert self.attn2_lm is not None - # nf x nloc x nnei x ng2 + # nb x nloc x nnei x ng2 g2_2 = self.attn2_mh_apply(AAg, g2) g2_2 = self.attn2_lm(g2_2) g2_update.append(g2_2) @@ -1052,7 +1052,7 @@ def forward( ) ) - # nf x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] + # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] # conv grrg drrd g1_1 = self.act(self.linear1(torch.cat(g1_mlp, dim=-1))) g1_update.append(g1_1) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 77f95f4b91..b8a24945c0 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -428,14 +428,14 @@ def forward( atype_embd = extended_atype_embd assert isinstance(atype_embd, torch.Tensor) # for jit g1 = self.act(atype_embd) - # nf x nloc x nnei x 1, nf x nloc x nnei x 3 + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 if not self.direct_dist: g2, h2 = torch.split(dmatrix, [1, 3], dim=-1) else: g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut - # nf x nloc x nnei x ng2 + # nb x nloc x nnei x ng2 g2 = self.act(self.g2_embd(g2)) # set all padding positions to index of 0 @@ -448,8 +448,8 @@ def forward( mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim) ) for idx, ll in enumerate(self.layers): - # g1: nf x nloc x ng1 - # g1_ext: nf x nall x ng1 + # g1: nb x nloc x ng1 + # g1_ext: nb x nall x ng1 if comm_dict is None: assert mapping is not None g1_ext = torch.gather(g1, 1, mapping) @@ -485,9 +485,9 @@ def forward( sw, ) - # nf x nloc x 3 x ng2 + # nb x nloc x 3 x ng2 h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon) - # (nf x nloc) x ng2 x 3 + # (nb x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw From d1e38ad602a4f02fd0393ced0559fbe60eb74f62 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 15:42:49 +0800 Subject: [PATCH 27/37] Update repformers.py --- deepmd/pt/model/descriptor/repformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index b8a24945c0..b2df49c964 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -400,10 +400,10 @@ def forward( nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] - # # nf x nloc x nnei + # nf x nloc x nnei exclude_mask = self.emask(nlist, extended_atype) nlist = nlist * exclude_mask - # nf x nloc x nnei x 4, nf x nloc x nnei x 3, nf x nloc x nnei x 1 + # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 dmatrix, diff, sw = prod_env_mat( extended_coord, nlist, From 9d0ad7fae0bf426f032b6edd251581bd0f99f5b2 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 15:43:46 +0800 Subject: [PATCH 28/37] Update repformers.py --- deepmd/pt/model/descriptor/repformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index b2df49c964..e352c6b40c 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -400,7 +400,7 @@ def forward( nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] - # nf x nloc x nnei + # nb x nloc x nnei exclude_mask = self.emask(nlist, extended_atype) nlist = nlist * exclude_mask # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 From 375c03ee37bf7a534c28035c559ff17f256574de Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 16:51:45 +0800 Subject: [PATCH 29/37] mv symmetrization_op into static --- deepmd/pt/model/descriptor/repformer_layer.py | 260 +++++++++--------- deepmd/pt/model/descriptor/repformers.py | 5 +- 2 files changed, 133 insertions(+), 132 deletions(-) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 8af81520dd..af436ca96d 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -137,134 +137,6 @@ def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: return gg * sw.unsqueeze(-1) -def _cal_hg( - g: torch.Tensor, - h: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - smooth: bool = True, - epsilon: float = 1e-4, -) -> torch.Tensor: - """ - Calculate the transposed rotation matrix. - - Parameters - ---------- - g - Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. - h - Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. - nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. - sw - The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. - smooth - Whether to use smoothness in processes such as attention weights calculation. - epsilon - Protection of 1./nnei. - - Returns - ------- - hg - The transposed rotation matrix, with shape nf x nloc x 3 x ng. - """ - # g: nf x nloc x nnei x ng - # h: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei - nf, nloc, nnei, _ = g.shape - ng = g.shape[-1] - # nf x nloc x nnei x ng - g = _apply_nlist_mask(g, nlist_mask) - if not smooth: - # nf x nloc - # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy - invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1)) - # nf x nloc x 1 x 1 - invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) - else: - g = _apply_switch(g, sw) - invnnei = (1.0 / float(nnei)) * torch.ones( - (nf, nloc, 1, 1), dtype=g.dtype, device=g.device - ) - # nf x nloc x 3 x ng - hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei - return hg - - -def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor: - """ - Calculate the atomic invariant rep. - - Parameters - ---------- - hg - The transposed rotation matrix, with shape nf x nloc x 3 x ng. - axis_neuron - Size of the submatrix. - - Returns - ------- - grrg - Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) - """ - # nf x nloc x 3 x ng - nf, nloc, _, ng = hg.shape - # nf x nloc x 3 x axis - hgm = torch.split(hg, axis_neuron, dim=-1)[0] - # nf x nloc x axis_neuron x ng - grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1) - # nf x nloc x (axis_neuron x ng) - grrg = grrg.view(nf, nloc, axis_neuron * ng) - return grrg - - -def symmetrization_op( - g: torch.Tensor, - h: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - axis_neuron: int, - smooth: bool = True, - epsilon: float = 1e-4, -) -> torch.Tensor: - """ - Symmetrization operator to obtain atomic invariant rep. - - Parameters - ---------- - g - Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. - h - Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. - nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. - sw - The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. - axis_neuron - Size of the submatrix. - smooth - Whether to use smoothness in processes such as attention weights calculation. - epsilon - Protection of 1./nnei. - - Returns - ------- - grrg - Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) - """ - # g: nf x nloc x nnei x ng - # h: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei - nf, nloc, nnei, _ = g.shape - # nf x nloc x 3 x ng - hg = _cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) - # nf x nloc x (axis_neuron x ng) - grrg = _cal_grrg(hg, axis_neuron) - return grrg - - class Atten2Map(torch.nn.Module): def __init__( self, @@ -845,6 +717,134 @@ def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: ret += g2d return ret + @staticmethod + def _cal_hg( + g: torch.Tensor, + h: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + smooth: bool = True, + epsilon: float = 1e-4, + ) -> torch.Tensor: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + g + Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. + h + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape nf x nloc x 3 x ng. + """ + # g: nf x nloc x nnei x ng + # h: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nf, nloc, nnei, _ = g.shape + ng = g.shape[-1] + # nf x nloc x nnei x ng + g = _apply_nlist_mask(g, nlist_mask) + if not smooth: + # nf x nloc + # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy + invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1)) + # nf x nloc x 1 x 1 + invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) + else: + g = _apply_switch(g, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nf, nloc, 1, 1), dtype=g.dtype, device=g.device + ) + # nf x nloc x 3 x ng + hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei + return hg + + @staticmethod + def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + hg + The transposed rotation matrix, with shape nf x nloc x 3 x ng. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) + """ + # nf x nloc x 3 x ng + nf, nloc, _, ng = hg.shape + # nf x nloc x 3 x axis + hgm = torch.split(hg, axis_neuron, dim=-1)[0] + # nf x nloc x axis_neuron x ng + grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1) + # nf x nloc x (axis_neuron x ng) + grrg = grrg.view(nf, nloc, axis_neuron * ng) + return grrg + + def symmetrization_op( + self, + g: torch.Tensor, + h: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + axis_neuron: int, + smooth: bool = True, + epsilon: float = 1e-4, + ) -> torch.Tensor: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + g + Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. + h + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + axis_neuron + Size of the submatrix. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + grrg + Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) + """ + # g: nf x nloc x nnei x ng + # h: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nf, nloc, nnei, _ = g.shape + # nf x nloc x 3 x ng + hg = self._cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) + # nf x nloc x (axis_neuron x ng) + grrg = self._cal_grrg(hg, axis_neuron) + return grrg + def _update_h2( self, h2: torch.Tensor, @@ -1027,7 +1027,7 @@ def forward( if self.update_g1_has_grrg: g1_mlp.append( - symmetrization_op( + self.symmetrization_op( g2, h2, nlist_mask, @@ -1041,7 +1041,7 @@ def forward( if self.update_g1_has_drrd: assert gg1 is not None g1_mlp.append( - symmetrization_op( + self.symmetrization_op( gg1, h2, nlist_mask, diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index e352c6b40c..f03f15096e 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -40,7 +40,6 @@ from .repformer_layer import ( RepformerLayer, - _cal_hg, ) from .repformer_layer_old_impl import RepformerLayer as RepformerLayerOld @@ -486,7 +485,9 @@ def forward( ) # nb x nloc x 3 x ng2 - h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon) + h2g2 = RepformerLayer._cal_hg( + g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon + ) # (nb x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) From d85eef031dd9f5bfe66f2f844300dfe7eec6b3da Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 17:24:58 +0800 Subject: [PATCH 30/37] Update test_descriptor_dpa2.py --- source/tests/common/dpmodel/test_descriptor_dpa2.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/source/tests/common/dpmodel/test_descriptor_dpa2.py b/source/tests/common/dpmodel/test_descriptor_dpa2.py index 4df01c61ad..3ae0689dad 100644 --- a/source/tests/common/dpmodel/test_descriptor_dpa2.py +++ b/source/tests/common/dpmodel/test_descriptor_dpa2.py @@ -45,5 +45,13 @@ def test_self_consistency( em1 = DescrptDPA2.deserialize(em0.serialize()) mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, self.mapping) mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist, self.mapping) - for ii in [0, 1, 4]: + desired_shape = [ + (nf, nloc, em0.get_dim_out()), # descriptor + (nf, nloc, em0.get_dim_emb(), 3), # rot_mat + (nf, nloc, nnei // 2, em0.repformers.g2_dim), # g2 + (nf, nloc, nnei // 2, 3), # h2 + (nf, nloc, nnei // 2), # sw + ] + for ii in [0, 1, 2, 3, 4]: + np.testing.assert_equal(mm0[ii].shape, desired_shape[ii]) np.testing.assert_allclose(mm0[ii], mm1[ii]) From 244c8e5dbc9262db3d0b81d87745a5f7e179d34e Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 17:29:05 +0800 Subject: [PATCH 31/37] Update dpa2.md --- doc/model/dpa2.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/doc/model/dpa2.md b/doc/model/dpa2.md index e295f6b6bb..0d17dfd475 100644 --- a/doc/model/dpa2.md +++ b/doc/model/dpa2.md @@ -1,5 +1,9 @@ -# Descriptor DPA-2 {{ pytorch_icon }} +# Descriptor DPA-2 {{ pytorch_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: PyTorch {{ pytorch_icon }} +**Supported backends**: PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} ::: + +The DPA-2 model implementation. See https://arxiv.org/abs/2312.15492 for more details. + +Training example: `examples/water/dpa2/input_torch.json`. From e9fe376e99f9b6f8dae08bbeaa22f3557101a3ad Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 18:27:06 +0800 Subject: [PATCH 32/37] separate args for repinit and repformers --- deepmd/dpmodel/descriptor/dpa2.py | 55 +++++++++---------- deepmd/dpmodel/descriptor/repformers.py | 4 -- deepmd/pt/model/descriptor/dpa2.py | 55 +++++++++---------- deepmd/pt/model/descriptor/repformers.py | 4 -- deepmd/utils/argcheck.py | 53 +++++++++++++----- .../tests/consistent/descriptor/test_dpa2.py | 54 +++++++++--------- 6 files changed, 118 insertions(+), 107 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 0d019590f6..b76530cf2f 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -72,6 +72,8 @@ def __init__( repinit_tebd_input_mode: str = "concat", repinit_set_davg_zero: bool = True, repinit_activation_function="tanh", + repinit_resnet_dt: bool = False, + repinit_type_one_side: bool = False, # kwargs for repformer repformer_nlayers: int = 3, repformer_g1_dim: int = 128, @@ -95,6 +97,8 @@ def __init__( repformer_update_residual: float = 0.001, repformer_update_residual_init: str = "norm", repformer_set_davg_zero: bool = True, + repformer_trainable_ln: bool = True, + repformer_ln_eps: Optional[float] = 1e-5, # kwargs for descriptor concat_output_tebd: bool = True, precision: str = "float64", @@ -103,10 +107,6 @@ def __init__( env_protection: float = 0.0, trainable: bool = True, seed: Optional[int] = None, - resnet_dt: bool = False, - trainable_ln: bool = True, - ln_eps: Optional[float] = 1e-5, - type_one_side: bool = False, add_tebd_to_repinit_out: bool = False, ): r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492. @@ -142,6 +142,12 @@ def __init__( repinit_activation_function : str, optional (Used in the repinit block.) The activation function in the embedding net. + repinit_resnet_dt : bool, optional + (Used in the repinit block.) + Whether to use a "Timestep" in the skip connection. + repinit_type_one_side : bool, optional + (Used in the repinit block.) + Whether to use one-side type embedding. repformer_rcut : float (Used in the repformer block.) The cut-off radius. @@ -223,6 +229,12 @@ def __init__( repformer_set_davg_zero : bool, optional (Used in the repformer block.) Set the normalization average to zero. + repformer_trainable_ln : bool, optional + (Used in the repformer block.) + Whether to use trainable shift and scale weights in layer normalization. + repformer_ln_eps : float, optional + (Used in the repformer block.) + The epsilon value for layer normalization. concat_output_tebd : bool, optional Whether to concat type embedding at the output of the descriptor. precision : str, optional @@ -239,14 +251,6 @@ def __init__( If the parameters are trainable. seed : int, optional (Unused yet) Random seed for parameter initialization. - resnet_dt : bool, optional - Whether to use a "Timestep" in the skip connection. - trainable_ln : bool, optional - Whether to use trainable shift and scale weights in layer normalization. - ln_eps : float, optional - The epsilon value for layer normalization. - type_one_side : bool, optional - Whether to use one-side type embedding. add_tebd_to_repinit_out : bool, optional Whether to add type embedding to the output representation from repinit before inputting it into repformer. @@ -266,8 +270,8 @@ def __init__( """ # to keep consistent with default value in this backends - if ln_eps is None: - ln_eps = 1e-5 + if repformer_ln_eps is None: + repformer_ln_eps = 1e-5 self.repinit = DescrptBlockSeAtten( repinit_rcut, repinit_rcut_smth, @@ -283,11 +287,9 @@ def __init__( env_protection=env_protection, activation_function=repinit_activation_function, precision=precision, - resnet_dt=resnet_dt, - trainable_ln=trainable_ln, - ln_eps=ln_eps, + resnet_dt=repinit_resnet_dt, smooth=smooth, - type_one_side=type_one_side, + type_one_side=repinit_type_one_side, ) self.repformers = DescrptBlockRepformers( repformer_rcut, @@ -320,9 +322,8 @@ def __init__( exclude_types=exclude_types, env_protection=env_protection, precision=precision, - resnet_dt=resnet_dt, - trainable_ln=trainable_ln, - ln_eps=ln_eps, + trainable_ln=repformer_trainable_ln, + ln_eps=repformer_ln_eps, ) self.type_embedding = TypeEmbedNet( ntypes=ntypes, @@ -337,10 +338,6 @@ def __init__( self.exclude_types = exclude_types self.env_protection = env_protection self.trainable = trainable - self.resnet_dt = resnet_dt - self.trainable_ln = trainable_ln - self.ln_eps = ln_eps - self.type_one_side = type_one_side self.add_tebd_to_repinit_out = add_tebd_to_repinit_out if self.repinit.dim_out == self.repformers.dim_in: @@ -531,6 +528,8 @@ def serialize(self) -> dict: "repinit_tebd_input_mode": repinit.tebd_input_mode, "repinit_set_davg_zero": repinit.set_davg_zero, "repinit_activation_function": repinit.activation_function, + "repinit_resnet_dt": repinit.resnet_dt, + "repinit_type_one_side": repinit.type_one_side, "repformer_nlayers": repformers.nlayers, "repformer_g1_dim": repformers.g1_dim, "repformer_g2_dim": repformers.g2_dim, @@ -551,16 +550,14 @@ def serialize(self) -> dict: "repformer_activation_function": repformers.activation_function, "repformer_update_style": repformers.update_style, "repformer_set_davg_zero": repformers.set_davg_zero, + "repformer_trainable_ln": repformers.trainable_ln, + "repformer_ln_eps": repformers.ln_eps, "concat_output_tebd": self.concat_output_tebd, "precision": self.precision, "smooth": self.smooth, "exclude_types": self.exclude_types, "env_protection": self.env_protection, "trainable": self.trainable, - "resnet_dt": self.resnet_dt, - "trainable_ln": self.trainable_ln, - "ln_eps": self.ln_eps, - "type_one_side": self.type_one_side, "add_tebd_to_repinit_out": self.add_tebd_to_repinit_out, "type_embedding": self.type_embedding.serialize(), "g1_shape_tranform": self.g1_shape_tranform.serialize(), diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index ca3651a979..f0275b44b1 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -80,7 +80,6 @@ def __init__( exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, precision: str = "float64", - resnet_dt: bool = False, trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, ): @@ -157,8 +156,6 @@ def __init__( env_protection : float, optional Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. - resnet_dt : bool, optional - Whether to use a "Timestep" in the skip connection. trainable_ln : bool, optional Whether to use trainable shift and scale weights in layer normalization. ln_eps : float, optional @@ -203,7 +200,6 @@ def __init__( self.reinit_exclude(exclude_types) self.env_protection = env_protection self.precision = precision - self.resnet_dt = resnet_dt self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.epsilon = 1e-4 diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 2e9e228bc9..28ff1b6848 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -73,6 +73,8 @@ def __init__( repinit_tebd_input_mode: str = "concat", repinit_set_davg_zero: bool = True, repinit_activation_function="tanh", + repinit_resnet_dt: bool = False, + repinit_type_one_side: bool = False, # kwargs for repformer repformer_nlayers: int = 3, repformer_g1_dim: int = 128, @@ -96,6 +98,8 @@ def __init__( repformer_update_residual: float = 0.001, repformer_update_residual_init: str = "norm", repformer_set_davg_zero: bool = True, + repformer_trainable_ln: bool = True, + repformer_ln_eps: Optional[float] = 1e-5, # kwargs for descriptor concat_output_tebd: bool = True, precision: str = "float64", @@ -104,10 +108,6 @@ def __init__( env_protection: float = 0.0, trainable: bool = True, seed: Optional[int] = None, - resnet_dt: bool = False, - trainable_ln: bool = True, - ln_eps: Optional[float] = 1e-5, - type_one_side: bool = False, add_tebd_to_repinit_out: bool = False, old_impl: bool = False, ): @@ -144,6 +144,12 @@ def __init__( repinit_activation_function : str, optional (Used in the repinit block.) The activation function in the embedding net. + repinit_resnet_dt : bool, optional + (Used in the repinit block.) + Whether to use a "Timestep" in the skip connection. + repinit_type_one_side : bool, optional + (Used in the repinit block.) + Whether to use one-side type embedding. repformer_rcut : float (Used in the repformer block.) The cut-off radius. @@ -225,6 +231,12 @@ def __init__( repformer_set_davg_zero : bool, optional (Used in the repformer block.) Set the normalization average to zero. + repformer_trainable_ln : bool, optional + (Used in the repformer block.) + Whether to use trainable shift and scale weights in layer normalization. + repformer_ln_eps : float, optional + (Used in the repformer block.) + The epsilon value for layer normalization. concat_output_tebd : bool, optional Whether to concat type embedding at the output of the descriptor. precision : str, optional @@ -241,14 +253,6 @@ def __init__( If the parameters are trainable. seed : int, optional (Unused yet) Random seed for parameter initialization. - resnet_dt : bool, optional - Whether to use a "Timestep" in the skip connection. - trainable_ln : bool, optional - Whether to use trainable shift and scale weights in layer normalization. - ln_eps : float, optional - The epsilon value for layer normalization. - type_one_side : bool, optional - Whether to use one-side type embedding. add_tebd_to_repinit_out : bool, optional Whether to add type embedding to the output representation from repinit before inputting it into repformer. @@ -269,8 +273,8 @@ def __init__( """ super().__init__() # to keep consistent with default value in this backends - if ln_eps is None: - ln_eps = 1e-5 + if repformer_ln_eps is None: + repformer_ln_eps = 1e-5 self.repinit = DescrptBlockSeAtten( repinit_rcut, repinit_rcut_smth, @@ -286,11 +290,9 @@ def __init__( env_protection=env_protection, activation_function=repinit_activation_function, precision=precision, - resnet_dt=resnet_dt, - trainable_ln=trainable_ln, - ln_eps=ln_eps, + resnet_dt=repinit_resnet_dt, smooth=smooth, - type_one_side=type_one_side, + type_one_side=repinit_type_one_side, ) self.repformers = DescrptBlockRepformers( repformer_rcut, @@ -323,9 +325,8 @@ def __init__( exclude_types=exclude_types, env_protection=env_protection, precision=precision, - resnet_dt=resnet_dt, - trainable_ln=trainable_ln, - ln_eps=ln_eps, + trainable_ln=repformer_trainable_ln, + ln_eps=repformer_ln_eps, old_impl=old_impl, ) self.type_embedding = TypeEmbedNet( @@ -337,10 +338,6 @@ def __init__( self.exclude_types = exclude_types self.env_protection = env_protection self.trainable = trainable - self.resnet_dt = resnet_dt - self.trainable_ln = trainable_ln - self.ln_eps = ln_eps - self.type_one_side = type_one_side self.add_tebd_to_repinit_out = add_tebd_to_repinit_out if self.repinit.dim_out == self.repformers.dim_in: @@ -504,6 +501,8 @@ def serialize(self) -> dict: "repinit_tebd_input_mode": repinit.tebd_input_mode, "repinit_set_davg_zero": repinit.set_davg_zero, "repinit_activation_function": repinit.activation_function, + "repinit_resnet_dt": repinit.resnet_dt, + "repinit_type_one_side": repinit.type_one_side, "repformer_nlayers": repformers.nlayers, "repformer_g1_dim": repformers.g1_dim, "repformer_g2_dim": repformers.g2_dim, @@ -524,16 +523,14 @@ def serialize(self) -> dict: "repformer_activation_function": repformers.activation_function, "repformer_update_style": repformers.update_style, "repformer_set_davg_zero": repformers.set_davg_zero, + "repformer_trainable_ln": repformers.trainable_ln, + "repformer_ln_eps": repformers.ln_eps, "concat_output_tebd": self.concat_output_tebd, "precision": self.precision, "smooth": self.smooth, "exclude_types": self.exclude_types, "env_protection": self.env_protection, "trainable": self.trainable, - "resnet_dt": self.resnet_dt, - "trainable_ln": self.trainable_ln, - "ln_eps": self.ln_eps, - "type_one_side": self.type_one_side, "add_tebd_to_repinit_out": self.add_tebd_to_repinit_out, "type_embedding": self.type_embedding.embedding.serialize(), "g1_shape_tranform": self.g1_shape_tranform.serialize(), diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index f03f15096e..2d6a61d264 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -99,7 +99,6 @@ def __init__( exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, precision: str = "float64", - resnet_dt: bool = False, trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, old_impl: bool = False, @@ -177,8 +176,6 @@ def __init__( env_protection : float, optional Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. - resnet_dt : bool, optional - Whether to use a "Timestep" in the skip connection. trainable_ln : bool, optional Whether to use trainable shift and scale weights in layer normalization. ln_eps : float, optional @@ -223,7 +220,6 @@ def __init__( self.reinit_exclude(exclude_types) self.env_protection = env_protection self.precision = precision - self.resnet_dt = resnet_dt self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.epsilon = 1e-4 diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index cc18d8c6d1..fb0e0855b8 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -645,6 +645,13 @@ def descrpt_dpa2_args(): f"This option should be set when `atom_ener` in the energy fitting is used." ) doc_repinit_activation_function = f"{doc_repinit}The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." + doc_repinit_type_one_side = ( + f"{doc_repinit}" + + r"If true, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters." + ) + doc_repinit_resnet_dt = ( + f'{doc_repinit}Whether to use a "Timestep" in the skip connection.' + ) # repformer args doc_repformer = "(Used in the repformer block.) " @@ -709,6 +716,10 @@ def descrpt_dpa2_args(): f"{doc_repformer}Set the normalization average to zero. " f"This option should be set when `atom_ener` in the energy fitting is used." ) + doc_repformer_trainable_ln = ( + "Whether to use trainable shift and scale weights in layer normalization." + ) + doc_repformer_ln_eps = "The epsilon value for layer normalization. The default value for TensorFlow is set to 1e-3 to keep consistent with keras while set to 1e-5 in PyTorch and DP implementation." # descriptor args doc_concat_output_tebd = ( @@ -722,12 +733,6 @@ def descrpt_dpa2_args(): doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." doc_trainable = "If the parameters in the embedding net is trainable." doc_seed = "Random seed for parameter initialization." - doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection.' - doc_trainable_ln = ( - "Whether to use trainable shift and scale weights in layer normalization." - ) - doc_ln_eps = "The epsilon value for layer normalization. The default value for TensorFlow is set to 1e-3 to keep consistent with keras while set to 1e-5 in PyTorch and DP implementation." - doc_type_one_side = r"If true, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters." doc_add_tebd_to_repinit_out = "Add type embedding to the output representation from repinit before inputting it into repformer." return [ # repinit args @@ -778,6 +783,20 @@ def descrpt_dpa2_args(): alias=["repinit_activation"], doc=doc_repinit_activation_function, ), + Argument( + "repinit_type_one_side", + bool, + optional=True, + default=False, + doc=doc_repinit_type_one_side, + ), + Argument( + "repinit_resnet_dt", + bool, + optional=True, + default=False, + doc=doc_repinit_resnet_dt, + ), # repformer args Argument("repformer_rcut", float, doc=doc_repformer_rcut), Argument("repformer_rcut_smth", float, doc=doc_repformer_rcut_smth), @@ -934,6 +953,20 @@ def descrpt_dpa2_args(): default=True, doc=doc_repformer_set_davg_zero, ), + Argument( + "repformer_trainable_ln", + bool, + optional=True, + default=True, + doc=doc_repformer_trainable_ln, + ), + Argument( + "repformer_ln_eps", + float, + optional=True, + default=None, + doc=doc_repformer_ln_eps, + ), # descriptor args Argument( "concat_output_tebd", @@ -960,14 +993,6 @@ def descrpt_dpa2_args(): ), Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), Argument("seed", [int, None], optional=True, doc=doc_seed), - Argument("resnet_dt", bool, optional=True, default=False, doc=doc_resnet_dt), - Argument( - "trainable_ln", bool, optional=True, default=True, doc=doc_trainable_ln - ), - Argument("ln_eps", float, optional=True, default=None, doc=doc_ln_eps), - Argument( - "type_one_side", bool, optional=True, default=False, doc=doc_type_one_side - ), Argument( "add_tebd_to_repinit_out", bool, diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 1313a6d727..dd868dbd59 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -40,6 +40,7 @@ @parameterized( ("concat", "strip"), # repinit_tebd_input_mode (True,), # repinit_set_davg_zero + (False,), # repinit_type_one_side (True, False), # repformer_direct_dist (True,), # repformer_update_g1_has_conv (True,), # repformer_update_g1_has_drrd @@ -52,12 +53,11 @@ ("res_avg", "res_residual"), # repformer_update_style ("norm", "const"), # repformer_update_residual_init (True,), # repformer_set_davg_zero + (True,), # repformer_trainable_ln + (1e-5,), # repformer_ln_eps (True, False), # smooth ([], [[0, 1]]), # exclude_types ("float64",), # precision - (True,), # trainable_lns - (1e-5,), # ln_eps - (False,), # type_one_side (True, False), # add_tebd_to_repinit_out ) class TestDPA2(CommonTest, DescriptorTest, unittest.TestCase): @@ -66,6 +66,7 @@ def data(self) -> dict: ( repinit_tebd_input_mode, repinit_set_davg_zero, + repinit_type_one_side, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -78,12 +79,11 @@ def data(self) -> dict: repformer_update_style, repformer_update_residual_init, repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, smooth, exclude_types, precision, - trainable_ln, - ln_eps, - type_one_side, add_tebd_to_repinit_out, ) = self.param return { @@ -101,6 +101,7 @@ def data(self) -> dict: "repinit_tebd_input_mode": repinit_tebd_input_mode, "repinit_set_davg_zero": repinit_set_davg_zero, "repinit_activation_function": "tanh", + "repinit_type_one_side": repinit_type_one_side, # kwargs for repformer "repformer_nlayers": 3, "repformer_g1_dim": 20, @@ -124,6 +125,8 @@ def data(self) -> dict: "repformer_update_residual": 0.001, "repformer_update_residual_init": repformer_update_residual_init, "repformer_set_davg_zero": True, + "repformer_trainable_ln": repformer_trainable_ln, + "repformer_ln_eps": repformer_ln_eps, # kwargs for descriptor "concat_output_tebd": True, "precision": precision, @@ -131,9 +134,6 @@ def data(self) -> dict: "exclude_types": exclude_types, "env_protection": 0.0, "trainable": True, - "trainable_ln": trainable_ln, - "ln_eps": ln_eps, - "type_one_side": type_one_side, "add_tebd_to_repinit_out": add_tebd_to_repinit_out, } @@ -142,6 +142,7 @@ def skip_pt(self) -> bool: ( repinit_tebd_input_mode, repinit_set_davg_zero, + repinit_type_one_side, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -154,12 +155,11 @@ def skip_pt(self) -> bool: repformer_update_style, repformer_update_residual_init, repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, smooth, exclude_types, precision, - trainable_ln, - ln_eps, - type_one_side, add_tebd_to_repinit_out, ) = self.param return CommonTest.skip_pt @@ -169,6 +169,7 @@ def skip_dp(self) -> bool: ( repinit_tebd_input_mode, repinit_set_davg_zero, + repinit_type_one_side, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -181,12 +182,11 @@ def skip_dp(self) -> bool: repformer_update_style, repformer_update_residual_init, repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, smooth, exclude_types, precision, - trainable_ln, - ln_eps, - type_one_side, add_tebd_to_repinit_out, ) = self.param return CommonTest.skip_pt @@ -196,6 +196,7 @@ def skip_tf(self) -> bool: ( repinit_tebd_input_mode, repinit_set_davg_zero, + repinit_type_one_side, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -208,12 +209,11 @@ def skip_tf(self) -> bool: repformer_update_style, repformer_update_residual_init, repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, smooth, exclude_types, precision, - trainable_ln, - ln_eps, - type_one_side, add_tebd_to_repinit_out, ) = self.param return True @@ -259,6 +259,7 @@ def setUp(self): ( repinit_tebd_input_mode, repinit_set_davg_zero, + repinit_type_one_side, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -271,12 +272,11 @@ def setUp(self): repformer_update_style, repformer_update_residual_init, repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, smooth, exclude_types, precision, - trainable_ln, - ln_eps, - type_one_side, add_tebd_to_repinit_out, ) = self.param @@ -319,6 +319,7 @@ def rtol(self) -> float: ( repinit_tebd_input_mode, repinit_set_davg_zero, + repinit_type_one_side, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -331,12 +332,11 @@ def rtol(self) -> float: repformer_update_style, repformer_update_residual_init, repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, smooth, exclude_types, precision, - trainable_ln, - ln_eps, - type_one_side, add_tebd_to_repinit_out, ) = self.param if precision == "float64": @@ -352,6 +352,7 @@ def atol(self) -> float: ( repinit_tebd_input_mode, repinit_set_davg_zero, + repinit_type_one_side, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -364,12 +365,11 @@ def atol(self) -> float: repformer_update_style, repformer_update_residual_init, repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, smooth, exclude_types, precision, - trainable_ln, - ln_eps, - type_one_side, add_tebd_to_repinit_out, ) = self.param if precision == "float64": From 515c5344295249b3038f4654862f47b0d7f8fa8e Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 23:04:35 +0800 Subject: [PATCH 33/37] Update repformer_layer.py --- deepmd/pt/model/descriptor/repformer_layer.py | 132 +++++++++--------- 1 file changed, 66 insertions(+), 66 deletions(-) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index af436ca96d..f22529c7f6 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -717,6 +717,72 @@ def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: ret += g2d return ret + def _update_h2( + self, + h2: torch.Tensor, + attn: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate the attention weights update for pair-wise equivariant rep. + + Parameters + ---------- + h2 + Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + attn + Attention weights from g2 attention, with shape nf x nloc x nnei x nnei x nh2. + """ + assert self.attn2_ev_apply is not None + # nf x nloc x nnei x nh2 + h2_1 = self.attn2_ev_apply(attn, h2) + return h2_1 + + def _update_g1_conv( + self, + gg1: torch.Tensor, + g2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate the convolution update for atomic invariant rep. + + Parameters + ---------- + gg1 + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + g2 + Pair invariant rep, with shape nb x nloc x nnei x ng2. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + """ + assert self.proj_g1g2 is not None + nb, nloc, nnei, _ = g2.shape + ng1 = gg1.shape[-1] + ng2 = g2.shape[-1] + # gg1 : nb x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) + # nb x nloc x nnei x ng2 + gg1 = _apply_nlist_mask(gg1, nlist_mask) + if not self.smooth: + # normalized by number of neighbors, not smooth + # nb x nloc x 1 + # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy + invnnei = 1.0 / ( + self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1) + ).unsqueeze(-1) + else: + gg1 = _apply_switch(gg1, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device + ) + # nb x nloc x ng2 + g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei + return g1_11 + @staticmethod def _cal_hg( g: torch.Tensor, @@ -845,72 +911,6 @@ def symmetrization_op( grrg = self._cal_grrg(hg, axis_neuron) return grrg - def _update_h2( - self, - h2: torch.Tensor, - attn: torch.Tensor, - ) -> torch.Tensor: - """ - Calculate the attention weights update for pair-wise equivariant rep. - - Parameters - ---------- - h2 - Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. - attn - Attention weights from g2 attention, with shape nf x nloc x nnei x nnei x nh2. - """ - assert self.attn2_ev_apply is not None - # nf x nloc x nnei x nh2 - h2_1 = self.attn2_ev_apply(attn, h2) - return h2_1 - - def _update_g1_conv( - self, - gg1: torch.Tensor, - g2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - ) -> torch.Tensor: - """ - Calculate the convolution update for atomic invariant rep. - - Parameters - ---------- - gg1 - Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. - g2 - Pair invariant rep, with shape nb x nloc x nnei x ng2. - nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. - sw - The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nb x nloc x nnei. - """ - assert self.proj_g1g2 is not None - nb, nloc, nnei, _ = g2.shape - ng1 = gg1.shape[-1] - ng2 = g2.shape[-1] - # gg1 : nb x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) - # nb x nloc x nnei x ng2 - gg1 = _apply_nlist_mask(gg1, nlist_mask) - if not self.smooth: - # normalized by number of neighbors, not smooth - # nb x nloc x 1 - # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy - invnnei = 1.0 / ( - self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1) - ).unsqueeze(-1) - else: - gg1 = _apply_switch(gg1, sw) - invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device - ) - # nb x nloc x ng2 - g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei - return g1_11 - def _update_g2_g1g1( self, g1: torch.Tensor, # nb x nloc x ng1 From bd25aa698281df7c85e4e173d8e78b6611abd31a Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 23:09:07 +0800 Subject: [PATCH 34/37] Update repformer_layer.py --- deepmd/pt/model/descriptor/repformer_layer.py | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index f22529c7f6..1697591a42 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -798,14 +798,14 @@ def _cal_hg( Parameters ---------- g - Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng. h - Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. + and remains 0 beyond rcut, with shape nb x nloc x nnei. smooth Whether to use smoothness in processes such as attention weights calculation. epsilon @@ -814,27 +814,27 @@ def _cal_hg( Returns ------- hg - The transposed rotation matrix, with shape nf x nloc x 3 x ng. + The transposed rotation matrix, with shape nb x nloc x 3 x ng. """ - # g: nf x nloc x nnei x ng - # h: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei - nf, nloc, nnei, _ = g.shape + # g: nb x nloc x nnei x ng + # h: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = g.shape ng = g.shape[-1] - # nf x nloc x nnei x ng + # nb x nloc x nnei x ng g = _apply_nlist_mask(g, nlist_mask) if not smooth: - # nf x nloc + # nb x nloc # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1)) - # nf x nloc x 1 x 1 + # nb x nloc x 1 x 1 invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) else: g = _apply_switch(g, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nf, nloc, 1, 1), dtype=g.dtype, device=g.device + (nb, nloc, 1, 1), dtype=g.dtype, device=g.device ) - # nf x nloc x 3 x ng + # nb x nloc x 3 x ng hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei return hg @@ -846,23 +846,23 @@ def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor: Parameters ---------- hg - The transposed rotation matrix, with shape nf x nloc x 3 x ng. + The transposed rotation matrix, with shape nb x nloc x 3 x ng. axis_neuron Size of the submatrix. Returns ------- grrg - Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng) """ - # nf x nloc x 3 x ng - nf, nloc, _, ng = hg.shape - # nf x nloc x 3 x axis + # nb x nloc x 3 x ng + nb, nloc, _, ng = hg.shape + # nb x nloc x 3 x axis hgm = torch.split(hg, axis_neuron, dim=-1)[0] - # nf x nloc x axis_neuron x ng + # nb x nloc x axis_neuron x ng grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1) - # nf x nloc x (axis_neuron x ng) - grrg = grrg.view(nf, nloc, axis_neuron * ng) + # nb x nloc x (axis_neuron x ng) + grrg = grrg.view(nb, nloc, axis_neuron * ng) return grrg def symmetrization_op( @@ -881,14 +881,14 @@ def symmetrization_op( Parameters ---------- g - Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng. h - Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. + and remains 0 beyond rcut, with shape nb x nloc x nnei. axis_neuron Size of the submatrix. smooth @@ -899,15 +899,15 @@ def symmetrization_op( Returns ------- grrg - Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng) """ - # g: nf x nloc x nnei x ng - # h: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei - nf, nloc, nnei, _ = g.shape - # nf x nloc x 3 x ng + # g: nb x nloc x nnei x ng + # h: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = g.shape + # nb x nloc x 3 x ng hg = self._cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) - # nf x nloc x (axis_neuron x ng) + # nb x nloc x (axis_neuron x ng) grrg = self._cal_grrg(hg, axis_neuron) return grrg From f17f40f51471a4d66d47f9b790d889116fb806a1 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 23:35:56 +0800 Subject: [PATCH 35/37] Update repformer_layer.py --- deepmd/pt/model/descriptor/repformer_layer.py | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 1697591a42..df91bcb8fa 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -785,8 +785,8 @@ def _update_g1_conv( @staticmethod def _cal_hg( - g: torch.Tensor, - h: torch.Tensor, + g2: torch.Tensor, + h2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, smooth: bool = True, @@ -797,9 +797,9 @@ def _cal_hg( Parameters ---------- - g - Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng. - h + g2 + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + h2 Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. nlist_mask Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. @@ -814,61 +814,61 @@ def _cal_hg( Returns ------- hg - The transposed rotation matrix, with shape nb x nloc x 3 x ng. + The transposed rotation matrix, with shape nb x nloc x 3 x ng2. """ - # g: nb x nloc x nnei x ng - # h: nb x nloc x nnei x 3 + # g2: nb x nloc x nnei x ng2 + # h2: nb x nloc x nnei x 3 # msk: nb x nloc x nnei - nb, nloc, nnei, _ = g.shape - ng = g.shape[-1] - # nb x nloc x nnei x ng - g = _apply_nlist_mask(g, nlist_mask) + nb, nloc, nnei, _ = g2.shape + ng2 = g2.shape[-1] + # nb x nloc x nnei x ng2 + g2 = _apply_nlist_mask(g2, nlist_mask) if not smooth: # nb x nloc # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy - invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1)) + invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) # nb x nloc x 1 x 1 invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) else: - g = _apply_switch(g, sw) + g2 = _apply_switch(g2, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1, 1), dtype=g.dtype, device=g.device + (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device ) # nb x nloc x 3 x ng - hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei - return hg + h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei + return h2g2 @staticmethod - def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor: + def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: """ Calculate the atomic invariant rep. Parameters ---------- - hg - The transposed rotation matrix, with shape nb x nloc x 3 x ng. + h2g2 + The transposed rotation matrix, with shape nb x nloc x 3 x ng2. axis_neuron Size of the submatrix. Returns ------- grrg - Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng) + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) """ - # nb x nloc x 3 x ng - nb, nloc, _, ng = hg.shape + # nb x nloc x 3 x ng2 + nb, nloc, _, ng2 = h2g2.shape # nb x nloc x 3 x axis - hgm = torch.split(hg, axis_neuron, dim=-1)[0] + h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0] # nb x nloc x axis_neuron x ng - grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1) + grrg = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) # nb x nloc x (axis_neuron x ng) - grrg = grrg.view(nb, nloc, axis_neuron * ng) + grrg = grrg.view(nb, nloc, axis_neuron * ng2) return grrg def symmetrization_op( self, - g: torch.Tensor, - h: torch.Tensor, + g2: torch.Tensor, + h2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, axis_neuron: int, @@ -880,9 +880,9 @@ def symmetrization_op( Parameters ---------- - g - Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng. - h + g2 + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + h2 Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. nlist_mask Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. @@ -899,16 +899,16 @@ def symmetrization_op( Returns ------- grrg - Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng) + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) """ - # g: nb x nloc x nnei x ng - # h: nb x nloc x nnei x 3 + # g2: nb x nloc x nnei x ng2 + # h2: nb x nloc x nnei x 3 # msk: nb x nloc x nnei - nb, nloc, nnei, _ = g.shape + nb, nloc, nnei, _ = g2.shape # nb x nloc x 3 x ng - hg = self._cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) - # nb x nloc x (axis_neuron x ng) - grrg = self._cal_grrg(hg, axis_neuron) + h2g2 = self._cal_hg(g2, h2, nlist_mask, sw, smooth=smooth, epsilon=epsilon) + # nb x nloc x (axis_neuron x ng2) + grrg = self._cal_grrg(h2g2, axis_neuron) return grrg def _update_g2_g1g1( From a8c89dcc491a17b202b120a72c002b8480d8020a Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 23:38:36 +0800 Subject: [PATCH 36/37] Update repformer_layer.py --- deepmd/pt/model/descriptor/repformer_layer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index df91bcb8fa..f171b6919b 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -859,11 +859,11 @@ def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: nb, nloc, _, ng2 = h2g2.shape # nb x nloc x 3 x axis h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0] - # nb x nloc x axis_neuron x ng - grrg = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) - # nb x nloc x (axis_neuron x ng) - grrg = grrg.view(nb, nloc, axis_neuron * ng2) - return grrg + # nb x nloc x axis x ng2 + g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) + # nb x nloc x (axisxng2) + g1_13 = g1_13.view(nb, nloc, axis_neuron * ng2) + return g1_13 def symmetrization_op( self, @@ -905,11 +905,11 @@ def symmetrization_op( # h2: nb x nloc x nnei x 3 # msk: nb x nloc x nnei nb, nloc, nnei, _ = g2.shape - # nb x nloc x 3 x ng + # nb x nloc x 3 x ng2 h2g2 = self._cal_hg(g2, h2, nlist_mask, sw, smooth=smooth, epsilon=epsilon) - # nb x nloc x (axis_neuron x ng2) - grrg = self._cal_grrg(h2g2, axis_neuron) - return grrg + # nb x nloc x (axisxng2) + g1_13 = self._cal_grrg(h2g2, axis_neuron) + return g1_13 def _update_g2_g1g1( self, From e329e309206dc0d4fdf34d4280cd30c06fc76051 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 23:40:58 +0800 Subject: [PATCH 37/37] Update repformer_layer.py --- deepmd/pt/model/descriptor/repformer_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index f171b6919b..8397b4b421 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -834,7 +834,7 @@ def _cal_hg( invnnei = (1.0 / float(nnei)) * torch.ones( (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device ) - # nb x nloc x 3 x ng + # nb x nloc x 3 x ng2 h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei return h2g2