Skip to content

Commit

Permalink
Merge branch 'main' into quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchitvj authored Mar 24, 2024
2 parents 14ca30d + 56f16cf commit d161553
Show file tree
Hide file tree
Showing 21 changed files with 862 additions and 110 deletions.
8 changes: 6 additions & 2 deletions projects/Basic-RAG/BasicRAG_stuff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from grag.grag.rag import BasicRAG
from grag.components.multivec_retriever import Retriever
from grag.components.vectordb.deeplake_client import DeepLakeClient
from grag.rag.basic_rag import BasicRAG

rag = BasicRAG(doc_chain="stuff")
client = DeepLakeClient(collection_name="test")
retriever = Retriever(vectordb=client)
rag = BasicRAG(doc_chain="stuff", retriever=retriever)

if __name__ == "__main__":
while True:
Expand Down
14 changes: 7 additions & 7 deletions projects/Retriver-GUI/retriever_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def render_search_results(self):
st.write(result.metadata)

def check_connection(self):
response = self.app.retriever.client.test_connection()
response = self.app.retriever.vectordb.test_connection()
if response:
return True
else:
Expand All @@ -55,14 +55,14 @@ def check_connection(self):
def render_stats(self):
st.write(f'''
**Chroma Client Details:** \n
Host Address : {self.app.retriever.client.host}:{self.app.retriever.client.port} \n
Collection Name : {self.app.retriever.client.collection_name} \n
Embeddings Type : {self.app.retriever.client.embedding_type} \n
Embeddings Model: {self.app.retriever.client.embedding_model} \n
Number of docs : {self.app.retriever.client.collection.count()} \n
Host Address : {self.app.retriever.vectordb.host}:{self.app.retriever.vectordb.port} \n
Collection Name : {self.app.retriever.vectordb.collection_name} \n
Embeddings Type : {self.app.retriever.vectordb.embedding_type} \n
Embeddings Model: {self.app.retriever.vectordb.embedding_model} \n
Number of docs : {self.app.retriever.vectordb.collection.count()} \n
''')
if st.button('Check Connection'):
response = self.app.retriever.client.test_connection()
response = self.app.retriever.vectordb.test_connection()
if response:
st.write(':green[Connection Active]')
else:
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"huggingface_hub>=0.20.2",
"pydantic>=2.5.0",
"rouge-score>=0.1.2",
"deeplake>=3.8.27"
]

