Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample generate to PTQ for NeMo 2.0 #11339

Merged
merged 8 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 64 additions & 49 deletions nemo/collections/llm/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from tqdm import tqdm

from nemo.collections import llm
from nemo.lightning.ckpt_utils import CONTEXT_PATH
from nemo.collections.llm.inference import MCoreTokenizerWrappper, generate
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.utils import logging

from .utils import get_unwrapped_mcore_model
from .utils import get_modelopt_decoder_type, get_unwrapped_mcore_model

try:
import modelopt.torch.quantization as mtq
Expand Down Expand Up @@ -83,35 +85,12 @@ class ExportConfig:
decoder_type: Optional[str] = None
inference_tensor_parallel: int = 1
inference_pipeline_parallel: int = 1
generate_sample: bool = False

def __post_init__(self):
self.path = Path(self.path)


def get_modelopt_decoder_type(config: llm.GPTConfig) -> str:
"""Infers the modelopt decoder type from GPTConfig class."""
mapping = [
(llm.Baichuan2Config, "baichuan"),
(llm.ChatGLMConfig, "chatglm"),
(llm.GemmaConfig, "gemma"),
(llm.LlamaConfig, "llama"),
(llm.MistralConfig7B, "llama"),
(llm.MixtralConfig, "llama"),
(llm.NemotronConfig, "gptnext"),
(llm.Qwen2Config, "qwen"),
# TODO: (llm.StarcoderConfig, ""),
(llm.Starcoder2Config, "gptnext"),
]

for config_class, decoder_type in mapping:
if isinstance(config, config_class):
return decoder_type

logging.warning("Could not directly infer the decoder type")
# TODO: Add a reasonable behavior for GPTConfig (for instance based on position_embedding_type)
return "llama"


