diff --git a/src/beyondllm/embeddings/base.py b/src/beyondllm/embeddings/base.py index 7c1a171..c20999d 100644 --- a/src/beyondllm/embeddings/base.py +++ b/src/beyondllm/embeddings/base.py @@ -1,9 +1,9 @@ from pydantic import BaseModel class EmbeddingConfig(BaseModel): - """Base configuration model for all LLMs. + """Base configuration model for all Embeddings. - This class can be extended to include more fields specific to certain LLMs. + This class can be extended to include more fields specific to certain Embeddings. """ pass diff --git a/src/beyondllm/retrieve.py b/src/beyondllm/retrieve.py index efd7a0d..11a4b8b 100644 --- a/src/beyondllm/retrieve.py +++ b/src/beyondllm/retrieve.py @@ -6,7 +6,7 @@ from beyondllm.retrievers.utils import generate_qa_dataset, evaluate_from_dataset import pandas as pd -def auto_retriever(data,embed_model=None,type="normal",top_k=4,**kwargs): +def auto_retriever(data=None,embed_model=None,type="normal",top_k=4,vectordb=None,**kwargs): """ Automatically selects and initializes a retriever based on the specified type. Parameters: @@ -16,6 +16,7 @@ def auto_retriever(data,embed_model=None,type="normal",top_k=4,**kwargs): type (str): The type of retriever to use. Options include 'normal', 'flag-rerank', 'cross-rerank', and 'hybrid'. Defaults to 'normal'. top_k (int): The number of top results to retrieve. Defaults to 4. + vectordb (VectorDb): The vectordb to use for retrieval Additional parameters: reranker: Name of the reranking model to be used. To be specified only for type = 'flag-rerank' and 'cross-rerank' mode: Possible options are 'AND' or 'OR'. To be specified only for type = 'hybrid. 'AND' mode will retrieve nodes in common between @@ -27,22 +28,23 @@ def auto_retriever(data,embed_model=None,type="normal",top_k=4,**kwargs): data = embed_model = + vector_store = - retriever = auto_retriever(data=data, embed_model=embed_model, type="normal", top_k=5) + retriever = auto_retriever(data=data, embed_model=embed_model, type="normal", top_k=5, vectordb=vector_store) """ if embed_model is None: embed_model = GeminiEmbeddings() if type == 'normal': - retriever = NormalRetriever(data,embed_model,top_k,**kwargs) + retriever = NormalRetriever(data,embed_model,top_k,vectordb,**kwargs) elif type == 'flag-rerank': from .retrievers.flagReranker import FlagEmbeddingRerankRetriever - retriever = FlagEmbeddingRerankRetriever(data,embed_model,top_k,**kwargs) + retriever = FlagEmbeddingRerankRetriever(data,embed_model,top_k,vectordb,**kwargs) elif type == 'cross-rerank': from .retrievers.crossEncoderReranker import CrossEncoderRerankRetriever - retriever = CrossEncoderRerankRetriever(data,embed_model,top_k,**kwargs) + retriever = CrossEncoderRerankRetriever(data,embed_model,top_k,vectordb,**kwargs) elif type == 'hybrid': from .retrievers.hybridRetriever import HybridRetriever - retriever = HybridRetriever(data,embed_model,top_k,**kwargs) + retriever = HybridRetriever(data,embed_model,top_k,vectordb,**kwargs) else: raise NotImplementedError(f"Retriever for the type '{type}' is not implemented.") diff --git a/src/beyondllm/retrievers/base.py b/src/beyondllm/retrievers/base.py index 239a164..312b2e8 100644 --- a/src/beyondllm/retrievers/base.py +++ b/src/beyondllm/retrievers/base.py @@ -8,10 +8,12 @@ class BaseRetriever: data: The dataset to be indexed or retrieved from. embed_model: The embedding model used to generate embeddings for the data. top_k: The top k similarity search results to be retrieved + vectordb: The vectordb to be used for retrieval """ - def __init__(self, data, embed_model,**kwargs): + def __init__(self, data, embed_model, vectordb, **kwargs): self.data = data self.embed_model = embed_model + self.vectordb = vectordb def load_index(self): raise NotImplementedError("This method should be implemented by subclasses.") diff --git a/src/beyondllm/retrievers/crossEncoderReranker.py b/src/beyondllm/retrievers/crossEncoderReranker.py index ec13704..d8ff23d 100644 --- a/src/beyondllm/retrievers/crossEncoderReranker.py +++ b/src/beyondllm/retrievers/crossEncoderReranker.py @@ -1,5 +1,5 @@ from beyondllm.retrievers.base import BaseRetriever -from llama_index.core import VectorStoreIndex, ServiceContext +from llama_index.core import VectorStoreIndex, ServiceContext,StorageContext from llama_index.core.schema import QueryBundle import sys import subprocess @@ -45,10 +45,34 @@ def __init__(self, data, embed_model, top_k,*args, **kwargs): self.reranker = kwargs.get('reranker',"cross-encoder/ms-marco-MiniLM-L-2-v2") def load_index(self): - service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model) - index = VectorStoreIndex( - self.data, service_context= service_context, - ) + if self.data is None: + index = self.initialize_from_vector_store() + else: + index = self.initialize_from_data() + + return index + + def initialize_from_vector_store(self): + if self.vectordb is None: + raise ValueError("Vector store must be provided if no data is passed") + else: + index = VectorStoreIndex.from_vector_store( + self.vectordb, + embed_model=self.embed_model, + ) + return index + + + def initialize_from_data(self): + if self.vectordb==None: + index = VectorStoreIndex( + self.data, embed_model=self.embed_model + ) + else: + storage_context = StorageContext.from_defaults(vector_store=self.vectordb) + index = VectorStoreIndex( + self.data, storage_context=storage_context, embed_model=self.embed_model + ) return index def retrieve(self, query): diff --git a/src/beyondllm/retrievers/flagReranker.py b/src/beyondllm/retrievers/flagReranker.py index ed68a15..371d93d 100644 --- a/src/beyondllm/retrievers/flagReranker.py +++ b/src/beyondllm/retrievers/flagReranker.py @@ -1,5 +1,5 @@ from beyondllm.retrievers.base import BaseRetriever -from llama_index.core import VectorStoreIndex, ServiceContext +from llama_index.core import VectorStoreIndex, ServiceContext, StorageContext import sys import subprocess try: @@ -44,10 +44,34 @@ def __init__(self, data, embed_model, top_k,*args, **kwargs): self.reranker = kwargs.get('reranker',"BAAI/bge-reranker-large") def load_index(self): - service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model) - index = VectorStoreIndex( - self.data, service_context= service_context, - ) + if self.data is None: + index = self.initialize_from_vector_store() + else: + index = self.initialize_from_data() + + return index + + def initialize_from_vector_store(self): + if self.vectordb is None: + raise ValueError("Vector store must be provided if no data is passed") + else: + index = VectorStoreIndex.from_vector_store( + self.vectordb, + embed_model=self.embed_model, + ) + return index + + + def initialize_from_data(self): + if self.vectordb==None: + index = VectorStoreIndex( + self.data, embed_model=self.embed_model + ) + else: + storage_context = StorageContext.from_defaults(vector_store=self.vectordb) + index = VectorStoreIndex( + self.data, storage_context=storage_context, embed_model=self.embed_model + ) return index def retrieve(self, query): diff --git a/src/beyondllm/retrievers/hybridRetriever.py b/src/beyondllm/retrievers/hybridRetriever.py index 4aebe8a..a807bfd 100644 --- a/src/beyondllm/retrievers/hybridRetriever.py +++ b/src/beyondllm/retrievers/hybridRetriever.py @@ -88,15 +88,39 @@ def __init__(self, data, embed_model, top_k,*args, **kwargs): raise ValueError("Invalid mode. Mode must be 'AND' or 'OR'.") def load_index(self): - service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model) - storage_context = StorageContext.from_defaults() - vector_index = VectorStoreIndex( - self.data, service_context= service_context, storage_context=storage_context - ) - keyword_index = SimpleKeywordTableIndex( - self.data,service_context=service_context,storage_context=storage_context - ) + if self.data is None: + raise ValueError("Data needs to be passed for keyword retrieval.") + else: + vector_index, keyword_index = self.initialize_from_data() + return vector_index, keyword_index + + def initialize_from_data(self): + if self.vectordb==None: + vector_index = VectorStoreIndex( + self.data, embed_model=self.embed_model + ) + keyword_index = SimpleKeywordTableIndex( + self.data, service_context=ServiceContext.from_defaults(llm=None,embed_model=None) + ) + else: + storage_context = StorageContext.from_defaults(vector_store=self.vectordb) + vector_index = VectorStoreIndex( + self.data, storage_context=storage_context, embed_model=self.embed_model + ) + keyword_index = SimpleKeywordTableIndex( + self.data, service_context=ServiceContext.from_defaults(llm=None,embed_model=None) + ) + return vector_index, keyword_index + + # def load_index(self): + # vector_index = VectorStoreIndex( + # self.data, embed_model=self.embed_model + # ) + # keyword_index = SimpleKeywordTableIndex( + # self.data, service_context=ServiceContext.from_defaults(llm=None,embed_model=None) + # ) + # return vector_index, keyword_index def as_retriever(self): vector_index, keyword_index = self.load_index() diff --git a/src/beyondllm/retrievers/normalRetriever.py b/src/beyondllm/retrievers/normalRetriever.py index 6c91138..61da295 100644 --- a/src/beyondllm/retrievers/normalRetriever.py +++ b/src/beyondllm/retrievers/normalRetriever.py @@ -1,5 +1,5 @@ from beyondllm.retrievers.base import BaseRetriever -from llama_index.core import VectorStoreIndex, ServiceContext +from llama_index.core import VectorStoreIndex, ServiceContext, StorageContext class NormalRetriever(BaseRetriever): """ @@ -14,7 +14,7 @@ class NormalRetriever(BaseRetriever): results = retriever.retrieve("") """ - def __init__(self, data, embed_model, top_k,*args, **kwargs): + def __init__(self, data, embed_model, top_k, vectordb,*args, **kwargs): """ Initializes a NormalRetriever instance. @@ -22,19 +22,45 @@ def __init__(self, data, embed_model, top_k,*args, **kwargs): data: The dataset to be indexed. embed_model: The embedding model to use. top_k: The number of top results to retrieve. + vectordb: The vectordb to use for retrieval """ - super().__init__(data, embed_model,*args, **kwargs) + super().__init__(data, embed_model, vectordb,*args, **kwargs) self.embed_model = embed_model self.data = data self.top_k = top_k + self.vectordb = vectordb def load_index(self): - service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model) - index = VectorStoreIndex( - self.data, service_context= service_context - ) + if self.data is None: + index = self.initialize_from_vector_store() + else: + index = self.initialize_from_data() + return index + def initialize_from_vector_store(self): + if self.vectordb is None: + raise ValueError("Vector store must be provided if no data is passed") + else: + index = VectorStoreIndex.from_vector_store( + self.vectordb, + embed_model=self.embed_model, + ) + return index + + + def initialize_from_data(self): + if self.vectordb==None: + index = VectorStoreIndex( + self.data, embed_model=self.embed_model + ) + else: + storage_context = StorageContext.from_defaults(vector_store=self.vectordb) + index = VectorStoreIndex( + self.data, storage_context=storage_context, embed_model=self.embed_model + ) + return index + def retrieve(self, query): retriever = self.as_retriever() return retriever.retrieve(query) diff --git a/src/beyondllm/vectordb/__init__.py b/src/beyondllm/vectordb/__init__.py new file mode 100644 index 0000000..08522af --- /dev/null +++ b/src/beyondllm/vectordb/__init__.py @@ -0,0 +1 @@ +from .chroma import ChromaVectorDb \ No newline at end of file diff --git a/src/beyondllm/vectordb/base.py b/src/beyondllm/vectordb/base.py index e69de29..aa32dda 100644 --- a/src/beyondllm/vectordb/base.py +++ b/src/beyondllm/vectordb/base.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel + +class VectorDbConfig(BaseModel): + """Base configuration model for all LLMs. + + This class can be extended to include more fields specific to certain LLMs. + """ + pass + +class VectorDb(BaseModel): + def load(self): + raise NotImplementedError("This method should be implemented by subclasses.") + + def add(self,*args, **kwargs): + raise NotImplementedError("This method should be implemented by subclasses.") + + def stores_text(self,*args, **kwargs): + raise NotImplementedError("This method should be implemented by subclasses.") + + def is_embedding_query(self,*args, **kwargs): + raise NotImplementedError("This method should be implemented by subclasses.") + + def query(self,*args, **kwargs): + raise NotImplementedError("This method should be implemented by subclasses.") \ No newline at end of file diff --git a/src/beyondllm/vectordb/chroma.py b/src/beyondllm/vectordb/chroma.py new file mode 100644 index 0000000..aee89ff --- /dev/null +++ b/src/beyondllm/vectordb/chroma.py @@ -0,0 +1,73 @@ +from beyondllm.vectordb.base import VectorDb, VectorDbConfig +from dataclasses import dataclass, field +import warnings +warnings.filterwarnings("ignore") +import subprocess,sys +try: + from llama_index.vector_stores.chroma import ChromaVectorStore +except ImportError: + user_agree = input("The feature you're trying to use requires an additional library(s):llama_index.vector_stores.chroma. Would you like to install it now? [y/N]: ") + if user_agree.lower() == 'y': + subprocess.check_call([sys.executable, "-m", "pip", "install", "llama_index.vector_stores.chroma"]) + from llama_index.vector_stores.chroma import ChromaVectorStore + else: + raise ImportError("The required 'llama_index.vector_stores.chroma' is not installed.") +import chromadb + +@dataclass +class ChromaVectorDb: + """ + from beyondllm.vectordb import ChromaVectorDb + vectordb = ChromaVectorDb(collection_name="quickstart",persist_directory="./db/chroma/") + """ + collection_name: str + persist_directory: str = "" + + def __post_init__(self): + if self.persist_directory=="" or self.persist_directory==None: + self.chroma_client = chromadb.EphemeralClient() + else: + self.chroma_client = chromadb.PersistentClient(self.persist_directory) + self.load() + + def load(self): + try: + from llama_index.vector_stores.chroma import ChromaVectorStore + except: + raise ImportError("ChromaVectorStore library is not installed. Please install it with ``pip install llama_index.vector_stores.chroma``.") + + # More clarity and specificity required for try error statements + try: + try: + chroma_collection = self.chroma_client.get_collection(self.collection_name) + except Exception: + chroma_collection = self.chroma_client.create_collection(self.collection_name) + vector_store = ChromaVectorStore(chroma_collection=chroma_collection) + self.client = vector_store + except Exception as e: + raise Exception(f"Failed to load the Chroma Vectorstore: {e}") + + return self.client + + def add(self,*args, **kwargs): + client = self.client + return client.add(*args, **kwargs) + + def stores_text(self,*args, **kwargs): + client = self.client + return client.stores_text(*args, **kwargs) + + def is_embedding_query(self,*args, **kwargs): + client = self.client + return client.is_embedding_query(*args, **kwargs) + + def query(self,*args, **kwargs): + client = self.client + return client.query(*args, **kwargs) + + + @staticmethod + def load_from_kwargs(self,kwargs): + embed_config = VectorDbConfig(**kwargs) + self.config = embed_config + self.load()