Skip to content

Commit

Permalink
Squash 10400
Browse files Browse the repository at this point in the history
Signed-off-by: Jefferson Fialho <[email protected]>
  • Loading branch information
fialhocoelho committed Nov 18, 2024
1 parent b715908 commit 6aa7422
Show file tree
Hide file tree
Showing 16 changed files with 646 additions and 86 deletions.
5 changes: 5 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,11 @@ def uses_mrope(self) -> bool:
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None

@property
def is_cross_encoder(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_cross_encoder_model(architectures)


class CacheConfig:
"""Configuration for the KV cache.
Expand Down
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,7 @@ def schedule(
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
state=seq_group.state,
token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
Expand Down
107 changes: 106 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format)
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -804,6 +804,111 @@ def encode(
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)

def score(
self,
query: SingletonPrompt,
texts: Union[SingletonPrompt, Sequence[SingletonPrompt]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates similarity scores for all pairs <query,text>.
This method pairs the input query with each of the texts to generate
a list of prompts for the cross encoder model. This class automatically
batches the prompts, considering the memory constraint. For the best
performance, put all of your texts into a single list and pass it to
this method.
Args:
query: The query to compare against all other text input
texts: The texts to pair with the query to form the input
to the LLM. You may pass a sequence of texts for batch
inference. See :class:`~vllm.inputs.PromptType` for more
details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.score() is only supported for embedding models."]

supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")

raise ValueError(" ".join(messages))

if not self.llm_engine.model_config.is_cross_encoder:
raise ValueError("Your model does not support the cross encoding")

tokenizer = self.llm_engine.get_tokenizer()

if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")

# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict):
if "multi_modal_data" in prompt:
raise ValueError("Multi-modal prompt is not "
"supported for cross encoding")
elif "prompt_token_ids" in prompt:
prompt = tokenizer.decode(
cast(TokensPrompt, prompt)["prompt_token_ids"])
elif "prompt" in prompt:
prompt = cast(TextPrompt, prompt)["prompt"]
assert type(prompt) is str
return prompt

query = ensure_str(query)
if isinstance(texts, (str, dict)):
# Convert a single prompt to a list.
texts = [texts]

input_pairs = [(query, ensure_str(t)) for t in texts]
pooling_params = PoolingParams()

tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens

parsed_prompts = []

for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)

self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)

outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)

def start_profile(self) -> None:
self.llm_engine.start_profile()

Expand Down
123 changes: 96 additions & 27 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
import asyncio
import base64
import time
from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
from typing import (Annotated, Any, AsyncGenerator, Dict, Final, List, Literal,
Optional, Sequence, Tuple, Union, cast)

import numpy as np
from fastapi import Request
from pydantic import Field
from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
OpenAIServing,
RequestPrompt)
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import merge_async_iterators, random_uuid

logger = init_logger(__name__)
Expand Down Expand Up @@ -136,31 +146,49 @@ async def create_embedding(
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")

if isinstance(request, EmbeddingChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
if self.model_config.is_cross_encoder:
if isinstance(request, EmbeddingChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_cross_encoding(
tokenizer,
request.messages,
truncate_prompt_tokens=truncate_prompt_tokens,
)
else:
return self.create_error_response(
"Cross encoding requests must "
"use the chat embedding API")
else:
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
if isinstance(request, EmbeddingChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template
or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
request_prompts, engine_prompts = self\
._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
Expand Down Expand Up @@ -225,3 +253,44 @@ async def create_embedding(
return self.create_error_response(str(e))

return response

async def _preprocess_cross_encoding(
self,
tokenizer: AnyTokenizer,
messages: List[ChatCompletionMessageParam],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=3)]] = None,
) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
List[TokensPrompt]]:

conversation, mm_data_future = parse_chat_messages_futures(
messages, self.model_config, tokenizer, "string")
await mm_data_future

if len(conversation) != 2:
raise ValueError("For cross encoding two inputs must be provided")
prompts = []
for msg in conversation:
content = msg["content"]
assert type(msg["content"]) is str
prompts.append(content)

if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")

request_prompt = f"{prompts[0]}{tokenizer.sep_token}{prompts[1]}"

tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens

prompt_inputs = tokenizer(prompts[0],
text_pair=prompts[1],
**tokenization_kwargs)

engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))

return conversation, [request_prompt], [engine_prompt]
18 changes: 18 additions & 0 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class TokensPrompt(TypedDict):
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""

token_type_ids: NotRequired[List[int]]
"""A list of token type IDs to pass to the cross encoder model."""

multi_modal_data: NotRequired["MultiModalDataDict"]
"""
DEPRECATED: Optional multi-modal data to pass to the model,
Expand Down Expand Up @@ -133,6 +136,9 @@ class TokenInputs(TypedDict):
prompt_token_ids: List[int]
"""The token IDs of the prompt."""

token_type_ids: NotRequired[List[int]]
"""The token type IDs of the prompt."""

prompt: NotRequired[str]
"""
The original prompt text corresponding to the token IDs, if available.
Expand Down Expand Up @@ -160,6 +166,7 @@ class TokenInputs(TypedDict):

def token_inputs(
prompt_token_ids: List[int],
token_type_ids: Optional[List[int]] = None,
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
Expand All @@ -170,6 +177,8 @@ def token_inputs(

if prompt is not None:
inputs["prompt"] = prompt
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if multi_modal_placeholders is not None:
Expand Down Expand Up @@ -234,6 +243,15 @@ def prompt_token_ids(self) -> List[int]:

assert_never(inputs)

@cached_property
def token_type_ids(self) -> List[int]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("token_type_ids", [])

assert_never(inputs)

@cached_property
def prompt_embeds(self) -> Optional[torch.Tensor]:
inputs = self.inputs
Expand Down
2 changes: 2 additions & 0 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def _prompt_to_llm_inputs(
tokens_content = parsed["content"]

prompt_token_ids = tokens_content["prompt_token_ids"]
token_type_ids = tokens_content.get("token_type_ids")
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")

Expand All @@ -318,6 +319,7 @@ def _prompt_to_llm_inputs(

return token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
Expand Down
Loading

0 comments on commit 6aa7422

Please sign in to comment.