Skip to content

Commit

Permalink
mv symmetrization_op into static
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed May 9, 2024
1 parent 9d0ad7f commit 375c03e
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 132 deletions.
260 changes: 130 additions & 130 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,134 +137,6 @@ def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor:
return gg * sw.unsqueeze(-1)


def _cal_hg(
g: torch.Tensor,
h: torch.Tensor,
nlist_mask: torch.Tensor,
sw: torch.Tensor,
smooth: bool = True,
epsilon: float = 1e-4,
) -> torch.Tensor:
"""
Calculate the transposed rotation matrix.
Parameters
----------
g
Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng.
h
Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3.
nlist_mask
Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei.
sw
The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut,
and remains 0 beyond rcut, with shape nf x nloc x nnei.
smooth
Whether to use smoothness in processes such as attention weights calculation.
epsilon
Protection of 1./nnei.
Returns
-------
hg
The transposed rotation matrix, with shape nf x nloc x 3 x ng.
"""
# g: nf x nloc x nnei x ng
# h: nf x nloc x nnei x 3
# msk: nf x nloc x nnei
nf, nloc, nnei, _ = g.shape
ng = g.shape[-1]
# nf x nloc x nnei x ng
g = _apply_nlist_mask(g, nlist_mask)
if not smooth:
# nf x nloc
# must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy
invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1))
# nf x nloc x 1 x 1
invnnei = invnnei.unsqueeze(-1).unsqueeze(-1)
else:
g = _apply_switch(g, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nf, nloc, 1, 1), dtype=g.dtype, device=g.device
)
# nf x nloc x 3 x ng
hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei
return hg


def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor:
"""
Calculate the atomic invariant rep.
Parameters
----------
hg
The transposed rotation matrix, with shape nf x nloc x 3 x ng.
axis_neuron
Size of the submatrix.
Returns
-------
grrg
Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng)
"""
# nf x nloc x 3 x ng
nf, nloc, _, ng = hg.shape
# nf x nloc x 3 x axis
hgm = torch.split(hg, axis_neuron, dim=-1)[0]
# nf x nloc x axis_neuron x ng
grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1)
# nf x nloc x (axis_neuron x ng)
grrg = grrg.view(nf, nloc, axis_neuron * ng)
return grrg


def symmetrization_op(
g: torch.Tensor,
h: torch.Tensor,
nlist_mask: torch.Tensor,
sw: torch.Tensor,
axis_neuron: int,
smooth: bool = True,
epsilon: float = 1e-4,
) -> torch.Tensor:
"""
Symmetrization operator to obtain atomic invariant rep.
Parameters
----------
g
Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng.
h
Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3.
nlist_mask
Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei.
sw
The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut,
and remains 0 beyond rcut, with shape nf x nloc x nnei.
axis_neuron
Size of the submatrix.
smooth
Whether to use smoothness in processes such as attention weights calculation.
epsilon
Protection of 1./nnei.
Returns
-------
grrg
Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng)
"""
# g: nf x nloc x nnei x ng
# h: nf x nloc x nnei x 3
# msk: nf x nloc x nnei
nf, nloc, nnei, _ = g.shape
# nf x nloc x 3 x ng
hg = _cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon)
# nf x nloc x (axis_neuron x ng)
grrg = _cal_grrg(hg, axis_neuron)
return grrg


class Atten2Map(torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -845,6 +717,134 @@ def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int:
ret += g2d
return ret

@staticmethod
def _cal_hg(
g: torch.Tensor,
h: torch.Tensor,
nlist_mask: torch.Tensor,
sw: torch.Tensor,
smooth: bool = True,
epsilon: float = 1e-4,
) -> torch.Tensor:
"""
Calculate the transposed rotation matrix.
Parameters
----------
g
Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng.
h
Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3.
nlist_mask
Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei.
sw
The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut,
and remains 0 beyond rcut, with shape nf x nloc x nnei.
smooth
Whether to use smoothness in processes such as attention weights calculation.
epsilon
Protection of 1./nnei.
Returns
-------
hg
The transposed rotation matrix, with shape nf x nloc x 3 x ng.
"""
# g: nf x nloc x nnei x ng
# h: nf x nloc x nnei x 3
# msk: nf x nloc x nnei
nf, nloc, nnei, _ = g.shape
ng = g.shape[-1]
# nf x nloc x nnei x ng
g = _apply_nlist_mask(g, nlist_mask)
if not smooth:
# nf x nloc
# must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy
invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1))
# nf x nloc x 1 x 1
invnnei = invnnei.unsqueeze(-1).unsqueeze(-1)
else:
g = _apply_switch(g, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nf, nloc, 1, 1), dtype=g.dtype, device=g.device
)
# nf x nloc x 3 x ng
hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei
return hg

@staticmethod
def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor:
"""
Calculate the atomic invariant rep.
Parameters
----------
hg
The transposed rotation matrix, with shape nf x nloc x 3 x ng.
axis_neuron
Size of the submatrix.
Returns
-------
grrg
Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng)
"""
# nf x nloc x 3 x ng
nf, nloc, _, ng = hg.shape
# nf x nloc x 3 x axis
hgm = torch.split(hg, axis_neuron, dim=-1)[0]
# nf x nloc x axis_neuron x ng
grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1)
# nf x nloc x (axis_neuron x ng)
grrg = grrg.view(nf, nloc, axis_neuron * ng)
return grrg

def symmetrization_op(
self,
g: torch.Tensor,
h: torch.Tensor,
nlist_mask: torch.Tensor,
sw: torch.Tensor,
axis_neuron: int,
smooth: bool = True,
epsilon: float = 1e-4,
) -> torch.Tensor:
"""
Symmetrization operator to obtain atomic invariant rep.
Parameters
----------
g
Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng.
h
Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3.
nlist_mask
Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei.
sw
The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut,
and remains 0 beyond rcut, with shape nf x nloc x nnei.
axis_neuron
Size of the submatrix.
smooth
Whether to use smoothness in processes such as attention weights calculation.
epsilon
Protection of 1./nnei.
Returns
-------
grrg
Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng)
"""
# g: nf x nloc x nnei x ng
# h: nf x nloc x nnei x 3
# msk: nf x nloc x nnei
nf, nloc, nnei, _ = g.shape
# nf x nloc x 3 x ng
hg = self._cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon)
# nf x nloc x (axis_neuron x ng)
grrg = self._cal_grrg(hg, axis_neuron)
return grrg

def _update_h2(
self,
h2: torch.Tensor,
Expand Down Expand Up @@ -1027,7 +1027,7 @@ def forward(

if self.update_g1_has_grrg:
g1_mlp.append(
symmetrization_op(
self.symmetrization_op(
g2,
h2,
nlist_mask,
Expand All @@ -1041,7 +1041,7 @@ def forward(
if self.update_g1_has_drrd:
assert gg1 is not None
g1_mlp.append(
symmetrization_op(
self.symmetrization_op(
gg1,
h2,
nlist_mask,
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@

from .repformer_layer import (
RepformerLayer,
_cal_hg,
)
from .repformer_layer_old_impl import RepformerLayer as RepformerLayerOld

Expand Down Expand Up @@ -486,7 +485,9 @@ def forward(
)

# nb x nloc x 3 x ng2
h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon)
h2g2 = RepformerLayer._cal_hg(
g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon
)
# (nb x nloc) x ng2 x 3
rot_mat = torch.permute(h2g2, (0, 1, 3, 2))

Expand Down

0 comments on commit 375c03e

Please sign in to comment.