Skip to content

Commit

Permalink
Bloom filter (#31)
Browse files Browse the repository at this point in the history
* Start bloomfilter

* add polished bloom filter with test and example

* Rename examples
  • Loading branch information
alexchapeaux authored Oct 5, 2023
1 parent 4891228 commit a0c85dd
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 1 deletion.
File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions src/datatrove/pipeline/dedup/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .bloom_filter import SingleBloomFilter
from .exact_substrings import DatasetToSequence, DedupReader, MergeSequences
from .minhash import MinhashDedupBuckets, MinhashDedupCluster, MinhashDedupFilter, MinhashDedupSignature
from .sentence_dedup import SentenceDedupFilter, SentenceDedupSignature, SentenceFindDedups
170 changes: 170 additions & 0 deletions src/datatrove/pipeline/dedup/bloom_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import contextlib
import math

import numpy as np
from loguru import logger
from nltk import ngrams, word_tokenize

from datatrove.data import Document, DocumentsPipeline
from datatrove.io import BaseOutputDataFolder
from datatrove.pipeline.base import PipelineStep
from datatrove.pipeline.dedup.utils import sha1_hash32, simplify_content
from datatrove.pipeline.writers.disk_base import DiskWriter
from datatrove.utils.typeshelper import StatHints


# http://en.wikipedia.org/wiki/Mersenne_prime
_mersenne_prime = np.uint64((1 << 61) - 1)
MAX_HASH = 1 << 32 - 1


def get_optimal_k(size_in_bytes: int, expected_elements: int) -> int:
assert expected_elements, f"if {expected_elements=} then k must be given"
m = size_in_bytes * 8
k = (m / expected_elements) * np.log(2)
return math.ceil(k)


def get_false_positive_prob(size_in_bytes: int, n: int, k: int) -> float:
m = size_in_bytes * 8
return (1.0 - (1.0 - (1.0 / m)) ** (k * n)) ** k


class SingleBloomFilter(PipelineStep):
type = "🫂 - DEDUPS"
name = "🪷 Bloom-filter"

def __init__(
self,
output_folder: BaseOutputDataFolder,
m_bytes: int,
k: int = None,
expected_elements: int = None,
duplicate_threshold: float = 0.8,
n_grams: int = 13,
seed: int = 0,
save_bloom_filter: bool = False,
exclusion_writer: DiskWriter = None,
**kwargs,
):
"""
:param output_folder: output folder: local or on S3
:param m_bytes: bloom filter size in bytes (actual size x8 bigger)
:param k: number of hashes
:param expected_elements: expected number of elements, aka shingles.
:param duplicate_threshold: above which documents are considered as duplicated
:param n_grams: n_grams to use
:param seed: seed
:param save_bloom_filter: if true saves bloom filter for later use
:param exclusion_writer: saves duplicated data
:param kwargs:
"""

super().__init__(**kwargs)
self.output_folder = output_folder
self.m_bytes = m_bytes # size in bits
self.k = k if k else get_optimal_k(self.m, expected_elements=expected_elements)
self.m = m_bytes * 8 # (self.m + 7) // 8 # size in bytes
self.duplicate_threshold = duplicate_threshold
self.n_grams = n_grams
self.bit_vector = bytearray(([0] * self.m_bytes))
self.save_bloom_filter = save_bloom_filter
self.exclusion_writer = exclusion_writer
assert self.m < MAX_HASH

self.seed = seed
self.total_shingles = 0
self._parameters = None

assert self.m_bytes < MAX_HASH, f"{MAX_HASH=} is smaller than {self.m_bytes=}"
if expected_elements:
fp = get_false_positive_prob(self.m_bytes, n=expected_elements, k=self.k)
if fp > 0.05:
logger.warning(f"False probability = {fp:.3}")
else:
logger.info(f"False probability = {fp:.3}")

def set_up_dl_locks(self, dl_lock, up_lock):
self.output_folder.set_lock(up_lock)

@property
def parameters(self):
if not self._parameters:
# Create parameters for a random bijective permutation function
# that maps a 32-bit hash value to another 32-bit hash value.
# http://en.wikipedia.org/wiki/Universal_hashing
gen = np.random.RandomState(self.seed)
self._parameters = gen.randint(1, _mersenne_prime, dtype=np.uint64, size=(1, self.k)), gen.randint(
0, _mersenne_prime, dtype=np.uint64, size=(1, self.k)
)
return self._parameters

def get_shingles(self, text: str) -> np.ndarray:
return np.array(
[
[sha1_hash32(" ".join(x).encode("utf-8"))]
for x in ngrams(word_tokenize(simplify_content(text)), self.n_grams)
],
dtype=np.uint64,
)

def get_indexes(self, shingles: np.ndarray) -> list[list[int]]:
a, b = self.parameters
phv = np.bitwise_and((shingles * a + b) % _mersenne_prime, self.m_bytes)
return phv.tolist()

def update_bf(self, indexes: list[int]):
for index in indexes:
byte_index, bit_index = divmod(index, 8)
mask = 1 << bit_index
self.bit_vector[byte_index] |= mask

def query(self, indexes: list[int]) -> bool:
for idx in indexes:
byte_index, bit_index = divmod(idx, 8)
mask = 1 << bit_index
if (self.bit_vector[byte_index] & mask) == 0:
return False

return True

def step(self, doc: Document) -> bool:
shingles = self.get_shingles(doc.content)
self.total_shingles += shingles.size
if shingles.size == 0:
return True
shingle_indexes = self.get_indexes(shingles)

duplicate_shingles = 0
indexes_to_update = []
for indexes in shingle_indexes:
if self.query(indexes):
duplicate_shingles += 1
else:
indexes_to_update.extend(indexes)

self.update_bf(indexes_to_update)
if duplicate_shingles / len(shingles) > self.duplicate_threshold:
self.stat_update(StatHints.dropped)
return False
return True

def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
with self.exclusion_writer if self.exclusion_writer else contextlib.nullcontext() as writer:
for doc_idx, doc in enumerate(data):
with self.stats.time_manager:
self.stat_update(StatHints.total)
if self.step(doc):
self.stat_update(StatHints.forwarded)
yield doc
else:
if self.exclusion_writer:
writer.write(doc, rank)
if self.save_bloom_filter:
with self.output_folder.open("bloom_filter.bloom", mode="wb") as f:
f.write(self.bit_vector)

logger.info(f"{self.total_shingles=}")
logger.info(f"False probability = {get_false_positive_prob(self.m_bytes, n=self.total_shingles, k=self.k):.3}")
logger.info(f"Optimal K given total shingles = {get_optimal_k(self.m_bytes, self.total_shingles)}")
19 changes: 18 additions & 1 deletion src/datatrove/pipeline/dedup/exact_substrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
a second file with the bytes offset of where each individual doc begins.
2) MergeSequences all sequences into a big single sequence. it saves the bytes offset per file.
... call deduplicate-text-datasets scripts ...
... call deduplicate-text-datasets scripts
in particular `cargo run self-similar ...` and `cargo run self-similar` need to be called
3) DedupReader reads docs and ranges at the same time and remove duplicates.
Expand Down Expand Up @@ -50,6 +53,12 @@ class DatasetToSequence(PipelineStep):
name = "🪞 - exact-substrings stage 1"

