Skip to content

Commit

Permalink
Clone then in place ops
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 17, 2024
1 parent 829f090 commit 437c838
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,14 +471,14 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec
def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
# Labels are just input IDs shifted to the left (first item is ignored).
labels, label_mask, attention_mask = (
batch["input_ids"],
batch["input_ids"].clone(),
batch.get("label_mask"),
batch.get("attention_mask"),
)
if label_mask is not None:
labels = labels.masked_fill(~label_mask, -100)
labels.masked_fill_(~label_mask, -100)
if attention_mask is not None:
labels = labels.masked_fill(attention_mask == 0.0, -100)
labels.masked_fill_(attention_mask == 0.0, -100)
return labels[..., 1:].contiguous()

def model_forward(
Expand Down

0 comments on commit 437c838

Please sign in to comment.