From 568eee9bcf81f1cda2234e07cee2188108200c75 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Wed, 7 Feb 2024 12:34:43 -0800 Subject: [PATCH] Add change_vocabulary and save_tokenizers() support to Multitask ASR models (#8357) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add change_vocabulary and save_tokenizers() support Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update nemo/collections/asr/models/aed_multitask_models.py Co-authored-by: Piotr Żelasko Signed-off-by: Somshubra Majumdar --------- Signed-off-by: smajumdar Signed-off-by: Somshubra Majumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Piotr Żelasko --- .../asr/models/aed_multitask_models.py | 135 +++++++++++++++++- nemo/collections/asr/parts/mixins/mixins.py | 95 +++++++++++- 2 files changed, 228 insertions(+), 2 deletions(-) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 668789ebc156..5740a74bd876 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import itertools import os import tempfile @@ -23,7 +24,7 @@ import numpy as np import torch import torch.distributed as dist -from omegaconf import DictConfig, OmegaConf, open_dict +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict from pytorch_lightning import Trainer from torchmetrics.text import SacreBLEUScore from tqdm.auto import tqdm @@ -247,6 +248,138 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + def change_vocabulary( + self, + new_tokenizer_dir: Union[str, DictConfig], + new_tokenizer_type: str, + decoding_cfg: Optional[DictConfig] = None, + prompt_format: Optional[str] = None, + ): + """ + Changes vocabulary used during AED decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + Args: + new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) + new_tokenizer_type: Type of tokenizer. Can be either `agg`, `bpe` or `wpe`. + decoding_cfg: A config for the decoding, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + prompt_format: A string alias of the object that represents the prompt structure. + If not None, it will be used to update the prompt format. + """ + if isinstance(new_tokenizer_dir, (dict, DictConfig)): + if new_tokenizer_type == 'agg': + if not isinstance(new_tokenizer_dir, DictConfig): + new_tokenizer_dir = OmegaConf.create(new_tokenizer_dir) + + new_tokenizer_cfg = new_tokenizer_dir + else: + raise ValueError( + f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' + ) + else: + new_tokenizer_cfg = None + + if new_tokenizer_cfg is not None: + tokenizer_cfg = new_tokenizer_cfg + else: + if not os.path.isdir(new_tokenizer_dir): + raise NotADirectoryError( + f'New tokenizer dir must be non-empty path to a directory. But instead got: {new_tokenizer_dir}' + ) + + if new_tokenizer_type.lower() not in ('bpe', 'wpe'): + raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`') + + tokenizer_cfg = OmegaConf.create({'dir': new_tokenizer_dir, 'type': new_tokenizer_type}) + + if prompt_format is None: + prompt_format = self.cfg.prompt_format + + # Setup the tokenizer + self._setup_tokenizer(tokenizer_cfg) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Setup Decoder + transf_decoder_cfg_dict = self.transf_decoder.to_config_dict() + + vocab_size = 8 * ceil(self.tokenizer.vocab_size / 8) + + # Auto inject vocab size for `get_transformer` + with open_dict(transf_decoder_cfg_dict): + if 'config_dict' in transf_decoder_cfg_dict: + transf_decoder_cfg_dict['config_dict']['vocab_size'] = vocab_size + + original_decoder_state_dict = self.transf_decoder.state_dict() + self.transf_decoder = EncDecMultiTaskModel.from_config_dict(transf_decoder_cfg_dict) + + # Partially load the original state dict into the new decoder + decoder_state_dict = self.transf_decoder.state_dict() + for og_key, og_value in original_decoder_state_dict.items(): + if og_key in decoder_state_dict and og_value.shape == decoder_state_dict[og_key].shape: + decoder_state_dict[og_key] = og_value + else: + logging.warning( + f"Skipping key `{og_key}` in the `transf_decoder` module from original state dict due " + f"to shape mismatch after change in vocabulary.\n" + f"Original shape: {og_value.shape}, New shape: {decoder_state_dict[og_key].shape}" + ) + + self.transf_decoder.load_state_dict(decoder_state_dict) + + # Setup token classifier + with open_dict(self.cfg.head): + self.cfg.head.num_classes = vocab_size + + del self.log_softmax + self.log_softmax = EncDecMultiTaskModel.from_config_dict(self.cfg.head) + + # Weight tying - if using TokenClassifier only + if isinstance(self.log_softmax, TokenClassifier): + self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight + + # Initialize weights of token classifier + std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5 + self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) + + # Setup Decoding class + if decoding_cfg is None: + # Assume same decoding config as before + decoding_cfg = self.cfg.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(MultiTaskDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + del self.decoding + self.decoding = MultiTaskDecoding( + decoding_cfg=decoding_cfg, + transformer_decoder=self.transf_decoder, + log_softmax_module=self.log_softmax, + tokenizer=self.tokenizer, + ) + + with open_dict(self.cfg.decoding): + self.cfg.decoding = decoding_cfg + + # Setup loss + with open_dict(self.cfg.loss): + self.cfg.loss.pad_id = self.tokenizer.pad_id + + del self.loss + self.loss = EncDecMultiTaskModel.from_config_dict(self.cfg.loss) + + # Update config + with open_dict(self.cfg): + self.cfg.prompt_format = prompt_format + + logging.info(f"Changed decoder to output to {vocabulary} vocabulary.") + @torch.no_grad() def transcribe( self, diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index eeac9d3c78ad..006f028a0a1d 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import shutil +import tarfile from abc import ABC, abstractmethod from typing import List @@ -25,7 +27,7 @@ from nemo.collections.asr.parts.utils import asr_module_utils from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.common import tokenizers -from nemo.utils import logging +from nemo.utils import app_state, logging class ASRBPEMixin(ABC): @@ -372,6 +374,97 @@ def _cleanup_aggregate_config_and_artifacts_if_needed(self): if akey.startswith('tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.'): self.artifacts.pop(akey) + def save_tokenizers(self, directory: str): + """ + Save the model tokenizer(s) to the specified directory. + + Args: + directory: The directory to save the tokenizer(s) to. + """ + if not hasattr(self, 'cfg'): + raise RuntimeError( + "The model has not been initialized with a tokenizer yet. Please call the model's " + "__init__ and _setup_tokenizer methods first." + ) + + if self.tokenizer_type == 'agg': + for lang in self.tokenizer.langs: + subconfig = self.cfg.tokenizer.langs.get(lang) + new_dir = os.path.join(directory, lang) + self._extract_tokenizer_from_config(subconfig, new_dir) + else: + self._extract_tokenizer_from_config(self.cfg.tokenizer, directory) + + def _extract_tokenizer_from_config(self, tokenizer_cfg: DictConfig, dir: str): + """ + Extracts the tokenizer from the config and write the objects to dir. + The file may be from a local path (new model init) or from a .nemo file (restored model). + If its from a newly initialized model, the file is copied to dir. + If its from a restored model, the file is extracted from the .nemo file and copied to dir. + + Args: + tokenizer_cfg: The tokenizer config to extract the tokenizer from. + dir: The directory to write the tokenizer objects to. + """ + if not os.path.exists(dir): + os.makedirs(dir, exist_ok=True) + + nemo_file_objects = [] + + for k, v in tokenizer_cfg.items(): + # Check if the value is a filepath (new model init) or has `nemo:` in it (restored model) + if isinstance(v, str) and os.path.exists(v): + # local file from first instantiation + loc = shutil.copy2(v, dir) + logging.info(f"Saved {k} at {loc}") + + if isinstance(v, str) and v.startswith('nemo:'): + nemo_object_name = v[5:] + nemo_file_objects.append(nemo_object_name) + + if len(nemo_file_objects) > 0: + logging.debug(f"Copying the following nemo file objects to {dir}: {nemo_file_objects}") + + if not hasattr(self, 'model_guid'): + raise ValueError( + "The model does not have a model_guid attribute. " + "Please ensure that the model has been restored from a .nemo file." + ) + + appstate = app_state.AppState() + restore_path = appstate.get_model_metadata_from_guid(self.model_guid).restoration_path + if restore_path is None: + raise ValueError( + "The model has not been restored from a .nemo file. Cannot extract the tokenizer " + "as the nemo file cannot be located." + ) + + # Read the nemo file without fully extracting all contents + # we start with an assumption of uncompressed tar, + # which should be true for versions 1.7.0 and above + tar_header = "r:" + try: + tar_test = tarfile.open(restore_path, tar_header) + tar_test.close() + except tarfile.ReadError: + # can be older checkpoint => try compressed tar + tar_header = "r:gz" + tar = tarfile.open(restore_path, tar_header) + + for nemo_object_name in nemo_file_objects: + members = [x for x in tar.getmembers() if nemo_object_name in x.name] + for member in members: + tar.extract(member, dir) + + new_name = member.name.split("_")[1:] + if len(new_name) > 1: + new_name = "_".join(new_name) + else: + new_name = new_name[0] + os.rename(os.path.join(dir, member.name), os.path.join(dir, new_name)) + + logging.info(f"Saved {nemo_object_name} at {os.path.join(dir, new_name)}") + class ASRModuleMixin(ASRAdapterModelMixin): """