Skip to content

Commit

Permalink
fix: chat roles for model responses in chat generators (#1030)
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Oct 2, 2024
1 parent 4ac29d1 commit d7926a2
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 17 deletions.
2 changes: 1 addition & 1 deletion integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
"""
Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False
"""
messages = [ChatMessage.from_system("What is the biggest city in United States?")]
messages = [ChatMessage.from_user("What is the biggest city in United States?")]

# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
max_length_generated_text = 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,14 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
raise ValueError(msg)

def _message_to_part(self, message: ChatMessage) -> Part:
if message.role == ChatRole.SYSTEM and message.name:
if message.role == ChatRole.ASSISTANT and message.name:
p = Part()
p.function_call.name = message.name
p.function_call.args = {}
for k, v in message.content.items():
p.function_call.args[k] = v
return p
elif message.role == ChatRole.SYSTEM:
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
p = Part()
p.text = message.content
return p
Expand All @@ -250,13 +250,13 @@ def _message_to_part(self, message: ChatMessage) -> Part:
return self._convert_part(message.content)

def _message_to_content(self, message: ChatMessage) -> Content:
if message.role == ChatRole.SYSTEM and message.name:
if message.role == ChatRole.ASSISTANT and message.name:
part = Part()
part.function_call.name = message.name
part.function_call.args = {}
for k, v in message.content.items():
part.function_call.args[k] = v
elif message.role == ChatRole.SYSTEM:
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
part = Part()
part.text = message.content
elif message.role == ChatRole.FUNCTION:
Expand Down Expand Up @@ -315,12 +315,12 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess
for candidate in response_body.candidates:
for part in candidate.content.parts:
if part.text != "":
replies.append(ChatMessage.from_system(part.text))
replies.append(ChatMessage.from_assistant(part.text))
elif part.function_call is not None:
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.SYSTEM,
role=ChatRole.ASSISTANT,
name=part.function_call.name,
)
)
Expand All @@ -343,4 +343,4 @@ def _get_stream_response(
responses.append(content)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_system(content=combined_response)]
return [ChatMessage.from_assistant(content=combined_response)]
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,9 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
def test_past_conversation():
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro")
messages = [
ChatMessage.from_system(content="You are a knowledageable mathematician."),
ChatMessage.from_user(content="What is 2+2?"),
ChatMessage.from_system(content="It's an arithmetic operation."),
ChatMessage.from_assistant(content="It's an arithmetic operation."),
ChatMessage.from_user(content="Yeah, but what's the result?"),
]
res = gemini_chat.run(messages=messages)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,24 +161,24 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
raise ValueError(msg)

def _message_to_part(self, message: ChatMessage) -> Part:
if message.role == ChatRole.SYSTEM and message.name:
if message.role == ChatRole.ASSISTANT and message.name:
p = Part.from_dict({"function_call": {"name": message.name, "args": {}}})
for k, v in message.content.items():
p.function_call.args[k] = v
return p
elif message.role == ChatRole.SYSTEM:
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
return Part.from_text(message.content)
elif message.role == ChatRole.FUNCTION:
return Part.from_function_response(name=message.name, response=message.content)
elif message.role == ChatRole.USER:
return self._convert_part(message.content)

def _message_to_content(self, message: ChatMessage) -> Content:
if message.role == ChatRole.SYSTEM and message.name:
if message.role == ChatRole.ASSISTANT and message.name:
part = Part.from_dict({"function_call": {"name": message.name, "args": {}}})
for k, v in message.content.items():
part.function_call.args[k] = v
elif message.role == ChatRole.SYSTEM:
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
part = Part.from_text(message.content)
elif message.role == ChatRole.FUNCTION:
part = Part.from_function_response(name=message.name, response=message.content)
Expand Down Expand Up @@ -233,12 +233,12 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
for candidate in response_body.candidates:
for part in candidate.content.parts:
if part._raw_part.text != "":
replies.append(ChatMessage.from_system(part.text))
replies.append(ChatMessage.from_assistant(part.text))
elif part.function_call is not None:
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.SYSTEM,
role=ChatRole.ASSISTANT,
name=part.function_call.name,
)
)
Expand All @@ -261,4 +261,4 @@ def _get_stream_response(
responses.append(streaming_chunk.content)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_system(content=combined_response)]
return [ChatMessage.from_assistant(content=combined_response)]
2 changes: 1 addition & 1 deletion integrations/ollama/examples/chat_generator_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

messages = [
ChatMessage.from_user("What's Natural Language Processing?"),
ChatMessage.from_system(
ChatMessage.from_assistant(
"Natural Language Processing (NLP) is a field of computer science and artificial "
"intelligence concerned with the interaction between computers and human language"
),
Expand Down

0 comments on commit d7926a2

Please sign in to comment.