Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix pipeline with join node #7510

Merged
merged 22 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion haystack/nodes/file_converter/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def convert(

file_path = Path(file_path)
image = Image.open(file_path)
pages = self._image_to_text(image)
pages = self._image_to_text(image) # type: ignore
if remove_numeric_tables is None:
remove_numeric_tables = self.remove_numeric_tables
if valid_languages is None:
Expand Down
27 changes: 17 additions & 10 deletions haystack/nodes/other/join_docs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections import defaultdict
from math import inf
from typing import List, Optional
from typing import List, Optional, Dict, Tuple

from haystack.nodes.other.join import JoinNode
from haystack.schema import Document
Expand Down Expand Up @@ -58,8 +58,13 @@ def __init__(
self.top_k_join = top_k_join
self.sort_by_score = sort_by_score

def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
results = [inp["documents"] for inp in inputs]

# Check if all results are non-empty
if all(res is None for res in results) or all(res == [] for res in results):
sjrl marked this conversation as resolved.
Show resolved Hide resolved
return {"documents": [], "labels": inputs[0].get("labels", None)}, "output_1"

document_map = {doc.id: doc for result in results for doc in result}

if self.join_mode == "concatenate":
Expand Down Expand Up @@ -98,9 +103,11 @@ def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None):

return output, "output_1"

def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
# Join single document lists
if isinstance(inputs[0]["documents"][0], Document):
if inputs[0]["documents"] is None or inputs[0]["documents"] == []:
return {"documents": [], "labels": inputs[0].get("labels", None)}, "output_1"
sjrl marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(inputs[0]["documents"][0], Document):
return self.run(inputs=inputs, top_k_join=top_k_join)
# Join lists of document lists
else:
Expand All @@ -117,13 +124,13 @@ def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] =

return output, "output_1"

def _concatenate_results(self, results, document_map):
def _concatenate_results(self, results: List[List[Document]], document_map: Dict) -> Dict[str, float]:
"""
Concatenates multiple document result lists.
Return the documents with the higher score.
"""
list_id = list(document_map.keys())
scores_map = {}
scores_map: Dict[str, float] = {}
for idx in list_id:
tmp = []
for result in results:
Expand All @@ -134,11 +141,11 @@ def _concatenate_results(self, results, document_map):
scores_map.update({idx: item_best_score.score})
return scores_map

def _calculate_comb_sum(self, results):
def _calculate_comb_sum(self, results: List[List[Document]]) -> Dict[str, float]:
"""
Calculates a combination sum by multiplying each score by its weight.
"""
scores_map = defaultdict(int)
scores_map: Dict[str, float] = defaultdict(float)
weights = self.weights if self.weights else [1 / len(results)] * len(results)

for result, weight in zip(results, weights):
Expand All @@ -147,14 +154,14 @@ def _calculate_comb_sum(self, results):

return scores_map

def _calculate_rrf(self, results):
def _calculate_rrf(self, results: List[List[Document]]) -> Dict[str, float]:
"""
Calculates the reciprocal rank fusion. The constant K is set to 61 (60 was suggested by the original paper,
plus 1 as python lists are 0-based and the paper used 1-based ranking).
"""
K = 61

scores_map = defaultdict(int)
scores_map: Dict[str, float] = defaultdict(float)
weights = self.weights if self.weights else [1 / len(results)] * len(results)

# Calculate weighted reciprocal rank fusion score
Expand Down
8 changes: 6 additions & 2 deletions haystack/nodes/other/shaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,11 @@ def run( # type: ignore
meta: Optional[dict] = None,
invocation_context: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict, str]:
invocation_context = invocation_context or {}
if invocation_context is None:
invocation_context = {}
else:
invocation_context = invocation_context.copy()

if query and "query" not in invocation_context.keys():
invocation_context["query"] = query

