From 14e5a80fb3da8fa215adc64803e4d383d4627247 Mon Sep 17 00:00:00 2001 From: Artem Astapenko Date: Sat, 14 Sep 2024 11:10:33 +0300 Subject: [PATCH] Add runnable qa chain --- services/ml/app/config/config.py | 6 ++-- services/ml/app/server.py | 34 +++++++++++++++++++++-- services/ml/packages/prompts/prompt_qa.py | 2 ++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/services/ml/app/config/config.py b/services/ml/app/config/config.py index 1441ea2..2bd8192 100644 --- a/services/ml/app/config/config.py +++ b/services/ml/app/config/config.py @@ -2,14 +2,14 @@ class Config: PORT: int = int(os.environ.get("PORT", 8000)) - QDRANT_HOST: str = os.environ.get("QDRANT_HOST", "http://localhost:6333") + QDRANT_HOST: str = os.environ.get("QDRANT_HOST", "http://84.201.156.82:6333") QDRANT_COLLECTION_NAME: str = os.environ.get("QDRANT_COLLECTION_NAME", "docs_768") - QDRANT_API_KEY: str = os.environ.get("QDRANT_API_KEY") + QDRANT_API_KEY: str = os.environ.get("QDRANT_API_KEY", "somwerhjsadimqwe") LLM_SOURCE: str = os.environ.get("LLM_SOURCE", "openai") def __init__(self): if self.LLM_SOURCE == "openai": - self.MODEL = os.environ.get("MODEL", "meta-llama/llama-3.1-8b-instruct:free") + self.MODEL = os.environ.get("MODEL", "meta-llama/llama-3.1-8b-instruct") self.OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://openrouter.ai/api/v1") self.OPENAI_API_KEY = os.environ.get("") self.EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2") diff --git a/services/ml/app/server.py b/services/ml/app/server.py index 70403c0..5b6af90 100644 --- a/services/ml/app/server.py +++ b/services/ml/app/server.py @@ -9,25 +9,42 @@ from packages.retriever import Retriever from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda +from sentence_transformers import CrossEncoder +from langchain.prompts import ChatPromptTemplate +from operator import itemgetter from .config import config llm_instance = LLMInstance(config) retriever_instance = Retriever(llm_instance.get_embeddings(), host=config.QDRANT_HOST, collection_name=config.QDRANT_COLLECTION_NAME, api_key=config.QDRANT_API_KEY) +cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') + +retriever = retriever_instance.get_retriever() +retriever.search_kwargs['k'] = 20 + +def rerank_documents_with_crossencoder(query_and_docs): + docs = query_and_docs["docs"] + query = query_and_docs["question"] + scores = cross_encoder.predict([(query, doc.page_content) for doc in docs]) + ranked_docs = [doc for _, doc in sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)] + return {"docs": ranked_docs[:3], "question": query} + +def join_docs(docs): + return " ".join([doc.page_content for doc in docs]) -# Modify the qa_chain definition qa_chain = ( RunnableParallel( - {"context": retriever_instance.get_retriever(), "question": RunnablePassthrough()} + {"docs": retriever, "question": itemgetter("question")} ) + | RunnableLambda(rerank_documents_with_crossencoder) + | RunnableParallel({"context": itemgetter("docs") | RunnableLambda(join_docs), "question": itemgetter("question")}) | prompt | llm_instance.get_llm() | StrOutputParser() | (lambda x: {"result": x}) ).with_types(input_type=InputChat, output_type=OutputChat) - app = FastAPI( title="LangChain Server", version="1.0", @@ -57,6 +74,17 @@ add_routes(app, path="/search", runnable=retriever_chain) +vqa_chain = ( + RunnableParallel( + {"context": retriever_instance.get_retriever(), "question": RunnablePassthrough()} + ) + | RunnableLambda(lambda x: x.replace("s3://afana-propdoc-production", "https://")) + | prompt + | llm_instance.get_llm() +).with_types(input_type=InputChat, output_type=OutputChat) + +add_routes(app, path="/vqa", runnable=vqa_chain) + if __name__ == "__main__": import uvicorn diff --git a/services/ml/packages/prompts/prompt_qa.py b/services/ml/packages/prompts/prompt_qa.py index 188aecc..972f14e 100644 --- a/services/ml/packages/prompts/prompt_qa.py +++ b/services/ml/packages/prompts/prompt_qa.py @@ -2,6 +2,8 @@ prompt = ChatPromptTemplate.from_template("""Вы являетесь помощником в выполнении поиска ответов на вопросы по нормативной документации по строительству объектов. Используйте приведенные ниже фрагменты извлеченного контекста, чтобы ответить на вопрос. + При ответе ссылайтесь на источник, откуда вы взяли информацию (например, информация может лежать в метаданных в headers или pdf_file). + Старайтесь изложить информацию четко и подробно. Если вы не знаете ответа, просто скажите, что вы не знаете. Question: {question} Context: {context}