Skip to content

Commit

Permalink
Support passing the ignore_index in chunked_cross_entropy (#904)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jan 29, 2024
1 parent 0e085e5 commit 00defde
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
17 changes: 10 additions & 7 deletions lit_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ def __exit__(self, type, value, traceback):


def chunked_cross_entropy(
logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128
logits: Union[torch.Tensor, List[torch.Tensor]],
targets: torch.Tensor,
chunk_size: int = 128,
ignore_index: int = -1,
) -> torch.Tensor:
# with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
# the memory usage in fine-tuning settings with low number of parameters.
Expand All @@ -241,32 +244,32 @@ def chunked_cross_entropy(
logits = torch.cat(logits, dim=1)
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=ignore_index)

# chunk cross entropy
logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)]
loss_chunks = [
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none")
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != -1).sum()
non_masked_elems = (targets != ignore_index).sum()
return torch.cat(loss_chunks).sum() / max(1, non_masked_elems)

# no chunking at all
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
if chunk_size == 0:
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=ignore_index)

# lm_head wasn't chunked, chunk cross entropy
logit_chunks = logits.split(chunk_size)
target_chunks = targets.split(chunk_size)
loss_chunks = [
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none")
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != -1).sum()
non_masked_elems = (targets != ignore_index).sum()
return torch.cat(loss_chunks).sum() / max(1, non_masked_elems)


Expand Down
20 changes: 11 additions & 9 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,39 +96,41 @@ def test_incremental_write(tmp_path):


@pytest.mark.parametrize("B", (1, 2))
@pytest.mark.parametrize("with_ignore_index", (True, False))
def test_chunked_cross_entropy(with_ignore_index, B):
@pytest.mark.parametrize("ignore_index", (None, -1, -2, -100))
def test_chunked_cross_entropy(ignore_index, B):
from lit_gpt.utils import chunked_cross_entropy

V = 50
T = 25
regular_logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))

if with_ignore_index:
targets[:, [1, 4, 10, 19]] = -1
if ignore_index is not None:
targets[:, [1, 4, 10, 19]] = ignore_index

baseline_loss = F.cross_entropy(
regular_logits.reshape(-1, regular_logits.size(-1)),
targets.reshape(-1),
ignore_index=(-1 if with_ignore_index else -100),
ignore_index=(ignore_index if ignore_index is not None else -100),
)
regular_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=0)

ignore_index = ignore_index if ignore_index is not None else -1
regular_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=0, ignore_index=ignore_index)
assert torch.equal(baseline_loss, regular_loss)
assert regular_loss.numel() == 1

chunked_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=10)
chunked_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=10, ignore_index=ignore_index)
torch.testing.assert_close(chunked_loss, regular_loss)
torch.testing.assert_close(chunked_loss, baseline_loss)

logit_chunk_size = 6
assert T % logit_chunk_size != 0 # ensure leftover
chunked_logits = list(regular_logits.split(logit_chunk_size, dim=1))
chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=0)
chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=0, ignore_index=ignore_index)
torch.testing.assert_close(chunked_loss, regular_loss)
torch.testing.assert_close(chunked_loss, baseline_loss)

chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=10)
chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=10, ignore_index=ignore_index)
torch.testing.assert_close(chunked_loss, regular_loss)
torch.testing.assert_close(chunked_loss, baseline_loss)

Expand Down

0 comments on commit 00defde

Please sign in to comment.