From 775447fad6a5dfe9fadd3f56d2f397fcf2c8713c Mon Sep 17 00:00:00 2001 From: Leila Messallem Date: Tue, 17 Dec 2024 12:10:54 +0100 Subject: [PATCH] Update tests --- tests/unit/llm/test_anthropic_llm.py | 2 +- tests/unit/llm/test_cohere_llm.py | 2 +- tests/unit/llm/test_mistralai_llm.py | 2 +- tests/unit/llm/test_openai_llm.py | 11 ++++------- tests/unit/llm/test_vertexai_llm.py | 15 +++++++++------ tests/unit/test_graphrag.py | 19 +++++++++++++------ 6 files changed, 29 insertions(+), 22 deletions(-) diff --git a/tests/unit/llm/test_anthropic_llm.py b/tests/unit/llm/test_anthropic_llm.py index 13e57498..5458926f 100644 --- a/tests/unit/llm/test_anthropic_llm.py +++ b/tests/unit/llm/test_anthropic_llm.py @@ -105,7 +105,7 @@ def test_anthropic_invoke_with_message_history_validation_error( with pytest.raises(LLMGenerationError) as exc_info: llm.invoke(question, message_history) - assert "Input should be 'user' or 'assistant'" in str(exc_info.value) + assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) @pytest.mark.asyncio diff --git a/tests/unit/llm/test_cohere_llm.py b/tests/unit/llm/test_cohere_llm.py index c0e76bb0..1fb7839a 100644 --- a/tests/unit/llm/test_cohere_llm.py +++ b/tests/unit/llm/test_cohere_llm.py @@ -89,7 +89,7 @@ def test_cohere_llm_invoke_with_message_history_validation_error( with pytest.raises(LLMGenerationError) as exc_info: llm.invoke(question, message_history) - assert "Input should be 'user' or 'assistant'" in str(exc_info.value) + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) @pytest.mark.asyncio diff --git a/tests/unit/llm/test_mistralai_llm.py b/tests/unit/llm/test_mistralai_llm.py index 4e7ee6fd..a56a118c 100644 --- a/tests/unit/llm/test_mistralai_llm.py +++ b/tests/unit/llm/test_mistralai_llm.py @@ -100,7 +100,7 @@ def test_mistralai_llm_invoke_with_message_history_validation_error( with pytest.raises(LLMGenerationError) as exc_info: llm.invoke(question, message_history) - assert "Input should be 'user' or 'assistant'" in str(exc_info.value) + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) @pytest.mark.asyncio diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 356cc4e4..9f097ebf 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -82,7 +82,7 @@ def test_openai_llm_with_message_history_validation_error(mock_import: Mock) -> with pytest.raises(LLMGenerationError) as exc_info: llm.invoke(question, message_history) - assert "Input should be 'user' or 'assistant'" in str(exc_info.value) + assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) @patch("builtins.__import__", side_effect=ImportError) @@ -158,13 +158,10 @@ def test_azure_openai_llm_with_message_history_validation_error( ) message_history = [ - {"content": "When does the sun come up in the summer?"}, + {"role": "user", "content": 33}, ] question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) - assert ( - "{'type': 'missing', 'loc': ('messages', 0, 'role'), 'msg': 'Field required'," - in str(exc_info.value) - ) + llm.invoke(question, message_history) # type: ignore + assert "Input should be a valid string" in str(exc_info.value) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 04363fe7..62370d73 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -30,15 +30,20 @@ def test_vertexai_llm_missing_dependency() -> None: @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: + system_instruction = "You are a helpful assistant." + model_name = "gemini-1.5-flash-001" + input_text = "may thy knife chip and shatter" mock_response = Mock() mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value mock_model.generate_content.return_value = mock_response model_params = {"temperature": 0.5} - llm = VertexAILLM("gemini-1.5-flash-001", model_params) - input_text = "may thy knife chip and shatter" + llm = VertexAILLM(model_name, model_params, system_instruction) response = llm.invoke(input_text) assert response.content == "Return text" + GenerativeModelMock.assert_called_once_with( + model_name=model_name, system_instruction=[system_instruction] + ) llm.model.generate_content.assert_called_once_with([mock.ANY], **model_params) @@ -67,9 +72,7 @@ def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) response = llm.get_messages(question, message_history) - GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[system_instruction] - ) + GenerativeModelMock.assert_not_called assert len(response) == len(expected_response) for actual, expected in zip(response, expected_response): assert actual.role == expected.role @@ -88,7 +91,7 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) with pytest.raises(LLMGenerationError) as exc_info: llm.invoke(question, message_history) - assert "Input should be 'user' or 'assistant'" in str(exc_info.value) + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) @pytest.mark.asyncio diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index 8d2e98dc..a9624480 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -113,19 +113,20 @@ def test_graphrag_happy_path_with_message_history( res = rag.search("question", message_history) expected_retriever_query_text = """ -Chat Summary: +Message Summary: llm generated summary Current Query: question """ - first_invokation = """ -Summarize the chat history: + first_invokation_input = """ +Summarize the message history: user: initial question assistant: answer to initial question """ + first_invokation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 200 words" second_invokation = """Answer the user question using the following context Context: @@ -146,7 +147,13 @@ def test_graphrag_happy_path_with_message_history( ) assert llm.invoke.call_count == 2 llm.invoke.assert_has_calls( - [call(first_invokation), call(second_invokation, message_history)] + [ + call( + input=first_invokation_input, + system_instruction=first_invokation_system_instruction, + ), + call(second_invokation, message_history), + ] ) assert isinstance(res, RagResultModel) @@ -185,7 +192,7 @@ def test_chat_summary_template() -> None: assert ( prompt == """ -Summarize the chat history: +Summarize the message history: user: initial question assistant: answer to initial question @@ -203,7 +210,7 @@ def test_conversation_template() -> None: assert ( prompt == """ -Chat Summary: +Message Summary: llm generated chat summary Current Query: