Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 18, 2024
1 parent 13ece0f commit 01fcb80
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,9 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
raise ValueError(msg)

def _message_to_part(self, message: ChatMessage) -> Part:
name = getattr(message, "name", None)
if name is None:
name = getattr(message, "_name", None)

if message.is_from(ChatRole.ASSISTANT) and name:
if message.is_from(ChatRole.ASSISTANT) and message.name:
p = Part()
p.function_call.name = name
p.function_call.name = message.name
p.function_call.args = {}
for k, v in json.loads(message.text).items():
p.function_call.args[k] = v
Expand All @@ -248,27 +244,21 @@ def _message_to_part(self, message: ChatMessage) -> Part:
return p
elif message.is_from(ChatRole.FUNCTION):
p = Part()
p.function_response.name = name
p.function_response.name = message.name
p.function_response.response = message.text
return p
elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL):
print("********* HERE *********")
part = Part()
part.function_response.name = message.tool_call_result.origin.tool_name
part.function_response.response = message.tool_call_result.result
print(part)
p = Part()
p.function_response.name = message.tool_call_result.origin.tool_name
p.function_response.response = message.tool_call_result.result
return p
elif message.is_from(ChatRole.USER):
return self._convert_part(message.text)

def _message_to_content(self, message: ChatMessage) -> Content:
# support both new and legacy ChatMessage
name = getattr(message, "name", None)
if name is None:
name = getattr(message, "_name", None)

if message.is_from(ChatRole.ASSISTANT) and name:
if message.is_from(ChatRole.ASSISTANT) and message.name:
part = Part()
part.function_call.name = name
part.function_call.name = message.name
part.function_call.args = {}
for k, v in json.loads(message.text).items():
part.function_call.args[k] = v
Expand All @@ -277,20 +267,26 @@ def _message_to_content(self, message: ChatMessage) -> Content:
part.text = message.text
elif message.is_from(ChatRole.FUNCTION):
part = Part()
part.function_response.name = name
part.function_response.name = message.name
part.function_response.response = message.text
elif message.is_from(ChatRole.USER):
part = self._convert_part(message.text)
elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL):
print("********* HERE *********")
part = Part()
part.function_response.name = message.tool_call_result.origin.tool_name
part.function_response.response = message.tool_call_result.result
print(part)
part.function_response.response = message.tool_call_result.result
elif message.is_from(ChatRole.USER):
part = self._convert_part(message.text)
else:
msg = f"Unsupported message role {message.role}"
raise ValueError(msg)
role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model"
role = (
"user"
if message.is_from(ChatRole.USER)
or message.is_from(ChatRole.FUNCTION)
or ("TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL))
else "model"
)
return Content(parts=[part], role=role)

@component.output_types(replies=List[ChatMessage])
Expand All @@ -312,11 +308,9 @@ def run(
"""
streaming_callback = streaming_callback or self._streaming_callback
history = [self._message_to_content(m) for m in messages[:-1]]
print(history)
session = self._model.start_chat(history=history)

new_message = self._message_to_part(messages[-1])
print(new_message)
res = session.send_message(
content=new_message,
generation_config=self._generation_config,
Expand Down Expand Up @@ -395,7 +389,7 @@ def _get_stream_response(
metadata["function_call"] = part["function_call"]
content = json.dumps(dict(part["function_call"]["args"]))
new_message = ChatMessage.from_assistant(content, meta=metadata)
try:
try:
new_message.name = part["function_call"]["name"]
except AttributeError:
new_message._name = part["function_call"]["name"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,8 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
assert "function_call" in chat_message.meta
assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"}

weather = str(get_current_weather(**json.loads(response["replies"][0].text)))
weather = get_current_weather(**json.loads(chat_message.text))
messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")]
print(messages)
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
Expand Down

0 comments on commit 01fcb80

Please sign in to comment.