From c93e056be5613f53a93e9f7a5147abd373c5d6ce Mon Sep 17 00:00:00 2001 From: Mohamed ASSOUKTI Date: Mon, 21 Oct 2024 17:34:01 +0200 Subject: [PATCH] [DERCBOT-1037] Rewrite the RAG chain using LCEL --- .../callbacks/rag_callback_handler.py | 61 ++++++ .../retriever_json_callback_handler.py | 188 ------------------ .../services/langchain/rag_chain.py | 165 ++++++++------- .../services/test_langchain_callbacks.py | 142 ++++--------- .../server/tests/services/test_rag_chain.py | 12 +- 5 files changed, 206 insertions(+), 362 deletions(-) create mode 100644 gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/rag_callback_handler.py delete mode 100644 gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/retriever_json_callback_handler.py diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/rag_callback_handler.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/rag_callback_handler.py new file mode 100644 index 0000000000..e85003d9ac --- /dev/null +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/rag_callback_handler.py @@ -0,0 +1,61 @@ +# Copyright (C) 2023-2024 Credit Mutuel Arkea +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Retriever callback handler for LangChain.""" + +import logging +from typing import Any, Dict, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain_core.messages import SystemMessage, AIMessage +from langchain_core.prompt_values import ChatPromptValue, StringPromptValue + +logger = logging.getLogger(__name__) + + +class RAGCallbackHandler(BaseCallbackHandler): + """Customized RAG callback handler that retrieves data from the chain execution.""" + + records: Dict[str, Any] = { + 'chat_prompt': None, + 'chat_chain_output': None, + 'rag_prompt': None, + 'rag_chain_output': None, + 'documents': None, + } + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + + if kwargs['name'] == 'chat_chain_output' and isinstance(inputs, AIMessage): + self.records['chat_chain_output'] = inputs.content + + if kwargs['name'] == 'rag_chain_output' and isinstance(inputs, AIMessage): + self.records['rag_chain_output'] = inputs.content + + if kwargs['name'] == 'RunnableAssign' and 'documents' in inputs: + self.records['documents'] = inputs['documents'] + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Print out that we finished a chain.""" # if outputs is instance of StringPromptValue + + if isinstance(outputs, ChatPromptValue): + self.records['chat_prompt'] = next( + (msg.content for msg in outputs.messages if isinstance(msg, SystemMessage)), None + ) + + if isinstance(outputs, StringPromptValue): + self.records['rag_prompt'] = outputs.text diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/retriever_json_callback_handler.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/retriever_json_callback_handler.py deleted file mode 100644 index 2ee502e34a..0000000000 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/retriever_json_callback_handler.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (C) 2023-2024 Credit Mutuel Arkea -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""Retriever callback handler for LangChain.""" - -import logging -import re -from typing import Any, Dict, List, Optional, Union - -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult - -logger = logging.getLogger(__name__) - - -class RetrieverJsonCallbackHandler(BaseCallbackHandler): - """Callback Handler that reorganize logs to json data.""" - - def __init__(self, color: Optional[str] = None) -> None: - """Initialize callback handler.""" - self.logger = logger - self.color = color - - self.records: Dict[str, Any] = { - # "on_llm_start_records": [], - # "on_llm_token_records": [], - # "on_llm_end_records": [], - 'on_chain_start_records': [], - 'on_chain_end_records': [], - # "on_tool_start_records": [], - # "on_tool_end_records": [], - 'on_text_records': [], - # "on_agent_finish_records": [], - # "on_agent_action_records": [], - 'action_records': [], - } - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - # filter to gest only input documents - if 'input_documents' in inputs: - docs = inputs['input_documents'] - input_documents = [ - {'page_content': doc.page_content, 'metadata': doc.metadata} - for doc in docs - ] - json_data = { - 'event_name': 'on_chain_start', - 'inputs': { - 'input_documents': input_documents, - 'question': inputs['question'], - 'chat_history': inputs['chat_history'], - }, - } - if json_data not in self.records['on_chain_start_records']: - self.records['on_chain_start_records'].append(json_data) - if json_data not in self.records['action_records']: - self.records['action_records'].append(json_data) - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - # reponse FAQ - if 'text' in outputs: - json_data = {'event_name': 'on_chain_end', 'output': outputs['text']} - if json_data not in self.records['on_chain_end_records']: - self.records['on_chain_end_records'].append(json_data) - if json_data not in self.records['action_records']: - self.records['action_records'].append(json_data) - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - pass - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Do nothing.""" - pass - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Do nothing.""" - pass - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = '', - **kwargs: Any, - ) -> None: - """Run when agent ends.""" - json_data = { - 'event_name': 'on_text', - 'text': self.normalise_prompt(text), - } - if json_data not in self.records['on_text_records']: - self.records['on_text_records'].append(json_data) - if json_data not in self.records['action_records']: - self.records['action_records'].append(json_data) - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def show_records(self, record_name: str = None): - """Show registered records from handler""" - if record_name != None and record_name in self.records: - records = self.records[record_name] - else: - records = self.records - return records - - - def normalise_prompt(self, prompt: str): - """ - Remove 'on after prompt' and color on prompt. - To identify the color ansi sequence, the function uses this regular expression : \x1B\[[0-?]*[ -/]*[@-~] - - Args: - prompt: the prompt to normalise - """ - - # remove ansi escape sequences - ansi_escape = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]') - prompt = ansi_escape.sub('', prompt) - - # remove a static sentence - return prompt.replace('Prompt after formatting:\n', '') - diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py index 895ab11eba..a8a8fd833e 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py @@ -19,13 +19,16 @@ import logging import time +from functools import partial from logging import ERROR, WARNING -from typing import List, Optional +from typing import List from langchain.chains import ConversationalRetrievalChain from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.documents import Document -from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate, ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnablePassthrough, RunnableParallel from gen_ai_orchestrator.errors.exceptions.exceptions import ( GenAIGuardCheckException, @@ -48,8 +51,8 @@ ) from gen_ai_orchestrator.routers.requests.requests import RagQuery from gen_ai_orchestrator.routers.responses.responses import RagResponse -from gen_ai_orchestrator.services.langchain.callbacks.retriever_json_callback_handler import ( - RetrieverJsonCallbackHandler, +from gen_ai_orchestrator.services.langchain.callbacks.rag_callback_handler import ( + RAGCallbackHandler, ) from gen_ai_orchestrator.services.langchain.factories.langchain_factory import ( get_em_factory, @@ -100,7 +103,7 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse: ) callback_handlers = [] - records_callback_handler = RetrieverJsonCallbackHandler() + records_callback_handler = RAGCallbackHandler() if debug: # Debug callback handler callback_handlers.append(records_callback_handler) @@ -117,7 +120,7 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse: ) # RAG Guard - __rag_guard(inputs, response) + rag_guard(inputs, response) # Calculation of RAG processing time rag_duration = '{:.2f}'.format(time.time() - start_time) @@ -135,12 +138,12 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse: url=doc.metadata['source'], content=get_source_content(doc), ), - response['source_documents'], + response['documents'], ) ), ), debug=get_rag_debug_data( - query, response, records_callback_handler, rag_duration + query, records_callback_handler, rag_duration ) if debug else None @@ -164,40 +167,89 @@ def get_source_content(doc: Document) -> str: def create_rag_chain(query: RagQuery) -> ConversationalRetrievalChain: """ - Create the RAG chain from RagQuery, using the LLM and Embedding settings specified in the query + Create the RAG chain from RagQuery, using the LLM and Embedding settings specified in the query. Args: query: The RAG query Returns: The RAG chain. """ + llm_factory = get_llm_factory(setting=query.question_answering_llm_setting) em_factory = get_em_factory(setting=query.embedding_question_em_setting) - vector_store_factory = get_vector_store_factory(setting=query.vector_store_setting, - index_name=query.document_index_name, - embedding_function=em_factory.get_embedding_model()) + vector_store_factory = get_vector_store_factory( + setting=query.vector_store_setting, + index_name=query.document_index_name, + embedding_function=em_factory.get_embedding_model() + ) + retriever = vector_store_factory.get_vector_store_retriever(query.document_search_params.to_dict()) - logger.info('RAG chain - LLM template validation') + # Log progress and validate prompt template + logger.info('RAG chain - Validating LLM prompt template') validate_prompt_template(query.question_answering_prompt) + logger.debug('RAG chain - Document index name: %s', query.document_index_name) + # Build LLM and prompt templates + llm = llm_factory.get_language_model() + rag_prompt = build_rag_prompt(query) - logger.debug('RAG chain - Document index name: %s', query.document_index_name) - logger.debug('RAG chain - Create a ConversationalRetrievalChain from LLM') - return ConversationalRetrievalChain.from_llm( - llm=llm_factory.get_language_model(), - retriever=vector_store_factory.get_vector_store_retriever(query.document_search_params.to_dict()), - return_source_documents=True, - return_generated_question=True, - combine_docs_chain_kwargs={ - 'prompt': LangChainPromptTemplate.from_template( - template=query.question_answering_prompt.template, - template_format=query.question_answering_prompt.formatter.value, - ) - }, + # Construct the RAG chain using the prompt and LLM + rag_chain = construct_rag_chain(llm, rag_prompt) + + # Build the chat chain for question contextualization + chat_chain = build_chat_chain(llm) + + # Function to contextualize the question based on chat history + contextualize_question_fn = partial(contextualize_question, chat_chain=chat_chain) + + # Final RAG chain with retriever and source documents + rag_chain_with_source = contextualize_question_fn | RunnableParallel( + {"question": RunnablePassthrough(), "documents": retriever} + ).assign(answer=rag_chain) + + return rag_chain_with_source + + +def build_rag_prompt(query: RagQuery) -> LangChainPromptTemplate: + """ + Build the RAG prompt template. + """ + return LangChainPromptTemplate.from_template( + template=query.question_answering_prompt.template, + template_format=query.question_answering_prompt.formatter.value, + partial_variables=query.question_answering_prompt.inputs ) -def __rag_guard(inputs, response): +def construct_rag_chain(llm, rag_prompt): + """ + Construct the RAG chain from LLM and prompt. + """ + return { + "context": lambda x: "\n\n".join(doc.page_content for doc in x["documents"]) + } | rag_prompt | llm | StrOutputParser(name="rag_chain_output") + +def build_chat_chain(llm) -> ChatPromptTemplate: + """ + Build the chat chain for contextualizing questions. + """ + return ChatPromptTemplate.from_messages([ + ("system", """Given a chat history and the latest user question which might reference context in \ + the chat history, formulate a standalone question which can be understood without the chat history. \ + Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""), + MessagesPlaceholder(variable_name="chat_history"), + ("human", "{question}"), + ]) | llm | StrOutputParser(name="chat_chain_output") + +def contextualize_question(inputs: dict, chat_chain) -> str: + """ + Contextualize the question based on the chat history. + """ + if inputs.get("chat_history") and len(inputs["chat_history"]) > 0: + return chat_chain + return inputs["question"] + +def rag_guard(inputs, response): """ If a 'no_answer' input was given as a rag setting, then the RAG system should give no further response when no source document has been found. @@ -212,23 +264,23 @@ def __rag_guard(inputs, response): if 'no_answer' in inputs: if ( response['answer'] != inputs['no_answer'] - and response['source_documents'] == [] + and response['documents'] == [] ): message = 'The RAG gives an answer when no document has been found!' - __rag_log(level=ERROR, message=message, inputs=inputs, response=response) + rag_log(level=ERROR, message=message, inputs=inputs, response=response) raise GenAIGuardCheckException(ErrorInfo(cause=message)) if ( response['answer'] == inputs['no_answer'] - and response['source_documents'] != [] + and response['documents'] != [] ): message = 'The RAG gives no answer for user question, but some documents has been found!' - __rag_log(level=WARNING, message=message, inputs=inputs, response=response) + rag_log(level=WARNING, message=message, inputs=inputs, response=response) # Remove source documents - response['source_documents'] = [] + response['documents'] = [] -def __rag_log(level, message, inputs, response): +def rag_log(level, message, inputs, response): """ RAG logging @@ -247,12 +299,12 @@ def __rag_log(level, message, inputs, response): 'message': message, 'question': inputs['question'], 'answer': response['answer'], - 'documents': response['source_documents'], + 'documents': response['documents'], }, ) -def get_rag_documents(handler: RetrieverJsonCallbackHandler) -> List[RagDocument]: +def get_rag_documents(handler: RAGCallbackHandler) -> List[RagDocument]: """ Get documents used on RAG context @@ -260,56 +312,29 @@ def get_rag_documents(handler: RetrieverJsonCallbackHandler) -> List[RagDocument handler: the callback handler """ - on_chain_start_records = handler.show_records('on_chain_start_records') return [ # Get first 100 char of content RagDocument( - content=doc['page_content'][0:len(doc['metadata']['title'])+100] + '...', - metadata=RagDocumentMetadata(**doc['metadata']), + content=doc.page_content[0:len(doc.metadata['title'])+100] + '...', + metadata=RagDocumentMetadata(**doc.metadata), ) - for doc in on_chain_start_records[0]['inputs']['input_documents'] + for doc in handler.records['documents'] ] -def get_condense_question(handler: RetrieverJsonCallbackHandler) -> Optional[str]: - """Get the condensed question""" - - on_text_records = handler.show_records('on_text_records') - # If the handler records 2 texts (prompts), this means that 2 LLM providers are invoked - if len(on_text_records) == 2: - # So the user question is condensed - on_chain_start_records = handler.show_records('on_chain_start_records') - return on_chain_start_records[0]['inputs']['question'] - else: - # Else, the user's question was not formulated - return None - - -def get_llm_prompts(handler: RetrieverJsonCallbackHandler) -> (Optional[str], str): - """Get used llm prompt""" - - on_text_records = handler.show_records('on_text_records') - # If the handler records 2 texts (prompts), this means that 2 LLM providers are invoked - if len(on_text_records) == 2: - return on_text_records[0]['text'], on_text_records[1]['text'] - - # Else, only the LLM for "question answering" was invoked - return None, on_text_records[0]['text'] - - def get_rag_debug_data( - query, response, records_callback_handler, rag_duration + query: RagQuery, records_callback_handler: RAGCallbackHandler, rag_duration ) -> RagDebugData: """RAG debug data assembly""" return RagDebugData( user_question=query.question_answering_prompt.inputs['question'], - condense_question_prompt=get_llm_prompts(records_callback_handler)[0], - condense_question=get_condense_question(records_callback_handler), - question_answering_prompt=get_llm_prompts(records_callback_handler)[1], + condense_question_prompt=records_callback_handler.records['chat_prompt'], + condense_question=records_callback_handler.records['chat_chain_output'], + question_answering_prompt=records_callback_handler.records['rag_prompt'], documents=get_rag_documents(records_callback_handler), document_index_name=query.document_index_name, document_search_params=query.document_search_params, - answer=response['answer'], + answer=records_callback_handler.records['rag_chain_output'], duration=rag_duration, ) diff --git a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_langchain_callbacks.py b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_langchain_callbacks.py index af2a66a6d0..8895550589 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_langchain_callbacks.py +++ b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_langchain_callbacks.py @@ -13,112 +13,58 @@ # limitations under the License. # from langchain_core.documents import Document +from langchain_core.messages import AIMessage, SystemMessage, HumanMessage +from langchain_core.prompt_values import StringPromptValue, ChatPromptValue -from gen_ai_orchestrator.services.langchain.callbacks.retriever_json_callback_handler import ( - RetrieverJsonCallbackHandler, +from gen_ai_orchestrator.services.langchain.callbacks.rag_callback_handler import ( + RAGCallbackHandler, ) -def test_retriever_json_callback_handler_on_chain_start(): +def test_rag_callback_handler_qa_documents(): """Check records are added (in the correct entries)""" - handler = RetrieverJsonCallbackHandler() - _inputs = { - 'input_documents': [ - Document( - page_content='some page content', - metadata={'some meta': 'some meta value'}, - ) - ], - 'question': 'What is happening?', - 'chat_history': [], - } - handler.on_chain_start(serialized={}, inputs=_inputs) - expected_json_data = { - 'event_name': 'on_chain_start', - 'inputs': { - 'input_documents': [ - { - 'page_content': 'some page content', - 'metadata': {'some meta': 'some meta value'}, - } - ], - 'question': _inputs['question'], - 'chat_history': _inputs['chat_history'], - }, - } - assert handler.records['on_chain_start_records'][0] == expected_json_data - assert handler.records['action_records'][0] == expected_json_data - - -def test_retriever_json_callback_handler_on_chain_start_no_double_entries(): - """Check records are added only once in history.""" - handler = RetrieverJsonCallbackHandler() - _inputs = { - 'input_documents': [ - Document( - page_content='some page content', - metadata={'some meta': 'some meta value'}, - ) - ], - 'question': 'What is happening?', - 'chat_history': [], - } - handler.on_chain_start(serialized={}, inputs=_inputs) - expected_json_data = { - 'event_name': 'on_chain_start', - 'inputs': { - 'input_documents': [ - { - 'page_content': 'some page content', - 'metadata': {'some meta': 'some meta value'}, - } - ], - 'question': _inputs['question'], - 'chat_history': _inputs['chat_history'], - }, - } - assert expected_json_data in handler.records['on_chain_start_records'] - assert expected_json_data in handler.records['action_records'] - assert len(handler.records['on_chain_start_records']) == 1 - assert len(handler.records['action_records']) == 1 - handler.on_chain_start(serialized={}, inputs=_inputs) - assert expected_json_data in handler.records['on_chain_start_records'] - assert expected_json_data in handler.records['action_records'] - assert len(handler.records['on_chain_start_records']) == 1 - assert len(handler.records['action_records']) == 1 - - -def test_retriever_json_callback_handler_on_chain_start_no_inputs(): - """Check no records are added if none are present in chain inputs.""" - handler = RetrieverJsonCallbackHandler() - _inputs = {'question': 'What is happening?', 'chat_history': []} - handler.on_chain_start(serialized={}, inputs=_inputs) - assert len(handler.records['on_chain_start_records']) == 0 - assert len(handler.records['action_records']) == 0 + handler = RAGCallbackHandler() + docs = [Document( + page_content='some page content', + metadata={'some meta': 'some meta value'}, + )] + handler.on_chain_start(serialized={}, + inputs={'documents': docs}, + **{'name': 'RunnableAssign'}) + assert handler.records['documents'] == docs +def test_rag_callback_handler_chat_prompt_output(): + """Check records are added (in the correct entries)""" + handler = RAGCallbackHandler() + llm_output = 'llm result !' + handler.on_chain_start(serialized={}, + inputs=AIMessage(content=llm_output), + **{'name': 'chat_chain_output'}) + assert handler.records['chat_chain_output'] == llm_output -def test_retriever_json_callback_handler_on_chain_end(): +def test_rag_callback_handler_qa_prompt_output(): """Check records are added (in the correct entries)""" - handler = RetrieverJsonCallbackHandler() - _outputs = { - 'text': 'This is what is happening', - } - handler.on_chain_end(outputs=_outputs) - expected_json_data = { - 'event_name': 'on_chain_end', - 'output': 'This is what is happening', - } - assert handler.records['on_chain_end_records'][0] == expected_json_data - assert handler.records['action_records'][0] == expected_json_data + handler = RAGCallbackHandler() + llm_output = 'llm result !' + handler.on_chain_start(serialized={}, + inputs=AIMessage(content=llm_output), + **{'name': 'rag_chain_output'}) + assert handler.records['rag_chain_output'] == llm_output +def test_rag_callback_handler_chat_prompt(): + """Check records are added (in the correct entries)""" + handler = RAGCallbackHandler() + prompt = 'A custom prompt !' + outputs = ChatPromptValue(messages=[ + SystemMessage(content=prompt), + HumanMessage(content='hi !') + ]) + handler.on_chain_end(serialized={}, outputs=outputs) + assert handler.records['chat_prompt'] == prompt -def test_retriever_json_callback_handler_on_text(): +def test_rag_callback_handler_qa_prompt(): """Check records are added (in the correct entries)""" - handler = RetrieverJsonCallbackHandler() - handler.on_text(text='Some text arrives') - expected_json_data = { - 'event_name': 'on_text', - 'text': 'Some text arrives', - } - assert handler.records['on_text_records'][0] == expected_json_data - assert handler.records['action_records'][0] == expected_json_data + handler = RAGCallbackHandler() + prompt = 'A custom prompt !' + handler.on_chain_end(serialized={}, outputs=StringPromptValue(text=prompt)) + assert handler.records['rag_prompt'] == prompt diff --git a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py index 73ae69eea3..cd9786af82 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py +++ b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py @@ -25,8 +25,8 @@ ) from gen_ai_orchestrator.routers.requests.requests import RagQuery from gen_ai_orchestrator.services.langchain import rag_chain -from gen_ai_orchestrator.services.langchain.callbacks.retriever_json_callback_handler import ( - RetrieverJsonCallbackHandler, +from gen_ai_orchestrator.services.langchain.callbacks.rag_callback_handler import ( + RAGCallbackHandler, ) from gen_ai_orchestrator.services.langchain.rag_chain import ( execute_qa_chain, @@ -232,7 +232,7 @@ def test_rag_guard_removes_docs_if_no_answer(mocked_log): def test_get_llm_prompts_one_record(): - handler = RetrieverJsonCallbackHandler() + handler = RAGCallbackHandler() handler.on_text(text='LLM 1') llm_1, llm_2 = get_llm_prompts(handler) assert llm_1 is None @@ -240,7 +240,7 @@ def test_get_llm_prompts_one_record(): def test_get_llm_prompts_one_record(): - handler = RetrieverJsonCallbackHandler() + handler = RAGCallbackHandler() handler.on_text(text='LLM 1') handler.on_text(text='LLM 2') llm_1, llm_2 = get_llm_prompts(handler) @@ -249,7 +249,7 @@ def test_get_llm_prompts_one_record(): def test_get_condense_question_none(): - handler = RetrieverJsonCallbackHandler() + handler = RAGCallbackHandler() handler.on_text(text='LLM 1') handler.on_chain_start( serialized={}, @@ -264,7 +264,7 @@ def test_get_condense_question_none(): def test_get_condense_question(): - handler = RetrieverJsonCallbackHandler() + handler = RAGCallbackHandler() handler.on_text(text='LLM 1') handler.on_text(text='LLM 2') handler.on_chain_start(