Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix pipeline with join node #7510

Merged
merged 22 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion haystack/nodes/file_converter/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def convert(

file_path = Path(file_path)
image = Image.open(file_path)
pages = self._image_to_text(image)
pages = self._image_to_text(image) # type: ignore
if remove_numeric_tables is None:
remove_numeric_tables = self.remove_numeric_tables
if valid_languages is None:
Expand Down
2 changes: 1 addition & 1 deletion haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _prepare( # type: ignore
"""
Prepare prompt invocation.
"""
invocation_context = invocation_context or {}
invocation_context = invocation_context.copy() or {}
sjrl marked this conversation as resolved.
Show resolved Hide resolved

if query and "query" not in invocation_context:
invocation_context["query"] = query
Expand Down
56 changes: 44 additions & 12 deletions haystack/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,19 +596,27 @@ def run( # type: ignore
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
if query:
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if query and "query" not in updated_input:
updated_input["query"] = query
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand All @@ -618,6 +626,22 @@ def run( # type: ignore

return node_output

def _combine_node_outputs(self, existing_input: Dict[str, Any], node_output: Dict[str, Any]) -> Dict[str, Any]:
"""
Combines the outputs of two nodes into a single input for a downstream node.
For matching keys first node's (existing_input) value is kept. This is used for join nodes.

:param existing_input: The output of the first node.
:param node_output: The output of the second node.
"""
additional_input = {}
combined = {**node_output, **existing_input}
for key in combined:
# Don't overwrite these keys since they are set in Pipeline.run
if key not in ["inputs", "params", "_debug"]:
additional_input[key] = combined[key]
return additional_input

async def _arun( # noqa: C901,PLR0912 type: ignore
self,
query: Optional[str] = None,
Expand Down Expand Up @@ -734,19 +758,27 @@ async def _arun( # noqa: C901,PLR0912 type: ignore
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
if query:
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if query and "query" not in updated_input:
updated_input["query"] = query
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
sjrl marked this conversation as resolved.
Show resolved Hide resolved
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Fixes pipelines using join nodes to properly pass on additional key value pairs from nodes prior to the join node to nodes that come after the join node.
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
88 changes: 88 additions & 0 deletions test/pipelines/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from haystack.nodes.base import BaseComponent
from haystack.nodes.retriever.sparse import BM25Retriever
from haystack.nodes.retriever.sparse import FilterRetriever
from haystack.nodes import Shaper
from haystack.pipelines import (
Pipeline,
RootNode,
Expand Down Expand Up @@ -2086,6 +2087,93 @@ def test_fix_to_pipeline_execution_when_join_follows_join():
assert len(documents) == 4 # all four documents should be found


@pytest.mark.unit
def test_pipeline_execution_using_join_preserves_previous_keys():
document_store_1 = InMemoryDocumentStore()
retriever_1 = FilterRetriever(document_store_1, scale_score=True)
dicts_1 = [{"content": "Alpha", "score": 0.552}]
document_store_1.write_documents(dicts_1)

document_store_2 = InMemoryDocumentStore()
retriever_2 = FilterRetriever(document_store_2, scale_score=True)
dicts_2 = [{"content": "Beta", "score": 0.542}]
document_store_2.write_documents(dicts_2)

# Create Shaper to insert "invocation_context" and "test_key" into the node_output
shaper = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key"])

pipeline = Pipeline()
pipeline.add_node(component=shaper, name="Shaper", inputs=["Query"])
pipeline.add_node(component=retriever_1, name="Retriever1", inputs=["Shaper"])
pipeline.add_node(component=retriever_2, name="Retriever2", inputs=["Shaper"])
pipeline.add_node(
component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["Retriever1", "Retriever2"]
)
res = pipeline.run(query="Alpha Beta Gamma Delta")
assert set(res.keys()) == {
"documents",
"labels",
"root_node",
"params",
"test_key",
"invocation_context",
"query",
"node_id",
}
assert res["test_key"] == "Alpha Beta Gamma Delta"
assert res["invocation_context"] == {"query": "Alpha Beta Gamma Delta", "test_key": "Alpha Beta Gamma Delta"}
assert len(res["documents"]) == 2


@pytest.mark.unit
def test_pipeline_execution_using_join_preserves_previous_keys_three_streams():
document_store_1 = InMemoryDocumentStore()
retriever_1 = FilterRetriever(document_store_1, scale_score=True)
dicts_1 = [{"content": "Alpha", "score": 0.552}]
document_store_1.write_documents(dicts_1)

document_store_2 = InMemoryDocumentStore()
retriever_2 = FilterRetriever(document_store_2, scale_score=True)
dicts_2 = [{"content": "Beta", "score": 0.542}]
document_store_2.write_documents(dicts_2)

document_store_3 = InMemoryDocumentStore()
retriever_3 = FilterRetriever(document_store_3, scale_score=True)
dicts_3 = [{"content": "Gamma", "score": 0.532}]
document_store_3.write_documents(dicts_3)

# Create Shaper to insert "invocation_context" and "test_key" into the node_output
shaper1 = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key1"])
shaper2 = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key2"])

pipeline = Pipeline()
pipeline.add_node(component=shaper1, name="Shaper1", inputs=["Query"])
pipeline.add_node(component=shaper2, name="Shaper2", inputs=["Query"])
pipeline.add_node(component=retriever_3, name="Retriever3", inputs=["Shaper2"])
pipeline.add_node(component=retriever_1, name="Retriever1", inputs=["Shaper1"])
pipeline.add_node(component=retriever_2, name="Retriever2", inputs=["Shaper1"])

pipeline.add_node(
component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["Retriever3", "Retriever1", "Retriever2"]
)
res = pipeline.run(query="Alpha Beta Gamma Delta")
assert set(res.keys()) == {
"documents",
"labels",
"root_node",
"params",
"test_key1",
"test_key2",
"invocation_context",
"query",
"node_id",
}
assert res["test_key1"] == "Alpha Beta Gamma Delta"
assert res["test_key2"] == "Alpha Beta Gamma Delta"
assert res["invocation_context"] == {"query": "Alpha Beta Gamma Delta", "test_key1": "Alpha Beta Gamma Delta"}
assert len(res["documents"]) == 3


@pytest.mark.unit
def test_update_config_hash():
fake_configs = {
Expand Down
Loading