Skip to content

Commit

Permalink
Rename chat_history to message_history
Browse files Browse the repository at this point in the history
  • Loading branch information
leila-messallem committed Dec 16, 2024
1 parent 07038dd commit 37225fd
Show file tree
Hide file tree
Showing 16 changed files with 129 additions and 127 deletions.
14 changes: 7 additions & 7 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
def search(
self,
query_text: str = "",
chat_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[dict[str, str]]] = None,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
Expand All @@ -105,7 +105,7 @@ def search(
Args:
query_text (str): The user question.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
examples (str): Examples added to the LLM prompt.
retriever_config (Optional[dict]): Parameters passed to the retriever.
search method; e.g.: top_k
Expand All @@ -130,7 +130,7 @@ def search(
)
except ValidationError as e:
raise SearchValidationError(e.errors())
query = self.build_query(validated_data.query_text, chat_history)
query = self.build_query(validated_data.query_text, message_history)
retriever_result: RetrieverResult = self.retriever.search(
query_text=query, **validated_data.retriever_config
)
Expand All @@ -140,18 +140,18 @@ def search(
)
logger.debug(f"RAG: retriever_result={retriever_result}")
logger.debug(f"RAG: prompt={prompt}")
answer = self.llm.invoke(prompt, chat_history)
answer = self.llm.invoke(prompt, message_history)
result: dict[str, Any] = {"answer": answer.content}
if return_context:
result["retriever_result"] = retriever_result
return RagResultModel(**result)

def build_query(
self, query_text: str, chat_history: Optional[list[dict[str, str]]] = None
self, query_text: str, message_history: Optional[list[dict[str, str]]] = None
) -> str:
if chat_history:
if message_history:
summarization_prompt = ChatSummaryTemplate().format(
chat_history=chat_history
message_history=message_history
)
summary = self.llm.invoke(
input=summarization_prompt,
Expand Down
14 changes: 7 additions & 7 deletions src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,25 +200,25 @@ def format(

class ChatSummaryTemplate(PromptTemplate):
DEFAULT_TEMPLATE = """
Summarize the chat history:
Summarize the message history:
{chat_history}
{message_history}
"""
EXPECTED_INPUTS = ["chat_history"]
EXPECTED_INPUTS = ["message_history"]
SYSTEM_MESSAGE = "You are a summarization assistant. Summarize the given text in no more than 200 words"

def format(self, chat_history: list[dict[str, str]]) -> str:
def format(self, message_history: list[dict[str, str]]) -> str:
message_list = [
": ".join([f"{value}" for _, value in message.items()])
for message in chat_history
for message in message_history
]
history = "\n".join(message_list)
return super().format(chat_history=history)
return super().format(message_history=history)


class ConversationTemplate(PromptTemplate):
DEFAULT_TEMPLATE = """
Chat Summary:
Message Summary:
{summary}
Current Query:
Expand Down
20 changes: 10 additions & 10 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,32 +71,32 @@ def __init__(
self.async_client = anthropic.AsyncAnthropic(**kwargs)

def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> Iterable[MessageParam]:
messages = []
if chat_history:
if message_history:
try:
MessageList(messages=chat_history)
MessageList(messages=message_history)
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(chat_history)
messages.extend(message_history)
messages.append(UserMessage(content=input).model_dump())
return messages

def invoke(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, chat_history)
messages = self.get_messages(input, message_history)
response = self.client.messages.create(
model=self.model_name,
system=self.system_instruction,
Expand All @@ -108,19 +108,19 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, chat_history)
messages = self.get_messages(input, message_history)
response = await self.async_client.messages.create(
model=self.model_name,
system=self.system_instruction,
Expand Down
8 changes: 4 additions & 4 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def __init__(
def invoke(
self,
input: str,
chat_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the LLM and retrieves a response.
Args:
input (str): Text sent to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
Expand All @@ -64,13 +64,13 @@ def invoke(

@abstractmethod
async def ainvoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, message_history: Optional[list[dict[str, str]]] = None
) -> LLMResponse:
"""Asynchronously sends a text input to the LLM and retrieves a response.
Args:
input (str): Text sent to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
Returns:
Expand Down
20 changes: 10 additions & 10 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,34 +74,34 @@ def __init__(
self.async_client = cohere.AsyncClientV2(**kwargs)

def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> ChatMessages:
messages = []
if self.system_instruction:
messages.append(SystemMessage(content=self.system_instruction).model_dump())
if chat_history:
if message_history:
try:
MessageList(messages=chat_history)
MessageList(messages=message_history)
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(chat_history)
messages.extend(message_history)
messages.append(UserMessage(content=input).model_dump())
return messages

def invoke(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, chat_history)
messages = self.get_messages(input, message_history)
res = self.client.chat(
messages=messages,
model=self.model_name,
Expand All @@ -113,19 +113,19 @@ def invoke(
)

async def ainvoke(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, chat_history)
messages = self.get_messages(input, message_history)
res = self.async_client.chat(
messages=messages,
model=self.model_name,
Expand Down
20 changes: 10 additions & 10 deletions src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,29 +65,29 @@ def __init__(
self.client = Mistral(api_key=api_key, **kwargs)

def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> list[Messages]:
messages = []
if self.system_instruction:
messages.append(SystemMessage(content=self.system_instruction).model_dump())
if chat_history:
if message_history:
try:
MessageList(messages=chat_history)
MessageList(messages=message_history)
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(chat_history)
messages.extend(message_history)
messages.append(UserMessage(content=input).model_dump())
return messages

def invoke(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Sends a text input to the Mistral chat completion model
and returns the response's content.
Args:
input (str): Text sent to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
Returns:
LLMResponse: The response from MistralAI.
Expand All @@ -96,7 +96,7 @@ def invoke(
LLMGenerationError: If anything goes wrong.
"""
try:
messages = self.get_messages(input, chat_history)
messages = self.get_messages(input, message_history)
response = self.client.chat.complete(
model=self.model_name,
messages=messages,
Expand All @@ -112,14 +112,14 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Asynchronously sends a text input to the MistralAI chat
completion model and returns the response's content.
Args:
input (str): Text sent to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
Returns:
LLMResponse: The response from MistralAI.
Expand All @@ -128,7 +128,7 @@ async def ainvoke(
LLMGenerationError: If anything goes wrong.
"""
try:
messages = self.get_messages(input, chat_history)
messages = self.get_messages(input, message_history)
response = await self.client.chat.complete_async(
model=self.model_name,
messages=messages,
Expand Down
16 changes: 8 additions & 8 deletions src/neo4j_graphrag/llm/ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,27 @@ def __init__(
)

def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> Sequence[Message]:
messages = []
if self.system_instruction:
messages.append(SystemMessage(content=self.system_instruction).model_dump())
if chat_history:
if message_history:
try:
MessageList(messages=chat_history)
MessageList(messages=message_history)
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(chat_history)
messages.extend(message_history)
messages.append(UserMessage(content=input).model_dump())
return messages

def invoke(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> LLMResponse:
try:
response = self.client.chat(
model=self.model_name,
messages=self.get_messages(input, chat_history),
messages=self.get_messages(input, message_history),
options=self.model_params,
)
content = response.message.content or ""
Expand All @@ -79,12 +79,12 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, chat_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[Any]] = None
) -> LLMResponse:
try:
response = await self.async_client.chat(
model=self.model_name,
messages=self.get_messages(input, chat_history),
messages=self.get_messages(input, message_history),
options=self.model_params,
)
content = response.message.content or ""
Expand Down
Loading

0 comments on commit 37225fd

Please sign in to comment.