Skip to content

Commit

Permalink
Add sample generate to PTQ for NeMo 2.0 (#11339)
Browse files Browse the repository at this point in the history
* Initial commit

Signed-off-by: Piotr Kaminski <[email protected]>

* Remove leftover print

Signed-off-by: Piotr Kaminski <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Laplasjan107 <[email protected]>

* Fix docs and type annotations

Signed-off-by: Piotr Kaminski <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Laplasjan107 <[email protected]>

* Applied code review suggestions

Signed-off-by: Piotr Kaminski <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Laplasjan107 <[email protected]>

* Fix _get_decoder_type parameter

Signed-off-by: Piotr Kamiński <[email protected]>

---------

Signed-off-by: Piotr Kaminski <[email protected]>
Signed-off-by: Laplasjan107 <[email protected]>
Signed-off-by: Piotr Kamiński <[email protected]>
Co-authored-by: Piotr Kaminski <[email protected]>
Co-authored-by: Laplasjan107 <[email protected]>
  • Loading branch information
3 people authored Nov 25, 2024
1 parent 42d164e commit 8f779ba
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 51 deletions.
113 changes: 64 additions & 49 deletions nemo/collections/llm/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from tqdm import tqdm

from nemo.collections import llm
from nemo.lightning.ckpt_utils import CONTEXT_PATH
from nemo.collections.llm.inference import MCoreTokenizerWrappper, generate
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.utils import logging

from .utils import get_unwrapped_mcore_model
from .utils import get_modelopt_decoder_type, get_unwrapped_mcore_model

try:
import modelopt.torch.quantization as mtq
Expand Down Expand Up @@ -83,35 +85,12 @@ class ExportConfig:
decoder_type: Optional[str] = None
inference_tensor_parallel: int = 1
inference_pipeline_parallel: int = 1
generate_sample: bool = False

def __post_init__(self):
self.path = Path(self.path)


def get_modelopt_decoder_type(config: llm.GPTConfig) -> str:
"""Infers the modelopt decoder type from GPTConfig class."""
mapping = [
(llm.Baichuan2Config, "baichuan"),
(llm.ChatGLMConfig, "chatglm"),
(llm.GemmaConfig, "gemma"),
(llm.LlamaConfig, "llama"),
(llm.MistralConfig7B, "llama"),
(llm.MixtralConfig, "llama"),
(llm.NemotronConfig, "gptnext"),
(llm.Qwen2Config, "qwen"),
# TODO: (llm.StarcoderConfig, ""),
(llm.Starcoder2Config, "gptnext"),
]

for config_class, decoder_type in mapping:
if isinstance(config, config_class):
return decoder_type

logging.warning("Could not directly infer the decoder type")
# TODO: Add a reasonable behavior for GPTConfig (for instance based on position_embedding_type)
return "llama"


class Quantizer:
"""Post-training quantization (PTQ) and TensorRT-LLM export of NeMo 2.0 checkpoints.
Expand Down Expand Up @@ -146,16 +125,37 @@ def __init__(self, quantization_config: QuantizationConfig, export_config: Expor
assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}"
self.torch_dtype = torch_dtype_from_precision(dtype)

def _setup(self, model: llm.GPTModel) -> None:
@staticmethod
def _setup(model: MegatronParallel) -> None:
"""Setup model for quantization."""
# TODO: disable activation checkpointing
model.config.vocab_size = model.tokenizer.vocab_size
model.freeze()

def _get_decoder_type(self, config: llm.GPTConfig):
return self.export_config.decoder_type or get_modelopt_decoder_type(config)
def _get_decoder_type(self, model: MegatronParallel):
if self.export_config.decoder_type is not None:
return self.export_config.decoder_type
unwrapped_model = model
while not isinstance(unwrapped_model, llm.GPTModel):
unwrapped_model = unwrapped_model.module

return get_modelopt_decoder_type(unwrapped_model)

@staticmethod
def _generate_sample(model: MegatronParallel):
prompts = ["Born in north-east France, Soyer trained as a", "Born in California, Soyer trained as a"]

mcore_tokenizer = MCoreTokenizerWrappper(model.tokenizer)
mcore_inference = model.get_inference_wrapper(
params_dtype=torch.bfloat16, inference_batch_times_seqlen_threshold=30
)

generated = [r.generated_text for r in generate(mcore_inference, mcore_tokenizer, prompts)]
outputs = [prompt + generation for prompt, generation in zip(prompts, generated)]

logging.info(f'Sample generation after PTQ (with prompts): {outputs}')

def quantize(self, model: llm.GPTModel, forward_loop=None):
def quantize(self, model: MegatronParallel, forward_loop=None):
"""Quantize the model and calibrate using given forward loop."""
if forward_loop is None:
get_dataloader = create_data_iterator_getter(
Expand Down Expand Up @@ -185,7 +185,7 @@ def quantize(self, model: llm.GPTModel, forward_loop=None):

self._setup(model)
unwrapped_model = get_unwrapped_mcore_model(model)
decoder_type = self._get_decoder_type(unwrapped_model.config)
decoder_type = self._get_decoder_type(model)
quant_cfg = QUANT_CFG_CHOICES[algorithm]
if "awq" in algorithm:
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
Expand Down Expand Up @@ -230,6 +230,10 @@ def quantize(self, model: llm.GPTModel, forward_loop=None):
if dist.get_rank() == 0:
mtq.print_quant_summary(unwrapped_model)

if self.export_config.generate_sample:
logging.info("Generating a sample output after model quantization.")
self._generate_sample(model)

return model

def create_megatron_forward_loop(
Expand Down Expand Up @@ -266,36 +270,48 @@ def loop(model):

return loop

def export(self, model: llm.GPTModel, model_dir: str) -> None:
@staticmethod
def _validate_quantized_checkpoint(checkpoint_dir: Path, tensor_parallelism_size: int) -> bool:
"""Basic validation of the model structure."""

saved_config = (checkpoint_dir / 'config.json').exists()
saved_weights = True
for i in range(tensor_parallelism_size):
saved_weights &= (checkpoint_dir / f'rank{i}.safetensors').exists()

export_successful = saved_config and saved_weights
if not export_successful:
logging.error("Failed to export the quantized model.")
return export_successful

def export(self, model: MegatronParallel, model_dir: str) -> None:
"""Export model to a TensorRT-LLM checkpoint."""
assert self.export_config is not None, "Export config is not set"
# TODO: Add sample generate
# TODO: Support megatron_amp_O2
export_dir = self.export_config.path
inference_tp = self.export_config.inference_tensor_parallel
inference_pp = self.export_config.inference_pipeline_parallel

use_nfs_workspace = model.config.pipeline_model_parallel_size > 1
export_tensorrt_llm_checkpoint(
model=get_unwrapped_mcore_model(model),
decoder_type=self._get_decoder_type(model.config),
decoder_type=self._get_decoder_type(model),
dtype=self.torch_dtype,
export_dir=export_dir,
inference_tensor_parallel=self.export_config.inference_tensor_parallel,
inference_pipeline_parallel=self.export_config.inference_pipeline_parallel,
inference_tensor_parallel=inference_tp,
inference_pipeline_parallel=inference_pp,
use_nfs_workspace=use_nfs_workspace,
)
dist.barrier()

# Save the model context in order to restore its tokenizer later. The destination
# path is "nemo_context" as this name is used in nemo.export to setup tokenizer.
if dist.get_rank() == 0:
assert self._validate_quantized_checkpoint(export_dir, inference_tp)
shutil.copytree(
os.path.join(model_dir, CONTEXT_PATH),
ckpt_to_context_subdir(model_dir),
os.path.join(export_dir, "nemo_context"),
dirs_exist_ok=True,
)
logging.info("Model context saved.")

logging.info(f"Export succeeded, model has been exported to {export_dir}.")
logging.info(f"Export succeeded, model has been exported to {export_dir}.")


def get_calib_data_iter(
Expand Down Expand Up @@ -323,7 +339,7 @@ def get_calib_data_iter(
def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size):
"""Create a function that provides iterator over a given dataset."""

def _iterator():
def _get_iterator():
CHARACTERS_PER_TOKEN = 4

dataloader = get_calib_data_iter(
Expand All @@ -332,14 +348,13 @@ def _iterator():
batch_size=batch_size,
calib_size=calibration_size,
)

data = []
for batch in dataloader:
batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch]
batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch]
yield torch.tensor(batch, device=model.device)
data.append(torch.tensor(batch, device=model.device))

def _iterator_getter():
dataloader = _iterator()
dataloader = [data for data in dataloader]
return iter(tqdm(dataloader))
return iter(tqdm(data))

return _iterator_getter
return _get_iterator
32 changes: 30 additions & 2 deletions nemo/collections/llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,33 @@
from nemo.utils import logging


def get_modelopt_decoder_type(model: llm.GPTModel) -> str:
"""Infers the modelopt decoder type from GPTModel subclass."""
mapping = [
(llm.Baichuan2Model, "baichuan"),
(llm.ChatGLMModel, "chatglm"),
(llm.Gemma2Model, "gemma2"),
(llm.GemmaModel, "gemma"),
(llm.LlamaModel, "llama"),
(llm.MistralModel, "llama"),
(llm.MixtralModel, "llama"),
(llm.NemotronModel, "gptnext"),
(llm.Qwen2Model, "qwen"),
(llm.StarcoderModel, "gptnext"),
(llm.Starcoder2Model, "gptnext"),
(llm.Phi3Model, "phi3"),
]

for config_class, decoder_type in mapping:
if isinstance(model, config_class):
return decoder_type

logging.warning("Could not infer the decoder type")
return None


def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig:
"""Modify model config for TensorRT Model Optimizer"""
"""Modify model config for TensorRT-Model-Optimizer quantization"""

from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import (
get_gpt_layer_modelopt_spec,
Expand All @@ -46,7 +71,9 @@ def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig:
def load_with_modelopt_layer_spec(
nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1, inference_only: bool = True
):
# TODO: setting ddp="pytorch" with manually deleting model.optim is a hackish way to disable DDP initialization. Needs a systematic solution.
"""Loads a model from a NeMo 2.0 checkpoint using modelopt layer spec."""
# TODO: setting ddp="pytorch" and deleting model.optim is a hackish way to disable DDP initialization.
# Needs a systematic solution.
if inference_only:
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=calib_tp,
Expand Down Expand Up @@ -81,6 +108,7 @@ def load_with_modelopt_layer_spec(


def get_unwrapped_mcore_model(model):
"""Unwraps NeMo 2.0 to base MCore model."""
from megatron.core.models.gpt import GPTModel as MCoreGPTModel

unwrapped_model = model
Expand Down
9 changes: 9 additions & 0 deletions scripts/llm/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@


def get_args():
"""Parses PTQ arguments"""

parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="NeMo PTQ argument parser",
Expand Down Expand Up @@ -58,6 +60,10 @@ def get_args():
type=str,
help='Calibration dataset to be used. Should be \"wikitext\", \"cnn_dailymail\" or path to a local .json file',
)
parser.add_argument(
'--generate_sample', help='Generate sample model output after performing PTQ', action='store_true'
)
parser.set_defaults(generate_sample=False)

args = parser.parse_args()
if args.output_path is None:
Expand All @@ -68,6 +74,8 @@ def get_args():


def main():
"""Example NeMo 2.0 Post Training Quantization workflow"""

args = get_args()

quantization_config = quantization.QuantizationConfig(
Expand All @@ -87,6 +95,7 @@ def main():
inference_tensor_parallel=args.tensor_parallelism_size,
inference_pipeline_parallel=args.pipeline_parallelism_size,
dtype=args.dtype,
generate_sample=args.generate_sample,
)

quantizer = quantization.Quantizer(quantization_config, export_config)
Expand Down

0 comments on commit 8f779ba

Please sign in to comment.