From fba36166dc7bdbfc1d55abf2f645219ece7b5eeb Mon Sep 17 00:00:00 2001 From: Andrea Soria Date: Tue, 3 Dec 2024 15:44:52 -0400 Subject: [PATCH] Adding save to HF support for async webcrawler Adding support for sync webcrawler --- crawl4ai/async_configs.py | 19 ++- crawl4ai/async_webcrawler.py | 10 +- crawl4ai/data_persistence_strategy.py | 152 ++++++++++++++++++ crawl4ai/models.py | 1 + crawl4ai/web_crawler.py | 10 +- docs/examples/async_webcrawler_md_to_hf.py | 22 +++ .../async_webcrawler_structured_to_hf.py | 67 ++++++++ requirements.txt | 3 +- tests/async/test_data_persistance_strategy.py | 93 +++++++++++ 9 files changed, 367 insertions(+), 10 deletions(-) create mode 100644 crawl4ai/data_persistence_strategy.py create mode 100644 docs/examples/async_webcrawler_md_to_hf.py create mode 100644 docs/examples/async_webcrawler_structured_to_hf.py create mode 100644 tests/async/test_data_persistance_strategy.py diff --git a/crawl4ai/async_configs.py b/crawl4ai/async_configs.py index aa0b849..93279b1 100644 --- a/crawl4ai/async_configs.py +++ b/crawl4ai/async_configs.py @@ -8,6 +8,7 @@ from .extraction_strategy import ExtractionStrategy from .chunking_strategy import ChunkingStrategy from .markdown_generation_strategy import MarkdownGenerationStrategy +from .data_persistence_strategy import DataPersistenceStrategy, SkipDataPersistenceStrategy class BrowserConfig: """ @@ -188,6 +189,7 @@ class CrawlerRunConfig: Default: None (NoExtractionStrategy is used if None). chunking_strategy (ChunkingStrategy): Strategy to chunk content before extraction. Default: RegexChunking(). + data_persistence_strategy (DataPersistenceStrategy): Strategy for storing the results. Defaults to SkipDataPersistenceStrategy. content_filter (RelevantContentFilter or None): Optional filter to prune irrelevant content. Default: None. cache_mode (CacheMode or None): Defines how caching is handled. @@ -268,11 +270,12 @@ class CrawlerRunConfig: def __init__( self, word_count_threshold: int = MIN_WORD_THRESHOLD , - extraction_strategy : ExtractionStrategy=None, # Will default to NoExtractionStrategy if None - chunking_strategy : ChunkingStrategy= None, # Will default to RegexChunking if None + extraction_strategy : ExtractionStrategy = None, # Will default to NoExtractionStrategy if None + chunking_strategy : ChunkingStrategy = None, # Will default to RegexChunking if None + data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(), markdown_generator : MarkdownGenerationStrategy = None, - content_filter=None, - cache_mode=None, + content_filter = None, + cache_mode = None, session_id: str = None, bypass_cache: bool = False, disable_cache: bool = False, @@ -285,7 +288,7 @@ def __init__( only_text: bool = False, image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, prettiify: bool = False, - js_code=None, + js_code = None, wait_for: str = None, js_only: bool = False, wait_until: str = "domcontentloaded", @@ -311,6 +314,7 @@ def __init__( self.word_count_threshold = word_count_threshold self.extraction_strategy = extraction_strategy self.chunking_strategy = chunking_strategy + self.data_persistence_strategy = data_persistence_strategy self.markdown_generator = markdown_generator self.content_filter = content_filter self.cache_mode = cache_mode @@ -354,7 +358,9 @@ def __init__( raise ValueError("extraction_strategy must be an instance of ExtractionStrategy") if self.chunking_strategy is not None and not isinstance(self.chunking_strategy, ChunkingStrategy): raise ValueError("chunking_strategy must be an instance of ChunkingStrategy") - + if self.data_persistence_strategy is not None and not isinstance(data_persistence_strategy, DataPersistenceStrategy): + raise ValueError("data_persistence_strategy must be an instance of DataPersistenceStrategy") + # Set default chunking strategy if None if self.chunking_strategy is None: from .chunking_strategy import RegexChunking @@ -367,6 +373,7 @@ def from_kwargs(kwargs: dict) -> "CrawlerRunConfig": word_count_threshold=kwargs.get("word_count_threshold", 200), extraction_strategy=kwargs.get("extraction_strategy"), chunking_strategy=kwargs.get("chunking_strategy"), + data_persistence_strategy=kwargs.get("data_persistence_strategy"), markdown_generator=kwargs.get("markdown_generator"), content_filter=kwargs.get("content_filter"), cache_mode=kwargs.get("cache_mode"), diff --git a/crawl4ai/async_webcrawler.py b/crawl4ai/async_webcrawler.py index 9b96815..b5d1dc9 100644 --- a/crawl4ai/async_webcrawler.py +++ b/crawl4ai/async_webcrawler.py @@ -25,6 +25,7 @@ IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, URL_LOG_SHORTEN_LENGTH ) +from .data_persistence_strategy import DataPersistenceStrategy, SkipDataPersistenceStrategy from .utils import ( sanitize_input_encode, InvalidCSSSelectorError, @@ -153,6 +154,7 @@ async def arun( word_count_threshold=MIN_WORD_THRESHOLD, extraction_strategy: ExtractionStrategy = None, chunking_strategy: ChunkingStrategy = RegexChunking(), + data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(), content_filter: RelevantContentFilter = None, cache_mode: Optional[CacheMode] = None, # Deprecated cache parameters @@ -206,7 +208,7 @@ async def arun( if crawler_config is not None: if any(param is not None for param in [ word_count_threshold, extraction_strategy, chunking_strategy, - content_filter, cache_mode, css_selector, screenshot, pdf + data_persistence_strategy, content_filter, cache_mode, css_selector, screenshot, pdf ]): self.logger.warning( message="Both crawler_config and legacy parameters provided. crawler_config will take precedence.", @@ -219,6 +221,7 @@ async def arun( "word_count_threshold": word_count_threshold, "extraction_strategy": extraction_strategy, "chunking_strategy": chunking_strategy, + "data_persistence_strategy": data_persistence_strategy, "content_filter": content_filter, "cache_mode": cache_mode, "bypass_cache": bypass_cache, @@ -350,6 +353,9 @@ async def arun( } ) + if config.data_persistence_strategy: + crawl_result.storage_metadata = data_persistence_strategy.save(crawl_result) + # Update cache if appropriate if cache_context.should_write() and not bool(cached_result): await async_db_manager.acache_url(crawl_result) @@ -530,6 +536,7 @@ async def arun_many( word_count_threshold=MIN_WORD_THRESHOLD, extraction_strategy: ExtractionStrategy = None, chunking_strategy: ChunkingStrategy = RegexChunking(), + data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(), content_filter: RelevantContentFilter = None, cache_mode: Optional[CacheMode] = None, bypass_cache: bool = False, @@ -683,4 +690,3 @@ async def aget_cache_size(self): """Get the total number of cached items.""" return await async_db_manager.aget_total_count() - diff --git a/crawl4ai/data_persistence_strategy.py b/crawl4ai/data_persistence_strategy.py new file mode 100644 index 0000000..b982c62 --- /dev/null +++ b/crawl4ai/data_persistence_strategy.py @@ -0,0 +1,152 @@ +from abc import ABC, abstractmethod +from .models import CrawlResult +import json +import re +from datasets import Dataset +from huggingface_hub import DatasetCard +from typing import Any + + +class DataPersistenceStrategy(ABC): + """ + Abstract base class for implementing data persistence strategies. + """ + + @abstractmethod + def save(self, result: CrawlResult) -> dict[str, Any]: + """ + Save the given crawl result using a specific persistence strategy. + + Args: + result (CrawlResult): The crawl result containing data to persist. + + Returns: + dict[str, Any]: A dictionary representing the outcome details of the persistence operation. + """ + pass + + +class SkipDataPersistenceStrategy(DataPersistenceStrategy): + def save(self, result: CrawlResult) -> dict[str, Any]: + return None + + +DATASET_CARD_TEMPLATE = """ +--- +tags: +- crawl4ai +- crawl +--- + +**Source of the data:** + +The dataset was generated using [Crawl4ai](https://crawl4ai.com/mkdocs/) library from {url}. + +""" + + +class HFDataPersistenceStrategy(DataPersistenceStrategy): + """ + A persistence strategy for uploading extracted content or markdown from crawl results to the Hugging Face Hub. + + This strategy converts the extracted content or markdown into a Hugging Face Dataset + and uploads it to a specified repository on the Hub. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + private (bool): Whether the repository should be private. + card (str, optional): The card information for the dataset. Defaults to None. + token (str, optional): The authentication token for the Hugging Face Hub. Defaults to None. + logger (Logger, optional): Logger instance for logging messages. Defaults to None. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, repo_id: str, private: bool, card: str = None, token=None, **kwargs + ): + self.repo_id = repo_id + self.private = private + self.card = card + self.verbose = kwargs.get("verbose", False) + self.token = token + + def save(self, result: CrawlResult) -> dict[str, Any]: + """ + Uploads extracted content or markdown from the given crawl result to the Hugging Face Hub. + + Args: + result (CrawlResult): The crawl result containing extracted content or markdown to upload. + + Returns: + dict[str, Any]: A dictionary with the repository ID and dataset split name. + + Raises: + ValueError: If neither extracted content nor markdown is present in the result. + TypeError: If extracted content or markdown is not a string. + + Notes: + - Extracted content should be a JSON string containing a list of dictionaries. + - If extracted content is invalid, raw markdown will be used as a fallback. + - The repository ID and dataset split name are returned upon successful upload. + """ + if not (result.extracted_content or result.markdown): + raise ValueError("No extracted content or markdown present.") + + if result.extracted_content and not isinstance(result.extracted_content, str): + raise TypeError("Extracted content must be a string.") + + if result.markdown and not isinstance(result.markdown, str): + raise TypeError("Markdown must be a string.") + + records = self._prepare_records(result) + + if self.verbose: + print( + f"[LOG] 🔄 Successfully converted extracted content to JSON records: {len(records)} records found" + ) + + ds = Dataset.from_list(records) + sanitized_split_name = re.sub(r"[^a-zA-Z0-9_]", "_", result.url) + commit_info = ds.push_to_hub( + repo_id=self.repo_id, + private=self.private, + token=self.token, + split=sanitized_split_name, + ) + + repo_id = commit_info.repo_url.repo_id + self._push_dataset_card(repo_id, result.url) + + if self.verbose: + print( + f"[LOG] ✅ Data has been successfully pushed to the Hugging Face Hub. Repository ID: {repo_id}" + ) + + return {"repo_id": repo_id, "split": sanitized_split_name} + + def _prepare_records(self, result: CrawlResult) -> list[dict[str, Any]]: + if result.extracted_content: + try: + records = json.loads(result.extracted_content) + if not isinstance(records, list) or not all( + isinstance(rec, dict) for rec in records + ): + raise ValueError( + "Extracted content must be a JSON list of dictionaries." + ) + except json.JSONDecodeError as e: + if self.verbose: + print(f"[LOG] ⚠️ Failed to parse extracted content as JSON: {e}") + records = [{"extracted_content": result.extracted_content}] + else: + records = [{"markdown": result.markdown}] + + return records + + def _push_dataset_card(self, repo_id: str, url: str) -> None: + card_content = self.card or DATASET_CARD_TEMPLATE.format(url=url) + DatasetCard(content=card_content).push_to_hub( + repo_id=repo_id, repo_type="dataset", token=self.token + ) + if self.verbose: + print(f"[LOG] 🔄 Dataset card successfully pushed to repository: {repo_id}") diff --git a/crawl4ai/models.py b/crawl4ai/models.py index 315069f..2258c5f 100644 --- a/crawl4ai/models.py +++ b/crawl4ai/models.py @@ -34,6 +34,7 @@ class CrawlResult(BaseModel): session_id: Optional[str] = None response_headers: Optional[dict] = None status_code: Optional[int] = None + storage_metadata: Optional[dict] = None class AsyncCrawlResponse(BaseModel): html: str diff --git a/crawl4ai/web_crawler.py b/crawl4ai/web_crawler.py index a32a988..34479fa 100644 --- a/crawl4ai/web_crawler.py +++ b/crawl4ai/web_crawler.py @@ -11,6 +11,7 @@ from typing import List from concurrent.futures import ThreadPoolExecutor from .content_scraping_strategy import WebScrapingStrategy +from .data_persistence_strategy import DataPersistenceStrategy, SkipDataPersistenceStrategy from .config import * import warnings import json @@ -109,6 +110,7 @@ def run( word_count_threshold=MIN_WORD_THRESHOLD, extraction_strategy: ExtractionStrategy = None, chunking_strategy: ChunkingStrategy = RegexChunking(), + data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(), bypass_cache: bool = False, css_selector: str = None, screenshot: bool = False, @@ -123,7 +125,9 @@ def run( raise ValueError("Unsupported extraction strategy") if not isinstance(chunking_strategy, ChunkingStrategy): raise ValueError("Unsupported chunking strategy") - + if not isinstance(data_persistence_strategy, DataPersistenceStrategy): + raise ValueError("Unsupported data persistence strategy") + word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD) cached = None @@ -157,6 +161,10 @@ def run( crawl_result = self.process_html(url, html, extracted_content, word_count_threshold, extraction_strategy, chunking_strategy, css_selector, screenshot_data, verbose, bool(cached), **kwargs) crawl_result.success = bool(html) + + if data_persistence_strategy: + crawl_result.storage_metadata = data_persistence_strategy.save(crawl_result) + return crawl_result except Exception as e: if not hasattr(e, "msg"): diff --git a/docs/examples/async_webcrawler_md_to_hf.py b/docs/examples/async_webcrawler_md_to_hf.py new file mode 100644 index 0000000..1e228dd --- /dev/null +++ b/docs/examples/async_webcrawler_md_to_hf.py @@ -0,0 +1,22 @@ +import asyncio +from crawl4ai import AsyncWebCrawler +from crawl4ai.data_persistence_strategy import HFDataPersistenceStrategy + + +async def main(): + async with AsyncWebCrawler(verbose=True) as crawler: + persistence_strategy = HFDataPersistenceStrategy( + repo_id="crawl4ai_hf_page_md", private=False, verbose=True + ) + + result = await crawler.arun( + url="https://huggingface.co/", + data_persistence_strategy=persistence_strategy, + ) + + print(f"Successfully crawled markdown: {result.markdown}") + print(f"Persistence details: {result.storage_metadata}") + + +# Run the async main function +asyncio.run(main()) diff --git a/docs/examples/async_webcrawler_structured_to_hf.py b/docs/examples/async_webcrawler_structured_to_hf.py new file mode 100644 index 0000000..c97f437 --- /dev/null +++ b/docs/examples/async_webcrawler_structured_to_hf.py @@ -0,0 +1,67 @@ +import os +import sys +import asyncio +from crawl4ai.extraction_strategy import JsonCssExtractionStrategy +from crawl4ai.data_persistence_strategy import HFDataPersistenceStrategy + +# Add the parent directory to the Python path +parent_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(parent_dir) + +from crawl4ai.async_webcrawler import AsyncWebCrawler + + +async def main(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://www.nbcnews.com/business" + schema = { + "name": "News Teaser Extractor", + "baseSelector": ".wide-tease-item__wrapper", + "fields": [ + { + "name": "category", + "selector": ".unibrow span[data-testid='unibrow-text']", + "type": "text", + }, + { + "name": "headline", + "selector": ".wide-tease-item__headline", + "type": "text", + }, + { + "name": "summary", + "selector": ".wide-tease-item__description", + "type": "text", + }, + { + "name": "link", + "selector": "a[href]", + "type": "attribute", + "attribute": "href", + }, + ], + } + + extraction_strategy = JsonCssExtractionStrategy(schema, verbose=True) + persistence_strategy = HFDataPersistenceStrategy( + repo_id="crawl4ai_nbcnews_structured", private=False, verbose=True + ) + + result = await crawler.arun( + url=url, + bypass_cache=True, + extraction_strategy=extraction_strategy, + data_persistence_strategy=persistence_strategy, + ) + if result.success: + print(f"Successfully crawled: {result.url}") + print(f"Persistence details: {result.storage_metadata}") + else: + print(f"Failed to crawl: {result.url}") + print(f"Error: {result.error_message}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/requirements.txt b/requirements.txt index 741e12e..ee1e09c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ rank-bm25~=0.2 aiofiles>=24.1.0 colorama~=0.4 snowballstemmer~=2.2 -pydantic>=2.10 \ No newline at end of file +pydantic>=2.10 +datasets~=3.1.0 diff --git a/tests/async/test_data_persistance_strategy.py b/tests/async/test_data_persistance_strategy.py new file mode 100644 index 0000000..f64d196 --- /dev/null +++ b/tests/async/test_data_persistance_strategy.py @@ -0,0 +1,93 @@ +import os +import sys +import pytest +from datasets import load_dataset +from crawl4ai.extraction_strategy import JsonCssExtractionStrategy +from crawl4ai.data_persistence_strategy import HFDataPersistenceStrategy + +# Add the parent directory to the Python path +parent_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(parent_dir) + +from crawl4ai.async_webcrawler import AsyncWebCrawler + + +@pytest.mark.asyncio +async def test_save_with_unsupported_data_persistence_strategy(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://www.nbcnews.com/business" + unsupported_data_persistence_strategy = JsonCssExtractionStrategy({}) + + result = await crawler.arun( + url=url, + bypass_cache=True, + data_persistence_strategy=unsupported_data_persistence_strategy, + ) + assert not result.success + assert result.url == url + assert not result.html + assert "data_persistence_strategy must be an instance of DataPersistenceStrategy" in result.error_message + assert result.storage_metadata is None + + +@pytest.mark.asyncio +async def test_save_to_hf(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://www.nbcnews.com/business" + schema = { + "name": "News Teaser Extractor", + "baseSelector": ".wide-tease-item__wrapper", + "fields": [ + { + "name": "category", + "selector": ".unibrow span[data-testid='unibrow-text']", + "type": "text", + }, + { + "name": "headline", + "selector": ".wide-tease-item__headline", + "type": "text", + }, + { + "name": "summary", + "selector": ".wide-tease-item__description", + "type": "text", + }, + { + "name": "link", + "selector": "a[href]", + "type": "attribute", + "attribute": "href", + }, + ], + } + + extraction_strategy = JsonCssExtractionStrategy(schema, verbose=True) + repo_id = "test_repo" + data_persistence_strategy = HFDataPersistenceStrategy( + repo_id=repo_id, private=False, verbose=True + ) + + result = await crawler.arun( + url=url, + bypass_cache=True, + extraction_strategy=extraction_strategy, + data_persistence_strategy=data_persistence_strategy, + ) + assert result.success + assert result.url == url + assert result.html + assert result.markdown + assert result.cleaned_html + assert result.storage_metadata + assert result.storage_metadata["split"] == "https___www_nbcnews_com_business" + created_repo_id = result.storage_metadata["repo_id"] + new_dataset = load_dataset(created_repo_id, split="train") + assert len(new_dataset) > 0 + + +# Entry point for debugging +if __name__ == "__main__": + pytest.main([__file__, "-v"])