Skip to content

Commit

Permalink
Fix app code
Browse files Browse the repository at this point in the history
Signed-off-by: Sanket <[email protected]>
  • Loading branch information
sanketsudake committed Aug 22, 2024
1 parent 76cedfa commit 054edbc
Showing 1 changed file with 39 additions and 25 deletions.
64 changes: 39 additions & 25 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_community.vectorstores.chroma import Chroma
import chromadb
from chromadb.config import Settings
from chromadb.utils.embedding_functions import HuggingFaceEmbeddingServer
Expand Down Expand Up @@ -175,20 +176,15 @@ def load_documents(self, doc, num_docs=250):
return documents

def chunk_doc(self, pages, chunk_size=512, chunk_overlap=30):
tokenizer = AutoTokenizer.from_pretrained(
"BAAI/bge-large-en-v1.5"
)
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5")
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer,
chunk_size=chunk_size, chunk_overlap=chunk_overlap
tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
chunks = text_splitter.split_documents(pages)
print("Document chunked")
return chunks

def insert_embeddings(
self, chunks, chroma_embedding_function, batch_size=32
):
def insert_embeddings(self, chunks, chroma_embedding_function, batch_size=32):
print(
"Inserting embeddings into collection: {collection_name}".format(
collection_name=self.collection_name
Expand All @@ -212,9 +208,11 @@ def insert_embeddings(
print("Embeddings inserted\n")
# return db

def query_docs(self, model, question, vector_store, prompt, chat_history, use_reranker=False):
def query_docs(
self, model, question, vector_store, prompt, chat_history, use_reranker=False
):
retriever = vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": 10}
search_type="similarity", search_kwargs={"k": 10}
)
if use_reranker:
compressor = TEIRerank(
Expand All @@ -236,22 +234,28 @@ def query_docs(self, model, question, vector_store, prompt, chat_history, use_re
| model
| StrOutputParser()
)

answer = rag_chain.invoke({"question": question, "chat_history": chat_history})
return answer


def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)


def create_retriever(
name, model, description, client, chroma_embedding_function, embedder, reranker=False
name, description, client, chroma_embedding_function, embedding_svc, reranker=False
):
rag = RAG(collection_name="Slack", db_client=client)
collection_name = "software-slacks"
rag = RAG(collection_name=collection_name, db_client=client)
pages = rag.load_documents("spencer/software_slacks", num_docs=100)
chunks = rag.chunk_doc(pages)
vector_store = rag.insert_embeddings(chunks, chroma_embedding_function, embedder)

rag.insert_embeddings(chunks, chroma_embedding_function)
vector_store = Chroma(
embedding_function=embedding_svc,
collection_name=collection_name,
client=client,
)
if reranker:
compressor = TEIRerank(
url="http://{host}:{port}".format(
Expand All @@ -261,6 +265,7 @@ def create_retriever(
top_n=10,
batch_size=16,
)

retriever = vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": 100}
)
Expand All @@ -278,22 +283,31 @@ def create_retriever(


def setup_tools(_model, _client, _chroma_embedding_function, _embedder):
stackexchange_wrapper = StackExchangeAPIWrapper(max_results=3)
stackexchange_tool = StackExchangeTool(api_wrapper=stackexchange_wrapper)
tools = []
if (
os.getenv("STACK_OVERFLOW_API_KEY")
and os.getenv("STACK_OVERFLOW_API_KEY").strip()
):
stackexchange_wrapper = StackExchangeAPIWrapper(max_results=3)
stackexchange_tool = StackExchangeTool(api_wrapper=stackexchange_wrapper)
tools.append(stackexchange_tool)

web_search_tool = TavilySearchResults(max_results=10, handle_tool_error=True)
if os.getenv("TAVILY_API_KEY") and os.getenv("TAVILY_API_KEY").strip():
web_search_tool = TavilySearchResults(max_results=10, handle_tool_error=True)
tools.append(web_search_tool)

use_reranker = os.getenv("USE_RERANKER", "False") == "True"
retriever = create_retriever(
name="slack_conversations_retriever",
model=_model,
description="Useful for when you need to answer from Slack conversations.",
client=_client,
chroma_embedding_function=_chroma_embedding_function,
embedder=_embedder,
"slack_conversations_retriever",
"Useful for when you need to answer from Slack conversations.",
_client,
_chroma_embedding_function,
_embedder,
reranker=use_reranker,
)
return [web_search_tool, stackexchange_tool, retriever]
tools.append(retriever)

return tools


@st.cache_resource
Expand Down

0 comments on commit 054edbc

Please sign in to comment.