Skip to content

Commit

Permalink
Deeplake tests, typing
Browse files Browse the repository at this point in the history
  • Loading branch information
arjbingly committed Mar 22, 2024
1 parent 2f05d98 commit 41b2bcf
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 38 deletions.
24 changes: 18 additions & 6 deletions src/grag/components/vectordb/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from abc import ABC, abstractmethod
from typing import List
from typing import List, Tuple, Union

from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_core.documents import Document


class VectorDB(ABC):

@abstractmethod
def __len__(self) -> int:
"""Number of chunks in the vector database."""
...

@abstractmethod
def delete(self) -> None:
"""Delete all chunks in the vector database."""

@abstractmethod
def add_docs(self, docs: List[Document], verbose: bool = True):
def add_docs(self, docs: List[Document], verbose: bool = True) -> None:
"""Adds documents to the vector database.
Args:
Expand All @@ -20,7 +30,7 @@ def add_docs(self, docs: List[Document], verbose: bool = True):
...

@abstractmethod
async def aadd_docs(self, docs: List[Document], verbose: bool = True):
async def aadd_docs(self, docs: List[Document], verbose: bool = True) -> None:
"""Adds documents to the vector database (asynchronous).
Args:
Expand All @@ -33,7 +43,8 @@ async def aadd_docs(self, docs: List[Document], verbose: bool = True):
...

