diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py index d0e8008aca..a8a8fd833e 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py @@ -19,6 +19,7 @@ import logging import time +from functools import partial from logging import ERROR, WARNING from typing import List @@ -119,7 +120,7 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse: ) # RAG Guard - __rag_guard(inputs, response) + rag_guard(inputs, response) # Calculation of RAG processing time rag_duration = '{:.2f}'.format(time.time() - start_time) @@ -166,62 +167,89 @@ def get_source_content(doc: Document) -> str: def create_rag_chain(query: RagQuery) -> ConversationalRetrievalChain: """ - Create the RAG chain from RagQuery, using the LLM and Embedding settings specified in the query + Create the RAG chain from RagQuery, using the LLM and Embedding settings specified in the query. Args: query: The RAG query Returns: The RAG chain. """ + llm_factory = get_llm_factory(setting=query.question_answering_llm_setting) em_factory = get_em_factory(setting=query.embedding_question_em_setting) - vector_store_factory = get_vector_store_factory(setting=query.vector_store_setting, - index_name=query.document_index_name, - embedding_function=em_factory.get_embedding_model()) + vector_store_factory = get_vector_store_factory( + setting=query.vector_store_setting, + index_name=query.document_index_name, + embedding_function=em_factory.get_embedding_model() + ) + retriever = vector_store_factory.get_vector_store_retriever(query.document_search_params.to_dict()) - logger.info('RAG chain - LLM template validation') + # Log progress and validate prompt template + logger.info('RAG chain - Validating LLM prompt template') validate_prompt_template(query.question_answering_prompt) logger.debug('RAG chain - Document index name: %s', query.document_index_name) - logger.debug('RAG chain - Create a ConversationalRetrievalChain from LLM') + # Build LLM and prompt templates llm = llm_factory.get_language_model() - rag_prompt = LangChainPromptTemplate.from_template( - template=query.question_answering_prompt.template, - template_format=query.question_answering_prompt.formatter.value, - partial_variables=query.question_answering_prompt.inputs - ) - rag_chain = { - "context": lambda x: "\n\n".join(doc.page_content for doc in x["documents"]) - } | rag_prompt | llm | StrOutputParser(name="rag_chain_output") + rag_prompt = build_rag_prompt(query) - chat_prompt = ChatPromptTemplate.from_messages( - [ - ("system", """Given a chat history and the latest user question which might reference context in \ - the chat history, formulate a standalone question which can be understood without the chat history. \ - Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""), - MessagesPlaceholder(variable_name="chat_history"), - ("human", "{question}"), - ] - ) - chat_chain = chat_prompt | llm | StrOutputParser(name="chat_chain_output") + # Construct the RAG chain using the prompt and LLM + rag_chain = construct_rag_chain(llm, rag_prompt) - def contextualizing_the_question(inputs: dict): - """ Contextualize a question in relation to the history of a conversation. """ - if inputs.get("chat_history") and len(inputs["chat_history"]) > 0: - return chat_chain - else: - return inputs["question"] + # Build the chat chain for question contextualization + chat_chain = build_chat_chain(llm) - retriever = vector_store_factory.get_vector_store_retriever(query.document_search_params.to_dict()) - rag_chain_with_source = contextualizing_the_question | RunnableParallel( + # Function to contextualize the question based on chat history + contextualize_question_fn = partial(contextualize_question, chat_chain=chat_chain) + + # Final RAG chain with retriever and source documents + rag_chain_with_source = contextualize_question_fn | RunnableParallel( {"question": RunnablePassthrough(), "documents": retriever} ).assign(answer=rag_chain) return rag_chain_with_source -def __rag_guard(inputs, response): +def build_rag_prompt(query: RagQuery) -> LangChainPromptTemplate: + """ + Build the RAG prompt template. + """ + return LangChainPromptTemplate.from_template( + template=query.question_answering_prompt.template, + template_format=query.question_answering_prompt.formatter.value, + partial_variables=query.question_answering_prompt.inputs + ) + +def construct_rag_chain(llm, rag_prompt): + """ + Construct the RAG chain from LLM and prompt. + """ + return { + "context": lambda x: "\n\n".join(doc.page_content for doc in x["documents"]) + } | rag_prompt | llm | StrOutputParser(name="rag_chain_output") + +def build_chat_chain(llm) -> ChatPromptTemplate: + """ + Build the chat chain for contextualizing questions. + """ + return ChatPromptTemplate.from_messages([ + ("system", """Given a chat history and the latest user question which might reference context in \ + the chat history, formulate a standalone question which can be understood without the chat history. \ + Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""), + MessagesPlaceholder(variable_name="chat_history"), + ("human", "{question}"), + ]) | llm | StrOutputParser(name="chat_chain_output") + +def contextualize_question(inputs: dict, chat_chain) -> str: + """ + Contextualize the question based on the chat history. + """ + if inputs.get("chat_history") and len(inputs["chat_history"]) > 0: + return chat_chain + return inputs["question"] + +def rag_guard(inputs, response): """ If a 'no_answer' input was given as a rag setting, then the RAG system should give no further response when no source document has been found. @@ -239,7 +267,7 @@ def __rag_guard(inputs, response): and response['documents'] == [] ): message = 'The RAG gives an answer when no document has been found!' - __rag_log(level=ERROR, message=message, inputs=inputs, response=response) + rag_log(level=ERROR, message=message, inputs=inputs, response=response) raise GenAIGuardCheckException(ErrorInfo(cause=message)) if ( @@ -247,12 +275,12 @@ def __rag_guard(inputs, response): and response['documents'] != [] ): message = 'The RAG gives no answer for user question, but some documents has been found!' - __rag_log(level=WARNING, message=message, inputs=inputs, response=response) + rag_log(level=WARNING, message=message, inputs=inputs, response=response) # Remove source documents response['documents'] = [] -def __rag_log(level, message, inputs, response): +def rag_log(level, message, inputs, response): """ RAG logging