From bb768bb0ae1cb8f12550719c5c89a3c1f54b321c Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Fri, 6 Sep 2024 09:50:13 +0200 Subject: [PATCH] fix for eval --- src/nanotron/models/gpt3_moe.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/nanotron/models/gpt3_moe.py b/src/nanotron/models/gpt3_moe.py index 1915136c..06a624ac 100644 --- a/src/nanotron/models/gpt3_moe.py +++ b/src/nanotron/models/gpt3_moe.py @@ -95,7 +95,7 @@ def forward( self, hidden_states: torch.Tensor | TensorPointer, sequence_mask: torch.Tensor | TensorPointer, - aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]] = None, ) -> dict[str, torch.Tensor | TensorPointer]: residual = hidden_states @@ -119,9 +119,10 @@ def forward( mlp_output = self.ff(hidden_states=hidden_states) hidden_states = mlp_output["hidden_states"] - for key, value in mlp_output.items(): - if key != "hidden_states": - aux_losses[key] = aux_losses[key] + value + if aux_losses is not None: + for key, value in mlp_output.items(): + if key != "hidden_states": + aux_losses[key] = aux_losses[key] + value if self.training: with branch_random_state( @@ -171,7 +172,7 @@ def forward( self, input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length] input_mask: torch.Tensor | TensorPointer, # [batch_size, seq_length] - aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], + aux_losses: Optional[Dict[str, Union[torch.Tensor, TensorPointer]]] = None, ): # all tensors are optional as most ranks don't need anything from the dataloader. @@ -199,7 +200,10 @@ def forward( fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] - return {"sharded_logits": fp32_sharded_logits, "aux_losses": hidden_encoder_states["aux_losses"]} + if aux_losses is not None: + return {"sharded_logits": fp32_sharded_logits, "aux_losses": hidden_encoder_states["aux_losses"]} + else: + return fp32_sharded_logits class GPT3MoEForTraining(GPT3ForTraining):