From d56d3c5adefe15e68e7368eb8284e9c4236586c3 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 4 May 2024 19:27:19 +0200 Subject: [PATCH] add skip in decont index builder --- src/datatrove/pipeline/decont/n_grams.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/datatrove/pipeline/decont/n_grams.py b/src/datatrove/pipeline/decont/n_grams.py index 64b25d22..f197da08 100644 --- a/src/datatrove/pipeline/decont/n_grams.py +++ b/src/datatrove/pipeline/decont/n_grams.py @@ -147,8 +147,14 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 for task_name, task in task_dict.items(): for eval_doc in task.eval_docs(): - for gold in eval_doc.get_golds(): - hashes[task_name].update(self.compute_hashes(gold, eval_doc.query)) + try: + golds = eval_doc.get_golds() + query = eval_doc.query + except Exception as e: + logger.warning(f"Error while fetching doc data: {e}") + continue + for gold in golds: + hashes[task_name].update(self.compute_hashes(gold, query)) for task_name, task_hashes in hashes.items(): hashes_array = np.array(list(task_hashes), dtype=" bool | Tuple[bool, str]: doc.metadata["contaminated_ngram"] = n_gram doc.metadata["contaminated_task"] = task self.stat_update(f"contaminated_{task}") + if ":" in task: + self.stat_update(f"contaminated_tg_{task[:task.index(':')]}") return False, "contaminated" return True