From 0f2f620f428b06449a37d22699dba8720793103c Mon Sep 17 00:00:00 2001 From: Jan Deriu Date: Tue, 25 Jun 2024 13:13:35 +0200 Subject: [PATCH] robots --- pipelines/fineweb_pipeline.py | 43 +++++- .../pipeline/filters/robots_filter.py | 122 ++++++++++++++++++ .../pipeline/formatters/pii_removal.py | 2 + 3 files changed, 163 insertions(+), 4 deletions(-) create mode 100644 src/swiss_ai/pipeline/filters/robots_filter.py diff --git a/pipelines/fineweb_pipeline.py b/pipelines/fineweb_pipeline.py index f215b4ca..9e4ee238 100644 --- a/pipelines/fineweb_pipeline.py +++ b/pipelines/fineweb_pipeline.py @@ -3,18 +3,53 @@ from datatrove.pipeline.filters import LambdaFilter from datatrove.pipeline.writers import JsonlWriter from swiss_ai.pipeline.formatters.pii_removal import PIIFormatter +from swiss_ai.pipeline.filters.robots_filter import RobotsFilter +from swiss_ai.writers.jsonl import SwissAIJsonlWriter + + +def _fineweb_adapter(self, data: dict, path: str, id_in_file: int | str): + year = int(data.get('date', '2024').split('-')[0]) + + metadata = { + "language": data["language"], + "year": year, + "token_count": data["token_count"], + "optional": { + "url": data["url"], + "dump": data["dump"], + "language_score": data["language_score"], + }, + } + + return { + "text": data.pop("text", ""), + "id": f"{path}/{id_in_file}", + "media": data.pop("media", []), + "metadata": metadata, + } if __name__ == "__main__": + + robots_writer = JsonlWriter( + '/work_space_data/swiss_ai/logs/hugginface/fineweb/robots', + adapter=lambda s, x: x.metadata, + compression=None + ) + pipeline_exec = LocalPipelineExecutor( pipeline=[ # replace "data/CC-MAIN-2024-10" with "sample/100BT" to use the 100BT sample - ParquetReader("/work_space_data/hf_cache/hub/datasets--HuggingFaceFW--fineweb-edu/snapshots/", limit=10_000), - PIIFormatter(), - JsonlWriter("/work_space_data/swiss_ai/hugginface/fineweb") + ParquetReader( + "/work_space_data/hf_cache/hub/datasets--HuggingFaceFW--fineweb-edu/snapshots/", + limit=10_000, + adapter=_fineweb_adapter + ), + RobotsFilter(robots_writer=robots_writer, dissalow_agents=['*', 'ccbot']), + SwissAIJsonlWriter("/work_space_data/swiss_ai/hugginface/fineweb") ], start_method="spawn", workers=8, - logging_dir='work_space_data/swiss_ai/logs', + logging_dir='/work_space_data/swiss_ai/logs', tasks=10 ) diff --git a/src/swiss_ai/pipeline/filters/robots_filter.py b/src/swiss_ai/pipeline/filters/robots_filter.py new file mode 100644 index 00000000..6cfd9524 --- /dev/null +++ b/src/swiss_ai/pipeline/filters/robots_filter.py @@ -0,0 +1,122 @@ +import json +import os, urllib, requests +import contextlib + +from datatrove.data import Document +from datatrove.pipeline.filters.base_filter import BaseFilter +from datatrove.io import DataFolderLike, get_datafolder +from datatrove.pipeline.writers.disk_base import DiskWriter + +from datatrove.data import Document, DocumentsPipeline +from datatrove.pipeline.base import PipelineStep +from datatrove.pipeline.writers.disk_base import DiskWriter +from datatrove.utils.typeshelper import StatHints +from collections import defaultdict + +class RobotsFilter(BaseFilter): + """ + Performs filtering based on whether the robots.txt disallows some urls + """ + + name = "😈 Robots.txt-filter" + + def __init__( + self, + robots_writer: DiskWriter = None, + dissalow_agents = ["*"], + exclusion_writer: DiskWriter = None + ): + super().__init__(exclusion_writer) + self.domain_to_dissallowed_urls = {} + self.robots_writer = robots_writer + self.dissalow_agents = dissalow_agents + + def fetch_robots_txt(self, domain): + urls = [f"https://{domain}/robots.txt", f"http://{domain}/robots.txt"] + for url in urls: + try: + response = requests.get(url) + except: + continue + if response.status_code == 200: + return response.text + return None + + def parse_robots_txt(self, robots_txt): + disallowed_paths_for_agent = defaultdict(list) + current_agent = "*" + for line in robots_txt.splitlines(): + line = line.strip() + if not line or line.startswith('#'): + continue + + if line.lower().startswith('user-agent:'): + current_agent = line.split(':', 1)[1].strip().lower() + + if line.strip().lower().startswith('disallow:'): + path = line.split(':', 1)[1].strip() + if path: + disallowed_paths_for_agent[current_agent].append(path) + + disallowed_paths = [] + for agent in self.dissalow_agents: + disallowed_paths.extend(disallowed_paths_for_agent.get(agent.lower(), [])) + + return disallowed_paths, disallowed_paths_for_agent + + def check_disallowed_urls(self, domain, disallowed_paths): + parsed_domain = urllib.parse.urlparse(f"http://{domain}") + disallowed_urls = [f"{parsed_domain.scheme}://{parsed_domain.netloc}{path}" for path in disallowed_paths] + return disallowed_urls + + def extract_domain(self, url): + parsed_url = urllib.parse.urlparse(url) + return parsed_url.netloc + + def is_url_disallowed(self, url, disallowed_paths): + parsed_url = urllib.parse.urlparse(url) + path = parsed_url.path + for disallowed_path in disallowed_paths: + if path.startswith(disallowed_path): + return True + return False + + def filter(self, document: Document, writer: DiskWriter, rank: int) -> bool | tuple[bool, str]: + url = document.metadata['optional'].get("url") + domain = self.extract_domain(url) + + disallowed_paths = self.domain_to_dissallowed_urls.get(domain, None) + if not disallowed_paths: + robots_txt = self.fetch_robots_txt(domain) + if robots_txt: + disallowed_paths, disallowed_paths_for_agent = self.parse_robots_txt(robots_txt) + self.domain_to_dissallowed_urls[domain] = disallowed_paths + rbts_doc = Document(metadata={domain: disallowed_paths_for_agent}, text='', id='') + writer.write(rbts_doc, rank) + else: + disallowed_paths = [] + self.domain_to_dissallowed_urls[domain] = [] + + disallowed = self.is_url_disallowed(url, disallowed_paths) + + return not disallowed, f'Disallowed by Robots.txt: {url}' + + def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: + with self.exclusion_writer if self.exclusion_writer else contextlib.nullcontext() as writer, self.robots_writer if self.robots_writer else contextlib.nullcontext() as rbs_writer: + for doc in data: + self.stat_update(StatHints.total) + with self.track_time(): + filter_result, reason = self.filter(doc, rbs_writer, rank) + if filter_result: + self.stat_update(StatHints.forwarded) + self.update_doc_stats(doc) + else: + self.stat_update(StatHints.dropped) + if reason: + self.stat_update(f"dropped_{reason}") + if self.exclusion_writer: + if reason: + doc.metadata["filter_reason"] = reason + writer.write(doc, rank) + continue + yield doc diff --git a/src/swiss_ai/pipeline/formatters/pii_removal.py b/src/swiss_ai/pipeline/formatters/pii_removal.py index edfc2755..e2ae4fec 100644 --- a/src/swiss_ai/pipeline/formatters/pii_removal.py +++ b/src/swiss_ai/pipeline/formatters/pii_removal.py @@ -101,6 +101,8 @@ def format(self, text: str) -> str: text = self.emails_replacer.replace(text) if self.remove_ips: text = self.ip_replacer.replace(text) + + #need to remove phone nr etc first. if self.remove_eu: for eu_replacer in self.eu_replacers: text = eu_replacer.replace(text)