Skip to content

Commit

Permalink
Support for tool streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Nov 26, 2024
1 parent 7137144 commit 8c4362a
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -286,32 +286,76 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C
def process_streaming_response(
self, response_stream, streaming_callback: Callable[[StreamingChunk], None]
) -> List[ChatMessage]:
content = ""
meta = {
replies = []
current_content = ""
current_tool_use = None
base_meta = {
"model": self.model,
"index": 0,
}

for event in response_stream:
if "contentBlockDelta" in event:
if "contentBlockStart" in event:
# Reset accumulators for new message
current_content = ""
current_tool_use = None
block_start = event["contentBlockStart"]
if "start" in block_start and "toolUse" in block_start["start"]:
tool_start = block_start["start"]["toolUse"]
current_tool_use = {
"toolUseId": tool_start["toolUseId"],
"name": tool_start["name"],
"input": "" # Will accumulate deltas as string
}

elif "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
delta_text = delta.get("text", "")
if delta_text:
content += delta_text
if "text" in delta:
delta_text = delta["text"]
current_content += delta_text
streaming_chunk = StreamingChunk(content=delta_text, meta=None)
# it only makes sense to call callback on text deltas
streaming_callback(streaming_chunk)
if "messageStop" in event:
meta["finish_reason"] = event["messageStop"].get("stopReason")
if "metadata" in event:
elif "toolUse" in delta and current_tool_use:
# Accumulate tool use input deltas
current_tool_use["input"] += delta["toolUse"].get("input", "")
elif "contentBlockStop" in event:
if current_tool_use:
# Parse accumulated input if it's a JSON string
try:
input_json = json.loads(current_tool_use["input"])
current_tool_use["input"] = input_json
except json.JSONDecodeError:
# Keep as string if not valid JSON
pass

tool_content = json.dumps(current_tool_use)
replies.append(
ChatMessage.from_assistant(
content=tool_content,
meta=base_meta.copy()
)
)
elif current_content:
replies.append(
ChatMessage.from_assistant(content=current_content, meta=base_meta.copy())
)

elif "messageStop" in event:
# not 100% correct for multiple messages but no way around it
for reply in replies:
reply.meta["finish_reason"] = event["messageStop"].get("stopReason")

elif "metadata" in event:
metadata = event["metadata"]
if "usage" in metadata:
usage = metadata["usage"]
# use OpenAI's format for usage for cross ChatGenerator compatibility
meta["usage"] = {
"prompt_tokens": usage.get("inputTokens", 0),
"completion_tokens": usage.get("outputTokens", 0),
"total_tokens": usage.get("totalTokens", 0),
}
# not 100% correct for multiple messages but no way around it
for reply in replies:
if "usage" in metadata:
usage = metadata["usage"]
reply.meta["usage"] = {
"prompt_tokens": usage.get("inputTokens", 0),
"completion_tokens": usage.get("outputTokens", 0),
"total_tokens": usage.get("totalTokens", 0),
}

replies = [ChatMessage.from_assistant(content=content, meta=meta)]
return replies
73 changes: 72 additions & 1 deletion integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"
MODELS_TO_TEST = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"]
MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-5-sonnet-20240620-v1:0","cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"]
MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"]

# so far we've discovered these models support streaming and tool use
STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"]


@pytest.fixture
Expand Down Expand Up @@ -226,6 +229,74 @@ def test_tools_use(self, model_name):
assert first_reply.meta, "First reply has no metadata"


# Some models return thinking message as first and the second one as the tool call
if len(replies) > 1:
second_reply = replies[1]
assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance"
assert second_reply.content, "Second reply has no content"
assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant"
tool_call = json.loads(second_reply.content)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value"
else:
# case where the model returns the tool call as the first message
# double check that the tool call is correct
tool_call = json.loads(first_reply.content)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value"

@pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS)
@pytest.mark.integration
def test_tools_use_with_streaming(self, model_name):
"""
Test function calling with AWS Bedrock Anthropic adapter
"""
# See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html
tool_config = {
"tools": [
{
"toolSpec": {
"name": "top_song",
"description": "Get the most popular song played on a radio station.",
"inputSchema": {
"json": {
"type": "object",
"properties": {
"sign": {
"type": "string",
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP."
}
},
"required": [
"sign"
]
}
}
}
}
],
# See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
"toolChoice": {"auto": {}}
}

messages = []
messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?"))
client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=print_streaming_chunk)
response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config})
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert first_reply.meta, "First reply has no metadata"

# Some models return thinking message as first and the second one as the tool call
if len(replies) > 1:
second_reply = replies[1]
Expand Down

0 comments on commit 8c4362a

Please sign in to comment.