Skip to content

Commit

Permalink
Fix T5 G2P Input and Output Types (#9224) (#9269)
Browse files Browse the repository at this point in the history
* fix t5 g2p model



* Apply isort and black reformatting



---------

Signed-off-by: Jason <[email protected]>
Signed-off-by: blisc <[email protected]>
Co-authored-by: Jason <[email protected]>
Co-authored-by: blisc <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
4 people authored Jun 6, 2024
1 parent 95ca2f4 commit 3b758de
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions nemo/collections/tts/g2p/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,23 @@ class T5G2PModel(G2PModel, Exportable):

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"input_ids": NeuralType(('B', 'T'), TokenIndex()),
"attention_mask": NeuralType(('B', 'T'), MaskType(), optional=True),
"labels": NeuralType(('B', 'T'), LabelsType()),
}
if self._input_types is None:
return {
"input_ids": NeuralType(('B', 'T'), TokenIndex()),
"attention_mask": NeuralType(('B', 'T'), MaskType(), optional=True),
"labels": NeuralType(('B', 'T'), LabelsType()),
}
return self._input_types

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"loss": NeuralType((), LossType())}
if self._output_types is None:
return {"loss": NeuralType((), LossType())}
return self._output_types

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._input_types = None
self._output_types = None
self.world_size = 1
if trainer is not None:
self.world_size = trainer.num_nodes * trainer.num_devices
Expand Down Expand Up @@ -91,7 +97,11 @@ def forward(self, input_ids, attention_mask, labels):
# ===== Training Functions ===== #
def training_step(self, batch, batch_idx):
input_ids, attention_mask, labels = batch
train_loss = self.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels,)
train_loss = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)

self.log('train_loss', train_loss)
return train_loss
Expand Down Expand Up @@ -126,7 +136,10 @@ def _setup_infer_dataloader(self, cfg) -> 'torch.utils.data.DataLoader':

# Functions for inference
@torch.no_grad()
def _infer(self, config: DictConfig,) -> List[int]:
def _infer(
self,
config: DictConfig,
) -> List[int]:
"""
Runs model inference.
Expand Down Expand Up @@ -161,7 +174,11 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, split="val"):
input_ids, attention_mask, labels = batch

# Get loss from forward step
val_loss = self.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels,)
val_loss = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)

# Get preds from generate function and calculate PER
labels_str = self._tokenizer.batch_decode(
Expand Down Expand Up @@ -287,15 +304,8 @@ def _prepare_for_export(self, **kwargs):
}

def _export_teardown(self):
self._input_types = self._output_types = None

@property
def input_types(self):
return self._input_types

@property
def output_types(self):
return self._output_types
self._input_types = None
self._output_types = None

def input_example(self, max_batch=1, max_dim=44):
"""
Expand All @@ -307,7 +317,11 @@ def input_example(self, max_batch=1, max_dim=44):
sentence = "Kupil sem si bicikel in mu zamenjal stol."
input_ids = [sentence]
input_encoding = self._tokenizer(
input_ids, padding='longest', max_length=self.max_source_len, truncation=True, return_tensors='pt',
input_ids,
padding='longest',
max_length=self.max_source_len,
truncation=True,
return_tensors='pt',
)
return (input_encoding.input_ids,)

Expand Down

0 comments on commit 3b758de

Please sign in to comment.