Skip to content

Commit

Permalink
ensure tool_choice=name works for func & cls tools
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Jul 24, 2024
1 parent 24598b3 commit c3745e0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
10 changes: 3 additions & 7 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,19 +489,15 @@ def bind_tools(
tool_name = tool_choice["function"]["name"]

# check that the specified tool is in the tools list
tool_dicts = [convert_to_openai_tool(tool) for tool in tools]
if tool_name:
if not any(
isinstance(tool, BaseTool) and tool.name == tool_name for tool in tools
) and not any(
isinstance(tool, dict) and tool.get("name") == tool_name
for tool in tools
):
if not any(tool["function"]["name"] == tool_name for tool in tool_dicts):
raise ValueError(
f"Tool choice '{tool_name}' not found in the tools list"
)

return super().bind(
tools=[convert_to_openai_tool(tool) for tool in tools],
tools=tool_dicts,
tool_choice=tool_choice,
**kwargs,
)
Expand Down
64 changes: 64 additions & 0 deletions libs/ai-endpoints/tests/unit_tests/test_bind_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import warnings
from typing import Any

import pytest
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool

from langchain_nvidia_ai_endpoints import ChatNVIDIA


def xxyyzz_func(a: int, b: int) -> int:
"""xxyyzz two numbers"""
return 42


class xxyyzz_cls(BaseModel):
"""xxyyzz two numbers"""

a: int = Field(..., description="First number")
b: int = Field(..., description="Second number")


@tool
def xxyyzz_tool(
a: int = Field(..., description="First number"),
b: int = Field(..., description="Second number"),
) -> int:
"""xxyyzz two numbers"""
return 42


@pytest.mark.parametrize(
"tools, choice",
[
([xxyyzz_func], "xxyyzz_func"),
([xxyyzz_cls], "xxyyzz_cls"),
([xxyyzz_tool], "xxyyzz_tool"),
],
ids=["func", "cls", "tool"],
)
def test_bind_tool_and_select(tools: Any, choice: str) -> None:
warnings.filterwarnings(
"ignore", category=UserWarning, message=".*not known to support tools.*"
)
ChatNVIDIA(api_key="BOGUS").bind_tools(tools=tools, tool_choice=choice)


@pytest.mark.parametrize(
"tools, choice",
[
([], "wrong"),
([xxyyzz_func], "wrong_xxyyzz_func"),
([xxyyzz_cls], "wrong_xxyyzz_cls"),
([xxyyzz_tool], "wrong_xxyyzz_tool"),
],
ids=["empty", "func", "cls", "tool"],
)
def test_bind_tool_and_select_negative(tools: Any, choice: str) -> None:
warnings.filterwarnings(
"ignore", category=UserWarning, message=".*not known to support tools.*"
)
with pytest.raises(ValueError) as e:
ChatNVIDIA(api_key="BOGUS").bind_tools(tools=tools, tool_choice=choice)
assert "not found in the tools list" in str(e.value)

0 comments on commit c3745e0

Please sign in to comment.