Skip to content

Commit

Permalink
Merge pull request #16 from jzaldi/fix_no_args_function_call
Browse files Browse the repository at this point in the history
Fix no args function call
  • Loading branch information
lkuligin authored Feb 21, 2024
2 parents 302f264 + 48ebb98 commit 2689bd3
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 23 deletions.
57 changes: 34 additions & 23 deletions libs/vertexai/langchain_google_vertexai/functions_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Dict, List, Type, Union
from typing import Any, Dict, List, Type, Union

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser
Expand All @@ -25,17 +25,7 @@ def _format_pydantic_to_vertex_function(
return {
"name": schema["title"],
"description": schema.get("description", ""),
"parameters": {
"properties": {
k: {
"type": v["type"],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type": schema["type"],
},
"parameters": _get_parameters_from_schema(schema=schema),
}


Expand All @@ -48,17 +38,7 @@ def _format_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription:
return {
"name": tool.name or schema["title"],
"description": tool.description or schema["description"],
"parameters": {
"properties": {
k: {
"type": v["type"],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type": schema["type"],
},
"parameters": _get_parameters_from_schema(schema=schema),
}
else:
return {
Expand Down Expand Up @@ -89,6 +69,37 @@ def _format_tools_to_vertex_tool(
return [VertexTool(function_declarations=function_declarations)]


def _get_parameters_from_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
"""Given a schema, format the parameters key to match VertexAI
expected input.
Args:
schema: Dictionary that must have the following keys.
Returns:
Dictionary with the formatted parameters.
"""

parameters = {}

parameters["type"] = schema["type"]

if "required" in schema:
parameters["required"] = schema["required"]

schema_properties: Dict[str, Any] = schema.get("properties", {})

parameters["properties"] = {
parameter_name: {
"type": parameter_dict["type"],
"description": parameter_dict.get("description"),
}
for parameter_name, parameter_dict in schema_properties.items()
}

return parameters


class PydanticFunctionsOutputParser(BaseOutputParser):
"""Parse an output as a pydantic object.
Expand Down
45 changes: 45 additions & 0 deletions libs/vertexai/tests/unit_tests/test_function_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from langchain_core.tools import tool

from langchain_google_vertexai.functions_utils import _format_tool_to_vertex_function


def test_format_tool_to_vertex_function():
@tool
def get_datetime() -> str:
"""Gets the current datetime"""
import datetime

return datetime.datetime.now().strftime("%Y-%m-%d")

schema = _format_tool_to_vertex_function(get_datetime) # type: ignore

assert schema["name"] == "get_datetime"
assert schema["description"] == "get_datetime() -> str - Gets the current datetime"
assert "parameters" in schema
assert "required" not in schema["parameters"]

@tool
def sum_two_numbers(a: float, b: float) -> str:
"""Sum two numbers 'a' and 'b'.
Returns:
a + b in string format
"""
return str(a + b)

schema = _format_tool_to_vertex_function(sum_two_numbers) # type: ignore

assert schema["name"] == "sum_two_numbers"
assert "parameters" in schema
assert len(schema["parameters"]["required"]) == 2

@tool
def do_something_optional(a: float, b: float = 0) -> str:
"""Some description"""
return str(a + b)

schema = _format_tool_to_vertex_function(do_something_optional) # type: ignore

assert schema["name"] == "do_something_optional"
assert "parameters" in schema
assert len(schema["parameters"]["required"]) == 1

0 comments on commit 2689bd3

Please sign in to comment.