From 85d8756a8121bb4bb021c4d2b89bc5f290e7c571 Mon Sep 17 00:00:00 2001 From: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Date: Wed, 31 Jan 2024 12:04:14 -0800 Subject: [PATCH] Add Bert HF checkpoint converter (#8088) * Add Bert HF checkpoint converter Signed-off-by: yaoyu-33 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reformat Signed-off-by: yaoyu-33 * Add BERT ONNX export * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add NeMo BERT to HF BERT script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Clean code Signed-off-by: yaoyu-33 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update argument names Signed-off-by: yaoyu-33 * Update build_transformer_config in Bert Signed-off-by: yaoyu-33 --------- Signed-off-by: yaoyu-33 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bobby Chen --- .../conf/megatron_bert_config.yaml | 8 +- .../language_modeling/megatron/bert_model.py | 10 + .../language_modeling/megatron_bert_model.py | 63 ++++ .../modules/common/megatron/transformer.py | 27 +- .../convert_bert_hf_to_nemo.py | 289 ++++++++++++++++++ .../convert_bert_nemo_to_hf.py | 269 ++++++++++++++++ .../export_nemo_bert_to_onnx.py | 83 +++++ 7 files changed, 745 insertions(+), 4 deletions(-) create mode 100644 scripts/nlp_language_modeling/convert_bert_hf_to_nemo.py create mode 100644 scripts/nlp_language_modeling/convert_bert_nemo_to_hf.py create mode 100644 scripts/nlp_language_modeling/export_nemo_bert_to_onnx.py diff --git a/examples/nlp/language_modeling/conf/megatron_bert_config.yaml b/examples/nlp/language_modeling/conf/megatron_bert_config.yaml index d388fe35b963..b3e3912fffd4 100644 --- a/examples/nlp/language_modeling/conf/megatron_bert_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_bert_config.yaml @@ -2,7 +2,7 @@ name: megatron_bert restore_from_path: null # used when starting from a .nemo file trainer: - devices: 2 + devices: 1 num_nodes: 1 accelerator: gpu precision: 16 @@ -56,15 +56,19 @@ model: hidden_size: 768 ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size. num_attention_heads: 12 + skip_head: False + transformer_block_type: post_ln init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') hidden_dropout: 0.1 # Dropout probability for hidden state transformer. kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm layernorm_epsilon: 1e-5 make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. pre_process: True # add embedding post_process: True # add pooler bert_binary_head: True # BERT binary head + megatron_legacy: False tokenizer: library: 'megatron' @@ -128,7 +132,7 @@ model: # - /raid/data/pile/my-gpt3_00_text_document # - .5 # - /raid/data/pile/my-gpt3_01_text_document - data_prefix: ??? + data_prefix: [1.0, /path/to/data] index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix data_impl: mmap splits_string: 900,50,50 diff --git a/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py index 22cfd7fb8efa..7e928a4e893b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py @@ -65,6 +65,9 @@ def bert_extended_attention_mask(attention_mask): # [b, 1, s, s] extended_attention_mask = attention_mask_bss.unsqueeze(1) + # HF Masking is equivalent to the one below + # extended_attention_mask = (attention_mask.unsqueeze(1) * torch.ones_like(attention_mask).unsqueeze(2)).unsqueeze(1) + # Convert attention mask to binary: extended_attention_mask = extended_attention_mask < 0.5 @@ -182,12 +185,15 @@ def __init__( activations_checkpoint_num_layers=1, activations_checkpoint_layers_per_pipeline=None, layernorm_epsilon=1e-5, + normalization='layernorm', + transformer_block_type='pre_ln', masked_softmax_fusion=False, bias_gelu_fusion=True, bias_dropout_add_fusion=True, openai_gelu=False, onnx_safe=False, add_binary_head=True, + skip_head=False, megatron_legacy=False, sequence_parallel=False, position_embedding_type='learned_absolute', @@ -229,6 +235,8 @@ def __init__( activations_checkpoint_num_layers=activations_checkpoint_num_layers, activations_checkpoint_layers_per_pipeline=activations_checkpoint_layers_per_pipeline, layernorm_epsilon=layernorm_epsilon, + normalization=normalization, + transformer_block_type=transformer_block_type, masked_softmax_fusion=masked_softmax_fusion, bias_activation_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, @@ -242,6 +250,8 @@ def __init__( init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size ) + if skip_head: + self.post_process = False if self.post_process: self.lm_head = BertLMHead( config, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index e4ae0f87d353..bef13367eb10 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -184,10 +184,13 @@ def model_provider_func(self, pre_process, post_process): ), layernorm_epsilon=cfg.get('layernorm_epsilon', 1e-5), masked_softmax_fusion=cfg.get('masked_softmax_fusion', True), + normalization=cfg.get('normalization', 'layernorm'), + transformer_block_type=cfg.get('transformer_block_type', 'pre_ln'), bias_gelu_fusion=cfg.get('bias_gelu_fusion', True), bias_dropout_add_fusion=cfg.get("bias_dropout_add_fusion", True), onnx_safe=cfg.get('onnx_safe', False), add_binary_head=cfg.bert_binary_head, + skip_head=cfg.get('skip_head', False), megatron_legacy=cfg.get('megatron_legacy', False), position_embedding_type=self.cfg.get("position_embedding_type", "learned_absolute"), ) @@ -1034,5 +1037,65 @@ def build_transformer_config(self) -> TransformerConfig: """ activation = self.cfg.get('activation', 'gelu') assert activation == 'gelu', "Only gelu activation is support for BERT at the moment." + + normalization = self.cfg.get('normalization', 'layernorm') + + layernorm_zero_centered_gamma = self.cfg.get('normalization', 'layernorm') == 'layernorm1p' + if normalization == 'layernorm': + normalization = 'LayerNorm' + elif normalization == 'rmsnorm': + normalization = 'RMSNorm' + elif normalization == 'layernorm1p': + normalization = 'LayerNorm' + layernorm_zero_centered_gamma = True + else: + logging.warning( + f"The normalization type: {normalization} might not be supported in megatron core." + f"Supported types are LayerNorm and RMSNorm." + ) + + # any configs that are not in the nemo model config will be added here + model_specific_configs = { + 'layernorm_zero_centered_gamma': layernorm_zero_centered_gamma, + 'normalization': normalization, + } + transformer_config = super().build_transformer_config() + + for key, value in model_specific_configs.items(): + setattr(transformer_config, key, value) + + # pass mcore customization configs directly to mcore + mcore_customization_config_dict = self.cfg.get('mcore_customization_config', {}) + for key, value in mcore_customization_config_dict.items(): + setattr(transformer_config, key, value) + return transformer_config + + +class MegatronBertTextEmbeddingModel(MegatronBertModel): + """ + Megatron Bert Text Embedding. + Model returns [batch, hidden] shape + """ + + def average_pool(self, last_hidden_states, attention_mask): + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + def forward( + self, + input_ids, + attention_mask, + token_type_ids, + lm_labels=None, + checkpoint_activations_all_layers=None, + model=None, + ): + outputs = super().forward( + input_ids, attention_mask, token_type_ids, lm_labels, checkpoint_activations_all_layers, model + ) + embeddings = self.average_pool(outputs[0], attention_mask) + embeddings = F.normalize(embeddings, p=2, dim=1) + + return embeddings diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index ca8c0ecafefd..9e9c7b526782 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -625,7 +625,6 @@ def forward( ) output = bias_dropout_add_func(mlp_output, mlp_bias, residual, self.hidden_dropout) - # print(f"Layer: {self.layer_number} MLP + Dropout + Residual checksum {output.sum()}") if self.transformer_block_type == 'post_ln': output = self.post_attention_layernorm(output) @@ -1158,6 +1157,27 @@ def build_layer(layer_number): offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)]) + if self.pre_process and self.transformer_block_type == 'post_ln': + # Final layer norm before output. + if normalization == 'layernorm': + self.initial_layernorm = get_layer_norm( + hidden_size, layernorm_epsilon, persist_layer_norm, sequence_parallel=config.sequence_parallel + ) + + elif normalization == 'layernorm1p': + self.initial_layernorm = LayerNorm1P( + hidden_size, layernorm_epsilon, sequence_parallel_enabled=config.sequence_parallel + ) + elif normalization == 'low_precision_layernorm': + self.initial_layernorm = LPLayerNorm(hidden_size, layernorm_epsilon) + else: + self.initial_layernorm = MixedFusedRMSNorm(hidden_size, layernorm_epsilon) + # for architectures such as MPT, there is no bias term even on the layernorms + # this code allows us to remove the bias terms from the layernorm module + # so that we can support MPT. However, certain apex-based LNs don't support + # removing bias, so we also have to check for that + if not bias and normalization not in ['layernorm', 'layernorm1p']: + remove_bias_from_layernorm(self.initial_layernorm) if self.post_process and self.transformer_block_type != 'post_ln': # Final layer norm before output. @@ -1435,7 +1455,10 @@ def forward( 'get_key_value does not work with ' 'activation checkpointing' ) - if not self.pre_process: + if self.pre_process: + if self.transformer_block_type == 'post_ln': + hidden_states = self.initial_layernorm(hidden_states) + else: # See set_input_tensor() hidden_states = self.input_tensor diff --git a/scripts/nlp_language_modeling/convert_bert_hf_to_nemo.py b/scripts/nlp_language_modeling/convert_bert_hf_to_nemo.py new file mode 100644 index 000000000000..cc9483b68c8a --- /dev/null +++ b/scripts/nlp_language_modeling/convert_bert_hf_to_nemo.py @@ -0,0 +1,289 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example to run this conversion script: +``` + python convert_bert_hf_to_nemo.py \ + --input_name_or_path "thenlper/gte-large" \ + --output_path /path/to/output/nemo/file.nemo \ + --precision 32 +``` +""" + +import os +from argparse import ArgumentParser + +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf +from transformers import AutoModel, AutoTokenizer + +from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.utils import logging + + +def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +def create_rename_keys(num_hidden_layers): + rename_keys = [] + for i in range(num_hidden_layers): + # encoder layers: attention mechanism, 2 feedforward neural networks, and 2 layernorms + rename_keys.extend( + [ + ( + f"encoder.layer.{i}.attention.self.query.weight", + f"model.language_model.encoder.layers.{i}.self_attention.query.weight", + ), + ( + f"encoder.layer.{i}.attention.self.query.bias", + f"model.language_model.encoder.layers.{i}.self_attention.query.bias", + ), + ( + f"encoder.layer.{i}.attention.self.key.weight", + f"model.language_model.encoder.layers.{i}.self_attention.key.weight", + ), + ( + f"encoder.layer.{i}.attention.self.key.bias", + f"model.language_model.encoder.layers.{i}.self_attention.key.bias", + ), + ( + f"encoder.layer.{i}.attention.self.value.weight", + f"model.language_model.encoder.layers.{i}.self_attention.value.weight", + ), + ( + f"encoder.layer.{i}.attention.self.value.bias", + f"model.language_model.encoder.layers.{i}.self_attention.value.bias", + ), + ( + f"encoder.layer.{i}.attention.output.dense.weight", + f"model.language_model.encoder.layers.{i}.self_attention.dense.weight", + ), + ( + f"encoder.layer.{i}.attention.output.dense.bias", + f"model.language_model.encoder.layers.{i}.self_attention.dense.bias", + ), + ( + f"encoder.layer.{i}.attention.output.LayerNorm.weight", + f"model.language_model.encoder.layers.{i}.input_layernorm.weight", + ), + ( + f"encoder.layer.{i}.attention.output.LayerNorm.bias", + f"model.language_model.encoder.layers.{i}.input_layernorm.bias", + ), + ( + f"encoder.layer.{i}.intermediate.dense.weight", + f"model.language_model.encoder.layers.{i}.mlp.dense_h_to_4h.weight", + ), + ( + f"encoder.layer.{i}.intermediate.dense.bias", + f"model.language_model.encoder.layers.{i}.mlp.dense_h_to_4h.bias", + ), + ( + f"encoder.layer.{i}.output.dense.weight", + f"model.language_model.encoder.layers.{i}.mlp.dense_4h_to_h.weight", + ), + ( + f"encoder.layer.{i}.output.dense.bias", + f"model.language_model.encoder.layers.{i}.mlp.dense_4h_to_h.bias", + ), + ( + f"encoder.layer.{i}.output.LayerNorm.weight", + f"model.language_model.encoder.layers.{i}.post_attention_layernorm.weight", + ), + ( + f"encoder.layer.{i}.output.LayerNorm.bias", + f"model.language_model.encoder.layers.{i}.post_attention_layernorm.bias", + ), + ] + ) + + # Non-layer dependent keys + rename_keys.extend( + [ + ("embeddings.word_embeddings.weight", "model.language_model.embedding.word_embeddings.weight"), + ("embeddings.position_embeddings.weight", "model.language_model.embedding.position_embeddings.weight"), + ("embeddings.token_type_embeddings.weight", "model.language_model.embedding.tokentype_embeddings.weight"), + ("embeddings.LayerNorm.weight", "model.language_model.encoder.initial_layernorm.weight"), + ("embeddings.LayerNorm.bias", "model.language_model.encoder.initial_layernorm.bias"), + ("pooler.dense.weight", "model.language_model.pooler.dense.weight"), + ("pooler.dense.bias", "model.language_model.pooler.dense.bias"), + ] + ) + + return rename_keys + + +def rename_model_keys(model_state_dict, rename_keys): + """ + Rename keys in the model's state dictionary based on the provided mappings. + + Parameters: + model_state_dict (dict): The state dictionary of the model. + rename_keys (list): A list of tuples with the mapping (old_key, new_key). + + Returns: + dict: A new state dictionary with updated key names. + """ + + # Create a new state dictionary with updated key names + new_state_dict = {} + + # Track keys from the original state dict to ensure all are processed + remaining_keys = set(model_state_dict.keys()) + + # Iterate over the rename mappings + for old_key, new_key in rename_keys: + if old_key in model_state_dict: + # Rename the key and remove it from the tracking set + new_state_dict[new_key] = model_state_dict[old_key] + remaining_keys.remove(old_key) + else: + print(f"Warning: Key '{old_key}' not found in the model state dictionary.") + + # Check if any keys were not converted from old to new + for old_key in remaining_keys: + print(f"Warning: Key '{old_key}' was not converted.") + + return new_state_dict + + +def adjust_tensor_shapes(model, nemo_state_dict): + """ + Adapt tensor shapes in the state dictionary to ensure compatibility with a different model structure. + + Parameters: + nemo_state_dict (dict): The state dictionary of the model. + + Returns: + dict: The updated state dictionary with modified tensor shapes for compatibility. + """ + + # Note: For 'key' and 'value' weights and biases, NeMo uses a consolidated tensor 'query_key_value'. + for key_ in list(nemo_state_dict.keys()): + if "self_attention.query" in key_: + key_q = key_ + key_k = key_.replace('self_attention.query', 'self_attention.key') + key_v = key_.replace('self_attention.query', 'self_attention.value') + key_new = key_.replace('self_attention.query', 'self_attention.query_key_value') + value_new = torch.concat((nemo_state_dict[key_q], nemo_state_dict[key_k], nemo_state_dict[key_v]), dim=0) + nemo_state_dict[key_new] = value_new + del nemo_state_dict[key_q], nemo_state_dict[key_k], nemo_state_dict[key_v] + + # Padding to new vocab size + original_embedding = nemo_state_dict['model.language_model.embedding.word_embeddings.weight'] + vocab_size = original_embedding.size(0) + if model.padded_vocab_size > vocab_size: + zeros_to_add = torch.zeros( + model.padded_vocab_size - vocab_size, + original_embedding.size(1), + dtype=original_embedding.dtype, + device=original_embedding.device, + ) + # Concatenate the two tensors along rows + padded_embedding = torch.cat([original_embedding, zeros_to_add], dim=0) + nemo_state_dict['model.language_model.embedding.word_embeddings.weight'] = padded_embedding + + return nemo_state_dict + + +def adjust_nemo_config(model_config, ref_config): + model_config.tokenizer["type"] = "intfloat/e5-large-unsupervised" # ref_config["_input_name_or_path"] + model_config["num_layers"] = ref_config["num_hidden_layers"] + model_config["hidden_size"] = ref_config["hidden_size"] + model_config["ffn_hidden_size"] = ref_config["intermediate_size"] + model_config["num_attention_heads"] = ref_config["num_attention_heads"] + model_config["layernorm_epsilon"] = ref_config["layer_norm_eps"] + model_config["normalization"] = "layernorm" + model_config["transformer_block_type"] = "post_ln" + model_config["apply_query_key_layer_scaling"] = False + model_config["skip_head"] = True + model_config["megatron_legacy"] = True + return model_config + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--input_name_or_path", type=str, default="thenlper/gte-large") + parser.add_argument( + "--hparams_file", + type=str, + default=f"{os.path.dirname(__file__)}/../../examples/nlp/language_modeling/conf/megatron_bert_config.yaml", + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") + parser.add_argument( + "--precision", type=str, default="32", choices=["bf16", "32"], help="Precision for checkpoint weights saved" + ) + + args = parser.parse_args() + return args + + +def convert(args): + logging.info(f"Loading checkpoint from HF: `{args.input_name_or_path}`") + hf_tokenizer = AutoTokenizer.from_pretrained(args.input_name_or_path) + hf_model = AutoModel.from_pretrained(args.input_name_or_path) + + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.model = adjust_nemo_config(nemo_config.model, hf_model.config.to_dict()) + + nemo_config.trainer["precision"] = args.precision + trainer = MegatronTrainerBuilder(nemo_config).create_trainer() + model = MegatronBertModel(nemo_config.model, trainer) + + old_state_dict = hf_model.state_dict() + rename_keys = create_rename_keys(nemo_config.model.num_layers) + new_state_dict = rename_model_keys(model_state_dict=old_state_dict, rename_keys=rename_keys) + nemo_state_dict = adjust_tensor_shapes(model, new_state_dict) + model.load_state_dict(nemo_state_dict, strict=True) + + logging.info(f'=' * 50) + # Verifications + input_texts = [ + 'query: how much protein should a female eat', + 'query: summit define', + "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", + "passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.", + ] + + # Tokenize the input texts + batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') + batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} + hf_model = hf_model.cuda().eval() + model = model.eval() + with torch.no_grad(): + hf_outputs = hf_model(**batch_dict_cuda) + embeddings_hf = average_pool(hf_outputs.last_hidden_state, batch_dict_cuda['attention_mask']) + embeddings_hf = F.normalize(embeddings_hf, p=2, dim=1) + + outputs = model(**batch_dict_cuda) + embeddings = average_pool(outputs[0], batch_dict_cuda['attention_mask']) + embeddings = F.normalize(embeddings, p=2, dim=1) + # Print difference between two embeddings + print("Difference between reference embedding and converted embedding results:") + print(embeddings - embeddings_hf) + + model.save_to(args.output_path) + logging.info(f'NeMo model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/nlp_language_modeling/convert_bert_nemo_to_hf.py b/scripts/nlp_language_modeling/convert_bert_nemo_to_hf.py new file mode 100644 index 000000000000..e970ea29fca2 --- /dev/null +++ b/scripts/nlp_language_modeling/convert_bert_nemo_to_hf.py @@ -0,0 +1,269 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example to run this conversion script: +``` + python convert_bert_hf_to_nemo.py \ + --input_name_or_path /path/to/input/nemo/file.nemo \ + --output_path /path/to/output/huggingface/file \ + --precision 32 +``` +""" + +from argparse import ArgumentParser + +import torch +import torch.nn.functional as F +from pytorch_lightning import Trainer +from transformers import AutoTokenizer, BertConfig, BertModel + +from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.utils import logging + + +def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +def create_rename_keys(num_hidden_layers): + rename_keys = [] + for i in range(num_hidden_layers): + # encoder layers: attention mechanism, 2 feedforward neural networks, and 2 layernorms + rename_keys.extend( + [ + ( + f"encoder.layer.{i}.attention.self.query.weight", + f"model.language_model.encoder.layers.{i}.self_attention.query.weight", + ), + ( + f"encoder.layer.{i}.attention.self.query.bias", + f"model.language_model.encoder.layers.{i}.self_attention.query.bias", + ), + ( + f"encoder.layer.{i}.attention.self.key.weight", + f"model.language_model.encoder.layers.{i}.self_attention.key.weight", + ), + ( + f"encoder.layer.{i}.attention.self.key.bias", + f"model.language_model.encoder.layers.{i}.self_attention.key.bias", + ), + ( + f"encoder.layer.{i}.attention.self.value.weight", + f"model.language_model.encoder.layers.{i}.self_attention.value.weight", + ), + ( + f"encoder.layer.{i}.attention.self.value.bias", + f"model.language_model.encoder.layers.{i}.self_attention.value.bias", + ), + ( + f"encoder.layer.{i}.attention.output.dense.weight", + f"model.language_model.encoder.layers.{i}.self_attention.dense.weight", + ), + ( + f"encoder.layer.{i}.attention.output.dense.bias", + f"model.language_model.encoder.layers.{i}.self_attention.dense.bias", + ), + ( + f"encoder.layer.{i}.attention.output.LayerNorm.weight", + f"model.language_model.encoder.layers.{i}.input_layernorm.weight", + ), + ( + f"encoder.layer.{i}.attention.output.LayerNorm.bias", + f"model.language_model.encoder.layers.{i}.input_layernorm.bias", + ), + ( + f"encoder.layer.{i}.intermediate.dense.weight", + f"model.language_model.encoder.layers.{i}.mlp.dense_h_to_4h.weight", + ), + ( + f"encoder.layer.{i}.intermediate.dense.bias", + f"model.language_model.encoder.layers.{i}.mlp.dense_h_to_4h.bias", + ), + ( + f"encoder.layer.{i}.output.dense.weight", + f"model.language_model.encoder.layers.{i}.mlp.dense_4h_to_h.weight", + ), + ( + f"encoder.layer.{i}.output.dense.bias", + f"model.language_model.encoder.layers.{i}.mlp.dense_4h_to_h.bias", + ), + ( + f"encoder.layer.{i}.output.LayerNorm.weight", + f"model.language_model.encoder.layers.{i}.post_attention_layernorm.weight", + ), + ( + f"encoder.layer.{i}.output.LayerNorm.bias", + f"model.language_model.encoder.layers.{i}.post_attention_layernorm.bias", + ), + ] + ) + + # Non-layer dependent keys + rename_keys.extend( + [ + ("embeddings.word_embeddings.weight", "model.language_model.embedding.word_embeddings.weight"), + ("embeddings.position_embeddings.weight", "model.language_model.embedding.position_embeddings.weight"), + ("embeddings.token_type_embeddings.weight", "model.language_model.embedding.tokentype_embeddings.weight"), + ("embeddings.LayerNorm.weight", "model.language_model.encoder.initial_layernorm.weight"), + ("embeddings.LayerNorm.bias", "model.language_model.encoder.initial_layernorm.bias"), + ("pooler.dense.weight", "model.language_model.pooler.dense.weight"), + ("pooler.dense.bias", "model.language_model.pooler.dense.bias"), + ] + ) + + return rename_keys + + +def rename_model_keys(model_state_dict, rename_keys): + """ + Rename keys in the model's state dictionary based on the provided mappings. + + Parameters: + model_state_dict (dict): The state dictionary of the model. + rename_keys (list): A list of tuples with the mapping (new_key, old_key). + + Returns: + dict: A new state dictionary with updated key names. + """ + + # Create a new state dictionary with updated key names + new_state_dict = {} + + # Track keys from the original state dict to ensure all are processed + remaining_keys = set(model_state_dict.keys()) + + # Iterate over the rename mappings + for new_key, old_key in rename_keys: + if old_key in model_state_dict: + # Rename the key and remove it from the tracking set + new_state_dict[new_key] = model_state_dict[old_key] + remaining_keys.remove(old_key) + else: + print(f"Warning: Key '{old_key}' not found in the model state dictionary.") + + # Check if any keys were not converted from old to new + for old_key in remaining_keys: + print(f"Warning: Key '{old_key}' was not converted.") + + return new_state_dict + + +def adjust_tensor_shapes(model_state_dict): + """ + Adapt tensor shapes in the state dictionary to ensure compatibility with a different model structure. + + Parameters: + nemo_state_dict (dict): The state dictionary of the model. + + Returns: + dict: The updated state dictionary with modified tensor shapes for compatibility. + """ + + # Note: For 'key' and 'value' weights and biases, NeMo uses a consolidated tensor 'query_key_value'. + for key_ in list(model_state_dict.keys()): + if "self_attention.query_key_value" in key_: + key_q = key_.replace('self_attention.query_key_value', 'self_attention.query') + key_k = key_.replace('self_attention.query_key_value', 'self_attention.key') + key_v = key_.replace('self_attention.query_key_value', 'self_attention.value') + local_dim = model_state_dict[key_].shape[0] // 3 + q, k, v = model_state_dict[key_].split(local_dim) + model_state_dict[key_q] = q + model_state_dict[key_k] = k + model_state_dict[key_v] = v + del model_state_dict[key_] + + return model_state_dict + + +def convert_config(ref_config, hf_state_dict): + vocab_size = hf_state_dict['embeddings.word_embeddings.weight'].shape[0] + new_config = { + "vocab_size": vocab_size, + "num_hidden_layers": ref_config["num_layers"], + "hidden_size": ref_config["hidden_size"], + "intermediate_size": ref_config["ffn_hidden_size"], + "num_attention_heads": ref_config["num_attention_heads"], + "layer_norm_eps": ref_config["layernorm_epsilon"], + "max_position_embeddings": ref_config["max_position_embeddings"], + } + hf_config = BertConfig(**new_config) + return hf_config + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--input_name_or_path", type=str, required=True, help="Path to .nemo file", + ) + parser.add_argument( + "--output_path", type=str, required=True, help="Output HF model path", + ) + + args = parser.parse_args() + return args + + +def convert(args): + logging.info(f"Loading checkpoint from: `{args.input_name_or_path}`") + dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) + nemo_model = MegatronBertModel.restore_from(args.input_name_or_path, trainer=dummy_trainer) + nemo_config = nemo_model.cfg + + old_state_dict = nemo_model.state_dict() + rename_keys = create_rename_keys(nemo_config.num_layers) + new_state_dict = adjust_tensor_shapes(old_state_dict) + hf_state_dict = rename_model_keys(model_state_dict=new_state_dict, rename_keys=rename_keys) + + hf_config = convert_config(nemo_config, hf_state_dict) + hf_model = BertModel(hf_config) + + hf_model.load_state_dict(hf_state_dict, strict=True) + + logging.info(f'=' * 50) + # Verifications + input_texts = [ + 'query: how much protein should a female eat', + 'query: summit define', + "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", + "passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.", + ] + + # Tokenize the input texts + hf_tokenizer = AutoTokenizer.from_pretrained(nemo_config.tokenizer["type"]) + batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') + batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} + hf_model = hf_model.cuda().eval() + nemo_model = nemo_model.eval() + with torch.no_grad(): + hf_outputs = hf_model(**batch_dict_cuda) + embeddings_hf = average_pool(hf_outputs.last_hidden_state, batch_dict_cuda['attention_mask']) + embeddings_hf = F.normalize(embeddings_hf, p=2, dim=1) + + outputs = nemo_model(**batch_dict_cuda) + embeddings = average_pool(outputs[0], batch_dict_cuda['attention_mask']) + embeddings = F.normalize(embeddings, p=2, dim=1) + # Print difference between two embeddings + print("Difference between reference embedding and converted embedding results:") + print(embeddings - embeddings_hf) + + hf_model.save_pretrained(args.output_path) + logging.info(f'Full HF model model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/nlp_language_modeling/export_nemo_bert_to_onnx.py b/scripts/nlp_language_modeling/export_nemo_bert_to_onnx.py new file mode 100644 index 000000000000..c6b3f351cc07 --- /dev/null +++ b/scripts/nlp_language_modeling/export_nemo_bert_to_onnx.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from argparse import ArgumentParser + +import torch +from omegaconf import OmegaConf + +from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertTextEmbeddingModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.utils import logging + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--nemo_path", type=str, required=True) + parser.add_argument( + "--hparams_file", + type=str, + default=f"{os.path.dirname(__file__)}/../../examples/nlp/language_modeling/conf/megatron_bert_config.yaml", + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument( + "--onnx_path", type=str, default="bert.onnx", required=False, help="Path to output .nemo file." + ) + parser.add_argument( + "--precision", type=str, default="32", choices=["bf16", "32"], help="Precision for checkpoint weights saved" + ) + + args = parser.parse_args() + return args + + +def export(args): + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.trainer["precision"] = args.precision + + trainer = MegatronTrainerBuilder(nemo_config).create_trainer() + model = MegatronBertTextEmbeddingModel.restore_from(args.nemo_path, trainer=trainer) + + hf_tokenizer = model.tokenizer.tokenizer + + logging.info(f'=' * 50) + # Verifications + input_texts = [ + 'query: how much protein should a female eat', + 'query: summit define', + "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", + "passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.", + ] + + # Tokenize the input texts + batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') + batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} + model = model.eval() + + input_names = ["input_ids", "attention_mask", "token_type_ids"] + output_names = ["outputs"] + export_input = tuple([batch_dict_cuda[name] for name in input_names]) + + torch.onnx.export( + model, export_input, args.onnx_path, verbose=False, input_names=input_names, output_names=output_names, + ) + logging.info(f'NeMo model saved to: {args.onnx_path}') + + +if __name__ == '__main__': + args = get_args() + export(args)