Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding save to HF support for async webcrawler #312

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions crawl4ai/async_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .extraction_strategy import ExtractionStrategy
from .chunking_strategy import ChunkingStrategy
from .markdown_generation_strategy import MarkdownGenerationStrategy
from .data_persistence_strategy import DataPersistenceStrategy, SkipDataPersistenceStrategy

class BrowserConfig:
"""
Expand Down Expand Up @@ -188,6 +189,7 @@ class CrawlerRunConfig:
Default: None (NoExtractionStrategy is used if None).
chunking_strategy (ChunkingStrategy): Strategy to chunk content before extraction.
Default: RegexChunking().
data_persistence_strategy (DataPersistenceStrategy): Strategy for storing the results. Defaults to SkipDataPersistenceStrategy.
content_filter (RelevantContentFilter or None): Optional filter to prune irrelevant content.
Default: None.
cache_mode (CacheMode or None): Defines how caching is handled.
Expand Down Expand Up @@ -268,11 +270,12 @@ class CrawlerRunConfig:
def __init__(
self,
word_count_threshold: int = MIN_WORD_THRESHOLD ,
extraction_strategy : ExtractionStrategy=None, # Will default to NoExtractionStrategy if None
chunking_strategy : ChunkingStrategy= None, # Will default to RegexChunking if None
extraction_strategy : ExtractionStrategy = None, # Will default to NoExtractionStrategy if None
chunking_strategy : ChunkingStrategy = None, # Will default to RegexChunking if None
data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(),
markdown_generator : MarkdownGenerationStrategy = None,
content_filter=None,
cache_mode=None,
content_filter = None,
cache_mode = None,
session_id: str = None,
bypass_cache: bool = False,
disable_cache: bool = False,
Expand All @@ -285,7 +288,7 @@ def __init__(
only_text: bool = False,
image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
prettiify: bool = False,
js_code=None,
js_code = None,
wait_for: str = None,
js_only: bool = False,
wait_until: str = "domcontentloaded",
Expand All @@ -311,6 +314,7 @@ def __init__(
self.word_count_threshold = word_count_threshold
self.extraction_strategy = extraction_strategy
self.chunking_strategy = chunking_strategy
self.data_persistence_strategy = data_persistence_strategy
self.markdown_generator = markdown_generator
self.content_filter = content_filter
self.cache_mode = cache_mode
Expand Down Expand Up @@ -354,7 +358,9 @@ def __init__(
raise ValueError("extraction_strategy must be an instance of ExtractionStrategy")
if self.chunking_strategy is not None and not isinstance(self.chunking_strategy, ChunkingStrategy):
raise ValueError("chunking_strategy must be an instance of ChunkingStrategy")

if self.data_persistence_strategy is not None and not isinstance(data_persistence_strategy, DataPersistenceStrategy):
raise ValueError("data_persistence_strategy must be an instance of DataPersistenceStrategy")

# Set default chunking strategy if None
if self.chunking_strategy is None:
from .chunking_strategy import RegexChunking
Expand All @@ -367,6 +373,7 @@ def from_kwargs(kwargs: dict) -> "CrawlerRunConfig":
word_count_threshold=kwargs.get("word_count_threshold", 200),
extraction_strategy=kwargs.get("extraction_strategy"),
chunking_strategy=kwargs.get("chunking_strategy"),
data_persistence_strategy=kwargs.get("data_persistence_strategy"),
markdown_generator=kwargs.get("markdown_generator"),
content_filter=kwargs.get("content_filter"),
cache_mode=kwargs.get("cache_mode"),
Expand Down
10 changes: 8 additions & 2 deletions crawl4ai/async_webcrawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
URL_LOG_SHORTEN_LENGTH
)
from .data_persistence_strategy import DataPersistenceStrategy, SkipDataPersistenceStrategy
from .utils import (
sanitize_input_encode,
InvalidCSSSelectorError,
Expand Down Expand Up @@ -153,6 +154,7 @@ async def arun(
word_count_threshold=MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(),
data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(),
content_filter: RelevantContentFilter = None,
cache_mode: Optional[CacheMode] = None,
# Deprecated cache parameters
Expand Down Expand Up @@ -206,7 +208,7 @@ async def arun(
if crawler_config is not None:
if any(param is not None for param in [
word_count_threshold, extraction_strategy, chunking_strategy,
content_filter, cache_mode, css_selector, screenshot, pdf
data_persistence_strategy, content_filter, cache_mode, css_selector, screenshot, pdf
]):
self.logger.warning(
message="Both crawler_config and legacy parameters provided. crawler_config will take precedence.",
Expand All @@ -219,6 +221,7 @@ async def arun(
"word_count_threshold": word_count_threshold,
"extraction_strategy": extraction_strategy,
"chunking_strategy": chunking_strategy,
"data_persistence_strategy": data_persistence_strategy,
"content_filter": content_filter,
"cache_mode": cache_mode,
"bypass_cache": bypass_cache,
Expand Down Expand Up @@ -350,6 +353,9 @@ async def arun(
}
)

if config.data_persistence_strategy:
crawl_result.storage_metadata = data_persistence_strategy.save(crawl_result)

# Update cache if appropriate
if cache_context.should_write() and not bool(cached_result):
await async_db_manager.acache_url(crawl_result)
Expand Down Expand Up @@ -530,6 +536,7 @@ async def arun_many(
word_count_threshold=MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(),
data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(),
content_filter: RelevantContentFilter = None,
cache_mode: Optional[CacheMode] = None,
bypass_cache: bool = False,
Expand Down Expand Up @@ -683,4 +690,3 @@ async def aget_cache_size(self):
"""Get the total number of cached items."""
return await async_db_manager.aget_total_count()


152 changes: 152 additions & 0 deletions crawl4ai/data_persistence_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from abc import ABC, abstractmethod
from .models import CrawlResult
import json
import re
from datasets import Dataset
from huggingface_hub import DatasetCard
from typing import Any


class DataPersistenceStrategy(ABC):
"""
Abstract base class for implementing data persistence strategies.
"""

@abstractmethod
def save(self, result: CrawlResult) -> dict[str, Any]:
"""
Save the given crawl result using a specific persistence strategy.

