diff --git a/lit_gpt/utils.py b/lit_gpt/utils.py index c24c124dff..9e8e952378 100644 --- a/lit_gpt/utils.py +++ b/lit_gpt/utils.py @@ -262,7 +262,9 @@ def chunked_cross_entropy( torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] - return torch.cat(loss_chunks).mean() + non_masked_elems = (targets != -1).sum() + mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems) + return mean_loss # no chunking at all logits = logits.reshape(-1, logits.size(-1)) @@ -277,7 +279,9 @@ def chunked_cross_entropy( torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] - return torch.cat(loss_chunks).mean() + non_masked_elems = (targets != -1).sum() + mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems) + return mean_loss def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: diff --git a/tests/test_utils.py b/tests/test_utils.py index a704267b08..97117ccadb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -94,7 +94,8 @@ def test_incremental_write(tmp_path): @pytest.mark.parametrize("B", (1, 2)) -def test_chunked_cross_entropy(B): +@pytest.mark.parametrize("with_ignore_index", (True, False)) +def test_chunked_cross_entropy(with_ignore_index, B): from lit_gpt.utils import chunked_cross_entropy V = 50 @@ -102,7 +103,14 @@ def test_chunked_cross_entropy(B): regular_logits = torch.randn(B, T, V) targets = torch.randint(0, V, (B, T)) - baseline_loss = F.cross_entropy(regular_logits.reshape(-1, regular_logits.size(-1)), targets.reshape(-1)) + if with_ignore_index: + targets[:, [1, 4, 10, 19]] = -1 + + baseline_loss = F.cross_entropy( + regular_logits.reshape(-1, regular_logits.size(-1)), + targets.reshape(-1), + ignore_index=(-1 if with_ignore_index else -100), + ) regular_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=0) assert torch.equal(baseline_loss, regular_loss) assert regular_loss.numel() == 1