Skip to content

Commit

Permalink
added check
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Nov 26, 2024
1 parent 5600ecb commit 3cb0763
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 80 deletions.
6 changes: 5 additions & 1 deletion src/datatrove/pipeline/dedup/fast_mh3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@ indicatif = "0.17.7"
tokio = { version = "1.33.0", features = ["full"] }

# Retries
tokio-retry = "0.3"
tokio-retry = "0.3"

[[bin]]
name = "check"
path = "src/check.rs"
82 changes: 3 additions & 79 deletions src/datatrove/pipeline/dedup/minhash.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import array
import contextlib
import heapq
import os
Expand Down Expand Up @@ -448,79 +447,6 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1
out_index.close()


class MemEfficientDict:
# keys are (file, doc) pairs. We assume #file << #docs, to really save on the memory overhead of ints
# this also assumes there are quite a lot of duplicates, otherwise the array will be too sparse

def __init__(self, dtype):
self.dtype = dtype
self._data = {}
self._SENTINEL = (1 << (struct.calcsize(dtype) * 8)) - 1

def __setitem__(self, key, value):
assert value < self._SENTINEL, f"value {value} for key {key} is too large"
k_f, k_d = key
if k_f not in self._data:
self._data[k_f] = array.array(self.dtype)
if len(self._data[k_f]) <= k_d: # increase array size as needed
self._data[k_f].extend((self._SENTINEL for _ in range(k_d - len(self._data[k_f]) + 1)))
self._data[k_f][k_d] = value

def __contains__(self, item):
return self[item] is not None

def __getitem__(self, item):
k_f, k_d = item
if k_f not in self._data:
return None
if len(self._data[k_f]) <= k_d or self._data[k_f][k_d] == self._SENTINEL:
return None
return self._data[k_f][k_d]

def __iter__(self):
for f in self._data:
for d in range(len(self._data[f])):
if self._data[f][d] != self._SENTINEL:
yield f, d

def pop(self, key, default):
pass

def items(self):
for f in self._data:
for d in range(len(self._data[f])):
val = self._data[f][d]
if val != self._SENTINEL:
yield (f, d), val

def get(self, key, default):
val = self[key]
if val is None:
return default
return val


class UnionSet:
def __init__(self):
self._fs = MemEfficientDict("H")
self._ds = MemEfficientDict("I")

def __getitem__(self, item):
return self._fs[item], self._ds[item]

def __setitem__(self, key, value):
self._fs[key], self._ds[key] = value

def __iter__(self):
return self._fs.__iter__()

def __contains__(self, item):
return item in self._fs

def items(self):
return zip(self._fs.items(), self._ds.items())


class MinhashDedupCluster(PipelineStep):
"""Minhash Deduplication: Third Pipeline Step
Expand All @@ -539,7 +465,6 @@ def __init__(
save_cluster_size: bool = False,
ignore_index_matches: bool = False,
lines_to_buffer: int = 5,
sparse_array_mode: bool = False,
):
super().__init__()
self.input_folder = get_datafolder(input_folder)
Expand All @@ -549,16 +474,15 @@ def __init__(
self.save_cluster_size = save_cluster_size
self.ignore_index_matches = ignore_index_matches
self.lines_to_buffer = lines_to_buffer
self.sparse_array_mode = sparse_array_mode

def run(self, data: DocumentsPipeline = None, _: int = 0, world_size: int = 1):
dup_files = self.input_folder.list_files(glob_pattern="*.dups")
assert (
len(dup_files) % self.config.num_buckets
) == 0, "Number of .dups files should be divisible by number of buckets"
assert world_size == 1, "World size must be 1 for clustering"
union_set = UnionSet() if self.sparse_array_mode else {}
set_size = MemEfficientDict("I") if self.sparse_array_mode else {}
union_set = {}
set_size = {}

def parent(x):
if x not in union_set or union_set[x] == x:
Expand All @@ -577,7 +501,7 @@ def union(v_a, v_b):
size_a = set_size.get(root_a, 1)
size_b = set_size.get(root_b, 1)
if size_a < size_b:
root_a, root_b, size_a, size_b = root_b, root_a, size_b, size_a
root_a, root_b = root_b, root_a
# #a >= #b
union_set[root_b] = root_a # make the smallest one join the biggest one to keep sets shallow
set_size[root_a] = size_a + size_b
Expand Down

0 comments on commit 3cb0763

Please sign in to comment.