diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index e035048c..0722124b 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -1,6 +1,6 @@ import random import string -from typing import Any +from typing import Any, Optional from neo4j_graphrag.llm import LLMInterface, LLMResponse @@ -9,13 +9,23 @@ class CustomLLM(LLMInterface): def __init__(self, model_name: str, **kwargs: Any): super().__init__(model_name, **kwargs) - def invoke(self, input: str) -> LLMResponse: + def invoke( + self, + input: str, + message_history: Optional[list[dict[str, str]]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: content: str = ( self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30)) ) return LLMResponse(content=content) - async def ainvoke(self, input: str) -> LLMResponse: + async def ainvoke( + self, + input: str, + message_history: Optional[list[dict[str, str]]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: raise NotImplementedError() diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index 6b42cac5..3c2d7292 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -15,7 +15,7 @@ from __future__ import annotations import os -from typing import Any, Optional +from typing import Any, Optional, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError @@ -31,8 +31,8 @@ from mistralai import Mistral, Messages from mistralai.models.sdkerror import SDKError except ImportError: - Mistral = None - SDKError = None + Mistral = None # type: ignore + SDKError = None # type: ignore class MistralAILLM(LLMInterface): @@ -85,7 +85,7 @@ def get_messages( raise LLMGenerationError(e.errors()) from e messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) - return messages + return cast(list[Messages], messages) def invoke( self, diff --git a/tests/unit/llm/test_anthropic_llm.py b/tests/unit/llm/test_anthropic_llm.py index e17b1ec2..13e57498 100644 --- a/tests/unit/llm/test_anthropic_llm.py +++ b/tests/unit/llm/test_anthropic_llm.py @@ -76,7 +76,7 @@ def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock) response = llm.invoke(question, message_history) assert response.content == "generated text" message_history.append({"role": "user", "content": question}) - llm.client.messages.create.assert_called_once_with( + llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined] messages=message_history, model="claude-3-opus-20240229", system=system_instruction, diff --git a/tests/unit/llm/test_mistralai_llm.py b/tests/unit/llm/test_mistralai_llm.py index aab8eeb1..4e7ee6fd 100644 --- a/tests/unit/llm/test_mistralai_llm.py +++ b/tests/unit/llm/test_mistralai_llm.py @@ -71,7 +71,7 @@ def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None: messages = [{"role": "system", "content": system_instruction}] messages.extend(message_history) messages.append({"role": "user", "content": question}) - llm.client.chat.complete.assert_called_once_with( + llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined] messages=messages, model=model, ) diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index 779bf902..76272520 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -58,7 +58,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: {"role": "system", "content": system_instruction}, {"role": "user", "content": question}, ] - llm.client.chat.assert_called_once_with( + llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] model=model, messages=messages, options=model_params ) @@ -89,7 +89,7 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non messages = [{"role": "system", "content": system_instruction}] messages.extend(message_history) messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_once_with( + llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] model=model, messages=messages, options=model_params )