def __init__(self, output_folder=BaseOutputDataFolder, tokenizer_name: str = "gpt2", **kwargs):
"""
:param output_folder: folder where sequences are saved
:param tokenizer_name: name of tokenizer as in HF tokenizers.
:param kwargs:
"""
super().__init__(**kwargs)
self.output_folder = output_folder
self.tokenizer = tokenizers.Tokenizer.from_pretrained(tokenizer_name)
Expand Down Expand Up @@ -95,6 +104,14 @@ def __init__(
bytes_per_batch: int = int(500e6),
**kwargs,
):
"""
:param input_folder: folder where sequences were saved in stage 1
:param output_folder: folder where the big sequence will be saved
:param tasks_stage_1: number of tasks used in stage 1
:param bytes_per_batch: number of bytes read per sequence
:param kwargs:
"""
super().__init__(**kwargs)
self.input_folder = input_folder
self.output_folder = output_folder
Expand Down
109 changes: 109 additions & 0 deletions tests/pipeline/test_bloom_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import shutil
import tempfile
import unittest

from datatrove.data import Document
from datatrove.io import LocalInputDataFolder
from datatrove.pipeline.dedup.bloom_filter import SingleBloomFilter, get_false_positive_prob, get_optimal_k


TEXT_0 = (
"A SQUAT grey building of only thirty-four stories. Over the main entrance the words, CENTRAL LONDON HATCHERY "
"AND CONDITIONING CENTRE, and, in a shield, the World State's motto, COMMUNITY, IDENTITY, STABILITY. The enormous"
" room on the ground floor faced towards the north. Cold for all the summer beyond the panes, for all the "
"tropical heat of the room itself, a harsh thin light glared through the windows, hungrily seeking some draped "
"lay figure, some pallid shape of academic goose-flesh, but finding only the glass and nickel and bleakly shining"
" porcelain of a laboratory. Wintriness responded to wintriness. The overalls of the workers were white, their "
"hands gloved with a pale corpse-coloured rubber. The light was frozen, dead, a ghost. Only from the yellow "
"barrels of the microscopes did it borrow a certain rich and living substance, lying along the polished tubes "
"like butter, streak after luscious streak in long recession down the work tables. And this, said the Director "
"opening the door, 'is the Fertilizing Room.'"
)

