diff --git a/nemo/deploy/deploy_pytriton.py b/nemo/deploy/deploy_pytriton.py index 22dea8ac47cd9..25e09cf3eacca 100644 --- a/nemo/deploy/deploy_pytriton.py +++ b/nemo/deploy/deploy_pytriton.py @@ -24,7 +24,6 @@ class DeployPyTriton(DeployBase): - """ Deploys any models to Triton Inference Server that implements ITritonDeployable interface in nemo.deploy. @@ -102,7 +101,6 @@ def __init__( ) def deploy(self): - """ Deploys any models to Triton Inference Server. """ @@ -148,7 +146,6 @@ def deploy(self): print(e) def serve(self): - """ Starts serving the model and waits for the requests """ @@ -163,7 +160,6 @@ def serve(self): print(e) def run(self): - """ Starts serving the model asynchronously. """ diff --git a/nemo/deploy/nlp/query_llm.py b/nemo/deploy/nlp/query_llm.py index 6a4337024eeb9..c8387914c2e9a 100644 --- a/nemo/deploy/nlp/query_llm.py +++ b/nemo/deploy/nlp/query_llm.py @@ -71,7 +71,8 @@ class NemoQueryLLM(NemoQueryLLMBase): def __init__(self, url, model_name): super().__init__( - url=url, model_name=model_name, + url=url, + model_name=model_name, ) def query_llm( diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 8fc8522c8333a..2312a469a33a5 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -84,15 +84,24 @@ class TensorRTLLM(ITritonDeployable): """ - def __init__(self, model_dir: str, lora_ckpt_list: List[str] = None, load_model: bool = True): + def __init__( + self, + model_dir: str, + lora_ckpt_list: List[str] = None, + load_model: bool = True, + use_python_runtime: bool = True, + ): """ Args: model_dir (str): path for storing the TensorRT-LLM model files. + lora_ckpt_list (List[str]): lora checkpoint paths. load_model (bool): load TensorRT-LLM model if the engine files exist in the model_dir. + use_python_runtime (bool): whether to use python or c++ runtime. """ self.model_dir = model_dir self.lora_ckpt_list = lora_ckpt_list + self.use_python_runtime = use_python_runtime self.model = None self.tokenizer = None self.n_gpus = None @@ -645,7 +654,7 @@ def _prep_ptuning_table(self): if len(vtokens_embeddings) > 0: self.p_table = torch.stack(vtokens_embeddings, dim=0).view(-1, self.get_hidden_size) - max_prompt_embedding_table_size = self.config['builder_config']['max_prompt_embedding_table_size'] + max_prompt_embedding_table_size = self.config['build_config']['max_prompt_embedding_table_size'] actual_prompt_table_size = self.p_table.shape[0] if actual_prompt_table_size > max_prompt_embedding_table_size: @@ -776,7 +785,10 @@ def _load(self): self._load_config_file() self.tokenizer = get_tokenzier(Path(os.path.join(self.model_dir))) self.model = load( - tokenizer=self.tokenizer, engine_dir=self.model_dir, lora_ckpt_list=self.lora_ckpt_list + tokenizer=self.tokenizer, + engine_dir=self.model_dir, + lora_ckpt_list=self.lora_ckpt_list, + use_python_runtime=self.use_python_runtime, ) self._load_prompt_tables() except Exception as error: diff --git a/nemo/export/trt_llm/decoder/decoder.py b/nemo/export/trt_llm/decoder/decoder.py index b3c0e2257e9fc..2d1993fd74c0c 100644 --- a/nemo/export/trt_llm/decoder/decoder.py +++ b/nemo/export/trt_llm/decoder/decoder.py @@ -90,7 +90,11 @@ def build_post_layernorm(self, layer) -> Optional[LayernormConfig]: pass def __init__( - self, decoder_type: str, dtype: trt.DataType = trt.float16, rank: int = 0, tensor_parallel: int = 1, + self, + decoder_type: str, + dtype: trt.DataType = trt.float16, + rank: int = 0, + tensor_parallel: int = 1, ): """Initializes the DecoderLayerConfigBuilder.""" self.decoder_type = decoder_type diff --git a/nemo/export/trt_llm/decoder/falcon.py b/nemo/export/trt_llm/decoder/falcon.py index 91edc7794607b..e05979fa75a01 100644 --- a/nemo/export/trt_llm/decoder/falcon.py +++ b/nemo/export/trt_llm/decoder/falcon.py @@ -69,7 +69,11 @@ def build_attention(self, layer) -> AttentionConfig: ) config.dense = LinearConfig.from_nn_module( - layer.self_attn.o_proj, LINEAR_ROW, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.self_attn.o_proj, + LINEAR_ROW, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) return config @@ -78,13 +82,25 @@ def build_attention(self, layer) -> AttentionConfig: def build_mlp(self, layer) -> MLPConfig: config = MLPConfig() config.fc = LinearConfig.from_nn_module( - layer.mlp.gate_proj, LINEAR_COLUMN, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.gate_proj, + LINEAR_COLUMN, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) config.proj = LinearConfig.from_nn_module( - layer.mlp.down_proj, LINEAR_ROW, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.down_proj, + LINEAR_ROW, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) config.gate = LinearConfig.from_nn_module( - layer.mlp.up_proj, LINEAR_COLUMN, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.up_proj, + LINEAR_COLUMN, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) return config @@ -130,4 +146,7 @@ def build_decoder(self, layer): config.set_if_not_exist('bias', False) config.set_if_not_exist('moe_num_experts', 0) - return FalconDecoderLayer(config=config, layer_idx=self.layer_id,) + return FalconDecoderLayer( + config=config, + layer_idx=self.layer_id, + ) diff --git a/nemo/export/trt_llm/decoder/gemma.py b/nemo/export/trt_llm/decoder/gemma.py index 10301c7a47d75..37f843dcf0ca7 100644 --- a/nemo/export/trt_llm/decoder/gemma.py +++ b/nemo/export/trt_llm/decoder/gemma.py @@ -64,7 +64,11 @@ def build_attention(self, layer) -> AttentionConfig: ) config.dense = LinearConfig.from_nn_module( - layer.self_attn.o_proj, LINEAR_ROW, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.self_attn.o_proj, + LINEAR_ROW, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) return config @@ -73,13 +77,25 @@ def build_attention(self, layer) -> AttentionConfig: def build_mlp(self, layer) -> MLPConfig: config = MLPConfig() config.fc = LinearConfig.from_nn_module( - layer.mlp.gate_proj, LINEAR_COLUMN, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.gate_proj, + LINEAR_COLUMN, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) config.proj = LinearConfig.from_nn_module( - layer.mlp.down_proj, LINEAR_ROW, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.down_proj, + LINEAR_ROW, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) config.gate = LinearConfig.from_nn_module( - layer.mlp.up_proj, LINEAR_COLUMN, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.up_proj, + LINEAR_COLUMN, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) return config @@ -128,4 +144,7 @@ def build_decoder(self, layer): config.set_if_not_exist('dense_context_fmha', False) config.set_if_not_exist('moe_num_experts', 0) - return GemmaDecoderLayer(config=config, layer_idx=self.layer_id,) + return GemmaDecoderLayer( + config=config, + layer_idx=self.layer_id, + ) diff --git a/nemo/export/trt_llm/decoder/gpt.py b/nemo/export/trt_llm/decoder/gpt.py index 8af4e4ef01e4e..a405aabbbd48e 100644 --- a/nemo/export/trt_llm/decoder/gpt.py +++ b/nemo/export/trt_llm/decoder/gpt.py @@ -54,11 +54,18 @@ def build_input_layernorm(self, layer) -> LayernormConfig: def build_attention(self, layer) -> AttentionConfig: config = AttentionConfig() config.qkv = LinearConfig.from_qkv_nn_modules( - [layer.attn.c_attn], rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + [layer.attn.c_attn], + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) config.dense = LinearConfig.from_nn_module( - layer.attn.c_proj, LINEAR_ROW, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.attn.c_proj, + LINEAR_ROW, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) return config @@ -67,10 +74,18 @@ def build_attention(self, layer) -> AttentionConfig: def build_mlp(self, layer) -> MLPConfig: config = MLPConfig() config.fc = LinearConfig.from_nn_module( - layer.mlp.c_fc, LINEAR_COLUMN, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.c_fc, + LINEAR_COLUMN, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) config.proj = LinearConfig.from_nn_module( - layer.mlp.c_proj, LINEAR_ROW, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.c_proj, + LINEAR_ROW, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) return config @@ -126,4 +141,7 @@ def build_decoder(self, layer): config.set_if_not_exist('rotary_pct', rotary_pct) config.set_if_not_exist('moe_num_experts', 0) - return GPTDecoderLayer(config=config, layer_idx=self.layer_id,) + return GPTDecoderLayer( + config=config, + layer_idx=self.layer_id, + ) diff --git a/nemo/export/trt_llm/decoder/gptj.py b/nemo/export/trt_llm/decoder/gptj.py index aa65ca385a479..327a31fdd35cb 100644 --- a/nemo/export/trt_llm/decoder/gptj.py +++ b/nemo/export/trt_llm/decoder/gptj.py @@ -60,7 +60,11 @@ def build_attention(self, layer) -> AttentionConfig: ) config.dense = LinearConfig.from_nn_module( - layer.attn.out_proj, LINEAR_ROW, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.attn.out_proj, + LINEAR_ROW, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) config.rotary_dim = layer.attn.rotary_dim @@ -71,10 +75,18 @@ def build_attention(self, layer) -> AttentionConfig: def build_mlp(self, layer) -> MLPConfig: config = MLPConfig() config.fc = LinearConfig.from_nn_module( - layer.mlp.fc_in, LINEAR_COLUMN, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.fc_in, + LINEAR_COLUMN, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) config.proj = LinearConfig.from_nn_module( - layer.mlp.fc_out, LINEAR_ROW, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.fc_out, + LINEAR_ROW, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) return config diff --git a/nemo/export/trt_llm/decoder/llama.py b/nemo/export/trt_llm/decoder/llama.py index 873c0306375b6..b37d62e214de3 100644 --- a/nemo/export/trt_llm/decoder/llama.py +++ b/nemo/export/trt_llm/decoder/llama.py @@ -66,7 +66,11 @@ def build_attention(self, layer) -> AttentionConfig: ) config.dense = LinearConfig.from_nn_module( - layer.self_attn.o_proj, LINEAR_ROW, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.self_attn.o_proj, + LINEAR_ROW, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) return config @@ -75,13 +79,25 @@ def build_attention(self, layer) -> AttentionConfig: def build_mlp(self, layer) -> MLPConfig: config = MLPConfig() config.fc = LinearConfig.from_nn_module( - layer.mlp.gate_proj, LINEAR_COLUMN, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.gate_proj, + LINEAR_COLUMN, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) config.proj = LinearConfig.from_nn_module( - layer.mlp.down_proj, LINEAR_ROW, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.down_proj, + LINEAR_ROW, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) config.gate = LinearConfig.from_nn_module( - layer.mlp.up_proj, LINEAR_COLUMN, rank=self.rank, tensor_parallel=self.tensor_parallel, dtype=self.dtype, + layer.mlp.up_proj, + LINEAR_COLUMN, + rank=self.rank, + tensor_parallel=self.tensor_parallel, + dtype=self.dtype, ) return config @@ -147,4 +163,7 @@ def build_decoder(self, layer): config.moe_tp_mode = layer.moe_tp_mode config.moe_normalization_mode = layer.moe_renorm_mode - return LLaMADecoderLayer(config=config, layer_idx=self.layer_id,) + return LLaMADecoderLayer( + config=config, + layer_idx=self.layer_id, + ) diff --git a/nemo/export/trt_llm/model_config.py b/nemo/export/trt_llm/model_config.py index dd360afd6b8af..0f120dc56153d 100644 --- a/nemo/export/trt_llm/model_config.py +++ b/nemo/export/trt_llm/model_config.py @@ -122,7 +122,11 @@ def from_nn_module(module: nn.Module, linear_type: str, rank=0, tensor_parallel= if hasattr(module, "bias") and module.bias is not None: if linear_type == LINEAR_COLUMN: config.bias = np.ascontiguousarray( - split(torch_to_numpy_with_dtype(module.bias, dtype), tensor_parallel, rank,) + split( + torch_to_numpy_with_dtype(module.bias, dtype), + tensor_parallel, + rank, + ) ) else: config.bias = torch_to_numpy_with_dtype(module.bias, dtype) @@ -234,7 +238,9 @@ class AttentionConfig: @staticmethod def from_nemo( - weights_dict: Dict[str, np.ndarray], layer_id: int, rank: int = 0, + weights_dict: Dict[str, np.ndarray], + layer_id: int, + rank: int = 0, ): """Converts the nemo weights and config to `AttentionConfig`.""" attention = AttentionConfig() @@ -243,12 +249,16 @@ def from_nemo( weights_dict, f"layers.{layer_id}.attention.query_key_value.weight.{rank}" ) attention.qkv.bias = get_tensor_from_dict( - weights_dict, f"layers.{layer_id}.attention.query_key_value.bias.{rank}", + weights_dict, + f"layers.{layer_id}.attention.query_key_value.bias.{rank}", ) attention.dense = LinearConfig(linear_type=LINEAR_ROW) attention.dense.weight = get_tensor_from_dict(weights_dict, f"layers.{layer_id}.attention.dense.weight.{rank}") - attention.dense.bias = get_tensor_from_dict(weights_dict, f"layers.{layer_id}.attention.dense.bias",) + attention.dense.bias = get_tensor_from_dict( + weights_dict, + f"layers.{layer_id}.attention.dense.bias", + ) return attention @@ -276,7 +286,10 @@ def from_nemo( # print("********** mlp.fc.weight : ", mlp.fc.weight ) - mlp.fc.bias = get_tensor_from_dict(weights_dict, f"layers.{layer_id}.mlp.dense_h_to_4h.bias.{rank}",) + mlp.fc.bias = get_tensor_from_dict( + weights_dict, + f"layers.{layer_id}.mlp.dense_h_to_4h.bias.{rank}", + ) gated = is_gated_activation(mlp.hidden_act) is_fast_glu = mlp.hidden_act in ['fast-geglu', 'fast-swiglu', 'fast-reglu'] @@ -287,9 +300,13 @@ def from_nemo( if isinstance(llm_config, LlamaConfig) and not is_mcore and not is_fast_glu else f"layers.{layer_id}.mlp.dense_h_to_4h.gate.weight.{rank}" ) - mlp.gate.weight = get_tensor_from_dict(weights_dict, layer_name,) + mlp.gate.weight = get_tensor_from_dict( + weights_dict, + layer_name, + ) mlp.gate.bias = get_tensor_from_dict( - weights_dict, f"layers.{layer_id}.mlp.dense_h_to_4h.gate.bias.{rank}", + weights_dict, + f"layers.{layer_id}.mlp.dense_h_to_4h.gate.bias.{rank}", ) mlp.proj = LinearConfig(linear_type=LINEAR_ROW) @@ -382,19 +399,23 @@ def from_nemo( LAYERNORM_RMS if isinstance(llm_config, LlamaConfig) else LAYERNORM_DEFAULT ) layer_config.input_layernorm.weight = get_tensor_from_dict( - weights_dict, f"layers.{layer_id}.input_layernorm.weight", + weights_dict, + f"layers.{layer_id}.input_layernorm.weight", ) layer_config.input_layernorm.bias = get_tensor_from_dict( - weights_dict, f"layers.{layer_id}.input_layernorm.bias", + weights_dict, + f"layers.{layer_id}.input_layernorm.bias", ) layer_config.mlp_layernorm = LayernormConfig() layer_config.mlp_layernorm.layernorm_type = LAYERNORM_DEFAULT # Falcon uses default layernorm layer_config.mlp_layernorm.weight = get_tensor_from_dict( - weights_dict, f"layers.{layer_id}.pre_mlp_layernorm.weight", + weights_dict, + f"layers.{layer_id}.pre_mlp_layernorm.weight", ) layer_config.mlp_layernorm.bias = get_tensor_from_dict( - weights_dict, f"layers.{layer_id}.pre_mlp_layernorm.bias", + weights_dict, + f"layers.{layer_id}.pre_mlp_layernorm.bias", ) layer_config.post_layernorm = LayernormConfig() @@ -403,10 +424,12 @@ def from_nemo( ) layer_config.post_layernorm.weight = get_tensor_from_dict( - weights_dict, f"layers.{layer_id}.post_attention_layernorm.weight", + weights_dict, + f"layers.{layer_id}.post_attention_layernorm.weight", ) layer_config.post_layernorm.bias = get_tensor_from_dict( - weights_dict, f"layers.{layer_id}.post_attention_layernorm.bias", + weights_dict, + f"layers.{layer_id}.post_attention_layernorm.bias", ) if layer_config.post_layernorm.weight is None: # Falcon doesn't have post layernorm @@ -415,7 +438,11 @@ def from_nemo( if layer_config.mlp_layernorm.weight is None: layer_config.mlp_layernorm = None - layer_config.attention = AttentionConfig.from_nemo(weights_dict, layer_id, rank,) + layer_config.attention = AttentionConfig.from_nemo( + weights_dict, + layer_id, + rank, + ) moe = False if llm_config.moe_num_experts is not None: diff --git a/nemo/export/trt_llm/nemo/nemo.py b/nemo/export/trt_llm/nemo/nemo.py index 9026cd9cfba9e..c3564f1c4e8e9 100644 --- a/nemo/export/trt_llm/nemo/nemo.py +++ b/nemo/export/trt_llm/nemo/nemo.py @@ -106,7 +106,9 @@ def extract_layers_with_prefix(model_, prefix): class UnpackedNemoCheckpointDir: def __init__( - self, checkpoints_dir: typing.Union[pathlib.Path, TarPath], load_checkpoints_to_cpu: bool = False, + self, + checkpoints_dir: typing.Union[pathlib.Path, TarPath], + load_checkpoints_to_cpu: bool = False, ): assert isinstance(checkpoints_dir, (pathlib.Path, TarPath)) self._checkpoints_dir = checkpoints_dir @@ -121,11 +123,7 @@ def model_config(self): model_configs_paths = list(self._checkpoints_dir.rglob(model_config_filename)) if model_configs_paths: if len(model_configs_paths) > 1: - raise RuntimeError( - f"There are more than single {model_config_filename} in" - f" {self._checkpoints_dir}:" - f" {', '.join(map(lambda p: p.as_posix(), model_configs_paths))}" - ) + LOGGER.debug(f"There are more than single {model_config_filename} in" f" {self._checkpoints_dir}") model_config_path = model_configs_paths[0] LOGGER.debug("Loading model config from %s", model_config_path) with model_config_path.open("r") as model_config_file: diff --git a/nemo/export/trt_llm/tensorrt_llm_model.py b/nemo/export/trt_llm/tensorrt_llm_model.py index 736d6180807e7..f4b44552af63d 100644 --- a/nemo/export/trt_llm/tensorrt_llm_model.py +++ b/nemo/export/trt_llm/tensorrt_llm_model.py @@ -144,7 +144,12 @@ def forward( if attention_mask is not None: attention_mask = expand_mask(attention_mask, shape(input_ids, -1)) - for layer_idx, (layer, past) in enumerate(zip(self.layers, kv_cache_params.past_key_value,)): + for layer_idx, (layer, past) in enumerate( + zip( + self.layers, + kv_cache_params.past_key_value, + ) + ): decoder_params = { "hidden_states": hidden_states, diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 92fc36272f7c6..fe0189b106286 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -16,17 +16,19 @@ import json import logging import os +import tempfile from dataclasses import dataclass from pathlib import Path from typing import List, Optional +import numpy as np import tensorrt_llm import torch from mpi4py.futures import MPIPoolExecutor from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import LoraManager from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import ModelConfig, ModelRunnerCpp, SamplingConfig +from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig from transformers import PreTrainedTokenizer from nemo.export.trt_llm.tensor_utils import get_tensor_parallel_group @@ -55,7 +57,7 @@ class TensorrtLLMHostContext: class TensorrtLLMWorkerContext: """The MPI worker side context for TRT LLM inference.""" - decoder: ModelRunnerCpp = None + decoder: ModelRunner = None sampling_config: SamplingConfig = None lora_manager: LoraManager = None max_batch_size: int = 0 @@ -128,7 +130,13 @@ def _read_config(config_path: Path): return model_config, world_size, tensor_parallel_size, pipeline_parallel_size, dtype, max_input_len, max_batch_size -def _load(tokenizer: PreTrainedTokenizer, engine_dir, lora_ckpt_list=None, num_beams=1): +def _load( + tokenizer: PreTrainedTokenizer, + engine_dir, + lora_ckpt_list=None, + num_beams=1, + use_python_runtime: bool = True, +): """The impl of `load` API for on a single GPU worker.""" try: tensorrt_llm.logger.set_level("info") @@ -147,17 +155,26 @@ def _load(tokenizer: PreTrainedTokenizer, engine_dir, lora_ckpt_list=None, num_b runtime_rank = tensorrt_llm.mpi_rank() - decoder = ModelRunnerCpp.from_dir( - engine_dir=engine_dir, - lora_dir=lora_ckpt_list, - lora_ckpt_source="nemo", - rank=runtime_rank, - max_batch_size=max_batch_size, - max_input_len=max_input_len, - max_output_len=max_output_len, - max_beam_width=max_beam_width, - debug_mode=False, - ) + if use_python_runtime: + decoder = ModelRunner.from_dir( + engine_dir=engine_dir, + lora_dir=lora_ckpt_list, + lora_ckpt_source="nemo", + rank=runtime_rank, + debug_mode=False, + ) + else: + decoder = ModelRunnerCpp.from_dir( + engine_dir=engine_dir, + lora_dir=lora_ckpt_list, + lora_ckpt_source="nemo", + rank=runtime_rank, + max_batch_size=max_batch_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_beam_width=max_beam_width, + debug_mode=False, + ) sampling_config = SamplingConfig( end_id=tokenizer.eos_token_id, pad_id=tokenizer.eos_token_id, num_beams=num_beams @@ -218,6 +235,13 @@ def _forward( with torch.no_grad(): prompt_tasks = None if task_ids is None else ",".join(str(task) for task in task_ids) + if prompt_table is not None: + prompt_table = prompt_table.reshape(1, *prompt_table.shape) + tmp_dir = tempfile.TemporaryDirectory() + prompt_table_path = os.path.join(tmp_dir.name, 'prompt_table.npy') + np.save(prompt_table_path, prompt_table.cpu().float().numpy()) + prompt_table = prompt_table_path + outputs = decoder.generate( input_tensors, max_new_tokens=max_output_len, @@ -230,6 +254,7 @@ def _forward( stop_words_list=stop_words_list, bad_words_list=bad_words_list, lora_uids=lora_uids, + prompt_table_path=prompt_table, prompt_table=prompt_table, prompt_tasks=prompt_tasks, streaming=streaming, @@ -239,6 +264,9 @@ def _forward( torch.cuda.synchronize() + if prompt_table is not None: + tmp_dir.cleanup() + runtime_rank = tensorrt_llm.mpi_rank() if runtime_rank == 0 or multiprocessed_env: return outputs @@ -251,7 +279,11 @@ def _forward( def load( - tokenizer: PreTrainedTokenizer, engine_dir: str, lora_ckpt_list: List[str] = None, num_beams: int = 1 + tokenizer: PreTrainedTokenizer, + engine_dir: str, + lora_ckpt_list: List[str] = None, + num_beams: int = 1, + use_python_runtime: bool = True, ) -> TensorrtLLMHostContext: """Loaded the compiled LLM model and run it. @@ -263,17 +295,17 @@ def load( config = json.load(f) world_size = config["pretrained_config"]["mapping"]["world_size"] if world_size == 1: - _load(tokenizer, engine_dir, lora_ckpt_list, num_beams) + _load(tokenizer, engine_dir, lora_ckpt_list, num_beams, use_python_runtime) executor = None elif tensorrt_llm.mpi_world_size() > 1: - _load(tokenizer, engine_dir, lora_ckpt_list, num_beams) + _load(tokenizer, engine_dir, lora_ckpt_list, num_beams, use_python_runtime) executor = None tensorrt_llm.mpi_barrier() else: executor = MPIPoolExecutor(max_workers=world_size) futures = [] for _ in range(world_size): - future = executor.submit(_load, tokenizer, engine_dir, lora_ckpt_list, num_beams) + future = executor.submit(_load, tokenizer, engine_dir, lora_ckpt_list, num_beams, use_python_runtime) futures.append(future) for future in futures: future.result() diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py index 7370731ec996e..0a9604a73cdc2 100755 --- a/scripts/deploy/nlp/deploy_triton.py +++ b/scripts/deploy/nlp/deploy_triton.py @@ -80,7 +80,7 @@ def get_args(argv): "-mpet", "--max_prompt_embedding_table_size", default=None, type=int, help="Max prompt embedding table size" ) parser.add_argument( - "-upkc", "--use_paged_kv_cache", default=False, action='store_true', help="Enable paged kv cache." + "-npkc", "--no_paged_kv_cache", default=False, action='store_true', help="Enable paged kv cache." ) parser.add_argument( "-drip", @@ -133,6 +133,13 @@ def get_args(argv): parser.add_argument( "-lc", "--lora_ckpt", default=None, type=str, nargs="+", help="The checkpoint list of LoRA weights" ) + parser.add_argument( + "-ucr", + '--use_cpp_runtime', + default=False, + action='store_true', + help='Use TensorRT LLM C++ runtime', + ) parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") args = parser.parse_args(argv) @@ -206,32 +213,13 @@ def nemo_deploy(argv): ) return - trt_llm_exporter = TensorRTLLM(model_dir=trt_llm_path, lora_ckpt_list=args.lora_ckpt) + trt_llm_exporter = TensorRTLLM( + model_dir=trt_llm_path, + lora_ckpt_list=args.lora_ckpt, + use_python_runtime=(not args.use_cpp_runtime), + ) if args.nemo_checkpoint is not None: - - trt_llm_exporter.export( - nemo_checkpoint_path=args.nemo_checkpoint, - model_type=args.model_type, - n_gpus=args.num_gpus, - tensor_parallel_size=args.num_gpus, - pipeline_parallel_size=1, - max_input_token=args.max_input_len, - max_output_token=args.max_output_len, - max_batch_size=args.max_batch_size, - max_num_tokens=args.max_num_tokens, - opt_num_tokens=args.opt_num_tokens, - max_prompt_embedding_table_size=args.max_prompt_embedding_table_size, - paged_kv_cache=args.use_paged_kv_cache, - remove_input_padding=(not args.disable_remove_input_padding), - dtype=args.dtype, - enable_multi_block_mode=args.multi_block_mode, - use_lora_plugin=args.use_lora_plugin, - lora_target_modules=args.lora_target_modules, - max_lora_rank=args.max_lora_rank, - save_nemo_model_config=True, - ) - try: LOGGER.info("Export operation will be started to export the nemo checkpoint to TensorRT-LLM.") trt_llm_exporter.export( @@ -246,7 +234,7 @@ def nemo_deploy(argv): max_num_tokens=args.max_num_tokens, opt_num_tokens=args.opt_num_tokens, max_prompt_embedding_table_size=args.max_prompt_embedding_table_size, - paged_kv_cache=args.use_paged_kv_cache, + paged_kv_cache=(not args.no_paged_kv_cache), remove_input_padding=(not args.disable_remove_input_padding), dtype=args.dtype, enable_multi_block_mode=args.multi_block_mode, diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index e9741516cf004..ce9ef6a1e1328 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -45,8 +45,8 @@ def get_args(argv): parser.add_argument( "-dt", "--dtype", - choices=["bf16", "fp16", "fp8", "int8"], - default="bf16", + choices=["bfloat16", "float16", "fp8", "int8"], + default="bfloat16", type=str, help="dtype of the model on TensorRT-LLM", ) @@ -59,7 +59,7 @@ def get_args(argv): "-mpet", "--max_prompt_embedding_table_size", default=None, type=int, help="Max prompt embedding table size" ) parser.add_argument( - "-upkc", "--use_paged_kv_cache", default=False, action='store_true', help="Enable paged kv cache." + "-npkc", "--no_paged_kv_cache", default=False, action='store_true', help="Enable paged kv cache." ) parser.add_argument( "-drip", @@ -123,7 +123,7 @@ def nemo_export_trt_llm(argv): LOGGER.info("Logging level set to {}".format(loglevel)) LOGGER.info(args) - if args.dtype != "bf16": + if args.dtype != "bfloat16": LOGGER.error( "Only bf16 is currently supported for the optimized deployment with TensorRT-LLM. " "Support for the other precisions will be added in the coming releases." @@ -146,7 +146,7 @@ def nemo_export_trt_llm(argv): max_num_tokens=args.max_num_tokens, opt_num_tokens=args.opt_num_tokens, max_prompt_embedding_table_size=args.max_prompt_embedding_table_size, - paged_kv_cache=args.use_paged_kv_cache, + paged_kv_cache=(not args.no_paged_kv_cache), remove_input_padding=(not args.disable_remove_input_padding), dtype=args.dtype, enable_multi_block_mode=args.multi_block_mode,