Skip to content

Commit

Permalink
Enhance/gpu support (#203)
Browse files Browse the repository at this point in the history
Support GPU model loading and refine multiprocessing strategy
  • Loading branch information
drcege authored Feb 22, 2024
1 parent e08a225 commit 092b3da
Show file tree
Hide file tree
Showing 18 changed files with 306 additions and 187 deletions.
77 changes: 77 additions & 0 deletions data_juicer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,78 @@
__version__ = '0.1.3'

import os
import subprocess

import multiprocess as mp
from loguru import logger


def _cuda_device_count():
try:
nvidia_smi_output = subprocess.check_output(['nvidia-smi', '-L'],
text=True)
all_devices = nvidia_smi_output.strip().split('\n')

cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')

if cuda_visible_devices:
visible_devices = cuda_visible_devices.split(',')
visible_devices = [int(dev.strip()) for dev in visible_devices]
num_visible_devices = sum(1 for dev in visible_devices
if 0 <= dev < len(all_devices))
else:
num_visible_devices = len(all_devices)

return num_visible_devices
except Exception:
# nvidia-smi not found or other error
return 0


_USE_CUDA = False
_CUDA_COUNT = _cuda_device_count()


def use_cuda():
return _USE_CUDA


def cuda_device_count():
return _CUDA_COUNT


def setup_mp():
method = os.getenv('MP_START_METHOD', 'auto').lower()
if method == 'auto':
if _CUDA_COUNT > 0:
# forkserver is more lightweight
method = ('forkserver' if 'forkserver'
in mp.get_all_start_methods() else 'spawn')
else:
method = 'fork'
try:
logger.info(f"Setting multiprocess start method to '{method}'.")
mp.set_start_method(method, force=True)
except RuntimeError as e:
logger.warning(f'Error setting multiprocess start method: {e}')


def setup_cuda():
global _USE_CUDA

method = mp.get_start_method()
if method != 'fork' and _CUDA_COUNT > 0:
_USE_CUDA = True
else:
_USE_CUDA = False
logger.debug(f'_USE_CUDA: {_USE_CUDA} | MP: {method} '
f'({mp.current_process().name})')


def initialize():
if mp.current_process().name == 'MainProcess':
setup_mp()
setup_cuda()


initialize()
16 changes: 13 additions & 3 deletions data_juicer/core/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from loguru import logger

from data_juicer import cuda_device_count, use_cuda
from data_juicer.config import init_configs
from data_juicer.format.load import load_formatter
from data_juicer.ops import (OPERATORS, Deduplicator, Filter, Mapper, Selector,
Expand Down Expand Up @@ -115,10 +116,17 @@ def run(self, load_data_np=None):
for op_cfg, op in zip(self.process_list, self.ops):
op_name, op_args = list(op_cfg.items())[0]
prev = dataset # record last dataset
if use_cuda() and op._accelerator == 'cuda':
op_proc = min(cuda_device_count(), self.cfg.np)
with_rank = True
else:
op_proc = self.cfg.np
with_rank = False
try:
if isinstance(op, Mapper):
tmp = dataset.map(function=op.process,
num_proc=self.cfg.np,
num_proc=op_proc,
with_rank=with_rank,
desc=op_name + '_process')
if self.open_tracer and \
op_name in self.op_list_to_trace:
Expand All @@ -142,7 +150,8 @@ def run(self, load_data_np=None):
if self.cfg.use_checkpoint:
prev = dataset
dataset = dataset.map(op.compute_stats,
num_proc=self.cfg.np,
num_proc=op_proc,
with_rank=with_rank,
desc=op_name + '_compute_stats')
if self.cfg.use_checkpoint:
prev = dataset
Expand All @@ -157,7 +166,8 @@ def run(self, load_data_np=None):
self.tracer.trace_filter(op_name, dataset, tmp)
elif isinstance(op, Deduplicator):
dataset = dataset.map(op.compute_hash,
num_proc=self.cfg.np,
num_proc=op_proc,
with_rank=with_rank,
desc=op_name + '_compute_hash')
if self.cfg.use_checkpoint:
prev = dataset
Expand Down
1 change: 1 addition & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
if audio_key is None:
audio_key = 'audios'
self.audio_key = audio_key
self._accelerator = 'cpu'

from data_juicer.core.data import wrap_func_with_nested_access
self.process = wrap_func_with_nested_access(self.process)
Expand Down
40 changes: 0 additions & 40 deletions data_juicer/ops/deduplicator/document_simhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,46 +21,6 @@
with AvailabilityChecking(['simhash-pybind'], OP_NAME):
import simhash

def local_num_differing_bits(hash_a, hash_b):
"""
Local implementation of calculating the number of different bits
between two integers.
:param hash_a: integer hash value a
:param hash_b: integer hash value b
:return: number of different bits between input hashes.
"""
cnt = 0
n = hash_a ^ hash_b
while n != 0:
cnt += 1
n = n & (n - 1)
return cnt

def num_differing_bits_selector():
"""
Select a num_differing_bits method according to the Python version
installed.
When Python >= 3.9, the original simhash library cannot be compiled
correctly due to some changes in cython. After fixing this
incompatibility, RecursionError occurs sometimes when calling
simhash.num_differing_bits. So we use our implementation when Python
>= 3.9. Otherwise, we use implementation of simhash.
:return: an available num_differing_bits function.
"""
import platform
a, b, _ = platform.python_version().split('.')
if a == '3' and int(b) >= 9:
# for >= 3.9, use local implementation
return local_num_differing_bits
else:
# for < 3.9, use simhash version
return simhash.num_differing_bits

num_differing_bits = num_differing_bits_selector()


@OPERATORS.register_module(OP_NAME)
class DocumentSimhashDeduplicator(Deduplicator):
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self,
if tokenization:
self.model_key = prepare_model(
model_type='huggingface',
model_name_or_path='EleutherAI/pythia-6.9b-deduped',
pretrained_model_name_or_path='EleutherAI/pythia-6.9b-deduped',
return_model=False)

def compute_stats(self, sample):
Expand Down
9 changes: 5 additions & 4 deletions data_juicer/ops/filter/image_text_matching_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='huggingface',
model_name_or_path=hf_blip)
pretrained_model_name_or_path=hf_blip)
self._accelerator = 'cuda'
self.reduce_mode = reduce_mode
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip

