From 6aa7422cbbd616a4b9ca435e039dde3dc73e2921 Mon Sep 17 00:00:00 2001 From: Jefferson Fialho Date: Mon, 18 Nov 2024 19:12:38 -0300 Subject: [PATCH] Squash 10400 Signed-off-by: Jefferson Fialho --- vllm/config.py | 5 + vllm/core/scheduler.py | 1 + vllm/entrypoints/llm.py | 107 +++++++++- vllm/entrypoints/openai/serving_embedding.py | 123 +++++++++--- vllm/inputs/data.py | 18 ++ vllm/inputs/preprocess.py | 2 + vllm/model_executor/models/bert.py | 146 ++++++++++++-- vllm/model_executor/models/interfaces.py | 36 ++++ vllm/model_executor/models/registry.py | 22 ++- vllm/model_executor/models/roberta.py | 194 +++++++++++++++---- vllm/multimodal/inputs.py | 5 +- vllm/sequence.py | 9 + vllm/worker/cpu_embedding_model_runner.py | 4 + vllm/worker/cpu_model_runner.py | 25 ++- vllm/worker/embedding_model_runner.py | 7 +- vllm/worker/model_runner.py | 28 +++ 16 files changed, 646 insertions(+), 86 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 8d0e7353c9e09..3093ca535cdda 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -696,6 +696,11 @@ def uses_mrope(self) -> bool: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + @property + def is_cross_encoder(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_cross_encoder_model(architectures) + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a6598118badf6..0c62de78c9f0c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1452,6 +1452,7 @@ def schedule( encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, state=seq_group.state, + token_type_ids=seq_group.token_type_ids, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 86b0b6893f1d9..9c7c8e3b2aa13 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -18,7 +18,7 @@ apply_mistral_chat_template, parse_chat_messages, resolve_chat_template_content_format) -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -804,6 +804,111 @@ def encode( return self.engine_class.validate_outputs(outputs, EmbeddingRequestOutput) + def score( + self, + query: SingletonPrompt, + texts: Union[SingletonPrompt, Sequence[SingletonPrompt]], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[EmbeddingRequestOutput]: + """Generates similarity scores for all pairs . + + This method pairs the input query with each of the texts to generate + a list of prompts for the cross encoder model. This class automatically + batches the prompts, considering the memory constraint. For the best + performance, put all of your texts into a single list and pass it to + this method. + + Args: + query: The query to compare against all other text input + texts: The texts to pair with the query to form the input + to the LLM. You may pass a sequence of texts for batch + inference. See :class:`~vllm.inputs.PromptType` for more + details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``EmbeddingRequestOutput`` objects containing the + generated scores in the same order as the input prompts. + """ + task = self.llm_engine.model_config.task + if task != "embedding": + messages = ["LLM.score() is only supported for embedding models."] + + supported_tasks = self.llm_engine.model_config.supported_tasks + if "embedding" in supported_tasks: + messages.append( + "Your model supports the 'embedding' task, but is " + f"currently initialized for the '{task}' task. Please " + "initialize the model using `--task embedding`.") + + raise ValueError(" ".join(messages)) + + if not self.llm_engine.model_config.is_cross_encoder: + raise ValueError("Your model does not support the cross encoding") + + tokenizer = self.llm_engine.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + # the tokenizer for models such as + # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing + # lists of tokens to the `text` and `text_pair` kwargs + def ensure_str(prompt: SingletonPrompt): + if isinstance(prompt, dict): + if "multi_modal_data" in prompt: + raise ValueError("Multi-modal prompt is not " + "supported for cross encoding") + elif "prompt_token_ids" in prompt: + prompt = tokenizer.decode( + cast(TokensPrompt, prompt)["prompt_token_ids"]) + elif "prompt" in prompt: + prompt = cast(TextPrompt, prompt)["prompt"] + assert type(prompt) is str + return prompt + + query = ensure_str(query) + if isinstance(texts, (str, dict)): + # Convert a single prompt to a list. + texts = [texts] + + input_pairs = [(query, ensure_str(t)) for t in texts] + pooling_params = PoolingParams() + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + parsed_prompts = [] + + for q, t in input_pairs: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, + EmbeddingRequestOutput) + def start_profile(self) -> None: self.llm_engine.start_profile() diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 74ad7389784fc..91aa070b056d8 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,24 +1,34 @@ import asyncio import base64 import time -from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast +from typing import (Annotated, Any, AsyncGenerator, Dict, Final, List, Literal, + Optional, Sequence, Tuple, Union, cast) import numpy as np from fastapi import Request +from pydantic import Field from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + parse_chat_messages_futures) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + OpenAIServing, + RequestPrompt) +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -136,31 +146,49 @@ async def create_embedding( raise NotImplementedError("Prompt adapter is not supported " "for embedding models") - if isinstance(request, EmbeddingChatRequest): - ( - _, - request_prompts, - engine_prompts, - ) = await self._preprocess_chat( - request, - tokenizer, - request.messages, - chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self. - chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + if self.model_config.is_cross_encoder: + if isinstance(request, EmbeddingChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_cross_encoding( + tokenizer, + request.messages, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + else: + return self.create_error_response( + "Cross encoding requests must " + "use the chat embedding API") else: - request_prompts, engine_prompts = self._preprocess_completion( - request, - tokenizer, - request.input, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + if isinstance(request, EmbeddingChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template + or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + request_prompts, engine_prompts = self\ + ._preprocess_completion( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) except ValueError as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) @@ -225,3 +253,44 @@ async def create_embedding( return self.create_error_response(str(e)) return response + + async def _preprocess_cross_encoding( + self, + tokenizer: AnyTokenizer, + messages: List[ChatCompletionMessageParam], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=3)]] = None, + ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt], + List[TokensPrompt]]: + + conversation, mm_data_future = parse_chat_messages_futures( + messages, self.model_config, tokenizer, "string") + await mm_data_future + + if len(conversation) != 2: + raise ValueError("For cross encoding two inputs must be provided") + prompts = [] + for msg in conversation: + content = msg["content"] + assert type(msg["content"]) is str + prompts.append(content) + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + request_prompt = f"{prompts[0]}{tokenizer.sep_token}{prompts[1]}" + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + prompt_inputs = tokenizer(prompts[0], + text_pair=prompts[1], + **tokenization_kwargs) + + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + return conversation, [request_prompt], [engine_prompt] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 07ff9faa50f13..fb7dbbebd7b90 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -38,6 +38,9 @@ class TokensPrompt(TypedDict): prompt_token_ids: List[int] """A list of token IDs to pass to the model.""" + token_type_ids: NotRequired[List[int]] + """A list of token type IDs to pass to the cross encoder model.""" + multi_modal_data: NotRequired["MultiModalDataDict"] """ DEPRECATED: Optional multi-modal data to pass to the model, @@ -133,6 +136,9 @@ class TokenInputs(TypedDict): prompt_token_ids: List[int] """The token IDs of the prompt.""" + token_type_ids: NotRequired[List[int]] + """The token type IDs of the prompt.""" + prompt: NotRequired[str] """ The original prompt text corresponding to the token IDs, if available. @@ -160,6 +166,7 @@ class TokenInputs(TypedDict): def token_inputs( prompt_token_ids: List[int], + token_type_ids: Optional[List[int]] = None, prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, @@ -170,6 +177,8 @@ def token_inputs( if prompt is not None: inputs["prompt"] = prompt + if token_type_ids is not None: + inputs["token_type_ids"] = token_type_ids if multi_modal_data is not None: inputs["multi_modal_data"] = multi_modal_data if multi_modal_placeholders is not None: @@ -234,6 +243,15 @@ def prompt_token_ids(self) -> List[int]: assert_never(inputs) + @cached_property + def token_type_ids(self) -> List[int]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("token_type_ids", []) + + assert_never(inputs) + @cached_property def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index aacff87df6d79..1801397811b22 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -305,6 +305,7 @@ def _prompt_to_llm_inputs( tokens_content = parsed["content"] prompt_token_ids = tokens_content["prompt_token_ids"] + token_type_ids = tokens_content.get("token_type_ids") multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") @@ -318,6 +319,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index d8301a36acb01..0f6347a7fd78b 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -17,8 +17,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.model_executor.models.interfaces import SupportsCrossEncoding +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors, + PoolerOutput) from .utils import maybe_prefix @@ -48,7 +51,9 @@ def __init__(self, config: BertConfig): def forward( self, input_ids: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() @@ -58,17 +63,34 @@ def forward( # Position embeddings. position_embeddings = self.position_embeddings(position_ids) - # Token type embeddings. (TODO: move off hotpath?) - token_type_embeddings = self.token_type_embeddings( - torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device)) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) return embeddings +class BertPooler(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[0, :] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + class BertEncoder(nn.Module): def __init__(self, @@ -309,7 +331,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", - embedding_class: type = BertEmbedding): + embedding_class: type = BertEmbedding, + add_pooling_layer: bool = False): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -319,6 +342,7 @@ def __init__(self, cache_config, quant_config, prefix=f"{prefix}.encoder") + self.pooler = BertPooler(config) if add_pooling_layer else None def forward( self, @@ -328,13 +352,17 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embeddings(input_ids=input_ids, - position_ids=position_ids) - + assert hasattr(attn_metadata, "seq_lens_tensor") + hidden_states = self.embeddings( + input_ids=input_ids, + seq_lens=attn_metadata.seq_lens_tensor, + position_ids=position_ids, + token_type_ids=token_type_ids) return self.encoder(hidden_states, kv_caches, attn_metadata) def load_weights(self, weights: Iterable[Tuple[str, @@ -349,7 +377,7 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - if "pooler" in name: + if self.pooler is None and "pooler" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -430,3 +458,95 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: pooling_type=PoolingType.CLS, normalize=True, softmax=False) + + +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.num_labels = config.num_labels + self.bert = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding, + add_pooling_layer=True) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("bert."): + yield (name[len("bert."):], weight) + else: + self_weights.append((name, weight)) + + self.bert.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + offset = 0 + pooled_data_lst = [] + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset:offset + prompt_len] + + pooled_data_i = self.bert.pooler(pooled_data_i) + + pooled_data_lst.append(pooled_data_i) + offset += prompt_len + + pooled_output = torch.stack(pooled_data_lst) + + classifier_output = self.classifier(pooled_output) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) + for data in classifier_output + ] + return PoolerOutput(outputs=pooled_outputs) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.bert(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + attn_metadata=attn_metadata, + token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index dcead65115132..4f0c75b2c6a57 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -7,6 +7,8 @@ from vllm.logger import init_logger from vllm.utils import supports_kw +from .interfaces_base import is_embedding_model + if TYPE_CHECKING: from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.sequence import IntermediateTensors @@ -350,3 +352,37 @@ def is_attention_free( return isinstance(model, _IsAttentionFreeType) return isinstance(model, IsAttentionFree) + + +@runtime_checkable +class SupportsCrossEncoding(Protocol): + """The interface required for all models that support cross encoding.""" + + supports_cross_encoding: ClassVar[Literal[True]] = True + + +@overload +def supports_cross_encoding( + model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]: + ... + + +@overload +def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: + ... + + +def _supports_cross_encoding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + + if isinstance(model, type): + return isinstance(model, SupportsCrossEncoding) + + return isinstance(model, SupportsCrossEncoding) + + +def supports_cross_encoding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + return is_embedding_model(model) and _supports_cross_encoding(model) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 22c2e328bfb65..17ecd8001a3e5 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -21,7 +21,8 @@ from vllm.platforms import current_platform from .interfaces import (has_inner_state, is_attention_free, - supports_multimodal, supports_pp) + supports_cross_encoding, supports_multimodal, + supports_pp) from .interfaces_base import is_embedding_model, is_text_generation_model logger = init_logger(__name__) @@ -121,6 +122,14 @@ "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501, } +_CROSS_ENCODER_MODELS = { + "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "RobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), + "XLMRobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), +} + _MULTIMODAL_MODELS = { # [Decoder-only] "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), @@ -159,6 +168,7 @@ _VLLM_MODELS = { **_TEXT_GENERATION_MODELS, **_EMBEDDING_MODELS, + **_CROSS_ENCODER_MODELS, **_MULTIMODAL_MODELS, **_SPECULATIVE_DECODING_MODELS, } @@ -193,6 +203,7 @@ class _ModelInfo: is_text_generation_model: bool is_embedding_model: bool + supports_cross_encoding: bool supports_multimodal: bool supports_pp: bool has_inner_state: bool @@ -203,6 +214,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": return _ModelInfo( is_text_generation_model=is_text_generation_model(model), is_embedding_model=is_embedding_model(model), + supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), @@ -415,6 +427,12 @@ def is_embedding_model( ) -> bool: return self.inspect_model_cls(architectures).is_embedding_model + def is_cross_encoder_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).supports_cross_encoding + def is_multimodal_model( self, architectures: Union[str, List[str]], @@ -489,4 +507,4 @@ def _run() -> None: if __name__ == "__main__": - _run() \ No newline at end of file + _run() diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index c1dcdd36ec3de..b6b17332e58d2 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -8,8 +8,14 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel -from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.interfaces import SupportsCrossEncoding +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors, + PoolerOutput) class RobertaEmbedding(nn.Module): @@ -39,34 +45,93 @@ def __init__(self, config: RobertaConfig): def forward( self, input_ids: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() - - # Input embeddings. inputs_embeds = self.word_embeddings(input_ids) - # TODO: figure out if there is a better way - # to make to make position ids start at padding_idx + 1 + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens # References: # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - position_ids += self.padding_idx + 1 + pos_list = [] + token_list = [] + offset = 0 + for seq_len in seq_lens: + pos_list.append(position_ids[offset:offset + seq_len]) + token_list.append(input_ids[offset:offset + seq_len]) + offset += seq_len + + new_pos_list = [] + for positions, tokens in zip(pos_list, token_list): + # Verify assumption that incoming position are + # always a sequence from 0 to N. + expected_pos = torch.arange(positions.size()[0], + dtype=torch.long, + device=inputs_embeds.device) + assert torch.equal(positions, expected_pos) + new_pos_list.append( + create_position_ids_from_input_ids(tokens, self.padding_idx)) + position_ids = torch.cat(new_pos_list) # Position embeddings. position_embeddings = self.position_embeddings(position_ids) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) - # Token type embeddings. (TODO: move off hotpath?) - token_type_embeddings = self.token_type_embeddings( - torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device)) - + token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) return embeddings +# Adapted from transformers +def create_position_ids_from_input_ids(input_ids, + padding_idx, + past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. + Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + + incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + + past_key_values_length) * mask + + return incremental_indices.long() + padding_idx + + +# Adapted from transformers +class RobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: RobertaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[0, :] # take token (equiv. to [CLS]) + x = self.dense(x) + x = torch.tanh(x) + x = self.out_proj(x) + return x + + class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities. @@ -85,6 +150,78 @@ def _build_model(self, prefix=prefix, embedding_class=RobertaEmbedding) + +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): + """A model that uses Roberta to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + roberta: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.num_labels = config.num_labels + self.roberta = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=RobertaEmbedding, + add_pooling_layer=False) + self.classifier = RobertaClassificationHead(config) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("roberta."): + yield (name[len("roberta."):], weight) + else: + self_weights.append((name, weight)) + + self.roberta.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + offset = 0 + pooled_data_lst = [] + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset:offset + prompt_len] + + pooled_data_i = self.classifier(pooled_data_i) + + pooled_data_lst.append(pooled_data_i) + offset += prompt_len + + pooled_output = torch.stack(pooled_data_lst) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) + for data in pooled_output + ] + return PoolerOutput(outputs=pooled_outputs) + def forward( self, input_ids: Optional[torch.Tensor], @@ -93,25 +230,12 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - # Verify assumption that position are always a sequence from - # 0 to N. (Actually here we just check 0 and N to simplify). - # This is important to fix the position which are assumed to - # start from padding_idx + 1 instead of 0 in the Roberta models. - assert hasattr(attn_metadata, "seq_lens_tensor") - cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0) - start_pos = torch.cat( - (torch.tensor([0], device=attn_metadata.seq_lens_tensor.device), - cumulative[:-1])) - assert len(torch.nonzero(positions[start_pos])) == 0 - end_pos = cumulative - 1 - last_tokens = attn_metadata.seq_lens_tensor - 1 - assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0 - - return super().forward(input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + return self.roberta(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + attn_metadata=attn_metadata, + token_type_ids=token_type_ids) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 64a4c58d5509c..4035a87231712 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -6,7 +6,7 @@ import torch import torch.types from PIL.Image import Image -from typing_extensions import TypeAlias +from typing_extensions import NotRequired, TypeAlias from vllm.utils import JSONTree, is_list_of, json_map_leaves @@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict): prompt_token_ids: List[int] """The processed token IDs which includes placeholder tokens.""" + token_type_ids: NotRequired[List[int]] + """The token type IDs of the prompt.""" + mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" diff --git a/vllm/sequence.py b/vllm/sequence.py index 3b41d25a2fe42..16c8d8654a6a9 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -449,6 +449,10 @@ def prompt_token_ids(self) -> List[int]: def prompt_embeds(self) -> Optional[torch.Tensor]: return self.inputs.prompt_embeds + @property + def token_type_ids(self) -> List[int]: + return self.inputs.token_type_ids + @property def multi_modal_data(self) -> "MultiModalDataDict": return self.inputs.multi_modal_data @@ -684,6 +688,10 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: return (self.encoder_seq.prompt_token_ids if self.encoder_seq is not None else None) + @property + def token_type_ids(self) -> Optional[List[int]]: + return self.first_seq.token_type_ids + @property def multi_modal_data(self) -> MultiModalDataDict: return self.first_seq.multi_modal_data @@ -906,6 +914,7 @@ class SequenceGroupMetadata( default_factory=lambda: SequenceGroupState()) # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. + token_type_ids: Optional[List[int]] = None multi_modal_data: Optional[Any] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None diff --git a/vllm/worker/cpu_embedding_model_runner.py b/vllm/worker/cpu_embedding_model_runner.py index d0b8fec48d74f..941e9910413fd 100644 --- a/vllm/worker/cpu_embedding_model_runner.py +++ b/vllm/worker/cpu_embedding_model_runner.py @@ -49,6 +49,9 @@ def execute_model( ] model_executable = self.model + cross_enc_kwargs = {} + if model_input.token_type_ids is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids execute_model_kwargs = { "input_ids": model_input.input_tokens, @@ -60,6 +63,7 @@ def execute_model( model_input.attn_metadata, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), + **cross_enc_kwargs, "intermediate_tensors": intermediate_tensors, } diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d3e1202c15e61..bb1d5e58c3c8d 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -43,6 +43,7 @@ class ModelInputForCPU(ModelRunnerInputBase): """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None + token_type_ids: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None @@ -54,6 +55,7 @@ def as_broadcastable_tensor_dict( tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -83,6 +85,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -127,18 +130,20 @@ def build(self) -> ModelInputForCPU: is_prompt = self.seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) = self._prepare_prompt( + (input_tokens, input_positions, token_type_ids, attn_metadata, + seq_lens, multi_modal_kwargs) = self._prepare_prompt( self.seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode( self.seq_group_metadata_list) seq_lens = None + token_type_ids = None return self.model_input_cls( input_tokens=input_tokens, input_positions=input_positions, + token_type_ids=token_type_ids, attn_metadata=attn_metadata, multi_modal_kwargs=multi_modal_kwargs, # query_lens is not needed if chunked prefill is not @@ -203,11 +208,12 @@ def _compute_multi_modal_input( def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - BatchedTensorInputs]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, AttentionMetadata, + List[int], BatchedTensorInputs]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + token_type_ids: List[int] = [] input_mrope_positions: List[List[int]] = [[] for _ in range(3)] slot_mapping: List[int] = [] @@ -225,11 +231,13 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() + token_types = seq_group_metadata.token_type_ids computed_len = seq_data.get_num_computed_tokens() seq_len = len(prompt_tokens) seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids + token_type_ids.extend(token_types if token_types else []) mrope_positions = None if seq_group_metadata.multi_modal_data: @@ -293,6 +301,11 @@ def _prepare_prompt( or input_mrope_positions, dtype=torch.long, device=self.device) # type: ignore + token_type_ids = torch.tensor(token_type_ids, + dtype=torch.long, + device=self.device)\ + if token_type_ids else None + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore @@ -317,8 +330,8 @@ def _prepare_prompt( multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - return (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) + return (input_tokens, input_positions, token_type_ids, attn_metadata, + seq_lens, multi_modal_kwargs) def _prepare_decode( self, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 37cfcbf13d7a3..19951c4787cf9 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -97,6 +97,10 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() + cross_enc_kwargs = {} + if model_input.token_types is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_types + with set_forward_context(model_input.attn_metadata): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, @@ -105,7 +109,8 @@ def execute_model( attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device)) + device=self.device), + **cross_enc_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ed0360fb7f727..c97ac33490c40 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase): """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None + token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None lora_mapping: Optional["LoRAMapping"] = None @@ -200,6 +201,7 @@ class InterDataForSeqGroup: def simple_reinit(self): self.input_tokens[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore + self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore @@ -226,6 +228,7 @@ def __init__( # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None, + token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, # The sequence length (may be capped to the sliding window). @@ -291,6 +294,12 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.input_positions[seq_id].clear() + if token_types: + self.token_types = token_types + else: + for seq_id in range(len(self.seq_ids)): + self.token_types[seq_id].clear() + self.mrope_input_positions = None if seq_lens: @@ -354,6 +363,7 @@ def __init__( else: self.input_tokens = input_tokens or [] self.input_positions = input_positions or [] + self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] @@ -386,6 +396,7 @@ def __post_init__(self): self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)] + self.token_types = [[] for _ in range(self.n_seqs)] self.mrope_input_positions = None self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs @@ -498,12 +509,15 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] + token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) + inter_data.token_types[seq_idx].extend( + token_types if token_types else []) inter_data.query_lens[seq_idx] = seq_len - context_len if seq_data.mrope_position_delta is not None: @@ -561,6 +575,8 @@ def _compute_for_prefix_cache_hit( seq_idx][uncomputed_start:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][uncomputed_start:] + inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ + uncomputed_start:] context_len = prefix_cache_len inter_data.context_lens[seq_idx] = context_len @@ -575,6 +591,8 @@ def _compute_for_prefix_cache_hit( seq_idx][-1:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][-1:] + inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ + -1:] inter_data.query_lens[seq_idx] = 1 inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 @@ -803,9 +821,12 @@ def build(self) -> ModelInputForGPU: """ # Combine and flatten intermediate data. input_tokens = [] + token_types = [] for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) + for cur_token_types in inter_data.token_types: + token_types.extend(cur_token_types) if not input_tokens: # This may happen when all prefill requests hit @@ -874,6 +895,12 @@ def build(self) -> ModelInputForGPU: input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, self.runner.device, self.runner.pin_memory) + + token_types_tensor = async_tensor_h2d(token_types, torch.long, + self.runner.device, + self.runner.pin_memory) \ + if token_types else None + if mrope_input_positions is not None: for idx in range(3): mrope_input_positions[idx].extend( @@ -952,6 +979,7 @@ def build(self) -> ModelInputForGPU: return self.model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, + token_types=token_types_tensor, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens,