From 524ee6d9acbfed9c6d86ece971c661ad1cf5254f Mon Sep 17 00:00:00 2001 From: Mohammad Mohtashim <45242107+keenborder786@users.noreply.github.com> Date: Sun, 8 Dec 2024 00:33:40 +0500 Subject: [PATCH] Invalid `tool_choice` being passed to `ChatLiteLLM` (#28198) - **Description:** Invalid `tool_choice` is given to `ChatLiteLLM` to `bind_tools` due to it's parent's class default value being pass through `with_structured_output`. - **Issue:** #28176 --- .../chat_models/litellm.py | 73 +++++++++++++++++-- 1 file changed, 67 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/chat_models/litellm.py b/libs/community/langchain_community/chat_models/litellm.py index d6c9557339857..83c3020910155 100644 --- a/libs/community/langchain_community/chat_models/litellm.py +++ b/libs/community/langchain_community/chat_models/litellm.py @@ -11,6 +11,7 @@ Dict, Iterator, List, + Literal, Mapping, Optional, Sequence, @@ -212,6 +213,33 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict +_OPENAI_MODELS = [ + "o1-mini", + "o1-preview", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4-turbo", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-16k-0613", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", +] + + class ChatLiteLLM(BaseChatModel): """Chat model that uses the LiteLLM API.""" @@ -465,6 +493,9 @@ async def _agenerate( def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -476,17 +507,47 @@ def bind_tools( Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic models, callables, and BaseTools will be automatically converted to their schema dictionary representation. - tool_choice: Which tool to require the model to call. - Must be the name of the single provided function or - "auto" to automatically determine which function to call - (if any), or a dict of the form: - {"type": "function", "function": {"name": <>}}. + tool_choice: Which tool to require the model to call. Options are: + - str of the form ``"<>"``: calls <> tool. + - ``"auto"``: + automatically selects a tool (including no tool). + - ``"none"``: + does not call a tool. + - ``"any"`` or ``"required"`` or ``True``: + forces least one tool to be called. + - dict of the form: + ``{"type": "function", "function": {"name": <>}}`` + - ``False`` or ``None``: no effect **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """ formatted_tools = [convert_to_openai_tool(tool) for tool in tools] - return super().bind(tools=formatted_tools, **kwargs) + + # In case of openai if tool_choice is `any` or if bool has been provided we + # change it to `required` as that is suppored by openai. + if ( + (self.model is not None and "azure" in self.model) + or (self.model_name is not None and "azure" in self.model_name) + or (self.model is not None and self.model in _OPENAI_MODELS) + or (self.model_name is not None and self.model_name in _OPENAI_MODELS) + ) and (tool_choice == "any" or isinstance(tool_choice, bool)): + tool_choice = "required" + # If tool_choice is bool apart from openai we make it `any` + elif isinstance(tool_choice, bool): + tool_choice = "any" + elif isinstance(tool_choice, dict): + tool_names = [ + formatted_tool["function"]["name"] for formatted_tool in formatted_tools + ] + if not any( + tool_name == tool_choice["function"]["name"] for tool_name in tool_names + ): + raise ValueError( + f"Tool choice {tool_choice} was specified, but the only " + f"provided tools were {tool_names}." + ) + return super().bind(tools=formatted_tools, tool_choice=tool_choice, **kwargs) @property def _identifying_params(self) -> Dict[str, Any]: