Skip to content

Commit

Permalink
pt: fix se_a type_one_side performance degradation (#3361)
Browse files Browse the repository at this point in the history
The code in this PR is ugly, but applying a mask is causing performance
degradation for ~3 ms/step.

When applying a mask, `aten::nonzero` has a high host time, as it causes
host-device synchronization:

![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/86b3518c-206d-410d-928e-2f605746147c)

After fixing:

![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/af9e86fa-7908-4bbb-ace7-58b4602e167f)

See pytorch/pytorch#12461 for more
information.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 29, 2024
1 parent 2bee22c commit 48c8818
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,23 +472,34 @@ def forward(
if self.type_one_side:
ii = embedding_idx
# torch.jit is not happy with slice(None)
ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device)
# ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device)
# applying a mask seems to cause performance degradation
ti_mask = None
else:
# ti: center atom type, ii: neighbor type...
ii = embedding_idx // self.ntypes
ti = embedding_idx % self.ntypes
ti_mask = atype.ravel().eq(ti)
# nfnl x nt
mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]]
if ti_mask is not None:
mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]]
else:
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
# nfnl x nt x 4
rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :]
if ti_mask is not None:
rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :]
else:
rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
rr = rr * mm[:, :, None]
ss = rr[:, :, :1]
# nfnl x nt x ng
gg = ll.forward(ss)
# nfnl x 4 x ng
gr = torch.matmul(rr.permute(0, 2, 1), gg)
xyz_scatter[ti_mask] += gr
if ti_mask is not None:
xyz_scatter[ti_mask] += gr
else:
xyz_scatter += gr

xyz_scatter /= self.nnei
xyz_scatter_1 = xyz_scatter.permute(0, 2, 1)
Expand Down

0 comments on commit 48c8818

Please sign in to comment.