From dc1545d48301155bd3a25fec020873f8b68073da Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 9 Feb 2024 14:41:23 -0800 Subject: [PATCH 1/5] Pass input embeddings from HF OLMo to inner model forward --- hf_olmo/modeling_olmo.py | 2 ++ olmo/model.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index 8856be8ad..2c32f43c8 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -48,6 +48,7 @@ def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None, init_params def forward( self, input_ids: torch.LongTensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, @@ -64,6 +65,7 @@ def forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model.forward( input_ids=input_ids, + input_embeddings=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, diff --git a/olmo/model.py b/olmo/model.py index cc621a37b..04a4764e0 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -1137,6 +1137,7 @@ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch. def forward( self, input_ids: torch.LongTensor, + input_embeddings: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, attention_bias: Optional[torch.Tensor] = None, past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, @@ -1145,6 +1146,8 @@ def forward( ) -> OlmoOutput: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. + :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input + embeddings. When provided, it is treated as the output of the input embedding layer. :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates which input IDs are masked. A `1` value in the mask means that the corresponding input ID should *not* be ignored. A `0` means From 75e1476c790af8998a82cbb5e44d5c64be40113c Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 9 Feb 2024 14:42:50 -0800 Subject: [PATCH 2/5] Use input embeddings instead of input ids when provided --- olmo/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 04a4764e0..3d8c24d77 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -1177,7 +1177,7 @@ def forward( if past_key_values: assert len(past_key_values) == self.config.n_layers - batch_size, seq_len = input_ids.size() + batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] if past_key_values is None: past_length = 0 else: @@ -1185,13 +1185,13 @@ def forward( # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - x = self.transformer.wte(input_ids) # type: ignore + x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore if not (self.config.alibi or self.config.rope): # Get positional embeddings. # shape: (1, seq_len) pos = torch.arange( - past_length, past_length + seq_len, dtype=torch.long, device=input_ids.device + past_length, past_length + seq_len, dtype=torch.long, device=x.device ).unsqueeze(0) # shape: (1, seq_len, d_model) pos_emb = self.transformer.wpe(pos) # type: ignore @@ -1232,7 +1232,7 @@ def forward( if attention_mask is not None: mask_len = attention_mask.shape[-1] elif past_key_values is not None: - mask_len = past_key_values[0][0].shape[-2] + input_ids.shape[-1] + mask_len = past_key_values[0][0].shape[-2] + seq_len attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) # Add in the masking bias. From 3d537587973bceeb8f745bcb43078d3fbd3d4772 Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 9 Feb 2024 14:50:21 -0800 Subject: [PATCH 3/5] Run Ruff --- olmo/model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 3d8c24d77..6d2d7dee4 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -1190,9 +1190,7 @@ def forward( if not (self.config.alibi or self.config.rope): # Get positional embeddings. # shape: (1, seq_len) - pos = torch.arange( - past_length, past_length + seq_len, dtype=torch.long, device=x.device - ).unsqueeze(0) + pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) # shape: (1, seq_len, d_model) pos_emb = self.transformer.wpe(pos) # type: ignore x = pos_emb + x @@ -1473,7 +1471,7 @@ def generate( tokens_generated = 0 def flatten_past_key_values( - past_key_values: List[Tuple[torch.Tensor, torch.Tensor]] + past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Dict[str, torch.Tensor]: out = {} for i, (key, value) in enumerate(past_key_values): @@ -1482,7 +1480,7 @@ def flatten_past_key_values( return out def unflatten_past_key_values( - past_key_values: Dict[str, torch.Tensor] + past_key_values: Dict[str, torch.Tensor], ) -> List[Tuple[torch.Tensor, torch.Tensor]]: out = [] for i in range(self.config.n_layers): From faccd949e5b725c201c7266520a6044f0c3142a3 Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 9 Feb 2024 14:53:11 -0800 Subject: [PATCH 4/5] Update Changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba363cd07..a1c89f419 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed default value of `--tokenizer` argument to `scripts/prepare_tulu_data.py` to be an absolute path, not relative path, the script can be run from other directories. +- Added the option to directly pass input embeddings to `OLMo` and `OLMoForCausalLM`. ## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02 From d9c09937e4ad3daef97d1cfdaf2bd948d9316296 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 11 Feb 2024 12:01:55 -0800 Subject: [PATCH 5/5] Require Python>=3.9 for now See #446. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index db9af8201..ba451325f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "Open Language Model (OLMo)" authors = [ { name = "Allen Institute for Artificial Intelligence", email = "olmo@allenai.org" } ] -requires-python = ">=3.8" +requires-python = ">=3.9" license = { file = "LICENSE" } dependencies = [ "numpy",