diff --git a/examples/nlp/language_modeling/megatron_retro_eval.py b/examples/nlp/language_modeling/megatron_retro_eval.py index e1abbfa5cb40..0f1ee65ce0fa 100644 --- a/examples/nlp/language_modeling/megatron_retro_eval.py +++ b/examples/nlp/language_modeling/megatron_retro_eval.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,26 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import datetime import os -import threading -from functools import partial - -import torch -from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer -from torch.utils.data import DataLoader, Dataset - -from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel -from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel -from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer -from nemo.collections.nlp.modules.common.text_generation_utils import generate + +from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet +from omegaconf.omegaconf import OmegaConf, open_dict +from pytorch_lightning import Trainer +from torch.utils.data import DataLoader + +from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam -from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector from nemo.core.config import hydra_runner -from nemo.utils.app_state import AppState -from nemo.utils.model_utils import inject_model_parallel_rank try: from megatron.core import parallel_state @@ -43,164 +34,62 @@ HAVE_MEGATRON_CORE = False """ -This is the script to run Retro text generation. +This is the script to run RETRO Model text generation. Usage: - Assume the model has TP=1, PP=1 in the following use cases. - a. run greedy inference from a nemo file: + Assume the model has TP=1, PP=1 + run greedy inference from a nemo file: python megatron_retro_eval.py \ - gpt_model_file=PATH_TO_MODEL \ - inference.greedy=True \ - inference.add_BOS=True \ trainer.devices=1 \ trainer.num_nodes=1 \ + trainer.accelerator=gpu \ + trainer.precision=16 \ + inference.tokens_to_generate=128 \ + inference.greedy=True \ + retro_model_file=path_to_retro_nemo_file \ tensor_model_parallel_size=-1 \ pipeline_model_parallel_size=-1 \ - prompts=[prompt1, prompt2] \ - inference.retro_inference.retro_num_neighbors=2 \ - neighbors=[[prompt1_neighbor1, prompt1_neighbor2], [prompt2_neighbor1, prompt2_neighbor2]] - - - ``` + retrieval_service.faiss_devices='0' \ + retrieval_service.faiss_index=path_to_faiss_index \ + retrieval_service.retrieval_index=path_to_retrieval_dataset \ + retrieval_service.neighbors=20 """ -if not torch.cuda.is_available(): - raise EnvironmentError("GPU is needed for the inference") - - -class RequestDataSet(Dataset): - def __init__(self, sentences, neighbors): - super().__init__() - self.sentences = sentences - self.neighbors = neighbors - - def __len__(self,): - return len(self.sentences) - - def __getitem__(self, idx): - return {'prompts': self.sentences[idx], 'neighbors': self.neighbors[idx]} - - -def remove_padded_prompts(response, nb_paddings): - result = {} - for k, v in response.items(): - if v != None and (type(v) is list or type(v) is torch.Tensor): - v = v[:-nb_paddings] - result[k] = v - return result - @hydra_runner(config_path="conf", config_name="megatron_retro_inference") def main(cfg) -> None: + trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) - # trainer required for restoring model parallel models - trainer = Trainer( - strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), - **cfg.trainer, - callbacks=[CustomProgressBar()], - ) + model_path = cfg.retro_model_file - if cfg.retro_model_file is not None: - if ( - cfg.tensor_model_parallel_size < 0 - or cfg.pipeline_model_parallel_size < 0 - or cfg.get('pipeline_model_parallel_split_rank', -1) < 0 - ): - save_restore_connector = NLPSaveRestoreConnector() - if os.path.isdir(cfg.retro_model_file): - save_restore_connector.model_extracted_dir = cfg.retro_model_file - model_config = MegatronRetroModel.restore_from( - restore_path=cfg.retro_model_file, - trainer=trainer, - return_config=True, - save_restore_connector=save_restore_connector, - ) - - # with dist checkpointing we don't need to set this - if not model_config.get('mcore_gpt', False): - with open_dict(cfg): - cfg.tensor_model_parallel_size = model_config.get('tensor_model_parallel_size', 1) - cfg.pipeline_model_parallel_size = model_config.get('pipeline_model_parallel_size', 1) - cfg.pipeline_model_parallel_split_rank = model_config.get('pipeline_model_parallel_split_rank', 0) - - assert ( - cfg.trainer.devices * cfg.trainer.num_nodes - == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size - ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" - - if cfg.retro_model_file: - save_restore_connector = NLPSaveRestoreConnector() - if os.path.isdir(cfg.retro_model_file): - save_restore_connector.model_extracted_dir = cfg.retro_model_file - - pretrained_cfg = MegatronRetroModel.restore_from( - restore_path=cfg.retro_model_file, - trainer=trainer, - return_config=True, - save_restore_connector=save_restore_connector, - ) - OmegaConf.set_struct(pretrained_cfg, True) - with open_dict(pretrained_cfg): - pretrained_cfg.sequence_parallel = False - pretrained_cfg.activations_checkpoint_granularity = None - pretrained_cfg.activations_checkpoint_method = None - pretrained_cfg.precision = trainer.precision - pretrained_cfg["use_flash_attention"] = cfg.inference.get("use_flash_attention", False) - if pretrained_cfg.get('mcore_gpt', False): - # with dist checkpointing we can use the model parallel config specified by the user - pretrained_cfg.tensor_model_parallel_size = cfg.tensor_model_parallel_size - pretrained_cfg.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size - if trainer.precision == "16": - pretrained_cfg.megatron_amp_O2 = False - elif trainer.precision in ['bf16', 'bf16-mixed'] and cfg.get('megatron_amp_O2', False): - pretrained_cfg.megatron_amp_O2 = True - model = MegatronRetroModel.restore_from( - restore_path=cfg.retro_model_file, - trainer=trainer, - override_config_path=pretrained_cfg, - save_restore_connector=save_restore_connector, - map_location=f'cuda:{trainer.local_rank}', # map_location is needed for converted models - ) - elif cfg.checkpoint_dir: - app_state = AppState() - if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1: - app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size - app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size - app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size - ( - app_state.tensor_model_parallel_rank, - app_state.pipeline_model_parallel_rank, - app_state.model_parallel_size, - app_state.data_parallel_size, - app_state.pipeline_model_parallel_split_rank, - app_state.virtual_pipeline_model_parallel_rank, - ) = fake_initialize_model_parallel( - world_size=app_state.model_parallel_size, - rank=trainer.global_rank, - tensor_model_parallel_size_=cfg.tensor_model_parallel_size, - pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, - pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, - ) - checkpoint_path = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name) - # checkpoint_path is a dir in case of distributed checkpointing - if not os.path.isdir(checkpoint_path): - # legacy checkpoint needs model parallel rank injection - checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) - model = MegatronRetroModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) - else: - raise ValueError("need at least a nemo file or checkpoint dir") + save_restore_connector = NLPSaveRestoreConnector() - # # DEBUGGING - # print("RETRO model loaded: ") - # print(model) + if os.path.isdir(model_path): + save_restore_connector.model_extracted_dir = model_path - model.freeze() + model_cfg = MegatronRetrievalModel.restore_from( + model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector, + ) - # Have to turn off activations_checkpoint_method for inference - try: - model.model.language_model.encoder.activations_checkpoint_method = None - except AttributeError: - pass + with open_dict(model_cfg): + model_cfg.precision = trainer.precision + model_cfg.sequence_parallel = False + model_cfg.activations_checkpoint_granularity = None + model_cfg.activations_checkpoint_method = None + + if ( + cfg.tensor_model_parallel_size < 0 + or cfg.pipeline_model_parallel_size < 0 + or cfg.get('pipeline_model_parallel_split_rank', -1) < 0 + ): + with open_dict(cfg): + cfg.tensor_model_parallel_size = model_cfg.get('tensor_model_parallel_size', 1) + cfg.pipeline_model_parallel_size = model_cfg.get('pipeline_model_parallel_size', 1) + cfg.pipeline_model_parallel_split_rank = model_cfg.get('pipeline_model_parallel_split_rank', 0) + + model = MegatronRetrievalModel.restore_from( + model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg, + ) length_params: LengthParam = { "max_length": cfg.inference.tokens_to_generate, @@ -216,39 +105,40 @@ def main(cfg) -> None: "add_BOS": cfg.inference.add_BOS, "all_probs": cfg.inference.all_probs, "compute_logprob": cfg.inference.compute_logprob, - "end_strings": cfg.inference.end_strings, } - # # DEBUGGING - # # Turn off first method for now, because both use text_generation_utils.generate(), and first method is more complicated - # # First method of running text generation, call model.generate method - # response = model.generate( - # inputs=OmegaConf.to_container(cfg.prompts), length_params=length_params, sampling_params=sampling_params - # ) - - # print("***************************") - # print(response) - # print("***************************") - + # check whether the DDP is initialized + if parallel_state.is_unitialized(): - # DEBUGGING - cfg.prompts = ["Hi, my name is Huy. What's your name?", "Today looks like a nice day. Do you think so?"] - cfg.neighbors = [["I am 28 years old.","Your name is Jeff."], ["You also think today is a nice day.","Yesterday was very nice too."]] + def dummy(): + return + if model.trainer.strategy.launcher is not None: + model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) + model.trainer.strategy.setup_environment() - # Second method of running text generation, call trainer.predict [recommended] - bs = 2 - prompts = cfg.prompts - neighbors = cfg.neighbors - ds = RequestDataSet(prompts, neighbors) - request_dl = DataLoader(dataset=ds, batch_size=bs) config = OmegaConf.to_container(cfg.inference) - model.set_inference_config(config) - response = trainer.predict(model, request_dl) + retrieval_service = OmegaConf.to_container(cfg.retrieval_service) + model.set_inference_config(config, retrieval_service) + + if not cfg.use_predict_method: + # First method of running text generation, call model.generate method + response = model.generate( + inputs=OmegaConf.to_container(cfg.prompts), + length_params=length_params, + sampling_params=sampling_params, + strategy=model.inference_strategy, + ) + else: + # Second method of running text generation, call trainer.predict + ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) + request_dl = DataLoader(dataset=ds, batch_size=cfg.inference_batch_size) + response = trainer.predict(model, request_dl) print("***************************") print(response) print("***************************") + if __name__ == '__main__': - main() # noqa pylint: disable=no-value-for-parameter + main() \ No newline at end of file