diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 33298cbaf..47de328ad 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -40,6 +40,8 @@ executor_type: default # type of executor, ray_address: auto # the address of the Ray cluster. # only for data analysis +percentiles: [0.25, 0.5, 0.75] # percentiles to analyse the dataset distribution +export_original_dataset: false # whether to export the original dataset with stats. If you only need the stats of the dataset, setting it to false could speed up the exporting. save_stats_in_one_file: false # whether to store all stats result into one file # for sandbox or hpo @@ -478,7 +480,9 @@ process: ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash. - image_deduplicator: # deduplicator to deduplicate samples at document-level using exact matching of images between documents. method: phash # hash method for image. One of [phash, dhash, whash, ahash] + consider_text: false # whether to consider text hash together with image hash when applying deduplication. - video_deduplicator: # deduplicator to deduplicate samples at document-level using exact matching of videos between documents. + consider_text: false # whether to consider text hash together with video hash when applying deduplication. - ray_video_deduplicator: # the simple video deduplicator that can run on multi-nodes using md5 hashing exact matching method redis_host: 'redis_host' # the host of the redis instance redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 1755ae920..d05b50c3d 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -301,6 +301,19 @@ def init_configs(args=None): type=List[Dict], help='List of several operators with their arguments, these ops will ' 'be applied to dataset in order') + parser.add_argument( + '--percentiles', + type=List[float], + default=[], + help='Percentiles to analyse the dataset distribution. Only used in ' + 'Analysis.') + parser.add_argument( + '--export_original_dataset', + type=bool, + default=False, + help='whether to export the original dataset with stats. If you only ' + 'need the stats of the dataset, setting it to false could speed ' + 'up the exporting..') parser.add_argument( '--save_stats_in_one_file', type=bool, diff --git a/data_juicer/core/analyser.py b/data_juicer/core/analyser.py index 04c48b33b..641471156 100644 --- a/data_juicer/core/analyser.py +++ b/data_juicer/core/analyser.py @@ -51,12 +51,14 @@ def __init__(self, cfg=None): # (export_ds=False). Instead, only need to export stats # (export_stats=True). logger.info('Preparing exporter...') - self.exporter = Exporter(self.cfg.export_path, - self.cfg.export_shard_size, - self.cfg.export_in_parallel, - self.cfg.np, - export_ds=False, - export_stats=True) + self.exporter = Exporter( + self.cfg.export_path, + self.cfg.export_shard_size, + self.cfg.export_in_parallel, + self.cfg.np, + export_ds=self.cfg.export_original_dataset, + keep_stats_in_res_ds=self.cfg.export_original_dataset, + export_stats=True) # parsed_res self.overall_result = None @@ -121,8 +123,10 @@ def run(self, load_data_np=None, skip_export=False): logger.info('Applying overall analysis on stats...') overall_analysis = OverallAnalysis(dataset, self.analysis_path) - self.overall_result = overall_analysis.analyse(num_proc=self.cfg.np, - skip_export=skip_export) + self.overall_result = overall_analysis.analyse( + percentiles=self.cfg.percentiles, + num_proc=self.cfg.np, + skip_export=skip_export) logger.info(f'The overall analysis results are: {self.overall_result}') diff --git a/data_juicer/core/exporter.py b/data_juicer/core/exporter.py index fe74aacc5..72b555d34 100644 --- a/data_juicer/core/exporter.py +++ b/data_juicer/core/exporter.py @@ -125,8 +125,11 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True): dataset = dataset.remove_columns(removed_fields) if not self.keep_hashes_in_res_ds: extra_fields = { - HashKeys.hash, HashKeys.minhash, HashKeys.simhash, - HashKeys.imagehash + HashKeys.hash, + HashKeys.minhash, + HashKeys.simhash, + HashKeys.imagehash, + HashKeys.videohash, } feature_fields = set(dataset.features.keys()) removed_fields = extra_fields.intersection(feature_fields) diff --git a/data_juicer/ops/deduplicator/image_deduplicator.py b/data_juicer/ops/deduplicator/image_deduplicator.py index 50ccc1014..d61e18cea 100644 --- a/data_juicer/ops/deduplicator/image_deduplicator.py +++ b/data_juicer/ops/deduplicator/image_deduplicator.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, Set +from typing import Dict, Set, Tuple import numpy as np @@ -9,6 +9,7 @@ from ..base_op import OPERATORS, Deduplicator from ..op_fusion import LOADED_IMAGES +from .document_deduplicator import DocumentDeduplicator OP_NAME = 'image_deduplicator' @@ -38,11 +39,17 @@ class ImageDeduplicator(Deduplicator): of images between documents. """ - def __init__(self, method: str = 'phash', *args, **kwargs): + def __init__(self, + method: str = 'phash', + consider_text: bool = False, + *args, + **kwargs): """ Initialization method. :param method: hash method for image + :param consider_text: whether to consider text hash together with image + hash when applying deduplication. :param args: extra args :param kwargs: extra args """ @@ -51,6 +58,10 @@ def __init__(self, method: str = 'phash', *args, **kwargs): raise ValueError(f'Keep strategy [{method}] is not supported. ' f'Can only be one of {HASH_METHOD}.') self.hasher = get_hash_method(method)() + self.consider_text = consider_text + self.text_dedup_op = None + if self.consider_text: + self.text_dedup_op = DocumentDeduplicator(**kwargs) def compute_hash(self, sample, context=False): # check if it's computed already @@ -71,6 +82,8 @@ def compute_hash(self, sample, context=False): for key in images: sample[HashKeys.imagehash] += self.hasher.encode_image( image_array=np.array(images[key])) + if self.consider_text: + sample = self.text_dedup_op.compute_hash(sample) return sample def process(self, dataset, show_num=0): @@ -89,8 +102,14 @@ def process(self, dataset, show_num=0): dup_hashes = None if show_num > 0: # sample duplicate pairs - hash2ids: Dict[int, Set[int]] = defaultdict(set) - for sid, hash_val in enumerate(dataset[HashKeys.imagehash]): + if self.consider_text: + hash2ids: Dict[Tuple[int], Set[int]] = defaultdict(set) + hashes = zip(dataset[HashKeys.imagehash], + dataset[HashKeys.hash]) + else: + hash2ids: Dict[int, Set[int]] = defaultdict(set) + hashes = dataset[HashKeys.imagehash] + for sid, hash_val in enumerate(hashes): if hash_val: hash2ids[hash_val].add(sid) dup_samples = sorted(list(hash2ids.items()), @@ -101,7 +120,10 @@ def process(self, dataset, show_num=0): ][:show_num]) def _filter_dup_helper(sample, hashes): - hash = sample[HashKeys.imagehash] + if self.consider_text: + hash = (sample[HashKeys.imagehash], sample[HashKeys.hash]) + else: + hash = sample[HashKeys.imagehash] if not hash: return True if show_num > 0 and hash in dup_hashes \ diff --git a/data_juicer/ops/deduplicator/video_deduplicator.py b/data_juicer/ops/deduplicator/video_deduplicator.py index 205e5a8d3..8073fec2e 100644 --- a/data_juicer/ops/deduplicator/video_deduplicator.py +++ b/data_juicer/ops/deduplicator/video_deduplicator.py @@ -1,12 +1,13 @@ import hashlib from collections import defaultdict -from typing import Dict, Set +from typing import Dict, Set, Tuple from data_juicer.utils.constant import HashKeys from data_juicer.utils.mm_utils import load_data_with_context, load_video from ..base_op import OPERATORS, Deduplicator from ..op_fusion import LOADED_VIDEOS +from .document_deduplicator import DocumentDeduplicator OP_NAME = 'video_deduplicator' @@ -19,14 +20,20 @@ class VideoDeduplicator(Deduplicator): of videos between documents. """ - def __init__(self, *args, **kwargs): + def __init__(self, consider_text: bool = False, *args, **kwargs): """ Initialization. + :param consider_text: whether to consider text hash together with video + hash when applying deduplication. :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) + self.consider_text = consider_text + self.text_dedup_op = None + if self.consider_text: + self.text_dedup_op = DocumentDeduplicator(**kwargs) def compute_hash(self, sample, context=False): # check if it's computed already @@ -52,6 +59,8 @@ def compute_hash(self, sample, context=False): md5_hash.update(bytes(packet)) sample[HashKeys.videohash] = md5_hash.hexdigest() + if self.consider_text: + sample = self.text_dedup_op.compute_hash(sample) return sample def process(self, dataset, show_num=0): @@ -70,8 +79,14 @@ def process(self, dataset, show_num=0): dup_hashes = None if show_num > 0: # sample duplicate pairs - hash2ids: Dict[int, Set[int]] = defaultdict(set) - for sid, hash_val in enumerate(dataset[HashKeys.videohash]): + if self.consider_text: + hash2ids: Dict[Tuple[int], Set[int]] = defaultdict(set) + hashes = zip(dataset[HashKeys.videohash], + dataset[HashKeys.hash]) + else: + hash2ids: Dict[int, Set[int]] = defaultdict(set) + hashes = dataset[HashKeys.videohash] + for sid, hash_val in enumerate(hashes): if hash_val: hash2ids[hash_val].add(sid) dup_samples = sorted(list(hash2ids.items()), @@ -82,7 +97,10 @@ def process(self, dataset, show_num=0): ][:show_num]) def _filter_dup_helper(sample, hashes): - hash = sample[HashKeys.videohash] + if self.consider_text: + hash = (sample[HashKeys.videohash], sample[HashKeys.hash]) + else: + hash = sample[HashKeys.videohash] if not hash: return True if show_num > 0 and hash in dup_hashes \ diff --git a/tests/ops/deduplicator/test_image_deduplicator.py b/tests/ops/deduplicator/test_image_deduplicator.py index 0c5752b4b..53c85758d 100644 --- a/tests/ops/deduplicator/test_image_deduplicator.py +++ b/tests/ops/deduplicator/test_image_deduplicator.py @@ -32,10 +32,12 @@ class ImageDeduplicatorTest(DataJuicerTestCaseBase): os.symlink(img6_path, img7_path) def _run_image_deduplicator(self, dataset: Dataset, target_list, op): + key_list = [op.image_key, op.text_key] \ + if op.consider_text else [op.image_key] dataset = dataset.map(op.compute_hash) dataset, _ = op.process(dataset) - dataset = dataset.select_columns(column_names=[op.image_key]) + dataset = dataset.select_columns(column_names=key_list) res_list = dataset.to_list() self.assertEqual(res_list, target_list) @@ -101,6 +103,50 @@ def test_3(self): op = ImageDeduplicator() self._run_image_deduplicator(dataset, tgt_list, op) + def test_3_consider_text(self): + + ds_list = [{ + 'images': [self.img1_path], + 'text': '