Skip to content

Commit

Permalink
[Model] SiglipVisionModel ported from transformers (vllm-project#6942)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
ChristopherCho and ywang96 authored Aug 5, 2024
1 parent cc08fc7 commit c0d8f16
Show file tree
Hide file tree
Showing 3 changed files with 650 additions and 53 deletions.
3 changes: 2 additions & 1 deletion examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def run_phi3v(question):
# PaliGemma
def run_paligemma(question):

prompt = question
# PaliGemma has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224")

return llm, prompt
Expand Down
79 changes: 27 additions & 52 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict

import torch
from PIL import Image
from torch import nn
from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel
from transformers import PaliGemmaConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
Expand All @@ -18,9 +17,11 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.sequence import IntermediateTensors, SamplerOutput

from .interfaces import SupportsVision
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import merge_vision_embeddings

logger = init_logger(__name__)
Expand All @@ -32,55 +33,22 @@

def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
text_config = hf_config.text_config

return text_config.num_image_tokens


def dummy_seq_data_for_paligemma(
hf_config: PaliGemmaConfig,
seq_len: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = hf_config.text_config.num_image_tokens
else:
image_feature_size = image_feature_size_override

token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
return SequenceData(token_ids)


def dummy_image_for_paligemma(
hf_config: SiglipVisionConfig,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
vision_config = hf_config.vision_config

image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return get_max_siglip_image_tokens(vision_config)


def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config

seq_data = dummy_seq_data_for_paligemma(
hf_config,
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
)

mm_data = dummy_image_for_paligemma(vision_config)
mm_data = dummy_image_for_siglip(vision_config)
return seq_data, mm_data


Expand Down Expand Up @@ -208,30 +176,37 @@ def _parse_and_validate_image_input(
data=self._validate_pixel_values(pixel_values),
)

def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
) -> torch.Tensor:

target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
output_hidden_states=True)

selected_image_features = image_outputs.last_hidden_state
image_features = vision_tower(pixel_values.to(dtype=target_dtype))

return selected_image_features
return image_features

def _process_image_pixels(
self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor:
self,
inputs: PaliGemmaImagePixelInputs,
) -> torch.Tensor:
assert self.vision_tower is not None

pixel_values = inputs["data"]

return self._image_pixels_to_features(self.vision_tower, pixel_values)
return self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)

def _process_image_input(
self, image_input: PaliGemmaImageInputs) -> torch.Tensor:
self,
image_input: PaliGemmaImageInputs,
) -> torch.Tensor:

assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
image_features = self._process_image_pixels(image_input, )

return self.multi_modal_projector(image_features)

Expand Down
Loading

0 comments on commit c0d8f16

Please sign in to comment.