Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Webbase documentloader proxysupport #1

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
121 changes: 121 additions & 0 deletions langchain/document_loaders/rss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Loader that fetches a sitemap and loads those URLs."""
import re
import itertools
from typing import Any, Callable, List, Optional, Generator

from langchain.document_loaders.web_base import WebBaseLoader
from langchain.schema import Document

from lxml import etree


def _default_parsing_function_text(content: Any) -> str:
text = ""
if "content" in content:
text = content["content"]
elif "description" in content:
text = content["description"]

return text


def _default_parsing_function_meta(meta: Any) -> str:
r_meta = dict(meta)
if "content" in r_meta:
del r_meta["content"]

if "description" in r_meta:
del r_meta["description"]

return r_meta

class RssLoader(WebBaseLoader):
"""Loader that fetches a sitemap and loads those URLs."""

def __init__(
self,
web_path: str,
parsing_function_text: Optional[Callable] = None,
parsing_function_meta: Optional[Callable] = None,
):
"""Initialize with webpage path and optional filter URLs.

Args:
web_path: url of the sitemap
filter_urls: list of strings or regexes that will be applied to filter the
urls that are parsed and loaded
parsing_function: Function to parse bs4.Soup output
"""

try:
import lxml # noqa:F401
except ImportError:
raise ValueError(
"lxml package not found, please install it with " "`pip install lxml`"
)

super().__init__(
web_path,
header_template=header_template,
)

self.parsing_function_text = parsing_function_text or _default_parsing_function_text
self.parsing_function_meta = parsing_function_meta or _default_parsing_function_meta

self.namespaces = {
'content': 'http://purl.org/rss/1.0/modules/content/',
'dc':'http://purl.org/dc/elements/1.1/'
}
self.fields = [
{"tag": "./link", "field":"source"},
{"tag": "./title", "field":"title"},
{"tag": "./category", "field": "category", "multi": True},
{"tag": "./pubDate", "field":"publication_date"},
{"tag": "./dc:creator", "field": "author"},
{"tag": "./description", "field": "description", "type":"html"},
{"tag": "./content:encoded", "field":"content", "type":"html"},
]
self.items_selector = './channel/item'

def parse_rss(self, root: Any) -> Generator[List[dict], None, None]:
"""Parse rss xml and load into a list of dicts."""

for item in root.findall(self.items_selector):
meta = {}
for field in self.fields:
element_list = item.findall(field["tag"], namespaces=self.namespaces)
for element in element_list:
text = element.text

if "type" in field and field["type"] == "html":
soup = BeautifulSoup(text,"html.parser")
text = soup.get_text()

if field["field"] not in meta:
meta[field["field"]] = [] if "multi" in field and field["multi"] ==True else ""

if "multi" in field and field["multi"] ==True:
meta[field["field"]] = meta[field["field"]] if "field" in field else []
meta[field["field"]].append(text)
else:
meta[field["field"]] = text

yield meta



def load(self) -> List[Document]:
"""Load feeds."""

docs: List[Document] = list()
for feed in self.web_paths:
xml = self.session.get(feed)
root = etree.fromstring(xml)

for item in self.parse_rss(root):
text = self.parsing_function_text(item)
metadata = self.parsing_function_meta(item)

docs.append(Document(page_content=text, metadata=metadata))

return docs
49 changes: 46 additions & 3 deletions langchain/document_loaders/sitemap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Loader that fetches a sitemap and loads those URLs."""
import re
from typing import Any, Callable, List, Optional
import itertools
from typing import Any, Callable, List, Optional, Iterable, Generator

from aiohttp.helpers import BasicAuth
from aiohttp.typedefs import StrOrURL

from langchain.document_loaders.web_base import WebBaseLoader
from langchain.schema import Document
Expand All @@ -9,6 +13,13 @@
def _default_parsing_function(content: Any) -> str:
return str(content.get_text())

def _default_meta_function(list: dict, _content: Any) -> dict:
return list

def _batch_block(iterable: Iterable, size: int) -> Generator[List[dict], None, None]:
it = iter(iterable)
while item := list(itertools.islice(it, size)):
yield item

class SitemapLoader(WebBaseLoader):
"""Loader that fetches a sitemap and loads those URLs."""
Expand All @@ -18,6 +29,13 @@ def __init__(
web_path: str,
filter_urls: Optional[List[str]] = None,
parsing_function: Optional[Callable] = None,
meta_function: Optional[Callable] = None,
header_template: Optional[dict] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
cookies: Optional[dict] = None,
blocksize: Optional[int] = None,
blocknum: Optional[int] = None,
):
"""Initialize with webpage path and optional filter URLs.

Expand All @@ -26,6 +44,10 @@ def __init__(
filter_urls: list of strings or regexes that will be applied to filter the
urls that are parsed and loaded
parsing_function: Function to parse bs4.Soup output
proxy: proxy url
proxy_auth: proxy server authentication
blocksize: number of sitemap location per block
blocknum: the number of the block that should be loaded - zero indexed
"""

try:
Expand All @@ -35,9 +57,19 @@ def __init__(
"lxml package not found, please install it with " "`pip install lxml`"
)

super().__init__(web_path)
super().__init__(
web_path,
proxy=proxy,
proxy_auth=proxy_auth,
cookies=cookies,
header_template=header_template,
)

self.blocksize = blocksize
self.blocknum = blocknum

self.filter_urls = filter_urls
self.meta_function = meta_function or _default_meta_function
self.parsing_function = parsing_function or _default_parsing_function

def parse_sitemap(self, soup: Any) -> List[dict]:
Expand Down Expand Up @@ -76,12 +108,23 @@ def load(self) -> List[Document]:

