diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 79fc30343..e7c1f95f0 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -148,11 +148,14 @@ process: frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. horizontal_flip: false # flip frame image horizontally (left to right). vertical_flip: false # flip frame image vertically (top to bottom). - - video_split_by_scene_mapper: # split videos into scene clips - detector: 'ContentDetector' # PySceneDetect scene detector. Should be one of ['ContentDetector', 'ThresholdDetector', 'AdaptiveDetector`] - threshold: 27.0 # threshold passed to the detector - min_scene_len: 15 # minimum length of any scene - show_progress: false # whether to show progress from scenedetect + - video_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg video filters + - video_remove_watermark_mapper: # Remove the watermarks in videos given regions + roi_strings: ['0,0,0.1,0.1'] # a given list of regions the watermarks locate. The format of each can be "x1, y1, x2, y2", "(x1, y1, x2, y2)", or "[x1, y1, x2, y2]". + roi_type: ratio # the roi string type. When the type is 'pixel', (x1, y1), (x2, y2) are the locations of pixels in the top left corner and the bottom right corner respectively. If the roi_type is 'ratio', the coordinates are normalized by wights and heights. + roi_key: null # the key name of fields in samples to store roi_strings for each sample. It's used for set different rois for different samples. + frame_num: 10 # the number of frames to be extracted uniformly from the video to detect the pixels of watermark. + min_frame_threshold: 7 # a coodination is considered as the location of a watermark pixel when it is a watermark pixel in no less min_frame_threshold frames. + detection_method: pixel_value # the method to detect the pixels of watermark. If it is 'pixel_value', we consider the distribution of pixel value in each frame. If it is 'pixel_diversity', we will consider the pixel diversity in different frames. - video_resize_aspect_ratio_mapper: # resize videos aspect ratios of videos (a fraction of width by height, r=w/h) to a specified range min_ratio: 9/21 # the minimum aspect ratio to enforce videos with an aspect ratio below `min_ratio` will be resized to match this minimum ratio. The ratio should be provided as a string in the format "9:21" or "9/21". max_ratio: 21/9 # the maximum aspect ratio to enforce videos with an aspect ratio above `max_ratio` will be resized to match this maximum ratio. The ratio should be provided as a string in the format "21:9" or "21/9". @@ -164,13 +167,17 @@ process: max_height: 1080 # the max vertical resolution (unit p), videos with height more than 'max_height' will be mapped to videos with equal or smaller height force_original_aspect_ratio: 'increase' # Enable decreasing or increasing output video width or height if necessary to keep the original aspect ratio force_divisible_by: 4 # Ensures that both the output dimensions, width and height, are divisible by the given integer when used together with force_original_aspect_ratio - - video_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg video filters - video_split_by_duration_mapper: # Mapper to split video by duration. split_duration: 10 # duration of each video split in seconds. min_last_split_duration: 0.1 # the minimum allowable duration in seconds for the last video split. If the duration of the last split is less than this value, it will be discarded. keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only cut sample in the final datasets and the original sample will be removed. It's True in default - video_split_by_key_frame_mapper: # Mapper to split video by key frame. keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only cut sample in the final datasets and the original sample will be removed. It's True in default + - video_split_by_scene_mapper: # split videos into scene clips + detector: 'ContentDetector' # PySceneDetect scene detector. Should be one of ['ContentDetector', 'ThresholdDetector', 'AdaptiveDetector`] + threshold: 27.0 # threshold passed to the detector + min_scene_len: 15 # minimum length of any scene + show_progress: false # whether to show progress from scenedetect - video_tagging_from_audio_mapper: # Mapper to generate video tags from audio streams extracted from the video. hf_ast: 'MIT/ast-finetuned-audioset-10-10-0.4593' # Huggingface model name for the audio classification model. - video_tagging_from_frames_mapper: # Mapper to generate video tags from frames extracted from the video. diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 1c5cc2855..4f6377a34 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -14,7 +14,7 @@ replace_content_mapper, sentence_split_mapper, video_captioning_from_audio_mapper, video_captioning_from_video_mapper, video_ffmpeg_wrapped_mapper, - video_resize_aspect_ratio_mapper, + video_remove_watermark_mapper, video_resize_aspect_ratio_mapper, video_resize_resolution_mapper, video_split_by_duration_mapper, video_split_by_key_frame_mapper, video_split_by_scene_mapper, video_tagging_from_audio_mapper, diff --git a/data_juicer/ops/mapper/video_remove_watermark_mapper.py b/data_juicer/ops/mapper/video_remove_watermark_mapper.py new file mode 100644 index 000000000..068eb6507 --- /dev/null +++ b/data_juicer/ops/mapper/video_remove_watermark_mapper.py @@ -0,0 +1,232 @@ +import os + +import av +import numpy as np +from jsonargparse.typing import List, PositiveInt + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.file_utils import transfer_filename +from data_juicer.utils.logger_utils import HiddenPrints +from data_juicer.utils.mm_utils import (extract_video_frames_uniformly, + load_data_with_context, load_video, + parse_string_to_roi, + process_each_frame) + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_remove_watermark_mapper' + +with AvailabilityChecking(['opencv-python'], OP_NAME), HiddenPrints(): + import cv2 as cv + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoRemoveWatermarkMapper(Mapper): + """ + Remove the watermarks in videos given regions. + """ + + def __init__(self, + roi_strings: List[str] = ['0,0,0.1,0.1'], + roi_type: str = 'ratio', + roi_key: str = None, + frame_num: PositiveInt = 10, + min_frame_threshold: PositiveInt = 7, + detection_method: str = 'pixel_value', + threshold: int = None, + *args, + **kwargs): + """ + Initialization method. + + :param roi_strings: a given list of regions the watermarks locate. + The format of each can be "x1, y1, x2, y2", "(x1, y1, x2, y2)", + or "[x1, y1, x2, y2]". + :param roi_type: the roi string type. When the type is 'pixel', (x1, + y1), (x2, y2) are the locations of pixels in the top left corner + and the bottom right corner respectively. If the roi_type is + 'ratio', the coordinates are normalized by wights and heights. + :param roi_key: the key name of fields in samples to store roi_strings + for each sample. It's used for set different rois for different + samples. If it's none, use rois in parameter "roi_strings". + It's None in default. + :param frame_num: the number of frames to be extracted uniformly from + the video to detect the pixels of watermark. + :param min_frame_threshold: a coodination is considered as the + location of a watermark pixel when it is that in no less + min_frame_threshold frames. + :param detection_method: the method to detect the pixels of watermark. + If it is 'pixel_value', we consider the distribution of pixel + value in each frame. If it is 'pixel_diversity', we will consider + the pixel diversity in different frames. The min_frame_threshold + is useless and frame_num must be greater than 1 in + 'pixel_diversity' mode. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + if roi_type not in ['ratio', 'pixel']: + raise ValueError(f'roi_type [{roi_type}]' + f' is not supported. ' + f"Can only be one of ['ratio', 'pixel']. ") + + if detection_method not in ['pixel_value', 'pixel_diversity']: + raise ValueError( + f'etection_method [{detection_method}]' + f' is not supported. ' + f"Can only be one of ['pixel_value', 'pixel_diversity']. ") + + if detection_method == 'pixel_diversity' and frame_num < 2: + raise ValueError( + "frame_num must be gteater than 1 in 'pixel_diversity' mode.") + + rois = [] + if roi_key is None: + for roi_string in roi_strings: + roi = parse_string_to_roi(roi_string, roi_type) + if roi is None: + raise ValueError( + 'The roi in roi_strings must be four no negative' + ' numbers in the format of "x1, y1, x2, y2", ' + '"(x1, y1, x2, y2)", or "[x1, y1, x2, y2]".') + rois.append(roi) + + self.roi_type = roi_type + self.rois = rois + self.roi_key = roi_key + self.frame_num = frame_num + self.min_frame_threshold = min_frame_threshold + self.detection_method = detection_method + + def _detect_watermark_via_pixel_value(self, frames, rois): + + masks = [] + for frame in frames: + frame = frame.to_ndarray(format='bgr24') + mask = np.zeros_like(frame[:, :, 0], dtype=np.uint8) + for roi in rois: + # dimension of ndarray frame: height x width x channel + roi_frame = frame[roi[1]:roi[3], roi[0]:roi[2]] + gray_frame = cv.cvtColor(roi_frame, cv.COLOR_BGR2GRAY) + _, binary_frame = cv.threshold( + gray_frame, 0, 255, cv.THRESH_BINARY + cv.THRESH_OTSU) + + # assume the watermark is located in the box, so the pixel in + # the edge must be 0, if not, reverse binary_frame + edge_postive_num = (binary_frame[0] > + 0).sum() + (binary_frame[:, 0] > 0).sum() + total = binary_frame.shape[0] + binary_frame.shape[1] + if edge_postive_num * 2 > total: + binary_frame = ~binary_frame + + mask[roi[1]:roi[3], + roi[0]:roi[2]] = mask[roi[1]:roi[3], + roi[0]:roi[2]] | binary_frame + masks.append(mask) + final_mask = sum((mask == 255).astype(np.uint8) for mask in masks) + final_mask = np.where(final_mask >= self.min_frame_threshold, 255, 0) + final_mask = final_mask.astype(np.uint8) + return final_mask + + def _detect_watermark_via_pixel_diversity(self, frames, rois): + + mask = np.zeros((frames[0].height, frames[0].width), dtype=np.uint8) + frames = [frame.to_ndarray(format='bgr24') for frame in frames] + + for roi in rois: + roi_frames = [ + frame[roi[1]:roi[3], roi[0]:roi[2]] for frame in frames + ] + roi_frames = np.stack(roi_frames, axis=0) + pixel_diversity = roi_frames.std(axis=0) + pixel_diversity = pixel_diversity.sum(-1) + max_diversity = np.max(pixel_diversity) + min_diversity = np.min(pixel_diversity) + if max_diversity > min_diversity: + scaled_diversity = 255 * (pixel_diversity - min_diversity) / ( + max_diversity - min_diversity) + else: + scaled_diversity = np.zeros_like(pixel_diversity) + scaled_diversity = scaled_diversity.astype(np.uint8) + _, binary_frame = cv.threshold(scaled_diversity, 0, 255, + cv.THRESH_BINARY + cv.THRESH_OTSU) + # the watermark pixels have less diversity + binary_frame = ~binary_frame + mask[roi[1]:roi[3], + roi[0]:roi[2]] = mask[roi[1]:roi[3], + roi[0]:roi[2]] | binary_frame + + return mask + + def _generate_watermark_mask(self, video, sample): + frames = extract_video_frames_uniformly(video, self.frame_num) + + if self.roi_key is not None: + roi_strings = sample[self.roi_key] + if isinstance(roi_strings, str): + roi_strings = [roi_strings] + rois = [ + parse_string_to_roi(roi_string, self.roi_type) + for roi_string in roi_strings + ] + rois = [roi for roi in rois if roi is not None] + else: + rois = self.rois + if self.roi_type == 'ratio': + rois = [ + tuple([ + int(roi[0] * frames[0].width), + int(roi[1] * frames[0].height), + int(roi[2] * frames[0].width), + int(roi[3] * frames[0].height) + ]) for roi in self.rois + ] + + if self.detection_method == 'pixel_value': + mask = self._detect_watermark_via_pixel_value(frames, rois) + else: + mask = self._detect_watermark_via_pixel_diversity(frames, rois) + + kernel = np.ones((5, 5), np.uint8) + return cv.dilate(mask, kernel) + + def _clean_watermark(self, frame, watermark_mask): + np_frame = frame.to_ndarray(format='bgr24') + new_np_frame = cv.inpaint(np_frame, watermark_mask, 3, cv.INPAINT_NS) + return av.VideoFrame.from_ndarray(new_np_frame, format='bgr24') + + def process(self, sample, context=False): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + return sample + + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + + for index, video_key in enumerate(loaded_video_keys): + video = videos[video_key] + cleaned_video_key = transfer_filename(video_key, OP_NAME, + **self._init_parameters) + + if (not os.path.exists(cleaned_video_key) + or cleaned_video_key not in loaded_video_keys): + watermark_mask = self._generate_watermark_mask(video, sample) + + def process_frame_func(frame): + return self._clean_watermark(frame, watermark_mask) + + process_each_frame(video, cleaned_video_key, + process_frame_func) + + loaded_video_keys[index] = cleaned_video_key + + if not context: + video.close() + + sample[self.video_key] = loaded_video_keys + return sample diff --git a/data_juicer/ops/mapper/video_split_by_duration_mapper.py b/data_juicer/ops/mapper/video_split_by_duration_mapper.py index 053b3835e..ee7ff83b8 100644 --- a/data_juicer/ops/mapper/video_split_by_duration_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_duration_mapper.py @@ -85,9 +85,8 @@ def split_videos_by_duration(self, video_key, container): def _process_single_sample(self, sample): # there is no video in this sample - if self.video_key not in sample \ - or sample[self.video_key] is None \ - or len(sample[self.video_key]) == 0: + if self.video_key not in sample or sample[ + self.video_key] is None or len(sample[self.video_key]) == 0: return [] # the split results diff --git a/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py b/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py index a91d1ead4..735dc6493 100644 --- a/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py @@ -69,9 +69,8 @@ def get_split_key_frame(self, video_key, container): def _process_single_sample(self, sample): # there is no video in this sample - if self.video_key not in sample \ - or sample[self.video_key] is None \ - or len(sample[self.video_key]) == 0: + if self.video_key not in sample or sample[ + self.video_key] is None or len(sample[self.video_key]) == 0: return [] # the split results diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 739e94e91..093a2e235 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -324,6 +324,62 @@ def cut_video_by_seconds( return os.path.exists(output_video) +def process_each_frame(input_video: Union[str, av.container.InputContainer], + output_video: str, frame_func): + """ + Process each frame in video by replacing each frame by + `frame_func(frame)`. + + :param input_video: the path to input video or the video container. + :param output_video: the path to output video. + :param frame_func: a function which inputs a frame and outputs another + frame. + """ + # open the original video + if isinstance(input_video, str): + container = av.open(input_video) + else: + container = input_video + + # create the output video + output_container = av.open(output_video, 'w') + + # add the audio stream into the output video with template of input audio + for input_audio_stream in container.streams.audio: + output_container.add_stream(template=input_audio_stream) + + # add the video stream into the output video according to input video + for input_video_stream in container.streams.video: + # search from the beginning + container.seek(0, backward=False, any_frame=True) + + codec_name = input_video_stream.codec_context.name + fps = input_video_stream.base_rate + output_video_stream = output_container.add_stream(codec_name, + rate=str(fps)) + output_video_stream.pix_fmt = input_video_stream.codec_context.pix_fmt + output_video_stream.width = input_video_stream.codec_context.width + output_video_stream.height = input_video_stream.codec_context.height + + for packet in container.demux(input_video_stream): + for frame in packet.decode(): + new_frame = frame_func(frame) + # for resize cases + output_video_stream.width = new_frame.width + output_video_stream.height = new_frame.height + for inter_packet in output_video_stream.encode(new_frame): + output_container.mux(inter_packet) + + # flush all packets + for packet in output_video_stream.encode(): + output_container.mux(packet) + + # close the output videos + if isinstance(input_video, str): + container.close() + output_container.close() + + def extract_key_frames(input_video: Union[str, av.container.InputContainer]): """ Extract key frames from the input video. @@ -670,3 +726,39 @@ def timecode_string_to_seconds(timecode: str): # compute the start/end time in second pts = dt.hour * 3600 + dt.minute * 60 + dt.second + dt.microsecond / 1e6 return pts + + +def parse_string_to_roi(roi_string, roi_type='pixel'): + """ + Convert a roi string to four number x1, y1, x2, y2 stand for the region. + When the type is 'pixel', (x1, y1), (x2, y2) are the locations of pixels + in the top left corner and the bottom right corner respectively. If the + roi_type is 'ratio', the coordinates are normalized by wights and + heights. + + :param roi_string: the roi string + :patam roi_type: the roi string type + return tuple of (x1, y1, x2, y2) if roi_string is valid, else None + """ + if not roi_string: + return None + + pattern = r'^\s*[\[\(]?\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*[\]\)]?\s*$' # noqa: E501 + + match = re.match(pattern, roi_string) + + if match: + if roi_type == 'pixel': + return tuple(int(num) for num in match.groups()) + elif roi_type == 'ratio': + return tuple(min(1.0, float(num)) for num in match.groups()) + else: + logger.warning('The roi_type must be "pixel" or "ratio".') + return None + else: + logger.warning( + 'The roi_string must be four no negative numbers in the ' + 'format of "x1, y1, x2, y2", "(x1, y1, x2, y2)", or ' + '"[x1, y1, x2, y2]".') + return None + return None diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index bf248aa82..dbb1cec99 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -140,10 +140,10 @@ class StatsKeys(object): # ... (same as above) self._batched_op = True - def compute_stats(self, sample, rank=None): + def compute_stats(self, sample): # ... (same as above) - def process(self, sample, rank=None): + def process(self, sample): # ... (same as above) ``` @@ -161,7 +161,7 @@ class StatsKeys(object): super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - def process(self, sample, rank=None): + def process(self, sample): # ... (some codes) # captions[index] is the prompt for diffusion model related_parameters = self.add_parameters( @@ -184,7 +184,7 @@ class StatsKeys(object): super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - def process(self, sample, rank=None): + def process(self, sample): # ... (some codes) split_video_path = transfer_filename( original_video_path, OP_NAME, **self._init_parameters) diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index 3c6bb2411..917e6d6b2 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -135,10 +135,10 @@ class StatsKeys(object): # ... (same as above) self._batched_op = True - def compute_stats(self, sample, rank=None): + def compute_stats(self, sample): # ... (same as above) - def process(self, sample, rank=None): + def process(self, sample): # ... (same as above) ``` @@ -156,7 +156,7 @@ class StatsKeys(object): super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - def process(self, sample, rank=None): + def process(self, sample): # ... (some codes) # captions[index] is the prompt for diffusion model related_parameters = self.add_parameters( @@ -179,7 +179,7 @@ class StatsKeys(object): super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - def process(self, sample, rank=None): + def process(self, sample): # ... (some codes) split_video_path = transfer_filename( original_video_path, OP_NAME, **self._init_parameters) diff --git a/docs/Operators.md b/docs/Operators.md index 9409449b3..815cc559b 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 38 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 39 | Edits and transforms samples | | [ Filter ]( #filter ) | 36 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 2 | Selects top samples based on ranking | @@ -80,6 +80,7 @@ All the specific operators are listed below, each featured with several capabili | video_captioning_from_audio_mapper | Multimodal | - | Caption a video according to its audio streams based on Qwen-Audio model | | video_captioning_from_video_mapper | Multimodal | - | generate samples whose captions are generated based on another model (video-blip) and sampled video frame within the original sample | | video_ffmpeg_wrapped_mapper | Video | - | Simple wrapper to run a FFmpeg video filter | +| video_remove_watermark_mapper | Video | - | Remove the watermarks in videos given regions | | video_resize_aspect_ratio_mapper | Video | - | Resize video aspect ratio to a specified range | | video_resize_resolution_mapper | Video | - | Map videos to ones with given resolution range | | video_split_by_duration_mapper | Multimodal | - | Mapper to split video by duration. | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 4517c614c..9defed7e3 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 38 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 39 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 36 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 | @@ -79,6 +79,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | video_captioning_from_audio_mapper | Multimodal | - | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 | | video_captioning_from_video_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(video-blip)和原始样本中的视频中指定帧的图像。 | | video_ffmpeg_wrapped_mapper | Video | - | 运行 FFmpeg 视频过滤器的简单封装 | +| video_remove_watermark_mapper | Video | - | 去除视频中给定区域的水印 | | video_resize_aspect_ratio_mapper | Video | - | 将视频的宽高比调整到指定范围内 | | video_resize_resolution_mapper | Video | - | 将视频映射到给定的分辨率区间 | | video_split_by_duration_mapper | Multimodal | - | 根据时长将视频切分为多个片段 | diff --git a/docs/sphinx_doc/source/data_juicer.ops.mapper.rst b/docs/sphinx_doc/source/data_juicer.ops.mapper.rst index 20e512f8f..18a96e777 100644 --- a/docs/sphinx_doc/source/data_juicer.ops.mapper.rst +++ b/docs/sphinx_doc/source/data_juicer.ops.mapper.rst @@ -258,6 +258,14 @@ data\_juicer.ops.mapper.video\_ffmpeg\_wrapped\_mapper :undoc-members: :show-inheritance: +data\_juicer.ops.mapper.video\_remove\_watermark\_mapper +------------------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.video_remove_watermark_mapper + :members: + :undoc-members: + :show-inheritance: + data\_juicer.ops.mapper.video\_resize\_aspect\_ratio\_mapper ------------------------------------------------------------------- diff --git a/environments/science_requires.txt b/environments/science_requires.txt index 4421aad0d..4faa330c8 100644 --- a/environments/science_requires.txt +++ b/environments/science_requires.txt @@ -24,3 +24,4 @@ diffusers simple-aesthetics-predictor scenedetect[opencv] ffmpeg-python +opencv-python diff --git a/tests/ops/mapper/test_video_remove_watermark_mapper.py b/tests/ops/mapper/test_video_remove_watermark_mapper.py new file mode 100644 index 000000000..4e5463ac8 --- /dev/null +++ b/tests/ops/mapper/test_video_remove_watermark_mapper.py @@ -0,0 +1,108 @@ +import os +import shutil +import unittest + +from datasets import Dataset + +from data_juicer.ops.mapper.video_remove_watermark_mapper import \ + VideoRemoveWatermarkMapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + + +@SKIPPED_TESTS.register_module() +class VideoRemoveWatermarkMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + # video1: horizontal resolution 640p, vertical resolution 360p + vid1_path = os.path.join(data_path, 'video1.mp4') + + def _run_video_remove_watermask_mapper(self, + dataset: Dataset, + op, + test_name, + np=1): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.process, num_proc=np) + dataset = dataset.select_columns(column_names=[op.video_key]) + + # check each video personally + output_dir = '../video_remove_watermark_mapper' + move_to_dir = os.path.join(output_dir, test_name) + if not os.path.exists(move_to_dir): + os.makedirs(move_to_dir) + for sample in dataset.to_list(): + for value in sample['videos']: + move_to_path = os.path.join(move_to_dir, + os.path.basename(value)) + shutil.copyfile(value, move_to_path) + + def test_roi_pixel_type(self): + ds_list = [{'videos': [self.vid1_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoRemoveWatermarkMapper(roi_strings=['[0, 0, 150, 60]'], + roi_type='pixel') + self._run_video_remove_watermask_mapper(dataset, op, + 'test_roi_pixel_type') + + def test_multi_roi_region(self): + ds_list = [{'videos': [self.vid1_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoRemoveWatermarkMapper( + roi_strings=['[0, 0, 150, 60]', '[30, 60, 75, 140]'], + roi_type='pixel') + self._run_video_remove_watermask_mapper(dataset, op, + 'test_multi_roi_region') + + def test_roi_ratio_type(self): + ds_list = [{'videos': [self.vid1_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoRemoveWatermarkMapper( + roi_strings=['[0, 0, 0.234375, 0.16667]'], roi_type='ratio') + self._run_video_remove_watermask_mapper(dataset, op, + 'test_roi_ratio_type') + + def test_frame_num_and_frame_threshold(self): + ds_list = [{'videos': [self.vid1_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoRemoveWatermarkMapper(roi_strings=['[0, 0, 150, 60]'], + roi_type='pixel', + frame_num=100, + min_frame_threshold=100) + self._run_video_remove_watermask_mapper( + dataset, op, 'test_frame_num_and_frame_threshold') + + def test_roi_key(self): + ds_list = [{ + 'videos': [self.vid1_path], + 'roi_strings': ['[30, 60, 75, 300]'], + }, { + 'videos': [self.vid1_path], + 'roi_strings': ['[30, 60, 75, 140]', '30, 140, 53, 300', 'none'], + }, { + 'videos': [self.vid1_path], + 'roi_strings': + ['[30, 60, 75, 140]', '30, 140, 53, 200', '(30, 200, 53, 300)'], + }] + dataset = Dataset.from_list(ds_list) + op = VideoRemoveWatermarkMapper(roi_type='pixel', + roi_key='roi_strings') + self._run_video_remove_watermask_mapper(dataset, op, 'test_roi_key') + + def test_detection_method(self): + ds_list = [{'videos': [self.vid1_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoRemoveWatermarkMapper( + roi_strings=['[0, 0, 150, 60]', '[30, 60, 75, 300]'], + roi_type='pixel', + detection_method='pixel_diversity') + self._run_video_remove_watermask_mapper(dataset, op, + 'test_detection_method') + + +if __name__ == '__main__': + unittest.main()