From 3085cf9ce3ac9950df4989b4100822cd450267e4 Mon Sep 17 00:00:00 2001 From: Corentin Meyer Date: Tue, 20 Aug 2024 16:09:12 +0200 Subject: [PATCH] fix(Bedrock): allow tools kwargs for AWS Bedrock Claude model (#976) --- .../amazon_bedrock/chat/adapters.py | 18 ++++++-- .../tests/test_chat_generator.py | 44 +++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 56eefdf09..f5e8f8181 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -166,6 +166,8 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): "top_p", "top_k", "system", + "tools", + "tool_choice", ] def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): @@ -253,10 +255,18 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List """ messages: List[ChatMessage] = [] if response_body.get("type") == "message": - for content in response_body["content"]: - if content.get("type") == "text": - meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} - messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) + if response_body.get("stop_reason") == "tool_use": # If `tool_use` we only keep the tool_use content + for content in response_body["content"]: + if content.get("type") == "tool_use": + meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} + json_answer = json.dumps(content) + messages.append(ChatMessage.from_assistant(json_answer, meta=meta)) + else: # For other stop_reason, return all text content + for content in response_body["content"]: + if content.get("type") == "text": + meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} + messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) + return messages def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 98f20fc2d..ed0c27401 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,3 +1,4 @@ +import json import logging import os from typing import Optional, Type @@ -17,6 +18,7 @@ KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] +MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-haiku-20240307-v1:0"] MISTRAL_MODELS = [ "mistral.mistral-7b-instruct-v0:2", "mistral.mixtral-8x7b-instruct-v0:1", @@ -303,6 +305,48 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body + @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) + @pytest.mark.integration + def test_tools_use(self, model_name): + """ + Test function calling with AWS Bedrock Anthropic adapter + """ + # See https://docs.anthropic.com/en/docs/tool-use for more information + tools = [ + { + "name": "top_song", + "description": "Get the most popular song played on a radio station.", + "input_schema": { + "type": "object", + "properties": { + "sign": { + "type": "string", + "description": "The call sign for the radio station for which you want the most popular" + " song. Example calls signs are WZPZ and WKRP.", + } + }, + "required": ["sign"], + }, + } + ] + messages = [] + messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) + client = AmazonBedrockChatGenerator(model=model_name) + response = client.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": {"type": "any"}}) + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "top_song" in first_reply.content.lower(), "First reply does not contain top_song" + assert first_reply.meta, "First reply has no metadata" + fc_response = json.loads(first_reply.content) + assert "name" in fc_response, "First reply does not contain name of the tool" + assert "input" in fc_response, "First reply does not contain input of the tool" + class TestMistralAdapter: def test_prepare_body_with_default_params(self) -> None: