Skip to content

Commit

Permalink
[Model] Add support for H2OVL-Mississippi models (#9747)
Browse files Browse the repository at this point in the history
Signed-off-by: Shanshan Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
cooleel and ywang96 authored Nov 4, 2024
1 parent 1f1b6d6 commit 5459772
Show file tree
Hide file tree
Showing 12 changed files with 698 additions and 4 deletions.
6 changes: 6 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,12 @@ Text Generation
- :code:`THUDM/glm-4v-9b` etc.
-
- ✅︎
* - :code:`H2OVLChatModel`
- H2OVL
- T + I\ :sup:`E+`
- :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc.
-
- ✅︎
* - :code:`InternVLChatModel`
- InternVL2
- T + I\ :sup:`E+`
Expand Down
28 changes: 27 additions & 1 deletion examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,31 @@ def run_minicpmv(question: str, modality: str):
return llm, prompt, stop_token_ids


# H2OVL-Mississippi
def run_h2ovl(question: str, modality: str):
assert modality == "image"

model_name = "h2oai/h2ovl-mississippi-2b"

llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
stop_token_ids = [tokenizer.eos_token_id]
return llm, prompt, stop_token_ids


# InternVL
def run_internvl(question: str, modality: str):
assert modality == "image"
Expand Down Expand Up @@ -363,6 +388,7 @@ def run_glm4v(question: str, modality: str):
"chameleon": run_chameleon,
"minicpmv": run_minicpmv,
"blip-2": run_blip2,
"h2ovl_chat": run_h2ovl,
"internvl_chat": run_internvl,
"NVLM_D": run_nvlm_d,
"qwen_vl": run_qwen_vl,
Expand Down Expand Up @@ -475,4 +501,4 @@ def main(args):
default=16,
help='Number of frames to extract from the video.')
args = parser.parse_args()
main(args)
main(args)
35 changes: 35 additions & 0 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,40 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
)


def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-2b"

llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"max_dynamic_patch": 4},
)

placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
stop_token_ids = [tokenizer.eos_token_id]

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "OpenGVLab/InternVL2-2B"

Expand Down Expand Up @@ -258,6 +292,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:

model_example_map = {
"phi3_v": load_phi3v,
"h2ovl_chat": load_h2onvl,
"internvl_chat": load_internvl,
"NVLM_D": load_nvlm_d,
"qwen2_vl": load_qwen2_vl,
Expand Down
130 changes: 130 additions & 0 deletions tests/models/decoder_only/vision_language/test_h2ovl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import Optional, Tuple

import pytest
import torch
from PIL.Image import Image
from transformers import AutoConfig

# Import the functions to test
from vllm.model_executor.models.h2ovl import (calculate_num_blocks,
image_to_pixel_values_wrapper)
from vllm.multimodal.utils import rescale_image_size

models = [
"h2oai/h2ovl-mississippi-800m", # Replace with your actual model names
"h2oai/h2ovl-mississippi-2b",
]
target_dtype = "bfloat16"


def run_preprocessing_test(
image: Image,
config,
max_dynamic_patch: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Test the image preprocessing and calculate expected blocks."""

if max_dynamic_patch is None:
max_dynamic_patch = config.max_dynamic_patch

width, height = image.size
use_MSAC = config.use_msac

# Create the mapper function with the provided configuration
mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC)
pixel_values = mapper(image)

# Calculate the expected number of blocks
if use_MSAC:
# First pass
blocks1, _, _, aspect_ratio = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
prior_aspect_ratio=None,
)

# Second pass
blocks2, _, _, _ = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False,
prior_aspect_ratio=aspect_ratio,
)

# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0

# Total blocks is the sum of blocks from both passes minus overlapping
total_blocks = blocks1 + blocks2 - 1

expected_blocks = total_blocks

else:
blocks, _, _, _ = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False,
prior_aspect_ratio=None,
)
expected_blocks = blocks

if config.use_thumbnail and expected_blocks > 1:
expected_blocks += 1

return pixel_values, expected_blocks


@pytest.mark.parametrize("model_name", models)
@pytest.mark.parametrize(
"size_factors",
[
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8])
def test_image_preprocessing(image_assets, model_name, size_factors,
max_dynamic_patch):
"""Test image preprocessing pipeline with different configurations."""
# Load the configuration from the model
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

for asset in image_assets:
image = asset.pil_image
for factor in size_factors:
scaled_image = rescale_image_size(image, factor)

# Test preprocessing and get expected number of blocks
pixel_values, expected_blocks = run_preprocessing_test(
scaled_image, config, max_dynamic_patch)

# Verify output shapes and properties
actual_blocks = pixel_values.shape[0]
assert actual_blocks == expected_blocks, (
f"Expected {expected_blocks} blocks, got {actual_blocks}")

# Check image dimensions
expected_size = (
3, # Number of channels (C, H, W)
config.vision_config.image_size,
config.vision_config.image_size,
)
for img in pixel_values:
assert img.shape == expected_size, (
f"Expected image size {expected_size}, got {img.shape}")
17 changes: 17 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@
marks=[large_gpu_mark(min_gb=48)],
patch_hf_runner=model_utils.glm_patch_hf_runner,
),
"h2ovl": VLMTestInfo(
models = [
"h2oai/h2ovl-mississippi-800m",
"h2oai/h2ovl-mississippi-2b",
],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501
"cherry_blossom": "<image>\nWhat is the season?",
}),
multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501
max_model_len=8192,
dtype="bfloat16",
use_tokenizer_eos=True,
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
),
"intern_vl": VLMTestInfo(
models=[
"OpenGVLab/InternVL2-1B",
Expand Down
60 changes: 60 additions & 0 deletions tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,66 @@ def processor(*args, text="", images=None, **kwargs):
return hf_model


def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for H2OVL."""

class H2OVLProcessor:
"""A simple processor for H2OVL models."""

def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype

self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size

def __call__(self, text: str, images: Union[Image, List[Image]],
**kwargs):
# yapf: disable
from vllm.model_executor.models.h2ovl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)

# yapf: enable
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values(image,
self.image_size,
self.min_num,
self.max_num,
self.use_thumbnail,
use_MSAC=self.config.use_msac).to(
self.dtype) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
]
pixel_values = torch.cat(pixel_values, dim=0)
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt

img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
"<IMG_CONTEXT>")
hf_model.model.img_context_token_id = img_context_token_id
hf_model.processor = H2OVLProcessor(hf_model)
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.get_output_embeddings()
hf_model.model.generate = types.MethodType(_internvl_generate,
hf_model.model)
return hf_model


def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for InternVL."""

Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def _placeholder_str(self, modality: ModalityStr,
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
if model_type in ("chameleon", "internvl_chat", "NVLM_D",
"h2ovl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
Expand Down
Loading

0 comments on commit 5459772

Please sign in to comment.