[project.urls]
Expand Down Expand Up @@ -101,9 +102,13 @@ exclude_lines = [
[tool.ruff]
line-length = 88
indent-width = 4
extend-exclude = ["tests", "others"]

[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "D"]
ignore = ["D104"]
exclude = ["__about__.py"]


[tool.ruff.format]
quote-style = "double"
Expand Down
16 changes: 15 additions & 1 deletion src/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ n_gpu_layers_cpp : -1
std_out : True
base_dir : ${root:root_path}/models

[deeplake]
collection_name : arxiv
embedding_type : instructor-embedding
embedding_model : hkunlp/instructor-xl
store_path : ${data:data_path}/vectordb

[chroma]
host : localhost
port : 8000
Expand All @@ -25,6 +31,14 @@ embedding_model : hkunlp/instructor-xl
store_path : ${data:data_path}/vectordb
allow_reset : True

[deeplake]
collection_name : arxiv
# embedding_type : sentence-transformers
# embedding_model : "all-mpnet-base-v2"
embedding_type : instructor-embedding
embedding_model : hkunlp/instructor-xl
store_path : ${data:data_path}/vectordb

[text_splitter]
chunk_size : 5000
chunk_overlap : 400
Expand Down Expand Up @@ -54,4 +68,4 @@ data_path : ${root:root_path}/data
root_path : /home/ubuntu/volume_2k/Capstone_5

[quantize]
llama_cpp_path : ${root:root_path}
llama_cpp_path : ${root:root_path}
8 changes: 8 additions & 0 deletions src/grag/components/embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""Class for embedding.
This module provides:
- Embedding
"""

from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
Expand All @@ -6,6 +12,7 @@

class Embedding:
"""A class for vector embeddings.
Supports:
huggingface sentence transformers -> model_type = 'sentence-transformers'
huggingface instructor embeddings -> model_type = 'instructor-embedding'
Expand All @@ -17,6 +24,7 @@ class Embedding:
"""

def __init__(self, embedding_type: str, embedding_model: str):
"""Initialize the embedding with embedding_type and embedding_model."""
self.embedding_type = embedding_type
self.embedding_model = embedding_model
match self.embedding_type:
Expand Down
3 changes: 3 additions & 0 deletions src/grag/components/llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Class for LLM."""

import os
from pathlib import Path

Expand Down Expand Up @@ -50,6 +52,7 @@ def __init__(
quantization=llm_conf["quantization"],
pipeline=llm_conf["pipeline"],
):
"""Initialize the LLM class using the given parameters."""
self.base_dir = Path(base_dir)
self._model_name = model_name
self.quantization = quantization
Expand Down
76 changes: 43 additions & 33 deletions src/grag/components/multivec_retriever.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
"""Class for retriever.
This module provides:
- Retriever
"""

import asyncio
import uuid
from typing import List
from typing import Any, Dict, List, Optional

from grag.components.chroma_client import ChromaClient
from grag.components.text_splitter import TextSplitter
from grag.components.utils import get_config
from grag.components.vectordb.base import VectorDB
from grag.components.vectordb.deeplake_client import DeepLakeClient
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import LocalFileStore
from langchain_core.documents import Document
Expand All @@ -13,14 +20,16 @@


class Retriever:
"""A class for multi vector retriever, it connects to a vector database and a local file store.
It is used to return most similar chunks from a vector store but has the additional funcationality
to return a linked document, chunk, etc.
"""A class for multi vector retriever.
It connects to a vector database and a local file store.
It is used to return most similar chunks from a vector store but has the additional functionality to return a
linked document, chunk, etc.
Attributes:
store_path: Path to the local file store
id_key: A key prefix for identifying documents
client: ChromaClient class instance from components.chroma_client
vectordb: ChromaClient class instance from components.client
store: langchain.storage.LocalFileStore object, stores the key value pairs of document id and parent file
retriever: langchain.retrievers.multi_vector.MultiVectorRetriever class instance, langchain's multi-vector retriever
splitter: TextSplitter class instance from components.text_splitter
Expand All @@ -31,12 +40,16 @@ class Retriever:

def __init__(
self,
vectordb: Optional[VectorDB] = None,
store_path: str = multivec_retriever_conf["store_path"],
id_key: str = multivec_retriever_conf["id_key"],
namespace: str = multivec_retriever_conf["namespace"],
top_k=1,
client_kwargs: Optional[Dict[str, Any]] = None,
):
"""Args:
"""Initialize the Retriever.
Args:
store_path: Path to the local file store, defaults to argument from config file
id_key: A key prefix for identifying documents, defaults to argument from config file
namespace: A namespace for producing unique id, defaults to argument from congig file
Expand All @@ -45,10 +58,16 @@ def __init__(
self.store_path = store_path
self.id_key = id_key
self.namespace = uuid.UUID(namespace)
self.client = ChromaClient()
if vectordb is None:
if client_kwargs is not None:
self.vectordb = DeepLakeClient(**client_kwargs)
else:
self.vectordb = DeepLakeClient()
else:
self.vectordb = vectordb
self.store = LocalFileStore(self.store_path)
self.retriever = MultiVectorRetriever(
vectorstore=self.client.langchain_chroma,
vectorstore=self.vectordb.langchain_client,
byte_store=self.store,
id_key=self.id_key,
)
Expand All @@ -58,6 +77,7 @@ def __init__(

def id_gen(self, doc: Document) -> str:
"""Takes a document and returns a unique id (uuid5) using the namespace and document source.
This ensures that a single document always gets the same unique id.
Args:
Expand All @@ -81,7 +101,9 @@ def gen_doc_ids(self, docs: List[Document]) -> List[str]:
return [self.id_gen(doc) for doc in docs]

def split_docs(self, docs: List[Document]) -> List[Document]:
"""Takes a list of documents and splits them into smaller chunks using TextSplitter from compoenents.text_splitter
"""Takes a list of documents and splits them into smaller chunks.
Using TextSplitter from components.text_splitter
Also adds the unique parent document id into metadata
Args:
Expand All @@ -101,8 +123,7 @@ def split_docs(self, docs: List[Document]) -> List[Document]:
return chunks

def add_docs(self, docs: List[Document]):
"""Takes a list of documents, splits them using the split_docs method and then adds them into the vector database
and adds the parent document into the file store.
"""Adds given documents into the vector database also adds the parent document into the file store.
Args:
docs: List of langchain_core.documents.Document
Expand All @@ -113,12 +134,11 @@ def add_docs(self, docs: List[Document]):
"""
chunks = self.split_docs(docs)
doc_ids = self.gen_doc_ids(docs)
self.client.add_docs(chunks)
self.vectordb.add_docs(chunks)
self.retriever.docstore.mset(list(zip(doc_ids, docs)))

async def aadd_docs(self, docs: List[Document]):
"""Takes a list of documents, splits them using the split_docs method and then adds them into the vector database
and adds the parent document into the file store.
"""Adds given documents into the vector database also adds the parent document into the file store.
Args:
docs: List of langchain_core.documents.Document
Expand All @@ -129,11 +149,11 @@ async def aadd_docs(self, docs: List[Document]):
"""
chunks = self.split_docs(docs)
doc_ids = self.gen_doc_ids(docs)
await asyncio.run(self.client.aadd_docs(chunks))
await asyncio.run(self.vectordb.aadd_docs(chunks))
self.retriever.docstore.mset(list(zip(doc_ids)))

def get_chunk(self, query: str, with_score=False, top_k=None):
"""Returns the most (cosine) similar chunks from the vector database.
"""Returns the most similar chunks from the vector database.
Args:
query: A query string
Expand All @@ -144,14 +164,8 @@ def get_chunk(self, query: str, with_score=False, top_k=None):
list of Documents
"""
if with_score:
return self.client.langchain_chroma.similarity_search_with_relevance_scores(
query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs
)
else:
return self.client.langchain_chroma.similarity_search(
query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs
)
_top_k = top_k if top_k else self.retriever.search_kwargs["k"]
return self.vectordb.get_chunk(query=query, top_k=_top_k, with_score=with_score)

async def aget_chunk(self, query: str, with_score=False, top_k=None):
"""Returns the most (cosine) similar chunks from the vector database, asynchronously.
Expand All @@ -165,14 +179,10 @@ async def aget_chunk(self, query: str, with_score=False, top_k=None):
list of Documents
"""
if with_score:
return await self.client.langchain_chroma.asimilarity_search_with_relevance_scores(
query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs
)
else:
return await self.client.langchain_chroma.asimilarity_search(
query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs
)
_top_k = top_k if top_k else self.retriever.search_kwargs["k"]
return await self.vectordb.aget_chunk(
query=query, top_k=_top_k, with_score=with_score
)

def get_doc(self, query: str):
"""Returns the parent document of the most (cosine) similar chunk from the vector database.
Expand Down
13 changes: 11 additions & 2 deletions src/grag/components/parse_pdf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""Classes for parsing files.
This module provides:
- ParsePDF
"""

from langchain_core.documents import Document
from unstructured.partition.pdf import partition_pdf

Expand Down Expand Up @@ -32,7 +38,7 @@ def __init__(
add_captions_to_blocks=parser_conf["add_captions_to_blocks"],
table_as_html=parser_conf["table_as_html"],
):
# Instantialize instance variables with parameters
"""Initialize instance variables with parameters."""
self.strategy = strategy
if extract_images: # by default always extract Table
self.extract_image_block_types = [
Expand Down Expand Up @@ -72,7 +78,8 @@ def partition(self, path: str):

def classify(self, partitions):
"""Classifies the partitioned elements into Text, Tables, and Images list in a dictionary.
Add captions for each element (if available).
Also adds captions for each element (if available).
Parameters:
partitions (list): The list of partitioned elements from the PDF document.
Expand Down Expand Up @@ -117,6 +124,8 @@ def classify(self, partitions):
return classified_elements

def text_concat(self, elements) -> str:
"""Context aware concatenates all elements into a single string."""
full_text = ""
for current_element, next_element in zip(elements, elements[1:]):
curr_type = current_element.category
next_type = next_element.category
Expand Down
Loading

0 comments on commit d161553

Please sign in to comment.