diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 6c29636d6d..8a211c977d 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -472,23 +472,34 @@ def forward( if self.type_one_side: ii = embedding_idx # torch.jit is not happy with slice(None) - ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device) + # ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device) + # applying a mask seems to cause performance degradation + ti_mask = None else: # ti: center atom type, ii: neighbor type... ii = embedding_idx // self.ntypes ti = embedding_idx % self.ntypes ti_mask = atype.ravel().eq(ti) # nfnl x nt - mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]] + if ti_mask is not None: + mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]] + else: + mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] # nfnl x nt x 4 - rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :] + if ti_mask is not None: + rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :] + else: + rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] rr = rr * mm[:, :, None] ss = rr[:, :, :1] # nfnl x nt x ng gg = ll.forward(ss) # nfnl x 4 x ng gr = torch.matmul(rr.permute(0, 2, 1), gg) - xyz_scatter[ti_mask] += gr + if ti_mask is not None: + xyz_scatter[ti_mask] += gr + else: + xyz_scatter += gr xyz_scatter /= self.nnei xyz_scatter_1 = xyz_scatter.permute(0, 2, 1)