diff --git a/olmo/model.py b/olmo/model.py index 04a4764e0..3d8c24d77 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -1177,7 +1177,7 @@ def forward( if past_key_values: assert len(past_key_values) == self.config.n_layers - batch_size, seq_len = input_ids.size() + batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] if past_key_values is None: past_length = 0 else: @@ -1185,13 +1185,13 @@ def forward( # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - x = self.transformer.wte(input_ids) # type: ignore + x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore if not (self.config.alibi or self.config.rope): # Get positional embeddings. # shape: (1, seq_len) pos = torch.arange( - past_length, past_length + seq_len, dtype=torch.long, device=input_ids.device + past_length, past_length + seq_len, dtype=torch.long, device=x.device ).unsqueeze(0) # shape: (1, seq_len, d_model) pos_emb = self.transformer.wpe(pos) # type: ignore @@ -1232,7 +1232,7 @@ def forward( if attention_mask is not None: mask_len = attention_mask.shape[-1] elif past_key_values is not None: - mask_len = past_key_values[0][0].shape[-2] + input_ids.shape[-1] + mask_len = past_key_values[0][0].shape[-2] + seq_len attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) # Add in the masking bias.