diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 42586bb7..48a864e4 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -24,12 +24,7 @@ RagInitializationError, SearchValidationError, ) -from neo4j_graphrag.generation.prompts import ( - SUMMARY_SYSTEM_MESSAGE, - RagTemplate, - ChatSummaryTemplate, - ConversationTemplate, -) +from neo4j_graphrag.generation.prompts import RagTemplate from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel from neo4j_graphrag.llm import LLMInterface from neo4j_graphrag.llm.types import LLMMessage @@ -151,11 +146,35 @@ def search( def build_query( self, query_text: str, message_history: Optional[list[LLMMessage]] = None ) -> str: + summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words." if message_history: - summarization_prompt = ChatSummaryTemplate(message_history=message_history) + summarization_prompt = self.chat_summary_prompt( + message_history=message_history + ) summary = self.llm.invoke( input=summarization_prompt, - system_instruction=SUMMARY_SYSTEM_MESSAGE, + system_instruction=summary_system_message, ).content - return ConversationTemplate(summary=summary, current_query=query_text) + return self.conversation_prompt(summary=summary, current_query=query_text) return query_text + + def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str: + message_list = [ + ": ".join([f"{value}" for _, value in message.items()]) + for message in message_history + ] + history = "\n".join(message_list) + return f""" +Summarize the message history: + +{history} +""" + + def conversation_prompt(self, summary: str, current_query: str) -> str: + return f""" +Message Summary: +{summary} + +Current Query: +{current_query} +""" diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 190ae32e..365d74c0 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -21,7 +21,6 @@ PromptMissingInputError, PromptMissingPlaceholderError, ) -from neo4j_graphrag.llm.types import LLMMessage class PromptTemplate: @@ -197,29 +196,3 @@ def format( text: str = "", ) -> str: return super().format(text=text, schema=schema, examples=examples) - - -SUMMARY_SYSTEM_MESSAGE = "You are a summarization assistant. Summarize the given text in no more than 200 words" - - -def ChatSummaryTemplate(message_history: list[LLMMessage]) -> str: - message_list = [ - ": ".join([f"{value}" for _, value in message.items()]) - for message in message_history - ] - history = "\n".join(message_list) - return f""" -Summarize the message history: - -{history} -""" - - -def ConversationTemplate(summary: str, current_query: str) -> str: - return f""" -Message Summary: -{summary} - -Current Query: -{current_query} -""" diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index 59d26ac7..178d34f7 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -17,11 +17,7 @@ import pytest from neo4j_graphrag.exceptions import RagInitializationError, SearchValidationError from neo4j_graphrag.generation.graphrag import GraphRAG -from neo4j_graphrag.generation.prompts import ( - RagTemplate, - ChatSummaryTemplate, - ConversationTemplate, -) +from neo4j_graphrag.generation.prompts import RagTemplate from neo4j_graphrag.generation.types import RagResultModel from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem @@ -126,7 +122,7 @@ def test_graphrag_happy_path_with_message_history( user: initial question assistant: answer to initial question """ - first_invokation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 200 words" + first_invokation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words." second_invokation = """Answer the user question using the following context Context: @@ -180,14 +176,18 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non assert "Input should be a valid string" in str(excinfo) -def test_chat_summary_template() -> None: +def test_chat_summary_template(retriever_mock: MagicMock, llm: MagicMock) -> None: message_history = [ {"role": "user", "content": "initial question"}, {"role": "assistant", "content": "answer to initial question"}, {"role": "user", "content": "second question"}, {"role": "assistant", "content": "answer to second question"}, ] - prompt = ChatSummaryTemplate(message_history=message_history) # type: ignore + rag = GraphRAG( + retriever=retriever_mock, + llm=llm, + ) + prompt = rag.chat_summary_prompt(message_history=message_history) # type: ignore assert ( prompt == """ @@ -201,8 +201,12 @@ def test_chat_summary_template() -> None: ) -def test_conversation_template() -> None: - prompt = ConversationTemplate( +def test_conversation_template(retriever_mock: MagicMock, llm: MagicMock) -> None: + rag = GraphRAG( + retriever=retriever_mock, + llm=llm, + ) + prompt = rag.conversation_prompt( summary="llm generated chat summary", current_query="latest question" ) assert (