diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index f89c920effbd..08af7c8c0617 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -6485,7 +6485,7 @@ jobs: Speech_Checkpoints_tests: needs: [cicd-test-container-setup] runs-on: self-hosted-azure - timeout-minutes: 10 + timeout-minutes: 20 container: image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} options: diff --git a/examples/nlp/language_modeling/conf/megatron_chatglm_config.yaml b/examples/nlp/language_modeling/conf/megatron_chatglm_config.yaml index 5c1191dbe64e..84fbd1b801d4 100644 --- a/examples/nlp/language_modeling/conf/megatron_chatglm_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_chatglm_config.yaml @@ -81,7 +81,7 @@ model: position_embedding_type: 'rope' # Position embedding type. Options ['learned_absolute', 'rope'] rotary_percentage: 0.5 # If using position_embedding_type=rope, then the per head dim is multiplied by this. For chatglm2, it is 0.5 (https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L754) rotary_interleaved: True # chatglm2 use interleaved rotary embedding - apply_rope_fusion: True + apply_rope_fusion: False attention_type: 'multihead' # Attention type. Options ['multihead'] share_embeddings_and_output_weights: False # Share embedding and output layer weights. overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 diff --git a/examples/nlp/language_modeling/conf/megatron_falcon_config.yaml b/examples/nlp/language_modeling/conf/megatron_falcon_config.yaml index f5746433cc78..8905abaf3ac2 100644 --- a/examples/nlp/language_modeling/conf/megatron_falcon_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_falcon_config.yaml @@ -113,7 +113,7 @@ model: bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition. masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. - apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope + apply_rope_fusion: False # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope # Miscellaneous diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index aa43dfe7e53e..0295f96db838 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -130,7 +130,7 @@ model: bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. - apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope + apply_rope_fusion: False # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope # Miscellaneous diff --git a/examples/nlp/language_modeling/conf/megatron_llama_config.yaml b/examples/nlp/language_modeling/conf/megatron_llama_config.yaml index 38ed239ec6e1..965b511fc7e7 100644 --- a/examples/nlp/language_modeling/conf/megatron_llama_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_llama_config.yaml @@ -112,7 +112,7 @@ model: bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition. masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. - apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope + apply_rope_fusion: False # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope # Miscellaneous diff --git a/examples/nlp/language_modeling/conf/megatron_starcoder_config.yaml b/examples/nlp/language_modeling/conf/megatron_starcoder_config.yaml index b170e82ca983..355e575a6d59 100644 --- a/examples/nlp/language_modeling/conf/megatron_starcoder_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_starcoder_config.yaml @@ -117,7 +117,7 @@ model: bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. - apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope + apply_rope_fusion: False # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope # Miscellaneous seed: 1234 diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index 01c56f1e3269..f3413a5fa92e 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -148,7 +148,9 @@ def __init__(self, sentences): super().__init__() self.sentences = sentences - def __len__(self,): + def __len__( + self, + ): return len(self.sentences) def __getitem__(self, idx): @@ -173,7 +175,9 @@ def main(cfg) -> None: callbacks.append(CustomProgressBar()) # trainer required for restoring model parallel models trainer = Trainer( - strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), **cfg.trainer, callbacks=callbacks, + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=callbacks, ) if cfg.gpt_model_file is not None: @@ -224,6 +228,7 @@ def main(cfg) -> None: pretrained_cfg.activations_checkpoint_method = None pretrained_cfg.precision = trainer.precision pretrained_cfg["use_flash_attention"] = cfg.inference.get("use_flash_attention", False) + pretrained_cfg["apply_rope_fusion"] = False if pretrained_cfg.get('mcore_gpt', False): # with dist checkpointing we can use the model parallel config specified by the user pretrained_cfg.tensor_model_parallel_size = cfg.tensor_model_parallel_size diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index f9413a4dd738..b11d744a7e6a 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -21,6 +21,7 @@ import torch from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer +from torch.utils.data import DataLoader from nemo.collections.asr.data.audio_to_text_lhotse_prompted import ( PromptedAudioToTextLhotseDataset, @@ -156,7 +157,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.transf_encoder = EncDecMultiTaskModel.from_config_dict(transf_encoder_cfg_dict) # Initialize weights - std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden ** 0.5 + std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden**0.5 self.transf_encoder.apply(lambda module: transformer_weights_init(module, std_init_range)) transf_decoder_cfg_dict = cfg.transf_decoder @@ -182,7 +183,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight # Initialize weights - std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5 + std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden**0.5 self.transf_decoder.apply(lambda module: transformer_weights_init(module, std_init_range)) self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) @@ -347,7 +348,7 @@ def change_vocabulary( self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight # Initialize weights of token classifier - std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5 + std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden**0.5 self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) # Setup Decoding class @@ -387,7 +388,7 @@ def change_vocabulary( @torch.no_grad() def transcribe( self, - audio: Union[List[str], str], + audio: Union[str, List[str], np.ndarray, DataLoader], batch_size: int = 4, return_hypotheses: bool = False, task: Optional[str] = None, @@ -403,7 +404,8 @@ def transcribe( """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a list) of paths to audio files. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. diff --git a/nemo/collections/asr/models/classification_models.py b/nemo/collections/asr/models/classification_models.py index c1294de5bdc0..7b226f59e364 100644 --- a/nemo/collections/asr/models/classification_models.py +++ b/nemo/collections/asr/models/classification_models.py @@ -15,7 +15,6 @@ import copy import json import os -import tempfile from abc import abstractmethod from dataclasses import dataclass, field from math import ceil, floor @@ -24,6 +23,7 @@ import torch from omegaconf import DictConfig, ListConfig, OmegaConf from pytorch_lightning import Trainer +from torch.utils.data import DataLoader from torchmetrics import Accuracy from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError @@ -169,7 +169,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) # Crop or pad is always applied if self.crop_or_pad is not None: @@ -355,7 +356,7 @@ def _setup_feature_label_dataloader(self, config: DictConfig) -> torch.utils.dat @torch.no_grad() def transcribe( self, - audio: List[str], + audio: Union[List[str], DataLoader], batch_size: int = 4, logprobs=None, override_config: Optional[ClassificationInferConfig] | Optional[RegressionInferConfig] = None, @@ -364,7 +365,8 @@ def transcribe( Generate class labels for provided audio files. Use this method for debugging and prototyping. Args: - audio: (a single or list) of paths to audio files or a np.ndarray audio sample. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is approximately 1 second. batch_size: (int) batch size to use during inference. \ Bigger will result in better throughput performance but would use more memory. @@ -952,7 +954,10 @@ def _setup_dataloader_from_config(self, config: DictConfig): shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 dataset = audio_to_label_dataset.get_tarred_audio_multi_label_dataset( - cfg=config, shuffle_n=shuffle_n, global_rank=self.global_rank, world_size=self.world_size, + cfg=config, + shuffle_n=shuffle_n, + global_rank=self.global_rank, + world_size=self.world_size, ) shuffle = False if hasattr(dataset, 'collate_fn'): @@ -1022,7 +1027,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) # Crop or pad is always applied @@ -1124,7 +1130,7 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): def reshape_labels(self, logits, labels, logits_len, labels_len): """ Reshape labels to match logits shape. For example, each label is expected to cover a 40ms frame, while each frme prediction from the - model covers 20ms. If labels are shorter than logits, labels are repeated, otherwise labels are folded and argmax is applied to obtain + model covers 20ms. If labels are shorter than logits, labels are repeated, otherwise labels are folded and argmax is applied to obtain the label of each frame. When lengths of labels and logits are not factors of each other, labels are truncated or padded with zeros. The ratio_threshold=0.2 is used to determine whether to pad or truncate labels, where the value 0.2 is not important as in real cases the ratio is very close to either ceil(ratio) or floor(ratio). We use 0.2 here for easier unit-testing. This implementation does not allow frame length diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 4df02b1177cd..177da81f85f2 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -22,6 +22,7 @@ import torch from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer +from torch.utils.data import DataLoader from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset @@ -119,7 +120,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): def transcribe( self, - audio: Union[str, List[str], torch.Tensor, np.ndarray], + audio: Union[str, List[str], torch.Tensor, np.ndarray, DataLoader], batch_size: int = 4, return_hypotheses: bool = False, num_workers: int = 0, @@ -135,7 +136,8 @@ def transcribe( Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a single or list) of paths to audio files or a np.ndarray audio array. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. @@ -493,7 +495,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) if self.spec_augmentation is not None and self.training: @@ -579,7 +582,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) transcribed_texts, _ = self.wer.decoding.ctc_decoder_predictions_tensor( - decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, + decoder_outputs=log_probs, + decoder_lengths=encoded_len, + return_hypotheses=False, ) sample_id = sample_id.cpu().detach().numpy() @@ -601,11 +606,19 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0): log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len ) loss_value, metrics = self.add_interctc_losses( - loss_value, transcript, transcript_len, compute_wer=True, log_wer_num_denom=True, log_prefix="val_", + loss_value, + transcript, + transcript_len, + compute_wer=True, + log_wer_num_denom=True, + log_prefix="val_", ) self.wer.update( - predictions=log_probs, targets=transcript, targets_lengths=transcript_len, predictions_lengths=encoded_len, + predictions=log_probs, + targets=transcript, + targets_lengths=transcript_len, + predictions_lengths=encoded_len, ) wer, wer_num, wer_denom = self.wer.compute() self.wer.reset() @@ -677,7 +690,9 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen logits_len = outputs.pop('logits_len') current_hypotheses, all_hyp = self.decoding.ctc_decoder_predictions_tensor( - logits, decoder_lengths=logits_len, return_hypotheses=trcfg.return_hypotheses, + logits, + decoder_lengths=logits_len, + return_hypotheses=trcfg.return_hypotheses, ) if trcfg.return_hypotheses: if logits.is_cuda: diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 3eaab9961ef8..9a5c4188aebd 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -111,7 +111,8 @@ def transcribe( Args: - audio: (a list) of paths to audio files. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. \ @@ -182,7 +183,9 @@ def _transcribe_output_processing( encoded_len = outputs.pop('encoded_len') best_hyp, all_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor( - logits, encoded_len, return_hypotheses=trcfg.return_hypotheses, + logits, + encoded_len, + return_hypotheses=trcfg.return_hypotheses, ) logits = logits.cpu() @@ -554,7 +557,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx): loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss tensorboard_logs['val_loss'] = loss_value self.ctc_wer.update( - predictions=log_probs, targets=transcript, targets_lengths=transcript_len, predictions_lengths=encoded_len, + predictions=log_probs, + targets=transcript, + targets_lengths=transcript_len, + predictions_lengths=encoded_len, ) ctc_wer, ctc_wer_num, ctc_wer_denom = self.ctc_wer.compute() self.ctc_wer.reset() diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 386f2a915142..cb2505fbadbf 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -13,16 +13,15 @@ # limitations under the License. import copy -import json import os -import tempfile from math import ceil from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer -from tqdm.auto import tqdm +from torch.utils.data import DataLoader from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text import _AudioTextDataset @@ -101,7 +100,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.cfg.decoding = self.set_decoding_type_according_to_loss(self.cfg.decoding) # Setup decoding objects self.decoding = RNNTDecoding( - decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + decoding_cfg=self.cfg.decoding, + decoder=self.decoder, + joint=self.joint, + vocabulary=self.joint.vocabulary, ) # Setup WER calculation self.wer = WER( @@ -236,7 +238,7 @@ def set_decoding_type_according_to_loss(self, decoding_cfg): @torch.no_grad() def transcribe( self, - audio: List[str], + audio: Union[str, List[str], np.ndarray, DataLoader], batch_size: int = 4, return_hypotheses: bool = False, partial_hypothesis: Optional[List['Hypothesis']] = None, @@ -250,7 +252,8 @@ def transcribe( Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a list) of paths to audio files. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. \ @@ -338,7 +341,10 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di decoding_cfg = self.set_decoding_type_according_to_loss(decoding_cfg) self.decoding = RNNTDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + vocabulary=self.joint.vocabulary, ) self.wer = WER( @@ -394,7 +400,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): decoding_cfg = self.set_decoding_type_according_to_loss(decoding_cfg) self.decoding = RNNTDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + vocabulary=self.joint.vocabulary, ) self.wer = WER( @@ -649,7 +658,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) # Spec augment is not applied during evaluation/testing diff --git a/nemo/collections/asr/models/slu_models.py b/nemo/collections/asr/models/slu_models.py index 1303bbfde7ea..c599b7f4272a 100644 --- a/nemo/collections/asr/models/slu_models.py +++ b/nemo/collections/asr/models/slu_models.py @@ -13,15 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os -import tempfile from math import ceil from typing import Any, Dict, List, Optional, Union import torch from omegaconf import DictConfig, OmegaConf, open_dict -from tqdm.auto import tqdm +from torch.utils.data import DataLoader from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs @@ -190,7 +188,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) if self.spec_augmentation is not None and self.training: @@ -278,7 +277,8 @@ def predict( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) if self.spec_augmentation is not None and self.training: @@ -560,7 +560,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo @torch.no_grad() def transcribe( self, - audio: List[str], + audio: Union[List[str], DataLoader], batch_size: int = 4, return_hypotheses: bool = False, num_workers: int = 0, @@ -571,7 +571,8 @@ def transcribe( Use this method for debugging and prototyping. Args: - audio: (a list) of paths to audio files. \ + audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 21a5f34b3038..e7e67f8fbb2f 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -24,6 +24,7 @@ import torch.distributed as dist from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer +from torch.utils.data import DataLoader from torchmetrics.text import SacreBLEUScore from tqdm.auto import tqdm @@ -141,7 +142,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): num_layers=self.cfg.head.num_layers, ) self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight - std_init_range = 1 / self.transf_decoder.hidden_size ** 0.5 + std_init_range = 1 / self.transf_decoder.hidden_size**0.5 self.transf_decoder.apply(lambda module: transformer_weights_init(module, std_init_range)) self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) @@ -174,7 +175,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): @torch.no_grad() def transcribe( self, - audio: List[str], + audio: Union[List[str], DataLoader], batch_size: int = 4, return_hypotheses: bool = False, num_workers: int = 0, @@ -185,7 +186,8 @@ def transcribe( """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a list) of paths to audio files. \ + audio: (a list) of paths to audio files. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. @@ -225,7 +227,9 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): config, global_rank=self.global_rank, world_size=self.world_size, - dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer,), + dataset=LhotseSpeechToTextBpeDataset( + tokenizer=self.tokenizer, + ), ) dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index 5a71679607be..cd3f609781ca 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -186,7 +186,7 @@ class TranscriptionMixin(ABC): @torch.no_grad() def transcribe( self, - audio: Union[str, List[str], np.ndarray], + audio: Union[str, List[str], np.ndarray, DataLoader], batch_size: int = 4, return_hypotheses: bool = False, num_workers: int = 0, @@ -201,6 +201,7 @@ def transcribe( Args: audio: (a single or list) of paths to audio files or a np.ndarray audio array. + Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. But it is possible to pass a few hours long file if enough GPU memory is available. batch_size: (int) batch size to use during inference. @@ -368,7 +369,11 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig with tempfile.TemporaryDirectory() as tmpdir: transcribe_cfg._internal.temp_dir = tmpdir - dataloader = self._transcribe_input_processing(audio, transcribe_cfg) + # Create a DataLoader if not already present + if not isinstance(audio, DataLoader): + dataloader = self._transcribe_input_processing(audio, transcribe_cfg) + else: + dataloader = audio if hasattr(transcribe_cfg, 'verbose'): verbose = transcribe_cfg.verbose diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index 70d63c0f8c6f..d2bfb629293e 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -213,20 +213,20 @@ def __init__(self, decoding_cfg, blank_id: int): self.batch_dim_index = self.cfg.get('batch_dim_index', 0) self.word_seperator = self.cfg.get('word_seperator', ' ') - possible_strategies = ['greedy', 'greedy_batched', 'beam', 'pyctcdecode', 'flashlight'] + possible_strategies = ['greedy', 'greedy_batch', 'beam', 'pyctcdecode', 'flashlight'] if self.cfg.strategy not in possible_strategies: raise ValueError(f"Decoding strategy must be one of {possible_strategies}. Given {self.cfg.strategy}") # Update preserve alignments if self.preserve_alignments is None: - if self.cfg.strategy in ['greedy', 'greedy_batched']: + if self.cfg.strategy in ['greedy', 'greedy_batch']: self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False) else: self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False) # Update compute timestamps if self.compute_timestamps is None: - if self.cfg.strategy in ['greedy', 'greedy_batched']: + if self.cfg.strategy in ['greedy', 'greedy_batch']: self.compute_timestamps = self.cfg.greedy.get('compute_timestamps', False) elif self.cfg.strategy in ['beam']: self.compute_timestamps = self.cfg.beam.get('compute_timestamps', False) @@ -234,10 +234,10 @@ def __init__(self, decoding_cfg, blank_id: int): # initialize confidence-related fields self._init_confidence(self.cfg.get('confidence_cfg', None)) - # Confidence estimation is not implemented for strategies other than `greedy` and `greedy_batched` + # Confidence estimation is not implemented for strategies other than `greedy` and `greedy_batch` if ( not self.preserve_frame_confidence - and self.cfg.strategy not in ('greedy', 'greedy_batched') + and self.cfg.strategy not in ('greedy', 'greedy_batch') and self.cfg.beam.get('preserve_frame_confidence', False) ): raise NotImplementedError(f"Confidence calculation is not supported for strategy `{self.cfg.strategy}`") @@ -247,11 +247,6 @@ def __init__(self, decoding_cfg, blank_id: int): self.compute_timestamps |= self.preserve_frame_confidence if self.cfg.strategy == 'greedy': - logging.warning( - "CTC decoding strategy 'greedy' is slower than 'greedy_batched', which implements the same exact interface. Consider changing your strategy to 'greedy_batched' for a free performance improvement.", - mode=logging_mode.ONCE, - ) - self.decoding = ctc_greedy_decoding.GreedyCTCInfer( blank_id=self.blank_id, preserve_alignments=self.preserve_alignments, @@ -260,7 +255,7 @@ def __init__(self, decoding_cfg, blank_id: int): confidence_method_cfg=self.confidence_method_cfg, ) - elif self.cfg.strategy == "greedy_batched": + elif self.cfg.strategy == "greedy_batch": self.decoding = ctc_greedy_decoding.GreedyBatchedCTCInfer( blank_id=self.blank_id, preserve_alignments=self.preserve_alignments, @@ -1023,7 +1018,9 @@ class CTCDecoding(AbstractCTCDecoding): """ def __init__( - self, decoding_cfg, vocabulary, + self, + decoding_cfg, + vocabulary, ): blank_id = len(vocabulary) self.vocabulary = vocabulary @@ -1300,7 +1297,7 @@ def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: @dataclass class CTCDecodingConfig: - strategy: str = "greedy_batched" + strategy: str = "greedy_batch" # preserve decoding alignments preserve_alignments: Optional[bool] = None diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index 1ef26cd7adf3..d0063ee81150 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -22,10 +22,13 @@ from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import HypothesisType, LengthsType, LogprobsType, NeuralType -from nemo.utils import logging +from nemo.utils import logging, logging_mode -def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor,) -> List[rnnt_utils.Hypothesis]: +def pack_hypotheses( + hypotheses: List[rnnt_utils.Hypothesis], + logitlen: torch.Tensor, +) -> List[rnnt_utils.Hypothesis]: if logitlen is not None: if hasattr(logitlen, 'cpu'): @@ -108,8 +111,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] @@ -120,8 +122,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -145,7 +146,9 @@ def __init__( @typecheck() def forward( - self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, + self, + decoder_output: torch.Tensor, + decoder_lengths: torch.Tensor, ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. @@ -158,6 +161,12 @@ def forward( Returns: packed list containing batch number of sentences (Hypotheses). """ + + logging.warning( + "CTC decoding strategy 'greedy' is slower than 'greedy_batch', which implements the same exact interface. Consider changing your strategy to 'greedy_batch' for a free performance improvement.", + mode=logging_mode.ONCE, + ) + with torch.inference_mode(): hypotheses = [] # Process each sequence independently @@ -324,8 +333,7 @@ class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] @@ -336,8 +344,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -361,7 +368,9 @@ def __init__( @typecheck() def forward( - self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, + self, + decoder_output: torch.Tensor, + decoder_lengths: torch.Tensor, ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 223fe22bd00a..f9d6ed5250f6 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -138,7 +138,8 @@ def load_nemo_model_weights(nemo_path, sharded_state_dict=None): tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0] assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.' checkpoint = dist_checkpointing.load( - sharded_state_dict=checkpoint, checkpoint_dir=tmp_model_weights_dir, + sharded_state_dict=checkpoint, + checkpoint_dir=tmp_model_weights_dir, ) state_dict = checkpoint["state_dict"] @@ -149,7 +150,9 @@ def load_nemo_model_weights(nemo_path, sharded_state_dict=None): def setup_trainer_and_models_for_inference( - model_provider: Any, cfg: DictConfig, model_cfg_modifier: Callable, + model_provider: Any, + cfg: DictConfig, + model_cfg_modifier: Callable, ): """ Set up a trainer and NeMo model for inference. @@ -172,7 +175,10 @@ def setup_trainer_and_models_for_inference( # Use the NLPDDPStrategy for the distributed data parallel strategy. # We don't use DDP for async grad allreduce and don't find unused parameters. - strategy = NLPDDPStrategy(no_ddp_communication_hook=True, find_unused_parameters=False,) + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) # Set up the trainer with the specified plugins and strategy. trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) @@ -215,7 +221,9 @@ def setup_trainer_and_models_for_inference( ) model = model_provider.load_from_checkpoint( - single_model_cfg.restore_from_path, hparams_file=cfg.model.get("hparams_file"), trainer=trainer, + single_model_cfg.restore_from_path, + hparams_file=cfg.model.get("hparams_file"), + trainer=trainer, ) models.append(model) @@ -239,7 +247,9 @@ def dummy(): def setup_trainer_and_model_for_inference( - model_provider: Any, cfg: DictConfig, model_cfg_modifier: Callable, + model_provider: Any, + cfg: DictConfig, + model_cfg_modifier: Callable, ) -> Tuple[Trainer, Any]: """ Set up a trainer and NeMo model for inference. @@ -261,7 +271,10 @@ def setup_trainer_and_model_for_inference( # Use the NLPDDPStrategy for the distributed data parallel strategy. # We don't use DDP for async grad allreduce and don't find unused parameters. - strategy = NLPDDPStrategy(no_ddp_communication_hook=True, find_unused_parameters=False,) + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) # Set up the trainer with the specified plugins and strategy. trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) @@ -299,7 +312,9 @@ def setup_trainer_and_model_for_inference( ) model = model_provider.load_from_checkpoint( - cfg.model.restore_from_path, hparams_file=cfg.model.get("hparams_file"), trainer=trainer, + cfg.model.restore_from_path, + hparams_file=cfg.model.get("hparams_file"), + trainer=trainer, ) else: @@ -335,7 +350,9 @@ def create_neva_model_and_processor(cfg): or cfg.get('pipeline_model_parallel_split_rank', -1) < 0 ): model_config = MegatronNevaModel.restore_from( - restore_path=cfg.neva_model_file, trainer=trainer, return_config=True, + restore_path=cfg.neva_model_file, + trainer=trainer, + return_config=True, ) with open_dict(cfg): @@ -366,6 +383,7 @@ def create_neva_model_and_processor(cfg): neva_cfg.activations_checkpoint_method = None neva_cfg.precision = trainer.precision neva_cfg.mm_cfg.llm.from_pretrained = cfg.get('base_model_file', None) + neva_cfg.apply_rope_fusion = False neva_cfg.fp8 = False # neva_cfg.mm_cfg.vision_encoder.from_pretrained = None diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py index d477b337cd29..f20035ad5738 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py @@ -61,9 +61,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): assert ( self.cfg.get("post_process", False) is False ), "post_process must be False to get hidden states in the loss_func" - assert ( - self.cfg.get('apply_rope_fusion', True) is False - ), "RoPE fusion should be set to False for MegatronGPTEmbeddingModel" def model_provider_func(self, pre_process, post_process): # (@adithyare) We need post_process to be False to get hidden states in the loss_func @@ -255,7 +252,14 @@ def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_me gathered_output_batches = [None for _ in range(parallel_state.get_data_parallel_world_size())] torch.distributed.all_gather_object( gathered_output_batches, - [{'q_hs': batch['q_hs'], 'd_hs': batch['d_hs'], 'metadata': batch['metadata'],} for batch in output], + [ + { + 'q_hs': batch['q_hs'], + 'd_hs': batch['d_hs'], + 'metadata': batch['metadata'], + } + for batch in output + ], group=parallel_state.get_data_parallel_group(), ) @@ -272,7 +276,11 @@ def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_me l_d_hs = listify(batch['d_hs']) l_m = batch['metadata'] assert len(l_m) == len(l_q_hs) == len(l_d_hs) - for q_hs, d_hs, metadata in zip(l_q_hs, l_d_hs, l_m,): + for q_hs, d_hs, metadata in zip( + l_q_hs, + l_d_hs, + l_m, + ): total_size += 1 if not metadata.get("__AUTOGENERATED__", False): deduplicated_outputs['q_hs'].append(q_hs) @@ -326,10 +334,10 @@ def write_embeddings_to_file(self, outputs, output_file_path, d_idx): def local_validation_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ # Check if iterator is exhausted # dataloader_iter, done = self._val_iterator_done(dataloader_iter) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 8b7c7a38045c..a27f9fd5e5e4 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -246,12 +246,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): self.use_fsdp = cfg.get('fsdp', False) def setup_transformer_engine_tp_groups(self): - """ This should be called after model parallel groups have been initialized - and only needs to be called when using Transformer Engine. + """This should be called after model parallel groups have been initialized + and only needs to be called when using Transformer Engine. """ for module in self.get_model_module_list(): """Set TP group - Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py#L398 + Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py#L398 """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(module.modules()): @@ -262,14 +262,14 @@ def setup_transformer_engine_tp_groups(self): child.set_tensor_parallel_group(tp_group) def setup_transformer_engine_cp_groups(self): - """ This should be called after context parallel groups have been initialized - and only needs to be called when using Transformer Engine. + """This should be called after context parallel groups have been initialized + and only needs to be called when using Transformer Engine. """ cp_stream = torch.cuda.Stream() for module in self.get_model_module_list(): """Set context parallel running - Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py + Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(module.modules()): @@ -283,11 +283,11 @@ def setup_transformer_engine_cp_groups(self): ) def _wrap_model_for_O2(self): - """ Wraps self.model in a float16 wrapper if the model is using megatron amp O2. - Args: - model: The model to wrap. Can be a list of modules or a single module. - Returns: - The wrapped model. Returns a list of wrapped modules or a single wrapped module. + """Wraps self.model in a float16 wrapper if the model is using megatron amp O2. + Args: + model: The model to wrap. Can be a list of modules or a single module. + Returns: + The wrapped model. Returns a list of wrapped modules or a single wrapped module. """ is_mcore_model = self.__dict__.get('mcore_gpt', False) or self.__dict__.get('mcore_bert', False) @@ -450,10 +450,10 @@ def on_validation_end(self) -> None: gc.collect() def build_transformer_config(self) -> TransformerConfig: - """ Builds the megatron core transformer config for the model. - For attributes in the nemo model config that are the same - as the megatron core TransformerConfig, we will use the value from the nemo model config. - For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. + """Builds the megatron core transformer config for the model. + For attributes in the nemo model config that are the same + as the megatron core TransformerConfig, we will use the value from the nemo model config. + For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. """ # create a dictionary copy of the model config @@ -509,7 +509,7 @@ def build_transformer_config(self) -> TransformerConfig: bias_dropout_fusion = self.cfg.get('bias_dropout_add_fusion', True) - apply_rope_fusion = self.cfg.get('apply_rope_fusion', True) + apply_rope_fusion = self.cfg.get('apply_rope_fusion', False) # TODO: need to check if recompute APIs are matching up properly recompute_granularity = self.cfg.get('activations_checkpoint_granularity', None) @@ -601,7 +601,7 @@ def get_parameters_with_grad(self): def configure_gradient_clipping(self, *args, **kwargs): """PTL hook to configure gradients. - We use gradient clipping implementation from megatron-lm. + We use gradient clipping implementation from megatron-lm. """ clip_val = self.trainer.gradient_clip_val if clip_val is None: @@ -627,13 +627,17 @@ def configure_gradient_clipping(self, *args, **kwargs): parameters = self._optimizer.get_parameters_with_grad() else: parameters = self.get_parameters_with_grad() - grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val, use_fsdp=self.use_fsdp,) + grad_norm = clip_grad_norm_fp32( + parameters=parameters, + max_norm=clip_val, + use_fsdp=self.use_fsdp, + ) self.log('grad_norm', grad_norm, rank_zero_only=True, batch_size=1) def allreduce_gradients(self): """Reduce gradients across data parallel ranks. - Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/model/distributed.py#L188 + Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/model/distributed.py#L188 """ # Bucketize and all-reduce buckets = {} @@ -732,7 +736,9 @@ def on_validation_batch_end(self, outputs, batch: Any, batch_idx: int, dataloade self.validation_global_step += 1 def setup_optimization( - self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, + self, + optim_config: Optional[Union[DictConfig, Dict]] = None, + optim_kwargs: Optional[Dict[str, Any]] = None, ): # Ensure `max_steps` is set correctly optim_config = self._optim_config_copy(optim_config) @@ -913,8 +919,8 @@ def _extract_consumed_samples_from_ckpt(self, ckpt_path): return init_consumed_samples def _validate_and_override_config(self): - """ Certain configurations might be incompatible or discouraged. - We can check for them here and override if necessary. + """Certain configurations might be incompatible or discouraged. + We can check for them here and override if necessary. """ app_state = AppState() @@ -1093,9 +1099,9 @@ def _get_total_params_across_model_parallel_groups_enc_dec(self, model): return num_parameters_on_device, total_num_parameters def build_model_parallel_config(self) -> ModelParallelConfig: - """ For attributes in the nemo model config that are the same as the - megatron core ModelParallelConfig we will use the value from the nemo config. - For attributes in ModelParallelConfig that are not in the nemo model config, we add custom logic. + """For attributes in the nemo model config that are the same as the + megatron core ModelParallelConfig we will use the value from the nemo config. + For attributes in ModelParallelConfig that are not in the nemo model config, we add custom logic. """ cfg = OmegaConf.to_container(self.cfg, resolve=True) @@ -1116,9 +1122,9 @@ def build_model_parallel_config(self) -> ModelParallelConfig: "async_tensor_model_parallel_allreduce": self.cfg.get('tensor_model_parallel_world_size', 1) > 1 and not self.cfg.get('sequence_parallel', False), "pipeline_dtype": pipeline_dtype, - "grad_scale_func": self.trainer.precision_plugin.scaler.scale - if self.trainer.precision in ["16", "16-mixed"] - else None, + "grad_scale_func": ( + self.trainer.precision_plugin.scaler.scale if self.trainer.precision in ["16", "16-mixed"] else None + ), "enable_autocast": not megatron_amp_O2 and self.torch_dtype in [torch.bfloat16, torch.float16], "autocast_dtype": self.autocast_dtype, "variable_seq_lengths": False, # set dynamically during training @@ -1230,7 +1236,7 @@ def find_frozen_submodules(model): return frozen_submodule_names, frozen_submodules if self.use_fsdp: - """ Top-evel FSDP model sharding """ + """Top-evel FSDP model sharding""" # Shard the top-level model hierarchically. We shard the strategy-unwrapped model not # to lose the structure of non-FSDP wrapped parameters (e.g, embedding) # TODO: Currently the main parameter data type is kept in fp32 (when O2=False). This needs to be diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 32bd7e6c1154..50ad38978476 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -136,7 +136,8 @@ def hook(*unused): self._grad_copy(param) if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False): self._try_start_bucket_grad_sync( - params=[param], ignore_last_bucket=need_to_initialize, + params=[param], + ignore_last_bucket=need_to_initialize, ) return hook @@ -167,10 +168,14 @@ def init_params( # Initialize FP8 and non-FP8 tensors separately if any(is_float8tensor(param) for param in params): super().init_params( - filter(is_float8tensor, params), param_sync_dtype=torch.uint8, **kwargs, + filter(is_float8tensor, params), + param_sync_dtype=torch.uint8, + **kwargs, ) super().init_params( - params, param_sync_dtype=param_sync_dtype, **kwargs, + params, + param_sync_dtype=param_sync_dtype, + **kwargs, ) def init_params_bucket( @@ -200,7 +205,10 @@ def init_params_bucket( params = remaining_params start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - fp32_params, grad_sync_dtype=torch.float32, param_sync_dtype=param_sync_dtype, **kwargs, + fp32_params, + grad_sync_dtype=torch.float32, + param_sync_dtype=param_sync_dtype, + **kwargs, ) end_bucket_id = len(self.state["buckets"]) fp32_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] @@ -216,7 +224,10 @@ def init_params_bucket( params = remaining_params start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - fp8_params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=torch.uint8, **kwargs, + fp8_params, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=torch.uint8, + **kwargs, ) end_bucket_id = len(self.state["buckets"]) fp8_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] @@ -225,12 +236,18 @@ def init_params_bucket( normal_buckets = [] start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=param_sync_dtype, **kwargs, + params, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=param_sync_dtype, + **kwargs, ) end_bucket_id = len(self.state["buckets"]) normal_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] - def add_param_to_bucket(param: torch.nn.Parameter, bucket: self.StateBucket,) -> None: + def add_param_to_bucket( + param: torch.nn.Parameter, + bucket: self.StateBucket, + ) -> None: """Add trivial param fragment to bucket""" param_fragments = self.state[param]["fragments"] param_group_id = param_fragments[0].param_group_id @@ -283,7 +300,11 @@ def _init_param_state( # Initialize non-FP8 params as usual if not is_float8tensor(param): super()._init_param_state( - param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs, + param, + param_group_id, + param_id, + param_sync_dtype=param_sync_dtype, + **kwargs, ) # Return immediately if already initialized @@ -293,7 +314,11 @@ def _init_param_state( # Initialize with FP32 copy of param fp32_param = param.float() super()._init_param_state( - fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs, + fp32_param, + param_group_id, + param_id, + param_sync_dtype=torch.uint8, + **kwargs, ) self.state[param].update(self.state[fp32_param]) del self.state[fp32_param] @@ -360,7 +385,9 @@ def init_param_buffer(self) -> None: # Copy values into param buffer _multi_tensor_copy( - param_flat_views, param_buffer_views, dummy_overflow_buf=self._dummy_overflow_buf, + param_flat_views, + param_buffer_views, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Make all params a view into the param buffer @@ -393,7 +420,10 @@ def zero_grad(self, *args, **kwargs) -> None: param.main_grad = self.grad_buffer_view(param) def grad_norm( - self, parameters: Optional[Iterable[torch.nn.Parameter]] = None, norm_type: float = 2.0, force: bool = False, + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None, + norm_type: float = 2.0, + force: bool = False, ) -> torch.Tensor: assert norm_type == 2 @@ -411,7 +441,8 @@ def grad_norm( # Sum over all procs to get grad norm torch.distributed.all_reduce( - grad_norm_sq, op=torch.distributed.ReduceOp.SUM, + grad_norm_sq, + op=torch.distributed.ReduceOp.SUM, ) self._grad_norm = grad_norm_sq.sqrt() @@ -479,7 +510,9 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet # Copy data from parameter buckets to parameters _multi_tensor_copy( - buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf, + buffers_in, + buffers_out, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Update transpose caches @@ -487,7 +520,7 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet for param in params: if is_float8tensor(param): param._reset_caches() - param.transpose(update_cache=True) + param.transpose_2d(cache=True) param._lazy_transpose_cache = True @torch.no_grad() @@ -570,11 +603,15 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( - scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf, + scales, + packed_scale_views, + dummy_overflow_buf=self._dummy_overflow_buf, ) torch.reciprocal(packed_scales, out=packed_scales) _multi_tensor_copy( - packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf, + packed_scale_views, + scale_invs, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Reduce amaxes @@ -582,13 +619,19 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( - amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, + amaxes, + packed_amax_views, + dummy_overflow_buf=self._dummy_overflow_buf, ) torch.distributed.all_reduce( - packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, + packed_amaxes, + op=torch.distributed.ReduceOp.MAX, + group=self.distributed_process_group, ) _multi_tensor_copy( - packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, + packed_amax_views, + amaxes, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Reset @@ -602,7 +645,8 @@ def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None optimizer_state_dict = self.state_dict() id_to_sharded_param_map = get_param_id_to_sharded_param_map( - model_sharded_state_dict=model_sharded_state_dict, optim_params_iter=self.parameters(), + model_sharded_state_dict=model_sharded_state_dict, + optim_params_iter=self.parameters(), ) # Convert state step = optimizer_state_dict['state'].pop('step') diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index 8eceb822fd38..ea2cdea58119 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -90,7 +90,9 @@ def test_constructor_subword(self, tmp_tokenizer): assert decoding is not None @pytest.mark.unit - def test_char_decoding_greedy_forward(self,): + def test_char_decoding_greedy_forward( + self, + ): cfg = CTCDecodingConfig(strategy='greedy') vocab = char_vocabulary() decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab) @@ -206,7 +208,7 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, ) unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) - cfg.strategy = 'greedy_batched' + cfg.strategy = 'greedy_batch' batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) torch.manual_seed(1) @@ -243,7 +245,7 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, def test_batched_decoding_labels(self, tmp_tokenizer, timestamps): cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps) unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) - cfg.strategy = 'greedy_batched' + cfg.strategy = 'greedy_batch' batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) torch.manual_seed(1) diff --git a/tests/collections/asr/mixins/test_transcription.py b/tests/collections/asr/mixins/test_transcription.py index 794213c72397..1a6f38681d0c 100644 --- a/tests/collections/asr/mixins/test_transcription.py +++ b/tests/collections/asr/mixins/test_transcription.py @@ -22,6 +22,7 @@ import torch from torch.utils.data import DataLoader, Dataset +from nemo.collections.asr.data.audio_to_text import _speech_collate_fn from nemo.collections.asr.models import ASRModel from nemo.collections.asr.parts.mixins import TranscribeConfig, TranscriptionMixin from nemo.collections.asr.parts.mixins.transcription import GenericTranscriptionType @@ -121,6 +122,27 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig): self.flag_end = True +class DummyDataset(Dataset): + def __init__(self, audio_tensors: List[str], config: Dict = None): + self.audio_tensors = audio_tensors + self.config = config + + def __getitem__(self, index): + data = self.audio_tensors[index] + samples = torch.tensor(data) + # Calculate seq length + seq_len = torch.tensor(samples.shape[0], dtype=torch.long) + + # Dummy text tokens + text_tokens = torch.tensor([0], dtype=torch.long) + text_tokens_len = torch.tensor(1, dtype=torch.long) + + return (samples, seq_len, text_tokens, text_tokens_len) + + def __len__(self): + return len(self.audio_tensors) + + @pytest.fixture() def dummy_model(): return TranscribableDummy() @@ -326,3 +348,27 @@ def test_transcribe_multiple_tensor(self, test_data_dir): assert len(outputs) == 2 assert isinstance(outputs[0], str) assert isinstance(outputs[1], str) + + @pytest.mark.with_downloads() + @pytest.mark.unit + def test_transcribe_dataloader(self, test_data_dir): + model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") + + # Load audio file + import soundfile as sf + + audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") + audio, sr = sf.read(audio_file, dtype='float32') + + audio_file2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an152-mwhw-b.wav") + audio2, sr = sf.read(audio_file2, dtype='float32') + + dataset = DummyDataset([audio, audio2]) + collate_fn = lambda x: _speech_collate_fn(x, pad_id=0) + dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn) + + # DataLoader test + outputs = model.transcribe(dataloader, batch_size=1) + assert len(outputs) == 2 + assert isinstance(outputs[0], str) + assert isinstance(outputs[1], str) diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index 2005c0e8d41c..0d7c555ee778 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -269,7 +269,7 @@ def test_vocab_change(self, test_data_dir, asr_model): def test_decoding_change(self, asr_model): assert asr_model.decoding is not None assert isinstance(asr_model.decoding, CTCBPEDecoding) - assert asr_model.decoding.cfg.strategy == "greedy_batched" + assert asr_model.decoding.cfg.strategy == "greedy_batch" assert asr_model.decoding.preserve_alignments is False assert asr_model.decoding.compute_timestamps is False @@ -309,7 +309,10 @@ def test_ASRDatasetConfig_for_AudioToBPEDataset(self): REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'} result = assert_dataclass_signature_match( - audio_to_text.AudioToBPEDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS, + audio_to_text.AudioToBPEDataset, + configs.ASRDatasetConfig, + ignore_args=IGNORE_ARGS, + remap_args=REMAP_ARGS, ) signatures_match, cls_subset, dataclass_subset = result diff --git a/tests/collections/asr/test_asr_ctcencdec_model.py b/tests/collections/asr/test_asr_ctcencdec_model.py index d2587913b879..28a07fd54663 100644 --- a/tests/collections/asr/test_asr_ctcencdec_model.py +++ b/tests/collections/asr/test_asr_ctcencdec_model.py @@ -150,7 +150,7 @@ def test_vocab_change(self, asr_model): def test_decoding_change(self, asr_model): assert asr_model.decoding is not None assert isinstance(asr_model.decoding, CTCDecoding) - assert asr_model.decoding.cfg.strategy == "greedy_batched" + assert asr_model.decoding.cfg.strategy == "greedy_batch" assert asr_model.decoding.preserve_alignments is False assert asr_model.decoding.compute_timestamps is False @@ -279,7 +279,10 @@ def test_ASRDatasetConfig_for_AudioToCharDataset(self): REMAP_ARGS = {'trim_silence': 'trim'} result = assert_dataclass_signature_match( - audio_to_text.AudioToCharDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS, + audio_to_text.AudioToCharDataset, + configs.ASRDatasetConfig, + ignore_args=IGNORE_ARGS, + remap_args=REMAP_ARGS, ) signatures_match, cls_subset, dataclass_subset = result diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py index 994d832ec6e5..1743acc6878c 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py @@ -64,12 +64,18 @@ def hybrid_asr_model(test_data_dir): decoder = { '_target_': 'nemo.collections.asr.modules.RNNTDecoder', - 'prednet': {'pred_hidden': model_defaults['pred_hidden'], 'pred_rnn_layers': 1,}, + 'prednet': { + 'pred_hidden': model_defaults['pred_hidden'], + 'pred_rnn_layers': 1, + }, } joint = { '_target_': 'nemo.collections.asr.modules.RNNTJoint', - 'jointnet': {'joint_hidden': 32, 'activation': 'relu',}, + 'jointnet': { + 'joint_hidden': 32, + 'activation': 'relu', + }, } decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}} @@ -111,7 +117,8 @@ def hybrid_asr_model(test_data_dir): class TestEncDecHybridRNNTCTCBPEModel: @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.with_downloads() @pytest.mark.unit @@ -125,7 +132,8 @@ def test_constructor(self, hybrid_asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_forward(self, hybrid_asr_model): @@ -160,7 +168,8 @@ def test_forward(self, hybrid_asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact(self, hybrid_asr_model): @@ -178,7 +187,8 @@ def test_save_restore_artifact(self, hybrid_asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact_spe(self, hybrid_asr_model, test_data_dir): @@ -224,7 +234,8 @@ def test_save_restore_artifact_agg(self, hybrid_asr_model, test_data_dir): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_vocab_change(self, test_data_dir, hybrid_asr_model): @@ -255,7 +266,8 @@ def test_vocab_change(self, test_data_dir, hybrid_asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_change(self, hybrid_asr_model): @@ -297,7 +309,7 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoding is not None assert isinstance(hybrid_asr_model.ctc_decoding, CTCBPEDecoding) - assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy_batched" + assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy_batch" assert hybrid_asr_model.ctc_decoding.preserve_alignments is False assert hybrid_asr_model.ctc_decoding.compute_timestamps is False @@ -309,7 +321,8 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.cur_decoder == "ctc" @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_type_change(self, hybrid_asr_model): diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 923263787def..a0d5627f1a65 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -117,7 +117,8 @@ def hybrid_asr_model(): class TestEncDecHybridRNNTCTCModel: @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_constructor(self, hybrid_asr_model): @@ -129,7 +130,8 @@ def test_constructor(self, hybrid_asr_model): assert isinstance(instance2, EncDecHybridRNNTCTCModel) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_forward(self, hybrid_asr_model): @@ -163,7 +165,8 @@ def test_forward(self, hybrid_asr_model): assert diff <= 1e-6 @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_vocab_change(self, hybrid_asr_model): @@ -186,10 +189,12 @@ def test_vocab_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoder.vocabulary == hybrid_asr_model.joint.vocabulary @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_change(self, hybrid_asr_model): @@ -231,7 +236,7 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoding is not None assert isinstance(hybrid_asr_model.ctc_decoding, CTCDecoding) - assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy_batched" + assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy_batch" assert hybrid_asr_model.ctc_decoding.preserve_alignments is False assert hybrid_asr_model.ctc_decoding.compute_timestamps is False @@ -242,7 +247,8 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoding.compute_timestamps is True @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_type_change(self, hybrid_asr_model): @@ -306,7 +312,8 @@ def test_BeamRNNTInferConfig(self): assert dataclass_subset is None @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -349,11 +356,13 @@ def test_greedy_decoding(self, greedy_class, loop_labels: Optional[bool]): _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer], + "greedy_class", + [greedy_decode.GreedyRNNTInfer], ) def test_greedy_multi_decoding(self, greedy_class): token_list = [" ", "a", "b", "c"] @@ -386,7 +395,8 @@ def test_greedy_multi_decoding(self, greedy_class): _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len, partial_hypotheses=partial_hyp) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -430,11 +440,13 @@ def test_greedy_decoding_stateless_decoder(self, greedy_class, loop_labels: Opti _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer], + "greedy_class", + [greedy_decode.GreedyRNNTInfer], ) def test_greedy_multi_decoding_stateless_decoder(self, greedy_class): token_list = [" ", "a", "b", "c"] @@ -467,7 +479,8 @@ def test_greedy_multi_decoding_stateless_decoder(self, greedy_class): _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len, partial_hypotheses=partial_hyp) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -522,7 +535,8 @@ def test_greedy_decoding_preserve_alignment(self, greedy_class, loop_labels: Opt assert torch.is_tensor(label) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -556,7 +570,12 @@ def test_beam_decoding(self, beam_config): decoder = RNNTDecoder(prednet_cfg, vocab_size) joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) - beam = beam_decode.BeamRNNTInfer(decoder, joint_net, beam_size=beam_size, **beam_config,) + beam = beam_decode.BeamRNNTInfer( + decoder, + joint_net, + beam_size=beam_size, + **beam_config, + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30) @@ -566,12 +585,16 @@ def test_beam_decoding(self, beam_config): _ = beam(encoder_output=enc_out, encoded_lengths=enc_len) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( "beam_config", - [{"search_type": "greedy"}, {"search_type": "default", "score_norm": False, "return_best_hypothesis": False},], + [ + {"search_type": "greedy"}, + {"search_type": "default", "score_norm": False, "return_best_hypothesis": False}, + ], ) def test_beam_decoding_preserve_alignments(self, beam_config): token_list = [" ", "a", "b", "c"] @@ -616,7 +639,8 @@ def test_beam_decoding_preserve_alignments(self, beam_config): assert torch.is_tensor(label) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -659,7 +683,8 @@ def test_greedy_decoding_SampledRNNTJoint(self, greedy_class, loop_labels: Optio _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -693,7 +718,12 @@ def test_beam_decoding_SampledRNNTJoint(self, beam_config): decoder = RNNTDecoder(prednet_cfg, vocab_size) joint_net = SampledRNNTJoint(jointnet_cfg, vocab_size, n_samples=2, vocabulary=token_list) - beam = beam_decode.BeamRNNTInfer(decoder, joint_net, beam_size=beam_size, **beam_config,) + beam = beam_decode.BeamRNNTInfer( + decoder, + joint_net, + beam_size=beam_size, + **beam_config, + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30) diff --git a/tutorials/asr/ASR_Context_Biasing.ipynb b/tutorials/asr/ASR_Context_Biasing.ipynb index c632205311c0..75385234ce29 100644 --- a/tutorials/asr/ASR_Context_Biasing.ipynb +++ b/tutorials/asr/ASR_Context_Biasing.ipynb @@ -259,6 +259,7 @@ "execution_count": null, "id": "d34ee0ba", "metadata": { + "collapsed": true, "jupyter": { "outputs_hidden": true }, @@ -717,6 +718,28 @@ "The context graph consists of a composition of a prefix tree (Trie) with the CTC transition topology for words and phrases from the context-biasing list. We use a BPE tokenizer from the target ASR model for word segmentation." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "55a36a27-919c-4d64-9163-b0b2c9dca15e", + "metadata": {}, + "outputs": [], + "source": [ + "# install graphviz from source in case of local run (not Google Colab)\n", + "# this may take about 5-10 minutes\n", + "# make sure that env variables have been set\n", + "\n", + "if not IN_COLAB:\n", + "\n", + " os.environ['DEBIAN_FRONTEND'] = 'noninteractive'\n", + " os.environ['TZ'] = 'Etc/UTC'\n", + "\n", + " !echo $DEBIAN_FRONTEND\n", + " !echo $TZ\n", + "\n", + " !{NEMO_DIR_PATH}/scripts/installers/install_graphviz.sh" + ] + }, { "cell_type": "code", "execution_count": null, @@ -750,23 +773,6 @@ "context_graph.draw()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "e1c57878", - "metadata": {}, - "outputs": [], - "source": [ - "# install graphviz from source if you have problems with graph picture\n", - "# set instal_graphviz = True\n", - "# this may take about 5-10 minutes\n", - "\n", - "instal_graphviz = False\n", - "\n", - "if instal_graphviz:\n", - " !{NEMO_DIR_PATH}/scripts/installers/install_graphviz.sh" - ] - }, { "cell_type": "markdown", "id": "04a6f4be", diff --git a/tutorials/asr/Online_Offline_Microphone_VAD_Demo.ipynb b/tutorials/asr/Online_Offline_Microphone_VAD_Demo.ipynb index b4d1ddf825d4..97e8d273f2fd 100644 --- a/tutorials/asr/Online_Offline_Microphone_VAD_Demo.ipynb +++ b/tutorials/asr/Online_Offline_Microphone_VAD_Demo.ipynb @@ -638,7 +638,7 @@ " ax2.set_ylabel('Preds and Probas')\n", " \n", " \n", - "ax = plt.subplot(num+1,1,i+2)\n", + "ax = plt.subplot(num+1,1,num+1)\n", "S = librosa.feature.melspectrogram(y=audio, sr=sample_rate, n_mels=64, fmax=8000)\n", "S_dB = librosa.power_to_db(S, ref=np.max)\n", "librosa.display.specshow(S_dB, x_axis='time', y_axis='mel', sr=sample_rate, fmax=8000)\n",