From bcf8946f0acb578c534b1d33d534450d1fc88507 Mon Sep 17 00:00:00 2001 From: Insu Jang Date: Tue, 17 Sep 2024 03:33:07 -0400 Subject: [PATCH] Fix number of patch check for different vision feature select strategy (#32494) * Fix number of patch check for different vision feature select strategy * add test --------- Co-authored-by: raushan --- .../models/llava_next/modeling_llava_next.py | 13 ++++++++++-- .../llava_next/test_modeling_llava_next.py | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index ebb4da3102da42..c1d1ca8c276d7a 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -645,7 +645,7 @@ def _merge_input_ids_with_image_features( return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids - def pack_image_features(self, image_features, image_sizes, image_newline=None): + def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -654,6 +654,8 @@ def pack_image_features(self, image_features, image_sizes, image_newline=None): List of image feature tensor, each contains all the visual feature of all patches. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). + vision_feature_select_strategy (`str`) + The feature selection strategy used to select the vision feature from the vision backbone. image_newline (`torch.Tensor` of shape `(embed_dim)`) New line embedding vector. Returns: @@ -668,8 +670,14 @@ def pack_image_features(self, image_features, image_sizes, image_newline=None): base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size - if height * width != base_image_feature.shape[0]: + + if vision_feature_select_strategy == "default": + expected_num_patches = height * width + elif vision_feature_select_strategy == "full": + expected_num_patches = height * width + 1 + if expected_num_patches != base_image_feature.shape[0]: raise ValueError("The number of patches is not consistent with the image size.") + num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, @@ -825,6 +833,7 @@ def forward( image_features, feature_lens = self.pack_image_features( image_features, image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, image_newline=self.image_newline, ) if legacy_processing: diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 3120db216ea4bb..bd0b5a19064650 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -620,3 +620,24 @@ def test_expansion_in_processing(self): # check that both inputs are handled correctly and generate the same output self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_full_vision_state_selection(self): + model = LlavaNextForConditionalGeneration.from_pretrained( + "llava-hf/llava-v1.6-mistral-7b-hf", + load_in_4bit=True, + ) + # test that changing `strategy` won't error out + model.vision_feature_select_strategy = "full" + + inputs = self.processor(self.prompt, self.image, return_tensors="pt") + + # verify generation + output = model.generate(**inputs, max_new_tokens=30) + EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes' # fmt: skip + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + )