Skip to content

Commit

Permalink
Vllm 0.6.0 integration test (#10697)
Browse files Browse the repository at this point in the history
* fixes for vllm 0.5.5

Signed-off-by: HuiyingLi <[email protected]>

* vllm 0.6.0 updates

Signed-off-by: HuiyingLi <[email protected]>

* change --ptuning --lora arg

Signed-off-by: HuiyingLi <[email protected]>

* convert ptuning and lora to bool

Signed-off-by: HuiyingLi <[email protected]>

* parse arg being empty string to be false

Signed-off-by: HuiyingLi <[email protected]>

* cleanup and format

Signed-off-by: HuiyingLi <[email protected]>

---------

Signed-off-by: HuiyingLi <[email protected]>
Co-authored-by: Onur Yilmaz <[email protected]>
  • Loading branch information
HuiyingLi and oyilmaz-nvidia authored Oct 4, 2024
1 parent 7877766 commit 3578f75
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 10 deletions.
7 changes: 6 additions & 1 deletion nemo/export/vllm/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union
from typing import Any, Dict, Optional, Union

import torch
import yaml
Expand All @@ -39,6 +39,7 @@ def __init__(
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
override_neuron_config: Optional[Dict[str, Any]] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
Expand All @@ -50,6 +51,7 @@ def __init__(
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
disable_sliding_window: bool = False,
use_async_output_proc: bool = False,
) -> None:
# Don't call ModelConfig.__init__ because we don't want it to call
# transformers.AutoConfig.from_pretrained(...)
Expand All @@ -67,6 +69,7 @@ def __init__(
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.override_neuron_config = override_neuron_config
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.tokenizer_revision = tokenizer_revision
Expand All @@ -77,6 +80,8 @@ def __init__(
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.served_model_name = nemo_checkpoint
self.multimodal_config = None
self.use_async_output_proc = use_async_output_proc

self.model_converter = get_model_converter(model_type)
if self.model_converter is None:
Expand Down
3 changes: 1 addition & 2 deletions nemo/export/vllm/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def load_model(
model_config: NemoModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
Expand All @@ -88,7 +87,7 @@ def load_model(

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config)
model = _initialize_model(model_config, self.load_config, lora_config, cache_config)

weights_iterator = model_config.model_converter.convert_weights(model_config.nemo_model_config, state_dict)

Expand Down
5 changes: 5 additions & 0 deletions nemo/export/vllm/tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import List, Optional

from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup

Expand All @@ -29,6 +30,10 @@ def __init__(self, tokenizer: SentencePieceTokenizer, add_bos_token: bool = Fals
self.tokenizer = tokenizer
self.add_bos_token = add_bos_token

@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, **init_kwargs):
raise NotImplementedError

def ping(self) -> bool:
return True

Expand Down
11 changes: 9 additions & 2 deletions nemo/export/vllm_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os.path
from typing import Iterable, List, Optional, Union
Expand Down Expand Up @@ -147,6 +148,11 @@ def export(
pipeline_parallel_size=pipeline_parallel_size, tensor_parallel_size=tensor_parallel_size
)

# vllm/huggingface doesn't like the absense of config file. Place config in load dir.
if model_config.model and not os.path.exists(os.path.join(model_config.model, 'config.json')):
with open(os.path.join(model_config.model, 'config.json'), "w") as f:
json.dump(model_config.hf_text_config.to_dict(), f)

# See if we have an up-to-date safetensors file
safetensors_file = os.path.join(model_config.model, 'model.safetensors')
safetensors_file_valid = os.path.exists(safetensors_file) and os.path.getmtime(
Expand Down Expand Up @@ -248,7 +254,6 @@ def export(
device_config=device_config,
load_config=load_config,
lora_config=lora_config,
multimodal_config=None,
speculative_config=None,
decoding_config=None,
observability_config=None,
Expand Down Expand Up @@ -295,7 +300,9 @@ def _add_request_to_engine(
if top_p <= 0.0:
top_p = 1.0

sampling_params = SamplingParams(max_tokens=max_output_len, temperature=temperature, top_k=top_k, top_p=top_p)
sampling_params = SamplingParams(
max_tokens=max_output_len, temperature=temperature, top_k=int(top_k), top_p=top_p
)

if lora_uid is not None and lora_uid >= 0 and lora_uid < len(self.lora_checkpoints):
lora_request = LoRARequest(
Expand Down
14 changes: 9 additions & 5 deletions tests/export/nemo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def run_inference(
if test_deployment:
nm.stop()

if not save_trt_engine:
if not save_trt_engine and model_dir:
shutil.rmtree(model_dir)

return (functional_result, accuracy_result)
Expand Down Expand Up @@ -677,17 +677,17 @@ def get_args():
)
parser.add_argument(
"--ptuning",
default=False,
action='store_true',
type=str,
default="False",
)
parser.add_argument(
"--lora_checkpoint",
type=str,
)
parser.add_argument(
"--lora",
default=False,
action='store_true',
type=str,
default="False",
)
parser.add_argument(
"--top_k",
Expand Down Expand Up @@ -784,6 +784,8 @@ def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]:
s = s.lower()
true_strings = ["true", "1"]
false_strings = ["false", "0"]
if s == '':
return False
if s in true_strings:
return True
if s in false_strings:
Expand All @@ -798,6 +800,8 @@ def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]:
args.save_trt_engine = str_to_bool("save_trt_engin", args.save_trt_engine)
args.run_accuracy = str_to_bool("run_accuracy", args.run_accuracy)
args.use_vllm = str_to_bool("use_vllm", args.use_vllm)
args.lora = str_to_bool("lora", args.lora)
args.ptuning = str_to_bool("ptuning", args.ptuning)
args.use_parallel_embedding = str_to_bool("use_parallel_embedding", args.use_parallel_embedding)
args.in_framework = str_to_bool("in_framework", args.in_framework)
args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized, optional=True)
Expand Down

0 comments on commit 3578f75

Please sign in to comment.