From 2a723c280df02a5c9c8ecc32464ced162a2deb2c Mon Sep 17 00:00:00 2001 From: Piotr Rudnik Date: Sun, 22 Dec 2024 15:06:11 +0100 Subject: [PATCH] genai: Fix handling of optional arrays in tool input --- libs/genai/langchain_google_genai/_function_utils.py | 8 ++++++++ libs/genai/tests/unit_tests/test_function_utils.py | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/libs/genai/langchain_google_genai/_function_utils.py b/libs/genai/langchain_google_genai/_function_utils.py index d7227268..3d8eb4f7 100644 --- a/libs/genai/langchain_google_genai/_function_utils.py +++ b/libs/genai/langchain_google_genai/_function_utils.py @@ -314,6 +314,14 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]: if properties_item.get("type_") == glm.Type.ARRAY and v.get("items"): properties_item["items"] = _get_items_from_schema_any(v.get("items")) + elif properties_item.get("type_") == glm.Type.ARRAY and v.get("anyOf"): + types_with_items = [t for t in v.get("anyOf") if t.get("items")] + if len(types_with_items) > 1: + logger.warning( + "Only first value for 'anyOf' key is supported in array types." + f"Got {len(types_with_items)} types, using first one: {types_with_items[0]}" + ) + properties_item["items"] = _get_items_from_schema_any(types_with_items[0]['items']) if properties_item.get("type_") == glm.Type.OBJECT and v.get("properties"): properties_item["properties"] = _get_properties_from_schema_any( diff --git a/libs/genai/tests/unit_tests/test_function_utils.py b/libs/genai/tests/unit_tests/test_function_utils.py index 536c2a99..d87efbc2 100644 --- a/libs/genai/tests/unit_tests/test_function_utils.py +++ b/libs/genai/tests/unit_tests/test_function_utils.py @@ -309,3 +309,12 @@ class MyModel(BaseModel): gapic_tool = convert_to_genai_function_declarations([MyModel]) tool_dict = tool_to_dict(gapic_tool) assert gapic_tool == convert_to_genai_function_declarations([tool_dict]) + + +def test_tool_input_can_have_optional_arrays() -> None: + class ExampleToolInput(BaseModel): + numbers: Optional[List[str]] = Field() + + gapic_tool = convert_to_genai_function_declarations([ExampleToolInput]) + assert gapic_tool.function_declarations[0].parameters.properties.get('numbers').items.type_ == 1 +