Skip to content

Commit

Permalink
Anthropic allows multiple system messages, simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Aug 20, 2024
1 parent f50b491 commit c8d93f1
Showing 1 changed file with 17 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,20 +179,20 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
f"Model parameters {disallowed_params} are not allowed and will be ignored. "
f"Allowed parameters are {self.ALLOWED_PARAMS}."
)

# adapt ChatMessage(s) to the format expected by the Anthropic API
anthropic_formatted_messages = self._convert_to_anthropic_format(messages)

# system message provided by the user overrides the system message from the self.generation_kwargs
system = [anthropic_formatted_messages[0]] if messages and messages[0].is_from(ChatRole.SYSTEM) else None
if system:
anthropic_formatted_messages = anthropic_formatted_messages[1:]
system_messages: List[ChatMessage] = [msg for msg in messages if msg.is_from(ChatRole.SYSTEM)]
non_system_messages: List[ChatMessage] = [msg for msg in messages if not msg.is_from(ChatRole.SYSTEM)]
system_messages_formatted: List[Dict[str, Any]] = (
self._convert_to_anthropic_format(system_messages) if system_messages else []
)
messages_formatted: List[Dict[str, Any]] = (
self._convert_to_anthropic_format(non_system_messages) if non_system_messages else []
)

response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create(
max_tokens=filtered_generation_kwargs.pop("max_tokens", 512),
system=system if system else filtered_generation_kwargs.pop("system", ""),
system=system_messages_formatted or filtered_generation_kwargs.pop("system", ""),
model=self.model,
messages=anthropic_formatted_messages,
messages=messages_formatted,
stream=self.streaming_callback is not None,
**filtered_generation_kwargs,
)
Expand Down Expand Up @@ -261,14 +261,15 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict
anthropic_formatted_messages = []
for m in messages:
message_dict = dataclasses.asdict(m)
filtered_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v}
formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v}
if m.is_from(ChatRole.SYSTEM):
# system messages need to be in the format expected by the Anthropic API
filtered_message.pop("role")
filtered_message["type"] = "text"
filtered_message["text"] = filtered_message.pop("content")
filtered_message.update(m.meta or {})
anthropic_formatted_messages.append(filtered_message)
# remove role and content from the message dict, add type and text
formatted_message.pop("role")
formatted_message["type"] = "text"
formatted_message["text"] = formatted_message.pop("content")
formatted_message.update(m.meta or {})
anthropic_formatted_messages.append(formatted_message)
return anthropic_formatted_messages

def _connect_chunks(
Expand Down

0 comments on commit c8d93f1

Please sign in to comment.