Skip to content

Commit

Permalink
changes to pipeline for backward through aux losses
Browse files Browse the repository at this point in the history
  • Loading branch information
haeggee committed Aug 5, 2024
1 parent 2efffb8 commit 57c58b7
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions src/nanotron/parallel/pipeline_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,18 @@ def forward(
if not isinstance(output, dict):
output = {"loss": output}

# We normalize our loss
if not isinstance(output["loss"], TensorPointer):
output["loss"] = output["loss"] / self.nb_microbatches

# Add output as activations that require backward pass
if not isinstance(output["loss"], TensorPointer):
assert output["loss"].requires_grad
state.register_activation_requiring_backward(output["loss"])
for k, v in output.items():
if not isinstance(v, TensorPointer):
output[k] = v / self.nb_microbatches

# the outputs are either
# - token prediction loss ["loss"]
# - auxiliary losses ["load_balancing_loss", "z_loss"]
# that we need to backpropagate through, so register activations
for loss_key, output_tensor in output.items():
if not isinstance(output_tensor, TensorPointer):
assert output_tensor.requires_grad
state.register_activation_requiring_backward(output_tensor)
return output

@staticmethod
Expand Down Expand Up @@ -154,7 +158,7 @@ def validate_batch_iter(
if not isinstance(output, dict):
output = {"loss": output}

# Store the loss for each microbatch
# Store the loss(es) for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)
Expand Down Expand Up @@ -269,8 +273,9 @@ def train_batch_iter(
send_activation()

# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
for k, v in output.items():
if not isinstance(v, TensorPointer):
output[k] = v.detach()
outputs.append(output)

for micro_batch in batch:
Expand All @@ -282,8 +287,9 @@ def train_batch_iter(
output = {"loss": output}

# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
for k, v in output.items():
if not isinstance(v, TensorPointer):
output[k] = v.detach()
outputs.append(output)

# One backward
Expand Down

0 comments on commit 57c58b7

Please sign in to comment.