diff --git a/libs/core/langchain_core/messages/__init__.py b/libs/core/langchain_core/messages/__init__.py index cb20a40eeb434..7bcd9bfed5fa5 100644 --- a/libs/core/langchain_core/messages/__init__.py +++ b/libs/core/langchain_core/messages/__init__.py @@ -54,6 +54,8 @@ def get_buffer_string( role = "System" elif isinstance(m, FunctionMessage): role = "Function" + elif isinstance(m, ToolMessage): + role = "Tool" elif isinstance(m, ChatMessage): role = m.role else: diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 9e8918dbd8953..ef1cfa8fb1459 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -1,10 +1,21 @@ +import unittest + import pytest from langchain_core.messages import ( + AIMessage, AIMessageChunk, + ChatMessage, ChatMessageChunk, + FunctionMessage, FunctionMessageChunk, + HumanMessage, HumanMessageChunk, + SystemMessage, + ToolMessage, + get_buffer_string, + messages_from_dict, + messages_to_dict, ) @@ -100,3 +111,76 @@ def test_ani_message_chunks() -> None: AIMessageChunk(example=True, content="I am") + AIMessageChunk( example=False, content=" indeed." ) + + +class TestGetBufferString(unittest.TestCase): + def setUp(self) -> None: + self.human_msg = HumanMessage(content="human") + self.ai_msg = AIMessage(content="ai") + self.sys_msg = SystemMessage(content="system") + self.func_msg = FunctionMessage(name="func", content="function") + self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool") + self.chat_msg = ChatMessage(role="Chat", content="chat") + + def test_empty_input(self) -> None: + self.assertEqual(get_buffer_string([]), "") + + def test_valid_single_message(self) -> None: + expected_output = f"Human: {self.human_msg.content}" + self.assertEqual( + get_buffer_string([self.human_msg]), + expected_output, + ) + + def test_custom_human_prefix(self) -> None: + prefix = "H" + expected_output = f"{prefix}: {self.human_msg.content}" + self.assertEqual( + get_buffer_string([self.human_msg], human_prefix="H"), + expected_output, + ) + + def test_custom_ai_prefix(self) -> None: + prefix = "A" + expected_output = f"{prefix}: {self.ai_msg.content}" + self.assertEqual( + get_buffer_string([self.ai_msg], ai_prefix="A"), + expected_output, + ) + + def test_multiple_msg(self) -> None: + msgs = [ + self.human_msg, + self.ai_msg, + self.sys_msg, + self.func_msg, + self.tool_msg, + self.chat_msg, + ] + expected_output = "\n".join( + [ + "Human: human", + "AI: ai", + "System: system", + "Function: function", + "Tool: tool", + "Chat: chat", + ] + ) + self.assertEqual( + get_buffer_string(msgs), + expected_output, + ) + + +def test_multiple_msg() -> None: + human_msg = HumanMessage(content="human", additional_kwargs={"key": "value"}) + ai_msg = AIMessage(content="ai") + sys_msg = SystemMessage(content="sys") + + msgs = [ + human_msg, + ai_msg, + sys_msg, + ] + assert messages_from_dict(messages_to_dict(msgs)) == msgs diff --git a/libs/langchain/tests/unit_tests/test_schema.py b/libs/langchain/tests/unit_tests/test_schema.py index c749b1ca1c883..0f7efb07dfef0 100644 --- a/libs/langchain/tests/unit_tests/test_schema.py +++ b/libs/langchain/tests/unit_tests/test_schema.py @@ -1,5 +1,4 @@ """Test formatting functionality.""" -import unittest from typing import Union import pytest @@ -16,75 +15,12 @@ HumanMessageChunk, SystemMessage, SystemMessageChunk, - get_buffer_string, - messages_from_dict, - messages_to_dict, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue from langchain_core.pydantic_v1 import BaseModel, ValidationError -class TestGetBufferString(unittest.TestCase): - def setUp(self) -> None: - self.human_msg = HumanMessage(content="human") - self.ai_msg = AIMessage(content="ai") - self.sys_msg = SystemMessage(content="sys") - - def test_empty_input(self) -> None: - self.assertEqual(get_buffer_string([]), "") - - def test_valid_single_message(self) -> None: - expected_output = f"Human: {self.human_msg.content}" - self.assertEqual( - get_buffer_string([self.human_msg]), - expected_output, - ) - - def test_custom_human_prefix(self) -> None: - prefix = "H" - expected_output = f"{prefix}: {self.human_msg.content}" - self.assertEqual( - get_buffer_string([self.human_msg], human_prefix="H"), - expected_output, - ) - - def test_custom_ai_prefix(self) -> None: - prefix = "A" - expected_output = f"{prefix}: {self.ai_msg.content}" - self.assertEqual( - get_buffer_string([self.ai_msg], ai_prefix="A"), - expected_output, - ) - - def test_multiple_msg(self) -> None: - msgs = [self.human_msg, self.ai_msg, self.sys_msg] - expected_output = "\n".join( - [ - f"Human: {self.human_msg.content}", - f"AI: {self.ai_msg.content}", - f"System: {self.sys_msg.content}", - ] - ) - self.assertEqual( - get_buffer_string(msgs), - expected_output, - ) - - -def test_multiple_msg() -> None: - human_msg = HumanMessage(content="human", additional_kwargs={"key": "value"}) - ai_msg = AIMessage(content="ai") - sys_msg = SystemMessage(content="sys") - - msgs = [ - human_msg, - ai_msg, - sys_msg, - ] - assert messages_from_dict(messages_to_dict(msgs)) == msgs - - def test_serialization_of_wellknown_objects() -> None: """Test that pydantic is able to serialize and deserialize well known objects."""