From a22a608c0a004401fdf72f30c3e4c8a080ef020d Mon Sep 17 00:00:00 2001 From: Alex Ostapenko Date: Thu, 27 Jun 2024 17:35:15 +0200 Subject: [PATCH] fixed sequence of multiple AIMessage (#335) --- .../langchain_google_vertexai/chat_models.py | 31 ++++++++----- .../functions_utils.py | 6 +++ .../tests/unit_tests/test_chat_models.py | 45 +++++++++++++++++++ 3 files changed, 71 insertions(+), 11 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 554acc9b..892c32f6 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -296,6 +296,14 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: ) parts.append(Part(function_call=function_call)) + prev_content = vertex_messages[-1] + prev_content_is_model = prev_content and prev_content.role == "model" + if prev_content_is_model: + prev_parts = list(prev_content.parts) + prev_parts.extend(parts) + vertex_messages[-1] = Content(role=role, parts=prev_parts) + continue + vertex_messages.append(Content(role=role, parts=parts)) elif isinstance(message, FunctionMessage): prev_ai_message = None @@ -306,18 +314,18 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: name=message.name, response={"content": message.content} ) ) + parts = [part] prev_content = vertex_messages[-1] prev_content_is_function = prev_content and prev_content.role == "function" + if prev_content_is_function: - parts = list(prev_content.parts) - parts.append(part) + prev_parts = list(prev_content.parts) + prev_parts.extend(parts) # replacing last message - vertex_messages[-1] = Content(role=role, parts=parts) + vertex_messages[-1] = Content(role=role, parts=prev_parts) continue - parts = [part] - vertex_messages.append(Content(role=role, parts=parts)) elif isinstance(message, ToolMessage): role = "function" @@ -383,18 +391,19 @@ def _parse_content(raw_content: str | Dict[Any, Any]) -> Dict[Any, Any]: response=content, ) ) + parts = [part] prev_content = vertex_messages[-1] prev_content_is_function = prev_content and prev_content.role == "function" + if prev_content_is_function: - parts = list(prev_content.parts) - parts.append(part) + prev_parts = list(prev_content.parts) + prev_parts.extend(parts) # replacing last message - vertex_messages[-1] = Content(role=role, parts=parts) + vertex_messages[-1] = Content(role=role, parts=prev_parts) continue - else: - parts = [part] - vertex_messages.append(Content(role=role, parts=parts)) + + vertex_messages.append(Content(role=role, parts=parts)) else: raise ValueError( f"Unexpected message with type {type(message)} at the position {i}." diff --git a/libs/vertexai/langchain_google_vertexai/functions_utils.py b/libs/vertexai/langchain_google_vertexai/functions_utils.py index 4fa54c81..771fd70f 100644 --- a/libs/vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/vertexai/langchain_google_vertexai/functions_utils.py @@ -167,6 +167,12 @@ def _format_to_gapic_function_declaration( elif isinstance(tool, vertexai.FunctionDeclaration): return _format_vertex_to_function_declaration(tool) elif isinstance(tool, dict): + # this could come from + # 'langchain_core.utils.function_calling.convert_to_openai_tool' + if tool.get("type") == "function" and tool.get("function"): + return _format_dict_to_function_declaration( + cast(FunctionDescription, tool.get("function")) + ) return _format_dict_to_function_declaration(tool) else: raise ValueError(f"Unsupported tool call type {tool}") diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index e9953877..afc01006 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -464,6 +464,51 @@ def test_parse_history_gemini_function() -> None: ) ], ), + ( + [ + AIMessage( + content=["Mike age is 30"], + tool_calls=[ + ToolCall( + name="Information", + args={"name": "Rob"}, + id="00000000-0000-0000-0000-00000000000", + ), + ], + ), + AIMessage( + content=["Arthur age is 30"], + tool_calls=[ + ToolCall( + name="Information", + args={"name": "Ben"}, + id="00000000-0000-0000-0000-00000000000", + ), + ], + ), + ], + [ + Content( + role="model", + parts=[ + Part(text="Mike age is 30"), + Part( + function_call=FunctionCall( + name="Information", + args={"name": "Rob"}, + ) + ), + Part(text="Arthur age is 30"), + Part( + function_call=FunctionCall( + name="Information", + args={"name": "Ben"}, + ) + ), + ], + ) + ], + ), ], ) def test_parse_history_gemini_multi(source_history, expected_history) -> None: