Skip to content

Commit

Permalink
add index (inter-dataset) dedup support
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Oct 18, 2023
1 parent 2355157 commit d03013b
Showing 1 changed file with 87 additions and 15 deletions.
102 changes: 87 additions & 15 deletions src/datatrove/pipeline/dedup/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,40 @@
DEFAULT_PER_BUCKET = 8
DEFAULT_N_GRAMS = 5

SENTINEL = (1 << 32) - 1


@dataclass
class HashSig:
sig: tuple[int]
doc_id: int
file_id: int
reader_id: int

def to_tuple(self):
return self.sig, self.file_id, self.doc_id
return self.sig, self.file_id, self.doc_id, self.reader_id

def is_from_index(self):
return self.reader_id != self.file_id

def __lt__(self, other):
return self.to_tuple() < other.to_tuple()


def read_sigs(file: InputDataFile, reader_id: int, hashes_per_bucket: int, index_file: bool = False) -> Generator:
n = hashes_per_bucket + 1 if not index_file else hashes_per_bucket
with file.open(binary=True) as f:
while True:
data = f.read(n * struct.calcsize("I"))
if not data:
return
data = struct.unpack("<%sI" % n, data)
if index_file:
yield HashSig(sig=data, doc_id=-1, file_id=-1, reader_id=reader_id)
else:
yield HashSig(sig=data[:-1], doc_id=data[-1], file_id=reader_id, reader_id=reader_id)


class MinhashDedupSignature(PipelineStep):
type = "🫂 - DEDUP"
name = "🎯 MinHash stage 1"
Expand Down Expand Up @@ -130,6 +150,7 @@ def __init__(
self,
input_folder: BaseInputDataFolder,
output_folder: BaseOutputDataFolder,
index_folder: BaseInputDataFolder = None,
hashes_per_bucket: int = DEFAULT_PER_BUCKET,
num_buckets: int = DEFAULT_NR_BUCKETS,
**kwargs,
Expand All @@ -139,16 +160,7 @@ def __init__(
self.output_folder = output_folder
self.num_buckets = num_buckets
self.hashes_per_bucket = hashes_per_bucket

def read_sigs(self, file: InputDataFile, file_id: int) -> Generator:
n = self.hashes_per_bucket + 1
with file.open(binary=True) as f:
while True:
data = f.read(n * struct.calcsize("I"))
if not data:
return
data = struct.unpack("<%sI" % n, data)
yield HashSig(sig=data[:-1], doc_id=data[-1], file_id=file_id)
self.index_folder = index_folder

def set_up_dl_locks(self, dl_lock, up_lock):
self.input_folder.set_lock(dl_lock)
Expand All @@ -158,7 +170,15 @@ def __call__(self, data: DocumentsPipeline, bucket: int = 0, world_size: int = 1
assert data is None, "You should not use an input block before MinhashDedupBuckets"
assert world_size == self.num_buckets, "You must run exactly one task per bucket"
sig_files = self.input_folder.list_files(suffix=f"bucket_{bucket:03d}")
sig_readers = [self.read_sigs(file, file_i) for file_i, file in enumerate(sig_files)]
sig_readers = [read_sigs(file, file_i, self.hashes_per_bucket) for file_i, file in enumerate(sig_files)]
index_files = self.index_folder.list_files(suffix=f"bucket_{bucket:03d}") if self.index_folder else None
if index_files:
sig_readers.extend(
[
read_sigs(file, len(sig_readers) + file_i, self.hashes_per_bucket, index_file=True)
for file_i, file in enumerate(index_files)
]
)

pq = [next(sig_reader) for sig_reader in sig_readers]
heapq.heapify(pq)
Expand All @@ -168,10 +188,14 @@ def __call__(self, data: DocumentsPipeline, bucket: int = 0, world_size: int = 1
last: HashSig | None = None
while pq:
v: HashSig = heapq.heappop(pq)
if last and last.sig == v.sig:
out_f.write(struct.pack("<4I", last.file_id, last.doc_id, v.file_id, v.doc_id))
if last and last.sig == v.sig and not v.is_from_index():
# write (file_id1, doc_id1, file_id2, doc_id2), where file_id1 <= file_id2
if last.is_from_index():
out_f.write(struct.pack("<4I", v.file_id, v.doc_id, SENTINEL, SENTINEL))
else:
out_f.write(struct.pack("<4I", last.file_id, last.doc_id, v.file_id, v.doc_id))
last = v
next_sig = next(sig_readers[v.file_id], None)
next_sig = next(sig_readers[v.reader_id], None)
if next_sig:
heapq.heappush(pq, next_sig)
self.output_folder.close()
Expand Down Expand Up @@ -297,3 +321,51 @@ def load_clusters():
continue
self.stat_update(StatHints.forwarded)
yield doc


class MinhashBuildIndex(PipelineStep):
type = "🫂 - DEDUP"
name = "🎯 MinHash build index"

def __init__(
self,
input_folder: BaseInputDataFolder,
output_folder: BaseOutputDataFolder,
index_name: str,
hashes_per_bucket: int = DEFAULT_PER_BUCKET,
num_buckets: int = DEFAULT_NR_BUCKETS,
**kwargs,
):
super().__init__(**kwargs)
self.input_folder = input_folder
self.output_folder = output_folder
self.num_buckets = num_buckets
self.hashes_per_bucket = hashes_per_bucket
self.index_name = index_name

def set_up_dl_locks(self, dl_lock, up_lock):
self.input_folder.set_lock(dl_lock)
self.output_folder.set_lock(up_lock)

def __call__(self, data: DocumentsPipeline, bucket: int = 0, world_size: int = 1):
assert data is None, "You should not use an input block before MinhashDedupBuckets"
assert world_size == self.num_buckets, "You must run exactly one task per bucket"
sig_files = self.input_folder.list_files(suffix=f"bucket_{bucket:03d}")
sig_readers = [read_sigs(file, file_i, self.hashes_per_bucket) for file_i, file in enumerate(sig_files)]

pq = [next(sig_reader) for sig_reader in sig_readers]
heapq.heapify(pq)

# writes all the sigs for the entire bucket, sequentially
out_f = self.output_folder.open(f"bucket_{bucket:03d}/{self.index_name}.minhash.index", mode="wb")

last: HashSig | None = None
while pq:
v: HashSig = heapq.heappop(pq)
if not last or last.sig != v.sig:
out_f.write(struct.pack("<%dI" % self.hashes_per_bucket, *v.sig))
last = v
next_sig = next(sig_readers[v.file_id], None)
if next_sig:
heapq.heappush(pq, next_sig)
self.output_folder.close()

0 comments on commit d03013b

Please sign in to comment.