From 3cb07639782e4163d6254b5323555279f79a4ea3 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Tue, 26 Nov 2024 17:04:30 +0100 Subject: [PATCH] added check --- .../pipeline/dedup/fast_mh3/Cargo.toml | 6 +- src/datatrove/pipeline/dedup/minhash.py | 82 +------------------ 2 files changed, 8 insertions(+), 80 deletions(-) diff --git a/src/datatrove/pipeline/dedup/fast_mh3/Cargo.toml b/src/datatrove/pipeline/dedup/fast_mh3/Cargo.toml index 3b9e5759..cdc2bf51 100644 --- a/src/datatrove/pipeline/dedup/fast_mh3/Cargo.toml +++ b/src/datatrove/pipeline/dedup/fast_mh3/Cargo.toml @@ -24,4 +24,8 @@ indicatif = "0.17.7" tokio = { version = "1.33.0", features = ["full"] } # Retries -tokio-retry = "0.3" \ No newline at end of file +tokio-retry = "0.3" + +[[bin]] +name = "check" +path = "src/check.rs" \ No newline at end of file diff --git a/src/datatrove/pipeline/dedup/minhash.py b/src/datatrove/pipeline/dedup/minhash.py index 86d60a88..be294c82 100644 --- a/src/datatrove/pipeline/dedup/minhash.py +++ b/src/datatrove/pipeline/dedup/minhash.py @@ -1,4 +1,3 @@ -import array import contextlib import heapq import os @@ -448,79 +447,6 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 out_index.close() -class MemEfficientDict: - # keys are (file, doc) pairs. We assume #file << #docs, to really save on the memory overhead of ints - # this also assumes there are quite a lot of duplicates, otherwise the array will be too sparse - - def __init__(self, dtype): - self.dtype = dtype - self._data = {} - self._SENTINEL = (1 << (struct.calcsize(dtype) * 8)) - 1 - - def __setitem__(self, key, value): - assert value < self._SENTINEL, f"value {value} for key {key} is too large" - k_f, k_d = key - if k_f not in self._data: - self._data[k_f] = array.array(self.dtype) - if len(self._data[k_f]) <= k_d: # increase array size as needed - self._data[k_f].extend((self._SENTINEL for _ in range(k_d - len(self._data[k_f]) + 1))) - self._data[k_f][k_d] = value - - def __contains__(self, item): - return self[item] is not None - - def __getitem__(self, item): - k_f, k_d = item - if k_f not in self._data: - return None - if len(self._data[k_f]) <= k_d or self._data[k_f][k_d] == self._SENTINEL: - return None - return self._data[k_f][k_d] - - def __iter__(self): - for f in self._data: - for d in range(len(self._data[f])): - if self._data[f][d] != self._SENTINEL: - yield f, d - - def pop(self, key, default): - pass - - def items(self): - for f in self._data: - for d in range(len(self._data[f])): - val = self._data[f][d] - if val != self._SENTINEL: - yield (f, d), val - - def get(self, key, default): - val = self[key] - if val is None: - return default - return val - - -class UnionSet: - def __init__(self): - self._fs = MemEfficientDict("H") - self._ds = MemEfficientDict("I") - - def __getitem__(self, item): - return self._fs[item], self._ds[item] - - def __setitem__(self, key, value): - self._fs[key], self._ds[key] = value - - def __iter__(self): - return self._fs.__iter__() - - def __contains__(self, item): - return item in self._fs - - def items(self): - return zip(self._fs.items(), self._ds.items()) - - class MinhashDedupCluster(PipelineStep): """Minhash Deduplication: Third Pipeline Step @@ -539,7 +465,6 @@ def __init__( save_cluster_size: bool = False, ignore_index_matches: bool = False, lines_to_buffer: int = 5, - sparse_array_mode: bool = False, ): super().__init__() self.input_folder = get_datafolder(input_folder) @@ -549,7 +474,6 @@ def __init__( self.save_cluster_size = save_cluster_size self.ignore_index_matches = ignore_index_matches self.lines_to_buffer = lines_to_buffer - self.sparse_array_mode = sparse_array_mode def run(self, data: DocumentsPipeline = None, _: int = 0, world_size: int = 1): dup_files = self.input_folder.list_files(glob_pattern="*.dups") @@ -557,8 +481,8 @@ def run(self, data: DocumentsPipeline = None, _: int = 0, world_size: int = 1): len(dup_files) % self.config.num_buckets ) == 0, "Number of .dups files should be divisible by number of buckets" assert world_size == 1, "World size must be 1 for clustering" - union_set = UnionSet() if self.sparse_array_mode else {} - set_size = MemEfficientDict("I") if self.sparse_array_mode else {} + union_set = {} + set_size = {} def parent(x): if x not in union_set or union_set[x] == x: @@ -577,7 +501,7 @@ def union(v_a, v_b): size_a = set_size.get(root_a, 1) size_b = set_size.get(root_b, 1) if size_a < size_b: - root_a, root_b, size_a, size_b = root_b, root_a, size_b, size_a + root_a, root_b = root_b, root_a # #a >= #b union_set[root_b] = root_a # make the smallest one join the biggest one to keep sets shallow set_size[root_a] = size_a + size_b