Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
mike-gee committed Oct 30, 2023
1 parent 14d84df commit 52bfeba
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 73 deletions.
1 change: 0 additions & 1 deletion src/webtranspose/consts.py

This file was deleted.

115 changes: 62 additions & 53 deletions src/webtranspose/crawl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import zipfile
from datetime import datetime
from fnmatch import fnmatch
from typing import Dict, List, Optional
from urllib.parse import urljoin, urlparse, urlunparse

import httpx
Expand All @@ -20,17 +21,17 @@
class Crawl:
def __init__(
self,
url,
allowed_urls=[],
banned_urls=[],
n_workers=1,
max_pages=15,
render_js=False,
output_dir="webtranspose-out",
verbose=False,
api_key=None,
_created=False,
):
url: str,
allowed_urls: List[str] = [],
banned_urls: List[str] = [],
n_workers: int = 1,
max_pages: int = 15,
render_js: bool = False,
output_dir: str = "webtranspose-out",
verbose: bool = False,
api_key: Optional[str] = None,
_created: bool = False,
) -> None:
"""
Initialize the Crawl object.
Expand Down Expand Up @@ -74,19 +75,19 @@ def __init__(

@staticmethod
async def crawl_worker(
name,
queue,
crawl_id,
visited_urls,
allowed_urls,
banned_urls,
output_dir,
base_url,
max_pages,
leftover_queue,
ignored_queue,
verbose,
):
name: str,
queue: asyncio.Queue,
crawl_id: str,
visited_urls: Dict[str, str],
allowed_urls: List[str],
banned_urls: List[str],
output_dir: str,
base_url: str,
max_pages: int,
leftover_queue: asyncio.Queue,
ignored_queue: asyncio.Queue,
verbose: bool,
) -> None:
"""
Worker function for crawling URLs.
Expand All @@ -104,7 +105,7 @@ async def crawl_worker(
:param verbose: Whether to print verbose logging messages.
"""

def _lint_url(url):
def _lint_url(url: str) -> str:
"""
Lint the given URL by removing the fragment component.
Expand Down Expand Up @@ -200,7 +201,7 @@ def _lint_url(url):

queue.task_done()

async def create_crawl_api(self):
def create_crawl_api(self):
"""
Creates a Crawl on https://webtranspose.com
"""
Expand All @@ -219,7 +220,7 @@ async def create_crawl_api(self):
self.crawl_id = out_json["crawl_id"]
self.created = True

async def queue_crawl(self):
def queue_crawl(self):
"""
Resume crawling of Crawl object. Don't wait for it to finish crawling.
"""
Expand All @@ -228,21 +229,23 @@ async def queue_crawl(self):

else:
if not self.created:
await self.create_crawl_api()
self.create_crawl_api()
queue_json = {
"crawl_id": self.crawl_id,
}
run_webt_api(
out = run_webt_api(
queue_json,
"v1/crawl/resume",
self.api_key,
)

print(out)

async def crawl(self):
"""
Resume crawling of Crawl object.
"""
if self.verbose:
logging.info(f"Starting crawl of {self.base_url}")
if self.api_key is None:
leftover_queue = asyncio.Queue()
ignored_queue = asyncio.Queue()
Expand Down Expand Up @@ -274,14 +277,18 @@ async def crawl(self):
self.ignored_urls = list(ignored_queue._queue)
self.to_metadata()
else:
await self.queue_crawl()
self.queue_crawl()
status = self.status()
while status['num_queued'] > 0 and status['num_visited'] < status['max_pages']:
while status["num_queued"] + status["num_visited"] + status["num_ignored"] == 0:
await asyncio.sleep(5)
status = self.status()

while status["num_queued"] > 0 and status["num_visited"] < status["max_pages"]:
await asyncio.sleep(5)
status = self.status()
return self

def get_queue(self, n=10):
def get_queue(self, n: int = 10) -> list:
"""
Get a list of URLs from the queue.
Expand Down Expand Up @@ -316,7 +323,7 @@ def get_queue(self, n=10):
)
return out_json["urls"]

def set_allowed_urls(self, allowed_urls):
def set_allowed_urls(self, allowed_urls: list) -> "Crawl":
"""
Set the allowed URLs for the crawl.
Expand All @@ -341,7 +348,7 @@ def set_allowed_urls(self, allowed_urls):
)
return self

def set_banned_urls(self, banned_urls):
def set_banned_urls(self, banned_urls: list) -> "Crawl":
"""
Set the banned URLs for the crawl.
Expand All @@ -351,7 +358,7 @@ def set_banned_urls(self, banned_urls):
Returns:
self: The Crawl object.
"""
self.banned_urls = banned_urls
self.banned_urls = banned_urls
if not self.created:
self.to_metadata()
else:
Expand All @@ -366,7 +373,7 @@ def set_banned_urls(self, banned_urls):
)
return self

def get_filename(self, url):
def get_filename(self, url: str) -> str:
"""
Get the filename associated with a visited URL.
Expand All @@ -384,7 +391,7 @@ def get_filename(self, url):
except KeyError:
raise ValueError(f"URL {url} not found in visited URLs")

def set_max_pages(self, max_pages):
def set_max_pages(self, max_pages: int) -> "Crawl":
"""
Set the maximum number of pages to crawl.
Expand All @@ -409,7 +416,7 @@ def set_max_pages(self, max_pages):
)
return self

