From f27a982ed192f8039d5a30b51208847c9c1b0fc3 Mon Sep 17 00:00:00 2001 From: Huy Vu <86480512+huvunvidia@users.noreply.github.com> Date: Mon, 28 Oct 2024 13:05:32 -0400 Subject: [PATCH 1/4] Generalizing Inference pipeline in NeMo 2.0 to support encoder-decoder models (#10924) * initial commit * adding example t5_generate.py * workable inference code * updating code * update cpde * workable solution for T5 tokenizer (we add 100 sentinel tokens when initializing tokenizer throug setting config, instead of adding after initialization) * separate autokenizer's changes to another PR * cleaning code * addressing Marc's comments * addressing Marc's reviews * update code after merge * small fix * Apply isort and black reformatting Signed-off-by: huvunvidia --------- Signed-off-by: huvunvidia Co-authored-by: Huy Vu2 Co-authored-by: root Co-authored-by: huvunvidia --- .../tokenizers/huggingface/auto_tokenizer.py | 7 +- nemo/collections/llm/api.py | 4 + nemo/collections/llm/gpt/model/base.py | 22 ++++ nemo/collections/llm/inference/base.py | 63 +++++----- nemo/collections/llm/t5/model/t5.py | 22 ++++ scripts/llm/t5_generate.py | 108 ++++++++++++++++++ 6 files changed, 196 insertions(+), 30 deletions(-) create mode 100644 scripts/llm/t5_generate.py diff --git a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py index 439322b8e810..43d377b73f34 100644 --- a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py +++ b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py @@ -224,9 +224,12 @@ def text_to_ids(self, text): ids = self.tokens_to_ids(tokens) return ids - def ids_to_text(self, ids): + def ids_to_text(self, ids, remove_special_tokens=True): tokens = self.ids_to_tokens(ids) - tokens_clean = [t for t in tokens if t not in self.tokenizer.all_special_tokens] + if remove_special_tokens: + tokens_clean = [t for t in tokens if t not in self.tokenizer.all_special_tokens] + else: + tokens_clean = tokens text = self.tokens_to_text(tokens_clean) return text diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index a9b3d4361f5b..c4913e07da9b 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -437,7 +437,9 @@ def generate( path: Union[Path, str], prompts: list[str], trainer: nl.Trainer, + encoder_prompts: Optional[list[str]] = None, params_dtype: torch.dtype = torch.bfloat16, + add_BOS: bool = False, max_batch_size: int = 4, random_seed: Optional[int] = None, inference_batch_times_seqlen_threshold: int = 1000, @@ -456,6 +458,8 @@ def generate( model=inference_wrapped_model, tokenizer=mcore_tokenizer, prompts=prompts, + encoder_prompts=encoder_prompts, + add_BOS=add_BOS, max_batch_size=max_batch_size, random_seed=random_seed, inference_params=inference_params, diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index c7a6e01c673e..a7823c9bee80 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -18,6 +18,8 @@ import pytorch_lightning as L import torch import torch.distributed +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig @@ -310,6 +312,26 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: return self.forward_step(batch) + def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_threshold) -> torch.Tensor: + # This is to get the MCore model required in GPTInferenceWrapper. + mcore_model = self.module + while mcore_model: + if type(mcore_model) is MCoreGPTModel: + break + mcore_model = getattr(mcore_model, "module", None) + if mcore_model is None or type(mcore_model) is not MCoreGPTModel: + raise ValueError("Exact McoreGPTModel instance not found in the model structure.") + + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=mcore_model.config.hidden_size, + params_dtype=params_dtype, + inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, + padded_vocab_size=self.tokenizer.vocab_size, + ) + + model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config) + return model_inference_wrapper + @property def training_loss_reduction(self) -> MaskedTokenLossReduction: if not self._training_loss_reduction: diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index 9c4da7940b70..f3d202451c60 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -21,12 +21,16 @@ import torch.distributed from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.engines.mcore_engine import MCoreEngine -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) +from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import ( + EncoderDecoderTextGenerationController, +) from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( SimpleTextGenerationController, ) -from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel +from megatron.core.transformer.module import MegatronModule from pytorch_lightning.trainer.states import TrainerFn import nemo.lightning as nl @@ -37,19 +41,31 @@ from nemo.lightning.pytorch.strategies.utils import RestoreConfig -# We need this wrapper since mcore generate uses tokenizer.detokenize, tokenizer.tokenize to encode and decode prompts +# We need this wrapper since mcore generate uses methods/properties such as tokenizer.detokenize, tokenizer.tokenize, tokenizer.bos, tokenizer.pad, etc. to encode and decode prompts class MCoreTokenizerWrappper: def __init__(self, tokenizer): self.tokenizer = tokenizer self.eod = tokenizer.eod self.vocab_size = tokenizer.vocab_size - def detokenize(self, tokens): - return self.tokenizer.ids_to_text(tokens) + def detokenize(self, tokens, remove_special_tokens=False): + return self.tokenizer.ids_to_text(tokens, remove_special_tokens) def tokenize(self, prompt): return self.tokenizer.text_to_ids(prompt) + @property + def additional_special_tokens_ids(self): + return self.tokenizer.additional_special_tokens_ids + + @property + def bos(self): + return self.tokenizer.bos_id + + @property + def pad(self): + return self.tokenizer.pad_id + # TODO: Move to lightning Fabric API. def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.LightningModule): @@ -101,41 +117,30 @@ def setup_model_and_tokenizer( trainer: nl.Trainer, params_dtype: torch.dtype = torch.bfloat16, inference_batch_times_seqlen_threshold: int = 1000, -) -> tuple[MCoreGPTModel, MCoreTokenizerWrappper]: +) -> tuple[MegatronModule, MCoreTokenizerWrappper]: model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model") _setup_trainer_and_restore_model(path=path, trainer=trainer, model=model) - # This is to get the MCore model required in GPTInferenceWrapper. - mcore_model = model - while mcore_model: - if type(mcore_model) is MCoreGPTModel: - break - mcore_model = getattr(mcore_model, "module", None) - if mcore_model is None or type(mcore_model) is not MCoreGPTModel: - raise ValueError("Exact McoreGPTModel instance not found in the model structure.") - - inference_wrapped_model = GPTInferenceWrapper( - mcore_model, - InferenceWrapperConfig( - hidden_size=mcore_model.config.hidden_size, - params_dtype=params_dtype, - inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, - padded_vocab_size=model.tokenizer.vocab_size, - ), - ) - + inference_wrapped_model = model.get_inference_wrapper(params_dtype, inference_batch_times_seqlen_threshold) return inference_wrapped_model, MCoreTokenizerWrappper(model.tokenizer) def generate( - model: GPTInferenceWrapper, + model: AbstractModelInferenceWrapper, tokenizer: MCoreTokenizerWrappper, prompts: list[str], + encoder_prompts: Optional[list[str]] = None, + add_BOS: bool = False, max_batch_size: int = 4, random_seed: Optional[int] = None, inference_params: Optional[CommonInferenceParams] = None, ) -> dict: - text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=model, tokenizer=tokenizer) + if encoder_prompts is not None: + text_generation_controller = EncoderDecoderTextGenerationController( + inference_wrapped_model=model, tokenizer=tokenizer + ) + else: + text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=model, tokenizer=tokenizer) mcore_engine = MCoreEngine( text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed ) @@ -144,6 +149,8 @@ def generate( results = mcore_engine.generate( prompts=prompts, + add_BOS=add_BOS, + encoder_prompts=encoder_prompts, common_inference_params=common_inference_params, ) diff --git a/nemo/collections/llm/t5/model/t5.py b/nemo/collections/llm/t5/model/t5.py index fa4095d75588..e6970cba3dd8 100644 --- a/nemo/collections/llm/t5/model/t5.py +++ b/nemo/collections/llm/t5/model/t5.py @@ -19,6 +19,8 @@ import pytorch_lightning as L import torch import torch.distributed +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig +from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig @@ -258,6 +260,26 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: return self.forward_step(batch) + def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_threshold) -> torch.Tensor: + # This is to get the MCore model required in T5InferenceWrapper. + mcore_model = self.module + while mcore_model: + if type(mcore_model) is MCoreT5Model: + break + mcore_model = getattr(mcore_model, "module", None) + if mcore_model is None or type(mcore_model) is not MCoreT5Model: + raise ValueError("Exact MCoreT5Model instance not found in the model structure.") + + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=mcore_model.config.hidden_size, + params_dtype=params_dtype, + inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, + padded_vocab_size=self.tokenizer.vocab_size, + ) + + model_inference_wrapper = T5InferenceWrapper(mcore_model, inference_wrapper_config) + return model_inference_wrapper + @property def training_loss_reduction(self) -> MaskedTokenLossReduction: if not self._training_loss_reduction: diff --git a/scripts/llm/t5_generate.py b/scripts/llm/t5_generate.py new file mode 100644 index 000000000000..917fca6e1dfe --- /dev/null +++ b/scripts/llm/t5_generate.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: This script is just an example of using NeMo checkpoints for generating outputs and is subject to change without notice. + +import argparse +import torch +import torch.distributed +from megatron.core.inference.common_inference_params import CommonInferenceParams + +import nemo.lightning as nl +from nemo.collections.llm import api + + +def get_args(): + parser = argparse.ArgumentParser(description='Train a small T5 model using NeMo 2.0') + parser.add_argument('--devices', type=int, help="Number of devices to use for training.") + parser.add_argument('--checkpoint-path', type=str, help="Path to trained model.") + parser.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') + parser.add_argument("--top_k", type=int, default=1, help='Top k sampling.') + parser.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') + parser.add_argument( + "--num-tokens-to-generate", type=int, default=30, help='Number of tokens to generate for each prompt.' + ) + parser.add_argument( + "--prompts", + metavar='N', + type=str, + nargs='+', + help='Prompts with each prompt within quotes and seperated by space.', + ) + parser.add_argument( + "--encoder-prompts", + metavar='N', + type=str, + nargs='+', + help='Encoder input prompts with each prompt within quotes and seperated by space.', + ) + parser.add_argument("--max-batch-size", type=int, default=1, help='Max number of prompts to process at once.') + + return parser.parse_args() + + +if __name__ == "__main__": + + args = get_args() + + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + sequence_parallel=False, + setup_optimizers=False, + store_optimizer_states=False, + ) + + trainer = nl.Trainer( + accelerator="gpu", + devices=args.devices, + num_nodes=1, + strategy=strategy, + plugins=nl.MegatronMixedPrecision( + precision="bf16-mixed", + params_dtype=torch.bfloat16, + pipeline_dtype=torch.bfloat16, + autocast_enabled=False, + grad_reduce_in_fp32=False, + ), + ) + prompts = [ + "", + "", + "", + ] + encoder_prompts = [ + "Hello, how are ?", + "How many r's are in the 'strawberry'?", + "Which number is ? 10.119 10.19?", + ] + + results = api.generate( + path=args.checkpoint_path, + prompts=prompts, + encoder_prompts=encoder_prompts, + trainer=trainer, + add_BOS=True, + inference_params=CommonInferenceParams( + temperature=args.temperature, top_k=args.top_k, num_tokens_to_generate=args.num_tokens_to_generate + ), + text_only=True, + ) + if torch.distributed.get_rank() == 0: + for i, r in enumerate(results): + print(prompts[i]) + print("*" * 50) + print(r) + print("\n\n") From 869625e59e1058a87fd6cd5947e36bcf21d3da4e Mon Sep 17 00:00:00 2001 From: guyueh1 <140554423+guyueh1@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:56:16 -0700 Subject: [PATCH 2/4] [Bug fix] use default_factory to provide default value in dataclasses in energon MultiModalSampleConfig (#11041) Signed-off-by: Guyue Huang Co-authored-by: Guyue Huang --- nemo/collections/multimodal/data/energon/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/multimodal/data/energon/config.py b/nemo/collections/multimodal/data/energon/config.py index ab8ecf9fc06d..45ca8e9db800 100644 --- a/nemo/collections/multimodal/data/energon/config.py +++ b/nemo/collections/multimodal/data/energon/config.py @@ -64,7 +64,7 @@ class LLaVATemplateConfig(BaseConversationTemplateConfig): @dataclass class MultiModalSampleConfig: - image_token: ImageToken = ImageToken() + image_token: ImageToken = field(default_factory=ImageToken) ignore_place_holder: int = -100 - conversation_template_config: LLaVATemplateConfig = LLaVATemplateConfig() + conversation_template_config: LLaVATemplateConfig = field(default_factory=LLaVATemplateConfig) image_following_text: bool = True From 60cce8d4470bd7ab0f8fb84e150fe082d45c5b6a Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 29 Oct 2024 00:48:16 +0530 Subject: [PATCH 3/4] Update ModelOpt Width Pruning example defaults (#10902) * update width pruning example defaults Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> * Update Dockerfile.ci Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> * Undo CI version update Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../conf/megatron_gpt_prune.yaml | 19 +++++++----- .../language_modeling/megatron_gpt_prune.py | 29 ++++++++++--------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml index cb26d5744b5b..f174aafed0ee 100644 --- a/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml @@ -23,19 +23,22 @@ trainer: model: tensor_model_parallel_size: 1 # Pruning currently only supports tensor_model_parallel_size=1 pipeline_model_parallel_size: 1 - restore_from_path: llama3.1-8b-base.nemo # Nemo file path + sequence_parallel: false # Sequence parallelism is not supported with pipeline parallelism + restore_from_path: llama3.1-8b-instruct.nemo # Nemo file path ## Activation Checkpoint activations_checkpoint_granularity: null # 'selective' or 'full' activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' prune: - calib_dataset: cnn_dailymail # wikitext, cnn_dailymail, or a local dataset - num_calib_size: 512 # number of samples used for calibration - ffn_hidden_size: 3584 # ffn_hidden_size in the pruned model, ffn_hidden_size // 4 - num_attention_heads: 8 # num_attention_heads in the pruned model, num_attention_heads // 4 - num_query_groups: 4 # num_query_groups in the pruned model, num_query_groups // 2 - hidden_size: 2048 # hidden_size in the pruned model, hidden_size // 2 + calib_dataset: wikitext # wikitext, cnn_dailymail, or a local dataset + num_calib_size: 1024 # number of samples used for calibration + # pruning constraints (null means no pruning) + ffn_hidden_size: 9216 # ffn_hidden_size in the pruned model + num_attention_heads: null # num_attention_heads in the pruned model + num_query_groups: null # num_query_groups in the pruned model + hidden_size: 3072 # hidden_size (embedding size) in the pruned model + num_layers: null # num_layers (depth) in the pruned model export: - save_path: llama3.1-8b-base-pruned.nemo # Path where the pruned model will be saved + save_path: llama3.1-8b-instruct-pruned.nemo # Path where the pruned model will be saved diff --git a/examples/nlp/language_modeling/megatron_gpt_prune.py b/examples/nlp/language_modeling/megatron_gpt_prune.py index b9bf8edbfb1a..de12b861a1c0 100644 --- a/examples/nlp/language_modeling/megatron_gpt_prune.py +++ b/examples/nlp/language_modeling/megatron_gpt_prune.py @@ -36,23 +36,23 @@ Example usage: ``` python examples/nlp/language_modeling/megatron_gpt_prune.py \ - model.restore_from_path=llama3.1-8b-base.nemo \ + model.restore_from_path=llama3.1-8b-instruct.nemo \ model.tensor_model_parallel_size=1 \ model.pipeline_model_parallel_size=8 \ trainer.num_nodes=1 \ trainer.precision=bf16 \ trainer.devices=8 \ - prune.ffn_hidden_size=3584 \ - prune.num_attention_heads=8 \ - prune.num_query_groups=4 \ - prune.hidden_size=2048 \ - export.save_path=llama3.1-8b-base-pruned.nemo + prune.ffn_hidden_size=9216 \ + prune.num_attention_heads=null \ + prune.num_query_groups=null \ + prune.hidden_size=3072 \ + export.save_path=llama3.1-8b-instruct-pruned.nemo ``` where tensor_model_parallel_size must be 1 because of the current prune API limitation """ -def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512): +def get_calib_data_iter(data="wikitext", batch_size=64, calib_size=512, max_sequence_length=512): if data == "wikitext": dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") text_column = "text" @@ -73,18 +73,12 @@ def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max @hydra_runner(config_path="conf", config_name="megatron_gpt_prune") def main(cfg) -> None: - if not torch.cuda.is_available(): - raise EnvironmentError("GPU is required for the pruning.") - # Overwrite model config with the one from the model checkpoint and apply pruning modifications model_cfg = load_config(cfg.model.restore_from_path) model_cfg.update(cfg.model) model_cfg.name = "modelopt" # Use modelopt transformer spec for pruning assert cfg.model.tensor_model_parallel_size == 1, "Pruning currently only supports tensor_model_parallel_size=1" - assert ( - not hasattr(cfg.model, "sequence_parallel") or not cfg.model.sequence_parallel - ), "Pruning currently does not support sequence parallelism" trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) model = MegatronGPTModel.restore_from( @@ -112,7 +106,13 @@ def forward_loop(model): constraints={ "export_config": { k: cfg.prune.get(k) - for k in ["ffn_hidden_size", "num_attention_heads", "num_query_groups", "hidden_size"] + for k in [ + "ffn_hidden_size", + "num_attention_heads", + "num_query_groups", + "hidden_size", + "num_layers", + ] if cfg.prune.get(k) is not None }, }, @@ -121,6 +121,7 @@ def forward_loop(model): ) model_pruned.save_to(cfg.export.save_path) + print(f"Pruned model saved to {cfg.export.save_path}") if __name__ == '__main__': From 5d3dadb419463a1feea6cb1f517d24c708c8f9ea Mon Sep 17 00:00:00 2001 From: Michal Futrega Date: Mon, 28 Oct 2024 22:34:32 +0100 Subject: [PATCH 4/4] fix: Resolve mutable default issue in MultiModalSampleConfig dataclass (#11061)