From abef33c195b7c23ec0a18482baa4190775fe6b10 Mon Sep 17 00:00:00 2001 From: Leila Messallem Date: Mon, 16 Dec 2024 15:20:25 +0100 Subject: [PATCH] Use BaseMessage class type * for the type declaration of the `message_history` parameter --- src/neo4j_graphrag/generation/graphrag.py | 7 ++++--- src/neo4j_graphrag/generation/prompts.py | 6 +++--- src/neo4j_graphrag/llm/anthropic_llm.py | 8 ++++---- src/neo4j_graphrag/llm/base.py | 4 ++-- src/neo4j_graphrag/llm/cohere_llm.py | 7 ++++--- src/neo4j_graphrag/llm/mistralai_llm.py | 7 ++++--- src/neo4j_graphrag/llm/ollama_llm.py | 8 ++++---- src/neo4j_graphrag/llm/openai_llm.py | 8 ++++---- src/neo4j_graphrag/llm/vertexai_llm.py | 16 ++++++++-------- 9 files changed, 37 insertions(+), 34 deletions(-) diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 9aaee06d..6ce903f9 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -31,6 +31,7 @@ ) from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.llm.types import BaseMessage from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import RetrieverResult @@ -87,7 +88,7 @@ def __init__( def search( self, query_text: str = "", - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[BaseMessage]] = None, examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool | None = None, @@ -147,7 +148,7 @@ def search( return RagResultModel(**result) def build_query( - self, query_text: str, message_history: Optional[list[dict[str, str]]] = None + self, query_text: str, message_history: Optional[list[BaseMessage]] = None ) -> str: if message_history: summarization_prompt = ChatSummaryTemplate().format( @@ -155,7 +156,7 @@ def build_query( ) summary = self.llm.invoke( input=summarization_prompt, - system_instruction=summarization_prompt.SYSTEM_MESSAGE, + system_instruction=ChatSummaryTemplate().SYSTEM_MESSAGE, ).content return ConversationTemplate().format( summary=summary, current_query=query_text diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 862c5af0..6df4d069 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -17,6 +17,7 @@ import warnings from typing import Any, Optional +from neo4j_graphrag.llm.types import BaseMessage from neo4j_graphrag.exceptions import ( PromptMissingInputError, PromptMissingPlaceholderError, @@ -207,10 +208,9 @@ class ChatSummaryTemplate(PromptTemplate): EXPECTED_INPUTS = ["message_history"] SYSTEM_MESSAGE = "You are a summarization assistant. Summarize the given text in no more than 200 words" - def format(self, message_history: list[dict[str, str]]) -> str: + def format(self, message_history: list[BaseMessage]) -> str: message_list = [ - ": ".join([f"{value}" for _, value in message.items()]) - for message in message_history + f"{message.role}: {message.content}" for message in message_history ] history = "\n".join(message_list) return super().format(message_history=history) diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index a295f888..7089d8b1 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -19,7 +19,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage +from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage, BaseMessage if TYPE_CHECKING: from anthropic.types.message_param import MessageParam @@ -71,7 +71,7 @@ def __init__( self.async_client = anthropic.AsyncAnthropic(**kwargs) def get_messages( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> Iterable[MessageParam]: messages = [] if message_history: @@ -84,7 +84,7 @@ def get_messages( return messages def invoke( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -108,7 +108,7 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 48ffea3d..f4c66b94 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from typing import Any, Optional -from .types import LLMResponse +from .types import LLMResponse, BaseMessage class LLMInterface(ABC): @@ -45,7 +45,7 @@ def __init__( def invoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[BaseMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 8d04b8df..43578e73 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -24,6 +24,7 @@ MessageList, SystemMessage, UserMessage, + BaseMessage, ) if TYPE_CHECKING: @@ -74,7 +75,7 @@ def __init__( self.async_client = cohere.AsyncClientV2(**kwargs) def get_messages( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> ChatMessages: messages = [] if self.system_instruction: @@ -89,7 +90,7 @@ def get_messages( return messages def invoke( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -113,7 +114,7 @@ def invoke( ) async def ainvoke( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index a3bab0e9..6d9f073d 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -25,6 +25,7 @@ MessageList, SystemMessage, UserMessage, + BaseMessage, ) try: @@ -65,7 +66,7 @@ def __init__( self.client = Mistral(api_key=api_key, **kwargs) def get_messages( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> list[Messages]: messages = [] if self.system_instruction: @@ -80,7 +81,7 @@ def get_messages( return messages def invoke( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> LLMResponse: """Sends a text input to the Mistral chat completion model and returns the response's content. @@ -112,7 +113,7 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> LLMResponse: """Asynchronously sends a text input to the MistralAI chat completion model and returns the response's content. diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 4e7a7b6f..0a75887f 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -19,7 +19,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from .base import LLMInterface -from .types import LLMResponse, SystemMessage, UserMessage, MessageList +from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage if TYPE_CHECKING: from ollama import Message @@ -50,7 +50,7 @@ def __init__( ) def get_messages( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> Sequence[Message]: messages = [] if self.system_instruction: @@ -65,7 +65,7 @@ def get_messages( return messages def invoke( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> LLMResponse: try: response = self.client.chat( @@ -79,7 +79,7 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> LLMResponse: try: response = await self.async_client.chat( diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 3a1abf09..ab074c25 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -21,7 +21,7 @@ from ..exceptions import LLMGenerationError from .base import LLMInterface -from .types import LLMResponse, SystemMessage, UserMessage, MessageList +from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage if TYPE_CHECKING: import openai @@ -63,7 +63,7 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[Any]] = None, + message_history: Optional[list[BaseMessage]] = None, system_instruction: Optional[str] = None, ) -> Iterable[ChatCompletionMessageParam]: messages = [] @@ -86,7 +86,7 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[Any]] = None, + message_history: Optional[list[BaseMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model @@ -115,7 +115,7 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat completion model and returns the response's content. diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index d8510c41..ede7d73b 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -19,7 +19,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import LLMResponse, MessageList +from neo4j_graphrag.llm.types import LLMResponse, MessageList, BaseMessage try: from vertexai.generative_models import ( @@ -77,7 +77,7 @@ def __init__( self.model_params = kwargs def get_messages( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> list[Content]: messages = [] if message_history: @@ -87,16 +87,16 @@ def get_messages( raise LLMGenerationError(e.errors()) from e for message in message_history: - if message.get("role") == "user": + if message.role == "user": messages.append( Content( - role="user", parts=[Part.from_text(message.get("content"))] + role="user", parts=[Part.from_text(message.content)] ) ) - elif message.get("role") == "assistant": + elif message.role == "assistant": messages.append( Content( - role="model", parts=[Part.from_text(message.get("content"))] + role="model", parts=[Part.from_text(message.content)] ) ) @@ -106,7 +106,7 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[Any]] = None, + message_history: Optional[list[BaseMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -137,7 +137,7 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, message_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[BaseMessage]] = None ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response.