diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 394bf5a2..84f6bb38 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -58,7 +58,7 @@ from langchain_core.output_parsers.openai_tools import parse_tool_calls from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, root_validator, Field -from langchain_core.runnables import Runnable, RunnablePassthrough +from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableGenerator from vertexai.generative_models import ( # type: ignore Tool as VertexTool, ) @@ -1497,6 +1497,15 @@ def with_structured_output( ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: """Model wrapper that returns outputs formatted to match the given schema. + .. versionchanged:: 1.1.0 + + Return type corrected in version 1.1.0. Previously if a dict schema + was provided then the output had the form + ``[{"args": {}, "name": "schema_name"}]`` where the output was a list with + a single dict and the "args" of the one dict corresponded to the schema. + As of `1.1.0` this has been fixed so that the schema (the value + corresponding to the old "args" key) is returned directly. + Args: schema: The output schema as a dict or a Pydantic class. If a Pydantic class then the model output will be an object of that class. If a dict then @@ -1589,7 +1598,9 @@ class AnswerWithJustification(BaseModel): tools=[schema], first_tool_only=True ) else: - parser = JsonOutputToolsParser() + parser = JsonOutputToolsParser(first_tool_only=True) | RunnableGenerator( + _yield_args + ) llm = self.bind_tools([schema], tool_choice=self._is_gemini_advanced) if include_raw: parser_with_fallback = RunnablePassthrough.assign( @@ -1707,6 +1718,11 @@ def _gemini_chunk_to_generation_chunk( ) +def _yield_args(tool_call_chunks: Iterator[dict]) -> Iterator[dict]: + for tc in tool_call_chunks: + yield tc["args"] + + def _get_usage_metadata_gemini(raw_metadata: dict) -> Optional[UsageMetadata]: """Get UsageMetadata from raw response metadata.""" input_tokens = raw_metadata.get("prompt_token_count", 0) diff --git a/libs/vertexai/langchain_google_vertexai/functions_utils.py b/libs/vertexai/langchain_google_vertexai/functions_utils.py index 771fd70f..10cc7676 100644 --- a/libs/vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/vertexai/langchain_google_vertexai/functions_utils.py @@ -24,7 +24,10 @@ from langchain_core.pydantic_v1 import BaseModel from langchain_core.tools import BaseTool from langchain_core.tools import tool as callable_as_lc_tool -from langchain_core.utils.function_calling import FunctionDescription +from langchain_core.utils.function_calling import ( + FunctionDescription, + convert_to_openai_tool, +) from langchain_core.utils.json_schema import dereference_refs logger = logging.getLogger(__name__) @@ -169,11 +172,8 @@ def _format_to_gapic_function_declaration( elif isinstance(tool, dict): # this could come from # 'langchain_core.utils.function_calling.convert_to_openai_tool' - if tool.get("type") == "function" and tool.get("function"): - return _format_dict_to_function_declaration( - cast(FunctionDescription, tool.get("function")) - ) - return _format_dict_to_function_declaration(tool) + function = convert_to_openai_tool(cast(dict, tool))["function"] + return _format_dict_to_function_declaration(cast(FunctionDescription, function)) else: raise ValueError(f"Unsupported tool call type {tool}") diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index a50f8c90..0bfe2a9b 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -545,16 +545,28 @@ class MyModel(BaseModel): {"name": "MyModel", "description": "MyModel", "parameters": MyModel.schema()} ) response = model.invoke([message]) - expected = [ + assert response == { + "name": "Erick", + "age": 27, + } + + model = llm.with_structured_output( { - "type": "MyModel", - "args": { - "name": "Erick", - "age": 27, + "title": "MyModel", + "description": "MyModel", + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, }, + "required": ["name", "age"], } - ] - assert response == expected + ) + response = model.invoke([message]) + assert response == { + "name": "Erick", + "age": 27, + } @pytest.mark.release