Skip to content

Commit

Permalink
update mean reduction zloss to ignore labels == ignore_index vs. sett…
Browse files Browse the repository at this point in the history
…ing them to 0
  • Loading branch information
jasonkrone committed Dec 4, 2024
1 parent b41634f commit 4b8109e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def cross_entropy_loss(

z_squared = logits.logsumexp(-1).pow(2)
if reduction == "mean":
z_squared = (z_squared * (labels != ignore_index)).mean()
mask = labels != ignore_index
z_squared = (z_squared * mask).sum() / mask.sum()
elif reduction == "sum":
z_squared = (z_squared * (labels != ignore_index)).sum()

Expand Down

0 comments on commit 4b8109e

Please sign in to comment.