Skip to content

Commit

Permalink
Chunked CrossEntropyLoss for AdapterV2 (#1194)
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov authored Mar 26, 2024
1 parent bbae7af commit 8bb5e7b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down
29 changes: 28 additions & 1 deletion litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -80,6 +80,33 @@ def __init__(self, config: Config) -> None:
self.max_seq_length = self.config.block_size
self.mask_cache: Optional[torch.Tensor] = None

def forward(
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
) -> Union[torch.Tensor, List[torch.Tensor]]:
T = idx.size(1)
if self.max_seq_length < T:
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")

if input_pos is not None: # use the kv cache
cos = self.cos.index_select(0, input_pos)
sin = self.sin.index_select(0, input_pos)
if self.mask_cache is None:
raise TypeError("You need to call `gpt.set_kv_cache()`")
mask = self.mask_cache.index_select(2, input_pos)
else:
cos = self.cos[:T]
sin = self.sin[:T]
mask = None

x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
for block in self.transformer.h:
x = block(x, cos, sin, mask, input_pos)
x = self.transformer.ln_f(x)
if lm_head_chunk_size > 0:
# chunk the lm head logits to reduce the peak memory used by autograd
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
return self.lm_head(x) # (b, t, vocab_size)

@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
Expand Down

0 comments on commit 8bb5e7b

Please sign in to comment.