diff --git a/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py b/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py index be82d6a31582..f6e5c155646d 100644 --- a/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py +++ b/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py @@ -236,14 +236,15 @@ def validation_step(self, batch, batch_idx): val_loss_tag = self.loss_fn(logits=tag_logits, labels=tag_labels, loss_mask=labels_mask) val_loss_semiotic = self.loss_fn(logits=semiotic_logits, labels=semiotic_labels, loss_mask=labels_mask) val_loss = val_loss_tag + val_loss_semiotic + self.validation_step_outputs.append(val_loss) return {'val_loss': val_loss} - def on_validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + avg_loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean() # calculate metrics and classification report # In our task recall = accuracy, and the recall column - is the per class accuracy @@ -269,6 +270,8 @@ def on_validation_epoch_end(self, outputs): self.tag_multiword_classification_report.reset() self.semiotic_classification_report.reset() + self.validation_step_outputs.clear() # free memory + def test_step(self, batch, batch_idx): """ Lightning calls this inside the test loop with the data from the test dataloader @@ -276,12 +279,12 @@ def test_step(self, batch, batch_idx): """ return self.validation_step(batch, batch_idx) - def on_test_epoch_end(self, outputs): + def on_test_epoch_end(self): """ Called at the end of test to aggregate outputs. :param outputs: list of individual outputs of each test step. """ - return self.on_validation_epoch_end(outputs) + return self.on_validation_epoch_end() # Functions for inference @torch.no_grad() diff --git a/nemo/collections/tts/g2p/models/t5.py b/nemo/collections/tts/g2p/models/t5.py index 16f1f1933fb0..b41fcf1d5945 100644 --- a/nemo/collections/tts/g2p/models/t5.py +++ b/nemo/collections/tts/g2p/models/t5.py @@ -170,7 +170,18 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, split="val"): ) generated_str, _, _ = self._generate_predictions(input_ids=input_ids, model_max_target_len=self.max_target_len) per = word_error_rate(hypotheses=generated_str, references=labels_str, use_cer=True) - return {f"{split}_loss": val_loss, 'per': per} + output = {f"{split}_loss": val_loss, 'per': per} + if split == 'val': + if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(output) + else: + self.validation_step_outputs.append(output) + else: + if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx].append(output) + else: + self.test_step_outputs.append(output) + return output def test_step(self, batch, batch_idx, dataloader_idx=0): """