TEXT_1 = (
"Wintriness responded to wintriness. The overalls of the workers were white, their "
"hands gloved with a pale corpse-coloured rubber. The light was frozen, dead, a ghost. Only from the yellow "
"barrels of the microscopes did it borrow a certain rich and living substance, lying along the polished tubes "
"like butter, streak after luscious streak in long recession down the work tables. What wintriness even mean ?"
"If you google it you will find that pretty bus it is used as a word in this book and pretty much it."
)

TEXT_2 = (
"Arise, arise, Riders of Théoden! Fell deeds awake: fire and slaughter! Spear shall be shaken, shield be "
"splintered, a sword-day, a red day, ere the sun rises!"
)

TEXT_3 = (
"I hope you're pleased with yourselves. We could all have been killed — or worse, expelled. Now if you don't "
"mind, I'm going to bed."
)

TEXT_4 = (
"Meycauayan Tree is one of the three acacia trees (Samanea saman) located in the patio of the Parish Church "
"of St. Francis of Assisi in Meycauayan City, Bulacan, Philippines. Planted by an unknown person, it has "
"stood on the grounds of the parish church for almost a century and a half."
)

TEXT_5 = (
"Geologically the Castelltallat range is made up of limestone and marl mountains oriented WSW-ENE. "
"The highest point of the range is the 936 m high 'Tossal'. The northern slopes are steep and forested, "
"while the southern slopes are used for agriculture owing to their lesser inclination. Most of the mountain "
"belongs to the municipality of Sant Mateu de Bages while the western part lies within the municipalities "
"of Pinós and La Molsosa. The village of Castelltallat was a municipality until 1840 when it became part "
"of San Mateu de Bages municipal term. The parish church of Sant Miquel has been documented since 1031 "
"and is located at an altitude of 887 m."
)

TEXT_6 = (
"Chukanovite was first discovered in weathered cavities of a meteorite which fell near the small village "
"of Dronino, 350 km southeast of Moscow, Russia, but the mineral has since been found elsewhere in cavities"
" of other iron-rich meteorites. It occurs primarily in association with goethite, akaganeite, hematite, "
"hibbingite, reevesite, honessite, and kamacite, though the meteorites that contain chukanovite also tend "
"to contain taenite and chromite. Individual crystals form from a reaction between kamacite and cold "
"water that is rich in dissolved carbon dioxide, during which they adopt a fibrous to acicular habit and "
"grow to an average size of roughly 0.5 mm in length and 2-3 μm in thickness. Individual crystals tend to "
"coalesce within the meteorite cavities into porous collections or crusts of spherulites, each with a "
"diameter of about 1 mm."
)

TEXT_7 = "1 + 1 = 2, 2 + 2 = 4, 4 + 4 = 8, ..."

DOCS = [
Document(content=TEXT_0, data_id="0"),
Document(content=TEXT_1, data_id="1"),
Document(content=TEXT_2, data_id="2"),
Document(content=TEXT_3, data_id="3"),
Document(content=TEXT_4, data_id="4"),
Document(content=TEXT_5, data_id="5"),
Document(content=TEXT_6, data_id="6"),
Document(content=TEXT_7, data_id="7"),
Document(content=TEXT_0, data_id="8"),
Document(content=TEXT_1, data_id="9"),
Document(content=TEXT_6[:-10], data_id="10"),
]

TARGETS = [True] * 8 + [False] * 3


class SentenceDedup(unittest.TestCase):
def setUp(self):
# Create a temporary directory
self.test_dir = tempfile.mkdtemp()

def tearDown(self):
# Remove the directory after the test
shutil.rmtree(self.test_dir)

def test_sd(self):
bloom_filter = SingleBloomFilter(
output_folder=LocalInputDataFolder(self.test_dir), m_bytes=2**10 - 1, k=7, expected_elements=866
)

for doc_idx, doc in enumerate(DOCS):
is_unique = bloom_filter.step(doc)
self.assertEqual(is_unique, TARGETS[doc_idx])

fp = get_false_positive_prob(bloom_filter.m_bytes, n=bloom_filter.total_shingles, k=bloom_filter.k)
print(f"False probability = {fp:.3}")
print(f"Optimal K given total shingles = {get_optimal_k(bloom_filter.m_bytes, bloom_filter.total_shingles)}")
print(f"{bloom_filter.total_shingles=}")

0 comments on commit a0c85dd

Please sign in to comment.