Skip to content

Commit

Permalink
migration to PTL 2.0 for spellmapper model (NVIDIA#7924)
Browse files Browse the repository at this point in the history
* migrate to PTL 2.0

Signed-off-by: Alexandra Antonova <[email protected]>

* add difference between val and test

Signed-off-by: Alexandra Antonova <[email protected]>

* fix warning about potential uninitialized variable

Signed-off-by: Alexandra Antonova <[email protected]>

---------

Signed-off-by: Alexandra Antonova <[email protected]>
  • Loading branch information
bene-ges authored Dec 28, 2023
1 parent c2da018 commit 060c926
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@

@hydra_runner(config_path="conf", config_name="spellchecking_asr_customization_config")
def main(cfg: DictConfig) -> None:
# PTL 2.0 has find_unused_parameters as False by default, so its required to set it to True
# when there are unused parameters like here
if cfg.trainer.strategy == 'ddp':
cfg.trainer.strategy = "ddp_find_unused_parameters_true"
logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')

# Train the model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def training_step(self, batch, batch_idx):
return {'loss': loss, 'lr': lr}

# Validation and Testing
def validation_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx, split="val"):
"""
Lightning calls this inside the validation loop with the data from the validation dataloader
passed in as `batch`.
Expand Down Expand Up @@ -271,15 +271,25 @@ def validation_step(self, batch, batch_idx):
torch.tensor(span_predictions).to(self.device), torch.tensor(span_labels).to(self.device)
)

val_loss = self.loss_fn(logits=logits, labels=labels, loss_mask=labels_mask)
return {'val_loss': val_loss}
loss = self.loss_fn(logits=logits, labels=labels, loss_mask=labels_mask)

if split == 'val':
self.validation_step_outputs.append({f'{split}_loss': loss})
elif split == 'test':
self.test_step_outputs.append({f'{split}_loss': loss})

def validation_epoch_end(self, outputs):
return {f'{split}_loss': loss}

def on_validation_epoch_end(self):
"""
Called at the end of validation to aggregate outputs.
:param outputs: list of individual outputs of each validation step.
"""
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
split = "test" if self.trainer.testing else "val"
if split == 'val':
avg_loss = torch.stack([x[f'{split}_loss'] for x in self.validation_step_outputs]).mean()
else:
avg_loss = torch.stack([x[f'{split}_loss'] for x in self.test_step_outputs]).mean()

# Calculate metrics and classification report
# Note that in our task recall = accuracy, and the recall column is the per class accuracy
Expand All @@ -288,8 +298,8 @@ def validation_epoch_end(self, outputs):
logging.info("Total tag accuracy: " + str(tag_accuracy))
logging.info(tag_report)

self.log('val_loss', avg_loss, prog_bar=True)
self.log('tag accuracy', tag_accuracy)
self.log(f"{split}_loss", avg_loss, prog_bar=True)
self.log(f"{split}_tag_accuracy", tag_accuracy)

self.tag_classification_report.reset()

Expand All @@ -298,14 +308,14 @@ def test_step(self, batch, batch_idx):
Lightning calls this inside the test loop with the data from the test dataloader
passed in as `batch`.
"""
return self.validation_step(batch, batch_idx)
return self.validation_step(batch, batch_idx, split="test")

def test_epoch_end(self, outputs):
def on_test_epoch_end(self):
"""
Called at the end of test to aggregate outputs.
:param outputs: list of individual outputs of each test step.
"""
return self.validation_epoch_end(outputs)
return self.on_validation_epoch_end()

# Functions for inference

Expand Down

0 comments on commit 060c926

Please sign in to comment.