From 37225fdd124d21701cbfa7a0f7256c3d91e8c3b4 Mon Sep 17 00:00:00 2001 From: Leila Messallem Date: Mon, 16 Dec 2024 14:56:33 +0100 Subject: [PATCH] Rename `chat_history` to `message_history` --- src/neo4j_graphrag/generation/graphrag.py | 14 ++++++------ src/neo4j_graphrag/generation/prompts.py | 14 ++++++------ src/neo4j_graphrag/llm/anthropic_llm.py | 20 ++++++++--------- src/neo4j_graphrag/llm/base.py | 8 +++---- src/neo4j_graphrag/llm/cohere_llm.py | 20 ++++++++--------- src/neo4j_graphrag/llm/mistralai_llm.py | 20 ++++++++--------- src/neo4j_graphrag/llm/ollama_llm.py | 16 +++++++------- src/neo4j_graphrag/llm/openai_llm.py | 20 ++++++++--------- src/neo4j_graphrag/llm/vertexai_llm.py | 20 ++++++++--------- tests/unit/llm/test_anthropic_llm.py | 16 +++++++------- tests/unit/llm/test_cohere_llm.py | 14 ++++++------ tests/unit/llm/test_mistralai_llm.py | 14 ++++++------ tests/unit/llm/test_ollama_llm.py | 14 ++++++------ tests/unit/llm/test_openai_llm.py | 26 ++++++++++++----------- tests/unit/llm/test_vertexai_llm.py | 8 +++---- tests/unit/test_graphrag.py | 12 +++++------ 16 files changed, 129 insertions(+), 127 deletions(-) diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index ea5e55d5..9aaee06d 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -87,7 +87,7 @@ def __init__( def search( self, query_text: str = "", - chat_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[dict[str, str]]] = None, examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool | None = None, @@ -105,7 +105,7 @@ def search( Args: query_text (str): The user question. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. examples (str): Examples added to the LLM prompt. retriever_config (Optional[dict]): Parameters passed to the retriever. search method; e.g.: top_k @@ -130,7 +130,7 @@ def search( ) except ValidationError as e: raise SearchValidationError(e.errors()) - query = self.build_query(validated_data.query_text, chat_history) + query = self.build_query(validated_data.query_text, message_history) retriever_result: RetrieverResult = self.retriever.search( query_text=query, **validated_data.retriever_config ) @@ -140,18 +140,18 @@ def search( ) logger.debug(f"RAG: retriever_result={retriever_result}") logger.debug(f"RAG: prompt={prompt}") - answer = self.llm.invoke(prompt, chat_history) + answer = self.llm.invoke(prompt, message_history) result: dict[str, Any] = {"answer": answer.content} if return_context: result["retriever_result"] = retriever_result return RagResultModel(**result) def build_query( - self, query_text: str, chat_history: Optional[list[dict[str, str]]] = None + self, query_text: str, message_history: Optional[list[dict[str, str]]] = None ) -> str: - if chat_history: + if message_history: summarization_prompt = ChatSummaryTemplate().format( - chat_history=chat_history + message_history=message_history ) summary = self.llm.invoke( input=summarization_prompt, diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index ca8451ab..862c5af0 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -200,25 +200,25 @@ def format( class ChatSummaryTemplate(PromptTemplate): DEFAULT_TEMPLATE = """ -Summarize the chat history: +Summarize the message history: -{chat_history} +{message_history} """ - EXPECTED_INPUTS = ["chat_history"] + EXPECTED_INPUTS = ["message_history"] SYSTEM_MESSAGE = "You are a summarization assistant. Summarize the given text in no more than 200 words" - def format(self, chat_history: list[dict[str, str]]) -> str: + def format(self, message_history: list[dict[str, str]]) -> str: message_list = [ ": ".join([f"{value}" for _, value in message.items()]) - for message in chat_history + for message in message_history ] history = "\n".join(message_list) - return super().format(chat_history=history) + return super().format(message_history=history) class ConversationTemplate(PromptTemplate): DEFAULT_TEMPLATE = """ -Chat Summary: +Message Summary: {summary} Current Query: diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 1487b1b9..a295f888 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -71,32 +71,32 @@ def __init__( self.async_client = anthropic.AsyncAnthropic(**kwargs) def get_messages( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> Iterable[MessageParam]: messages = [] - if chat_history: + if message_history: try: - MessageList(messages=chat_history) + MessageList(messages=message_history) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - messages.extend(chat_history) + messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) return messages def invoke( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. Returns: LLMResponse: The response from the LLM. """ try: - messages = self.get_messages(input, chat_history) + messages = self.get_messages(input, message_history) response = self.client.messages.create( model=self.model_name, system=self.system_instruction, @@ -108,19 +108,19 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. Returns: LLMResponse: The response from the LLM. """ try: - messages = self.get_messages(input, chat_history) + messages = self.get_messages(input, message_history) response = await self.async_client.messages.create( model=self.model_name, system=self.system_instruction, diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 5cc94291..48ffea3d 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -45,14 +45,14 @@ def __init__( def invoke( self, input: str, - chat_history: Optional[list[dict[str, str]]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. Args: input (str): Text sent to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: @@ -64,13 +64,13 @@ def invoke( @abstractmethod async def ainvoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, message_history: Optional[list[dict[str, str]]] = None ) -> LLMResponse: """Asynchronously sends a text input to the LLM and retrieves a response. Args: input (str): Text sent to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. Returns: diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index ceddd63b..8d04b8df 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -74,34 +74,34 @@ def __init__( self.async_client = cohere.AsyncClientV2(**kwargs) def get_messages( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> ChatMessages: messages = [] if self.system_instruction: messages.append(SystemMessage(content=self.system_instruction).model_dump()) - if chat_history: + if message_history: try: - MessageList(messages=chat_history) + MessageList(messages=message_history) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - messages.extend(chat_history) + messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) return messages def invoke( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. Returns: LLMResponse: The response from the LLM. """ try: - messages = self.get_messages(input, chat_history) + messages = self.get_messages(input, message_history) res = self.client.chat( messages=messages, model=self.model_name, @@ -113,19 +113,19 @@ def invoke( ) async def ainvoke( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. Returns: LLMResponse: The response from the LLM. """ try: - messages = self.get_messages(input, chat_history) + messages = self.get_messages(input, message_history) res = self.async_client.chat( messages=messages, model=self.model_name, diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index b95249c8..a3bab0e9 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -65,29 +65,29 @@ def __init__( self.client = Mistral(api_key=api_key, **kwargs) def get_messages( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> list[Messages]: messages = [] if self.system_instruction: messages.append(SystemMessage(content=self.system_instruction).model_dump()) - if chat_history: + if message_history: try: - MessageList(messages=chat_history) + MessageList(messages=message_history) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - messages.extend(chat_history) + messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) return messages def invoke( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> LLMResponse: """Sends a text input to the Mistral chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. Returns: LLMResponse: The response from MistralAI. @@ -96,7 +96,7 @@ def invoke( LLMGenerationError: If anything goes wrong. """ try: - messages = self.get_messages(input, chat_history) + messages = self.get_messages(input, message_history) response = self.client.chat.complete( model=self.model_name, messages=messages, @@ -112,14 +112,14 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> LLMResponse: """Asynchronously sends a text input to the MistralAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. Returns: LLMResponse: The response from MistralAI. @@ -128,7 +128,7 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ try: - messages = self.get_messages(input, chat_history) + messages = self.get_messages(input, message_history) response = await self.client.chat.complete_async( model=self.model_name, messages=messages, diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 9ae72c87..4e7a7b6f 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -50,27 +50,27 @@ def __init__( ) def get_messages( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> Sequence[Message]: messages = [] if self.system_instruction: messages.append(SystemMessage(content=self.system_instruction).model_dump()) - if chat_history: + if message_history: try: - MessageList(messages=chat_history) + MessageList(messages=message_history) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - messages.extend(chat_history) + messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) return messages def invoke( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> LLMResponse: try: response = self.client.chat( model=self.model_name, - messages=self.get_messages(input, chat_history), + messages=self.get_messages(input, message_history), options=self.model_params, ) content = response.message.content or "" @@ -79,12 +79,12 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> LLMResponse: try: response = await self.async_client.chat( model=self.model_name, - messages=self.get_messages(input, chat_history), + messages=self.get_messages(input, message_history), options=self.model_params, ) content = response.message.content or "" diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 895e5028..3a1abf09 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -63,7 +63,7 @@ def __init__( def get_messages( self, input: str, - chat_history: Optional[list[Any]] = None, + message_history: Optional[list[Any]] = None, system_instruction: Optional[str] = None, ) -> Iterable[ChatCompletionMessageParam]: messages = [] @@ -74,19 +74,19 @@ def get_messages( ) if system_message: messages.append(SystemMessage(content=system_message).model_dump()) - if chat_history: + if message_history: try: - MessageList(messages=chat_history) + MessageList(messages=message_history) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - messages.extend(chat_history) + messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) return messages def invoke( self, input: str, - chat_history: Optional[list[Any]] = None, + message_history: Optional[list[Any]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model @@ -94,7 +94,7 @@ def invoke( Args: input (str): Text sent to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: @@ -105,7 +105,7 @@ def invoke( """ try: response = self.client.chat.completions.create( - messages=self.get_messages(input, chat_history, system_instruction), + messages=self.get_messages(input, message_history, system_instruction), model=self.model_name, **self.model_params, ) @@ -115,14 +115,14 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. Returns: LLMResponse: The response from OpenAI. @@ -132,7 +132,7 @@ async def ainvoke( """ try: response = await self.async_client.chat.completions.create( - messages=self.get_messages(input, chat_history), + messages=self.get_messages(input, message_history), model=self.model_name, **self.model_params, ) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 78ecdb56..d8510c41 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -77,16 +77,16 @@ def __init__( self.model_params = kwargs def get_messages( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> list[Content]: messages = [] - if chat_history: + if message_history: try: - MessageList(messages=chat_history) + MessageList(messages=message_history) except ValidationError as e: raise LLMGenerationError(e.errors()) from e - for message in chat_history: + for message in message_history: if message.get("role") == "user": messages.append( Content( @@ -106,14 +106,14 @@ def get_messages( def invoke( self, input: str, - chat_history: Optional[list[Any]] = None, + message_history: Optional[list[Any]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. system_instruction (Optional[str]): An option to override the llm system message for this invokation. Returns: @@ -130,26 +130,26 @@ def invoke( **self.model_params, ) try: - messages = self.get_messages(input, chat_history) + messages = self.get_messages(input, message_history) response = self.model.generate_content(messages, **self.model_params) return LLMResponse(content=response.text) except ResponseValidationError as e: raise LLMGenerationError(e) async def ainvoke( - self, input: str, chat_history: Optional[list[Any]] = None + self, input: str, message_history: Optional[list[Any]] = None ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. + message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. Returns: LLMResponse: The response from the LLM. """ try: - messages = self.get_messages(input, chat_history) + messages = self.get_messages(input, message_history) response = await self.model.generate_content_async( messages, **self.model_params ) diff --git a/tests/unit/llm/test_anthropic_llm.py b/tests/unit/llm/test_anthropic_llm.py index 917a3de4..e17b1ec2 100644 --- a/tests/unit/llm/test_anthropic_llm.py +++ b/tests/unit/llm/test_anthropic_llm.py @@ -56,7 +56,7 @@ def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None: ) -def test_anthropic_invoke_with_chat_history_happy_path(mock_anthropic: Mock) -> None: +def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock) -> None: mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( content="generated text" ) @@ -67,24 +67,24 @@ def test_anthropic_invoke_with_chat_history_happy_path(mock_anthropic: Mock) -> model_params=model_params, system_instruction=system_instruction, ) - chat_history = [ + message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" - response = llm.invoke(question, chat_history) + response = llm.invoke(question, message_history) assert response.content == "generated text" - chat_history.append({"role": "user", "content": question}) + message_history.append({"role": "user", "content": question}) llm.client.messages.create.assert_called_once_with( - messages=chat_history, + messages=message_history, model="claude-3-opus-20240229", system=system_instruction, **model_params, ) -def test_anthropic_invoke_with_chat_history_validation_error( +def test_anthropic_invoke_with_message_history_validation_error( mock_anthropic: Mock, ) -> None: mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( @@ -97,14 +97,14 @@ def test_anthropic_invoke_with_chat_history_validation_error( model_params=model_params, system_instruction=system_instruction, ) - chat_history = [ + message_history = [ {"role": "human", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, chat_history) + llm.invoke(question, message_history) assert "Input should be 'user' or 'assistant'" in str(exc_info.value) diff --git a/tests/unit/llm/test_cohere_llm.py b/tests/unit/llm/test_cohere_llm.py index b69e02e8..c0e76bb0 100644 --- a/tests/unit/llm/test_cohere_llm.py +++ b/tests/unit/llm/test_cohere_llm.py @@ -47,24 +47,24 @@ def test_cohere_llm_happy_path(mock_cohere: Mock) -> None: assert res.content == "cohere response text" -def test_cohere_llm_invoke_with_chat_history_happy_path(mock_cohere: Mock) -> None: +def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) -> None: chat_response_mock = MagicMock() chat_response_mock.message.content = [MagicMock(text="cohere response text")] mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock system_instruction = "You are a helpful assistant." llm = CohereLLM(model_name="something", system_instruction=system_instruction) - chat_history = [ + message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" - res = llm.invoke(question, chat_history) + res = llm.invoke(question, message_history) assert isinstance(res, LLMResponse) assert res.content == "cohere response text" messages = [{"role": "system", "content": system_instruction}] - messages.extend(chat_history) + messages.extend(message_history) messages.append({"role": "user", "content": question}) llm.client.chat.assert_called_once_with( messages=messages, @@ -72,7 +72,7 @@ def test_cohere_llm_invoke_with_chat_history_happy_path(mock_cohere: Mock) -> No ) -def test_cohere_llm_invoke_with_chat_history_validation_error( +def test_cohere_llm_invoke_with_message_history_validation_error( mock_cohere: Mock, ) -> None: chat_response_mock = MagicMock() @@ -81,14 +81,14 @@ def test_cohere_llm_invoke_with_chat_history_validation_error( system_instruction = "You are a helpful assistant." llm = CohereLLM(model_name="something", system_instruction=system_instruction) - chat_history = [ + message_history = [ {"role": "robot", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, chat_history) + llm.invoke(question, message_history) assert "Input should be 'user' or 'assistant'" in str(exc_info.value) diff --git a/tests/unit/llm/test_mistralai_llm.py b/tests/unit/llm/test_mistralai_llm.py index 938b7364..aab8eeb1 100644 --- a/tests/unit/llm/test_mistralai_llm.py +++ b/tests/unit/llm/test_mistralai_llm.py @@ -47,7 +47,7 @@ def test_mistralai_llm_invoke(mock_mistral: Mock) -> None: @patch("neo4j_graphrag.llm.mistralai_llm.Mistral") -def test_mistralai_llm_invoke_with_chat_history(mock_mistral: Mock) -> None: +def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None: mock_mistral_instance = mock_mistral.return_value chat_response_mock = MagicMock() chat_response_mock.choices = [ @@ -59,17 +59,17 @@ def test_mistralai_llm_invoke_with_chat_history(mock_mistral: Mock) -> None: llm = MistralAILLM(model_name=model, system_instruction=system_instruction) - chat_history = [ + message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" - res = llm.invoke(question, chat_history) + res = llm.invoke(question, message_history) assert isinstance(res, LLMResponse) assert res.content == "mistral response" messages = [{"role": "system", "content": system_instruction}] - messages.extend(chat_history) + messages.extend(message_history) messages.append({"role": "user", "content": question}) llm.client.chat.complete.assert_called_once_with( messages=messages, @@ -78,7 +78,7 @@ def test_mistralai_llm_invoke_with_chat_history(mock_mistral: Mock) -> None: @patch("neo4j_graphrag.llm.mistralai_llm.Mistral") -def test_mistralai_llm_invoke_with_chat_history_validation_error( +def test_mistralai_llm_invoke_with_message_history_validation_error( mock_mistral: Mock, ) -> None: mock_mistral_instance = mock_mistral.return_value @@ -92,14 +92,14 @@ def test_mistralai_llm_invoke_with_chat_history_validation_error( llm = MistralAILLM(model_name=model, system_instruction=system_instruction) - chat_history = [ + message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "monkey", "content": "Usually around 6am."}, ] question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, chat_history) + llm.invoke(question, message_history) assert "Input should be 'user' or 'assistant'" in str(exc_info.value) diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index 6f86b8fb..779bf902 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -64,7 +64,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_ollama_invoke_with_chat_history_happy_path(mock_import: Mock) -> None: +def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> None: mock_ollama = get_mock_ollama() mock_import.return_value = mock_ollama mock_ollama.Client.return_value.chat.return_value = MagicMock( @@ -78,16 +78,16 @@ def test_ollama_invoke_with_chat_history_happy_path(mock_import: Mock) -> None: model_params=model_params, system_instruction=system_instruction, ) - chat_history = [ + message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" - response = llm.invoke(question, chat_history) + response = llm.invoke(question, message_history) assert response.content == "ollama chat response" messages = [{"role": "system", "content": system_instruction}] - messages.extend(chat_history) + messages.extend(message_history) messages.append({"role": "user", "content": question}) llm.client.chat.assert_called_once_with( model=model, messages=messages, options=model_params @@ -95,7 +95,7 @@ def test_ollama_invoke_with_chat_history_happy_path(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_ollama_invoke_with_chat_history_validation_error(mock_import: Mock) -> None: +def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) -> None: mock_ollama = get_mock_ollama() mock_import.return_value = mock_ollama mock_ollama.ResponseError = ollama.ResponseError @@ -107,14 +107,14 @@ def test_ollama_invoke_with_chat_history_validation_error(mock_import: Mock) -> model_params=model_params, system_instruction=system_instruction, ) - chat_history = [ + message_history = [ {"role": "human", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, chat_history) + llm.invoke(question, message_history) assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 7eb3af3e..356cc4e4 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -48,40 +48,40 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_openai_llm_with_chat_history_happy_path(mock_import: Mock) -> None: +def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: mock_openai = get_mock_openai() mock_import.return_value = mock_openai mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( choices=[MagicMock(message=MagicMock(content="openai chat response"))], ) llm = OpenAILLM(api_key="my key", model_name="gpt") - chat_history = [ + message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" - res = llm.invoke(question, chat_history) + res = llm.invoke(question, message_history) assert isinstance(res, LLMResponse) assert res.content == "openai chat response" @patch("builtins.__import__") -def test_openai_llm_with_chat_history_validation_error(mock_import: Mock) -> None: +def test_openai_llm_with_message_history_validation_error(mock_import: Mock) -> None: mock_openai = get_mock_openai() mock_import.return_value = mock_openai mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( choices=[MagicMock(message=MagicMock(content="openai chat response"))], ) llm = OpenAILLM(api_key="my key", model_name="gpt") - chat_history = [ + message_history = [ {"role": "human", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, chat_history) + llm.invoke(question, message_history) assert "Input should be 'user' or 'assistant'" in str(exc_info.value) @@ -113,7 +113,7 @@ def test_azure_openai_llm_happy_path(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_azure_openai_llm_with_chat_history_happy_path(mock_import: Mock) -> None: +def test_azure_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: mock_openai = get_mock_openai() mock_import.return_value = mock_openai mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( @@ -128,19 +128,21 @@ def test_azure_openai_llm_with_chat_history_happy_path(mock_import: Mock) -> Non api_version="version", ) - chat_history = [ + message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" - res = llm.invoke(question, chat_history) + res = llm.invoke(question, message_history) assert isinstance(res, LLMResponse) assert res.content == "openai chat response" @patch("builtins.__import__") -def test_azure_openai_llm_with_chat_history_validation_error(mock_import: Mock) -> None: +def test_azure_openai_llm_with_message_history_validation_error( + mock_import: Mock, +) -> None: mock_openai = get_mock_openai() mock_import.return_value = mock_openai mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( @@ -155,13 +157,13 @@ def test_azure_openai_llm_with_chat_history_validation_error(mock_import: Mock) api_version="version", ) - chat_history = [ + message_history = [ {"content": "When does the sun come up in the summer?"}, ] question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, chat_history) + llm.invoke(question, message_history) assert ( "{'type': 'missing', 'loc': ('messages', 0, 'role'), 'msg': 'Field required'," in str(exc_info.value) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 17201502..04363fe7 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -47,7 +47,7 @@ def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: system_instruction = "You are a helpful assistant." model_name = "gemini-1.5-flash-001" question = "When does it set?" - chat_history = [ + message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, {"role": "user", "content": "What about next season?"}, @@ -65,7 +65,7 @@ def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: ] llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) - response = llm.get_messages(question, chat_history) + response = llm.get_messages(question, message_history) GenerativeModelMock.assert_called_once_with( model_name=model_name, system_instruction=[system_instruction] @@ -81,13 +81,13 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) system_instruction = "You are a helpful assistant." model_name = "gemini-1.5-flash-001" question = "hi!" - chat_history = [ + message_history = [ {"role": "model", "content": "hello!"}, ] llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, chat_history) + llm.invoke(question, message_history) assert "Input should be 'user' or 'assistant'" in str(exc_info.value) diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index b5e91e99..8d2e98dc 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -89,7 +89,7 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None: assert res.retriever_result is None -def test_graphrag_happy_path_with_chat_history( +def test_graphrag_happy_path_with_message_history( retriever_mock: MagicMock, llm: MagicMock ) -> None: rag = GraphRAG( @@ -106,11 +106,11 @@ def test_graphrag_happy_path_with_chat_history( LLMResponse(content="llm generated summary"), LLMResponse(content="llm generated text"), ] - chat_history = [ + message_history = [ {"role": "user", "content": "initial question"}, {"role": "assistant", "content": "answer to initial question"}, ] - res = rag.search("question", chat_history) + res = rag.search("question", message_history) expected_retriever_query_text = """ Chat Summary: @@ -146,7 +146,7 @@ def test_graphrag_happy_path_with_chat_history( ) assert llm.invoke.call_count == 2 llm.invoke.assert_has_calls( - [call(first_invokation), call(second_invokation, chat_history)] + [call(first_invokation), call(second_invokation, message_history)] ) assert isinstance(res, RagResultModel) @@ -174,14 +174,14 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non def test_chat_summary_template() -> None: - chat_history = [ + message_history = [ {"role": "user", "content": "initial question"}, {"role": "assistant", "content": "answer to initial question"}, {"role": "user", "content": "second question"}, {"role": "assistant", "content": "answer to second question"}, ] template = ChatSummaryTemplate() - prompt = template.format(chat_history=chat_history) + prompt = template.format(message_history=message_history) assert ( prompt == """