diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 1582a238..ea5e55d5 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -154,7 +154,9 @@ def build_query( chat_history=chat_history ) summary = self.llm.invoke( - input=summarization_prompt, system_instruction=summarization_prompt.SYSTEM_MESSAGE).content + input=summarization_prompt, + system_instruction=summarization_prompt.SYSTEM_MESSAGE, + ).content return ConversationTemplate().format( summary=summary, current_query=query_text ) diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index f08dbf16..1487b1b9 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -22,7 +22,6 @@ from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage if TYPE_CHECKING: - import anthropic from anthropic.types.message_param import MessageParam diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 40081564..5cc94291 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -43,7 +43,10 @@ def __init__( @abstractmethod def invoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None + self, + input: str, + chat_history: Optional[list[dict[str, str]]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 6bbad87b..ceddd63b 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -27,7 +27,6 @@ ) if TYPE_CHECKING: - import cohere from cohere import ChatMessages diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index ce020d4f..9ae72c87 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -22,7 +22,6 @@ from .types import LLMResponse, SystemMessage, UserMessage, MessageList if TYPE_CHECKING: - import ollama from ollama import Message diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 2c83e952..895e5028 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -61,10 +61,17 @@ def __init__( super().__init__(model_name, model_params, system_instruction) def get_messages( - self, input: str, chat_history: Optional[list[Any]] = None, system_instruction: Optional[str] = None + self, + input: str, + chat_history: Optional[list[Any]] = None, + system_instruction: Optional[str] = None, ) -> Iterable[ChatCompletionMessageParam]: messages = [] - system_message = system_instruction if system_instruction is not None else self.system_instruction + system_message = ( + system_instruction + if system_instruction is not None + else self.system_instruction + ) if system_message: messages.append(SystemMessage(content=system_message).model_dump()) if chat_history: @@ -77,7 +84,10 @@ def get_messages( return messages def invoke( - self, input: str, chat_history: Optional[list[Any]] = None, system_instruction: Optional[str] = None + self, + input: str, + chat_history: Optional[list[Any]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model and returns the response's content. diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 71ce6770..78ecdb56 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -104,7 +104,10 @@ def get_messages( return messages def invoke( - self, input: str, chat_history: Optional[list[Any]] = None, system_instruction: Optional[str] = None + self, + input: str, + chat_history: Optional[list[Any]] = None, + system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -116,9 +119,15 @@ def invoke( Returns: LLMResponse: The response from the LLM. """ - system_message = system_instruction if system_instruction is not None else self.system_instruction + system_message = ( + system_instruction + if system_instruction is not None + else self.system_instruction + ) self.model = GenerativeModel( - model_name=self.model_name, system_instruction=[system_message], **self.model_params + model_name=self.model_name, + system_instruction=[system_message], + **self.model_params, ) try: messages = self.get_messages(input, chat_history)