def status(self):
def status(self) -> dict:
"""
Get the status of the Crawl object.
Expand Down Expand Up @@ -441,8 +448,8 @@ def status(self):
)
crawl_status["loc"] = "cloud"
return crawl_status
def get_ignored(self):

def get_ignored(self) -> list:
"""
Get a list of ignored URLs.
Expand All @@ -462,7 +469,7 @@ def get_ignored(self):
)
return out_json["pages"]

def get_visited(self):
def get_visited(self) -> list:
"""
Get a list of visited URLs.
Expand All @@ -482,7 +489,7 @@ def get_visited(self):
)
return out_json["pages"]

def get_banned(self):
def get_banned(self) -> list:
"""
Get a list of banned URLs.
Expand Down Expand Up @@ -539,9 +546,10 @@ def download(self):
filename = urllib.parse.quote_plus(url).replace("/", "_")
filepath = os.path.join(base_dir, filename) + ".json"
shutil.move(json_file, filepath)

logging.info(f"The output of the crawl can be found at: {self.output_dir}")

def to_metadata(self):
def to_metadata(self) -> None:
"""
Save the metadata of the Crawl object to a file.
"""
Expand All @@ -564,7 +572,7 @@ def to_metadata(self):
json.dump(metadata, file)

@staticmethod
def from_metadata(crawl_id, output_dir="webtranspose-out"):
def from_metadata(crawl_id: str, output_dir: str = "webtranspose-out") -> "Crawl":
"""
Create a Crawl object from metadata stored in a file.
Expand Down Expand Up @@ -596,7 +604,7 @@ def from_metadata(crawl_id, output_dir="webtranspose-out"):
return crawl

@staticmethod
def from_cloud(crawl_id, api_key=None):
def from_cloud(crawl_id: str, api_key: Optional[str] = None) -> "Crawl":
"""
Create a Crawl object from metadata stored in the cloud.
Expand Down Expand Up @@ -631,7 +639,7 @@ def from_cloud(crawl_id, api_key=None):
"API key not found. Please set WEBTRANSPOSE_API_KEY environment variable or pass api_key argument."
)

def status(self):
def status(self) -> dict:
"""
Get the status of the Crawl object.
Expand Down Expand Up @@ -661,7 +669,7 @@ def status(self):
)
return crawl_status

def __str__(self):
def __str__(self) -> str:
"""
Get a string representation of the Crawl object.
Expand All @@ -683,7 +691,7 @@ def __str__(self):
f")"
)

def __repr__(self):
def __repr__(self) -> str:
"""
Get a string representation of the Crawl object.
Expand All @@ -705,7 +713,7 @@ def __repr__(self):
f")"
)

def get_page(self, url):
def get_page(self, url: str) -> dict:
"""
Get the page data for a given URL.
Expand Down Expand Up @@ -735,7 +743,7 @@ def get_page(self, url):
)
return out_json

def get_child_urls(self, url):
def get_child_urls(self, url: str) -> list:
"""
Get the child URLs for a given URL.
Expand Down Expand Up @@ -770,12 +778,13 @@ def get_child_urls(self, url):
return out_json


def get_crawl(crawl_id, api_key=None):
def get_crawl(crawl_id: str, api_key: Optional[str] = None) -> Crawl:
"""
Get a Crawl object based on the crawl ID.
Args:
crawl_id (str): The ID of the crawl.
api_key (str, optional): The API key. Defaults to None.
Returns:
Crawl: The Crawl object.
Expand All @@ -786,7 +795,7 @@ def get_crawl(crawl_id, api_key=None):
return Crawl.from_cloud(crawl_id, api_key=api_key)


def list_crawls(loc="cloud", api_key=None):
def list_crawls(loc: str = "cloud", api_key: Optional[str] = None) -> list:
"""
List all available crawls.
Expand Down
12 changes: 7 additions & 5 deletions src/webtranspose/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
class OpenAIScraper:
def __init__(
self,
chunk_size=2500,
overlap_size=100,
chunk_size: int = 2500,
overlap_size: int = 100,
):
"""
Initialize the OpenAIScraper.
Expand All @@ -24,7 +24,9 @@ def __init__(
self.overlap_size = overlap_size

@staticmethod
def process_html(text, chunk_size, overlap_size, encoding):
def process_html(
text: str, chunk_size: int, overlap_size: int, encoding: tiktoken.Encoding
) -> list:
"""
Process the HTML text into chunks.
Expand All @@ -49,7 +51,7 @@ def process_html(text, chunk_size, overlap_size, encoding):
decoded_chunks = [encoding.decode(chunk) for chunk in chunks]
return decoded_chunks

def scrape(self, html, schema):
def scrape(self, html: str, schema: dict) -> dict:
"""
Scrape the HTML text using the provided schema.
Expand Down Expand Up @@ -106,7 +108,7 @@ def scrape(self, html, schema):

return out_data

def transform_schema(self, schema):
def transform_schema(self, schema: dict) -> dict:
"""
Transform the schema into the format required by OpenAI.
Expand Down
Loading

0 comments on commit 52bfeba

Please sign in to comment.