Skip to content

Commit

Permalink
add captions generation
Browse files Browse the repository at this point in the history
  • Loading branch information
VladOS95-cyber committed Nov 24, 2024
1 parent 9ecb971 commit 336def5
Showing 1 changed file with 86 additions and 19 deletions.
105 changes: 86 additions & 19 deletions examples/inference/distributed/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import pathlib
import queue
import time
import av
from concurrent.futures import ThreadPoolExecutor

import fire
import torch
from datasets import load_dataset
from huggingface_hub import snapshot_download
import numpy as np
from tqdm import tqdm
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor

Expand All @@ -46,13 +48,11 @@ def save_results(output_queue: queue.Queue, output_dir: pathlib.Path):
item = output_queue.get(timeout=5)
if item is None:
break
prompt, video, generated_text = item
example_file = f"example_{count}"
temp_dir = os.path.join(output_dir, example_file)

metadata = {
"caption": item[0],
"generated_answer": item[1],
}
metadata = {"prompt": prompt, "video": video, "generated_text": generated_text}
with open(temp_dir, "w") as f:
json.dump(metadata, f, indent=4)
count += 1
Expand All @@ -61,23 +61,85 @@ def save_results(output_queue: queue.Queue, output_dir: pathlib.Path):
continue


def get_batches(captions, batch_size):
num_batches = (len(captions) + batch_size - 1) // batch_size
def get_batches(processed_videos, batch_size):
num_batches = (len(processed_videos) + batch_size - 1) // batch_size
batches = []

for i in range(num_batches):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(captions))
batch = captions[start_index:end_index]
end_index = min((i + 1) * batch_size, len(processed_videos))
batch = processed_videos[start_index:end_index]
batches.append(batch)

return batches


def read_video_pyav(container, indices):
"""
Decode the video with PyAV decoder.
Args:
container (`av.container.input.InputContainer`): PyAV container.
indices (`List[int]`): List of frame indices to decode.
Returns:
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
"""
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])


def get_video_paths(video_dir):
"""Get paths to all video files in the directory and its subdirectories."""
video_extensions = (".mp4", ".avi", ".mov", ".mkv") # Add more extensions if needed
video_paths = []

for root, _, files in os.walk(video_dir):
for file in files:
if file.lower().endswith(video_extensions):
video_paths.append(os.path.join(root, file))

return video_paths


def process_videos(video_paths, processor, prompt):
"""Process a batch of videos and prepare them for the model."""
batch_inputs = []

for video_path in video_paths:
try:
container = av.open(video_path)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)
container.close()

processed = processor(text=prompt, videos=clip, return_tensors="pt")
batch_inputs.append(
{
"input_ids": processed["input_ids"],
"pixel_values_videos": processed["pixel_values_videos"],
"video": video_path,
}
)

except Exception as e:
print(f"Error processing video {video_path}: {str(e)}")
continue

return batch_inputs


def main(
model_name: str = "llava-hf/LLaVA-NeXT-Video-7B-hf",
save_dir: str = "./evaluation/examples",
max_captions: int = 10,
prompt: str = "USER: <video>\nGenerate caption ASSISTANT:",
max_new_tokens: int = 100,
batch_size: int = 4,
dtype: str = "fp16",
Expand All @@ -99,20 +161,25 @@ def main(
else:
print(f"Directory '{save_dir}' already exists.")

captions = load_dataset("nkp37/OpenVid-1M", split="train")["caption"]
reduced_captions = captions[: min(len(captions), max_captions)]
batches = get_batches(reduced_captions, batch_size)
videos_dir = snapshot_download(repo_id="malterei/LLaVA-Video-small-swift", repo_type="dataset")
video_paths = get_video_paths(videos_dir)
processed_videos = process_videos(video_paths, processor, prompt)
batches = get_batches(processed_videos, batch_size)

output_queue = queue.Queue()
save_thread = ThreadPoolExecutor(max_workers=num_workers)
save_future = save_thread.submit(save_results, output_queue, save_dir)
for _, caption_batch in tqdm(enumerate(batches), total=len(batches)):
for _, batch_raw in tqdm(enumerate(batches), total=len(batches)):
try:
with distributed_state.split_between_processes(caption_batch) as caption:
input = processor(caption, padding=True, return_tensors="pt").to(model.device)
output = model.generate(**input, max_new_tokens=max_new_tokens)
generated_text = processor.batch_decode(output, skip_special_tokens=True)
output_queue.put((caption, generated_text))
with distributed_state.split_between_processes(batch_raw) as batched_inputs:
for batch in batched_inputs:
output = model.generate(
input_ids=batch["input_ids"].to(distributed_state.device),
pixel_values_videos=batch["pixel_values_videos"].to(distributed_state.device, model.dtype),
max_new_tokens=max_new_tokens,
)
generated_text = processor.batch_decode(output, skip_special_tokens=True)
output_queue.put((prompt, batch["video"], generated_text))
finally:
output_queue.put(None)
save_thread.shutdown(wait=True)
Expand Down

0 comments on commit 336def5

Please sign in to comment.