diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 846246a2062a6..ce9dc9e457c09 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -65,7 +65,8 @@ def run_phi3v(question): # PaliGemma def run_paligemma(question): - prompt = question + # PaliGemma has special prompt format for VQA + prompt = "caption en" llm = LLM(model="google/paligemma-3b-mix-224") return llm, prompt diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index fe91611cd30ff..9ba53b8b59a2f 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,9 +1,8 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict import torch -from PIL import Image from torch import nn -from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel +from transformers import PaliGemmaConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -18,9 +17,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsVision +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip, get_max_siglip_image_tokens) from .utils import merge_vision_embeddings logger = init_logger(__name__) @@ -32,55 +33,22 @@ def get_max_paligemma_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PaliGemmaConfig) - text_config = hf_config.text_config - - return text_config.num_image_tokens - - -def dummy_seq_data_for_paligemma( - hf_config: PaliGemmaConfig, - seq_len: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - if image_feature_size_override is None: - image_feature_size = hf_config.text_config.num_image_tokens - else: - image_feature_size = image_feature_size_override - - token_ids = [image_token_id] * image_feature_size - token_ids += [0] * (seq_len - image_feature_size) - return SequenceData(token_ids) - - -def dummy_image_for_paligemma( - hf_config: SiglipVisionConfig, - *, - image_width_override: Optional[int] = None, - image_height_override: Optional[int] = None, -): - width = height = hf_config.image_size - if image_width_override is not None: - width = image_width_override - if image_height_override is not None: - height = image_height_override + vision_config = hf_config.vision_config - image = Image.new("RGB", (width, height), color=0) - return {"image": image} + return get_max_siglip_image_tokens(vision_config) def dummy_data_for_paligemma(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(PaliGemmaConfig) vision_config = hf_config.vision_config - seq_data = dummy_seq_data_for_paligemma( - hf_config, + seq_data = dummy_seq_data_for_siglip( + vision_config, seq_len, image_token_id=hf_config.image_token_index, ) - mm_data = dummy_image_for_paligemma(vision_config) + mm_data = dummy_image_for_siglip(vision_config) return seq_data, mm_data @@ -208,30 +176,37 @@ def _parse_and_validate_image_input( data=self._validate_pixel_values(pixel_values), ) - def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, - pixel_values: torch.Tensor) -> torch.Tensor: + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: target_dtype = vision_tower.get_input_embeddings().weight.dtype - image_outputs = vision_tower(pixel_values.to(dtype=target_dtype), - output_hidden_states=True) - - selected_image_features = image_outputs.last_hidden_state + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) - return selected_image_features + return image_features def _process_image_pixels( - self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor: + self, + inputs: PaliGemmaImagePixelInputs, + ) -> torch.Tensor: assert self.vision_tower is not None pixel_values = inputs["data"] - return self._image_pixels_to_features(self.vision_tower, pixel_values) + return self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) def _process_image_input( - self, image_input: PaliGemmaImageInputs) -> torch.Tensor: + self, + image_input: PaliGemmaImageInputs, + ) -> torch.Tensor: assert self.vision_tower is not None - image_features = self._process_image_pixels(image_input) + image_features = self._process_image_pixels(image_input, ) return self.multi_modal_projector(image_features) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py new file mode 100644 index 0000000000000..6faef45c9a6d3 --- /dev/null +++ b/vllm/model_executor/models/siglip.py @@ -0,0 +1,621 @@ +"""Implementation of SiglipVisionModel intended to be only used +within a vision language model.""" + +import math +from typing import Optional, Tuple + +import torch +from PIL import Image +from torch import nn +from transformers import SiglipConfig, SiglipVisionConfig +from transformers.models.siglip.modeling_siglip import SiglipAttention +from vllm_flash_attn import flash_attn_func +from xformers.ops import memory_efficient_attention + +from vllm.config import ModelConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.multimodal.image import (cached_get_tokenizer, + repeat_and_pad_image_tokens) +from vllm.sequence import SequenceData + + +def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: + assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_siglip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int: + return get_siglip_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + + +def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int: + return get_siglip_image_feature_size(hf_config) + + +def dummy_seq_data_for_siglip( + hf_config: SiglipVisionConfig, + seq_len: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_siglip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + token_ids = [image_token_id] * image_feature_size + token_ids += [0] * (seq_len - image_feature_size) + return SequenceData(token_ids) + + +def dummy_image_for_siglip( + hf_config: SiglipVisionConfig, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image} + + +def input_processor_for_siglip( + model_config: ModelConfig, + hf_config: SiglipVisionConfig, + llm_inputs: LLMInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_feature_size = get_siglip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids = repeat_and_pad_image_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + image_token_id=image_token_id, + repeat_count=image_feature_size, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + ) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + self.position_embedding = VocabParallelEmbedding( + self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions, dtype=torch.int64).expand( + (1, -1)), + persistent=False, + ) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, + width: int) -> torch.Tensor: + """ + This method is an adapted method for SigLIP (due to SigLIP not having + class embedding unlike other ViTs) that allows the model to interpolate + the pre-trained position encodings such that it can be usable on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] + num_positions = position_embeddings.shape[1] + if num_patches == num_positions and height == width: + return position_embeddings + + dim = embeddings.shape[-1] + height = height // self.patch_size + width = width // self.patch_size + # we add a small number to avoid floating point error + # in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + + patch_pos_embed = position_embeddings.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), + dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if (int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1]): + raise ValueError("Width or height does not match with " + "the interpolated position embeddings") + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False) -> torch.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding( + self.position_ids) + return embeddings + + +# NOTE: Not used - kept for later when we TP the ViT +# TODO(ChristopherCho): Implement TP version of Attention +class SiglipTPAttention(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + if self.total_num_heads % tp_size != 0: + raise ValueError( + f"Number of attention heads ({self.total_num_heads}) " + "must be divisible by the tensor model parallel size" + f" ({tp_size}).") + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.embed_dim // self.total_num_heads + if self.head_dim * self.total_num_heads != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads (got " + "`embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.qkv_size = self.num_heads * self.head_dim + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + ) + + self.attn_fn = self._basic_attention_forward + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + batch_size, q_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.split( + [self.qkv_size] * 3, dim=-1) + + attn_output = self.attn_fn( + q=query_states, + k=key_states, + v=value_states, + batch_size=batch_size, + q_len=q_len, + ) + + attn_output, _ = self.out_proj(attn_output) + return attn_output + + def _basic_attention_forward(self, q, k, v, batch_size, q_len): + q = q.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + k = k.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + v = v.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + k_v_seq_len = k.shape[-2] + attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale + + if attn_weights.size() != ( + batch_size, + self.num_heads, + q_len, + k_v_seq_len, + ): + raise ValueError( + "Attention weights should be of size " + f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}") + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to(q.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, v) + + if attn_output.size() != ( + batch_size, + self.num_heads, + q_len, + self.head_dim, + ): + raise ValueError( + "`attn_output` should be of size " + f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +# TODO(ChristopherCho): flash_attn_func is not working properly. +# It constantly throws a CUDA error. +class SiglipFlashAttention2(SiglipTPAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_fn = self._flash_attention_forward + + # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449 + # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133 + def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args, + **kwargs): + """Implements the multihead softmax attention. + Arguments + --------- + q, k, v: The tensor containing the + query, key, and value. (B, S, H, D) + """ + + q = q.view(batch_size, q_len, self.num_heads, self.head_dim) + k = k.view(batch_size, q_len, self.num_heads, self.head_dim) + v = v.view(batch_size, q_len, self.num_heads, self.head_dim) + + attn_output = flash_attn_func( + q, + k, + v, + dropout_p=self.dropout, + causal=False, + ) + + attn_output = attn_output.reshape(batch_size, q_len, + self.embed_dim).contiguous() + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +class SiglipSdpaAttention(SiglipTPAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + self.attn_fn = self._sdpa_attention_forward + + def _sdpa_attention_forward(self, q, k, v, batch_size, q_len): + q = q.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + k = k.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + v = v.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, q_len, self.embed_dim) + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +class SiglipxFormersAttention(SiglipTPAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_fn = self._xformers_attention_forward + + def _xformers_attention_forward(self, q, k, v, batch_size, q_len): + q = q.view(batch_size, q_len, self.num_heads, self.head_dim) + k = k.view(batch_size, q_len, self.num_heads, self.head_dim) + v = v.view(batch_size, q_len, self.num_heads, self.head_dim) + + attn_output = memory_efficient_attention(q, + k, + v, + p=0.0, + scale=self.scale) + attn_output = attn_output.reshape(batch_size, q_len, + self.embed_dim).contiguous() + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +SIGLIP_ATTENTION_CLASSES = { + "eager": SiglipTPAttention, + "flash_attention_2": SiglipFlashAttention2, + "sdpa": SiglipSdpaAttention, + "xformers": SiglipxFormersAttention, +} + + +class SiglipMLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + + # For quantization, we require the hidden size to be a multiple of 64 + quantizable = (config.hidden_size % 64 == 0 + and config.intermediate_size % 64 == 0) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config if quantizable else None, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config if quantizable else None, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + + def __init__( + self, + config: SiglipConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = config.hidden_size + + # TODO(ChristopherCho): use TP'ed Attention block + self.self_attn = SiglipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP( + config, + quant_config=quant_config, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, None + + +class SiglipEncoder(nn.Module): + + def __init__( + self, + config: SiglipConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + SiglipEncoderLayer( + config, + quant_config=quant_config, + ) for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> Tuple: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states, _ = encoder_layer(hidden_states) + + return hidden_states + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config=config, quant_config=quant_config) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SiglipVisionTransformer(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder( + config, + quant_config=quant_config, + ) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + self.use_head = (True if not hasattr(config, "vision_use_head") else + config.vision_use_head) + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead( + config=config, quant_config=quant_config) + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = True, + ) -> torch.Tensor: + hidden_states = self.embeddings( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + + last_hidden_state = self.post_layernorm(encoder_outputs) + + # TODO: add this back when pooled_output is used in inference + # if self.use_head: + # pooled_output = self.head(last_hidden_state) + + return last_hidden_state + + +class SiglipVisionModel(nn.Module): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.vision_model = SiglipVisionTransformer( + config, + quant_config, + ) + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + return self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + )