From 24598b3e8c27fc53459cda63c586a34197fa8e6e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 24 Jul 2024 07:21:01 -0400 Subject: [PATCH] allow AIMessage.content=None for tool calls --- .../chat_models.py | 4 ++- .../tests/unit_tests/test_messages.py | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 libs/ai-endpoints/tests/unit_tests/test_messages.py diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 86a6a8ca..18c206dd 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -378,7 +378,9 @@ def _get_payload( messages.append(dict(role="user", content=msg)) elif isinstance(msg, dict): if msg.get("content", None) is None: - raise ValueError(f"Message {msg} has no content") + # content=None is valid for assistant messages (tool calling) + if not msg.get("role") == "assistant": + raise ValueError(f"Message {msg} has no content.") messages.append(msg) else: raise ValueError(f"Unknown message received: {msg} of type {type(msg)}") diff --git a/libs/ai-endpoints/tests/unit_tests/test_messages.py b/libs/ai-endpoints/tests/unit_tests/test_messages.py new file mode 100644 index 00000000..3bbbaa92 --- /dev/null +++ b/libs/ai-endpoints/tests/unit_tests/test_messages.py @@ -0,0 +1,36 @@ +import requests_mock +from langchain_core.messages import AIMessage + +from langchain_nvidia_ai_endpoints import ChatNVIDIA + + +def test_invoke_aimessage_content_none(requests_mock: requests_mock.Mocker) -> None: + requests_mock.post( + "https://integrate.api.nvidia.com/v1/chat/completions", + json={ + "id": "mock-id", + "created": 1234567890, + "object": "chat.completion", + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "WORKED"}, + } + ], + }, + ) + + empty_aimessage = AIMessage(content="EMPTY") + empty_aimessage.content = None # type: ignore + + llm = ChatNVIDIA() + response = llm.invoke([empty_aimessage]) + request = requests_mock.request_history[0] + assert request.method == "POST" + assert request.url == "https://integrate.api.nvidia.com/v1/chat/completions" + message = request.json()["messages"][0] + assert "content" in message and message["content"] != "EMPTY" + assert "content" in message and message["content"] is None + assert isinstance(response, AIMessage) + assert response.content == "WORKED"