els = self.parse_sitemap(soup)

if self.blocksize is not None and self.blocknum is not None:
total_item_count = len(els)
elblocks = list(_batch_block(els, self.blocksize))
blockcount = len(elblocks)
if blockcount - 1 < self.blocknum:
raise ValueError(
"Selected sitemap does not contain enough blocks for given blocknum"
)
else:
els = elblocks[self.blocknum]

results = self.scrape_all([el["loc"].strip() for el in els if "loc" in el])

return [
Document(
page_content=self.parsing_function(results[i]),
metadata={**{"source": els[i]["loc"]}, **els[i]},
metadata={**{"source": els[i]["loc"]}, **self.meta_function(els[i], results[i])},
)
for i in range(len(results))
]
68 changes: 53 additions & 15 deletions langchain/document_loaders/web_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import aiohttp
import requests
from aiohttp.helpers import BasicAuth
from aiohttp.typedefs import StrOrURL

from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
Expand Down Expand Up @@ -47,8 +49,19 @@ class WebBaseLoader(BaseLoader):
default_parser: str = "html.parser"
"""Default parser to use for BeautifulSoup."""

proxy: Optional[StrOrURL] = None
"""aiohttp proxy server"""

proxy_auth: Optional[BasicAuth] = None
"""aio proxy auth"""

def __init__(
self, web_path: Union[str, List[str]], header_template: Optional[dict] = None
self,
web_path: Union[str, List[str]],
header_template: Optional[dict] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
cookies: Optional[dict] = None,
):
"""Initialize with webpage path."""

Expand All @@ -61,24 +74,37 @@ def __init__(
self.web_paths = web_path

self.session = requests.Session()
self.proxy = proxy
self.proxy_auth = proxy_auth
try:
import bs4 # noqa:F401
except ImportError:
raise ValueError(
"bs4 package not found, please install it with " "`pip install bs4`"
)

try:
from fake_useragent import UserAgent

headers = header_template or default_header_template
headers["User-Agent"] = UserAgent().random
self.session.headers = dict(headers)
except ImportError:
logger.info(
"fake_useragent not found, using default user agent."
"To get a realistic header for requests, `pip install fake_useragent`."
)
headers = header_template or default_header_template
if (
"User-Agent" not in headers
or headers["User-Agent"] == ""
or headers["User-Agent"] == None
):
try:
from fake_useragent import UserAgent

headers["User-Agent"] = UserAgent().random
except ImportError:
logger.info(
"fake_useragent not found, using default user agent."
"To get a realistic header for requests, `pip install fake_useragent`."
)

self.session.headers = dict(headers)

# Combine cookies
if cookies is None:
cookies = {}
self.session.cookies.update(cookies)

@property
def web_path(self) -> str:
Expand All @@ -89,11 +115,16 @@ def web_path(self) -> str:
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(
cookies=self.session.cookies.get_dict()
) as session:
for i in range(retries):
try:
async with session.get(
url, headers=self.session.headers
url,
headers=self.session.headers,
proxy=self.proxy,
proxy_auth=self.proxy_auth,
) as response:
return await response.text()
except aiohttp.ClientConnectionError as e:
Expand Down Expand Up @@ -168,7 +199,14 @@ def _scrape(self, url: str, parser: Union[str, None] = None) -> Any:

self._check_parser(parser)

html_doc = self.session.get(url)
proxies = None
if self.proxy is not None:
proxies = {
"http": self.proxy,
"https": self.proxy,
}

html_doc = self.session.get(url, proxies=proxies)
return BeautifulSoup(html_doc.text, parser)

def scrape(self, parser: Union[str, None] = None) -> Any:
Expand Down
3 changes: 2 additions & 1 deletion langchain/vectorstores/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import enum
import logging
import uuid
import os
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type

import sqlalchemy
Expand All @@ -19,7 +20,7 @@
Base = declarative_base() # type: Any


ADA_TOKEN_COUNT = 1536
ADA_TOKEN_COUNT = int(os.getenv("PGVECTOR_ADA_TOKEN_COUNT", default="1536"))
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"


Expand Down
27 changes: 25 additions & 2 deletions tests/integration_tests/document_loaders/test_sitemap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from langchain.document_loaders import SitemapLoader

import pytest

def test_sitemap() -> None:
"""Test sitemap loader."""
Expand All @@ -9,11 +9,34 @@ def test_sitemap() -> None:
assert "🦜🔗" in documents[0].page_content


def test_sitemap_block() -> None:
"""Test sitemap loader."""
loader = SitemapLoader("https://langchain.readthedocs.io/sitemap.xml", blocksize=1, blocknum=1)
documents = loader.load()
assert len(documents) == 1
assert "🦜🔗" in documents[0].page_content


def test_sitemap_block_only_one() -> None:
"""Test sitemap loader."""
loader = SitemapLoader("https://langchain.readthedocs.io/sitemap.xml", blocksize=1000000, blocknum=0)
documents = loader.load()
assert len(documents) > 1
assert "🦜🔗" in documents[0].page_content


def test_sitemap_block_does_not_exists() -> None:
"""Test sitemap loader."""
loader = SitemapLoader("https://langchain.readthedocs.io/sitemap.xml", blocksize=1000000, blocknum=15)
with pytest.raises(ValueError):
documents = loader.load()


def test_filter_sitemap() -> None:
"""Test sitemap loader."""
loader = SitemapLoader(
"https://langchain.readthedocs.io/sitemap.xml",
filter_urls=["https://langchain.readthedocs.io/en/stable/"],
filter_urls=["https://python.langchain.com/en/stable/"],
)
documents = loader.load()
assert len(documents) == 1
Expand Down