Skip to content

Commit

Permalink
small refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
VladOS95-cyber committed Nov 24, 2024
1 parent 336def5 commit a71dabe
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions examples/inference/distributed/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import pathlib
import queue
import time
import av
from concurrent.futures import ThreadPoolExecutor

import av
import fire
import numpy as np
import torch
from huggingface_hub import snapshot_download
import numpy as np
from tqdm import tqdm
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor

Expand Down Expand Up @@ -108,26 +108,25 @@ def get_video_paths(video_dir):
return video_paths


def process_videos(video_paths, processor, prompt):
def process_videos(video_paths, processor, prompt, frames_per_video):
"""Process a batch of videos and prepare them for the model."""
batch_inputs = []

for video_path in video_paths:
try:
container = av.open(video_path)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)
container.close()

processed = processor(text=prompt, videos=clip, return_tensors="pt")
batch_inputs.append(
{
"input_ids": processed["input_ids"],
"pixel_values_videos": processed["pixel_values_videos"],
"video": video_path,
}
)
with av.open(video_path) as container:
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / frames_per_video).astype(int)
clip = read_video_pyav(container, indices)

processed = processor(text=prompt, videos=clip, return_tensors="pt")
batch_inputs.append(
{
"input_ids": processed["input_ids"],
"pixel_values_videos": processed["pixel_values_videos"],
"video": video_path,
}
)

except Exception as e:
print(f"Error processing video {video_path}: {str(e)}")
Expand All @@ -140,6 +139,7 @@ def main(
model_name: str = "llava-hf/LLaVA-NeXT-Video-7B-hf",
save_dir: str = "./evaluation/examples",
prompt: str = "USER: <video>\nGenerate caption ASSISTANT:",
frames_per_video: int = 8,
max_new_tokens: int = 100,
batch_size: int = 4,
dtype: str = "fp16",
Expand All @@ -163,7 +163,7 @@ def main(

videos_dir = snapshot_download(repo_id="malterei/LLaVA-Video-small-swift", repo_type="dataset")
video_paths = get_video_paths(videos_dir)
processed_videos = process_videos(video_paths, processor, prompt)
processed_videos = process_videos(video_paths, processor, prompt, frames_per_video)
batches = get_batches(processed_videos, batch_size)

output_queue = queue.Queue()
Expand Down

0 comments on commit a71dabe

Please sign in to comment.