diff --git a/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml b/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml new file mode 100644 index 000000000000..76ce1c0150c7 --- /dev/null +++ b/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml @@ -0,0 +1,277 @@ +# It contains the default values for training an autoregressive FastConformer-Transformer AED model with sub-word encoding. + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file. +# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes +# It is recommended to initialize FastConformer with ASR/SSL pre-trained encoder for better accuracy and faster convergence + +name: "FastConformer-Transformer-MultiTask" + +# Note: for larger models (1B+ params) initializing from a pretrained encoder +# may help (or even be required to) stabilize the training. +init_from_nemo_model: null + +model: + _target_: nemo.collections.asr.models.EncDecMultiTaskModel + sample_rate: 16000 + label_smoothing: 0.0 + context_len_for_AR_decoding: 5 # Length of input prompt tokens. For example, in Canary models, we use [BOS,src_lang,task,tgt_lang,pnc] and thus the length is 5 + log_prediction: true # enables logging sample predictions in the output during training + + # Important ! Set the prompt format to the class you need + prompt_format: ??? # Options supported: ["canary"] + + model_defaults: + asr_enc_hidden: 1024 + lm_enc_hidden: 512 + lm_dec_hidden: 1024 + + train_ds: + use_lhotse: true + tarred_audio_filepaths: null + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + shuffle: true + num_workers: 8 + # To understand the settings below, please refer to Lhotse Dataloading documentation: + # https://github.com/NVIDIA/NeMo/blob/main/docs/source/asr/datasets.rst#lhotse-dataloading + # You can also check the following configuration dataclass: + # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/lhotse/dataloader.py#L36 + batch_size: None + batch_duration: 360 + quadratic_duration: 15 + use_bucketing: True + num_buckets: 20 + bucket_buffer_size: 20000 + shuffle_buffer_size: 10000 + + validation_ds: + use_lhotse: true + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 4 + pin_memory: true + use_start_end_token: true + use_bucketing: false + + test_ds: + use_lhotse: true + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 4 + pin_memory: true + use_start_end_token: true + use_bucketing: false + + # recommend small vocab size of 128 or 256 when using 4x sub-sampling + # you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: null # Null for aggregate tokenizers + type: agg # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) or `agg` for aggregate tokenizers + langs: + spl_tokens: # special tokens model + dir: ??? + type: bpe + en: # English tokenizer (example, replace with whichever language you would like or add tokenizers to add tokenizer for additional languages) + dir: ??? + type: bpe + + custom_tokenizer: + _target_: nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer # Can be replaced with other tokenizer for different prompt formats + tokenizers: null # Filled at runtime by all the tokenizers inside the aggregate tokenizer + + # Audio Preprocessor + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + pad_value: 0.0 + + # SpecAugment is applied either in the model or in the data layer + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + # you may use lower time_masks for smaller models to have a faster convergence + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # FastConformer Encoder + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 24 + d_model: ${model.model_defaults.asr_enc_hidden} + + # Sub-sampling params + subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory + subsampling_factor: 8 # must be power of 2 + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: false + reduction: null + reduction_position: null + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + xscaling: false # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: batch_norm + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # Optional Transformer Encoder sandwitched between ASR Encoder and Transformer Ddcoder. + # Only used if num_layers > 0 + transf_encoder: + _target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder + num_layers: 0 + hidden_size: ${model.model_defaults.lm_enc_hidden} + inner_size: ${multiply:${model.model_defaults.lm_enc_hidden}, 4} + num_attention_heads: 8 + ffn_dropout: 0.1 + attn_score_dropout: 0.1 + attn_layer_dropout: 0.1 + mask_future: False + pre_ln: True + pre_ln_final_layer_norm: True + + transf_decoder: + _target_: nemo.collections.asr.modules.transformer.get_nemo_transformer + model_name: null + pretrained: false + encoder: null + pre_ln_final_layer_norm: true + + config_dict: + max_sequence_length: 512 + num_token_types: 0 + embedding_dropout: 0.1 + learn_positional_encodings: false + hidden_size: ${model.model_defaults.lm_dec_hidden} + inner_size: ${multiply:${model.model_defaults.lm_dec_hidden}, 4} + num_layers: 24 + num_attention_heads: 8 + ffn_dropout: 0.1 + attn_score_dropout: 0.1 + attn_layer_dropout: 0.1 + hidden_act: relu + pre_ln: true + vocab_size: None # Will be set by the model at runtime + + # Label Prediction Head (Token Classifier) + head: + _target_: nemo.collections.asr.parts.submodules.token_classifier.TokenClassifier + num_layers: 1 + activation: relu + log_softmax: true + hidden_size: ${model.transf_decoder.config_dict.hidden_size} + num_classes: None # Will be set by the model at runtime + dropout: 0.0 + use_transformer_init: true + + # Decoding Strategy + decoding: + strategy: beam + return_best_hypothesis: true # Returns the most probably hypothesis after beam search + + beam: + beam_size: 1 + len_pen: 0.0 + max_generation_delta: 50 + + # Loss Config + loss: + _target_: nemo.collections.common.losses.smoothed_cross_entropy.SmoothedCrossEntropyLoss + label_smoothing: ${model.label_smoothing} + pad_id: null + + optim: + name: adamw + lr: 3e-4 + # optimizer arguments + betas: [0.9, 0.98] + # less necessity for weight_decay as we already have large augmentations with SpecAug + # you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used + # weight decay of 0.0 with lr of 2.0 also works fine + weight_decay: 1e-3 + + # scheduler setup + sched: + name: InverseSquareRootAnnealing + # scheduler config override + warmup_steps: 2500 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 100000 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 100 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 2 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_sacreBLEU" + mode: "max" + save_top_k: 3 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml b/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml index a1610364f526..6b6dbf129c54 100644 --- a/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml +++ b/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml @@ -176,7 +176,7 @@ model: min_lr: 1e-6 trainer: - gpus: -1 # number of GPUs, -1 would use all available GPUs + devices: -1 # number of GPUs, -1 would use all available GPUs num_nodes: 1 max_epochs: 100 max_steps: -1 # computed at runtime if not set diff --git a/examples/asr/speech_multitask/speech_to_text_aed.py b/examples/asr/speech_multitask/speech_to_text_aed.py new file mode 100644 index 000000000000..813ce03e8f38 --- /dev/null +++ b/examples/asr/speech_multitask/speech_to_text_aed.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Training the model +```sh +python speech_to_text_aed.py \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.tarred_audio_filepaths= \ + model.train_ds.manifest_filepath= \ + model.train_ds.batch_duration=360 \ + model.train_ds.num_buckets=30 \ + model.train_ds.bucket_duration_bins= \ + model.validation_ds.manifest_filepath= \ + model.test_ds.manifest_filepath= \ + model.model_defaults.asr_enc_hidden=1024 \ + model.model_defaults.lm_enc_hidden=512 \ + model.model_defaults.lm_dec_hidden=1024 \ + model.tokenizer.langs.spl_tokens.dir= \ + model.tokenizer.langs.spl_tokens.type=bpe \ + model.tokenizer.langs.en.dir= \ + model.tokenizer.langs.en.type=bpe \ + model.prompt_format="canary" \ + trainer.devices=-1 \ + trainer.accelerator="ddp" \ + trainer.max_steps=100000 \ + +trainer.limit_train_batches=20000 \ + trainer.val_check_interval=5000 \ + +trainer.use_distributed_sampler=false \ + model.optim.name="adamw" \ + model.optim.lr=0.001 \ + model.optim.betas=[0.9,0.999] \ + model.optim.weight_decay=0.0001 \ + model.optim.sched.warmup_steps=2000 \ + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name="" \ + exp_manager.wandb_logger_kwargs.project="" +``` + + +""" + +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecMultiTaskModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="../conf/speech_multitask/", config_name="fast-conformer_aed") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + aed_model = EncDecMultiTaskModel(cfg=cfg.model, trainer=trainer) + + # Initialize the weights of the model from another model, if provided via config + aed_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(aed_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if aed_model.prepare_test(trainer): + trainer.test(aed_model) + + +if __name__ == '__main__': + main() diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index df0f61121f00..79d6abfc154a 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -21,9 +21,10 @@ import torch from omegaconf import OmegaConf, open_dict -from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel +from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecMultiTaskModel from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig +from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis @@ -154,6 +155,9 @@ class TranscriptionConfig: # Decoding strategy for RNNT models rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1) + # Decoding strategy for AED models + multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig() + # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models decoder_type: Optional[str] = None # att_context_size can be set for cache-aware streaming models with multiple look-aheads @@ -178,6 +182,11 @@ class TranscriptionConfig: # key for groundtruth text in manifest gt_text_attr_name: str = "text" + # Use model's transcribe() function instead of transcribe_partial_audio() by default + # Only use transcribe_partial_audio() when the audio is too long to fit in memory + # Your manifest input should have `offset` field to use transcribe_partial_audio() + allow_partial_transcribe: bool = False + @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]: @@ -257,7 +266,11 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy'): - if cfg.decoder_type is not None: + if isinstance(asr_model.decoding, MultiTaskDecoding): + cfg.multitask_decoding.compute_langs = cfg.compute_langs + cfg.multitask_decoding.preserve_alignments = cfg.preserve_alignment + asr_model.change_decoding_strategy(cfg.multitask_decoding) + elif cfg.decoder_type is not None: # TODO: Support compute_langs in CTC eventually if cfg.compute_langs and cfg.decoder_type == 'ctc': raise ValueError("CTC models do not support `compute_langs` at the moment") @@ -298,8 +311,18 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis else: cfg.decoding = cfg.rnnt_decoding - # prepare audio filepaths and decide wether it's partial audio - filepaths, partial_audio = prepare_audio_data(cfg) + if isinstance(asr_model, EncDecMultiTaskModel): + # Special case for EncDecMultiTaskModel, where the input manifest is directly passed into the model's transcribe() function + partial_audio = False + filepaths = cfg.dataset_manifest + assert cfg.dataset_manifest is not None + else: + # prepare audio filepaths and decide wether it's partial audio + filepaths, partial_audio = prepare_audio_data(cfg) + + if not cfg.allow_partial_transcribe: + # by defatul, use model's transcribe() function, unless partial audio is required + partial_audio = False # setup AMP (optional) if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py new file mode 100644 index 000000000000..cb981fafe08c --- /dev/null +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -0,0 +1,181 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Sequence + +import torch.utils.data +from lhotse import CutSet +from lhotse.cut import MixedCut, MonoCut +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_vectors + +from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.common.tokenizers import CanaryTokenizer, TokenizerSpec + + +class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): + """ + This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`. + It is a Lhotse-style dataset that converts a mini-batch of Cuts into tensors. + The main difference from ``LhotseSpeechToTextBpeDataset`` is that we introduce + a special prompt format for multitask encoder-decoder models. + + To perform the prompt formatting, we accept a ``prompt_format_fn``. + It's expected to accept: + * a ``CutSet`` which it will internally iterate over for utterances, and + * a ``TokenizerWrapper`` object that will be internally used to tokenize the utterances + + Tokenized utterances will be extended with special prompt tokens according to ``prompt_format_fn`` logic. + We support cuts with multiple supervision segments -- their tokenized texts will be concatenated before we add the prompt tokens. + This is useful, for example, in code-switched scenarios where each segment is spoken in a different language. + """ + + def __init__( + self, tokenizer: TokenizerSpec, prompt_format_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]] + ): + super().__init__() + self.tokenizer = TokenizerWrapper(tokenizer) + self.load_audio = AudioSamples(fault_tolerant=True) + self.padding_value = self.tokenizer._tokenizer.pad_id + self.prompt_format_fn = prompt_format_fn + + def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + audio, audio_lens, cuts = self.load_audio(cuts) + + tokens = self.prompt_format_fn(cuts, self.tokenizer) + tokens = [torch.as_tensor(t) for t in tokens] + token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) + tokens = collate_vectors(tokens, padding_value=self.padding_value) + + return audio, audio_lens, tokens, token_lens + + +# Mapping from a string name to a known prompt formatter function. +PROMPT_FORMAT_FNS = {} + + +def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]): + """ + Decorator for registering prompt functions under a name. + + Example:: + + >>> @registered_prompt_format_fn + ... def my_prompt(cuts, tokenizer): + ... pass + ... + ... prompt_fn = get_prompt_format_fn("my_prompt") + """ + global PROMPT_FORMAT_FNS + + PROMPT_FORMAT_FNS[prompt_fn.__name__] = prompt_fn + return prompt_fn + + +def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]: + if name not in PROMPT_FORMAT_FNS: + raise ValueError( + f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}" + ) + return PROMPT_FORMAT_FNS[name] + + +@registered_prompt_format_fn +def canary(cuts: CutSet, tokenizer: TokenizerWrapper) -> Sequence[Sequence[int]]: + """ + Prepend and append control tokens to the token sequence as per Canary format. + + We use the following special tokens: + * <|startoftranscript|> + * <|transcribe|> + * <|translate|> + * <|nopnc|> + * <|pnc|> + * <|endoftext|> + * <|LANG|> - for each supported language, where LANG is a 2-char language code. + * <|nospeech|> + + The prompt format syntax is as follows: + + <|startoftranscript|> [ <|nospeech|> | <|LANG|> [ <|transcribe|> | <|translate|> ] <|LANG|> [ <|pnc|> | <|nopnc|> ] TEXT <|endoftext|> ] + + Where expression ``[ a | b ]`` denotes expression ``a`` or expression ``b``, and can be nested. + Note that ``<|LANG|>`` appears twice: the first occurrence is for the "source" language + (i.e., spoken language in the recording) and the second occurrence is for the "target" language + (i.e., the language in which we are going to output the text). + """ + + assert isinstance( + tokenizer._tokenizer, CanaryTokenizer + ), "To use 'canary' prompt format, you must use the CanaryTokenizer." + tokenizer = tokenizer._tokenizer + + canary_tokens = [] + for cut in cuts: + if isinstance(cut, MixedCut): + cut = cut._first_non_padding_cut + assert isinstance(cut, MonoCut), "Expected MonoCut." + + # Actual tokenization. If a cut has multiple supervisions, we'll stitch their tokenized texts together. + tokens = sum((tokenizer.text_to_ids(sup.text, sup.language) for sup in cut.supervisions), start=[]) + + # bos + prompted_tokens = [tokenizer.bos_id] + + if len(tokens) == 0: + # no speech token + prompted_tokens.append(tokenizer.nospeech_id) + else: + # first, validate the utterance + missing_keys = [k for k in ("source_lang", "target_lang", "taskname", "pnc") if k not in cut.custom] + if missing_keys: + raise RuntimeError( + f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}" + f"Please ensure that every utterance in the input manifests contains these keys." + ) + + # src_lang_id/no_speech + src_lang_id = tokenizer.to_language_id(cut.custom['source_lang']) + prompted_tokens.append(src_lang_id) + + # task + task = cut.custom['taskname'] + if task == 'asr': + prompted_tokens.append(tokenizer.transcribe_id) + elif task == 's2t_translation': + prompted_tokens.append(tokenizer.translate_id) + else: + raise ValueError(f"Unknown task: {task} for cut ID: {cut.id}") + + # tgt_lang_id + tgt_lang_id = tokenizer.to_language_id(cut.custom['target_lang']) + prompted_tokens.append(tgt_lang_id) + + # PnC + pnc = f"{cut.custom['pnc']}".lower().strip() # to account for bool or str + if pnc in {'yes', 'true'}: + prompted_tokens.append(tokenizer.pnc_id) + elif pnc in {'no', 'false'}: + prompted_tokens.append(tokenizer.nopnc_id) + else: + raise ValueError(f"Unknown value for key 'pnc': {pnc} for cut ID: {cut.id}") + + # text + prompted_tokens.extend(tokens) + + # eos + prompted_tokens.append(tokenizer.eos_id) + + canary_tokens.append(prompted_tokens) + + return canary_tokens diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 34f2c4f62e29..2b5659066daa 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel from nemo.collections.asr.models.asr_model import ASRModel from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel from nemo.collections.asr.models.classification_models import EncDecClassificationModel, EncDecFrameClassificationModel diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py new file mode 100644 index 000000000000..ede61f8a1ff9 --- /dev/null +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -0,0 +1,710 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json +import os +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union + +import editdistance +import torch +import torch.distributed as dist +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning import Trainer +from torchmetrics.text import SacreBLEUScore +from tqdm.auto import tqdm + +from nemo.collections.asr.data.audio_to_text_lhotse_prompted import ( + PromptedAudioToTextLhotseDataset, + get_prompt_format_fn, +) +from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel +from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig +from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier +from nemo.collections.asr.parts.utils import manifest_utils +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.common import tokenizers +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.metrics import GlobalAverageLossMetric +from nemo.collections.common.parts import transformer_weights_init +from nemo.collections.common.parts.preprocessing.manifest import get_full_path +from nemo.core.classes.common import typecheck +from nemo.core.neural_types import ( + AudioSignal, + ChannelType, + LabelsType, + LengthsType, + LogprobsType, + MaskType, + NeuralType, + SpectrogramType, +) +from nemo.utils import logging, model_utils + +__all__ = ['EncDecMultiTaskModel'] + + +def lens_to_mask(lens, max_length): + batch_size = lens.shape[0] + mask = torch.arange(max_length).repeat(batch_size, 1).to(lens.device) < lens[:, None] + return mask + + +class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin): + """Base class for AED multi-task models""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + + # Convert to Hydra 1.0 compatible DictConfig + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + if 'tokenizer' not in cfg: + raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") + + # Setup the tokenizer + self._setup_tokenizer(cfg.tokenizer) + + # Assert config has "prompt_format" + if "prompt_format" not in cfg: + raise ValueError("`cfg` must have `prompt_format` config to create a multi task model !") + self.prompt_format = cfg.prompt_format + + super().__init__(cfg=cfg, trainer=trainer) + + # Setup audio preprocessor + self.preprocessor = EncDecMultiTaskModel.from_config_dict(self.cfg.preprocessor) + + # Setup audio encoder + self.encoder = EncDecMultiTaskModel.from_config_dict(self.cfg.encoder) + + # Assert config has `model_defaults` + if 'model_defaults' not in self.cfg: + raise ValueError("`cfg` must have `model_defaults` config to create a model !") + if "asr_enc_hidden" not in self.cfg.model_defaults: + raise ValueError("`cfg.model_defaults` must have `asr_enc_hidden` key !") + if "lm_enc_hidden" not in self.cfg.model_defaults: + raise ValueError("`cfg.model_defaults` must have `lm_enc_hidden` key !") + if "lm_dec_hidden" not in self.cfg.model_defaults: + raise ValueError("`cfg.model_defaults` must have `lm_dec_hidden` key !") + + # Add projection layer if encoder and decoder differ in hidden size + asr_enc_hidden_size = self.cfg.model_defaults.asr_enc_hidden + decoder_hidden_size = self.cfg.model_defaults.lm_dec_hidden + if asr_enc_hidden_size != decoder_hidden_size: + self.encoder_decoder_proj = torch.nn.Linear(asr_enc_hidden_size, decoder_hidden_size) + else: + self.encoder_decoder_proj = torch.nn.Identity() + + transf_encoder_cfg_dict = self.cfg.get('transf_encoder', None) + + # Whether to add Transformer Encoder block between Conformer and Transformer Decoder + self.use_transf_encoder = False + if transf_encoder_cfg_dict is not None and transf_encoder_cfg_dict['num_layers'] > 0: + self.use_transf_encoder = True + + self.transf_encoder = EncDecMultiTaskModel.from_config_dict(transf_encoder_cfg_dict) + + # Initialize weights + std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden ** 0.5 + self.transf_encoder.apply(lambda module: transformer_weights_init(module, std_init_range)) + + transf_decoder_cfg_dict = cfg.transf_decoder + + # Transformer decoder + vocab_size = 8 * ceil(self.tokenizer.vocab_size / 8) + + # Auto inject vocab size for `get_transformer` + with open_dict(transf_decoder_cfg_dict): + if 'config_dict' in transf_decoder_cfg_dict: + transf_decoder_cfg_dict['config_dict']['vocab_size'] = vocab_size + + self.transf_decoder = EncDecMultiTaskModel.from_config_dict(transf_decoder_cfg_dict) + + # Setup token classifier + with open_dict(self.cfg.head): + self.cfg.head.num_classes = vocab_size + + self.log_softmax = EncDecMultiTaskModel.from_config_dict(self.cfg.head) + + # Weight tying - if using TokenClassifier only + if isinstance(self.log_softmax, TokenClassifier): + self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight + + # Initialize weights + std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5 + self.transf_decoder.apply(lambda module: transformer_weights_init(module, std_init_range)) + self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) + + # Setup decoding objects + decoding_cfg = self.cfg.get('decoding', None) + + # In case decoding config not found, use default config + if decoding_cfg is None: + decoding_cfg = OmegaConf.structured(MultiTaskDecodingConfig) + with open_dict(self.cfg): + self.cfg.decoding = decoding_cfg + + self.decoding = MultiTaskDecoding( + decoding_cfg=self.cfg.decoding, + transformer_decoder=self.transf_decoder, + log_softmax_module=self.log_softmax, + tokenizer=self.tokenizer, + ) + + self.context_len_for_AR_decoding = self.cfg.get("context_len_for_AR_decoding", 5) + + # Define autoregressive CE loss + with open_dict(self.cfg.loss): + self.cfg.loss.pad_id = self.tokenizer.pad_id + + self.loss = EncDecMultiTaskModel.from_config_dict(self.cfg.loss) + + if hasattr(self.cfg, 'spec_augment') and self.cfg.spec_augment is not None: + self.spec_augmentation = EncDecMultiTaskModel.from_config_dict(self.cfg.spec_augment) + else: + self.spec_augmentation = None + + self.val_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True) + + def change_decoding_strategy(self, decoding_cfg: DictConfig): + """ + Changes decoding strategy used during Multi Task decoding process. + + Args: + decoding_cfg: A config for the decoder, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + """ + if decoding_cfg is None: + # Assume same decoding config as before + logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config") + decoding_cfg = self.cfg.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(MultiTaskDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.decoding = MultiTaskDecoding( + decoding_cfg=decoding_cfg, + transformer_decoder=self.transf_decoder, + log_softmax_module=self.log_softmax, + tokenizer=self.tokenizer, + ) + + # Update config + with open_dict(self.cfg.decoding): + self.cfg.decoding = decoding_cfg + + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + + @torch.no_grad() + def transcribe( + self, + paths2audio_files: Union[List[str], str], + batch_size: int = 4, + logprobs: Optional[bool] = None, + return_hypotheses: bool = False, + num_workers: int = 0, + channel_selector: Optional[ChannelSelectorType] = None, + augmentor: DictConfig = None, + verbose: bool = True, + ) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + num_workers: (int) number of workers for DataLoader + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + verbose: (bool) whether to display tqdm progress bar + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + + # get ready for new transcribe API + if logprobs is not None: + logging.warning("logprobs is deprecated, please use return_hypotheses instead") + return_hypotheses = logprobs + audio = paths2audio_files + + if audio is None or len(audio) == 0: + return {} + + if return_hypotheses: + logging.warning("return_hypotheses=True is currently not supported, returning text instead.") + + manifest_path = None + if isinstance(audio, list): + logging.debug(f"Found 'paths2audio_files' to be a list of {len(audio)} items.") + logging.debug(f"Assuming each item in 'audio' is a path to audio file.") + + if isinstance(self.tokenizer, tokenizers.AggregateTokenizer): + primary_language = self.tokenizer.langs[0] + logging.debug(f"Transcribing with default setting of {primary_language}.") + + elif isinstance(audio, str): + logging.debug(f"Found 'paths2audio_files' to be a string. Assuming it is a path to manifest file.") + assert os.path.exists(audio), f"File {audio} doesn't exist" + assert audio.endswith('.json') or audio.endswith('.jsonl'), f"File {audio} must be a json or jsonl file" + + # load json lines + manifest_path = audio # need to save this as we are overwriting paths2audio_files in nextline + audio = manifest_utils.read_manifest(manifest_path) + + def _may_be_make_dict_and_fix_paths(json_items, manifest_path): + out_json_items = [] + for item in json_items: + if isinstance(item, str): + # assume it is a path to audio file + entry = { + 'audio_filepath': item, + 'duration': 100000, + 'source_lang': 'en', + 'taskname': 'asr', + 'target_lang': 'en', + 'pnc': 'yes', + 'answer': 'nothing', + } + elif isinstance(item, dict): + entry = item + entry['audio_filepath'] = get_full_path(entry['audio_filepath'], manifest_file=manifest_path) + else: + raise ValueError(f"Expected str or dict, got {type(item)}") + out_json_items.append(entry) + return out_json_items + + paths2audio_files = _may_be_make_dict_and_fix_paths(audio, manifest_path) + + if num_workers is None: + num_workers = min(batch_size, os.cpu_count() - 1) + + # We will store transcriptions here + hypotheses = [] + + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + dither_value = self.preprocessor.featurizer.dither + pad_to_value = self.preprocessor.featurizer.pad_to + + try: + self.preprocessor.featurizer.dither = 0.0 + self.preprocessor.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + self.encoder.freeze() + self.transf_decoder.freeze() + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + fp.write(json.dumps(audio_file) + '\n') + + config = { + 'paths2audio_files': paths2audio_files, + 'batch_size': batch_size, + 'temp_dir': tmpdir, + 'num_workers': num_workers, + 'channel_selector': channel_selector, + } + + if augmentor: + config['augmentor'] = augmentor + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose): + log_probs, encoded_len, enc_states, enc_mask = self.forward( + input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + ) + + beam_hypotheses = self.decoding.decode_predictions_tensor( + encoder_hidden_states=enc_states, + encoder_input_mask=enc_mask, + decoder_input_ids=test_batch[2][:, : self.context_len_for_AR_decoding].to(device) + if self.context_len_for_AR_decoding > 0 + else None, + return_hypotheses=False, + )[0] + + beam_hypotheses = [self.decoding.strip_special_tokens(text) for text in beam_hypotheses] + + hypotheses += beam_hypotheses + + del test_batch, log_probs, encoded_len, enc_states, enc_mask + finally: + # set mode back to its original value + self.train(mode=mode) + self.preprocessor.featurizer.dither = dither_value + self.preprocessor.featurizer.pad_to = pad_to_value + if mode is True: + self.encoder.unfreeze() + self.transf_decoder.unfreeze() + logging.set_verbosity(logging_level) + + return hypotheses + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + assert config.get("use_lhotse", False), ( + "Multi-task model only supports dataloading with Lhotse. " + "Please set config.{train,validation,test}_ds.use_lhotse=True" + ) + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=PromptedAudioToTextLhotseDataset( + tokenizer=self.tokenizer, prompt_format_fn=get_prompt_format_fn(self.prompt_format), + ), + ) + + def setup_training_data(self, train_data_config: Optional[DictConfig]): + + # create audio-only data loader + self._update_dataset_config(dataset_name='train', config=train_data_config) + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + # Need to set this because if using an IterableDataset, the length of the + # dataloader is the total number of samples rather than the number of batches, + # and this messes up the tqdm progress bar. So we set the number of steps manually + # (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, + # i.e. <= # training batches, and don't change it. Otherwise, adjust + # batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text_lhotse_prompted.PromptedAudioToTextLhotseDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text_lhotse_prompted.PromptedAudioToTextLhotseDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + input_signal_eltype = AudioSignal() + return { + "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "transcript": NeuralType(('B', 'T'), LabelsType(), optional=True), + "transcript_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "sample_id": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "transf_log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "encoder_states": NeuralType(('B', 'T', 'D'), ChannelType()), + "encoder_mask": NeuralType(('B', 'T'), MaskType()), + } + + @typecheck() + def forward( + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + transcript=None, + transcript_length=None, + ): + """ + Forward pass of the model. + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + processed_signal: Tensor that represents a batch of processed audio signals, + of shape (B, D, T). + processed_signal_length: Vector of length B, that contains the individual lengths of the + processed audio sequences. + Returns: + A tuple of 3 elements - + 1) The log probabilities tensor of shape [B, T, D]. + 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. + 3) The greedy token predictions of the model of shape [B, T] (via argmax) + """ + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) == False: + raise ValueError( + f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal: + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, length=input_signal_length + ) + + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + + enc_states = encoded.permute(0, 2, 1) + enc_states = self.encoder_decoder_proj(enc_states) + enc_mask = lens_to_mask(encoded_len, enc_states.shape[1]).to(enc_states.dtype) + if self.use_transf_encoder: + enc_states = self.transf_encoder(encoder_states=enc_states, encoder_mask=enc_mask) + + transf_log_probs = None + if transcript is not None: + dec_mask = lens_to_mask(transcript_length, transcript.shape[1]).to(transcript.dtype) + dec_states = self.transf_decoder( + input_ids=transcript, decoder_mask=dec_mask, encoder_embeddings=enc_states, encoder_mask=enc_mask + ) + transf_log_probs = self.log_softmax(hidden_states=dec_states) + + return transf_log_probs, encoded_len, enc_states, enc_mask + + def compute_loss( + self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None + ) -> torch.Tensor: + """ + Run forward pass through the model and compute the loss. + + Args: + batch: a tuple of 4 tensors (signal, signal_len, tokens, tokens_len) as returned + by :class:`~nemo.collections.asr.data.audio_to_text_lhotse_prompted.PromptedAudioToTextLhotseDataset`. + When batch is ``None``, we'll return a zero tensor. + Returns: + The computed loss value as a single-element tensor. + """ + + if batch is None: + return torch.tensor([0.0]) + + signal, signal_len, transcript, transcript_len = batch + input_ids, labels = transcript[:, :-1], transcript[:, 1:] + + transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( + input_signal=signal, + input_signal_length=signal_len, + transcript=input_ids, + transcript_length=transcript_len, + ) + + transf_loss = self.loss(log_probs=transf_log_probs, labels=labels) + + return transf_loss + + # PTL-specific methods + def training_step(self, batch, batch_nb): + + audio_loss = self.compute_loss(batch) + + tensorboard_logs = { + 'train_loss': audio_loss, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + } + + return {'loss': audio_loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): + signal, signal_len, transcript, transcript_len = batch + input_ids, labels = transcript[:, :-1], transcript[:, 1:] + + transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( + input_signal=signal, + input_signal_length=signal_len, + transcript=input_ids, + transcript_length=transcript_len, + ) + + beam_hypotheses = self.decoding.decode_predictions_tensor( + encoder_hidden_states=enc_states, + encoder_input_mask=enc_mask, + decoder_input_ids=input_ids[:, : self.context_len_for_AR_decoding] + if self.context_len_for_AR_decoding > 0 + else None, + return_hypotheses=False, + )[0] + + transf_loss = self.loss(log_probs=transf_log_probs, labels=labels) + + ground_truths = [self.tokenizer.ids_to_text(sent) for sent in transcript.detach().cpu().tolist()] + translations = [hyp for hyp in beam_hypotheses] + + self.val_loss(loss=transf_loss, num_measurements=transf_log_probs.shape[0] * transf_log_probs.shape[1]) + + output_dict = { + f'{eval_mode}_loss': transf_loss, + 'translations': [self.decoding.strip_special_tokens(t) for t in translations], + 'ground_truths': [self.decoding.strip_special_tokens(g) for g in ground_truths], + } + + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(output_dict) + else: + self.validation_step_outputs.append(output_dict) + + return output_dict + + def test_step(self, batch, batch_idx, dataloader_idx=0): + return self.validation_step(batch, batch_idx, dataloader_idx, eval_mode="test") + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, eval_mode: str = "val"): + """ + Called at the end of validation to aggregate outputs. + :param outputs: list of individual outputs of each validation step. + """ + if not outputs: + return + + if isinstance(outputs[0], dict): + outputs = [outputs] + + for output in outputs: + eval_loss = getattr(self, 'val_loss').compute() + translations = list(itertools.chain(*[x['translations'] for x in output])) + ground_truths = list(itertools.chain(*[x['ground_truths'] for x in output])) + + # Gather translations and ground truths from all workers + tr_and_gt = [None for _ in range(self.world_size)] + # we also need to drop pairs where ground truth is an empty string + if self.world_size > 1: + dist.all_gather_object( + tr_and_gt, [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] + ) + else: + tr_and_gt[0] = [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] + + if self.global_rank == 0: + _translations = [] + _ground_truths = [] + for rank in range(0, self.world_size): + _translations += [t for (t, g) in tr_and_gt[rank]] + _ground_truths += [g for (t, g) in tr_and_gt[rank]] + + sacre_bleu = SacreBLEUScore()(_translations, [[x] for x in _ground_truths]).item() + sb_score = sacre_bleu * self.world_size + + wer_scores, wer_words = 0, 0 + for h, r in zip(_translations, _ground_truths): + wer_words += len(r.split()) + wer_scores += editdistance.eval(h.split(), r.split()) + wer_score = 1.0 * wer_scores * self.world_size / wer_words + + else: + sb_score = 0.0 + wer_score = 0.0 + + # logging here only. + dataloader_prefix = self.get_validation_dataloader_prefix(dataloader_idx) + self.log(f"{dataloader_prefix}{eval_mode}_loss", eval_loss, sync_dist=True) + self.log(f"{dataloader_prefix}{eval_mode}_sacreBLEU", sb_score, sync_dist=True) + self.log(f"{dataloader_prefix}{eval_mode}_WER", wer_score, sync_dist=True) + + # in multi-validation case, anything after first one will become NaN + # as we are resetting the metric here. + # TODO: fix this, (not sure which hook will be ideal for this) + self.val_loss.reset() + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_validation_epoch_end(outputs, dataloader_idx, eval_mode="test") + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + Returns: + A pytorch DataLoader for the given audio file(s). + """ + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + 'sample_rate': self.preprocessor._sample_rate, + 'batch_size': batch_size, + 'trim_silence': False, + 'shuffle': False, + 'num_workers': min(batch_size, os.cpu_count() - 1), + 'pin_memory': True, + 'use_lhotse': True, + 'use_bucketing': False, + 'drop_last': False, + 'text_field': 'answer', + 'lang_field': 'target_lang', + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index e1bef823f5f5..b526867b6927 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -22,15 +22,22 @@ import editdistance import torch import torch.distributed as dist -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer +from torchmetrics.text import SacreBLEUScore from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel +from nemo.collections.asr.modules.transformer import ( + BeamSearchSequenceGenerator, + TransformerEncoder, + get_nemo_transformer, +) from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.losses import SmoothedCrossEntropyLoss @@ -49,18 +56,6 @@ ) from nemo.utils import logging -try: - from sacrebleu import corpus_bleu - - from nemo.collections.nlp.modules.common import TokenClassifier - from nemo.collections.nlp.modules.common.lm_utils import get_transformer - from nemo.collections.nlp.modules.common.transformer import BeamSearchSequenceGenerator, TransformerEncoder - - NLP_AVAILABLE = True -except (ImportError, ModuleNotFoundError): - NLP_AVAILABLE = False - logging.warning("Could not import NeMo NLP collection which is required for speech translation model.") - __all__ = ['EncDecTransfModelBPE'] @@ -123,10 +118,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): vocab_size = 8 * ceil(self.tokenizer.vocab_size / 8) transf_decoder_cfg_dict['vocab_size'] = vocab_size library = transf_decoder_cfg_dict.pop('library', 'nemo') + if library != 'nemo': + raise ValueError(f"Currently only 'nemo' library is supported for Transformer decoder. Got {library}") model_name = transf_decoder_cfg_dict.pop('model_name', None) pretrained = transf_decoder_cfg_dict.pop('pretrained', False) - self.transf_decoder = get_transformer( - library=library, + self.transf_decoder = get_nemo_transformer( model_name=model_name, pretrained=pretrained, config_dict=transf_decoder_cfg_dict, @@ -141,6 +137,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): log_softmax=self.cfg.head.log_softmax, dropout=self.cfg.head.dropout, use_transformer_init=self.cfg.head.use_transformer_init, + num_layers=self.cfg.head.num_layers, ) self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight std_init_range = 1 / self.transf_decoder.hidden_size ** 0.5 @@ -290,9 +287,18 @@ def transcribe( return hypotheses - def _setup_dataloader_from_config(self, config: Optional[Dict]): + def _update_default_values(self, config: DictConfig): + if self.training: # don't do anything for training + return config + with open_dict(config): + for k, v in self.cfg.train_ds.items(): + if k not in config: + config[k] = v + return config + def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): + config = self._update_default_values(config) return get_lhotse_dataloader_from_config( config, global_rank=self.global_rank, @@ -586,8 +592,8 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, eval_mode _translations += [t for (t, g) in tr_and_gt[rank]] _ground_truths += [g for (t, g) in tr_and_gt[rank]] - sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="13a") - sb_score = sacre_bleu.score * self.world_size + sacre_bleu = SacreBLEUScore()(_translations, [[x] for x in _ground_truths]).item() + sb_score = sacre_bleu * self.world_size wer_scores, wer_words = 0, 0 for h, r in zip(_translations, _ground_truths): diff --git a/nemo/collections/asr/modules/transformer/__init__.py b/nemo/collections/asr/modules/transformer/__init__.py index 4da13981c4b7..dc392de020b0 100644 --- a/nemo/collections/asr/modules/transformer/__init__.py +++ b/nemo/collections/asr/modules/transformer/__init__.py @@ -12,10 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.asr.modules.transformer.bridge_encoders import * -from nemo.collections.asr.modules.transformer.perceiver_encoders import * -from nemo.collections.asr.modules.transformer.transformer_bottleneck import * -from nemo.collections.asr.modules.transformer.transformer_decoders import * -from nemo.collections.asr.modules.transformer.transformer_encoders import * -from nemo.collections.asr.modules.transformer.transformer_generators import * -from nemo.collections.asr.modules.transformer.transformer_modules import * +from nemo.collections.asr.modules.transformer.bridge_encoders import BridgeEncoder +from nemo.collections.asr.modules.transformer.perceiver_encoders import PerceiverEncoder +from nemo.collections.asr.modules.transformer.transformer_bottleneck import ( + NeMoTransformerBottleneckConfig, + NeMoTransformerBottleneckDecoderConfig, + NeMoTransformerBottleneckEncoderConfig, + TransformerBottleneckEncoderNM, +) +from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder +from nemo.collections.asr.modules.transformer.transformer_encoders import TransformerEncoder +from nemo.collections.asr.modules.transformer.transformer_generators import ( + BeamSearchSequenceGenerator, + BeamSearchSequenceGeneratorWithLanguageModel, + EnsembleBeamSearchSequenceGenerator, + GreedySequenceGenerator, + TopKSequenceGenerator, +) +from nemo.collections.asr.modules.transformer.transformer_modules import AttentionBridge, TransformerEmbedding +from nemo.collections.asr.modules.transformer.transformer_utils import get_nemo_transformer diff --git a/nemo/collections/asr/modules/transformer/transformer_decoders.py b/nemo/collections/asr/modules/transformer/transformer_decoders.py index cfe2e9229f92..a5b2c299393c 100644 --- a/nemo/collections/asr/modules/transformer/transformer_decoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_decoders.py @@ -180,31 +180,34 @@ def forward( decoder_attn_mask = form_attention_mask(decoder_mask, diagonal=self.diagonal) encoder_attn_mask = form_attention_mask(encoder_mask) memory_states = self._get_memory_states(decoder_states, decoder_mems_list, 0) - if return_mems_as_list: - cached_mems_list = [memory_states] - else: - cached_mems_list = memory_states.unsqueeze(0) + if return_mems: + if return_mems_as_list: + cached_mems_list = [memory_states] + else: + cached_mems_list = memory_states.unsqueeze(0) for i, layer in enumerate(self.layers): decoder_states = layer(decoder_states, decoder_attn_mask, memory_states, encoder_states, encoder_attn_mask) memory_states = self._get_memory_states(decoder_states, decoder_mems_list, i + 1) - if return_mems_as_list: - cached_mems_list.append(memory_states) - else: - cached_mems_list = torch.cat((cached_mems_list, memory_states.unsqueeze(0)), dim=0) + if return_mems: + if return_mems_as_list: + cached_mems_list.append(memory_states) + else: + cached_mems_list = torch.cat((cached_mems_list, memory_states.unsqueeze(0)), dim=0) if self.final_layer_norm is not None: decoder_states = self.final_layer_norm(decoder_states) memory_states = self._get_memory_states(decoder_states, decoder_mems_list, i + 2) - if return_mems_as_list: - cached_mems_list.append(memory_states) - else: - cached_mems_list = torch.cat((cached_mems_list, memory_states.unsqueeze(0)), dim=0) + if return_mems: + if return_mems_as_list: + cached_mems_list.append(memory_states) + else: + cached_mems_list = torch.cat((cached_mems_list, memory_states.unsqueeze(0)), dim=0) if return_mems: return cached_mems_list else: - return cached_mems_list[-1] + return memory_states def input_example(self, max_batch=1, max_dim=256): """ diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 6e17151dcd1b..0a30efcee272 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -301,7 +301,7 @@ def _forward( scores, prefixes = scores.view(-1, 1), prefixes.view(-1, 1) # repeat init target prefixes and cached memory states beam_size times - prefixes = torch.cat((tgt.repeat(1, self.beam_size).view(-1, 1), prefixes), dim=1) + prefixes = torch.cat((tgt.repeat(1, self.beam_size).view(-1, tgt.shape[1]), prefixes), dim=1) for j in range(len(decoder_mems_list)): decoder_mems_list[j] = decoder_mems_list[j].repeat(self.beam_size, 1, 1) diff --git a/nemo/collections/asr/modules/transformer/transformer_utils.py b/nemo/collections/asr/modules/transformer/transformer_utils.py index a74a30368568..da9ffb8fbd00 100644 --- a/nemo/collections/asr/modules/transformer/transformer_utils.py +++ b/nemo/collections/asr/modules/transformer/transformer_utils.py @@ -20,6 +20,8 @@ from nemo.collections.asr.modules.transformer.transformer import TransformerDecoderNM, TransformerEncoderNM from nemo.collections.asr.modules.transformer.transformer_bottleneck import TransformerBottleneckEncoderNM +__all__ = ['get_nemo_transformer'] + def get_nemo_transformer( model_name: Optional[str] = None, @@ -130,45 +132,3 @@ def get_nemo_transformer( ) return model - - -# def get_huggingface_transformer( -# model_name: Optional[str] = None, -# pretrained: bool = False, -# config_dict: Optional[Union[dict, DictConfig]] = None, -# encoder: bool = True, -# ) -> Union[HuggingFaceEncoderModule, HuggingFaceDecoderModule]: - -# if encoder: -# model = HuggingFaceEncoderModule(model_name, pretrained, config_dict) -# else: -# model = HuggingFaceDecoderModule(model_name, pretrained, config_dict) - -# return model - - -def get_megatron_transformer( - model_name: Optional[str] = None, - pretrained: bool = True, - config_dict: Optional[Union[dict, DictConfig]] = None, - encoder: bool = True, - checkpoint_file: str = None, -) -> None: - - raise ValueError( - "megatron-lm bert encoders are deprecated in NeMo 1.5.0. Please use NeMo 1.4.0 until megatron bert support is added again." - ) - - # vocab_file = config_dict.pop('vocab_file', None) - # if encoder: - # model = MegatronEncoderModule( - # model_name=model_name, - # pretrained=pretrained, - # config_dict=config_dict, - # checkpoint_file=checkpoint_file, - # vocab_file=vocab_file, - # ) - # else: - # raise ValueError('Megatron decoders are not currently supported.') - - # return model diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index f3efea9e40f1..eeac9d3c78ad 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -214,7 +214,13 @@ def _setup_aggregate_tokenizer(self, tokenizer_cfg: DictConfig): self.AGGREGATE_TOKENIZERS_DICT_PREFIX ][lang]['type'] - self.tokenizer = tokenizers.AggregateTokenizer(tokenizers_dict) + if "custom_tokenizer" in tokenizer_cfg: + # Class which implements this is usually a ModelPT, has access to Serializable mixin by extension + self.tokenizer = self.from_config_dict( + {"_target_": tokenizer_cfg["custom_tokenizer"]["_target_"], "tokenizers": tokenizers_dict} + ) + else: + self.tokenizer = tokenizers.AggregateTokenizer(tokenizers_dict) def _make_tokenizer(self, tokenizer_cfg: DictConfig, lang=None): diff --git a/nemo/collections/asr/parts/submodules/classifier.py b/nemo/collections/asr/parts/submodules/classifier.py new file mode 100644 index 000000000000..7d9e42593c1c --- /dev/null +++ b/nemo/collections/asr/parts/submodules/classifier.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import torch +from torch import nn as nn + +from nemo.collections.common.parts import transformer_weights_init +from nemo.core.classes import Exportable, NeuralModule +from nemo.core.neural_types import ChannelType, NeuralType + +__all__ = ['Classifier'] + + +class Classifier(NeuralModule, Exportable): + """ + A baseclass for modules to perform various classification tasks. + """ + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + """ + Returns definitions of module input ports. + We implement it here since all NLP classifiers have the same inputs + """ + return {"hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} + + def __init__(self, hidden_size: int, dropout: float = 0.0,) -> None: + """ + Initializes the Classifier base module. + Args: + hidden_size: the size of the hidden dimension + dropout: dropout to apply to the input hidden states + """ + super().__init__() + self._hidden_size = hidden_size + self.dropout = nn.Dropout(dropout) + + def post_init(self, use_transformer_init: bool): + """ + Common post-processing to be called at the end of concrete Classifiers init methods + Args: + use_transformer_init : whether or not to apply transformer_weights_init + """ + if use_transformer_init: + self.apply(lambda module: transformer_weights_init(module, xavier=False)) + + def input_example(self, max_batch=1, max_dim=256): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + sample = next(self.parameters()) + example = torch.randn(max_batch, max_dim, self._hidden_size).to(sample.device).to(sample.dtype) + return tuple([example]) + + def save_to(self, save_path: str): + """ + Saves the module to the specified path. + Args: + save_path: Path to where to save the module. + """ + pass + + @classmethod + def restore_from(cls, restore_path: str): + """ + Restores the module from the specified path. + Args: + restore_path: Path to restore the module from. + """ + pass diff --git a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py new file mode 100644 index 000000000000..fc8d66f81719 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py @@ -0,0 +1,218 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from nemo.collections.asr.modules.transformer import BeamSearchSequenceGenerator +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.core import Typing, typecheck +from nemo.core.neural_types import ChannelType, HypothesisType, LabelsType, MaskType, NeuralType +from nemo.utils import logging + + +def pack_hypotheses( + hypotheses: List[Hypothesis], beam_hypotheses: torch.Tensor, scores: List[Optional[float]] +) -> List[Hypothesis]: + + for idx, hyp in enumerate(hypotheses): # type: Hypothesis + if scores[idx] is not None: + hyp.score = scores[idx] + + hypi = beam_hypotheses[idx] + if torch.is_tensor(hypi): + hyp.y_sequence = hypi.long() + else: + hyp.y_sequence = torch.tensor(hypi, dtype=torch.long) + + if hyp.dec_state is not None: + hyp.dec_state = _states_to_device(hyp.dec_state) + + return hypotheses + + +def _states_to_device(dec_state, device='cpu'): + if torch.is_tensor(dec_state): + dec_state = dec_state.to(device) + + elif isinstance(dec_state, (list, tuple)): + dec_state = tuple(_states_to_device(dec_i, device) for dec_i in dec_state) + + return dec_state + + +class AEDBeamInfer(ABC): + def __init__( + self, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + search_type: str = 'default', + return_best_hypothesis: bool = True, + preserve_alignments: bool = False, + ): + super().__init__() + + self.transformer_decoder = transformer_decoder + self.log_softmax_module = log_softmax_module + self.tokenizer = tokenizer + + self.search_type = search_type + self.return_best_hypothesis = return_best_hypothesis + self.preserve_alignments = preserve_alignments + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @abstractmethod + def forward( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + return_scores: bool = False, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ): + raise NotImplementedError() + + def set_decoding_type(self, decoding_type: str): + self.decoding_type = decoding_type + + +class TransformerAEDBeamInfer(AEDBeamInfer, Typing): + """A beam decoder engine for AED Transformer models. + + Provides a common abstraction for batch level beam decoding. + + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + # Input can be of dimention - + # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] + + return { + "encoder_hidden_states": NeuralType(tuple(('B', 'T', 'D')), ChannelType()), + "encoder_input_mask": NeuralType(tuple(('B', 'T')), MaskType()), + "decoder_input_ids": NeuralType(('B', 'T'), LabelsType()), + "return_scores": NeuralType(optional=True), + "partial_hypotheses": NeuralType(optional=True), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"predictions": [NeuralType(elements_type=HypothesisType())]} + + def __init__( + self, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + search_type: str = 'default', + beam_size: int = 1, + length_penalty: float = 0.0, + max_generation_delta: int = 50, + return_best_hypothesis: bool = True, + preserve_alignments: bool = False, + ): + super().__init__( + transformer_decoder=transformer_decoder, + log_softmax_module=log_softmax_module, + tokenizer=tokenizer, + search_type=search_type, + return_best_hypothesis=return_best_hypothesis, + preserve_alignments=preserve_alignments, + ) + + self.beam_search = BeamSearchSequenceGenerator( + embedding=transformer_decoder.embedding, + decoder=transformer_decoder.decoder, + log_softmax=log_softmax_module, + max_sequence_length=transformer_decoder.max_sequence_length, + beam_size=beam_size, + bos=tokenizer.bos_id, + pad=tokenizer.pad_id, + eos=tokenizer.eos_id, + len_pen=length_penalty, + max_delta_length=max_generation_delta, + ) + + self.preserve_alignments = preserve_alignments + if self.preserve_alignments: + logging.info( + "Preservation of alignments was requested but {} does not implement it.".format( + self.__class__.__name__ + ) + ) + + @typecheck() + def forward( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + return_scores: bool = False, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label). + decoder_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + with torch.inference_mode(): + hypotheses = [ + Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(encoder_hidden_states.shape[0]) + ] + beam_hypotheses = self.beam_search( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + return_beam_scores=return_scores, + ) + + if return_scores: + _, beam_scores, beam_hypotheses = beam_hypotheses + beam_scores = beam_scores.detach().cpu() + else: + beam_scores = [None for _ in range(len(beam_hypotheses))] + beam_hypotheses = beam_hypotheses.detach().cpu() + + # Pack results into Hypotheses + packed_result = pack_hypotheses(hypotheses, beam_hypotheses, beam_scores) + + return (packed_result,) + + +@dataclass +class AEDBeamInferConfig: + beam_size: int = 5 + search_type: str = 'default' + len_pen: float = 1.0 + max_generation_delta: int = 20 + return_best_hypothesis: bool = True + preserve_alignments: bool = False diff --git a/nemo/collections/asr/parts/submodules/multitask_decoding.py b/nemo/collections/asr/parts/submodules/multitask_decoding.py new file mode 100644 index 000000000000..ac21484f58da --- /dev/null +++ b/nemo/collections/asr/parts/submodules/multitask_decoding.py @@ -0,0 +1,492 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, is_dataclass +from typing import List, Optional, Tuple, Union + +import torch +from omegaconf import OmegaConf + +from nemo.collections.asr.parts.submodules.multitask_beam_decoding import ( + AEDBeamInfer, + AEDBeamInferConfig, + TransformerAEDBeamInfer, +) +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.utils import logging + + +class AbstractMultiTaskDecoding(ABC): + """ + Used for performing AED auto-regressive decoding of the Multi task model given the encoder state. + + Args: + decoding_cfg: A dict-like object which contains the following key-value pairs. + strategy: str value which represents the type of decoding that can occur. + Possible values are : + - greedy, greedy_batch (for greedy decoding). + - beam, tsd, alsd (for beam search decoding). + + return_scores: bool flag which determines whether to return the scores of the hypotheses. + + compute_langs: a bool flag, which allows to compute language id (LID) information per token, + word, and the entire sample (most likely language id). The LIDS will be available + in the returned Hypothesis object as a dictionary + + compute_hypothesis_token_set: A bool flag, which determines whether to compute a list of decoded + tokens as well as the decoded string. Default is False in order to avoid double decoding + unless required. + + preserve_alignments: Bool flag which preserves the history of logprobs generated during + decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `alignments` in it. Here, `alignments` is a List of List of + Tuple(Tensor (of length V + 1), Tensor(scalar, label after argmax)). + + In order to obtain this hypothesis, please utilize `rnnt_decoder_predictions_tensor` function + with the `return_hypotheses` flag set to True. + + The config may further contain the following sub-dictionaries: + "greedy": + max_symbols: int, describing the maximum number of target tokens to decode per + timestep during greedy decoding. Setting to larger values allows longer sentences + to be decoded, at the cost of increased execution time. + preserve_frame_confidence: Same as above, overrides above value. + confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg. + + "beam": + beam_size: int, defining the beam size for beam search. Must be >= 1. + If beam_size == 1, will perform cached greedy search. This might be slightly different + results compared to the greedy search above. + + length_penalty: float, length penalty for beam search decoding. Must be >= 0.0. + + max_generation_delta: int, maximum number of additional target tokens to generate + + return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the + hypotheses after beam search has concluded. This flag is set by default. + + + transformer_decoder: Transformer decoder module. + log_softmax_module: Log Softmax projection module to the vocab size. + tokenizer: Aggregate Tokenizer. + """ + + def __init__( + self, + decoding_cfg, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + ): + super().__init__() + + # Convert dataclass to config object + if is_dataclass(decoding_cfg): + decoding_cfg = OmegaConf.structured(decoding_cfg) + + self.cfg = decoding_cfg + + self.preserve_alignments = self.cfg.get('preserve_alignments', None) + self.compute_langs = self.cfg.get('compute_langs', False) + self.return_scores = self.cfg.get('return_scores', False) + self.compute_hypothesis_token_set = self.cfg.get('compute_hypothesis_token_set', False) + + possible_strategies = ['greedy', 'greedy_batch', 'beam'] + if self.cfg.strategy not in possible_strategies: + raise ValueError(f"Decoding strategy must be one of {possible_strategies}") + + # Update preserve alignments + if self.preserve_alignments is None: + if self.cfg.strategy in ['greedy', 'greedy_batch']: + self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False) + + elif self.cfg.strategy in ['beam']: + self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False) + + if self.cfg.strategy == 'greedy' or self.cfg.strategy == 'greedy_batch': + + # self.decoding = None + raise NotImplementedError("Greedy decoding is not implemented yet.") + + elif self.cfg.strategy == 'beam': + + self.decoding = TransformerAEDBeamInfer( + transformer_decoder=transformer_decoder, + log_softmax_module=log_softmax_module, + tokenizer=tokenizer, + search_type=self.cfg.beam.get('search_type', 'default'), + beam_size=self.cfg.beam.beam_size, + length_penalty=self.cfg.beam.get('length_penalty', 0.0), + max_generation_delta=self.cfg.beam.get('max_generation_delta', 50), + return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True), + preserve_alignments=self.preserve_alignments, + ) + + else: + + raise ValueError( + f"Incorrect decoding strategy provided. Must be one of {possible_strategies}\n" + f"but was provided {self.cfg.strategy}" + ) + + def decode_predictions_tensor( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + return_hypotheses: bool = False, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ) -> Tuple[List[str], Optional[List[List[str]]], Optional[Union[Hypothesis, NBestHypotheses]]]: + """ + Decode an encoder output by autoregressive decoding of the Decoder+Joint networks. + + Args: + encoder_output: torch.Tensor of shape [B, D, T]. + encoded_lengths: torch.Tensor containing lengths of the padded encoder outputs. Shape [B]. + return_hypotheses: bool. If set to True it will return list of Hypothesis or NBestHypotheses + + Returns: + If `return_best_hypothesis` is set: + A tuple (hypotheses, None): + hypotheses - list of Hypothesis (best hypothesis per sample). + Look at rnnt_utils.Hypothesis for more information. + + If `return_best_hypothesis` is not set: + A tuple(hypotheses, all_hypotheses) + hypotheses - list of Hypothesis (best hypothesis per sample). + Look at rnnt_utils.Hypothesis for more information. + all_hypotheses - list of NBestHypotheses. Each NBestHypotheses further contains a sorted + list of all the hypotheses of the model per sample. + Look at rnnt_utils.NBestHypotheses for more information. + """ + # Compute hypotheses + with torch.inference_mode(): + hypotheses_list = self.decoding( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + return_scores=self.return_scores, + partial_hypotheses=partial_hypotheses, + ) # type: [List[Hypothesis]] + + # extract the hypotheses + hypotheses_list = hypotheses_list[0] # type: List[Hypothesis] + + prediction_list = hypotheses_list + + if isinstance(prediction_list[0], NBestHypotheses): + hypotheses = [] + all_hypotheses = [] + + for nbest_hyp in prediction_list: # type: NBestHypotheses + n_hyps = nbest_hyp.n_best_hypotheses # Extract all hypotheses for this sample + decoded_hyps = self.decode_hypothesis(n_hyps) + + hypotheses.append(decoded_hyps[0]) # best hypothesis + all_hypotheses.append(decoded_hyps) + + if return_hypotheses: + return hypotheses, all_hypotheses + + best_hyp_text = [h.text for h in hypotheses] + all_hyp_text = [h.text for hh in all_hypotheses for h in hh] + return best_hyp_text, all_hyp_text + + else: + hypotheses = self.decode_hypothesis(prediction_list) + + if return_hypotheses: + return hypotheses, None + + best_hyp_text = [h.text for h in hypotheses] + return best_hyp_text, None + + def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hypothesis, NBestHypotheses]]: + """ + Decode a list of hypotheses into a list of strings. + + Args: + hypotheses_list: List of Hypothesis. + + Returns: + A list of strings. + """ + for ind in range(len(hypotheses_list)): + # Extract the integer encoded hypothesis + prediction = hypotheses_list[ind].y_sequence + + if type(prediction) != list: + prediction = prediction.tolist() + + hypothesis = self.decode_tokens_to_str(prediction) + + if self.compute_hypothesis_token_set: + hypotheses_list[ind].tokens = self.decode_ids_to_tokens(prediction) + + # De-tokenize the integer tokens + hypotheses_list[ind].text = hypothesis + + return hypotheses_list + + @abstractmethod + def decode_tokens_to_str(self, tokens: List[int]) -> str: + """ + Implemented by subclass in order to decoder a token id list into a string. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded string. + """ + raise NotImplementedError() + + @abstractmethod + def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: + """ + Implemented by subclass in order to decode a token id list into a token list. + A token list is the string representation of each token id. + + Args: + tokens: List of int representing the token ids. + + Returns: + A list of decoded tokens. + """ + raise NotImplementedError() + + @abstractmethod + def decode_tokens_to_lang(self, tokens: List[int]) -> str: + """ + Implemented by subclass in order to + compute the most likely language ID (LID) string given the tokens. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded LID string. + """ + raise NotImplementedError() + + @abstractmethod + def decode_ids_to_langs(self, tokens: List[int]) -> List[str]: + """ + Implemented by subclass in order to + decode a token id list into language ID (LID) list. + + Args: + tokens: List of int representing the token ids. + + Returns: + A list of decoded LIDS. + """ + raise NotImplementedError() + + def strip_special_tokens(self, text: str): + """ + assuming all special tokens are of format + Note that if any label/pred is of format , it will be stripped + """ + assert isinstance(text, str), f"Expected str, got {type(text)}" + text = re.sub(r'<[^>]+>', '', text) + # strip spaces at the beginning and end; + # this is training data artifact, will be fixed in future (@kpuvvada) + return text.strip() + + +class MultiTaskDecoding(AbstractMultiTaskDecoding): + """ + Used for performing AED auto-regressive decoding of the Multi task model given the encoder state. + + Args: + decoding_cfg: A dict-like object which contains the following key-value pairs. + strategy: str value which represents the type of decoding that can occur. + Possible values are : + - greedy, greedy_batch (for greedy decoding). + - beam, tsd, alsd (for beam search decoding). + + return_scores: bool flag which determines whether to return the scores of the hypotheses. + + compute_langs: a bool flag, which allows to compute language id (LID) information per token, + word, and the entire sample (most likely language id). The LIDS will be available + in the returned Hypothesis object as a dictionary + + compute_hypothesis_token_set: A bool flag, which determines whether to compute a list of decoded + tokens as well as the decoded string. Default is False in order to avoid double decoding + unless required. + + preserve_alignments: Bool flag which preserves the history of logprobs generated during + decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `alignments` in it. Here, `alignments` is a List of List of + Tuple(Tensor (of length V + 1), Tensor(scalar, label after argmax)). + + In order to obtain this hypothesis, please utilize `rnnt_decoder_predictions_tensor` function + with the `return_hypotheses` flag set to True. + + The config may further contain the following sub-dictionaries: + "greedy": + max_symbols: int, describing the maximum number of target tokens to decode per + timestep during greedy decoding. Setting to larger values allows longer sentences + to be decoded, at the cost of increased execution time. + preserve_frame_confidence: Same as above, overrides above value. + confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg. + + "beam": + beam_size: int, defining the beam size for beam search. Must be >= 1. + If beam_size == 1, will perform cached greedy search. This might be slightly different + results compared to the greedy search above. + + length_penalty: float, length penalty for beam search decoding. Must be >= 0.0. + + max_generation_delta: int, maximum number of additional target tokens to generate + + return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the + hypotheses after beam search has concluded. This flag is set by default. + + + transformer_decoder: Transformer decoder module. + log_softmax_module: Log Softmax projection module to the vocab size. + tokenizer: TokenizerSpec. + """ + + def __init__( + self, + decoding_cfg, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + ): + self.tokenizer = tokenizer + + super().__init__( + decoding_cfg=decoding_cfg, + transformer_decoder=transformer_decoder, + log_softmax_module=log_softmax_module, + tokenizer=tokenizer, + ) + + if isinstance(self.decoding, AEDBeamInfer): + self.decoding.set_decoding_type('subword') + + def decode_tokens_to_str(self, tokens: List[int]) -> str: + """ + Implemented by subclass in order to decoder a token list into a string. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded string. + """ + hypothesis = self.tokenizer.ids_to_text(tokens) + return hypothesis + + def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: + """ + Implemented by subclass in order to decode a token id list into a token list. + A token list is the string representation of each token id. + + Args: + tokens: List of int representing the token ids. + + Returns: + A list of decoded tokens. + """ + token_list = self.tokenizer.ids_to_tokens(tokens) + return token_list + + def decode_tokens_to_lang(self, tokens: List[int]) -> str: + """ + Compute the most likely language ID (LID) string given the tokens. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded LID string. + """ + lang = self.tokenizer.ids_to_lang(tokens) + return lang + + def decode_ids_to_langs(self, tokens: List[int]) -> List[str]: + """ + Decode a token id list into language ID (LID) list. + + Args: + tokens: List of int representing the token ids. + + Returns: + A list of decoded LIDS. + """ + lang_list = self.tokenizer.ids_to_text_and_langs(tokens) + return lang_list + + def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hypothesis, NBestHypotheses]]: + """ + Decode a list of hypotheses into a list of strings. + Overrides the super() method optionally adding lang information + + Args: + hypotheses_list: List of Hypothesis. + + Returns: + A list of strings. + """ + hypotheses = super().decode_hypothesis(hypotheses_list) + if self.compute_langs: + if isinstance(self.tokenizer, AggregateTokenizer): + for ind in range(len(hypotheses_list)): + # Extract the integer encoded hypothesis + prediction = hypotheses_list[ind].y_sequence + + if type(prediction) != list: + prediction = prediction.tolist() + + hypotheses[ind].langs = self.decode_tokens_to_lang(prediction) + hypotheses[ind].langs_chars = self.decode_ids_to_langs(prediction) + else: + logging.warning( + "Ignoring request for lang output in hypotheses since the model does not use an aggregate tokenizer" + ) + + return hypotheses + + +@dataclass +class MultiTaskDecodingConfig: + strategy: str = "beam" + + compute_hypothesis_token_set: bool = False + + # preserve decoding alignments + preserve_alignments: Optional[bool] = None + + # compute language IDs + compute_langs: bool = False + + # greedy decoding config + # greedy: rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig = field( + # default_factory=rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig + # ) + + # beam decoding config + beam: AEDBeamInferConfig = field(default_factory=lambda: AEDBeamInferConfig(beam_size=4)) + + # can be used to change temperature for decoding + temperature: float = 1.0 diff --git a/nemo/collections/asr/parts/submodules/token_classifier.py b/nemo/collections/asr/parts/submodules/token_classifier.py new file mode 100644 index 000000000000..4061d19d9015 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/token_classifier.py @@ -0,0 +1,164 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Optional + +from torch import nn as nn + +from nemo.collections.asr.parts.submodules.classifier import Classifier +from nemo.collections.common.parts import MultiLayerPerceptron +from nemo.core.classes import typecheck +from nemo.core.neural_types import LogitsType, LogprobsType, NeuralType + +__all__ = ['BertPretrainingTokenClassifier', 'TokenClassifier'] + +ACT2FN = {"gelu": nn.functional.gelu, "relu": nn.functional.relu} + + +@dataclass +class TokenClassifierConfig: + num_layers: int = 1 + activation: str = 'relu' + log_softmax: bool = True + dropout: float = 0.0 + use_transformer_init: bool = True + + +class TokenClassifier(Classifier): + """ + A module to perform token level classification tasks such as Named entity recognition. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """ + Returns definitions of module output ports. + """ + if not self.log_softmax: + return {"logits": NeuralType(('B', 'T', 'C'), LogitsType())} + else: + return {"log_probs": NeuralType(('B', 'T', 'C'), LogprobsType())} + + def __init__( + self, + hidden_size: int, + num_classes: int, + num_layers: int = 1, + activation: str = 'relu', + log_softmax: bool = True, + dropout: float = 0.0, + use_transformer_init: bool = True, + ) -> None: + + """ + Initializes the Token Classifier module. + + Args: + hidden_size: the size of the hidden dimension + num_classes: number of classes + num_layers: number of fully connected layers in the multilayer perceptron (MLP) + activation: activation to usee between fully connected layers in the MLP + log_softmax: whether to apply softmax to the output of the MLP + dropout: dropout to apply to the input hidden states + use_transformer_init: whether to initialize the weights of the classifier head with the same approach used in Transformer + """ + super().__init__(hidden_size=hidden_size, dropout=dropout) + self.log_softmax = log_softmax + self.mlp = MultiLayerPerceptron( + hidden_size, num_classes, num_layers=num_layers, activation=activation, log_softmax=log_softmax + ) + self.post_init(use_transformer_init=use_transformer_init) + + @typecheck() + def forward(self, hidden_states): + """ + Performs the forward step of the module. + Args: + hidden_states: batch of hidden states (for example, from the BERT encoder module) + [BATCH_SIZE x SEQ_LENGTH x HIDDEN_SIZE] + Returns: logits value for each class [BATCH_SIZE x SEQ_LENGTH x NUM_CLASSES] + """ + hidden_states = self.dropout(hidden_states) + logits = self.mlp(hidden_states) + return logits + + +class BertPretrainingTokenClassifier(Classifier): + """ + A module to perform token level classification tasks for Bert pretraining. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """ + Returns definitions of module output ports. + """ + if not self.log_softmax: + return {"logits": NeuralType(('B', 'T', 'C'), LogitsType())} + else: + return {"log_probs": NeuralType(('B', 'T', 'C'), LogprobsType())} + + def __init__( + self, + hidden_size: int, + num_classes: int, + num_layers: int = 1, + activation: str = 'relu', + log_softmax: bool = True, + dropout: float = 0.0, + use_transformer_init: bool = True, + ) -> None: + + """ + Initializes the Token Classifier module. + + Args: + hidden_size: the size of the hidden dimension + num_classes: number of classes + num_layers: number of fully connected layers in the multilayer perceptron (MLP) + activation: activation to usee between fully connected layers in the MLP + log_softmax: whether to apply softmax to the output of the MLP + dropout: dropout to apply to the input hidden states + use_transformer_init: whether to initialize the weights of the classifier head with the same approach used in Transformer + """ + super().__init__(hidden_size=hidden_size, dropout=dropout) + + self.log_softmax = log_softmax + + if activation not in ACT2FN: + raise ValueError(f'activation "{activation}" not found') + self.dense = nn.Linear(hidden_size, hidden_size) + self.act = ACT2FN[activation] + self.norm = nn.LayerNorm(hidden_size, eps=1e-12) + self.mlp = MultiLayerPerceptron( + hidden_size, num_classes, num_layers=num_layers, activation=activation, log_softmax=log_softmax + ) + self.post_init(use_transformer_init=use_transformer_init) + + @typecheck() + def forward(self, hidden_states): + """ + Performs the forward step of the module. + Args: + hidden_states: batch of hidden states (for example, from the BERT encoder module) + [BATCH_SIZE x SEQ_LENGTH x HIDDEN_SIZE] + Returns: logits value for each class [BATCH_SIZE x SEQ_LENGTH x NUM_CLASSES] + """ + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = self.act(hidden_states) + transform = self.norm(hidden_states) + logits = self.mlp(transform) + return logits diff --git a/nemo/collections/asr/parts/utils/eval_utils.py b/nemo/collections/asr/parts/utils/eval_utils.py index 3a9cc1d15766..5584a5047178 100644 --- a/nemo/collections/asr/parts/utils/eval_utils.py +++ b/nemo/collections/asr/parts/utils/eval_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import re from typing import Optional, Tuple, Union from torchmetrics.text import SacreBLEUScore @@ -27,6 +28,55 @@ 'rouge': ROUGEScore, } +from omegaconf import DictConfig + + +def flatten_dict_config(config: DictConfig, parent_key='', sep='.', join='\n') -> str: + """ + Flatten a DictConfig object into a string of parameter names and their values. + + Args: + config (DictConfig): The input DictConfig object. + parent_key (str): The parent key for nested configurations. + sep (str): Separator between keys. + + Returns: + str: Flattened string of parameter names and their values. + """ + items = [] + for k, v in config.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, DictConfig): + items.extend(flatten_dict_config(v, new_key, sep=sep, join=join).split(join)) + else: + items.append(f"{new_key}={v}") + return join.join(items) + + +def get_hydra_override_from_config(config: Optional[DictConfig] = None, exclude_keys: Optional[list] = None) -> str: + """ + Flatten a DictConfig object into a string of hydra overrides for commandline, for example: + >>> config = OmegaConf.create({"foo": {"bar": 1, "baz": 2}}) + >>> get_hydra_override_from_config(config) + "++foo.bar=1 ++foo.baz=2" + """ + if not config: + return "" + join = '\n' + overrides = flatten_dict_config(config, join=join).split(join) + if exclude_keys: + overrides = [x for x in overrides if not any([y == x.split("=")[0] for y in exclude_keys])] + param_str = " ".join([f"++{x}" for x in overrides]) + return param_str + + +def strip_spaces_before_punctuations(text: str) -> str: + """ + Remove spaces before punctuations, e.g. "hello , world" -> "hello, world" + """ + result = re.sub(r'(\w)\s+([.,;!?])', r'\1\2', text) + return result + def remove_punctuations(text: str, punctuations: Optional[Union[list, str]] = None) -> str: """ @@ -115,6 +165,7 @@ def cal_write_wer( ignore_capitalization: bool = False, ignore_punctuation: bool = False, punctuations: Optional[list] = None, + strip_punc_space: bool = False, ) -> Tuple[str, dict, str]: """ Calculate wer, inserion, deletion and substitution rate based on groundtruth text and pred_text_attr_name (pred_text) @@ -147,6 +198,9 @@ def cal_write_wer( if ignore_punctuation: ref = remove_punctuations(ref, punctuations=punctuations) hyp = remove_punctuations(hyp, punctuations=punctuations) + elif strip_punc_space: + ref = strip_spaces_before_punctuations(ref) + hyp = strip_spaces_before_punctuations(hyp) if ignore_capitalization: ref = ref.lower() @@ -201,6 +255,7 @@ def cal_write_text_metric( punctuations: Optional[list] = None, metric: str = 'bleu', metric_args: Optional[dict] = None, + strip_punc_space: bool = False, ): samples = [] hyps = [] @@ -229,6 +284,9 @@ def cal_write_text_metric( if ignore_punctuation: ref = remove_punctuations(ref, punctuations=punctuations) hyp = remove_punctuations(hyp, punctuations=punctuations) + elif strip_punc_space: + ref = strip_spaces_before_punctuations(ref) + hyp = strip_spaces_before_punctuations(hyp) if ignore_capitalization: ref = ref.lower() diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index dcf0ea8b0770..17be7d1e33d2 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -16,6 +16,7 @@ import os import re from dataclasses import dataclass +from pathlib import Path from typing import List, Optional, Tuple, Union import torch @@ -313,6 +314,8 @@ def write_transcription( else: raise TypeError + # create output dir if not exists + Path(cfg.output_filename).parent.mkdir(parents=True, exist_ok=True) with open(cfg.output_filename, 'w', encoding='utf-8', newline='\n') as f: if cfg.audio_dir is not None: for idx, transcription in enumerate(best_hyps): # type: rnnt_utils.Hypothesis or str diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 5b5e0c1fa9de..37b664725ddd 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -16,10 +16,12 @@ import random import warnings from dataclasses import dataclass +from functools import partial from typing import Any, Callable, Optional import torch from lhotse import CutSet +from lhotse.cut import Cut from lhotse.dataset import ( CutConcatenate, DynamicBucketingSampler, @@ -62,8 +64,8 @@ class LhotseDataLoadingConfig: bucket_duration_bins: list[float] | None = None bucket_buffer_size: int = 10000 # d. Other Lhotse sampling options. - shuffle_buffer_size: int = 10000 - drop_last: bool = True + shuffle_buffer_size: int | None = 10000 + drop_last: bool = False shard_seed: int | str = "trng" max_open_streams: int | None = None @@ -72,7 +74,7 @@ class LhotseDataLoadingConfig: sample_rate: int = 16000 min_duration: float | None = -1 max_duration: float | None = float("inf") - seed: int = 0 + seed: int | str = "randomized" # int | "randomized" | "trng"; the latter two are lazily resolved by Lhotse in dloading worker processes num_workers: int = 0 pin_memory: bool = False @@ -123,12 +125,7 @@ def get_lhotse_dataloader_from_config( cuts = cuts.resample(config.sample_rate) # Duration filtering, same as native NeMo dataloaders. - min_dur, max_dur = config.min_duration, config.max_duration - cuts = cuts.filter(lambda c: min_dur <= c.duration <= max_dur) - - # Safeguard against utterances with identical IDs across different datasets - # that would make Lhotse complain otherwise. - cuts = cuts.modify_ids(create_id_randomizer(config.seed)) + cuts = cuts.filter(DurationFilter(config.min_duration, config.max_duration)) # 2. Optional augmentations. # 2.a. Noise mixing. @@ -204,9 +201,9 @@ def get_lhotse_dataloader_from_config( CutConcatenate(gap=config.concatenate_gap_seconds, duration_factor=config.concatenate_duration_factor,) ) if config.db_norm is not None: - sampler = sampler.map(lambda cuts: cuts.normalize_loudness(config.db_norm, mix_first=False)) + sampler = sampler.map(partial(_normalize_loudness, db_norm=config.db_norm)) if config.concatenate_merge_supervisions: - sampler = sampler.map(lambda cuts: cuts.merge_supervisions()) + sampler = sampler.map(_merge_supervisions) # 4. Creating dataloader. if is_tarred: @@ -234,16 +231,6 @@ def get_lhotse_dataloader_from_config( return dloader -def create_id_randomizer(seed: int = 0) -> Callable[[str], str]: - rng = random.Random(seed) - max_sfx = 2 ** 20 - 1 - - def add_random_suffix(cut_id: str) -> str: - return f"{cut_id}-rnd{rng.randint(0, max_sfx):07d}" - - return add_random_suffix - - def make_structured_with_schema_warnings(config: DictConfig) -> DictConfig: """ Checks the schema and fills missing default option values. @@ -264,3 +251,28 @@ def make_structured_with_schema_warnings(config: DictConfig) -> DictConfig: config = OmegaConf.masked_copy(config, list(supported_keys)) return OmegaConf.merge(default, config) + + +# The helper callables below exist to avoid passing lambdas into lhotse CutSet map/filter methods. +# Lambdas are not serializable across processes by pickle. +# Note: lhotse offers LHOTSE_DILL_ENABLED=1 and ``lhotse.lazy.set_dill_enabled(True)`` +# to support pickling lambdas if its ever truly necessary. + + +class DurationFilter: + """Callable, returns ``True`` if a cut's duration is in range [d_min, d_max] and ``False`` otherwise.""" + + def __init__(self, d_min: float, d_max: float) -> None: + self.d_min = d_min + self.d_max = d_max + + def __call__(self, cut: Cut) -> bool: + return self.d_min <= cut.duration <= self.d_max + + +def _normalize_loudness(cuts: CutSet, db_norm: float) -> CutSet: + return cuts.normalize_loudness(target=db_norm, mix_first=False) + + +def _merge_supervisions(cuts: CutSet) -> CutSet: + return cuts.merge_supervisions() diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index 989469f7f409..25476585ede1 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -18,17 +18,18 @@ import tarfile from io import BytesIO from pathlib import Path -from typing import Iterable, List +from typing import Generator, Iterable, List import soundfile from cytoolz import groupby from lhotse import AudioSource, Recording, SupervisionSegment -from lhotse.lazy import ImitatesDict, LazyIteratorChain, LazyJsonlIterator +from lhotse.cut import Cut +from lhotse.lazy import LazyIteratorChain, LazyJsonlIterator from lhotse.serialization import open_best from lhotse.utils import compute_num_samples -class LazyNeMoIterator(ImitatesDict): +class LazyNeMoIterator: """ ``LazyNeMoIterator`` reads a NeMo (non-tarred) JSON manifest and converts it on the fly to an ``Iterable[Cut]``. It's used to create a ``lhotse.CutSet``. @@ -62,7 +63,7 @@ def __init__(self, path: str | Path, text_field: str = "text", lang_field: str = def path(self) -> str | Path: return self.source.path - def __iter__(self): + def __iter__(self) -> Generator[Cut, None, None]: for data in self.source: audio_path = data.pop("audio_filepath") duration = data.pop("duration") @@ -94,7 +95,7 @@ def __add__(self, other): return LazyIteratorChain(self, other) -class LazyNeMoTarredIterator(ImitatesDict): +class LazyNeMoTarredIterator: """ ``LazyNeMoTarredIterator`` reads a NeMo tarred JSON manifest and converts it on the fly to an ``Iterable[Cut]``. It's used to create a ``lhotse.CutSet``. @@ -189,7 +190,7 @@ def _validate(self) -> None: def shard_ids(self) -> List[int]: return sorted(self.shard_id_to_manifest.keys()) - def __iter__(self): + def __iter__(self) -> Generator[Cut, None, None]: shard_ids = self.shard_ids if self.shuffle_shards: diff --git a/nemo/collections/common/tokenizers/__init__.py b/nemo/collections/common/tokenizers/__init__.py index f46e3b150acc..7dc3387c005f 100644 --- a/nemo/collections/common/tokenizers/__init__.py +++ b/nemo/collections/common/tokenizers/__init__.py @@ -14,6 +14,7 @@ from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer from nemo.collections.common.tokenizers.bytelevel_tokenizers import ByteLevelTokenizer +from nemo.collections.common.tokenizers.canary_tokenizer import CanaryTokenizer from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer diff --git a/nemo/collections/common/tokenizers/canary_tokenizer.py b/nemo/collections/common/tokenizers/canary_tokenizer.py new file mode 100644 index 000000000000..b812cdb46dd5 --- /dev/null +++ b/nemo/collections/common/tokenizers/canary_tokenizer.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cached_property +from pathlib import Path +from typing import Dict + +from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model + +__all__ = ['CanaryTokenizer'] + + +LANGUAGES = { + "en": "english", + "de": "german", + "es": "spanish", + "fr": "french", +} + +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, +} + +SPECIAL_TOKENS = [ + "", + "<|endoftext|>", + "<|startoftranscript|>", + *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())], + "<|transcribe|>", + "<|translate|>", + "<|nopnc|>", + "<|pnc|>", + "<|nospeech|>", +] + +UNUSED_SPECIAL_TOKENS = [f"<|spltoken{i}|>" for i in range(18)] + + +class CanaryTokenizer(AggregateTokenizer): + """ + Thin wrapper around AggregateTokenizer to provide quick access to special tokens + """ + + def __init__(self, tokenizers: Dict): + super().__init__(tokenizers) + + # for easy access of special tokens + special_tokens: Dict[str, int] = {} + for special in SPECIAL_TOKENS: + special_tokens[special] = self.token_to_id(special, lang_id='spl_tokens') + + self.special_tokens = special_tokens + + @cached_property + def eos_id(self) -> int: + return self.special_tokens["<|endoftext|>"] + + @cached_property + def bos_id(self) -> int: + return self.special_tokens["<|startoftranscript|>"] + + @cached_property + def transcribe_id(self) -> int: + return self.special_tokens["<|transcribe|>"] + + @cached_property + def translate_id(self) -> int: + return self.special_tokens["<|translate|>"] + + @cached_property + def nopnc_id(self) -> int: + return self.special_tokens["<|nopnc|>"] + + @cached_property + def pnc_id(self) -> int: + return self.special_tokens["<|pnc|>"] + + @cached_property + def nospeech_id(self) -> int: + return self.special_tokens["<|nospeech|>"] + + @cached_property + def pad_id(self) -> int: + return self.special_tokens[""] + + def to_language_id(self, language): + if token_id := self.special_tokens.get(f"<|{language}|>", None): + return token_id + + raise KeyError(f"Language {language} not found in tokenizer.") + + @staticmethod + def build_special_tokenizer(output_dir: str | Path) -> SentencePieceTokenizer: + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + text_path = output_dir / "train_text.txt" + all_tokens = SPECIAL_TOKENS + UNUSED_SPECIAL_TOKENS + train_text = "\n".join(all_tokens) + text_path.write_text(train_text) + model_path = output_dir / "tokenizer.model" + create_spt_model( + str(text_path), + vocab_size=32, + sample_size=-1, + do_lower_case=False, + output_dir=str(output_dir), + user_defined_symbols=all_tokens, + ) + return SentencePieceTokenizer(str(model_path)) diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 1bac077bde3c..7454ec5b97cc 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -42,6 +42,10 @@ __all__ = ['ModelPT'] +# multiple interpolated values in the config +OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) + + class ModelPT(LightningModule, Model): """ Interface for Pytorch-lightning based NeMo models diff --git a/nemo/core/config/hydra_runner.py b/nemo/core/config/hydra_runner.py index 9cabc45042f7..604d2134f66b 100644 --- a/nemo/core/config/hydra_runner.py +++ b/nemo/core/config/hydra_runner.py @@ -45,7 +45,7 @@ def _get_gpu_name(): OmegaConf.register_new_resolver("gpu_name", _get_gpu_name) # multiple interpolated values in the config -OmegaConf.register_new_resolver("multiply", lambda x, y: x * y) +OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) def hydra_runner( diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index 671f06f3dcca..6df223209cc1 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -5,7 +5,7 @@ ipywidgets jiwer kaldi-python-io kaldiio -lhotse==1.19.2 +lhotse>=1.20.0 librosa>=0.10.0 marshmallow matplotlib diff --git a/tests/collections/asr/test_custom_tokenizer.py b/tests/collections/asr/test_custom_tokenizer.py new file mode 100644 index 000000000000..79cb6255fb31 --- /dev/null +++ b/tests/collections/asr/test_custom_tokenizer.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +import pytest +import sentencepiece as spm +from omegaconf import OmegaConf + +from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.collections.common.tokenizers.canary_tokenizer import SPECIAL_TOKENS, UNUSED_SPECIAL_TOKENS, CanaryTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model +from nemo.core import Serialization + + +@pytest.fixture(scope="session") +def special_tokenizer_path(tmp_path_factory) -> str: + tmpdir = tmp_path_factory.mktemp("spl_tokens") + CanaryTokenizer.build_special_tokenizer(tmpdir) + return str(tmpdir) + + +@pytest.fixture(scope="session") +def lang_tokenizer_path(tmp_path_factory) -> str: + tmpdir = tmp_path_factory.mktemp("klingon_tokens") + text_path = tmpdir / "text.txt" + text_path.write_text("a\nb\nc\nd\n") + create_spt_model(text_path, vocab_size=8, sample_size=-1, do_lower_case=False, output_dir=str(tmpdir)) + return str(tmpdir) + + +def test_canary_tokenizer_build_special_tokenizer(tmp_path): + tokenizer = CanaryTokenizer.build_special_tokenizer(tmp_path) + expected_tokens = [""] + SPECIAL_TOKENS + UNUSED_SPECIAL_TOKENS + ["▁"] + tokens = [] + for i in range(tokenizer.tokenizer.vocab_size()): + tokens.append(tokenizer.tokenizer.IdToPiece(i)) + assert expected_tokens == tokens + + +def test_canary_tokenizer_init_from_cfg(special_tokenizer_path, lang_tokenizer_path): + class DummyModel(ASRBPEMixin, Serialization): + pass + + model = DummyModel() + model.register_artifact = Mock(side_effect=lambda self, x: x) + config = OmegaConf.create( + { + "type": "agg", + "dir": None, + "langs": { + "spl_tokens": {"dir": special_tokenizer_path, "type": "bpe"}, + "en": {"dir": lang_tokenizer_path, "type": "bpe"}, + }, + "custom_tokenizer": {"_target_": "nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer",}, + } + ) + model._setup_aggregate_tokenizer(config) + tokenizer = model.tokenizer + + assert isinstance(tokenizer, CanaryTokenizer) + assert len(tokenizer.tokenizers_dict) == 2 + assert set(tokenizer.tokenizers_dict.keys()) == {"spl_tokens", "en"} + + assert isinstance(tokenizer.tokenizers_dict["spl_tokens"], SentencePieceTokenizer) + assert tokenizer.tokenizers_dict["spl_tokens"].vocab_size == 32 + + assert isinstance(tokenizer.tokenizers_dict["en"], SentencePieceTokenizer) + assert tokenizer.tokenizers_dict["en"].vocab_size == 6 + + assert tokenizer.text_to_ids("<|startoftranscript|>", lang_id="spl_tokens") == [31, 3] # "_" comes first + assert tokenizer.text_to_ids("a", lang_id="en") == [32 + 1, 32 + 2] diff --git a/tools/asr_evaluator/asr_evaluator.py b/tools/asr_evaluator/asr_evaluator.py index 882def194754..5aa97b36d885 100644 --- a/tools/asr_evaluator/asr_evaluator.py +++ b/tools/asr_evaluator/asr_evaluator.py @@ -75,6 +75,7 @@ def main(cfg): ignore_capitalization=cfg.analyst.metric_calculator.get("ignore_capitalization", False), ignore_punctuation=cfg.analyst.metric_calculator.get("ignore_punctuation", False), punctuations=cfg.analyst.metric_calculator.get("punctuations", None), + strip_punc_space=cfg.analyst.metric_calculator.get("strip_punc_space", False), ) else: output_manifest_w_wer, total_res, eval_metric = cal_write_text_metric( @@ -87,6 +88,7 @@ def main(cfg): punctuations=cfg.analyst.metric_calculator.get("punctuations", None), metric=cfg.analyst.metric_calculator.get("metric", "bleu"), metric_args=cfg.analyst.metric_calculator.get("metric_args", None), + strip_punc_space=cfg.analyst.metric_calculator.get("strip_punc_space", False), ) with open_dict(cfg): diff --git a/tools/asr_evaluator/conf/eval.yaml b/tools/asr_evaluator/conf/eval.yaml index 883d5a2b1f08..5721d8f19ba6 100644 --- a/tools/asr_evaluator/conf/eval.yaml +++ b/tools/asr_evaluator/conf/eval.yaml @@ -37,7 +37,12 @@ engine: max_snr_db: 15 rng: *random_seed - + transcribe_params: + # Put additional overrides for params in TranscriptionConfig used by transcribe_speech.py here + # Don't put the following fields here: 'calculate_wer', 'model_path', 'pretrained_name', 'dataset_manifest', + # 'output_filename', 'batch_size', 'num_workers', 'random_seed', 'eval_config_yaml', 'decoder_type' + allow_partial_transcribe: False # only set True if your audio is too long and have 'offset' in manifest + analyst: metric_calculator: exist_pred_manifest: null # specify the previously generated manifest will skip engine @@ -48,6 +53,7 @@ analyst: ignore_capitalization: False ignore_punctuation: False punctuations: null # a string of punctuations to remove when ignore_punctuation=True. if not set, default to '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~' + strip_punc_space: False # strip spaces before punctuations. e.g., "I do ." -> "I do." metadata: duration: diff --git a/tools/asr_evaluator/utils.py b/tools/asr_evaluator/utils.py index 1702dc3caf53..d240ca84e5b9 100644 --- a/tools/asr_evaluator/utils.py +++ b/tools/asr_evaluator/utils.py @@ -18,6 +18,7 @@ from typing import Tuple from omegaconf import DictConfig, OmegaConf, open_dict +from nemo.collections.asr.parts.utils.eval_utils import get_hydra_override_from_config from nemo.utils import logging @@ -183,7 +184,20 @@ def run_offline_inference(cfg: DictConfig) -> DictConfig: OmegaConf.save(cfg, f) f.seek(0) # reset file pointer script_path = Path(__file__).parents[2] / "examples" / "asr" / "transcribe_speech.py" - + # some keys to ingore when generating hydra overrides + exclude_keys = [ + 'calculate_wer', + 'model_path', + 'pretrained_name', + 'dataset_manifest', + 'output_filename', + 'batch_size', + 'num_workers', + 'random_seed', + 'eval_config_yaml', + 'decoder_type', + ] + hydra_overrides = get_hydra_override_from_config(cfg.get("transcribe_params", None), exclude_keys=exclude_keys) # If need to change other config such as decoding strategy, could either: # 1) change TranscriptionConfig on top of the executed scripts such as transcribe_speech.py in examples/asr, or # 2) add command as "rnnt_decoding.strategy=greedy_batch " to below script @@ -198,7 +212,7 @@ def run_offline_inference(cfg: DictConfig) -> DictConfig: f"num_workers={cfg.test_ds.num_workers} " f"random_seed={cfg.random_seed} " f"eval_config_yaml={f.name} " - f"decoder_type={cfg.inference.decoder_type} ", + f"decoder_type={cfg.inference.decoder_type} {hydra_overrides}", shell=True, check=True, )