diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 9699e562..f94e761f 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Iterable, Optional, TYPE_CHECKING, cast +from typing import Any, Iterable, Optional, TYPE_CHECKING from pydantic import ValidationError @@ -81,7 +81,7 @@ def get_messages( raise LLMGenerationError(e.errors()) from e messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) - return cast(Iterable[MessageParam], messages) + return messages # type: ignore def invoke( self, diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index acdf669e..8c7f9e28 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -14,7 +14,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError @@ -94,7 +94,7 @@ def get_messages( raise LLMGenerationError(e.errors()) from e messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) - return cast(ChatMessages, messages) + return messages # type: ignore def invoke( self, diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index f358e2ac..e88a0f07 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from typing import Any, Optional, Sequence, TYPE_CHECKING, cast +from typing import Any, Optional, Sequence, TYPE_CHECKING from pydantic import ValidationError @@ -71,7 +71,7 @@ def get_messages( raise LLMGenerationError(e.errors()) from e messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) - return cast(Sequence[Message], messages) + return messages # type: ignore def invoke( self, diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index bc29469f..4e17bfb8 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -15,7 +15,7 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any, Iterable, Optional, cast +from typing import TYPE_CHECKING, Any, Iterable, Optional from pydantic import ValidationError @@ -81,7 +81,7 @@ def get_messages( raise LLMGenerationError(e.errors()) from e messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) - return cast(Iterable[ChatCompletionMessageParam], messages) + return messages # type: ignore def invoke( self,