Skip to content

Commit

Permalink
Add runnable qa chain
Browse files Browse the repository at this point in the history
  • Loading branch information
jamakase committed Sep 14, 2024
1 parent b42f0d0 commit 14e5a80
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
6 changes: 3 additions & 3 deletions services/ml/app/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

class Config:
PORT: int = int(os.environ.get("PORT", 8000))
QDRANT_HOST: str = os.environ.get("QDRANT_HOST", "http://localhost:6333")
QDRANT_HOST: str = os.environ.get("QDRANT_HOST", "http://84.201.156.82:6333")
QDRANT_COLLECTION_NAME: str = os.environ.get("QDRANT_COLLECTION_NAME", "docs_768")
QDRANT_API_KEY: str = os.environ.get("QDRANT_API_KEY")
QDRANT_API_KEY: str = os.environ.get("QDRANT_API_KEY", "somwerhjsadimqwe")
LLM_SOURCE: str = os.environ.get("LLM_SOURCE", "openai")

def __init__(self):
if self.LLM_SOURCE == "openai":
self.MODEL = os.environ.get("MODEL", "meta-llama/llama-3.1-8b-instruct:free")
self.MODEL = os.environ.get("MODEL", "meta-llama/llama-3.1-8b-instruct")
self.OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://openrouter.ai/api/v1")
self.OPENAI_API_KEY = os.environ.get("")
self.EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
Expand Down
34 changes: 31 additions & 3 deletions services/ml/app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,42 @@
from packages.retriever import Retriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from sentence_transformers import CrossEncoder
from langchain.prompts import ChatPromptTemplate
from operator import itemgetter

from .config import config

llm_instance = LLMInstance(config)
retriever_instance = Retriever(llm_instance.get_embeddings(), host=config.QDRANT_HOST, collection_name=config.QDRANT_COLLECTION_NAME, api_key=config.QDRANT_API_KEY)

cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')

retriever = retriever_instance.get_retriever()
retriever.search_kwargs['k'] = 20

def rerank_documents_with_crossencoder(query_and_docs):
docs = query_and_docs["docs"]
query = query_and_docs["question"]
scores = cross_encoder.predict([(query, doc.page_content) for doc in docs])
ranked_docs = [doc for _, doc in sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)]
return {"docs": ranked_docs[:3], "question": query}

def join_docs(docs):
return " ".join([doc.page_content for doc in docs])

# Modify the qa_chain definition
qa_chain = (
RunnableParallel(
{"context": retriever_instance.get_retriever(), "question": RunnablePassthrough()}
{"docs": retriever, "question": itemgetter("question")}
)
| RunnableLambda(rerank_documents_with_crossencoder)
| RunnableParallel({"context": itemgetter("docs") | RunnableLambda(join_docs), "question": itemgetter("question")})
| prompt
| llm_instance.get_llm()
| StrOutputParser()
| (lambda x: {"result": x})
).with_types(input_type=InputChat, output_type=OutputChat)


app = FastAPI(
title="LangChain Server",
version="1.0",
Expand Down Expand Up @@ -57,6 +74,17 @@

add_routes(app, path="/search", runnable=retriever_chain)

vqa_chain = (
RunnableParallel(
{"context": retriever_instance.get_retriever(), "question": RunnablePassthrough()}
)
| RunnableLambda(lambda x: x.replace("s3://afana-propdoc-production", "https://"))
| prompt
| llm_instance.get_llm()
).with_types(input_type=InputChat, output_type=OutputChat)

add_routes(app, path="/vqa", runnable=vqa_chain)

if __name__ == "__main__":
import uvicorn

Expand Down
2 changes: 2 additions & 0 deletions services/ml/packages/prompts/prompt_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

prompt = ChatPromptTemplate.from_template("""Вы являетесь помощником в выполнении поиска ответов на вопросы по нормативной документации по строительству объектов.
Используйте приведенные ниже фрагменты извлеченного контекста, чтобы ответить на вопрос.
При ответе ссылайтесь на источник, откуда вы взяли информацию (например, информация может лежать в метаданных в headers или pdf_file).
Старайтесь изложить информацию четко и подробно.
Если вы не знаете ответа, просто скажите, что вы не знаете.
Question: {question}
Context: {context}
Expand Down

0 comments on commit 14e5a80

Please sign in to comment.