From 54597724f4c6b52d50152f3cc46e86c101d9c820 Mon Sep 17 00:00:00 2001 From: shanshan wang Date: Sun, 3 Nov 2024 18:15:36 -0600 Subject: [PATCH] [Model] Add support for H2OVL-Mississippi models (#9747) Signed-off-by: Shanshan Wang Signed-off-by: Roger Wang Co-authored-by: Roger Wang --- docs/source/models/supported_models.rst | 6 + examples/offline_inference_vision_language.py | 28 +- ...e_inference_vision_language_multi_image.py | 35 ++ .../vision_language/test_h2ovl.py | 130 ++++++ .../vision_language/test_models.py | 17 + .../vision_language/vlm_utils/model_utils.py | 60 +++ vllm/entrypoints/chat_utils.py | 3 +- vllm/model_executor/models/h2ovl.py | 401 ++++++++++++++++++ vllm/model_executor/models/registry.py | 3 +- vllm/transformers_utils/config.py | 2 + vllm/transformers_utils/configs/__init__.py | 4 +- vllm/transformers_utils/configs/h2ovl.py | 13 + 12 files changed, 698 insertions(+), 4 deletions(-) create mode 100644 tests/models/decoder_only/vision_language/test_h2ovl.py create mode 100644 vllm/model_executor/models/h2ovl.py create mode 100644 vllm/transformers_utils/configs/h2ovl.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index a5c085bb84db9..55835d945b00c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -440,6 +440,12 @@ Text Generation - :code:`THUDM/glm-4v-9b` etc. - - ✅︎ + * - :code:`H2OVLChatModel` + - H2OVL + - T + I\ :sup:`E+` + - :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc. + - + - ✅︎ * - :code:`InternVLChatModel` - InternVL2 - T + I\ :sup:`E+` diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 60cdb186331fe..4fd002caf1763 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -176,6 +176,31 @@ def run_minicpmv(question: str, modality: str): return llm, prompt, stop_token_ids +# H2OVL-Mississippi +def run_h2ovl(question: str, modality: str): + assert modality == "image" + + model_name = "h2oai/h2ovl-mississippi-2b" + + llm = LLM( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + messages = [{'role': 'user', 'content': f"\n{question}"}] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for H2OVL-Mississippi + # https://huggingface.co/h2oai/h2ovl-mississippi-2b + stop_token_ids = [tokenizer.eos_token_id] + return llm, prompt, stop_token_ids + + # InternVL def run_internvl(question: str, modality: str): assert modality == "image" @@ -363,6 +388,7 @@ def run_glm4v(question: str, modality: str): "chameleon": run_chameleon, "minicpmv": run_minicpmv, "blip-2": run_blip2, + "h2ovl_chat": run_h2ovl, "internvl_chat": run_internvl, "NVLM_D": run_nvlm_d, "qwen_vl": run_qwen_vl, @@ -475,4 +501,4 @@ def main(args): default=16, help='Number of frames to extract from the video.') args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index e28514bf403f7..d99684078ff3d 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -107,6 +107,40 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: ) +def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData: + model_name = "h2oai/h2ovl-mississippi-2b" + + llm = LLM( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"max_dynamic_patch": 4}, + ) + + placeholders = "\n".join(f"Image-{i}: \n" + for i, _ in enumerate(image_urls, start=1)) + messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for H2OVL-Mississippi + # https://huggingface.co/h2oai/h2ovl-mississippi-2b + stop_token_ids = [tokenizer.eos_token_id] + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "OpenGVLab/InternVL2-2B" @@ -258,6 +292,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData: model_example_map = { "phi3_v": load_phi3v, + "h2ovl_chat": load_h2onvl, "internvl_chat": load_internvl, "NVLM_D": load_nvlm_d, "qwen2_vl": load_qwen2_vl, diff --git a/tests/models/decoder_only/vision_language/test_h2ovl.py b/tests/models/decoder_only/vision_language/test_h2ovl.py new file mode 100644 index 0000000000000..ad9aa3104750b --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_h2ovl.py @@ -0,0 +1,130 @@ +from typing import Optional, Tuple + +import pytest +import torch +from PIL.Image import Image +from transformers import AutoConfig + +# Import the functions to test +from vllm.model_executor.models.h2ovl import (calculate_num_blocks, + image_to_pixel_values_wrapper) +from vllm.multimodal.utils import rescale_image_size + +models = [ + "h2oai/h2ovl-mississippi-800m", # Replace with your actual model names + "h2oai/h2ovl-mississippi-2b", +] +target_dtype = "bfloat16" + + +def run_preprocessing_test( + image: Image, + config, + max_dynamic_patch: Optional[int] = None, +) -> Tuple[torch.Tensor, int]: + """Test the image preprocessing and calculate expected blocks.""" + + if max_dynamic_patch is None: + max_dynamic_patch = config.max_dynamic_patch + + width, height = image.size + use_MSAC = config.use_msac + + # Create the mapper function with the provided configuration + mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC) + pixel_values = mapper(image) + + # Calculate the expected number of blocks + if use_MSAC: + # First pass + blocks1, _, _, aspect_ratio = calculate_num_blocks( + width, + height, + config.min_dynamic_patch, + max_dynamic_patch, + config.vision_config.image_size, + use_thumbnail=False, # Thumbnail is handled separately + prior_aspect_ratio=None, + ) + + # Second pass + blocks2, _, _, _ = calculate_num_blocks( + width, + height, + config.min_dynamic_patch, + max_dynamic_patch, + config.vision_config.image_size, + use_thumbnail=False, + prior_aspect_ratio=aspect_ratio, + ) + + # Add thumbnail if use_thumbnail is True and total_blocks > 1 + if config.use_thumbnail: + blocks1 += 1 if blocks1 > 1 else 0 + blocks2 += 1 if blocks2 > 1 else 0 + + # Total blocks is the sum of blocks from both passes minus overlapping + total_blocks = blocks1 + blocks2 - 1 + + expected_blocks = total_blocks + + else: + blocks, _, _, _ = calculate_num_blocks( + width, + height, + config.min_dynamic_patch, + max_dynamic_patch, + config.vision_config.image_size, + use_thumbnail=False, + prior_aspect_ratio=None, + ) + expected_blocks = blocks + + if config.use_thumbnail and expected_blocks > 1: + expected_blocks += 1 + + return pixel_values, expected_blocks + + +@pytest.mark.parametrize("model_name", models) +@pytest.mark.parametrize( + "size_factors", + [ + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8]) +def test_image_preprocessing(image_assets, model_name, size_factors, + max_dynamic_patch): + """Test image preprocessing pipeline with different configurations.""" + # Load the configuration from the model + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + for asset in image_assets: + image = asset.pil_image + for factor in size_factors: + scaled_image = rescale_image_size(image, factor) + + # Test preprocessing and get expected number of blocks + pixel_values, expected_blocks = run_preprocessing_test( + scaled_image, config, max_dynamic_patch) + + # Verify output shapes and properties + actual_blocks = pixel_values.shape[0] + assert actual_blocks == expected_blocks, ( + f"Expected {expected_blocks} blocks, got {actual_blocks}") + + # Check image dimensions + expected_size = ( + 3, # Number of channels (C, H, W) + config.vision_config.image_size, + config.vision_config.image_size, + ) + for img in pixel_values: + assert img.shape == expected_size, ( + f"Expected image size {expected_size}, got {img.shape}") diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index e49ea6f98324d..cfd2d61f2b633 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -187,6 +187,23 @@ marks=[large_gpu_mark(min_gb=48)], patch_hf_runner=model_utils.glm_patch_hf_runner, ), + "h2ovl": VLMTestInfo( + models = [ + "h2oai/h2ovl-mississippi-800m", + "h2oai/h2ovl-mississippi-2b", + ], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts({ + "stop_sign": "\nWhat's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "\nWhat is the season?", + }), + multi_image_prompt="Image-1: \nImage-2: \nDescribe the two images in short.", # noqa: E501 + max_model_len=8192, + dtype="bfloat16", + use_tokenizer_eos=True, + patch_hf_runner=model_utils.h2ovl_patch_hf_runner, + ), "intern_vl": VLMTestInfo( models=[ "OpenGVLab/InternVL2-1B", diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index e925934db0e7c..849857b4232e7 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -259,6 +259,66 @@ def processor(*args, text="", images=None, **kwargs): return hf_model +def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for H2OVL.""" + + class H2OVLProcessor: + """A simple processor for H2OVL models.""" + + def __init__(self, hf_runner: HfRunner): + self.num_image_token = hf_runner.model.num_image_token + self.tokenizer = hf_runner.tokenizer + self.dtype = hf_runner.model.dtype + + self.config = AutoConfig.from_pretrained(hf_runner.model_name, + trust_remote_code=True) + self.vision_config = self.config.vision_config + self.use_thumbnail = self.config.use_thumbnail + self.min_num = self.config.min_dynamic_patch + self.max_num = self.config.max_dynamic_patch + self.image_size = self.vision_config.image_size + + def __call__(self, text: str, images: Union[Image, List[Image]], + **kwargs): + # yapf: disable + from vllm.model_executor.models.h2ovl import ( + IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values) + + # yapf: enable + images = [images] if isinstance(images, Image) else images + pixel_values = [ + image_to_pixel_values(image, + self.image_size, + self.min_num, + self.max_num, + self.use_thumbnail, + use_MSAC=self.config.use_msac).to( + self.dtype) for image in images + ] + num_patches_list = [ + pixel_value.shape[0] for pixel_value in pixel_values + ] + pixel_values = torch.cat(pixel_values, dim=0) + for num_patches in num_patches_list: + context_tokens = IMG_CONTEXT * self.num_image_token \ + * num_patches + image_tokens = IMG_START + context_tokens + IMG_END + text = text.replace('', image_tokens, 1) + prompt = self.tokenizer(text, return_tensors="pt") + prompt.update({"pixel_values": pixel_values}) + return prompt + + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( + "") + hf_model.model.img_context_token_id = img_context_token_id + hf_model.processor = H2OVLProcessor(hf_model) + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.language_model.get_output_embeddings() + hf_model.model.generate = types.MethodType(_internvl_generate, + hf_model.model) + return hf_model + + def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for InternVL.""" diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index bc2de2d162473..c9552977710d1 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -187,7 +187,8 @@ def _placeholder_str(self, modality: ModalityStr, if model_type.startswith("llava"): return self._cached_token_str(self._tokenizer, hf_config.image_token_index) - if model_type in ("chameleon", "internvl_chat", "NVLM_D"): + if model_type in ("chameleon", "internvl_chat", "NVLM_D", + "h2ovl_chat"): return "" if model_type == "mllama": return "<|image|>" diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py new file mode 100644 index 0000000000000..43242fe370ba2 --- /dev/null +++ b/vllm/model_executor/models/h2ovl.py @@ -0,0 +1,401 @@ +# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py +# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py +# -------------------------------------------------------- +# H2OVL-Mississippi +# Copyright (c) 2024 H2O.AI +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- +from functools import partial +from typing import List, Optional, Tuple + +import torch +from PIL import Image +from transformers import PretrainedConfig + +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.utils import is_list_of + +from .intern_vit import InternVisionModel +from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, InternVLChatModel, + InternVLInputPipeline, build_transform, + find_closest_aspect_ratio, get_internvl_num_patches) + + +# modified to include blocks generated in second pass +def calculate_num_blocks( + orig_width: int, + orig_height: int, + min_num: int, + max_num: int, + image_size: int, + use_thumbnail: bool, + prior_aspect_ratio=None, +) -> Tuple[int, int, int, Tuple[int, int]]: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for n in range(min_num, max_num + 1) + for i in range(1, n + 1) for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # if prior_aspect_ratio is provided, filter the target ratios + if prior_aspect_ratio is not None: + target_ratios = [ + ratio for ratio in target_ratios if prior_aspect_ratio[0] % + ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0 + ] + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, + target_ratios, orig_width, + orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # add thumbnail image if num_blocks > 1 + if use_thumbnail and blocks > 1: + blocks += 1 + return blocks, target_width, target_height, target_aspect_ratio + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +# refactored to handle prior_aspect_ratio as optional +def dynamic_preprocess( + image: Image.Image, + min_num: int, + max_num: int, + image_size: int, + use_thumbnail: bool, + prior_aspect_ratio: Optional[Tuple[int, int]] = None, +) -> Tuple[List[Image.Image], Tuple[int, int]]: + orig_width, orig_height = image.size + + # calculate the number of blocks based on prior aspect ratio if available + blocks, target_width, target_height, target_aspect_ratio = ( + calculate_num_blocks( + orig_width, + orig_height, + min_num, + max_num, + image_size, + use_thumbnail=False, + prior_aspect_ratio=prior_aspect_ratio, + )) + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images, target_aspect_ratio + + +def load_image( + image: Image.Image, + input_size=448, + min_num=1, + max_num=6, + use_thumbnail=True, + prior_aspect_ratio: Optional[Tuple[int, int]] = None, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + transform = build_transform(input_size=input_size) + images, target_aspect_ratio = dynamic_preprocess( + image, + image_size=input_size, + use_thumbnail=use_thumbnail, + min_num=min_num, + max_num=max_num, + prior_aspect_ratio=prior_aspect_ratio, + ) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values, target_aspect_ratio + + +# refactored to use the combined load_image function +def image_to_pixel_values( + image: Image.Image, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, + use_MSAC: bool, +) -> torch.Tensor: + # when MSAC is turned on, we need to process the image twice + if use_MSAC: + # first pass + pixel_values, target_aspect_ratio = load_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=True, + ) + # second pass + pixel_values2, _ = load_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + prior_aspect_ratio=target_aspect_ratio, + ) + # combine pixel values + pixel_values = torch.cat( + [pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0) + + else: + pixel_values, _ = load_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=use_thumbnail, + ) + + return pixel_values + + +def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, + max_dynamic_patch: Optional[int] = None, + use_MSAC: Optional[bool] = None): + image_size = hf_config.vision_config.image_size + min_num = hf_config.min_dynamic_patch + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + if use_MSAC is None: + use_MSAC = hf_config.use_msac + use_thumbnail = hf_config.use_thumbnail + return partial( + image_to_pixel_values, + input_size=image_size, + min_num=min_num, + max_num=max_dynamic_patch, + use_thumbnail=use_thumbnail, + use_MSAC=use_MSAC, + ) + + +def get_max_internvl_image_tokens(ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None): + """ + Calculate the maximum number of tokens with/without MSAC and thumbnail + """ + hf_config = ctx.get_hf_config() + use_thumbnail = hf_config.use_thumbnail + use_MSAC = hf_config.use_msac + + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + + num_patches = get_internvl_num_patches(hf_config) + + coefficient = 2 if use_MSAC else 1 + num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0) + + return num_blocks * num_patches + + +class H2OVLInputPipeline(InternVLInputPipeline): + """ + Input pipeline for processing image and text data for the H2OVL model. + """ + + def input_processor( + self, + ctx: InputContext, + inputs: DecoderOnlyInputs, + *, + max_dynamic_patch: Optional[int] = None, + ) -> DecoderOnlyInputs: + # get multi_modal_data + multi_modal_data = inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config() + use_MSAC = hf_config.use_msac + + image_data = multi_modal_data["image"] + num_patches = get_internvl_num_patches(hf_config) + + image_pixel_values_mapper = image_to_pixel_values_wrapper( + hf_config, max_dynamic_patch=max_dynamic_patch) + + # single image + if isinstance(image_data, Image.Image): + pixel_values = image_pixel_values_mapper(image_data, + use_MSAC=use_MSAC) + num_blocks = pixel_values.shape[0] + image_feature_sizes = [num_blocks * num_patches] + pixel_values = pixel_values.unsqueeze(0) + + # multi images + elif is_list_of(image_data, Image.Image): + # Do not use MSAC for multi images + image_feature_sizes = [] + pixel_values = [ + image_pixel_values_mapper(image, use_MSAC=False) + for image in image_data + ] + for pixel_value in pixel_values: + num_blocks = pixel_value.shape[0] + image_feature_sizes.append(num_blocks * num_patches) + + # image embeddings as input + elif isinstance(image_data, torch.Tensor): + _, image_feature_size, _ = image_data.shape + image_feature_sizes = [image_feature_size] + pixel_values = None + + # multi-image image embeddings + elif is_list_of(image_data, torch.Tensor): + + image_feature_sizes = [] + for image_embed in image_data: + _, image_feature_size, _ = image_embed.shape + image_feature_sizes.append(image_feature_size) + pixel_values = None + + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + + prompt = inputs.get("prompt") + prompt_token_ids = inputs["prompt_token_ids"] + if prompt is None: + prompt = tokenizer.decode(prompt_token_ids) + + new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, + num_patches) + new_prompt_token_ids = tokenizer.encode(new_prompt) + + # Wrap image processing in input_processor to avoid duplication + image_token_id = tokenizer.encode( + self.img_context_token, + add_special_tokens=False, + return_tensors="pt", + )[0] + + # Update multi_modal_data to return + if pixel_values is not None: + multi_modal_data = { + "image": { + "pixel_values": pixel_values, + "image_token_id": image_token_id, + } + } + else: + multi_modal_data = {"image": {"image_embeds": image_data}} + + return token_inputs( + prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + + def input_mapper( + self, + ctx: InputContext, + data: object, + *, + max_dynamic_patch: Optional[int] = None, + ) -> MultiModalInputs: + + # NOTE: Preprocessing for the image data is done in the + # 'input_processor' function during actual inference. + if isinstance(data, dict): + return MultiModalInputs(data) + + # The section below is only used with dummy data during + # memory profiling. + hf_config = ctx.get_hf_config() + + image_pixel_values_mapper = image_to_pixel_values_wrapper( + hf_config, max_dynamic_patch) + + if isinstance(data, Image.Image): + pixel_values = image_pixel_values_mapper(data) + pixel_values = pixel_values.unsqueeze(0) + + elif is_list_of(data, Image.Image): + hf_config.use_msac = False + pixel_values = [image_pixel_values_mapper(img) for img in data] + + else: + return MultiModalInputs({"image_embeds": data}) + model_config = ctx.model_config + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + image_token_id = tokenizer.encode( + self.img_context_token, + add_special_tokens=False, + return_tensors="pt", + )[0] + + return MultiModalInputs({ + "pixel_values": pixel_values, + "image_token_id": image_token_id + }) + + +input_pipeline = H2OVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) +@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data) +@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor) +class H2OVLChatModel(InternVLChatModel): + + def _init_vision_model( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + is_mono: bool, + prefix: str, + ): + if not is_mono: + vision_feature_layer = config.select_layer + if vision_feature_layer < 0: + num_hidden_layers = (config.vision_config.num_hidden_layers + + vision_feature_layer + 1) + else: + num_hidden_layers = vision_feature_layer + 1 + + return InternVisionModel( + config.vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=prefix, + ) + else: + msg = "Monolith mode is not applicable to H2OVL" + raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f50ceaccb1bbe..3a929f5cb5195 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -128,6 +128,7 @@ def add_embedding_models(base_models, embedding_models): "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 @@ -482,4 +483,4 @@ def _run() -> None: if __name__ == "__main__": - _run() + _run() \ No newline at end of file diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 9bd2531d7a15c..08697274854e0 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -19,6 +19,7 @@ # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, + H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, @@ -52,6 +53,7 @@ "medusa": MedusaConfig, "eagle": EAGLEConfig, "exaone": ExaoneConfig, + "h2ovl_chat": H2OVLChatConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index f0d79197a82c5..d1e19c9a33c24 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -6,6 +6,7 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -22,6 +23,7 @@ "DbrxConfig", "MPTConfig", "RWConfig", + "H2OVLChatConfig", "InternVLChatConfig", "JAISConfig", "MedusaConfig", @@ -33,4 +35,4 @@ "NVLM_D_Config", "SolarConfig", "UltravoxConfig", -] +] \ No newline at end of file diff --git a/vllm/transformers_utils/configs/h2ovl.py b/vllm/transformers_utils/configs/h2ovl.py new file mode 100644 index 0000000000000..b94c5b77e4b7f --- /dev/null +++ b/vllm/transformers_utils/configs/h2ovl.py @@ -0,0 +1,13 @@ +# Adapted from +# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py +# -------------------------------------------------------- +# H2OVL-Mississippi +# Copyright (c) 2024 H2O.AI +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- + +from .internvl import InternVLChatConfig + + +class H2OVLChatConfig(InternVLChatConfig): + model_type = "h2ovl_chat"