From 55aa6f9bc364d0e3a66145cbfb3fe610ebed7eb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Kami=C5=84ski?= <67481570+Laplasjan107@users.noreply.github.com> Date: Fri, 25 Oct 2024 12:37:15 +0200 Subject: [PATCH] PTQ example for NeMo 2.0 (#10642) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial commit Signed-off-by: Piotr Kaminski * create Quantizer for NeMo 2.0 Signed-off-by: Piotr Kaminski * refactor Signed-off-by: Piotr Kaminski * Call quantize on an unwrapped mcore model Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 * Add tests, adjust unwrapping Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 * fix export Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 * Apply isort and black reformatting Signed-off-by: artbataev * Fix output_path argument for HF import Signed-off-by: Piotr Kamiński <67481570+Laplasjan107@users.noreply.github.com> * fix fabric ckpt loading Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 * code review suggestions Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 * remove unused import Signed-off-by: Piotr Kaminski * use cnn dataset in github ci Signed-off-by: Piotr Kaminski * applied code review Signed-off-by: Piotr Kaminski * code review changes Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 * simplify interface for data iterator Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 * (partial) PP fix Signed-off-by: Piotr Kaminski * Apply isort and black reformatting Signed-off-by: Laplasjan107 --------- Signed-off-by: Piotr Kaminski Signed-off-by: Laplasjan107 Signed-off-by: Piotr Kamiński <67481570+Laplasjan107@users.noreply.github.com> Signed-off-by: artbataev Co-authored-by: Piotr Kaminski Co-authored-by: Laplasjan107 Co-authored-by: artbataev --- .github/workflows/cicd-main.yml | 16 + nemo/collections/llm/__init__.py | 3 + nemo/collections/llm/gpt/model/__init__.py | 2 + nemo/collections/llm/quantization/__init__.py | 25 ++ .../collections/llm/quantization/quantizer.py | 333 ++++++++++++++++++ nemo/collections/llm/quantization/utils.py | 69 ++++ nemo/export/trt_llm/qnemo/tokenizer_utils.py | 6 + nemo/lightning/fabric/fabric.py | 7 +- scripts/llm/ptq.py | 99 ++++++ tests/collections/llm/test_hf_import.py | 29 ++ 10 files changed, 586 insertions(+), 3 deletions(-) create mode 100644 nemo/collections/llm/quantization/__init__.py create mode 100644 nemo/collections/llm/quantization/quantizer.py create mode 100644 nemo/collections/llm/quantization/utils.py create mode 100644 scripts/llm/ptq.py create mode 100644 tests/collections/llm/test_hf_import.py diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 55a952c21eb6..d5b4d2d8081e 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4308,6 +4308,21 @@ jobs: SCRIPT: | bash tests/collections/llm/bitexact/mixtral/run.sh + L2_NeMo_2_PTQ_Llama2_FP8: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_PTQ_Llama2_FP8') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python tests/collections/llm/test_hf_import.py --hf_model /home/TestData/nlp/megatron_llama/llama-ci-hf --output_path /tmp/nemo2_ckpt + + python scripts/llm/ptq.py -nc /tmp/nemo2_ckpt -algo fp8 -out /tmp/nemo2_ptq_engine + + AFTER_SCRIPT: | + rm -rf /tmp/nemo2_ckpt + rm -rf /tmp/nemo2_ptq_engine + Nemo_CICD_Test: needs: - pre-flight @@ -4455,6 +4470,7 @@ jobs: - L2_Speech_Transcription_Canary_Transcribe_Audio_Dir - L2_Megatron_GPT_Reranker - L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact + - L2_NeMo_2_PTQ_Llama2_FP8 if: always() runs-on: ubuntu-latest steps: diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 6dde88079567..6f8070015c07 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -73,6 +73,7 @@ MistralConfig7B, MistralModel, MistralNeMoConfig12B, + MixtralConfig, MixtralConfig8x3B, MixtralConfig8x7B, MixtralConfig8x22B, @@ -104,6 +105,7 @@ gpt_data_step, gpt_forward_step, ) +from nemo.collections.llm.quantization import Quantizer, get_calib_data_iter from nemo.collections.llm.t5.model import T5Config, T5Model, t5_data_step, t5_forward_step __all__ = [ @@ -120,6 +122,7 @@ "MistralConfig7B", "MistralNeMoConfig12B", "MistralModel", + "MixtralConfig", "MixtralConfig8x3B", "MixtralConfig8x7B", "MixtralConfig8x22B", diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 26b8d67cb53d..a0d5d5a92663 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -56,6 +56,7 @@ ) from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel, MistralNeMoConfig12B from nemo.collections.llm.gpt.model.mixtral import ( + MixtralConfig, MixtralConfig8x3B, MixtralConfig8x7B, MixtralConfig8x22B, @@ -105,6 +106,7 @@ "MixtralConfig8x3B", "MixtralConfig8x7B", "MixtralConfig8x22B", + "MixtralConfig", "MixtralModel", "Starcoder2Config", "Starcoder2Model", diff --git a/nemo/collections/llm/quantization/__init__.py b/nemo/collections/llm/quantization/__init__.py new file mode 100644 index 000000000000..c6c690b3e801 --- /dev/null +++ b/nemo/collections/llm/quantization/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .quantizer import ExportConfig, QuantizationConfig, Quantizer, create_data_iterator_getter, get_calib_data_iter +from .utils import load_with_modelopt_layer_spec + +__all__ = [ + "Quantizer", + "QuantizationConfig", + "ExportConfig", + "get_calib_data_iter", + "load_with_modelopt_layer_spec", + "create_data_iterator_getter", +] diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py new file mode 100644 index 000000000000..15367cb25aba --- /dev/null +++ b/nemo/collections/llm/quantization/quantizer.py @@ -0,0 +1,333 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.distributed as dist +from datasets import load_dataset +from tqdm import tqdm + +from nemo.collections import llm +from nemo.utils import logging + +from .utils import get_unwrapped_mcore_model + +try: + import modelopt.torch.quantization as mtq + from modelopt.torch.export import export_tensorrt_llm_checkpoint + + QUANT_CFG_CHOICES = { + "int8": mtq.INT8_DEFAULT_CFG, + "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, + "fp8": mtq.FP8_DEFAULT_CFG, + "int4_awq": mtq.INT4_AWQ_CFG, + "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, + "int4": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + } + + HAVE_MODELOPT = True + +except (ImportError, ModuleNotFoundError) as e: + HAVE_MODELOPT = False + HAVE_MODELOPT_ERROR = e + + +SUPPORTED_DTYPE = [16, "16", "bf16"] # Default precision for non-quantized layers + + +@dataclass +class QuantizationConfig: + """Quantization parameters. + + Available quantization methods are listed in `QUANT_CFG_CHOICES` dictionary above. + Please consult Model Optimizer documentation https://nvidia.github.io/TensorRT-Model-Optimizer/ for details. + + Quantization algorithm can also be conveniently set to None to perform only weights export step + for TensorRT-LLM deployment. This is useful to getting baseline results for a full-precision model. + """ + + algorithm: Optional[str] = "fp8" + awq_block_size: int = 128 + sq_alpha: float = 0.5 + enable_kv_cache: Optional[bool] = None + + calibration_dataset: str = "cnn_dailymail" + calibration_dataset_size: int = 512 + calibration_batch_size: int = 64 + calibration_seq_len: int = 128 + + +@dataclass +class ExportConfig: + """Inference configuration for the quantized TensorRT-LLM engine""" + + path: str + dtype: Union[str, int] = "bf16" + decoder_type: Optional[str] = None + inference_tensor_parallel: int = 1 + inference_pipeline_parallel: int = 1 + + +def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: + """Infers the modelopt decoder type from GPTConfig class""" + mapping = [ + (llm.Baichuan2Config, "baichuan"), + (llm.ChatGLMConfig, "chatglm"), + (llm.GemmaConfig, "gemma"), + (llm.LlamaConfig, "llama"), + (llm.MistralConfig7B, "llama"), + (llm.MixtralConfig, "llama"), + (llm.NemotronConfig, "gptnext"), + (llm.Qwen2Config, "qwen"), + # TODO: (llm.StarcoderConfig, ""), + (llm.Starcoder2Config, "gptnext"), + ] + + for config_class, decoder_type in mapping: + if isinstance(config, config_class): + return decoder_type + + logging.warning("Could not directly infer the decoder type") + # TODO: Add a reasonable behavior for GPTConfig (for instance based on position_embedding_type) + return "llama" + + +class Quantizer: + """Post-training quantization (PTQ) and TRT-LLM export of NeMo 2.0 checkpoints. + + PTQ converts selected model layers to low-precision format (e.g., INT4, FP8) for efficient serving. + The process consist of several steps: + + 1. Loading a Nemo model from disk using appropriate parallelism strategy + 2. Calibrating the model to obtain appropriate algorithm-specific scaling factors + 3. Producing output directory + + The output directory produced is intended to be consumed by TensorRT-LLM toolbox + for efficient inference. This can be achieved using NeMo inference containers. + """ + + def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig): + """Initialize Quantizer with quantization and export configurations.""" + from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision + + if not HAVE_MODELOPT: + raise RuntimeError("nvidia-modelopt is needed to use Quantizer") from HAVE_MODELOPT_ERROR + if not torch.cuda.is_available(): + raise EnvironmentError("GPU is required for the quantization.") + + self.quantization_config: QuantizationConfig = quantization_config + self.export_config: ExportConfig = export_config + + algorithm = quantization_config.algorithm + dtype = export_config.dtype + # Export and Quantization config sanity checks + assert algorithm is None or algorithm in QUANT_CFG_CHOICES, f"Unsupported quantization algorithm: {algorithm}" + if export_config is not None: + assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}" + self.torch_dtype = torch_dtype_from_precision(dtype) + + def _setup(self, model: llm.GPTModel) -> None: + """Setup model for quantization.""" + # TODO: disable activation checkpointing + model.config.vocab_size = model.tokenizer.vocab_size + model.freeze() + + def _get_decoder_type(self, config: llm.GPTConfig): + return self.export_config.decoder_type or get_modelopt_decoder_type(config) + + def quantize(self, model: llm.GPTModel, forward_loop=None): + """Quantize the model and calibrate using given forward loop.""" + if forward_loop is None: + get_dataloader = create_data_iterator_getter( + model, + dataset=self.quantization_config.calibration_dataset, + seq_len=self.quantization_config.calibration_seq_len, + batch_size=self.quantization_config.calibration_batch_size, + calibration_size=self.quantization_config.calibration_dataset_size, + ) + + number_of_batches = ( + self.quantization_config.calibration_dataset_size // self.quantization_config.calibration_batch_size + ) + forward_loop = self.create_megatron_forward_loop( + get_dataloader, + num_batches=number_of_batches, + seq_length=self.quantization_config.calibration_seq_len, + micro_batch_size=self.quantization_config.calibration_batch_size, + ) + + algorithm = self.quantization_config.algorithm + if algorithm is None: + logging.info("Quantization algorithm set to None, returning the non-quantized model") + return model + + logging.info(f"Quantizing model to {algorithm}...") + + self._setup(model) + unwrapped_model = get_unwrapped_mcore_model(model) + decoder_type = self._get_decoder_type(unwrapped_model.config) + quant_cfg = QUANT_CFG_CHOICES[algorithm] + if "awq" in algorithm: + weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] + if isinstance(weight_quantizer, list): + weight_quantizer = weight_quantizer[0] + weight_quantizer["block_sizes"][-1] = self.quantization_config.awq_block_size + + # Always turn on FP8 kv cache to save memory footprint. + # For int8_sq, we use int8 kv cache. + # TODO: Investigate why enabling FP8 kv cache will cause accuracy regressions for Nemotron. + enable_quant_kv_cache = self.quantization_config.enable_kv_cache + if enable_quant_kv_cache is None: + enable_quant_kv_cache = "int8" not in algorithm and decoder_type != "gptnext" + logging.info(f'{"Enabled" if enable_quant_kv_cache else "Disabled"} KV cache quantization') + quant_cfg["quant_cfg"]["*output_quantizer"] = { + "num_bits": 8 if algorithm == "int8_sq" else (4, 3), + "axis": None, + "enable": enable_quant_kv_cache, + } + if algorithm == "int8_sq": + sq_alpha = self.quantization_config.sq_alpha + logging.info(f"Using int8_sq alpha = {sq_alpha}") + quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": sq_alpha} + + unwrapped_model = mtq.quantize(unwrapped_model, quant_cfg, forward_loop) + + if decoder_type == "gptnext": + # We found squared_relu may have an under-calibration problem. + # Clamp the scaling_factor with a min threshold to avoid under-calibration. + match algorithm: + case "fp8": + maxbound = 448 + case "int8_sq": + maxbound = 127 + case _: + maxbound = 0 + + unwrapped_model = mtq.postprocess_amax( + unwrapped_model, "*input_quantizer", lambda amax: torch.clamp(amax, min=0.01 * maxbound) + ) + + if dist.get_rank() == 0: + mtq.print_quant_summary(unwrapped_model) + + return model + + def create_megatron_forward_loop( + self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None + ): + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + forward_backward_func = get_forward_backward_func() + + def forward_step_func(data_iterator, model): + data = next(data_iterator) + batch_len, seq_len = data.shape + position_ids = torch.arange(seq_len, device=data.device).expand((batch_len, seq_len)) + output_tensor = model(data, position_ids, None) + + def _mock_loss_function(tensor): + return 0, {} + + return output_tensor, _mock_loss_function + + def loop(model): + dataloader = get_dataloader() + forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=dataloader, + model=model, + num_microbatches=num_batches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + forward_only=True, + ) + + return loop + + def export(self, model: llm.GPTModel) -> None: + assert self.export_config is not None, "Export config is not set" + # TODO: Add sample generate + # TODO: Support megatron_amp_O2 + export_dir = self.export_config.path + use_nfs_workspace = (model.trainer._fabric.__io__.num_nodes > 1) or ( + model.config.pipeline_model_parallel_size > 1 + ) + export_tensorrt_llm_checkpoint( + model=get_unwrapped_mcore_model(model), + decoder_type=self._get_decoder_type(model.config), + dtype=self.torch_dtype, + export_dir=export_dir, + inference_tensor_parallel=self.export_config.inference_tensor_parallel, + inference_pipeline_parallel=self.export_config.inference_pipeline_parallel, + use_nfs_workspace=use_nfs_workspace, + ) + + dist.barrier() # Wait until all ranks complete export_model_config step + logging.info(f"Export succeeded, model has been exported to {export_dir}. Saving tokenizer if possible...") + + if dist.get_rank() == 0: + try: + tokenizer_dst = os.path.join(export_dir, 'tokenizer') + model.tokenizer.tokenizer.save_pretrained(tokenizer_dst) + except Exception as err: + logging.warning("Could not save the tokenizer: " + str(err)) + + +def get_calib_data_iter( + data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512 +): + """Creates a sample data iterator for calibration""" + if data == "wikitext": + dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") + text_column = "text" + elif data == "cnn_dailymail": + dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") + text_column = "article" + else: + # Assume a local JSON dataset with a column named "text" + dataset = load_dataset("json", data_files=data, split="train") + text_column = "text" + calib_size = max(min(len(dataset), calib_size), batch_size) + for i in range(calib_size // batch_size): + batch = dataset[i * batch_size : (i + 1) * batch_size][text_column] + for j in range(len(batch)): + batch[j] = batch[j][:max_sequence_length] + yield batch + + +def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): + def _iterator(): + CHARACTERS_PER_TOKEN = 4 + + dataloader = get_calib_data_iter( + data=dataset, + max_sequence_length=CHARACTERS_PER_TOKEN * seq_len, + batch_size=batch_size, + calib_size=calibration_size, + ) + for batch in dataloader: + batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch] + batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch] + yield torch.tensor(batch, device=model.device) + + def _iterator_getter(): + dataloader = _iterator() + dataloader = [data for data in dataloader] + return iter(tqdm(dataloader)) + + return _iterator_getter diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py new file mode 100644 index 000000000000..86c343ad54ec --- /dev/null +++ b/nemo/collections/llm/quantization/utils.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import torch + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.utils import logging + + +def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: + """Modify model config for TensorRT Model Optimizer""" + + from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import ( + get_gpt_layer_modelopt_spec, + ) + + model_cfg.transformer_layer_spec = get_gpt_layer_modelopt_spec() + if model_cfg.sequence_parallel: + logging.warning("Disabling sequence parallelism for quantization...") + model_cfg.sequence_parallel = False + # Only custom ModelOpt spec is supported for Quantization: this custom spec is largely based on local Megatron-LM + # layer definitions to avoid Transformer Engine implementations that are currently not supported. + # This layer spec also requires RoPE fusion to be disabled for tensor view operations in attention + # layer implementation from megatron/core/transformer/dot_product_attention.py to be functional. + model_cfg.name = "modelopt" + model_cfg.apply_rope_fusion = False + return model_cfg + + +def load_with_modelopt_layer_spec(nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel: + trainer = nl.Trainer( + devices=calib_tp * calib_pp, + strategy=nl.MegatronStrategy( + tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, pipeline_dtype=torch.float32 + ), + plugins=nl.MegatronMixedPrecision(precision='32', pipeline_dtype=torch.float32), + ) + fabric = trainer.to_fabric() + fabric.launch() + + model_path = Path(nemo_checkpoint_path) + model = nl.io.load_context(ckpt_to_context_subdir(model_path)).model + model.config = quantizable_model_config(model.config) + return fabric.load_model(nemo_checkpoint_path, model=model) + + +def get_unwrapped_mcore_model(model: llm.GPTModel): + from megatron.core.models.gpt import GPTModel as MCoreGPTModel + + unwrapped_model = model + while not isinstance(unwrapped_model, MCoreGPTModel): + unwrapped_model = unwrapped_model.module + + return unwrapped_model diff --git a/nemo/export/trt_llm/qnemo/tokenizer_utils.py b/nemo/export/trt_llm/qnemo/tokenizer_utils.py index c3dd5c2befc9..36efa9259f9d 100644 --- a/nemo/export/trt_llm/qnemo/tokenizer_utils.py +++ b/nemo/export/trt_llm/qnemo/tokenizer_utils.py @@ -23,11 +23,17 @@ # from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer TOKENIZER_CONFIG_FILE = "tokenizer_config.yaml" +TOKENIZER_DIR = "tokenizer" def get_nmt_tokenizer(nemo_checkpoint_path: str): """Build tokenizer from Nemo tokenizer config.""" + tokenizer_dir = os.path.join(nemo_checkpoint_path, TOKENIZER_DIR) + if os.path.exists(tokenizer_dir): + print(f"Initializing tokenizer from {TOKENIZER_DIR} directory") + return AutoTokenizer.from_pretrained(tokenizer_dir) + print(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}") tokenizer_cfg = OmegaConf.load(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE)) diff --git a/nemo/lightning/fabric/fabric.py b/nemo/lightning/fabric/fabric.py index 55431d940193..ddad49f7d211 100644 --- a/nemo/lightning/fabric/fabric.py +++ b/nemo/lightning/fabric/fabric.py @@ -6,9 +6,9 @@ import lightning_fabric as lb import pytorch_lightning as pl from torch import nn - from typing_extensions import Self, override +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir from nemo.lightning.io.mixin import IOMixin, serialization, track_io if TYPE_CHECKING: @@ -63,12 +63,13 @@ def load_model( from nemo.lightning.io import load_context + path = Path(path) if model is None: - context = load_context(path) + context = load_context(ckpt_to_context_subdir(path)) model = context.model dist_model = self.setup_module(model) - self.load(path, {"state_dict": dist_model}) + self.load(ckpt_to_weights_subdir(path), {"state_dict": dist_model}) return dist_model diff --git a/scripts/llm/ptq.py b/scripts/llm/ptq.py new file mode 100644 index 000000000000..0fd2c5682e8a --- /dev/null +++ b/scripts/llm/ptq.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from nemo.collections.llm import quantization + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="NeMo PTQ argument parser", + ) + parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source NeMo 2.0 checkpoint") + parser.add_argument("--decoder_type", type=str, help="Decoder type for TensorRT-Model-Optimizer") + parser.add_argument("-ctp", "--calib_tp", type=int, default=1) + parser.add_argument("-cpp", "--calib_pp", type=int, default=1) + parser.add_argument("-tps", "--tensor_parallelism_size", type=int, default=1) + parser.add_argument("-pps", "--pipeline_parallelism_size", type=int, default=1) + parser.add_argument('-out', '--output_path', type=str, help='Path for the exported engine') + parser.add_argument( + '-algo', + '--algorithm', + type=str, + default="fp8", + choices=["no_quant", "int8", "int8_sq", "fp8", "int4_awq", "w4a8_awq", "int4"], + help='TensorRT-Model-Optimizer quantization algorithm', + ) + parser.add_argument( + '-awq_bs', '--awq_block_size', type=int, default=128, help='Block size for AWQ quantization algorithms' + ) + parser.add_argument('--sq_alpha', type=float, default=0.5, help='Smooth-Quant alpha parameter') + parser.add_argument('--enable_kv_cache', help='Enables KV-cache quantization', action='store_true') + parser.add_argument('--disable_kv_cache', dest='enable_kv_cache', action='store_false') + parser.set_defaults(enable_kv_cache=None) + parser.add_argument( + '-dt', '--dtype', default="bf16", choices=["16", "bf16"], help='Default precision for non-quantized layers' + ) + parser.add_argument('-bs', '--batch_size', default=64, type=int, help='Calibration batch size') + parser.add_argument('-sl', '--seq_len', default=128, type=int, help='Length of the tokenized text') + parser.add_argument( + '-calib_size', '--calibration_dataset_size', default=512, type=int, help='Size of calibration dataset' + ) + parser.add_argument( + '-calib_ds', + '--calibration_dataset', + default="cnn_dailymail", + type=str, + help='Calibration dataset to be used. Should be \"wikitext\", \"cnn_dailymail\" or path to a local .json file', + ) + + args = parser.parse_args() + if args.output_path is None: + args.output_path = ( + f"./qnemo_{args.algorithm}_tp{args.tensor_parallelism_size}_pp{args.pipeline_parallelism_size}" + ) + return args + + +def main(): + args = get_args() + + quantization_config = quantization.QuantizationConfig( + algorithm=None if args.algorithm == "no_quant" else args.algorithm, + awq_block_size=args.awq_block_size, + sq_alpha=args.sq_alpha, + enable_kv_cache=args.enable_kv_cache, + calibration_dataset=args.calibration_dataset, + calibration_dataset_size=args.calibration_dataset_size, + calibration_batch_size=args.batch_size, + calibration_seq_len=args.seq_len, + ) + + export_config = quantization.ExportConfig( + path=args.output_path, + decoder_type=args.decoder_type, + inference_tensor_parallel=args.tensor_parallelism_size, + inference_pipeline_parallel=args.pipeline_parallelism_size, + dtype=args.dtype, + ) + + quantizer = quantization.Quantizer(quantization_config, export_config) + model = quantization.load_with_modelopt_layer_spec(args.nemo_checkpoint, args.calib_tp, args.calib_pp) + model = quantizer.quantize(model) + quantizer.export(model) + + +if __name__ == '__main__': + main() diff --git a/tests/collections/llm/test_hf_import.py b/tests/collections/llm/test_hf_import.py new file mode 100644 index 000000000000..53232eb02bb2 --- /dev/null +++ b/tests/collections/llm/test_hf_import.py @@ -0,0 +1,29 @@ +import argparse +from pathlib import Path + +from nemo import lightning as nl +from nemo.collections import llm + + +def get_args(): + parser = argparse.ArgumentParser(description='Test Llama2 7B model model conversion from HF') + parser.add_argument('--hf_model', type=str, help="Original HF model") + parser.add_argument('--output_path', type=str, help="NeMo 2.0 export path") + + return parser.parse_args() + + +if __name__ == '__main__': + args = get_args() + + model = llm.LlamaModel(config=llm.Llama2Config7B) + nemo2_path = llm.import_ckpt(model, "hf://" + args.hf_model, output_path=Path(args.output_path)) + + trainer = nl.Trainer( + devices=1, + strategy=nl.MegatronStrategy(tensor_model_parallel_size=1), + plugins=nl.MegatronMixedPrecision(precision='fp16'), + ) + fabric = trainer.to_fabric() + trainer.strategy.setup_environment() + fabric.load_model(nemo2_path)