From 075fa6f0970dd0f8cecca3e006ce8583098a4279 Mon Sep 17 00:00:00 2001 From: Matthias Seeger Date: Sat, 7 Dec 2024 22:25:45 +0100 Subject: [PATCH] Small change --- litgpt/model.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index 6e1e5d0c2f..cd5f1ab2b5 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -80,6 +80,8 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) - if input_pos.dim() > 2: # otherwise, things go wrong in `apply_rope` raise ValueError("input_pos must have 1 or 2 dimensions") + if input_pos.shape[-1] != T: + raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1]") cos = batched_index_select(self.cos, 0, input_pos) sin = batched_index_select(self.sin, 0, input_pos) if self.mask_cache is None: @@ -496,10 +498,11 @@ def batched_index_select(t, dim, idx): res = torch.index_select(t, dim, idx.reshape(-1)) # flat index # split out single batch idx res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :]) - # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors - dims = [dim] + list(range(res.dim())) - del dims[dim + 1] - res = res.permute(dims) + if dim > 0: + # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors + dims = [dim] + list(range(res.dim())) + del dims[dim + 1] + res = res.permute(dims) # unflatten batch dims res = res.view(*batch_shape, *res.shape[1:]) return res