Skip to content

Commit

Permalink
Use BaseMessage class type
Browse files Browse the repository at this point in the history
* for the type declaration of the `message_history` parameter
  • Loading branch information
leila-messallem committed Dec 16, 2024
1 parent 37225fd commit abef33c
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 34 deletions.
7 changes: 4 additions & 3 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.types import BaseMessage
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RetrieverResult

Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(
def search(
self,
query_text: str = "",
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[BaseMessage]] = None,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
Expand Down Expand Up @@ -147,15 +148,15 @@ def search(
return RagResultModel(**result)

def build_query(
self, query_text: str, message_history: Optional[list[dict[str, str]]] = None
self, query_text: str, message_history: Optional[list[BaseMessage]] = None
) -> str:
if message_history:
summarization_prompt = ChatSummaryTemplate().format(
message_history=message_history
)
summary = self.llm.invoke(
input=summarization_prompt,
system_instruction=summarization_prompt.SYSTEM_MESSAGE,
system_instruction=ChatSummaryTemplate().SYSTEM_MESSAGE,
).content
return ConversationTemplate().format(
summary=summary, current_query=query_text
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import warnings
from typing import Any, Optional

from neo4j_graphrag.llm.types import BaseMessage
from neo4j_graphrag.exceptions import (
PromptMissingInputError,
PromptMissingPlaceholderError,
Expand Down Expand Up @@ -207,10 +208,9 @@ class ChatSummaryTemplate(PromptTemplate):
EXPECTED_INPUTS = ["message_history"]
SYSTEM_MESSAGE = "You are a summarization assistant. Summarize the given text in no more than 200 words"

def format(self, message_history: list[dict[str, str]]) -> str:
def format(self, message_history: list[BaseMessage]) -> str:
message_list = [
": ".join([f"{value}" for _, value in message.items()])
for message in message_history
f"{message.role}: {message.content}" for message in message_history
]
history = "\n".join(message_list)
return super().format(message_history=history)
Expand Down
8 changes: 4 additions & 4 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage
from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage, BaseMessage

if TYPE_CHECKING:
from anthropic.types.message_param import MessageParam
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
self.async_client = anthropic.AsyncAnthropic(**kwargs)

def get_messages(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> Iterable[MessageParam]:
messages = []
if message_history:
Expand All @@ -84,7 +84,7 @@ def get_messages(
return messages

def invoke(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -108,7 +108,7 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from typing import Any, Optional

from .types import LLMResponse
from .types import LLMResponse, BaseMessage


class LLMInterface(ABC):
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(
def invoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the LLM and retrieves a response.
Expand Down
7 changes: 4 additions & 3 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MessageList,
SystemMessage,
UserMessage,
BaseMessage,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(
self.async_client = cohere.AsyncClientV2(**kwargs)

def get_messages(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> ChatMessages:
messages = []
if self.system_instruction:
Expand All @@ -89,7 +90,7 @@ def get_messages(
return messages

def invoke(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -113,7 +114,7 @@ def invoke(
)

async def ainvoke(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Expand Down
7 changes: 4 additions & 3 deletions src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MessageList,
SystemMessage,
UserMessage,
BaseMessage,
)

try:
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
self.client = Mistral(api_key=api_key, **kwargs)

def get_messages(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> list[Messages]:
messages = []
if self.system_instruction:
Expand All @@ -80,7 +81,7 @@ def get_messages(
return messages

def invoke(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> LLMResponse:
"""Sends a text input to the Mistral chat completion model
and returns the response's content.
Expand Down Expand Up @@ -112,7 +113,7 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> LLMResponse:
"""Asynchronously sends a text input to the MistralAI chat
completion model and returns the response's content.
Expand Down
8 changes: 4 additions & 4 deletions src/neo4j_graphrag/llm/ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from neo4j_graphrag.exceptions import LLMGenerationError

from .base import LLMInterface
from .types import LLMResponse, SystemMessage, UserMessage, MessageList
from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage

if TYPE_CHECKING:
from ollama import Message
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
)

def get_messages(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> Sequence[Message]:
messages = []
if self.system_instruction:
Expand All @@ -65,7 +65,7 @@ def get_messages(
return messages

def invoke(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> LLMResponse:
try:
response = self.client.chat(
Expand All @@ -79,7 +79,7 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> LLMResponse:
try:
response = await self.async_client.chat(
Expand Down
8 changes: 4 additions & 4 deletions src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ..exceptions import LLMGenerationError
from .base import LLMInterface
from .types import LLMResponse, SystemMessage, UserMessage, MessageList
from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage

if TYPE_CHECKING:
import openai
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
def get_messages(
self,
input: str,
message_history: Optional[list[Any]] = None,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> Iterable[ChatCompletionMessageParam]:
messages = []
Expand All @@ -86,7 +86,7 @@ def get_messages(
def invoke(
self,
input: str,
message_history: Optional[list[Any]] = None,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the OpenAI chat completion model
Expand Down Expand Up @@ -115,7 +115,7 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> LLMResponse:
"""Asynchronously sends a text input to the OpenAI chat
completion model and returns the response's content.
Expand Down
16 changes: 8 additions & 8 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse, MessageList
from neo4j_graphrag.llm.types import LLMResponse, MessageList, BaseMessage

try:
from vertexai.generative_models import (
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
self.model_params = kwargs

def get_messages(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> list[Content]:
messages = []
if message_history:
Expand All @@ -87,16 +87,16 @@ def get_messages(
raise LLMGenerationError(e.errors()) from e

for message in message_history:
if message.get("role") == "user":
if message.role == "user":
messages.append(
Content(
role="user", parts=[Part.from_text(message.get("content"))]
role="user", parts=[Part.from_text(message.content)]
)
)
elif message.get("role") == "assistant":
elif message.role == "assistant":
messages.append(
Content(
role="model", parts=[Part.from_text(message.get("content"))]
role="model", parts=[Part.from_text(message.content)]
)
)

Expand All @@ -106,7 +106,7 @@ def get_messages(
def invoke(
self,
input: str,
message_history: Optional[list[Any]] = None,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand Down Expand Up @@ -137,7 +137,7 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, message_history: Optional[list[Any]] = None
self, input: str, message_history: Optional[list[BaseMessage]] = None
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Expand Down

0 comments on commit abef33c

Please sign in to comment.