diff --git a/src/datatrove/pipeline/dedup/minhash.py b/src/datatrove/pipeline/dedup/minhash.py index 257d0a83..4b641922 100644 --- a/src/datatrove/pipeline/dedup/minhash.py +++ b/src/datatrove/pipeline/dedup/minhash.py @@ -18,8 +18,7 @@ # http://en.wikipedia.org/wiki/Mersenne_prime _mersenne_prime = np.uint64((1 << 61) - 1) -_max_hash = np.uint64((1 << 32) - 1) -_hash_range = 1 << 32 +_max_hash_32b = np.uint64((1 << 32) - 1) """ n_grams -> roughly nr of words (this should be small enough to catch fuzzy matches but big enough to not have each shingle be too common) @@ -28,13 +27,31 @@ probability of inclusion for s=0.8: 1-(1-0.8^8)^14=0.924 """ -DEFAULT_NR_BUCKETS = 14 -DEFAULT_PER_BUCKET = 8 -DEFAULT_N_GRAMS = 5 - SENTINEL = (1 << 32) - 1 +@dataclass +class MinhashConfig: + n_grams: int = 5 + + num_buckets: int = 14 + hashes_per_bucket: int = 8 + + use_64bit_hashes: bool = False + seed: int = 1 + + @property + def hash_dtype(self): + return np.uint64 if self.use_64bit_hashes else np.uint32 + + @property + def hash_format(self): + return "Q" if self.use_64bit_hashes else "I" + + +DEFAULT_MINHASH_CONFIG = MinhashConfig() + + @dataclass class HashSig: sig: tuple[int] @@ -52,13 +69,13 @@ def __lt__(self, other): return self.to_tuple() < other.to_tuple() -def read_sigs(file: InputDataFile, reader_id: int, hashes_per_bucket: int, index_file: bool = False) -> Generator: - n = hashes_per_bucket + 1 if not index_file else hashes_per_bucket - for data in read_tuples_from_file(file, f"{n}I"): - if index_file: +def read_sigs(file: InputDataFile, reader_id: int, config: MinhashConfig, index_file: bool = False) -> Generator: + if index_file: + for data in read_tuples_from_file(file, f"{config.hashes_per_bucket}{config.hash_format}"): yield HashSig(sig=data, doc_id=-1, file_id=-1, reader_id=reader_id) - else: - yield HashSig(sig=data[:-1], doc_id=data[-1], file_id=reader_id, reader_id=reader_id) + else: + for *data, doc_id in read_tuples_from_file(file, f"{config.hashes_per_bucket}{config.hash_format}", "I"): + yield HashSig(sig=data, doc_id=doc_id, file_id=reader_id, reader_id=reader_id) class MinhashDedupSignature(PipelineStep): @@ -68,19 +85,13 @@ class MinhashDedupSignature(PipelineStep): def __init__( self, output_folder: BaseOutputDataFolder, - num_buckets: int = DEFAULT_NR_BUCKETS, - hashes_per_bucket: int = DEFAULT_PER_BUCKET, - n_grams: int = DEFAULT_N_GRAMS, - seed: int = 1, + config: MinhashConfig = DEFAULT_MINHASH_CONFIG, **kwargs, ): super().__init__(**kwargs) self.output_folder = output_folder - self.n_grams = n_grams - self.num_buckets = num_buckets - self.hashes_per_bucket = hashes_per_bucket - self.num_hashes = self.num_buckets * self.hashes_per_bucket - self.seed = seed + self.config = config + self.num_hashes = self.config.num_buckets * self.config.hashes_per_bucket self._parameters = None @property @@ -89,7 +100,7 @@ def parameters(self): # Create parameters for a random bijective permutation function # that maps a 32-bit hash value to another 32-bit hash value. # http://en.wikipedia.org/wiki/Universal_hashing - gen = np.random.RandomState(self.seed) + gen = np.random.RandomState(self.config.seed) self._parameters = gen.randint( 1, _mersenne_prime, dtype=np.uint64, size=(1, self.num_hashes) ), gen.randint(0, _mersenne_prime, dtype=np.uint64, size=(1, self.num_hashes)) @@ -97,8 +108,12 @@ def parameters(self): def get_signature(self, shingles): a, b = self.parameters - phv = np.bitwise_and((shingles * a + b) % _mersenne_prime, _max_hash) - return [x.tolist() for x in np.split(np.min(phv, axis=0).astype(np.uint32), self.num_buckets)] + phv = (shingles * a + b) % _mersenne_prime + if self.config.use_64bit_hashes: + phv = np.bitwise_and(phv, _max_hash_32b) + return [ + x.tolist() for x in np.split(np.min(phv, axis=0).astype(self.config.hash_dtype), self.config.num_buckets) + ] def set_up_dl_locks(self, dl_lock, up_lock): self.output_folder.set_lock(up_lock) @@ -107,7 +122,7 @@ def get_shingles(self, text): return np.array( [ [sha1_hash32(" ".join(x).encode("utf-8"))] - for x in ngrams(word_tokenize(simplify_content(text)), self.n_grams) + for x in ngrams(word_tokenize(simplify_content(text)), self.config.n_grams) ], dtype=np.uint64, ) @@ -115,7 +130,7 @@ def get_shingles(self, text): def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1): buckets = [ self.output_folder.open(f"bucket_{bi:03d}/{rank:05d}.minhash.sig", mode="wb") - for bi in range(self.num_buckets) + for bi in range(self.config.num_buckets) ] for doc_idx, doc in enumerate(data): self.stat_update(StatHints.total) @@ -124,14 +139,20 @@ def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1): sig = self.get_signature(shingles) for bi, (bucket, bucket_sig) in enumerate(zip(buckets, sig)): # print(f"{self.hashes_per_bucket=} {bucket_sig=}") - bucket.write(struct.pack("<%sI" % self.hashes_per_bucket, *bucket_sig)) + bucket.write( + struct.pack(f"<{self.config.hashes_per_bucket}{self.config.hash_format}", *bucket_sig) + ) bucket.write(struct.pack("= prev - prev = data - assert 0 <= doc_id < 100 - doc_ids.add(doc_id) - assert len(doc_ids) == 100 + for use_64bit_hashes in (True, False): + config = MinhashConfig(use_64bit_hashes=use_64bit_hashes) + minhash = MinhashDedupSignature( + output_folder=LocalOutputDataFolder(os.path.join(self.test_dir, "signatures1")), config=config + ) + shingles = minhash.get_shingles(lorem_ipsum) + sig = minhash.get_signature(shingles) + + minhash2 = MinhashDedupSignature( + output_folder=LocalOutputDataFolder(os.path.join(self.test_dir, "signatures2")), config=config + ) + # check consistency + assert sig == minhash2.get_signature(shingles) + + # check correct number of outputs + assert len(sig) == minhash.config.num_buckets + assert all((len(x) == minhash.config.hashes_per_bucket for x in sig)) + + # check similarity approximation + for pctd in range(0, 100, 5): + dec = pctd / 100 + endp = floor(len(lorem_ipsum) * dec) + textd = lorem_ipsum[:endp] + lorem_ipsum[len(lorem_ipsum) - 1 : endp : -1] + sigd = minhash.get_signature(minhash.get_shingles(textd)) + simil = ( + sum([1 if a == b else 0 for ba, bb in zip(sig, sigd) for a, b in zip(ba, bb)]) / minhash.num_hashes + ) + assert dec - 0.21 < simil < dec + 0.21 + + # check output file format and order + samples = [Document(f"sample {i}, {lorem_ipsum[i:: 10]}", data_id="test") for i in range(100)] + minhash(samples) + for bi in range(config.num_buckets): + with open( + os.path.join(minhash.output_folder.path, f"bucket_{bi:03d}", "00000.minhash.sig"), "rb" + ) as f: + prev = None + doc_ids = set() + S = np.dtype(config.hash_dtype).itemsize + for di in range(100): + data = struct.unpack( + f"<%s{config.hash_format}" % config.hashes_per_bucket, f.read(config.hashes_per_bucket * S) + ) + doc_id = struct.unpack("= prev + prev = data + assert 0 <= doc_id < 100 + doc_ids.add(doc_id) + assert len(doc_ids) == 100 def test_buckets_and_cluster(self): - sigs_folder = os.path.join(self.test_dir, "b_signatures") - buckets_folder = os.path.join(self.test_dir, "b_buckets") - clusters_folder = os.path.join(self.test_dir, "b_clusters") - - signatures_block = MinhashDedupSignature(output_folder=LocalOutputDataFolder(sigs_folder)) - buckets_block = MinhashDedupBuckets( - input_folder=LocalInputDataFolder(sigs_folder), - output_folder=LocalOutputDataFolder(buckets_folder), - ) - - clusters = [[0, 20, 50], [400, 420], [800, 810, 820, 840, 860], [1200, 1215, 1225, 1245], [1600], [2000]] - - cluster_samples = [ - Document(content=lorem_ipsum[x : x + 300], data_id=f"{ci}_{xi}", metadata={"ci": ci, "xi": xi}) - for ci, cluster in enumerate(clusters) - for xi, x in enumerate(cluster) - ] - - signatures_block(cluster_samples) - # test file read - for fi, file in enumerate(buckets_block.input_folder.list_files()): - last = None - for sig in read_sigs(file, fi, buckets_block.hashes_per_bucket): - assert 0 <= sig.doc_id < 100 - assert last is None or sig.sig >= last - assert len(sig.sig) == buckets_block.hashes_per_bucket - last = sig.sig - - # test duplicate pairs - for b in range(buckets_block.num_buckets): - buckets_block(None, bucket=b, world_size=buckets_block.num_buckets) - bucket_results_folder = LocalInputDataFolder(buckets_folder) - dup_files = bucket_results_folder.list_files(extension=".dups") - pairs = defaultdict(set) - for dup_file in dup_files: - with dup_file.open(binary=True) as df: - while data := df.read(4 * struct.calcsize("I")): - f1, d1, f2, d2 = struct.unpack("<4I", data) - assert f1 == f2 == 0 - assert cluster_samples[d1].metadata["ci"] == cluster_samples[d2].metadata["ci"] - pairs[d1].add(d2) - pairs[d2].add(d1) - doc_id = 0 - for cluster in clusters: - for a in range(doc_id, doc_id + len(cluster)): - assert len(cluster) < 2 or any(a in pairs[b] for b in range(doc_id, doc_id + len(cluster)) if a != b) - doc_id += len(cluster) - - # clustering - cluster_block = MinhashDedupCluster(bucket_results_folder, LocalOutputDataFolder(clusters_folder)) - cluster_block(None) - - cluster_results_folder = LocalInputDataFolder(clusters_folder) - remove_ids = set() - with cluster_results_folder.list_files()[0].open_binary() as df: - while data := df.read(struct.calcsize("I")): - remove_ids.add(struct.unpack("= last + assert len(sig.sig) == config.hashes_per_bucket + last = sig.sig + + # test duplicate pairs + for b in range(config.num_buckets): + buckets_block(None, bucket=b, world_size=config.num_buckets) + bucket_results_folder = LocalInputDataFolder(buckets_folder) + dup_files = bucket_results_folder.list_files(extension=".dups") + pairs = defaultdict(set) + for dup_file in dup_files: + with dup_file.open(binary=True) as df: + while data := df.read(4 * struct.calcsize("I")): + f1, d1, f2, d2 = struct.unpack("<4I", data) + assert f1 == f2 == 0 + assert cluster_samples[d1].metadata["ci"] == cluster_samples[d2].metadata["ci"] + pairs[d1].add(d2) + pairs[d2].add(d1) + doc_id = 0 + for cluster in clusters: + for a in range(doc_id, doc_id + len(cluster)): + assert len(cluster) < 2 or any( + a in pairs[b] for b in range(doc_id, doc_id + len(cluster)) if a != b + ) + doc_id += len(cluster) + + # clustering + cluster_block = MinhashDedupCluster( + bucket_results_folder, LocalOutputDataFolder(clusters_folder), config=config + ) + cluster_block(None) + + cluster_results_folder = LocalInputDataFolder(clusters_folder) + remove_ids = set() + with cluster_results_folder.list_files()[0].open_binary() as df: + while data := df.read(struct.calcsize("I")): + remove_ids.add(struct.unpack("