diff --git a/data_juicer/__init__.py b/data_juicer/__init__.py index 8ce9b3623..e40cf396b 100644 --- a/data_juicer/__init__.py +++ b/data_juicer/__init__.py @@ -1 +1,78 @@ __version__ = '0.1.3' + +import os +import subprocess + +import multiprocess as mp +from loguru import logger + + +def _cuda_device_count(): + try: + nvidia_smi_output = subprocess.check_output(['nvidia-smi', '-L'], + text=True) + all_devices = nvidia_smi_output.strip().split('\n') + + cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + + if cuda_visible_devices: + visible_devices = cuda_visible_devices.split(',') + visible_devices = [int(dev.strip()) for dev in visible_devices] + num_visible_devices = sum(1 for dev in visible_devices + if 0 <= dev < len(all_devices)) + else: + num_visible_devices = len(all_devices) + + return num_visible_devices + except Exception: + # nvidia-smi not found or other error + return 0 + + +_USE_CUDA = False +_CUDA_COUNT = _cuda_device_count() + + +def use_cuda(): + return _USE_CUDA + + +def cuda_device_count(): + return _CUDA_COUNT + + +def setup_mp(): + method = os.getenv('MP_START_METHOD', 'auto').lower() + if method == 'auto': + if _CUDA_COUNT > 0: + # forkserver is more lightweight + method = ('forkserver' if 'forkserver' + in mp.get_all_start_methods() else 'spawn') + else: + method = 'fork' + try: + logger.info(f"Setting multiprocess start method to '{method}'.") + mp.set_start_method(method, force=True) + except RuntimeError as e: + logger.warning(f'Error setting multiprocess start method: {e}') + + +def setup_cuda(): + global _USE_CUDA + + method = mp.get_start_method() + if method != 'fork' and _CUDA_COUNT > 0: + _USE_CUDA = True + else: + _USE_CUDA = False + logger.debug(f'_USE_CUDA: {_USE_CUDA} | MP: {method} ' + f'({mp.current_process().name})') + + +def initialize(): + if mp.current_process().name == 'MainProcess': + setup_mp() + setup_cuda() + + +initialize() diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index 93d4cb30b..17501d441 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -3,6 +3,7 @@ from loguru import logger +from data_juicer import cuda_device_count, use_cuda from data_juicer.config import init_configs from data_juicer.format.load import load_formatter from data_juicer.ops import (OPERATORS, Deduplicator, Filter, Mapper, Selector, @@ -115,10 +116,17 @@ def run(self, load_data_np=None): for op_cfg, op in zip(self.process_list, self.ops): op_name, op_args = list(op_cfg.items())[0] prev = dataset # record last dataset + if use_cuda() and op._accelerator == 'cuda': + op_proc = min(cuda_device_count(), self.cfg.np) + with_rank = True + else: + op_proc = self.cfg.np + with_rank = False try: if isinstance(op, Mapper): tmp = dataset.map(function=op.process, - num_proc=self.cfg.np, + num_proc=op_proc, + with_rank=with_rank, desc=op_name + '_process') if self.open_tracer and \ op_name in self.op_list_to_trace: @@ -142,7 +150,8 @@ def run(self, load_data_np=None): if self.cfg.use_checkpoint: prev = dataset dataset = dataset.map(op.compute_stats, - num_proc=self.cfg.np, + num_proc=op_proc, + with_rank=with_rank, desc=op_name + '_compute_stats') if self.cfg.use_checkpoint: prev = dataset @@ -157,7 +166,8 @@ def run(self, load_data_np=None): self.tracer.trace_filter(op_name, dataset, tmp) elif isinstance(op, Deduplicator): dataset = dataset.map(op.compute_hash, - num_proc=self.cfg.np, + num_proc=op_proc, + with_rank=with_rank, desc=op_name + '_compute_hash') if self.cfg.use_checkpoint: prev = dataset diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 8ee7e893c..6b494e5ec 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -31,6 +31,7 @@ def __init__( if audio_key is None: audio_key = 'audios' self.audio_key = audio_key + self._accelerator = 'cpu' from data_juicer.core.data import wrap_func_with_nested_access self.process = wrap_func_with_nested_access(self.process) diff --git a/data_juicer/ops/deduplicator/document_simhash_deduplicator.py b/data_juicer/ops/deduplicator/document_simhash_deduplicator.py index 28f109b6e..3d1afa475 100644 --- a/data_juicer/ops/deduplicator/document_simhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/document_simhash_deduplicator.py @@ -21,46 +21,6 @@ with AvailabilityChecking(['simhash-pybind'], OP_NAME): import simhash - def local_num_differing_bits(hash_a, hash_b): - """ - Local implementation of calculating the number of different bits - between two integers. - - :param hash_a: integer hash value a - :param hash_b: integer hash value b - :return: number of different bits between input hashes. - """ - cnt = 0 - n = hash_a ^ hash_b - while n != 0: - cnt += 1 - n = n & (n - 1) - return cnt - - def num_differing_bits_selector(): - """ - Select a num_differing_bits method according to the Python version - installed. - - When Python >= 3.9, the original simhash library cannot be compiled - correctly due to some changes in cython. After fixing this - incompatibility, RecursionError occurs sometimes when calling - simhash.num_differing_bits. So we use our implementation when Python - >= 3.9. Otherwise, we use implementation of simhash. - - :return: an available num_differing_bits function. - """ - import platform - a, b, _ = platform.python_version().split('.') - if a == '3' and int(b) >= 9: - # for >= 3.9, use local implementation - return local_num_differing_bits - else: - # for < 3.9, use simhash version - return simhash.num_differing_bits - - num_differing_bits = num_differing_bits_selector() - @OPERATORS.register_module(OP_NAME) class DocumentSimhashDeduplicator(Deduplicator): diff --git a/data_juicer/ops/filter/alphanumeric_filter.py b/data_juicer/ops/filter/alphanumeric_filter.py index 111c97ddc..88e93c534 100644 --- a/data_juicer/ops/filter/alphanumeric_filter.py +++ b/data_juicer/ops/filter/alphanumeric_filter.py @@ -51,7 +51,7 @@ def __init__(self, if tokenization: self.model_key = prepare_model( model_type='huggingface', - model_name_or_path='EleutherAI/pythia-6.9b-deduped', + pretrained_model_name_or_path='EleutherAI/pythia-6.9b-deduped', return_model=False) def compute_stats(self, sample): diff --git a/data_juicer/ops/filter/image_text_matching_filter.py b/data_juicer/ops/filter/image_text_matching_filter.py index 688b78420..294c140b4 100644 --- a/data_juicer/ops/filter/image_text_matching_filter.py +++ b/data_juicer/ops/filter/image_text_matching_filter.py @@ -69,12 +69,13 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') self.model_key = prepare_model(model_type='huggingface', - model_name_or_path=hf_blip) + pretrained_model_name_or_path=hf_blip) + self._accelerator = 'cuda' self.reduce_mode = reduce_mode self.horizontal_flip = horizontal_flip self.vertical_flip = vertical_flip - def compute_stats(self, sample, context=False): + def compute_stats(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.image_text_matching_score in sample[Fields.stats]: return sample @@ -94,7 +95,7 @@ def compute_stats(self, sample, context=False): text = sample[self.text_key] offset = 0 matching_scores = [] - model, processor = get_model(self.model_key) + model, processor = get_model(self.model_key, rank=rank) for chunk in text.split(SpecialTokens.eoc): count = chunk.count(SpecialTokens.image) @@ -139,7 +140,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process(self, sample, rank=None): itm_scores = sample[Fields.stats][StatsKeys.image_text_matching_score] if len(itm_scores) <= 0: return True diff --git a/data_juicer/ops/filter/image_text_similarity_filter.py b/data_juicer/ops/filter/image_text_similarity_filter.py index d67356cd1..5efe2f107 100644 --- a/data_juicer/ops/filter/image_text_similarity_filter.py +++ b/data_juicer/ops/filter/image_text_similarity_filter.py @@ -70,12 +70,13 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') self.model_key = prepare_model(model_type='huggingface', - model_name_or_path=hf_clip) + pretrained_model_name_or_path=hf_clip) + self._accelerator = 'cuda' self.reduce_mode = reduce_mode self.horizontal_flip = horizontal_flip self.vertical_flip = vertical_flip - def compute_stats(self, sample, context=False): + def compute_stats(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.image_text_similarity in sample[Fields.stats]: return sample @@ -94,7 +95,7 @@ def compute_stats(self, sample, context=False): text = sample[self.text_key] offset = 0 similarity = [] - model, processor = get_model(self.model_key) + model, processor = get_model(self.model_key, rank=rank) for chunk in text.split(SpecialTokens.eoc): count = chunk.count(SpecialTokens.image) @@ -137,7 +138,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process(self, sample, rank=None): similarity = sample[Fields.stats][StatsKeys.image_text_similarity] if len(similarity) <= 0: return True diff --git a/data_juicer/ops/filter/phrase_grounding_recall_filter.py b/data_juicer/ops/filter/phrase_grounding_recall_filter.py index f19f21843..bf91a41ce 100644 --- a/data_juicer/ops/filter/phrase_grounding_recall_filter.py +++ b/data_juicer/ops/filter/phrase_grounding_recall_filter.py @@ -130,7 +130,8 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') self.model_key = prepare_model(model_type='huggingface', - model_name_or_path=hf_owlvit) + pretrained_model_name_or_path=hf_owlvit) + self._accelerator = 'cuda' self.reduce_mode = reduce_mode self.horizontal_flip = horizontal_flip self.vertical_flip = vertical_flip @@ -144,7 +145,7 @@ def __init__(self, for nltk_data_pkg in requires_nltk_data: nltk.download(nltk_data_pkg) - def compute_stats(self, sample, context=False): + def compute_stats(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.phrase_grounding_recall in sample[Fields.stats]: return sample @@ -163,7 +164,7 @@ def compute_stats(self, sample, context=False): text = sample[self.text_key] offset = 0 recalls = [] - model, processor = get_model(self.model_key) + model, processor = get_model(self.model_key, rank=rank) for chunk in text.split(SpecialTokens.eoc): count = chunk.count(SpecialTokens.image) @@ -197,8 +198,9 @@ def compute_stats(self, sample, context=False): with torch.no_grad(): outputs = model(**inputs) - target_sizes = torch.tensor( - [img.size[::-1] for img in images_this_chunk]) + target_sizes = torch.tensor([ + img.size[::-1] for img in images_this_chunk + ]).to(model.device) results = processor.post_process_object_detection( outputs, threshold=self.conf_thr, diff --git a/data_juicer/ops/filter/token_num_filter.py b/data_juicer/ops/filter/token_num_filter.py index 342e77cbd..793a103b2 100644 --- a/data_juicer/ops/filter/token_num_filter.py +++ b/data_juicer/ops/filter/token_num_filter.py @@ -43,9 +43,10 @@ def __init__(self, self.min_num = min_num self.max_num = max_num self.hf_tokenizer = hf_tokenizer - self.model_key = prepare_model(model_type='huggingface', - model_name_or_path=hf_tokenizer, - return_model=False) + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=hf_tokenizer, + return_model=False) def compute_stats(self, sample): # check if it's computed already diff --git a/data_juicer/ops/mapper/generate_caption_mapper.py b/data_juicer/ops/mapper/generate_caption_mapper.py index 032743604..8cd587182 100644 --- a/data_juicer/ops/mapper/generate_caption_mapper.py +++ b/data_juicer/ops/mapper/generate_caption_mapper.py @@ -88,7 +88,8 @@ def __init__(self, f'["random_any", "similar_one_simhash", "all"].') self.model_key = prepare_model(model_type='huggingface', - model_name_or_path=hf_blip2) + pretrained_model_name_or_path=hf_blip2) + self._accelerator = 'cuda' self.caption_num = caption_num self.keep_candidate_mode = keep_candidate_mode self.keep_original_sample = keep_original_sample @@ -109,7 +110,7 @@ def __init__(self, 'Both the parameter `prompt` and `prompt_key` are ' 'set. Data-Juicer will consider `prompt_key` first.') - def _process_single_sample(self, ori_sample): + def _process_single_sample(self, ori_sample, rank=None): """ :param ori_sample: a single data sample before applying generation @@ -146,7 +147,7 @@ def _process_single_sample(self, ori_sample): # the generated text will be placed following each SpecialTokens.img # and the original special tokens are kept in an order-preserving way. - model, processor = get_model(self.model_key) + model, processor = get_model(self.model_key, rank=rank) # do generation for each image chunk by chunk for chunk in ori_sample[self.text_key].split(SpecialTokens.eoc): @@ -233,8 +234,10 @@ def _reduce_captions_per_image(self, chunk, new_generated_text_per_chunk.extend( generated_text_candidates_single_chunk) elif self.keep_candidate_mode == 'similar_one_simhash': - from ..deduplicator.document_simhash_deduplicator import ( - DocumentSimhashDeduplicator, num_differing_bits) + from simhash import num_differing_bits + + from ..deduplicator.document_simhash_deduplicator import \ + DocumentSimhashDeduplicator ori_normal_text = remove_special_tokens(chunk) # using a simhash OP to calculate their similarity # NOTE: simhash is just one method to calculate the similarities @@ -262,7 +265,7 @@ def _reduce_captions_per_image(self, chunk, generated_text_candidates_single_chunk[max_index]) return new_generated_text_per_chunk - def process(self, samples): + def process(self, samples, rank=None): """ Note: This is a batched_OP, whose the input and output type are both list. Suppose there are $N$ input sample list with batch @@ -284,7 +287,8 @@ def process(self, samples): for ori_sample in reconstructed_samples: if self.keep_original_sample: samples_after_generation.append(ori_sample) - generated_samples = self._process_single_sample(ori_sample) + generated_samples = self._process_single_sample(ori_sample, + rank=rank) if len(generated_samples) != 0: samples_after_generation.extend(generated_samples) # reconstruct samples from "list of dicts" to "dict of lists" diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index b989649f6..4569c0de1 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -2,9 +2,12 @@ import os from functools import partial +import multiprocess as mp import wget from loguru import logger +from data_juicer import use_cuda + from .cache_utils import DATA_JUICER_MODELS_CACHE as DJMC MODEL_ZOO = {} @@ -173,13 +176,13 @@ def prepare_nltk_model(lang, name_pattern='punkt.{}.pickle'): return nltk_model -def prepare_huggingface_model(model_name_or_path, +def prepare_huggingface_model(pretrained_model_name_or_path, return_model=True, trust_remote_code=False): """ Prepare and load a HuggingFace model with the correspoding processor. - :param model_name: model name or path + :param pretrained_model_name_or_path: model name or path :param return_model: return model or not :param trust_remote_code: passed to transformers :return: a tuple (model, input processor) if `return_model` is True; @@ -195,25 +198,25 @@ def prepare_huggingface_model(model_name_or_path, from transformers.models.auto.tokenization_auto import \ TOKENIZER_MAPPING_NAMES - config = AutoConfig.from_pretrained(model_name_or_path) + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) # TODO: What happens when there are more than one? arch = config.architectures[0] model_class = getattr(transformers, arch) model_type = config.model_type if model_type in PROCESSOR_MAPPING_NAMES: processor = AutoProcessor.from_pretrained( - model_name_or_path, trust_remote_code=trust_remote_code) + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) elif model_type in IMAGE_PROCESSOR_MAPPING_NAMES: processor = AutoImageProcessor.from_pretrained( - model_name_or_path, trust_remote_code=trust_remote_code) + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) elif model_type in TOKENIZER_MAPPING_NAMES: processor = AutoTokenizer.from_pretrained( - model_name_or_path, trust_remote_code=trust_remote_code) + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) else: processor = None if return_model: - model = model_class.from_pretrained(model_name_or_path) + model = model_class.from_pretrained(pretrained_model_name_or_path) return (model, processor) if return_model else processor @@ -275,11 +278,32 @@ def prepare_model(model_type, **model_kwargs): return model_key -def get_model(model_key=None): - global MODEL_ZOO +def move_to_cuda(model, rank): + # Assuming model can be either a single module or a tuple of modules + if not isinstance(model, tuple): + model = (model, ) + + for module in model: + if callable(getattr(module, 'to', None)): + logger.info( + f'Moving {module.__class__.__name__} to CUDA device {rank}') + module.to(f'cuda:{rank}') + # Optionally, verify the device assignment + logger.debug(f'{module.__class__.__name__} is on device ' + f'{next(module.parameters()).device}') + + +def get_model(model_key=None, rank=None): if model_key is None: - logger.warning('Please specify model_key to get models') return None + + global MODEL_ZOO if model_key not in MODEL_ZOO: + logger.debug( + f'{model_key} not found in MODEL_ZOO ({mp.current_process().name})' + ) MODEL_ZOO[model_key] = model_key() + if use_cuda(): + rank = 0 if rank is None else rank + move_to_cuda(MODEL_ZOO[model_key], rank) return MODEL_ZOO[model_key] diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index ff437367f..b8cdb3fd6 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -53,56 +53,78 @@ class StatsKeys(object): 2. Create a new OP file `text_length_filter.py` in the corresponding `data_juicer/ops/filter/` directory as follows. - Because it's a Filter OP, so the new OP needs to inherit from the basic `Filter` class in the `base_op.py`, and be decorated with `OPERATORS` to register itself automatically. -```python -import sys - -from jsonargparse.typing import PositiveInt - -from data_juicer.utils.constant import Fields, StatsKeys - -from ..base_op import OPERATORS, Filter - - -@OPERATORS.register_module('text_length_filter') -class TextLengthFilter(Filter): - """Filter to keep samples with total text length within a specific - range.""" - - def __init__(self, - min_len: PositiveInt = 10, - max_len: PositiveInt = sys.maxsize, - *args, - **kwargs): - """ - Initialization method. - - :param min_len: The min text length in the filtering. samples - will be filtered if their text length is below this - parameter. - :param max_len: The max text length in the filtering. samples - will be filtered if their text length exceeds this - parameter. - :param args: extra args - :param kwargs: extra args - """ - super().__init__(*args, **kwargs) - self.min_len = min_len - self.max_len = max_len - - def compute_stats(self, sample): - # check if it's computed already - if StatsKeys.text_len in sample[Fields.stats]: + ```python + import sys + + from jsonargparse.typing import PositiveInt + + from data_juicer.utils.constant import Fields, StatsKeys + + from ..base_op import OPERATORS, Filter + + + @OPERATORS.register_module('text_length_filter') + class TextLengthFilter(Filter): + """Filter to keep samples with total text length within a specific + range.""" + + def __init__(self, + min_len: PositiveInt = 10, + max_len: PositiveInt = sys.maxsize, + *args, + **kwargs): + """ + Initialization method. + + :param min_len: The min text length in the filtering. samples + will be filtered if their text length is below this + parameter. + :param max_len: The max text length in the filtering. samples + will be filtered if their text length exceeds this + parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_len = min_len + self.max_len = max_len + + def compute_stats(self, sample): + # check if it's computed already + if StatsKeys.text_len in sample[Fields.stats]: + return sample + + sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key]) return sample - sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key]) - return sample + def process(self, sample): + if self.min_len <= sample[Fields.stats][StatsKeys.text_len] <= self.max_len: + return True + else: + return False + ``` - def process(self, sample): - if self.min_len <= sample[Fields.stats][StatsKeys.text_len] <= self.max_len: - return True - else: - return False -``` + - If Hugging Face models are used within an operator, you might want to leverage GPU acceleration. To achieve this, declare `self._accelerator = 'cuda'` in the constructor, and ensure that `compute_stats` and `process` methods accept an additional positional argument `rank`. + + ```python + # ... (same as above) + + @OPERATORS.register_module('text_length_filter') + class TextLengthFilter(Filter): + def __init__(self, + min_len: PositiveInt = 10, + max_len: PositiveInt = sys.maxsize, + *args, + **kwargs): + # ... (same as above) + self._accelerator = 'cuda' + + def compute_stats(self, sample, rank=None): + # ... (same as above) + + def process(self, sample, rank=None): + # ... (same as above) + ``` 3. After implemention, add it to the OP dictionary in the `__init__.py` file in `data_juicer/ops/filter/` directory. @@ -249,9 +271,7 @@ class WordNumFilter(Filter): ```python # before modification ... -tokenizer = get_model(self.model_key, - lang=self.lang, - model_type='sentencepiece') +tokenizer = get_model(self.model_key) words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) @@ -265,9 +285,7 @@ if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: # normal calculation process - tokenizer = get_model(self.model_key, - lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key) words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index ffbd78377..9253139fc 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -48,56 +48,78 @@ class StatsKeys(object): 2. 在 `data_juicer/ops/filter/` 目录下创建一个新的算子文件 `text_length_filter.py`,内容如下: - 因为它是一个 Filter 算子,所以需要继承 `base_op.py` 中的 `Filter` 基类,并用 `OPERATORS` 修饰以实现自动注册。 -```python -import sys - -from jsonargparse.typing import PositiveInt - -from data_juicer.utils.constant import Fields, StatsKeys - -from ..base_op import OPERATORS, Filter - - -@OPERATORS.register_module('text_length_filter') -class TextLengthFilter(Filter): - """Filter to keep samples with total text length within a specific - range.""" - - def __init__(self, - min_len: PositiveInt = 10, - max_len: PositiveInt = sys.maxsize, - *args, - **kwargs): - """ - Initialization method. - - :param min_len: The min text length in the filtering. samples - will be filtered if their text length is below this - parameter. - :param max_len: The max text length in the filtering. samples - will be filtered if their text length exceeds this - parameter. - :param args: extra args - :param kwargs: extra args - """ - super().__init__(*args, **kwargs) - self.min_len = min_len - self.max_len = max_len - - def compute_stats(self, sample): - # check if it's computed already - if StatsKeys.text_len in sample[Fields.stats]: + ```python + import sys + + from jsonargparse.typing import PositiveInt + + from data_juicer.utils.constant import Fields, StatsKeys + + from ..base_op import OPERATORS, Filter + + + @OPERATORS.register_module('text_length_filter') + class TextLengthFilter(Filter): + """Filter to keep samples with total text length within a specific + range.""" + + def __init__(self, + min_len: PositiveInt = 10, + max_len: PositiveInt = sys.maxsize, + *args, + **kwargs): + """ + Initialization method. + + :param min_len: The min text length in the filtering. samples + will be filtered if their text length is below this + parameter. + :param max_len: The max text length in the filtering. samples + will be filtered if their text length exceeds this + parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_len = min_len + self.max_len = max_len + + def compute_stats(self, sample): + # check if it's computed already + if StatsKeys.text_len in sample[Fields.stats]: + return sample + + sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key]) return sample - sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key]) - return sample + def process(self, sample): + if self.min_len <= sample[Fields.stats][StatsKeys.text_len] <= self.max_len: + return True + else: + return False + ``` - def process(self, sample): - if self.min_len <= sample[Fields.stats][StatsKeys.text_len] <= self.max_len: - return True - else: - return False -``` + - 如果在算子中使用了 Hugging Face 模型,您可能希望利用 GPU 加速。为了实现这一点,请在构造函数中声明 `self._accelerator = 'cuda'`,并确保 `compute_stats` 和 `process` 方法接受一个额外的位置参数 `rank`。 + + ```python + # ... (same as above) + + @OPERATORS.register_module('text_length_filter') + class TextLengthFilter(Filter): + def __init__(self, + min_len: PositiveInt = 10, + max_len: PositiveInt = sys.maxsize, + *args, + **kwargs): + # ... (same as above) + self._accelerator = 'cuda' + + def compute_stats(self, sample, rank=None): + # ... (same as above) + + def process(self, sample, rank=None): + # ... (same as above) + ``` 3. 实现后,将其添加到 `data_juicer/ops/filter` 目录下 `__init__.py` 文件中的算子字典中: @@ -229,9 +251,7 @@ class WordNumFilter(Filter): ```python # 修改计算逻辑前 ... -tokenizer = get_model(self.model_key, - lang=self.lang, - model_type='sentencepiece') +tokenizer = get_model(self.model_key) words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) @@ -245,9 +265,7 @@ if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: # 正常计算流程 - tokenizer = get_model(self.model_key, - lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key) words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/tests/ops/filter/test_image_text_matching_filter.py b/tests/ops/filter/test_image_text_matching_filter.py index c1014a405..15adfb5d4 100644 --- a/tests/ops/filter/test_image_text_matching_filter.py +++ b/tests/ops/filter/test_image_text_matching_filter.py @@ -30,7 +30,7 @@ def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1): dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats, num_proc=num_proc) + dataset = dataset.map(op.compute_stats, num_proc=num_proc, with_rank=True) dataset = dataset.filter(op.process, num_proc=num_proc) dataset = dataset.select_columns(column_names=['text', 'images']) res_list = dataset.to_list() diff --git a/tests/ops/filter/test_image_text_similarity_filter.py b/tests/ops/filter/test_image_text_similarity_filter.py index 960cf1265..f50637561 100644 --- a/tests/ops/filter/test_image_text_similarity_filter.py +++ b/tests/ops/filter/test_image_text_similarity_filter.py @@ -30,7 +30,7 @@ def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1): dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats, num_proc=num_proc) + dataset = dataset.map(op.compute_stats, num_proc=num_proc, with_rank=True) dataset = dataset.filter(op.process, num_proc=num_proc) dataset = dataset.select_columns(column_names=['text', 'images']) res_list = dataset.to_list() diff --git a/tests/ops/filter/test_phrase_grounding_recall_filter.py b/tests/ops/filter/test_phrase_grounding_recall_filter.py index bbd378cb3..c5510014d 100644 --- a/tests/ops/filter/test_phrase_grounding_recall_filter.py +++ b/tests/ops/filter/test_phrase_grounding_recall_filter.py @@ -34,7 +34,7 @@ def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1): dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats, num_proc=num_proc) + dataset = dataset.map(op.compute_stats, num_proc=num_proc, with_rank=True) dataset = dataset.filter(op.process, num_proc=num_proc) dataset = dataset.select_columns(column_names=['text', 'images']) res_list = dataset.to_list() diff --git a/tests/ops/mapper/test_generate_caption_mapper.py b/tests/ops/mapper/test_generate_caption_mapper.py index 3cf563aae..bbf51668d 100644 --- a/tests/ops/mapper/test_generate_caption_mapper.py +++ b/tests/ops/mapper/test_generate_caption_mapper.py @@ -26,7 +26,7 @@ def tearDownClass(cls) -> None: def _run_mapper(self, dataset: NestedDataset, op, num_proc=1, caption_num=0): - dataset = dataset.map(op.process, num_proc=num_proc) + dataset = dataset.map(op.process, num_proc=num_proc, with_rank=True) dataset_list = dataset.select_columns(column_names=['text']).to_list() # assert the caption is generated successfully in terms of not_none # as the generated content is not deterministic diff --git a/tests/ops/test_op_fusion.py b/tests/ops/test_op_fusion.py index 5fc46d9bb..7f13ad431 100644 --- a/tests/ops/test_op_fusion.py +++ b/tests/ops/test_op_fusion.py @@ -1,8 +1,10 @@ import unittest from data_juicer.ops.load import load_ops +from data_juicer.utils.unittest_utils import SKIPPED_TESTS +@SKIPPED_TESTS.register_module() class OpFusionTest(unittest.TestCase): def _run_op_fusion(self, original_process_list, target_process_list):