Skip to content

Commit

Permalink
Add video processing logic for idefics2 (#418)
Browse files Browse the repository at this point in the history
  • Loading branch information
kcz358 authored Nov 23, 2024
1 parent 9baa13e commit d65e0e2
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions lmms_eval/models/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import torch
from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm
from transformers import AutoProcessor, Idefics2ForConditionalGeneration

from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.load_video import load_video_decord

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -53,6 +55,7 @@ def __init__(
device_map: str = "",
use_cache: bool = True,
do_image_splitting: bool = False,
max_frames_num: int = 16,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -70,6 +73,7 @@ def __init__(
dtype = getattr(torch, dtype)
self._model = Idefics2ForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
self._processor = AutoProcessor.from_pretrained(pretrained, do_image_splitting=do_image_splitting, revision=revision, trust_remote_code=trust_remote_code)
self.max_frames_num = max_frames_num

self._tokenizer = self._processor.tokenizer
self._config = self._model.config
Expand Down Expand Up @@ -205,16 +209,27 @@ def _collate(x):
gen_kwargs["temperature"] = 0

prompts = []
videos = None
for context, visual in zip(contexts, visuals):
content = []
if DEFAULT_IMAGE_TOKEN not in context:
if isinstance(visual[0], str):
videos = load_video_decord(visual[0], max_frames_num=self.max_frames_num)
for _ in range(videos.shape[0]):
content.append({"type": "image"})
elif DEFAULT_IMAGE_TOKEN not in context:
for image in visual:
content.append({"type": "image"})
content.append({"type": "text", "text": context})
message = [{"role": "user", "content": content}]
prompt = self._processor.apply_chat_template(message, add_generation_prompt=True)
prompts.append(prompt)
inputs = self._processor(text=prompts, images=visuals, padding=True, return_tensors="pt")
if videos is not None:
images = []
for frame in videos:
images.append(to_pil_image(frame))
inputs = self._processor(text=prompts, images=images, padding=True, return_tensors="pt")
else:
inputs = self._processor(text=prompts, images=visuals, padding=True, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
output_ids = self.model.generate(**inputs, **gen_kwargs)
# only retain the generated text
Expand Down

0 comments on commit d65e0e2

Please sign in to comment.