Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Jul 5, 2024
1 parent b1e90b3 commit d8ae66b
Showing 1 changed file with 29 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
Expand All @@ -28,7 +29,13 @@ def magic_function(input: int) -> int:
return input + 2


def _validate_tool_call_message(message: AIMessage) -> None:
@tool
def magic_function_no_args() -> int:
"""Calculates a magic function."""
return 5


def _validate_tool_call_message(message: BaseMessage) -> None:
assert isinstance(message, AIMessage)
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
Expand All @@ -37,6 +44,15 @@ def _validate_tool_call_message(message: AIMessage) -> None:
assert tool_call["id"] is not None


def _validate_tool_call_message_no_args(message: BaseMessage) -> None:
assert isinstance(message, AIMessage)
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call["name"] == "magic_function_no_args"
assert tool_call["args"] == {}
assert tool_call["id"] is not None


class ChatModelIntegrationTests(ChatModelTests):
def test_invoke(self, model: BaseChatModel) -> None:
result = model.invoke("Hello")
Expand Down Expand Up @@ -131,7 +147,6 @@ def test_tool_calling(self, model: BaseChatModel) -> None:
# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)
_validate_tool_call_message(result)

# Test stream
Expand All @@ -141,6 +156,18 @@ def test_tool_calling(self, model: BaseChatModel) -> None:
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)

# Test tool with no arguments
model_with_tools = model.bind_tools([magic_function_no_args])
query = "What is the value of magic_function()? Use the tool."
result = model_with_tools.invoke(query)
_validate_tool_call_message_no_args(result)

full = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message_no_args(full)

def test_structured_output(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
Expand Down

0 comments on commit d8ae66b

Please sign in to comment.