Skip to content

Commit

Permalink
Merge pull request #47 from arjbingly/vectordb
Browse files Browse the repository at this point in the history
Vectordb, Add support for deeplake
  • Loading branch information
sanchitvj authored Mar 24, 2024
2 parents 58be4d8 + 454bb5d commit 56f16cf
Show file tree
Hide file tree
Showing 15 changed files with 693 additions and 93 deletions.
8 changes: 6 additions & 2 deletions projects/Basic-RAG/BasicRAG_stuff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from grag.grag.rag import BasicRAG
from grag.components.multivec_retriever import Retriever
from grag.components.vectordb.deeplake_client import DeepLakeClient
from grag.rag.basic_rag import BasicRAG

rag = BasicRAG(doc_chain="stuff")
client = DeepLakeClient(collection_name="test")
retriever = Retriever(vectordb=client)
rag = BasicRAG(doc_chain="stuff", retriever=retriever)

if __name__ == "__main__":
while True:
Expand Down
14 changes: 7 additions & 7 deletions projects/Retriver-GUI/retriever_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def render_search_results(self):
st.write(result.metadata)

def check_connection(self):
response = self.app.retriever.client.test_connection()
response = self.app.retriever.vectordb.test_connection()
if response:
return True
else:
Expand All @@ -55,14 +55,14 @@ def check_connection(self):
def render_stats(self):
st.write(f'''
**Chroma Client Details:** \n
Host Address : {self.app.retriever.client.host}:{self.app.retriever.client.port} \n
Collection Name : {self.app.retriever.client.collection_name} \n
Embeddings Type : {self.app.retriever.client.embedding_type} \n
Embeddings Model: {self.app.retriever.client.embedding_model} \n
Number of docs : {self.app.retriever.client.collection.count()} \n
Host Address : {self.app.retriever.vectordb.host}:{self.app.retriever.vectordb.port} \n
Collection Name : {self.app.retriever.vectordb.collection_name} \n
Embeddings Type : {self.app.retriever.vectordb.embedding_type} \n
Embeddings Model: {self.app.retriever.vectordb.embedding_model} \n
Number of docs : {self.app.retriever.vectordb.collection.count()} \n
''')
if st.button('Check Connection'):
response = self.app.retriever.client.test_connection()
response = self.app.retriever.vectordb.test_connection()
if response:
st.write(':green[Connection Active]')
else:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"huggingface_hub>=0.20.2",
"pydantic>=2.5.0",
"rouge-score>=0.1.2",
"deeplake>=3.8.27"
]

[project.urls]
Expand Down
6 changes: 6 additions & 0 deletions src/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ n_gpu_layers_cpp : -1
std_out : True
base_dir : ${root:root_path}/models

[deeplake]
collection_name : arxiv
embedding_type : instructor-embedding
embedding_model : hkunlp/instructor-xl
store_path : ${data:data_path}/vectordb

[chroma]
host : localhost
port : 8000
Expand Down
47 changes: 23 additions & 24 deletions src/grag/components/multivec_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import asyncio
import uuid
from typing import List
from typing import Any, Dict, List, Optional

from grag.components.chroma_client import ChromaClient
from grag.components.text_splitter import TextSplitter
from grag.components.utils import get_config
from grag.components.vectordb.base import VectorDB
from grag.components.vectordb.deeplake_client import DeepLakeClient
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import LocalFileStore
from langchain_core.documents import Document
Expand All @@ -28,7 +29,7 @@ class Retriever:
Attributes:
store_path: Path to the local file store
id_key: A key prefix for identifying documents
client: ChromaClient class instance from components.chroma_client
vectordb: ChromaClient class instance from components.client
store: langchain.storage.LocalFileStore object, stores the key value pairs of document id and parent file
retriever: langchain.retrievers.multi_vector.MultiVectorRetriever class instance, langchain's multi-vector retriever
splitter: TextSplitter class instance from components.text_splitter
Expand All @@ -39,10 +40,12 @@ class Retriever:

