Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
leila-messallem committed Dec 17, 2024
1 parent 2143973 commit 775447f
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 22 deletions.
2 changes: 1 addition & 1 deletion tests/unit/llm/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_anthropic_invoke_with_message_history_validation_error(

with pytest.raises(LLMGenerationError) as exc_info:
llm.invoke(question, message_history)
assert "Input should be 'user' or 'assistant'" in str(exc_info.value)
assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value)


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/llm/test_cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_cohere_llm_invoke_with_message_history_validation_error(

with pytest.raises(LLMGenerationError) as exc_info:
llm.invoke(question, message_history)
assert "Input should be 'user' or 'assistant'" in str(exc_info.value)
assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value)


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/llm/test_mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_mistralai_llm_invoke_with_message_history_validation_error(

with pytest.raises(LLMGenerationError) as exc_info:
llm.invoke(question, message_history)
assert "Input should be 'user' or 'assistant'" in str(exc_info.value)
assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value)


@pytest.mark.asyncio
Expand Down
11 changes: 4 additions & 7 deletions tests/unit/llm/test_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_openai_llm_with_message_history_validation_error(mock_import: Mock) ->

with pytest.raises(LLMGenerationError) as exc_info:
llm.invoke(question, message_history)
assert "Input should be 'user' or 'assistant'" in str(exc_info.value)
assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value)


@patch("builtins.__import__", side_effect=ImportError)
Expand Down Expand Up @@ -158,13 +158,10 @@ def test_azure_openai_llm_with_message_history_validation_error(
)

message_history = [
{"content": "When does the sun come up in the summer?"},
{"role": "user", "content": 33},
]
question = "What about next season?"

with pytest.raises(LLMGenerationError) as exc_info:
llm.invoke(question, message_history)
assert (
"{'type': 'missing', 'loc': ('messages', 0, 'role'), 'msg': 'Field required',"
in str(exc_info.value)
)
llm.invoke(question, message_history) # type: ignore
assert "Input should be a valid string" in str(exc_info.value)
15 changes: 9 additions & 6 deletions tests/unit/llm/test_vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,20 @@ def test_vertexai_llm_missing_dependency() -> None:

@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None:
system_instruction = "You are a helpful assistant."
model_name = "gemini-1.5-flash-001"
input_text = "may thy knife chip and shatter"
mock_response = Mock()
mock_response.text = "Return text"
mock_model = GenerativeModelMock.return_value
mock_model.generate_content.return_value = mock_response
model_params = {"temperature": 0.5}
llm = VertexAILLM("gemini-1.5-flash-001", model_params)
input_text = "may thy knife chip and shatter"
llm = VertexAILLM(model_name, model_params, system_instruction)
response = llm.invoke(input_text)
assert response.content == "Return text"
GenerativeModelMock.assert_called_once_with(
model_name=model_name, system_instruction=[system_instruction]
)
llm.model.generate_content.assert_called_once_with([mock.ANY], **model_params)


Expand Down Expand Up @@ -67,9 +72,7 @@ def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None:
llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction)
response = llm.get_messages(question, message_history)

GenerativeModelMock.assert_called_once_with(
model_name=model_name, system_instruction=[system_instruction]
)
GenerativeModelMock.assert_not_called
assert len(response) == len(expected_response)
for actual, expected in zip(response, expected_response):
assert actual.role == expected.role
Expand All @@ -88,7 +91,7 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock)
llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction)
with pytest.raises(LLMGenerationError) as exc_info:
llm.invoke(question, message_history)
assert "Input should be 'user' or 'assistant'" in str(exc_info.value)
assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value)


@pytest.mark.asyncio
Expand Down
19 changes: 13 additions & 6 deletions tests/unit/test_graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,20 @@ def test_graphrag_happy_path_with_message_history(
res = rag.search("question", message_history)

expected_retriever_query_text = """
Chat Summary:
Message Summary:
llm generated summary
Current Query:
question
"""

first_invokation = """
Summarize the chat history:
first_invokation_input = """
Summarize the message history:
user: initial question
assistant: answer to initial question
"""
first_invokation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 200 words"
second_invokation = """Answer the user question using the following context
Context:
Expand All @@ -146,7 +147,13 @@ def test_graphrag_happy_path_with_message_history(
)
assert llm.invoke.call_count == 2
llm.invoke.assert_has_calls(
[call(first_invokation), call(second_invokation, message_history)]
[
call(
input=first_invokation_input,
system_instruction=first_invokation_system_instruction,
),
call(second_invokation, message_history),
]
)

assert isinstance(res, RagResultModel)
Expand Down Expand Up @@ -185,7 +192,7 @@ def test_chat_summary_template() -> None:
assert (
prompt
== """
Summarize the chat history:
Summarize the message history:
user: initial question
assistant: answer to initial question
Expand All @@ -203,7 +210,7 @@ def test_conversation_template() -> None:
assert (
prompt
== """
Chat Summary:
Message Summary:
llm generated chat summary
Current Query:
Expand Down

0 comments on commit 775447f

Please sign in to comment.