Skip to content

Commit

Permalink
refactoring minhash + 64bit hash support
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Nov 2, 2023
1 parent dd30eac commit edbea5f
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 162 deletions.
109 changes: 63 additions & 46 deletions src/datatrove/pipeline/dedup/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

# http://en.wikipedia.org/wiki/Mersenne_prime
_mersenne_prime = np.uint64((1 << 61) - 1)
_max_hash = np.uint64((1 << 32) - 1)
_hash_range = 1 << 32
_max_hash_32b = np.uint64((1 << 32) - 1)

"""
n_grams -> roughly nr of words (this should be small enough to catch fuzzy matches but big enough to not have each shingle be too common)
Expand All @@ -28,13 +27,31 @@
probability of inclusion for s=0.8: 1-(1-0.8^8)^14=0.924
"""

DEFAULT_NR_BUCKETS = 14
DEFAULT_PER_BUCKET = 8
DEFAULT_N_GRAMS = 5

SENTINEL = (1 << 32) - 1


@dataclass
class MinhashConfig:
n_grams: int = 5

num_buckets: int = 14
hashes_per_bucket: int = 8

use_64bit_hashes: bool = False
seed: int = 1

@property
def hash_dtype(self):
return np.uint64 if self.use_64bit_hashes else np.uint32

@property
def hash_format(self):
return "Q" if self.use_64bit_hashes else "I"


DEFAULT_MINHASH_CONFIG = MinhashConfig()


@dataclass
class HashSig:
sig: tuple[int]
Expand All @@ -52,13 +69,13 @@ def __lt__(self, other):
return self.to_tuple() < other.to_tuple()


def read_sigs(file: InputDataFile, reader_id: int, hashes_per_bucket: int, index_file: bool = False) -> Generator:
n = hashes_per_bucket + 1 if not index_file else hashes_per_bucket
for data in read_tuples_from_file(file, f"{n}I"):
if index_file:
def read_sigs(file: InputDataFile, reader_id: int, config: MinhashConfig, index_file: bool = False) -> Generator:
if index_file:
for data in read_tuples_from_file(file, f"{config.hashes_per_bucket}{config.hash_format}"):
yield HashSig(sig=data, doc_id=-1, file_id=-1, reader_id=reader_id)
else:
yield HashSig(sig=data[:-1], doc_id=data[-1], file_id=reader_id, reader_id=reader_id)
else:
for *data, doc_id in read_tuples_from_file(file, f"{config.hashes_per_bucket}{config.hash_format}", "I"):
yield HashSig(sig=data, doc_id=doc_id, file_id=reader_id, reader_id=reader_id)


class MinhashDedupSignature(PipelineStep):
Expand All @@ -68,19 +85,13 @@ class MinhashDedupSignature(PipelineStep):
def __init__(
self,
output_folder: BaseOutputDataFolder,
num_buckets: int = DEFAULT_NR_BUCKETS,
hashes_per_bucket: int = DEFAULT_PER_BUCKET,
n_grams: int = DEFAULT_N_GRAMS,
seed: int = 1,
config: MinhashConfig = DEFAULT_MINHASH_CONFIG,
**kwargs,
):
super().__init__(**kwargs)
self.output_folder = output_folder
self.n_grams = n_grams
self.num_buckets = num_buckets
self.hashes_per_bucket = hashes_per_bucket
self.num_hashes = self.num_buckets * self.hashes_per_bucket
self.seed = seed
self.config = config
self.num_hashes = self.config.num_buckets * self.config.hashes_per_bucket
self._parameters = None

@property
Expand All @@ -89,16 +100,20 @@ def parameters(self):
# Create parameters for a random bijective permutation function
# that maps a 32-bit hash value to another 32-bit hash value.
# http://en.wikipedia.org/wiki/Universal_hashing
gen = np.random.RandomState(self.seed)
gen = np.random.RandomState(self.config.seed)
self._parameters = gen.randint(
1, _mersenne_prime, dtype=np.uint64, size=(1, self.num_hashes)
), gen.randint(0, _mersenne_prime, dtype=np.uint64, size=(1, self.num_hashes))
return self._parameters

