Skip to content

Commit

Permalink
Perf: replace unnecessary torch.split with indexing (#4505)
Browse files Browse the repository at this point in the history
Some operations only use the first segment of the result tensor of
`torch.split`. In this case, all the other segments are created and
discarded. This slightly adds an overhead to the training process.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Bug Fixes**
- Simplified tensor slicing operations in the `RepformerLayer` class and
the `nlist_distinguish_types` function, enhancing readability and
performance.
  
- **Documentation**
- Updated comments for clarity regarding tensor shapes in the
`RepformerLayer` class.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
caic99 authored Dec 25, 2024
1 parent beeb3d9 commit 3cecca4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor:
# nb x nloc x 3 x ng2
nb, nloc, _, ng2 = h2g2.shape
# nb x nloc x 3 x axis
h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0]
h2g2m = h2g2[..., :axis_neuron]
# nb x nloc x axis x ng2
g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1)
# nb x nloc x (axisxng2)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def nlist_distinguish_types(
inlist = torch.gather(nlist, 2, imap)
inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1)
# nloc x nsel[ii]
ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0])
ret_nlist.append(inlist[..., :ss])
return torch.concat(ret_nlist, dim=-1)


Expand Down

0 comments on commit 3cecca4

Please sign in to comment.