From 437c838c98755642d1e30db2b129db3d2bebe2d3 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 16 Jan 2024 19:42:17 -0800 Subject: [PATCH] Clone then in place ops --- olmo/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/olmo/train.py b/olmo/train.py index dc81a08ec..038ece47c 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -471,14 +471,14 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor: # Labels are just input IDs shifted to the left (first item is ignored). labels, label_mask, attention_mask = ( - batch["input_ids"], + batch["input_ids"].clone(), batch.get("label_mask"), batch.get("attention_mask"), ) if label_mask is not None: - labels = labels.masked_fill(~label_mask, -100) + labels.masked_fill_(~label_mask, -100) if attention_mask is not None: - labels = labels.masked_fill(attention_mask == 0.0, -100) + labels.masked_fill_(attention_mask == 0.0, -100) return labels[..., 1:].contiguous() def model_forward(