Skip to content

Commit

Permalink
Fix P-tuning for Llama based models (#9300)
Browse files Browse the repository at this point in the history
* Fix P-tuning for Llama based models (#9297)

* Added the BOS token for Llama, Mistral and Mixtral.

Signed-off-by: Alexey Panteleev <[email protected]>

* Don't load an existing TRT-LLM model before export to speed up the export process and avoid possible contamination from previous runs.

Signed-off-by: Alexey Panteleev <[email protected]>

* Apply isort and black reformatting

Signed-off-by: apanteleev <[email protected]>

---------

Signed-off-by: Alexey Panteleev <[email protected]>
Signed-off-by: apanteleev <[email protected]>
Co-authored-by: apanteleev <[email protected]>
Co-authored-by: Onur Yilmaz <[email protected]>

* Fix the export test

---------

Signed-off-by: Alexey Panteleev <[email protected]>
Signed-off-by: apanteleev <[email protected]>
Signed-off-by: Onur Yilmaz <[email protected]>
Co-authored-by: Alexey Panteleev <[email protected]>
Co-authored-by: apanteleev <[email protected]>
Co-authored-by: Onur Yilmaz <[email protected]>
Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
4 people authored and janekl committed Jun 12, 2024
1 parent 474e00e commit 1fe50da
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 3 deletions.
8 changes: 7 additions & 1 deletion nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,13 @@ def load(

max_batch_size = config["build_config"]["max_batch_size"]
max_input_len = config["build_config"]["max_input_len"]
add_bos = True if config["pretrained_config"]["architecture"] == "GemmaForCausalLM" else False
architectures_that_need_bos_token = [
"GemmaForCausalLM",
"LLaMAForCausalLM",
"MistralForCausalLM",
"MixtralForCausalLM",
]
add_bos = config["pretrained_config"]["architecture"] in architectures_that_need_bos_token

return TensorrtLLMHostContext(
executor=executor,
Expand Down
1 change: 1 addition & 0 deletions scripts/deploy/nlp/deploy_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def nemo_deploy(argv):
trt_llm_exporter = TensorRTLLM(
model_dir=trt_llm_path,
lora_ckpt_list=args.lora_ckpt,
load_model=(args.nemo_checkpoint is None),
use_python_runtime=(not args.use_cpp_runtime),
)

Expand Down
2 changes: 1 addition & 1 deletion scripts/export/export_to_trt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def nemo_export_trt_llm(argv):
return

try:
trt_llm_exporter = TensorRTLLM(model_dir=args.model_repository)
trt_llm_exporter = TensorRTLLM(model_dir=args.model_repository, load_model=False)

LOGGER.info("Export to TensorRT-LLM function is called.")
trt_llm_exporter.export(
Expand Down
2 changes: 1 addition & 1 deletion tests/export/test_nemo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def run_trt_llm_inference(
print("---- LoRA could not be enabled and skipping the test.")
return None, None, None, None, None

trt_llm_exporter = TensorRTLLM(trt_llm_model_dir, lora_ckpt_list)
trt_llm_exporter = TensorRTLLM(trt_llm_model_dir, lora_ckpt_list, load_model=False)

trt_llm_exporter.export(
nemo_checkpoint_path=checkpoint_path,
Expand Down

0 comments on commit 1fe50da

Please sign in to comment.