Args:
result (CrawlResult): The crawl result containing data to persist.

Returns:
dict[str, Any]: A dictionary representing the outcome details of the persistence operation.
"""
pass


class SkipDataPersistenceStrategy(DataPersistenceStrategy):
def save(self, result: CrawlResult) -> dict[str, Any]:
return None


DATASET_CARD_TEMPLATE = """
---
tags:
- crawl4ai
- crawl
---

**Source of the data:**

The dataset was generated using [Crawl4ai](https://crawl4ai.com/mkdocs/) library from {url}.

"""


class HFDataPersistenceStrategy(DataPersistenceStrategy):
"""
A persistence strategy for uploading extracted content or markdown from crawl results to the Hugging Face Hub.

This strategy converts the extracted content or markdown into a Hugging Face Dataset
and uploads it to a specified repository on the Hub.

Args:
repo_id (str): The repository ID on the Hugging Face Hub.
private (bool): Whether the repository should be private.
card (str, optional): The card information for the dataset. Defaults to None.
token (str, optional): The authentication token for the Hugging Face Hub. Defaults to None.
logger (Logger, optional): Logger instance for logging messages. Defaults to None.
**kwargs: Additional keyword arguments.
"""

def __init__(
self, repo_id: str, private: bool, card: str = None, token=None, **kwargs
):
self.repo_id = repo_id
self.private = private
self.card = card
self.verbose = kwargs.get("verbose", False)
self.token = token

def save(self, result: CrawlResult) -> dict[str, Any]:
"""
Uploads extracted content or markdown from the given crawl result to the Hugging Face Hub.

Args:
result (CrawlResult): The crawl result containing extracted content or markdown to upload.

Returns:
dict[str, Any]: A dictionary with the repository ID and dataset split name.

Raises:
ValueError: If neither extracted content nor markdown is present in the result.
TypeError: If extracted content or markdown is not a string.

Notes:
- Extracted content should be a JSON string containing a list of dictionaries.
- If extracted content is invalid, raw markdown will be used as a fallback.
- The repository ID and dataset split name are returned upon successful upload.
"""
if not (result.extracted_content or result.markdown):
raise ValueError("No extracted content or markdown present.")

if result.extracted_content and not isinstance(result.extracted_content, str):
raise TypeError("Extracted content must be a string.")

if result.markdown and not isinstance(result.markdown, str):
raise TypeError("Markdown must be a string.")

records = self._prepare_records(result)

if self.verbose:
print(
f"[LOG] 🔄 Successfully converted extracted content to JSON records: {len(records)} records found"
)

ds = Dataset.from_list(records)
sanitized_split_name = re.sub(r"[^a-zA-Z0-9_]", "_", result.url)
commit_info = ds.push_to_hub(
repo_id=self.repo_id,
private=self.private,
token=self.token,
split=sanitized_split_name,
)

repo_id = commit_info.repo_url.repo_id
self._push_dataset_card(repo_id, result.url)

if self.verbose:
print(
f"[LOG] ✅ Data has been successfully pushed to the Hugging Face Hub. Repository ID: {repo_id}"
)

return {"repo_id": repo_id, "split": sanitized_split_name}

def _prepare_records(self, result: CrawlResult) -> list[dict[str, Any]]:
if result.extracted_content:
try:
records = json.loads(result.extracted_content)
if not isinstance(records, list) or not all(
isinstance(rec, dict) for rec in records
):
raise ValueError(
"Extracted content must be a JSON list of dictionaries."
)
except json.JSONDecodeError as e:
if self.verbose:
print(f"[LOG] ⚠️ Failed to parse extracted content as JSON: {e}")
records = [{"extracted_content": result.extracted_content}]
else:
records = [{"markdown": result.markdown}]

return records

def _push_dataset_card(self, repo_id: str, url: str) -> None:
card_content = self.card or DATASET_CARD_TEMPLATE.format(url=url)
DatasetCard(content=card_content).push_to_hub(
repo_id=repo_id, repo_type="dataset", token=self.token
)
if self.verbose:
print(f"[LOG] 🔄 Dataset card successfully pushed to repository: {repo_id}")
1 change: 1 addition & 0 deletions crawl4ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class CrawlResult(BaseModel):
session_id: Optional[str] = None
response_headers: Optional[dict] = None
status_code: Optional[int] = None
storage_metadata: Optional[dict] = None

class AsyncCrawlResponse(BaseModel):
html: str
Expand Down
10 changes: 9 additions & 1 deletion crawl4ai/web_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import List
from concurrent.futures import ThreadPoolExecutor
from .content_scraping_strategy import WebScrapingStrategy
from .data_persistence_strategy import DataPersistenceStrategy, SkipDataPersistenceStrategy
from .config import *
import warnings
import json
Expand Down Expand Up @@ -109,6 +110,7 @@ def run(
word_count_threshold=MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(),
data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(),
bypass_cache: bool = False,
css_selector: str = None,
screenshot: bool = False,
Expand All @@ -123,7 +125,9 @@ def run(
raise ValueError("Unsupported extraction strategy")
if not isinstance(chunking_strategy, ChunkingStrategy):
raise ValueError("Unsupported chunking strategy")

if not isinstance(data_persistence_strategy, DataPersistenceStrategy):
raise ValueError("Unsupported data persistence strategy")

word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD)

cached = None
Expand Down Expand Up @@ -157,6 +161,10 @@ def run(

crawl_result = self.process_html(url, html, extracted_content, word_count_threshold, extraction_strategy, chunking_strategy, css_selector, screenshot_data, verbose, bool(cached), **kwargs)
crawl_result.success = bool(html)

if data_persistence_strategy:
crawl_result.storage_metadata = data_persistence_strategy.save(crawl_result)

return crawl_result
except Exception as e:
if not hasattr(e, "msg"):
Expand Down
22 changes: 22 additions & 0 deletions docs/examples/async_webcrawler_md_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import asyncio
from crawl4ai import AsyncWebCrawler
from crawl4ai.data_persistence_strategy import HFDataPersistenceStrategy


async def main():
async with AsyncWebCrawler(verbose=True) as crawler:
persistence_strategy = HFDataPersistenceStrategy(
repo_id="crawl4ai_hf_page_md", private=False, verbose=True
)

result = await crawler.arun(
url="https://huggingface.co/",
data_persistence_strategy=persistence_strategy,
)

print(f"Successfully crawled markdown: {result.markdown}")
print(f"Persistence details: {result.storage_metadata}")


# Run the async main function
asyncio.run(main())
Loading