Skip to content

Commit

Permalink
Add classmethod support
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanglide committed Dec 9, 2024
1 parent b5daee7 commit ee66099
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 23 deletions.
10 changes: 8 additions & 2 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,11 @@ def run(
if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config

# For method tools, change name of self or cls to outer_instance
if "self" in tool_kwargs:
tool_kwargs["outer_self"] = tool_kwargs.pop("self")
tool_kwargs["outer_instance"] = tool_kwargs.pop("self")
if "cls" in tool_kwargs:
tool_kwargs["outer_instance"] = tool_kwargs.pop("cls")

response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
Expand Down Expand Up @@ -771,8 +774,11 @@ async def arun(
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config

# For method tools, change name of self or cls to outer_instance
if "self" in tool_kwargs:
tool_kwargs["outer_self"] = tool_kwargs.pop("self")
tool_kwargs["outer_instance"] = tool_kwargs.pop("self")
if "cls" in tool_kwargs:
tool_kwargs["outer_instance"] = tool_kwargs.pop("cls")

coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context():
Expand Down
14 changes: 10 additions & 4 deletions libs/core/langchain_core/tools/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,15 @@ def invoke_wrapper(
description = None

if infer_schema or args_schema is not None:
if (
not isinstance(dec_func, Runnable)
and "self" in inspect.signature(dec_func).parameters
if not isinstance(dec_func, Runnable) and (
"self" in inspect.signature(dec_func).parameters
or "cls" in inspect.signature(dec_func).parameters
):
outer_instance_name: Literal["self", "cls"] = (
"self"
if "self" in inspect.signature(dec_func).parameters
else "cls"
)

def method_tool(self: Callable) -> StructuredTool:
return StructuredTool.from_function(
Expand All @@ -276,7 +281,8 @@ def method_tool(self: Callable) -> StructuredTool:
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
outer_self=self,
outer_instance=self,
outer_instance_name=outer_instance_name,
)

return property(method_tool)
Expand Down
37 changes: 20 additions & 17 deletions libs/core/langchain_core/tools/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,31 @@ class StructuredTool(BaseTool):
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
outer_self: Optional[Any] = None
outer_instance: Optional[Any] = None
"""The outer self of the tool for methods."""
outer_instance_name: Union[Literal["self"], Literal["cls"]] = "self"

# --- Runnable ---
def _add_outer_self(
def _add_outer_instance(
self, input: Union[str, dict, ToolCall]
) -> Union[dict, ToolCall]:
"""Add outer self into arguments for method tools."""

# If input is a string, then it is the first argument
if isinstance(input, str):
args = {"self": self.outer_self}
args = {self.outer_instance_name: self.outer_instance}
for x in self.args: # loop should only happen once
args[x] = input
return args

# ToolCall
if "type" in input and input["type"] == "tool_call":
input["args"]["self"] = self.outer_self
input["args"][self.outer_instance_name] = self.outer_instance
return input

# Dict
new_input = cast(dict, input) # to avoid mypy error
new_input["self"] = self.outer_self
new_input[self.outer_instance_name] = self.outer_instance
return new_input

def invoke(
Expand All @@ -74,8 +75,8 @@ def invoke(
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
if self.outer_self is not None:
input = self._add_outer_self(input)
if self.outer_instance is not None:
input = self._add_outer_instance(input)
return super().invoke(input, config, **kwargs)

# TODO: Is this needed?
Expand All @@ -89,8 +90,8 @@ async def ainvoke(
# If the tool does not implement async, fall back to default implementation
return await run_in_executor(config, self.invoke, input, config, **kwargs)

if self.outer_self is not None:
input = self._add_outer_self(input)
if self.outer_instance is not None:
input = self._add_outer_instance(input)
return await super().ainvoke(input, config, **kwargs)

# --- Tool ---
Expand All @@ -99,8 +100,8 @@ async def ainvoke(
def args(self) -> dict:
"""The tool's input arguments."""
properties = self.args_schema.model_json_schema()["properties"]
if self.outer_self is not None:
properties.pop("self")
if self.outer_instance is not None:
properties.pop(self.outer_instance_name)
return properties

def _run(
Expand All @@ -116,8 +117,8 @@ def _run(
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.func):
kwargs[config_param] = config
if "outer_self" in kwargs:
kwargs["self"] = kwargs.pop("outer_self")
if "outer_instance" in kwargs:
kwargs[self.outer_instance_name] = kwargs.pop("outer_instance")
return self.func(*args, **kwargs)
msg = "StructuredTool does not support sync invocation."
raise NotImplementedError(msg)
Expand All @@ -135,8 +136,8 @@ async def _arun(
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.coroutine):
kwargs[config_param] = config
if "outer_self" in kwargs:
kwargs["self"] = kwargs.pop("outer_self")
if "outer_instance" in kwargs:
kwargs[self.outer_instance_name] = kwargs.pop("outer_instance")
return await self.coroutine(*args, **kwargs)

# If self.coroutine is None, then this will delegate to the default
Expand All @@ -159,7 +160,8 @@ def from_function(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
outer_self: Optional[Any] = None,
outer_instance: Optional[Any] = None,
outer_instance_name: Union[Literal["self"], Literal["cls"]] = "self",
**kwargs: Any,
) -> StructuredTool:
"""Create tool from a given function.
Expand Down Expand Up @@ -247,7 +249,8 @@ def add(a: int, b: int) -> int:
description=description_,
return_direct=return_direct,
response_format=response_format,
outer_self=outer_self,
outer_instance=outer_instance,
outer_instance_name=outer_instance_name,
**kwargs,
)

Expand Down
15 changes: 15 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2201,3 +2201,18 @@ def foo(self, a: int, b: int) -> int:
tool_message = a.foo.invoke(tool_call)

assert int(tool_message.content) == 13


def test_method_tool_classmethod() -> None:
"""Test that a method tool can be a classmethod."""

class A:
c = 5

@classmethod
@tool
def foo(cls, a: int, b: int) -> int:
"""Add two numbers to c."""
return a + b + cls.c

assert A.foo.invoke({"a": 1, "b": 2}) == 8

0 comments on commit ee66099

Please sign in to comment.