Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Dec 5, 2024
1 parent 2511db3 commit 88bf808
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,7 @@ def run(
system_prompts = [{"text": messages[0].content}]
messages = messages[1:]

messages_list = [
{"role": msg.role.value, "content": [{"text": msg.content}]}
for msg in messages
]
messages_list = [{"role": msg.role.value, "content": [{"text": msg.content}]} for msg in messages]

try:
# Build API parameters
Expand Down Expand Up @@ -276,23 +273,17 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C
"prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0),
"completion_tokens": response_body.get("usage", {}).get("outputTokens", 0),
"total_tokens": response_body.get("usage", {}).get("totalTokens", 0),
}
},
}

# Process each content block separately
for content_block in content_blocks:
if "text" in content_block:
replies.append(
ChatMessage.from_assistant(
content=content_block["text"],
meta=base_meta.copy()
)
)
replies.append(ChatMessage.from_assistant(content=content_block["text"], meta=base_meta.copy()))
elif "toolUse" in content_block:
replies.append(
ChatMessage.from_assistant(
content=json.dumps(content_block["toolUse"]),
meta={**base_meta.copy()}
content=json.dumps(content_block["toolUse"]), meta=base_meta.copy()
)
)
return replies
Expand All @@ -319,7 +310,7 @@ def process_streaming_response(
current_tool_use = {
"toolUseId": tool_start["toolUseId"],
"name": tool_start["name"],
"input": "" # Will accumulate deltas as string
"input": "", # Will accumulate deltas as string
}

elif "contentBlockDelta" in event:
Expand All @@ -344,16 +335,9 @@ def process_streaming_response(
pass

tool_content = json.dumps(current_tool_use)
replies.append(
ChatMessage.from_assistant(
content=tool_content,
meta=base_meta.copy()
)
)
replies.append(ChatMessage.from_assistant(content=tool_content, meta=base_meta.copy()))
elif current_content:
replies.append(
ChatMessage.from_assistant(content=current_content, meta=base_meta.copy())
)
replies.append(ChatMessage.from_assistant(content=current_content, meta=base_meta.copy()))

elif "messageStop" in event:
# not 100% correct for multiple messages but no way around it
Expand Down
57 changes: 37 additions & 20 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator

KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"
MODELS_TO_TEST = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"]
MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"]
MODELS_TO_TEST = [
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"cohere.command-r-plus-v1:0",
"mistral.mistral-large-2402-v1:0",
]
MODELS_TO_TEST_WITH_TOOLS = [
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"cohere.command-r-plus-v1:0",
"mistral.mistral-large-2402-v1:0",
]

# so far we've discovered these models support streaming and tool use
STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"]
Expand Down Expand Up @@ -119,7 +127,9 @@ def test_constructor_with_generation_kwargs(mock_boto3_session):
"""
generation_kwargs = {"temperature": 0.7}

layer = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0", generation_kwargs=generation_kwargs)
layer = AmazonBedrockChatGenerator(
model="anthropic.claude-3-5-sonnet-20240620-v1:0", generation_kwargs=generation_kwargs
)
assert layer.generation_kwargs == generation_kwargs


Expand Down Expand Up @@ -212,19 +222,19 @@ def test_tools_use(self, model_name):
"properties": {
"sign": {
"type": "string",
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP."
"description": "The call sign for the radio station "
"for which you want the most popular song. "
"Example calls signs are WZPZ and WKRP.",
}
},
"required": [
"sign"
]
"required": ["sign"],
}
}
},
}
}
],
# See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
"toolChoice": {"auto": {}}
"toolChoice": {"auto": {}},
}

messages = []
Expand All @@ -241,7 +251,6 @@ def test_tools_use(self, model_name):
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert first_reply.meta, "First reply has no metadata"


# Some models return thinking message as first and the second one as the tool call
if len(replies) > 1:
second_reply = replies[1]
Expand All @@ -252,15 +261,19 @@ def test_tools_use(self, model_name):
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value"
assert (
tool_call["input"]["sign"] == "WZPZ"
), f"Tool call {tool_call} does not contain the correct 'input' value"
else:
# case where the model returns the tool call as the first message
# double check that the tool call is correct
tool_call = json.loads(first_reply.content)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value"
assert (
tool_call["input"]["sign"] == "WZPZ"
), f"Tool call {tool_call} does not contain the correct 'input' value"

@pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS)
@pytest.mark.integration
Expand All @@ -281,19 +294,19 @@ def test_tools_use_with_streaming(self, model_name):
"properties": {
"sign": {
"type": "string",
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP."
"description": "The call sign for the radio station "
"for which you want the most popular song. Example "
"calls signs are WZPZ and WKRP.",
}
},
"required": [
"sign"
]
"required": ["sign"],
}
}
},
}
}
],
# See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
"toolChoice": {"auto": {}}
"toolChoice": {"auto": {}},
}

messages = []
Expand All @@ -320,12 +333,16 @@ def test_tools_use_with_streaming(self, model_name):
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value"
assert (
tool_call["input"]["sign"] == "WZPZ"
), f"Tool call {tool_call} does not contain the correct 'input' value"
else:
# case where the model returns the tool call as the first message
# double check that the tool call is correct
tool_call = json.loads(first_reply.content)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value"
assert (
tool_call["input"]["sign"] == "WZPZ"
), f"Tool call {tool_call} does not contain the correct 'input' value"

0 comments on commit 88bf808

Please sign in to comment.