Skip to content

Commit

Permalink
Small change
Browse files Browse the repository at this point in the history
  • Loading branch information
mseeger committed Dec 7, 2024
1 parent bb86cb0 commit 075fa6f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -
if input_pos.dim() > 2:
# otherwise, things go wrong in `apply_rope`
raise ValueError("input_pos must have 1 or 2 dimensions")
if input_pos.shape[-1] != T:
raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1]")
cos = batched_index_select(self.cos, 0, input_pos)
sin = batched_index_select(self.sin, 0, input_pos)
if self.mask_cache is None:
Expand Down Expand Up @@ -496,10 +498,11 @@ def batched_index_select(t, dim, idx):
res = torch.index_select(t, dim, idx.reshape(-1)) # flat index
# split out single batch idx
res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :])
# move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors
dims = [dim] + list(range(res.dim()))
del dims[dim + 1]
res = res.permute(dims)
if dim > 0:
# move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors
dims = [dim] + list(range(res.dim()))
del dims[dim + 1]
res = res.permute(dims)
# unflatten batch dims
res = res.view(*batch_shape, *res.shape[1:])
return res
Expand Down

0 comments on commit 075fa6f

Please sign in to comment.