class Quantizer:
"""Post-training quantization (PTQ) and TensorRT-LLM export of NeMo 2.0 checkpoints.

Expand Down Expand Up @@ -146,16 +125,37 @@ def __init__(self, quantization_config: QuantizationConfig, export_config: Expor
assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}"
self.torch_dtype = torch_dtype_from_precision(dtype)

def _setup(self, model: llm.GPTModel) -> None:
@staticmethod
def _setup(model: MegatronParallel) -> None:
"""Setup model for quantization."""
# TODO: disable activation checkpointing
model.config.vocab_size = model.tokenizer.vocab_size
model.freeze()

def _get_decoder_type(self, config: llm.GPTConfig):
return self.export_config.decoder_type or get_modelopt_decoder_type(config)
def _get_decoder_type(self, model: MegatronParallel):
if self.export_config.decoder_type is not None:
return self.export_config.decoder_type
unwrapped_model = model
while not isinstance(unwrapped_model, llm.GPTModel):
unwrapped_model = unwrapped_model.module

return get_modelopt_decoder_type(unwrapped_model)

@staticmethod
def _generate_sample(model: MegatronParallel):
prompts = ["Born in north-east France, Soyer trained as a", "Born in California, Soyer trained as a"]

mcore_tokenizer = MCoreTokenizerWrappper(model.tokenizer)
mcore_inference = model.get_inference_wrapper(
params_dtype=torch.bfloat16, inference_batch_times_seqlen_threshold=30
)

generated = [r.generated_text for r in generate(mcore_inference, mcore_tokenizer, prompts)]
outputs = [prompt + generation for prompt, generation in zip(prompts, generated)]

logging.info(f'Sample generation after PTQ (with prompts): {outputs}')

def quantize(self, model: llm.GPTModel, forward_loop=None):
def quantize(self, model: MegatronParallel, forward_loop=None):
"""Quantize the model and calibrate using given forward loop."""
if forward_loop is None:
get_dataloader = create_data_iterator_getter(
Expand Down Expand Up @@ -185,7 +185,7 @@ def quantize(self, model: llm.GPTModel, forward_loop=None):

self._setup(model)
unwrapped_model = get_unwrapped_mcore_model(model)
decoder_type = self._get_decoder_type(unwrapped_model.config)
decoder_type = self._get_decoder_type(model)
quant_cfg = QUANT_CFG_CHOICES[algorithm]
if "awq" in algorithm:
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
Expand Down Expand Up @@ -230,6 +230,10 @@ def quantize(self, model: llm.GPTModel, forward_loop=None):
if dist.get_rank() == 0:
mtq.print_quant_summary(unwrapped_model)

if self.export_config.generate_sample:
logging.info("Generating a sample output after model quantization.")
self._generate_sample(model)

return model

def create_megatron_forward_loop(
Expand Down Expand Up @@ -266,36 +270,48 @@ def loop(model):

return loop

def export(self, model: llm.GPTModel, model_dir: str) -> None:
@staticmethod
def _validate_quantized_checkpoint(checkpoint_dir: Path, tensor_parallelism_size: int) -> bool:
"""Basic validation of the model structure."""

saved_config = (checkpoint_dir / 'config.json').exists()
saved_weights = True
for i in range(tensor_parallelism_size):
saved_weights &= (checkpoint_dir / f'rank{i}.safetensors').exists()

export_successful = saved_config and saved_weights
if not export_successful:
logging.error("Failed to export the quantized model.")
return export_successful

def export(self, model: MegatronParallel, model_dir: str) -> None:
"""Export model to a TensorRT-LLM checkpoint."""
assert self.export_config is not None, "Export config is not set"
# TODO: Add sample generate
# TODO: Support megatron_amp_O2
export_dir = self.export_config.path
inference_tp = self.export_config.inference_tensor_parallel
inference_pp = self.export_config.inference_pipeline_parallel

use_nfs_workspace = model.config.pipeline_model_parallel_size > 1
export_tensorrt_llm_checkpoint(
model=get_unwrapped_mcore_model(model),
decoder_type=self._get_decoder_type(model.config),
decoder_type=self._get_decoder_type(model),
dtype=self.torch_dtype,
export_dir=export_dir,
inference_tensor_parallel=self.export_config.inference_tensor_parallel,
inference_pipeline_parallel=self.export_config.inference_pipeline_parallel,
inference_tensor_parallel=inference_tp,
inference_pipeline_parallel=inference_pp,
use_nfs_workspace=use_nfs_workspace,
)
dist.barrier()

# Save the model context in order to restore its tokenizer later. The destination
# path is "nemo_context" as this name is used in nemo.export to setup tokenizer.
if dist.get_rank() == 0:
assert self._validate_quantized_checkpoint(export_dir, inference_tp)
shutil.copytree(
os.path.join(model_dir, CONTEXT_PATH),
ckpt_to_context_subdir(model_dir),
os.path.join(export_dir, "nemo_context"),
dirs_exist_ok=True,
)
logging.info("Model context saved.")

logging.info(f"Export succeeded, model has been exported to {export_dir}.")
logging.info(f"Export succeeded, model has been exported to {export_dir}.")


def get_calib_data_iter(
Expand Down Expand Up @@ -323,7 +339,7 @@ def get_calib_data_iter(
def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size):
"""Create a function that provides iterator over a given dataset."""

def _iterator():
def _get_iterator():
CHARACTERS_PER_TOKEN = 4

dataloader = get_calib_data_iter(
Expand All @@ -332,14 +348,13 @@ def _iterator():
batch_size=batch_size,
calib_size=calibration_size,
)

data = []
for batch in dataloader:
batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch]
batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch]
yield torch.tensor(batch, device=model.device)
data.append(torch.tensor(batch, device=model.device))

def _iterator_getter():
dataloader = _iterator()
dataloader = [data for data in dataloader]
return iter(tqdm(dataloader))
return iter(tqdm(data))

return _iterator_getter
return _get_iterator
32 changes: 30 additions & 2 deletions nemo/collections/llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,33 @@
from nemo.utils import logging


def get_modelopt_decoder_type(model: llm.GPTModel) -> str:
"""Infers the modelopt decoder type from GPTModel subclass."""
mapping = [
(llm.Baichuan2Model, "baichuan"),
(llm.ChatGLMModel, "chatglm"),
(llm.Gemma2Model, "gemma2"),
(llm.GemmaModel, "gemma"),
(llm.LlamaModel, "llama"),
(llm.MistralModel, "llama"),
(llm.MixtralModel, "llama"),
(llm.NemotronModel, "gptnext"),
(llm.Qwen2Model, "qwen"),
(llm.StarcoderModel, "gptnext"),
(llm.Starcoder2Model, "gptnext"),
(llm.Phi3Model, "phi3"),
]

for config_class, decoder_type in mapping:
if isinstance(model, config_class):
return decoder_type

logging.warning("Could not infer the decoder type")
return None


def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig:
"""Modify model config for TensorRT Model Optimizer"""
"""Modify model config for TensorRT-Model-Optimizer quantization"""

from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import (
get_gpt_layer_modelopt_spec,
Expand All @@ -46,7 +71,9 @@ def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig:
def load_with_modelopt_layer_spec(
nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1, inference_only: bool = True
):
# TODO: setting ddp="pytorch" with manually deleting model.optim is a hackish way to disable DDP initialization. Needs a systematic solution.
"""Loads a model from a NeMo 2.0 checkpoint using modelopt layer spec."""
# TODO: setting ddp="pytorch" and deleting model.optim is a hackish way to disable DDP initialization.
# Needs a systematic solution.
if inference_only:
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=calib_tp,
Expand Down Expand Up @@ -81,6 +108,7 @@ def load_with_modelopt_layer_spec(


def get_unwrapped_mcore_model(model):
"""Unwraps NeMo 2.0 to base MCore model."""
from megatron.core.models.gpt import GPTModel as MCoreGPTModel

unwrapped_model = model
Expand Down
9 changes: 9 additions & 0 deletions scripts/llm/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@


def get_args():
"""Parses PTQ arguments"""

parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="NeMo PTQ argument parser",
Expand Down Expand Up @@ -58,6 +60,10 @@ def get_args():
type=str,
help='Calibration dataset to be used. Should be \"wikitext\", \"cnn_dailymail\" or path to a local .json file',
)
parser.add_argument(
'--generate_sample', help='Generate sample model output after performing PTQ', action='store_true'
)
parser.set_defaults(generate_sample=False)

args = parser.parse_args()
if args.output_path is None:
Expand All @@ -68,6 +74,8 @@ def get_args():


def main():
"""Example NeMo 2.0 Post Training Quantization workflow"""

args = get_args()

quantization_config = quantization.QuantizationConfig(
Expand All @@ -87,6 +95,7 @@ def main():
inference_tensor_parallel=args.tensor_parallelism_size,
inference_pipeline_parallel=args.pipeline_parallelism_size,
dtype=args.dtype,
generate_sample=args.generate_sample,
)

quantizer = quantization.Quantizer(quantization_config, export_config)
Expand Down
Loading