From e24f86e55f73f6dd84c94be764d1d509926033b0 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 10 Dec 2024 01:59:38 -0800 Subject: [PATCH] core[patch]: return ToolMessage from tool (#28605) --- libs/core/langchain_core/messages/tool.py | 11 ++- libs/core/langchain_core/tools/base.py | 103 +++++++++++++++++----- libs/core/langchain_core/tools/simple.py | 6 +- libs/core/tests/unit_tests/test_tools.py | 64 ++++++++++++++ 4 files changed, 158 insertions(+), 26 deletions(-) diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 653dd838f860e..873f872cef268 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -9,7 +9,16 @@ from langchain_core.utils._merge import merge_dicts, merge_obj -class ToolMessage(BaseMessage): +class ToolOutputMixin: + """Mixin for objects that tools can return directly. + + If a custom BaseTool is invoked with a ToolCall and the output of custom code is + not an instance of ToolOutputMixin, the output will automatically be coerced to a + string and wrapped in a ToolMessage. + """ + + +class ToolMessage(BaseMessage, ToolOutputMixin): """Message for passing the result of executing a tool back to a model. ToolMessages contain the result of a tool invocation. Typically, the result diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 815607f3b4325..ff264edac3284 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -45,7 +45,7 @@ CallbackManager, Callbacks, ) -from langchain_core.messages.tool import ToolCall, ToolMessage +from langchain_core.messages.tool import ToolCall, ToolMessage, ToolOutputMixin from langchain_core.runnables import ( RunnableConfig, RunnableSerializable, @@ -494,7 +494,9 @@ async def ainvoke( # --- Tool --- - def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any]]: + def _parse_input( + self, tool_input: Union[str, dict], tool_call_id: Optional[str] + ) -> Union[str, dict[str, Any]]: """Convert tool input to a pydantic model. Args: @@ -512,9 +514,39 @@ def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any else: if input_args is not None: if issubclass(input_args, BaseModel): + for k, v in get_all_basemodel_annotations(input_args).items(): + if ( + _is_injected_arg_type(v, injected_type=InjectedToolCallId) + and k not in tool_input + ): + if tool_call_id is None: + msg = ( + "When tool includes an InjectedToolCallId " + "argument, tool must always be invoked with a full " + "model ToolCall of the form: {'args': {...}, " + "'name': '...', 'type': 'tool_call', " + "'tool_call_id': '...'}" + ) + raise ValueError(msg) + tool_input[k] = tool_call_id result = input_args.model_validate(tool_input) result_dict = result.model_dump() elif issubclass(input_args, BaseModelV1): + for k, v in get_all_basemodel_annotations(input_args).items(): + if ( + _is_injected_arg_type(v, injected_type=InjectedToolCallId) + and k not in tool_input + ): + if tool_call_id is None: + msg = ( + "When tool includes an InjectedToolCallId " + "argument, tool must always be invoked with a full " + "model ToolCall of the form: {'args': {...}, " + "'name': '...', 'type': 'tool_call', " + "'tool_call_id': '...'}" + ) + raise ValueError(msg) + tool_input[k] = tool_call_id result = input_args.parse_obj(tool_input) result_dict = result.dict() else: @@ -570,8 +602,10 @@ async def _arun(self, *args: Any, **kwargs: Any) -> Any: kwargs["run_manager"] = kwargs["run_manager"].get_sync() return await run_in_executor(None, self._run, *args, **kwargs) - def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]: - tool_input = self._parse_input(tool_input) + def _to_args_and_kwargs( + self, tool_input: Union[str, dict], tool_call_id: Optional[str] + ) -> tuple[tuple, dict]: + tool_input = self._parse_input(tool_input, tool_call_id) # For backwards compatibility, if run_input is a string, # pass as a positional argument. if isinstance(tool_input, str): @@ -648,10 +682,9 @@ def run( child_config = patch_config(config, callbacks=run_manager.get_child()) context = copy_context() context.run(_set_config_context, child_config) - tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) + tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id) if signature(self._run).parameters.get("run_manager"): tool_kwargs["run_manager"] = run_manager - if config_param := _get_runnable_config_param(self._run): tool_kwargs[config_param] = config response = context.run(self._run, *tool_args, **tool_kwargs) @@ -755,7 +788,7 @@ async def arun( artifact = None error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None try: - tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) + tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id) child_config = patch_config(config, callbacks=run_manager.get_child()) context = copy_context() context.run(_set_config_context, child_config) @@ -889,20 +922,23 @@ def _prep_run_args( def _format_output( - content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str -) -> Union[ToolMessage, Any]: - if tool_call_id: - if not _is_message_content_type(content): - content = _stringify(content) - return ToolMessage( - content, - artifact=artifact, - tool_call_id=tool_call_id, - name=name, - status=status, - ) - else: + content: Any, + artifact: Any, + tool_call_id: Optional[str], + name: str, + status: str, +) -> Union[ToolOutputMixin, Any]: + if isinstance(content, ToolOutputMixin) or not tool_call_id: return content + if not _is_message_content_type(content): + content = _stringify(content) + return ToolMessage( + content, + artifact=artifact, + tool_call_id=tool_call_id, + name=name, + status=status, + ) def _is_message_content_type(obj: Any) -> bool: @@ -954,10 +990,31 @@ class InjectedToolArg: """Annotation for a Tool arg that is **not** meant to be generated by a model.""" -def _is_injected_arg_type(type_: type) -> bool: +class InjectedToolCallId(InjectedToolArg): + r'''Annotation for injecting the tool_call_id. + + Example: + ..code-block:: python + + from typing_extensions import Annotated + + from langchain_core.messages import ToolMessage + from langchain_core.tools import tool, InjectedToolCallID + + @tool + def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallID]) -> ToolMessage: + """Return x.""" + return ToolMessage(str(x), artifact=x, name="foo", tool_call_id=tool_call_id) + ''' # noqa: E501 + + +def _is_injected_arg_type( + type_: type, injected_type: Optional[type[InjectedToolArg]] = None +) -> bool: + injected_type = injected_type or InjectedToolArg return any( - isinstance(arg, InjectedToolArg) - or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)) + isinstance(arg, injected_type) + or (isinstance(arg, type) and issubclass(arg, injected_type)) for arg in get_args(type_)[1:] ) diff --git a/libs/core/langchain_core/tools/simple.py b/libs/core/langchain_core/tools/simple.py index 118c8b39f6db3..d9e38ba227c8b 100644 --- a/libs/core/langchain_core/tools/simple.py +++ b/libs/core/langchain_core/tools/simple.py @@ -62,9 +62,11 @@ def args(self) -> dict: # assume it takes a single string input. return {"tool_input": {"type": "string"}} - def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]: + def _to_args_and_kwargs( + self, tool_input: Union[str, dict], tool_call_id: Optional[str] + ) -> tuple[tuple, dict]: """Convert tool input to pydantic model.""" - args, kwargs = super()._to_args_and_kwargs(tool_input) + args, kwargs = super()._to_args_and_kwargs(tool_input, tool_call_id) # For backwards compatibility. The tool must be run with a single input all_args = list(args) + list(kwargs.values()) if len(all_args) != 1: diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index ce7ea4894bb5a..164ecc508e76e 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -31,6 +31,7 @@ CallbackManagerForToolRun, ) from langchain_core.messages import ToolMessage +from langchain_core.messages.tool import ToolOutputMixin from langchain_core.runnables import ( Runnable, RunnableConfig, @@ -46,6 +47,7 @@ ) from langchain_core.tools.base import ( InjectedToolArg, + InjectedToolCallId, SchemaAnnotationError, _is_message_content_block, _is_message_content_type, @@ -856,6 +858,7 @@ class _RaiseNonValidationErrorTool(BaseTool): def _parse_input( self, tool_input: Union[str, dict], + tool_call_id: Optional[str], ) -> Union[str, dict[str, Any]]: raise NotImplementedError @@ -920,6 +923,7 @@ class _RaiseNonValidationErrorTool(BaseTool): def _parse_input( self, tool_input: Union[str, dict], + tool_call_id: Optional[str], ) -> Union[str, dict[str, Any]]: raise NotImplementedError @@ -2110,3 +2114,63 @@ def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: return foo.value assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore + + +def test_tool_injected_tool_call_id() -> None: + @tool + def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage: + """foo""" + return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore + + assert foo.invoke( + {"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"} + ) == ToolMessage(0, tool_call_id="bar") # type: ignore + + with pytest.raises(ValueError): + assert foo.invoke({"x": 0}) + + @tool + def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage: + """foo""" + return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore + + assert foo2.invoke( + {"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"} + ) == ToolMessage(0, tool_call_id="bar") # type: ignore + + +def test_tool_uninjected_tool_call_id() -> None: + @tool + def foo(x: int, tool_call_id: str) -> ToolMessage: + """foo""" + return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore + + with pytest.raises(ValueError): + foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}) + + assert foo.invoke( + { + "type": "tool_call", + "args": {"x": 0, "tool_call_id": "zap"}, + "name": "foo", + "id": "bar", + } + ) == ToolMessage(0, tool_call_id="zap") # type: ignore + + +def test_tool_return_output_mixin() -> None: + class Bar(ToolOutputMixin): + def __init__(self, x: int) -> None: + self.x = x + + def __eq__(self, other: Any) -> bool: + return isinstance(other, self.__class__) and self.x == other.x + + @tool + def foo(x: int) -> Bar: + """Foo.""" + return Bar(x=x) + + assert foo.invoke( + {"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"} + ) == Bar(x=0)