Skip to content

Commit

Permalink
Revert "Try to bring back compile"
Browse files Browse the repository at this point in the history
This reverts commit f3491db.
  • Loading branch information
dirkgr committed Nov 1, 2023
1 parent f52220f commit 7b6add5
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections.abc import MutableMapping
from functools import partial
from typing import (
Callable,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -838,6 +839,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
)

self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)

if not (
0 < self.config.block_group_size <= self.config.n_layers
Expand Down Expand Up @@ -1006,8 +1008,6 @@ def forward(
if past_key_values:
assert len(past_key_values) == self.config.n_layers

checkpoint = activation_checkpoint_function(self.config)

batch_size, seq_len = input_ids.size()
if past_key_values is None:
past_length = 0
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def forward(
)
):
# shape: (batch_size, seq_len, d_model)
x, cache = checkpoint(
x, cache = self._activation_checkpoint_fn(
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
Expand Down

0 comments on commit 7b6add5

Please sign in to comment.