def get_signature(self, shingles):
a, b = self.parameters
phv = np.bitwise_and((shingles * a + b) % _mersenne_prime, _max_hash)
return [x.tolist() for x in np.split(np.min(phv, axis=0).astype(np.uint32), self.num_buckets)]
phv = (shingles * a + b) % _mersenne_prime
if self.config.use_64bit_hashes:
phv = np.bitwise_and(phv, _max_hash_32b)
return [
x.tolist() for x in np.split(np.min(phv, axis=0).astype(self.config.hash_dtype), self.config.num_buckets)
]

def set_up_dl_locks(self, dl_lock, up_lock):
self.output_folder.set_lock(up_lock)
Expand All @@ -107,15 +122,15 @@ def get_shingles(self, text):
return np.array(
[
[sha1_hash32(" ".join(x).encode("utf-8"))]
for x in ngrams(word_tokenize(simplify_content(text)), self.n_grams)
for x in ngrams(word_tokenize(simplify_content(text)), self.config.n_grams)
],
dtype=np.uint64,
)

def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
buckets = [
self.output_folder.open(f"bucket_{bi:03d}/{rank:05d}.minhash.sig", mode="wb")
for bi in range(self.num_buckets)
for bi in range(self.config.num_buckets)
]
for doc_idx, doc in enumerate(data):
self.stat_update(StatHints.total)
Expand All @@ -124,14 +139,20 @@ def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
sig = self.get_signature(shingles)
for bi, (bucket, bucket_sig) in enumerate(zip(buckets, sig)):
# print(f"{self.hashes_per_bucket=} {bucket_sig=}")
bucket.write(struct.pack("<%sI" % self.hashes_per_bucket, *bucket_sig))
bucket.write(
struct.pack(f"<{self.config.hashes_per_bucket}{self.config.hash_format}", *bucket_sig)
)
bucket.write(struct.pack("<I", doc_idx))
logger.info("Sorting buckets...")
for bi, bucket in enumerate(buckets):
bucket.close()
fo = self.output_folder.open(bucket.relative_path, mode="r+b", overwrite=True)
mmap = np.memmap(fo.file_handler, dtype=[(str(i), np.uint32) for i in range(self.hashes_per_bucket + 1)])
mmap.sort(order=[str(i) for i in range(self.hashes_per_bucket + 1)])
mmap = np.memmap(
fo.file_handler,
dtype=[(str(i), self.config.hash_dtype) for i in range(self.config.hashes_per_bucket)]
+ [(str(self.config.hashes_per_bucket), np.uint32)],
) # doc_id at the end
mmap.sort(order=[str(i) for i in range(self.config.hashes_per_bucket + 1)])
mmap.flush()
fo.close()
self.output_folder.close()
Expand All @@ -146,17 +167,15 @@ def __init__(
input_folder: BaseInputDataFolder,
output_folder: BaseOutputDataFolder,
index_folder: BaseInputDataFolder = None,
hashes_per_bucket: int = DEFAULT_PER_BUCKET,
num_buckets: int = DEFAULT_NR_BUCKETS,
config: MinhashConfig = DEFAULT_MINHASH_CONFIG,
only_dedup_in_index: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.input_folder = input_folder
self.output_folder = output_folder
self.num_buckets = num_buckets
self.hashes_per_bucket = hashes_per_bucket
self.index_folder = index_folder
self.config = config
self.only_dedup_in_index = only_dedup_in_index

def set_up_dl_locks(self, dl_lock, up_lock):
Expand All @@ -165,15 +184,15 @@ def set_up_dl_locks(self, dl_lock, up_lock):

def __call__(self, data: DocumentsPipeline, bucket: int = 0, world_size: int = 1):
assert data is None, "You should not use an input block before MinhashDedupBuckets"
assert world_size == self.num_buckets, "You must run exactly one task per bucket"
assert world_size == self.config.num_buckets, "You must run exactly one task per bucket"
sig_files = self.input_folder.list_files(suffix=f"bucket_{bucket:03d}")
sig_readers = [read_sigs(file, file_i, self.hashes_per_bucket) for file_i, file in enumerate(sig_files)]
sig_readers = [read_sigs(file, file_i, self.config) for file_i, file in enumerate(sig_files)]
index_files = self.index_folder.list_files(suffix=f"bucket_{bucket:03d}") if self.index_folder else None
if index_files:
logger.info(f"Found index file(s): {', '.join([file.relative_path for file in index_files])}")
sig_readers.extend(
[
read_sigs(file, len(sig_readers) + file_i, self.hashes_per_bucket, index_file=True)
read_sigs(file, len(sig_readers) + file_i, self.config, index_file=True)
for file_i, file in enumerate(index_files)
]
)
Expand Down Expand Up @@ -209,14 +228,14 @@ def __init__(
self,
input_folder: BaseInputDataFolder,
output_folder: BaseOutputDataFolder,
num_buckets: int = DEFAULT_NR_BUCKETS,
config: MinhashConfig = DEFAULT_MINHASH_CONFIG,
save_cluster_id: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.input_folder = input_folder
self.output_folder = output_folder
self.num_buckets = num_buckets
self.config = config
self.save_cluster_id = save_cluster_id

def set_up_dl_locks(self, dl_lock, up_lock):
Expand All @@ -225,7 +244,7 @@ def set_up_dl_locks(self, dl_lock, up_lock):

def __call__(self, data: DocumentsPipeline, _: int = 0, world_size: int = 1):
dup_files = self.input_folder.list_files(extension=".dups")
assert len(dup_files) == self.num_buckets, "There should be exactly one .dups file per bucket"
assert len(dup_files) == self.config.num_buckets, "There should be exactly one .dups file per bucket"
assert world_size == 1, "World size must be 1 for clustering"
union_set = {}

Expand Down Expand Up @@ -334,15 +353,13 @@ def __init__(
input_folder: BaseInputDataFolder,
output_folder: BaseOutputDataFolder,
index_name: str,
hashes_per_bucket: int = DEFAULT_PER_BUCKET,
num_buckets: int = DEFAULT_NR_BUCKETS,
config: MinhashConfig = DEFAULT_MINHASH_CONFIG,
**kwargs,
):
super().__init__(**kwargs)
self.input_folder = input_folder
self.output_folder = output_folder
self.num_buckets = num_buckets
self.hashes_per_bucket = hashes_per_bucket
self.config = config
self.index_name = index_name

def set_up_dl_locks(self, dl_lock, up_lock):
Expand All @@ -351,9 +368,9 @@ def set_up_dl_locks(self, dl_lock, up_lock):

def __call__(self, data: DocumentsPipeline, bucket: int = 0, world_size: int = 1):
assert data is None, "You should not use an input block before MinhashDedupBuckets"
assert world_size == self.num_buckets, "You must run exactly one task per bucket"
assert world_size == self.config.num_buckets, "You must run exactly one task per bucket"
sig_files = self.input_folder.list_files(suffix=f"bucket_{bucket:03d}")
sig_readers = [read_sigs(file, file_i, self.hashes_per_bucket) for file_i, file in enumerate(sig_files)]
sig_readers = [read_sigs(file, file_i, self.config) for file_i, file in enumerate(sig_files)]

pq = [next(sig_reader) for sig_reader in sig_readers]
heapq.heapify(pq)
Expand All @@ -365,7 +382,7 @@ def __call__(self, data: DocumentsPipeline, bucket: int = 0, world_size: int = 1
while pq:
v: HashSig = heapq.heappop(pq)
if not last or last.sig != v.sig:
out_f.write(struct.pack("<%dI" % self.hashes_per_bucket, *v.sig))
out_f.write(struct.pack("<%dI" % self.config.hashes_per_bucket, *v.sig))
last = v
next_sig = next(sig_readers[v.file_id], None)
if next_sig:
Expand Down
Loading

0 comments on commit edbea5f

Please sign in to comment.