Skip to content

Commit

Permalink
fixed sequence of multiple AIMessage (langchain-ai#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
alx13 authored Jun 27, 2024
1 parent 2e4d8ca commit a22a608
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 11 deletions.
31 changes: 20 additions & 11 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
)
parts.append(Part(function_call=function_call))

prev_content = vertex_messages[-1]
prev_content_is_model = prev_content and prev_content.role == "model"
if prev_content_is_model:
prev_parts = list(prev_content.parts)
prev_parts.extend(parts)
vertex_messages[-1] = Content(role=role, parts=prev_parts)
continue

vertex_messages.append(Content(role=role, parts=parts))
elif isinstance(message, FunctionMessage):
prev_ai_message = None
Expand All @@ -306,18 +314,18 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
name=message.name, response={"content": message.content}
)
)
parts = [part]

prev_content = vertex_messages[-1]
prev_content_is_function = prev_content and prev_content.role == "function"

if prev_content_is_function:
parts = list(prev_content.parts)
parts.append(part)
prev_parts = list(prev_content.parts)
prev_parts.extend(parts)
# replacing last message
vertex_messages[-1] = Content(role=role, parts=parts)
vertex_messages[-1] = Content(role=role, parts=prev_parts)
continue

parts = [part]

vertex_messages.append(Content(role=role, parts=parts))
elif isinstance(message, ToolMessage):
role = "function"
Expand Down Expand Up @@ -383,18 +391,19 @@ def _parse_content(raw_content: str | Dict[Any, Any]) -> Dict[Any, Any]:
response=content,
)
)
parts = [part]

prev_content = vertex_messages[-1]
prev_content_is_function = prev_content and prev_content.role == "function"

if prev_content_is_function:
parts = list(prev_content.parts)
parts.append(part)
prev_parts = list(prev_content.parts)
prev_parts.extend(parts)
# replacing last message
vertex_messages[-1] = Content(role=role, parts=parts)
vertex_messages[-1] = Content(role=role, parts=prev_parts)
continue
else:
parts = [part]
vertex_messages.append(Content(role=role, parts=parts))

vertex_messages.append(Content(role=role, parts=parts))
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
Expand Down
6 changes: 6 additions & 0 deletions libs/vertexai/langchain_google_vertexai/functions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ def _format_to_gapic_function_declaration(
elif isinstance(tool, vertexai.FunctionDeclaration):
return _format_vertex_to_function_declaration(tool)
elif isinstance(tool, dict):
# this could come from
# 'langchain_core.utils.function_calling.convert_to_openai_tool'
if tool.get("type") == "function" and tool.get("function"):
return _format_dict_to_function_declaration(
cast(FunctionDescription, tool.get("function"))
)
return _format_dict_to_function_declaration(tool)
else:
raise ValueError(f"Unsupported tool call type {tool}")
Expand Down
45 changes: 45 additions & 0 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,51 @@ def test_parse_history_gemini_function() -> None:
)
],
),
(
[
AIMessage(
content=["Mike age is 30"],
tool_calls=[
ToolCall(
name="Information",
args={"name": "Rob"},
id="00000000-0000-0000-0000-00000000000",
),
],
),
AIMessage(
content=["Arthur age is 30"],
tool_calls=[
ToolCall(
name="Information",
args={"name": "Ben"},
id="00000000-0000-0000-0000-00000000000",
),
],
),
],
[
Content(
role="model",
parts=[
Part(text="Mike age is 30"),
Part(
function_call=FunctionCall(
name="Information",
args={"name": "Rob"},
)
),
Part(text="Arthur age is 30"),
Part(
function_call=FunctionCall(
name="Information",
args={"name": "Ben"},
)
),
],
)
],
),
],
)
def test_parse_history_gemini_multi(source_history, expected_history) -> None:
Expand Down

0 comments on commit a22a608

Please sign in to comment.