Skip to content

Commit

Permalink
community[minor]: Add async methods to CassandraLoader (#20609)
Browse files Browse the repository at this point in the history
Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
cbornet and eyurtsev authored Apr 18, 2024
1 parent 8c29b7b commit d2d0137
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
14 changes: 14 additions & 0 deletions libs/community/langchain_community/document_loaders/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Iterator,
Optional,
Expand All @@ -13,6 +14,7 @@
from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.cassandra import wrapped_response_future

_NOT_SET = object()

Expand Down Expand Up @@ -112,3 +114,15 @@ def lazy_load(self) -> Iterator[Document]:
yield Document(
page_content=self.page_content_mapper(row), metadata=metadata
)

async def alazy_load(self) -> AsyncIterator[Document]:
for row in await wrapped_response_future(
self.session.execute_async,
self.query,
**self.query_kwargs,
):
metadata = self.metadata.copy()
metadata.update(self.metadata_mapper(row))
yield Document(
page_content=self.page_content_mapper(row), metadata=metadata
)
24 changes: 24 additions & 0 deletions libs/community/langchain_community/utilities/cassandra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any, Callable

if TYPE_CHECKING:
from cassandra.cluster import ResponseFuture


async def wrapped_response_future(
func: Callable[..., ResponseFuture], *args: Any, **kwargs: Any
) -> Any:
loop = asyncio.get_event_loop()
asyncio_future = loop.create_future()
response_future = func(*args, **kwargs)

def success_handler(_: Any) -> None:
loop.call_soon_threadsafe(asyncio_future.set_result, response_future.result())

def error_handler(exc: BaseException) -> None:
loop.call_soon_threadsafe(asyncio_future.set_exception, exc)

response_future.add_callbacks(success_handler, error_handler)
return await asyncio_future
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def keyspace() -> Iterator[str]:
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{CASSANDRA_TABLE}")


def test_loader_table(keyspace: str) -> None:
async def test_loader_table(keyspace: str) -> None:
loader = CassandraLoader(table=CASSANDRA_TABLE)
assert loader.load() == [
expected = [
Document(
page_content="Row(row_id='id1', body_blob='text1')",
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
Expand All @@ -67,24 +67,28 @@ def test_loader_table(keyspace: str) -> None:
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
),
]
assert loader.load() == expected
assert await loader.aload() == expected


def test_loader_query(keyspace: str) -> None:
async def test_loader_query(keyspace: str) -> None:
loader = CassandraLoader(
query=f"SELECT body_blob FROM {keyspace}.{CASSANDRA_TABLE}"
)
assert loader.load() == [
expected = [
Document(page_content="Row(body_blob='text1')"),
Document(page_content="Row(body_blob='text2')"),
]
assert loader.load() == expected
assert await loader.aload() == expected


def test_loader_page_content_mapper(keyspace: str) -> None:
async def test_loader_page_content_mapper(keyspace: str) -> None:
def mapper(row: Any) -> str:
return str(row.body_blob)

loader = CassandraLoader(table=CASSANDRA_TABLE, page_content_mapper=mapper)
assert loader.load() == [
expected = [
Document(
page_content="text1",
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
Expand All @@ -94,14 +98,16 @@ def mapper(row: Any) -> str:
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
),
]
assert loader.load() == expected
assert await loader.aload() == expected


def test_loader_metadata_mapper(keyspace: str) -> None:
async def test_loader_metadata_mapper(keyspace: str) -> None:
def mapper(row: Any) -> dict:
return {"id": row.row_id}

loader = CassandraLoader(table=CASSANDRA_TABLE, metadata_mapper=mapper)
assert loader.load() == [
expected = [
Document(
page_content="Row(row_id='id1', body_blob='text1')",
metadata={
Expand All @@ -119,3 +125,5 @@ def mapper(row: Any) -> dict:
},
),
]
assert loader.load() == expected
assert await loader.aload() == expected

0 comments on commit d2d0137

Please sign in to comment.