From 8f779babf33203f0ea42ebfcb3edc92fde5742d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Kami=C5=84ski?= <67481570+Laplasjan107@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:24:15 +0100 Subject: [PATCH] Add sample generate to PTQ for NeMo 2.0 (#11339) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Initial commit Signed-off-by: Piotr Kaminski * Remove leftover print Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 * Fix docs and type annotations Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 * Applied code review suggestions Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 * Fix _get_decoder_type parameter Signed-off-by: Piotr Kamiński <67481570+Laplasjan107@users.noreply.github.com> --------- Signed-off-by: Piotr Kaminski Signed-off-by: Laplasjan107 Signed-off-by: Piotr Kamiński <67481570+Laplasjan107@users.noreply.github.com> Co-authored-by: Piotr Kaminski Co-authored-by: Laplasjan107 --- .../collections/llm/quantization/quantizer.py | 113 ++++++++++-------- nemo/collections/llm/quantization/utils.py | 32 ++++- scripts/llm/ptq.py | 9 ++ 3 files changed, 103 insertions(+), 51 deletions(-) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 45f72f06741e..d41ba39f39ea 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -24,10 +24,12 @@ from tqdm import tqdm from nemo.collections import llm -from nemo.lightning.ckpt_utils import CONTEXT_PATH +from nemo.collections.llm.inference import MCoreTokenizerWrappper, generate +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.lightning.megatron_parallel import MegatronParallel from nemo.utils import logging -from .utils import get_unwrapped_mcore_model +from .utils import get_modelopt_decoder_type, get_unwrapped_mcore_model try: import modelopt.torch.quantization as mtq @@ -83,35 +85,12 @@ class ExportConfig: decoder_type: Optional[str] = None inference_tensor_parallel: int = 1 inference_pipeline_parallel: int = 1 + generate_sample: bool = False def __post_init__(self): self.path = Path(self.path) -def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: - """Infers the modelopt decoder type from GPTConfig class.""" - mapping = [ - (llm.Baichuan2Config, "baichuan"), - (llm.ChatGLMConfig, "chatglm"), - (llm.GemmaConfig, "gemma"), - (llm.LlamaConfig, "llama"), - (llm.MistralConfig7B, "llama"), - (llm.MixtralConfig, "llama"), - (llm.NemotronConfig, "gptnext"), - (llm.Qwen2Config, "qwen"), - # TODO: (llm.StarcoderConfig, ""), - (llm.Starcoder2Config, "gptnext"), - ] - - for config_class, decoder_type in mapping: - if isinstance(config, config_class): - return decoder_type - - logging.warning("Could not directly infer the decoder type") - # TODO: Add a reasonable behavior for GPTConfig (for instance based on position_embedding_type) - return "llama" - - class Quantizer: """Post-training quantization (PTQ) and TensorRT-LLM export of NeMo 2.0 checkpoints. @@ -146,16 +125,37 @@ def __init__(self, quantization_config: QuantizationConfig, export_config: Expor assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}" self.torch_dtype = torch_dtype_from_precision(dtype) - def _setup(self, model: llm.GPTModel) -> None: + @staticmethod + def _setup(model: MegatronParallel) -> None: """Setup model for quantization.""" # TODO: disable activation checkpointing model.config.vocab_size = model.tokenizer.vocab_size model.freeze() - def _get_decoder_type(self, config: llm.GPTConfig): - return self.export_config.decoder_type or get_modelopt_decoder_type(config) + def _get_decoder_type(self, model: MegatronParallel): + if self.export_config.decoder_type is not None: + return self.export_config.decoder_type + unwrapped_model = model + while not isinstance(unwrapped_model, llm.GPTModel): + unwrapped_model = unwrapped_model.module + + return get_modelopt_decoder_type(unwrapped_model) + + @staticmethod + def _generate_sample(model: MegatronParallel): + prompts = ["Born in north-east France, Soyer trained as a", "Born in California, Soyer trained as a"] + + mcore_tokenizer = MCoreTokenizerWrappper(model.tokenizer) + mcore_inference = model.get_inference_wrapper( + params_dtype=torch.bfloat16, inference_batch_times_seqlen_threshold=30 + ) + + generated = [r.generated_text for r in generate(mcore_inference, mcore_tokenizer, prompts)] + outputs = [prompt + generation for prompt, generation in zip(prompts, generated)] + + logging.info(f'Sample generation after PTQ (with prompts): {outputs}') - def quantize(self, model: llm.GPTModel, forward_loop=None): + def quantize(self, model: MegatronParallel, forward_loop=None): """Quantize the model and calibrate using given forward loop.""" if forward_loop is None: get_dataloader = create_data_iterator_getter( @@ -185,7 +185,7 @@ def quantize(self, model: llm.GPTModel, forward_loop=None): self._setup(model) unwrapped_model = get_unwrapped_mcore_model(model) - decoder_type = self._get_decoder_type(unwrapped_model.config) + decoder_type = self._get_decoder_type(model) quant_cfg = QUANT_CFG_CHOICES[algorithm] if "awq" in algorithm: weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] @@ -230,6 +230,10 @@ def quantize(self, model: llm.GPTModel, forward_loop=None): if dist.get_rank() == 0: mtq.print_quant_summary(unwrapped_model) + if self.export_config.generate_sample: + logging.info("Generating a sample output after model quantization.") + self._generate_sample(model) + return model def create_megatron_forward_loop( @@ -266,21 +270,34 @@ def loop(model): return loop - def export(self, model: llm.GPTModel, model_dir: str) -> None: + @staticmethod + def _validate_quantized_checkpoint(checkpoint_dir: Path, tensor_parallelism_size: int) -> bool: + """Basic validation of the model structure.""" + + saved_config = (checkpoint_dir / 'config.json').exists() + saved_weights = True + for i in range(tensor_parallelism_size): + saved_weights &= (checkpoint_dir / f'rank{i}.safetensors').exists() + + export_successful = saved_config and saved_weights + if not export_successful: + logging.error("Failed to export the quantized model.") + return export_successful + + def export(self, model: MegatronParallel, model_dir: str) -> None: """Export model to a TensorRT-LLM checkpoint.""" - assert self.export_config is not None, "Export config is not set" - # TODO: Add sample generate - # TODO: Support megatron_amp_O2 export_dir = self.export_config.path + inference_tp = self.export_config.inference_tensor_parallel + inference_pp = self.export_config.inference_pipeline_parallel use_nfs_workspace = model.config.pipeline_model_parallel_size > 1 export_tensorrt_llm_checkpoint( model=get_unwrapped_mcore_model(model), - decoder_type=self._get_decoder_type(model.config), + decoder_type=self._get_decoder_type(model), dtype=self.torch_dtype, export_dir=export_dir, - inference_tensor_parallel=self.export_config.inference_tensor_parallel, - inference_pipeline_parallel=self.export_config.inference_pipeline_parallel, + inference_tensor_parallel=inference_tp, + inference_pipeline_parallel=inference_pp, use_nfs_workspace=use_nfs_workspace, ) dist.barrier() @@ -288,14 +305,13 @@ def export(self, model: llm.GPTModel, model_dir: str) -> None: # Save the model context in order to restore its tokenizer later. The destination # path is "nemo_context" as this name is used in nemo.export to setup tokenizer. if dist.get_rank() == 0: + assert self._validate_quantized_checkpoint(export_dir, inference_tp) shutil.copytree( - os.path.join(model_dir, CONTEXT_PATH), + ckpt_to_context_subdir(model_dir), os.path.join(export_dir, "nemo_context"), dirs_exist_ok=True, ) - logging.info("Model context saved.") - - logging.info(f"Export succeeded, model has been exported to {export_dir}.") + logging.info(f"Export succeeded, model has been exported to {export_dir}.") def get_calib_data_iter( @@ -323,7 +339,7 @@ def get_calib_data_iter( def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): """Create a function that provides iterator over a given dataset.""" - def _iterator(): + def _get_iterator(): CHARACTERS_PER_TOKEN = 4 dataloader = get_calib_data_iter( @@ -332,14 +348,13 @@ def _iterator(): batch_size=batch_size, calib_size=calibration_size, ) + + data = [] for batch in dataloader: batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch] batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch] - yield torch.tensor(batch, device=model.device) + data.append(torch.tensor(batch, device=model.device)) - def _iterator_getter(): - dataloader = _iterator() - dataloader = [data for data in dataloader] - return iter(tqdm(dataloader)) + return iter(tqdm(data)) - return _iterator_getter + return _get_iterator diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py index bdfccb208d06..20739c872e80 100644 --- a/nemo/collections/llm/quantization/utils.py +++ b/nemo/collections/llm/quantization/utils.py @@ -23,8 +23,33 @@ from nemo.utils import logging +def get_modelopt_decoder_type(model: llm.GPTModel) -> str: + """Infers the modelopt decoder type from GPTModel subclass.""" + mapping = [ + (llm.Baichuan2Model, "baichuan"), + (llm.ChatGLMModel, "chatglm"), + (llm.Gemma2Model, "gemma2"), + (llm.GemmaModel, "gemma"), + (llm.LlamaModel, "llama"), + (llm.MistralModel, "llama"), + (llm.MixtralModel, "llama"), + (llm.NemotronModel, "gptnext"), + (llm.Qwen2Model, "qwen"), + (llm.StarcoderModel, "gptnext"), + (llm.Starcoder2Model, "gptnext"), + (llm.Phi3Model, "phi3"), + ] + + for config_class, decoder_type in mapping: + if isinstance(model, config_class): + return decoder_type + + logging.warning("Could not infer the decoder type") + return None + + def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: - """Modify model config for TensorRT Model Optimizer""" + """Modify model config for TensorRT-Model-Optimizer quantization""" from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import ( get_gpt_layer_modelopt_spec, @@ -46,7 +71,9 @@ def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: def load_with_modelopt_layer_spec( nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1, inference_only: bool = True ): - # TODO: setting ddp="pytorch" with manually deleting model.optim is a hackish way to disable DDP initialization. Needs a systematic solution. + """Loads a model from a NeMo 2.0 checkpoint using modelopt layer spec.""" + # TODO: setting ddp="pytorch" and deleting model.optim is a hackish way to disable DDP initialization. + # Needs a systematic solution. if inference_only: strategy = nl.MegatronStrategy( tensor_model_parallel_size=calib_tp, @@ -81,6 +108,7 @@ def load_with_modelopt_layer_spec( def get_unwrapped_mcore_model(model): + """Unwraps NeMo 2.0 to base MCore model.""" from megatron.core.models.gpt import GPTModel as MCoreGPTModel unwrapped_model = model diff --git a/scripts/llm/ptq.py b/scripts/llm/ptq.py index c04d32290e5f..2afe38c37b4d 100644 --- a/scripts/llm/ptq.py +++ b/scripts/llm/ptq.py @@ -17,6 +17,8 @@ def get_args(): + """Parses PTQ arguments""" + parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="NeMo PTQ argument parser", @@ -58,6 +60,10 @@ def get_args(): type=str, help='Calibration dataset to be used. Should be \"wikitext\", \"cnn_dailymail\" or path to a local .json file', ) + parser.add_argument( + '--generate_sample', help='Generate sample model output after performing PTQ', action='store_true' + ) + parser.set_defaults(generate_sample=False) args = parser.parse_args() if args.output_path is None: @@ -68,6 +74,8 @@ def get_args(): def main(): + """Example NeMo 2.0 Post Training Quantization workflow""" + args = get_args() quantization_config = quantization.QuantizationConfig( @@ -87,6 +95,7 @@ def main(): inference_tensor_parallel=args.tensor_parallelism_size, inference_pipeline_parallel=args.pipeline_parallelism_size, dtype=args.dtype, + generate_sample=args.generate_sample, ) quantizer = quantization.Quantizer(quantization_config, export_config)