Skip to content

Commit

Permalink
Bugfix: Corrected loss computation with chunksize > 0 (#751)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
codiceSpaghetti and awaelchli authored Nov 20, 2023
1 parent 759fcc6 commit 7b867e9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
8 changes: 6 additions & 2 deletions lit_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ def chunked_cross_entropy(
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
return torch.cat(loss_chunks).mean()
non_masked_elems = (targets != -1).sum()
mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems)
return mean_loss

# no chunking at all
logits = logits.reshape(-1, logits.size(-1))
Expand All @@ -277,7 +279,9 @@ def chunked_cross_entropy(
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
return torch.cat(loss_chunks).mean()
non_masked_elems = (targets != -1).sum()
mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems)
return mean_loss


def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
Expand Down
12 changes: 10 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,23 @@ def test_incremental_write(tmp_path):


@pytest.mark.parametrize("B", (1, 2))
def test_chunked_cross_entropy(B):
@pytest.mark.parametrize("with_ignore_index", (True, False))
def test_chunked_cross_entropy(with_ignore_index, B):
from lit_gpt.utils import chunked_cross_entropy

V = 50
T = 25
regular_logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))

baseline_loss = F.cross_entropy(regular_logits.reshape(-1, regular_logits.size(-1)), targets.reshape(-1))
if with_ignore_index:
targets[:, [1, 4, 10, 19]] = -1

baseline_loss = F.cross_entropy(
regular_logits.reshape(-1, regular_logits.size(-1)),
targets.reshape(-1),
ignore_index=(-1 if with_ignore_index else -100),
)
regular_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=0)
assert torch.equal(baseline_loss, regular_loss)
assert regular_loss.numel() == 1
Expand Down

0 comments on commit 7b867e9

Please sign in to comment.