Skip to content

Commit

Permalink
Improve graph repr for runnable passthrough and itemgetter (#15083)
Browse files Browse the repository at this point in the history
<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
  • Loading branch information
nfcampos authored Dec 23, 2023
1 parent 0d0901e commit a2d3042
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 57 deletions.
8 changes: 5 additions & 3 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,7 +2007,7 @@ def get_input_schema(
):
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableParallelInput",
"RunnableMapInput",
**{
k: (v.annotation, v.default)
for step in self.steps.values()
Expand All @@ -2024,7 +2024,7 @@ def get_output_schema(
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableParallelOutput",
"RunnableMapOutput",
**{k: (v.OutputType, None) for k, v in self.steps.items()},
__config__=_SchemaConfig,
)
Expand Down Expand Up @@ -2650,7 +2650,9 @@ def __eq__(self, other: Any) -> bool:

def __repr__(self) -> str:
"""A string representation of this runnable."""
if hasattr(self, "func"):
if hasattr(self, "func") and isinstance(self.func, itemgetter):
return f"RunnableLambda({str(self.func)[len('operator.'):]})"
elif hasattr(self, "func"):
return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
elif hasattr(self, "afunc"):
return f"RunnableLambda(afunc={get_lambda_source(self.afunc) or '...'})"
Expand Down
6 changes: 3 additions & 3 deletions libs/core/langchain_core/runnables/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ def node_data(node: Node) -> str:
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 36:
data = data[:36] + "..."
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data
return data if not data.startswith("Runnable") else data[8:]

return draw(
{node.id: node_data(node) for node in self.nodes.values()},
Expand Down
16 changes: 16 additions & 0 deletions libs/core/langchain_core/runnables/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
get_executor_for_config,
patch_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.iter import safetee
Expand Down Expand Up @@ -297,6 +298,9 @@ async def input_aiter() -> AsyncIterator[Other]:
yield chunk


_graph_passthrough: RunnablePassthrough = RunnablePassthrough()


class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
Expand Down Expand Up @@ -355,6 +359,18 @@ def get_output_schema(
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.mapper.config_specs

def get_graph(self, config: RunnableConfig | None = None) -> Graph:
# get graph from mapper
graph = self.mapper.get_graph(config)
# add passthrough node and edges
input_node = graph.first_node()
output_node = graph.last_node()
if input_node is not None and output_node is not None:
passthrough_node = graph.add_node(_graph_passthrough)
graph.add_edge(input_node, passthrough_node)
graph.add_edge(passthrough_node, output_node)
return graph

def _invoke(
self,
input: Dict[str, Any],
Expand Down
90 changes: 45 additions & 45 deletions libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -32,51 +32,51 @@
# ---
# name: test_graph_sequence_map
'''
+-------------+
| PromptInput |
+-------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+-------------+
| FakeListLLM |
+-------------+
*
*
*
+-----------------------+
| RunnableParallelInput |
+-----------------------+**
**** *******
**** *****
** *******
+---------------------+ ***
| RunnableLambdaInput | *
+---------------------+ *
*** *** *
*** *** *
** ** *
+-----------------+ +-----------------+ *
| StrOutputParser | | XMLOutputParser | *
+-----------------+ +-----------------+ *
*** *** *
*** *** *
** ** *
+----------------------+ +--------------------------------+
| RunnableLambdaOutput | | CommaSeparatedListOutputParser |
+----------------------+ +--------------------------------+
**** *******
**** *****
** ****
+------------------------+
| RunnableParallelOutput |
+------------------------+
+-------------+
| PromptInput |
+-------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+-------------+
| FakeListLLM |
+-------------+
*
*
*
+---------------+
| ParallelInput |
+---------------+*****
*** ******
*** *****
** *****
+-------------+ ***
| LambdaInput | *
+-------------+ *
** ** *
*** *** *
** ** *
+-----------------+ +-----------------+ *
| StrOutputParser | | XMLOutputParser | *
+-----------------+ +-----------------+ *
** ** *
*** *** *
** ** *
+--------------+ +--------------------------------+
| LambdaOutput | | CommaSeparatedListOutputParser |
+--------------+ +--------------------------------+
*** ******
*** *****
** ***
+-----------+
| MapOutput |
+-----------+
'''
# ---
# name: test_graph_single_runnable
Expand Down
9 changes: 3 additions & 6 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ async def typed_async_lambda_impl(x: str) -> int:
"properties": {"name": {"title": "Name", "type": "string"}},
}
assert seq_w_map.output_schema.schema() == {
"title": "RunnableParallelOutput",
"title": "RunnableMapOutput",
"type": "object",
"properties": {
"original": {"title": "Original", "type": "string"},
Expand Down Expand Up @@ -613,7 +613,7 @@ def test_passthrough_assign_schema() -> None:
# expected dict input_schema
assert invalid_seq_w_assign.input_schema.schema() == {
"properties": {"question": {"title": "Question"}},
"title": "RunnableParallelInput",
"title": "RunnableMapInput",
"type": "object",
}

Expand Down Expand Up @@ -768,7 +768,7 @@ def test_schema_complex_seq() -> None:
)

assert chain2.input_schema.schema() == {
"title": "RunnableParallelInput",
"title": "RunnableMapInput",
"type": "object",
"properties": {
"person": {"title": "Person", "type": "string"},
Expand Down Expand Up @@ -2221,7 +2221,6 @@ async def list_producer(input: AsyncIterator[Any]) -> AsyncIterator[AddableDict]
}


@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_llm_and_async_lambda(
mocker: MockerFixture, snapshot: SnapshotAssertion
Expand Down Expand Up @@ -4262,7 +4261,6 @@ def test_with_config_callbacks() -> None:
assert isinstance(result, RunnableBinding)


@pytest.mark.asyncio
async def test_ainvoke_on_returned_runnable() -> None:
"""Verify that a runnable returned by a sync runnable in the async path will
be runthroughaasync path (issue #13407)"""
Expand Down Expand Up @@ -4301,7 +4299,6 @@ def idchain_sync(__input: dict) -> bool:
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"


@pytest.mark.asyncio
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
return False
Expand Down

0 comments on commit a2d3042

Please sign in to comment.