diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml index ce7016c672..5f43634679 100644 --- a/.github/workflows/cpu-tests.yml +++ b/.github/workflows/cpu-tests.yml @@ -37,7 +37,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index a77f2c94d5..c43120a974 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -9,7 +9,7 @@ """ from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -80,6 +80,33 @@ def __init__(self, config: Config) -> None: self.max_seq_length = self.config.block_size self.mask_cache: Optional[torch.Tensor] = None + def forward( + self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 + ) -> Union[torch.Tensor, List[torch.Tensor]]: + T = idx.size(1) + if self.max_seq_length < T: + raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") + + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos) + x = self.transformer.ln_f(x) + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] + return self.lm_head(x) # (b, t, vocab_size) + @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: return cls(Config.from_name(name, **kwargs))