Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
leila-messallem committed Dec 17, 2024
1 parent 819179e commit 2143973
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 11 deletions.
16 changes: 13 additions & 3 deletions examples/customize/llms/custom_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
import string
from typing import Any
from typing import Any, Optional

from neo4j_graphrag.llm import LLMInterface, LLMResponse

Expand All @@ -9,13 +9,23 @@ class CustomLLM(LLMInterface):
def __init__(self, model_name: str, **kwargs: Any):
super().__init__(model_name, **kwargs)

def invoke(self, input: str) -> LLMResponse:
def invoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
content: str = (
self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30))
)
return LLMResponse(content=content)

async def ainvoke(self, input: str) -> LLMResponse:
async def ainvoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
raise NotImplementedError()


Expand Down
8 changes: 4 additions & 4 deletions src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import os
from typing import Any, Optional
from typing import Any, Optional, cast
from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
Expand All @@ -31,8 +31,8 @@
from mistralai import Mistral, Messages
from mistralai.models.sdkerror import SDKError
except ImportError:
Mistral = None
SDKError = None
Mistral = None # type: ignore
SDKError = None # type: ignore


class MistralAILLM(LLMInterface):
Expand Down Expand Up @@ -85,7 +85,7 @@ def get_messages(
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
messages.append(UserMessage(content=input).model_dump())
return messages
return cast(list[Messages], messages)

def invoke(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/llm/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock)
response = llm.invoke(question, message_history)
assert response.content == "generated text"
message_history.append({"role": "user", "content": question})
llm.client.messages.create.assert_called_once_with(
llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined]
messages=message_history,
model="claude-3-opus-20240229",
system=system_instruction,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/llm/test_mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None:
messages = [{"role": "system", "content": system_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.complete.assert_called_once_with(
llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined]
messages=messages,
model=model,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/llm/test_ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None:
{"role": "system", "content": system_instruction},
{"role": "user", "content": question},
]
llm.client.chat.assert_called_once_with(
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
model=model, messages=messages, options=model_params
)

Expand Down Expand Up @@ -89,7 +89,7 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non
messages = [{"role": "system", "content": system_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.assert_called_once_with(
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
model=model, messages=messages, options=model_params
)

Expand Down

0 comments on commit 2143973

Please sign in to comment.