From f02f36c6241202a741d917e888f203ae5fea6f63 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Thu, 26 Sep 2024 12:03:32 -0700 Subject: [PATCH 01/24] initial commit Signed-off-by: Piotr Kaminski --- examples/llm/quantization/ptq.py | 173 +++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 examples/llm/quantization/ptq.py diff --git a/examples/llm/quantization/ptq.py b/examples/llm/quantization/ptq.py new file mode 100644 index 000000000000..bc95e409731e --- /dev/null +++ b/examples/llm/quantization/ptq.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys + +import torch +from tqdm import tqdm +from datasets import load_dataset + +from nemo import lightning as nl +from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec +import nemo.collections.llm as llm + +HAS_MODELOPT = True +try: + import modelopt.torch.quantization as mtq + from modelopt.torch.export import export_tensorrt_llm_checkpoint +except: + HAS_MODELOPT = False + + +# Sample hyperparameters +MODEL_HPARAMS = { + "llama3": ('meta-llama/Meta-Llama-3-8B', "llama", llm.LlamaModel, llm.Llama3Config8B), + "llama3-70b": ('meta-llama/Meta-Llama-3-70B', "llama", llm.LlamaModel, llm.Llama3Config70B), + "mistral": ('mistralai/Mistral-7B-Instruct-v0.3', "llama", llm.MistralModel, llm.MistralConfig7B), + "gemma": ('google/gemma-2b', "gemma", llm.GemmaModel, llm.GemmaConfig2B), +} + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="NeMo PTQ argument parser", + ) + parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source NeMo 2.0 checkpoint") + parser.add_argument( + "-tps", + "--tensor_parallelism_size", + type=int, + default=1 + ) + parser.add_argument( + '-id', + '--model_id', + type=str, + required=True, + choices=list(MODEL_HPARAMS.keys()), + help='Model id for MODEL_HPARAMS map' + ) + parser.add_argument( + '-out', + '--output_path', + type=str, + default="", + help='Path for the exported engine' + ) + + args = parser.parse_args(sys.argv[1:]) + if args.output_path == "": + args.output_path = f'./trt_llm_fp8_ckpt-{args.model_id}-tp{args.tensor_parallelism_size}' + + args.name, args.modelopt_type, args.model, args.config = MODEL_HPARAMS[args.model_id] + return args + + +# TODO: Unify implementation with examples/nlp/language_modeling/megatron_gpt_ptq.py +def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512): + if data == "wikitext": + dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") + text_column = "text" + elif data == "cnn_dailymail": + dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") + text_column = "article" + else: + # Assume a local JSON dataset with a column named "text" + dataset = load_dataset("json", data_files=data, split="train") + text_column = "text" + calib_size = max(min(len(dataset), calib_size), batch_size) + for i in range(calib_size // batch_size): + batch = dataset[i * batch_size : (i + 1) * batch_size][text_column] + for j in range(len(batch)): + batch[j] = batch[j][:max_sequence_length] + yield batch + + +# TODO: generalize +def save_tokenizer_config(model_name: str, output_path: str): + from os.path import isfile + tokenizer_path = output_path + '/tokenizer_config.yaml' + if not isfile(tokenizer_path): + with open(tokenizer_path, 'w') as tokenizer_config: + tokenizer_config.write(f"library: huggingface\ntype: {model_name}\nuse_fast: true\nvocab_file: null\nmerge_file: null") + + +# TODO: use llm.generate (#10471) once merged +def forward_loop(model): + tokenizer = model.tokenizer + dataloader = get_calib_data_iter() + dataloader = [data for data in dataloader] + for batch in tqdm(dataloader): + batch = [tokenizer.text_to_ids(text) for text in batch] + max_len = max([len(text) for text in batch]) + batch = [ids + (max_len - len(ids)) * [tokenizer.eos] for ids in batch] + position_ids = torch.arange(max_len, device=model.device).expand((len(batch), max_len)) + batch = torch.tensor(batch, device=model.device) + model_input = { + "input_ids": batch, + "position_ids": position_ids, + "attention_mask": None, + } + model(**model_input) + + +if __name__ == '__main__': + if not HAS_MODELOPT: + print("Modelopt could not be imported") + exit(1) + + args = get_args() + + # TODO: make/extend the Quantizer class from nemo.export.quantizer + # Configure global state + trainer = nl.Trainer( + devices=args.tensor_parallelism_size, + strategy=nl.MegatronStrategy(tensor_model_parallel_size=args.tensor_parallelism_size), + plugins=nl.MegatronMixedPrecision(precision='16-mixed') + ) + fabric = trainer.to_fabric() + trainer.strategy.setup_environment() + + # Load model with modelopt layer spec + model = nl.io.load_context(args.nemo_checkpoint).model + model.config.transformer_layer_spec = get_gpt_layer_modelopt_spec() + + # TODO: [0] works only for PP = 1. Will be changed when PP support is added. + model = fabric.load_model(args.nemo_checkpoint, model=model)[0] + + # TODO: allow other configs + atq_config = mtq.FP8_DEFAULT_CFG + enable_quant_kv_cache = True + print(f'{"Enable" if enable_quant_kv_cache else "Disable"} KV cache quantization') + atq_config["quant_cfg"]["*output_quantizer"] = { # type: ignore[index] + "num_bits": (4, 3), + "axis": None, + "enable": enable_quant_kv_cache, + } + + model = mtq.quantize(model, atq_config, forward_loop) + mtq.print_quant_summary(model) + export_tensorrt_llm_checkpoint( + model, + args.modelopt_type, + torch.float16, + export_dir=args.output_path, + inference_tensor_parallel=args.tensor_parallelism_size, + inference_pipeline_parallel=1, + use_nfs_workspace=False, + ) + save_tokenizer_config(args.name, args.output_path) + From b17fe3c7a7a3acd48eef481234ded079d9fc2131 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Mon, 30 Sep 2024 02:41:48 -0700 Subject: [PATCH 02/24] create Quantizer for NeMo 2.0 Signed-off-by: Piotr Kaminski --- examples/llm/quantization/ptq.py | 148 +------ nemo/collections/llm/__init__.py | 1 + nemo/collections/llm/quantization/__init__.py | 15 + .../collections/llm/quantization/quantizer.py | 395 ++++++++++++++++++ nemo/export/trt_llm/qnemo/tokenizer_utils.py | 7 +- 5 files changed, 433 insertions(+), 133 deletions(-) create mode 100644 nemo/collections/llm/quantization/__init__.py create mode 100644 nemo/collections/llm/quantization/quantizer.py diff --git a/examples/llm/quantization/ptq.py b/examples/llm/quantization/ptq.py index bc95e409731e..d961c35f8c7c 100644 --- a/examples/llm/quantization/ptq.py +++ b/examples/llm/quantization/ptq.py @@ -12,104 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import sys - import torch from tqdm import tqdm -from datasets import load_dataset - -from nemo import lightning as nl -from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec -import nemo.collections.llm as llm - -HAS_MODELOPT = True -try: - import modelopt.torch.quantization as mtq - from modelopt.torch.export import export_tensorrt_llm_checkpoint -except: - HAS_MODELOPT = False - - -# Sample hyperparameters -MODEL_HPARAMS = { - "llama3": ('meta-llama/Meta-Llama-3-8B', "llama", llm.LlamaModel, llm.Llama3Config8B), - "llama3-70b": ('meta-llama/Meta-Llama-3-70B', "llama", llm.LlamaModel, llm.Llama3Config70B), - "mistral": ('mistralai/Mistral-7B-Instruct-v0.3', "llama", llm.MistralModel, llm.MistralConfig7B), - "gemma": ('google/gemma-2b', "gemma", llm.GemmaModel, llm.GemmaConfig2B), -} - - -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - description="NeMo PTQ argument parser", - ) - parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source NeMo 2.0 checkpoint") - parser.add_argument( - "-tps", - "--tensor_parallelism_size", - type=int, - default=1 - ) - parser.add_argument( - '-id', - '--model_id', - type=str, - required=True, - choices=list(MODEL_HPARAMS.keys()), - help='Model id for MODEL_HPARAMS map' - ) - parser.add_argument( - '-out', - '--output_path', - type=str, - default="", - help='Path for the exported engine' - ) - - args = parser.parse_args(sys.argv[1:]) - if args.output_path == "": - args.output_path = f'./trt_llm_fp8_ckpt-{args.model_id}-tp{args.tensor_parallelism_size}' - - args.name, args.modelopt_type, args.model, args.config = MODEL_HPARAMS[args.model_id] - return args +from nemo.collections.llm.quantization import Quantizer, get_calib_data_iter +# TODO: Support PP +# TODO: Inference TP/PP != Calibration TP/PP -# TODO: Unify implementation with examples/nlp/language_modeling/megatron_gpt_ptq.py -def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512): - if data == "wikitext": - dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") - text_column = "text" - elif data == "cnn_dailymail": - dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") - text_column = "article" - else: - # Assume a local JSON dataset with a column named "text" - dataset = load_dataset("json", data_files=data, split="train") - text_column = "text" - calib_size = max(min(len(dataset), calib_size), batch_size) - for i in range(calib_size // batch_size): - batch = dataset[i * batch_size : (i + 1) * batch_size][text_column] - for j in range(len(batch)): - batch[j] = batch[j][:max_sequence_length] - yield batch - - -# TODO: generalize -def save_tokenizer_config(model_name: str, output_path: str): - from os.path import isfile - tokenizer_path = output_path + '/tokenizer_config.yaml' - if not isfile(tokenizer_path): - with open(tokenizer_path, 'w') as tokenizer_config: - tokenizer_config.write(f"library: huggingface\ntype: {model_name}\nuse_fast: true\nvocab_file: null\nmerge_file: null") - - -# TODO: use llm.generate (#10471) once merged +# TODO: maybe use llm.generate (#10471) def forward_loop(model): tokenizer = model.tokenizer dataloader = get_calib_data_iter() dataloader = [data for data in dataloader] + for batch in tqdm(dataloader): batch = [tokenizer.text_to_ids(text) for text in batch] max_len = max([len(text) for text in batch]) @@ -124,50 +40,18 @@ def forward_loop(model): model(**model_input) -if __name__ == '__main__': - if not HAS_MODELOPT: - print("Modelopt could not be imported") - exit(1) - - args = get_args() - - # TODO: make/extend the Quantizer class from nemo.export.quantizer - # Configure global state - trainer = nl.Trainer( - devices=args.tensor_parallelism_size, - strategy=nl.MegatronStrategy(tensor_model_parallel_size=args.tensor_parallelism_size), - plugins=nl.MegatronMixedPrecision(precision='16-mixed') - ) - fabric = trainer.to_fabric() - trainer.strategy.setup_environment() +def main(): + parser = Quantizer.create_argparser() + params = parser.parse_args(sys.argv[1:]) + params = Quantizer.postprocess_argparse(params) - # Load model with modelopt layer spec - model = nl.io.load_context(args.nemo_checkpoint).model - model.config.transformer_layer_spec = get_gpt_layer_modelopt_spec() - - # TODO: [0] works only for PP = 1. Will be changed when PP support is added. - model = fabric.load_model(args.nemo_checkpoint, model=model)[0] + quantizer = Quantizer(params.quantization_config, params.export_config) + model = quantizer.load_quantizable_model(params.nemo_checkpoint, params.tensor_parallelism_size) - # TODO: allow other configs - atq_config = mtq.FP8_DEFAULT_CFG - enable_quant_kv_cache = True - print(f'{"Enable" if enable_quant_kv_cache else "Disable"} KV cache quantization') - atq_config["quant_cfg"]["*output_quantizer"] = { # type: ignore[index] - "num_bits": (4, 3), - "axis": None, - "enable": enable_quant_kv_cache, - } + if params.quant_algo != "no_quant": + model = quantizer.quantize(model, forward_loop) + quantizer.export(model) - model = mtq.quantize(model, atq_config, forward_loop) - mtq.print_quant_summary(model) - export_tensorrt_llm_checkpoint( - model, - args.modelopt_type, - torch.float16, - export_dir=args.output_path, - inference_tensor_parallel=args.tensor_parallelism_size, - inference_pipeline_parallel=1, - use_nfs_workspace=False, - ) - save_tokenizer_config(args.name, args.output_path) +if __name__ == '__main__': + main() diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index bc6f4dd9201e..f4ee03d4c911 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -102,6 +102,7 @@ gpt_forward_step, ) from nemo.collections.llm.t5.model import T5Config, T5Model, t5_data_step, t5_forward_step +from nemo.collections.llm.quantization import Quantizer, get_calib_data_iter __all__ = [ "MockDataModule", diff --git a/nemo/collections/llm/quantization/__init__.py b/nemo/collections/llm/quantization/__init__.py new file mode 100644 index 000000000000..3a8b2989a6a1 --- /dev/null +++ b/nemo/collections/llm/quantization/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .quantizer import Quantizer, get_calib_data_iter diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py new file mode 100644 index 000000000000..828ae64e483d --- /dev/null +++ b/nemo/collections/llm/quantization/quantizer.py @@ -0,0 +1,395 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import shutil +import os + +import torch +import torch.distributed as dist + +from nemo import lightning as nl +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.utils import logging + +try: + import modelopt.torch.quantization as mtq + from modelopt.torch.export import export_tensorrt_llm_checkpoint + + QUANT_CFG_CHOICES = { + "int8": mtq.INT8_DEFAULT_CFG, + "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, + "fp8": mtq.FP8_DEFAULT_CFG, + "int4_awq": mtq.INT4_AWQ_CFG, + "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, + "int4": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + } + + HAVE_MODELOPT = True + +except (ImportError, ModuleNotFoundError) as e: + HAVE_MODELOPT = False + HAVE_MODELOPT_ERROR = e + + +SUPPORTED_DTYPE = [16, "16", "bf16"] # Default precision for non-quantized layers + +# TODO: delete +class config_dict(dict): + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def _dict_to_config(config): + if isinstance(config, dict): + return config_dict(config) + return config + + +#### nemo.export.quantize.quantizer.Quantizer class for NeMo 2 +class Quantizer: + """Post-training quantization (PTQ) and TRT-LLM export of Nemo checkpoints. + + PTQ converts selected model layers to low-precision format (e.g., INT4, FP8) for efficient serving. + The process consist of several steps: + + 1. Loading a Nemo model from disk using appropriate parallelism strategy + 2. Calibrating the model to obtain appropriate algorithm-specific scaling factors + 3. Producing output directory or .qnemo tarball with model config (json), + quantized weights (safetensors) and tokenizer config (yaml). + + The output directory (or .qnemo file) produced is intended to be consumed by TensorRT-LLM toolbox + for efficient inference. This can be achieved using Nemo inference containers. + + Currently supported and tested model family is Llama2. Model type needs to be specified in + the quantization command with decoder_type parameter on exporting (see below). Quantizing other + model families is experimental and might not be fully supported. + + Available quantization methods are listed in `QUANT_CFG_CHOICES` dictionary above. + Please consult Model Optimizer documentation https://nvidia.github.io/TensorRT-Model-Optimizer/ for details. + You can also inspect different choices in examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml + for quantization algorithms and calibration data as well as recommended settings. + + Quantization algorithm can also be conveniently set to 'null' to perform only weights export step + for TensorRT-LLM deployment. This is useful to getting baseline results for a full-precision model. + """ + + def __init__(self, quantization_config, export_config): + """Initialize Quantizer with quantization and export configurations. + + Expected keys in `quantization_config`: + - algorithm: str + - awq_block_size: int (only for awq algorithms) + - sq_alpha: float (only for smooth quant algorithms) + - enable_kv_cache: bool (default: None i.e. auto-detect based on algorithm and decoder_type) + + Expected keys in `export_config`: + - dtype: str/int + - decoder_type: str + - inference_tensor_parallel: int + - inference_pipeline_parallel: int + - path: str + """ + if not HAVE_MODELOPT: + raise RuntimeError("nvidia-modelopt is needed to use Quantizer") from HAVE_MODELOPT_ERROR + if not torch.cuda.is_available(): + raise EnvironmentError("GPU is required for the quantization.") + + quantization_config = _dict_to_config(quantization_config) + export_config = _dict_to_config(export_config) + + self.quantization_config = quantization_config + self.export_config = export_config + + # Quantization sanity checks + assert ( + quantization_config.algorithm is None or quantization_config.algorithm in QUANT_CFG_CHOICES + ), f"Unsupported quantization algorithm: {quantization_config.algorithm}" + if quantization_config.algorithm is not None: + quant_cfg = QUANT_CFG_CHOICES[quantization_config.algorithm] + + if "awq" in quantization_config.algorithm: + weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] + if isinstance(weight_quantizer, list): + weight_quantizer = weight_quantizer[0] + weight_quantizer["block_sizes"][-1] = quantization_config.awq_block_size + + # Always turn on FP8 kv cache to save memory footprint. + # For int8_sq, we use int8 kv cache. + # TODO: Investigate why enabling FP8 kv cache will cause accuracy regressions for Nemotron. + enable_quant_kv_cache = quantization_config.get("enable_kv_cache", None) + if enable_quant_kv_cache is None: + enable_quant_kv_cache = ( + "int8" not in quantization_config.algorithm and quantization_config.decoder_type != "gptnext" + ) + logging.info(f'{"Enabled" if enable_quant_kv_cache else "Disabled"} KV cache quantization') + quant_cfg["quant_cfg"]["*output_quantizer"] = { + "num_bits": 8 if quantization_config.algorithm == "int8_sq" else (4, 3), + "axis": None, + "enable": enable_quant_kv_cache, + } + if quantization_config.algorithm == "int8_sq": + logging.info(f"Using int8_sq alpha = {quantization_config.sq_alpha}") + quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": quantization_config.sq_alpha} + + self.quant_cfg = quant_cfg + else: + self.quant_cfg = None + + # Export sanity checks + if export_config is not None: + assert export_config.dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {export_config.dtype}" + + self.nemo_checkpoint_path = None + + + def load_quantizable_model(self, nemo_checkpoint_path: str, tensor_parallelism_size: int = 1): + from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec + + self.nemo_checkpoint_path = nemo_checkpoint_path + + trainer = nl.Trainer( + devices=tensor_parallelism_size, + strategy=nl.MegatronStrategy( + tensor_model_parallel_size=tensor_parallelism_size, + pipeline_model_parallel_size=1, + ), + plugins=nl.MegatronMixedPrecision(precision='16-mixed'), + ) + fabric = trainer.to_fabric() + trainer.strategy.setup_environment() + + model = nl.io.load_context(nemo_checkpoint_path).model + model.config.transformer_layer_spec = get_gpt_layer_modelopt_spec() + model.config = self.modify_model_config(model.config) + + # TODO: [0] works only for PP=1 + model = fabric.load_model(nemo_checkpoint_path, model=model)[0] + model.freeze() + return model + + + # TODO: what happens with NeMo 2? + # @staticmethod + # def _setup(model): + # """Setup model for quantization.""" + # try: + # model.model.module.language_model.encoder.activations_checkpoint_method = None + # except AttributeError: + # pass + + # if not parallel_state.is_initialized(): + + # def dummy(): + # return + + # if model.trainer.strategy.launcher is not None: + # model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) + # model.trainer.strategy.setup_environment() + + @staticmethod + def create_argparser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="NeMo PTQ argument parser", + ) + parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source NeMo 2.0 checkpoint") + parser.add_argument("--decoder_type", type=str, help="Decoder type for TensorRT-Model-Optimizer") + parser.add_argument( + "-tps", + "--tensor_parallelism_size", + type=int, + default=1 + ) + parser.add_argument( + '-out', + '--output_path', + type=str, + help='Path for the exported engine' + ) + parser.add_argument( + '--quant_algo', + type=str, + default="no_quant", + choices=["no_quant", "int8", "int8_sq", "fp8", "int4_awq", "w4a8_awq", "int4"], + help='TensorRT-Model-Optimizer quantization algorithm' + ) + parser.add_argument( + '-awq_bs', + '--awq_block_size', + type=int, + default=128, + help='Block size for AWQ quantization algorithms' + ) + parser.add_argument( + '--sq_alpha', + type=float, + default=1.0, + help='Smooth-Quant alpha parameter' + ) + + return parser + + @staticmethod + def postprocess_argparse(args): + args.quantization_config = { + "algorithm": args.quant_algo, + "awq_block_size": args.awq_block_size, + "sq_alpha": args.sq_alpha, + "enable_kv_cache": None, + } + args.export_config = { + "path": args.output_path, + "decoder_type": args.decoder_type, + "inference_tensor_parallel": args.tensor_parallelism_size, + "inference_pipeline_parallel": 1, + "dtype": "bf16", + } + return args + + @staticmethod + def modify_model_config(model_cfg): + """Modify model config for quantization.""" + + # TODO: re-think + # with open_dict(model_cfg): + if model_cfg.sequence_parallel: + logging.warning("Disabling sequence parallelism for quantization...") + model_cfg.sequence_parallel = False + # Only custom ModelOpt spec is supported for Quantization: this custom spec is largely based on local Megatron-LM + # layer definitions to avoid Transformer Engine implementations that are currently not supported. + # This layer spec also requires RoPE fusion to be disabled for tensor view operations in attention + # layer implementation from megatron/core/transformer/dot_product_attention.py to be functional. + model_cfg.name = "modelopt" + model_cfg.apply_rope_fusion = False + + return model_cfg + + # TODO: Add support for NeMo 2 + # @staticmethod + # def _sample_output(model: MegatronGPTModel): + # """Generate sample output for a model instance.""" + # logging.info("Generating sample output for the model...") + + # response = model.generate( + # inputs=[ + # "Born in north-east France, Soyer trained as a", + # "Born in California, Soyer trained as a", + # ], + # length_params={ + # "max_length": 100, + # "min_length": 100, + # }, + # ) + + # logging.info(f'Example NeMo output before export: {response["sentences"]}"') + + def quantize(self, model, forward_loop): + """Quantize the model and calibrate using given forward loop.""" + assert self.quant_cfg is not None, "Quantization algorithm is not set" + + logging.info(f"Quantizing model to {self.quantization_config.algorithm}...") + # self._setup(model) + + model = mtq.quantize(model, self.quant_cfg, forward_loop) + + if self.quantization_config.decoder_type == "gptnext": + # We found squared_relu may have an under-calibration problem. + # Clamp the scaling_factor with a min threshold to avoid under-calibration. + maxbound = 0 + if self.quantization_config.algorithm == "fp8": + maxbound = 448 + elif self.quantization_config.algorithm == "int8_sq": + maxbound = 127 + model = mtq.postprocess_amax( + model, "*input_quantizer", lambda amax: torch.clamp(amax, min=0.01 * maxbound) + ) + + if dist.get_rank() == 0: + mtq.print_quant_summary(model) + + return model + + def export(self, model, nemo_checkpoint_path = None): + """Export model to '.qnemo' format for TensorRT-LLM engine build.""" + assert self.export_config is not None, "Export config is not set" + torch_dtype = torch_dtype_from_precision(self.export_config.dtype) + + # TODO: add with generate + # if self.export_config.get("sample_output", True): + # self._sample_output(model) + + + # TODO: Support compressing to .qnemo + + # TODO: SUPPORT NeMo 2: + # if model.cfg.megatron_amp_O2: + # model.model = unwrap_model(model.model, Float16Module) + + # with export_handler as export_dir: + mtq.print_quant_summary(model) + export_dir = self.export_config.path + export_tensorrt_llm_checkpoint( + model=model, + decoder_type=self.export_config.decoder_type, + dtype=torch_dtype, + export_dir=export_dir, + inference_tensor_parallel=self.export_config.inference_tensor_parallel, + inference_pipeline_parallel=self.export_config.inference_pipeline_parallel, + + # TODO: What happens in NeMo 2? + # use_nfs_workspace=model.trainer.num_nodes > 1, + ) + dist.barrier() # Wait until all ranks complete export_model_config step + logging.info( + f"Exporting quantized weights, model artifacts, and tokenizer config to {self.export_config.path}..." + ) + + + if dist.get_rank() == 0: + self.nemo_checkpoint_path = nemo_checkpoint_path or self.nemo_checkpoint_path + + if self.nemo_checkpoint_path is not None: + tokenizer_src = os.path.join(self.nemo_checkpoint_path, 'nemo_tokenizer') + tokenizer_dst = os.path.join(self.export_config.path, 'tokenizer') + + if os.path.exists(tokenizer_src): + shutil.copytree(tokenizer_src, tokenizer_dst) + else: + print("Could not find copy tokenizer from NeMo checkpoint") + + # TODO Support for NeMo 2? + # save_artifacts(model, export_dir) + + + +def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512): + from datasets import load_dataset + if data == "wikitext": + dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") + text_column = "text" + elif data == "cnn_dailymail": + dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") + text_column = "article" + else: + # Assume a local JSON dataset with a column named "text" + dataset = load_dataset("json", data_files=data, split="train") + text_column = "text" + calib_size = max(min(len(dataset), calib_size), batch_size) + for i in range(calib_size // batch_size): + batch = dataset[i * batch_size : (i + 1) * batch_size][text_column] + for j in range(len(batch)): + batch[j] = batch[j][:max_sequence_length] + yield batch diff --git a/nemo/export/trt_llm/qnemo/tokenizer_utils.py b/nemo/export/trt_llm/qnemo/tokenizer_utils.py index c3dd5c2befc9..acb47a40cf0a 100644 --- a/nemo/export/trt_llm/qnemo/tokenizer_utils.py +++ b/nemo/export/trt_llm/qnemo/tokenizer_utils.py @@ -23,11 +23,16 @@ # from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer TOKENIZER_CONFIG_FILE = "tokenizer_config.yaml" - +TOKENIZER_DIR = "tokenizer" def get_nmt_tokenizer(nemo_checkpoint_path: str): """Build tokenizer from Nemo tokenizer config.""" + tokenizer_dir = os.path.join(nemo_checkpoint_path, TOKENIZER_DIR) + if os.path.exists(tokenizer_dir): + print(f"Initializing tokenizer from {TOKENIZER_DIR} directory") + return AutoTokenizer.from_pretrained(tokenizer_dir) + print(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}") tokenizer_cfg = OmegaConf.load(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE)) From 6676ef6ba4d1ce2273cb674bdf5ec364ed1364a5 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Mon, 30 Sep 2024 10:07:43 -0700 Subject: [PATCH 03/24] refactor Signed-off-by: Piotr Kaminski --- examples/llm/quantization/ptq.py | 91 ++++- nemo/collections/llm/__init__.py | 2 + nemo/collections/llm/gpt/model/__init__.py | 2 + .../collections/llm/quantization/quantizer.py | 331 +++++++----------- 4 files changed, 208 insertions(+), 218 deletions(-) diff --git a/examples/llm/quantization/ptq.py b/examples/llm/quantization/ptq.py index d961c35f8c7c..02acc0484da6 100644 --- a/examples/llm/quantization/ptq.py +++ b/examples/llm/quantization/ptq.py @@ -12,13 +12,91 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import sys import torch from tqdm import tqdm from nemo.collections.llm.quantization import Quantizer, get_calib_data_iter -# TODO: Support PP + + # TODO: Inference TP/PP != Calibration TP/PP +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="NeMo PTQ argument parser", + ) + parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source NeMo 2.0 checkpoint") + parser.add_argument("--decoder_type", type=str, help="Decoder type for TensorRT-Model-Optimizer") + parser.add_argument( + "-tps", + "--tensor_parallelism_size", + type=int, + default=1 + ) + parser.add_argument( + '-out', + '--output_path', + type=str, + help='Path for the exported engine' + ) + parser.add_argument( + '-algo', + '--algorithm', + type=str, + default="no_quant", + choices=["no_quant", "int8", "int8_sq", "fp8", "int4_awq", "w4a8_awq", "int4"], + help='TensorRT-Model-Optimizer quantization algorithm' + ) + parser.add_argument( + '-awq_bs', + '--awq_block_size', + type=int, + default=128, + help='Block size for AWQ quantization algorithms' + ) + parser.add_argument( + '--sq_alpha', + type=float, + default=1.0, + help='Smooth-Quant alpha parameter' + ) + parser.add_argument( + '--enable_kv_cache', + type=bool, + help='Enables KV-cache quantization' + ) + parser.add_argument( + '-dt', + '--dtype', + default="bf16", + choices=["16", "bf16"], + help='Default precision for non-quantized layers' + ) + + return parser.parse_args(sys.argv[1:]) + + +def get_quantizer_config(args): + if args.output_path is None: + args.output_path = f"./trt_llm_{args.algorithm}_tp{args.tensor_parallelism_size}" + + quantization_config = { + "algorithm": None if args.algorithm == "no_quant" else args.algorithm, + "awq_block_size": args.awq_block_size, + "sq_alpha": args.sq_alpha, + "enable_kv_cache": args.enable_kv_cache, + } + + export_config = { + "path": args.output_path, + "decoder_type": args.decoder_type, + "inference_tensor_parallel": args.tensor_parallelism_size, + "inference_pipeline_parallel": 1, + "dtype": args.dtype, + } + return quantization_config, export_config + # TODO: maybe use llm.generate (#10471) def forward_loop(model): @@ -41,15 +119,12 @@ def forward_loop(model): def main(): - parser = Quantizer.create_argparser() - params = parser.parse_args(sys.argv[1:]) - params = Quantizer.postprocess_argparse(params) + params = get_args() + quantization_config, export_config = get_quantizer_config(params) - quantizer = Quantizer(params.quantization_config, params.export_config) + quantizer = Quantizer(quantization_config, export_config) model = quantizer.load_quantizable_model(params.nemo_checkpoint, params.tensor_parallelism_size) - - if params.quant_algo != "no_quant": - model = quantizer.quantize(model, forward_loop) + model = quantizer.quantize(model, forward_loop) quantizer.export(model) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index f4ee03d4c911..a1ef79d75e32 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -70,6 +70,7 @@ MaskedTokenLossReduction, MistralConfig7B, MistralModel, + MixtralConfig, MixtralConfig8x3B, MixtralConfig8x7B, MixtralConfig8x22B, @@ -117,6 +118,7 @@ "MaskedTokenLossReduction", "MistralConfig7B", "MistralModel", + "MistralConfig", "MixtralConfig8x3B", "MixtralConfig8x7B", "MixtralConfig8x22B", diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index aa3615b3ddfd..95dc32781b9e 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -58,6 +58,7 @@ MixtralConfig8x3B, MixtralConfig8x7B, MixtralConfig8x22B, + MixtralConfig, MixtralModel, ) from nemo.collections.llm.gpt.model.nemotron import ( @@ -104,6 +105,7 @@ "MixtralConfig8x3B", "MixtralConfig8x7B", "MixtralConfig8x22B", + "MixtralConfig", "MixtralModel", "Starcoder2Config", "Starcoder2Model", diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 828ae64e483d..d0b93225a002 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import shutil import os +from typing import Optional import torch import torch.distributed as dist +from datasets import load_dataset from nemo import lightning as nl +from nemo.collections import llm from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import logging +from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec try: import modelopt.torch.quantization as mtq @@ -45,20 +48,32 @@ SUPPORTED_DTYPE = [16, "16", "bf16"] # Default precision for non-quantized layers -# TODO: delete -class config_dict(dict): - __getattr__ = dict.get - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ +def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: + """Infers the modelopt decoder type from GPTConfig class""" + mapping = [ + (llm.Baichuan2Config, "baichuan"), + (llm.ChatGLMConfig, "chatglm"), + (llm.GemmaConfig, "gemma"), + (llm.LlamaConfig, "llama"), + (llm.MistralConfig7B, "llama"), + (llm.MixtralConfig, "llama"), + (llm.NemotronConfig, "gptnext"), + # TODO: (llm.Qwen2Config, ""), + # (llm.StarcoderConfig, ""), + (llm.Starcoder2Config, "gptnext"), + ] -def _dict_to_config(config): - if isinstance(config, dict): - return config_dict(config) - return config + for config_class, decoder_type in mapping: + if isinstance(config, config_class): + return decoder_type + logging.warning("Could not directly infer the decoder type") + # TODO: Add a reasonable behavior for GPTConfig (for instance based on position_embedding_type) + return "llama" -#### nemo.export.quantize.quantizer.Quantizer class for NeMo 2 + +# TODO: Support PP class Quantizer: """Post-training quantization (PTQ) and TRT-LLM export of Nemo checkpoints. @@ -67,37 +82,30 @@ class Quantizer: 1. Loading a Nemo model from disk using appropriate parallelism strategy 2. Calibrating the model to obtain appropriate algorithm-specific scaling factors - 3. Producing output directory or .qnemo tarball with model config (json), - quantized weights (safetensors) and tokenizer config (yaml). - - The output directory (or .qnemo file) produced is intended to be consumed by TensorRT-LLM toolbox - for efficient inference. This can be achieved using Nemo inference containers. + 3. Producing output directory - Currently supported and tested model family is Llama2. Model type needs to be specified in - the quantization command with decoder_type parameter on exporting (see below). Quantizing other - model families is experimental and might not be fully supported. + The output directory produced is intended to be consumed by TensorRT-LLM toolbox + for efficient inference. This can be achieved using NeMo inference containers. Available quantization methods are listed in `QUANT_CFG_CHOICES` dictionary above. Please consult Model Optimizer documentation https://nvidia.github.io/TensorRT-Model-Optimizer/ for details. - You can also inspect different choices in examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml - for quantization algorithms and calibration data as well as recommended settings. - Quantization algorithm can also be conveniently set to 'null' to perform only weights export step + Quantization algorithm can also be conveniently set to None to perform only weights export step for TensorRT-LLM deployment. This is useful to getting baseline results for a full-precision model. """ - def __init__(self, quantization_config, export_config): + def __init__(self, quantization_config: dict, export_config: dict): """Initialize Quantizer with quantization and export configurations. Expected keys in `quantization_config`: - - algorithm: str + - algorithm: (optional) str - awq_block_size: int (only for awq algorithms) - sq_alpha: float (only for smooth quant algorithms) - enable_kv_cache: bool (default: None i.e. auto-detect based on algorithm and decoder_type) Expected keys in `export_config`: - dtype: str/int - - decoder_type: str + - decoder_type: (optional) str - inference_tensor_parallel: int - inference_pipeline_parallel: int - path: str @@ -107,59 +115,22 @@ def __init__(self, quantization_config, export_config): if not torch.cuda.is_available(): raise EnvironmentError("GPU is required for the quantization.") - quantization_config = _dict_to_config(quantization_config) - export_config = _dict_to_config(export_config) - self.quantization_config = quantization_config self.export_config = export_config + self.nemo_checkpoint_path = None - # Quantization sanity checks - assert ( - quantization_config.algorithm is None or quantization_config.algorithm in QUANT_CFG_CHOICES - ), f"Unsupported quantization algorithm: {quantization_config.algorithm}" - if quantization_config.algorithm is not None: - quant_cfg = QUANT_CFG_CHOICES[quantization_config.algorithm] - - if "awq" in quantization_config.algorithm: - weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] - if isinstance(weight_quantizer, list): - weight_quantizer = weight_quantizer[0] - weight_quantizer["block_sizes"][-1] = quantization_config.awq_block_size - - # Always turn on FP8 kv cache to save memory footprint. - # For int8_sq, we use int8 kv cache. - # TODO: Investigate why enabling FP8 kv cache will cause accuracy regressions for Nemotron. - enable_quant_kv_cache = quantization_config.get("enable_kv_cache", None) - if enable_quant_kv_cache is None: - enable_quant_kv_cache = ( - "int8" not in quantization_config.algorithm and quantization_config.decoder_type != "gptnext" - ) - logging.info(f'{"Enabled" if enable_quant_kv_cache else "Disabled"} KV cache quantization') - quant_cfg["quant_cfg"]["*output_quantizer"] = { - "num_bits": 8 if quantization_config.algorithm == "int8_sq" else (4, 3), - "axis": None, - "enable": enable_quant_kv_cache, - } - if quantization_config.algorithm == "int8_sq": - logging.info(f"Using int8_sq alpha = {quantization_config.sq_alpha}") - quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": quantization_config.sq_alpha} - - self.quant_cfg = quant_cfg - else: - self.quant_cfg = None + algorithm = quantization_config.get("algorithm", None) + dtype = export_config["dtype"] + # Quantization sanity checks + assert (algorithm is None or algorithm in QUANT_CFG_CHOICES), f"Unsupported quantization algorithm: {algorithm}" # Export sanity checks if export_config is not None: - assert export_config.dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {export_config.dtype}" + assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}" - self.nemo_checkpoint_path = None - def load_quantizable_model(self, nemo_checkpoint_path: str, tensor_parallelism_size: int = 1): - from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec - - self.nemo_checkpoint_path = nemo_checkpoint_path - + def load_quantizable_model(self, nemo_checkpoint_path: str, tensor_parallelism_size: int = 1) -> llm.GPTModel: trainer = nl.Trainer( devices=tensor_parallelism_size, strategy=nl.MegatronStrategy( @@ -177,94 +148,29 @@ def load_quantizable_model(self, nemo_checkpoint_path: str, tensor_parallelism_s # TODO: [0] works only for PP=1 model = fabric.load_model(nemo_checkpoint_path, model=model)[0] - model.freeze() - return model - - # TODO: what happens with NeMo 2? - # @staticmethod - # def _setup(model): - # """Setup model for quantization.""" - # try: - # model.model.module.language_model.encoder.activations_checkpoint_method = None - # except AttributeError: - # pass - - # if not parallel_state.is_initialized(): - - # def dummy(): - # return + self.nemo_checkpoint_path = nemo_checkpoint_path + return model - # if model.trainer.strategy.launcher is not None: - # model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) - # model.trainer.strategy.setup_environment() + @staticmethod - def create_argparser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - description="NeMo PTQ argument parser", - ) - parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source NeMo 2.0 checkpoint") - parser.add_argument("--decoder_type", type=str, help="Decoder type for TensorRT-Model-Optimizer") - parser.add_argument( - "-tps", - "--tensor_parallelism_size", - type=int, - default=1 - ) - parser.add_argument( - '-out', - '--output_path', - type=str, - help='Path for the exported engine' - ) - parser.add_argument( - '--quant_algo', - type=str, - default="no_quant", - choices=["no_quant", "int8", "int8_sq", "fp8", "int4_awq", "w4a8_awq", "int4"], - help='TensorRT-Model-Optimizer quantization algorithm' - ) - parser.add_argument( - '-awq_bs', - '--awq_block_size', - type=int, - default=128, - help='Block size for AWQ quantization algorithms' - ) - parser.add_argument( - '--sq_alpha', - type=float, - default=1.0, - help='Smooth-Quant alpha parameter' - ) - - return parser + def _setup(model: llm.GPTModel) -> None: + """Setup model for quantization.""" + model.freeze() - @staticmethod - def postprocess_argparse(args): - args.quantization_config = { - "algorithm": args.quant_algo, - "awq_block_size": args.awq_block_size, - "sq_alpha": args.sq_alpha, - "enable_kv_cache": None, - } - args.export_config = { - "path": args.output_path, - "decoder_type": args.decoder_type, - "inference_tensor_parallel": args.tensor_parallelism_size, - "inference_pipeline_parallel": 1, - "dtype": "bf16", - } - return args + # TODO: update for NeMo 2.0 + # try: + # model.model.module.language_model.encoder.activations_checkpoint_method = None + # except AttributeError: + # pass + + @staticmethod - def modify_model_config(model_cfg): + def modify_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: """Modify model config for quantization.""" - # TODO: re-think - # with open_dict(model_cfg): if model_cfg.sequence_parallel: logging.warning("Disabling sequence parallelism for quantization...") model_cfg.sequence_parallel = False @@ -274,45 +180,63 @@ def modify_model_config(model_cfg): # layer implementation from megatron/core/transformer/dot_product_attention.py to be functional. model_cfg.name = "modelopt" model_cfg.apply_rope_fusion = False - return model_cfg - # TODO: Add support for NeMo 2 - # @staticmethod - # def _sample_output(model: MegatronGPTModel): - # """Generate sample output for a model instance.""" - # logging.info("Generating sample output for the model...") - - # response = model.generate( - # inputs=[ - # "Born in north-east France, Soyer trained as a", - # "Born in California, Soyer trained as a", - # ], - # length_params={ - # "max_length": 100, - # "min_length": 100, - # }, - # ) - - # logging.info(f'Example NeMo output before export: {response["sentences"]}"') - - def quantize(self, model, forward_loop): + + + def _get_decoder_type(self, config: llm.GPTConfig): + return self.export_config.get("decoder_type", None) or get_modelopt_decoder_type(config) + + + + def quantize(self, model: llm.GPTConfig, forward_loop): """Quantize the model and calibrate using given forward loop.""" - assert self.quant_cfg is not None, "Quantization algorithm is not set" + algorithm = self.quantization_config["algorithm"] + if algorithm is None: + logging.info("Quantization algorithm set to None, returning the non-quantized model") + return model - logging.info(f"Quantizing model to {self.quantization_config.algorithm}...") - # self._setup(model) + logging.info(f"Quantizing model to {algorithm}...") + + self._setup(model) + decoder_type = self._get_decoder_type(model.config) + quant_cfg = QUANT_CFG_CHOICES[algorithm] + if "awq" in algorithm: + weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] + if isinstance(weight_quantizer, list): + weight_quantizer = weight_quantizer[0] + weight_quantizer["block_sizes"][-1] = self.quantization_config["awq_block_size"] + + # Always turn on FP8 kv cache to save memory footprint. + # For int8_sq, we use int8 kv cache. + # TODO: Investigate why enabling FP8 kv cache will cause accuracy regressions for Nemotron. + enable_quant_kv_cache = self.quantization_config.get("enable_kv_cache", None) + if enable_quant_kv_cache is None: + enable_quant_kv_cache = ( + "int8" not in algorithm and decoder_type != "gptnext" + ) + logging.info(f'{"Enabled" if enable_quant_kv_cache else "Disabled"} KV cache quantization') + quant_cfg["quant_cfg"]["*output_quantizer"] = { + "num_bits": 8 if algorithm == "int8_sq" else (4, 3), + "axis": None, + "enable": enable_quant_kv_cache, + } + if algorithm == "int8_sq": + sq_alpha = self.quantization_config["sq_alpha"] + logging.info(f"Using int8_sq alpha = {sq_alpha}") + quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": sq_alpha} - model = mtq.quantize(model, self.quant_cfg, forward_loop) - if self.quantization_config.decoder_type == "gptnext": + model = mtq.quantize(model, quant_cfg, forward_loop) + + if decoder_type == "gptnext": # We found squared_relu may have an under-calibration problem. # Clamp the scaling_factor with a min threshold to avoid under-calibration. - maxbound = 0 - if self.quantization_config.algorithm == "fp8": - maxbound = 448 - elif self.quantization_config.algorithm == "int8_sq": - maxbound = 127 + match algorithm: + case "fp8": maxbound = 448 + case "int8_sq": maxbound = 127 + case _: maxbound = 0 + model = mtq.postprocess_amax( model, "*input_quantizer", lambda amax: torch.clamp(amax, min=0.01 * maxbound) ) @@ -322,61 +246,48 @@ def quantize(self, model, forward_loop): return model - def export(self, model, nemo_checkpoint_path = None): + + def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None) -> None: """Export model to '.qnemo' format for TensorRT-LLM engine build.""" assert self.export_config is not None, "Export config is not set" - torch_dtype = torch_dtype_from_precision(self.export_config.dtype) - - # TODO: add with generate - # if self.export_config.get("sample_output", True): - # self._sample_output(model) - - - # TODO: Support compressing to .qnemo - - # TODO: SUPPORT NeMo 2: + torch_dtype = torch_dtype_from_precision(self.export_config["dtype"]) + # TODO: Add sample generate + # TODO: Support NeMo 2: # if model.cfg.megatron_amp_O2: # model.model = unwrap_model(model.model, Float16Module) - # with export_handler as export_dir: - mtq.print_quant_summary(model) - export_dir = self.export_config.path + export_dir = self.export_config["path"] export_tensorrt_llm_checkpoint( model=model, - decoder_type=self.export_config.decoder_type, + decoder_type=self._get_decoder_type(model.config), dtype=torch_dtype, export_dir=export_dir, - inference_tensor_parallel=self.export_config.inference_tensor_parallel, - inference_pipeline_parallel=self.export_config.inference_pipeline_parallel, - - # TODO: What happens in NeMo 2? - # use_nfs_workspace=model.trainer.num_nodes > 1, + inference_tensor_parallel=self.export_config["inference_tensor_parallel"], + inference_pipeline_parallel=self.export_config["inference_pipeline_parallel"], + use_nfs_workspace=model.trainer._fabric.__io__.num_nodes > 1, # TODO: check it ) + dist.barrier() # Wait until all ranks complete export_model_config step logging.info( - f"Exporting quantized weights, model artifacts, and tokenizer config to {self.export_config.path}..." + f"Exporting quantized weights, model artifacts, and tokenizer config to {export_dir}..." ) - if dist.get_rank() == 0: self.nemo_checkpoint_path = nemo_checkpoint_path or self.nemo_checkpoint_path - + if self.nemo_checkpoint_path is not None: tokenizer_src = os.path.join(self.nemo_checkpoint_path, 'nemo_tokenizer') - tokenizer_dst = os.path.join(self.export_config.path, 'tokenizer') + tokenizer_dst = os.path.join(export_dir, 'tokenizer') - if os.path.exists(tokenizer_src): + if os.path.exists(tokenizer_src) and not os.path.exists(tokenizer_dst): shutil.copytree(tokenizer_src, tokenizer_dst) else: - print("Could not find copy tokenizer from NeMo checkpoint") - - # TODO Support for NeMo 2? - # save_artifacts(model, export_dir) + logging.info("Could not copy tokenizer from NeMo checkpoint") -def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512): - from datasets import load_dataset +def get_calib_data_iter(data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512): + """Creates a sample data iterator for calibration""" if data == "wikitext": dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") text_column = "text" From 9df6ac080dd35d8212de2400eb173dd567646fd1 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Fri, 4 Oct 2024 07:26:58 -0700 Subject: [PATCH 04/24] Call quantize on an unwrapped mcore model Signed-off-by: Piotr Kaminski --- examples/llm/quantization/ptq.py | 105 ++++++++++++++---- .../collections/llm/quantization/quantizer.py | 68 ++++++++---- 2 files changed, 129 insertions(+), 44 deletions(-) diff --git a/examples/llm/quantization/ptq.py b/examples/llm/quantization/ptq.py index 02acc0484da6..8ed4f4e7ee1a 100644 --- a/examples/llm/quantization/ptq.py +++ b/examples/llm/quantization/ptq.py @@ -20,7 +20,6 @@ from nemo.collections.llm.quantization import Quantizer, get_calib_data_iter -# TODO: Inference TP/PP != Calibration TP/PP def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -28,12 +27,30 @@ def get_args(): ) parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source NeMo 2.0 checkpoint") parser.add_argument("--decoder_type", type=str, help="Decoder type for TensorRT-Model-Optimizer") + parser.add_argument( + "-ctp", + "--calib_tp", + type=int, + default=1 + ) + parser.add_argument( + "-cpp", + "--calib_pp", + type=int, + default=1 + ) parser.add_argument( "-tps", "--tensor_parallelism_size", type=int, default=1 ) + parser.add_argument( + "-pps", + "--pipeline_parallelism_size", + type=int, + default=1 + ) parser.add_argument( '-out', '--output_path', @@ -58,7 +75,7 @@ def get_args(): parser.add_argument( '--sq_alpha', type=float, - default=1.0, + default=0.5, help='Smooth-Quant alpha parameter' ) parser.add_argument( @@ -73,13 +90,42 @@ def get_args(): choices=["16", "bf16"], help='Default precision for non-quantized layers' ) + parser.add_argument( + '-bs', + '--batch_size', + default=64, + type=int, + help='Calibration batch size' + ) + parser.add_argument( + '-sl', + '--seq_len', + default=128, + type=int, + help='Length of the tokenized text' + ) + parser.add_argument( + '-calib_size', + '--calibration_dataset_size', + default=512, + type=int, + help='Size of calibration dataset' + ) + parser.add_argument( + '-calib_ds', + '--calibration_dataset', + default="cnn_dailymail", + choices=["wikitext", "cnn_dailymail"], + type=str, + help='Calibration dataset to be used' + ) return parser.parse_args(sys.argv[1:]) def get_quantizer_config(args): if args.output_path is None: - args.output_path = f"./trt_llm_{args.algorithm}_tp{args.tensor_parallelism_size}" + args.output_path = f"./qnemo_{args.algorithm}_tp{args.tensor_parallelism_size}_pp{args.pipeline_parallelism_size}" quantization_config = { "algorithm": None if args.algorithm == "no_quant" else args.algorithm, @@ -87,43 +133,54 @@ def get_quantizer_config(args): "sq_alpha": args.sq_alpha, "enable_kv_cache": args.enable_kv_cache, } - export_config = { "path": args.output_path, "decoder_type": args.decoder_type, "inference_tensor_parallel": args.tensor_parallelism_size, - "inference_pipeline_parallel": 1, + "inference_pipeline_parallel": args.pipeline_parallelism_size, "dtype": args.dtype, } + return quantization_config, export_config -# TODO: maybe use llm.generate (#10471) -def forward_loop(model): - tokenizer = model.tokenizer - dataloader = get_calib_data_iter() - dataloader = [data for data in dataloader] +def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): + def _iterator(): + CHARACTERS_PER_TOKEN = 4 - for batch in tqdm(dataloader): - batch = [tokenizer.text_to_ids(text) for text in batch] - max_len = max([len(text) for text in batch]) - batch = [ids + (max_len - len(ids)) * [tokenizer.eos] for ids in batch] - position_ids = torch.arange(max_len, device=model.device).expand((len(batch), max_len)) - batch = torch.tensor(batch, device=model.device) - model_input = { - "input_ids": batch, - "position_ids": position_ids, - "attention_mask": None, - } - model(**model_input) + dataloader = get_calib_data_iter(data=dataset, max_sequence_length=CHARACTERS_PER_TOKEN*seq_len, batch_size=batch_size, calib_size=calibration_size) + for batch in dataloader: + batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch] + batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch] + yield torch.tensor(batch, device=model.device) + + def _iterator_getter(): + dataloader = _iterator() + dataloader = [data for data in dataloader] + return iter(tqdm(dataloader)) + return _iterator_getter + def main(): params = get_args() quantization_config, export_config = get_quantizer_config(params) - quantizer = Quantizer(quantization_config, export_config) - model = quantizer.load_quantizable_model(params.nemo_checkpoint, params.tensor_parallelism_size) + model = quantizer.load_quantizable_model(params.nemo_checkpoint, params.calib_tp, params.calib_pp) + + get_dataloader = create_data_iterator_getter(model, + dataset=params.calibration_dataset, + seq_len=params.seq_len, + batch_size=params.batch_size, + calibration_size=params.calibration_dataset_size) + + forward_loop = quantizer.create_megatron_forward_loop( + get_dataloader, + num_batches=params.calibration_dataset_size // params.batch_size, + seq_length=params.seq_len, + micro_batch_size=params.batch_size, + ) + model = quantizer.quantize(model, forward_loop) quantizer.export(model) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index d0b93225a002..d9dc9652a2ba 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -59,8 +59,8 @@ def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: (llm.MistralConfig7B, "llama"), (llm.MixtralConfig, "llama"), (llm.NemotronConfig, "gptnext"), - # TODO: (llm.Qwen2Config, ""), - # (llm.StarcoderConfig, ""), + (llm.Qwen2Config, "qwen"), + # TODO: (llm.StarcoderConfig, ""), (llm.Starcoder2Config, "gptnext"), ] @@ -130,12 +130,12 @@ def __init__(self, quantization_config: dict, export_config: dict): - def load_quantizable_model(self, nemo_checkpoint_path: str, tensor_parallelism_size: int = 1) -> llm.GPTModel: + def load_quantizable_model(self, nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel: trainer = nl.Trainer( - devices=tensor_parallelism_size, + devices=calib_tp * calib_pp, strategy=nl.MegatronStrategy( - tensor_model_parallel_size=tensor_parallelism_size, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=calib_tp, + pipeline_model_parallel_size=calib_pp, ), plugins=nl.MegatronMixedPrecision(precision='16-mixed'), ) @@ -146,8 +146,7 @@ def load_quantizable_model(self, nemo_checkpoint_path: str, tensor_parallelism_s model.config.transformer_layer_spec = get_gpt_layer_modelopt_spec() model.config = self.modify_model_config(model.config) - # TODO: [0] works only for PP=1 - model = fabric.load_model(nemo_checkpoint_path, model=model)[0] + model = fabric.load_model(nemo_checkpoint_path, model=model) self.nemo_checkpoint_path = nemo_checkpoint_path return model @@ -157,6 +156,7 @@ def load_quantizable_model(self, nemo_checkpoint_path: str, tensor_parallelism_s @staticmethod def _setup(model: llm.GPTModel) -> None: """Setup model for quantization.""" + model.config.vocab_size = model.tokenizer.vocab_size model.freeze() # TODO: update for NeMo 2.0 @@ -183,22 +183,22 @@ def modify_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: return model_cfg - def _get_decoder_type(self, config: llm.GPTConfig): return self.export_config.get("decoder_type", None) or get_modelopt_decoder_type(config) - - def quantize(self, model: llm.GPTConfig, forward_loop): + def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): """Quantize the model and calibrate using given forward loop.""" algorithm = self.quantization_config["algorithm"] if algorithm is None: logging.info("Quantization algorithm set to None, returning the non-quantized model") - return model + return wrapped_model.module logging.info(f"Quantizing model to {algorithm}...") - self._setup(model) + self._setup(wrapped_model) + model = wrapped_model.module.module + model.config.pipeline_dtype = wrapped_model.pipeline.dtype decoder_type = self._get_decoder_type(model.config) quant_cfg = QUANT_CFG_CHOICES[algorithm] if "awq" in algorithm: @@ -244,37 +244,65 @@ def quantize(self, model: llm.GPTConfig, forward_loop): if dist.get_rank() == 0: mtq.print_quant_summary(model) - return model + return wrapped_model + + + def create_megatron_forward_loop(self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None): + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + forward_backward_func = get_forward_backward_func() + + def forward_step_func(data_iterator, model): + data = next(data_iterator) + batch_len, seq_len = data.shape + position_ids = torch.arange(seq_len, device=data.device).expand((batch_len, seq_len)) + output_tensor = model(data, position_ids, None) + + def _mock_loss_function(tensor): + return 0, {} + return output_tensor, _mock_loss_function + + + def loop(model): + dataloader = get_dataloader() + forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=dataloader, + model=model, + num_microbatches=num_batches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + forward_only=True) + + return loop def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None) -> None: - """Export model to '.qnemo' format for TensorRT-LLM engine build.""" assert self.export_config is not None, "Export config is not set" torch_dtype = torch_dtype_from_precision(self.export_config["dtype"]) # TODO: Add sample generate # TODO: Support NeMo 2: # if model.cfg.megatron_amp_O2: # model.model = unwrap_model(model.model, Float16Module) - export_dir = self.export_config["path"] + use_nfs_workspace = (model.trainer._fabric.__io__.num_nodes > 1) or (model.config.pipeline_model_parallel_size > 1) export_tensorrt_llm_checkpoint( - model=model, + model=model.module.module, decoder_type=self._get_decoder_type(model.config), dtype=torch_dtype, export_dir=export_dir, inference_tensor_parallel=self.export_config["inference_tensor_parallel"], inference_pipeline_parallel=self.export_config["inference_pipeline_parallel"], - use_nfs_workspace=model.trainer._fabric.__io__.num_nodes > 1, # TODO: check it + use_nfs_workspace=use_nfs_workspace, ) dist.barrier() # Wait until all ranks complete export_model_config step logging.info( - f"Exporting quantized weights, model artifacts, and tokenizer config to {export_dir}..." + f"Export succeeded, model has been exported to {export_dir}. Saving tokenizer if possible..." ) if dist.get_rank() == 0: self.nemo_checkpoint_path = nemo_checkpoint_path or self.nemo_checkpoint_path - if self.nemo_checkpoint_path is not None: tokenizer_src = os.path.join(self.nemo_checkpoint_path, 'nemo_tokenizer') tokenizer_dst = os.path.join(export_dir, 'tokenizer') From 30f5a978251e50a163f1d8bc760c6eb0ed1d96d8 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Fri, 4 Oct 2024 14:27:56 +0000 Subject: [PATCH 05/24] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- examples/llm/quantization/ptq.py | 117 +++++------------- nemo/collections/llm/__init__.py | 2 +- nemo/collections/llm/gpt/model/__init__.py | 2 +- .../collections/llm/quantization/quantizer.py | 75 ++++++----- nemo/export/trt_llm/qnemo/tokenizer_utils.py | 1 + 5 files changed, 72 insertions(+), 125 deletions(-) diff --git a/examples/llm/quantization/ptq.py b/examples/llm/quantization/ptq.py index 8ed4f4e7ee1a..6e1d9d54e00b 100644 --- a/examples/llm/quantization/ptq.py +++ b/examples/llm/quantization/ptq.py @@ -27,89 +27,31 @@ def get_args(): ) parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source NeMo 2.0 checkpoint") parser.add_argument("--decoder_type", type=str, help="Decoder type for TensorRT-Model-Optimizer") - parser.add_argument( - "-ctp", - "--calib_tp", - type=int, - default=1 - ) - parser.add_argument( - "-cpp", - "--calib_pp", - type=int, - default=1 - ) - parser.add_argument( - "-tps", - "--tensor_parallelism_size", - type=int, - default=1 - ) - parser.add_argument( - "-pps", - "--pipeline_parallelism_size", - type=int, - default=1 - ) - parser.add_argument( - '-out', - '--output_path', - type=str, - help='Path for the exported engine' - ) + parser.add_argument("-ctp", "--calib_tp", type=int, default=1) + parser.add_argument("-cpp", "--calib_pp", type=int, default=1) + parser.add_argument("-tps", "--tensor_parallelism_size", type=int, default=1) + parser.add_argument("-pps", "--pipeline_parallelism_size", type=int, default=1) + parser.add_argument('-out', '--output_path', type=str, help='Path for the exported engine') parser.add_argument( '-algo', '--algorithm', type=str, default="no_quant", choices=["no_quant", "int8", "int8_sq", "fp8", "int4_awq", "w4a8_awq", "int4"], - help='TensorRT-Model-Optimizer quantization algorithm' - ) - parser.add_argument( - '-awq_bs', - '--awq_block_size', - type=int, - default=128, - help='Block size for AWQ quantization algorithms' - ) - parser.add_argument( - '--sq_alpha', - type=float, - default=0.5, - help='Smooth-Quant alpha parameter' - ) - parser.add_argument( - '--enable_kv_cache', - type=bool, - help='Enables KV-cache quantization' - ) - parser.add_argument( - '-dt', - '--dtype', - default="bf16", - choices=["16", "bf16"], - help='Default precision for non-quantized layers' + help='TensorRT-Model-Optimizer quantization algorithm', ) parser.add_argument( - '-bs', - '--batch_size', - default=64, - type=int, - help='Calibration batch size' + '-awq_bs', '--awq_block_size', type=int, default=128, help='Block size for AWQ quantization algorithms' ) + parser.add_argument('--sq_alpha', type=float, default=0.5, help='Smooth-Quant alpha parameter') + parser.add_argument('--enable_kv_cache', type=bool, help='Enables KV-cache quantization') parser.add_argument( - '-sl', - '--seq_len', - default=128, - type=int, - help='Length of the tokenized text' + '-dt', '--dtype', default="bf16", choices=["16", "bf16"], help='Default precision for non-quantized layers' ) + parser.add_argument('-bs', '--batch_size', default=64, type=int, help='Calibration batch size') + parser.add_argument('-sl', '--seq_len', default=128, type=int, help='Length of the tokenized text') parser.add_argument( - '-calib_size', - '--calibration_dataset_size', - default=512, - type=int, - help='Size of calibration dataset' + '-calib_size', '--calibration_dataset_size', default=512, type=int, help='Size of calibration dataset' ) parser.add_argument( '-calib_ds', @@ -117,15 +59,17 @@ def get_args(): default="cnn_dailymail", choices=["wikitext", "cnn_dailymail"], type=str, - help='Calibration dataset to be used' + help='Calibration dataset to be used', ) - + return parser.parse_args(sys.argv[1:]) def get_quantizer_config(args): if args.output_path is None: - args.output_path = f"./qnemo_{args.algorithm}_tp{args.tensor_parallelism_size}_pp{args.pipeline_parallelism_size}" + args.output_path = ( + f"./qnemo_{args.algorithm}_tp{args.tensor_parallelism_size}_pp{args.pipeline_parallelism_size}" + ) quantization_config = { "algorithm": None if args.algorithm == "no_quant" else args.algorithm, @@ -148,31 +92,38 @@ def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration def _iterator(): CHARACTERS_PER_TOKEN = 4 - dataloader = get_calib_data_iter(data=dataset, max_sequence_length=CHARACTERS_PER_TOKEN*seq_len, batch_size=batch_size, calib_size=calibration_size) + dataloader = get_calib_data_iter( + data=dataset, + max_sequence_length=CHARACTERS_PER_TOKEN * seq_len, + batch_size=batch_size, + calib_size=calibration_size, + ) for batch in dataloader: batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch] batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch] yield torch.tensor(batch, device=model.device) - + def _iterator_getter(): dataloader = _iterator() dataloader = [data for data in dataloader] return iter(tqdm(dataloader)) return _iterator_getter - + def main(): params = get_args() quantization_config, export_config = get_quantizer_config(params) quantizer = Quantizer(quantization_config, export_config) model = quantizer.load_quantizable_model(params.nemo_checkpoint, params.calib_tp, params.calib_pp) - - get_dataloader = create_data_iterator_getter(model, - dataset=params.calibration_dataset, - seq_len=params.seq_len, - batch_size=params.batch_size, - calibration_size=params.calibration_dataset_size) + + get_dataloader = create_data_iterator_getter( + model, + dataset=params.calibration_dataset, + seq_len=params.seq_len, + batch_size=params.batch_size, + calibration_size=params.calibration_dataset_size, + ) forward_loop = quantizer.create_megatron_forward_loop( get_dataloader, diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index a1ef79d75e32..4975263a23d5 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -102,8 +102,8 @@ gpt_data_step, gpt_forward_step, ) -from nemo.collections.llm.t5.model import T5Config, T5Model, t5_data_step, t5_forward_step from nemo.collections.llm.quantization import Quantizer, get_calib_data_iter +from nemo.collections.llm.t5.model import T5Config, T5Model, t5_data_step, t5_forward_step __all__ = [ "MockDataModule", diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 95dc32781b9e..a60ad6b255a1 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -55,10 +55,10 @@ ) from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel from nemo.collections.llm.gpt.model.mixtral import ( + MixtralConfig, MixtralConfig8x3B, MixtralConfig8x7B, MixtralConfig8x22B, - MixtralConfig, MixtralModel, ) from nemo.collections.llm.gpt.model.nemotron import ( diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index d9dc9652a2ba..e53d5af9ccfa 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import shutil import os +import shutil from typing import Optional import torch @@ -22,9 +22,9 @@ from nemo import lightning as nl from nemo.collections import llm +from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import logging -from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec try: import modelopt.torch.quantization as mtq @@ -52,16 +52,16 @@ def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: """Infers the modelopt decoder type from GPTConfig class""" mapping = [ - (llm.Baichuan2Config, "baichuan"), - (llm.ChatGLMConfig, "chatglm"), - (llm.GemmaConfig, "gemma"), - (llm.LlamaConfig, "llama"), - (llm.MistralConfig7B, "llama"), - (llm.MixtralConfig, "llama"), - (llm.NemotronConfig, "gptnext"), - (llm.Qwen2Config, "qwen"), + (llm.Baichuan2Config, "baichuan"), + (llm.ChatGLMConfig, "chatglm"), + (llm.GemmaConfig, "gemma"), + (llm.LlamaConfig, "llama"), + (llm.MistralConfig7B, "llama"), + (llm.MixtralConfig, "llama"), + (llm.NemotronConfig, "gptnext"), + (llm.Qwen2Config, "qwen"), # TODO: (llm.StarcoderConfig, ""), - (llm.Starcoder2Config, "gptnext"), + (llm.Starcoder2Config, "gptnext"), ] for config_class, decoder_type in mapping: @@ -123,20 +123,18 @@ def __init__(self, quantization_config: dict, export_config: dict): dtype = export_config["dtype"] # Quantization sanity checks - assert (algorithm is None or algorithm in QUANT_CFG_CHOICES), f"Unsupported quantization algorithm: {algorithm}" + assert algorithm is None or algorithm in QUANT_CFG_CHOICES, f"Unsupported quantization algorithm: {algorithm}" # Export sanity checks if export_config is not None: assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}" - - def load_quantizable_model(self, nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel: trainer = nl.Trainer( devices=calib_tp * calib_pp, strategy=nl.MegatronStrategy( tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, - ), + ), plugins=nl.MegatronMixedPrecision(precision='16-mixed'), ) fabric = trainer.to_fabric() @@ -151,8 +149,6 @@ def load_quantizable_model(self, nemo_checkpoint_path: str, calib_tp: int = 1, c self.nemo_checkpoint_path = nemo_checkpoint_path return model - - @staticmethod def _setup(model: llm.GPTModel) -> None: """Setup model for quantization.""" @@ -165,8 +161,6 @@ def _setup(model: llm.GPTModel) -> None: # except AttributeError: # pass - - @staticmethod def modify_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: """Modify model config for quantization.""" @@ -182,11 +176,9 @@ def modify_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: model_cfg.apply_rope_fusion = False return model_cfg - def _get_decoder_type(self, config: llm.GPTConfig): return self.export_config.get("decoder_type", None) or get_modelopt_decoder_type(config) - def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): """Quantize the model and calibrate using given forward loop.""" algorithm = self.quantization_config["algorithm"] @@ -195,7 +187,7 @@ def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): return wrapped_model.module logging.info(f"Quantizing model to {algorithm}...") - + self._setup(wrapped_model) model = wrapped_model.module.module model.config.pipeline_dtype = wrapped_model.pipeline.dtype @@ -212,9 +204,7 @@ def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): # TODO: Investigate why enabling FP8 kv cache will cause accuracy regressions for Nemotron. enable_quant_kv_cache = self.quantization_config.get("enable_kv_cache", None) if enable_quant_kv_cache is None: - enable_quant_kv_cache = ( - "int8" not in algorithm and decoder_type != "gptnext" - ) + enable_quant_kv_cache = "int8" not in algorithm and decoder_type != "gptnext" logging.info(f'{"Enabled" if enable_quant_kv_cache else "Disabled"} KV cache quantization') quant_cfg["quant_cfg"]["*output_quantizer"] = { "num_bits": 8 if algorithm == "int8_sq" else (4, 3), @@ -226,16 +216,18 @@ def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): logging.info(f"Using int8_sq alpha = {sq_alpha}") quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": sq_alpha} - model = mtq.quantize(model, quant_cfg, forward_loop) if decoder_type == "gptnext": # We found squared_relu may have an under-calibration problem. # Clamp the scaling_factor with a min threshold to avoid under-calibration. match algorithm: - case "fp8": maxbound = 448 - case "int8_sq": maxbound = 127 - case _: maxbound = 0 + case "fp8": + maxbound = 448 + case "int8_sq": + maxbound = 127 + case _: + maxbound = 0 model = mtq.postprocess_amax( model, "*input_quantizer", lambda amax: torch.clamp(amax, min=0.01 * maxbound) @@ -246,9 +238,11 @@ def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): return wrapped_model - - def create_megatron_forward_loop(self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None): + def create_megatron_forward_loop( + self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None + ): from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + forward_backward_func = get_forward_backward_func() def forward_step_func(data_iterator, model): @@ -259,8 +253,8 @@ def forward_step_func(data_iterator, model): def _mock_loss_function(tensor): return 0, {} - return output_tensor, _mock_loss_function + return output_tensor, _mock_loss_function def loop(model): dataloader = get_dataloader() @@ -272,11 +266,11 @@ def loop(model): seq_length=seq_length, micro_batch_size=micro_batch_size, decoder_seq_length=decoder_seq_length, - forward_only=True) + forward_only=True, + ) return loop - def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None) -> None: assert self.export_config is not None, "Export config is not set" torch_dtype = torch_dtype_from_precision(self.export_config["dtype"]) @@ -285,7 +279,9 @@ def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None # if model.cfg.megatron_amp_O2: # model.model = unwrap_model(model.model, Float16Module) export_dir = self.export_config["path"] - use_nfs_workspace = (model.trainer._fabric.__io__.num_nodes > 1) or (model.config.pipeline_model_parallel_size > 1) + use_nfs_workspace = (model.trainer._fabric.__io__.num_nodes > 1) or ( + model.config.pipeline_model_parallel_size > 1 + ) export_tensorrt_llm_checkpoint( model=model.module.module, decoder_type=self._get_decoder_type(model.config), @@ -297,9 +293,7 @@ def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None ) dist.barrier() # Wait until all ranks complete export_model_config step - logging.info( - f"Export succeeded, model has been exported to {export_dir}. Saving tokenizer if possible..." - ) + logging.info(f"Export succeeded, model has been exported to {export_dir}. Saving tokenizer if possible...") if dist.get_rank() == 0: self.nemo_checkpoint_path = nemo_checkpoint_path or self.nemo_checkpoint_path @@ -313,8 +307,9 @@ def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None logging.info("Could not copy tokenizer from NeMo checkpoint") - -def get_calib_data_iter(data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512): +def get_calib_data_iter( + data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512 +): """Creates a sample data iterator for calibration""" if data == "wikitext": dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") diff --git a/nemo/export/trt_llm/qnemo/tokenizer_utils.py b/nemo/export/trt_llm/qnemo/tokenizer_utils.py index acb47a40cf0a..36efa9259f9d 100644 --- a/nemo/export/trt_llm/qnemo/tokenizer_utils.py +++ b/nemo/export/trt_llm/qnemo/tokenizer_utils.py @@ -25,6 +25,7 @@ TOKENIZER_CONFIG_FILE = "tokenizer_config.yaml" TOKENIZER_DIR = "tokenizer" + def get_nmt_tokenizer(nemo_checkpoint_path: str): """Build tokenizer from Nemo tokenizer config.""" From d228ef7168b959c042380ae8fd0d15c6ab571bb5 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Mon, 14 Oct 2024 10:21:01 -0700 Subject: [PATCH 06/24] Add tests, adjust unwrapping Signed-off-by: Piotr Kaminski --- .github/workflows/cicd-main.yml | 15 ++++++ .../collections/llm/quantization/quantizer.py | 46 +++++++++++-------- nemo/lightning/fabric/fabric.py | 3 +- nemo/lightning/fabric/plugins.py | 4 ++ nemo/lightning/io/api.py | 7 +-- tests/collections/llm/test_hf_import.py | 27 +++++++++++ 6 files changed, 79 insertions(+), 23 deletions(-) create mode 100644 tests/collections/llm/test_hf_import.py diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 800d91acb7ed..1c3e5c316248 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5450,6 +5450,21 @@ jobs: SCRIPT: | bash tests/collections/llm/bitexact/mixtral/run.sh + L2_NeMo_2_PTQ_Llama2_FP8: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_PTQ_Llama2_FP8') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python tests/collections/llm/test_hf_import.py --hf_model /home/TestData/nlp/megatron_llama/llama-ci-hf --output_path /tmp/nemo2_ckpt + + python examples/llm/quantization/ptq.py -nc /tmp/nemo2_ckpt -algo fp8 -out /tmp/nemo2_ptq_engine + + AFTER_SCRIPT: | + rm -rf /tmp/nemo2_ckpt + rm -rf /tmp/nemo2_ptq_engine + Nemo_CICD_Test: needs: - pre-flight diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index e53d5af9ccfa..9eac589202e2 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -127,6 +127,8 @@ def __init__(self, quantization_config: dict, export_config: dict): # Export sanity checks if export_config is not None: assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}" + self.torch_dtype = torch_dtype_from_precision(dtype) + def load_quantizable_model(self, nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel: trainer = nl.Trainer( @@ -141,30 +143,26 @@ def load_quantizable_model(self, nemo_checkpoint_path: str, calib_tp: int = 1, c trainer.strategy.setup_environment() model = nl.io.load_context(nemo_checkpoint_path).model - model.config.transformer_layer_spec = get_gpt_layer_modelopt_spec() - model.config = self.modify_model_config(model.config) - + model.config = self.quantizable_model_config(model.config) model = fabric.load_model(nemo_checkpoint_path, model=model) self.nemo_checkpoint_path = nemo_checkpoint_path return model - @staticmethod - def _setup(model: llm.GPTModel) -> None: + + def _setup(self, model: llm.GPTModel) -> None: """Setup model for quantization.""" + # TODO: disable activation checkpointing model.config.vocab_size = model.tokenizer.vocab_size + model.config.pipeline_dtype = self.torch_dtype # TODO: for some reason model.pipeline.dtype does not work model.freeze() - # TODO: update for NeMo 2.0 - # try: - # model.model.module.language_model.encoder.activations_checkpoint_method = None - # except AttributeError: - # pass @staticmethod - def modify_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: + def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: """Modify model config for quantization.""" + model_cfg.transformer_layer_spec = get_gpt_layer_modelopt_spec() if model_cfg.sequence_parallel: logging.warning("Disabling sequence parallelism for quantization...") model_cfg.sequence_parallel = False @@ -176,21 +174,32 @@ def modify_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: model_cfg.apply_rope_fusion = False return model_cfg + def _get_decoder_type(self, config: llm.GPTConfig): return self.export_config.get("decoder_type", None) or get_modelopt_decoder_type(config) + + @staticmethod + def _get_unwrapped_mcore_model(model: llm.GPTModel): + from megatron.core.models.gpt import GPTModel as MCoreGPTModel + unwrapped_model = model + while not isinstance(unwrapped_model, MCoreGPTModel): + unwrapped_model = unwrapped_model.module + + return unwrapped_model + + def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): """Quantize the model and calibrate using given forward loop.""" algorithm = self.quantization_config["algorithm"] if algorithm is None: logging.info("Quantization algorithm set to None, returning the non-quantized model") - return wrapped_model.module + return wrapped_model logging.info(f"Quantizing model to {algorithm}...") self._setup(wrapped_model) - model = wrapped_model.module.module - model.config.pipeline_dtype = wrapped_model.pipeline.dtype + model = self._get_unwrapped_mcore_model(wrapped_model) decoder_type = self._get_decoder_type(model.config) quant_cfg = QUANT_CFG_CHOICES[algorithm] if "awq" in algorithm: @@ -273,7 +282,6 @@ def loop(model): def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None) -> None: assert self.export_config is not None, "Export config is not set" - torch_dtype = torch_dtype_from_precision(self.export_config["dtype"]) # TODO: Add sample generate # TODO: Support NeMo 2: # if model.cfg.megatron_amp_O2: @@ -283,9 +291,9 @@ def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None model.config.pipeline_model_parallel_size > 1 ) export_tensorrt_llm_checkpoint( - model=model.module.module, + model=self._get_unwrapped_mcore_model(model), decoder_type=self._get_decoder_type(model.config), - dtype=torch_dtype, + dtype=self.torch_dtype, export_dir=export_dir, inference_tensor_parallel=self.export_config["inference_tensor_parallel"], inference_pipeline_parallel=self.export_config["inference_pipeline_parallel"], @@ -298,13 +306,13 @@ def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None if dist.get_rank() == 0: self.nemo_checkpoint_path = nemo_checkpoint_path or self.nemo_checkpoint_path if self.nemo_checkpoint_path is not None: - tokenizer_src = os.path.join(self.nemo_checkpoint_path, 'nemo_tokenizer') + tokenizer_src = os.path.join(self.nemo_checkpoint_path, 'context', 'nemo_tokenizer') tokenizer_dst = os.path.join(export_dir, 'tokenizer') if os.path.exists(tokenizer_src) and not os.path.exists(tokenizer_dst): shutil.copytree(tokenizer_src, tokenizer_dst) else: - logging.info("Could not copy tokenizer from NeMo checkpoint") + logging.info("Could not copy tokenizer from the NeMo checkpoint") def get_calib_data_iter( diff --git a/nemo/lightning/fabric/fabric.py b/nemo/lightning/fabric/fabric.py index 60da546fd2b3..8de3dcc1fb9b 100644 --- a/nemo/lightning/fabric/fabric.py +++ b/nemo/lightning/fabric/fabric.py @@ -62,12 +62,13 @@ def load_model( from nemo.lightning.io import load_context + path = Path(path) if model is None: context = load_context(path) model = context.model dist_model = self.setup_module(model) - self.load(path, {"state_dict": dist_model}) + self.load(path / 'weights', {"state_dict": dist_model}) return dist_model diff --git a/nemo/lightning/fabric/plugins.py b/nemo/lightning/fabric/plugins.py index 513d6b86e62a..1876f5bd432b 100644 --- a/nemo/lightning/fabric/plugins.py +++ b/nemo/lightning/fabric/plugins.py @@ -122,6 +122,10 @@ def convert_module(self, module: nn.Module) -> nn.Module: This is optional and depends on the precision limitations during optimization. """ + from nemo.collections import llm + if isinstance(module, llm.GPTModel) and not hasattr(module, "module"): + return module + from megatron.core.transformer.module import Float16Module from megatron.core.utils import get_model_config diff --git a/nemo/lightning/io/api.py b/nemo/lightning/io/api.py index 1bbbe43f8df9..fb16667f7e6b 100644 --- a/nemo/lightning/io/api.py +++ b/nemo/lightning/io/api.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Callable, Optional, Type, TypeVar +from typing import Any, Callable, Optional, Type, TypeVar, Union import fiddle as fdl import pytorch_lightning as pl @@ -9,7 +9,7 @@ from nemo.lightning.io.pl import TrainerContext -def load_context(path: Path, subpath: Optional[str] = None) -> TrainerContext: +def load_context(path: Union[Path, str], subpath: Optional[str] = None) -> TrainerContext: """ Loads a TrainerContext from a json-file or directory. @@ -29,7 +29,8 @@ def load_context(path: Path, subpath: Optional[str] = None) -> TrainerContext: checkpoint: TrainerContext = load_ckpt("/path/to/checkpoint", subpath="model.config") """ - return load(path, output_type=TrainerContext, subpath=subpath) + path = Path(path) + return load(path / 'context', output_type=TrainerContext, subpath=subpath) def model_importer(target: Type[ConnectorMixin], ext: str) -> Callable[[Type[ConnT]], Type[ConnT]]: diff --git a/tests/collections/llm/test_hf_import.py b/tests/collections/llm/test_hf_import.py new file mode 100644 index 000000000000..388f0a7fcdf8 --- /dev/null +++ b/tests/collections/llm/test_hf_import.py @@ -0,0 +1,27 @@ +import argparse + +from nemo.collections import llm +from nemo import lightning as nl + + +def get_args(): + parser = argparse.ArgumentParser(description='Test Llama2 7B model model conversion from HF') + parser.add_argument('--hf_model', type=str, help="Original HF model") + parser.add_argument('--output_path', type=str, help="NeMo 2.0 export path") + + return parser.parse_args() + +if __name__ == '__main__': + args = get_args() + + model = llm.LlamaModel(config=llm.Llama2Config7B) + nemo2_path = llm.import_ckpt(model, "hf://" + args.hf_model, output_path=args.output_path) + + trainer = nl.Trainer( + devices=1, + strategy=nl.MegatronStrategy(tensor_model_parallel_size=1), + plugins=nl.MegatronMixedPrecision(precision='fp16') + ) + fabric = trainer.to_fabric() + trainer.strategy.setup_environment() + model = fabric.load_model(nemo2_path) From 0ea0e0c563b4818c27cfb1505eb543c5545bb276 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Mon, 14 Oct 2024 17:22:52 +0000 Subject: [PATCH 07/24] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/collections/llm/quantization/quantizer.py | 9 ++------- nemo/lightning/fabric/plugins.py | 1 + tests/collections/llm/test_hf_import.py | 5 +++-- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 9eac589202e2..60187362cf16 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -129,7 +129,6 @@ def __init__(self, quantization_config: dict, export_config: dict): assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}" self.torch_dtype = torch_dtype_from_precision(dtype) - def load_quantizable_model(self, nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel: trainer = nl.Trainer( devices=calib_tp * calib_pp, @@ -149,15 +148,13 @@ def load_quantizable_model(self, nemo_checkpoint_path: str, calib_tp: int = 1, c self.nemo_checkpoint_path = nemo_checkpoint_path return model - def _setup(self, model: llm.GPTModel) -> None: """Setup model for quantization.""" # TODO: disable activation checkpointing model.config.vocab_size = model.tokenizer.vocab_size - model.config.pipeline_dtype = self.torch_dtype # TODO: for some reason model.pipeline.dtype does not work + model.config.pipeline_dtype = self.torch_dtype # TODO: for some reason model.pipeline.dtype does not work model.freeze() - @staticmethod def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: """Modify model config for quantization.""" @@ -174,21 +171,19 @@ def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: model_cfg.apply_rope_fusion = False return model_cfg - def _get_decoder_type(self, config: llm.GPTConfig): return self.export_config.get("decoder_type", None) or get_modelopt_decoder_type(config) - @staticmethod def _get_unwrapped_mcore_model(model: llm.GPTModel): from megatron.core.models.gpt import GPTModel as MCoreGPTModel + unwrapped_model = model while not isinstance(unwrapped_model, MCoreGPTModel): unwrapped_model = unwrapped_model.module return unwrapped_model - def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): """Quantize the model and calibrate using given forward loop.""" algorithm = self.quantization_config["algorithm"] diff --git a/nemo/lightning/fabric/plugins.py b/nemo/lightning/fabric/plugins.py index 1876f5bd432b..e87f2b864147 100644 --- a/nemo/lightning/fabric/plugins.py +++ b/nemo/lightning/fabric/plugins.py @@ -123,6 +123,7 @@ def convert_module(self, module: nn.Module) -> nn.Module: """ from nemo.collections import llm + if isinstance(module, llm.GPTModel) and not hasattr(module, "module"): return module diff --git a/tests/collections/llm/test_hf_import.py b/tests/collections/llm/test_hf_import.py index 388f0a7fcdf8..cca6a54fe960 100644 --- a/tests/collections/llm/test_hf_import.py +++ b/tests/collections/llm/test_hf_import.py @@ -1,7 +1,7 @@ import argparse -from nemo.collections import llm from nemo import lightning as nl +from nemo.collections import llm def get_args(): @@ -11,6 +11,7 @@ def get_args(): return parser.parse_args() + if __name__ == '__main__': args = get_args() @@ -20,7 +21,7 @@ def get_args(): trainer = nl.Trainer( devices=1, strategy=nl.MegatronStrategy(tensor_model_parallel_size=1), - plugins=nl.MegatronMixedPrecision(precision='fp16') + plugins=nl.MegatronMixedPrecision(precision='fp16'), ) fabric = trainer.to_fabric() trainer.strategy.setup_environment() From aeed59fd2e6a8ae918564250f5f6577b2df95db7 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Tue, 15 Oct 2024 01:03:05 -0700 Subject: [PATCH 08/24] fix export Signed-off-by: Piotr Kaminski --- nemo/collections/llm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 4975263a23d5..250cb7b73f66 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -118,7 +118,7 @@ "MaskedTokenLossReduction", "MistralConfig7B", "MistralModel", - "MistralConfig", + "MixtralConfig", "MixtralConfig8x3B", "MixtralConfig8x7B", "MixtralConfig8x22B", From ab5beb3881f48061b00b9219062b08fe80572983 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Tue, 15 Oct 2024 08:07:28 +0000 Subject: [PATCH 09/24] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/collections/common/parts/run_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/nemo/collections/common/parts/run_utils.py b/nemo/collections/common/parts/run_utils.py index cdc5d46c50d0..61b8c26e222b 100644 --- a/nemo/collections/common/parts/run_utils.py +++ b/nemo/collections/common/parts/run_utils.py @@ -12,25 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import json -import subprocess +import os import shlex - -from pathlib import Path -from functools import lru_cache -from omegaconf import OmegaConf, DictConfig +import subprocess from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path import nemo_run as run -from nemo_run.core.tunnel import LocalTunnel, SSHTunnel from nemo_run.config import NEMORUN_HOME from nemo_run.core.execution.docker import DockerExecutor from nemo_run.core.execution.slurm import SlurmJobDetails from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer +from nemo_run.core.tunnel import LocalTunnel, SSHTunnel +from omegaconf import DictConfig, OmegaConf from nemo.utils import logging + @lru_cache(maxsize=2) def get_tunnel(**ssh_tunnel): return SSHTunnel(**ssh_tunnel) From 332a6dc51c782f2716fa949cb140806725b457d8 Mon Sep 17 00:00:00 2001 From: artbataev Date: Tue, 15 Oct 2024 08:08:08 +0000 Subject: [PATCH 10/24] Apply isort and black reformatting Signed-off-by: artbataev --- nemo/collections/common/parts/run_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nemo/collections/common/parts/run_utils.py b/nemo/collections/common/parts/run_utils.py index 61b8c26e222b..51c16b54c4f3 100644 --- a/nemo/collections/common/parts/run_utils.py +++ b/nemo/collections/common/parts/run_utils.py @@ -89,6 +89,7 @@ def get_mounts_from_config(cluster_config: dict, env_vars: dict = None): return mounts + def check_if_mounted(cluster_config, path_to_check): """Will check that path_to_check is referenced inside one of the mounts.""" for mount in get_mounts_from_config(cluster_config) + ['/nemo_run/code:/nemo_run/code']: @@ -330,8 +331,6 @@ def get_mounted_filepath(cluster_config: dict, filepath: str): return filepath - - def get_env_variables(cluster_config): """ Will get the environment variables from the cluster config and the user environment. @@ -570,7 +569,6 @@ def add_task( ) - def run_exp(exp, cluster_config, sequential=False): if cluster_config['executor'] == 'local': # locally we are always running sequentially - does that need to be changed? From c5a2a4d0269afac34405ad59567c67d8c96a6ad3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Kami=C5=84ski?= <67481570+Laplasjan107@users.noreply.github.com> Date: Tue, 15 Oct 2024 22:21:48 +0200 Subject: [PATCH 11/24] Fix output_path argument for HF import MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr KamiƄski <67481570+Laplasjan107@users.noreply.github.com> --- tests/collections/llm/test_hf_import.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/collections/llm/test_hf_import.py b/tests/collections/llm/test_hf_import.py index cca6a54fe960..53232eb02bb2 100644 --- a/tests/collections/llm/test_hf_import.py +++ b/tests/collections/llm/test_hf_import.py @@ -1,4 +1,5 @@ import argparse +from pathlib import Path from nemo import lightning as nl from nemo.collections import llm @@ -16,7 +17,7 @@ def get_args(): args = get_args() model = llm.LlamaModel(config=llm.Llama2Config7B) - nemo2_path = llm.import_ckpt(model, "hf://" + args.hf_model, output_path=args.output_path) + nemo2_path = llm.import_ckpt(model, "hf://" + args.hf_model, output_path=Path(args.output_path)) trainer = nl.Trainer( devices=1, @@ -25,4 +26,4 @@ def get_args(): ) fabric = trainer.to_fabric() trainer.strategy.setup_environment() - model = fabric.load_model(nemo2_path) + fabric.load_model(nemo2_path) From 20db14b0473f010b429c2ee6810c4b6b266ffb68 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Mon, 21 Oct 2024 05:08:58 -0700 Subject: [PATCH 12/24] fix fabric ckpt loading Signed-off-by: Piotr Kaminski --- .github/workflows/cicd-main.yml | 1 + nemo/collections/llm/quantization/quantizer.py | 5 ++++- nemo/lightning/fabric/fabric.py | 5 +++-- nemo/lightning/io/api.py | 7 +++---- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index ac5d80bcafab..45e49a597f95 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4470,6 +4470,7 @@ jobs: - L2_Speech_Transcription_Canary_Transcribe_Audio_Dir - L2_Megatron_GPT_Reranker - L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact + - L2_NeMo_2_PTQ_Llama2_FP8 if: always() runs-on: ubuntu-latest steps: diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 60187362cf16..fe4573bfe652 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from pathlib import Path import shutil from typing import Optional @@ -21,6 +22,7 @@ from datasets import load_dataset from nemo import lightning as nl +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.collections import llm from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision @@ -141,7 +143,8 @@ def load_quantizable_model(self, nemo_checkpoint_path: str, calib_tp: int = 1, c fabric = trainer.to_fabric() trainer.strategy.setup_environment() - model = nl.io.load_context(nemo_checkpoint_path).model + model_path = Path(nemo_checkpoint_path) + model = nl.io.load_context(ckpt_to_context_subdir(model_path)).model model.config = self.quantizable_model_config(model.config) model = fabric.load_model(nemo_checkpoint_path, model=model) diff --git a/nemo/lightning/fabric/fabric.py b/nemo/lightning/fabric/fabric.py index 54b6e382e479..4038e3f75f9f 100644 --- a/nemo/lightning/fabric/fabric.py +++ b/nemo/lightning/fabric/fabric.py @@ -10,6 +10,7 @@ from typing_extensions import Self, override from nemo.lightning.io.mixin import IOMixin, serialization, track_io +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir if TYPE_CHECKING: from megatron.core.optimizer import OptimizerConfig @@ -65,11 +66,11 @@ def load_model( path = Path(path) if model is None: - context = load_context(path) + context = load_context(ckpt_to_context_subdir(path)) model = context.model dist_model = self.setup_module(model) - self.load(path / 'weights', {"state_dict": dist_model}) + self.load(ckpt_to_weights_subdir(path), {"state_dict": dist_model}) return dist_model diff --git a/nemo/lightning/io/api.py b/nemo/lightning/io/api.py index 7c32b2cc8ed5..643b671d1d85 100644 --- a/nemo/lightning/io/api.py +++ b/nemo/lightning/io/api.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Callable, Optional, Type, TypeVar, Union +from typing import Callable, Optional, Type import pytorch_lightning as pl @@ -7,7 +7,7 @@ from nemo.lightning.io.pl import TrainerContext -def load_context(path: Union[Path, str], subpath: Optional[str] = None) -> TrainerContext: +def load_context(path: Path, subpath: Optional[str] = None) -> TrainerContext: """ Loads a TrainerContext from a json-file or directory. @@ -27,8 +27,7 @@ def load_context(path: Union[Path, str], subpath: Optional[str] = None) -> Train checkpoint: TrainerContext = load_ckpt("/path/to/checkpoint", subpath="model.config") """ - path = Path(path) - return load(path / 'context', output_type=TrainerContext, subpath=subpath) + return load(path, output_type=TrainerContext, subpath=subpath) def model_importer(target: Type[ConnectorMixin], ext: str) -> Callable[[Type[ConnT]], Type[ConnT]]: From dc493c2932e0ea80c6203a12870e0a8a1509da43 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Mon, 21 Oct 2024 12:10:23 +0000 Subject: [PATCH 13/24] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/collections/llm/quantization/quantizer.py | 4 ++-- nemo/lightning/fabric/fabric.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index fe4573bfe652..61e3dd8e1597 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -13,8 +13,8 @@ # limitations under the License. import os -from pathlib import Path import shutil +from pathlib import Path from typing import Optional import torch @@ -22,10 +22,10 @@ from datasets import load_dataset from nemo import lightning as nl -from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.collections import llm from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.utils import logging try: diff --git a/nemo/lightning/fabric/fabric.py b/nemo/lightning/fabric/fabric.py index 4038e3f75f9f..ddad49f7d211 100644 --- a/nemo/lightning/fabric/fabric.py +++ b/nemo/lightning/fabric/fabric.py @@ -6,11 +6,10 @@ import lightning_fabric as lb import pytorch_lightning as pl from torch import nn - from typing_extensions import Self, override -from nemo.lightning.io.mixin import IOMixin, serialization, track_io from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir +from nemo.lightning.io.mixin import IOMixin, serialization, track_io if TYPE_CHECKING: from megatron.core.optimizer import OptimizerConfig From 23fc9c24dc692e0bb405b861bfc8cd2132db307e Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Mon, 21 Oct 2024 08:37:26 -0700 Subject: [PATCH 14/24] code review suggestions Signed-off-by: Piotr Kaminski --- .github/workflows/cicd-main.yml | 2 +- examples/llm/quantization/ptq.py | 71 +++++---- nemo/collections/llm/quantization/__init__.py | 17 ++- .../collections/llm/quantization/quantizer.py | 139 ++++++------------ nemo/collections/llm/quantization/utils.py | 63 ++++++++ 5 files changed, 163 insertions(+), 129 deletions(-) create mode 100644 nemo/collections/llm/quantization/utils.py diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 45e49a597f95..f5b4735123c5 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4317,7 +4317,7 @@ jobs: SCRIPT: | python tests/collections/llm/test_hf_import.py --hf_model /home/TestData/nlp/megatron_llama/llama-ci-hf --output_path /tmp/nemo2_ckpt - python examples/llm/quantization/ptq.py -nc /tmp/nemo2_ckpt -algo fp8 -out /tmp/nemo2_ptq_engine + python examples/llm/quantization/ptq.py -nc /tmp/nemo2_ckpt -algo fp8 -out /tmp/nemo2_ptq_engine -calib_ds /home/TestData/nlp/test_quantization/test.json AFTER_SCRIPT: | rm -rf /tmp/nemo2_ckpt diff --git a/examples/llm/quantization/ptq.py b/examples/llm/quantization/ptq.py index 6e1d9d54e00b..5d92864224e2 100644 --- a/examples/llm/quantization/ptq.py +++ b/examples/llm/quantization/ptq.py @@ -17,7 +17,7 @@ import torch from tqdm import tqdm -from nemo.collections.llm.quantization import Quantizer, get_calib_data_iter +from nemo.collections.llm import quantization def get_args(): @@ -44,7 +44,9 @@ def get_args(): '-awq_bs', '--awq_block_size', type=int, default=128, help='Block size for AWQ quantization algorithms' ) parser.add_argument('--sq_alpha', type=float, default=0.5, help='Smooth-Quant alpha parameter') - parser.add_argument('--enable_kv_cache', type=bool, help='Enables KV-cache quantization') + parser.add_argument('--enable_kv_cache', help='Enables KV-cache quantization', action='store_true') + parser.add_argument('--disable_kv_cache', dest='enable_kv_cache', action='store_false') + parser.set_defaults(enable_kv_cache=True) parser.add_argument( '-dt', '--dtype', default="bf16", choices=["16", "bf16"], help='Default precision for non-quantized layers' ) @@ -57,42 +59,23 @@ def get_args(): '-calib_ds', '--calibration_dataset', default="cnn_dailymail", - choices=["wikitext", "cnn_dailymail"], type=str, - help='Calibration dataset to be used', + help='Calibration dataset to be used. Should be \"wikitext\", \"cnn_dailymail\" or path to a local .json file', ) - return parser.parse_args(sys.argv[1:]) - - -def get_quantizer_config(args): + args = parser.parse_args() if args.output_path is None: args.output_path = ( f"./qnemo_{args.algorithm}_tp{args.tensor_parallelism_size}_pp{args.pipeline_parallelism_size}" ) - - quantization_config = { - "algorithm": None if args.algorithm == "no_quant" else args.algorithm, - "awq_block_size": args.awq_block_size, - "sq_alpha": args.sq_alpha, - "enable_kv_cache": args.enable_kv_cache, - } - export_config = { - "path": args.output_path, - "decoder_type": args.decoder_type, - "inference_tensor_parallel": args.tensor_parallelism_size, - "inference_pipeline_parallel": args.pipeline_parallelism_size, - "dtype": args.dtype, - } - - return quantization_config, export_config + return args def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): def _iterator(): CHARACTERS_PER_TOKEN = 4 - dataloader = get_calib_data_iter( + dataloader = quantization.get_calib_data_iter( data=dataset, max_sequence_length=CHARACTERS_PER_TOKEN * seq_len, batch_size=batch_size, @@ -112,24 +95,40 @@ def _iterator_getter(): def main(): - params = get_args() - quantization_config, export_config = get_quantizer_config(params) - quantizer = Quantizer(quantization_config, export_config) - model = quantizer.load_quantizable_model(params.nemo_checkpoint, params.calib_tp, params.calib_pp) + args = get_args() + + quantization_config = quantization.QuantizationConfig( + algorithm=None if args.algorithm == "no_quant" else args.algorithm, + awq_block_size=args.awq_block_size, + sq_alpha=args.sq_alpha, + enable_kv_cache=args.enable_kv_cache + ) + + export_config = quantization.ExportConfig( + path=args.output_path, + decoder_type=args.decoder_type, + inference_tensor_parallel=args.tensor_parallelism_size, + inference_pipeline_parallel=args.pipeline_parallelism_size, + dtype=args.dtype, + ) + + + quantizer = quantization.Quantizer(quantization_config, export_config) + model = quantization.load_with_modelopt_layer_spec(args.nemo_checkpoint, args.calib_tp, args.calib_pp) get_dataloader = create_data_iterator_getter( model, - dataset=params.calibration_dataset, - seq_len=params.seq_len, - batch_size=params.batch_size, - calibration_size=params.calibration_dataset_size, + dataset=args.calibration_dataset, + seq_len=args.seq_len, + batch_size=args.batch_size, + calibration_size=args.calibration_dataset_size, ) forward_loop = quantizer.create_megatron_forward_loop( get_dataloader, - num_batches=params.calibration_dataset_size // params.batch_size, - seq_length=params.seq_len, - micro_batch_size=params.batch_size, + num_batches=args.calibration_dataset_size // args.batch_size, + seq_length=args.seq_len, + micro_batch_size=args.batch_size, ) model = quantizer.quantize(model, forward_loop) diff --git a/nemo/collections/llm/quantization/__init__.py b/nemo/collections/llm/quantization/__init__.py index 3a8b2989a6a1..c0dd4d0f7497 100644 --- a/nemo/collections/llm/quantization/__init__.py +++ b/nemo/collections/llm/quantization/__init__.py @@ -12,4 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .quantizer import Quantizer, get_calib_data_iter +from .quantizer import ( + Quantizer, + QuantizationConfig, + ExportConfig, + get_calib_data_iter, +) + +from .utils import load_with_modelopt_layer_spec + +__all__ = [ + "Quantizer", + "QuantizationConfig", + "ExportConfig", + "get_calib_data_iter", + "load_with_modelopt_layer_spec" +] diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 61e3dd8e1597..10c975be3314 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -12,21 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass import os import shutil -from pathlib import Path -from typing import Optional +from typing import Optional, Union import torch import torch.distributed as dist from datasets import load_dataset -from nemo import lightning as nl from nemo.collections import llm -from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision -from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.utils import logging +from .utils import get_unwrapped_mcore_model try: import modelopt.torch.quantization as mtq @@ -50,6 +48,21 @@ SUPPORTED_DTYPE = [16, "16", "bf16"] # Default precision for non-quantized layers +@dataclass +class QuantizationConfig: + algorithm: Optional[str] = "fp8" # one of QUANT_CFG_CHOICES keys + awq_block_size: int = 128 + sq_alpha: float = 0.5 + enable_kv_cache: bool = True + +@dataclass +class ExportConfig: + path: str + dtype: Union[str, int] = "bf16" + decoder_type: Optional[str] = None + inference_tensor_parallel: int = 1 + inference_pipeline_parallel: int = 1 + def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: """Infers the modelopt decoder type from GPTConfig class""" @@ -75,9 +88,8 @@ def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: return "llama" -# TODO: Support PP class Quantizer: - """Post-training quantization (PTQ) and TRT-LLM export of Nemo checkpoints. + """Post-training quantization (PTQ) and TRT-LLM export of NeMo 2.0 checkpoints. PTQ converts selected model layers to low-precision format (e.g., INT4, FP8) for efficient serving. The process consist of several steps: @@ -96,60 +108,27 @@ class Quantizer: for TensorRT-LLM deployment. This is useful to getting baseline results for a full-precision model. """ - def __init__(self, quantization_config: dict, export_config: dict): - """Initialize Quantizer with quantization and export configurations. - - Expected keys in `quantization_config`: - - algorithm: (optional) str - - awq_block_size: int (only for awq algorithms) - - sq_alpha: float (only for smooth quant algorithms) - - enable_kv_cache: bool (default: None i.e. auto-detect based on algorithm and decoder_type) - - Expected keys in `export_config`: - - dtype: str/int - - decoder_type: (optional) str - - inference_tensor_parallel: int - - inference_pipeline_parallel: int - - path: str - """ + def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig): + """Initialize Quantizer with quantization and export configurations. """ + if not HAVE_MODELOPT: raise RuntimeError("nvidia-modelopt is needed to use Quantizer") from HAVE_MODELOPT_ERROR if not torch.cuda.is_available(): raise EnvironmentError("GPU is required for the quantization.") - self.quantization_config = quantization_config - self.export_config = export_config + self.quantization_config: QuantizationConfig = quantization_config + self.export_config: ExportConfig = export_config self.nemo_checkpoint_path = None - algorithm = quantization_config.get("algorithm", None) - dtype = export_config["dtype"] + algorithm = quantization_config.algorithm + dtype = export_config.dtype - # Quantization sanity checks + # Export and Quantization config sanity checks assert algorithm is None or algorithm in QUANT_CFG_CHOICES, f"Unsupported quantization algorithm: {algorithm}" - # Export sanity checks if export_config is not None: assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}" self.torch_dtype = torch_dtype_from_precision(dtype) - def load_quantizable_model(self, nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel: - trainer = nl.Trainer( - devices=calib_tp * calib_pp, - strategy=nl.MegatronStrategy( - tensor_model_parallel_size=calib_tp, - pipeline_model_parallel_size=calib_pp, - ), - plugins=nl.MegatronMixedPrecision(precision='16-mixed'), - ) - fabric = trainer.to_fabric() - trainer.strategy.setup_environment() - - model_path = Path(nemo_checkpoint_path) - model = nl.io.load_context(ckpt_to_context_subdir(model_path)).model - model.config = self.quantizable_model_config(model.config) - model = fabric.load_model(nemo_checkpoint_path, model=model) - - self.nemo_checkpoint_path = nemo_checkpoint_path - return model def _setup(self, model: llm.GPTModel) -> None: """Setup model for quantization.""" @@ -158,58 +137,34 @@ def _setup(self, model: llm.GPTModel) -> None: model.config.pipeline_dtype = self.torch_dtype # TODO: for some reason model.pipeline.dtype does not work model.freeze() - @staticmethod - def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: - """Modify model config for quantization.""" - - model_cfg.transformer_layer_spec = get_gpt_layer_modelopt_spec() - if model_cfg.sequence_parallel: - logging.warning("Disabling sequence parallelism for quantization...") - model_cfg.sequence_parallel = False - # Only custom ModelOpt spec is supported for Quantization: this custom spec is largely based on local Megatron-LM - # layer definitions to avoid Transformer Engine implementations that are currently not supported. - # This layer spec also requires RoPE fusion to be disabled for tensor view operations in attention - # layer implementation from megatron/core/transformer/dot_product_attention.py to be functional. - model_cfg.name = "modelopt" - model_cfg.apply_rope_fusion = False - return model_cfg def _get_decoder_type(self, config: llm.GPTConfig): - return self.export_config.get("decoder_type", None) or get_modelopt_decoder_type(config) - - @staticmethod - def _get_unwrapped_mcore_model(model: llm.GPTModel): - from megatron.core.models.gpt import GPTModel as MCoreGPTModel - - unwrapped_model = model - while not isinstance(unwrapped_model, MCoreGPTModel): - unwrapped_model = unwrapped_model.module + return self.export_config.decoder_type or get_modelopt_decoder_type(config) - return unwrapped_model - def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): + def quantize(self, model: llm.GPTModel, forward_loop): """Quantize the model and calibrate using given forward loop.""" - algorithm = self.quantization_config["algorithm"] + algorithm = self.quantization_config.algorithm if algorithm is None: logging.info("Quantization algorithm set to None, returning the non-quantized model") - return wrapped_model + return model logging.info(f"Quantizing model to {algorithm}...") - self._setup(wrapped_model) - model = self._get_unwrapped_mcore_model(wrapped_model) - decoder_type = self._get_decoder_type(model.config) + self._setup(model) + unwrapped_model = get_unwrapped_mcore_model(model) + decoder_type = self._get_decoder_type(unwrapped_model.config) quant_cfg = QUANT_CFG_CHOICES[algorithm] if "awq" in algorithm: weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] - weight_quantizer["block_sizes"][-1] = self.quantization_config["awq_block_size"] + weight_quantizer["block_sizes"][-1] = self.quantization_config.awq_block_size # Always turn on FP8 kv cache to save memory footprint. # For int8_sq, we use int8 kv cache. # TODO: Investigate why enabling FP8 kv cache will cause accuracy regressions for Nemotron. - enable_quant_kv_cache = self.quantization_config.get("enable_kv_cache", None) + enable_quant_kv_cache = self.quantization_config.enable_kv_cache if enable_quant_kv_cache is None: enable_quant_kv_cache = "int8" not in algorithm and decoder_type != "gptnext" logging.info(f'{"Enabled" if enable_quant_kv_cache else "Disabled"} KV cache quantization') @@ -219,11 +174,11 @@ def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): "enable": enable_quant_kv_cache, } if algorithm == "int8_sq": - sq_alpha = self.quantization_config["sq_alpha"] + sq_alpha = self.quantization_config.sq_alpha logging.info(f"Using int8_sq alpha = {sq_alpha}") quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": sq_alpha} - model = mtq.quantize(model, quant_cfg, forward_loop) + unwrapped_model = mtq.quantize(unwrapped_model, quant_cfg, forward_loop) if decoder_type == "gptnext": # We found squared_relu may have an under-calibration problem. @@ -236,14 +191,15 @@ def quantize(self, wrapped_model: llm.GPTConfig, forward_loop): case _: maxbound = 0 - model = mtq.postprocess_amax( - model, "*input_quantizer", lambda amax: torch.clamp(amax, min=0.01 * maxbound) + unwrapped_model = mtq.postprocess_amax( + unwrapped_model, "*input_quantizer", lambda amax: torch.clamp(amax, min=0.01 * maxbound) ) if dist.get_rank() == 0: - mtq.print_quant_summary(model) + mtq.print_quant_summary(unwrapped_model) + + return model - return wrapped_model def create_megatron_forward_loop( self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None @@ -278,23 +234,24 @@ def loop(model): return loop + def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None) -> None: assert self.export_config is not None, "Export config is not set" # TODO: Add sample generate # TODO: Support NeMo 2: # if model.cfg.megatron_amp_O2: # model.model = unwrap_model(model.model, Float16Module) - export_dir = self.export_config["path"] + export_dir = self.export_config.path use_nfs_workspace = (model.trainer._fabric.__io__.num_nodes > 1) or ( model.config.pipeline_model_parallel_size > 1 ) export_tensorrt_llm_checkpoint( - model=self._get_unwrapped_mcore_model(model), + model=get_unwrapped_mcore_model(model), decoder_type=self._get_decoder_type(model.config), dtype=self.torch_dtype, export_dir=export_dir, - inference_tensor_parallel=self.export_config["inference_tensor_parallel"], - inference_pipeline_parallel=self.export_config["inference_pipeline_parallel"], + inference_tensor_parallel=self.export_config.inference_tensor_parallel, + inference_pipeline_parallel=self.export_config.inference_pipeline_parallel, use_nfs_workspace=use_nfs_workspace, ) diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py new file mode 100644 index 000000000000..5e8ef955414a --- /dev/null +++ b/nemo/collections/llm/quantization/utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from nemo.collections import llm +from nemo import lightning as nl +from nemo.utils import logging +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec + +def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: + """Modify model config for quantization.""" + + model_cfg.transformer_layer_spec = get_gpt_layer_modelopt_spec() + if model_cfg.sequence_parallel: + logging.warning("Disabling sequence parallelism for quantization...") + model_cfg.sequence_parallel = False + # Only custom ModelOpt spec is supported for Quantization: this custom spec is largely based on local Megatron-LM + # layer definitions to avoid Transformer Engine implementations that are currently not supported. + # This layer spec also requires RoPE fusion to be disabled for tensor view operations in attention + # layer implementation from megatron/core/transformer/dot_product_attention.py to be functional. + model_cfg.name = "modelopt" + model_cfg.apply_rope_fusion = False + return model_cfg + + +def load_with_modelopt_layer_spec(nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel: + trainer = nl.Trainer( + devices=calib_tp * calib_pp, + strategy=nl.MegatronStrategy( + tensor_model_parallel_size=calib_tp, + pipeline_model_parallel_size=calib_pp, + ), + plugins=nl.MegatronMixedPrecision(precision='16-mixed'), + ) + fabric = trainer.to_fabric() + fabric.launch() + + model_path = Path(nemo_checkpoint_path) + model = nl.io.load_context(ckpt_to_context_subdir(model_path)).model + model.config = quantizable_model_config(model.config) + return fabric.load_model(nemo_checkpoint_path, model=model) + + +def get_unwrapped_mcore_model(model: llm.GPTModel): + from megatron.core.models.gpt import GPTModel as MCoreGPTModel + unwrapped_model = model + while not isinstance(unwrapped_model, MCoreGPTModel): + unwrapped_model = unwrapped_model.module + + return unwrapped_model From 40ebf78be4766a5bdfe6118ea402ac21fad0b0f8 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Mon, 21 Oct 2024 15:38:31 +0000 Subject: [PATCH 15/24] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- examples/llm/quantization/ptq.py | 3 +-- nemo/collections/llm/quantization/__init__.py | 16 ++-------------- nemo/collections/llm/quantization/quantizer.py | 14 ++++++-------- nemo/collections/llm/quantization/utils.py | 8 +++++--- 4 files changed, 14 insertions(+), 27 deletions(-) diff --git a/examples/llm/quantization/ptq.py b/examples/llm/quantization/ptq.py index 5d92864224e2..0c79fb1bfe7a 100644 --- a/examples/llm/quantization/ptq.py +++ b/examples/llm/quantization/ptq.py @@ -101,7 +101,7 @@ def main(): algorithm=None if args.algorithm == "no_quant" else args.algorithm, awq_block_size=args.awq_block_size, sq_alpha=args.sq_alpha, - enable_kv_cache=args.enable_kv_cache + enable_kv_cache=args.enable_kv_cache, ) export_config = quantization.ExportConfig( @@ -112,7 +112,6 @@ def main(): dtype=args.dtype, ) - quantizer = quantization.Quantizer(quantization_config, export_config) model = quantization.load_with_modelopt_layer_spec(args.nemo_checkpoint, args.calib_tp, args.calib_pp) diff --git a/nemo/collections/llm/quantization/__init__.py b/nemo/collections/llm/quantization/__init__.py index c0dd4d0f7497..d118e9adf454 100644 --- a/nemo/collections/llm/quantization/__init__.py +++ b/nemo/collections/llm/quantization/__init__.py @@ -12,19 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .quantizer import ( - Quantizer, - QuantizationConfig, - ExportConfig, - get_calib_data_iter, -) - +from .quantizer import ExportConfig, QuantizationConfig, Quantizer, get_calib_data_iter from .utils import load_with_modelopt_layer_spec -__all__ = [ - "Quantizer", - "QuantizationConfig", - "ExportConfig", - "get_calib_data_iter", - "load_with_modelopt_layer_spec" -] +__all__ = ["Quantizer", "QuantizationConfig", "ExportConfig", "get_calib_data_iter", "load_with_modelopt_layer_spec"] diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 10c975be3314..b60740db3748 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import os import shutil +from dataclasses import dataclass from typing import Optional, Union import torch @@ -24,6 +24,7 @@ from nemo.collections import llm from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import logging + from .utils import get_unwrapped_mcore_model try: @@ -48,13 +49,15 @@ SUPPORTED_DTYPE = [16, "16", "bf16"] # Default precision for non-quantized layers + @dataclass class QuantizationConfig: - algorithm: Optional[str] = "fp8" # one of QUANT_CFG_CHOICES keys + algorithm: Optional[str] = "fp8" # one of QUANT_CFG_CHOICES keys awq_block_size: int = 128 sq_alpha: float = 0.5 enable_kv_cache: bool = True + @dataclass class ExportConfig: path: str @@ -109,7 +112,7 @@ class Quantizer: """ def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig): - """Initialize Quantizer with quantization and export configurations. """ + """Initialize Quantizer with quantization and export configurations.""" if not HAVE_MODELOPT: raise RuntimeError("nvidia-modelopt is needed to use Quantizer") from HAVE_MODELOPT_ERROR @@ -129,7 +132,6 @@ def __init__(self, quantization_config: QuantizationConfig, export_config: Expor assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}" self.torch_dtype = torch_dtype_from_precision(dtype) - def _setup(self, model: llm.GPTModel) -> None: """Setup model for quantization.""" # TODO: disable activation checkpointing @@ -137,11 +139,9 @@ def _setup(self, model: llm.GPTModel) -> None: model.config.pipeline_dtype = self.torch_dtype # TODO: for some reason model.pipeline.dtype does not work model.freeze() - def _get_decoder_type(self, config: llm.GPTConfig): return self.export_config.decoder_type or get_modelopt_decoder_type(config) - def quantize(self, model: llm.GPTModel, forward_loop): """Quantize the model and calibrate using given forward loop.""" algorithm = self.quantization_config.algorithm @@ -200,7 +200,6 @@ def quantize(self, model: llm.GPTModel, forward_loop): return model - def create_megatron_forward_loop( self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None ): @@ -234,7 +233,6 @@ def loop(model): return loop - def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None) -> None: assert self.export_config is not None, "Export config is not set" # TODO: Add sample generate diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py index 5e8ef955414a..2573da5586ca 100644 --- a/nemo/collections/llm/quantization/utils.py +++ b/nemo/collections/llm/quantization/utils.py @@ -14,11 +14,12 @@ from pathlib import Path -from nemo.collections import llm from nemo import lightning as nl -from nemo.utils import logging -from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.collections import llm from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.utils import logging + def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: """Modify model config for quantization.""" @@ -56,6 +57,7 @@ def load_with_modelopt_layer_spec(nemo_checkpoint_path: str, calib_tp: int = 1, def get_unwrapped_mcore_model(model: llm.GPTModel): from megatron.core.models.gpt import GPTModel as MCoreGPTModel + unwrapped_model = model while not isinstance(unwrapped_model, MCoreGPTModel): unwrapped_model = unwrapped_model.module From 93d7d667bbe5084c978c8194bc21c2a2e60ec8d8 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Mon, 21 Oct 2024 08:47:19 -0700 Subject: [PATCH 16/24] remove unused import Signed-off-by: Piotr Kaminski --- examples/llm/quantization/ptq.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/llm/quantization/ptq.py b/examples/llm/quantization/ptq.py index 0c79fb1bfe7a..97385a68a907 100644 --- a/examples/llm/quantization/ptq.py +++ b/examples/llm/quantization/ptq.py @@ -13,7 +13,6 @@ # limitations under the License. import argparse -import sys import torch from tqdm import tqdm From cf99e13ce06f521316c6f16cab017b5be7b0e3e0 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Tue, 22 Oct 2024 01:25:44 -0700 Subject: [PATCH 17/24] use cnn dataset in github ci Signed-off-by: Piotr Kaminski --- .github/workflows/cicd-main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index f5b4735123c5..45e49a597f95 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4317,7 +4317,7 @@ jobs: SCRIPT: | python tests/collections/llm/test_hf_import.py --hf_model /home/TestData/nlp/megatron_llama/llama-ci-hf --output_path /tmp/nemo2_ckpt - python examples/llm/quantization/ptq.py -nc /tmp/nemo2_ckpt -algo fp8 -out /tmp/nemo2_ptq_engine -calib_ds /home/TestData/nlp/test_quantization/test.json + python examples/llm/quantization/ptq.py -nc /tmp/nemo2_ckpt -algo fp8 -out /tmp/nemo2_ptq_engine AFTER_SCRIPT: | rm -rf /tmp/nemo2_ckpt From c3e6296d28a212812606e48b870755f74c2bef3b Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Tue, 22 Oct 2024 05:44:01 -0700 Subject: [PATCH 18/24] applied code review Signed-off-by: Piotr Kaminski --- .github/workflows/cicd-main.yml | 2 +- nemo/collections/llm/quantization/quantizer.py | 18 ++++++------------ .../llm/quantization => scripts/llm}/ptq.py | 4 ++-- 3 files changed, 9 insertions(+), 15 deletions(-) rename {examples/llm/quantization => scripts/llm}/ptq.py (98%) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 45e49a597f95..71650d350e2e 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4317,7 +4317,7 @@ jobs: SCRIPT: | python tests/collections/llm/test_hf_import.py --hf_model /home/TestData/nlp/megatron_llama/llama-ci-hf --output_path /tmp/nemo2_ckpt - python examples/llm/quantization/ptq.py -nc /tmp/nemo2_ckpt -algo fp8 -out /tmp/nemo2_ptq_engine + python scripts/llm/ptq.py -nc /tmp/nemo2_ckpt -algo fp8 -out /tmp/nemo2_ptq_engine AFTER_SCRIPT: | rm -rf /tmp/nemo2_ckpt diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index b60740db3748..afe994bdd85b 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -55,7 +55,7 @@ class QuantizationConfig: algorithm: Optional[str] = "fp8" # one of QUANT_CFG_CHOICES keys awq_block_size: int = 128 sq_alpha: float = 0.5 - enable_kv_cache: bool = True + enable_kv_cache: Optional[bool] = None @dataclass @@ -121,11 +121,9 @@ def __init__(self, quantization_config: QuantizationConfig, export_config: Expor self.quantization_config: QuantizationConfig = quantization_config self.export_config: ExportConfig = export_config - self.nemo_checkpoint_path = None algorithm = quantization_config.algorithm dtype = export_config.dtype - # Export and Quantization config sanity checks assert algorithm is None or algorithm in QUANT_CFG_CHOICES, f"Unsupported quantization algorithm: {algorithm}" if export_config is not None: @@ -233,7 +231,7 @@ def loop(model): return loop - def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None) -> None: + def export(self, model: llm.GPTModel) -> None: assert self.export_config is not None, "Export config is not set" # TODO: Add sample generate # TODO: Support NeMo 2: @@ -257,15 +255,11 @@ def export(self, model: llm.GPTModel, nemo_checkpoint_path: Optional[str] = None logging.info(f"Export succeeded, model has been exported to {export_dir}. Saving tokenizer if possible...") if dist.get_rank() == 0: - self.nemo_checkpoint_path = nemo_checkpoint_path or self.nemo_checkpoint_path - if self.nemo_checkpoint_path is not None: - tokenizer_src = os.path.join(self.nemo_checkpoint_path, 'context', 'nemo_tokenizer') + try: tokenizer_dst = os.path.join(export_dir, 'tokenizer') - - if os.path.exists(tokenizer_src) and not os.path.exists(tokenizer_dst): - shutil.copytree(tokenizer_src, tokenizer_dst) - else: - logging.info("Could not copy tokenizer from the NeMo checkpoint") + model.tokenizer.tokenizer.save_pretrained(tokenizer_dst) + except Exception as err: + logging.warning("Could not save the tokenizer: " + str(err)) def get_calib_data_iter( diff --git a/examples/llm/quantization/ptq.py b/scripts/llm/ptq.py similarity index 98% rename from examples/llm/quantization/ptq.py rename to scripts/llm/ptq.py index 97385a68a907..623a6199b401 100644 --- a/examples/llm/quantization/ptq.py +++ b/scripts/llm/ptq.py @@ -35,7 +35,7 @@ def get_args(): '-algo', '--algorithm', type=str, - default="no_quant", + default="fp8", choices=["no_quant", "int8", "int8_sq", "fp8", "int4_awq", "w4a8_awq", "int4"], help='TensorRT-Model-Optimizer quantization algorithm', ) @@ -45,7 +45,7 @@ def get_args(): parser.add_argument('--sq_alpha', type=float, default=0.5, help='Smooth-Quant alpha parameter') parser.add_argument('--enable_kv_cache', help='Enables KV-cache quantization', action='store_true') parser.add_argument('--disable_kv_cache', dest='enable_kv_cache', action='store_false') - parser.set_defaults(enable_kv_cache=True) + parser.set_defaults(enable_kv_cache=None) parser.add_argument( '-dt', '--dtype', default="bf16", choices=["16", "bf16"], help='Default precision for non-quantized layers' ) From b8a530c9cedcd0a0c2814304fee68c5e9fb2ba66 Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 23 Oct 2024 03:19:23 -0700 Subject: [PATCH 19/24] code review changes Signed-off-by: Piotr Kaminski --- .../collections/llm/quantization/quantizer.py | 24 +++++++++---------- nemo/collections/llm/quantization/utils.py | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index afe994bdd85b..b16b60dd2a31 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import shutil from dataclasses import dataclass from typing import Optional, Union @@ -22,7 +21,6 @@ from datasets import load_dataset from nemo.collections import llm -from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import logging from .utils import get_unwrapped_mcore_model @@ -52,7 +50,15 @@ @dataclass class QuantizationConfig: - algorithm: Optional[str] = "fp8" # one of QUANT_CFG_CHOICES keys + """Quantization parameters. + + Available quantization methods are listed in `QUANT_CFG_CHOICES` dictionary above. + Please consult Model Optimizer documentation https://nvidia.github.io/TensorRT-Model-Optimizer/ for details. + + Quantization algorithm can also be conveniently set to None to perform only weights export step + for TensorRT-LLM deployment. This is useful to getting baseline results for a full-precision model. + """ + algorithm: Optional[str] = "fp8" awq_block_size: int = 128 sq_alpha: float = 0.5 enable_kv_cache: Optional[bool] = None @@ -60,6 +66,7 @@ class QuantizationConfig: @dataclass class ExportConfig: + """Inference configuration for the quantized TensorRT-LLM engine""" path: str dtype: Union[str, int] = "bf16" decoder_type: Optional[str] = None @@ -103,16 +110,11 @@ class Quantizer: The output directory produced is intended to be consumed by TensorRT-LLM toolbox for efficient inference. This can be achieved using NeMo inference containers. - - Available quantization methods are listed in `QUANT_CFG_CHOICES` dictionary above. - Please consult Model Optimizer documentation https://nvidia.github.io/TensorRT-Model-Optimizer/ for details. - - Quantization algorithm can also be conveniently set to None to perform only weights export step - for TensorRT-LLM deployment. This is useful to getting baseline results for a full-precision model. """ def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig): """Initialize Quantizer with quantization and export configurations.""" + from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision if not HAVE_MODELOPT: raise RuntimeError("nvidia-modelopt is needed to use Quantizer") from HAVE_MODELOPT_ERROR @@ -234,9 +236,7 @@ def loop(model): def export(self, model: llm.GPTModel) -> None: assert self.export_config is not None, "Export config is not set" # TODO: Add sample generate - # TODO: Support NeMo 2: - # if model.cfg.megatron_amp_O2: - # model.model = unwrap_model(model.model, Float16Module) + # TODO: Support megatron_amp_O2 export_dir = self.export_config.path use_nfs_workspace = (model.trainer._fabric.__io__.num_nodes > 1) or ( model.config.pipeline_model_parallel_size > 1 diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py index 2573da5586ca..0793a03461f6 100644 --- a/nemo/collections/llm/quantization/utils.py +++ b/nemo/collections/llm/quantization/utils.py @@ -16,14 +16,14 @@ from nemo import lightning as nl from nemo.collections import llm -from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.utils import logging def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: - """Modify model config for quantization.""" + """Modify model config for TensorRT Model Optimizer""" + from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec model_cfg.transformer_layer_spec = get_gpt_layer_modelopt_spec() if model_cfg.sequence_parallel: logging.warning("Disabling sequence parallelism for quantization...") From 371b4221e600056619b63b80f663f8dde4095a79 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Wed, 23 Oct 2024 10:23:15 +0000 Subject: [PATCH 20/24] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/collections/llm/quantization/quantizer.py | 2 ++ nemo/collections/llm/quantization/utils.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index b16b60dd2a31..b2544ae4ecf9 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -58,6 +58,7 @@ class QuantizationConfig: Quantization algorithm can also be conveniently set to None to perform only weights export step for TensorRT-LLM deployment. This is useful to getting baseline results for a full-precision model. """ + algorithm: Optional[str] = "fp8" awq_block_size: int = 128 sq_alpha: float = 0.5 @@ -67,6 +68,7 @@ class QuantizationConfig: @dataclass class ExportConfig: """Inference configuration for the quantized TensorRT-LLM engine""" + path: str dtype: Union[str, int] = "bf16" decoder_type: Optional[str] = None diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py index 0793a03461f6..f5be5812ef05 100644 --- a/nemo/collections/llm/quantization/utils.py +++ b/nemo/collections/llm/quantization/utils.py @@ -23,7 +23,10 @@ def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: """Modify model config for TensorRT Model Optimizer""" - from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec + from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import ( + get_gpt_layer_modelopt_spec, + ) + model_cfg.transformer_layer_spec = get_gpt_layer_modelopt_spec() if model_cfg.sequence_parallel: logging.warning("Disabling sequence parallelism for quantization...") From d7e54c28f559032113d3f358c37f4d28ed914f2a Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 23 Oct 2024 04:53:29 -0700 Subject: [PATCH 21/24] simplify interface for data iterator Signed-off-by: Piotr Kaminski --- nemo/collections/llm/quantization/__init__.py | 4 +- .../collections/llm/quantization/quantizer.py | 49 ++++++++++++++++++- scripts/llm/ptq.py | 49 ++----------------- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/nemo/collections/llm/quantization/__init__.py b/nemo/collections/llm/quantization/__init__.py index d118e9adf454..5984a00b65bc 100644 --- a/nemo/collections/llm/quantization/__init__.py +++ b/nemo/collections/llm/quantization/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .quantizer import ExportConfig, QuantizationConfig, Quantizer, get_calib_data_iter +from .quantizer import ExportConfig, QuantizationConfig, Quantizer, get_calib_data_iter, create_data_iterator_getter from .utils import load_with_modelopt_layer_spec -__all__ = ["Quantizer", "QuantizationConfig", "ExportConfig", "get_calib_data_iter", "load_with_modelopt_layer_spec"] +__all__ = ["Quantizer", "QuantizationConfig", "ExportConfig", "get_calib_data_iter", "load_with_modelopt_layer_spec", "create_data_iterator_getter"] diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index b2544ae4ecf9..0333e35c781a 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -14,6 +14,7 @@ import os from dataclasses import dataclass +from tqdm import tqdm from typing import Optional, Union import torch @@ -64,6 +65,12 @@ class QuantizationConfig: sq_alpha: float = 0.5 enable_kv_cache: Optional[bool] = None + calibration_dataset: str = "cnn_dailymail" + calibration_dataset_size: int = 512 + calibration_batch_size: int = 64 + calibration_seq_len: int = 128 + + @dataclass class ExportConfig: @@ -144,8 +151,25 @@ def _setup(self, model: llm.GPTModel) -> None: def _get_decoder_type(self, config: llm.GPTConfig): return self.export_config.decoder_type or get_modelopt_decoder_type(config) - def quantize(self, model: llm.GPTModel, forward_loop): + def quantize(self, model: llm.GPTModel, forward_loop = None): """Quantize the model and calibrate using given forward loop.""" + if forward_loop is None: + get_dataloader = create_data_iterator_getter( + model, + dataset=self.quantization_config.calibration_dataset, + seq_len=self.quantization_config.calibration_seq_len, + batch_size=self.quantization_config.calibration_batch_size, + calibration_size=self.quantization_config.calibration_dataset_size, + ) + + number_of_batches = self.quantization_config.calibration_dataset_size // self.quantization_config.calibration_batch_size + forward_loop = self.create_megatron_forward_loop( + get_dataloader, + num_batches=number_of_batches, + seq_length=self.quantization_config.calibration_seq_len, + micro_batch_size=self.quantization_config.calibration_batch_size, + ) + algorithm = self.quantization_config.algorithm if algorithm is None: logging.info("Quantization algorithm set to None, returning the non-quantized model") @@ -284,3 +308,26 @@ def get_calib_data_iter( for j in range(len(batch)): batch[j] = batch[j][:max_sequence_length] yield batch + + +def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): + def _iterator(): + CHARACTERS_PER_TOKEN = 4 + + dataloader = get_calib_data_iter( + data=dataset, + max_sequence_length=CHARACTERS_PER_TOKEN * seq_len, + batch_size=batch_size, + calib_size=calibration_size, + ) + for batch in dataloader: + batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch] + batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch] + yield torch.tensor(batch, device=model.device) + + def _iterator_getter(): + dataloader = _iterator() + dataloader = [data for data in dataloader] + return iter(tqdm(dataloader)) + + return _iterator_getter diff --git a/scripts/llm/ptq.py b/scripts/llm/ptq.py index 623a6199b401..382744e96b3a 100644 --- a/scripts/llm/ptq.py +++ b/scripts/llm/ptq.py @@ -13,12 +13,8 @@ # limitations under the License. import argparse -import torch -from tqdm import tqdm - from nemo.collections.llm import quantization - def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -70,29 +66,6 @@ def get_args(): return args -def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): - def _iterator(): - CHARACTERS_PER_TOKEN = 4 - - dataloader = quantization.get_calib_data_iter( - data=dataset, - max_sequence_length=CHARACTERS_PER_TOKEN * seq_len, - batch_size=batch_size, - calib_size=calibration_size, - ) - for batch in dataloader: - batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch] - batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch] - yield torch.tensor(batch, device=model.device) - - def _iterator_getter(): - dataloader = _iterator() - dataloader = [data for data in dataloader] - return iter(tqdm(dataloader)) - - return _iterator_getter - - def main(): args = get_args() @@ -101,6 +74,10 @@ def main(): awq_block_size=args.awq_block_size, sq_alpha=args.sq_alpha, enable_kv_cache=args.enable_kv_cache, + calibration_dataset=args.calibration_dataset, + calibration_dataset_size=args.calibration_dataset_size, + calibration_batch_size=args.batch_size, + calibration_seq_len=args.seq_len, ) export_config = quantization.ExportConfig( @@ -113,23 +90,7 @@ def main(): quantizer = quantization.Quantizer(quantization_config, export_config) model = quantization.load_with_modelopt_layer_spec(args.nemo_checkpoint, args.calib_tp, args.calib_pp) - - get_dataloader = create_data_iterator_getter( - model, - dataset=args.calibration_dataset, - seq_len=args.seq_len, - batch_size=args.batch_size, - calibration_size=args.calibration_dataset_size, - ) - - forward_loop = quantizer.create_megatron_forward_loop( - get_dataloader, - num_batches=args.calibration_dataset_size // args.batch_size, - seq_length=args.seq_len, - micro_batch_size=args.batch_size, - ) - - model = quantizer.quantize(model, forward_loop) + model = quantizer.quantize(model) quantizer.export(model) From 6fbab589d42bacc25ba0f48a826e51b529006931 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Wed, 23 Oct 2024 11:54:30 +0000 Subject: [PATCH 22/24] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/collections/llm/quantization/__init__.py | 11 +++++++++-- nemo/collections/llm/quantization/quantizer.py | 9 +++++---- scripts/llm/ptq.py | 1 + 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/nemo/collections/llm/quantization/__init__.py b/nemo/collections/llm/quantization/__init__.py index 5984a00b65bc..c6c690b3e801 100644 --- a/nemo/collections/llm/quantization/__init__.py +++ b/nemo/collections/llm/quantization/__init__.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .quantizer import ExportConfig, QuantizationConfig, Quantizer, get_calib_data_iter, create_data_iterator_getter +from .quantizer import ExportConfig, QuantizationConfig, Quantizer, create_data_iterator_getter, get_calib_data_iter from .utils import load_with_modelopt_layer_spec -__all__ = ["Quantizer", "QuantizationConfig", "ExportConfig", "get_calib_data_iter", "load_with_modelopt_layer_spec", "create_data_iterator_getter"] +__all__ = [ + "Quantizer", + "QuantizationConfig", + "ExportConfig", + "get_calib_data_iter", + "load_with_modelopt_layer_spec", + "create_data_iterator_getter", +] diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 0333e35c781a..ede83a1b4c0e 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -14,12 +14,12 @@ import os from dataclasses import dataclass -from tqdm import tqdm from typing import Optional, Union import torch import torch.distributed as dist from datasets import load_dataset +from tqdm import tqdm from nemo.collections import llm from nemo.utils import logging @@ -71,7 +71,6 @@ class QuantizationConfig: calibration_seq_len: int = 128 - @dataclass class ExportConfig: """Inference configuration for the quantized TensorRT-LLM engine""" @@ -151,7 +150,7 @@ def _setup(self, model: llm.GPTModel) -> None: def _get_decoder_type(self, config: llm.GPTConfig): return self.export_config.decoder_type or get_modelopt_decoder_type(config) - def quantize(self, model: llm.GPTModel, forward_loop = None): + def quantize(self, model: llm.GPTModel, forward_loop=None): """Quantize the model and calibrate using given forward loop.""" if forward_loop is None: get_dataloader = create_data_iterator_getter( @@ -162,7 +161,9 @@ def quantize(self, model: llm.GPTModel, forward_loop = None): calibration_size=self.quantization_config.calibration_dataset_size, ) - number_of_batches = self.quantization_config.calibration_dataset_size // self.quantization_config.calibration_batch_size + number_of_batches = ( + self.quantization_config.calibration_dataset_size // self.quantization_config.calibration_batch_size + ) forward_loop = self.create_megatron_forward_loop( get_dataloader, num_batches=number_of_batches, diff --git a/scripts/llm/ptq.py b/scripts/llm/ptq.py index 382744e96b3a..0fd2c5682e8a 100644 --- a/scripts/llm/ptq.py +++ b/scripts/llm/ptq.py @@ -15,6 +15,7 @@ import argparse from nemo.collections.llm import quantization + def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, From 5626ec3a94368b01936b07815c145dac1b35c87d Mon Sep 17 00:00:00 2001 From: Piotr Kaminski Date: Wed, 23 Oct 2024 05:48:32 -0700 Subject: [PATCH 23/24] (partial) PP fix Signed-off-by: Piotr Kaminski --- nemo/collections/llm/quantization/quantizer.py | 1 - nemo/collections/llm/quantization/utils.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index ede83a1b4c0e..15367cb25aba 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -144,7 +144,6 @@ def _setup(self, model: llm.GPTModel) -> None: """Setup model for quantization.""" # TODO: disable activation checkpointing model.config.vocab_size = model.tokenizer.vocab_size - model.config.pipeline_dtype = self.torch_dtype # TODO: for some reason model.pipeline.dtype does not work model.freeze() def _get_decoder_type(self, config: llm.GPTConfig): diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py index f5be5812ef05..a38e1d67fcad 100644 --- a/nemo/collections/llm/quantization/utils.py +++ b/nemo/collections/llm/quantization/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch from pathlib import Path from nemo import lightning as nl @@ -46,8 +47,9 @@ def load_with_modelopt_layer_spec(nemo_checkpoint_path: str, calib_tp: int = 1, strategy=nl.MegatronStrategy( tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, + pipeline_dtype=torch.float32 ), - plugins=nl.MegatronMixedPrecision(precision='16-mixed'), + plugins=nl.MegatronMixedPrecision(precision='32', pipeline_dtype=torch.float32), ) fabric = trainer.to_fabric() fabric.launch() From 2925392b2d0c0c649b97cd3bd18e498f8b236fd6 Mon Sep 17 00:00:00 2001 From: Laplasjan107 Date: Wed, 23 Oct 2024 12:49:38 +0000 Subject: [PATCH 24/24] Apply isort and black reformatting Signed-off-by: Laplasjan107 --- nemo/collections/llm/quantization/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py index a38e1d67fcad..86c343ad54ec 100644 --- a/nemo/collections/llm/quantization/utils.py +++ b/nemo/collections/llm/quantization/utils.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch from pathlib import Path +import torch + from nemo import lightning as nl from nemo.collections import llm from nemo.lightning.ckpt_utils import ckpt_to_context_subdir @@ -45,9 +46,7 @@ def load_with_modelopt_layer_spec(nemo_checkpoint_path: str, calib_tp: int = 1, trainer = nl.Trainer( devices=calib_tp * calib_pp, strategy=nl.MegatronStrategy( - tensor_model_parallel_size=calib_tp, - pipeline_model_parallel_size=calib_pp, - pipeline_dtype=torch.float32 + tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, pipeline_dtype=torch.float32 ), plugins=nl.MegatronMixedPrecision(precision='32', pipeline_dtype=torch.float32), )