@abstractmethod
def get_chunk(self, query: str, with_score: bool = False, top_k: int = None):
def get_chunk(self, query: str, with_score: bool = False, top_k: int = None) -> Union[
List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the vector database.
Args:
Expand All @@ -47,7 +58,8 @@ def get_chunk(self, query: str, with_score: bool = False, top_k: int = None):
...

@abstractmethod
async def aget_chunk(self, query: str, with_score: bool = False, top_k: int = None):
async def aget_chunk(self, query: str, with_score: bool = False, top_k: int = None) -> Union[
List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the vector database. (asynchronous)
Args:
Expand All @@ -60,5 +72,5 @@ async def aget_chunk(self, query: str, with_score: bool = False, top_k: int = No
"""
...

def _filter_metadata(self, docs: List[Document]):
def _filter_metadata(self, docs: List[Document]) -> List[Document]:
return filter_complex_metadata(docs, allowed_types=self.allowed_metadata_types)
28 changes: 22 additions & 6 deletions src/grag/components/vectordb/chroma_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple, Union

import chromadb
from grag.components.embedding import Embedding
Expand Down Expand Up @@ -72,7 +72,21 @@ def __init__(
)
self.allowed_metadata_types = (str, int, float, bool)

def test_connection(self, verbose=True):
def __len__(self) -> int:
return self.collection.count()

def delete(self) -> None:
self.client.delete_collection(self.collection_name)
self.collection = self.client.get_or_create_collection(
name=self.collection_name
)
self.langchain_client = Chroma(
client=self.client,
collection_name=self.collection_name,
embedding_function=self.embedding_function,
)

def test_connection(self, verbose=True) -> int:
"""Tests connection with Chroma Vectorstore
Args:
Expand All @@ -89,7 +103,7 @@ def test_connection(self, verbose=True):
print(f"Connection to {self.host}/{self.port} is not alive !!")
return response

def add_docs(self, docs: List[Document], verbose=True):
def add_docs(self, docs: List[Document], verbose=True) -> None:
"""Adds documents to chroma vectorstore
Args:
Expand All @@ -105,7 +119,7 @@ def add_docs(self, docs: List[Document], verbose=True):
):
_id = self.langchain_client.add_documents([doc])

async def aadd_docs(self, docs: List[Document], verbose=True):
async def aadd_docs(self, docs: List[Document], verbose=True) -> None:
"""Asynchronously adds documents to chroma vectorstore
Args:
Expand All @@ -127,7 +141,8 @@ async def aadd_docs(self, docs: List[Document], verbose=True):
for doc in docs:
await self.langchain_client.aadd_documents([doc])

def get_chunk(self, query: str, with_score: bool = False, top_k: int = None):
def get_chunk(self, query: str, with_score: bool = False, top_k: int = None) -> Union[
List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the chroma database.
Args:
Expand All @@ -148,7 +163,8 @@ def get_chunk(self, query: str, with_score: bool = False, top_k: int = None):
query=query, k=top_k if top_k else 1
)

async def aget_chunk(self, query: str, with_score=False, top_k=None):
async def aget_chunk(self, query: str, with_score=False, top_k=None) -> Union[
List[Document], List[Tuple[Document, float]]]:
"""Returns the most (cosine) similar chunks from the vector database, asynchronously.
Args:
Expand Down
47 changes: 30 additions & 17 deletions src/grag/components/vectordb/deeplake_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pathlib import Path
from typing import List, Union
from typing import List, Tuple, Union

from deeplake.core.vectorstore import VectorStore
from grag.components.embedding import Embedding
from grag.components.utils import get_config
from grag.components.vectordb.base import VectorDB
Expand Down Expand Up @@ -34,24 +33,36 @@ class DeepLakeClient(VectorDB):
"""

def __init__(self,
store_path: Union[str, Path],
embedding_model: str,
embedding_type: str,
collection_name: str = deeplake_conf["collection_name"],
store_path: Union[str, Path] = deeplake_conf["store_path"],
embedding_type: str = deeplake_conf["embedding_type"],
embedding_model: str = deeplake_conf["embedding_model"],
read_only: bool = False
):
self.store_path = Path(store_path)
self.collection_name = collection_name
self.read_only = read_only
self.embedding_type: str = embedding_type
self.embedding_model: str = embedding_model

self.embedding_function = Embedding(
embedding_model=self.embedding_model, embedding_type=self.embedding_type
).embedding_function

self.client = VectorStore(path=self.store_path)
self.langchain_client = DeepLake(path=self.store_path,
embedding=self.embedding_function)
# self.client = VectorStore(path=self.store_path / self.collection_name)
self.langchain_client = DeepLake(dataset_path=str(self.store_path / self.collection_name),
embedding=self.embedding_function,
read_only=self.read_only)
self.client = self.langchain_client.vectorstore
self.allowed_metadata_types = (str, int, float, bool)

def add_docs(self, docs: List[Document], verbose=True):
def __len__(self) -> int:
return self.client.__len__()

def delete(self) -> None:
self.client.delete(delete_all=True)

def add_docs(self, docs: List[Document], verbose=True) -> None:
"""Adds documents to deeplake vectorstore
Args:
Expand All @@ -65,9 +76,9 @@ def add_docs(self, docs: List[Document], verbose=True):
for doc in (
tqdm(docs, desc=f"Adding to {self.collection_name}:") if verbose else docs
):
_id = self.langchain_chroma.add_documents([doc])
_id = self.langchain_client.add_documents([doc])

async def aadd_docs(self, docs: List[Document], verbose=True):
async def aadd_docs(self, docs: List[Document], verbose=True) -> None:
"""Asynchronously adds documents to chroma vectorstore
Args:
Expand All @@ -84,12 +95,13 @@ async def aadd_docs(self, docs: List[Document], verbose=True):
desc=f"Adding documents to {self.collection_name}",
total=len(docs),
):
await self.langchain_deeplake.aadd_documents([doc])
await self.langchain_client.aadd_documents([doc])
else:
for doc in docs:
await self.langchain_deeplake.aadd_documents([doc])
await self.langchain_client.aadd_documents([doc])

def get_chunk(self, query: str, with_score: bool = False, top_k: int = None):
def get_chunk(self, query: str, with_score: bool = False, top_k: int = None) -> Union[
List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the deeplake database.
Args:
Expand All @@ -102,15 +114,16 @@ def get_chunk(self, query: str, with_score: bool = False, top_k: int = None):
"""
if with_score:
return self.langchain_client.similarity_search_with_relevance_scores(
return self.langchain_client.similarity_search_with_score(
query=query, k=top_k if top_k else 1
)
else:
return self.langchain_client.similarity_search(
query=query, k=top_k if top_k else 1
)

async def aget_chunk(self, query: str, with_score=False, top_k=None):
async def aget_chunk(self, query: str, with_score=False, top_k=None) -> Union[
List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the deeplake database, asynchronously.
Args:
Expand All @@ -123,7 +136,7 @@ async def aget_chunk(self, query: str, with_score=False, top_k=None):
"""
if with_score:
return await self.langchain_client.asimilarity_search_with_relevance_scores(
return await self.langchain_client.asimilarity_search_with_score(
query=query, k=top_k if top_k else 1
)
else:
Expand Down
15 changes: 6 additions & 9 deletions src/tests/components/vectordb/chroma_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,11 @@ def test_chroma_add_docs():
thunder rolled and boomed, like the Colorado in flood.""",
]
chroma_client = ChromaClient(collection_name="test")
if chroma_client.collection.count() > 0:
chroma_client.client.delete_collection("test")
chroma_client = ChromaClient(collection_name="test")
if len(chroma_client) > 0:
chroma_client.delete()
docs = [Document(page_content=doc) for doc in docs]
chroma_client.add_docs(docs)
collection_count = chroma_client.collection.count()
assert collection_count == len(docs)
assert len(chroma_client) == len(docs)


def test_chroma_aadd_docs():
Expand Down Expand Up @@ -92,13 +90,12 @@ def test_chroma_aadd_docs():
thunder rolled and boomed, like the Colorado in flood.""",
]
chroma_client = ChromaClient(collection_name="test")
if chroma_client.collection.count() > 0:
chroma_client.client.delete_collection("test")
chroma_client = ChromaClient(collection_name="test")
if len(chroma_client) > 0:
chroma_client.delete()
docs = [Document(page_content=doc) for doc in docs]
loop = asyncio.get_event_loop()
loop.run_until_complete(chroma_client.aadd_docs(docs))
assert chroma_client.collection.count() == len(docs)
assert len(chroma_client) == len(docs)


chrome_get_chunk_params = [(1, False), (1, True), (2, False), (2, True)]
Expand Down
Loading

0 comments on commit 41b2bcf

Please sign in to comment.