Skip to content

Commit

Permalink
[DERCBOT-1037] TU OK
Browse files Browse the repository at this point in the history
  • Loading branch information
assouktim committed Oct 23, 2024
1 parent cb693dd commit 3a890f3
Showing 1 changed file with 65 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import logging
import time
from functools import partial
from logging import ERROR, WARNING
from typing import List

Expand Down Expand Up @@ -119,7 +120,7 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse:
)

# RAG Guard
__rag_guard(inputs, response)
rag_guard(inputs, response)

# Calculation of RAG processing time
rag_duration = '{:.2f}'.format(time.time() - start_time)
Expand Down Expand Up @@ -166,62 +167,89 @@ def get_source_content(doc: Document) -> str:

def create_rag_chain(query: RagQuery) -> ConversationalRetrievalChain:
"""
Create the RAG chain from RagQuery, using the LLM and Embedding settings specified in the query
Create the RAG chain from RagQuery, using the LLM and Embedding settings specified in the query.
Args:
query: The RAG query
Returns:
The RAG chain.
"""

llm_factory = get_llm_factory(setting=query.question_answering_llm_setting)
em_factory = get_em_factory(setting=query.embedding_question_em_setting)
vector_store_factory = get_vector_store_factory(setting=query.vector_store_setting,
index_name=query.document_index_name,
embedding_function=em_factory.get_embedding_model())
vector_store_factory = get_vector_store_factory(
setting=query.vector_store_setting,
index_name=query.document_index_name,
embedding_function=em_factory.get_embedding_model()
)
retriever = vector_store_factory.get_vector_store_retriever(query.document_search_params.to_dict())

logger.info('RAG chain - LLM template validation')
# Log progress and validate prompt template
logger.info('RAG chain - Validating LLM prompt template')
validate_prompt_template(query.question_answering_prompt)

logger.debug('RAG chain - Document index name: %s', query.document_index_name)
logger.debug('RAG chain - Create a ConversationalRetrievalChain from LLM')

# Build LLM and prompt templates
llm = llm_factory.get_language_model()
rag_prompt = LangChainPromptTemplate.from_template(
template=query.question_answering_prompt.template,
template_format=query.question_answering_prompt.formatter.value,
partial_variables=query.question_answering_prompt.inputs
)
rag_chain = {
"context": lambda x: "\n\n".join(doc.page_content for doc in x["documents"])
} | rag_prompt | llm | StrOutputParser(name="rag_chain_output")
rag_prompt = build_rag_prompt(query)

chat_prompt = ChatPromptTemplate.from_messages(
[
("system", """Given a chat history and the latest user question which might reference context in \
the chat history, formulate a standalone question which can be understood without the chat history. \
Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
chat_chain = chat_prompt | llm | StrOutputParser(name="chat_chain_output")
# Construct the RAG chain using the prompt and LLM
rag_chain = construct_rag_chain(llm, rag_prompt)

def contextualizing_the_question(inputs: dict):
""" Contextualize a question in relation to the history of a conversation. """
if inputs.get("chat_history") and len(inputs["chat_history"]) > 0:
return chat_chain
else:
return inputs["question"]
# Build the chat chain for question contextualization
chat_chain = build_chat_chain(llm)

retriever = vector_store_factory.get_vector_store_retriever(query.document_search_params.to_dict())
rag_chain_with_source = contextualizing_the_question | RunnableParallel(
# Function to contextualize the question based on chat history
contextualize_question_fn = partial(contextualize_question, chat_chain=chat_chain)

# Final RAG chain with retriever and source documents
rag_chain_with_source = contextualize_question_fn | RunnableParallel(
{"question": RunnablePassthrough(), "documents": retriever}
).assign(answer=rag_chain)

return rag_chain_with_source


def __rag_guard(inputs, response):
def build_rag_prompt(query: RagQuery) -> LangChainPromptTemplate:
"""
Build the RAG prompt template.
"""
return LangChainPromptTemplate.from_template(
template=query.question_answering_prompt.template,
template_format=query.question_answering_prompt.formatter.value,
partial_variables=query.question_answering_prompt.inputs
)

def construct_rag_chain(llm, rag_prompt):
"""
Construct the RAG chain from LLM and prompt.
"""
return {
"context": lambda x: "\n\n".join(doc.page_content for doc in x["documents"])
} | rag_prompt | llm | StrOutputParser(name="rag_chain_output")

def build_chat_chain(llm) -> ChatPromptTemplate:
"""
Build the chat chain for contextualizing questions.
"""
return ChatPromptTemplate.from_messages([
("system", """Given a chat history and the latest user question which might reference context in \
the chat history, formulate a standalone question which can be understood without the chat history. \
Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]) | llm | StrOutputParser(name="chat_chain_output")

def contextualize_question(inputs: dict, chat_chain) -> str:
"""
Contextualize the question based on the chat history.
"""
if inputs.get("chat_history") and len(inputs["chat_history"]) > 0:
return chat_chain
return inputs["question"]

def rag_guard(inputs, response):
"""
If a 'no_answer' input was given as a rag setting,
then the RAG system should give no further response when no source document has been found.
Expand All @@ -239,20 +267,20 @@ def __rag_guard(inputs, response):
and response['documents'] == []
):
message = 'The RAG gives an answer when no document has been found!'
__rag_log(level=ERROR, message=message, inputs=inputs, response=response)
rag_log(level=ERROR, message=message, inputs=inputs, response=response)
raise GenAIGuardCheckException(ErrorInfo(cause=message))

if (
response['answer'] == inputs['no_answer']
and response['documents'] != []
):
message = 'The RAG gives no answer for user question, but some documents has been found!'
__rag_log(level=WARNING, message=message, inputs=inputs, response=response)
rag_log(level=WARNING, message=message, inputs=inputs, response=response)
# Remove source documents
response['documents'] = []


def __rag_log(level, message, inputs, response):
def rag_log(level, message, inputs, response):
"""
RAG logging
Expand Down

0 comments on commit 3a890f3

Please sign in to comment.