Expand All @@ -755,7 +759,7 @@ def run( # type: ignore
if labels and "labels" not in invocation_context.keys():
invocation_context["labels"] = labels

if documents != None and "documents" not in invocation_context.keys():
if documents is not None and "documents" not in invocation_context.keys():
invocation_context["documents"] = documents

if meta and "meta" not in invocation_context.keys():
Expand Down
5 changes: 4 additions & 1 deletion haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def _prepare( # type: ignore
"""
Prepare prompt invocation.
"""
invocation_context = invocation_context or {}
if invocation_context is None:
invocation_context = {}
else:
invocation_context = invocation_context.copy()

if query and "query" not in invocation_context:
invocation_context["query"] = query
Expand Down
82 changes: 64 additions & 18 deletions haystack/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,19 +596,27 @@ def run( # type: ignore
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
if query:
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if query and "query" not in updated_input:
updated_input["query"] = query
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand All @@ -618,6 +626,22 @@ def run( # type: ignore

return node_output

def _combine_node_outputs(self, existing_input: Dict[str, Any], node_output: Dict[str, Any]) -> Dict[str, Any]:
"""
Combines the outputs of two nodes into a single input for a downstream node.
For matching keys first node's (existing_input) value is kept. This is used for join nodes.
:param existing_input: The output of the first node.
:param node_output: The output of the second node.
"""
additional_input = {}
combined = {**node_output, **existing_input}
for key in combined:
# Don't overwrite these keys since they are set in Pipeline.run
if key not in ["inputs", "params", "_debug"]:
additional_input[key] = combined[key]
return additional_input

async def _arun( # noqa: C901,PLR0912 type: ignore
self,
query: Optional[str] = None,
Expand Down Expand Up @@ -734,19 +758,27 @@ async def _arun( # noqa: C901,PLR0912 type: ignore
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
if query:
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if query and "query" not in updated_input:
updated_input["query"] = query
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
sjrl marked this conversation as resolved.
Show resolved Hide resolved
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand All @@ -756,6 +788,7 @@ async def _arun( # noqa: C901,PLR0912 type: ignore

return node_output

# pylint: disable=too-many-branches
def run_batch( # noqa: C901,PLR0912 type: ignore
self,
queries: Optional[List[str]] = None,
Expand Down Expand Up @@ -896,19 +929,32 @@ def run_batch( # noqa: C901,PLR0912 type: ignore
existing_input = queue[n]
if "inputs" not in existing_input.keys():
updated_input: Dict = {"inputs": [existing_input, node_output], "params": params}
if queries:
if "_debug" in existing_input.keys() or "_debug" in node_output.keys():
updated_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if queries and "queries" not in updated_input:
updated_input["queries"] = queries
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Fixes pipelines using join nodes to properly pass on additional key value pairs from nodes prior to the join node to nodes that come after the join node.
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
32 changes: 31 additions & 1 deletion test/nodes/test_join_documents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


from haystack import Document
from haystack import Document, Pipeline
from haystack.nodes.other.join_docs import JoinDocuments
from copy import deepcopy

Expand Down Expand Up @@ -149,3 +149,33 @@ def test_joindocuments_rrf_weights():
assert result_none["documents"] == result_even["documents"]
assert result_uneven["documents"] != result_none["documents"]
assert result_uneven["documents"][0].score > result_none["documents"][0].score


@pytest.mark.unit
def test_join_node_empty_documents():
pipe = Pipeline()
join_node = JoinDocuments(join_mode="concatenate")
pipe.add_node(component=join_node, name="Join", inputs=["Query"])

# Test single document lists
output = pipe.run(query="test", documents=[])
assert len(output["documents"]) == 0

# Test lists of document lists
output = join_node.run_batch(queries=["test"], documents=[])
assert len(output[0]["documents"]) == 0


@pytest.mark.unit
def test_join_node_none_documents():
pipe = Pipeline()
join_node = JoinDocuments(join_mode="concatenate")
pipe.add_node(component=join_node, name="Join", inputs=["Query"])

# Test single document lists
output = pipe.run(query="test", documents=None)
assert len(output["documents"]) == 0

# Test lists of document lists
output = join_node.run_batch(queries=["test"], documents=None)
assert len(output[0]["documents"]) == 0
Loading
Loading