Skip to content

Commit

Permalink
Fix number of patch check for different vision feature select strategy (
Browse files Browse the repository at this point in the history
huggingface#32494)

* Fix number of patch check for different vision feature select strategy

* add test

---------

Co-authored-by: raushan <[email protected]>
  • Loading branch information
insujang and zucchini-nlp authored Sep 17, 2024
1 parent 18e1a9c commit bcf8946
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def _merge_input_ids_with_image_features(

return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids

def pack_image_features(self, image_features, image_sizes, image_newline=None):
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Expand All @@ -654,6 +654,8 @@ def pack_image_features(self, image_features, image_sizes, image_newline=None):
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
Expand All @@ -668,8 +670,14 @@ def pack_image_features(self, image_features, image_sizes, image_newline=None):
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
if height * width != base_image_feature.shape[0]:

if vision_feature_select_strategy == "default":
expected_num_patches = height * width
elif vision_feature_select_strategy == "full":
expected_num_patches = height * width + 1
if expected_num_patches != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")

num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
Expand Down Expand Up @@ -825,6 +833,7 @@ def forward(
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
if legacy_processing:
Expand Down
21 changes: 21 additions & 0 deletions tests/models/llava_next/test_modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,24 @@ def test_expansion_in_processing(self):

# check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())

@slow
@require_bitsandbytes
def test_small_model_integration_test_full_vision_state_selection(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)
# test that changing `strategy` won't error out
model.vision_feature_select_strategy = "full"

inputs = self.processor(self.prompt, self.image, return_tensors="pt")

# verify generation
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes' # fmt: skip

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

0 comments on commit bcf8946

Please sign in to comment.