def __init__(
self,
vectordb: Optional[VectorDB] = None,
store_path: str = multivec_retriever_conf["store_path"],
id_key: str = multivec_retriever_conf["id_key"],
namespace: str = multivec_retriever_conf["namespace"],
top_k=1,
client_kwargs: Optional[Dict[str, Any]] = None,
):
"""Initialize the Retriever.
Expand All @@ -55,10 +58,16 @@ def __init__(
self.store_path = store_path
self.id_key = id_key
self.namespace = uuid.UUID(namespace)
self.client = ChromaClient()
if vectordb is None:
if client_kwargs is not None:
self.vectordb = DeepLakeClient(**client_kwargs)
else:
self.vectordb = DeepLakeClient()
else:
self.vectordb = vectordb
self.store = LocalFileStore(self.store_path)
self.retriever = MultiVectorRetriever(
vectorstore=self.client.langchain_chroma,
vectorstore=self.vectordb.langchain_client,
byte_store=self.store,
id_key=self.id_key,
)
Expand Down Expand Up @@ -125,7 +134,7 @@ def add_docs(self, docs: List[Document]):
"""
chunks = self.split_docs(docs)
doc_ids = self.gen_doc_ids(docs)
self.client.add_docs(chunks)
self.vectordb.add_docs(chunks)
self.retriever.docstore.mset(list(zip(doc_ids, docs)))

async def aadd_docs(self, docs: List[Document]):
Expand All @@ -140,11 +149,11 @@ async def aadd_docs(self, docs: List[Document]):
"""
chunks = self.split_docs(docs)
doc_ids = self.gen_doc_ids(docs)
await asyncio.run(self.client.aadd_docs(chunks))
await asyncio.run(self.vectordb.aadd_docs(chunks))
self.retriever.docstore.mset(list(zip(doc_ids)))

def get_chunk(self, query: str, with_score=False, top_k=None):
"""Returns the most (cosine) similar chunks from the vector database.
"""Returns the most similar chunks from the vector database.
Args:
query: A query string
Expand All @@ -155,14 +164,8 @@ def get_chunk(self, query: str, with_score=False, top_k=None):
list of Documents
"""
if with_score:
return self.client.langchain_chroma.similarity_search_with_relevance_scores(
query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs
)
else:
return self.client.langchain_chroma.similarity_search(
query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs
)
_top_k = top_k if top_k else self.retriever.search_kwargs["k"]
return self.vectordb.get_chunk(query=query, top_k=_top_k, with_score=with_score)

async def aget_chunk(self, query: str, with_score=False, top_k=None):
"""Returns the most (cosine) similar chunks from the vector database, asynchronously.
Expand All @@ -176,14 +179,10 @@ async def aget_chunk(self, query: str, with_score=False, top_k=None):
list of Documents
"""
if with_score:
return await self.client.langchain_chroma.asimilarity_search_with_relevance_scores(
query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs
)
else:
return await self.client.langchain_chroma.asimilarity_search(
query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs
)
_top_k = top_k if top_k else self.retriever.search_kwargs["k"]
return await self.vectordb.aget_chunk(
query=query, top_k=_top_k, with_score=with_score
)

def get_doc(self, query: str):
"""Returns the parent document of the most (cosine) similar chunk from the vector database.
Expand Down
Empty file.
85 changes: 85 additions & 0 deletions src/grag/components/vectordb/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Abstract base class for vector database clients.
This module provides:
- VectorDB
"""

from abc import ABC, abstractmethod
from typing import List, Tuple, Union

from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_core.documents import Document


class VectorDB(ABC):
"""Abstract base class for vector database clients."""

@abstractmethod
def __len__(self) -> int:
"""Number of chunks in the vector database."""
...

@abstractmethod
def delete(self) -> None:
"""Delete all chunks in the vector database."""

@abstractmethod
def add_docs(self, docs: List[Document], verbose: bool = True) -> None:
"""Adds documents to the vector database.
Args:
docs: List of Documents
verbose: Show progress bar
Returns:
None
"""
...

@abstractmethod
async def aadd_docs(self, docs: List[Document], verbose: bool = True) -> None:
"""Adds documents to the vector database (asynchronous).
Args:
docs: List of Documents
verbose: Show progress bar
Returns:
None
"""
...

@abstractmethod
def get_chunk(
self, query: str, with_score: bool = False, top_k: int = None
) -> Union[List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the vector database.
Args:
query: A query string
with_score: Outputs scores of returned chunks
top_k: Number of top similar chunks to return, if None defaults to self.top_k
Returns:
list of Documents
"""
...

@abstractmethod
async def aget_chunk(
self, query: str, with_score: bool = False, top_k: int = None
) -> Union[List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the vector database (asynchronous).
Args:
query: A query string
with_score: Outputs scores of returned chunks
top_k: Number of top similar chunks to return, if None defaults to self.top_k
Returns:
list of Documents
"""
...

def _filter_metadata(self, docs: List[Document]) -> List[Document]:
return filter_complex_metadata(docs, allowed_types=self.allowed_metadata_types)
Loading

0 comments on commit 56f16cf

Please sign in to comment.