From 7a23bfa3969da3acb60a3f00a5191652833ca880 Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Mon, 13 May 2024 13:53:19 -0400 Subject: [PATCH 01/18] Change FIM Dataset Random Seed Init (#9165) * change seed to dataset init * Apply isort and black reformatting Signed-off-by: suiyoubi --------- Signed-off-by: suiyoubi Co-authored-by: suiyoubi --- .../megatron/gpt_fim_dataset.py | 50 ++++++++++--------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_fim_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_fim_dataset.py index 474761c41d67b..358dbc22a2cd6 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_fim_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_fim_dataset.py @@ -17,6 +17,7 @@ import numpy as np from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults +from nemo.utils import logging try: from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig @@ -36,8 +37,8 @@ class GPTFIMDatasetConfig(GPTDatasetConfig): """Configuration object for Megatron Core GPT FIM datasets - Attributes: - fim: fill in the middle parameters config + Attributes: + fim: fill in the middle parameters config """ def __init__(self, fim, **kwargs): @@ -79,6 +80,27 @@ def __init__( super().__init__(indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config) self.indexed_dataset = indexed_dataset + self.np_rng = np.random.RandomState(seed=self.config.random_seed) + logging.info(f"Initialized FIM RNG with seed = {self.config.random_seed}") + # get FIM params + self.fim_rate = self.config.fim.get('rate', 0.5) + self.fim_spm_rate = self.config.fim.get('spm_rate', 0.5) + self.fragment_fim_rate = self.config.fim.get('fragment_rate', 0.5) + split_sample = self.config.fim.get('split_sample', None) + self.fim_split_sample = self.config.tokenizer.tokens_to_ids(split_sample) if split_sample else None + self.no_fim_prefix = self.config.fim.get('no_prefix', None) + + # get extra tokens ids + fim_tokens = self.config.fim.extra_tokens + fim_tokens = [fim_tokens.prefix, fim_tokens.middle, fim_tokens.suffix, fim_tokens.pad, fim_tokens.eod] + fim_tokens_ids = self.config.tokenizer.tokens_to_ids(fim_tokens) + ( + self.prefix_tok_id, + self.middle_tok_id, + self.suffix_tok_id, + self.pad_tok_id, + self.eod_tok_id, + ) = fim_tokens_ids def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, np.ndarray]: """Get the text (token ids) and document ids for a given index @@ -126,29 +148,9 @@ def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, sample = np.concatenate(sample_parts) - # get FIM params - self.fim_rate = self.config.fim.get('rate', 0.5) - self.fim_spm_rate = self.config.fim.get('spm_rate', 0.5) - self.fragment_fim_rate = self.config.fim.get('fragment_rate', 0.5) - split_sample = self.config.fim.get('split_sample', None) - self.fim_split_sample = self.config.tokenizer.tokens_to_ids(split_sample) if split_sample else None - self.no_fim_prefix = self.config.fim.get('no_prefix', None) - - # get extra tokens ids - fim_tokens = self.config.fim.extra_tokens - fim_tokens = [fim_tokens.prefix, fim_tokens.middle, fim_tokens.suffix, fim_tokens.pad, fim_tokens.eod] - fim_tokens_ids = self.config.tokenizer.tokens_to_ids(fim_tokens) - ( - self.prefix_tok_id, - self.middle_tok_id, - self.suffix_tok_id, - self.pad_tok_id, - self.eod_tok_id, - ) = fim_tokens_ids - sample_len = sample.shape[0] segment_breaks = np.argwhere(sample == self.eod_tok_id) - np_rng = np.random.RandomState(seed=self.config.random_seed) + np_rng = self.np_rng if segment_breaks.shape != (0, 1): # then there is an EOD token in this example curr_start_position = 0 @@ -245,7 +247,7 @@ def _permute( no_fim_prefix=None, ): """ - Take in a sample (np array w/ size (0,chunklength)) and perform a FIM transformation on it. + Take in a sample (np array w/ size (0,chunklength)) and perform a FIM transformation on it. Maintain the same sample length (if transform creates a few extra tokens, drop them). """ if np_rng.binomial(1, fim_rate): # sample bernoulli dist From 43686ecef00837bca9a1c63e64759dc57d4fe2f7 Mon Sep 17 00:00:00 2001 From: Pablo Garay Date: Mon, 13 May 2024 15:40:54 -0700 Subject: [PATCH 02/18] increase time limit for Speech_Checkpoints_tests (#9186) --- .github/workflows/cicd-main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index ef646ab92e7ba..4652e4d19f897 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -6484,7 +6484,7 @@ jobs: Speech_Checkpoints_tests: needs: [cicd-test-container-setup] runs-on: self-hosted-azure - timeout-minutes: 10 + timeout-minutes: 20 container: image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} options: From 467d94b7b9ab796b49025487edc05e635e0f8a94 Mon Sep 17 00:00:00 2001 From: gdengk <160076886+gdengk@users.noreply.github.com> Date: Mon, 13 May 2024 15:58:56 -0700 Subject: [PATCH 03/18] fix ep rank (#9161) Signed-off-by: Gao Deng --- nemo/collections/nlp/modules/common/megatron/megatron_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 5d5b65b360eec..341e534bcd898 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -315,7 +315,7 @@ def fake_initialize_model_parallel( if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1: for ranks in rank_generator.get_ranks('ep', independent_ep=True): if rank in ranks: - expert_model_parallel_rank = list(ranks).index(rank) // tensor_model_parallel_size + expert_model_parallel_rank = list(ranks).index(rank) # Build the pipeline model-parallel groups and embedding groups # (first and last rank in each pipeline model-parallel group). From 77090d4e5e218261b1fe6b3a931d16f4083f2d53 Mon Sep 17 00:00:00 2001 From: meatybobby Date: Mon, 13 May 2024 16:14:34 -0700 Subject: [PATCH 04/18] TRTLLM new API support (#9003) * Add trtllm checkpoint * Change model config * fix no query_group * Using build API * Change export to new API * Update generate API * Fix runtime config * Fix for llama * Fix for ptuning * Fix TP issue * Change TP rank for building weight dict * Add lora config * add prompt embedding table config * Fix PP isue * PP layers fix * Fix no prompt task ids * Add bos for Gemma * Add multi block mode * Embedding and layernorm for PP * MPI multiprocess support for multinode * Only output text on first rank * Change to ModelRunnerCpp * Add falcon * Add rotary_pct default value * Falcon fix * Add MOE config * Fix MOE weight dict * Clean code * Add rotary_base * Fix MOE config * Fix falcon new architecture * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Gemma 7B * Add rotary_scaling * Apply isort and black reformatting Signed-off-by: oyilmaz-nvidia --------- Signed-off-by: oyilmaz-nvidia Co-authored-by: abharwani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com> Co-authored-by: oyilmaz-nvidia Co-authored-by: Eric Harper --- nemo/export/tensorrt_llm.py | 138 ++++++++------ nemo/export/trt_llm/decoder/__init__.py | 8 + nemo/export/trt_llm/nemo/convert.py | 71 ++++--- nemo/export/trt_llm/nemo/nemo_ckpt_convert.py | 39 ++-- nemo/export/trt_llm/nemo_utils.py | 180 +++++++++++++++++- nemo/export/trt_llm/tensorrt_llm_build.py | 90 ++++++++- nemo/export/trt_llm/tensorrt_llm_run.py | 130 ++++++------- scripts/export/export_to_trt_llm.py | 12 +- 8 files changed, 468 insertions(+), 200 deletions(-) diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 033044b3b3285..af4f1b6699ee3 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -30,9 +30,10 @@ from nemo.export.tarutils import TarPath, unpack_tarball from nemo.export.trt_llm.model_config_trt import model_config_to_tensorrt_llm from nemo.export.trt_llm.nemo.nemo_ckpt_convert import build_tokenizer -from nemo.export.trt_llm.nemo_utils import get_tokenzier, nemo_llm_model_to_model_config, nemo_llm_to_model_config +from nemo.export.trt_llm.nemo_utils import get_tokenzier, nemo_llm_model_to_model_config, nemo_to_trtllm_config from nemo.export.trt_llm.qnemo import qnemo_to_tensorrt_llm from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer +from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_refit from nemo.export.trt_llm.utils import is_nemo_file @@ -115,6 +116,7 @@ def export( max_output_token: int = 256, max_batch_size: int = 8, max_prompt_embedding_table_size=None, + use_parallel_embedding: bool = False, use_inflight_batching: bool = False, enable_context_fmha: bool = True, paged_kv_cache: bool = False, @@ -188,65 +190,70 @@ def export( self.model = None - tmp_dir = tempfile.TemporaryDirectory() - nemo_export_dir = Path(tmp_dir.name) + if tensorrt_llm.mpi_rank() == 0: + tmp_dir = tempfile.TemporaryDirectory() + nemo_export_dir = Path(tmp_dir.name) - if nemo_checkpoint_path.endswith("qnemo"): - if os.path.isdir(nemo_checkpoint_path): - nemo_export_dir = nemo_checkpoint_path + if nemo_checkpoint_path.endswith("qnemo"): + if os.path.isdir(nemo_checkpoint_path): + nemo_export_dir = nemo_checkpoint_path + else: + unpack_tarball(nemo_checkpoint_path, tmp_dir.name) + nemo_checkpoint_path = tmp_dir.name + self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path) + + qnemo_to_tensorrt_llm( + nemo_checkpoint_path=nemo_checkpoint_path, + engine_dir=self.model_dir, + max_input_len=max_input_token, + max_output_len=max_output_token, + max_batch_size=max_batch_size, + max_prompt_embedding_table_size=max_prompt_embedding_table_size, + lora_target_modules=lora_target_modules, + ) else: - unpack_tarball(nemo_checkpoint_path, tmp_dir.name) - nemo_checkpoint_path = tmp_dir.name - self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path) - - qnemo_to_tensorrt_llm( - nemo_checkpoint_path=nemo_checkpoint_path, - engine_dir=self.model_dir, - max_input_len=max_input_token, - max_output_len=max_output_token, - max_batch_size=max_batch_size, - max_prompt_embedding_table_size=max_prompt_embedding_table_size, - lora_target_modules=lora_target_modules, - ) - else: - model_configs, self.tokenizer = nemo_llm_to_model_config( - in_file=nemo_checkpoint_path, - decoder_type=model_type, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - pipeline_parallel_size=pipeline_parallel_size, - nemo_export_dir=nemo_export_dir, - save_nemo_model_config=save_nemo_model_config, - ) + weights_dicts, model_configs, self.tokenizer = nemo_to_trtllm_config( + in_file=nemo_checkpoint_path, + decoder_type=model_type, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + use_parallel_embedding=use_parallel_embedding, + nemo_export_dir=nemo_export_dir, + save_nemo_model_config=save_nemo_model_config, + ) - model_config_to_tensorrt_llm( - model_configs, - self.model_dir, - world_size=tensor_parallel_size * pipeline_parallel_size, - max_input_len=max_input_token, - max_output_len=max_output_token, - max_batch_size=max_batch_size, - max_prompt_embedding_table_size=max_prompt_embedding_table_size, - use_inflight_batching=use_inflight_batching, - paged_kv_cache=paged_kv_cache, - enable_context_fmha=enable_context_fmha, - enable_multi_block_mode=enable_multi_block_mode, - use_lora_plugin=use_lora_plugin, - lora_target_modules=lora_target_modules, - max_lora_rank=max_lora_rank, - ) + for weight_dict, model_config in zip(weights_dicts, model_configs): + build_and_save_engine( + max_input_len=max_input_token, + max_output_len=max_output_token, + max_batch_size=max_batch_size, + model_config=model_config, + model_weights=weight_dict, + model_dir=self.model_dir, + model_type=model_type, + lora_ckpt_list=self.lora_ckpt_list, + use_lora_plugin=use_lora_plugin, + max_lora_rank=max_lora_rank, + lora_target_modules=lora_target_modules, + max_prompt_embedding_table_size=max_prompt_embedding_table_size, + enable_multi_block_mode=enable_multi_block_mode, + ) - tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model") - if os.path.exists(tokenizer_path): - shutil.copy(tokenizer_path, self.model_dir) - else: - self.tokenizer.save_pretrained(os.path.join(self.model_dir, 'huggingface_tokenizer')) + tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model") + if os.path.exists(tokenizer_path): + shutil.copy(tokenizer_path, self.model_dir) + else: + self.tokenizer.save_pretrained(os.path.join(self.model_dir, 'huggingface_tokenizer')) + + nemo_model_config = os.path.join(nemo_export_dir, "model_config.yaml") + if os.path.exists(nemo_model_config): + shutil.copy(nemo_model_config, self.model_dir) - nemo_model_config = os.path.join(nemo_export_dir, "model_config.yaml") - if os.path.exists(nemo_model_config): - shutil.copy(nemo_model_config, self.model_dir) + tmp_dir.cleanup() - tmp_dir.cleanup() + if tensorrt_llm.mpi_world_size() > 1: + tensorrt_llm.mpi_barrier() if load_model: self._load() @@ -279,7 +286,9 @@ def build( # Build or refit TRT-LLM engine from a nemo model. model_configs = nemo_llm_model_to_model_config( - nemo_model=nemo_model, decoder_type=model_type, nemo_model_config=nemo_model_config, + nemo_model=nemo_model, + decoder_type=model_type, + nemo_model_config=nemo_model_config, ) model_config_to_tensorrt_llm( @@ -298,7 +307,9 @@ def build( ) def refit( - self, nemo_model, nemo_model_config, + self, + nemo_model, + nemo_model_config, ): assert self.use_refit, "TRT-LLM model must be built() with refit=True" @@ -329,7 +340,6 @@ def forward( output_log_probs: bool = False, **sampling_kwargs, ): - """ Exports nemo checkpoints to TensorRT-LLM. @@ -394,7 +404,7 @@ def forward( ), "Task: {0} doesn't exist in the task list.".format(task_ids[i]) input_task_ids.append(self.task_ids[task_ids[i]]) if not streaming: - if torch.distributed.is_initialized(): + if torch.distributed.is_initialized() or tensorrt_llm.mpi_world_size() > 1: multiprocessed_env = True else: multiprocessed_env = False @@ -478,7 +488,7 @@ def get_hidden_size(self): if self.config is None: return None else: - return self.config["builder_config"]["hidden_size"] + return self.config["pretrained_config"]["hidden_size"] @property def get_triton_input(self): @@ -665,7 +675,9 @@ def _get_prompt_embedding_table_ckpt(self, prompt_embeddings_checkpoint_path): return weights.cpu().detach() def _get_prompt_embedding_table( - self, prompt_embeddings_table=None, prompt_embeddings_checkpoint_path=None, + self, + prompt_embeddings_table=None, + prompt_embeddings_checkpoint_path=None, ): if prompt_embeddings_table is not None and prompt_embeddings_checkpoint_path is not None: LOGGER.warning( @@ -694,15 +706,15 @@ def _get_prompt_embedding_table( raise TypeError(prompt_embeddings_checkpoint_path + " is not a nemo file.") prompt_embeddings_table = self._get_prompt_embedding_table_ckpt(prompt_embeddings_checkpoint_path) - dtype = self.config['builder_config']['precision'] + dtype = self.config['pretrained_config']['dtype'] prompt_embeddings_table = prompt_embeddings_table.to( dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype) ).cuda() - if prompt_embeddings_table.size(dim=1) != self.config["builder_config"]["hidden_size"]: + if prompt_embeddings_table.size(dim=1) != self.config["pretrained_config"]["hidden_size"]: raise Exception( "Hidden dimension of the model is {0} and does not match with the dimension of the prompt table.".format( - self.config["builder_config"]["hidden_size"] + self.config["pretrained_config"]["hidden_size"] ) ) diff --git a/nemo/export/trt_llm/decoder/__init__.py b/nemo/export/trt_llm/decoder/__init__.py index 5fe749408cb9e..b5e22b5e513e1 100644 --- a/nemo/export/trt_llm/decoder/__init__.py +++ b/nemo/export/trt_llm/decoder/__init__.py @@ -40,6 +40,14 @@ DECODER_GEMMA: GemmaDecoderLayerConfigBuilder, } +DECODER_MODEL_TYPE = { + DECODER_GPT2: 'GPTForCausalLM', + DECODER_GPTNEXT: 'GPTForCausalLM', + DECODER_LLAMA: 'LLaMAForCausalLM', + DECODER_GEMMA: 'GemmaForCausalLM', + DECODER_FALCON: 'FalconForCausalLM', +} + def build_decoder_layer_config(layer, decoder: str, dtype=trt.float16, rank=0, tensor_parallel=1): """Builds the decoder layer config with the input torch module.""" diff --git a/nemo/export/trt_llm/nemo/convert.py b/nemo/export/trt_llm/nemo/convert.py index 09476da6b939f..7598b3f6825f7 100644 --- a/nemo/export/trt_llm/nemo/convert.py +++ b/nemo/export/trt_llm/nemo/convert.py @@ -39,12 +39,12 @@ def gpu_map_location(storage, loc): def save_val(val, dir, key, tp_num=None): - suffix = "bin" if tp_num is None else f"{tp_num}.bin" + suffix = "" if tp_num is None else f".{tp_num}.bin" # Transpose linear layer weights to the correct shape. if len(val.shape) >= 2: val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0])) global weights_dict - weights_dict[f"model.{key}.{suffix}"] = val + weights_dict[f"{key}{suffix}"] = val def save_split(split_vals, dir, key, i, split_factor): @@ -55,10 +55,10 @@ def save_split(split_vals, dir, key, i, split_factor): def save_expert_split(split_vals, dir, key, i, split_factor): for j, val in enumerate(split_vals): tp_num = i * split_factor + j - suffix = "bin" if tp_num is None else f"{tp_num}.bin" + suffix = "" if tp_num is None else f".{tp_num}.bin" global weights_dict - weights_dict[f"model.{key}.{suffix}"] = val + weights_dict[f"{key}{suffix}"] = val def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): @@ -183,6 +183,9 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" + layer_num = key.split(".")[1] + layer_prefix = f'transformer.layers.{layer_num}' + if not isinstance(vals, list): vals = [vals] @@ -210,12 +213,27 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t or "final_layernorm.bias" in key ): # shared weights, only need to convert the weights of rank 0 - if "post_self_attn_layernorm.weight" in key: - key = key.replace("post_self_attn_layernorm.weight", "post_attention_layernorm.weight") - elif "mlp.linear_fc2.bias" in key: - key = key.replace("mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias") - elif "attention.linear_proj.bias" in key: - key = key.replace("attention.linear_proj.bias", "attention.dense.bias") + if "post_self_attn_layernorm" in key or "post_attention_layernorm" in key: + if key.endswith('weight'): + key = f'{layer_prefix}.post_layernorm.weight' + else: + key = f'{layer_prefix}.post_layernorm.bias' + elif "mlp.linear_fc2.bias" in key or "mlp.dense_4h_to_h.bias" in key: + key = f'{layer_prefix}.mlp.proj.bias' + elif "attention.linear_proj.bias" in key or "attention.dense.bias" in key: + key = f'{layer_prefix}.attention.dense.bias' + elif "final_layernorm" in key: + key = key.replace("final_layernorm", "transformer.ln_f") + elif "input_layernorm" in key: + if key.endswith('weight'): + key = f'{layer_prefix}.input_layernorm.weight' + else: + key = f'{layer_prefix}.input_layernorm.bias' + elif "pre_mlp_layernorm" in key: + if key.endswith('weight'): + key = f'{layer_prefix}.post_layernorm.weight' + else: + key = f'{layer_prefix}.post_layernorm.bias' if tp_rank == 0: save_val(vals[0], saved_dir, key) @@ -228,10 +246,10 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t cat_dim = 0 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - if "attention.linear_proj.weight" in key: - key = key.replace("attention.linear_proj.weight", "attention.dense.weight") - elif "mlp.linear_fc2.weight" in key: - key = key.replace("mlp.linear_fc2.weight", "mlp.dense_4h_to_h.weight") + if "attention.linear_proj.weight" in key or "attention.dense.weight" in key: + key = f'{layer_prefix}.attention.dense.weight' + elif "mlp.linear_fc2.weight" in key or "mlp.dense_4h_to_h.weight" in key: + key = f'{layer_prefix}.mlp.proj.weight' save_split(split_vals, saved_dir, key, tp_rank, split_factor) if act_range is not None and int8_outputs == "all": base_key = key.replace(".weight", "") @@ -251,8 +269,10 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - if "mlp.linear_fc1" in key: - key = key.replace("mlp.linear_fc1", "mlp.dense_h_to_4h") + if key.endswith("weight"): + key = f'{layer_prefix}.mlp.fc.weight' + else: + key = f'{layer_prefix}.mlp.fc.bias' save_split(split_vals, saved_dir, key, tp_rank, split_factor) if act_range is not None and int8_outputs == "all": base_key = key.replace(".weight", "") @@ -261,8 +281,10 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if split_gated_activation: assert not save_int8 - prefix, dot, suffix = key.rpartition(".") - key = prefix + ".gate" + dot + suffix + if key.endswith("weight"): + key = f'{layer_prefix}.mlp.gate.weight' + else: + key = f'{layer_prefix}.mlp.gate.bias' gate = np.concatenate(gates, axis=cat_dim) split_vals = np.split(gate, split_factor, axis=cat_dim) @@ -279,9 +301,6 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) elif "attention.query_key_value.bias" in key or "attention.linear_qkv.bias" in key: - if "attention.linear_qkv.bias" in key: - key = key.replace("attention.linear_qkv.bias", "attention.query_key_value.bias") - qkv_hidden_dim = vals[0].shape[0] size_per_head = qkv_hidden_dim // (num_attention_heads + 2 * num_kv_heads) q_num = num_attention_heads // num_kv_heads @@ -304,6 +323,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t np.concatenate([q_split[i].reshape(-1), k_split[i].reshape(-1), v_split[i].reshape(-1)], axis=0) for i in range(split_factor) ] + key = f'{layer_prefix}.attention.qkv.bias' save_split(split_vals, saved_dir, key, tp_rank, split_factor) elif "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key: @@ -342,8 +362,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t for i in range(split_factor) ] - if "attention.linear_qkv.weight" in key: - key = key.replace("attention.linear_qkv.weight", "attention.query_key_value.weight") + key = f'{layer_prefix}.attention.qkv.weight' save_split(split_vals, saved_dir, key, tp_rank, split_factor) if save_int8: base_key = key.replace(".weight", "") @@ -366,8 +385,8 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t pass elif "mlp.router.weight" in key: val = np.concatenate(vals, axis=1) - split_vals = np.split(val, split_factor, axis=0) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + key = f'{layer_prefix}.mlp.router.weight' + save_val(val, saved_dir, key) elif "experts.linear_fc1.weight" in key: cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) @@ -378,12 +397,14 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_w3s = np.split(w3, split_factor, axis=1) split_vals = [np.concatenate(item, axis=1) for item in zip(split_w3s, split_w1s)] + key = f'{layer_prefix}.mlp.experts_weight_1' save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor) elif "experts.linear_fc2.weight" in key: cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) + key = f'{layer_prefix}.mlp.experts_weight_2' save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor) else: print(f"[WARNING] {key} not handled by converter") diff --git a/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py b/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py index d9135d5c0c21a..44133de381bd3 100644 --- a/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py +++ b/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py @@ -27,7 +27,7 @@ import tensorstore # This is important even though not used. Otherwise zarr raises error. import torch import zarr -from tensorrt_llm._utils import np_bfloat16, str_dtype_to_torch, torch_to_numpy +from tensorrt_llm._utils import np_bfloat16, pad_vocab_size, str_dtype_to_torch, torch_to_numpy from tqdm import tqdm from transformers import AutoTokenizer, GPT2Tokenizer, LlamaConfig @@ -174,6 +174,7 @@ def convert_dist_checkpoint(unpacked_checkpoints_dir: UnpackedNemoCheckpointDir, multi_query_mode = nemo_model_config.get("multi_query_mode", False) num_attention_heads = nemo_model_config["num_attention_heads"] kv_channels = nemo_model_config.get("kv_channels", None) + use_parallel_embedding = args.use_parallel_embedding if num_kv_heads == 0: if multi_query_mode: num_kv_heads = 1 @@ -191,6 +192,7 @@ def convert_dist_checkpoint(unpacked_checkpoints_dir: UnpackedNemoCheckpointDir, "kv_channels": kv_channels, "use_attention_nemo_shape": True, "transpose_weights": True, + "use_parallel_embedding": use_parallel_embedding, } # split_factor: in how many parts a TP training node is split @@ -202,22 +204,30 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): if has_position_embedding: val = model[get_layer_name("position_embedding", prefix)] val = torch_to_numpy(val.to(storage_type).cpu()) - model_level_weights["model.wpe.bin"].append(val) + model_level_weights["transformer.position_embedding.weight"].append(val) if pp_idx == 0: val = model.get("state_dict", model)[get_layer_name("word_embedding", prefix)] if embedding_scaling: val = val * float(math.sqrt(hidden_size)) + vocab_size = val.shape[0] + if use_parallel_embedding: + # Pad vocab_size first + if vocab_size % inference_tp_size != 0: + vocab_size_padded = pad_vocab_size(vocab_size, inference_tp_size) + pad_width = vocab_size_padded - vocab_size + val = torch.nn.functional.pad(val, (0, 0, 0, pad_width), value=0) + val = torch_to_numpy(val.to(storage_type).cpu()) - model_level_weights["model.wte.bin"].append(val) + model_level_weights["transformer.vocab_embedding.weight"].append(val) if share_embeddings_and_output: val = model.get("state_dict", model)[get_layer_name("word_embedding", prefix)] val = torch_to_numpy(val.to(storage_type).cpu()) - model_level_weights["model.lm_head.weight.bin"].append(val) + model_level_weights["lm_head.weight"].append(val) if has_lm_head and pp_idx == training_pp_size - 1: val = model.get("state_dict", model)[get_layer_name("output_layer", prefix)] val = torch_to_numpy(val.to(storage_type).cpu()) - model_level_weights["model.lm_head.weight.bin"].append(val) + model_level_weights["lm_head.weight"].append(val) weights_dict = {} @@ -280,7 +290,6 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): model_level_weights[key] = np.concatenate(values, axis=0) weights_dict[key] = model_level_weights[key] - vocab_size = model_level_weights["model.wte.bin"].shape[0] if nemo_model_config["tokenizer"].get("library", None) == "huggingface": tokenizer = AutoTokenizer.from_pretrained( @@ -293,23 +302,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): tokenizer_config["model"] = os.path.join(out_dir, "tokenizer.model") tokenizer = build_tokenizer(tokenizer_config) - llm_config = nemo_to_llm_config( - nemo_model_config, vocab_size, tokenizer.eos_token_id, tokenizer.bos_token_id, args.decoder_type, - ) - - llm_config.is_mcore = is_mcore - - config = configparser.ConfigParser() - decoder_name_dict = {"llama": "llama", "falcon": "falcon"} - model_name = decoder_name_dict[args.decoder_type] if args.decoder_type in decoder_name_dict else "gpt" - - config[model_name] = {k: str(v) for k, v in vars(llm_config).items()} - config[model_name]["storage_dtype"] = args.storage_type - config_path = out_dir / "config.ini" - with config_path.open("w") as config_file: - config.write(config_file) - - return weights_dict, llm_config, tokenizer + return weights_dict, nemo_model_config, tokenizer @torch.no_grad() diff --git a/nemo/export/trt_llm/nemo_utils.py b/nemo/export/trt_llm/nemo_utils.py index ee2073fa518d0..d735cab36b006 100644 --- a/nemo/export/trt_llm/nemo_utils.py +++ b/nemo/export/trt_llm/nemo_utils.py @@ -28,9 +28,14 @@ import numpy as np import tensorrt_llm from tensorrt_llm import str_dtype_to_trt -from transformers import AutoTokenizer, LlamaConfig, PretrainedConfig, PreTrainedTokenizer +from tensorrt_llm._utils import pad_vocab_size +from tensorrt_llm.functional import non_gated_version +from tensorrt_llm.layers import MoeConfig +from tensorrt_llm.models.modeling_utils import PretrainedConfig +from transformers import AutoTokenizer, LlamaConfig, PreTrainedTokenizer from nemo.export.tarutils import TarPath +from nemo.export.trt_llm.decoder import DECODER_MODEL_TYPE from nemo.export.trt_llm.model_config import ( LAYERNORM_DEFAULT, LAYERNORM_RMS, @@ -56,6 +61,7 @@ def _nemo_llm_decode( storage_type: str = "bfloat16", load_checkpoints_on_gpu: bool = False, decoder_type: str = "gptnext", + use_parallel_embedding: bool = False, save_nemo_model_config: bool = False, ) -> Tuple[Dict[str, np.ndarray], PretrainedConfig, PreTrainedTokenizer]: """Decodes the NEMO file and returns the weights dict, llm config and tokenizer.""" @@ -67,6 +73,7 @@ def _nemo_llm_decode( args.load_checkpoints_on_gpu = load_checkpoints_on_gpu args.verbose = False args.decoder_type = decoder_type + args.use_parallel_embedding = use_parallel_embedding if not os.path.exists(in_file): LOGGER.error("%s does not exist", in_file) @@ -194,7 +201,9 @@ def nemo_llm_to_model_config( def to_word_list_format( - word_dict: List[List[str]], tokenizer=None, ref_str="", + word_dict: List[List[str]], + tokenizer=None, + ref_str="", ): ''' format of word_dict @@ -250,7 +259,10 @@ def to_word_list_format( def nemo_llm_model_to_model_config( - nemo_model: str, decoder_type: str, nemo_model_config: str, dtype_str: str = "float32", + nemo_model: str, + decoder_type: str, + nemo_model_config: str, + dtype_str: str = "float32", ) -> Tuple[List[ModelConfig], PreTrainedTokenizer]: """Converts the NEMO model object and construct the `ModelConfig` before tensorrt_llm deployment.""" from megatron.core import parallel_state @@ -297,8 +309,8 @@ def nemo_llm_model_to_model_config( LOGGER.info( f'''Resharing: Rank {tensorrt_llm.mpi_rank()} mapping: - tp_rank {parallel_state.get_tensor_model_parallel_rank()} -> {model_config.mapping.tp_rank}, - pp_rank {parallel_state.get_pipeline_model_parallel_rank()} -> {model_config.mapping.pp_rank}, + tp_rank {parallel_state.get_tensor_model_parallel_rank()} -> {model_config.mapping.tp_rank}, + pp_rank {parallel_state.get_pipeline_model_parallel_rank()} -> {model_config.mapping.pp_rank}, tp_group {model_config.mapping.tp_group}''' ) @@ -321,3 +333,161 @@ def nemo_llm_model_to_model_config( model_config.lm_head.weight = lm_head_weight return [model_config] + + +def nemo_to_trtllm_config( + in_file: str, + decoder_type: str, + nemo_export_dir: Union[str, Path], + dtype: str = "bfloat16", + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + use_parallel_embedding: bool = False, + save_nemo_model_config: bool = False, +) -> Tuple[List[Dict], List[PretrainedConfig], PreTrainedTokenizer]: + """Converts the NEMO file and construct the `PretrainedConfig` before tensorrt_llm deployment.""" + dtype_str = dtype + + weights_dict, nemo_model_config, tokenizer = _nemo_llm_decode( + in_file=in_file, + out_dir=nemo_export_dir, + tensor_parallelism=tensor_parallel_size, + processes=1, + storage_type=dtype_str, + use_parallel_embedding=use_parallel_embedding, + load_checkpoints_on_gpu=False, + decoder_type=decoder_type, + save_nemo_model_config=save_nemo_model_config, + ) + + world_size = tensor_parallel_size * pipeline_parallel_size + + lm_head_weight = weights_dict["lm_head.weight"] + + vocab_size = weights_dict["transformer.vocab_embedding.weight"].shape[0] + vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) + + if vocab_size_padded != vocab_size: + pad_width = vocab_size_padded - vocab_size + lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0) + + hidden_act = nemo_model_config.get('activation') + hidden_act = ( + hidden_act.split("-")[-1] if nemo_model_config.get('num_moe_experts', 0) else non_gated_version(hidden_act) + ) + + config = { + 'architecture': DECODER_MODEL_TYPE[decoder_type], + 'dtype': dtype_str, + 'num_hidden_layers': nemo_model_config.get('num_layers'), + 'num_attention_heads': nemo_model_config.get('num_attention_heads'), + 'num_key_value_heads': nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), + 'head_size': nemo_model_config.get('kv_channels'), + 'hidden_size': nemo_model_config.get('hidden_size'), + 'intermediate_size': nemo_model_config.get('ffn_hidden_size'), + 'norm_epsilon': nemo_model_config.get('layernorm_epsilon'), + 'vocab_size': vocab_size_padded, + 'position_embedding_type': ( + "rope_gpt_neox" if nemo_model_config.get('position_embedding_type') == "rope" else "learned_absolute" + ), + 'max_position_embeddings': nemo_model_config.get('max_position_embeddings'), + 'hidden_act': hidden_act, + 'use_parallel_embedding': use_parallel_embedding, + 'embedding_sharding_dim': 0, + 'share_embedding_table': False, + 'quantization': { + 'quant_algo': None, + 'kv_cache_quant_algo': None, + }, + 'bias': nemo_model_config.get('bias'), + 'apply_query_key_layer_scaling': False, + 'rotary_pct': nemo_model_config.get('rotary_percentage', 1.0), + 'rotary_base': nemo_model_config.get('rotary_base', 10000), + 'moe_num_experts': nemo_model_config.get('num_moe_experts', 0), + 'moe_top_k': nemo_model_config.get('moe_router_topk'), + 'moe_normalization_mode': nemo_model_config.get( + 'moe_renorm_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE + ), + 'moe_tp_mode': nemo_model_config.get('moe_tp_mode', MoeConfig.ParallelismMode.TENSOR_PARALLEL), + 'logits_dtype': 'float32', + 'world_size': world_size, + 'tp_size': tensor_parallel_size, + 'pp_size': pipeline_parallel_size, + } + + model_configs = [] + weights_dicts = [] + num_layers = nemo_model_config.get('num_layers') + rotary_scaling = nemo_model_config.get("seq_len_interpolation_factor") + + if decoder_type == "falcon": + config["new_decoder_architecture"] = False if num_layers == 32 else True + config["parallel_attention"] = True + if rotary_scaling is not None: + config["rotary_scaling"] = {"type": "linear", "factor": float(rotary_scaling)} + + pp_key = { + "transformer.vocab_embedding.weight", + "transformer.position_embedding.weight", + "lm_head.weight", + "transformer.ln_f.weight", + "transformer.ln_f.bias", + } + + for i in range(world_size): + mapping = tensorrt_llm.Mapping( + world_size=world_size, rank=i, tp_size=tensor_parallel_size, pp_size=pipeline_parallel_size + ) + layers_range = mapping.pp_layers(num_layers) + + weights_dict_local = {} + for k, v in weights_dict.items(): + if k in pp_key: + continue + new_key = k + if new_key.endswith(".bin"): # TP split + if new_key.endswith(f"{mapping.tp_rank}.bin"): + new_key = new_key.replace(f".{mapping.tp_rank}.bin", "") + if "layers" in new_key: # PP + layer_num = int(new_key.split(".")[2]) + if layer_num in layers_range: + new_key = new_key.replace(f"layers.{layer_num}", f"layers.{layer_num-layers_range[0]}") + if config.get("new_decoder_architecture", False) and "post_layernorm" in new_key: + new_key = new_key.replace("post_layernorm", "mlp_layernorm") + weights_dict_local[new_key] = v + + if mapping.is_first_pp_rank(): + embedding_weight = ( + np.ascontiguousarray( + split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank) + ) + if use_parallel_embedding + else weights_dict["transformer.vocab_embedding.weight"] + ) + + weights_dict_local["transformer.vocab_embedding.weight"] = embedding_weight + + pos_embedding_weight = weights_dict.get("transformer.position_embedding.weight") + if pos_embedding_weight is not None: + if use_parallel_embedding: + pos_embedding_weight = np.ascontiguousarray( + split(pos_embedding_weight, mapping.tp_size, mapping.tp_rank) + ) + weights_dict_local["transformer.position_embedding.weight"] = pos_embedding_weight + + if mapping.is_last_pp_rank(): + weights_dict_local["lm_head.weight"] = np.ascontiguousarray( + split(lm_head_weight, mapping.tp_size, mapping.tp_rank) + ) + weights_dict_local["transformer.ln_f.weight"] = weights_dict["transformer.ln_f.weight"] + + ln_f_bias = weights_dict.get("transformer.ln_f.bias") + if ln_f_bias is not None: + weights_dict_local["transformer.ln_f.bias"] = ln_f_bias + + model_config = PretrainedConfig(**config) + model_config.mapping = mapping + model_configs.append(model_config) + weights_dicts.append(weights_dict_local) + + return weights_dicts, model_configs, tokenizer diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index 3ad27a2eb9a68..ac8d9094ea32f 100644 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -25,10 +25,13 @@ import torch from tensorrt_llm import str_dtype_to_trt from tensorrt_llm._utils import np_dtype_to_trt -from tensorrt_llm.builder import Builder +from tensorrt_llm.builder import BuildConfig, Builder +from tensorrt_llm.commands.build import build as build_trtllm from tensorrt_llm.logger import logger -from tensorrt_llm.models.modeling_utils import add_lora +from tensorrt_llm.lora_manager import LoraBuildConfig +from tensorrt_llm.models.modeling_utils import add_lora, optimize_model, preprocess_weights from tensorrt_llm.network import net_guard +from tensorrt_llm.plugin import PluginConfig from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode @@ -57,11 +60,11 @@ def serialize_engine(engine, path): def refit_runtime_engine(params, cuda_engine): ''' - @brief: Inplace refit one TensorRT cuda engine using weights from the network, - user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine. - @param engine_buffer: A serialized TensorRT engine. - @param network: Network object. - @return: A serialized TRT engine if refit successfully, None otherwise + @brief: Inplace refit one TensorRT cuda engine using weights from the network, + user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine. + @param engine_buffer: A serialized TensorRT engine. + @param network: Network object. + @return: A serialized TRT engine if refit successfully, None otherwise ''' logger.info(f'Refit runtime engine') tik = time.time() @@ -88,7 +91,11 @@ def refit_runtime_engine(params, cuda_engine): def build_rank_engine( - tensorrt_llm_gpt, builder: Builder, builder_config: tensorrt_llm.builder.BuilderConfig, engine_name, args, + tensorrt_llm_gpt, + builder: Builder, + builder_config: tensorrt_llm.builder.BuilderConfig, + engine_name, + args, ): str_dtype_to_trt(args.dtype) @@ -348,3 +355,70 @@ def build( tok = time.time() t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) logger.info(f"Total time of building all {args.mapping.world_size} engines: {t}") + + +def build_and_save_engine( + max_input_len=1024, + max_output_len=1024, + max_batch_size=4, + model_dir=None, + model_weights=None, + model_config=None, + model_type='gpt', + lora_ckpt_list=None, + use_lora_plugin=None, + max_lora_rank=64, + lora_target_modules=None, + max_prompt_embedding_table_size=0, + enable_multi_block_mode: bool = False, +): + try: + model_cls = getattr(tensorrt_llm.models, model_config.architecture) + except: + raise AttributeError(f"Could not find TRTLLM model type: {model_type}!") + + logger.set_level("info") + str_dtype = model_config.dtype + plugin_config = PluginConfig() + plugin_config.set_gpt_attention_plugin(dtype=str_dtype) + plugin_config.set_gemm_plugin(dtype=str_dtype) + plugin_config.set_plugin("multi_block_mode", enable_multi_block_mode) + max_num_tokens = max_batch_size * max_input_len + + build_dict = { + 'max_input_len': max_input_len, + 'max_output_len': max_output_len, + 'max_batch_size': max_batch_size, + 'max_beam_width': 1, + 'max_num_tokens': max_num_tokens, + 'opt_num_tokens': None, + 'max_prompt_embedding_table_size': max_prompt_embedding_table_size, + 'gather_context_logits': False, + 'gather_generation_logits': False, + 'strongly_typed': False, + 'builder_opt': None, + } + build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config) + + if use_lora_plugin is not None: + build_config.plugin_config.set_lora_plugin(use_lora_plugin) + lora_config = LoraBuildConfig( + lora_dir=lora_ckpt_list, + lora_ckpt_source='nemo', + max_lora_rank=max_lora_rank, + lora_target_modules=lora_target_modules, + ) + build_config.lora_config = lora_config + + model = model_cls.from_config(model_config) + model = optimize_model( + model, + use_parallel_embedding=model_config.use_parallel_embedding, + share_embedding_table=model_config.share_embedding_table, + ) + preprocess_weights(model_weights, model_config) + model.load(model_weights) + engine = build_trtllm(model, build_config) + engine.save(model_dir) + + return engine diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index c490f37e1fc44..92fc36272f7c6 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -26,7 +26,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import LoraManager from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import ModelConfig, SamplingConfig +from tensorrt_llm.runtime import ModelConfig, ModelRunnerCpp, SamplingConfig from transformers import PreTrainedTokenizer from nemo.export.trt_llm.tensor_utils import get_tensor_parallel_group @@ -55,7 +55,7 @@ class TensorrtLLMHostContext: class TensorrtLLMWorkerContext: """The MPI worker side context for TRT LLM inference.""" - decoder: tensorrt_llm.runtime.GenerationSession = None + decoder: ModelRunnerCpp = None sampling_config: SamplingConfig = None lora_manager: LoraManager = None max_batch_size: int = 0 @@ -135,42 +135,38 @@ def _load(tokenizer: PreTrainedTokenizer, engine_dir, lora_ckpt_list=None, num_b engine_dir = Path(engine_dir) config_path = engine_dir / "config.json" - model_config, world_size, tp_size, pp_size, dtype, max_input_len, max_batch_size = _read_config(config_path) + # model_config, world_size, tp_size, pp_size, dtype, max_input_len, max_batch_size = _read_config(config_path) - runtime_rank = tensorrt_llm.mpi_rank() + with open(config_path, "r") as f: + config = json.load(f) - assert runtime_rank < torch.cuda.device_count(), f"Rank {runtime_rank} out of bound" - runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank, tp_size=tp_size, pp_size=pp_size) + max_batch_size = config["build_config"]["max_batch_size"] + max_input_len = config["build_config"]["max_input_len"] + max_output_len = config["build_config"]["max_output_len"] + max_beam_width = config["build_config"]["max_beam_width"] - torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) - engine_name = get_engine_name(MODEL_NAME, dtype, tp_size, pp_size, runtime_rank) - serialize_path = os.path.join(engine_dir, engine_name) - logger.info(f"Reading from serialize path {serialize_path}") + runtime_rank = tensorrt_llm.mpi_rank() - with open(serialize_path, "rb") as f: - engine_buffer = f.read() - decoder = tensorrt_llm.runtime.GenerationSession( - model_config, engine_buffer, runtime_mapping, debug_mode=False + decoder = ModelRunnerCpp.from_dir( + engine_dir=engine_dir, + lora_dir=lora_ckpt_list, + lora_ckpt_source="nemo", + rank=runtime_rank, + max_batch_size=max_batch_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_beam_width=max_beam_width, + debug_mode=False, ) sampling_config = SamplingConfig( end_id=tokenizer.eos_token_id, pad_id=tokenizer.eos_token_id, num_beams=num_beams ) - if decoder.use_lora_plugin: - lora_manager = LoraManager() - if lora_ckpt_list is not None: - lora_manager.load_from_nemo( - model_files=lora_ckpt_list, model_config=model_config, runtime_mapping=runtime_mapping, - ) - else: - lora_manager = None - # Initialize the global context so it can be used during `run` API. global tensorrt_llm_worker_context tensorrt_llm_worker_context.decoder = decoder tensorrt_llm_worker_context.sampling_config = sampling_config - tensorrt_llm_worker_context.lora_manager = lora_manager tensorrt_llm_worker_context.max_batch_size = max_batch_size tensorrt_llm_worker_context.max_input_len = max_input_len @@ -207,7 +203,6 @@ def _forward( decoder = tensorrt_llm_worker_context.decoder assert decoder is not None, "Invalid worker context, decoder is not loaded." sampling_config = tensorrt_llm_worker_context.sampling_config - lora_manager = tensorrt_llm_worker_context.lora_manager max_batch_size = tensorrt_llm_worker_context.max_batch_size max_input_len = tensorrt_llm_worker_context.max_input_len @@ -217,60 +212,36 @@ def _forward( max_length = max(input_lengths) assert max_length <= max_input_len, f"input length {max_length} exceedng max input length {max_input_len}" pad_id = sampling_config.pad_id - - if decoder.remove_input_padding: - line_encoded = torch.concat(input_tensors).cuda() - else: - line_encoded = torch.nested.to_padded_tensor( - torch.nested.nested_tensor(input_tensors, dtype=torch.int32), pad_id - ).cuda() - - input_lengths = torch.tensor(input_lengths, dtype=torch.int32).cuda() - - if prompt_table is None: - ptuning_args = [] - else: - if task_vocab_size is None: - raise Exception("task_vocab_size cannot be None") - - task_vocab_size = torch.tensor([task_vocab_size], dtype=torch.int32, device="cuda") - task_ids = torch.tensor(task_ids, dtype=torch.int32, device="cuda") - prompt_table = prompt_table.cuda() - ptuning_args = [prompt_table, task_ids, task_vocab_size] + end_id = sampling_config.end_id + num_beams = sampling_config.num_beams with torch.no_grad(): - sampling_config.top_k = top_k - sampling_config.top_p = top_p - sampling_config.temperature = temperature - for key, param in sampling_kwargs.items(): - # set any additional SamplingConfig kwargs - setattr(sampling_config, key, param) - - decoder.setup( - batch_size, - max_context_length=max_length, - max_new_tokens=max_output_len, - lora_manager=lora_manager, - lora_uids=lora_uids, - ) + prompt_tasks = None if task_ids is None else ",".join(str(task) for task in task_ids) - outputs = decoder.decode( - line_encoded, - input_lengths, - sampling_config, - *ptuning_args, + outputs = decoder.generate( + input_tensors, + max_new_tokens=max_output_len, + end_id=end_id, + pad_id=pad_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + num_beams=num_beams, stop_words_list=stop_words_list, bad_words_list=bad_words_list, - no_repeat_ngram_size=no_repeat_ngram_size, + lora_uids=lora_uids, + prompt_table=prompt_table, + prompt_tasks=prompt_tasks, streaming=streaming, output_sequence_lengths=True, return_dict=True, ) + torch.cuda.synchronize() runtime_rank = tensorrt_llm.mpi_rank() if runtime_rank == 0 or multiprocessed_env: - return outputs, decoder.log_probs + return outputs else: return None @@ -290,10 +261,14 @@ def load( config_path = os.path.join(engine_dir, "config.json") with open(config_path, "r") as f: config = json.load(f) - world_size = config["builder_config"]["world_size"] + world_size = config["pretrained_config"]["mapping"]["world_size"] if world_size == 1: _load(tokenizer, engine_dir, lora_ckpt_list, num_beams) executor = None + elif tensorrt_llm.mpi_world_size() > 1: + _load(tokenizer, engine_dir, lora_ckpt_list, num_beams) + executor = None + tensorrt_llm.mpi_barrier() else: executor = MPIPoolExecutor(max_workers=world_size) futures = [] @@ -303,9 +278,9 @@ def load( for future in futures: future.result() - max_batch_size = config["builder_config"]["max_batch_size"] - max_input_len = config["builder_config"]["max_input_len"] - add_bos = config["builder_config"]["add_bos"] + max_batch_size = config["build_config"]["max_batch_size"] + max_input_len = config["build_config"]["max_input_len"] + add_bos = True if config["pretrained_config"]["architecture"] == "GemmaForCausalLM" else False return TensorrtLLMHostContext( executor=executor, @@ -355,7 +330,10 @@ def load_refit( # Manipulate the tensorrt_llm mapping to make it compatible with the multiprocessed env. assert tensorrt_llm.mpi_world_size() == torch.distributed.get_world_size(), "MPI world size mismatch" runtime_mapping = tensorrt_llm.Mapping( - world_size=tensorrt_llm.mpi_world_size(), rank=runtime_rank, tp_size=tensorrt_llm.mpi_world_size(), pp_size=1, + world_size=tensorrt_llm.mpi_world_size(), + rank=runtime_rank, + tp_size=tensorrt_llm.mpi_world_size(), + pp_size=1, ) engine_name = get_engine_name( @@ -386,7 +364,9 @@ def load_refit( lora_manager = LoraManager() if lora_ckpt_list is not None: lora_manager.load_from_nemo( - model_files=lora_ckpt_list, model_config=model_config, runtime_mapping=runtime_mapping, + model_files=lora_ckpt_list, + model_config=model_config, + runtime_mapping=runtime_mapping, ) else: lora_manager = None @@ -576,7 +556,7 @@ def generate( if no_repeat_ngram_size is not None: no_repeat_ngram_size = torch.IntTensor(no_repeat_ngram_size).to(torch.cuda.current_device()) - outputs, log_probs = forward( + outputs = forward( input_tensors=input_tensors, max_output_len=max_output_len, host_context=host_context, @@ -596,6 +576,8 @@ def generate( **sampling_kwargs, ) assert outputs is not None + if tensorrt_llm.mpi_rank() != 0: + return None output_ids = outputs['output_ids'] sequence_lengths = outputs['sequence_lengths'] @@ -656,7 +638,7 @@ def generate_streaming( if no_repeat_ngram_size is not None: no_repeat_ngram_size = torch.IntTensor(no_repeat_ngram_size).to(torch.cuda.current_device()) - outputs, log_probs = forward( + outputs = forward( input_tensors=input_tensors, max_output_len=max_output_len, host_context=host_context, diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index 9798473dd880e..5e5833444f658 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -78,7 +78,6 @@ def get_args(argv): '--use_lora_plugin', nargs='?', const=None, - default=False, choices=['float16', 'float32', 'bfloat16'], help="Activates the lora plugin which enables embedding sharing.", ) @@ -86,7 +85,16 @@ def get_args(argv): '--lora_target_modules', nargs='+', default=None, - choices=["attn_qkv", "attn_q", "attn_k", "attn_v", "attn_dense", "mlp_h_to_4h", "mlp_gate", "mlp_4h_to_h",], + choices=[ + "attn_qkv", + "attn_q", + "attn_k", + "attn_v", + "attn_dense", + "mlp_h_to_4h", + "mlp_gate", + "mlp_4h_to_h", + ], help="Add lora in which modules. Only be activated when use_lora_plugin is enabled.", ) parser.add_argument( From b1628cf231eff0ca96a94b5f840b0dcbb7f2d667 Mon Sep 17 00:00:00 2001 From: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> Date: Mon, 13 May 2024 22:35:05 -0500 Subject: [PATCH 05/18] Alit/optim 8k (#9166) * fix fuser issue with dynamo * optimized 4k seq len * optim 8k * add checkpointing * add ckpt arg * fix minor bug * minor fix * more optimized chkpting * Apply isort and black reformatting Signed-off-by: JRD971000 * addressing comments * Apply isort and black reformatting Signed-off-by: JRD971000 --------- Signed-off-by: JRD971000 Co-authored-by: Ali Taghibakhshi Co-authored-by: JRD971000 --- .../conf/megatron_griffin_config.yaml | 1 + .../megatron_griffin_finetuning_config.yaml | 1 + .../megatron_griffin_generate_config.yaml | 2 +- .../megatron/griffin/griffin_block.py | 165 +++++++++++++-- .../megatron/griffin/griffin_model.py | 26 ++- .../megatron/griffin/recurrent_layer.py | 20 +- .../megatron/griffin/recurrent_module.py | 194 ++++++++++++------ .../megatron_griffin_model.py | 20 +- 8 files changed, 324 insertions(+), 105 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_griffin_config.yaml b/examples/nlp/language_modeling/conf/megatron_griffin_config.yaml index c080ff846ba12..1d36204931623 100644 --- a/examples/nlp/language_modeling/conf/megatron_griffin_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_griffin_config.yaml @@ -108,6 +108,7 @@ model: # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. # 'full' will checkpoint the entire transformer layer. activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers activations_checkpoint_method: null # 'uniform', 'block' # 'uniform' divides the total number of transformer layers and checkpoints the input activation # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. diff --git a/examples/nlp/language_modeling/conf/megatron_griffin_finetuning_config.yaml b/examples/nlp/language_modeling/conf/megatron_griffin_finetuning_config.yaml index e144c784fb0c6..f92f971eb0594 100644 --- a/examples/nlp/language_modeling/conf/megatron_griffin_finetuning_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_griffin_finetuning_config.yaml @@ -117,6 +117,7 @@ model: # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. # 'full' will checkpoint the entire transformer layer. activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers activations_checkpoint_method: null # 'uniform', 'block' activations_checkpoint_method: null # 'uniform', 'block' # 'uniform' divides the total number of transformer layers and checkpoints the input activation # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. diff --git a/examples/nlp/language_modeling/conf/megatron_griffin_generate_config.yaml b/examples/nlp/language_modeling/conf/megatron_griffin_generate_config.yaml index b09cce5671c91..e22b615d48aa9 100644 --- a/examples/nlp/language_modeling/conf/megatron_griffin_generate_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_griffin_generate_config.yaml @@ -121,7 +121,7 @@ model: # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. # 'full' will checkpoint the entire transformer layer. activations_checkpoint_granularity: null # 'selective' or 'full' - activations_checkpoint_method: null # 'uniform', 'block' + activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers activations_checkpoint_method: null # 'uniform', 'block' # 'uniform' divides the total number of transformer layers and checkpoints the input activation # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity diff --git a/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_block.py b/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_block.py index 3fc26a51f3c1a..d8954ad1b3c3e 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_block.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_block.py @@ -11,17 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from megatron.core.models.common.language_module.language_module import LanguageModule -from megatron.core.transformer.custom_layers.transformer_engine import TENorm -from megatron.core.transformer.spec_utils import build_module -from megatron.core.transformer.transformer_config import TransformerConfig -from torch import nn - +from torch import Tensor, nn from nemo.collections.nlp.models.language_modeling.megatron.griffin.griffin_layer_spec import ( griffin_mqa_layer_with_transformer_engine_spec, griffin_recurrent_layer_with_transformer_engine_spec, ) +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults + +try: + from megatron.core import parallel_state, tensor_parallel + from megatron.core.models.common.language_module.language_module import LanguageModule + from megatron.core.packed_seq_params import PackedSeqParams + from megatron.core.transformer.custom_layers.transformer_engine import TENorm, te_checkpoint + from megatron.core.transformer.spec_utils import build_module + from megatron.core.transformer.transformer_config import TransformerConfig + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + TransformerConfig = ApexGuardDefaults + HAVE_MEGATRON_CORE = False def get_griffin_layers(num_layers): @@ -41,16 +50,22 @@ def get_griffin_layers(num_layers): def create_block( - config, layer_spec, layer_idx, + config, + layer_spec, + layer_idx, ): - block = build_module(layer_spec, config,) + block = build_module( + layer_spec, + config, + ) block.layer_number = layer_idx + 1 return block class GriffinStack(LanguageModule): def __init__( - self, config: TransformerConfig, + self, + config: TransformerConfig, ): super().__init__(config) @@ -58,17 +73,139 @@ def __init__( self.griffin_layers = get_griffin_layers(self.config.num_layers) self.layers = nn.ModuleList( - [create_block(self.config, layer_spec, layer_idx=i,) for i, layer_spec in enumerate(self.griffin_layers)] + [ + create_block( + self.config, + layer_spec, + layer_idx=i, + ) + for i, layer_spec in enumerate(self.griffin_layers) + ] ) self.final_layernorm = TENorm( - config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, ) + self.num_layers = len(self.layers) + + def _get_layer(self, layer_number: int): + return self.layers[layer_number] + + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor = None, + context_mask: Tensor = None, + rotary_pos_emb: Tensor = None, + packed_seq_params: PackedSeqParams = None, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward( + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + packed_seq_params, + ): + for index in range(start, end): + layer = self._get_layer(index) + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=None, + packed_seq_params=packed_seq_params, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + if self.config.fp8: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + packed_seq_params, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + packed_seq_params, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + l = 0 + while l < self.num_layers: + hidden_states, context = checkpoint_handler(custom(l, l + self.config.recompute_num_layers)) + + l += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for l in range(self.num_layers): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + if self.config.fp8 and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if l >= recompute_skip_num_layers and l < self.config.recompute_num_layers + recompute_skip_num_layers: + hidden_states, context = checkpoint_handler(custom(l, l + 1)) + else: + hidden_states, context = custom(l, l + 1)( + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + packed_seq_params, + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states def forward(self, hidden_states, attention_mask, rotary_pos_emb): - for layer in self.layers: + if ( + self.config.recompute_granularity == 'full' + and self.training + and not self.config.activations_checkpoint_recurrent + ): + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + ) + else: + for layer in self.layers: - hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb) + hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb) hidden_states = self.final_layernorm(hidden_states) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_model.py b/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_model.py index 4531b64d1d969..7a327a3a35cbf 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/griffin/griffin_model.py @@ -13,15 +13,23 @@ # limitations under the License. import math - import torch -from megatron.core import tensor_parallel -from megatron.core.jit import jit_fuser -from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding -from megatron.core.models.common.language_module.language_module import LanguageModule -from megatron.core.transformer.transformer_config import TransformerConfig -from torch import Tensor, nn +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults + +try: + from megatron.core import tensor_parallel + from megatron.core.jit import jit_fuser + from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + from megatron.core.models.common.language_module.language_module import LanguageModule + from megatron.core.transformer.transformer_config import TransformerConfig + from torch import Tensor, nn + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + TransformerConfig = ApexGuardDefaults + HAVE_MEGATRON_CORE = False from nemo.collections.nlp.models.language_modeling.megatron.griffin.griffin_block import GriffinStack @@ -142,7 +150,7 @@ def forward( position_ids: Tensor = None, attention_mask: Tensor = None, labels: Tensor = None, - **extra_arg + **extra_arg, ): if input_ids is None: input_ids = self.input_tensor diff --git a/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_layer.py b/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_layer.py index 8263f54889a0f..3a33f8966fd20 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_layer.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_layer.py @@ -14,13 +14,21 @@ from dataclasses import dataclass from typing import Union - -from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import make_viewless_tensor from torch import Tensor +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults + +try: + from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp + from megatron.core.transformer.module import MegatronModule + from megatron.core.transformer.spec_utils import ModuleSpec, build_module + from megatron.core.transformer.transformer_config import TransformerConfig + from megatron.core.utils import make_viewless_tensor + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + TransformerConfig = ApexGuardDefaults + HAVE_MEGATRON_CORE = False @dataclass diff --git a/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_module.py b/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_module.py index d91c077189177..033d3abec732d 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_module.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/griffin/recurrent_module.py @@ -17,33 +17,50 @@ from typing import Union import torch +import torch._dynamo from accelerated_scan.triton import scan from causal_conv1d import causal_conv1d_fn from einops import rearrange -from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl -from megatron.core.jit import jit_fuser -from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.transformer.transformer_config import TransformerConfig from torch import nn +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults + +try: + from megatron.core import tensor_parallel + from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl + from megatron.core.jit import jit_fuser + from megatron.core.transformer.identity_op import IdentityOp + from megatron.core.transformer.module import MegatronModule + from megatron.core.transformer.spec_utils import ModuleSpec, build_module + from megatron.core.transformer.transformer_config import TransformerConfig + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + TransformerConfig = ApexGuardDefaults + HAVE_MEGATRON_CORE = False + +torch._dynamo.config.suppress_errors = True + # Class copied from https://github.com/google-deepmind/recurrentgemma class BlockDiagonalLinear(nn.Module): """Block-diagonal linear layer.""" def __init__( - self, width: int, num_blocks: int, w_init_variance_scale: float = 1.0, + self, + width: int, + num_blocks: int, + w_init_variance_scale: float = 1.0, ): """Initializes the BlockDiagonalLinear. - Args: - width: The number of dimensions of the input and output. - num_blocks: The number of diagonal blocks in the layer. - w_init_variance_scale: A parameters that scales the variance of the - initialization of the weights. - """ + Args: + width: The number of dimensions of the input and output. + num_blocks: The number of diagonal blocks in the layer. + w_init_variance_scale: A parameters that scales the variance of the + initialization of the weights. + """ super().__init__() self.width = width self.num_blocks = num_blocks @@ -62,25 +79,46 @@ def w_init_(self, w: torch.Tensor) -> None: std = math.sqrt(self.w_init_variance_scale / self.block_width) torch.nn.init.normal_(w, mean=0.0, std=std) - def forward(self, x): - """Calls the BlockDiagonalLinear.""" - # Split x to blocks. - bs, seq_l = x.shape[0], x.shape[1] + @jit_fuser + def _fused_pre_reshape_(self, x, bs, seq_l): x = ( x.reshape(bs, seq_l, self.num_blocks, self.block_width) .permute(2, 0, 1, 3) .reshape(self.num_blocks, bs * seq_l, self.block_width) ) - x = (torch.bmm(x, self.w).permute(1, 0, 2) + self.b).reshape(bs, seq_l, self.num_blocks * self.block_width) - out = torch.sigmoid(x) - return out + return x + + @jit_fuser + def _post_add_reshape_sigmoid_(self, x, bs, seq_l): + x = (x.permute(1, 0, 2) + self.b).reshape(bs, seq_l, self.num_blocks * self.block_width) + x = torch.sigmoid(x) + return x + + def forward(self, x): + """Calls the BlockDiagonalLinear.""" + # Split x to blocks. + bs, seq_l = x.shape[0], x.shape[1] + x = self._fused_pre_reshape_(x, bs, seq_l) + + x = torch.bmm(x, self.w) + x = self._post_add_reshape_sigmoid_(x, bs, seq_l) + + return x # Class copied from https://github.com/google-deepmind/recurrentgemma @jit_fuser -def _scan_preprocess_(a, x, reset): +def _scan_preprocess_(x, gate_a, gate_x, reset, a_params): + + log_a = -8.0 * gate_a * nn.functional.softplus(a_params) + a = torch.exp(log_a) + gated_x = x * gate_x + multiplier = torch.sqrt((1 - torch.exp(2 * log_a)) + 1e-6) + multiplier = reset + (1 - reset) * multiplier + x = gated_x * multiplier.type(x.dtype) + assert x.ndim == 3 assert a.shape == x.shape[-a.ndim :] assert a.dtype == x.dtype @@ -94,38 +132,54 @@ def _scan_preprocess_(a, x, reset): a = a.permute(0, 2, 1) x = x.contiguous() a = a.contiguous() + return a, x def rnn_scan( - x, a, reset, + x, + gate_a, + gate_x, + reset, + a_params, + # x, a, reset, ): """Runs the recurrence of a linear RNN. - Args: - x: The input sequence. - a: The diagonal of the recurrence matrix `A`. - reset: Indicator of document boundaries, e.g. when to reset the hidden - state of the RNN. - h0: The initial hidden state. - - Returns: - The output of the linear recurrence. - """ - a, x = _scan_preprocess_(a, x, reset) + Args: + x: The input sequence. + a: The diagonal of the recurrence matrix `A`. + reset: Indicator of document boundaries, e.g. when to reset the hidden + state of the RNN. + h0: The initial hidden state. + + Returns: + The output of the linear recurrence. + """ + + a, x = _scan_preprocess_(x, gate_a, gate_x, reset, a_params) + y = scan(a.float(), x.float()).type_as(x) + y = y.permute(0, 2, 1) + return y, None # Class copied from https://github.com/google-deepmind/recurrentgemma -def rnn_param_init(*, width: int, min_rad: float, max_rad: float, transform: str = "softplus",) -> torch.Tensor: +def rnn_param_init( + *, + width: int, + min_rad: float, + max_rad: float, + transform: str = "softplus", +) -> torch.Tensor: """Initializes the `A` parameter of the RG-LRU uniformly on a ring.""" unif = torch.rand(width) # Proportional to area in a ring. - a_real = 0.5 * torch.log(unif * (max_rad ** 2 - min_rad ** 2) + min_rad ** 2 + 1e-8) + a_real = 0.5 * torch.log(unif * (max_rad**2 - min_rad**2) + min_rad**2 + 1e-8) if transform == "softplus": # Inverse transform. @@ -141,17 +195,20 @@ class RGLRU(nn.Module): """A Real-Gated Linear Recurrent Unit (RG-LRU) layer.""" def __init__( - self, width: int, num_heads: int, w_init_variance_scale: float = 1.0, + self, + width: int, + num_heads: int, + w_init_variance_scale: float = 1.0, ): """Initializes the RG-LRU. - Args: - width: The number of dimensions of the input and output. - num_heads: The number of diagonal blocks in the input and A gate layers. - w_init_variance_scale: Initialization parameter for the - BlockDiagonalLinear layers of the gates. See the `BlockDiagonalLinear` - layer for details. - """ + Args: + width: The number of dimensions of the input and output. + num_heads: The number of diagonal blocks in the input and A gate layers. + w_init_variance_scale: Initialization parameter for the + BlockDiagonalLinear layers of the gates. See the `BlockDiagonalLinear` + layer for details. + """ super().__init__() self.width = width self.num_heads = num_heads @@ -160,7 +217,9 @@ def __init__( # Parameters and layers. self.a_param = nn.Parameter(self.a_param_init) self.input_gate = BlockDiagonalLinear( - width=self.width, num_blocks=self.num_heads, w_init_variance_scale=w_init_variance_scale, + width=self.width, + num_blocks=self.num_heads, + w_init_variance_scale=w_init_variance_scale, ) self.a_gate = BlockDiagonalLinear( width=self.width, num_blocks=self.num_heads, w_init_variance_scale=self.w_init_variance_scale @@ -184,18 +243,22 @@ def _fused_pst_gates_(self, x, gate_a, gate_x, reset): return normalized_x, a def __call__( - self, x, segment_pos, prev_h, + self, + x, + segment_pos, + prev_h, ): """Calls the RG-LRU. - Args: - x: Sequence of input activations. - segment_pos: Position of each token in the sequence. - prev_h: The previous hidden state of the RG-LRU. + Args: + x: Sequence of input activations. + segment_pos: Position of each token in the sequence. + prev_h: The previous hidden state of the RG-LRU. + + Returns: + Output of the block together with the updated hidden state. + """ - Returns: - Output of the block together with the updated hidden state. - """ for param in self.parameters(): param.data_ptr() @@ -207,9 +270,7 @@ def __call__( gate_x = self.input_gate(x) gate_a = self.a_gate(x) - # Compute the parameter `A` of the recurrence. - normalized_x, a = self._fused_pst_gates_(x, gate_a, gate_x, reset) - y, last_h = rnn_scan(x=normalized_x, a=a, reset=reset) + y, last_h = rnn_scan(x, gate_a, gate_x, reset, self.a_param) return y, last_h @@ -230,11 +291,17 @@ def __init__(self, config, width, temporal_width): ) def forward( - self, x, segment_pos=None, prev_x=None, + self, + x, + segment_pos=None, + prev_x=None, ): x = x.permute(0, 2, 1) output = causal_conv1d_fn( - x=x, weight=rearrange(self.conv_1d.weight, "d 1 w -> d w"), bias=self.conv_1d.bias, activation=None, + x=x, + weight=rearrange(self.conv_1d.weight, "d 1 w -> d w"), + bias=self.conv_1d.bias, + activation=None, ).permute(0, 2, 1) return output, None @@ -314,6 +381,11 @@ def __init__( submodules.rg_lru, width=self.config.hidden_size, num_heads=self.config.num_attention_heads ) + def checkpoint_handler(self, forward_func, x, segment_pos, prev_x): + return tensor_parallel.checkpoint( + forward_func, self.config.distribute_saved_activations, x, segment_pos, prev_x + ) + def forward(self, hidden_states, attention_mask=None, rotary_pos_emb=None): segment_pos = torch.arange(hidden_states.shape[0]).unsqueeze(0).repeat(hidden_states.shape[1], 1).cuda() @@ -326,9 +398,13 @@ def forward(self, hidden_states, attention_mask=None, rotary_pos_emb=None): x = _fused_permute_add_(x_intermidiate_parallel, x_bias_parallel) - x, _ = self.conv_1d(x=x, segment_pos=segment_pos, prev_x=None) + if self.config.activations_checkpoint_recurrent and self.training: + x, _ = self.checkpoint_handler(self.conv_1d, x=x, segment_pos=segment_pos, prev_x=None) + x, _ = self.checkpoint_handler(self.rg_lru, x=x, segment_pos=segment_pos, prev_x=None) - x, _ = self.rg_lru(x=x, segment_pos=segment_pos, prev_h=None,) + else: + x, _ = self.conv_1d(x=x, segment_pos=segment_pos, prev_x=None) + x, _ = self.rg_lru(x=x, segment_pos=segment_pos, prev_h=None) x = _fused_permute_mult_(x, y) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py b/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py index 20ad376b8f98c..1e5a2f0c15c04 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py @@ -18,15 +18,6 @@ from nemo.collections.nlp.models.language_modeling.megatron.griffin.griffin_model import GriffinModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults - -try: - - HAVE_MEGATRON_CORE = True - -except (ImportError, ModuleNotFoundError): - TransformerConfig = ApexGuardDefaults - HAVE_MEGATRON_CORE = False class MegatronGriffinModel(MegatronGPTModel): @@ -35,13 +26,6 @@ class MegatronGriffinModel(MegatronGPTModel): """ def __init__(self, cfg: DictConfig, trainer: Trainer): - if not HAVE_MEGATRON_CORE: - raise ImportError( - "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." - ) - - # build the transformer config - # TODO: add type hint once pip package is out self.vocab_size = cfg.get('vocab_size', 256000) self.cfg = cfg @@ -70,8 +54,12 @@ def forward(self, input_ids, position_ids=None, attention_mask=None, labels=None def build_transformer_config(self): transformer_config = super().build_transformer_config() + transformer_config.activations_checkpoint_recurrent = self.cfg.get('activations_checkpoint_recurrent', False) transformer_config.gated_linear_unit = self.cfg.get('gated_linear_unit', True) transformer_config.layernorm_zero_centered_gamma = self.cfg.get('layernorm_zero_centered_gamma', True) + assert ( + not transformer_config.activations_checkpoint_recurrent or not transformer_config.recompute_granularity + ), "Either the recurrent checkpoiting or the full/custom checkpointing should be set" return transformer_config From 93907f000dbaeed899556c5ae224557172233412 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 14 May 2024 12:01:13 -0400 Subject: [PATCH 06/18] Bucketing duration bins: less optimal but instant init when not provided + fixes in estimation script (#9157) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bucketing duration bins: less optimal but instant init when not provided + fixes in estimation script Signed-off-by: Piotr Żelasko * Fix CPU mem hungriness Signed-off-by: Piotr Żelasko * Make estimate duration bins work for every kind of manifest Signed-off-by: Piotr Żelasko * Support more type of inputs Signed-off-by: Piotr Żelasko * fixes Signed-off-by: Piotr Żelasko * msg Signed-off-by: Piotr Żelasko * fix Signed-off-by: Piotr Żelasko * fix Signed-off-by: Piotr Żelasko * Apply isort and black reformatting Signed-off-by: pablo-garay --------- Signed-off-by: Piotr Żelasko Signed-off-by: pablo-garay Co-authored-by: Pablo Garay Co-authored-by: pablo-garay --- nemo/collections/common/data/lhotse/cutset.py | 84 ++++++++++++++----- .../common/data/lhotse/dataloader.py | 62 ++++++++++++-- .../common/data/lhotse/nemo_adapters.py | 32 ++++--- .../convert_to_tarred_audio_dataset.py | 2 +- .../estimate_duration_bins.py | 53 ++++++++---- 5 files changed, 177 insertions(+), 56 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index cb2efe0312d2c..775395400d8e8 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -127,7 +127,7 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: "shard_seed": config.shard_seed, "text_field": config.text_field, "lang_field": config.lang_field, - "missing_sampling_rate_ok": config.missing_sampling_rate_ok, + "metadata_only": config.metadata_only, "max_open_streams": config.max_open_streams, } input_cfg = config.input_cfg @@ -164,7 +164,10 @@ def parse_group(grp_cfg: DictConfig, propagate_attrs: dict) -> [CutSet, bool]: is_tarred = True cuts = read_txt_pair_paths(grp_cfg) elif grp_cfg.type == "group": - cuts, is_tarred = parse_and_combine_datasets(grp_cfg.input_cfg, propagate_attrs=propagate_attrs,) + cuts, is_tarred = parse_and_combine_datasets( + grp_cfg.input_cfg, + propagate_attrs=propagate_attrs, + ) else: raise ValueError(f"Unrecognized group: {grp_cfg.type}") # Attach extra tags to every utterance dynamically, if provided. @@ -176,7 +179,10 @@ def parse_group(grp_cfg: DictConfig, propagate_attrs: dict) -> [CutSet, bool]: def read_txt_paths(config: DictConfig) -> CutSet: return CutSet( LhotseTextAdapter( - paths=config.paths, language=config.language, shuffle_shards=config.shuffle, shard_seed=config.shard_seed, + paths=config.paths, + language=config.language, + shuffle_shards=config.shuffle, + shard_seed=config.shard_seed, ) ).repeat() @@ -238,6 +244,7 @@ def parse_and_combine_datasets( weights=weights if weights else None, max_open_streams=propagate_attrs["max_open_streams"], seed=propagate_attrs["shard_seed"], + metadata_only=propagate_attrs["metadata_only"], ) else: (cuts,) = cuts @@ -261,11 +268,16 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: # - integer means we'll set a specific seed in every worker, and data would be duplicated across them. # This is mostly useful for unit testing or debugging. shard_seed = config.shard_seed + metadata_only = config.metadata_only if config.get("cuts_path") is not None: warnings.warn("Note: lhotse.cuts_path will be ignored because lhotse.shar_path was provided.") if isinstance(config.shar_path, (str, Path)): logging.info(f"Initializing Lhotse Shar CutSet (tarred) from a single data source: '{config.shar_path}'") - cuts = CutSet.from_shar(in_dir=config.shar_path, shuffle_shards=True, seed=shard_seed).repeat() + cuts = CutSet.from_shar( + **_resolve_shar_inputs(config.shar_path, metadata_only), shuffle_shards=True, seed=shard_seed + ) + if not metadata_only: + cuts = cuts.repeat() else: # Multiple datasets in Lhotse Shar format: we will dynamically multiplex them # with probability approximately proportional to their size @@ -278,7 +290,9 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: for item in config.shar_path: if isinstance(item, (str, Path)): path = item - cs = CutSet.from_shar(in_dir=path, shuffle_shards=True, seed=shard_seed) + cs = CutSet.from_shar( + **_resolve_shar_inputs(path, metadata_only), shuffle_shards=True, seed=shard_seed + ) weight = len(cs) else: assert isinstance(item, Sequence) and len(item) == 2 and isinstance(item[1], (int, float)), ( @@ -288,11 +302,19 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: f"We got: '{item}'" ) path, weight = item - cs = CutSet.from_shar(in_dir=path, shuffle_shards=True, seed=shard_seed) + cs = CutSet.from_shar( + **_resolve_shar_inputs(path, metadata_only), shuffle_shards=True, seed=shard_seed + ) logging.info(f"- {path=} {weight=}") - cutsets.append(cs.repeat()) + cutsets.append(cs) weights.append(weight) - cuts = mux(*cutsets, weights=weights, max_open_streams=config.max_open_streams, seed=config.shard_seed) + cuts = mux( + *cutsets, + weights=weights, + max_open_streams=config.max_open_streams, + seed=config.shard_seed, + metadata_only=metadata_only, + ) else: # Regular Lhotse manifest points to individual audio files (like native NeMo manifest). path = config.cuts_path @@ -300,6 +322,13 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: return cuts +def _resolve_shar_inputs(path: str | Path, only_metadata: bool) -> dict: + if only_metadata: + return dict(fields={"cuts": sorted(Path(path).glob("cuts.*"))}) + else: + return dict(in_dir=path) + + def resolve_relative_paths(cut: Cut, manifest_path: str) -> Cut: if isinstance(cut, PaddingCut): return cut @@ -352,20 +381,24 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: common_kwargs = { "text_field": config.text_field, "lang_field": config.lang_field, + "shuffle_shards": config.shuffle, + "shard_seed": config.shard_seed, } # The option below is to allow a special case of NeMo manifest iteration as Lhotse CutSet - # without performing any I/O. NeMo manifests typically don't have sampling_rate information required by Lhotse. - # This is useful for utility scripts that iterate metadata and estimate optimal batching settings. - notar_kwargs = {"missing_sampling_rate_ok": config.missing_sampling_rate_ok} + # without performing any I/O. NeMo manifests typically don't have sampling_rate information required by Lhotse, + # so lhotse has to look up the headers of audio files to fill it on-the-fly. + # (this only has an impact on non-tarred data; tarred data is read into memory anyway). + # This is useful for utility scripts that iterate metadata and estimate optimal batching settings + # and other data statistics. + notar_kwargs = {"metadata_only": config.metadata_only} + metadata_only = config.metadata_only if isinstance(config.manifest_filepath, (str, Path)): logging.info(f"Initializing Lhotse CutSet from a single NeMo manifest (tarred): '{config.manifest_filepath}'") - if is_tarred: + if is_tarred and not metadata_only: cuts = CutSet( LazyNeMoTarredIterator( config.manifest_filepath, tar_paths=config.tarred_audio_filepaths, - shuffle_shards=config.shuffle, - shard_seed=config.shard_seed, **common_kwargs, ) ).repeat() @@ -393,12 +426,10 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: for manifest_info, (tar_path,) in zip(config.manifest_filepath, tar_paths): # First, convert manifest_path[+tar_path] to an iterator. manifest_path = manifest_info[0] - if is_tarred: + if is_tarred and not metadata_only: nemo_iter = LazyNeMoTarredIterator( manifest_path=manifest_path, tar_paths=tar_path, - shuffle_shards=config.shuffle, - shard_seed=config.shard_seed, **common_kwargs, ) else: @@ -431,12 +462,22 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: cutsets.append(CutSet(nemo_iter)) weights.append(weight) # Finally, we multiplex the dataset streams to mix the data. - cuts = mux(*cutsets, weights=weights, max_open_streams=config.max_open_streams, seed=config.shard_seed) + cuts = mux( + *cutsets, + weights=weights, + max_open_streams=config.max_open_streams, + seed=config.shard_seed, + metadata_only=metadata_only, + ) return cuts def mux( - *cutsets: CutSet, weights: list[int | float], max_open_streams: int | None = None, seed: str | int = "trng" + *cutsets: CutSet, + weights: list[int | float], + max_open_streams: int | None = None, + seed: str | int = "trng", + metadata_only: bool = False, ) -> CutSet: """ Helper function to call the right multiplexing method flavour in lhotse. @@ -444,9 +485,12 @@ def mux( it will select a more appropriate multiplexing strategy. """ if max_open_streams is not None: + assert not metadata_only, "max_open_streams and metadata_only options are not compatible" cuts = CutSet.infinite_mux(*cutsets, weights=weights, seed=seed, max_open_streams=max_open_streams) else: - cuts = CutSet.mux(*[cs.repeat() for cs in cutsets], weights=weights, seed=seed) + if not metadata_only: + cutsets = [cs.repeat() for cs in cutsets] + cuts = CutSet.mux(*cutsets, weights=weights, seed=seed) return cuts diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 9efd6444aecdd..32bbc1f3e8f4b 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -95,7 +95,9 @@ class LhotseDataLoadingConfig: # 4. Optional Lhotse data augmentation. # a. On-the-fly noise/audio mixing. - noise_path: Any | None = None # str | dict where dict can have any of keys: manifest_filepath, tarred_audio_filepaths, cuts_path, shar_path + noise_path: Any | None = ( + None # str | dict where dict can have any of keys: manifest_filepath, tarred_audio_filepaths, cuts_path, shar_path + ) noise_snr: tuple[float, float] = (10.0, 20.0) noise_mix_prob: float = 0.5 # b. On-the-fly 3-way speed perturbation. @@ -114,7 +116,9 @@ class LhotseDataLoadingConfig: cut_into_windows_duration: Optional[float] = None # set this to enable cut_into_windows_hop: Optional[float] = None # III) common options - keep_excessive_supervisions: bool = True # when a cut is truncated in the middle of a supervision, should we keep them. + keep_excessive_supervisions: bool = ( + True # when a cut is truncated in the middle of a supervision, should we keep them. + ) # e. RIR augmentation (synthetic RIR if rir_path is None) # at the moment supports only Lhotse recording manifests, e.g. https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/rir_noise.py rir_enabled: bool = False @@ -126,11 +130,15 @@ class LhotseDataLoadingConfig: lang_field: str = "lang" # key to read the language tag from # Enables iteration of NeMo non-tarred manifests that don't have a "sampling_rate" key without performing any I/O. # Note that this will not allow actual dataloading; it's only for manifest iteration as Lhotse objects. - missing_sampling_rate_ok: bool = False + metadata_only: bool = False def get_lhotse_dataloader_from_config( - config: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, tokenizer=None, + config: DictConfig, + global_rank: int, + world_size: int, + dataset: torch.utils.data.Dataset, + tokenizer=None, ) -> torch.utils.data.DataLoader: """ Set up a Lhotse training dataloder. @@ -205,7 +213,11 @@ def get_lhotse_dataloader_from_config( # and applying it here (before sampler/dataset) ensures optimal # bucket allocation. if config.perturb_speed: - cuts = CutSet.mux(cuts, cuts.perturb_speed(0.9), cuts.perturb_speed(1.1),) + cuts = CutSet.mux( + cuts, + cuts.perturb_speed(0.9), + cuts.perturb_speed(1.1), + ) # 2.d: truncation/slicing if config.truncate_duration is not None: @@ -249,6 +261,7 @@ def get_lhotse_dataloader_from_config( f"Creating a Lhotse DynamicBucketingSampler " f"(max_batch_duration={config.batch_duration} max_batch_size={config.batch_size})" ) + # Determine the bucket duration bins sampler = DynamicBucketingSampler( cuts, constraint=constraint, @@ -257,7 +270,7 @@ def get_lhotse_dataloader_from_config( shuffle_buffer_size=config.shuffle_buffer_size, seed=config.shard_seed, num_buckets=config.num_buckets, - duration_bins=config.bucket_duration_bins, + duration_bins=determine_bucket_duration_bins(config), num_cuts_for_bins_estimate=config.num_cuts_for_bins_estimate, buffer_size=config.bucket_buffer_size, rank=0 if is_tarred else global_rank, @@ -291,7 +304,10 @@ def get_lhotse_dataloader_from_config( # object with texts joined by a whitespace so that "regular" dataset classes don't # have to add a special support for multi-supervision cuts. sampler = sampler.map( - CutConcatenate(gap=config.concatenate_gap_seconds, duration_factor=config.concatenate_duration_factor,) + CutConcatenate( + gap=config.concatenate_gap_seconds, + duration_factor=config.concatenate_duration_factor, + ) ) if config.db_norm is not None: sampler = sampler.map(partial(_normalize_loudness, db_norm=config.db_norm)) @@ -326,12 +342,38 @@ def get_lhotse_dataloader_from_config( # the meta-data to Dataset, which performs the actual I/O inside its __getitem__ method. dloader_kwargs = dict(dataset=dataset, sampler=sampler) dloader = torch.utils.data.DataLoader( - **dloader_kwargs, batch_size=None, num_workers=config.num_workers, pin_memory=config.pin_memory, + **dloader_kwargs, + batch_size=None, + num_workers=config.num_workers, + pin_memory=config.pin_memory, ) return dloader +def determine_bucket_duration_bins(config): + if config.bucket_duration_bins is not None: + # Bucket duration bins are provided: just use them. + return config.bucket_duration_bins + # Bucket duration bins are not set. + if config.use_multimodal_sampling: + # For multimodal sampling it's currently impossible to define a linspace over durations + # because the buckets are counted in the number of tokens. + # The bins will be auto-estimated by lhotse at the cost of a slight lag in the training start. + return None + elif config.max_duration is not None and config.max_duration < float("inf"): + # If max duration is provided, we can use that to compute uniformly distant bucket bins. + # This is not optimal but should be close enough for users who didn't want to estimate these up-front. + begin = config.min_duration if config.min_duration is not None and config.min_duration > 0 else 0.0 + end = config.max_duration + return np.linspace(begin, end, config.num_buckets + 1)[1:-1].tolist() + else: + # If we don't know max_duration, we can't guess a reasonable estimate of the upper bound of + # durations. + # The bins will be auto-estimated by lhotse at the cost of a slight lag in the training start. + return None + + def make_structured_with_schema_warnings(config: DictConfig) -> DictConfig: """ Checks the schema and fills missing default option values. @@ -377,7 +419,9 @@ class MultimodalSamplingConstraint(SamplingConstraint): def __post_init__(self): self._internal = TokenConstraint( - max_tokens=self.batch_tokens, max_examples=self.batch_size, quadratic_length=self.quadratic_factor, + max_tokens=self.batch_tokens, + max_examples=self.batch_size, + quadratic_length=self.quadratic_factor, ) def add(self, example: Any) -> None: diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index b8769b041b4f3..b2ca1186c8e30 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -49,7 +49,7 @@ class LazyNeMoIterator: .. caution:: We will perform some I/O (as much as required by soundfile.info) to discover the sampling rate of the audio file. If this is not acceptable, convert the manifest to Lhotse format which contains - sampling rate info. For pure metadata iteration purposes we also provide a ``missing_sampling_rate_ok`` flag that + sampling rate info. For pure metadata iteration purposes we also provide a ``metadata_only`` flag that will create only partially valid Lhotse objects (with metadata related to sampling rate / num samples missing). Example:: @@ -62,16 +62,23 @@ def __init__( path: str | Path, text_field: str = "text", lang_field: str = "lang", - missing_sampling_rate_ok: bool = False, + metadata_only: bool = False, + shuffle_shards: bool = False, + shard_seed: int | Literal["randomized", "trng"] = "trng", ) -> None: - self.source = LazyJsonlIterator(path) + self.path = path + self.shuffle_shards = shuffle_shards + self.shard_seed = shard_seed + paths = expand_sharded_filepaths(path) + if len(paths) == 1: + self.source = LazyJsonlIterator(paths[0]) + else: + self.source = LazyIteratorChain( + *(LazyJsonlIterator(p) for p in paths), shuffle_iters=self.shuffle_shards, seed=self.shard_seed + ) self.text_field = text_field self.lang_field = lang_field - self.missing_sampling_rate_ok = missing_sampling_rate_ok - - @property - def path(self) -> str | Path: - return self.source.path + self.metadata_only = metadata_only def __iter__(self) -> Generator[Cut, None, None]: for data in self.source: @@ -104,7 +111,12 @@ def __len__(self) -> int: def __add__(self, other): return LazyIteratorChain(self, other) - def _create_recording(self, audio_path: str, duration: float, sampling_rate: int | None = None,) -> Recording: + def _create_recording( + self, + audio_path: str, + duration: float, + sampling_rate: int | None = None, + ) -> Recording: if sampling_rate is not None: # TODO(pzelasko): It will only work with single-channel audio in the current shape. return Recording( @@ -115,7 +127,7 @@ def _create_recording(self, audio_path: str, duration: float, sampling_rate: int duration=duration, channel_ids=[0], ) - elif self.missing_sampling_rate_ok: + elif self.metadata_only: return Recording( id=audio_path, sources=[AudioSource(type="file", channels=[0], source=audio_path)], diff --git a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py index f0c7847b8c9b4..c3b5cef57cbca 100644 --- a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +++ b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py @@ -412,7 +412,7 @@ def estimate_dynamic_bucketing_duration_bins(self, manifest_path: str, num_bucke from lhotse.dataset.sampling.dynamic_bucketing import estimate_duration_buckets from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator - cuts = CutSet(LazyNeMoIterator(manifest_path, missing_sampling_rate_ok=True)) + cuts = CutSet(LazyNeMoIterator(manifest_path, metadata_only=True)) bins = estimate_duration_buckets(cuts, num_buckets=num_buckets) print( f"Note: we estimated the optimal bucketing duration bins for {num_buckets} buckets. " diff --git a/scripts/speech_recognition/estimate_duration_bins.py b/scripts/speech_recognition/estimate_duration_bins.py index 687c2af59ad20..cca1017317729 100644 --- a/scripts/speech_recognition/estimate_duration_bins.py +++ b/scripts/speech_recognition/estimate_duration_bins.py @@ -13,6 +13,10 @@ # limitations under the License. import argparse +from itertools import islice +from pathlib import Path + +from lhotse.cut import Cut from lhotse.dataset.sampling.dynamic_bucketing import estimate_duration_buckets from omegaconf import OmegaConf @@ -23,14 +27,18 @@ def parse_args(): parser = argparse.ArgumentParser( description="Estimate duration bins for Lhotse dynamic bucketing using a sample of the input dataset. " - "The dataset is read either from one or more manifest files and supports data weighting." + "The dataset is read either from one or more manifest files and supports data weighting.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "input", - help='Same input format as in model configs under model.train_ds.manifest_filepath. Options: ' - '1) "path.json"; ' - '2) "[[path1.json],[path2.json],...]"; ' - '3) "[[path1.json,weight1],[path2.json,weight2],...]"', + help='Data input. Options: ' + '1) "path.json" - any single NeMo manifest; ' + '2) "[[path1.json],[path2.json],...]" - any collection of NeMo manifests; ' + '3) "[[path1.json,weight1],[path2.json,weight2],...]" - any collection of weighted NeMo manifests; ' + '4) "input_cfg.yaml" - a new option supporting input configs, same as in model training \'input_cfg\' arg; ' + '5) "path/to/shar_data" - a path to Lhotse Shar data directory; ' + '6) "key=val" - in case none of the previous variants cover your case: "key" is the key you\'d use in NeMo training config with its corresponding value ', ) parser.add_argument("-b", "--buckets", type=int, default=30, help="The desired number of buckets.") parser.add_argument( @@ -38,7 +46,8 @@ def parse_args(): "--num_examples", type=int, default=-1, - help="The number of examples (utterances) to estimate the bins. -1 means use all data.", + help="The number of examples (utterances) to estimate the bins. -1 means use all data " + "(be careful: it could be iterated over infinitely).", ) parser.add_argument( "-l", @@ -62,25 +71,36 @@ def parse_args(): def main(): args = parse_args() + if '=' in args.input: + inp_arg = args.input + elif args.input.endswith(".yaml"): + inp_arg = f"input_cfg={args.input}" + elif Path(args.input).is_dir(): + inp_arg = f"shar_path={args.input}" + else: + inp_arg = f"manifest_filepath={args.input}" config = OmegaConf.merge( OmegaConf.structured(LhotseDataLoadingConfig), - OmegaConf.from_dotlist([f"manifest_filepath={args.input}", "missing_sampling_rate_ok=true"]), + OmegaConf.from_dotlist([inp_arg, "metadata_only=true"]), ) cuts, _ = read_cutset_from_config(config) min_dur, max_dur = args.min_duration, args.max_duration - discarded, tot = 0, 0 + nonaudio, discarded, tot = 0, 0, 0 def duration_ok(cut) -> bool: - nonlocal discarded, tot - ans = min_dur <= cut.duration <= max_dur - if not ans: - discarded += 1 + nonlocal nonaudio, discarded, tot tot += 1 - return ans + if not isinstance(cut, Cut): + nonaudio += 1 + return False + if not (min_dur <= cut.duration <= max_dur): + discarded += 1 + return False + return True cuts = cuts.filter(duration_ok) if (N := args.num_examples) > 0: - cuts = cuts.subset(first=N) + cuts = islice(cuts, N) duration_bins = estimate_duration_buckets(cuts, num_buckets=args.buckets) duration_bins = f"[{','.join(str(round(b, ndigits=5)) for b in duration_bins)}]" if args.quiet: @@ -89,11 +109,12 @@ def duration_ok(cut) -> bool: if discarded: ratio = discarded / tot print(f"Note: we discarded {discarded}/{tot} ({ratio:.2%}) utterances due to min/max duration filtering.") + if nonaudio: + print(f"Note: we discarded {nonaudio} non-audio examples found during iteration.") + print(f"Used {tot - nonaudio - discarded} examples for the estimation.") print("Use the following options in your config:") print(f"\tnum_buckets={args.buckets}") print(f"\tbucket_duration_bins={duration_bins}") - print("Computing utterance duration distribution...") - cuts.describe() # prints a nice table with duration stats + other info if __name__ == "__main__": From acbd4e00ae2618c36ed9dad265d339e77a57832a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 14 May 2024 21:55:34 +0400 Subject: [PATCH 07/18] Enable CUDA graphs by default only for transcription (#9196) (#9197) * Enable CUDA graphs only for transcription. Sync streams before capture. --------- Signed-off-by: Vladimir Bataev --- examples/asr/transcribe_speech.py | 17 +- examples/asr/transcribe_speech_parallel.py | 6 +- .../asr/parts/submodules/rnnt_decoding.py | 15 +- .../parts/submodules/rnnt_greedy_decoding.py | 20 +- .../submodules/rnnt_loop_labels_computer.py | 121 +++++++--- .../submodules/tdt_loop_labels_computer.py | 225 +++++++++++------- 6 files changed, 266 insertions(+), 138 deletions(-) diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index c8372c422e7ba..1763c20358050 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -29,6 +29,7 @@ from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.asr.parts.utils.transcribe_utils import ( @@ -121,9 +122,9 @@ class TranscriptionConfig: pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest - channel_selector: Optional[ - Union[int, str] - ] = None # Used to select a single channel from multichannel audio, or use average across channels + channel_selector: Optional[Union[int, str]] = ( + None # Used to select a single channel from multichannel audio, or use average across channels + ) audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation presort_manifest: bool = True # Significant inference speedup on short-form data due to padding reduction @@ -161,7 +162,10 @@ class TranscriptionConfig: ctc_decoding: CTCDecodingConfig = CTCDecodingConfig() # Decoding strategy for RNNT models - rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1) + # enable CUDA graphs for transcription + rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig( + fused_batch_size=-1, greedy=GreedyBatchedRNNTInferConfig(use_cuda_graph_decoder=True) + ) # Decoding strategy for AED models multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig() @@ -407,7 +411,10 @@ def autocast(dtype=None): override_cfg.augmentor = augmentor override_cfg.text_field = cfg.gt_text_attr_name override_cfg.lang_field = cfg.gt_lang_attr_name - transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,) + transcriptions = asr_model.transcribe( + audio=filepaths, + override_config=override_cfg, + ) if cfg.dataset_manifest is not None: logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}") diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index c0af8f97146a5..df2f310728511 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -84,6 +84,7 @@ from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig from nemo.core.config import TrainerConfig, hydra_runner from nemo.utils import logging from nemo.utils.get_rank import is_global_rank_zero @@ -100,7 +101,10 @@ class ParallelTranscriptionConfig: use_cer: bool = False # decoding strategy for RNNT models - rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig() + # enable CUDA graphs for transcription + rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig( + fused_batch_size=-1, greedy=GreedyBatchedRNNTInferConfig(use_cuda_graph_decoder=True) + ) # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models decoder_type: Optional[str] = None diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 5fa225864f8c7..2416d916ac136 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -331,7 +331,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_frame_confidence=self.preserve_frame_confidence, confidence_method_cfg=self.confidence_method_cfg, loop_labels=self.cfg.greedy.get('loop_labels', True), - use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True), + use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), ) else: self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer( @@ -347,7 +347,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_frame_confidence=self.preserve_frame_confidence, include_duration_confidence=self.tdt_include_duration_confidence, confidence_method_cfg=self.confidence_method_cfg, - use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True), + use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), ) else: @@ -1175,7 +1175,11 @@ class RNNTDecoding(AbstractRNNTDecoding): """ def __init__( - self, decoding_cfg, decoder, joint, vocabulary, + self, + decoding_cfg, + decoder, + joint, + vocabulary, ): # we need to ensure blank is the last token in the vocab for the case of RNNT and Multi-blank RNNT. blank_id = len(vocabulary) + joint.num_extra_outputs @@ -1186,7 +1190,10 @@ def __init__( self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) super(RNNTDecoding, self).__init__( - decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id, + decoding_cfg=decoding_cfg, + decoder=decoder, + joint=joint, + blank_id=blank_id, ) if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer): diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index b2fa9b85b5fda..fa7a5cc95fec5 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -45,7 +45,10 @@ from nemo.utils import logging -def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor,) -> List[rnnt_utils.Hypothesis]: +def pack_hypotheses( + hypotheses: List[rnnt_utils.Hypothesis], + logitlen: torch.Tensor, +) -> List[rnnt_utils.Hypothesis]: if hasattr(logitlen, 'cpu'): logitlen_cpu = logitlen.to('cpu') @@ -139,8 +142,7 @@ class _GreedyRNNTInfer(Typing, ConfidenceMethodMixin): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), @@ -149,8 +151,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -578,6 +579,7 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): (evaluating Joint multiple times in inner loop); It uses a minimal possible amount of calls to prediction network (with maximum possible batch size), which makes it especially useful for scaling the prediction network. + use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding (currently recommended only for inference) """ def __init__( @@ -590,7 +592,7 @@ def __init__( preserve_frame_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, loop_labels: bool = True, - use_cuda_graph_decoder: bool = True, + use_cuda_graph_decoder: bool = False, ): super().__init__( decoder_model=decoder_model, @@ -2358,7 +2360,7 @@ class GreedyBatchedRNNTInferConfig: tdt_include_duration_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) loop_labels: bool = True - use_cuda_graph_decoder: bool = True + use_cuda_graph_decoder: bool = False def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed @@ -2695,6 +2697,8 @@ class GreedyBatchedTDTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): Supported values: - 'lin' for using the linear mapping. - 'exp' for using exponential mapping with linear shift. + + use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding (currently recommended only for inference) """ def __init__( @@ -2708,7 +2712,7 @@ def __init__( preserve_frame_confidence: bool = False, include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, - use_cuda_graph_decoder: bool = True, + use_cuda_graph_decoder: bool = False, ): super().__init__( decoder_model=decoder_model, diff --git a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py index b920dba09cfd4..718deb7a409c4 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py @@ -112,7 +112,9 @@ def __init__( self.max_time = max_time self.encoder_output_projected = torch.zeros( - (self.batch_size, self.max_time, encoder_dim), dtype=float_dtype, device=self.device, + (self.batch_size, self.max_time, encoder_dim), + dtype=float_dtype, + device=self.device, ) self.encoder_output_length = torch.zeros((self.batch_size,), dtype=torch.long, device=self.device) @@ -288,7 +290,9 @@ def reset_cuda_graphs_state(self): self.separate_graphs = None def loop_labels_torch( - self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, + self, + encoder_output: torch.Tensor, + encoder_output_length: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: """ Pure PyTorch implementation @@ -361,7 +365,8 @@ def loop_labels_torch( # blank label in `labels` tensor means "end of hypothesis" (for this index) logits = ( self.joint.joint_after_projection( - encoder_output_projected[batch_indices, safe_time_indices].unsqueeze(1), decoder_output, + encoder_output_projected[batch_indices, safe_time_indices].unsqueeze(1), + decoder_output, ) .squeeze(1) .squeeze(1) @@ -378,9 +383,11 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) - if self.preserve_frame_confidence - else None, + confidence=( + self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) + if self.preserve_frame_confidence + else None + ), ) # advance_mask is a mask for current batch for searching non-blank labels; @@ -397,7 +404,8 @@ def loop_labels_torch( torch.where(advance_mask, time_indices, time_indices_current_labels, out=time_indices_current_labels) logits = ( self.joint.joint_after_projection( - encoder_output_projected[batch_indices, safe_time_indices].unsqueeze(1), decoder_output, + encoder_output_projected[batch_indices, safe_time_indices].unsqueeze(1), + decoder_output, ) .squeeze(1) .squeeze(1) @@ -416,9 +424,11 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) - if self.preserve_frame_confidence - else None, + confidence=( + self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) + if self.preserve_frame_confidence + else None + ), ) blank_mask = labels == self._blank_index @@ -432,19 +442,27 @@ def loop_labels_torch( # this seems to be redundant, but used in the `loop_frames` output torch.ne(active_mask, active_mask_prev, out=became_inactive_mask) self.decoder.batch_replace_states_mask( - src_states=state, dst_states=last_decoder_state, mask=became_inactive_mask, + src_states=state, + dst_states=last_decoder_state, + mask=became_inactive_mask, ) # store hypotheses if self.max_symbols is not None: # pre-allocated memory, no need for checks batched_hyps.add_results_masked_no_checks_( - active_mask, labels, time_indices_current_labels, scores, + active_mask, + labels, + time_indices_current_labels, + scores, ) else: # auto-adjusted storage batched_hyps.add_results_masked_( - active_mask, labels, time_indices_current_labels, scores, + active_mask, + labels, + time_indices_current_labels, + scores, ) # stage 4: to avoid looping, go to next frame after max_symbols emission @@ -455,7 +473,8 @@ def loop_labels_torch( active_mask, torch.logical_and( torch.logical_and( - labels != self._blank_index, batched_hyps.last_timestep_lasts >= self.max_symbols, + labels != self._blank_index, + batched_hyps.last_timestep_lasts >= self.max_symbols, ), batched_hyps.last_timestep == time_indices, ), @@ -470,7 +489,9 @@ def loop_labels_torch( return batched_hyps, None, last_decoder_state def loop_labels_cuda_graphs( - self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, + self, + encoder_output: torch.Tensor, + encoder_output_length: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: """ Implementation with CUDA graphs. @@ -565,7 +586,9 @@ def _create_inner_while_loop_kernel(cls): return run_nvrtc(kernel_string, b"inner_find_non_blank_conditional", cls.CUDA_PROGRAM_NAME) def _graph_reinitialize( - self, encoder_output_projected: torch.Tensor, encoder_output_length: torch.Tensor, + self, + encoder_output_projected: torch.Tensor, + encoder_output_length: torch.Tensor, ): batch_size, max_time, encoder_dim = encoder_output_projected.shape @@ -602,25 +625,34 @@ def _partial_graphs_compile(self): """Compile decoding by parts""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) + stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.separate_graphs = SeparateGraphsLoopLabels() - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.separate_graphs.before_outer_loop, stream=stream_for_graph + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph), ): self._before_outer_loop() - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.separate_graphs.before_inner_loop, stream=stream_for_graph + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph), ): self._before_inner_loop_get_decoder_output() self._before_inner_loop_get_joint_output() - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.separate_graphs.inner_loop_code, stream=stream_for_graph + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph), ): self._inner_loop_code() - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.separate_graphs.after_inner_loop, stream=stream_for_graph + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph), ): self._after_inner_loop() @@ -628,9 +660,12 @@ def _full_graph_compile(self): """Compile full graph for decoding""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) + stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.full_graph = torch.cuda.CUDAGraph() - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.full_graph, stream=stream_for_graph + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.full_graph, stream=stream_for_graph), ): self._before_outer_loop() @@ -644,7 +679,8 @@ def _full_graph_compile(self): outer_loop_kernel = self._create_outer_while_loop_kernel() active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) outer_loop_args = np.array( - [outer_loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], dtype=np.uint64, + [outer_loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], + dtype=np.uint64, ) # loop while there are active utterances with with_conditional_node( @@ -657,7 +693,11 @@ def _full_graph_compile(self): (inner_loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) advance_mask_any_ptr = np.array([self.state.advance_mask_any.data_ptr()], dtype=np.uint64) inner_loop_args = np.array( - [inner_loop_conditional_handle.getPtr(), advance_mask_any_ptr.ctypes.data,], dtype=np.uint64, + [ + inner_loop_conditional_handle.getPtr(), + advance_mask_any_ptr.ctypes.data, + ], + dtype=np.uint64, ) with with_conditional_node( inner_while_loop_kernel, inner_loop_args, inner_loop_conditional_handle, device=self.state.device @@ -734,9 +774,11 @@ def _before_inner_loop_get_joint_output(self): time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=self.state.labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) - if self.preserve_frame_confidence - else None, + confidence=( + self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) + if self.preserve_frame_confidence + else None + ), ) # advance_mask is a mask for current batch for searching non-blank labels; @@ -785,9 +827,11 @@ def _inner_loop_code(self): time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) - if self.preserve_frame_confidence - else None, + confidence=( + self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) + if self.preserve_frame_confidence + else None + ), ) # blank_mask = self.labels == self._blank_index @@ -813,7 +857,10 @@ def _after_inner_loop(self): ) self.state.batched_hyps.add_results_masked_no_checks_( - self.state.active_mask, self.state.labels, self.state.time_indices_current_labels, self.state.scores, + self.state.active_mask, + self.state.labels, + self.state.time_indices_current_labels, + self.state.scores, ) # stage 4: to avoid looping, go to next frame after max_symbols emission @@ -837,7 +884,9 @@ def _after_inner_loop(self): torch.any(self.state.active_mask, out=self.state.active_mask_any) def __call__( - self, x: torch.Tensor, out_len: torch.Tensor, + self, + x: torch.Tensor, + out_len: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: if self.cuda_graphs_mode is not None and x.device.type == "cuda": return self.loop_labels_cuda_graphs(encoder_output=x, encoder_output_length=out_len) diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py index 4e514966db2b9..7ad7065e019c1 100644 --- a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py @@ -117,7 +117,9 @@ def __init__( self.max_time = max_time self.encoder_output_projected = torch.zeros( - (self.batch_size, self.max_time, encoder_dim), dtype=float_dtype, device=self.device, + (self.batch_size, self.max_time, encoder_dim), + dtype=float_dtype, + device=self.device, ) self.encoder_output_length = torch.zeros((self.batch_size,), dtype=torch.long, device=self.device) @@ -301,7 +303,9 @@ def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): self.state = None def loop_labels_torch( - self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, + self, + encoder_output: torch.Tensor, + encoder_output_length: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: """ Pure PyTorch implementation @@ -379,7 +383,8 @@ def loop_labels_torch( # blank label in `labels` tensor means "end of hypothesis" (for this index) logits = ( self.joint.joint_after_projection( - encoder_output_projected[batch_indices, safe_time_indices].unsqueeze(1), decoder_output, + encoder_output_projected[batch_indices, safe_time_indices].unsqueeze(1), + decoder_output, ) .squeeze(1) .squeeze(1) @@ -400,23 +405,27 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=labels if self.preserve_alignments else None, - confidence=torch.stack( - ( - self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( - dtype=float_dtype + confidence=( + torch.stack( + ( + self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( + dtype=float_dtype + ), + self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( + dtype=float_dtype + ), ), - self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( + dim=-1, + ) + if self.include_duration_confidence + else ( + self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( dtype=float_dtype - ), - ), - dim=-1, - ) - if self.include_duration_confidence - else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( - dtype=float_dtype - ) - if self.preserve_frame_confidence - else None, + ) + if self.preserve_frame_confidence + else None + ) + ), ) # advance_mask is a mask for current batch for searching non-blank labels; @@ -433,7 +442,8 @@ def loop_labels_torch( torch.where(advance_mask, time_indices, time_indices_current_labels, out=time_indices_current_labels) logits = ( self.joint.joint_after_projection( - encoder_output_projected[batch_indices, safe_time_indices].unsqueeze(1), decoder_output, + encoder_output_projected[batch_indices, safe_time_indices].unsqueeze(1), + decoder_output, ) .squeeze(1) .squeeze(1) @@ -454,23 +464,27 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=torch.stack( - ( - self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( - dtype=float_dtype + confidence=( + torch.stack( + ( + self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( + dtype=float_dtype + ), + self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( + dtype=float_dtype + ), ), - self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( + dim=-1, + ) + if self.include_duration_confidence + else ( + self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( dtype=float_dtype - ), - ), - dim=-1, - ) - if self.include_duration_confidence - else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( - dtype=float_dtype - ) - if self.preserve_frame_confidence - else None, + ) + if self.preserve_frame_confidence + else None + ) + ), ) blank_mask = labels == self._blank_index @@ -487,19 +501,27 @@ def loop_labels_torch( # this seems to be redundant, but used in the `loop_frames` output torch.ne(active_mask, active_mask_prev, out=became_inactive_mask) self.decoder.batch_replace_states_mask( - src_states=state, dst_states=last_decoder_state, mask=became_inactive_mask, + src_states=state, + dst_states=last_decoder_state, + mask=became_inactive_mask, ) # store hypotheses if self.max_symbols is not None: # pre-allocated memory, no need for checks batched_hyps.add_results_masked_no_checks_( - active_mask, labels, time_indices_current_labels, scores, + active_mask, + labels, + time_indices_current_labels, + scores, ) else: # auto-adjusted storage batched_hyps.add_results_masked_( - active_mask, labels, time_indices_current_labels, scores, + active_mask, + labels, + time_indices_current_labels, + scores, ) # stage 4: to avoid looping, go to next frame after max_symbols emission @@ -510,7 +532,8 @@ def loop_labels_torch( active_mask, torch.logical_and( torch.logical_and( - labels != self._blank_index, batched_hyps.last_timestep_lasts >= self.max_symbols, + labels != self._blank_index, + batched_hyps.last_timestep_lasts >= self.max_symbols, ), batched_hyps.last_timestep == time_indices, ), @@ -525,7 +548,9 @@ def loop_labels_torch( return batched_hyps, None, last_decoder_state def loop_labels_cuda_graphs( - self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, + self, + encoder_output: torch.Tensor, + encoder_output_length: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: """ Implementation with CUDA graphs. @@ -620,7 +645,9 @@ def _create_inner_while_loop_kernel(cls): return run_nvrtc(kernel_string, b"inner_find_non_blank_conditional", cls.CUDA_PROGRAM_NAME) def _graph_reinitialize( - self, encoder_output_projected: torch.Tensor, encoder_output_length: torch.Tensor, + self, + encoder_output_projected: torch.Tensor, + encoder_output_length: torch.Tensor, ): batch_size, max_time, encoder_dim = encoder_output_projected.shape @@ -659,25 +686,34 @@ def _partial_graphs_compile(self): """Compile decoding by parts""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) + stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.separate_graphs = SeparateGraphsLoopLabels() - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.separate_graphs.before_outer_loop, stream=stream_for_graph + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph), ): self._before_outer_loop() - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.separate_graphs.before_inner_loop, stream=stream_for_graph + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph), ): self._before_inner_loop_get_decoder_output() self._before_inner_loop_get_joint_output() - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.separate_graphs.inner_loop_code, stream=stream_for_graph + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph), ): self._inner_loop_code() - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.separate_graphs.after_inner_loop, stream=stream_for_graph + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph), ): self._after_inner_loop() @@ -685,9 +721,12 @@ def _full_graph_compile(self): """Compile full graph for decoding""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) + stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.full_graph = torch.cuda.CUDAGraph() - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.full_graph, stream=stream_for_graph + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.full_graph, stream=stream_for_graph), ): self._before_outer_loop() @@ -700,7 +739,8 @@ def _full_graph_compile(self): outer_loop_kernel = self._create_outer_while_loop_kernel() active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) outer_loop_args = np.array( - [outer_loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], dtype=np.uint64, + [outer_loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], + dtype=np.uint64, ) # loop while there are active utterances @@ -714,7 +754,11 @@ def _full_graph_compile(self): (inner_loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) advance_mask_any_ptr = np.array([self.state.advance_mask_any.data_ptr()], dtype=np.uint64) inner_loop_args = np.array( - [inner_loop_conditional_handle.getPtr(), advance_mask_any_ptr.ctypes.data,], dtype=np.uint64, + [ + inner_loop_conditional_handle.getPtr(), + advance_mask_any_ptr.ctypes.data, + ], + dtype=np.uint64, ) # while self.advance_mask_any.item(): @@ -797,23 +841,27 @@ def _before_inner_loop_get_joint_output(self): time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=self.state.labels if self.preserve_alignments else None, - confidence=torch.stack( - ( + confidence=( + torch.stack( + ( + self._get_confidence_tensor( + F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + ).to(dtype=float_dtype), + self._get_confidence_tensor( + F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) + ).to(dtype=float_dtype), + ), + dim=-1, + ) + if self.include_duration_confidence + else ( self._get_confidence_tensor( F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=float_dtype), - self._get_confidence_tensor( - F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) - ).to(dtype=float_dtype), - ), - dim=-1, - ) - if self.include_duration_confidence - else self._get_confidence_tensor( - F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=float_dtype) - if self.preserve_frame_confidence - else None, + ).to(dtype=float_dtype) + if self.preserve_frame_confidence + else None + ) + ), ) # advance_mask is a mask for current batch for searching non-blank labels; @@ -864,23 +912,27 @@ def _inner_loop_code(self): time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=torch.stack( - ( + confidence=( + torch.stack( + ( + self._get_confidence_tensor( + F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + ).to(dtype=float_dtype), + self._get_confidence_tensor( + F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) + ).to(dtype=float_dtype), + ), + dim=-1, + ) + if self.include_duration_confidence + else ( self._get_confidence_tensor( F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=float_dtype), - self._get_confidence_tensor( - F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) - ).to(dtype=float_dtype), - ), - dim=-1, - ) - if self.include_duration_confidence - else self._get_confidence_tensor( - F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=float_dtype) - if self.preserve_frame_confidence - else None, + ).to(dtype=float_dtype) + if self.preserve_frame_confidence + else None + ) + ), ) # blank_mask = self.labels == self._blank_index @@ -913,7 +965,10 @@ def _after_inner_loop(self): ) self.state.batched_hyps.add_results_masked_no_checks_( - self.state.active_mask, self.state.labels, self.state.time_indices_current_labels, self.state.scores, + self.state.active_mask, + self.state.labels, + self.state.time_indices_current_labels, + self.state.scores, ) # stage 4: to avoid looping, go to next frame after max_symbols emission @@ -937,7 +992,9 @@ def _after_inner_loop(self): torch.any(self.state.active_mask, out=self.state.active_mask_any) def __call__( - self, x: torch.Tensor, out_len: torch.Tensor, + self, + x: torch.Tensor, + out_len: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: if self.cuda_graphs_mode is not None and x.device.type == "cuda": return self.loop_labels_cuda_graphs(encoder_output=x, encoder_output_length=out_len) From 4167641fae262b4f6b6828498b65aa148511c51c Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 14 May 2024 14:15:17 -0400 Subject: [PATCH 08/18] move tts fixtures (#9183) * move tts fixtures Signed-off-by: Jason * Apply isort and black reformatting Signed-off-by: blisc --------- Signed-off-by: Jason Signed-off-by: blisc Co-authored-by: blisc --- .../tts.py => collections/tts/conftest.py} | 0 tests/conftest.py | 13 +++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) rename tests/{fixtures/tts.py => collections/tts/conftest.py} (100%) diff --git a/tests/fixtures/tts.py b/tests/collections/tts/conftest.py similarity index 100% rename from tests/fixtures/tts.py rename to tests/collections/tts/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py index 5069890e48405..6298ed051c68f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,8 +25,6 @@ import pytest -from tests.fixtures.tts import * - # Those variables probably should go to main NeMo configuration file (config.yaml). __TEST_DATA_FILENAME = "test_data.tar.gz" __TEST_DATA_URL = "https://github.com/NVIDIA/NeMo/releases/download/v1.0.0rc1/" @@ -68,7 +66,7 @@ def pytest_addoption(parser): @pytest.fixture def device(request): - """ Simple fixture returning string denoting the device [CPU | GPU] """ + """Simple fixture returning string denoting the device [CPU | GPU]""" if request.config.getoption("--cpu"): return "CPU" else: @@ -193,13 +191,16 @@ def pytest_configure(config): If file absent or sizes not equal, function downloads the archive from github and unpacks it. """ config.addinivalue_line( - "markers", "run_only_on(device): runs the test only on a given device [CPU | GPU]", + "markers", + "run_only_on(device): runs the test only on a given device [CPU | GPU]", ) config.addinivalue_line( - "markers", "with_downloads: runs the test using data present in tests/.data", + "markers", + "with_downloads: runs the test using data present in tests/.data", ) config.addinivalue_line( - "markers", "nightly: runs the nightly test for QA.", + "markers", + "nightly: runs the nightly test for QA.", ) # Test dir and archive filepath. test_dir = join(dirname(__file__), __TEST_DATA_SUBDIR) From 4d574fe493df9f7e86629d2a0afe880f1a52764d Mon Sep 17 00:00:00 2001 From: Adi Renduchintala Date: Tue, 14 May 2024 13:23:32 -0700 Subject: [PATCH 09/18] enable matryoshka embedding learning (#9130) * enable matryoshka embedding learning Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply isort and black reformatting Signed-off-by: arendu --------- Signed-off-by: arendu Signed-off-by: arendu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: arendu --- .../megatron_gpt_embedding_model.py | 49 +++++++++++++++---- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py index d477b337cd299..389c90d7f97c2 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py @@ -58,6 +58,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.temperature = self.cfg.get('temperature', 0.02) self.use_all_possible_negatives = self.cfg.get("use_all_possible_negatives", True) self.global_inbatch_negatives = self.cfg.get("global_inbatch_negatives", True) + if self.cfg.get("do_mrl", False): + min_mrl = self.cfg.get("min_mrl_dim", int(np.log2(32))) - 1 + max_mrl = int(np.log2(self.cfg.hidden_size // 2)) + self.mrl_dims = [2**i for i in range(max_mrl, min_mrl, -1)] + else: + self.mrl_dims = [] + assert ( self.cfg.get("post_process", False) is False ), "post_process must be False to get hidden states in the loss_func" @@ -255,7 +262,14 @@ def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_me gathered_output_batches = [None for _ in range(parallel_state.get_data_parallel_world_size())] torch.distributed.all_gather_object( gathered_output_batches, - [{'q_hs': batch['q_hs'], 'd_hs': batch['d_hs'], 'metadata': batch['metadata'],} for batch in output], + [ + { + 'q_hs': batch['q_hs'], + 'd_hs': batch['d_hs'], + 'metadata': batch['metadata'], + } + for batch in output + ], group=parallel_state.get_data_parallel_group(), ) @@ -272,7 +286,11 @@ def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_me l_d_hs = listify(batch['d_hs']) l_m = batch['metadata'] assert len(l_m) == len(l_q_hs) == len(l_d_hs) - for q_hs, d_hs, metadata in zip(l_q_hs, l_d_hs, l_m,): + for q_hs, d_hs, metadata in zip( + l_q_hs, + l_d_hs, + l_m, + ): total_size += 1 if not metadata.get("__AUTOGENERATED__", False): deduplicated_outputs['q_hs'].append(q_hs) @@ -326,10 +344,10 @@ def write_embeddings_to_file(self, outputs, output_file_path, d_idx): def local_validation_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ # Check if iterator is exhausted # dataloader_iter, done = self._val_iterator_done(dataloader_iter) @@ -377,7 +395,7 @@ def local_validation_step(self, dataloader_iter): return loss, non_loss_tensors - def constrastive_scores(self, pos_doc_hs, neg_doc_hs, query_hs, bs, use_all_possible_negatives=False): + def constrastive_scores(self, pos_doc_hs, neg_doc_hs, query_hs, bs, temperature, use_all_possible_negatives=False): all_doc_hs = torch.cat([pos_doc_hs, neg_doc_hs], dim=0) # (2bs) x hidden_size cs = torch.mm(query_hs, all_doc_hs.transpose(0, 1)) # (bs) x (2bs) pos_cs = cs[:, :bs].diag() @@ -389,6 +407,8 @@ def constrastive_scores(self, pos_doc_hs, neg_doc_hs, query_hs, bs, use_all_poss cs = torch.cat([pos_cs.unsqueeze(1), neg_cs.unsqueeze(1)], dim=1) pos_cs = pos_cs.clone().detach().mean() neg_cs = neg_cs.clone().detach().mean() + cs = cs.clamp(-1.0, 1.0) + cs = cs / temperature return cs, pos_cs, neg_cs, labels def inference_loss_func(self, loss_mask, num_valid_tokens_in_ub, eos_tensors): @@ -426,11 +446,20 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): neg_doc_hs = torch.nn.functional.normalize(neg_doc_hs, dim=1) cs, pos_cs, neg_cs, labels = self.constrastive_scores( - pos_doc_hs, neg_doc_hs, query_hs, bs, self.use_all_possible_negatives + pos_doc_hs, neg_doc_hs, query_hs, bs, self.temperature, self.use_all_possible_negatives ) - cs = cs.clamp(-1.0, 1.0) - cs = cs / self.temperature loss = torch.nn.functional.cross_entropy(cs, labels) + if self.mrl_dims: + for dim in self.mrl_dims: + cs_dim, _, _, _ = self.constrastive_scores( + pos_doc_hs[:, :dim], + neg_doc_hs[:, :dim], + query_hs[:, :dim], + bs, + self.temperature, + self.use_all_possible_negatives, + ) + loss += torch.nn.functional.cross_entropy(cs_dim, labels) cp_size = self.cfg.get('context_parallel_size', 1) if cp_size > 1: From 5df8e11255802a2ce2f33db6362e60990e215b64 Mon Sep 17 00:00:00 2001 From: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Date: Tue, 14 May 2024 15:16:21 -0700 Subject: [PATCH 10/18] Add guards to SD imports (#9158) * Add guards to SD imports Signed-off-by: yaoyu-33 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: yaoyu-33 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../modules/imagen/diffusionmodules/layers.py | 9 +++++++- .../modules/stable_diffusion/attention.py | 17 ++++++++++++--- .../diffusionmodules/model.py | 9 +++++++- .../diffusionmodules/openaimodel.py | 21 +++++++++++++++---- .../stable_diffusion/diffusionmodules/util.py | 9 +++++++- 5 files changed, 55 insertions(+), 10 deletions(-) diff --git a/nemo/collections/multimodal/modules/imagen/diffusionmodules/layers.py b/nemo/collections/multimodal/modules/imagen/diffusionmodules/layers.py index 72e70250f0d73..f5beca436ecfc 100644 --- a/nemo/collections/multimodal/modules/imagen/diffusionmodules/layers.py +++ b/nemo/collections/multimodal/modules/imagen/diffusionmodules/layers.py @@ -43,7 +43,14 @@ import torch as th import torch.nn as nn import torch.nn.functional as F -from apex.contrib.group_norm import GroupNorm + +try: + from apex.contrib.group_norm import GroupNorm + + OPT_GROUP_NORM = True +except Exception: + print('Fused optimized group norm has not been installed.') + OPT_GROUP_NORM = False def conv_nd(dims, *args, **kwargs): diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py index f5689c706e2c8..c70b59d394817 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/attention.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py @@ -17,7 +17,6 @@ import torch import torch.nn.functional as F -from apex.contrib.group_norm import GroupNorm from einops import rearrange, repeat from torch import einsum, nn from torch._dynamo import disable @@ -25,9 +24,13 @@ if os.environ.get("USE_NATIVE_GROUP_NORM", "0") == "1": from nemo.gn_native import GroupNormNormlization as GroupNorm else: - from apex.contrib.group_norm import GroupNorm + try: + from apex.contrib.group_norm import GroupNorm -from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP + OPT_GROUP_NORM = True + except Exception: + print('Fused optimized group norm has not been installed.') + OPT_GROUP_NORM = False from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( @@ -37,6 +40,14 @@ from nemo.core import adapter_mixins from nemo.utils import logging +try: + from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP + + HAVE_TE = True + +except (ImportError, ModuleNotFoundError): + HAVE_TE = False + def check_cuda(): if not torch.cuda.is_available(): diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py index 7fc5c208004f7..644efafaf06a5 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py @@ -17,12 +17,19 @@ import numpy as np import torch import torch.nn as nn -from apex.contrib.group_norm import GroupNorm from einops import rearrange from nemo.collections.multimodal.modules.stable_diffusion.attention import LinearAttention from nemo.collections.multimodal.parts.stable_diffusion.utils import instantiate_from_config +try: + from apex.contrib.group_norm import GroupNorm + + OPT_GROUP_NORM = True +except Exception: + print('Fused optimized group norm has not been installed.') + OPT_GROUP_NORM = False + def get_timestep_embedding(timesteps, embedding_dim): """ diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index b610f921a22a8..3e301f0b8fc19 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -26,10 +26,6 @@ import torch.nn as nn import torch.nn.functional as F -# FP8 related import -import transformer_engine -from apex.contrib.group_norm import GroupNorm - from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import ( avg_pool_nd, @@ -45,6 +41,23 @@ ) from nemo.utils import logging +try: + # FP8 related import + import transformer_engine + + HAVE_TE = True + +except (ImportError, ModuleNotFoundError): + HAVE_TE = False + +try: + from apex.contrib.group_norm import GroupNorm + + OPT_GROUP_NORM = True +except Exception: + print('Fused optimized group norm has not been installed.') + OPT_GROUP_NORM = False + def convert_module_to_dtype(module, dtype, enable_norm_layers=False): # Convert module parameters to dtype diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py index 3b446f4a42c33..53f9669a0b8f3 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py @@ -29,11 +29,18 @@ import numpy as np import torch import torch.nn as nn -from apex.contrib.group_norm import GroupNorm from einops import repeat from torch._dynamo import disable from torch.cuda.amp import custom_bwd, custom_fwd +try: + from apex.contrib.group_norm import GroupNorm + + OPT_GROUP_NORM = True +except Exception: + print('Fused optimized group norm has not been installed.') + OPT_GROUP_NORM = False + def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": From c2daa916b6454fe568706b4ab5da06500e2c6728 Mon Sep 17 00:00:00 2001 From: mikolajblaz Date: Wed, 15 May 2024 13:57:18 +0200 Subject: [PATCH 11/18] Implement async distributed checkpoint save (#9028) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Prevent duplicated checkpoints Signed-off-by: Mikołaj Błaż * Introduce DistributedCheckpointIO Signed-off-by: Mikołaj Błaż * Fix DistCkptIO usage Signed-off-by: Mikołaj Błaż * Use NeMo logger Signed-off-by: Mikołaj Błaż * [DCIO] Fix save_to dist ckpt path Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add versioning to save_to Signed-off-by: Mikołaj Błaż * Add versioning logic to all .nemo files Signed-off-by: Mikołaj Błaż * Add versioning test Signed-off-by: Mikołaj Błaż * Add dist-ckpt test Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Mikołaj Błaż * Rename existing ckpts instead of using different name Signed-off-by: Mikołaj Błaż * Add comment Signed-off-by: Mikołaj Błaż * Use dist ckpt flag in all methods Signed-off-by: Mikołaj Błaż * Improve error msg Signed-off-by: Mikołaj Błaż * Add dist ckpt unit tests Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix load_checkpoint Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Mikołaj Błaż * Fix auto-issues Signed-off-by: Mikołaj Błaż * Fix ckpt_dir var Signed-off-by: Mikołaj Błaż * Restore skipping behavior The fix from prevent-duplicated-checkpoints is required to skip the checkpoints Signed-off-by: Mikołaj Błaż * Fix steps on single-GPU machine Signed-off-by: Mikołaj Błaż * Run dist-ckpt test on GPU Signed-off-by: Mikołaj Błaż * Add docs Signed-off-by: Mikołaj Błaż * Apply black Signed-off-by: Mikołaj Błaż * Prevent saving last for non-equal val intervals Signed-off-by: Mikołaj Błaż * Move checkpoint on rank 0 Signed-off-by: Mikołaj Błaż * Fix num steps in tests Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Mikołaj Błaż * Add async ckpt implementation Signed-off-by: Mikołaj Błaż * Abstract AsyncFinalizableCheckpointIO away Signed-off-by: Mikołaj Błaż * Change async_save flag location Signed-off-by: Mikołaj Błaż * Add debug info Signed-off-by: Mikołaj Błaż * Apply formatting Signed-off-by: Mikołaj Błaż * Handle multiple async saves Signed-off-by: Mikołaj Błaż * Apply formatting Signed-off-by: Mikołaj Błaż * Move finalization calls to a callback Signed-off-by: Mikołaj Błaż * Avoid deadlock in teardown Signed-off-by: Mikołaj Błaż * Adjust to MCore implementation Signed-off-by: Mikołaj Błaż * Add notes and copyrights Signed-off-by: Mikołaj Błaż * Apply formatting Signed-off-by: Mikołaj Błaż * Fix async_request attribute Signed-off-by: Mikołaj Błaż * Add MCore import guards Signed-off-by: Mikołaj Błaż * Add async test Signed-off-by: Mikołaj Błaż * Fix finalize_fn arg Signed-off-by: Mikołaj Błaż * Add docs Signed-off-by: Mikołaj Błaż * Remove checkpoints from accurate steps Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix MCore class usage Signed-off-by: Mikołaj Błaż * Update docs Signed-off-by: Mikołaj Błaż * Fix logger usage Signed-off-by: Mikołaj Błaż * Fix rebase Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix code scan issues Signed-off-by: Mikołaj Błaż * Remove unsused import Signed-off-by: Mikołaj Błaż * Use dist-ckpt for Bert Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix load checkpoint return val Signed-off-by: Mikołaj Błaż * Use dist-ckpt based on sharded_state_dict Signed-off-by: Mikołaj Błaż * Add async logging Signed-off-by: Mikołaj Błaż * Remove deprecated argument Signed-off-by: Mikołaj Błaż * Use correct checkpoint_io Signed-off-by: Mikołaj Błaż * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bad merge Signed-off-by: Mikołaj Błaż * Improve debug msg Signed-off-by: Mikołaj Błaż * Run async test on GPU Signed-off-by: Mikołaj Błaż * Fix async ckpt unit test Signed-off-by: Mikołaj Błaż * Apply isort and black reformatting Signed-off-by: mikolajblaz * Clarify async logs Signed-off-by: Mikołaj Błaż * Add schema print Signed-off-by: Mikołaj Błaż --------- Signed-off-by: Mikołaj Błaż Signed-off-by: mikolajblaz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../conf/megatron_gpt_config.yaml | 1 + .../nlp/parts/megatron_trainer_builder.py | 62 +++- nemo/collections/nlp/parts/nlp_overrides.py | 107 ++++--- nemo/utils/callbacks/checkpointing_context.py | 0 nemo/utils/callbacks/dist_ckpt_io.py | 221 ++++++++++++- nemo/utils/callbacks/nemo_model_checkpoint.py | 100 +++++- nemo/utils/callbacks/torch_dist_async.py | 298 ++++++++++++++++++ nemo/utils/exp_manager.py | 27 +- tests/core/test_dist_ckpt.py | 99 +++++- 9 files changed, 806 insertions(+), 109 deletions(-) create mode 100644 nemo/utils/callbacks/checkpointing_context.py create mode 100644 nemo/utils/callbacks/torch_dist_async.py diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index aa43dfe7e53e4..20e20744833c4 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -52,6 +52,7 @@ exp_manager: save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits filename: 'megatron_gpt--{val_loss:.2f}-{step}-{consumed_samples}' model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + async_save: False # Set to True to enable async checkpoint save. Currently works only with distributed checkpoints model: # use GPTModel from megatron.core diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index a97b9301fb266..e1a780f09756c 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -13,8 +13,9 @@ # limitations under the License. import sys -from typing import Union +from typing import Optional, Union +from lightning_fabric.utilities.exceptions import MisconfigurationException from omegaconf import DictConfig from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelSummary @@ -31,7 +32,11 @@ PipelineMixedPrecisionPlugin, ) from nemo.utils import logging -from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO +from nemo.utils.callbacks.dist_ckpt_io import ( + AsyncFinalizableCheckpointIO, + AsyncFinalizerCallback, + DistributedCheckpointIO, +) class MegatronTrainerBuilder: @@ -51,7 +56,10 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]: _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) if _IS_INTERACTIVE and self.cfg.trainer.devices == 1: logging.info("Detected interactive environment, using NLPDDPStrategyNotebook") - return NLPDDPStrategyNotebook(no_ddp_communication_hook=True, find_unused_parameters=False,) + return NLPDDPStrategyNotebook( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) if self.cfg.model.get('fsdp', False): assert ( @@ -89,7 +97,7 @@ def _grad_scaler(self) -> GradScaler: Returns a scaler for precision plugins. """ return GradScaler( - init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=self.cfg.model.get('native_amp_init_scale', 2**32), growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000), hysteresis=self.cfg.model.get('hysteresis', 2), ) @@ -137,19 +145,41 @@ def _plugins(self) -> list: use_dist_ckpt = not self.cfg.model.get('fsdp', False) and ( self.cfg.model.get('mcore_gpt', False) or self.cfg.model.get('mcore_bert', False) ) + async_save = self.cfg.exp_manager.checkpoint_callback_params.get('async_save', False) if use_dist_ckpt: - plugins.append(DistributedCheckpointIO.from_config(self.cfg.model)) + checkpoint_io = DistributedCheckpointIO.from_config(self.cfg.model, async_save) + if async_save: + checkpoint_io = AsyncFinalizableCheckpointIO(checkpoint_io) + plugins.append(checkpoint_io) + elif async_save: + raise MisconfigurationException( + 'exp_manager.checkpoint_callback_params.async_save=True without' + 'distributed checkpoints is currently not supported' + ) return plugins + def _callbacks(self, callbacks: Optional[list]) -> list: + """ + Returns: + callbacks: list of callbacks passed to Trainer.callbacks. + """ + if callbacks is None: + callbacks = [] + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + + if self.cfg.exp_manager.checkpoint_callback_params.get('async_save', False): + callbacks.append(AsyncFinalizerCallback()) + return callbacks + def create_trainer(self, callbacks=None) -> Trainer: # cfg.trainer.precision becomes None in Trainer if precision_plugins exist since both precision plugins and precision precision = self.cfg.trainer.precision strategy = self._training_strategy() plugins = self._plugins() - # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks - if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar: - callbacks = [CustomProgressBar()] + callbacks = self._callbacks(callbacks) trainer = Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks) # Restore the precision value after Trainer is built. self.cfg.trainer.precision = precision @@ -161,7 +191,7 @@ class MegatronBertTrainerBuilder(MegatronTrainerBuilder): def _grad_scaler(self) -> GradScaler: return GradScaler( - init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=self.cfg.model.get('native_amp_init_scale', 2**32), growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000), ) @@ -169,13 +199,15 @@ def _grad_scaler(self) -> GradScaler: class MegatronT5TrainerBuilder(MegatronTrainerBuilder): """Builder for T5 model Trainer with overrides.""" - def create_trainer(self) -> Trainer: + def _callbacks(self, callbacks: Optional[list]) -> list: + callbacks = super()._callbacks(callbacks) + callbacks.append(ModelSummary(max_depth=3)) + return callbacks + + def create_trainer(self, callbacks=None) -> Trainer: strategy = self._training_strategy() plugins = self._plugins() - callbacks = [ModelSummary(max_depth=3)] - # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks - if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar: - callbacks.append(CustomProgressBar()) + callbacks = self._callbacks(callbacks) return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks) @@ -207,7 +239,7 @@ class MegatronLMPPTrainerBuilder(MegatronTrainerBuilder): def _grad_scaler(self) -> GradScaler: return GradScaler( - init_scale=self.cfg.model.get("native_amp_init_scale", 2 ** 32), + init_scale=self.cfg.model.get("native_amp_init_scale", 2**32), growth_interval=self.cfg.model.get("native_amp_growth_interval", 1000), hysteresis=self.cfg.model.get("hysteresis", 2), enabled=False if self.cfg.model.pipeline_model_parallel_size > 1 else True, diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 1c68ebff81210..65ffb7df47f46 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -35,6 +35,7 @@ from pytorch_lightning.loops.fetchers import _DataFetcher from pytorch_lightning.plugins import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.precision.fsdp import FSDPPrecision from pytorch_lightning.strategies import DDPStrategy, FSDPStrategy @@ -120,7 +121,7 @@ def init_model_parallel( sharp: bool, nccl_communicator_config_path: str = None, distributed_timeout_minutes: int = 30 ) -> None: - """ Initializes Megatron-LM model parallel if using model parallelism. + """Initializes Megatron-LM model parallel if using model parallelism. Args: sharp: Apply SHARP to NCCL data-parallel communication. @@ -164,7 +165,7 @@ def init_model_parallel( class NLPDDPStrategy(DDPStrategy): - """ DDP plugin for Pytorch Lightning. Needed to customize DDP for model parallel models. + """DDP plugin for Pytorch Lightning. Needed to customize DDP for model parallel models. Args: no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2 @@ -231,8 +232,8 @@ def setup_distributed(self, global_rank: int = None, world_size: int = None) -> ) def configure_ddp(self): - """ Override LightningModule ddp if using model parallel. - Sets find_unused_parameters to False to use activation-checkpoint-recomputation. + """Override LightningModule ddp if using model parallel. + Sets find_unused_parameters to False to use activation-checkpoint-recomputation. """ if (hasattr(self.model, 'megatron_amp_O2') and self.model.megatron_amp_O2) or ( @@ -362,9 +363,6 @@ def save_checkpoint( unsharded_optim_state=checkpoint['optimizer_states'][0] ) checkpoint['optimizer_states'] = [sharded_optim_state] - # dist_checkpointing expects a directory so we will name the directory - # using the path with the file extension removed - checkpoint_dir = ckpt_to_dir(filepath) # remove device state_dict checkpoint['state_dict'] = OrderedDict([]) @@ -406,7 +404,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict) def _fix_tensors_device(self, ckpt: Dict) -> Dict: - """ Ensure checkpoint tensors are on the correct device.""" + """Ensure checkpoint tensors are on the correct device.""" assert torch.cuda.is_initialized(), (torch.cuda.is_available(), torch.cuda.is_initialized()) cur_dev = torch.device("cuda", index=torch.cuda.current_device()) @@ -418,10 +416,10 @@ def _fix_device(t): return dict_list_map_outplace(_fix_device, ckpt) def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - """ PTL method which we override to integrate distributed checkpoints for model parallel models. - In order to load distributed checkpoints we need to provide the sharded_state_dict to - the distributed load function. We get the sharded_state_dict from self.lightning_module - which makes it convenient to have the loading logic happen at the strategy level. + """PTL method which we override to integrate distributed checkpoints for model parallel models. + In order to load distributed checkpoints we need to provide the sharded_state_dict to + the distributed load function. We get the sharded_state_dict from self.lightning_module + which makes it convenient to have the loading logic happen at the strategy level. """ fs = get_filesystem(checkpoint_path) @@ -466,7 +464,10 @@ def remove_checkpoint(self, filepath: Union[str, Path]) -> None: @property def use_distributed_checkpointing(self): - has_dist_ckpt_io = HAVE_MEGATRON_CORE and isinstance(self.checkpoint_io, DistributedCheckpointIO) + checkpoint_io = self.checkpoint_io + while isinstance(checkpoint_io, _WrappingCheckpointIO): + checkpoint_io = checkpoint_io.checkpoint_io + has_dist_ckpt_io = HAVE_MEGATRON_CORE and isinstance(checkpoint_io, DistributedCheckpointIO) has_sharded_state_dict = ( hasattr(self.lightning_module, 'sharded_state_dict') and self.lightning_module.sharded_state_dict() is not None @@ -500,15 +501,15 @@ def distributed_sampler_kwargs(self): @property def restore_checkpoint_after_setup(self) -> bool: - """ This needs to be True for distributed checkpointing because - we require the model to have configured the optimizer before - deserializing the checkpoint. + """This needs to be True for distributed checkpointing because + we require the model to have configured the optimizer before + deserializing the checkpoint. """ return True class NLPDDPStrategyNotebook(NLPDDPStrategy): - """ Version of NLPDDPStrategy to be used in a Jupyter Notebook + """Version of NLPDDPStrategy to be used in a Jupyter Notebook A large portion of Megatron code has DDP dependency, so it has been necessary to use NLPDDPStrategy even for single-GPU training (e.g. in a Jupyter notebook) A PTL 2.0 changes has prevented DDPStrategy to be used in a notebook. @@ -546,7 +547,7 @@ def _get_full_state_dict_context(module: torch.nn.Module, rank0_only: bool = Fal class NLPFSDPStrategy(FSDPStrategy): - """ FSDP plugin for Pytorch Lightning with the support for tensor-parallelism. + """FSDP plugin for Pytorch Lightning with the support for tensor-parallelism. Args: sharding_strategy: FSDP parameter sharding strategy. @@ -639,7 +640,11 @@ def _set_mixed_precision_recipe( reduce_dtype = utils_funcs.torch_dtype_from_precision(grad_reduce_dtype, None) if set_buffer_dtype is not None: buffer_dtype = utils_funcs.torch_dtype_from_precision(buffer_dtype, None) - return MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype,) + return MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, + ) def setup_environment(self) -> None: """ @@ -750,7 +755,9 @@ def _get_osd(opt_state): with FSDP.summon_full_params(self.model, writeback=True, rank0_only=False): # rekey the osd stored from non-FSDP model rekeyed_osd = FSDP.rekey_optim_state_dict( - temp_osd, OptimStateKeyType.PARAM_NAME, self.model, + temp_osd, + OptimStateKeyType.PARAM_NAME, + self.model, ) temp_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, self.model) except Exception as e: @@ -758,7 +765,9 @@ def _get_osd(opt_state): exit(1) # Shard optimizer state dict sharded_osd = FSDP.optim_state_dict_to_load( - optim_state_dict=temp_osd, model=self.model, optim=optimizer, + optim_state_dict=temp_osd, + model=self.model, + optim=optimizer, ) optimizer.load_state_dict(sharded_osd) @@ -767,9 +776,9 @@ def _get_osd(opt_state): def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None ) -> None: - """ Store checkpoints - 1. In case of sharded checkpoint, all ranks store unique checkpoints. - 2. In case of non-sharded checkpoint, all data-parallel rank 0 store checkpoints. + """Store checkpoints + 1. In case of sharded checkpoint, all ranks store unique checkpoints. + 2. In case of non-sharded checkpoint, all data-parallel rank 0 store checkpoints. """ app_state = AppState() filepath = inject_model_parallel_rank(filepath, fsdp_sharded_ckpt=self.sharded_checkpoint) @@ -780,8 +789,7 @@ def save_checkpoint( self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - """ Load checkpoints - """ + """Load checkpoints""" # 1. Load normal or FSDP-sharded checkpoints. fs = get_filesystem(checkpoint_path) @@ -798,8 +806,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: return checkpoint def remove_checkpoint(self, filepath: Union[str, Path]) -> None: - """ Remove checkpoints - """ + """Remove checkpoints""" # legacy checkpoint logic, does not use megatron core app_state = AppState() # PTL override to accomodate model parallel checkpoints @@ -814,9 +821,9 @@ def remove_checkpoint(self, filepath: Union[str, Path]) -> None: @property def restore_checkpoint_after_setup(self) -> bool: - """ When loading FSDP-sharded checkpoint, need to restore checkpoint after configuring - FSDP sharding to match FSDP-sharded format between the checkpoint and the current - model and optimizer. + """When loading FSDP-sharded checkpoint, need to restore checkpoint after configuring + FSDP sharding to match FSDP-sharded format between the checkpoint and the current + model and optimizer. """ return True @@ -915,7 +922,8 @@ def dummy(): else: # move weights to the tmpdir for tp_rank, pp_rank in itertools.product( - range(app_state.tensor_model_parallel_size), range(app_state.pipeline_model_parallel_size), + range(app_state.tensor_model_parallel_size), + range(app_state.pipeline_model_parallel_size), ): os.makedirs(os.path.join(tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}')) mp_model_weights = os.path.join( @@ -1000,6 +1008,7 @@ def modify_state_dict(self, conf, state_dict): loaded_keys = state_dict.keys() if 'model.model.diffusion_model.input_blocks.1.0.in_layers.2.weight' in loaded_keys: new_state_dict = {} + # GroupNormOpt fuses activation function to one layer, thus the indexing of weights are shifted for following def should_process(key): base_str = "model.model.diffusion_model." @@ -1110,7 +1119,13 @@ def restore_from( # Get path where the command is executed - the artifacts will be "retrieved" there # (original .nemo behavior) loaded_params = super().load_config_and_state_dict( - calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer, + calling_cls, + restore_path, + override_config_path, + map_location, + strict, + return_config, + trainer, ) if not isinstance(loaded_params, tuple) or return_config is True: return loaded_params @@ -1165,12 +1180,12 @@ def dummy(): class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin): - """ Overrides PTL autocasting to not wrap training/val/test_step. - We do this because we have the megatron-core fwd/bwd functions in training_step. - This means .backward is being called in training_step so we do not want the whole - step wrapped in autocast. + """Overrides PTL autocasting to not wrap training/val/test_step. + We do this because we have the megatron-core fwd/bwd functions in training_step. + This means .backward is being called in training_step so we do not want the whole + step wrapped in autocast. - We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions. + We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions. """ def __init__( @@ -1206,12 +1221,12 @@ def forward_context(self) -> Generator[None, None, None]: class FSDPMixedPrecisionPlugin(FSDPPrecision): - """ Overrides PTL autocasting to not wrap training/val/test_step. - We do this because we have the megatron-core fwd/bwd functions in training_step. - This means .backward is being called in training_step so we do not want the whole - step wrapped in autocast. + """Overrides PTL autocasting to not wrap training/val/test_step. + We do this because we have the megatron-core fwd/bwd functions in training_step. + This means .backward is being called in training_step so we do not want the whole + step wrapped in autocast. - We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions. + We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions. """ def __init__( @@ -1246,7 +1261,7 @@ class GradScaler(torch.cuda.amp.GradScaler): def __init__( self, - init_scale=2.0 ** 16, + init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, @@ -1500,7 +1515,7 @@ def optimizer_step( @contextmanager def forward_context(self) -> Generator[None, None, None]: - """ No explicit precision casting. Inputs are supposed to be manually casted """ + """No explicit precision casting. Inputs are supposed to be manually casted""" try: yield finally: @@ -1508,7 +1523,7 @@ def forward_context(self) -> Generator[None, None, None]: class GlobalBatchDataFetcher(_DataFetcher): - """ Overrides PTL DataFetcher. Used to fetch global batches.""" + """Overrides PTL DataFetcher. Used to fetch global batches.""" def __init__(self, prefetch_batches: int = 0, store_on_device: bool = False) -> None: diff --git a/nemo/utils/callbacks/checkpointing_context.py b/nemo/utils/callbacks/checkpointing_context.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 2e695dd7bbaa0..905de4eb35670 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -1,41 +1,217 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import shutil +from abc import ABC, abstractmethod +from contextlib import contextmanager +from time import time from typing import Any, Dict, Optional +import pytorch_lightning as pl from lightning_fabric.plugins import CheckpointIO from lightning_fabric.utilities.cloud_io import get_filesystem from lightning_fabric.utilities.types import _PATH -from megatron.core import dist_checkpointing -from megatron.core.dist_checkpointing.strategies import tensorstore +from pytorch_lightning import Callback +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from nemo.utils import logging +try: + from megatron.core import dist_checkpointing + from megatron.core.dist_checkpointing.strategies import tensorstore + + from nemo.utils.callbacks.torch_dist_async import AsyncCallsQueue, AsyncRequest, TorchDistAsyncSaveShardedStrategy + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError) as IMPORT_ERROR_EXC: + + HAVE_MEGATRON_CORE = False + IMPORT_ERROR = "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + + +@contextmanager +def _debug_time(name: str): + """Simple context manager for timing functions/code blocks.""" + start = time() + try: + yield + finally: + logging.debug(f'{name} took {time() - start:.3f}s') + + +class AsyncCompatibleCheckpointIO(CheckpointIO, ABC): + """CheckpointIO that can be used together with async saving. + + Differs from the regular CheckpointIO only by the `save_checkpoint` + return type. The `save_checkpoint` method itself is synchronous, but returns + callbacks that can be performed asynchronously. + """ + + @abstractmethod + def save_checkpoint( + self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None + ) -> 'AsyncRequest': + raise NotImplementedError + -class DistributedCheckpointIO(CheckpointIO): - """ CheckpointIO for a distributed checkpoint format. +class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO): + """CheckpointIO wrapper for async checkpoint saving and synchronous finalization. + + Runs main part of the checkpoint save in a separate process (not thread as the PTL + AsyncCheckpointIO does). Allows to perform a (synchronous) finalization + function after all ranks finish checkpoint saving. + + NOTE: for correctness, this plugin must be used together with the + AsyncFinalizerCallback callback which performs the finalization checks. + + Args: + checkpoint_io (CheckpointIO): wrapped checkpoint_io object. Must be + of type AsyncCompatibleCheckpointIO. + Requires the underlying checkpoint_io.save_checkpoint to return save_fn, save_args, finalize_fn. + """ + + def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None: + if not HAVE_MEGATRON_CORE: + raise ImportError(IMPORT_ERROR) from IMPORT_ERROR_EXC + if not isinstance(checkpoint_io, AsyncCompatibleCheckpointIO): + raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}') + + super().__init__(checkpoint_io) + self.async_calls_queue = AsyncCallsQueue() + + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + """Executes async request returned from the underlying checkpoint_io asynchronously. + + Requires the underlying checkpoint_io.save_checkpoint to return an AsyncRequest. + It is then applied with `self.async_calls_queue` asynchronously. + + Args: + checkpoint (Dict[str, Any]): checkpoint to save. Passed to underlying + checkpoint_io without modifications. + path (_PATH): path to save the checkpoint. Passed to underlying + checkpoint_io without modifications. + storage_options (Any, optional): storage control modifiers. This class + consumed the `finalize_fn` parameter (if any), which is expected to be + a callback and is appended to async finalization functions. + + Applies underlying checkpoint_io finalize callback first, then the external one (postfix order). + """ + external_finalize_fn = (storage_options or {}).pop('finalize_fn', None) + assert isinstance(self.checkpoint_io, AsyncCompatibleCheckpointIO), type(self.checkpoint_io) + async_request = self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options) + if external_finalize_fn is not None: + async_request.add_finalize_fn(external_finalize_fn) + call_idx = self.async_calls_queue.schedule_async_request(async_request) + logging.debug(f'Scheduled an async call #{call_idx}') + + @_debug_time('AsyncFinalizableCheckpointIO.maybe_finalize_save_checkpoint') + def maybe_finalize_save_checkpoint(self, blocking: bool = False): + """Performs checkpoint finalization (if possible). + + Args: + blocking (bool, optional): if True, waits until all async saves are + completed. Otherwise, finalizes only those async calls which are + already done on all ranks. Defaults to False. + """ + call_idx_finalized = self.async_calls_queue.maybe_finalize_async_calls(blocking) + if call_idx_finalized: + logging.debug(f'Finalized async calls: {[f"#{idx}" for idx in call_idx_finalized]}') + return len(call_idx_finalized) > 0 + + def teardown(self) -> None: + """Warns if there are any pending checkpoint saves.""" + super().teardown() + if self.async_calls_queue.get_num_unfinalized_calls() > 0: + # Can't do finalization now because some ranks might be lost + logging.warning('Some async checkpoint saves might be not finalized properly.') + + +class AsyncFinalizerCallback(Callback): + """Callback which finalizes async saves initiated by the AsyncFinalizableCheckpointIO. + + Tries to perform non-blocking finalization on train_batch_end and train_epoch_end. + On train_end performs a blocking finalization of all pending checkpoints. + """ + + def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: + self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=False) + + def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: + self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=False) + + def on_train_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: + checkpoint_io = self._get_checkpoint_io(trainer) + if checkpoint_io.async_calls_queue.get_num_unfinalized_calls() > 0: + logging.info('Pending async checkpoint saves. Finalizing them synchronously now') + self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=True) + + def _get_checkpoint_io(self, trainer) -> AsyncFinalizableCheckpointIO: + checkpoint_io = trainer.strategy.checkpoint_io + if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): + raise ValueError(f'Async finalizer requires an async compatible CheckpointIO, got: {checkpoint_io}') + return checkpoint_io + + +class DistributedCheckpointIO(AsyncCompatibleCheckpointIO): + """CheckpointIO for a distributed checkpoint format. Args: save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving. load_directly_on_device (bool, optional): if True, loads the weights directly on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed always loads on device). Defaults to True. + async_save (bool): whether to save asynchronously. Should be set to True if + this class will be wrapped with AsyncFinalizableCheckpointIO. """ - def __init__(self, save_ckpt_format: str, load_directly_on_device: bool = True): + def __init__( + self, + save_ckpt_format: str, + load_directly_on_device: bool = True, + async_save: bool = False, + ): super().__init__() + if not HAVE_MEGATRON_CORE: + raise ImportError(IMPORT_ERROR) from IMPORT_ERROR_EXC + self.save_ckpt_format = save_ckpt_format self.load_directly_on_device = load_directly_on_device - - self.save_sharded_strategy = self.determine_dist_ckpt_save_strategy() + self.async_save = async_save + self.save_sharded_strategy = self._determine_dist_ckpt_save_strategy() @classmethod - def from_config(cls, model_cfg): + def from_config(cls, model_cfg: dict, async_save: bool = False): + """Instantiates a DistributedCheckpointIO from a config dict. + + Args: + model_cfg (dict): model config dict. Most of the configuration + is extracted from this config. + async_save (bool, optional): async_save flag is not part of the model config, + it should be provided separately. Defaults to False. + """ return cls( save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'), load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True), + async_save=async_save, ) - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: - """ Saves a distributed checkpoint. Creates the checkpoint root directory if doesn't exist. + @_debug_time('DistributedCheckpointIO.save_checkpoint') + def save_checkpoint( + self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None + ) -> Optional['AsyncRequest']: + """Saves a distributed checkpoint. Creates the checkpoint root directory if doesn't exist. Args: checkpoint (Dict[str, Any]): sharded state dict to save @@ -48,11 +224,19 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio dist_checkpointing.save( sharded_state_dict=checkpoint, checkpoint_dir=path, sharded_strategy=self.save_sharded_strategy ) + if not self.async_save: + return None + # NOTE: this logic will be simplified in MCore v0.7 + assert self.save_sharded_strategy.async_request is not None + async_request = self.save_sharded_strategy.async_request + self.save_sharded_strategy.async_request = None + return async_request + @_debug_time('DistributedCheckpointIO.load_checkpoint') def load_checkpoint( self, path: _PATH, map_location: Optional[Any] = None, sharded_state_dict: Dict[str, Any] = None ) -> Dict[str, Any]: - """ Loads a distributed checkpoint. + """Loads a distributed checkpoint. Args: path (_PATH): checkpoint directory @@ -79,18 +263,25 @@ def load_checkpoint( sharded_state_dict=sharded_state_dict, checkpoint_dir=path, sharded_strategy=sharded_strategy ) + @_debug_time('DistributedCheckpointIO.remove_checkpoint') def remove_checkpoint(self, path: _PATH) -> None: - """ Remove a distributed checkpoint. + """Remove a distributed checkpoint. Due to potentially large number of files, the implementation remove the whole directory at once. """ shutil.rmtree(path, ignore_errors=True) - def determine_dist_ckpt_save_strategy(self): - """ Determine the saving strategy based on storage config. + def _determine_dist_ckpt_save_strategy(self): + """Determine the saving strategy based on constructor args. - For now only decides the checkpoint format. + If self.async_save is True instantiates an async PyT Dist strategy, + otherwise relies on MCore to create a proper strategy based on ckpt format. """ save_strategy = (self.save_ckpt_format, 1) + if self.async_save: + if save_strategy[0] != 'torch_dist': + raise ValueError('Async dist-ckpt save supported only for torch_dist format') + save_strategy = TorchDistAsyncSaveShardedStrategy('torch_dist', 1) + logging.info(f'Using {save_strategy} dist-ckpt save strategy.') return save_strategy diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index f8bdb9d9b2941..15e8a4e21f55b 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -21,19 +21,21 @@ import pytorch_lightning import torch +from _weakref import proxy from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol from pytorch_lightning.utilities import rank_zero_info from nemo.collections.common.callbacks import EMA from nemo.utils import logging from nemo.utils.app_state import AppState +from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO from nemo.utils.get_rank import is_global_rank_zero from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank class NeMoModelCheckpoint(ModelCheckpoint): - """ Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end. - Extends Lightning's on_save_checkpoint func to save the .nemo file. Saves the .nemo file based + """Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end. + Extends Lightning's on_save_checkpoint func to save the .nemo file. Saves the .nemo file based on the best checkpoint saved (according to the monitor value). Also contains func to save the EMA copy of the model. """ @@ -48,6 +50,7 @@ def __init__( postfix: str = ".nemo", n_resume: bool = False, model_parallel_size: int = None, + async_save: bool = False, # controls only finalize callbacks **kwargs, ): # Parse and store "extended" parameters: save_best model and postfix. @@ -64,6 +67,13 @@ def __init__( self.postfix = postfix self.previous_best_path = "" self.model_parallel_size = model_parallel_size + self.async_save = async_save + self.async_finalize_cb = None + # Checkpoints which removal is deferred until async save is done. + # Each element of `deferred_ckpts_to_remove` is a growing list + # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` + # is called, the last element is frozen and a new element is added. + self.deferred_ckpts_to_remove: List[List[str]] = [] # `prefix` is deprecated if 'prefix' in kwargs: @@ -262,7 +272,7 @@ def on_train_end(self, trainer, pl_module): pl_module.save_to(save_path=self._format_nemo_checkpoint_name()) def _backup_existing_nemo_ckpt(self, trainer) -> str: - """ Search for an available name with version infix and rename existing checkpoint. + """Search for an available name with version infix and rename existing checkpoint. NOTE: this behavior is slightly different from regular checkpoints. PTL creates new regular checkpoint with the first available name. @@ -330,15 +340,15 @@ def _ema_callback(self, trainer: 'pytorch_lightning.Trainer') -> Optional[EMA]: @staticmethod def format_checkpoint_unfinished_marker_path(checkpoint_path: Union[Path, str]) -> Path: - """ Format the path to the unfinished checkpoint marker file. - + """Format the path to the unfinished checkpoint marker file. + If the marker file exists, corresponding checkpoint is considered unfinished/incomplete. NOTE: Marker path for the EMA checkpoint part is the same as for the original checkpoint. - + Args: checkpoint_path: Path to the checkpoint file or dir. Does not need to exist. - + Returns: Path to the unfinished checkpoint marker file. """ @@ -350,7 +360,7 @@ def format_checkpoint_unfinished_marker_path(checkpoint_path: Union[Path, str]) @staticmethod def is_checkpoint_unfinished(checkpoint_path: Union[Path, str]) -> bool: - """ Check if the checkpoint is unfinished. + """Check if the checkpoint is unfinished. Args: checkpoint_path: Path to the checkpoint file or dir. @@ -363,7 +373,7 @@ def is_checkpoint_unfinished(checkpoint_path: Union[Path, str]) -> bool: @staticmethod def set_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_after=False) -> None: - """ Marks given checkpoint as unfinished. + """Marks given checkpoint as unfinished. Args: checkpoint_filepath: Path to the checkpoint file or dir. @@ -409,6 +419,8 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) ema_callback = self._ema_callback(trainer) if ema_callback is not None: + if self.async_save: + raise ValueError('async_save with EMA not supported') with ema_callback.save_original_optimizer_state(trainer): super()._save_checkpoint(trainer, filepath) @@ -418,13 +430,71 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) if self.verbose: rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") super()._save_checkpoint(trainer, filepath) + self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: - super()._save_checkpoint(trainer, filepath) - # barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker - # we don't want to remove the marker until all checkpointing is done. - self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) + # Async save passed the finalization function to checkpoint_io, + # sync save calls the finalization function immediately after save. + finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step) + if self.async_save: + checkpoint_io = trainer.strategy.checkpoint_io + if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): + raise ValueError('Async save requires async compatible CheckpointIO') + storage_options = dict(finalize_fn=finalize_fn) + # Each upcoming ckpt removal request will be executed as part of this save finalization + self.deferred_ckpts_to_remove.append([]) + else: + storage_options = None + trainer.save_checkpoint(filepath, self.save_weights_only, storage_options=storage_options) + if self.async_save: + logging.info(f'Scheduled async checkpoint save for {filepath}') + else: + finalize_fn() + + def _get_finalize_save_checkpoint_callback( + self, trainer: 'pytorch_lightning.Trainer', filepath: str, global_step: int + ): + """Creates a callback that can be used to finalize async (and sync) ckpt saves.""" - def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str) -> None: + def _cb(): + logging.debug(f'Finalize callback called for step {global_step}, filepath {filepath}') + self._last_global_step_saved = global_step + self._last_checkpoint_saved = filepath + + # notify loggers + if trainer.is_global_zero: + for logger in trainer.loggers: + logger.after_save_checkpoint(proxy(self)) + + # barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker + # we don't want to remove the marker until all checkpointing is done. + self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) + + if not self.async_save: + return + + logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.') + + # Remove checkpoints marked for removal by `self._remove_checkpoint` + # For each finalization there is exactly one entry in self.deferred_ckpts_to_remove + assert self.deferred_ckpts_to_remove + ckpts_to_remove = self.deferred_ckpts_to_remove.pop(0) + logging.debug(f'Checkpoints to remove: {ckpts_to_remove}') + for ckpt_to_remove in ckpts_to_remove: + self._remove_checkpoint(trainer, ckpt_to_remove, override_async=True) + + return _cb + + def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str, override_async=False) -> None: + """Performs checkpoint removal or deferred removal. + + With async save, `self._remove_checkpoint` is called before the checkpoint + is actually finished so we can't remove it. Instead we add it to + `self.deferred_ckpts_to_remove` for future removal. + """ + if self.async_save and not override_async: + # Register checkpoint removal in the last (active) checkpoint removal list + self.deferred_ckpts_to_remove[-1].append(filepath) + return # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during removal, we should be able to detect that data is incomplete. self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) @@ -499,7 +569,7 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren A checkpoint won't be deleted if any of the cases apply: - The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new) - The previous checkpoint is not in the current checkpoint directory and the filesystem is local - - The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local + - The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local and the resumed from checkpoint is not the last checkpoint """ if previous == current: diff --git a/nemo/utils/callbacks/torch_dist_async.py b/nemo/utils/callbacks/torch_dist_async.py new file mode 100644 index 0000000000000..1cd226af9cdbe --- /dev/null +++ b/nemo/utils/callbacks/torch_dist_async.py @@ -0,0 +1,298 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque +from pathlib import Path +from time import time +from typing import Callable, List, NamedTuple, Optional, Tuple + +import torch +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync +from megatron.core.dist_checkpointing.strategies.state_dict_saver import ( + save_state_dict_async_finalize, + save_state_dict_async_plan, +) +from megatron.core.dist_checkpointing.strategies.torch import ( + MCoreSavePlanner, + TorchDistSaveShardedStrategy, + _replace_state_dict_keys_with_sharded_keys, + mcore_to_pyt_state_dict, +) +from torch import multiprocessing as mp + +from nemo.utils import logging + + +class TorchDistAsyncSaveShardedStrategy(TorchDistSaveShardedStrategy): + """Async save strategy for the PyT Distributed format. + + NOTE: this class will be removed and replaced with an MCore version + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.async_request = None + + def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Translates MCore ShardedTensors to PyT ShardedTensors and saves in PyT Distributed format. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to save + checkpoint_dir (Path): checkpoint directory + + Returns: None + """ + # Translate the state dict + ( + sharded_state_dict, + flat_mapping, + rename_mapping, + ) = _replace_state_dict_keys_with_sharded_keys(sharded_state_dict, self.keep_only_main_replica) + pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False) + # Use PyT saving mechanism + writer = FileSystemWriterAsync(checkpoint_dir, thread_count=self.thread_count) + + save_state_dict_ret = save_state_dict_async_plan( + pyt_state_dict, + writer, + None, + planner=MCoreSavePlanner(), + ) + self.async_request = self._get_save_and_finalize_callbacks(writer, save_state_dict_ret) + return self.async_request + + def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret): + save_fn_args = writer.get_save_function_and_args() + if save_fn_args is None: # this check can be removed with MCore v0.7 + save_fn_args = None, () + save_fn, save_args = save_fn_args + + def finalize_fn(): + save_state_dict_async_finalize(*save_state_dict_ret) + torch.distributed.barrier() + + return AsyncRequest(save_fn, save_args, [finalize_fn]) + + +class AsyncRequest(NamedTuple): + """Represents an async request that needs to be scheduled for execution. + + NOTE: this class will be removed and replaced with an MCore version + + Args: + async_fn (Callable, optional): async function to call. None represents noop. + async_fn_args (Tuple): args to pass to `async_fn`. + finalize_fns (List[Callable]): list of functions to call to finalize the request. + These functions will be called synchronously after `async_fn` is done + *on all ranks*. + """ + + async_fn: Optional[Callable] + async_fn_args: Tuple + finalize_fns: List[Callable] + is_frozen: bool = False + + def add_finalize_fn(self, fn: Callable) -> None: + """Adds a new finalize function to the request. + + Args: + fn (Callable): function to add to the async request. This function + will be called *after* existing finalization functions. + + Returns: + None + """ + if self.is_frozen: + raise RuntimeError('Cannot add finalization functions to a frozen AsyncRequest') + self.finalize_fns.append(fn) + + def execute_sync(self) -> None: + """Helper to synchronously execute the request. + + This logic is equivalent to what should happen in case of the async call. + """ + if self.async_fn is not None: + self.async_fn(*self.async_fn_args) + torch.distributed.barrier() + for finalize_fn in self.finalize_fns: + finalize_fn() + + def freeze(self) -> 'AsyncRequest': + """Freezes the async request, disallowing adding new finalization functions. + + Returns: + AsyncRequest: new async request with all same fields except for the + `is_frozen` flag. + """ + return self._replace(is_frozen=True) + + +class DistributedAsyncCaller: + """Wrapper around mp.Process that ensures correct semantic of distributed finalization. + + NOTE: this class will be removed and replaced with an MCore version + + Starts process asynchronously and allows checking if all processes on all ranks are done. + """ + + def __init__(self): + self.process: Optional[mp.Process] = None + self.start_time: Optional[float] = None + + def schedule_async_call( + self, + async_fn: Optional[Callable], + save_args: Tuple, + ) -> None: + """Spawn a process with `async_fn` as the target. + + This method must be called on all ranks. + + Args: + async_fn (Callable, optional): async function to call. If None, + no process will be started. + save_args (Tuple): async function args. + """ + if async_fn is None: + return # nothing to do + torch.cuda.synchronize() + ctx = mp.get_context('fork') + self.start_time = time() + self.process = ctx.Process( + target=async_fn, + args=save_args, + ) + self.process.start() + + def is_current_async_call_done(self, blocking=False) -> bool: + """Check if async save is finished on all ranks. + + For semantic correctness, requires rank synchronization in each check. + This method must be called on all ranks. + + Args: + blocking (bool, optional): if True, will wait until the call is done + on all ranks. Otherwise, returns immediately if at least one rank + is still active. Defaults to False. + + Returns: + bool: True if all ranks are done (immediately of after active wait + if `blocking` is True), False if at least one rank is still active. + """ + # The following takes the same overhead as torch.distributed.barrier (single integer all-reduce) + is_alive = int(self.process.is_alive()) if self.process is not None else 0 + ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device()) + logging.debug(f"[rank {torch.distributed.get_rank()}] DistributedAsyncCaller is_alive:{is_alive}") + torch.distributed.all_reduce(ten) + if ten[0] > 0 and not blocking: + return False + else: + if self.process is not None: + logging.debug(f"rank: {torch.distributed.get_rank()}, joining self.process") + self.process.join() + self.process = None + + logging.debug( + f"DistributedAsyncCaller: Async process join finished after {time() - self.start_time:.2f}s from forking" + ) + self.start_time = None + return True + + +class _ActiveAsyncRequest(NamedTuple): + """Helper to represent an active async call. + + NOTE: this class will be removed and replaced with an MCore version + + Args: + idx (int): index of the call (starting from 0) + async_caller (DistributedAsyncCaller): async caller instance that represents + the async process handling the async request + async_request (AsyncRequest): async request that is being called + """ + + idx: int + async_caller: DistributedAsyncCaller + async_request: AsyncRequest + + +class AsyncCallsQueue: + """Manages a queue of async calls. + + NOTE: this class will be removed and replaced with an MCore version + + Allows adding a new async call with `schedule_async_request` and finalizing + active calls with `maybe_finalize_async_calls`. + """ + + def __init__(self): + self.async_calls: deque[_ActiveAsyncRequest] = deque([]) + self.call_idx: int = -1 + + def schedule_async_request(self, async_request: AsyncRequest) -> int: + """Start a new async call and add it to a queue of active async calls. + + This method must be called on all ranks. + + Args: + async_request (AsyncRequest): async request to start. + + Returns: + int: index of the async call that was started. + This can help the user keep track of the async calls. + """ + self.call_idx += 1 + async_caller = DistributedAsyncCaller() + async_request = async_request.freeze() + async_caller.schedule_async_call(async_request.async_fn, async_request.async_fn_args) + self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request)) + return self.call_idx + + def maybe_finalize_async_calls(self, blocking=False) -> List[int]: + """Finalizes all available calls. + + This method must be called on all ranks. + + Args: + blocking (bool, optional): if True, will wait until all active requests + are done. Otherwise, finalizes only the async request that already + finished. Defaults to False. + Returns: + List[int]: list of indices (as returned by `schedule_async_request`) + of async calls that have been successfully finalized. + """ + call_idx_finalized = [] + while self.async_calls: + next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(blocking) + if not next_async_done: + break + call_idx, _, async_request = self.async_calls.popleft() + for finalize_fn in async_request.finalize_fns: + finalize_fn() + ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device()) + torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX) + assert ( + ten.item() == call_idx + ), 'Unmatched async calls. That probably means not all ranks are participating in async finalization' + call_idx_finalized.append(call_idx) + return call_idx_finalized + + def get_num_unfinalized_calls(self): + """Get the number of active async calls.""" + return len(self.async_calls) + + def close(self): + """Finalize all calls upon closing.""" + self.maybe_finalize_async_calls(blocking=True) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 5c7cac5a9a556..9e8b55eade1fb 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -51,11 +51,11 @@ class NotFoundError(NeMoBaseException): - """ Raised when a file or folder is not found""" + """Raised when a file or folder is not found""" class LoggerMisconfigurationError(NeMoBaseException): - """ Raised when a mismatch between trainer.logger and exp_manager occurs""" + """Raised when a mismatch between trainer.logger and exp_manager occurs""" def __init__(self, message): message = ( @@ -66,7 +66,7 @@ def __init__(self, message): class CheckpointMisconfigurationError(NeMoBaseException): - """ Raised when a mismatch between trainer.callbacks and exp_manager occurs""" + """Raised when a mismatch between trainer.callbacks and exp_manager occurs""" @dataclass @@ -106,6 +106,7 @@ class CallbackParams: save_nemo_on_train_end: Optional[bool] = True # Whether to automatically save .nemo file durin on_train_end hook model_parallel_size: Optional[int] = None # tensor parallel size * pipeline parallel size save_on_train_epoch_end: Optional[bool] = False # Save after training, not after validation + async_save: Optional[bool] = False # save the checkpoint asynchronously @dataclass @@ -128,8 +129,7 @@ class EMAParams: @dataclass class ExpManagerConfig: - """Experiment Manager config for validation of passed arguments. - """ + """Experiment Manager config for validation of passed arguments.""" # Log dir creation parameters explicit_log_dir: Optional[str] = None @@ -313,7 +313,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir. - log_global_rank_0_only (bool): Whether to only create log files for global rank 0. Defaults to False. Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir. - - max_time (str): The maximum wall clock time *per run*. This is intended to be used on clusters where you want + - max_time (str): The maximum wall clock time *per run*. This is intended to be used on clusters where you want a checkpoint to be saved after this specified time and be able to resume from that checkpoint. Defaults to None. - seconds_to_sleep (float): seconds to sleep non rank 0 processes for. Used to give enough time for rank 0 to initialize @@ -336,6 +336,10 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo # Ensure passed cfg is compliant with ExpManagerConfig schema = OmegaConf.structured(ExpManagerConfig) + # TODO: remove this check + if is_global_rank_zero(): + logging.info('ExpManager schema') + logging.info(schema) if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): @@ -681,7 +685,7 @@ def check_resume( def check_explicit_log_dir( trainer: 'pytorch_lightning.Trainer', explicit_log_dir: Union[Path, str], exp_dir: str, name: str, version: str ) -> Tuple[Path, str, str, str]: - """ Checks that the passed arguments are compatible with explicit_log_dir. + """Checks that the passed arguments are compatible with explicit_log_dir. Returns: log_dir (Path): the log_dir @@ -918,7 +922,7 @@ def configure_checkpointing( params: 'DictConfig', create_preemption_callback: bool, ): - """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint + """Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback """ for callback in trainer.callbacks: @@ -995,7 +999,12 @@ def check_slurm(trainer): class StatelessTimer(Timer): """Extension of PTL timers to be per run.""" - def __init__(self, duration: timedelta = None, interval: str = Interval.step, verbose: bool = True,) -> None: + def __init__( + self, + duration: timedelta = None, + interval: str = Interval.step, + verbose: bool = True, + ) -> None: super().__init__(duration, interval, verbose) # Override PTL Timer's state dict to not store elapsed time information so that we can restore and continue training. diff --git a/tests/core/test_dist_ckpt.py b/tests/core/test_dist_ckpt.py index b6dc5ca89d3ec..8fe21a3168540 100644 --- a/tests/core/test_dist_ckpt.py +++ b/tests/core/test_dist_ckpt.py @@ -1,6 +1,7 @@ import os import types from pathlib import Path +from typing import Any, Dict import pytest import pytorch_lightning as pl @@ -9,7 +10,19 @@ from pytorch_lightning.demos.boring_classes import BoringModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy -from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO +from nemo.utils.callbacks.dist_ckpt_io import ( + AsyncFinalizableCheckpointIO, + AsyncFinalizerCallback, + DistributedCheckpointIO, +) + +try: + from megatron.core.dist_checkpointing import ShardedTensor + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + HAVE_MEGATRON_CORE = False class ExampleModel(BoringModel): @@ -19,7 +32,13 @@ def on_validation_epoch_end(self) -> None: class ExampleMCoreModel(ExampleModel): def sharded_state_dict(self): - return {'a': 3} + return { + 'a': ShardedTensor.from_rank_offsets('a', self.layer.weight, replica_id=torch.distributed.get_rank()), + 'const': 3, + } + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + checkpoint['sharded_state_dict'] = self.sharded_state_dict() class MockDistributedCheckpointIO(DistributedCheckpointIO): @@ -42,17 +61,22 @@ def save_checkpoint(self, *args, **kwargs) -> None: def _get_last_checkpoint_dir(root_dir: Path, model: pl.LightningModule, suffix: str = '') -> Path: steps = len(model.train_dataloader().dataset) * model.trainer.max_epochs // torch.distributed.get_world_size() - return root_dir / 'checkpoints' / f'epoch=1-step={steps}{suffix}' + return root_dir / 'checkpoints' / f'epoch={model.trainer.max_epochs - 1}-step={steps}{suffix}' + + +def _get_nlp_strategy_without_optimizer_state(): + strategy = NLPDDPStrategy() + # this ensures optimizer sharded state creation is skipped + strategy.optimizer_sharded_state_dict = types.MethodType( + lambda self, unsharded_optim_state: unsharded_optim_state, strategy + ) + return strategy class TestDistCkptIO: @pytest.mark.run_only_on('GPU') def test_dist_ckpt_io_called_for_mcore_models(self, tmp_path): - strategy = NLPDDPStrategy() - # skip optimizer sharded state creation: - strategy.optimizer_sharded_state_dict = types.MethodType( - lambda self, unsharded_optim_state: unsharded_optim_state, strategy - ) + strategy = _get_nlp_strategy_without_optimizer_state() checkpoint_io = MockDistributedCheckpointIO('xxx') test_trainer = pl.Trainer( @@ -70,7 +94,7 @@ def test_dist_ckpt_io_called_for_mcore_models(self, tmp_path): assert checkpoint_io.save_checkpoint_called_args is not None (state_dict, path), _ = checkpoint_io.save_checkpoint_called_args # Ckpt path doesn't contain the .ckpt suffix - assert path.name == _get_last_checkpoint_dir(tmp_path, model).name, len(test_trainer.strategy.parallel_devices) + assert path.name == _get_last_checkpoint_dir(tmp_path, model).name @pytest.mark.run_only_on('GPU') def test_dist_ckpt_path_not_executed_for_non_core_models(self, tmp_path): @@ -96,3 +120,60 @@ def test_dist_ckpt_path_not_executed_for_non_core_models(self, tmp_path): assert os.path.basename(path) == _get_last_checkpoint_dir(tmp_path, model, suffix='.ckpt').name else: assert checkpoint_io.save_checkpoint_called_args is None + + +class TestAsyncSave: + @pytest.mark.run_only_on('GPU') + def test_async_save_produces_same_checkpoints_as_sync(self, tmp_path): + strategy = _get_nlp_strategy_without_optimizer_state() + sync_checkpoint_io = DistributedCheckpointIO('torch_dist') + async_checkpoint_io = AsyncFinalizableCheckpointIO(DistributedCheckpointIO('torch_dist', async_save=True)) + + model = ExampleMCoreModel() + + # dummy_trainer just to initialize NCCL + dummy_trainer = pl.Trainer( + enable_checkpointing=False, + logger=False, + max_epochs=1, + strategy=_get_nlp_strategy_without_optimizer_state(), + plugins=[sync_checkpoint_io], + ) + dummy_trainer.fit(model) + tmp_path = strategy.broadcast(tmp_path) + + sync_ckpt_dir = tmp_path / 'sync_checkpoints' + async_ckpt_dir = tmp_path / 'async_checkpoints' + + sync_test_trainer = pl.Trainer( + enable_checkpointing=True, + logger=False, + max_epochs=1, + strategy=_get_nlp_strategy_without_optimizer_state(), + plugins=[sync_checkpoint_io], + default_root_dir=sync_ckpt_dir, + ) + sync_test_trainer.fit(model) + + async_test_trainer = pl.Trainer( + enable_checkpointing=True, + logger=False, + max_epochs=1, + strategy=_get_nlp_strategy_without_optimizer_state(), + plugins=[async_checkpoint_io], + callbacks=AsyncFinalizerCallback(), + default_root_dir=async_ckpt_dir, + ) + async_test_trainer.fit(model) + + # Load and compare checkpoints + checkpoint = {'sharded_state_dict': model.sharded_state_dict()} + sync_state_dict = sync_checkpoint_io.load_checkpoint( + _get_last_checkpoint_dir(sync_ckpt_dir, model), sharded_state_dict=checkpoint + ) + async_state_dict = async_checkpoint_io.load_checkpoint( + _get_last_checkpoint_dir(async_ckpt_dir, model), sharded_state_dict=checkpoint + ) + + assert sync_state_dict['sharded_state_dict']['const'] == async_state_dict['sharded_state_dict']['const'] + assert torch.all(sync_state_dict['sharded_state_dict']['a'] == async_state_dict['sharded_state_dict']['a']) From 1de4b49d46da12e86716f4c30dac9d01590cb1ae Mon Sep 17 00:00:00 2001 From: mikolajblaz Date: Wed, 15 May 2024 13:57:43 +0200 Subject: [PATCH 12/18] Fix incorrect checkpoint removal logic (#9192) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix incorrect if logic Signed-off-by: Mikołaj Błaż * Apply isort and black reformatting Signed-off-by: mikolajblaz --------- Signed-off-by: Mikołaj Błaż Signed-off-by: mikolajblaz --- nemo/collections/nlp/parts/nlp_overrides.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 65ffb7df47f46..079732f6b9c5d 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -450,8 +450,9 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: def remove_checkpoint(self, filepath: Union[str, Path]) -> None: # check if filepath is a distributed checkpoint - if self.use_distributed_checkpointing and self.is_global_zero: - self.checkpoint_io.remove_checkpoint(ckpt_to_dir(filepath)) + if self.use_distributed_checkpointing: + if self.is_global_zero: + self.checkpoint_io.remove_checkpoint(ckpt_to_dir(filepath)) # legacy checkpoint logic, does not use megatron core else: From 6cb618a81d9239611da22e9ef23d075498d18336 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Wed, 15 May 2024 15:01:42 +0200 Subject: [PATCH 13/18] Update to using Model Optimizer (formerly AMMO) in PTQ workflow (#9178) * Update PTQ to use nvidia-modelopt Signed-off-by: Jan Lasek * Restore PTQ tests Signed-off-by: Jan Lasek * Update docs Signed-off-by: Jan Lasek * Comment on apply_rope_fusion Signed-off-by: Jan Lasek * Support for calibration PP > 1 Signed-off-by: Jan Lasek * Apply isort and black reformatting Signed-off-by: janekl * Fix cicd-main.yml indent Signed-off-by: Jan Lasek * Set data/tensor parallel groups Signed-off-by: Jan Lasek * Install only torch dependecies Signed-off-by: Jan Lasek * Follow up on recent modelopt changes Signed-off-by: Jan Lasek * Model support matrix Signed-off-by: Jan Lasek * Apply isort and black reformatting Signed-off-by: janekl * Rename PTQ script as it should be model-agnostic Signed-off-by: Jan Lasek * Remove unused import Signed-off-by: Jan Lasek * Update setup instructions Signed-off-by: Jan Lasek --------- Signed-off-by: Jan Lasek Signed-off-by: janekl Co-authored-by: janekl --- .github/workflows/cicd-main.yml | 135 +++++++++--------- Dockerfile | 2 - docs/source/nlp/quantization.rst | 48 ++++++- docs/source/starthere/intro.rst | 6 +- ...zation.yaml => megatron_quantization.yaml} | 0 ...antization.py => megatron_quantization.py} | 6 +- ...mmo_spec.py => gpt_layer_modelopt_spec.py} | 11 +- .../language_modeling/megatron_gpt_model.py | 4 +- nemo/export/quantize/quantizer.py | 111 +++++++++----- 9 files changed, 204 insertions(+), 119 deletions(-) rename examples/nlp/language_modeling/conf/{megatron_llama_quantization.yaml => megatron_quantization.yaml} (100%) rename examples/nlp/language_modeling/{megatron_llama_quantization.py => megatron_quantization.py} (92%) rename nemo/collections/nlp/models/language_modeling/megatron/{gpt_layer_ammo_spec.py => gpt_layer_modelopt_spec.py} (91%) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 4652e4d19f897..291eeaed7f895 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -132,8 +132,8 @@ jobs: apt-get update && apt-get install libsox-fmt-all -y && \ popd - # AMMO installation - pip install nvidia-ammo~=0.9.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir + # ModelOpt installation + pip install nvidia-modelopt[torch]~=0.11.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir # PyTorch Lightning version python -c "import pytorch_lightning; print(pytorch_lightning.__version__)" @@ -394,7 +394,7 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - run: | - python examples/nlp/language_modeling/megatron_llama_quantization.py \ + python examples/nlp/language_modeling/megatron_quantization.py \ model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ quantization.algorithm=null \ model_save=/home/TestData/nlp/megatron_llama/ci_baseline @@ -403,69 +403,70 @@ jobs: - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" - # L2_PTQ_Llama2_FP8: - # needs: [cicd-test-container-setup] - # runs-on: self-hosted-azure - # timeout-minutes: 10 - # container: - # image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - # options: - # # --user 0:128 - # --device=/dev/nvidia0 - # --gpus all - # --shm-size=8g - # --env TRANSFORMERS_OFFLINE=0 - # --env HYDRA_FULL_ERROR=1 - # --volume /mnt/datadrive/TestData:/home/TestData - # steps: - # - name: Checkout repository - # uses: actions/checkout@v4 - # - run: | - # python examples/nlp/language_modeling/megatron_llama_quantization.py \ - # model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ - # tensor_model_parallel_size=2 \ - # trainer.devices=2 \ - # quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ - # quantization.algorithm=fp8 \ - # quantization.num_calib_size=8 \ - # inference.batch_size=2 \ - # export.inference_tensor_parallel=2 \ - # model_save=/home/TestData/nlp/megatron_llama/ci_fp8.qnemo - - # rm -rf /home/TestData/nlp/megatron_llama/ci_fp8.qnemo - # - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - # if: "failure()" - - # L2_PTQ_Llama2_INT8_SQ: - # needs: [cicd-test-container-setup] - # runs-on: self-hosted-azure - # timeout-minutes: 10 - # container: - # image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - # options: - # # --user 0:128 - # --device=/dev/nvidia0 - # --gpus all - # --shm-size=8g - # --env TRANSFORMERS_OFFLINE=0 - # --env HYDRA_FULL_ERROR=1 - # --volume /mnt/datadrive/TestData:/home/TestData - # steps: - # - name: Checkout repository - # uses: actions/checkout@v4 - # - run: | - # python examples/nlp/language_modeling/megatron_llama_quantization.py \ - # model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ - # quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ - # quantization.algorithm=int8_sq \ - # quantization.num_calib_size=8 \ - # inference.batch_size=2 \ - # model_save=/home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo - - # rm -rf /home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo - # - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - # if: "failure()" - + L2_PTQ_Llama2_FP8: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + timeout-minutes: 10 + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + tensor_model_parallel_size=2 \ + trainer.devices=2 \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=fp8 \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + export.inference_tensor_parallel=2 \ + model_save=/home/TestData/nlp/megatron_llama/ci_fp8.qnemo + + rm -rf /home/TestData/nlp/megatron_llama/ci_fp8.qnemo + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + + L2_PTQ_Llama2_INT8_SQ: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + timeout-minutes: 10 + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=int8_sq \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + model_save=/home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo + + rm -rf /home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + + # TODO: investigate int4_awq stuck issues and restore the test #L2_PTQ_Llama2_INT4_AWQ: # needs: [cicd-test-container-setup] # runs-on: self-hosted-azure @@ -484,7 +485,7 @@ jobs: # - name: Checkout repository # uses: actions/checkout@v4 # - run: | - # python examples/nlp/language_modeling/megatron_llama_quantization.py \ + # python examples/nlp/language_modeling/megatron_quantization.py \ # model_file=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ # tensor_model_parallel_size=1 \ # trainer.devices=1 \ diff --git a/Dockerfile b/Dockerfile index 396645d37019f..c270487842449 100644 --- a/Dockerfile +++ b/Dockerfile @@ -133,8 +133,6 @@ RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-chec RUN pip install flash-attn # install numba for latest containers RUN pip install numba>=0.57.1 -# install ammo -RUN pip install nvidia-ammo~=0.9.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir # copy nemo source into a scratch image FROM scratch as nemo-src diff --git a/docs/source/nlp/quantization.rst b/docs/source/nlp/quantization.rst index afe2e9eccbca2..cc40b6a972a22 100644 --- a/docs/source/nlp/quantization.rst +++ b/docs/source/nlp/quantization.rst @@ -10,7 +10,7 @@ PTQ enables deploying a model in a low-precision format -- FP8, INT4, or INT8 -- Model quantization has two primary benefits: reduced model memory requirements and increased inference throughput. -In NeMo, quantization is enabled by the Nvidia AMMO library -- a unified algorithmic model optimization & deployment toolkit. +In NeMo, quantization is enabled by the `NVIDIA TensorRT Model Optimizer (ModelOpt) `_ library -- a library to quantize and compress deep learning models for optimized inference on GPUs. The quantization process consists of the following steps: @@ -18,10 +18,52 @@ The quantization process consists of the following steps: 2. Calibrating the model to obtain appropriate algorithm-specific scaling factors 3. Producing an output directory or .qnemo tarball with model config (json), quantized weights (safetensors) and tokenizer config (yaml). -Loading models requires using an AMMO spec defined in `megatron.core.inference.gpt.model_specs.py `_ module. Typically the calibration step is lightweight and uses a small dataset to obtain appropriate statistics for scaling tensors. The output directory produced (or a .qnemo tarball) is ready to be used to build a serving engine with the Nvidia TensorRT-LLM library. The engine build step is also available in NeMo project in ``nemo.deploy`` and ``nemo.export`` modules. +Loading models requires using an ModelOpt spec defined in `nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec `_ module. Typically the calibration step is lightweight and uses a small dataset to obtain appropriate statistics for scaling tensors. The output directory produced (or a .qnemo tarball) is ready to be used to build a serving engine with the Nvidia TensorRT-LLM library. The engine build step is also available in NeMo project in ``nemo.deploy`` and ``nemo.export`` modules. Quantization algorithm can also be conveniently set to ``"null"`` to perform only the weights export step using default precision for TensorRT-LLM deployment. This is useful to obtain baseline performance and accuracy results for comparison. +Support Matrix +^^^^^^^^^^^^^^ + +Table below presents verified model support matrix for popular LLM architectures. Each model entry also optionally provides a download link to a corresponding Nemo checkpoint for testing purposes. Support for other model families is experimental. + +.. list-table:: Model Support Matrix + :widths: 15 15 15 15 + :header-rows: 1 + + * - **Model Family** + - **FP8** + - **INT8_SQ** + - **INT4_AWQ** + * - Llama (1, 2, 3) + - ✅ + - ✅ + - ✅ + * - Mistral + - ✅ + - ✅ + - ✅ + * - `GPT-3 `_ + - ✅ + - ✅ + - ✅ + * - `Nemotron-3 8b `_ + - ✅ + - ✅ + - ✅ + * - Nemotron-4 15b + - ✅ + - ✅ + - ✅ + * - StarCoder 2 + - ✅ + - ✅ + - ✅ + * - Gemma + - ✅ + - ✅ + - ✅ + Example ^^^^^^^ @@ -31,7 +73,7 @@ The script must be launched correctly with the number of processes equal to tens .. code-block:: bash - torchrun --nproc-per-node 8 examples/nlp/language_modeling/megatron_llama_quantization.py \ + torchrun --nproc-per-node 8 examples/nlp/language_modeling/megatron_quantization.py \ model_file=llama2-70b-base-bf16.nemo \ tensor_model_parallel_size=8 \ pipeline_model_parallel_size=1 \ diff --git a/docs/source/starthere/intro.rst b/docs/source/starthere/intro.rst index 63fdcfb0406e5..ebbe1551c39ee 100644 --- a/docs/source/starthere/intro.rst +++ b/docs/source/starthere/intro.rst @@ -96,13 +96,13 @@ This section details the steps to clone and install the Megatron Core. git checkout a5415fcfacef2a37416259bd38b7c4b673583675 && \ pip install . -AMMO Installation +Model Optimizer Installation -This final step involves installing the AMMO package. +This final step involves installing the Model Optimizer package. .. code-block:: bash - pip install nvidia-ammo~=0.7.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir + pip install nvidia-modelopt[torch]~=0.11.0 --extra-index-url https://pypi.nvidia.com .. code-block:: bash diff --git a/examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml b/examples/nlp/language_modeling/conf/megatron_quantization.yaml similarity index 100% rename from examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml rename to examples/nlp/language_modeling/conf/megatron_quantization.yaml diff --git a/examples/nlp/language_modeling/megatron_llama_quantization.py b/examples/nlp/language_modeling/megatron_quantization.py similarity index 92% rename from examples/nlp/language_modeling/megatron_llama_quantization.py rename to examples/nlp/language_modeling/megatron_quantization.py index 92ead6b4ed699..d4d6a8b6b9174 100644 --- a/examples/nlp/language_modeling/megatron_llama_quantization.py +++ b/examples/nlp/language_modeling/megatron_quantization.py @@ -25,12 +25,12 @@ Nemo quantization example script. Please consult nemo.export.quantize.Quantizer class -and examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml config on available quantization methods, +and examples/nlp/language_modeling/conf/megatron_quantization.yaml config on available quantization methods, models supported as well as how to set up data and inference for calibration (with defaults recommended). Example usage: ``` -python examples/nlp/language_modeling/megatron_llama_quantization.py \ +python examples/nlp/language_modeling/megatron_quantization.py \ model_file=llama2-7b-fp16.nemo \ model_save=llama2-7b-fp8.qnemo \ quantization.algorithm=fp8 \ @@ -59,7 +59,7 @@ def get_calib_dataloader(data="cnn_dailymail", batch_size=64, calib_size=512, ma yield batch -@hydra_runner(config_path="conf", config_name="megatron_llama_quantization") +@hydra_runner(config_path="conf", config_name="megatron_quantization") def main(cfg) -> None: if not torch.cuda.is_available(): raise EnvironmentError("GPU is required for the inference.") diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_ammo_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py similarity index 91% rename from nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_ammo_spec.py rename to nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py index e51ecaba463ac..f9ba58736cbd3 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_ammo_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py @@ -36,8 +36,9 @@ HAVE_MEGATRON_CORE = False IMPORT_ERROR = e -# Use this spec for AMMO PTQ and TensorRT-LLM export -def get_gpt_layer_ammo_spec() -> ModuleSpec: + +# Use this spec for Model Optimizer PTQ and TensorRT-LLM export +def get_gpt_layer_modelopt_spec() -> ModuleSpec: """Mix the native spec with TENorm. This is essentially the native local spec except for the layernorm implementation @@ -65,7 +66,11 @@ def get_gpt_layer_ammo_spec() -> ModuleSpec: self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=TENorm, mlp=ModuleSpec( - module=MLP, submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,), + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, + linear_fc2=RowParallelLinear, + ), ), mlp_bda=get_bias_dropout_add, # Map TE-layernorm-fusion keys back diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 536fc5bff7c89..3660a5145b102 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -41,7 +41,7 @@ from nemo.collections.nlp.models.language_modeling.megatron.gpt_full_te_layer_autocast_spec import ( get_gpt_full_te_layer_autocast_spec, ) -from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_ammo_spec import get_gpt_layer_ammo_spec +from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel from nemo.collections.nlp.modules.common.megatron.build_model import build_model @@ -154,7 +154,7 @@ def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True): "te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm), "megatron_falcon_gpt": get_falcon_layer_spec(), "megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(), - "ammo": get_gpt_layer_ammo_spec(), + "modelopt": get_gpt_layer_modelopt_spec(), } if spec_name not in name_spec_dict: raise ValueError(f"Spec name '{spec_name}' is not recognized.") diff --git a/nemo/export/quantize/quantizer.py b/nemo/export/quantize/quantizer.py index 783f47a08e79d..4748f4957a52b 100644 --- a/nemo/export/quantize/quantizer.py +++ b/nemo/export/quantize/quantizer.py @@ -18,11 +18,12 @@ import torch import torch.distributed as dist -from megatron.core import parallel_state +from megatron.core import mpu, parallel_state from megatron.core.transformer.module import Float16Module from omegaconf import OmegaConf from omegaconf.omegaconf import DictConfig, open_dict from pytorch_lightning.trainer.trainer import Trainer +from tqdm import tqdm from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector @@ -32,18 +33,18 @@ from nemo.utils.model_utils import load_config, save_artifacts, unwrap_model try: - import ammo.torch.quantization as atq - from ammo.torch.export import export_tensorrt_llm_checkpoint + import modelopt.torch.quantization as mtq + from modelopt.torch.export import export_tensorrt_llm_checkpoint + from modelopt.torch.utils.distributed import set_data_parallel_group, set_tensor_parallel_group - HAVE_AMMO = True + HAVE_MODELOPT = True except (ImportError, ModuleNotFoundError) as e: - HAVE_AMMO = False - HAVE_AMMO_ERROR = e + HAVE_MODELOPT = False + HAVE_MODELOPT_ERROR = e class Quantizer: - """ Post-training quantization of Nemo checkpoints. @@ -63,9 +64,9 @@ class Quantizer: model families is experimental and might not be fully supported. Available quantization methods are listed in QUANT_CFG_CHOICES dictionary below. - Please consult AMMO documentation for details. You can also inspect different choices in - examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml for quantization algorithms and - calibration data as well as recommended settings. + Please consult Model Optimizer documentation https://nvidia.github.io/TensorRT-Model-Optimizer/ for details. + You can also inspect different choices in examples/nlp/language_modeling/conf/megatron_quantization.yaml + for quantization algorithms and calibration data as well as recommended settings. Quantization algorithm can also be conveniently set to 'null' to perform only weights export step for TensorRT-LLM deployment. This is useful to getting baseline results for a full-precision model. @@ -78,14 +79,14 @@ def __init__( export_config: DictConfig, trainer_config: DictConfig, ): - if not HAVE_AMMO: - raise RuntimeError("nvidia-ammo is needed to use Quantizer") from HAVE_AMMO_ERROR + if not HAVE_MODELOPT: + raise RuntimeError("nvidia-modelopt is needed to use Quantizer") from HAVE_MODELOPT_ERROR QUANT_CFG_CHOICES = { - "int8": atq.INT8_DEFAULT_CFG, - "int8_sq": atq.INT8_SMOOTHQUANT_CFG, - "fp8": atq.FP8_DEFAULT_CFG, - "int4_awq": atq.INT4_AWQ_CFG, - "w4a8_awq": atq.W4A8_AWQ_BETA_CFG, + "int8": mtq.INT8_DEFAULT_CFG, + "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, + "fp8": mtq.FP8_DEFAULT_CFG, + "int4_awq": mtq.INT4_AWQ_CFG, + "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, } SUPPORTED_DTYPE = [16, "16", "bf16"] # Default precision for non-quantized layers assert export_config.dtype in SUPPORTED_DTYPE @@ -95,25 +96,30 @@ def __init__( self.export_config = export_config self.trainer_config = trainer_config if quantization_config.algorithm is not None: - atq_config = QUANT_CFG_CHOICES[quantization_config.algorithm] + quant_cfg = QUANT_CFG_CHOICES[quantization_config.algorithm] if "awq" in quantization_config.algorithm: - weight_quantizer = atq_config["quant_cfg"]["*weight_quantizer"] + weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] weight_quantizer["block_sizes"][-1] = quantization_config.awq_block_size # Always turn on FP8 kv cache to save memory footprint. # For int8_sq, we use int8 kv cache. - atq_config["quant_cfg"]["*output_quantizer"] = { + # TODO: Investigate why enabling FP8 kv cache will cause accuracy regressions for Nemotron. + enable_quant_kv_cache = ( + "int8" not in quantization_config.algorithm and export_config.decoder_type != "gptnext" + ) + print(f'{"Enable" if enable_quant_kv_cache else "Disable"} KV cache quantization') + quant_cfg["quant_cfg"]["*output_quantizer"] = { "num_bits": 8 if quantization_config.algorithm == "int8_sq" else (4, 3), "axis": None, - "enable": export_config.decoder_type != "gptnext", + "enable": enable_quant_kv_cache, } - self.atq_config = atq_config + self.quant_cfg = quant_cfg else: - self.atq_config = None + self.quant_cfg = None def _load_model( self, @@ -121,14 +127,17 @@ def _load_model( tensor_model_parallel_size: Optional[int] = None, pipeline_model_parallel_size: Optional[int] = None, ): - """Load model using AMMO layer spec for quantization.""" + """Load model using ModelOpt layer spec for quantization.""" model_cfg = self._load_and_modify_config(model_file, tensor_model_parallel_size, pipeline_model_parallel_size) trainer = Trainer(strategy=NLPDDPStrategy(), **self.trainer_config) connector = NLPSaveRestoreConnector() model = MegatronGPTModel.restore_from( - restore_path=model_file, trainer=trainer, override_config_path=model_cfg, save_restore_connector=connector, + restore_path=model_file, + trainer=trainer, + override_config_path=model_cfg, + save_restore_connector=connector, ) model.freeze() @@ -144,7 +153,8 @@ def _load_model( return model - def _check_ddp_initialized(self, model): + @staticmethod + def _check_ddp_initialized(model): if not parallel_state.is_initialized(): def dummy(): @@ -154,8 +164,11 @@ def dummy(): model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() + set_data_parallel_group(mpu.get_data_parallel_group()) + set_tensor_parallel_group(mpu.get_tensor_model_parallel_group()) + + @staticmethod def _load_and_modify_config( - self, model_file: str, tensor_model_parallel_size: Optional[int] = None, pipeline_model_parallel_size: Optional[int] = None, @@ -170,12 +183,35 @@ def _load_and_modify_config( model_cfg.tensor_model_parallel_size = tensor_model_parallel_size if pipeline_model_parallel_size is not None: model_cfg.pipeline_model_parallel_size = pipeline_model_parallel_size - # Only custom AMMO spec is supported for PTQ: this custom spec is largely based on local Megatron-LM + # Only custom ModelOpt spec is supported for PTQ: this custom spec is largely based on local Megatron-LM # layer definitions to avoid Transformer Engine implementations that are currently not supported. - model_cfg.name = "ammo" + # This layer spec also requires RoPE fusion to be disabled for tensor view operations in attention + # layer implementation from megatron/core/transformer/dot_product_attention.py to be functional. + model_cfg.name = "modelopt" + model_cfg.apply_rope_fusion = False return model_cfg + @staticmethod + def _sample_output(model): + """Generate sample output for a model instance.""" + if torch.distributed.get_rank() == 0: + print("Generating sample output for a model...") + + response = model.generate( + inputs=[ + "Born in north-east France, Soyer trained as a", + "Born in California, Soyer trained as a", + ], + length_params={ + "max_length": 100, + "min_length": 100, + }, + ) + + if torch.distributed.get_rank() == 0: + print(f'Example NeMo output after PTQ: {response["sentences"]}"') + def quantize( self, model_file: str, @@ -191,13 +227,12 @@ def quantize( model.set_inference_config(OmegaConf.to_container(self.inference_config)) - def forward_loop(): - for i, batch in enumerate(dataloader): - if dist.get_rank() == 0: - print(f"Calibrating batch {i}") + def forward_loop(model): + print("Calibrating the model...") + for i, batch in enumerate(tqdm(dataloader)): model.predict_step(batch, i) - model = atq.quantize(model, self.atq_config, forward_loop) + model = mtq.quantize(model, self.quant_cfg, forward_loop) if self.export_config == "gptnext": # We found squared_relu may have an under-calibration problem. @@ -207,12 +242,12 @@ def forward_loop(): maxbound = 448 elif self.quantization_config.quantization.algorithm == "int8_sq": maxbound = 127 - model = atq.postprocess_amax( + model = mtq.postprocess_amax( model, "*input_quantizer", lambda amax: torch.clamp(amax, min=0.01 * maxbound) ) if dist.get_rank() == 0: - atq.print_quant_summary(model) + mtq.print_quant_summary(model) return model @@ -220,6 +255,8 @@ def export(self, model, model_save: str): """Export model to '.qnemo' format for TensorRT-LLM engine build.""" torch_dtype = torch_dtype_from_precision(self.export_config.dtype) + self._sample_output(model) + if model.cfg.megatron_amp_O2: model.model = unwrap_model(model.model, Float16Module) @@ -239,6 +276,8 @@ def export(self, model, model_save: str): export_dir=export_dir, inference_tensor_parallel=self.export_config.inference_tensor_parallel, inference_pipeline_parallel=self.export_config.inference_pipeline_parallel, + use_nfs_workspace=self.export_config.inference_pipeline_parallel == 1 + and model.cfg.pipeline_model_parallel_size > 1, ) dist.barrier() # Wait until all ranks complete export_model_config step if dist.get_rank() == 0: From 061cc452cf6c6b8687093799b9d048e55aad5fd8 Mon Sep 17 00:00:00 2001 From: Alessandro Morari Date: Wed, 15 May 2024 16:43:25 -0400 Subject: [PATCH 14/18] GPU-based vectorized Specaug Version 2 (#9155) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * GPU-based vectorized SpecAug Signed-off-by: Piotr Żelasko * Wider dtypes for specaug mask bounds computation Signed-off-by: Piotr Żelasko * fast spec augmentation v2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed randint code, added comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed padding coverage bug, fixed long casting bug, fixed comments Signed-off-by: Alessandro Morari * fixed bug due to using freq_axis with length Signed-off-by: Alessandro Morari * Added tests for vectorized spectrogram augmentation Signed-off-by: Alessandro Morari * Apply isort and black reformatting Signed-off-by: pzelasko --------- Signed-off-by: Piotr Żelasko Signed-off-by: Alessandro Morari Signed-off-by: pzelasko Co-authored-by: Piotr Żelasko Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: pzelasko --- .../asr/modules/audio_preprocessing.py | 239 +++++++++--------- .../asr/parts/submodules/spectr_augment.py | 118 ++++++++- tests/collections/asr/test_asr_modules.py | 35 ++- 3 files changed, 261 insertions(+), 131 deletions(-) diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index 643bc4a69d69d..d45c0acf314fb 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -66,8 +66,8 @@ class AudioPreprocessor(NeuralModule, ABC): """ - An interface for Neural Modules that performs audio pre-processing, - transforming the wav files to features. + An interface for Neural Modules that performs audio pre-processing, + transforming the wav files to features. """ def __init__(self, win_length, hop_length): @@ -101,72 +101,72 @@ def get_features(self, input_signal, length): class AudioToMelSpectrogramPreprocessor(AudioPreprocessor, Exportable): """Featurizer module that converts wavs to mel spectrograms. - Args: - sample_rate (int): Sample rate of the input audio data. - Defaults to 16000 - window_size (float): Size of window for fft in seconds - Defaults to 0.02 - window_stride (float): Stride of window for fft in seconds - Defaults to 0.01 - n_window_size (int): Size of window for fft in samples - Defaults to None. Use one of window_size or n_window_size. - n_window_stride (int): Stride of window for fft in samples - Defaults to None. Use one of window_stride or n_window_stride. - window (str): Windowing function for fft. can be one of ['hann', - 'hamming', 'blackman', 'bartlett'] - Defaults to "hann" - normalize (str): Can be one of ['per_feature', 'all_features']; all - other options disable feature normalization. 'all_features' - normalizes the entire spectrogram to be mean 0 with std 1. - 'pre_features' normalizes per channel / freq instead. - Defaults to "per_feature" - n_fft (int): Length of FT window. If None, it uses the smallest power - of 2 that is larger than n_window_size. - Defaults to None - preemph (float): Amount of pre emphasis to add to audio. Can be - disabled by passing None. - Defaults to 0.97 - features (int): Number of mel spectrogram freq bins to output. - Defaults to 64 - lowfreq (int): Lower bound on mel basis in Hz. - Defaults to 0 - highfreq (int): Lower bound on mel basis in Hz. - Defaults to None - log (bool): Log features. - Defaults to True - log_zero_guard_type(str): Need to avoid taking the log of zero. There - are two options: "add" or "clamp". - Defaults to "add". - log_zero_guard_value(float, or str): Add or clamp requires the number - to add with or clamp to. log_zero_guard_value can either be a float - or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is - passed. - Defaults to 2**-24. - dither (float): Amount of white-noise dithering. - Defaults to 1e-5 - pad_to (int): Ensures that the output size of the time dimension is - a multiple of pad_to. - Defaults to 16 - frame_splicing (int): Defaults to 1 - exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length - // hop_length. Defaults to False. - pad_value (float): The value that shorter mels are padded with. - Defaults to 0 - mag_power (float): The power that the linear spectrogram is raised to - prior to multiplication with mel basis. - Defaults to 2 for a power spec - rng : Random number generator - nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to - samples in the batch. - Defaults to 0.0 - nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation. - Defaults to 4000 - use_torchaudio: Whether to use the `torchaudio` implementation. - mel_norm: Normalization used for mel filterbank weights. - Defaults to 'slaney' (area normalization) - stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints. - stft_conv: Deprecated argument, kept for compatibility with older checkpoints. - """ + Args: + sample_rate (int): Sample rate of the input audio data. + Defaults to 16000 + window_size (float): Size of window for fft in seconds + Defaults to 0.02 + window_stride (float): Stride of window for fft in seconds + Defaults to 0.01 + n_window_size (int): Size of window for fft in samples + Defaults to None. Use one of window_size or n_window_size. + n_window_stride (int): Stride of window for fft in samples + Defaults to None. Use one of window_stride or n_window_stride. + window (str): Windowing function for fft. can be one of ['hann', + 'hamming', 'blackman', 'bartlett'] + Defaults to "hann" + normalize (str): Can be one of ['per_feature', 'all_features']; all + other options disable feature normalization. 'all_features' + normalizes the entire spectrogram to be mean 0 with std 1. + 'pre_features' normalizes per channel / freq instead. + Defaults to "per_feature" + n_fft (int): Length of FT window. If None, it uses the smallest power + of 2 that is larger than n_window_size. + Defaults to None + preemph (float): Amount of pre emphasis to add to audio. Can be + disabled by passing None. + Defaults to 0.97 + features (int): Number of mel spectrogram freq bins to output. + Defaults to 64 + lowfreq (int): Lower bound on mel basis in Hz. + Defaults to 0 + highfreq (int): Lower bound on mel basis in Hz. + Defaults to None + log (bool): Log features. + Defaults to True + log_zero_guard_type(str): Need to avoid taking the log of zero. There + are two options: "add" or "clamp". + Defaults to "add". + log_zero_guard_value(float, or str): Add or clamp requires the number + to add with or clamp to. log_zero_guard_value can either be a float + or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is + passed. + Defaults to 2**-24. + dither (float): Amount of white-noise dithering. + Defaults to 1e-5 + pad_to (int): Ensures that the output size of the time dimension is + a multiple of pad_to. + Defaults to 16 + frame_splicing (int): Defaults to 1 + exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length + // hop_length. Defaults to False. + pad_value (float): The value that shorter mels are padded with. + Defaults to 0 + mag_power (float): The power that the linear spectrogram is raised to + prior to multiplication with mel basis. + Defaults to 2 for a power spec + rng : Random number generator + nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to + samples in the batch. + Defaults to 0.0 + nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation. + Defaults to 4000 + use_torchaudio: Whether to use the `torchaudio` implementation. + mel_norm: Normalization used for mel filterbank weights. + Defaults to 'slaney' (area normalization) + stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints. + stft_conv: Deprecated argument, kept for compatibility with older checkpoints. + """ def save_to(self, save_path: str): pass @@ -177,8 +177,7 @@ def restore_from(cls, restore_path: str): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "input_signal": NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)), "length": NeuralType( @@ -218,7 +217,7 @@ def __init__( highfreq=None, log=True, log_zero_guard_type="add", - log_zero_guard_value=2 ** -24, + log_zero_guard_value=2**-24, dither=1e-5, pad_to=16, frame_splicing=1, @@ -335,8 +334,7 @@ class AudioToMFCCPreprocessor(AudioPreprocessor): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "input_signal": NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)), "length": NeuralType(tuple('B'), LengthsType()), @@ -344,8 +342,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "processed_signal": NeuralType(('B', 'D', 'T'), MFCCSpectrogramType()), "processed_length": NeuralType(tuple('B'), LengthsType()), @@ -463,12 +460,14 @@ class SpectrogramAugmentation(NeuralModule): rect_time (int): maximum size of cut rectangles along the time dimension Defaults to 25. + use_numba_spec_augment: use numba code for Spectrogram augmentation + use_vectorized_spec_augment: use vectorized code for Spectrogram augmentation + """ @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { "input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()), "length": NeuralType(tuple('B'), LengthsType()), @@ -476,8 +475,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} def __init__( @@ -491,12 +489,18 @@ def __init__( rect_freq=20, rng=None, mask_value=0.0, - use_numba_spec_augment: bool = True, + use_vectorized_spec_augment: bool = True, + use_numba_spec_augment: bool = False, ): super().__init__() if rect_masks > 0: - self.spec_cutout = SpecCutout(rect_masks=rect_masks, rect_time=rect_time, rect_freq=rect_freq, rng=rng,) + self.spec_cutout = SpecCutout( + rect_masks=rect_masks, + rect_time=rect_time, + rect_freq=rect_freq, + rng=rng, + ) # self.spec_cutout.to(self._device) else: self.spec_cutout = lambda input_spec: input_spec @@ -508,6 +512,7 @@ def __init__( time_width=time_width, rng=rng, mask_value=mask_value, + use_vectorized_code=use_vectorized_spec_augment, ) else: self.spec_augment = lambda input_spec, length: input_spec @@ -541,26 +546,25 @@ def forward(self, input_spec, length): class MaskedPatchAugmentation(NeuralModule): """ - Zeroes out fixed size time patches of the spectrogram. - All samples in batch are guaranteed to have the same amount of masked time steps. - Optionally also performs frequency masking in the same way as SpecAugment. - Args: - patch_size (int): up to how many time steps does one patch consist of. - Defaults to 48. - mask_patches (float): how many patches should be masked in each sample. - if >= 1., interpreted as number of patches (after converting to int) - if <1., interpreted as fraction of total tokens to be masked (number of patches is rounded up) - Defaults to 10. - freq_masks (int): how many frequency segments should be cut. - Defaults to 0. - freq_width (int): maximum number of frequencies to be cut in a segment. - Defaults to 0. + Zeroes out fixed size time patches of the spectrogram. + All samples in batch are guaranteed to have the same amount of masked time steps. + Optionally also performs frequency masking in the same way as SpecAugment. + Args: + patch_size (int): up to how many time steps does one patch consist of. + Defaults to 48. + mask_patches (float): how many patches should be masked in each sample. + if >= 1., interpreted as number of patches (after converting to int) + if <1., interpreted as fraction of total tokens to be masked (number of patches is rounded up) + Defaults to 10. + freq_masks (int): how many frequency segments should be cut. + Defaults to 0. + freq_width (int): maximum number of frequencies to be cut in a segment. + Defaults to 0. """ @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { "input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()), "length": NeuralType(tuple('B'), LengthsType()), @@ -568,12 +572,15 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} def __init__( - self, patch_size: int = 48, mask_patches: float = 10.0, freq_masks: int = 0, freq_width: int = 0, + self, + patch_size: int = 48, + mask_patches: float = 10.0, + freq_masks: int = 0, + freq_width: int = 0, ): super().__init__() self.patch_size = patch_size @@ -586,7 +593,12 @@ def __init__( raise ValueError('mask_patches cannot be negative') if freq_masks > 0: - self.spec_augment = SpecAugment(freq_masks=freq_masks, time_masks=0, freq_width=freq_width, time_width=0,) + self.spec_augment = SpecAugment( + freq_masks=freq_masks, + time_masks=0, + freq_width=freq_width, + time_width=0, + ) else: self.spec_augment = None @@ -676,8 +688,7 @@ def forward(self, input_signal, length): @property def input_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "input_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), "length": NeuralType(tuple('B'), LengthsType()), @@ -685,8 +696,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), "processed_length": NeuralType(tuple('B'), LengthsType()), @@ -754,8 +764,7 @@ def num_subbands(self) -> int: @property def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "input": NeuralType(('B', 'C', 'T'), AudioSignal()), "input_length": NeuralType(('B',), LengthsType(), optional=True), @@ -763,8 +772,7 @@ def input_types(self) -> Dict[str, NeuralType]: @property def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "output_length": NeuralType(('B',), LengthsType()), @@ -835,7 +843,7 @@ class SpectrogramToAudio(NeuralModule): fft_length: length of FFT hop_length: length of hops/shifts of the sliding window magnitude_power: Transform magnitude of the spectrogram as x^(1/magnitude_power). - scale: Spectrogram will be scaled with 1/scale before the inverse transform. + scale: Spectrogram will be scaled with 1/scale before the inverse transform. """ def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): @@ -878,8 +886,7 @@ def num_subbands(self) -> int: @property def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType(), optional=True), @@ -887,8 +894,7 @@ def input_types(self) -> Dict[str, NeuralType]: @property def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "output": NeuralType(('B', 'C', 'T'), AudioSignal()), "output_length": NeuralType(('B',), LengthsType()), @@ -968,7 +974,7 @@ class AudioToMelSpectrogramPreprocessorConfig: highfreq: Optional[int] = None log: bool = True log_zero_guard_type: str = "add" - log_zero_guard_value: float = 2 ** -24 + log_zero_guard_value: float = 2**-24 dither: float = 1e-5 pad_to: int = 16 frame_splicing: int = 1 @@ -1015,7 +1021,8 @@ class SpectrogramAugmentationConfig: rect_freq: int = 0 mask_value: float = 0 rng: Optional[Any] = None # random.Random() type - use_numba_spec_augment: bool = True + use_numba_spec_augment: bool = False + use_vectorized_spec_augment: bool = True @dataclass diff --git a/nemo/collections/asr/parts/submodules/spectr_augment.py b/nemo/collections/asr/parts/submodules/spectr_augment.py index 9b379ce10f375..5bc7104816afd 100644 --- a/nemo/collections/asr/parts/submodules/spectr_augment.py +++ b/nemo/collections/asr/parts/submodules/spectr_augment.py @@ -38,12 +38,18 @@ class SpecAugment(nn.Module, Typing): to be cut in one segment. If a float value, defines maximum percentage of timesteps that are cut adaptively. + use_vectorized_code - GPU-based implementation with batched masking and GPU rng, + setting it to False reverts to the legacy implementation. + Fast implementation is inspired by torchaudio: + https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/functional/functional.py#L816 """ + FREQ_AXIS = 1 # Frequency axis in the spectrogram tensor + TIME_AXIS = 2 # Time axis in the spectrogram tensor + @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { "input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()), "length": NeuralType(tuple('B'), LengthsType()), @@ -51,12 +57,18 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} def __init__( - self, freq_masks=0, time_masks=0, freq_width=10, time_width=10, rng=None, mask_value=0.0, + self, + freq_masks: int = 0, + time_masks: int = 0, + freq_width: int = 10, + time_width: int | float = 10, + rng: random.Random | None = None, + mask_value: float = 0.0, + use_vectorized_code: bool = True, ): super().__init__() @@ -69,6 +81,7 @@ def __init__( self.time_width = time_width self.mask_value = mask_value + self.use_vectorized_code = use_vectorized_code if isinstance(time_width, int): self.adaptive_temporal_width = False @@ -81,6 +94,12 @@ def __init__( @typecheck() @torch.no_grad() def forward(self, input_spec, length): + if self.use_vectorized_code: + return self._forward_vectorized(input_spec, length) + else: + return self._forward_legacy(input_spec, length) + + def _forward_legacy(self, input_spec, length): batch_size, num_freq_bins, _ = input_spec.shape # Move lengths to CPU before repeated indexing lengths_cpu = length.cpu().numpy() @@ -112,6 +131,89 @@ def forward(self, input_spec, length): masked_spec = input_spec.masked_fill(mask=fill_mask, value=self.mask_value) return masked_spec + def _forward_vectorized(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch.Tensor: + # time masks + input_spec = self._apply_masks( + input_spec=input_spec, + num_masks=self.time_masks, + length=length, + width=self.time_width, + axis=self.TIME_AXIS, + mask_value=self.mask_value, + ) + # freq masks + input_spec = self._apply_masks( + input_spec=input_spec, + num_masks=self.freq_masks, + length=length, + width=self.freq_width, + axis=self.FREQ_AXIS, + mask_value=self.mask_value, + ) + return input_spec + + def _apply_masks( + self, + input_spec: torch.Tensor, + num_masks: int, + length: torch.Tensor, + width: int | float, + mask_value: float, + axis: int, + ) -> torch.Tensor: + + assert axis in ( + self.FREQ_AXIS, + self.TIME_AXIS, + ), f"Axis can be only be equal to frequency \ + ({self.FREQ_AXIS}) or time ({self.TIME_AXIS}). Received: {axis=}" + assert not ( + isinstance(width, float) and axis == self.FREQ_AXIS + ), "Float width supported \ + only with time axis." + + batch_size = input_spec.shape[0] + axis_length = input_spec.shape[axis] + + # If width is float then it is transformed into a tensor + if axis == self.TIME_AXIS and isinstance(width, float): + width = torch.clamp(width * length, max=axis_length).unsqueeze(1) + + # Generate [0-1) random numbers and then scale the tensors. + # Use float32 dtype for begin/end mask markers before they are quantized to long. + mask_width = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * width + mask_width = mask_width.long() + mask_start = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) + + if axis == self.TIME_AXIS: + # length can only be used for the time axis + mask_start = mask_start * (length.unsqueeze(1) - mask_width) + else: + mask_start = mask_start * (axis_length - mask_width) + + mask_start = mask_start.long() + mask_end = mask_start + mask_width + + # Create mask values using vectorized indexing + indices = torch.arange(axis_length, device=input_spec.device) + # Create a mask_tensor with all the indices. + # The mask_tensor shape is (batch_size, num_masks, axis_length). + mask_tensor = (indices >= mask_start.unsqueeze(-1)) & (indices < mask_end.unsqueeze(-1)) + + # Reduce masks to one mask + mask_tensor = mask_tensor.any(dim=1) + + # Create a final mask that aligns with the full tensor + mask = torch.zeros_like(input_spec, dtype=torch.bool) + if axis == self.TIME_AXIS: + mask_ranges = mask_tensor[:, None, :] + else: # axis == self.FREQ_AXIS + mask_ranges = mask_tensor[:, :, None] + mask[:, :, :] = mask_ranges + + # Apply the mask value + return input_spec.masked_fill(mask=mask, value=mask_value) + class SpecCutout(nn.Module, Typing): """ @@ -126,14 +228,12 @@ class SpecCutout(nn.Module, Typing): @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return {"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None): diff --git a/tests/collections/asr/test_asr_modules.py b/tests/collections/asr/test_asr_modules.py index b47a72fe0476e..1a845232b2a73 100644 --- a/tests/collections/asr/test_asr_modules.py +++ b/tests/collections/asr/test_asr_modules.py @@ -69,10 +69,28 @@ def test_AudioToMelSpectrogramPreprocessor_batch(self): assert diff <= 1e-3 @pytest.mark.unit - def test_SpectrogramAugmentationr(self): + def test_SpectrogramAugmentationr_legacy(self): # Make sure constructor works instance1 = modules.SpectrogramAugmentation( - freq_masks=10, time_masks=3, rect_masks=3, use_numba_spec_augment=False + freq_masks=10, time_masks=3, rect_masks=3, use_numba_spec_augment=False, use_vectorized_spec_augment=False + ) + assert isinstance(instance1, modules.SpectrogramAugmentation) + + # Make sure forward doesn't throw with expected input + instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0) + input_signal = torch.randn(size=(4, 512)) + length = torch.randint(low=161, high=500, size=[4]) + res0 = instance0(input_signal=input_signal, length=length) + res = instance1(input_spec=res0[0], length=length) + + assert res.shape == res0[0].shape + + @pytest.mark.unit + @pytest.mark.run_only_on('GPU') + def test_SpectrogramAugmentationr_vectorized(self): + # Make sure constructor works + instance1 = modules.SpectrogramAugmentation( + freq_masks=10, time_masks=3, rect_masks=3, use_numba_spec_augment=False, use_vectorized_spec_augment=True ) assert isinstance(instance1, modules.SpectrogramAugmentation) @@ -97,7 +115,7 @@ def test_SpectrogramAugmentationr_numba_kernel(self, caplog): # Make sure constructor works instance1 = modules.SpectrogramAugmentation( - freq_masks=10, time_masks=3, rect_masks=3, use_numba_spec_augment=True + freq_masks=10, time_masks=3, rect_masks=3, use_numba_spec_augment=True, use_vectorized_spec_augment=False ) assert isinstance(instance1, modules.SpectrogramAugmentation) @@ -120,7 +138,8 @@ def test_SpectrogramAugmentationr_numba_kernel(self, caplog): def test_SpectrogramAugmentationr_config(self): # Test that dataclass matches signature of module result = config_utils.assert_dataclass_signature_match( - modules.SpectrogramAugmentation, modules.audio_preprocessing.SpectrogramAugmentationConfig, + modules.SpectrogramAugmentation, + modules.audio_preprocessing.SpectrogramAugmentationConfig, ) signatures_match, cls_subset, dataclass_subset = result @@ -178,7 +197,8 @@ def test_MaskedPatchAugmentation(self): def test_MaskedPatchAugmentation_config(self): # Test that dataclass matches signature of module result = config_utils.assert_dataclass_signature_match( - modules.MaskedPatchAugmentation, modules.audio_preprocessing.MaskedPatchAugmentationConfig, + modules.MaskedPatchAugmentation, + modules.audio_preprocessing.MaskedPatchAugmentationConfig, ) signatures_match, cls_subset, dataclass_subset = result @@ -195,7 +215,10 @@ def test_RNNTDecoder(self): pred_config = OmegaConf.create( { '_target_': 'nemo.collections.asr.modules.RNNTDecoder', - 'prednet': {'pred_hidden': 32, 'pred_rnn_layers': 1,}, + 'prednet': { + 'pred_hidden': 32, + 'pred_rnn_layers': 1, + }, 'vocab_size': vocab_size, 'blank_as_pad': True, } From 964ea3cb5faab50791d08226ec49741418774aa8 Mon Sep 17 00:00:00 2001 From: Pablo Garay Date: Wed, 15 May 2024 21:57:22 -0700 Subject: [PATCH 15/18] run_cicd_for_release_branches_also (#9213) --- .github/workflows/cicd-main.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 291eeaed7f895..8430dae564184 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -15,7 +15,9 @@ name: "CICD NeMo" on: pull_request: - branches: [ "main" ] + branches: + - 'main' + - 'r**' types: [ labeled ] concurrency: From d0a453531e686cc7d126600b42fb3d385b20a6ae Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Thu, 16 May 2024 18:11:45 +0200 Subject: [PATCH 16/18] Update nemo.export module for quantized models (#9218) * Remove config aligner - no longer needed after TRT-LLM 0.9 update Signed-off-by: Jan Lasek * Change default export precision to bf16 (more frequent) Signed-off-by: Jan Lasek * Specify gpt_attention_plugin Signed-off-by: Jan Lasek --------- Signed-off-by: Jan Lasek --- .../conf/megatron_quantization.yaml | 2 +- nemo/export/trt_llm/qnemo/__init__.py | 1 - nemo/export/trt_llm/qnemo/align_config.py | 46 ------------------- .../trt_llm/qnemo/qnemo_to_tensorrt_llm.py | 40 ++-------------- 4 files changed, 5 insertions(+), 84 deletions(-) delete mode 100644 nemo/export/trt_llm/qnemo/align_config.py diff --git a/examples/nlp/language_modeling/conf/megatron_quantization.yaml b/examples/nlp/language_modeling/conf/megatron_quantization.yaml index 79a5bfbd8fe6a..88d10ae0a66cd 100644 --- a/examples/nlp/language_modeling/conf/megatron_quantization.yaml +++ b/examples/nlp/language_modeling/conf/megatron_quantization.yaml @@ -31,7 +31,7 @@ export: decoder_type: llama # gptnext, gpt2, llama inference_tensor_parallel: 1 # Default using 1 TP for inference inference_pipeline_parallel: 1 # Default using 1 PP for inference - dtype: 16 # Default precision data type + dtype: bf16 # Default precision data type model_file: llama2-7b-fp16.nemo # Nemo file path model_save: llama2-7b-fp8.qnemo # Path where the quantized model will be saved diff --git a/nemo/export/trt_llm/qnemo/__init__.py b/nemo/export/trt_llm/qnemo/__init__.py index 77832d749b662..59b9eb8ae6a6a 100644 --- a/nemo/export/trt_llm/qnemo/__init__.py +++ b/nemo/export/trt_llm/qnemo/__init__.py @@ -12,5 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .align_config import align_config from .qnemo_to_tensorrt_llm import qnemo_to_tensorrt_llm diff --git a/nemo/export/trt_llm/qnemo/align_config.py b/nemo/export/trt_llm/qnemo/align_config.py deleted file mode 100644 index abc53224e4b30..0000000000000 --- a/nemo/export/trt_llm/qnemo/align_config.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -from typing import Any, Dict - - -def align_config(config_trtllm_build: Dict[str, Any]) -> Dict[str, Any]: - """Function to align config produced by trtllm-build API for consistency - with how ModelConfig from tensorrt_llm.runtime is used in the project. - """ - config = {} - - config_trtllm_build = copy.deepcopy(config_trtllm_build) - - # Builder config - config["builder_config"] = {} - config["builder_config"]["name"] = "NeMo" - config["builder_config"].update(config_trtllm_build["build_config"]) - config["builder_config"].update(config_trtllm_build["pretrained_config"]) - - # Plugin config - config["plugin_config"] = config["builder_config"].pop("plugin_config") - - # Parallelism config - config["builder_config"]["world_size"] = config["builder_config"]["mapping"]["world_size"] - config["builder_config"]["tensor_parallel"] = config["builder_config"]["mapping"]["tp_size"] - config["builder_config"]["pipeline_parallel"] = config["builder_config"]["mapping"]["pp_size"] - - # Other parameters - config["builder_config"]["num_heads"] = config_trtllm_build["pretrained_config"]["num_attention_heads"] - config["builder_config"]["num_layers"] = config_trtllm_build["pretrained_config"]["num_hidden_layers"] - config["builder_config"]["add_bos"] = False - config["builder_config"]["precision"] = config["builder_config"]["dtype"] - return config diff --git a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py index 4e74d8e5fb58b..b7e2f7bc29739 100644 --- a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py +++ b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py @@ -15,13 +15,10 @@ import json import os import subprocess -from typing import List, Optional -from nemo.export.trt_llm.qnemo import align_config -from nemo.export.trt_llm.tensorrt_llm_build import MODEL_NAME, get_engine_name +from typing import List, Optional CONFIG_NAME = "config.json" -CONFIG_TRTLLM_BUILD_NAME = "config_trtllm_build.json" def qnemo_to_tensorrt_llm( @@ -34,6 +31,7 @@ def qnemo_to_tensorrt_llm( lora_target_modules: Optional[List[str]] = None, ): """Build TRT-LLM engine via trtllm-build CLI API in a subprocess.""" + assert not lora_target_modules, f"LoRA is not supported for quantized checkpoints, got {lora_target_modules}" print( "Note that setting n_gpus, tensor_parallel_size and pipeline_parallel_size parameters" " for quantized models is possible only on export step via nemo.export.quantize module." @@ -58,6 +56,8 @@ def qnemo_to_tensorrt_llm( str(max_prompt_embedding_table_size), "--gemm_plugin", model_config["dtype"], + "--gpt_attention_plugin", + model_config["dtype"], "--strongly_typed", "--use_custom_all_reduce", "disable", @@ -75,35 +75,3 @@ def qnemo_to_tensorrt_llm( print("Building engine done. Full logs are:") print(result.stdout.decode()) - - # Alignment to make nemo-fw tensorrt_llm.runtime ModelConfig definition compatible with config - # produced by trtllm-build API. The new config is saved as "config.json" while the source build - # config is saved as "config_trtllm_build.json" in the engine directory for reference. - os.rename(os.path.join(engine_dir, CONFIG_NAME), os.path.join(engine_dir, CONFIG_TRTLLM_BUILD_NAME)) - with open(os.path.join(engine_dir, CONFIG_TRTLLM_BUILD_NAME), "r") as f: - config_trtllm_build = json.load(f) - - config = align_config(config_trtllm_build) - - # Other parameters - assert lora_target_modules is None - config["builder_config"]["lora_target_modules"] = lora_target_modules - - with open(os.path.join(engine_dir, CONFIG_NAME), "w") as f: - json.dump(config, f, indent=2) - - # Rename for consistency with how engine is run later - for i in range(config["builder_config"]["world_size"]): - os.rename( - os.path.join(engine_dir, f"rank{i}.engine"), - os.path.join( - engine_dir, - get_engine_name( - MODEL_NAME, - config["builder_config"]["precision"], - config["builder_config"]["tensor_parallel"], - config["builder_config"]["pipeline_parallel"], - i, - ), - ), - ) From b489fba96227657b3d04ab71e390cb017bbcf685 Mon Sep 17 00:00:00 2001 From: jgerh <163925524+jgerh@users.noreply.github.com> Date: Thu, 16 May 2024 10:26:29 -0700 Subject: [PATCH 17/18] Update index.rst (#9080) Removed best-practices.rst file Signed-off-by: jgerh <163925524+jgerh@users.noreply.github.com> Co-authored-by: Eric Harper --- docs/source/index.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 82d3359480caa..eb586f7498423 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -41,7 +41,6 @@ For quick guides and tutorials, see the "Getting started" section below. :titlesonly: starthere/intro - starthere/best-practices starthere/tutorials For more information, browse the developer docs for your area of interest in the contents section below or on the left sidebar. @@ -86,4 +85,4 @@ For more information, browse the developer docs for your area of interest in the :name: Speech AI Tools :titlesonly: - tools/intro \ No newline at end of file + tools/intro From 526b6ade4bb078635d88feff76b5941c24db9e66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Thu, 16 May 2024 20:45:05 +0200 Subject: [PATCH 18/18] ci: Speeding NeMo-CI up by using caching (#9174) * build: Add `Dockerfile.ci` Signed-off-by: Oliver Koenig * ci: Build, push, and test ci image Signed-off-by: Oliver Koenig * chore: Disable cache dir for NeMo reinstall Signed-off-by: Oliver Koenig * revert: Modify `reinstall.sh` Signed-off-by: Oliver Koenig * fix: install modelopt[torch] instead of ammo Signed-off-by: Oliver Koenig * deduplicate requirements Signed-off-by: Oliver Koenig * make mcore/datasets Signed-off-by: Oliver Koenig --------- Signed-off-by: Oliver Koenig --- .github/workflows/cicd-main.yml | 123 ++++++++++---------------------- Dockerfile.ci | 74 +++++++++++++++++++ 2 files changed, 112 insertions(+), 85 deletions(-) create mode 100644 Dockerfile.ci diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 8430dae564184..ed2fc9f71f49a 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -75,92 +75,45 @@ jobs: uses: actions/checkout@v4 with: path: ${{ github.run_id }} - - - name: Container setup - run: | - # Pull base PyTorch container - docker pull nvcr.io/nvidia/pytorch:24.02-py3 - docker run --device=/dev/nvidia0 --gpus all --shm-size=8g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --env PYTHONUNBUFFERED=1 --volume ${{ github.workspace }}/${{ github.run_id }}:/workspace --volume /mnt/datadrive/TestData:/home/TestData nvcr.io/nvidia/pytorch:24.02-py3 /bin/bash -c ' - set -x - - # PyTorch version - python -c "import torch; print(torch.__version__)" - python -c "import torchvision; print(torchvision.__version__)" - - # Install test requirements - apt-get update && apt-get install -y bc && pip install -r requirements/requirements_test.txt && pip install -r requirements/requirements_lightning.txt - - # Code formatting checks - python setup.py style - - # Copyright Headers check - python tests/check_copyright_header.py --dir . - - # NeMo Installation - ./reinstall.sh release - - # Transformer Engine installation - git clone https://github.com/NVIDIA/TransformerEngine.git && \ - pushd TransformerEngine && \ - git fetch origin bfe21c3d68b0a9951e5716fb520045db53419c5e && \ - git checkout FETCH_HEAD && \ - git submodule init && git submodule update && \ - NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . && \ - popd - - # Apex installation - git clone https://github.com/NVIDIA/apex.git && \ - pushd apex && \ - git checkout 810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c && \ - cp -R apex /usr/local/lib/python3.10/dist-packages && \ - popd - - # pip package should be working with main, if not we can update the commit here - # until the pip package is updated - # Megatron Core installation - git clone https://github.com/NVIDIA/Megatron-LM.git && \ - pushd Megatron-LM && \ - git checkout c90aa1671fc0b97f80fa6c3bb892ce6f8e88e7c9 && \ - pip install . && \ - pushd megatron/core/datasets && \ - make && \ - popd && \ - popd - export PYTHONPATH="${PYTHONPATH}:/workspace/Megatron-LM" - - # Install only for test: L2: Segmentation Tool - pushd tools/ctc_segmentation && \ - pip install -r requirements.txt && \ - apt-get update && apt-get install libsox-fmt-all -y && \ - popd - - # ModelOpt installation - pip install nvidia-modelopt[torch]~=0.11.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir - - # PyTorch Lightning version - python -c "import pytorch_lightning; print(pytorch_lightning.__version__)" - - # PyTorch Lightning DDP Checks - CUDA_VISIBLE_DEVICES="0,1" python "tests/core_ptl/check_for_ranks.py" - - # Basic Import Checks - python -c "import nemo.collections.asr as nemo_asr" - python -c "import nemo.collections.nlp as nemo_nlp" - python -c "import nemo.collections.tts as nemo_tts" - - # set permission - chmod 777 -R /workspace - ' - ### \'\' - - - name: Push container to registry for future use + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + with: + # We use `docker` driver as this speeds things up for + # trivial (non-multi-stage) builds. + driver: docker + + - name: Build and push + uses: docker/build-push-action@v5 + with: + file: Dockerfile.ci + push: true + cache-from: nemoci.azurecr.io/nemo_container:latest + cache-to: type=inline + tags: | + nemoci.azurecr.io/nemo_container_${{ github.run_id }} + nemoci.azurecr.io/nemo_container:latest + + - name: Run some checks run: | - # Push container - echo "Docker: List containers" && docker ps -a - DOCKER_COMMIT=$(docker ps --latest --quiet) # latest container - docker commit $DOCKER_COMMIT nemoci.azurecr.io/nemo_container_${{ github.run_id }} - docker tag nemoci.azurecr.io/nemo_container_${{ github.run_id }} nemoci.azurecr.io/nemo_container_${{ github.run_id }} - docker push nemoci.azurecr.io/nemo_container_${{ github.run_id }} + docker run --rm --device=/dev/nvidia0 --gpus all --shm-size=8g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --env PYTHONUNBUFFERED=1 nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '\ + # PyTorch Lightning version + python -c "import pytorch_lightning; print(pytorch_lightning.__version__)" + + # PyTorch Lightning DDP Checks + CUDA_VISIBLE_DEVICES="0,1" python "tests/core_ptl/check_for_ranks.py" + + # Basic Import Checks + python -c "import nemo.collections.asr as nemo_asr" + python -c "import nemo.collections.nlp as nemo_nlp" + python -c "import nemo.collections.tts as nemo_tts" + + python setup.py style + python tests/check_copyright_header.py --dir . + + # These checks are not crucial + exit 0 + ' # - name: Build and push to local registry # uses: docker/build-push-action@v5 diff --git a/Dockerfile.ci b/Dockerfile.ci new file mode 100644 index 0000000000000..5b2cd8d6eb616 --- /dev/null +++ b/Dockerfile.ci @@ -0,0 +1,74 @@ +# syntax=docker/dockerfile:1-labs + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3 + +FROM ${BASE_IMAGE} + +ENV TRANSFORMERS_OFFLINE=0 +ENV HYDRA_FULL_ERROR=1 +ENV PYTHONUNBUFFERED=1 + +# APT packages +RUN <<"EOF" bash -ex +apt-get update +apt-get install -y bc libsox-fmt-all -y +apt-get clean +EOF + +WORKDIR /workspace + +# Install NeMo requirements +ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e +ARG MODELOPT_VERSION=0.11.0 +ARG MCORE_TAG=c90aa1671fc0b97f80fa6c3bb892ce6f8e88e7c9 +ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c +RUN \ +--mount=type=bind,source=requirements,target=requirements \ +--mount=type=bind,source=tools,target=tools \ +--mount=type=bind,source=setup.py,target=setup.py \ +--mount=type=bind,source=nemo/package_info.py,target=nemo/package_info.py \ +--mount=type=bind,source=nemo/__init__.py,target=nemo/__init__.py <<"EOF" bash -ex +pip install --no-cache-dir --no-build-isolation --extra-index-url https://pypi.nvidia.com \ +"transformer-engine @ git+https://github.com/NVIDIA/TransformerEngine.git@${TE_TAG}" \ +"megatron_core @ git+https://github.com/NVIDIA/Megatron-LM.git@${MCORE_TAG}" \ +"nvidia-modelopt[torch]~=${MODELOPT_VERSION}" \ +"apex @ git+https://github.com/NVIDIA/apex.git@${APEX_TAG}" \ +-r tools/ctc_segmentation/requirements.txt \ +".[all]" + +# Megatron Core installation +git clone https://github.com/NVIDIA/Megatron-LM.git && \ +pushd Megatron-LM && \ +git checkout ${MCORE_TAG} && \ + pushd megatron/core/datasets && \ + make && \ + popd && \ +popd +export PYTHONPATH="${PYTHONPATH}:/workspace/Megatron-LM" +EOF + +# Copy over NeMo code +COPY ./ ./ +RUN <<"EOF" bash -ex +pip install --no-cache-dir --no-build-isolation ".[all]" + +# set permission +chmod 777 -R /workspace +EOF + +ENV PYTHONPATH="${PYTHONPATH}:/workspace/Megatron-LM" +