From 9c10541271bf570a92e524f3d0e60ddd184e957c Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Sat, 20 Apr 2024 09:17:08 -0700 Subject: [PATCH] vertexai[patch]: ChatVertexAI.generate, stream, bind_tools (#166) * vertexai[patch]: refactor bind_tools --------- Co-authored-by: Leonid Kuligin --- .../langchain_google_vertexai/chat_models.py | 557 ++++++++---------- .../functions_utils.py | 146 ++++- .../tests/unit_tests/test_function_utils.py | 28 +- 3 files changed, 404 insertions(+), 327 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 1066ce32..c11a0354 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -10,7 +10,6 @@ from typing import ( Any, AsyncIterator, - Callable, Dict, Iterator, List, @@ -19,6 +18,9 @@ Type, Union, cast, + Literal, + TypedDict, + overload, ) import proto # type: ignore[import-untyped] @@ -47,7 +49,6 @@ ToolCallChunk, ToolMessage, ) -from langchain_core.tools import BaseTool, tool as tool_from_callable from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_functions import ( JsonOutputFunctionsParser, @@ -62,9 +63,13 @@ Content, GenerativeModel, Part, + Tool as VertexTool, ) from vertexai.generative_models._generative_models import ( # type: ignore ToolConfig, + SafetySettingsType, + GenerationConfigType, + GenerationResponse, ) from vertexai.language_models import ( # type: ignore ChatMessage, @@ -92,8 +97,13 @@ ) from langchain_google_vertexai.functions_utils import ( _format_tool_config, - _format_tool_to_vertex_function, - _format_tools_to_vertex_tool, + _ToolConfigDict, + _tool_choice_to_tool_config, + _ToolChoiceType, + _FunctionDeclarationLike, + _VertexToolDict, + _format_to_vertex_tool, + _format_functions_to_vertex_tool_dict, ) logger = logging.getLogger(__name__) @@ -107,6 +117,13 @@ class _ChatHistory: context: Optional[str] = None +class _GeminiGenerateContentKwargs(TypedDict): + generation_config: Optional[GenerationConfigType] + safety_settings: Optional[SafetySettingsType] + tools: Optional[List[VertexTool]] + tool_config: Optional[ToolConfig] + + def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory: """Parse a sequence of messages into history. @@ -302,17 +319,18 @@ def _get_question(messages: List[BaseMessage]) -> HumanMessage: return question -def _get_client_with_sys_instruction( - client: GenerativeModel, - system_instruction: Content, - model_name: str, -): - if client._system_instruction != system_instruction: - client = GenerativeModel( - model_name=model_name, - system_instruction=system_instruction, - ) - return client +@overload +def _parse_response_candidate( + response_candidate: "Candidate", streaming: Literal[False] = False +) -> AIMessage: + ... + + +@overload +def _parse_response_candidate( + response_candidate: "Candidate", streaming: Literal[True] +) -> AIMessageChunk: + ... def _parse_response_candidate( @@ -325,9 +343,8 @@ def _parse_response_candidate( tool_call_chunks = [] for part in response_candidate.content.parts: - text = None try: - text = part.text + text: Optional[str] = part.text except AttributeError: text = None @@ -469,13 +486,9 @@ def validate_environment(cls, values: Dict) -> Dict: @property def _is_gemini_advanced(self) -> bool: try: - if float(self.model_name.split("-")[1]) > 1.0: - return True - except ValueError: - pass - except IndexError: - pass - return False + return float(self.model_name.split("-")[1]) > 1.0 + except (ValueError, IndexError): + return False def _generate( self, @@ -500,80 +513,50 @@ def _generate( Raises: ValueError: if the last message in the list is not from human. """ - should_stream = stream if stream is not None else self.streaming - safety_settings = kwargs.pop("safety_settings", None) - if should_stream: - with telemetry.tool_context_manager(self._user_agent): - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) + if stream is True or (stream is None and self.streaming): + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + if not self._is_gemini_model: + return self._generate_non_gemini(messages, stop=stop, **kwargs) + client, contents = self._gemini_client_and_contents(messages) + params = self._gemini_params(stop=stop, **kwargs) + with telemetry.tool_context_manager(self._user_agent): + response = client.generate_content(contents, **params) + return self._gemini_response_to_chat_result(response) + + def _generate_non_gemini( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> ChatResult: + kwargs.pop("safety_settings", None) params = self._prepare_params(stop=stop, stream=False, **kwargs) + question = _get_question(messages) + history = _parse_chat_history(messages[:-1]) + examples = kwargs.get("examples") or self.examples msg_params = {} if "candidate_count" in params: msg_params["candidate_count"] = params.pop("candidate_count") - - if self._is_gemini_model: - system_instruction, history_gemini = _parse_chat_history_gemini( - messages, - project=self.project, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - self.client = _get_client_with_sys_instruction( - client=self.client, - system_instruction=system_instruction, - model_name=self.model_name, - ) - - # set param to `functions` until core tool/function calling implemented - raw_tools = params.pop("functions") if "functions" in params else None - tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None - raw_tool_config = ( - params.pop("tool_config") if "tool_config" in params else None - ) - tool_config = ( - _format_tool_config(raw_tool_config) if raw_tool_config else None + if examples: + params["examples"] = _parse_examples(examples) + with telemetry.tool_context_manager(self._user_agent): + chat = self._start_chat(history, **params) + response = chat.send_message(question.content, **msg_params) + generations = [ + ChatGeneration( + message=AIMessage(content=candidate.text), + generation_info=get_generation_info( + candidate, + self._is_gemini_model, + usage_metadata=response.raw_prediction_response.metadata, + ), ) - with telemetry.tool_context_manager(self._user_agent): - response = self.client.generate_content( - history_gemini, - generation_config=params, - tools=tools, - tool_config=tool_config, - safety_settings=safety_settings, - ) - generations = [ - ChatGeneration( - message=_parse_response_candidate(candidate), - generation_info=get_generation_info( - candidate, - self._is_gemini_model, - usage_metadata=response.to_dict().get("usage_metadata"), - ), - ) - for candidate in response.candidates - ] - else: - question = _get_question(messages) - history = _parse_chat_history(messages[:-1]) - examples = kwargs.get("examples") or self.examples - if examples: - params["examples"] = _parse_examples(examples) - with telemetry.tool_context_manager(self._user_agent): - chat = self._start_chat(history, **params) - response = chat.send_message(question.content, **msg_params) - generations = [ - ChatGeneration( - message=AIMessage(content=candidate.text), - generation_info=get_generation_info( - candidate, - self._is_gemini_model, - usage_metadata=response.raw_prediction_response.metadata, - ), - ) - for candidate in response.candidates - ] + for candidate in response.candidates + ] return ChatResult(generations=generations) async def _agenerate( @@ -601,72 +584,45 @@ async def _agenerate( kwargs.pop("stream") logger.warning("ChatVertexAI does not currently support async streaming.") - params = self._prepare_params(stop=stop, **kwargs) - safety_settings = kwargs.pop("safety_settings", None) + if not self._is_gemini_model: + return await self._agenerate_non_gemini(messages, stop=stop, **kwargs) + + client, contents = self._gemini_client_and_contents(messages) + params = self._gemini_params(stop=stop, **kwargs) + with telemetry.tool_context_manager(self._user_agent): + response = await client.generate_content_async(contents, **params) + return self._gemini_response_to_chat_result(response) + + async def _agenerate_non_gemini( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> ChatResult: + kwargs.pop("safety_settings", None) + params = self._prepare_params(stop=stop, stream=False, **kwargs) + question = _get_question(messages) + history = _parse_chat_history(messages[:-1]) + examples = kwargs.get("examples") or self.examples msg_params = {} if "candidate_count" in params: msg_params["candidate_count"] = params.pop("candidate_count") - - if self._is_gemini_model: - system_instruction, history_gemini = _parse_chat_history_gemini( - messages, - project=self.project, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - - self.client = _get_client_with_sys_instruction( - client=self.client, - system_instruction=system_instruction, - model_name=self.model_name, - ) - # set param to `functions` until core tool/function calling implemented - raw_tools = params.pop("functions") if "functions" in params else None - tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None - raw_tool_config = ( - params.pop("tool_config") if "tool_config" in params else None - ) - tool_config = ( - _format_tool_config(raw_tool_config) if raw_tool_config else None + if examples: + params["examples"] = _parse_examples(examples) + with telemetry.tool_context_manager(self._user_agent): + chat = self._start_chat(history, **params) + response = await chat.send_message_async(question.content, **msg_params) + generations = [ + ChatGeneration( + message=AIMessage(content=candidate.text), + generation_info=get_generation_info( + candidate, + self._is_gemini_model, + usage_metadata=response.raw_prediction_response.metadata, + ), ) - with telemetry.tool_context_manager(self._user_agent): - response = await self.client.generate_content_async( - history_gemini, - generation_config=params, - tools=tools, - tool_config=tool_config, - safety_settings=safety_settings, - ) - generations = [ - ChatGeneration( - message=_parse_response_candidate(c), - generation_info=get_generation_info( - c, - self._is_gemini_model, - usage_metadata=response.to_dict().get("usage_metadata"), - ), - ) - for c in response.candidates - ] - else: - question = _get_question(messages) - history = _parse_chat_history(messages[:-1]) - examples = kwargs.get("examples", None) or self.examples - if examples: - params["examples"] = _parse_examples(examples) - with telemetry.tool_context_manager(self._user_agent): - chat = self._start_chat(history, **params) - response = await chat.send_message_async(question.content, **msg_params) - generations = [ - ChatGeneration( - message=AIMessage(content=r.text), - generation_info=get_generation_info( - r, - self._is_gemini_model, - usage_metadata=response.raw_prediction_response.metadata, - ), - ) - for r in response.candidates - ] + for candidate in response.candidates + ] return ChatResult(generations=generations) def _stream( @@ -676,81 +632,49 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - params = self._prepare_params(stop=stop, stream=True, **kwargs) - if self._is_gemini_model: - safety_settings = params.pop("safety_settings", None) - system_instruction, history_gemini = _parse_chat_history_gemini( - messages, - project=self.project, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - self.client = _get_client_with_sys_instruction( - client=self.client, - system_instruction=system_instruction, - model_name=self.model_name, - ) - # set param to `functions` until core tool/function calling implemented - raw_tools = params.pop("functions") if "functions" in params else None - tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None - raw_tool_config = ( - params.pop("tool_config") if "tool_config" in params else None - ) - tool_config = ( - _format_tool_config(raw_tool_config) if raw_tool_config else None + if not self._is_gemini_model: + yield from self._stream_non_gemini( + messages, stop=stop, run_manager=run_manager, **kwargs ) - with telemetry.tool_context_manager(self._user_agent): - responses = self.client.generate_content( - history_gemini, - stream=True, - generation_config=params, - tools=tools, - tool_config=tool_config, - safety_settings=safety_settings, - ) - for response in responses: - message = _parse_response_candidate( - response.candidates[0], streaming=True - ) - generation_info = get_generation_info( - response.candidates[0], + return + + client, contents = self._gemini_client_and_contents(messages) + params = self._gemini_params(stop=stop, stream=True, **kwargs) + with telemetry.tool_context_manager(self._user_agent): + response_iter = client.generate_content(contents, **params, stream=True) + for response_chunk in response_iter: + chunk = self._gemini_chunk_to_generation_chunk(response_chunk) + if run_manager and isinstance(chunk.message.content, str): + run_manager.on_llm_new_token(chunk.message.content) + yield chunk + + def _stream_non_gemini( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + params = self._prepare_params(stop=stop, stream=True, **kwargs) + question = _get_question(messages) + history = _parse_chat_history(messages[:-1]) + examples = kwargs.get("examples", None) + if examples: + params["examples"] = _parse_examples(examples) + with telemetry.tool_context_manager(self._user_agent): + chat = self._start_chat(history, **params) + responses = chat.send_message_streaming(question.content, **params) + for response in responses: + if run_manager: + run_manager.on_llm_new_token(response.text) + yield ChatGenerationChunk( + message=AIMessageChunk(content=response.text), + generation_info=get_generation_info( + response, self._is_gemini_model, - usage_metadata=response.to_dict().get("usage_metadata"), - ) - if run_manager and isinstance(message.content, str): - run_manager.on_llm_new_token(message.content) - if isinstance(message, AIMessageChunk): - yield ChatGenerationChunk( - message=message, - generation_info=generation_info, - ) - else: - yield ChatGenerationChunk( - message=AIMessageChunk( - content=message.content, - additional_kwargs=message.additional_kwargs, - ), - generation_info=generation_info, - ) - else: - question = _get_question(messages) - history = _parse_chat_history(messages[:-1]) - examples = kwargs.get("examples", None) - if examples: - params["examples"] = _parse_examples(examples) - with telemetry.tool_context_manager(self._user_agent): - chat = self._start_chat(history, **params) - responses = chat.send_message_streaming(question.content, **params) - for response in responses: - if run_manager: - run_manager.on_llm_new_token(response.text) - yield ChatGenerationChunk( - message=AIMessageChunk(content=response.text), - generation_info=get_generation_info( - response, - self._is_gemini_model, - usage_metadata=response.raw_prediction_response.metadata, - ), - ) + usage_metadata=response.raw_prediction_response.metadata, + ), + ) async def _astream( self, @@ -761,52 +685,16 @@ async def _astream( ) -> AsyncIterator[ChatGenerationChunk]: if not self._is_gemini_model: raise NotImplementedError() - params = self._prepare_params(stop=stop, stream=True, **kwargs) - safety_settings = params.pop("safety_settings", None) - system_instruction, history_gemini = _parse_chat_history_gemini( - messages, - project=self.project, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - self.client = _get_client_with_sys_instruction( - client=self.client, - system_instruction=system_instruction, - model_name=self.model_name, - ) - raw_tools = params.pop("functions") if "functions" in params else None - tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None - raw_tool_config = params.pop("tool_config") if "tool_config" in params else None - tool_config = _format_tool_config(raw_tool_config) if raw_tool_config else None + client, contents = self._gemini_client_and_contents(messages) + params = self._gemini_params(stop=stop, stream=True, **kwargs) with telemetry.tool_context_manager(self._user_agent): - async for chunk in await self.client.generate_content_async( - history_gemini, - stream=True, - generation_config=params, - tools=tools, - tool_config=tool_config, - safety_settings=safety_settings, + async for response_chunk in await client.generate_content_async( + contents, **params, stream=True ): - message = _parse_response_candidate(chunk.candidates[0], streaming=True) - generation_info = get_generation_info( - chunk.candidates[0], - self._is_gemini_model, - usage_metadata=chunk.to_dict().get("usage_metadata"), - ) - if run_manager and isinstance(message.content, str): - await run_manager.on_llm_new_token(message.content) - if isinstance(message, AIMessageChunk): - yield ChatGenerationChunk( - message=message, - generation_info=generation_info, - ) - else: - yield ChatGenerationChunk( - message=AIMessageChunk( - content=message.content, - additional_kwargs=message.additional_kwargs, - ), - generation_info=generation_info, - ) + chunk = self._gemini_chunk_to_generation_chunk(response_chunk) + if run_manager and isinstance(chunk.message.content, str): + await run_manager.on_llm_new_token(chunk.message.content) + yield chunk def with_structured_output( self, @@ -910,21 +798,7 @@ class AnswerWithJustification(BaseModel): ) else: parser = JsonOutputFunctionsParser() - - name = _format_tool_to_vertex_function(schema)["name"] - - if self._is_gemini_advanced: - llm = self.bind( - functions=[schema], - tool_config={ - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": [name], - } - }, - ) - else: - llm = self.bind(functions=[schema]) + llm = self.bind_tools([schema], tool_choice=self._is_gemini_advanced) if include_raw: parser_with_fallback = RunnablePassthrough.assign( parsed=itemgetter("raw") | parser, parsing_error=lambda _: None @@ -938,8 +812,10 @@ class AnswerWithJustification(BaseModel): def bind_tools( self, - tools: Sequence[Union[Type[BaseModel], Callable, BaseTool]], - tool_config: Optional[Dict[str, Any]] = None, + tools: Sequence[Union[_FunctionDeclarationLike, VertexTool]], + tool_config: Optional[_ToolConfigDict] = None, + *, + tool_choice: Optional[Union[_ToolChoiceType, bool]] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -954,19 +830,30 @@ def bind_tools( **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """ - formatted_tools = [] + if tool_choice and tool_config: + raise ValueError( + "Must specify at most one of tool_choice and tool_config, received " + f"both:\n\n{tool_choice=}\n\n{tool_config=}" + ) + vertexai_tools: List[_VertexToolDict] = [] + vertexai_functions = [] for schema in tools: - if isinstance(schema, BaseTool) or ( - isinstance(schema, type) and issubclass(schema, BaseModel) - ): - formatted_tools.append(schema) - elif callable(schema): - formatted_tools.append(tool_from_callable(schema)) # type: ignore - else: - raise ValueError( - "Tool must be a BaseTool, Pydantic model, or callable." + if isinstance(schema, VertexTool): + vertexai_tools.append( + {"function_declarations": schema.to_dict()["function_declarations"]} ) - return self.bind(functions=formatted_tools, tool_config=tool_config, **kwargs) + elif isinstance(schema, dict) and "function_declarations" in schema: + vertexai_tools.append(cast(_VertexToolDict, schema)) + else: + vertexai_functions.append(schema) + vertexai_tools.append(_format_functions_to_vertex_tool_dict(vertexai_functions)) + if tool_choice: + all_names = [ + f["name"] for vt in vertexai_tools for f in vt["function_declarations"] + ] + tool_config = _tool_choice_to_tool_config(tool_choice, all_names) + # Bind dicts for easier serialization/deserialization. + return self.bind(tools=vertexai_tools, tool_config=tool_config, **kwargs) def _start_chat( self, history: _ChatHistory, **kwargs: Any @@ -977,3 +864,75 @@ def _start_chat( ) else: return self.client.start_chat(message_history=history.history, **kwargs) + + def _gemini_params( + self, + *, + stop: Optional[List[str]] = None, + stream: bool = False, + tools: Optional[List[Union[_VertexToolDict, VertexTool]]] = None, + functions: Optional[List[_FunctionDeclarationLike]] = None, + tool_config: Optional[Union[_ToolConfigDict, ToolConfig]] = None, + safety_settings: Optional[SafetySettingsType] = None, + **kwargs: Any, + ) -> _GeminiGenerateContentKwargs: + generation_config = self._prepare_params(stop=stop, stream=stream, **kwargs) + if tools: + tools = [_format_to_vertex_tool(tool) for tool in tools] + elif functions: + tools = [_format_to_vertex_tool(functions)] + else: + pass + + if tool_config and not isinstance(tool_config, ToolConfig): + tool_config = _format_tool_config(cast(_ToolConfigDict, tool_config)) + + return _GeminiGenerateContentKwargs( + generation_config=generation_config, + tools=tools, + tool_config=tool_config, + safety_settings=safety_settings, + ) + + def _gemini_client_and_contents( + self, messages: List[BaseMessage] + ) -> tuple[GenerativeModel, list[Content]]: + system, contents = _parse_chat_history_gemini( + messages, + project=self.project, + convert_system_message_to_human=self.convert_system_message_to_human, + ) + # TODO: Store default client params explicitly so private params don't have to + # be accessed, like _safety_settings. + client = GenerativeModel( + model_name=self.model_name, + system_instruction=system, + safety_settings=self.client._safety_settings, + ) + return client, contents + + def _gemini_response_to_chat_result( + self, response: GenerationResponse + ) -> ChatResult: + generations = [] + usage = response.to_dict().get("usage_metadata") + for candidate in response.candidates: + info = get_generation_info(candidate, is_gemini=True, usage_metadata=usage) + message = _parse_response_candidate(candidate) + generations.append(ChatGeneration(message=message, generation_info=info)) + return ChatResult(generations=generations) + + def _gemini_chunk_to_generation_chunk( + self, response_chunk: GenerationResponse + ) -> ChatGenerationChunk: + top_candidate = response_chunk.candidates[0] + message = _parse_response_candidate(top_candidate, streaming=True) + generation_info = get_generation_info( + top_candidate, + is_gemini=True, + usage_metadata=response_chunk.to_dict().get("usage_metadata"), + ) + return ChatGenerationChunk( + message=message, + generation_info=generation_info, + ) diff --git a/libs/vertexai/langchain_google_vertexai/functions_utils.py b/libs/vertexai/langchain_google_vertexai/functions_utils.py index d1d9d5e6..3eda90c3 100644 --- a/libs/vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/vertexai/langchain_google_vertexai/functions_utils.py @@ -1,24 +1,31 @@ +from __future__ import annotations + import json -from typing import Any, Dict, List, Optional, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Type, + TypedDict, + Union, +) from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import BaseOutputParser from langchain_core.outputs import ChatGeneration, Generation from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool +from langchain_core.tools import tool as callable_as_lc_tool from langchain_core.utils.function_calling import FunctionDescription from langchain_core.utils.json_schema import dereference_refs -from vertexai.generative_models import ( # type: ignore - FunctionDeclaration, -) -from vertexai.generative_models import ( - Tool as VertexTool, -) +from vertexai.generative_models import FunctionDeclaration # type: ignore +from vertexai.generative_models import Tool as VertexTool # FIXME: vertexai is not exporting ToolConfig -from vertexai.generative_models._generative_models import ( # type: ignore - ToolConfig, -) +from vertexai.generative_models._generative_models import ToolConfig # type: ignore def _format_pydantic_to_vertex_function( @@ -57,8 +64,8 @@ def _format_base_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription: } -def _format_tool_to_vertex_function( - tool: Union[BaseTool, Type[BaseModel], dict], +def _format_to_vertex_function_dict( + tool: Union[BaseTool, Type[BaseModel], dict, Callable, FunctionDeclaration], ) -> FunctionDescription: "Format tool into the Vertex function declaration." if isinstance(tool, BaseTool): @@ -71,25 +78,62 @@ def _format_tool_to_vertex_function( "description": tool["description"], "parameters": _get_parameters_from_schema(tool["parameters"]), } + elif isinstance(tool, FunctionDeclaration): + tool_dict = tool.to_dict() + return FunctionDescription( + name=tool_dict["name"], + description=tool_dict["description"], + parameters=tool_dict["parameters"], + ) + elif callable(tool): + return _format_base_tool_to_vertex_function(callable_as_lc_tool()(tool)) else: raise ValueError(f"Unsupported tool call type {tool}") -def _format_tools_to_vertex_tool( - tools: List[Union[BaseTool, Type[BaseModel], dict]], -) -> List[VertexTool]: - "Format tools into the Vertex Tool instance." - function_declarations = [] - for tool in tools: - func = _format_tool_to_vertex_function(tool) - function_declarations.append(FunctionDeclaration(**func)) +_FunctionDeclarationLike = Union[ + BaseTool, Type[BaseModel], dict, Callable, FunctionDeclaration +] + - return [VertexTool(function_declarations=function_declarations)] +class _VertexToolDict(TypedDict): + function_declarations: List[FunctionDescription] -def _format_tool_config(tool_config: Dict[str, Any]) -> Union[ToolConfig, None]: +def _format_functions_to_vertex_tool_dict( + functions: List[_FunctionDeclarationLike], +) -> _VertexToolDict: + "Format tools into the Vertex Tool instance." + function_declarations = [_format_to_vertex_function_dict(fn) for fn in functions] + return _VertexToolDict(function_declarations=function_declarations) + + +def _format_to_vertex_tool( + tool: Union[VertexTool, _VertexToolDict, List[_FunctionDeclarationLike]], +) -> VertexTool: + if isinstance(tool, VertexTool): + return tool + elif isinstance(tool, (list, dict)): + tool = ( + _format_functions_to_vertex_tool_dict(tool) + if isinstance(tool, list) + else tool + ) + return VertexTool( + function_declarations=[ + FunctionDeclaration(**fd) for fd in tool["function_declarations"] + ] + ) + else: + raise ValueError(f"Unexpected tool value:\n\n{tool=}") + + +def _format_tool_config(tool_config: _ToolConfigDict) -> Union[ToolConfig, None]: if "function_calling_config" not in tool_config: - return None + raise ValueError( + "Invalid ToolConfig, missing 'function_calling_config' key. Received:\n\n" + f"{tool_config=}" + ) return ToolConfig( function_calling_config=ToolConfig.FunctionCallingConfig( **tool_config["function_calling_config"] @@ -195,3 +239,59 @@ def parse_result( def parse(self, text: str) -> BaseModel: raise ValueError("Can only parse messages") + + +class _FunctionCallingConfigDict(TypedDict): + mode: ToolConfig.FunctionCallingConfig.Mode + allowed_function_names: Optional[List[str]] + + +class _ToolConfigDict(TypedDict): + function_calling_config: _FunctionCallingConfigDict + + +_ToolChoiceType = Union[ + dict, List[str], str, Literal["auto", "none", "any"], Literal[True] +] + + +def _tool_choice_to_tool_config( + tool_choice: _ToolChoiceType, + all_names: List[str], +) -> _ToolConfigDict: + allowed_function_names: Optional[List[str]] = None + if tool_choice is True or tool_choice == "any": + mode = ToolConfig.FunctionCallingConfig.Mode.ANY + allowed_function_names = all_names + elif tool_choice == "auto": + mode = ToolConfig.FunctionCallingConfig.Mode.AUTO + elif tool_choice == "none": + mode = ToolConfig.FunctionCallingConfig.Mode.NONE + elif isinstance(tool_choice, str): + mode = ToolConfig.FunctionCallingConfig.Mode.ANY + allowed_function_names = [tool_choice] + elif isinstance(tool_choice, list): + mode = ToolConfig.FunctionCallingConfig.Mode.ANY + allowed_function_names = tool_choice + elif isinstance(tool_choice, dict): + if "mode" in tool_choice: + mode = tool_choice["mode"] + allowed_function_names = tool_choice.get("allowed_function_names") + elif "function_calling_config" in tool_choice: + mode = tool_choice["function_calling_config"]["mode"] + allowed_function_names = tool_choice["function_calling_config"].get( + "allowed_function_names" + ) + else: + raise ValueError( + f"Unrecognized tool choice format:\n\n{tool_choice=}\n\nShould match " + f"VertexAI ToolConfig or FunctionCallingConfig format." + ) + else: + raise ValueError(f"Unrecognized tool choice format:\n\n{tool_choice=}") + return _ToolConfigDict( + function_calling_config=_FunctionCallingConfigDict( + mode=mode, + allowed_function_names=allowed_function_names, + ) + ) diff --git a/libs/vertexai/tests/unit_tests/test_function_utils.py b/libs/vertexai/tests/unit_tests/test_function_utils.py index bdc685ad..4d031eaf 100644 --- a/libs/vertexai/tests/unit_tests/test_function_utils.py +++ b/libs/vertexai/tests/unit_tests/test_function_utils.py @@ -1,6 +1,7 @@ from enum import Enum -from typing import Optional, Sequence +from typing import Any, Optional, Sequence +import pytest from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped] @@ -10,7 +11,10 @@ from langchain_google_vertexai.functions_utils import ( _format_base_tool_to_vertex_function, _format_tool_config, + _FunctionCallingConfigDict, _get_parameters_from_schema, + _tool_choice_to_tool_config, + _ToolConfigDict, ) @@ -56,15 +60,17 @@ def do_something_optional(a: float, b: float = 0) -> str: assert len(schema["parameters"]["required"]) == 1 -def test_format_tool_config(): - tool_config = _format_tool_config({}) - assert tool_config is None +def test_format_tool_config_invalid(): + with pytest.raises(ValueError): + _format_tool_config({}) # type: ignore + +def test_format_tool_config(): tool_config = _format_tool_config( { "function_calling_config": { "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": "my_fun", + "allowed_function_names": ["my_fun"], } } ) @@ -139,3 +145,15 @@ class B(BaseModel): "title": "B", "required": ["array_field", "int_field", "str_field", "str_enum_field"], } + + +@pytest.mark.parametrize("choice", (True, "foo", ["foo"], "any")) +def test__tool_choice_to_tool_config(choice: Any) -> None: + expected = _ToolConfigDict( + function_calling_config=_FunctionCallingConfigDict( + mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + allowed_function_names=["foo"], + ), + ) + actual = _tool_choice_to_tool_config(choice, ["foo"]) + assert expected == actual