diff --git a/nemo/collections/tts/g2p/models/t5.py b/nemo/collections/tts/g2p/models/t5.py index 25f63d8d858a..19f976081687 100644 --- a/nemo/collections/tts/g2p/models/t5.py +++ b/nemo/collections/tts/g2p/models/t5.py @@ -46,17 +46,23 @@ class T5G2PModel(G2PModel, Exportable): @property def input_types(self) -> Optional[Dict[str, NeuralType]]: - return { - "input_ids": NeuralType(('B', 'T'), TokenIndex()), - "attention_mask": NeuralType(('B', 'T'), MaskType(), optional=True), - "labels": NeuralType(('B', 'T'), LabelsType()), - } + if self._input_types is None: + return { + "input_ids": NeuralType(('B', 'T'), TokenIndex()), + "attention_mask": NeuralType(('B', 'T'), MaskType(), optional=True), + "labels": NeuralType(('B', 'T'), LabelsType()), + } + return self._input_types @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - return {"loss": NeuralType((), LossType())} + if self._output_types is None: + return {"loss": NeuralType((), LossType())} + return self._output_types def __init__(self, cfg: DictConfig, trainer: Trainer = None): + self._input_types = None + self._output_types = None self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_devices @@ -91,7 +97,11 @@ def forward(self, input_ids, attention_mask, labels): # ===== Training Functions ===== # def training_step(self, batch, batch_idx): input_ids, attention_mask, labels = batch - train_loss = self.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels,) + train_loss = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + ) self.log('train_loss', train_loss) return train_loss @@ -126,7 +136,10 @@ def _setup_infer_dataloader(self, cfg) -> 'torch.utils.data.DataLoader': # Functions for inference @torch.no_grad() - def _infer(self, config: DictConfig,) -> List[int]: + def _infer( + self, + config: DictConfig, + ) -> List[int]: """ Runs model inference. @@ -161,7 +174,11 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, split="val"): input_ids, attention_mask, labels = batch # Get loss from forward step - val_loss = self.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels,) + val_loss = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + ) # Get preds from generate function and calculate PER labels_str = self._tokenizer.batch_decode( @@ -287,15 +304,8 @@ def _prepare_for_export(self, **kwargs): } def _export_teardown(self): - self._input_types = self._output_types = None - - @property - def input_types(self): - return self._input_types - - @property - def output_types(self): - return self._output_types + self._input_types = None + self._output_types = None def input_example(self, max_batch=1, max_dim=44): """ @@ -307,7 +317,11 @@ def input_example(self, max_batch=1, max_dim=44): sentence = "Kupil sem si bicikel in mu zamenjal stol." input_ids = [sentence] input_encoding = self._tokenizer( - input_ids, padding='longest', max_length=self.max_source_len, truncation=True, return_tensors='pt', + input_ids, + padding='longest', + max_length=self.max_source_len, + truncation=True, + return_tensors='pt', ) return (input_encoding.input_ids,)