From 7b867e96cc4215144b34c9e756b79a36e9209fc6 Mon Sep 17 00:00:00 2001 From: Alessio Serra Date: Mon, 20 Nov 2023 22:57:18 +0100 Subject: [PATCH] Bugfix: Corrected loss computation with chunksize > 0 (#751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- lit_gpt/utils.py | 8 ++++++-- tests/test_utils.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/lit_gpt/utils.py b/lit_gpt/utils.py index c24c124dff..9e8e952378 100644 --- a/lit_gpt/utils.py +++ b/lit_gpt/utils.py @@ -262,7 +262,9 @@ def chunked_cross_entropy( torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] - return torch.cat(loss_chunks).mean() + non_masked_elems = (targets != -1).sum() + mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems) + return mean_loss # no chunking at all logits = logits.reshape(-1, logits.size(-1)) @@ -277,7 +279,9 @@ def chunked_cross_entropy( torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] - return torch.cat(loss_chunks).mean() + non_masked_elems = (targets != -1).sum() + mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems) + return mean_loss def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: diff --git a/tests/test_utils.py b/tests/test_utils.py index a704267b08..97117ccadb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -94,7 +94,8 @@ def test_incremental_write(tmp_path): @pytest.mark.parametrize("B", (1, 2)) -def test_chunked_cross_entropy(B): +@pytest.mark.parametrize("with_ignore_index", (True, False)) +def test_chunked_cross_entropy(with_ignore_index, B): from lit_gpt.utils import chunked_cross_entropy V = 50 @@ -102,7 +103,14 @@ def test_chunked_cross_entropy(B): regular_logits = torch.randn(B, T, V) targets = torch.randint(0, V, (B, T)) - baseline_loss = F.cross_entropy(regular_logits.reshape(-1, regular_logits.size(-1)), targets.reshape(-1)) + if with_ignore_index: + targets[:, [1, 4, 10, 19]] = -1 + + baseline_loss = F.cross_entropy( + regular_logits.reshape(-1, regular_logits.size(-1)), + targets.reshape(-1), + ignore_index=(-1 if with_ignore_index else -100), + ) regular_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=0) assert torch.equal(baseline_loss, regular_loss) assert regular_loss.numel() == 1