Skip to content

Commit

Permalink
Adds n-gram based decontamination (huggingface#172)
Browse files Browse the repository at this point in the history
* adds n-gram decontamination

* added docs

* fix lighteval version

* fix formatting
  • Loading branch information
guipenedo authored May 4, 2024
1 parent 6a4881d commit 15c8425
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ processing = [
"fasteners",
"xxhash"
]
decont = [
"lighteval>=0.3.0"
]
quality = [
"ruff>=0.1.5"
]
Expand All @@ -66,6 +69,7 @@ testing = [
"datatrove[io]",
"datatrove[processing]",
"datatrove[s3]",
"datatrove[decont]",
"pytest",
"pytest-timeout",
"pytest-xdist",
Expand Down
5 changes: 5 additions & 0 deletions src/datatrove/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ def open_file(file: IO | str, mode="rt", **kwargs):
return file


def file_exists(path: str):
fs, a, fpath = get_fs_token_paths(path)
return fs.exists(fpath[0])


def download_file(remote_path: str, local_path: str, progress: bool = True):
fs, _, paths = get_fs_token_paths(remote_path)
fs.get_file(
Expand Down
1 change: 1 addition & 0 deletions src/datatrove/pipeline/decont/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .n_grams import NGramsDecontConfig, NGramsDecontFilter, NGramsDecontIndexer
220 changes: 220 additions & 0 deletions src/datatrove/pipeline/decont/n_grams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""
Used for n-gram decontamination.
First build an index using the tasks we want to use to decontaminate our training dataset.
Then read your training data and apply the filter with the index loaded.
"""

import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Tuple

import numpy as np
from loguru import logger

from datatrove.data import Document, DocumentsPipeline
from datatrove.io import DataFolderLike, file_exists, get_datafolder, open_file
from datatrove.pipeline.base import PipelineStep
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter
from datatrove.utils.binaryio import read_np_from_file
from datatrove.utils.text import TextNormConfig, simplify_text, xxhash64


@dataclass
class NGramsDecontConfig:
"""
Example for n_grams=4
query = ['A', 'B', 'C', 'D', 'E'] (the prompt/instruction)
label = ['F', 'G', 'H', 'I', 'J'] (the answer/gold)
Will find the following N-GRAMS in the training data:
'F G H I'
'G H I J'
+ IF find_query_ngrams:
'A B C D'
'B C D E'
+ IF find_overlap_ngrams:
'C D E F'
'D E F G'
'E F G H'
"""

n_grams: int = 12
find_query_ngrams: bool = False # enable to also check for matches in n-grams containing only the input/prompt
find_overlap_ngrams: bool = True # will also find matches for n-grams containing BOTH input and query
norm_config: TextNormConfig = field(default_factory=TextNormConfig)


DEFAULT_NGRAMS_DECONT_CONFIG = NGramsDecontConfig()


class NGramsDecontIndexer(PipelineStep):
"""
Creates a decontamination index (basically a list of uint64 hashes from ngrams) for each reference task.
Ways to provide task data:
- as input documents from the previous pipeline step with "text=label/correct answer"
and metadata={"query": query/prompt/input, "task": task name}
- as a list of strings in the format "suite|task" from the lighteval metadata table:
https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl as `lighteval_tasks`
- a path to a text file containing one such list, with one "suite|task" per line as `lighteval_tasks`
you can also define your custom tasks with `custom_lighteval_tasks`. See explanation for `custom_tasks` here:
https://github.com/huggingface/lighteval/tree/main?tab=readme-ov-file#evaluate-a-model-on-extended-community-or-custom-tasks
"""

type = "🦠 - DECONT"
name = "💥 N-grams build index"
_requires_dependencies = ["nltk", "lighteval", "xxhash"]

def __init__(
self,
output_folder: DataFolderLike,
lighteval_tasks: str | list[str] | None = None, # list in the format suite|task or path to one such list
custom_lighteval_tasks: str | None = None,
config: NGramsDecontConfig = DEFAULT_NGRAMS_DECONT_CONFIG,
language: str = "english",
):
super().__init__()
self.output_folder = get_datafolder(output_folder)
# parse list of tasks
if isinstance(lighteval_tasks, str):
if file_exists(lighteval_tasks):
with open_file(lighteval_tasks, "rt") as f:
self.lighteval_tasks = f.read().strip().splitlines()
else:
self.lighteval_tasks = [lighteval_tasks]
else:
self.lighteval_tasks = lighteval_tasks
self.custom_lighteval_tasks = custom_lighteval_tasks
self.config = config
self.language = language

def compute_hashes(self, label: str, query: str | None = None) -> list[int]:
from nltk import ngrams
from nltk.tokenize import word_tokenize

label_tokens = word_tokenize(simplify_text(label, self.config.norm_config), language=self.language)
ngrams_to_compute = list(ngrams(label_tokens, self.config.n_grams))
if query is not None:
query_tokens = word_tokenize(simplify_text(query, self.config.norm_config), language=self.language)
if self.config.find_query_ngrams:
ngrams_to_compute.extend(ngrams(query_tokens, self.config.n_grams))
if self.config.find_overlap_ngrams:
# add tokens overlapping query and label
"""
A, B, C, D, E | F, G, H, I, J
5 grams
B, C, D, E, F (-N + 1 + i:) + (:i + 1)
...
E, F, G, H, I
"""
ngrams_to_compute.extend(
[
query_tokens[-self.config.n_grams + 1 + i :] + label_tokens[: i + 1]
for i in range(self.config.n_grams - 1)
# make sure we actually get a list of size N
if len(query_tokens) >= self.config.n_grams - 1 - i and len(label_tokens) >= i + 1
]
)
return list(map(xxhash64, map(" ".join, ngrams_to_compute)))

def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1):
if world_size != 1:
raise ValueError("Decontamination index building requires a single worker.")
hashes = defaultdict(set)
# use whatever date is parsed in with the following format:
# doc.text -> label
# doc.metadata["input"] -> input
if data:
for doc in data:
if not self.config.find_query_ngrams and "query" not in doc.metadata:
raise ValueError(
"only_label_ngrams is False but could not find 'query' field in documents metadata"
)
hashes[doc.metadata.get("task", "input")].update(
self.compute_hashes(doc.text, doc.metadata.get("query", None))
)

# parse data from lighteval defined tasks
from lighteval.tasks.lighteval_task import LightevalTask
from lighteval.tasks.registry import Registry

task_dict = Registry(cache_dir=os.getenv("HF_HOME")).get_task_dict(
self.lighteval_tasks, custom_tasks=self.custom_lighteval_tasks
)
LightevalTask.load_datasets(task_dict.values())

for task_name, task in task_dict.items():
for eval_doc in task.eval_docs():
for gold in eval_doc.get_golds():
hashes[task_name].update(self.compute_hashes(gold, eval_doc.query))

for task_name, task_hashes in hashes.items():
hashes_array = np.array(list(task_hashes), dtype="<u8")
logger.info(f"Saving {len(task_hashes)} hashes for {task_name}")
with self.output_folder.open(f"{task_name.replace(' ', '_')}.index.hashes", mode="wb") as f:
if self.output_folder.is_local():
hashes_array.tofile(f)
else:
f.write(hashes_array.tobytes())


class NGramsDecontFilter(BaseFilter):
"""
Loads list of hashes created by the Indexer step.
For each document in the block's input, we will check if any of its ngrams are part of the reference eval tasks.
If so, they will be removed. The contaminated ngram and task where it was found will be saved in the removed
document's metadata.
"""

type = "🦠 - DECONT"
name = "💥 N-grams decontaminate"
_requires_dependencies = ["nltk", "xxhash"]

def __init__(
self,
index_folder: DataFolderLike,
config: NGramsDecontConfig = DEFAULT_NGRAMS_DECONT_CONFIG,
exclusion_writer: DiskWriter = None,
language: str = "english",
):
super().__init__()
self.index_folder = get_datafolder(index_folder)
self.config = config
self.exclusion_writer = exclusion_writer
self.language = language
self._index_hashes = None

def load_index_hashes(self):
def load_index_from_file(file):
with self.index_folder.open(file, mode="rb") as f:
return file, read_np_from_file(f, np.dtype("<u8"), self.index_folder.is_local()).tolist()

with ThreadPoolExecutor() as pool:
hashes = pool.map(load_index_from_file, self.index_folder.list_files())

self._index_hashes = {}
for filename, hashlist in hashes:
taskname = filename.removesuffix(".index.hashes")
logger.info(f"Loading {len(hashlist)} hashes for {taskname}")
for hash in hashlist:
self._index_hashes[hash] = taskname

def filter(self, doc: Document) -> bool | Tuple[bool, str]:
if self._index_hashes is None:
self.load_index_hashes()

from nltk import ngrams
from nltk.tokenize import word_tokenize

text_tokens = word_tokenize(simplify_text(doc.text, self.config.norm_config), language=self.language)
ngrams_to_compute = list(ngrams(text_tokens, self.config.n_grams))
for n_gram in map(" ".join, ngrams_to_compute):
task = self._index_hashes.get(xxhash64(n_gram), None)
if task is not None:
doc.metadata["contaminated_ngram"] = n_gram
doc.metadata["contaminated_task"] = task
self.stat_update(f"contaminated_{task}")
return False, "contaminated"
return True
60 changes: 60 additions & 0 deletions tests/pipeline/test_ngrams_decont.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import copy
import shutil
import tempfile
import unittest

from datatrove.data import Document
from datatrove.pipeline.decont import NGramsDecontConfig, NGramsDecontFilter, NGramsDecontIndexer
from tests.utils import require_xxhash


TEXTS = [
"A lady walks to a barbell. She bends down and grabs the pole.", # 0: contaminated query
"get into formation, then begin dancing and flipping as male cheerleaders join them.", # 1: contaminated label
"He is using commercial lawn mowing equipment. he walks back and forth as he mows the grass.", # 2: cont overlap
"He is using commercial lawn mowing equipment. he is animated as he does the task.", # 3: incorrect completion
"walks outside plugs his lawn mower in and gets ready to mow", # 4: single contaminated query ngram
"", # 5: not contaminated at all
"walks outside plugs his lawn mower in and gets ready to", # 6: single contaminated query text < 1 ngram
]

DOCS = [
Document(
text="Nothing is so painful to the human mind as a great and sudden change. "
+ text
+ " Beware; for I am fearless, and therefore powerful.",
id=str(text_i),
)
for text_i, text in enumerate(TEXTS)
]


@require_xxhash
class TestNGramDecont(unittest.TestCase):
def setUp(self):
# Create a temporary directory
self.tmp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.tmp_dir)

def get_test_results(self, config):
indexer = NGramsDecontIndexer(self.tmp_dir, lighteval_tasks="leaderboard|hellaswag", config=config)
indexer.run()
nfilter = NGramsDecontFilter(self.tmp_dir, config=config)
return tuple([int(doc.id) for doc in nfilter(copy.deepcopy(DOCS))])

def test_label_only(self):
self.assertEqual(
self.get_test_results(NGramsDecontConfig(find_query_ngrams=False, find_overlap_ngrams=False)),
(0, 2, 3, 4, 5, 6),
)

def test_query(self):
self.assertEqual(
self.get_test_results(NGramsDecontConfig(find_query_ngrams=True, find_overlap_ngrams=False)), (2, 3, 5, 6)
)

def test_overlap(self):
self.assertEqual(
self.get_test_results(NGramsDecontConfig(find_query_ngrams=False, find_overlap_ngrams=True)),
(0, 3, 4, 5, 6),
)
8 changes: 8 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,11 @@ def require_xxhash(test_case):
except ImportError:
test_case = unittest.skip("test requires xxhash")(test_case)
return test_case


def require_lighteval(test_case):
try:
import lighteval # noqa: F401
except ImportError:
test_case = unittest.skip("test requires lighteval")(test_case)
return test_case

0 comments on commit 15c8425

Please sign in to comment.