From d5a287b94aaeb33bbfe9634e8b2715fd26fdab38 Mon Sep 17 00:00:00 2001 From: Leila Messallem Date: Wed, 18 Dec 2024 11:29:38 +0100 Subject: [PATCH] Add TypeDict `LLMMessage` * to help with the type declaration of the message history --- examples/customize/llms/custom_llm.py | 5 +++-- src/neo4j_graphrag/generation/graphrag.py | 7 ++++--- src/neo4j_graphrag/llm/anthropic_llm.py | 22 ++++++++++++++-------- src/neo4j_graphrag/llm/base.py | 6 +++--- src/neo4j_graphrag/llm/cohere_llm.py | 14 ++++++++------ src/neo4j_graphrag/llm/mistralai_llm.py | 14 ++++++++------ src/neo4j_graphrag/llm/ollama_llm.py | 21 ++++++++++++++------- src/neo4j_graphrag/llm/openai_llm.py | 21 ++++++++++++++------- src/neo4j_graphrag/llm/types.py | 7 ++++++- src/neo4j_graphrag/llm/vertexai_llm.py | 12 ++++++------ tests/unit/llm/test_anthropic_llm.py | 4 ++-- tests/unit/llm/test_cohere_llm.py | 4 ++-- tests/unit/llm/test_mistralai_llm.py | 4 ++-- tests/unit/llm/test_ollama_llm.py | 4 ++-- tests/unit/llm/test_openai_llm.py | 6 +++--- tests/unit/llm/test_vertexai_llm.py | 6 ++++-- tests/unit/test_graphrag.py | 2 +- 17 files changed, 96 insertions(+), 63 deletions(-) diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index 0722124b..e7fb2118 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -3,6 +3,7 @@ from typing import Any, Optional from neo4j_graphrag.llm import LLMInterface, LLMResponse +from neo4j_graphrag.llm.types import LLMMessage class CustomLLM(LLMInterface): @@ -12,7 +13,7 @@ def __init__(self, model_name: str, **kwargs: Any): def invoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: content: str = ( @@ -23,7 +24,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: raise NotImplementedError() diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 5889c867..869da4f0 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -31,6 +31,7 @@ ) from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.llm.types import LLMMessage from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import RetrieverResult @@ -87,7 +88,7 @@ def __init__( def search( self, query_text: str = "", - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool | None = None, @@ -147,11 +148,11 @@ def search( return RagResultModel(**result) def build_query( - self, query_text: str, message_history: Optional[list[dict[str, str]]] = None + self, query_text: str, message_history: Optional[list[LLMMessage]] = None ) -> str: if message_history: summarization_prompt = ChatSummaryTemplate().format( - message_history=message_history + message_history=message_history # type: ignore ) summary = self.llm.invoke( input=summarization_prompt, diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index f94e761f..04d6555e 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -13,13 +13,19 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Iterable, Optional, TYPE_CHECKING +from typing import Any, Iterable, Optional, TYPE_CHECKING, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage +from neo4j_graphrag.llm.types import ( + BaseMessage, + LLMMessage, + LLMResponse, + MessageList, + UserMessage, +) if TYPE_CHECKING: from anthropic.types.message_param import MessageParam @@ -71,22 +77,22 @@ def __init__( self.async_client = anthropic.AsyncAnthropic(**kwargs) def get_messages( - self, input: str, message_history: Optional[list[dict[str, str]]] = None + self, input: str, message_history: Optional[list[LLMMessage]] = None ) -> Iterable[MessageParam]: - messages = [] + messages: list[dict[str, str]] = [] if message_history: try: - MessageList(messages=message_history) # type: ignore + MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - messages.extend(message_history) + messages.extend(cast(Iterable[dict[str, Any]], message_history)) messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore def invoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -119,7 +125,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 488cf4c9..eab3eb4f 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from typing import Any, Optional -from .types import LLMResponse +from .types import LLMMessage, LLMResponse class LLMInterface(ABC): @@ -45,7 +45,7 @@ def __init__( def invoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. @@ -66,7 +66,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the LLM and retrieves a response. diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 8c7f9e28..63e54aaa 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -14,12 +14,14 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.llm.types import ( + BaseMessage, + LLMMessage, LLMResponse, MessageList, SystemMessage, @@ -76,7 +78,7 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> ChatMessages: messages = [] @@ -89,17 +91,17 @@ def get_messages( messages.append(SystemMessage(content=system_message).model_dump()) if message_history: try: - MessageList(messages=message_history) # type: ignore + MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - messages.extend(message_history) + messages.extend(cast(Iterable[dict[str, Any]], message_history)) messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore def invoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -127,7 +129,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index 3c2d7292..a3c84759 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -15,12 +15,14 @@ from __future__ import annotations import os -from typing import Any, Optional, cast +from typing import Any, Iterable, Optional, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.llm.types import ( + BaseMessage, + LLMMessage, LLMResponse, MessageList, SystemMessage, @@ -67,7 +69,7 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> list[Messages]: messages = [] @@ -80,17 +82,17 @@ def get_messages( messages.append(SystemMessage(content=system_message).model_dump()) if message_history: try: - MessageList(messages=message_history) # type: ignore + MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - messages.extend(message_history) + messages.extend(cast(Iterable[dict[str, Any]], message_history)) messages.append(UserMessage(content=input).model_dump()) return cast(list[Messages], messages) def invoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the Mistral chat completion model @@ -126,7 +128,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the MistralAI chat diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index e88a0f07..a36d34f9 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -13,14 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from typing import Any, Optional, Sequence, TYPE_CHECKING +from typing import Any, Iterable, Optional, Sequence, TYPE_CHECKING, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from .base import LLMInterface -from .types import LLMResponse, SystemMessage, UserMessage, MessageList +from .types import ( + BaseMessage, + LLMMessage, + LLMResponse, + SystemMessage, + UserMessage, + MessageList, +) if TYPE_CHECKING: from ollama import Message @@ -53,7 +60,7 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> Sequence[Message]: messages = [] @@ -66,17 +73,17 @@ def get_messages( messages.append(SystemMessage(content=system_message).model_dump()) if message_history: try: - MessageList(messages=message_history) # type: ignore + MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - messages.extend(message_history) + messages.extend(cast(Iterable[dict[str, Any]], message_history)) messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore def invoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -103,7 +110,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 4e17bfb8..b3d99411 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -15,13 +15,20 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any, Iterable, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional, cast from pydantic import ValidationError from ..exceptions import LLMGenerationError from .base import LLMInterface -from .types import LLMResponse, SystemMessage, UserMessage, MessageList +from .types import ( + BaseMessage, + LLMMessage, + LLMResponse, + SystemMessage, + UserMessage, + MessageList, +) if TYPE_CHECKING: import openai @@ -63,7 +70,7 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> Iterable[ChatCompletionMessageParam]: messages = [] @@ -76,17 +83,17 @@ def get_messages( messages.append(SystemMessage(content=system_message).model_dump()) if message_history: try: - MessageList(messages=message_history) # type: ignore + MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - messages.extend(message_history) + messages.extend(cast(Iterable[dict[str, Any]], message_history)) messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore def invoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model @@ -117,7 +124,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat diff --git a/src/neo4j_graphrag/llm/types.py b/src/neo4j_graphrag/llm/types.py index a6888475..77e89aef 100644 --- a/src/neo4j_graphrag/llm/types.py +++ b/src/neo4j_graphrag/llm/types.py @@ -1,11 +1,16 @@ from pydantic import BaseModel -from typing import Literal +from typing import Literal, TypedDict class LLMResponse(BaseModel): content: str +class LLMMessage(TypedDict): + role: Literal["system", "user", "assistant"] + content: str + + class BaseMessage(BaseModel): role: Literal["user", "assistant", "system"] content: str diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 2ca506e0..48acfc9f 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -13,13 +13,13 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import LLMResponse, MessageList +from neo4j_graphrag.llm.types import BaseMessage, LLMMessage, LLMResponse, MessageList try: from vertexai.generative_models import ( @@ -77,12 +77,12 @@ def __init__( self.options = kwargs def get_messages( - self, input: str, message_history: Optional[list[dict[str, str]]] = None + self, input: str, message_history: Optional[list[LLMMessage]] = None ) -> list[Content]: messages = [] if message_history: try: - MessageList(messages=message_history) # type: ignore + MessageList(messages=cast(list[BaseMessage], message_history)) except ValidationError as e: raise LLMGenerationError(e.errors()) from e @@ -106,7 +106,7 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -139,7 +139,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[LLMMessage]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. diff --git a/tests/unit/llm/test_anthropic_llm.py b/tests/unit/llm/test_anthropic_llm.py index 5458926f..2b29908c 100644 --- a/tests/unit/llm/test_anthropic_llm.py +++ b/tests/unit/llm/test_anthropic_llm.py @@ -73,7 +73,7 @@ def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock) ] question = "What about next season?" - response = llm.invoke(question, message_history) + response = llm.invoke(question, message_history) # type: ignore assert response.content == "generated text" message_history.append({"role": "user", "content": question}) llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined] @@ -104,7 +104,7 @@ def test_anthropic_invoke_with_message_history_validation_error( question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) + llm.invoke(question, message_history) # type: ignore assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) diff --git a/tests/unit/llm/test_cohere_llm.py b/tests/unit/llm/test_cohere_llm.py index 1fb7839a..0835343e 100644 --- a/tests/unit/llm/test_cohere_llm.py +++ b/tests/unit/llm/test_cohere_llm.py @@ -60,7 +60,7 @@ def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) -> ] question = "What about next season?" - res = llm.invoke(question, message_history) + res = llm.invoke(question, message_history) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "cohere response text" messages = [{"role": "system", "content": system_instruction}] @@ -88,7 +88,7 @@ def test_cohere_llm_invoke_with_message_history_validation_error( question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) + llm.invoke(question, message_history) # type: ignore assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) diff --git a/tests/unit/llm/test_mistralai_llm.py b/tests/unit/llm/test_mistralai_llm.py index a56a118c..99a4592f 100644 --- a/tests/unit/llm/test_mistralai_llm.py +++ b/tests/unit/llm/test_mistralai_llm.py @@ -64,7 +64,7 @@ def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None: {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" - res = llm.invoke(question, message_history) + res = llm.invoke(question, message_history) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "mistral response" @@ -99,7 +99,7 @@ def test_mistralai_llm_invoke_with_message_history_validation_error( question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) + llm.invoke(question, message_history) # type: ignore assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index 76272520..1146f60a 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -84,7 +84,7 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non ] question = "What about next season?" - response = llm.invoke(question, message_history) + response = llm.invoke(question, message_history) # type: ignore assert response.content == "ollama chat response" messages = [{"role": "system", "content": system_instruction}] messages.extend(message_history) @@ -114,7 +114,7 @@ def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) + llm.invoke(question, message_history) # type: ignore assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 9f097ebf..6e423b28 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -61,7 +61,7 @@ def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: ] question = "What about next season?" - res = llm.invoke(question, message_history) + res = llm.invoke(question, message_history) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "openai chat response" @@ -81,7 +81,7 @@ def test_openai_llm_with_message_history_validation_error(mock_import: Mock) -> question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) + llm.invoke(question, message_history) # type: ignore assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) @@ -134,7 +134,7 @@ def test_azure_openai_llm_with_message_history_happy_path(mock_import: Mock) -> ] question = "What about next season?" - res = llm.invoke(question, message_history) + res = llm.invoke(question, message_history) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "openai chat response" diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 62370d73..0c178684 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -13,11 +13,13 @@ # limitations under the License. from __future__ import annotations +from typing import cast from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from neo4j_graphrag.exceptions import LLMGenerationError +from neo4j_graphrag.llm.types import LLMMessage from neo4j_graphrag.llm.vertexai_llm import VertexAILLM from vertexai.generative_models import Content, Part @@ -70,7 +72,7 @@ def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: ] llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) - response = llm.get_messages(question, message_history) + response = llm.get_messages(question, cast(list[LLMMessage], message_history)) GenerativeModelMock.assert_not_called assert len(response) == len(expected_response) @@ -90,7 +92,7 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) + llm.invoke(question, cast(list[LLMMessage], message_history)) assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index a9624480..744a37a5 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -110,7 +110,7 @@ def test_graphrag_happy_path_with_message_history( {"role": "user", "content": "initial question"}, {"role": "assistant", "content": "answer to initial question"}, ] - res = rag.search("question", message_history) + res = rag.search("question", message_history) # type: ignore expected_retriever_query_text = """ Message Summary: