Skip to content

Commit

Permalink
Export & deploy updates (part II) (#11344)
Browse files Browse the repository at this point in the history
* Set debug level accordingly and use logger for some prints

Signed-off-by: Jan Lasek <[email protected]>

* Param model_type is in fact required for vLLM deployment

Signed-off-by: Jan Lasek <[email protected]>

* It's time to remove version check

Signed-off-by: Jan Lasek <[email protected]>

* Start using v2 block manager (default) in vLLM

Signed-off-by: Jan Lasek <[email protected]>

* Rename original NeMo SP tokenizer to solve UnboundLocalError

Signed-off-by: Jan Lasek <[email protected]>

* Add output_generation_logits as Triton input for vLLM exporter

Signed-off-by: Jan Lasek <[email protected]>

* Update requirements_vllm.txt

Signed-off-by: Jan Lasek <[email protected]>

* Refer to NeMo tokenizers explicitely

Signed-off-by: Jan Lasek <[email protected]>

---------

Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
janekl authored Nov 21, 2024
1 parent 773590c commit ffdccaf
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 20 deletions.
12 changes: 6 additions & 6 deletions nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,14 @@ def build_tokenizer(tokenizer):
tokenizer.add_special_tokens({"eos_token": "</s>"})
else:
# For NeMo tokenizers, monkey patch encode & batch_decode methods for unified interface
from nemo.collections.common.tokenizers import AutoTokenizer, SentencePieceTokenizer, TokenizerSpec
import nemo.collections.common.tokenizers as nemo_tokenizers

if isinstance(tokenizer, TokenizerSpec):
if isinstance(tokenizer, AutoTokenizer):
if isinstance(tokenizer, nemo_tokenizers.TokenizerSpec):
if isinstance(tokenizer, nemo_tokenizers.AutoTokenizer):
# Unwrap the original methods of HF tokenizer
batch_decode = tokenizer.tokenizer.batch_decode
encode = tokenizer.tokenizer.encode
elif isinstance(tokenizer, SentencePieceTokenizer):
elif isinstance(tokenizer, nemo_tokenizers.SentencePieceTokenizer):
# Define HF equivalents based on available SP methods
def batch_decode(self, ids):
if torch.is_tensor(ids):
Expand All @@ -340,8 +340,8 @@ def batch_decode(self, ids):

tokenizer.bos_token_id = tokenizer.bos_id
tokenizer.eos_token_id = tokenizer.eos_id
TokenizerSpec.encode = encode
TokenizerSpec.batch_decode = batch_decode
nemo_tokenizers.TokenizerSpec.encode = encode
nemo_tokenizers.TokenizerSpec.batch_decode = batch_decode

return tokenizer

Expand Down
9 changes: 2 additions & 7 deletions nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import warnings
from typing import List, Optional

import tensorrt_llm
from tensorrt_llm.models import PretrainedConfig

from nemo.export.trt_llm.qnemo.utils import CONFIG_NAME, WEIGHTS_NAME
Expand Down Expand Up @@ -51,7 +50,7 @@ def qnemo_to_tensorrt_llm(

warnings.warn(
"Note that setting tensor_parallel_size, pipeline_parallel_size and use_parallel_embedding "
" parameters for quantized models is done on calibration step with nemo.export.quantize module."
" parameters for quantized models is done on the calibration step (in PTQ workflow)."
" These parameters are ignored when building and running TensorRT-LLM engine below.",
UserWarning,
stacklevel=3,
Expand Down Expand Up @@ -93,11 +92,7 @@ def qnemo_to_tensorrt_llm(
build_cmd += f"--remove_input_padding {'enable' if remove_input_padding else 'disable'} "
build_cmd += f"--multiple_profiles {'enable' if multiple_profiles else 'disable'} "
build_cmd += f"--reduce_fusion {'enable' if reduce_fusion else 'disable'} "
# TODO: resolve version check for setting use_fused_mlp once we move to 0.13.0 in the NeMo container
if tensorrt_llm.__version__ >= "0.13.0":
build_cmd += f"--use_fused_mlp {'enable' if use_fused_mlp else 'disable'} "
else:
build_cmd += "--use_fused_mlp " if use_fused_mlp else ""
build_cmd += f"--use_fused_mlp {'enable' if use_fused_mlp else 'disable'} "

if not use_qdq:
build_cmd += f"--gemm_plugin auto "
Expand Down
8 changes: 5 additions & 3 deletions nemo/export/trt_llm/qnemo/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

from omegaconf import OmegaConf
Expand All @@ -24,22 +25,23 @@

TOKENIZER_CONFIG_FILE = "tokenizer_config.yaml"
TOKENIZER_DIR = "tokenizer"
LOGGER = logging.getLogger("NeMo")


def get_nmt_tokenizer(nemo_checkpoint_path: str):
"""Build tokenizer from Nemo tokenizer config."""

print(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}")
LOGGER.info(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}")
tokenizer_cfg = OmegaConf.load(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE))

library = tokenizer_cfg.library
legacy = tokenizer_cfg.get("sentencepiece_legacy", library == "sentencepiece")

if library == "huggingface":
print(f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_cfg.type}")
LOGGER.info(f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_cfg.type}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_cfg["type"], use_fast=tokenizer_cfg.get("use_fast", False))
elif library == "sentencepiece":
print(f"Getting SentencePieceTokenizer with model: {tokenizer_cfg.model}")
LOGGER.info(f"Getting SentencePieceTokenizer with model: {tokenizer_cfg.model}")
tokenizer = SentencePieceTokenizer(
model_path=os.path.join(nemo_checkpoint_path, tokenizer_cfg.model), legacy=legacy
)
Expand Down
6 changes: 5 additions & 1 deletion nemo/export/vllm_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ def export(
max_num_seqs=256,
# Note: max_model_len can be derived by model_config if the input value is None
max_model_len=model_config.max_model_len,
use_v2_block_manager=False,
num_lookahead_slots=0,
delay_factor=0.0,
enable_chunked_prefill=False,
Expand Down Expand Up @@ -403,6 +402,7 @@ def get_triton_input(self):
Tensor(name="top_p", shape=(-1,), dtype=numpy.single, optional=True),
Tensor(name="temperature", shape=(-1,), dtype=numpy.single, optional=True),
Tensor(name="lora_uids", shape=(-1,), dtype=bytes, optional=True),
Tensor(name="output_generation_logits", shape=(-1,), dtype=numpy.bool_, optional=True),
)
return inputs

Expand Down Expand Up @@ -455,6 +455,7 @@ def forward(
prompt_embeddings_checkpoint_path: Optional[str] = None,
streaming: bool = False,
output_log_probs: bool = False,
output_generation_logits: bool = False,
) -> Union[List[List[str]], Iterable[List[List[str]]]]:
"""
The forward function performs LLM evaluation on the provided array of prompts with other parameters shared,
Expand Down Expand Up @@ -484,6 +485,9 @@ def forward(
if output_log_probs:
raise NotImplementedError("output_log_probs is not supported")

if output_generation_logits:
raise NotImplementedError("output_generation_logits is not supported")

request_ids = []
for index in range(len(input_texts)):
prompt = input_texts[index]
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements_multimodal.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ diffusers>=0.19.3
einops_exts
imageio
kornia
megatron-energon
megatron-energon<3.0.0
nerfacc>=0.5.3
open_clip_torch==2.24.0
PyMCubes
Expand Down
20 changes: 19 additions & 1 deletion requirements/requirements_vllm.txt
Original file line number Diff line number Diff line change
@@ -1 +1,19 @@
vllm==0.5.3.post1
# Minimal set of NeMo requirements to run vLLM export & deployment in /opt/venv in a NeMo container
braceexpand
faiss-cpu
h5py
hydra-core>1.3,<=1.3.2
ijson
jieba
lightning>2.2.1
matplotlib>=3.3.2
omegaconf<=2.3
onnx>=1.7.0
OpenCC
pangu
rouge_score
sacrebleu
scikit-learn
vllm==0.6.3
webdataset>=0.2.86
wget
2 changes: 1 addition & 1 deletion scripts/deploy/nlp/deploy_vllm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_args(argv):
"-mt",
"--model_type",
type=str,
required=False,
required=True,
choices=["llama", "mistral", "mixtral", "starcoder2", "gemma"],
help="Type of the model",
)
Expand Down
3 changes: 3 additions & 0 deletions tests/export/nemo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,9 @@ def run_inference_tests(args):
"Use the same value for --min_tps and --max_tps."
)

if args.debug:
LOGGER.setLevel(logging.DEBUG)

result_dic: Dict[int, Tuple[FunctionalResult, Optional[AccuracyResult]]] = {}

if args.existing_test_models:
Expand Down

0 comments on commit ffdccaf

Please sign in to comment.