From 15c8425cfed1d60c16f26ac307a5ed3dfa159bc6 Mon Sep 17 00:00:00 2001 From: Guilherme Penedo Date: Sat, 4 May 2024 11:38:58 +0200 Subject: [PATCH] Adds n-gram based decontamination (#172) * adds n-gram decontamination * added docs * fix lighteval version * fix formatting --- pyproject.toml | 4 + src/datatrove/io.py | 5 + src/datatrove/pipeline/decont/__init__.py | 1 + src/datatrove/pipeline/decont/n_grams.py | 220 ++++++++++++++++++++++ tests/pipeline/test_ngrams_decont.py | 60 ++++++ tests/utils.py | 8 + 6 files changed, 298 insertions(+) create mode 100644 src/datatrove/pipeline/decont/__init__.py create mode 100644 src/datatrove/pipeline/decont/n_grams.py create mode 100644 tests/pipeline/test_ngrams_decont.py diff --git a/pyproject.toml b/pyproject.toml index 2db9b32a..54c9231d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,9 @@ processing = [ "fasteners", "xxhash" ] +decont = [ + "lighteval>=0.3.0" +] quality = [ "ruff>=0.1.5" ] @@ -66,6 +69,7 @@ testing = [ "datatrove[io]", "datatrove[processing]", "datatrove[s3]", + "datatrove[decont]", "pytest", "pytest-timeout", "pytest-xdist", diff --git a/src/datatrove/io.py b/src/datatrove/io.py index be99291e..f91f2c77 100644 --- a/src/datatrove/io.py +++ b/src/datatrove/io.py @@ -284,6 +284,11 @@ def open_file(file: IO | str, mode="rt", **kwargs): return file +def file_exists(path: str): + fs, a, fpath = get_fs_token_paths(path) + return fs.exists(fpath[0]) + + def download_file(remote_path: str, local_path: str, progress: bool = True): fs, _, paths = get_fs_token_paths(remote_path) fs.get_file( diff --git a/src/datatrove/pipeline/decont/__init__.py b/src/datatrove/pipeline/decont/__init__.py new file mode 100644 index 00000000..efea634c --- /dev/null +++ b/src/datatrove/pipeline/decont/__init__.py @@ -0,0 +1 @@ +from .n_grams import NGramsDecontConfig, NGramsDecontFilter, NGramsDecontIndexer diff --git a/src/datatrove/pipeline/decont/n_grams.py b/src/datatrove/pipeline/decont/n_grams.py new file mode 100644 index 00000000..64b25d22 --- /dev/null +++ b/src/datatrove/pipeline/decont/n_grams.py @@ -0,0 +1,220 @@ +""" +Used for n-gram decontamination. +First build an index using the tasks we want to use to decontaminate our training dataset. +Then read your training data and apply the filter with the index loaded. +""" + +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Tuple + +import numpy as np +from loguru import logger + +from datatrove.data import Document, DocumentsPipeline +from datatrove.io import DataFolderLike, file_exists, get_datafolder, open_file +from datatrove.pipeline.base import PipelineStep +from datatrove.pipeline.filters.base_filter import BaseFilter +from datatrove.pipeline.writers.disk_base import DiskWriter +from datatrove.utils.binaryio import read_np_from_file +from datatrove.utils.text import TextNormConfig, simplify_text, xxhash64 + + +@dataclass +class NGramsDecontConfig: + """ + Example for n_grams=4 + query = ['A', 'B', 'C', 'D', 'E'] (the prompt/instruction) + label = ['F', 'G', 'H', 'I', 'J'] (the answer/gold) + Will find the following N-GRAMS in the training data: + 'F G H I' + 'G H I J' + + IF find_query_ngrams: + 'A B C D' + 'B C D E' + + IF find_overlap_ngrams: + 'C D E F' + 'D E F G' + 'E F G H' + """ + + n_grams: int = 12 + find_query_ngrams: bool = False # enable to also check for matches in n-grams containing only the input/prompt + find_overlap_ngrams: bool = True # will also find matches for n-grams containing BOTH input and query + norm_config: TextNormConfig = field(default_factory=TextNormConfig) + + +DEFAULT_NGRAMS_DECONT_CONFIG = NGramsDecontConfig() + + +class NGramsDecontIndexer(PipelineStep): + """ + Creates a decontamination index (basically a list of uint64 hashes from ngrams) for each reference task. + Ways to provide task data: + - as input documents from the previous pipeline step with "text=label/correct answer" + and metadata={"query": query/prompt/input, "task": task name} + - as a list of strings in the format "suite|task" from the lighteval metadata table: + https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl as `lighteval_tasks` + - a path to a text file containing one such list, with one "suite|task" per line as `lighteval_tasks` + you can also define your custom tasks with `custom_lighteval_tasks`. See explanation for `custom_tasks` here: + https://github.com/huggingface/lighteval/tree/main?tab=readme-ov-file#evaluate-a-model-on-extended-community-or-custom-tasks + + """ + + type = "🦠 - DECONT" + name = "💥 N-grams build index" + _requires_dependencies = ["nltk", "lighteval", "xxhash"] + + def __init__( + self, + output_folder: DataFolderLike, + lighteval_tasks: str | list[str] | None = None, # list in the format suite|task or path to one such list + custom_lighteval_tasks: str | None = None, + config: NGramsDecontConfig = DEFAULT_NGRAMS_DECONT_CONFIG, + language: str = "english", + ): + super().__init__() + self.output_folder = get_datafolder(output_folder) + # parse list of tasks + if isinstance(lighteval_tasks, str): + if file_exists(lighteval_tasks): + with open_file(lighteval_tasks, "rt") as f: + self.lighteval_tasks = f.read().strip().splitlines() + else: + self.lighteval_tasks = [lighteval_tasks] + else: + self.lighteval_tasks = lighteval_tasks + self.custom_lighteval_tasks = custom_lighteval_tasks + self.config = config + self.language = language + + def compute_hashes(self, label: str, query: str | None = None) -> list[int]: + from nltk import ngrams + from nltk.tokenize import word_tokenize + + label_tokens = word_tokenize(simplify_text(label, self.config.norm_config), language=self.language) + ngrams_to_compute = list(ngrams(label_tokens, self.config.n_grams)) + if query is not None: + query_tokens = word_tokenize(simplify_text(query, self.config.norm_config), language=self.language) + if self.config.find_query_ngrams: + ngrams_to_compute.extend(ngrams(query_tokens, self.config.n_grams)) + if self.config.find_overlap_ngrams: + # add tokens overlapping query and label + """ + A, B, C, D, E | F, G, H, I, J + 5 grams + B, C, D, E, F (-N + 1 + i:) + (:i + 1) + ... + E, F, G, H, I + """ + ngrams_to_compute.extend( + [ + query_tokens[-self.config.n_grams + 1 + i :] + label_tokens[: i + 1] + for i in range(self.config.n_grams - 1) + # make sure we actually get a list of size N + if len(query_tokens) >= self.config.n_grams - 1 - i and len(label_tokens) >= i + 1 + ] + ) + return list(map(xxhash64, map(" ".join, ngrams_to_compute))) + + def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1): + if world_size != 1: + raise ValueError("Decontamination index building requires a single worker.") + hashes = defaultdict(set) + # use whatever date is parsed in with the following format: + # doc.text -> label + # doc.metadata["input"] -> input + if data: + for doc in data: + if not self.config.find_query_ngrams and "query" not in doc.metadata: + raise ValueError( + "only_label_ngrams is False but could not find 'query' field in documents metadata" + ) + hashes[doc.metadata.get("task", "input")].update( + self.compute_hashes(doc.text, doc.metadata.get("query", None)) + ) + + # parse data from lighteval defined tasks + from lighteval.tasks.lighteval_task import LightevalTask + from lighteval.tasks.registry import Registry + + task_dict = Registry(cache_dir=os.getenv("HF_HOME")).get_task_dict( + self.lighteval_tasks, custom_tasks=self.custom_lighteval_tasks + ) + LightevalTask.load_datasets(task_dict.values()) + + for task_name, task in task_dict.items(): + for eval_doc in task.eval_docs(): + for gold in eval_doc.get_golds(): + hashes[task_name].update(self.compute_hashes(gold, eval_doc.query)) + + for task_name, task_hashes in hashes.items(): + hashes_array = np.array(list(task_hashes), dtype=" bool | Tuple[bool, str]: + if self._index_hashes is None: + self.load_index_hashes() + + from nltk import ngrams + from nltk.tokenize import word_tokenize + + text_tokens = word_tokenize(simplify_text(doc.text, self.config.norm_config), language=self.language) + ngrams_to_compute = list(ngrams(text_tokens, self.config.n_grams)) + for n_gram in map(" ".join, ngrams_to_compute): + task = self._index_hashes.get(xxhash64(n_gram), None) + if task is not None: + doc.metadata["contaminated_ngram"] = n_gram + doc.metadata["contaminated_task"] = task + self.stat_update(f"contaminated_{task}") + return False, "contaminated" + return True diff --git a/tests/pipeline/test_ngrams_decont.py b/tests/pipeline/test_ngrams_decont.py new file mode 100644 index 00000000..e60e9305 --- /dev/null +++ b/tests/pipeline/test_ngrams_decont.py @@ -0,0 +1,60 @@ +import copy +import shutil +import tempfile +import unittest + +from datatrove.data import Document +from datatrove.pipeline.decont import NGramsDecontConfig, NGramsDecontFilter, NGramsDecontIndexer +from tests.utils import require_xxhash + + +TEXTS = [ + "A lady walks to a barbell. She bends down and grabs the pole.", # 0: contaminated query + "get into formation, then begin dancing and flipping as male cheerleaders join them.", # 1: contaminated label + "He is using commercial lawn mowing equipment. he walks back and forth as he mows the grass.", # 2: cont overlap + "He is using commercial lawn mowing equipment. he is animated as he does the task.", # 3: incorrect completion + "walks outside plugs his lawn mower in and gets ready to mow", # 4: single contaminated query ngram + "", # 5: not contaminated at all + "walks outside plugs his lawn mower in and gets ready to", # 6: single contaminated query text < 1 ngram +] + +DOCS = [ + Document( + text="Nothing is so painful to the human mind as a great and sudden change. " + + text + + " Beware; for I am fearless, and therefore powerful.", + id=str(text_i), + ) + for text_i, text in enumerate(TEXTS) +] + + +@require_xxhash +class TestNGramDecont(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.tmp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp_dir) + + def get_test_results(self, config): + indexer = NGramsDecontIndexer(self.tmp_dir, lighteval_tasks="leaderboard|hellaswag", config=config) + indexer.run() + nfilter = NGramsDecontFilter(self.tmp_dir, config=config) + return tuple([int(doc.id) for doc in nfilter(copy.deepcopy(DOCS))]) + + def test_label_only(self): + self.assertEqual( + self.get_test_results(NGramsDecontConfig(find_query_ngrams=False, find_overlap_ngrams=False)), + (0, 2, 3, 4, 5, 6), + ) + + def test_query(self): + self.assertEqual( + self.get_test_results(NGramsDecontConfig(find_query_ngrams=True, find_overlap_ngrams=False)), (2, 3, 5, 6) + ) + + def test_overlap(self): + self.assertEqual( + self.get_test_results(NGramsDecontConfig(find_query_ngrams=False, find_overlap_ngrams=True)), + (0, 3, 4, 5, 6), + ) diff --git a/tests/utils.py b/tests/utils.py index 65af1b3d..0fe4b7cd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -103,3 +103,11 @@ def require_xxhash(test_case): except ImportError: test_case = unittest.skip("test requires xxhash")(test_case) return test_case + + +def require_lighteval(test_case): + try: + import lighteval # noqa: F401 + except ImportError: + test_case = unittest.skip("test requires lighteval")(test_case) + return test_case