Skip to content

Commit

Permalink
Move and rename the prompts
Browse files Browse the repository at this point in the history
* ... for query embedding and summarization to the GraphRAG class
  • Loading branch information
leila-messallem committed Dec 20, 2024
1 parent fa12a9f commit f5a9833
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 46 deletions.
37 changes: 28 additions & 9 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@
RagInitializationError,
SearchValidationError,
)
from neo4j_graphrag.generation.prompts import (
SUMMARY_SYSTEM_MESSAGE,
RagTemplate,
ChatSummaryTemplate,
ConversationTemplate,
)
from neo4j_graphrag.generation.prompts import RagTemplate
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.types import LLMMessage
Expand Down Expand Up @@ -151,11 +146,35 @@ def search(
def build_query(
self, query_text: str, message_history: Optional[list[LLMMessage]] = None
) -> str:
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
if message_history:
summarization_prompt = ChatSummaryTemplate(message_history=message_history)
summarization_prompt = self.chat_summary_prompt(
message_history=message_history
)
summary = self.llm.invoke(
input=summarization_prompt,
system_instruction=SUMMARY_SYSTEM_MESSAGE,
system_instruction=summary_system_message,
).content
return ConversationTemplate(summary=summary, current_query=query_text)
return self.conversation_prompt(summary=summary, current_query=query_text)
return query_text

def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str:
message_list = [
": ".join([f"{value}" for _, value in message.items()])
for message in message_history
]
history = "\n".join(message_list)
return f"""
Summarize the message history:
{history}
"""

def conversation_prompt(self, summary: str, current_query: str) -> str:
return f"""
Message Summary:
{summary}
Current Query:
{current_query}
"""
27 changes: 0 additions & 27 deletions src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
PromptMissingInputError,
PromptMissingPlaceholderError,
)
from neo4j_graphrag.llm.types import LLMMessage


class PromptTemplate:
Expand Down Expand Up @@ -197,29 +196,3 @@ def format(
text: str = "",
) -> str:
return super().format(text=text, schema=schema, examples=examples)


SUMMARY_SYSTEM_MESSAGE = "You are a summarization assistant. Summarize the given text in no more than 200 words"


def ChatSummaryTemplate(message_history: list[LLMMessage]) -> str:
message_list = [
": ".join([f"{value}" for _, value in message.items()])
for message in message_history
]
history = "\n".join(message_list)
return f"""
Summarize the message history:
{history}
"""


def ConversationTemplate(summary: str, current_query: str) -> str:
return f"""
Message Summary:
{summary}
Current Query:
{current_query}
"""
24 changes: 14 additions & 10 deletions tests/unit/test_graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
import pytest
from neo4j_graphrag.exceptions import RagInitializationError, SearchValidationError
from neo4j_graphrag.generation.graphrag import GraphRAG
from neo4j_graphrag.generation.prompts import (
RagTemplate,
ChatSummaryTemplate,
ConversationTemplate,
)
from neo4j_graphrag.generation.prompts import RagTemplate
from neo4j_graphrag.generation.types import RagResultModel
from neo4j_graphrag.llm import LLMResponse
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
Expand Down Expand Up @@ -126,7 +122,7 @@ def test_graphrag_happy_path_with_message_history(
user: initial question
assistant: answer to initial question
"""
first_invokation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 200 words"
first_invokation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words."
second_invokation = """Answer the user question using the following context
Context:
Expand Down Expand Up @@ -180,14 +176,18 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non
assert "Input should be a valid string" in str(excinfo)


def test_chat_summary_template() -> None:
def test_chat_summary_template(retriever_mock: MagicMock, llm: MagicMock) -> None:
message_history = [
{"role": "user", "content": "initial question"},
{"role": "assistant", "content": "answer to initial question"},
{"role": "user", "content": "second question"},
{"role": "assistant", "content": "answer to second question"},
]
prompt = ChatSummaryTemplate(message_history=message_history) # type: ignore
rag = GraphRAG(
retriever=retriever_mock,
llm=llm,
)
prompt = rag.chat_summary_prompt(message_history=message_history) # type: ignore
assert (
prompt
== """
Expand All @@ -201,8 +201,12 @@ def test_chat_summary_template() -> None:
)


def test_conversation_template() -> None:
prompt = ConversationTemplate(
def test_conversation_template(retriever_mock: MagicMock, llm: MagicMock) -> None:
rag = GraphRAG(
retriever=retriever_mock,
llm=llm,
)
prompt = rag.conversation_prompt(
summary="llm generated chat summary", current_query="latest question"
)
assert (
Expand Down

0 comments on commit f5a9833

Please sign in to comment.