Skip to content

Commit

Permalink
fix for eval
Browse files Browse the repository at this point in the history
  • Loading branch information
haeggee committed Sep 6, 2024
1 parent a45dc35 commit bb768bb
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/nanotron/models/gpt3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def forward(
self,
hidden_states: torch.Tensor | TensorPointer,
sequence_mask: torch.Tensor | TensorPointer,
aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]],
aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]] = None,
) -> dict[str, torch.Tensor | TensorPointer]:

residual = hidden_states
Expand All @@ -119,9 +119,10 @@ def forward(
mlp_output = self.ff(hidden_states=hidden_states)
hidden_states = mlp_output["hidden_states"]

for key, value in mlp_output.items():
if key != "hidden_states":
aux_losses[key] = aux_losses[key] + value
if aux_losses is not None:
for key, value in mlp_output.items():
if key != "hidden_states":
aux_losses[key] = aux_losses[key] + value

if self.training:
with branch_random_state(
Expand Down Expand Up @@ -171,7 +172,7 @@ def forward(
self,
input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length]
input_mask: torch.Tensor | TensorPointer, # [batch_size, seq_length]
aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]],
aux_losses: Optional[Dict[str, Union[torch.Tensor, TensorPointer]]] = None,
):
# all tensors are optional as most ranks don't need anything from the dataloader.

Expand Down Expand Up @@ -199,7 +200,10 @@ def forward(

fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]

return {"sharded_logits": fp32_sharded_logits, "aux_losses": hidden_encoder_states["aux_losses"]}
if aux_losses is not None:
return {"sharded_logits": fp32_sharded_logits, "aux_losses": hidden_encoder_states["aux_losses"]}
else:
return fp32_sharded_logits


class GPT3MoEForTraining(GPT3ForTraining):
Expand Down

0 comments on commit bb768bb

Please sign in to comment.