Skip to content

Commit

Permalink
Use input embeddings instead of input ids when provided
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras committed Feb 9, 2024
1 parent dc1545d commit 75e1476
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,21 +1177,21 @@ def forward(
if past_key_values:
assert len(past_key_values) == self.config.n_layers

batch_size, seq_len = input_ids.size()
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
if past_key_values is None:
past_length = 0
else:
past_length = past_key_values[0][0].size(-2)

# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
x = self.transformer.wte(input_ids) # type: ignore
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore

if not (self.config.alibi or self.config.rope):
# Get positional embeddings.
# shape: (1, seq_len)
pos = torch.arange(
past_length, past_length + seq_len, dtype=torch.long, device=input_ids.device
past_length, past_length + seq_len, dtype=torch.long, device=x.device
).unsqueeze(0)
# shape: (1, seq_len, d_model)
pos_emb = self.transformer.wpe(pos) # type: ignore
Expand Down Expand Up @@ -1232,7 +1232,7 @@ def forward(
if attention_mask is not None:
mask_len = attention_mask.shape[-1]
elif past_key_values is not None:
mask_len = past_key_values[0][0].shape[-2] + input_ids.shape[-1]
mask_len = past_key_values[0][0].shape[-2] + seq_len
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)

# Add in the masking bias.
Expand Down

0 comments on commit 75e1476

Please sign in to comment.