From 74e32c8a4fe368d7e66948a8e1258fd40ad0586c Mon Sep 17 00:00:00 2001 From: Adi Renduchintala Date: Wed, 10 Jul 2024 12:26:29 -0400 Subject: [PATCH] Contrastive Reranker/Reward model (#9171) * wip contrastive reranker Signed-off-by: arendu * wip Signed-off-by: arendu * wip Signed-off-by: arendu * working reranker training and validation Signed-off-by: arendu * default peft for reranker Signed-off-by: arendu * validation time update Signed-off-by: arendu * reranker test Signed-off-by: arendu * reranker inference Signed-off-by: arendu * reranker inference Signed-off-by: arendu * Apply isort and black reformatting Signed-off-by: arendu * updates Signed-off-by: arendu * Apply isort and black reformatting Signed-off-by: arendu * updates Signed-off-by: arendu * Apply isort and black reformatting Signed-off-by: arendu * also can support rlhf style reward model loss Signed-off-by: arendu * Apply isort and black reformatting Signed-off-by: arendu * Apply isort and black reformatting Signed-off-by: arendu * typo in cicd Signed-off-by: arendu --------- Signed-off-by: arendu Signed-off-by: arendu Signed-off-by: Adi Renduchintala Co-authored-by: arendu --- .github/workflows/cicd-main.yml | 41 +++ ...megatron_gpt_embedder_generate_config.yaml | 1 - .../megatron_gpt_embedder_tuning_config.yaml | 2 +- .../megatron_gpt_reranker_tuning_config.yaml | 222 +++++++++++++ .../megatron_gpt_embedding_generate.py | 5 +- .../megatron_gpt_reranker_finetuning.py | 76 +++++ .../megatron_gpt_reranker_generate.py | 138 ++++++++ .../tuning/megatron_gpt_finetuning.py | 2 +- .../gpt_embedding_dataset.py | 139 +++++++- .../megatron_gpt_embedding_model.py | 48 +-- .../megatron_gpt_reranker_model.py | 301 ++++++++++++++++++ .../language_modeling/megatron_gpt_model.py | 58 ++-- .../common/megatron/adapters/mcore_mixins.py | 33 ++ .../megatron/adapters/parallel_adapters.py | 65 +++- .../nlp/parts/mixins/nlp_adapter_mixins.py | 17 +- nemo/collections/nlp/parts/peft_config.py | 18 ++ 16 files changed, 1115 insertions(+), 51 deletions(-) create mode 100644 examples/nlp/information_retrieval/conf/megatron_gpt_reranker_tuning_config.yaml create mode 100644 examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py create mode 100644 examples/nlp/information_retrieval/megatron_gpt_reranker_generate.py create mode 100644 nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index bd794f59ae32..10cd8d1e6561 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -3198,6 +3198,47 @@ jobs: - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" + L2_Megatron_GPT_Reranker: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + timeout-minutes: 10 + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + rm -rf /home/TestData/nlp/megatron_ir/working_dir + + python examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py \ + exp_manager.exp_dir='/home/TestData/nlp/megatron_ir/working_dir' \ + model.global_batch_size=4 \ + model.micro_batch_size=4 \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + trainer.max_epochs=null \ + trainer.max_steps=20 \ + trainer.val_check_interval=10 \ + model.restore_from_path='/home/TestData/nlp/megatron_gpt/mcore_45M/megatron_llama.nemo' \ + model.peft.lora_tuning.adapter_dim=8 \ + model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_ir/train.jsonl] \ + model.data.validation_ds.write_embeddings_to_file=True \ + model.data.validation_ds.output_file_path_prefix='/home/TestData/nlp/megatron_ir/working_dir/val_embs' \ + model.data.train_ds.file_names=[/home/TestData/nlp/megatron_ir/train.jsonl] + + + rm -rf /home/TestData/nlp/megatron_ir/working_dir + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + L2_Megatron_GPT_Embedding: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml diff --git a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml index 1a81d21dd9a8..e407aec167e9 100644 --- a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml +++ b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml @@ -120,7 +120,6 @@ model: tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre data: - return_output_tensors: True test_ds: query_file_names: ??? # Path to a list of JSONL files corresponding to the query data. Data format is identical to validation_ds. doc_file_names: ??? # Path to a list of JSONL files corresponding to the doc data. Data format is identical to validation_ds. diff --git a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml index 6677dc2ed46c..1c2db1a862f4 100644 --- a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml +++ b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml @@ -84,6 +84,7 @@ model: use_flash_attention: True precision: bf16 apply_rope_fusion: False + reward_model_loss: False # Set this to true to perform RLHF style reward model loss -log(sigmoid(accept_logit - reject_logit)) peft: peft_scheme: "lora" # can be either adapter,ia3, or ptuning @@ -126,7 +127,6 @@ model: tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre data: - return_output_tensors: True train_ds: # Example of how to specify paths to multiple datasets # file_names: diff --git a/examples/nlp/information_retrieval/conf/megatron_gpt_reranker_tuning_config.yaml b/examples/nlp/information_retrieval/conf/megatron_gpt_reranker_tuning_config.yaml new file mode 100644 index 000000000000..863b5fb475a0 --- /dev/null +++ b/examples/nlp/information_retrieval/conf/megatron_gpt_reranker_tuning_config.yaml @@ -0,0 +1,222 @@ +name: megatron_gpt_peft_reranker_tuning + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: null + max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: ${trainer.max_steps} # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: null + num_sanity_val_steps: 0 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: True + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: selective # 'selective' or 'full' + activations_checkpoint_method: uniform # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + temperature: 0.02 + num_soft_negatives: 0 # Number of soft negatives to use for contrastive loss,it should be max(batch_size - 1), 0 means use hard negatives only + use_all_possible_negatives: False # If True, use all possible negatives for contrastive loss, otherwise use num_soft_negatives, if num_soft_negatives is 0, use hard negatives only + post_process: False # should be False. + apply_rope_fusion: False + transformer_engine: True # required to be True for newer versions of Megatron-LM based models + mcore_gpt: True # required to be True for newer versions of Megatron-LM based models + use_flash_attention: True + precision: bf16 + + peft: + peft_scheme: "mlp_head,lora" # can be either adapter,ia3, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['attention_qkv', 'attention_dense', 'mlp_fc1', 'mlp_fc2'] # + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + # Instead of using the GPT LM Head, we can use a custom head for the reranking task + mlp_head_tuning: + out_features: 1 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + data: + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_names: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + memmap_workers: 2 + pin_memory: True + max_seq_length: 512 # Even if the base model can handle longer sequences, 512 is generally a good choice for training efficiency. + min_seq_length: 1 + drop_last: True + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + concat_sampling_probabilities: + - 1.0 + label_key: 'output' + add_eos: True + add_bos: False + index_mapping_dir: null # Path to a directory to write index mapping files. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + validation_ds: + file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: ["validation"] # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: ${model.data.train_ds.max_seq_length} + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_bos: ${model.data.train_ds.add_bos} + write_embeddings_to_file: False + output_file_path_prefix: "validation_rankings" # Prefix of the file to write predictions to. + index_mapping_dir: null # Path to a directory to write index mapping files. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + test_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: ${model.data.train_ds.max_seq_length} + min_seq_length: 1 + drop_last: False + add_eos: ${model.data.train_ds.add_eos} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: True + output_file_path_prefix: "test_embeddings" # Prefix of the file to write predictions to. + index_mapping_dir: null # Path to a directory to write index mapping files. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false \ No newline at end of file diff --git a/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py b/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py index 8cddcebbab62..d66ddb339773 100644 --- a/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py +++ b/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py @@ -68,7 +68,9 @@ def use_inference_server(cfg, model, trainer): web_ui = get_demo loop = asyncio.new_event_loop() thread = threading.Thread( - target=web_ui, daemon=True, args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), + target=web_ui, + daemon=True, + args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), ) thread.start() server = MegatronServer(model.cuda()) @@ -93,7 +95,6 @@ def main(cfg) -> None: model_cfg = MegatronGPTEmbeddingModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) with open_dict(model_cfg): - model_cfg.data.return_output_tensors = True model_cfg.post_process = False model = MegatronGPTEmbeddingModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) diff --git a/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py b/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py new file mode 100644 index 000000000000..cf65840bb843 --- /dev/null +++ b/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import MutableMapping + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf +from pytorch_lightning.loggers import WandbLogger + +from nemo.collections.nlp.models.information_retrieval.megatron_gpt_reranker_model import MegatronGPTRerankerModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +mp.set_start_method("spawn", force=True) + + +def flatten_dict(d: MutableMapping, parent_key: str = '', sep: str = '.') -> MutableMapping: + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_reranker_tuning_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + model_cfg = MegatronGPTRerankerModel.merge_cfg_with(cfg.model.restore_from_path, cfg) + if trainer.global_rank == 0: + for logger in trainer.loggers: + if isinstance(logger, WandbLogger): + fd = flatten_dict(dict(model_cfg), sep="/") + logger.experiment.config.update(fd) + model = MegatronGPTRerankerModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + peft_cfg_cls_lst = [PEFT_CONFIG_MAP[s] for s in cfg.model.peft.peft_scheme.split(",")] + peft_cfg_cls = [_peft_cfg(model_cfg) for _peft_cfg in peft_cfg_cls_lst] + + if cfg.model.peft.restore_from_path is not None: + # initialize peft weights from a checkpoint instead of randomly + # This is not the same as resume training because optimizer states are not restored. + logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path) + model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls) + elif peft_cfg_cls is not None: + logging.info("Adding adapter weights to the model for PEFT") + # model.add_adapter(peft_cfg_cls(model_cfg)) + model.add_adapter(peft_cfg_cls) + else: + logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}") + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/information_retrieval/megatron_gpt_reranker_generate.py b/examples/nlp/information_retrieval/megatron_gpt_reranker_generate.py new file mode 100644 index 000000000000..a91449c3deda --- /dev/null +++ b/examples/nlp/information_retrieval/megatron_gpt_reranker_generate.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import os +import threading +from functools import partial + +import torch +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.nlp.models.information_retrieval.megatron_gpt_reranker_model import MegatronGPTRerankerModel +from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer +from nemo.collections.nlp.modules.common.text_generation_utils import generate +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.model_utils import inject_model_parallel_rank + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + +mp.set_start_method("spawn", force=True) + + +def use_inference_server(cfg, model, trainer): + if not HAVE_MEGATRON_CORE: + raise ValueError('Megatron-core needs to be installed to use this feature!') + + from nemo.collections.nlp.modules.common.megatron_web_server import get_chatbot_demo, get_demo + + trainer.test(model, dataloaders=None) + + if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0: + if cfg.web_server: + if cfg.chat: + defaults = { + 'user': cfg.chatbot_config.user, + 'assistant': cfg.chatbot_config.assistant, + 'system': cfg.chatbot_config.system, + } + web_ui = partial( + get_chatbot_demo, + defaults=defaults, + value=cfg.chatbot_config.value, + attributes=cfg.chatbot_config.attributes, + ) + else: + web_ui = get_demo + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=web_ui, + daemon=True, + args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), + ) + thread.start() + server = MegatronServer(model.cuda()) + server.run("0.0.0.0", port=cfg.port) + + while True: + choice = torch.cuda.LongTensor(1) + torch.distributed.broadcast(choice, 0) + if choice[0].item() == 0: + generate(model.cuda()) + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_reranker_generate_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + + if cfg.model.peft.restore_from_path: + model_cfg = MegatronGPTRerankerModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg) + else: + model_cfg = MegatronGPTRerankerModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) + + with open_dict(model_cfg): + model_cfg.post_process = False + + model = MegatronGPTRerankerModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + + if cfg.model.peft.restore_from_path: + model.load_adapters(cfg.model.peft.restore_from_path) + elif cfg.model.peft.restore_from_ckpt.checkpoint_dir and cfg.model.peft.restore_from_ckpt.checkpoint_name: + peft_cfg_cls_lst = [PEFT_CONFIG_MAP[s] for s in cfg.model.peft.peft_scheme.split(",")] + peft_cfg_cls = [_peft_cfg(model_cfg) for _peft_cfg in peft_cfg_cls_lst] + + checkpoint_path = os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank( + os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + ) + model.load_adapters(checkpoint_path, peft_cfgs=peft_cfg_cls) + else: + raise NotImplementedError("distributed checkpointing of PEFT weights is not supported") + + model.freeze() + logging.info(f"Freezing parameters for PEFT eval:\n{model.summarize()}") + + if not cfg.model.get('use_flash_attention', False): + cfg.inference.compute_attention_mask = True + config = OmegaConf.to_container(cfg.inference, resolve=True) + model.set_inference_config(config) + + if not cfg.server: + trainer.test(model) + else: + use_inference_server(cfg, model, trainer) + + +if __name__ == "__main__": + main() diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py b/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py index aaa087a46623..bfe8ea35960e 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py b/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py index e697d5ec3bf6..3a2a8152313e 100644 --- a/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py +++ b/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py @@ -27,7 +27,7 @@ from nemo.core.classes import Dataset from nemo.utils import logging -__all__ = ['GPTEmbeddingDataset'] +__all__ = ['GPTEmbeddingDataset', 'GPTRerankerDataset'] class GPTEmbeddingDataset(Dataset): @@ -49,7 +49,7 @@ def __init__( data_type: str = 'train', # train, query or doc ): """ - file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. + file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. @@ -279,3 +279,138 @@ def collate_fn(self, batch): } return processed_batch + + +class GPTRerankerDataset(GPTEmbeddingDataset): + def __init__( + self, + file_path: str, + tokenizer: TokenizerSpec, + max_seq_length: int = 1024, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + max_num_samples: int = None, + seed: int = 1234, + index_mapping_dir: str = None, + virtual_tokens: int = 0, + memmap_workers: Optional[int] = None, + truncation_method: str = 'right', + special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token} + data_type: str = 'train', # train, query or doc + ): + """ + file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. + tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). + max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. + min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. + add_bos (bool): Whether to add a beginning of sentence token to each data example + add_eos (bool): Whether to add an end of sentence token to each data example + seed: Random seed for data shuffling. + max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. + index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset. + truncation_method: Truncation from which position. Options: ['left', 'right'] + special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} + """ + super().__init__( + file_path=file_path, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + min_seq_length=min_seq_length, + add_bos=add_bos, + add_eos=add_eos, + max_num_samples=max_num_samples, + seed=seed, + index_mapping_dir=index_mapping_dir, + virtual_tokens=virtual_tokens, + memmap_workers=memmap_workers, + truncation_method=truncation_method, + special_tokens=special_tokens, + data_type=data_type, + ) + + def _process_example(self, example): + """ + Create an example by concatenating text and answer. + Truncation is carried out when needed, but it is performed only on the prompt side. + BOS, EOS, and SEP, are added if specified. + """ + metadata = {k: v for k, v in example.items()} + if self.data_type == 'train': + qd = self.tokenizer.text_to_ids( + "query: " + example['query'].strip() + " passage: " + example['pos_doc'].strip() + ) + qnd = self.tokenizer.text_to_ids( + "query: " + example['query'].strip() + " passage: " + example['neg_doc'].strip() + ) + else: + qd = self.tokenizer.text_to_ids( + "query: " + example['query'].strip() + " passage: " + example['pos_doc'].strip() + ) + qnd = [] + + if self.virtual_tokens: + # (@adithyare) we are going to insert "pad/eos" tokens in the beginning of the text and context + # these pad/eos tokens are placeholders for virtual tokens for ptuning (if used) + qd = [self.tokenizer.eos_id] * self.virtual_tokens + qd # type: ignore + qnd = [self.tokenizer.eos_id] * self.virtual_tokens + qnd # type: ignore + + if self.add_bos: + qd = [self.tokenizer.bos_id] + qd # type: ignore + qnd = [self.tokenizer.bos_id] + qnd # type: ignore + + # TODO: (@adithyare) should probably add a warning before truncation + qd = qd[: self.max_seq_length - 1] + qnd = qnd[: self.max_seq_length - 1] + + if self.add_eos: + qd = qd + [self.tokenizer.eos_id] # type: ignore + qnd = qnd + [self.tokenizer.eos_id] # type: ignore + + processed_example = { + 'query_pos_doc': qd, + 'query_neg_doc': qnd, + 'metadata': metadata, + } + + return processed_example + + def collate_fn(self, batch): + input_ids = [] + metadata = [] + lengths = [] + max_length = -1 + for item in batch: + metadata.append(item['metadata']) + if self.data_type == 'train': + input_ids.append(item['query_pos_doc']) + lengths.append(len(item['query_pos_doc'])) + input_ids.append(item['query_neg_doc']) + lengths.append(len(item['query_neg_doc'])) + max_length = max(max_length, len(item['query_pos_doc']), len(item['query_neg_doc'])) + else: + input_ids.append(item['query_pos_doc']) + lengths.append(len(item['query_pos_doc'])) + max_length = max(max_length, len(item['query_pos_doc'])) + + max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16)) + assert max_length <= self.max_seq_length + + attention_mask = [self._create_attention_mask(max_length) for _ in input_ids] + attention_mask = torch.stack(attention_mask) + position_ids = [list(range(max_length)) for _ in input_ids] + position_ids = torch.LongTensor(position_ids) + input_ids = torch.LongTensor( + self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id) + ) + lengths = torch.LongTensor(lengths) - 1 # subtract 1 to account for the eos token + + processed_batch = { + 'tokens': input_ids, + 'attention_mask': attention_mask, + 'loss_mask': lengths, + 'position_ids': position_ids, + 'metadata': metadata, + } + + return processed_batch diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py index 67fd2b1b6c62..c7565f45358e 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py @@ -36,11 +36,6 @@ except (ImportError, ModuleNotFoundError): HAVE_MEGATRON_CORE = False -try: - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False def listify(tensor): @@ -52,6 +47,17 @@ def listify(tensor): return l_tensor +def _gather_global_inbatch_representations(local_eos_tensor): + local_eos_tensor = local_eos_tensor.contiguous() + global_eos_tensors = [ + torch.zeros_like(local_eos_tensor) for _ in range(parallel_state.get_data_parallel_world_size()) + ] + torch.distributed.all_gather(global_eos_tensors, local_eos_tensor, group=parallel_state.get_data_parallel_group()) + global_eos_tensors[parallel_state.get_data_parallel_rank()] = local_eos_tensor + global_eos_tensors = torch.cat(global_eos_tensors, dim=0) + return global_eos_tensors + + class MegatronGPTEmbeddingModel(MegatronGPTSFTModel): def __init__(self, cfg: DictConfig, trainer: Trainer): super().__init__(cfg, trainer=trainer) @@ -412,25 +418,20 @@ def inference_loss_func(self, loss_mask, num_valid_tokens_in_ub, eos_tensors): hs = eos_tensors hs = torch.nn.functional.normalize(hs, dim=1) _blank = torch.zeros(1, device=hs.device, dtype=hs.dtype)[0] - return _blank, hs, hs, _blank, _blank, _blank - - def _gather_global_inbatch_representations(self, local_eos_tensor): - local_eos_tensor = local_eos_tensor.contiguous() - global_eos_tensors = [ - torch.zeros_like(local_eos_tensor) for _ in range(parallel_state.get_data_parallel_world_size()) - ] - torch.distributed.all_gather( - global_eos_tensors, local_eos_tensor, group=parallel_state.get_data_parallel_group() - ) - global_eos_tensors[parallel_state.get_data_parallel_rank()] = local_eos_tensor - global_eos_tensors = torch.cat(global_eos_tensors, dim=0) - return global_eos_tensors + return { + "loss": _blank, + "query_hs": hs, + "pos_doc_hs": hs, + "pos_cs": _blank, + "neg_cs": _blank, + "diff_cs": _blank, + } def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): idx = torch.arange(output_tensor.shape[1], device=output_tensor.device) eos_tensors = output_tensor[loss_mask, idx, :] if self.global_inbatch_negatives and self.trainer.training: - eos_tensors = self._gather_global_inbatch_representations(eos_tensors) + eos_tensors = _gather_global_inbatch_representations(eos_tensors) if not self.trainer.training: return self.inference_loss_func(loss_mask, num_valid_tokens_in_ub, eos_tensors) bs = eos_tensors.shape[0] // 3 @@ -464,4 +465,11 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): query_hs = query_hs.clone().detach() pos_doc_hs = pos_doc_hs.clone().detach() diff_cs = pos_cs - neg_cs - return loss, query_hs, pos_doc_hs, pos_cs, neg_cs, diff_cs + return { + "loss": loss, + "query_hs": query_hs, + "pos_doc_hs": pos_doc_hs, + "pos_cs": pos_cs, + "neg_cs": neg_cs, + "diff_cs": diff_cs, + } diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py new file mode 100644 index 000000000000..e316871fe607 --- /dev/null +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py @@ -0,0 +1,301 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os + +import numpy as np +import torch +from omegaconf import DictConfig, ListConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.data.information_retrieval.gpt_embedding_dataset import GPTRerankerDataset +from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( + get_datasets_weights_and_num_samples, +) +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.collections.nlp.models.information_retrieval.megatron_gpt_embedding_model import ( + MegatronGPTEmbeddingModel, + _gather_global_inbatch_representations, +) +from nemo.utils import logging + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +def listify(tensor): + l_tensor = [] + for t in tensor: + for rid in range(t.shape[0]): + r = t[rid, :].unsqueeze(0).cpu() + l_tensor.append(r) + return l_tensor + + +class MegatronGPTRerankerModel(MegatronGPTEmbeddingModel): + def __init__(self, cfg: DictConfig, trainer: Trainer): + self.reward_model_loss = cfg.get("reward_model_loss", False) + super().__init__(cfg, trainer=trainer) + + def model_provider_func(self, pre_process, post_process): + # (@adithyare) We need post_process to be False to get hidden states in the loss_func + return super().model_provider_func(pre_process, post_process=False) + + def maybe_setup_test(self): + if hasattr(self.cfg.data, 'test_ds') and self.cfg.data.test_ds.get('file_names', None) is not None: + self._test_dl = self.setup_eval_dataloader(self._test_ds, self.cfg.data.test_ds) + return + + def maybe_build_test(self): + if hasattr(self.cfg.data, 'test_ds') and self.cfg.data.test_ds.get('file_names', None) is not None: + logging.info('Building GPT Reranker test datasets.') + # Wrap this in a list since the general finetuning parent class supports multi-validation. + self._test_ds = self._build_dataset(self.cfg.data.test_ds, is_train=False) + + def _build_dataset(self, data_cfg, is_train=True): + packed_sequence = data_cfg.get("packed_sequence", False) + + # Determine if we are using a single dataset or a list of datasets. + if is_train: + # Construct the data prefix list for `get_datasets_weights_and_num_samples()` + # that is of the format [weight1,file_name1,weight2,file_name2,...] + if data_cfg.concat_sampling_probabilities is None or not isinstance( + data_cfg.concat_sampling_probabilities, ListConfig + ): + raise ValueError( + ( + f"concat_sampling_probabilities must be a ListConfig with the same number of files in file_names." + f"Found: {data_cfg.concat_sampling_probabilities}" + ) + ) + + if len(data_cfg.get('concat_sampling_probabilities', None)) != len(data_cfg.file_names): + raise ValueError( + ( + f"concat_sampling_probabilities must be of the same size as file_names.", + f"Provided size {len(data_cfg.concat_sampling_probabilities)}, number of datasets {len(data_cfg.file_names)}", + ) + ) + + data_prefix = [] + for weight, prefix in zip(data_cfg.concat_sampling_probabilities, data_cfg.file_names): + data_prefix.append(weight) + data_prefix.append(prefix) + + if self.trainer.max_steps is None or self.trainer.max_steps <= 0: + raise ValueError( + f'Trainer max_steps must be set to a positive integer. Found {self.trainer.max_steps}' + ) + num_train_samples = [self.trainer.max_steps * data_cfg.global_batch_size] + _, _, num_train_samples_per_dataset = get_datasets_weights_and_num_samples(data_prefix, num_train_samples) + num_train_samples_after_blend = sum([x[0] for x in num_train_samples_per_dataset]) + else: + num_train_samples_per_dataset = [[None]] * len(data_cfg.file_names) + + # Check dataset max_seq_legnth and max_position_embeddings size + if ( + self.cfg.get('position_embedding_type', None) in [None, 'learned_absolute'] + and data_cfg.max_seq_length > self.cfg.max_position_embeddings + ): + logging.warning( + f"Set dataset max_seq_length to max_position_embeddings {self.cfg.max_position_embeddings} if using learned_absolute position embedding" + ) + data_cfg.max_seq_length = self.cfg.max_position_embeddings + + # TE requires that the first input dim is divisible by 8 and the second by 16 for fp8 + # When using sequence parallel, sequence will further be split by TP size + pad_seq_length_to_mult = ( + 8 * self.cfg.get('tensor_model_parallel_size', 1) if self.cfg.get('sequence_parallel', False) else 16 + ) + pad_seq_length_to_mult *= self.cfg.get('context_parallel_size', 1) + + datasets = [] + for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset): + dataset = GPTRerankerDataset( + file_path=file_path, + tokenizer=self.tokenizer, + max_seq_length=data_cfg.max_seq_length, + min_seq_length=data_cfg.min_seq_length, + add_bos=data_cfg.get('add_bos', False), + add_eos=data_cfg.get('add_eos', True), + max_num_samples=num_samples[0], + seed=data_cfg.get('seed', 1234), + index_mapping_dir=data_cfg.get('index_mapping_dir', None), + virtual_tokens=self.virtual_tokens, + memmap_workers=data_cfg.get( + 'memmap_workers', None + ), # used to set num. of workers to create the memmap index files + truncation_method=data_cfg.get( + 'truncation_method', 'right' + ), # used to choose truncation method. Options: ['random', 'left', 'right'] + special_tokens=self.cfg.data.get( + 'chat_prompt_tokens', None + ), # special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} + data_type="train" if is_train else "validation", + ) + datasets.append(dataset) + if is_train: + if packed_sequence: + num_train_samples_after_blend = sum(len(dataset) for dataset in datasets) + dataset = BlendableDataset( + datasets=datasets, weights=data_cfg.concat_sampling_probabilities, size=num_train_samples_after_blend + ) + return dataset + else: + return datasets + + def training_step_fwd_bwd_step_call(self, dataloader_iter, forward_only): + loss_mean, non_loss_tensors = self.fwd_bwd_step(dataloader_iter, forward_only) + logit_diff = non_loss_tensors['logit_diff'][0].item() + self.log("logit_diff", logit_diff, prog_bar=True, rank_zero_only=True, batch_size=1) + return loss_mean + + def inference_step_validation_call(self, batch, batch_idx, data_cfg, dataloader_idx=0): + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + loss, non_loss_tensors = self.local_validation_step(itertools.chain([dataloader_idx], [batch])) + outputs = { + 'loss': loss, + 'metadata': metadata, # [dict] + 'query_pos_doc_logit': non_loss_tensors['query_pos_doc_logit'], # [batch_size, hidden_size] + } + return outputs + + def inference_loss_func(self, loss_mask, num_valid_tokens_in_ub, eos_tensors): + query_pos_doc_hs = eos_tensors + _blank = torch.zeros(1, device=query_pos_doc_hs.device, dtype=query_pos_doc_hs.dtype)[0] + return { + "loss": _blank, + "query_pos_doc_logit": query_pos_doc_hs, + "query_neg_doc_logit": _blank, + "logit_diff": _blank, + } + + def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): + idx = torch.arange(output_tensor.shape[1], device=output_tensor.device) + eos_tensors = output_tensor[loss_mask, idx, :] # (bs x 1) + if self.global_inbatch_negatives and self.trainer.training: + eos_tensors = _gather_global_inbatch_representations(eos_tensors) + if not self.trainer.training: + return self.inference_loss_func(loss_mask, num_valid_tokens_in_ub, eos_tensors) + bs = eos_tensors.shape[0] // 2 + query_pos_doc_hs = eos_tensors[::2, :] # every second tensor from idx 0 is a query w pos_doc (bs x 1) + query_neg_doc_hs = eos_tensors[1::2, :] # every second tensor from idx 1 is a query w negative doc (bs x 1) + + if self.reward_model_loss: + loss = -torch.nn.functional.logsigmoid(query_pos_doc_hs - query_neg_doc_hs).mean() + else: + cs = torch.cat([query_pos_doc_hs, query_neg_doc_hs], dim=1) # (bs x 2) + cs = cs / self.temperature + labels = torch.zeros(bs, device=cs.device).long() + loss = torch.nn.functional.cross_entropy(cs, labels) + + cp_size = self.cfg.get('context_parallel_size', 1) + if cp_size > 1: + torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) + query_pos_doc_hs = query_pos_doc_hs.clone().detach() + query_neg_doc_hs = query_neg_doc_hs.clone().detach() + logit_diffs = torch.mean(query_pos_doc_hs - query_neg_doc_hs) + return { + "loss": loss, + "query_pos_doc_logit": query_pos_doc_hs, + "query_neg_doc_logit": query_neg_doc_hs, + "logit_diff": logit_diffs, + } + + def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_metric, dataloader_idx=0): + if not data_cfg.get("write_embeddings_to_file", False): + return True + gathered_output_batches = [None for _ in range(parallel_state.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + gathered_output_batches, + [ + { + 'query_pos_doc_logit': batch['query_pos_doc_logit'], + 'metadata': batch['metadata'], + } + for batch in output + ], + group=parallel_state.get_data_parallel_group(), + ) + + # Remove duplicate examples due to distributed sampler. + deduplicated_outputs = { + 'query_pos_doc_logit': [], + 'metadata': [], + } + total_size, skipped = 0, 0 + for rank in range(0, parallel_state.get_data_parallel_world_size()): + for batch in gathered_output_batches[rank]: + l_q_hs = listify(batch['query_pos_doc_logit']) + l_m = batch['metadata'] + assert len(l_m) == len(l_q_hs) + for q_hs, metadata in zip( + l_q_hs, + l_m, + ): + total_size += 1 + if not metadata.get("__AUTOGENERATED__", False): + deduplicated_outputs['query_pos_doc_logit'].append(q_hs) + deduplicated_outputs['metadata'].append(metadata) + else: + skipped += 1 + + logging.info( + f"{total_size-skipped} deduplicated outputs in dataloader:{dataloader_idx}, (skipped {skipped} autogenerated examples)." + ) + # Compute metric score + metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name + assert metric_name == "loss", "Only loss is supported for now." + # avg_pos_cs = torch.tensor(deduplicated_outputs['avg_pos_cs']).mean().item() + # avg_neg_cs = torch.tensor(deduplicated_outputs['avg_neg_cs']).mean().item() + # diff_cs = torch.tensor(deduplicated_outputs['diff_cs']).mean().item() + # self.log('val_avg_pos_cs', avg_pos_cs, prog_bar=True, rank_zero_only=True, batch_size=1) + # self.log('val_avg_neg_cs', avg_neg_cs, prog_bar=True, rank_zero_only=True, batch_size=1) + # self.log('val_diff_cs', diff_cs, prog_bar=True, rank_zero_only=True, batch_size=1) + + # Write predictions to file + if self.global_rank == 0 and data_cfg.get("write_embeddings_to_file", False): + logging.info( + f"Total deduplicated inference data size: {total_size} to {len(deduplicated_outputs['metadata'])}" + ) + + # Check if the user provided a prefix path to the file(s) they want to write. + if not hasattr(data_cfg, "output_file_path_prefix") or data_cfg.output_file_path_prefix is None: + raise ValueError( + f"Cannot write predictions to file when output_file_path_prefix is not set or present in the yaml config file." + ) + # (@adithyare) We are not using the log key to write the embeddings to file + filename_log_key = self._determine_log_key(data_cfg, dataloader_idx, None, mode) + consumed_samples = self._compute_consumed_samples_after_training_step() + fldr_path = f"{data_cfg.output_file_path_prefix}/consumed_samples{consumed_samples}/{filename_log_key}" + self.write_embeddings_to_file(deduplicated_outputs, fldr_path, dataloader_idx) + return deduplicated_outputs, total_size + + def write_embeddings_to_file(self, outputs, output_file_path, d_idx): + hs = torch.cat(outputs['query_pos_doc_logit'], dim=0) + hs_npy = hs.float().numpy() + emb_fldr = f"{output_file_path}" + os.makedirs(emb_fldr, exist_ok=True) + with open(f"{output_file_path}/logits.ids", "w") as f: + for m in outputs['metadata']: + f.write(f"{m['query_id'].strip()} {m['doc_id']}\n") + np.save(f"{emb_fldr}/logits.npy", hs_npy) + return True diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 4f9722d900f6..69cd06021f50 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -391,7 +391,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) self.loss_broadcast_src_rank = None data_cfg = cfg.get('data', {}) - self.return_output_tensors = data_cfg.get('return_output_tensors', False) self.validation_drop_last = data_cfg.get('validation_drop_last', True) self.sample_weight = data_cfg.get('sample_weight', 'token') self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False) @@ -1275,24 +1274,47 @@ def loss_func(output_tensor): # Loss for a micro-batch (ub) loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) cp_size = parallel_state.get_context_parallel_world_size() - if self.return_output_tensors: + if isinstance(loss_for_ub, dict): # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare) - loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub - reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) - pos_cs = average_losses_across_data_parallel_group([pos_cs]) - neg_cs = average_losses_across_data_parallel_group([neg_cs]) - diff_cs = average_losses_across_data_parallel_group([diff_cs]) - return ( - loss_for_ub * cp_size, - { - 'avg': reduced_loss, - 'query_hs': q_hs, - 'doc_hs': d_hs, - 'avg_pos_cs': pos_cs, - 'avg_neg_cs': neg_cs, - 'diff_cs': diff_cs, - }, - ) + + if set(loss_for_ub.keys()) == set( + ["loss", "query_hs", "pos_doc_hs", "pos_cs", "neg_cs", "diff_cs"] + ): # (adithyare) this check will be True for GPT Embedding models + loss = loss_for_ub['loss'] + reduced_loss = average_losses_across_data_parallel_group([loss]) + pos_cs = average_losses_across_data_parallel_group([loss_for_ub['pos_cs']]) + neg_cs = average_losses_across_data_parallel_group([loss_for_ub['neg_cs']]) + diff_cs = average_losses_across_data_parallel_group([loss_for_ub['diff_cs']]) + return ( + loss * cp_size, + { + 'avg': reduced_loss, + 'query_hs': loss_for_ub['query_hs'], + 'doc_hs': loss_for_ub['pos_doc_hs'], + 'avg_pos_cs': pos_cs, + 'avg_neg_cs': neg_cs, + 'diff_cs': diff_cs, + }, + ) + elif set(loss_for_ub.keys()) == set( + ["loss", "query_pos_doc_logit", "query_neg_doc_logit", "logit_diff"] + ): # (adithyare) this check will be True for GPT Reranker models + + loss = loss_for_ub['loss'] + reduced_loss = average_losses_across_data_parallel_group([loss]) + logit_diff = average_losses_across_data_parallel_group([loss_for_ub['logit_diff']]) + return ( + loss * cp_size, + { + 'avg': reduced_loss, + 'query_pos_doc_logit': loss_for_ub['query_pos_doc_logit'], + 'query_neg_doc_logit': loss_for_ub['query_neg_doc_logit'], + 'logit_diff': logit_diff, + }, + ) + else: + raise RuntimeError(f"Dict loss_for_ub has unknown key set {loss_for_ub.keys()}") + elif validation_step and not self.validation_drop_last: num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub'] if loss_for_ub.isnan(): diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index 2f00f5907ad8..48b6afa788ae 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -14,17 +14,21 @@ import torch import torch.nn.functional as F +from megatron.core import InferenceParams from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim from megatron.core.transformer.mlp import MLP from megatron.core.transformer.moe.experts import SequentialMLP +from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import make_viewless_tensor +from torch import Tensor from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, @@ -37,6 +41,7 @@ LoraMoeHto4HAdapterConfig, LoraUnfusedHto4HAdapterConfig, LoraUnfusedKQVAdapterConfig, + MLPHeadAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, PromptEncoderAdapterConfig, @@ -61,6 +66,34 @@ def mcore_register_adapters(self): raise NotImplementedError("Mcore mixins should implement setup_adapters on a subclass of MyBase") +class MCoreTransformerBlockMixin(TransformerBlock, MCoreAdapterModuleMixin): + def mcore_register_adapters(self): + """ + Setup NeMo (canonical) Adapter to this MCore layer. + """ + self.set_accepted_adapter_types([MLPHeadAdapterConfig._target_]) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor = None, + context_mask: Tensor = None, + rotary_pos_emb: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + ): + hidden_states = super().forward( + hidden_states, attention_mask, context, context_mask, rotary_pos_emb, inference_params, packed_seq_params + ) + + mlp_head_adapter = self.get_adapter_module(AdapterName.MLP_HEAD_ADAPTER) + if mlp_head_adapter and self.adapter_cfg[AdapterName.MLP_HEAD_ADAPTER]['enabled']: + hidden_states = mlp_head_adapter(hidden_states) + + return hidden_states + + class MCoreSelfAttentionMixin(SelfAttention, MCoreAdapterModuleMixin): def mcore_register_adapters(self): """ diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 9ab1da7136a1..8d2d77c55cf2 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -77,6 +77,7 @@ class AdapterName(str, enum.Enum): PTUNING_ADAPTER = "ptuning_adapter" LORA_KQV_ADAPTER = "lora_kqv_adapter" LORA_UNFUSED_KQV_ADAPTER = "lora_unfused_kqv_adapter" + MLP_HEAD_ADAPTER = "mlp_head_adapter" LORA_KV_ADAPTER = "lora_kv_adapter" LORA_Q_ADAPTER = "lora_q_adapter" MM_LINEAR_ADAPTER = "mm_linear_adapter" @@ -388,6 +389,57 @@ class ParallelLinearAdapterConfig(AdapterConfig): _target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__) +class MLPHeadAdapter(nn.Module, AdapterModuleUtil): + def __init__( + self, + in_features: int, + out_features: int, + input_is_parallel: bool = False, + model_parallel_config: Optional[ModelParallelConfig] = None, + **kwargs, + ): + super().__init__() + if model_parallel_config is None: + model_parallel_config = ModelParallelConfig() + self._sequence_parallel = model_parallel_config.sequence_parallel + model_parallel_config.sequence_parallel = False # SP is irrelevant for the lora linear layer + + if input_is_parallel: + self.linear = RowParallelLinear( + in_features, + out_features, + config=model_parallel_config, + input_is_parallel=True, + skip_bias_add=True, + bias=False, + init_method=init.xavier_normal_, + ) + else: + self.linear = ColumnParallelLinear( + in_features, + out_features, + config=model_parallel_config, + bias=False, + gather_output=True, + init_method=init.xavier_normal_, + disable_grad_reduce=self._sequence_parallel, + ) + + # Setup adapter strategy + self.setup_adapter_strategy(adapter_mixin_strategies.ReturnResultAdapterStrategy()) + + def forward(self, x): + x, _ = self.linear(x) + return x + + +@dataclass +class MLPHeadAdapterConfig(AdapterConfig): + in_features: int + out_features: int + _target_: str = "{0}.{1}".format(MLPHeadAdapter.__module__, MLPHeadAdapter.__name__) + + class LoraKQVAdapter(ParallelLinearAdapter): """ Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes @@ -777,14 +829,21 @@ def set_inference_table(self, prompt_representation: torch.Tensor): self.is_inference_ready = True return True - def clear_inference_table(self): + def clear_inference_table( + self, + ): self.inference_table.fill_(0.0) self.is_inference_ready = False - def get_inference_table(self): + def get_inference_table( + self, + ): return self.inference_table.data - def inner_forward(self): + def inner_forward( + self, + ): + input_embeds = self.embedding(self.indices).unsqueeze(0) intermediate_parallel, bias_parallel = self.first(input_embeds) intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index 2bacaf52e3f8..90b3912784c8 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -30,8 +30,13 @@ HAVE_MEGATRON_CORE = False -from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import PromptEncoderAdapterConfig +from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( + MLPHeadAdapterConfig, + PromptEncoderAdapterConfig, +) + from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector + from nemo.collections.nlp.parts.peft_config import ( PEFT_CONFIG_MAP, CanonicalAdaptersPEFTConfig, @@ -168,7 +173,11 @@ def _check_and_add_peft_cfg(self, peft_cfg): for adapter_name, adapter_cfg in peft_cfg.get_config_dict().items(): # self.mcore_gpt means is GPT and not T5 - if hasattr(self, 'mcore_gpt') and not isinstance(adapter_cfg, PromptEncoderAdapterConfig): + if ( + hasattr(self, 'mcore_gpt') + and not isinstance(adapter_cfg, PromptEncoderAdapterConfig) + and not isinstance(adapter_cfg, MLPHeadAdapterConfig) + ): if layer_selection is not None: logging.info( f"Layer selection {layer_selection} is enabled for the current model (" @@ -351,8 +360,10 @@ def load_adapters( assert filepath.endswith( '.nemo' ), "Inferring peft scheme is only supported for .nemo checkpoints. Please supply the `peft_cfgs` argument." - peft_cfgs = [PEFT_CONFIG_MAP[conf.peft.peft_scheme](conf)] + peft_cfg_cls_lst = [PEFT_CONFIG_MAP[s] for s in conf.peft.peft_scheme.split(",")] + peft_cfgs = [_peft_cfg(conf) for _peft_cfg in peft_cfg_cls_lst] if getattr(self, 'megatron_amp_O2', False): + state_dict = {replace_prefix(k, 'model.', 'model.module.'): v for k, v in state_dict.items()} self.add_adapter(peft_cfgs) if not self.ptuning_only_and_non_first_stage: diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 726ca33611d7..25f303fc22fb 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -24,6 +24,7 @@ MCoreMLPMixin, MCoreSelfAttentionMixin, MCoreSequentialMLPMixin, + MCoreTransformerBlockMixin, MCoreTransformerLayerMixin, ) except (ImportError, ModuleNotFoundError): @@ -41,6 +42,7 @@ LoraMoeHto4HAdapterConfig, LoraUnfusedHto4HAdapterConfig, LoraUnfusedKQVAdapterConfig, + MLPHeadAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, ParallelLinearAdapterWeightTyingConfig, @@ -127,6 +129,21 @@ def __init__(self, cfg): self.tunable_base_param_names = selective_cfg.get("tunable_base_param_names", []) +class MLPHeadPEFTConfig(PEFTConfig): + def __init__(self, cfg): + config_args = {"in_features": cfg.hidden_size, "out_features": cfg.peft.mlp_head_tuning.out_features} + mlp_head_cfg = MLPHeadAdapterConfig(**config_args) + + name_key_to_cfg = { + AdapterName.MLP_HEAD_ADAPTER: mlp_head_cfg, + } + self.name_key_to_mcore_mixins = { + AdapterName.MLP_HEAD_ADAPTER: [("decoder", MCoreTransformerBlockMixin)], + } + + super().__init__(cfg.peft.mlp_head_tuning, name_key_to_cfg) + + class LoraPEFTConfig(PEFTConfig): def __init__(self, cfg): lora_cfg = cfg.peft.lora_tuning @@ -401,6 +418,7 @@ def __init__(self, cfg): "ia3": IA3PEFTConfig, "ptuning": PtuningPEFTConfig, "lora": LoraPEFTConfig, + "mlp_head": MLPHeadPEFTConfig, "qlora": QLoraPEFTConfig, "selective": SelectivePEFTConfig, 'none': None,