diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py index 538d207500..14677fd20b 100644 --- a/haystack/pipelines/base.py +++ b/haystack/pipelines/base.py @@ -610,7 +610,6 @@ def run( # type: ignore updated_input["meta"] = meta else: existing_input["inputs"].append(node_output) - # TODO This doesn't have an effect until we also pass on keys that only occur once additional_input = self._combine_node_outputs(existing_input, node_output) updated_input = {**additional_input, **existing_input} queue[n] = updated_input @@ -630,13 +629,24 @@ def _combine_node_outputs(self, existing_input: Dict[str, Any], node_output: Dic :param node_output: The output of the second node. """ additional_input = {} - # Pass keys that appear in both inputs that have the same values + # TODO Should we support overwriting keys that exist in both? --> first node's value is kept + # Add shared items from existing_input and node_output that have matching values shared_items = { k: existing_input[k] for k in existing_input if k in node_output and existing_input[k] == node_output[k] } for key in shared_items: - if key != "inputs" or key != "params" or key != "_debug": + if key not in ["inputs", "params", "_debug"]: additional_input[key] = shared_items[key] + unique_existing_input = {k: v for k, v in existing_input.items() if k not in shared_items} + # Add unique keys from existing_input + for key in unique_existing_input: + if key not in ["inputs", "params", "_debug"]: + additional_input[key] = unique_existing_input[key] + # Add unique keys from node_output + unique_node_output = {k: v for k, v in node_output.items() if k not in shared_items} + for key in unique_node_output: + if key not in ["inputs", "params", "_debug"]: + additional_input[key] = unique_node_output[key] return additional_input async def _arun( # noqa: C901,PLR0912 type: ignore diff --git a/test/pipelines/test_pipeline.py b/test/pipelines/test_pipeline.py index a454891534..d8239bfefe 100644 --- a/test/pipelines/test_pipeline.py +++ b/test/pipelines/test_pipeline.py @@ -2143,15 +2143,18 @@ def test_pipeline_execution_using_join_preserves_previous_keys_three_streams(): document_store_3.write_documents(dicts_3) # Create Shaper to insert "invocation_context" and "test_key" into the node_output - shaper = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key"]) + shaper1 = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key1"]) + shaper2 = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key2"]) pipeline = Pipeline() - pipeline.add_node(component=shaper, name="Shaper", inputs=["Query"]) - pipeline.add_node(component=retriever_1, name="Retriever1", inputs=["Shaper"]) - pipeline.add_node(component=retriever_2, name="Retriever2", inputs=["Shaper"]) - pipeline.add_node(component=retriever_3, name="Retriever3", inputs=["Shaper"]) + pipeline.add_node(component=shaper1, name="Shaper1", inputs=["Query"]) + pipeline.add_node(component=shaper2, name="Shaper2", inputs=["Query"]) + pipeline.add_node(component=retriever_3, name="Retriever3", inputs=["Shaper2"]) + pipeline.add_node(component=retriever_1, name="Retriever1", inputs=["Shaper1"]) + pipeline.add_node(component=retriever_2, name="Retriever2", inputs=["Shaper1"]) + pipeline.add_node( - component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["Retriever1", "Retriever2", "Retriever3"] + component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["Retriever3", "Retriever1", "Retriever2"] ) res = pipeline.run(query="Alpha Beta Gamma Delta") assert set(res.keys()) == { @@ -2159,13 +2162,15 @@ def test_pipeline_execution_using_join_preserves_previous_keys_three_streams(): "labels", "root_node", "params", - "test_key", + "test_key1", + "test_key2", "invocation_context", "query", "node_id", } - assert res["test_key"] == "Alpha Beta Gamma Delta" - assert res["invocation_context"] == {"query": "Alpha Beta Gamma Delta", "test_key": "Alpha Beta Gamma Delta"} + assert res["test_key1"] == "Alpha Beta Gamma Delta" + assert res["test_key2"] == "Alpha Beta Gamma Delta" + assert res["invocation_context"] == {"query": "Alpha Beta Gamma Delta", "test_key1": "Alpha Beta Gamma Delta"} assert len(res["documents"]) == 3