Skip to content

Commit

Permalink
Update repformer_layer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed May 9, 2024
1 parent f17f40f commit a8c89dc
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,11 +859,11 @@ def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor:
nb, nloc, _, ng2 = h2g2.shape
# nb x nloc x 3 x axis
h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0]
# nb x nloc x axis_neuron x ng
grrg = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1)
# nb x nloc x (axis_neuron x ng)
grrg = grrg.view(nb, nloc, axis_neuron * ng2)
return grrg
# nb x nloc x axis x ng2
g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1)
# nb x nloc x (axisxng2)
g1_13 = g1_13.view(nb, nloc, axis_neuron * ng2)
return g1_13

def symmetrization_op(
self,
Expand Down Expand Up @@ -905,11 +905,11 @@ def symmetrization_op(
# h2: nb x nloc x nnei x 3
# msk: nb x nloc x nnei
nb, nloc, nnei, _ = g2.shape
# nb x nloc x 3 x ng
# nb x nloc x 3 x ng2
h2g2 = self._cal_hg(g2, h2, nlist_mask, sw, smooth=smooth, epsilon=epsilon)
# nb x nloc x (axis_neuron x ng2)
grrg = self._cal_grrg(h2g2, axis_neuron)
return grrg
# nb x nloc x (axisxng2)
g1_13 = self._cal_grrg(h2g2, axis_neuron)
return g1_13

def _update_g2_g1g1(
self,
Expand Down

0 comments on commit a8c89dc

Please sign in to comment.