diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index a070862321e71..733e410cd1da7 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -655,8 +655,11 @@ def run( if config_param := _get_runnable_config_param(self._run): tool_kwargs[config_param] = config + # For method tools, change name of self or cls to outer_instance if "self" in tool_kwargs: - tool_kwargs["outer_self"] = tool_kwargs.pop("self") + tool_kwargs["outer_instance"] = tool_kwargs.pop("self") + if "cls" in tool_kwargs: + tool_kwargs["outer_instance"] = tool_kwargs.pop("cls") response = context.run(self._run, *tool_args, **tool_kwargs) if self.response_format == "content_and_artifact": @@ -771,8 +774,11 @@ async def arun( if config_param := _get_runnable_config_param(func_to_check): tool_kwargs[config_param] = config + # For method tools, change name of self or cls to outer_instance if "self" in tool_kwargs: - tool_kwargs["outer_self"] = tool_kwargs.pop("self") + tool_kwargs["outer_instance"] = tool_kwargs.pop("self") + if "cls" in tool_kwargs: + tool_kwargs["outer_instance"] = tool_kwargs.pop("cls") coro = context.run(self._arun, *tool_args, **tool_kwargs) if asyncio_accepts_context(): diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index 4fe251bdf3612..0051183a93ca3 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -259,10 +259,15 @@ def invoke_wrapper( description = None if infer_schema or args_schema is not None: - if ( - not isinstance(dec_func, Runnable) - and "self" in inspect.signature(dec_func).parameters + if not isinstance(dec_func, Runnable) and ( + "self" in inspect.signature(dec_func).parameters + or "cls" in inspect.signature(dec_func).parameters ): + outer_instance_name: Literal["self", "cls"] = ( + "self" + if "self" in inspect.signature(dec_func).parameters + else "cls" + ) def method_tool(self: Callable) -> StructuredTool: return StructuredTool.from_function( @@ -276,7 +281,8 @@ def method_tool(self: Callable) -> StructuredTool: response_format=response_format, parse_docstring=parse_docstring, error_on_invalid_docstring=error_on_invalid_docstring, - outer_self=self, + outer_instance=self, + outer_instance_name=outer_instance_name, ) return property(method_tool) diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index 8644564464a6f..a41c35a5be7cf 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -42,30 +42,31 @@ class StructuredTool(BaseTool): """The function to run when the tool is called.""" coroutine: Optional[Callable[..., Awaitable[Any]]] = None """The asynchronous version of the function.""" - outer_self: Optional[Any] = None + outer_instance: Optional[Any] = None """The outer self of the tool for methods.""" + outer_instance_name: Union[Literal["self"], Literal["cls"]] = "self" # --- Runnable --- - def _add_outer_self( + def _add_outer_instance( self, input: Union[str, dict, ToolCall] ) -> Union[dict, ToolCall]: """Add outer self into arguments for method tools.""" # If input is a string, then it is the first argument if isinstance(input, str): - args = {"self": self.outer_self} + args = {self.outer_instance_name: self.outer_instance} for x in self.args: # loop should only happen once args[x] = input return args # ToolCall if "type" in input and input["type"] == "tool_call": - input["args"]["self"] = self.outer_self + input["args"][self.outer_instance_name] = self.outer_instance return input # Dict new_input = cast(dict, input) # to avoid mypy error - new_input["self"] = self.outer_self + new_input[self.outer_instance_name] = self.outer_instance return new_input def invoke( @@ -74,8 +75,8 @@ def invoke( config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - if self.outer_self is not None: - input = self._add_outer_self(input) + if self.outer_instance is not None: + input = self._add_outer_instance(input) return super().invoke(input, config, **kwargs) # TODO: Is this needed? @@ -89,8 +90,8 @@ async def ainvoke( # If the tool does not implement async, fall back to default implementation return await run_in_executor(config, self.invoke, input, config, **kwargs) - if self.outer_self is not None: - input = self._add_outer_self(input) + if self.outer_instance is not None: + input = self._add_outer_instance(input) return await super().ainvoke(input, config, **kwargs) # --- Tool --- @@ -99,8 +100,8 @@ async def ainvoke( def args(self) -> dict: """The tool's input arguments.""" properties = self.args_schema.model_json_schema()["properties"] - if self.outer_self is not None: - properties.pop("self") + if self.outer_instance is not None: + properties.pop(self.outer_instance_name) return properties def _run( @@ -116,8 +117,8 @@ def _run( kwargs["callbacks"] = run_manager.get_child() if config_param := _get_runnable_config_param(self.func): kwargs[config_param] = config - if "outer_self" in kwargs: - kwargs["self"] = kwargs.pop("outer_self") + if "outer_instance" in kwargs: + kwargs[self.outer_instance_name] = kwargs.pop("outer_instance") return self.func(*args, **kwargs) msg = "StructuredTool does not support sync invocation." raise NotImplementedError(msg) @@ -135,8 +136,8 @@ async def _arun( kwargs["callbacks"] = run_manager.get_child() if config_param := _get_runnable_config_param(self.coroutine): kwargs[config_param] = config - if "outer_self" in kwargs: - kwargs["self"] = kwargs.pop("outer_self") + if "outer_instance" in kwargs: + kwargs[self.outer_instance_name] = kwargs.pop("outer_instance") return await self.coroutine(*args, **kwargs) # If self.coroutine is None, then this will delegate to the default @@ -159,7 +160,8 @@ def from_function( response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = False, - outer_self: Optional[Any] = None, + outer_instance: Optional[Any] = None, + outer_instance_name: Union[Literal["self"], Literal["cls"]] = "self", **kwargs: Any, ) -> StructuredTool: """Create tool from a given function. @@ -247,7 +249,8 @@ def add(a: int, b: int) -> int: description=description_, return_direct=return_direct, response_format=response_format, - outer_self=outer_self, + outer_instance=outer_instance, + outer_instance_name=outer_instance_name, **kwargs, ) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 5816b2e41f85e..50490c81a6c7b 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2201,3 +2201,18 @@ def foo(self, a: int, b: int) -> int: tool_message = a.foo.invoke(tool_call) assert int(tool_message.content) == 13 + + +def test_method_tool_classmethod() -> None: + """Test that a method tool can be a classmethod.""" + + class A: + c = 5 + + @classmethod + @tool + def foo(cls, a: int, b: int) -> int: + """Add two numbers to c.""" + return a + b + cls.c + + assert A.foo.invoke({"a": 1, "b": 2}) == 8