Skip to content

Commit

Permalink
TRT-LLM 0.12 + ModelOpt 0.17.0 updates (#10301)
Browse files Browse the repository at this point in the history
* Update trtllm-build options

Signed-off-by: Jan Lasek <[email protected]>

* Pull QUANT_CFG_CHOICES into try/catch for HAVE_MODELOPT consistency

Signed-off-by: Jan Lasek <[email protected]>

* Remove deprecated parallel group setup

Signed-off-by: Jan Lasek <[email protected]>

* Remove deprecated size settings

Signed-off-by: Jan Lasek <[email protected]>

* Use max_seq_len instead of max_output_len [part I]

Signed-off-by: Jan Lasek <[email protected]>

---------

Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
janekl authored Aug 29, 2024
1 parent 736a6fc commit ea0f69f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
21 changes: 9 additions & 12 deletions nemo/export/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@
try:
import modelopt.torch.quantization as mtq
from modelopt.torch.export import export_tensorrt_llm_checkpoint
from modelopt.torch.utils.distributed import set_data_parallel_group, set_tensor_parallel_group

QUANT_CFG_CHOICES = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"int4": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
}

HAVE_MODELOPT = True

Expand All @@ -41,14 +49,6 @@


SUPPORTED_DTYPE = [16, "16", "bf16"] # Default precision for non-quantized layers
QUANT_CFG_CHOICES = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"int4": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
}


class Quantizer:
Expand Down Expand Up @@ -157,9 +157,6 @@ def dummy():
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()

set_data_parallel_group(mpu.get_data_parallel_group())
set_tensor_parallel_group(mpu.get_tensor_model_parallel_group())

@staticmethod
def modify_model_config(model_cfg: DictConfig) -> DictConfig:
"""Modify model config for quantization."""
Expand Down
19 changes: 14 additions & 5 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def export(
pipeline_parallelism_size: int = 1,
gpus_per_node: Optional[int] = None,
max_input_len: int = 256,
max_output_len: int = 256,
max_output_len: Optional[int] = 256,
max_input_token: Optional[int] = None,
max_output_token: Optional[int] = None,
max_batch_size: int = 8,
Expand All @@ -169,6 +169,7 @@ def export(
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
reduce_fusion: bool = True,
fp8_quantized: Optional[bool] = None,
fp8_kvcache: Optional[bool] = None,
):
Expand Down Expand Up @@ -201,10 +202,11 @@ def export(
max_lora_rank (int): maximum lora rank.
max_num_tokens (int):
opt_num_tokens (int):
max_seq_len (int):
max_seq_len (int): the maximum sequence length of a single request.
multiple_profiles: (bool): enables multiple profiles feature of TRT-LLM. Default = False
gpt_attention_plugin (str): enable the gpt attention plugin. Default = "auto"
gemm_plugin (str): enable the gpt plugin. Default = "auto"
reduce_fusion (bool): enables fusing extra kernels after custom TRT-LLM allReduce
fp8_quantized (Optional[bool]): enables exporting to FP8 TRT-LLM checkpoints. If not set, autodetects the type.
fp8_kvcache (Optional[bool]): enables FP8 KV-cache quantization. If not set, autodetects the type.
"""
Expand Down Expand Up @@ -257,8 +259,14 @@ def export(
)
max_output_len = max_output_token

if max_seq_len is None:
max_seq_len = max_input_len + max_output_len
if max_output_len is not None:
warnings.warn(
"Parameter max_output_len is deprecated and will be removed. Please use max_seq_len instead.",
DeprecationWarning,
stacklevel=2,
)
if max_seq_len is None:
max_seq_len = max_input_len + max_output_len

if max_batch_size < 4:
warnings.warn(
Expand All @@ -284,21 +292,22 @@ def export(
nemo_checkpoint_path=nemo_checkpoint_path,
engine_dir=self.model_dir,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
tensor_parallel_size=tensor_parallelism_size,
pipeline_parallel_size=pipeline_parallelism_size,
use_parallel_embedding=use_parallel_embedding,
paged_kv_cache=paged_kv_cache,
paged_context_fmha=paged_context_fmha,
remove_input_padding=remove_input_padding,
use_lora_plugin=use_lora_plugin,
lora_target_modules=lora_target_modules,
max_lora_rank=max_lora_rank,
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
multiple_profiles=multiple_profiles,
reduce_fusion=reduce_fusion,
)
else:
if model_type is None:
Expand Down
10 changes: 5 additions & 5 deletions nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ def qnemo_to_tensorrt_llm(
nemo_checkpoint_path: str,
engine_dir: str,
max_input_len: int,
max_output_len: int,
max_seq_len: Optional[int],
max_batch_size: int,
max_prompt_embedding_table_size: int,
tensor_parallel_size: Optional[int] = None,
pipeline_parallel_size: Optional[int] = None,
use_parallel_embedding: bool = False,
paged_kv_cache: bool = True,
paged_context_fmha: bool = False,
remove_input_padding: bool = True,
use_lora_plugin: Optional[str] = None,
lora_target_modules: Optional[List[str]] = None,
Expand All @@ -43,6 +43,7 @@ def qnemo_to_tensorrt_llm(
opt_num_tokens: Optional[int] = None,
max_beam_width: int = 1,
multiple_profiles: bool = False,
reduce_fusion: bool = True,
):
"""Build TensorRT-LLM engine with trtllm-build command in a subprocess."""
assert not lora_target_modules, f"LoRA is not supported for quantized checkpoints, got {lora_target_modules}"
Expand Down Expand Up @@ -82,25 +83,24 @@ def qnemo_to_tensorrt_llm(
build_cmd += f"--workers {num_build_workers} "
build_cmd += f"--max_batch_size {max_batch_size} "
build_cmd += f"--max_input_len {max_input_len} "
build_cmd += f"--max_output_len {max_output_len} "
build_cmd += f"--max_beam_width {max_beam_width} "
build_cmd += f"--tp_size {config.mapping.tp_size} "
build_cmd += f"--pp_size {config.mapping.pp_size} "
build_cmd += f"--max_prompt_embedding_table_size {max_prompt_embedding_table_size} "
build_cmd += f"--builder_opt {builder_opt} "
build_cmd += f"--gpt_attention_plugin {config.dtype} "
build_cmd += f"--nccl_plugin {config.dtype} "
build_cmd += f"--paged_kv_cache {'enable' if paged_kv_cache else 'disable'} "
build_cmd += f"--use_paged_context_fmha {'enable' if paged_context_fmha else 'disable'} "
build_cmd += f"--remove_input_padding {'enable' if remove_input_padding else 'disable'} "
build_cmd += f"--multiple_profiles {'enable' if multiple_profiles else 'disable'} "
build_cmd += f"--reduce_fusion {'enable' if reduce_fusion else 'disable'} "

if use_fused_mlp:
build_cmd += "--use_fused_mlp " if "RecurrentGemma" not in config.architecture else ""

if not use_qdq:
build_cmd += f"--gemm_plugin {config.dtype} "

if max_seq_len:
if max_seq_len is not None:
build_cmd += f"--max_seq_len {max_seq_len} "

if max_num_tokens is not None:
Expand Down

0 comments on commit ea0f69f

Please sign in to comment.