From 375c03ee37bf7a534c28035c559ff17f256574de Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 9 May 2024 16:51:45 +0800 Subject: [PATCH] mv symmetrization_op into static --- deepmd/pt/model/descriptor/repformer_layer.py | 260 +++++++++--------- deepmd/pt/model/descriptor/repformers.py | 5 +- 2 files changed, 133 insertions(+), 132 deletions(-) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 8af81520dd..af436ca96d 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -137,134 +137,6 @@ def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: return gg * sw.unsqueeze(-1) -def _cal_hg( - g: torch.Tensor, - h: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - smooth: bool = True, - epsilon: float = 1e-4, -) -> torch.Tensor: - """ - Calculate the transposed rotation matrix. - - Parameters - ---------- - g - Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. - h - Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. - nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. - sw - The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. - smooth - Whether to use smoothness in processes such as attention weights calculation. - epsilon - Protection of 1./nnei. - - Returns - ------- - hg - The transposed rotation matrix, with shape nf x nloc x 3 x ng. - """ - # g: nf x nloc x nnei x ng - # h: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei - nf, nloc, nnei, _ = g.shape - ng = g.shape[-1] - # nf x nloc x nnei x ng - g = _apply_nlist_mask(g, nlist_mask) - if not smooth: - # nf x nloc - # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy - invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1)) - # nf x nloc x 1 x 1 - invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) - else: - g = _apply_switch(g, sw) - invnnei = (1.0 / float(nnei)) * torch.ones( - (nf, nloc, 1, 1), dtype=g.dtype, device=g.device - ) - # nf x nloc x 3 x ng - hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei - return hg - - -def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor: - """ - Calculate the atomic invariant rep. - - Parameters - ---------- - hg - The transposed rotation matrix, with shape nf x nloc x 3 x ng. - axis_neuron - Size of the submatrix. - - Returns - ------- - grrg - Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) - """ - # nf x nloc x 3 x ng - nf, nloc, _, ng = hg.shape - # nf x nloc x 3 x axis - hgm = torch.split(hg, axis_neuron, dim=-1)[0] - # nf x nloc x axis_neuron x ng - grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1) - # nf x nloc x (axis_neuron x ng) - grrg = grrg.view(nf, nloc, axis_neuron * ng) - return grrg - - -def symmetrization_op( - g: torch.Tensor, - h: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - axis_neuron: int, - smooth: bool = True, - epsilon: float = 1e-4, -) -> torch.Tensor: - """ - Symmetrization operator to obtain atomic invariant rep. - - Parameters - ---------- - g - Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. - h - Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. - nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. - sw - The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. - axis_neuron - Size of the submatrix. - smooth - Whether to use smoothness in processes such as attention weights calculation. - epsilon - Protection of 1./nnei. - - Returns - ------- - grrg - Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) - """ - # g: nf x nloc x nnei x ng - # h: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei - nf, nloc, nnei, _ = g.shape - # nf x nloc x 3 x ng - hg = _cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) - # nf x nloc x (axis_neuron x ng) - grrg = _cal_grrg(hg, axis_neuron) - return grrg - - class Atten2Map(torch.nn.Module): def __init__( self, @@ -845,6 +717,134 @@ def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: ret += g2d return ret + @staticmethod + def _cal_hg( + g: torch.Tensor, + h: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + smooth: bool = True, + epsilon: float = 1e-4, + ) -> torch.Tensor: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + g + Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. + h + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape nf x nloc x 3 x ng. + """ + # g: nf x nloc x nnei x ng + # h: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nf, nloc, nnei, _ = g.shape + ng = g.shape[-1] + # nf x nloc x nnei x ng + g = _apply_nlist_mask(g, nlist_mask) + if not smooth: + # nf x nloc + # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy + invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1)) + # nf x nloc x 1 x 1 + invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) + else: + g = _apply_switch(g, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nf, nloc, 1, 1), dtype=g.dtype, device=g.device + ) + # nf x nloc x 3 x ng + hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei + return hg + + @staticmethod + def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + hg + The transposed rotation matrix, with shape nf x nloc x 3 x ng. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) + """ + # nf x nloc x 3 x ng + nf, nloc, _, ng = hg.shape + # nf x nloc x 3 x axis + hgm = torch.split(hg, axis_neuron, dim=-1)[0] + # nf x nloc x axis_neuron x ng + grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1) + # nf x nloc x (axis_neuron x ng) + grrg = grrg.view(nf, nloc, axis_neuron * ng) + return grrg + + def symmetrization_op( + self, + g: torch.Tensor, + h: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + axis_neuron: int, + smooth: bool = True, + epsilon: float = 1e-4, + ) -> torch.Tensor: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + g + Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng. + h + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + axis_neuron + Size of the submatrix. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + grrg + Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng) + """ + # g: nf x nloc x nnei x ng + # h: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nf, nloc, nnei, _ = g.shape + # nf x nloc x 3 x ng + hg = self._cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) + # nf x nloc x (axis_neuron x ng) + grrg = self._cal_grrg(hg, axis_neuron) + return grrg + def _update_h2( self, h2: torch.Tensor, @@ -1027,7 +1027,7 @@ def forward( if self.update_g1_has_grrg: g1_mlp.append( - symmetrization_op( + self.symmetrization_op( g2, h2, nlist_mask, @@ -1041,7 +1041,7 @@ def forward( if self.update_g1_has_drrd: assert gg1 is not None g1_mlp.append( - symmetrization_op( + self.symmetrization_op( gg1, h2, nlist_mask, diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index e352c6b40c..f03f15096e 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -40,7 +40,6 @@ from .repformer_layer import ( RepformerLayer, - _cal_hg, ) from .repformer_layer_old_impl import RepformerLayer as RepformerLayerOld @@ -486,7 +485,9 @@ def forward( ) # nb x nloc x 3 x ng2 - h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon) + h2g2 = RepformerLayer._cal_hg( + g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon + ) # (nb x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2))