From a2d30428237695f076060dec881bae0258123775 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 22 Dec 2023 16:05:48 -0800 Subject: [PATCH] Improve graph repr for runnable passthrough and itemgetter (#15083) --- libs/core/langchain_core/runnables/base.py | 8 +- libs/core/langchain_core/runnables/graph.py | 6 +- .../langchain_core/runnables/passthrough.py | 16 ++++ .../runnables/__snapshots__/test_graph.ambr | 90 +++++++++---------- .../unit_tests/runnables/test_runnable.py | 9 +- 5 files changed, 72 insertions(+), 57 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 07b2e935225e3..04e9dce584342 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2007,7 +2007,7 @@ def get_input_schema( ): # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] - "RunnableParallelInput", + "RunnableMapInput", **{ k: (v.annotation, v.default) for step in self.steps.values() @@ -2024,7 +2024,7 @@ def get_output_schema( ) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] - "RunnableParallelOutput", + "RunnableMapOutput", **{k: (v.OutputType, None) for k, v in self.steps.items()}, __config__=_SchemaConfig, ) @@ -2650,7 +2650,9 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: """A string representation of this runnable.""" - if hasattr(self, "func"): + if hasattr(self, "func") and isinstance(self.func, itemgetter): + return f"RunnableLambda({str(self.func)[len('operator.'):]})" + elif hasattr(self, "func"): return f"RunnableLambda({get_lambda_source(self.func) or '...'})" elif hasattr(self, "afunc"): return f"RunnableLambda(afunc={get_lambda_source(self.afunc) or '...'})" diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 211b44b1f1e30..c1906b55155d1 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -123,13 +123,13 @@ def node_data(node: Node) -> str: or len(data.splitlines()) > 1 ): data = node.data.__class__.__name__ - elif len(data) > 36: - data = data[:36] + "..." + elif len(data) > 42: + data = data[:42] + "..." except Exception: data = node.data.__class__.__name__ else: data = node.data.__name__ - return data + return data if not data.startswith("Runnable") else data[8:] return draw( {node.id: node_data(node) for node in self.nodes.values()}, diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index db446cd7d315c..159f981607e9d 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -34,6 +34,7 @@ get_executor_for_config, patch_config, ) +from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec from langchain_core.utils.aiter import atee, py_anext from langchain_core.utils.iter import safetee @@ -297,6 +298,9 @@ async def input_aiter() -> AsyncIterator[Other]: yield chunk +_graph_passthrough: RunnablePassthrough = RunnablePassthrough() + + class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): """ A runnable that assigns key-value pairs to Dict[str, Any] inputs. @@ -355,6 +359,18 @@ def get_output_schema( def config_specs(self) -> List[ConfigurableFieldSpec]: return self.mapper.config_specs + def get_graph(self, config: RunnableConfig | None = None) -> Graph: + # get graph from mapper + graph = self.mapper.get_graph(config) + # add passthrough node and edges + input_node = graph.first_node() + output_node = graph.last_node() + if input_node is not None and output_node is not None: + passthrough_node = graph.add_node(_graph_passthrough) + graph.add_edge(input_node, passthrough_node) + graph.add_edge(passthrough_node, output_node) + return graph + def _invoke( self, input: Dict[str, Any], diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index f241ea429fd6f..eb42966cd5d93 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -32,51 +32,51 @@ # --- # name: test_graph_sequence_map ''' - +-------------+ - | PromptInput | - +-------------+ - * - * - * - +----------------+ - | PromptTemplate | - +----------------+ - * - * - * - +-------------+ - | FakeListLLM | - +-------------+ - * - * - * - +-----------------------+ - | RunnableParallelInput | - +-----------------------+** - **** ******* - **** ***** - ** ******* - +---------------------+ *** - | RunnableLambdaInput | * - +---------------------+ * - *** *** * - *** *** * - ** ** * - +-----------------+ +-----------------+ * - | StrOutputParser | | XMLOutputParser | * - +-----------------+ +-----------------+ * - *** *** * - *** *** * - ** ** * - +----------------------+ +--------------------------------+ - | RunnableLambdaOutput | | CommaSeparatedListOutputParser | - +----------------------+ +--------------------------------+ - **** ******* - **** ***** - ** **** - +------------------------+ - | RunnableParallelOutput | - +------------------------+ + +-------------+ + | PromptInput | + +-------------+ + * + * + * + +----------------+ + | PromptTemplate | + +----------------+ + * + * + * + +-------------+ + | FakeListLLM | + +-------------+ + * + * + * + +---------------+ + | ParallelInput | + +---------------+***** + *** ****** + *** ***** + ** ***** + +-------------+ *** + | LambdaInput | * + +-------------+ * + ** ** * + *** *** * + ** ** * + +-----------------+ +-----------------+ * + | StrOutputParser | | XMLOutputParser | * + +-----------------+ +-----------------+ * + ** ** * + *** *** * + ** ** * + +--------------+ +--------------------------------+ + | LambdaOutput | | CommaSeparatedListOutputParser | + +--------------+ +--------------------------------+ + *** ****** + *** ***** + ** *** + +-----------+ + | MapOutput | + +-----------+ ''' # --- # name: test_graph_single_runnable diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 2b9dc25909890..a67b91efac8df 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -569,7 +569,7 @@ async def typed_async_lambda_impl(x: str) -> int: "properties": {"name": {"title": "Name", "type": "string"}}, } assert seq_w_map.output_schema.schema() == { - "title": "RunnableParallelOutput", + "title": "RunnableMapOutput", "type": "object", "properties": { "original": {"title": "Original", "type": "string"}, @@ -613,7 +613,7 @@ def test_passthrough_assign_schema() -> None: # expected dict input_schema assert invalid_seq_w_assign.input_schema.schema() == { "properties": {"question": {"title": "Question"}}, - "title": "RunnableParallelInput", + "title": "RunnableMapInput", "type": "object", } @@ -768,7 +768,7 @@ def test_schema_complex_seq() -> None: ) assert chain2.input_schema.schema() == { - "title": "RunnableParallelInput", + "title": "RunnableMapInput", "type": "object", "properties": { "person": {"title": "Person", "type": "string"}, @@ -2221,7 +2221,6 @@ async def list_producer(input: AsyncIterator[Any]) -> AsyncIterator[AddableDict] } -@pytest.mark.asyncio @freeze_time("2023-01-01") async def test_prompt_with_llm_and_async_lambda( mocker: MockerFixture, snapshot: SnapshotAssertion @@ -4262,7 +4261,6 @@ def test_with_config_callbacks() -> None: assert isinstance(result, RunnableBinding) -@pytest.mark.asyncio async def test_ainvoke_on_returned_runnable() -> None: """Verify that a runnable returned by a sync runnable in the async path will be runthroughaasync path (issue #13407)""" @@ -4301,7 +4299,6 @@ def idchain_sync(__input: dict) -> bool: assert tracer.runs[0].child_runs[0].name == "RunnableParallel" -@pytest.mark.asyncio async def test_ainvoke_astream_passthrough_assign_trace() -> None: def idchain_sync(__input: dict) -> bool: return False