From 7b6add579c2b0d149c8e6cb0c50fd4488bc7983d Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 1 Nov 2023 00:19:54 -0700 Subject: [PATCH] Revert "Try to bring back compile" This reverts commit f3491db93856de6b4c63e385e39cdf580f887290. --- olmo/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 724b1e35a..ff4a0ec3c 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -12,6 +12,7 @@ from collections.abc import MutableMapping from functools import partial from typing import ( + Callable, Dict, Iterable, List, @@ -838,6 +839,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True): ) self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None + self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config) if not ( 0 < self.config.block_group_size <= self.config.n_layers @@ -1006,8 +1008,6 @@ def forward( if past_key_values: assert len(past_key_values) == self.config.n_layers - checkpoint = activation_checkpoint_function(self.config) - batch_size, seq_len = input_ids.size() if past_key_values is None: past_length = 0 @@ -1096,7 +1096,7 @@ def forward( ) ): # shape: (batch_size, seq_len, d_model) - x, cache = checkpoint( + x, cache = self._activation_checkpoint_fn( block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache ) else: