diff --git a/README.md b/README.md index 6ee00eac..a62d2b51 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ Available flavours (combine them with `,` i.e. `[processing,s3]`): - `processing` dependencies for text extraction, filtering and tokenization: `pip install datatrove[processing]` - `s3` s3 support: `pip install datatrove[s3]` - `cli` for command line tools: `pip install datatrove[cli]` +- `video` dependencies to opperate with video `pip install datatrove[video]`. Additionally install ffmpeg in your system, more details https://www.ffmpeg.org/download.html ## Quickstart examples You can check the following [examples](examples): diff --git a/examples/media_experiment.py b/examples/media_experiment.py new file mode 100644 index 00000000..ab67b667 --- /dev/null +++ b/examples/media_experiment.py @@ -0,0 +1,34 @@ +from datatrove.executor.base import PipelineExecutor +from datatrove.executor.local import LocalPipelineExecutor +from datatrove.pipeline.filters import VideoFrozenFilter +from datatrove.pipeline.readers import VideoTripletReader + + +def run_step_1(): + video_triplet_reader = VideoTripletReader( + data_folder="s3://amotoratolins/datatrovetest/", metadata_origin="youtube" + ) + + video_frozen_filter = VideoFrozenFilter() + + pipeline_1 = [video_triplet_reader, video_frozen_filter] + + # Create the executor with the pipeline + executor_1: PipelineExecutor = LocalPipelineExecutor(pipeline=pipeline_1, workers=1, tasks=1) + + # Execute the pipeline + # result = executor_1.run() + executor_1.run() + + +# # Additional debugging +# for document in video_triplet_reader.read_file(None): +# print(f"Document ID: {document.id}") +# print(f"Text: {document.text[:100]}...") # Print first 100 characters of text +# print(f"Media: {document.media}") +# print(f"Metadata: {document.metadata}") +# print("-" * 80) + + +# Run the testing pipeline +run_step_1() diff --git a/pyproject.toml b/pyproject.toml index a290adf6..e169a3eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,10 +92,17 @@ all = [ "datatrove[quality]", "datatrove[testing]", ] + +video = [ + "ffmpeg-python" +] + dev = [ - "datatrove[all]" + "datatrove[all]", + "datatrove[video]" ] + [project.urls] Repository = "https://github.com/huggingface/datatrove" diff --git a/src/datatrove/pipeline/filters/__init__.py b/src/datatrove/pipeline/filters/__init__.py index 065496a2..1620f7ee 100644 --- a/src/datatrove/pipeline/filters/__init__.py +++ b/src/datatrove/pipeline/filters/__init__.py @@ -9,3 +9,4 @@ from .sampler_filter import SamplerFilter from .unigram_log_probs import UnigramLogProbFilter from .url_filter import URLFilter +from .video_frozen_filter import VideoFrozenFilter diff --git a/src/datatrove/pipeline/filters/video_frozen_filter.py b/src/datatrove/pipeline/filters/video_frozen_filter.py new file mode 100644 index 00000000..2fc569b8 --- /dev/null +++ b/src/datatrove/pipeline/filters/video_frozen_filter.py @@ -0,0 +1,107 @@ +import shutil +from typing import Tuple + +from loguru import logger + +from datatrove.data import Document +from datatrove.pipeline.filters.base_filter import BaseFilter + + +class VideoFrozenFilter(BaseFilter): + """Filter that uses ffmpeg to detect if a video is static (frozen).""" + + name = "🧊 Video-Frozen-filter" + _requires_dependencies = ["ffmpeg"] + + def __init__( + self, exclusion_writer=None, batch_size: int = 1, freeze_threshold: float = 0.005, freeze_duration: int = 60 + ): + """ + Args: + exclusion_writer: optionally pass in a writer that will save the dropped documents. + batch_size: the number of documents to process in a batch. + freeze_threshold: the noise threshold for detecting a frozen frame (default is 0.005). + freeze_duration: the minimum duration (in seconds) that frames must be frozen to trigger detection (default is 60 seconds). + """ + super().__init__(exclusion_writer, batch_size) + self.ffmpeg = None + self.freeze_threshold = freeze_threshold + self.freeze_duration = freeze_duration + + # Check if ffmpeg is installed + if shutil.which("ffmpeg") is None: + raise EnvironmentError( + "ffmpeg is not installed. Please install it to use the VideoFrozenFilter. More details: https://www.ffmpeg.org/download.html" + ) + + def filter(self, doc: Document) -> bool | Tuple[bool, str]: + video_path = doc.media[0].local_path if doc.media else None + import os + + if not os.path.exists(video_path): + logger.warning(f"Video path does not exist: {video_path}") + if video_path and self.is_video_frozen(video_path): + return False, "frozen_video" + return True + + def is_video_frozen(self, video_path: str) -> bool: + """Dynamically determines intervals and checks if the video is frozen during those intervals.""" + + if self.ffmpeg is None: + import ffmpeg + + self.ffmpeg = ffmpeg + + video_duration = self.get_video_duration(video_path) + + # Adjusted video duration to account for 10-second padding + effective_duration = video_duration - 20 # Remove 10 seconds from start and end + + if effective_duration <= 0: + # If the effective duration is less than or equal to 0, return False as we can't analyze anything + return False + + intervals = [] + + # If the effective duration is very short, analyze the whole effective video + if effective_duration < 300: + intervals = [("10", str(effective_duration))] + else: + # Create intervals every 5 minutes (300 seconds), analyzing 1-minute chunks + intervals = [(str(10 + i * 300), "60") for i in range(int(effective_duration // 300))] + + # Handle the remaining part of the video, if it exists + remainder = effective_duration % 300 + if remainder > 0: + intervals.append((str(video_duration - remainder - 10), str(remainder))) + + for start_time, duration in intervals: + if self.check_freeze(video_path, start_time, duration): + print(f"{video_path} at {start_time} seen as frozen") + return True + return False + + def get_video_duration(self, video_path: str) -> float: + """Get the duration of the video in seconds using ffmpeg.""" + try: + probe = self.ffmpeg.probe(video_path) + return float(probe["format"]["duration"]) + except self.ffmpeg.Error as e: + logger.info(f"ffprobe {video_path}:") + logger.error(e.stderr.decode("utf-8")) + raise e + + def check_freeze(self, video_path: str, start_time: str, duration: str) -> bool: + """Check for frozen frames in a specific interval using ffmpeg's freezedetect filter.""" + try: + out, err = ( + self.ffmpeg.input(video_path, ss=start_time, t=duration) + .filter("freezedetect", n=self.freeze_threshold, d=self.freeze_duration) + .output("null", f="null") + .run(capture_stdout=True, capture_stderr=True) + ) + err = err.decode("utf-8") + return "freeze_start" in err and "freeze_end" not in err + except self.ffmpeg.Error as e: + print(f"Error processing video {video_path}: {e}") + return False diff --git a/src/datatrove/pipeline/readers/__init__.py b/src/datatrove/pipeline/readers/__init__.py index 5f460e7d..41f198b2 100644 --- a/src/datatrove/pipeline/readers/__init__.py +++ b/src/datatrove/pipeline/readers/__init__.py @@ -3,4 +3,5 @@ from .ipc import IpcReader from .jsonl import JsonlReader from .parquet import ParquetReader +from .videotriplet import VideoTripletReader from .warc import WarcReader diff --git a/src/datatrove/pipeline/readers/videotriplet.py b/src/datatrove/pipeline/readers/videotriplet.py new file mode 100644 index 00000000..3e2b522f --- /dev/null +++ b/src/datatrove/pipeline/readers/videotriplet.py @@ -0,0 +1,167 @@ +import json +import os +import warnings +from typing import Callable, Dict, List + +from datatrove.data import Document, DocumentsPipeline, Media, MediaType +from datatrove.io import DataFileLike, DataFolderLike, download_file +from datatrove.pipeline.readers.base import BaseDiskReader + + +class VideoTripletReader(BaseDiskReader): + """Read triplets of video, metadata, and optional caption files.""" + + name = "🎥 Video Triplet Reader" + + def __init__( + self, + data_folder: DataFolderLike, + paths_file: DataFileLike | None = None, + metadata_origin: str | None = None, + limit: int = -1, + skip: int = 0, + file_progress: bool = False, + doc_progress: bool = False, + adapter: Callable = None, + text_key: str = "text", + id_key: str = "id", + default_metadata: dict = None, + recursive: bool = True, + local_cache_dir="/tmp/local_video_cache", + ): + self.metadata_origin = metadata_origin + self.local_cache_dir = local_cache_dir + os.makedirs(self.local_cache_dir, exist_ok=True) + super().__init__( + data_folder, + paths_file, + limit, + skip, + file_progress, + doc_progress, + adapter, + text_key, + id_key, + default_metadata, + recursive, + ) + + def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: + """Overrides the base run method to handle triplet statistics correctly.""" + triplet_count = 0 + if data: + yield from data + for triplet in self.find_triplets(rank, world_size): + document = self.process_triplet(triplet) + if document: + self.stat_update("documents") # Track the number of triplets processed + self.update_doc_stats(document) + triplet_count += 1 + yield document + + def find_triplets(self, rank: int = 0, world_size: int = 1) -> List[Dict[str, str]]: + """Find triplets of video, metadata, and caption files in the data folder.""" + triplets = [] + video_extensions = (".mp4", ".avi", ".mkv", ".mov") + metadata_extension = ".json" + caption_extension = ".vtt" + + if self.paths_file: + with self.data_folder.open(self.paths_file, "r") as f: + paths = [line.strip() for line in f] + else: + paths = self.data_folder.list_files(recursive=self.recursive) + + for path in paths: + base_name, ext = os.path.splitext(path) + if ext in video_extensions: + video_file = path + metadata_file = base_name + metadata_extension + caption_file = base_name + caption_extension + + if self.data_folder.exists(metadata_file): + triplet = { + "video": video_file, + "metadata": metadata_file, + "caption": caption_file if self.data_folder.exists(caption_file) else None, + } + triplets.append(triplet) + return triplets[rank::world_size] + + def read_file(self, filepath: str): + for triplet in self.find_triplets(): + with self.track_time(): + document = self.process_triplet(triplet) + if document: + yield document + + def process_triplet(self, triplet: Dict[str, str]) -> Document | None: + video_path = triplet["video"] + metadata_path = triplet["metadata"] + caption_path = triplet["caption"] + video_id = os.path.splitext(os.path.basename(video_path))[0] + + # Resolve the correct URL and local paths + video_url = self.data_folder.resolve_paths(video_path) + video_local_path = self.ensure_local_copy(video_url) + + # Load metadata, video, and caption data + metadata = self.load_json(metadata_path) + video_media = Media(type=MediaType.VIDEO, url=video_url, local_path=video_local_path) + caption_text = self.load_caption(caption_path) if caption_path else "" + + document = Document( + text=caption_text, + id=video_id, + media=[video_media], + metadata=metadata, + ) + + return document + + def ensure_local_copy(self, video_url: str) -> str: + """Ensure that the video is available locally. If not, download it.""" + if self.data_folder.is_local(): + return video_url + + local_path = os.path.join(self.local_cache_dir, os.path.basename(video_url)) + if not os.path.exists(local_path): + download_file(video_url, local_path) + return local_path + + def load_json(self, filepath: str) -> dict: + with self.data_folder.open(filepath, "r") as f: + data = json.load(f) + + if self.metadata_origin == "youtube": + return self.process_youtube_metadata(data) + elif self.metadata_origin is None: + warnings.warn("metadata_origin is not specified. Loading full JSON without processing.") + return data + else: + return data + + def load_caption(self, filepath: str) -> str: + with self.data_folder.open(filepath, "r") as f: + return f.read() + + def process_youtube_metadata(self, data: dict) -> dict: + processed_metadata = { + "video_codec": data.get("vcodec"), + "audio_codec": data.get("acodec"), + "video_resolution": data.get("resolution"), + "duration": data.get("duration_string"), + "title": data.get("title"), + "description": data.get("description"), + "categories": data.get("categories"), + "tags": data.get("tags"), + "channel": data.get("channel"), + "view_count": data.get("view_count"), + "comment_count": data.get("comment_count"), + "like_count": data.get("like_count"), + "channel_follower_count": data.get("channel_follower_count"), + "upload_date": data.get("upload_date"), + "language": data.get("language"), + "age_limit": data.get("age_limit"), + } + return processed_metadata