Skip to content

Commit

Permalink
BUGFIX: handle tool message type when converting to string (#13626)
Browse files Browse the repository at this point in the history
**Description:** Currently, if we pass in a ToolMessage back to the
chain, it crashes with error

`Got unsupported message type: `

This fixes it. 

Tested locally

---------

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
tanujtiwari-at and baskaryan authored Nov 22, 2023
1 parent 143049c commit 5064890
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 64 deletions.
2 changes: 2 additions & 0 deletions libs/core/langchain_core/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def get_buffer_string(
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ToolMessage):
role = "Tool"
elif isinstance(m, ChatMessage):
role = m.role
else:
Expand Down
84 changes: 84 additions & 0 deletions libs/core/tests/unit_tests/test_messages.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import unittest

import pytest

from langchain_core.messages import (
AIMessage,
AIMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
ToolMessage,
get_buffer_string,
messages_from_dict,
messages_to_dict,
)


Expand Down Expand Up @@ -100,3 +111,76 @@ def test_ani_message_chunks() -> None:
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=False, content=" indeed."
)


class TestGetBufferString(unittest.TestCase):
def setUp(self) -> None:
self.human_msg = HumanMessage(content="human")
self.ai_msg = AIMessage(content="ai")
self.sys_msg = SystemMessage(content="system")
self.func_msg = FunctionMessage(name="func", content="function")
self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool")
self.chat_msg = ChatMessage(role="Chat", content="chat")

def test_empty_input(self) -> None:
self.assertEqual(get_buffer_string([]), "")

def test_valid_single_message(self) -> None:
expected_output = f"Human: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg]),
expected_output,
)

def test_custom_human_prefix(self) -> None:
prefix = "H"
expected_output = f"{prefix}: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg], human_prefix="H"),
expected_output,
)

def test_custom_ai_prefix(self) -> None:
prefix = "A"
expected_output = f"{prefix}: {self.ai_msg.content}"
self.assertEqual(
get_buffer_string([self.ai_msg], ai_prefix="A"),
expected_output,
)

def test_multiple_msg(self) -> None:
msgs = [
self.human_msg,
self.ai_msg,
self.sys_msg,
self.func_msg,
self.tool_msg,
self.chat_msg,
]
expected_output = "\n".join(
[
"Human: human",
"AI: ai",
"System: system",
"Function: function",
"Tool: tool",
"Chat: chat",
]
)
self.assertEqual(
get_buffer_string(msgs),
expected_output,
)


def test_multiple_msg() -> None:
human_msg = HumanMessage(content="human", additional_kwargs={"key": "value"})
ai_msg = AIMessage(content="ai")
sys_msg = SystemMessage(content="sys")

msgs = [
human_msg,
ai_msg,
sys_msg,
]
assert messages_from_dict(messages_to_dict(msgs)) == msgs
64 changes: 0 additions & 64 deletions libs/langchain/tests/unit_tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Test formatting functionality."""
import unittest
from typing import Union

import pytest
Expand All @@ -16,75 +15,12 @@
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
get_buffer_string,
messages_from_dict,
messages_to_dict,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
from langchain_core.pydantic_v1 import BaseModel, ValidationError


class TestGetBufferString(unittest.TestCase):
def setUp(self) -> None:
self.human_msg = HumanMessage(content="human")
self.ai_msg = AIMessage(content="ai")
self.sys_msg = SystemMessage(content="sys")

def test_empty_input(self) -> None:
self.assertEqual(get_buffer_string([]), "")

def test_valid_single_message(self) -> None:
expected_output = f"Human: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg]),
expected_output,
)

def test_custom_human_prefix(self) -> None:
prefix = "H"
expected_output = f"{prefix}: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg], human_prefix="H"),
expected_output,
)

def test_custom_ai_prefix(self) -> None:
prefix = "A"
expected_output = f"{prefix}: {self.ai_msg.content}"
self.assertEqual(
get_buffer_string([self.ai_msg], ai_prefix="A"),
expected_output,
)

def test_multiple_msg(self) -> None:
msgs = [self.human_msg, self.ai_msg, self.sys_msg]
expected_output = "\n".join(
[
f"Human: {self.human_msg.content}",
f"AI: {self.ai_msg.content}",
f"System: {self.sys_msg.content}",
]
)
self.assertEqual(
get_buffer_string(msgs),
expected_output,
)


def test_multiple_msg() -> None:
human_msg = HumanMessage(content="human", additional_kwargs={"key": "value"})
ai_msg = AIMessage(content="ai")
sys_msg = SystemMessage(content="sys")

msgs = [
human_msg,
ai_msg,
sys_msg,
]
assert messages_from_dict(messages_to_dict(msgs)) == msgs


def test_serialization_of_wellknown_objects() -> None:
"""Test that pydantic is able to serialize and deserialize well known objects."""

Expand Down

0 comments on commit 5064890

Please sign in to comment.