def compute_stats(self, sample, context=False):
def compute_stats(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.image_text_matching_score in sample[Fields.stats]:
return sample
Expand All @@ -94,7 +95,7 @@ def compute_stats(self, sample, context=False):
text = sample[self.text_key]
offset = 0
matching_scores = []
model, processor = get_model(self.model_key)
model, processor = get_model(self.model_key, rank=rank)

for chunk in text.split(SpecialTokens.eoc):
count = chunk.count(SpecialTokens.image)
Expand Down Expand Up @@ -139,7 +140,7 @@ def compute_stats(self, sample, context=False):

return sample

def process(self, sample):
def process(self, sample, rank=None):
itm_scores = sample[Fields.stats][StatsKeys.image_text_matching_score]
if len(itm_scores) <= 0:
return True
Expand Down
9 changes: 5 additions & 4 deletions data_juicer/ops/filter/image_text_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='huggingface',
model_name_or_path=hf_clip)
pretrained_model_name_or_path=hf_clip)
self._accelerator = 'cuda'
self.reduce_mode = reduce_mode
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip

def compute_stats(self, sample, context=False):
def compute_stats(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.image_text_similarity in sample[Fields.stats]:
return sample
Expand All @@ -94,7 +95,7 @@ def compute_stats(self, sample, context=False):
text = sample[self.text_key]
offset = 0
similarity = []
model, processor = get_model(self.model_key)
model, processor = get_model(self.model_key, rank=rank)

for chunk in text.split(SpecialTokens.eoc):
count = chunk.count(SpecialTokens.image)
Expand Down Expand Up @@ -137,7 +138,7 @@ def compute_stats(self, sample, context=False):

return sample

def process(self, sample):
def process(self, sample, rank=None):
similarity = sample[Fields.stats][StatsKeys.image_text_similarity]
if len(similarity) <= 0:
return True
Expand Down
12 changes: 7 additions & 5 deletions data_juicer/ops/filter/phrase_grounding_recall_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='huggingface',
model_name_or_path=hf_owlvit)
pretrained_model_name_or_path=hf_owlvit)
self._accelerator = 'cuda'
self.reduce_mode = reduce_mode
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip
Expand All @@ -144,7 +145,7 @@ def __init__(self,
for nltk_data_pkg in requires_nltk_data:
nltk.download(nltk_data_pkg)

def compute_stats(self, sample, context=False):
def compute_stats(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.phrase_grounding_recall in sample[Fields.stats]:
return sample
Expand All @@ -163,7 +164,7 @@ def compute_stats(self, sample, context=False):
text = sample[self.text_key]
offset = 0
recalls = []
model, processor = get_model(self.model_key)
model, processor = get_model(self.model_key, rank=rank)

for chunk in text.split(SpecialTokens.eoc):
count = chunk.count(SpecialTokens.image)
Expand Down Expand Up @@ -197,8 +198,9 @@ def compute_stats(self, sample, context=False):

with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.tensor(
[img.size[::-1] for img in images_this_chunk])
target_sizes = torch.tensor([
img.size[::-1] for img in images_this_chunk
]).to(model.device)
results = processor.post_process_object_detection(
outputs,
threshold=self.conf_thr,
Expand Down
7 changes: 4 additions & 3 deletions data_juicer/ops/filter/token_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def __init__(self,
self.min_num = min_num
self.max_num = max_num
self.hf_tokenizer = hf_tokenizer
self.model_key = prepare_model(model_type='huggingface',
model_name_or_path=hf_tokenizer,
return_model=False)
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=hf_tokenizer,
return_model=False)

def compute_stats(self, sample):
# check if it's computed already
Expand Down
18 changes: 11 additions & 7 deletions data_juicer/ops/mapper/generate_caption_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def __init__(self,
f'["random_any", "similar_one_simhash", "all"].')

self.model_key = prepare_model(model_type='huggingface',
model_name_or_path=hf_blip2)
pretrained_model_name_or_path=hf_blip2)
self._accelerator = 'cuda'
self.caption_num = caption_num
self.keep_candidate_mode = keep_candidate_mode
self.keep_original_sample = keep_original_sample
Expand All @@ -109,7 +110,7 @@ def __init__(self,
'Both the parameter `prompt` and `prompt_key` are '
'set. Data-Juicer will consider `prompt_key` first.')

def _process_single_sample(self, ori_sample):
def _process_single_sample(self, ori_sample, rank=None):
"""
:param ori_sample: a single data sample before applying generation
Expand Down Expand Up @@ -146,7 +147,7 @@ def _process_single_sample(self, ori_sample):
# the generated text will be placed following each SpecialTokens.img
# and the original special tokens are kept in an order-preserving way.

model, processor = get_model(self.model_key)
model, processor = get_model(self.model_key, rank=rank)

# do generation for each image chunk by chunk
for chunk in ori_sample[self.text_key].split(SpecialTokens.eoc):
Expand Down Expand Up @@ -233,8 +234,10 @@ def _reduce_captions_per_image(self, chunk,
new_generated_text_per_chunk.extend(
generated_text_candidates_single_chunk)
elif self.keep_candidate_mode == 'similar_one_simhash':
from ..deduplicator.document_simhash_deduplicator import (
DocumentSimhashDeduplicator, num_differing_bits)
from simhash import num_differing_bits

from ..deduplicator.document_simhash_deduplicator import \
DocumentSimhashDeduplicator
ori_normal_text = remove_special_tokens(chunk)
# using a simhash OP to calculate their similarity
# NOTE: simhash is just one method to calculate the similarities
Expand Down Expand Up @@ -262,7 +265,7 @@ def _reduce_captions_per_image(self, chunk,
generated_text_candidates_single_chunk[max_index])
return new_generated_text_per_chunk

def process(self, samples):
def process(self, samples, rank=None):
"""
Note: This is a batched_OP, whose the input and output type are
both list. Suppose there are $N$ input sample list with batch
Expand All @@ -284,7 +287,8 @@ def process(self, samples):
for ori_sample in reconstructed_samples:
if self.keep_original_sample:
samples_after_generation.append(ori_sample)
generated_samples = self._process_single_sample(ori_sample)
generated_samples = self._process_single_sample(ori_sample,
rank=rank)
if len(generated_samples) != 0:
samples_after_generation.extend(generated_samples)
# reconstruct samples from "list of dicts" to "dict of lists"
Expand Down
Loading

0 comments on commit 092b3da

Please sign in to comment.