Skip to content

Commit

Permalink
Migrate Claude to messaging API
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Mar 6, 2024
1 parent 5a339d4 commit 9272da8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
:param response_body: The response body.
:returns: The extracted responses.
"""
return self._extract_messages_from_response(self.response_body_message_key(), response_body)
return self._extract_messages_from_response(response_body)

def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]:
tokens: List[str] = []
Expand All @@ -53,11 +53,8 @@ def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[St
if chunk:
decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
token = self._extract_token_from_stream(decoded_chunk)
# take all the rest key/value pairs from the chunk, add them to the metadata
stream_metadata = {k: v for (k, v) in decoded_chunk.items() if v != token}
stream_chunk = StreamingChunk(content=token, meta=stream_metadata)
# callback the stream handler with StreamingChunk
stream_handler(stream_chunk)
stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only
stream_handler(stream_chunk) # callback the stream handler with StreamingChunk
tokens.append(token)
responses = ["".join(tokens).lstrip()]
return responses
Expand Down Expand Up @@ -124,25 +121,14 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
:returns: A dictionary containing the resized prompt and additional information.
"""

def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]:
@abstractmethod
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Extracts the messages from the response body.
:param message_tag: The key for the message in the response body.
:param response_body: The response body.
:returns: The extracted ChatMessage list.
"""
metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]

@abstractmethod
def response_body_message_key(self) -> str:
"""
Returns the key for the message in the response body.
Subclasses should override this method to return the correct message key - where the response is located.
:returns: The key for the message in the response body.
"""

@abstractmethod
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Expand Down Expand Up @@ -183,7 +169,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]):
self.prompt_handler = DefaultPromptHandler(
tokenizer="gpt2",
model_max_length=model_max_length,
max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512,
max_length=self.generation_kwargs.get("max_tokens") or 512,
)

def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]:
Expand All @@ -195,46 +181,33 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[
:returns: The prepared body.
"""
default_params = {
"max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512,
"stop_sequences": ["\n\nHuman:"],
"anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31",
"max_tokens": self.generation_kwargs.get("max_tokens") or 512,
}

# combine stop words with default stop sequences, remove stop_words as Anthropic does not support it
stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", [])
if stop_sequences:
inference_kwargs["stop_sequences"] = stop_sequences
params = self._get_params(inference_kwargs, default_params)
body = {"prompt": self.prepare_chat_messages(messages=messages), **params}
body = {**self.prepare_chat_messages(messages=messages), **params}
return body

def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]:
"""
Prepares the chat messages for the Anthropic Claude request.
:param messages: The chat messages to prepare.
:returns: The prepared chat messages as a string.
"""
conversation = []
for index, message in enumerate(messages):
if message.is_from(ChatRole.USER):
conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_USER_TOKEN} {message.content.strip()}")
elif message.is_from(ChatRole.ASSISTANT):
conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}")
elif message.is_from(ChatRole.FUNCTION):
error_message = "Anthropic does not support function calls."
raise ValueError(error_message)
elif message.is_from(ChatRole.SYSTEM) and index == 0:
# Until we transition to the new chat message format system messages will be ignored
# see https://docs.anthropic.com/claude/reference/messages_post for more details
logger.warning(
"System messages are not fully supported by the current version of Claude and will be ignored."
)
else:
invalid_role = f"Invalid role {message.role} for message {message.content}"
raise ValueError(invalid_role)

prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " "
return self._ensure_token_limit(prepared_prompt)
body: Dict[str, Any] = {}
system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None
body["messages"] = [
self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT)
]
if system:
body["system"] = system
return body

def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
Expand All @@ -245,13 +218,19 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
return self.prompt_handler(prompt)

def response_body_message_key(self) -> str:
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Returns the key for the message in the response body for Anthropic Claude i.e. "completion".
Extracts the messages from the response body.
:returns: The key for the message in the response body.
:param response_body: The response body.
:return: The extracted ChatMessage list.
"""
return "completion"
messages: List[ChatMessage] = []
if response_body.get("type") == "message":
for content in response_body["content"]:
if content.get("type") == "text":
messages.append(ChatMessage.from_assistant(content["text"]))
return messages

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Expand All @@ -260,7 +239,17 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
:param chunk: The streaming chunk.
:returns: The extracted token.
"""
return chunk.get("completion", "")
if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta":
return chunk.get("delta", {}).get("text", "")
return ""

def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]:
"""
Convert a ChatMessage to a dictionary with the content and role fields.
:param m: The ChatMessage to convert.
:return: The dictionary with the content and role fields.
"""
return {"content": [{"type": "text", "text": m.content}], "role": m.role.value}


class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
Expand Down Expand Up @@ -357,13 +346,16 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
return self.prompt_handler(prompt)

def response_body_message_key(self) -> str:
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Returns the key for the message in the response body for Meta Llama 2 i.e. "generation".
Extracts the messages from the response body.
:returns: The key for the message in the response body.
:param response_body: The response body.
:return: The extracted ChatMessage list.
"""
return "generation"
message_tag = "generation"
metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,21 @@ class AmazonBedrockChatGenerator:
"""
`AmazonBedrockChatGenerator` enables text generation via Amazon Bedrock hosted chat LLMs.
For example, to use the Anthropic Claude model, simply initialize the `AmazonBedrockChatGenerator` with the
'anthropic.claude-v2' model name.
For example, to use the Anthropic Claude 3 Sonnet model, simply initialize the `AmazonBedrockChatGenerator` with the
'anthropic.claude-3-sonnet-20240229-v1:0' model name.
```python
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.components.generators.utils import print_streaming_chunk
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant, answer in German only"),
ChatMessage.from_user("What's Natural Language Processing?")]
client = AmazonBedrockChatGenerator(model="anthropic.claude-v2", streaming_callback=print_streaming_chunk)
client.run(messages, generation_kwargs={"max_tokens_to_sample": 512})
client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0",
streaming_callback=print_streaming_chunk)
client.run(messages, generation_kwargs={"max_tokens": 512})
```
Expand Down Expand Up @@ -154,7 +155,7 @@ def invoke(self, *args, **kwargs):
msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt."
raise ValueError(msg)

body = self.model_adapter.prepare_body(messages=messages, stop_words=self.stop_words, **kwargs)
body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs})
try:
if self.streaming_callback:
response = self.client.invoke_model_with_response_stream(
Expand Down

0 comments on commit 9272da8

Please sign in to comment.