-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Video support for datatrove #271
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from datatrove.executor.base import PipelineExecutor | ||
from datatrove.executor.local import LocalPipelineExecutor | ||
from datatrove.pipeline.filters import VideoFrozenFilter | ||
from datatrove.pipeline.readers import VideoTripletReader | ||
|
||
|
||
def run_step_1(): | ||
video_triplet_reader = VideoTripletReader( | ||
data_folder="s3://amotoratolins/datatrovetest/", metadata_origin="youtube" | ||
) | ||
|
||
video_frozen_filter = VideoFrozenFilter() | ||
|
||
pipeline_1 = [video_triplet_reader, video_frozen_filter] | ||
|
||
# Create the executor with the pipeline | ||
executor_1: PipelineExecutor = LocalPipelineExecutor(pipeline=pipeline_1, workers=1, tasks=1) | ||
|
||
# Execute the pipeline | ||
# result = executor_1.run() | ||
executor_1.run() | ||
|
||
|
||
# # Additional debugging | ||
# for document in video_triplet_reader.read_file(None): | ||
# print(f"Document ID: {document.id}") | ||
# print(f"Text: {document.text[:100]}...") # Print first 100 characters of text | ||
# print(f"Media: {document.media}") | ||
# print(f"Metadata: {document.metadata}") | ||
# print("-" * 80) | ||
|
||
|
||
# Run the testing pipeline | ||
run_step_1() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import shutil | ||
from typing import Tuple | ||
|
||
from loguru import logger | ||
|
||
from datatrove.data import Document | ||
from datatrove.pipeline.filters.base_filter import BaseFilter | ||
|
||
|
||
class VideoFrozenFilter(BaseFilter): | ||
"""Filter that uses ffmpeg to detect if a video is static (frozen).""" | ||
|
||
name = "🧊 Video-Frozen-filter" | ||
_requires_dependencies = ["ffmpeg"] | ||
|
||
def __init__( | ||
self, exclusion_writer=None, batch_size: int = 1, freeze_threshold: float = 0.005, freeze_duration: int = 60 | ||
): | ||
""" | ||
Args: | ||
exclusion_writer: optionally pass in a writer that will save the dropped documents. | ||
batch_size: the number of documents to process in a batch. | ||
freeze_threshold: the noise threshold for detecting a frozen frame (default is 0.005). | ||
freeze_duration: the minimum duration (in seconds) that frames must be frozen to trigger detection (default is 60 seconds). | ||
""" | ||
super().__init__(exclusion_writer, batch_size) | ||
self.ffmpeg = None | ||
self.freeze_threshold = freeze_threshold | ||
self.freeze_duration = freeze_duration | ||
|
||
# Check if ffmpeg is installed | ||
if shutil.which("ffmpeg") is None: | ||
raise EnvironmentError( | ||
"ffmpeg is not installed. Please install it to use the VideoFrozenFilter. More details: https://www.ffmpeg.org/download.html" | ||
) | ||
|
||
def filter(self, doc: Document) -> bool | Tuple[bool, str]: | ||
video_path = doc.media[0].local_path if doc.media else None | ||
import os | ||
|
||
if not os.path.exists(video_path): | ||
logger.warning(f"Video path does not exist: {video_path}") | ||
if video_path and self.is_video_frozen(video_path): | ||
return False, "frozen_video" | ||
return True | ||
|
||
def is_video_frozen(self, video_path: str) -> bool: | ||
"""Dynamically determines intervals and checks if the video is frozen during those intervals.""" | ||
|
||
if self.ffmpeg is None: | ||
import ffmpeg | ||
|
||
self.ffmpeg = ffmpeg | ||
|
||
video_duration = self.get_video_duration(video_path) | ||
|
||
# Adjusted video duration to account for 10-second padding | ||
effective_duration = video_duration - 20 # Remove 10 seconds from start and end | ||
|
||
if effective_duration <= 0: | ||
# If the effective duration is less than or equal to 0, return False as we can't analyze anything | ||
return False | ||
|
||
intervals = [] | ||
|
||
# If the effective duration is very short, analyze the whole effective video | ||
if effective_duration < 300: | ||
intervals = [("10", str(effective_duration))] | ||
else: | ||
# Create intervals every 5 minutes (300 seconds), analyzing 1-minute chunks | ||
intervals = [(str(10 + i * 300), "60") for i in range(int(effective_duration // 300))] | ||
|
||
# Handle the remaining part of the video, if it exists | ||
remainder = effective_duration % 300 | ||
if remainder > 0: | ||
intervals.append((str(video_duration - remainder - 10), str(remainder))) | ||
|
||
for start_time, duration in intervals: | ||
if self.check_freeze(video_path, start_time, duration): | ||
print(f"{video_path} at {start_time} seen as frozen") | ||
return True | ||
return False | ||
|
||
def get_video_duration(self, video_path: str) -> float: | ||
"""Get the duration of the video in seconds using ffmpeg.""" | ||
try: | ||
probe = self.ffmpeg.probe(video_path) | ||
return float(probe["format"]["duration"]) | ||
except self.ffmpeg.Error as e: | ||
logger.info(f"ffprobe {video_path}:") | ||
logger.error(e.stderr.decode("utf-8")) | ||
raise e | ||
|
||
def check_freeze(self, video_path: str, start_time: str, duration: str) -> bool: | ||
"""Check for frozen frames in a specific interval using ffmpeg's freezedetect filter.""" | ||
try: | ||
out, err = ( | ||
self.ffmpeg.input(video_path, ss=start_time, t=duration) | ||
.filter("freezedetect", n=self.freeze_threshold, d=self.freeze_duration) | ||
.output("null", f="null") | ||
.run(capture_stdout=True, capture_stderr=True) | ||
) | ||
err = err.decode("utf-8") | ||
return "freeze_start" in err and "freeze_end" not in err | ||
except self.ffmpeg.Error as e: | ||
print(f"Error processing video {video_path}: {e}") | ||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import json | ||
import os | ||
import warnings | ||
from typing import Callable, Dict, List | ||
|
||
from datatrove.data import Document, DocumentsPipeline, Media, MediaType | ||
from datatrove.io import DataFileLike, DataFolderLike, download_file | ||
from datatrove.pipeline.readers.base import BaseDiskReader | ||
|
||
|
||
class VideoTripletReader(BaseDiskReader): | ||
"""Read triplets of video, metadata, and optional caption files.""" | ||
|
||
name = "🎥 Video Triplet Reader" | ||
|
||
def __init__( | ||
self, | ||
data_folder: DataFolderLike, | ||
paths_file: DataFileLike | None = None, | ||
metadata_origin: str | None = None, | ||
limit: int = -1, | ||
skip: int = 0, | ||
file_progress: bool = False, | ||
doc_progress: bool = False, | ||
adapter: Callable = None, | ||
text_key: str = "text", | ||
id_key: str = "id", | ||
default_metadata: dict = None, | ||
recursive: bool = True, | ||
local_cache_dir="/tmp/local_video_cache", | ||
): | ||
self.metadata_origin = metadata_origin | ||
self.local_cache_dir = local_cache_dir | ||
os.makedirs(self.local_cache_dir, exist_ok=True) | ||
super().__init__( | ||
data_folder, | ||
paths_file, | ||
limit, | ||
skip, | ||
file_progress, | ||
doc_progress, | ||
adapter, | ||
text_key, | ||
id_key, | ||
default_metadata, | ||
recursive, | ||
) | ||
|
||
def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: | ||
"""Overrides the base run method to handle triplet statistics correctly.""" | ||
triplet_count = 0 | ||
if data: | ||
yield from data | ||
for triplet in self.find_triplets(rank, world_size): | ||
document = self.process_triplet(triplet) | ||
if document: | ||
self.stat_update("documents") # Track the number of triplets processed | ||
self.update_doc_stats(document) | ||
triplet_count += 1 | ||
yield document | ||
|
||
def find_triplets(self, rank: int = 0, world_size: int = 1) -> List[Dict[str, str]]: | ||
"""Find triplets of video, metadata, and caption files in the data folder.""" | ||
triplets = [] | ||
video_extensions = (".mp4", ".avi", ".mkv", ".mov") | ||
metadata_extension = ".json" | ||
caption_extension = ".vtt" | ||
|
||
if self.paths_file: | ||
with self.data_folder.open(self.paths_file, "r") as f: | ||
paths = [line.strip() for line in f] | ||
else: | ||
paths = self.data_folder.list_files(recursive=self.recursive) | ||
|
||
for path in paths: | ||
base_name, ext = os.path.splitext(path) | ||
if ext in video_extensions: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could maybe extract the triple finding thing away to the datafolder class later on (you'd pass a tuple like |
||
video_file = path | ||
metadata_file = base_name + metadata_extension | ||
caption_file = base_name + caption_extension | ||
|
||
if self.data_folder.exists(metadata_file): | ||
triplet = { | ||
"video": video_file, | ||
"metadata": metadata_file, | ||
"caption": caption_file if self.data_folder.exists(caption_file) else None, | ||
} | ||
triplets.append(triplet) | ||
return triplets[rank::world_size] | ||
|
||
def read_file(self, filepath: str): | ||
for triplet in self.find_triplets(): | ||
with self.track_time(): | ||
document = self.process_triplet(triplet) | ||
if document: | ||
yield document | ||
|
||
def process_triplet(self, triplet: Dict[str, str]) -> Document | None: | ||
video_path = triplet["video"] | ||
metadata_path = triplet["metadata"] | ||
caption_path = triplet["caption"] | ||
video_id = os.path.splitext(os.path.basename(video_path))[0] | ||
|
||
# Resolve the correct URL and local paths | ||
video_url = self.data_folder.resolve_paths(video_path) | ||
video_local_path = self.ensure_local_copy(video_url) | ||
|
||
# Load metadata, video, and caption data | ||
metadata = self.load_json(metadata_path) | ||
video_media = Media(type=MediaType.VIDEO, url=video_url, local_path=video_local_path) | ||
caption_text = self.load_caption(caption_path) if caption_path else "" | ||
|
||
document = Document( | ||
text=caption_text, | ||
id=video_id, | ||
media=[video_media], | ||
metadata=metadata, | ||
) | ||
|
||
return document | ||
|
||
def ensure_local_copy(self, video_url: str) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good job using all of datatrove useful io features here :) |
||
"""Ensure that the video is available locally. If not, download it.""" | ||
if self.data_folder.is_local(): | ||
return video_url | ||
|
||
local_path = os.path.join(self.local_cache_dir, os.path.basename(video_url)) | ||
if not os.path.exists(local_path): | ||
download_file(video_url, local_path) | ||
return local_path | ||
|
||
def load_json(self, filepath: str) -> dict: | ||
with self.data_folder.open(filepath, "r") as f: | ||
data = json.load(f) | ||
|
||
if self.metadata_origin == "youtube": | ||
return self.process_youtube_metadata(data) | ||
elif self.metadata_origin is None: | ||
warnings.warn("metadata_origin is not specified. Loading full JSON without processing.") | ||
return data | ||
else: | ||
return data | ||
|
||
def load_caption(self, filepath: str) -> str: | ||
with self.data_folder.open(filepath, "r") as f: | ||
return f.read() | ||
|
||
def process_youtube_metadata(self, data: dict) -> dict: | ||
processed_metadata = { | ||
"video_codec": data.get("vcodec"), | ||
"audio_codec": data.get("acodec"), | ||
"video_resolution": data.get("resolution"), | ||
"duration": data.get("duration_string"), | ||
"title": data.get("title"), | ||
"description": data.get("description"), | ||
"categories": data.get("categories"), | ||
"tags": data.get("tags"), | ||
"channel": data.get("channel"), | ||
"view_count": data.get("view_count"), | ||
"comment_count": data.get("comment_count"), | ||
"like_count": data.get("like_count"), | ||
"channel_follower_count": data.get("channel_follower_count"), | ||
"upload_date": data.get("upload_date"), | ||
"language": data.get("language"), | ||
"age_limit": data.get("age_limit"), | ||
} | ||
return processed_metadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably be this to show users the correct package name to install
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You forgot to accept the actual change here I think