diff --git a/projects/Basic-RAG/BasicRAG_stuff.py b/projects/Basic-RAG/BasicRAG_stuff.py index 4bfafc3..63edeab 100644 --- a/projects/Basic-RAG/BasicRAG_stuff.py +++ b/projects/Basic-RAG/BasicRAG_stuff.py @@ -1,6 +1,10 @@ -from grag.grag.rag import BasicRAG +from grag.components.multivec_retriever import Retriever +from grag.components.vectordb.deeplake_client import DeepLakeClient +from grag.rag.basic_rag import BasicRAG -rag = BasicRAG(doc_chain="stuff") +client = DeepLakeClient(collection_name="test") +retriever = Retriever(vectordb=client) +rag = BasicRAG(doc_chain="stuff", retriever=retriever) if __name__ == "__main__": while True: diff --git a/projects/Retriver-GUI/retriever_app.py b/projects/Retriver-GUI/retriever_app.py index f55c0c6..9f4198c 100644 --- a/projects/Retriver-GUI/retriever_app.py +++ b/projects/Retriver-GUI/retriever_app.py @@ -46,7 +46,7 @@ def render_search_results(self): st.write(result.metadata) def check_connection(self): - response = self.app.retriever.client.test_connection() + response = self.app.retriever.vectordb.test_connection() if response: return True else: @@ -55,14 +55,14 @@ def check_connection(self): def render_stats(self): st.write(f''' **Chroma Client Details:** \n - Host Address : {self.app.retriever.client.host}:{self.app.retriever.client.port} \n - Collection Name : {self.app.retriever.client.collection_name} \n - Embeddings Type : {self.app.retriever.client.embedding_type} \n - Embeddings Model: {self.app.retriever.client.embedding_model} \n - Number of docs : {self.app.retriever.client.collection.count()} \n + Host Address : {self.app.retriever.vectordb.host}:{self.app.retriever.vectordb.port} \n + Collection Name : {self.app.retriever.vectordb.collection_name} \n + Embeddings Type : {self.app.retriever.vectordb.embedding_type} \n + Embeddings Model: {self.app.retriever.vectordb.embedding_model} \n + Number of docs : {self.app.retriever.vectordb.collection.count()} \n ''') if st.button('Check Connection'): - response = self.app.retriever.client.test_connection() + response = self.app.retriever.vectordb.test_connection() if response: st.write(':green[Connection Active]') else: diff --git a/pyproject.toml b/pyproject.toml index 897ab02..b97ba19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "huggingface_hub>=0.20.2", "pydantic>=2.5.0", "rouge-score>=0.1.2", + "deeplake>=3.8.27" ] [project.urls] @@ -101,9 +102,13 @@ exclude_lines = [ [tool.ruff] line-length = 88 indent-width = 4 +extend-exclude = ["tests", "others"] [tool.ruff.lint] select = ["E4", "E7", "E9", "F", "I", "D"] +ignore = ["D104"] +exclude = ["__about__.py"] + [tool.ruff.format] quote-style = "double" diff --git a/src/config.ini b/src/config.ini index 74ab6c4..c2938f9 100644 --- a/src/config.ini +++ b/src/config.ini @@ -14,6 +14,12 @@ n_gpu_layers_cpp : -1 std_out : True base_dir : ${root:root_path}/models +[deeplake] +collection_name : arxiv +embedding_type : instructor-embedding +embedding_model : hkunlp/instructor-xl +store_path : ${data:data_path}/vectordb + [chroma] host : localhost port : 8000 @@ -25,6 +31,14 @@ embedding_model : hkunlp/instructor-xl store_path : ${data:data_path}/vectordb allow_reset : True +[deeplake] +collection_name : arxiv +# embedding_type : sentence-transformers +# embedding_model : "all-mpnet-base-v2" +embedding_type : instructor-embedding +embedding_model : hkunlp/instructor-xl +store_path : ${data:data_path}/vectordb + [text_splitter] chunk_size : 5000 chunk_overlap : 400 @@ -54,4 +68,4 @@ data_path : ${root:root_path}/data root_path : /home/ubuntu/volume_2k/Capstone_5 [quantize] -llama_cpp_path : ${root:root_path} \ No newline at end of file +llama_cpp_path : ${root:root_path} diff --git a/src/grag/components/embedding.py b/src/grag/components/embedding.py index 7a9d249..eeb0f82 100644 --- a/src/grag/components/embedding.py +++ b/src/grag/components/embedding.py @@ -1,3 +1,9 @@ +"""Class for embedding. + +This module provides: +- Embedding +""" + from langchain_community.embeddings import HuggingFaceInstructEmbeddings from langchain_community.embeddings.sentence_transformer import ( SentenceTransformerEmbeddings, @@ -6,6 +12,7 @@ class Embedding: """A class for vector embeddings. + Supports: huggingface sentence transformers -> model_type = 'sentence-transformers' huggingface instructor embeddings -> model_type = 'instructor-embedding' @@ -17,6 +24,7 @@ class Embedding: """ def __init__(self, embedding_type: str, embedding_model: str): + """Initialize the embedding with embedding_type and embedding_model.""" self.embedding_type = embedding_type self.embedding_model = embedding_model match self.embedding_type: diff --git a/src/grag/components/llm.py b/src/grag/components/llm.py index 20db968..6e7296c 100644 --- a/src/grag/components/llm.py +++ b/src/grag/components/llm.py @@ -1,3 +1,5 @@ +"""Class for LLM.""" + import os from pathlib import Path @@ -50,6 +52,7 @@ def __init__( quantization=llm_conf["quantization"], pipeline=llm_conf["pipeline"], ): + """Initialize the LLM class using the given parameters.""" self.base_dir = Path(base_dir) self._model_name = model_name self.quantization = quantization diff --git a/src/grag/components/multivec_retriever.py b/src/grag/components/multivec_retriever.py index 18ed752..9fd8664 100644 --- a/src/grag/components/multivec_retriever.py +++ b/src/grag/components/multivec_retriever.py @@ -1,10 +1,17 @@ +"""Class for retriever. + +This module provides: +- Retriever +""" + import asyncio import uuid -from typing import List +from typing import Any, Dict, List, Optional -from grag.components.chroma_client import ChromaClient from grag.components.text_splitter import TextSplitter from grag.components.utils import get_config +from grag.components.vectordb.base import VectorDB +from grag.components.vectordb.deeplake_client import DeepLakeClient from langchain.retrievers.multi_vector import MultiVectorRetriever from langchain.storage import LocalFileStore from langchain_core.documents import Document @@ -13,14 +20,16 @@ class Retriever: - """A class for multi vector retriever, it connects to a vector database and a local file store. - It is used to return most similar chunks from a vector store but has the additional funcationality - to return a linked document, chunk, etc. + """A class for multi vector retriever. + + It connects to a vector database and a local file store. + It is used to return most similar chunks from a vector store but has the additional functionality to return a + linked document, chunk, etc. Attributes: store_path: Path to the local file store id_key: A key prefix for identifying documents - client: ChromaClient class instance from components.chroma_client + vectordb: ChromaClient class instance from components.client store: langchain.storage.LocalFileStore object, stores the key value pairs of document id and parent file retriever: langchain.retrievers.multi_vector.MultiVectorRetriever class instance, langchain's multi-vector retriever splitter: TextSplitter class instance from components.text_splitter @@ -31,12 +40,16 @@ class Retriever: def __init__( self, + vectordb: Optional[VectorDB] = None, store_path: str = multivec_retriever_conf["store_path"], id_key: str = multivec_retriever_conf["id_key"], namespace: str = multivec_retriever_conf["namespace"], top_k=1, + client_kwargs: Optional[Dict[str, Any]] = None, ): - """Args: + """Initialize the Retriever. + + Args: store_path: Path to the local file store, defaults to argument from config file id_key: A key prefix for identifying documents, defaults to argument from config file namespace: A namespace for producing unique id, defaults to argument from congig file @@ -45,10 +58,16 @@ def __init__( self.store_path = store_path self.id_key = id_key self.namespace = uuid.UUID(namespace) - self.client = ChromaClient() + if vectordb is None: + if client_kwargs is not None: + self.vectordb = DeepLakeClient(**client_kwargs) + else: + self.vectordb = DeepLakeClient() + else: + self.vectordb = vectordb self.store = LocalFileStore(self.store_path) self.retriever = MultiVectorRetriever( - vectorstore=self.client.langchain_chroma, + vectorstore=self.vectordb.langchain_client, byte_store=self.store, id_key=self.id_key, ) @@ -58,6 +77,7 @@ def __init__( def id_gen(self, doc: Document) -> str: """Takes a document and returns a unique id (uuid5) using the namespace and document source. + This ensures that a single document always gets the same unique id. Args: @@ -81,7 +101,9 @@ def gen_doc_ids(self, docs: List[Document]) -> List[str]: return [self.id_gen(doc) for doc in docs] def split_docs(self, docs: List[Document]) -> List[Document]: - """Takes a list of documents and splits them into smaller chunks using TextSplitter from compoenents.text_splitter + """Takes a list of documents and splits them into smaller chunks. + + Using TextSplitter from components.text_splitter Also adds the unique parent document id into metadata Args: @@ -101,8 +123,7 @@ def split_docs(self, docs: List[Document]) -> List[Document]: return chunks def add_docs(self, docs: List[Document]): - """Takes a list of documents, splits them using the split_docs method and then adds them into the vector database - and adds the parent document into the file store. + """Adds given documents into the vector database also adds the parent document into the file store. Args: docs: List of langchain_core.documents.Document @@ -113,12 +134,11 @@ def add_docs(self, docs: List[Document]): """ chunks = self.split_docs(docs) doc_ids = self.gen_doc_ids(docs) - self.client.add_docs(chunks) + self.vectordb.add_docs(chunks) self.retriever.docstore.mset(list(zip(doc_ids, docs))) async def aadd_docs(self, docs: List[Document]): - """Takes a list of documents, splits them using the split_docs method and then adds them into the vector database - and adds the parent document into the file store. + """Adds given documents into the vector database also adds the parent document into the file store. Args: docs: List of langchain_core.documents.Document @@ -129,11 +149,11 @@ async def aadd_docs(self, docs: List[Document]): """ chunks = self.split_docs(docs) doc_ids = self.gen_doc_ids(docs) - await asyncio.run(self.client.aadd_docs(chunks)) + await asyncio.run(self.vectordb.aadd_docs(chunks)) self.retriever.docstore.mset(list(zip(doc_ids))) def get_chunk(self, query: str, with_score=False, top_k=None): - """Returns the most (cosine) similar chunks from the vector database. + """Returns the most similar chunks from the vector database. Args: query: A query string @@ -144,14 +164,8 @@ def get_chunk(self, query: str, with_score=False, top_k=None): list of Documents """ - if with_score: - return self.client.langchain_chroma.similarity_search_with_relevance_scores( - query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs - ) - else: - return self.client.langchain_chroma.similarity_search( - query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs - ) + _top_k = top_k if top_k else self.retriever.search_kwargs["k"] + return self.vectordb.get_chunk(query=query, top_k=_top_k, with_score=with_score) async def aget_chunk(self, query: str, with_score=False, top_k=None): """Returns the most (cosine) similar chunks from the vector database, asynchronously. @@ -165,14 +179,10 @@ async def aget_chunk(self, query: str, with_score=False, top_k=None): list of Documents """ - if with_score: - return await self.client.langchain_chroma.asimilarity_search_with_relevance_scores( - query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs - ) - else: - return await self.client.langchain_chroma.asimilarity_search( - query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs - ) + _top_k = top_k if top_k else self.retriever.search_kwargs["k"] + return await self.vectordb.aget_chunk( + query=query, top_k=_top_k, with_score=with_score + ) def get_doc(self, query: str): """Returns the parent document of the most (cosine) similar chunk from the vector database. diff --git a/src/grag/components/parse_pdf.py b/src/grag/components/parse_pdf.py index d918c93..dc30f8a 100644 --- a/src/grag/components/parse_pdf.py +++ b/src/grag/components/parse_pdf.py @@ -1,3 +1,9 @@ +"""Classes for parsing files. + +This module provides: +- ParsePDF +""" + from langchain_core.documents import Document from unstructured.partition.pdf import partition_pdf @@ -32,7 +38,7 @@ def __init__( add_captions_to_blocks=parser_conf["add_captions_to_blocks"], table_as_html=parser_conf["table_as_html"], ): - # Instantialize instance variables with parameters + """Initialize instance variables with parameters.""" self.strategy = strategy if extract_images: # by default always extract Table self.extract_image_block_types = [ @@ -72,7 +78,8 @@ def partition(self, path: str): def classify(self, partitions): """Classifies the partitioned elements into Text, Tables, and Images list in a dictionary. - Add captions for each element (if available). + + Also adds captions for each element (if available). Parameters: partitions (list): The list of partitioned elements from the PDF document. @@ -117,6 +124,8 @@ def classify(self, partitions): return classified_elements def text_concat(self, elements) -> str: + """Context aware concatenates all elements into a single string.""" + full_text = "" for current_element, next_element in zip(elements, elements[1:]): curr_type = current_element.category next_type = next_element.category diff --git a/src/grag/components/prompt.py b/src/grag/components/prompt.py index ecefa71..4364c06 100644 --- a/src/grag/components/prompt.py +++ b/src/grag/components/prompt.py @@ -1,3 +1,10 @@ +"""Classes for prompts. + +This module provides: +- Prompt - for generic prompts +- FewShotPrompt - for few-shot prompts +""" + import json from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -13,6 +20,20 @@ class Prompt(BaseModel): + """A class for generic prompts. + + Attributes: + name (str): The prompt name (Optional, defaults to "custom_prompt") + llm_type (str): The type of llm, llama2, etc (Optional, defaults to "None") + task (str): The task (Optional, defaults to QA) + source (str): The source of the prompt (Optional, defaults to "NoSource") + doc_chain (str): The doc chain for the prompt ("stuff", "refine") (Optional, defaults to "stuff") + language (str): The language of the prompt (Optional, defaults to "en") + filepath (str): The filepath of the prompt (Optional) + input_keys (List[str]): The input keys for the prompt + template (str): The template for the prompt + """ + name: str = Field(default="custom_prompt") llm_type: str = Field(default="None") task: str = Field(default="QA") @@ -27,6 +48,7 @@ class Prompt(BaseModel): @field_validator("input_keys") @classmethod def validate_input_keys(cls, v) -> List[str]: + """Validate the input_keys field.""" if v is None or v == []: raise ValueError("input_keys cannot be empty") return v @@ -34,6 +56,7 @@ def validate_input_keys(cls, v) -> List[str]: @field_validator("doc_chain") @classmethod def validate_doc_chain(cls, v: str) -> str: + """Validate the doc_chain field.""" if v not in SUPPORTED_DOC_CHAINS: raise ValueError( f"The provided doc_chain, {v} is not supported, supported doc_chains are {SUPPORTED_DOC_CHAINS}" @@ -43,6 +66,7 @@ def validate_doc_chain(cls, v: str) -> str: @field_validator("task") @classmethod def validate_task(cls, v: str) -> str: + """Validate the task field.""" if v not in SUPPORTED_TASKS: raise ValueError( f"The provided task, {v} is not supported, supported tasks are {SUPPORTED_TASKS}" @@ -53,6 +77,7 @@ def validate_task(cls, v: str) -> str: # def load_template(self): # self.prompt = ChatPromptTemplate.from_template(self.template) def __init__(self, **kwargs): + """Initialize the prompt.""" super().__init__(**kwargs) self.prompt = PromptTemplate( input_variables=self.input_keys, template=self.template @@ -61,6 +86,7 @@ def __init__(self, **kwargs): def save( self, filepath: Union[Path, str, None], overwrite=False ) -> Union[None, ValueError]: + """Saves the prompt class into a json file.""" dump = self.model_dump_json(indent=2, exclude_defaults=True, exclude_none=True) if filepath is None: filepath = f"{self.name}.json" @@ -74,17 +100,37 @@ def save( @classmethod def load(cls, filepath: Union[Path, str]): + """Loads a json file and returns a Prompt class.""" with open(f"{filepath}", "r") as f: prompt_json = json.load(f) _prompt = cls(**prompt_json) _prompt.filepath = str(filepath) return _prompt - def format(self, **kwargs): + def format(self, **kwargs) -> str: + """Formats the prompt with provided keys and returns a string.""" return self.prompt.format(**kwargs) class FewShotPrompt(Prompt): + """A class for generic prompts. + + Attributes: + name (str): The prompt name (Optional, defaults to "custom_prompt") (Parent Class) + llm_type (str): The type of llm, llama2, etc (Optional, defaults to "None") (Parent Class) + task (str): The task (Optional, defaults to QA) (Parent Class) + source (str): The source of the prompt (Optional, defaults to "NoSource") (Parent Class) + doc_chain (str): The doc chain for the prompt ("stuff", "refine") (Optional, defaults to "stuff") (Parent Class) + language (str): The language of the prompt (Optional, defaults to "en") (Parent Class) + filepath (str): The filepath of the prompt (Optional) (Parent Class) + input_keys (List[str]): The input keys for the prompt (Parent Class) + input_keys (List[str]): The output keys for the prompt + prefix (str): The template prefix for the prompt + suffix (str): The template suffix for the prompt + example_template (str): The template for formatting the examples + examples (List[Dict[str, Any]]): The list of examples, each example is a dictionary with respective keys + """ + output_keys: List[str] examples: List[Dict[str, Any]] prefix: str @@ -95,6 +141,7 @@ class FewShotPrompt(Prompt): ) def __init__(self, **kwargs): + """Initialize the prompt.""" super().__init__(**kwargs) eg_formatter = PromptTemplate( input_vars=self.input_keys + self.output_keys, @@ -111,6 +158,7 @@ def __init__(self, **kwargs): @field_validator("output_keys") @classmethod def validate_output_keys(cls, v) -> List[str]: + """Validate the output_keys field.""" if v is None or v == []: raise ValueError("output_keys cannot be empty") return v @@ -118,6 +166,7 @@ def validate_output_keys(cls, v) -> List[str]: @field_validator("examples") @classmethod def validate_examples(cls, v) -> List[Dict[str, Any]]: + """Validate the examples field.""" if v is None or v == []: raise ValueError("examples cannot be empty") for eg in v: diff --git a/src/grag/components/text_splitter.py b/src/grag/components/text_splitter.py index cff3c7c..d04c9a5 100644 --- a/src/grag/components/text_splitter.py +++ b/src/grag/components/text_splitter.py @@ -1,3 +1,9 @@ +"""Class for splitting/chunking text. + +This module provides: +- TextSplitter +""" + from langchain.text_splitter import RecursiveCharacterTextSplitter from .utils import get_config @@ -7,10 +13,23 @@ # %% class TextSplitter: - def __init__(self): + """Class for recursively chunking text, it prioritizes '/n/n then '/n' and so on. + + Attributes: + chunk_size: maximum size of chunk + chunk_overlap: chunk overlap size + """ + + def __init__( + self, + chunk_size: int = text_splitter_conf["chunk_size"], + chunk_overlap: int = text_splitter_conf["chunk_overlap"], + ): + """Initialize TextSplitter.""" self.text_splitter = RecursiveCharacterTextSplitter( - chunk_size=int(text_splitter_conf["chunk_size"]), - chunk_overlap=int(text_splitter_conf["chunk_overlap"]), + chunk_size=int(chunk_size), + chunk_overlap=int(chunk_overlap), length_function=len, is_separator_regex=False, ) + """Initialize TextSplitter using chunk_size and chunk_overlap""" diff --git a/src/grag/components/utils.py b/src/grag/components/utils.py index cb64258..2e34dc9 100644 --- a/src/grag/components/utils.py +++ b/src/grag/components/utils.py @@ -1,3 +1,12 @@ +"""Utils functions. + +This module provides: +- stuff_docs: concats langchain documents into string +- load_prompt: loads json prompt to langchain prompt +- find_config_path: finds the path of the 'config.ini' file by traversing up the directory tree from the current path. +- get_config: retrieves and parses the configuration settings from the 'config.ini' file. +""" + import json import os import textwrap @@ -10,7 +19,9 @@ def stuff_docs(docs: List[Document]) -> str: - """Args: + r"""Concatenates langchain documents into a string using '\n\n' seperator. + + Args: docs: List of langchain_core.documents.Document Returns: @@ -20,8 +31,7 @@ def stuff_docs(docs: List[Document]) -> str: def reformat_text_with_line_breaks(input_text, max_width=110): - """Reformat the given text to ensure each line does not exceed a specific width, - preserving existing line breaks. + """Reformat the given text to ensure each line does not exceed a specific width, preserving existing line breaks. Args: input_text (str): The text to be reformatted. @@ -62,7 +72,7 @@ def display_llm_output_and_sources(response_from_llm): def load_prompt(json_file: str | os.PathLike, return_input_vars=False): - """Loads a prompt template from json file and returns a langchain ChatPromptTemplate + """Loads a prompt template from json file and returns a langchain ChatPromptTemplate. Args: json_file: path to the prompt template json file. diff --git a/src/grag/components/vectordb/__init__.py b/src/grag/components/vectordb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/grag/components/vectordb/base.py b/src/grag/components/vectordb/base.py new file mode 100644 index 0000000..b0b0623 --- /dev/null +++ b/src/grag/components/vectordb/base.py @@ -0,0 +1,85 @@ +"""Abstract base class for vector database clients. + +This module provides: +- VectorDB +""" + +from abc import ABC, abstractmethod +from typing import List, Tuple, Union + +from langchain_community.vectorstores.utils import filter_complex_metadata +from langchain_core.documents import Document + + +class VectorDB(ABC): + """Abstract base class for vector database clients.""" + + @abstractmethod + def __len__(self) -> int: + """Number of chunks in the vector database.""" + ... + + @abstractmethod + def delete(self) -> None: + """Delete all chunks in the vector database.""" + + @abstractmethod + def add_docs(self, docs: List[Document], verbose: bool = True) -> None: + """Adds documents to the vector database. + + Args: + docs: List of Documents + verbose: Show progress bar + + Returns: + None + """ + ... + + @abstractmethod + async def aadd_docs(self, docs: List[Document], verbose: bool = True) -> None: + """Adds documents to the vector database (asynchronous). + + Args: + docs: List of Documents + verbose: Show progress bar + + Returns: + None + """ + ... + + @abstractmethod + def get_chunk( + self, query: str, with_score: bool = False, top_k: int = None + ) -> Union[List[Document], List[Tuple[Document, float]]]: + """Returns the most similar chunks from the vector database. + + Args: + query: A query string + with_score: Outputs scores of returned chunks + top_k: Number of top similar chunks to return, if None defaults to self.top_k + + Returns: + list of Documents + """ + ... + + @abstractmethod + async def aget_chunk( + self, query: str, with_score: bool = False, top_k: int = None + ) -> Union[List[Document], List[Tuple[Document, float]]]: + """Returns the most similar chunks from the vector database (asynchronous). + + Args: + query: A query string + with_score: Outputs scores of returned chunks + top_k: Number of top similar chunks to return, if None defaults to self.top_k + + Returns: + list of Documents + """ + ... + + def _filter_metadata(self, docs: List[Document]) -> List[Document]: + return filter_complex_metadata(docs, allowed_types=self.allowed_metadata_types) diff --git a/src/grag/components/chroma_client.py b/src/grag/components/vectordb/chroma_client.py similarity index 53% rename from src/grag/components/chroma_client.py rename to src/grag/components/vectordb/chroma_client.py index 7efd7c3..cac8ab3 100644 --- a/src/grag/components/chroma_client.py +++ b/src/grag/components/vectordb/chroma_client.py @@ -1,10 +1,16 @@ -from typing import List +"""Class for Chroma vector database. + +This module provides: +- ChromaClient +""" + +from typing import List, Tuple, Union import chromadb from grag.components.embedding import Embedding from grag.components.utils import get_config +from grag.components.vectordb.base import VectorDB from langchain_community.vectorstores import Chroma -from langchain_community.vectorstores.utils import filter_complex_metadata from langchain_core.documents import Document from tqdm import tqdm from tqdm.asyncio import tqdm as atqdm @@ -12,7 +18,7 @@ chroma_conf = get_config()["chroma"] -class ChromaClient: +class ChromaClient(VectorDB): """A class for connecting to a hosted Chroma Vectorstore collection. Attributes: @@ -24,15 +30,15 @@ class ChromaClient: name of the collection in the Chroma Vectorstore, each ChromaClient connects to a single collection embedding_type : str type of embedding used, supported 'sentence-transformers' and 'instructor-embedding' - embedding_modelname : str + embedding_model : str model name of embedding used, should correspond to the embedding_type embedding_function a function of the embedding model, derived from the embedding_type and embedding_modelname - chroma_client + client: chromadb.HttpClient Chroma API for client collection Chroma API for the collection - langchain_chroma + langchain_client: langchain_community.vectorstores.Chroma LangChain wrapper for Chroma collection """ @@ -44,7 +50,9 @@ def __init__( embedding_type=chroma_conf["embedding_type"], embedding_model=chroma_conf["embedding_model"], ): - """Args: + """Initialize a ChromaClient object. + + Args: host: IP Address of hosted Chroma Vectorstore, defaults to argument from config file port: port address of hosted Chroma Vectorstore, defaults to argument from config file collection_name: name of the collection in the Chroma Vectorstore, defaults to argument from config file @@ -61,19 +69,35 @@ def __init__( embedding_model=self.embedding_model, embedding_type=self.embedding_type ).embedding_function - self.chroma_client = chromadb.HttpClient(host=self.host, port=self.port) - self.collection = self.chroma_client.get_or_create_collection( + self.client = chromadb.HttpClient(host=self.host, port=self.port) + self.collection = self.client.get_or_create_collection( name=self.collection_name ) - self.langchain_chroma = Chroma( - client=self.chroma_client, + self.langchain_client = Chroma( + client=self.client, collection_name=self.collection_name, embedding_function=self.embedding_function, ) self.allowed_metadata_types = (str, int, float, bool) - def test_connection(self, verbose=True): - """Tests connection with Chroma Vectorstore + def __len__(self) -> int: + """Count the number of chunks in the database.""" + return self.collection.count() + + def delete(self) -> None: + """Delete all the chunks in the database collection.""" + self.client.delete_collection(self.collection_name) + self.collection = self.client.get_or_create_collection( + name=self.collection_name + ) + self.langchain_client = Chroma( + client=self.client, + collection_name=self.collection_name, + embedding_function=self.embedding_function, + ) + + def test_connection(self, verbose=True) -> int: + """Tests connection with Chroma Vectorstore. Args: verbose: if True, prints connection status @@ -81,7 +105,7 @@ def test_connection(self, verbose=True): Returns: A random integer if connection is alive else None """ - response = self.chroma_client.heartbeat() + response = self.client.heartbeat() if verbose: if response: print(f"Connection to {self.host}/{self.port} is alive..") @@ -89,8 +113,24 @@ def test_connection(self, verbose=True): print(f"Connection to {self.host}/{self.port} is not alive !!") return response - async def aadd_docs(self, docs: List[Document], verbose=True): - """Asynchronously adds documents to chroma vectorstore + def add_docs(self, docs: List[Document], verbose=True) -> None: + """Adds documents to chroma vectorstore. + + Args: + docs: List of Documents + verbose: Show progress bar + + Returns: + None + """ + docs = self._filter_metadata(docs) + for doc in ( + tqdm(docs, desc=f"Adding to {self.collection_name}:") if verbose else docs + ): + _id = self.langchain_client.add_documents([doc]) + + async def aadd_docs(self, docs: List[Document], verbose=True) -> None: + """Asynchronously adds documents to chroma vectorstore. Args: docs: List of Documents @@ -100,37 +140,59 @@ async def aadd_docs(self, docs: List[Document], verbose=True): None """ docs = self._filter_metadata(docs) - # tasks = [self.langchain_chroma.aadd_documents([doc]) for doc in docs] - # if verbose: - # await tqdm_asyncio.gather(*tasks, desc=f'Adding to {self.collection_name}') - # else: - # await asyncio.gather(*tasks) if verbose: for doc in atqdm( docs, desc=f"Adding documents to {self.collection_name}", total=len(docs), ): - await self.langchain_chroma.aadd_documents([doc]) + await self.langchain_client.aadd_documents([doc]) else: for doc in docs: - await self.langchain_chroma.aadd_documents([doc]) + await self.langchain_client.aadd_documents([doc]) - def add_docs(self, docs: List[Document], verbose=True): - """Adds documents to chroma vectorstore + def get_chunk( + self, query: str, with_score: bool = False, top_k: int = None + ) -> Union[List[Document], List[Tuple[Document, float]]]: + """Returns the most similar chunks from the chroma database. Args: - docs: List of Documents - verbose: Show progress bar + query: A query string + with_score: Outputs scores of returned chunks + top_k: Number of top similar chunks to return, if None defaults to self.top_k Returns: - None + list of Documents + """ - docs = self._filter_metadata(docs) - for doc in ( - tqdm(docs, desc=f"Adding to {self.collection_name}:") if verbose else docs - ): - _id = self.langchain_chroma.add_documents([doc]) + if with_score: + return self.langchain_client.similarity_search_with_relevance_scores( + query=query, k=top_k if top_k else 1 + ) + else: + return self.langchain_client.similarity_search( + query=query, k=top_k if top_k else 1 + ) + + async def aget_chunk( + self, query: str, with_score=False, top_k=None + ) -> Union[List[Document], List[Tuple[Document, float]]]: + """Returns the most (cosine) similar chunks from the vector database, asynchronously. + + Args: + query: A query string + with_score: Outputs scores of returned chunks + top_k: Number of top similar chunks to return, if None defaults to self.top_k + + Returns: + list of Documents - def _filter_metadata(self, docs: List[Document]): - return filter_complex_metadata(docs, allowed_types=self.allowed_metadata_types) + """ + if with_score: + return await self.langchain_client.asimilarity_search_with_relevance_scores( + query=query, k=top_k if top_k else 1 + ) + else: + return await self.langchain_client.asimilarity_search( + query=query, k=top_k if top_k else 1 + ) diff --git a/src/grag/components/vectordb/deeplake_client.py b/src/grag/components/vectordb/deeplake_client.py new file mode 100644 index 0000000..f0d5ba5 --- /dev/null +++ b/src/grag/components/vectordb/deeplake_client.py @@ -0,0 +1,159 @@ +"""Class for DeepLake vector database. + +This module provides: +- DeepLakeClient +""" + +from pathlib import Path +from typing import List, Tuple, Union + +from grag.components.embedding import Embedding +from grag.components.utils import get_config +from grag.components.vectordb.base import VectorDB +from langchain_community.vectorstores import DeepLake +from langchain_core.documents import Document +from tqdm import tqdm +from tqdm.asyncio import tqdm as atqdm + +deeplake_conf = get_config()["deeplake"] + + +class DeepLakeClient(VectorDB): + """A class for connecting to a DeepLake Vectorstore. + + Attributes: + store_path : str, Path + The path to store the DeepLake vectorstore. + embedding_type : str + type of embedding used, supported 'sentence-transformers' and 'instructor-embedding' + embedding_model : str + model name of embedding used, should correspond to the embedding_type + embedding_function + a function of the embedding model, derived from the embedding_type and embedding_modelname + client: deeplake.core.vectorstore.VectorStore + DeepLake API + collection + Chroma API for the collection + langchain_client: langchain_community.vectorstores.DeepLake + LangChain wrapper for DeepLake API + """ + + def __init__( + self, + collection_name: str = deeplake_conf["collection_name"], + store_path: Union[str, Path] = deeplake_conf["store_path"], + embedding_type: str = deeplake_conf["embedding_type"], + embedding_model: str = deeplake_conf["embedding_model"], + read_only: bool = False, + ): + """Initialize DeepLake client object.""" + self.store_path = Path(store_path) + self.collection_name = collection_name + self.read_only = read_only + self.embedding_type: str = embedding_type + self.embedding_model: str = embedding_model + + self.embedding_function = Embedding( + embedding_model=self.embedding_model, embedding_type=self.embedding_type + ).embedding_function + + # self.client = VectorStore(path=self.store_path / self.collection_name) + self.langchain_client = DeepLake( + dataset_path=str(self.store_path / self.collection_name), + embedding=self.embedding_function, + read_only=self.read_only, + ) + self.client = self.langchain_client.vectorstore + self.allowed_metadata_types = (str, int, float, bool) + + def __len__(self) -> int: + """Number of chunks in the vector database.""" + return self.client.__len__() + + def delete(self) -> None: + """Delete all chunks in the vector database.""" + self.client.delete(delete_all=True) + + def add_docs(self, docs: List[Document], verbose=True) -> None: + """Adds documents to deeplake vectorstore. + + Args: + docs: List of Documents + verbose: Show progress bar + + Returns: + None + """ + docs = self._filter_metadata(docs) + for doc in ( + tqdm(docs, desc=f"Adding to {self.collection_name}:") if verbose else docs + ): + _id = self.langchain_client.add_documents([doc]) + + async def aadd_docs(self, docs: List[Document], verbose=True) -> None: + """Asynchronously adds documents to chroma vectorstore. + + Args: + docs: List of Documents + verbose: Show progress bar + + Returns: + None + """ + docs = self._filter_metadata(docs) + if verbose: + for doc in atqdm( + docs, + desc=f"Adding documents to {self.collection_name}", + total=len(docs), + ): + await self.langchain_client.aadd_documents([doc]) + else: + for doc in docs: + await self.langchain_client.aadd_documents([doc]) + + def get_chunk( + self, query: str, with_score: bool = False, top_k: int = None + ) -> Union[List[Document], List[Tuple[Document, float]]]: + """Returns the most similar chunks from the deeplake database. + + Args: + query: A query string + with_score: Outputs scores of returned chunks + top_k: Number of top similar chunks to return, if None defaults to self.top_k + + Returns: + list of Documents + + """ + if with_score: + return self.langchain_client.similarity_search_with_score( + query=query, k=top_k if top_k else 1 + ) + else: + return self.langchain_client.similarity_search( + query=query, k=top_k if top_k else 1 + ) + + async def aget_chunk( + self, query: str, with_score=False, top_k=None + ) -> Union[List[Document], List[Tuple[Document, float]]]: + """Returns the most similar chunks from the deeplake database, asynchronously. + + Args: + query: A query string + with_score: Outputs scores of returned chunks + top_k: Number of top similar chunks to return, if None defaults to self.top_k + + Returns: + list of Documents + + """ + if with_score: + return await self.langchain_client.asimilarity_search_with_score( + query=query, k=top_k if top_k else 1 + ) + else: + return await self.langchain_client.asimilarity_search( + query=query, k=top_k if top_k else 1 + ) diff --git a/src/grag/rag/basic_rag.py b/src/grag/rag/basic_rag.py index a99ecdd..da461b6 100644 --- a/src/grag/rag/basic_rag.py +++ b/src/grag/rag/basic_rag.py @@ -1,5 +1,11 @@ +"""Class for Basic RAG. + +This module provides: +- BasicRAG +""" + import json -from typing import List, Union +from typing import List, Optional, Union from grag import prompts from grag.components.llm import LLM @@ -13,8 +19,20 @@ class BasicRAG: + """Class for Basis RAG. + + Attributes: + model_name (str): Name of the llm model + doc_chain (str): Name of the document chain, ("stuff", "refine"), defaults to "stuff" + task (str): Name of task, defaults to "QA" + llm_kwargs (dict): Keyword arguments for LLM class + retriever_kwargs (dict): Keyword arguments for Retriever class + custom_prompt (Prompt): Prompt, defaults to None + """ + def __init__( self, + retriever: Optional[Retriever] = None, model_name=None, doc_chain="stuff", task="QA", @@ -22,10 +40,13 @@ def __init__( retriever_kwargs=None, custom_prompt: Union[Prompt, FewShotPrompt, None] = None, ): - if retriever_kwargs is None: - self.retriever = Retriever() + if retriever is None: + if retriever_kwargs is None: + self.retriever = Retriever() + else: + self.retriever = Retriever(**retriever_kwargs) else: - self.retriever = Retriever(**retriever_kwargs) + self.retriever = retriever if llm_kwargs is None: self.llm_ = LLM() @@ -54,6 +75,7 @@ def __init__( @property def model_name(self): + """Return the name of the model.""" return self._model_name @model_name.setter @@ -67,6 +89,7 @@ def model_name(self, value): @property def doc_chain(self): + """Returns the doc_chain.""" return self._doc_chain @doc_chain.setter @@ -86,6 +109,7 @@ def doc_chain(self, value): @property def task(self): + """Returns the task.""" return self._task @task.setter @@ -99,6 +123,7 @@ def task(self, value): self.prompt_matcher() def prompt_matcher(self): + """Matches relvant prompt using model, task and doc_chain.""" matcher_path = self.prompt_path.joinpath("matcher.json") with open(f"{matcher_path}", "r") as f: matcher_dict = json.load(f) @@ -122,7 +147,9 @@ def prompt_matcher(self): @staticmethod def stuff_docs(docs: List[Document]) -> str: - """Args: + r"""Concatenates docs into a string seperated by '\n\n'. + + Args: docs: List of langchain_core.documents.Document Returns: @@ -132,6 +159,8 @@ def stuff_docs(docs: List[Document]) -> str: @staticmethod def output_parser(call_func): + """Decorator to format llm output.""" + def output_parser_wrapper(*args, **kwargs): response, sources = call_func(*args, **kwargs) if conf["llm"]["std_out"] == "False": @@ -146,6 +175,7 @@ def output_parser_wrapper(*args, **kwargs): @output_parser def stuff_call(self, query: str): + """Call function for stuff chain.""" retrieved_docs = self.retriever.get_chunk(query) context = self.stuff_docs(retrieved_docs) prompt = self.main_prompt.format(context=context, question=query) @@ -155,6 +185,7 @@ def stuff_call(self, query: str): @output_parser def refine_call(self, query: str): + """Call function for refine chain.""" retrieved_docs = self.retriever.get_chunk(query) sources = [doc.metadata["source"] for doc in retrieved_docs] responses = [] @@ -176,6 +207,7 @@ def refine_call(self, query: str): return responses, sources def __call__(self, query: str): + """Call function for the class.""" if self.doc_chain == "stuff": return self.stuff_call(query) elif self.doc_chain == "refine": diff --git a/src/tests/components/multivec_retriever_test.py b/src/tests/components/multivec_retriever_test.py index 3ccb3fb..14dad0b 100644 --- a/src/tests/components/multivec_retriever_test.py +++ b/src/tests/components/multivec_retriever_test.py @@ -1,3 +1,92 @@ +import json + +from grag.components.multivec_retriever import Retriever +from langchain_core.documents import Document + +retriever = Retriever() # pass test collection + +doc = Document(page_content="Hello worlds", metadata={"source": "bars"}) + + +def test_retriver_id_gen(): + doc = Document(page_content="Hello world", metadata={"source": "bar"}) + id_ = retriever.id_gen(doc) + assert isinstance(id, str) + assert len(id_) == 32 + doc.page_content = doc.page_content + 'ABC' + id_1 = retriever.id_gen(doc) + assert id_ == id_1 + doc.metadata["source"] = "bars" + id_1 = retriever.id_gen(doc) + assert id_ != id_1 + + +def test_retriever_gen_doc_ids(): + docs = [Document(page_content="Hello world", metadata={"source": "bar"}), + Document(page_content="Hello", metadata={"source": "foo"})] + ids = retriever.gen_doc_ids(docs) + assert len(ids) == len(docs) + assert all(isinstance(id, str) for id in ids) + + +def test_retriever_split_docs(): + pass + + +def test_retriever_split_docs(): + pass + + +def test_retriever_add_docs(): + # small enough docs to not split. + docs = [Document(page_content= + """And so on this rainbow day, with storms all around them, and blue sky + above, they rode only as far as the valley. But from there, before they + turned to go back, the monuments appeared close, and they loomed + grandly with the background of purple bank and creamy cloud and shafts + of golden lightning. They seemed like sentinels--guardians of a great + and beautiful love born under their lofty heights, in the lonely + silence of day, in the star-thrown shadow of night. They were like that + love. And they held Lucy and Slone, calling every day, giving a + nameless and tranquil content, binding them true to love, true to the + sage and the open, true to that wild upland home""", metadata={"source": "test_doc_1"}), + Document(page_content= + """Slone and Lucy never rode down so far as the stately monuments, though + these held memories as hauntingly sweet as others were poignantly + bitter. Lucy never rode the King again. But Slone rode him, learned to + love him. And Lucy did not race any more. When Slone tried to stir in + her the old spirit all the response he got was a wistful shake of head + or a laugh that hid the truth or an excuse that the strain on her + ankles from Joel Creech's lasso had never mended. The girl was + unutterably happy, but it was possible that she would never race a + horse again.""", metadata={"source": "test_doc_2"}), + Document(page_content= + """Bostil wanted to be alone, to welcome the King, to lead him back to the + home corral, perhaps to hide from all eyes the change and the uplift + that would forever keep him from wronging another man. + + The late rains came and like magic, in a few days, the sage grew green + and lustrous and fresh, the gray turning to purple. + + Every morning the sun rose white and hot in a blue and cloudless sky. + And then soon the horizon line showed creamy clouds that rose and + spread and darkened. Every afternoon storms hung along the ramparts and + rainbows curved down beautiful and ethereal. The dim blackness of the + storm-clouds was split to the blinding zigzag of lightning, and the + thunder rolled and boomed, like the Colorado in flood.""", metadata={"source": "test_doc_3"}) + ] + ids = retriever.gen_doc_ids(docs) + retriever.add_docs(docs) + retrieved = retriever.store.mget(ids) + assert len(retrieved) == len(ids) + for i, doc in enumerate(docs): + retrieved_doc = json.loads(retrieved[i].decode()) + assert doc.metadata == retrieved_doc.metadata + + +def test_retriever_aadd_docs(): + pass + # # add code folder to sys path # import os # from pathlib import Path diff --git a/src/tests/components/vectordb/__init__.py b/src/tests/components/vectordb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/components/chroma_client_test.py b/src/tests/components/vectordb/chroma_client_test.py similarity index 59% rename from src/tests/components/chroma_client_test.py rename to src/tests/components/vectordb/chroma_client_test.py index 1596dd3..c491dfd 100644 --- a/src/tests/components/chroma_client_test.py +++ b/src/tests/components/vectordb/chroma_client_test.py @@ -1,12 +1,13 @@ import asyncio -from grag.components.chroma_client import ChromaClient +import pytest +from grag.components.vectordb.chroma_client import ChromaClient from langchain_core.documents import Document def test_chroma_connection(): - client = ChromaClient() - response = client.test_connection() + chroma_client = ChromaClient() + response = chroma_client.test_connection() assert isinstance(response, int) @@ -45,14 +46,12 @@ def test_chroma_add_docs(): storm-clouds was split to the blinding zigzag of lightning, and the thunder rolled and boomed, like the Colorado in flood.""", ] - client = ChromaClient(collection_name="test") - if client.collection.count() > 0: - client.chroma_client.delete_collection("test") - client = ChromaClient(collection_name="test") + chroma_client = ChromaClient(collection_name="test") + if len(chroma_client) > 0: + chroma_client.delete() docs = [Document(page_content=doc) for doc in docs] - client.add_docs(docs) - collection_count = client.collection.count() - assert collection_count == len(docs) + chroma_client.add_docs(docs) + assert len(chroma_client) == len(docs) def test_chroma_aadd_docs(): @@ -90,11 +89,60 @@ def test_chroma_aadd_docs(): storm-clouds was split to the blinding zigzag of lightning, and the thunder rolled and boomed, like the Colorado in flood.""", ] - client = ChromaClient(collection_name="test") - if client.collection.count() > 0: - client.chroma_client.delete_collection("test") - client = ChromaClient(collection_name="test") + chroma_client = ChromaClient(collection_name="test") + if len(chroma_client) > 0: + chroma_client.delete() docs = [Document(page_content=doc) for doc in docs] loop = asyncio.get_event_loop() - loop.run_until_complete(client.aadd_docs(docs)) - assert client.collection.count() == len(docs) + loop.run_until_complete(chroma_client.aadd_docs(docs)) + assert len(chroma_client) == len(docs) + + +chrome_get_chunk_params = [(1, False), (1, True), (2, False), (2, True)] + + +@pytest.mark.parametrize("top_k,with_score", chrome_get_chunk_params) +def test_chroma_get_chunk(top_k, with_score): + query = """Slone and Lucy never rode down so far as the stately monuments, though + these held memories as hauntingly sweet as others were poignantly + bitter. Lucy never rode the King again. But Slone rode him, learned to + love him. And Lucy did not race any more. When Slone tried to stir in + her the old spirit all the response he got was a wistful shake of head + or a laugh that hid the truth or an excuse that the strain on her + ankles from Joel Creech's lasso had never mended. The girl was + unutterably happy, but it was possible that she would never race a + horse again.""" + chroma_client = ChromaClient(collection_name="test") + retrieved_chunks = chroma_client.get_chunk( + query=query, top_k=top_k, with_score=with_score + ) + assert len(retrieved_chunks) == top_k + if with_score: + assert all(isinstance(doc[0], Document) for doc in retrieved_chunks) + assert all(isinstance(doc[1], float) for doc in retrieved_chunks) + else: + assert all(isinstance(doc, Document) for doc in retrieved_chunks) + + +@pytest.mark.parametrize("top_k,with_score", chrome_get_chunk_params) +def test_chroma_aget_chunk(top_k, with_score): + query = """Slone and Lucy never rode down so far as the stately monuments, though + these held memories as hauntingly sweet as others were poignantly + bitter. Lucy never rode the King again. But Slone rode him, learned to + love him. And Lucy did not race any more. When Slone tried to stir in + her the old spirit all the response he got was a wistful shake of head + or a laugh that hid the truth or an excuse that the strain on her + ankles from Joel Creech's lasso had never mended. The girl was + unutterably happy, but it was possible that she would never race a + horse again.""" + chroma_client = ChromaClient(collection_name="test") + loop = asyncio.get_event_loop() + retrieved_chunks = loop.run_until_complete( + chroma_client.aget_chunk(query=query, top_k=top_k, with_score=with_score) + ) + assert len(retrieved_chunks) == top_k + if with_score: + assert all(isinstance(doc[0], Document) for doc in retrieved_chunks) + assert all(isinstance(doc[1], float) for doc in retrieved_chunks) + else: + assert all(isinstance(doc, Document) for doc in retrieved_chunks) diff --git a/src/tests/components/vectordb/deeplake_client_test.py b/src/tests/components/vectordb/deeplake_client_test.py new file mode 100644 index 0000000..cea5e61 --- /dev/null +++ b/src/tests/components/vectordb/deeplake_client_test.py @@ -0,0 +1,146 @@ +import asyncio + +import pytest +from grag.components.vectordb.deeplake_client import DeepLakeClient +from langchain_core.documents import Document + + +def test_deeplake_add_docs(): + docs = [ + """And so on this rainbow day, with storms all around them, and blue sky + above, they rode only as far as the valley. But from there, before they + turned to go back, the monuments appeared close, and they loomed + grandly with the background of purple bank and creamy cloud and shafts + of golden lightning. They seemed like sentinels--guardians of a great + and beautiful love born under their lofty heights, in the lonely + silence of day, in the star-thrown shadow of night. They were like that + love. And they held Lucy and Slone, calling every day, giving a + nameless and tranquil content, binding them true to love, true to the + sage and the open, true to that wild upland home.""", + """Slone and Lucy never rode down so far as the stately monuments, though + these held memories as hauntingly sweet as others were poignantly + bitter. Lucy never rode the King again. But Slone rode him, learned to + love him. And Lucy did not race any more. When Slone tried to stir in + her the old spirit all the response he got was a wistful shake of head + or a laugh that hid the truth or an excuse that the strain on her + ankles from Joel Creech's lasso had never mended. The girl was + unutterably happy, but it was possible that she would never race a + horse again.""", + """Bostil wanted to be alone, to welcome the King, to lead him back to the + home corral, perhaps to hide from all eyes the change and the uplift + that would forever keep him from wronging another man. + + The late rains came and like magic, in a few days, the sage grew green + and lustrous and fresh, the gray turning to purple. + + Every morning the sun rose white and hot in a blue and cloudless sky. + And then soon the horizon line showed creamy clouds that rose and + spread and darkened. Every afternoon storms hung along the ramparts and + rainbows curved down beautiful and ethereal. The dim blackness of the + storm-clouds was split to the blinding zigzag of lightning, and the + thunder rolled and boomed, like the Colorado in flood.""", + ] + deeplake_client = DeepLakeClient(collection_name="test") + if len(deeplake_client) > 0: + deeplake_client.delete() + docs = [Document(page_content=doc) for doc in docs] + deeplake_client.add_docs(docs) + assert len(deeplake_client) == len(docs) + del deeplake_client + + +def test_chroma_aadd_docs(): + docs = [ + """And so on this rainbow day, with storms all around them, and blue sky + above, they rode only as far as the valley. But from there, before they + turned to go back, the monuments appeared close, and they loomed + grandly with the background of purple bank and creamy cloud and shafts + of golden lightning. They seemed like sentinels--guardians of a great + and beautiful love born under their lofty heights, in the lonely + silence of day, in the star-thrown shadow of night. They were like that + love. And they held Lucy and Slone, calling every day, giving a + nameless and tranquil content, binding them true to love, true to the + sage and the open, true to that wild upland home.""", + """Slone and Lucy never rode down so far as the stately monuments, though + these held memories as hauntingly sweet as others were poignantly + bitter. Lucy never rode the King again. But Slone rode him, learned to + love him. And Lucy did not race any more. When Slone tried to stir in + her the old spirit all the response he got was a wistful shake of head + or a laugh that hid the truth or an excuse that the strain on her + ankles from Joel Creech's lasso had never mended. The girl was + unutterably happy, but it was possible that she would never race a + horse again.""", + """Bostil wanted to be alone, to welcome the King, to lead him back to the + home corral, perhaps to hide from all eyes the change and the uplift + that would forever keep him from wronging another man. + + The late rains came and like magic, in a few days, the sage grew green + and lustrous and fresh, the gray turning to purple. + + Every morning the sun rose white and hot in a blue and cloudless sky. + And then soon the horizon line showed creamy clouds that rose and + spread and darkened. Every afternoon storms hung along the ramparts and + rainbows curved down beautiful and ethereal. The dim blackness of the + storm-clouds was split to the blinding zigzag of lightning, and the + thunder rolled and boomed, like the Colorado in flood.""", + ] + deeplake_client = DeepLakeClient(collection_name="test") + if len(deeplake_client) > 0: + deeplake_client.delete() + docs = [Document(page_content=doc) for doc in docs] + loop = asyncio.get_event_loop() + loop.run_until_complete(deeplake_client.aadd_docs(docs)) + assert len(deeplake_client) == len(docs) + del deeplake_client + + +deeplake_get_chunk_params = [(1, False), (1, True), (2, False), (2, True)] + + +@pytest.mark.parametrize("top_k,with_score", deeplake_get_chunk_params) +def test_deeplake_get_chunk(top_k, with_score): + query = """Slone and Lucy never rode down so far as the stately monuments, though + these held memories as hauntingly sweet as others were poignantly + bitter. Lucy never rode the King again. But Slone rode him, learned to + love him. And Lucy did not race any more. When Slone tried to stir in + her the old spirit all the response he got was a wistful shake of head + or a laugh that hid the truth or an excuse that the strain on her + ankles from Joel Creech's lasso had never mended. The girl was + unutterably happy, but it was possible that she would never race a + horse again.""" + deeplake_client = DeepLakeClient(collection_name="test", read_only=True) + retrieved_chunks = deeplake_client.get_chunk( + query=query, top_k=top_k, with_score=with_score + ) + assert len(retrieved_chunks) == top_k + if with_score: + assert all(isinstance(doc[0], Document) for doc in retrieved_chunks) + assert all(isinstance(doc[1], float) for doc in retrieved_chunks) + else: + assert all(isinstance(doc, Document) for doc in retrieved_chunks) + del deeplake_client + + +@pytest.mark.parametrize("top_k,with_score", deeplake_get_chunk_params) +def test_deeplake_aget_chunk(top_k, with_score): + query = """Slone and Lucy never rode down so far as the stately monuments, though + these held memories as hauntingly sweet as others were poignantly + bitter. Lucy never rode the King again. But Slone rode him, learned to + love him. And Lucy did not race any more. When Slone tried to stir in + her the old spirit all the response he got was a wistful shake of head + or a laugh that hid the truth or an excuse that the strain on her + ankles from Joel Creech's lasso had never mended. The girl was + unutterably happy, but it was possible that she would never race a + horse again.""" + deeplake_client = DeepLakeClient(collection_name="test", read_only=True) + loop = asyncio.get_event_loop() + retrieved_chunks = loop.run_until_complete( + deeplake_client.aget_chunk(query=query, top_k=top_k, with_score=with_score) + ) + assert len(retrieved_chunks) == top_k + if with_score: + assert all(isinstance(doc[0], Document) for doc in retrieved_chunks) + assert all(isinstance(doc[1], float) for doc in retrieved_chunks) + else: + assert all(isinstance(doc, Document) for doc in retrieved_chunks) + del deeplake_client diff --git a/src/tests/rag/basic_rag_test.py b/src/tests/rag/basic_rag_test.py index 06db25e..2249028 100644 --- a/src/tests/rag/basic_rag_test.py +++ b/src/tests/rag/basic_rag_test.py @@ -1,4 +1,4 @@ -from typing import List, Text +from typing import Text, List from grag.rag.basic_rag import BasicRAG