Skip to content

Commit

Permalink
Fix validation in G2PModel and ThutmoseTaggerModel (#7597)
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishree <[email protected]>
  • Loading branch information
athitten authored Oct 2, 2023
1 parent d0acb40 commit 9285e22
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,15 @@ def validation_step(self, batch, batch_idx):
val_loss_tag = self.loss_fn(logits=tag_logits, labels=tag_labels, loss_mask=labels_mask)
val_loss_semiotic = self.loss_fn(logits=semiotic_logits, labels=semiotic_labels, loss_mask=labels_mask)
val_loss = val_loss_tag + val_loss_semiotic
self.validation_step_outputs.append(val_loss)
return {'val_loss': val_loss}

def on_validation_epoch_end(self, outputs):
def on_validation_epoch_end(self):
"""
Called at the end of validation to aggregate outputs.
:param outputs: list of individual outputs of each validation step.
"""
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean()

# calculate metrics and classification report
# In our task recall = accuracy, and the recall column - is the per class accuracy
Expand All @@ -269,19 +270,21 @@ def on_validation_epoch_end(self, outputs):
self.tag_multiword_classification_report.reset()
self.semiotic_classification_report.reset()

self.validation_step_outputs.clear() # free memory

def test_step(self, batch, batch_idx):
"""
Lightning calls this inside the test loop with the data from the test dataloader
passed in as `batch`.
"""
return self.validation_step(batch, batch_idx)

def on_test_epoch_end(self, outputs):
def on_test_epoch_end(self):
"""
Called at the end of test to aggregate outputs.
:param outputs: list of individual outputs of each test step.
"""
return self.on_validation_epoch_end(outputs)
return self.on_validation_epoch_end()

# Functions for inference
@torch.no_grad()
Expand Down
13 changes: 12 additions & 1 deletion nemo/collections/tts/g2p/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,18 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, split="val"):
)
generated_str, _, _ = self._generate_predictions(input_ids=input_ids, model_max_target_len=self.max_target_len)
per = word_error_rate(hypotheses=generated_str, references=labels_str, use_cer=True)
return {f"{split}_loss": val_loss, 'per': per}
output = {f"{split}_loss": val_loss, 'per': per}
if split == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output)
else:
self.validation_step_outputs.append(output)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output)
else:
self.test_step_outputs.append(output)
return output

def test_step(self, batch, batch_idx, dataloader_idx=0):
"""
Expand Down

0 comments on commit 9285e22

Please sign in to comment.