Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix pipeline with join node #7510

Merged
merged 22 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion haystack/nodes/file_converter/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def convert(

file_path = Path(file_path)
image = Image.open(file_path)
pages = self._image_to_text(image)
pages = self._image_to_text(image) # type: ignore
if remove_numeric_tables is None:
remove_numeric_tables = self.remove_numeric_tables
if valid_languages is None:
Expand Down
23 changes: 14 additions & 9 deletions haystack/nodes/other/join_docs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections import defaultdict
from math import inf
from typing import List, Optional
from typing import List, Optional, Dict, Tuple

from haystack.nodes.other.join import JoinNode
from haystack.schema import Document
Expand Down Expand Up @@ -58,8 +58,13 @@ def __init__(
self.top_k_join = top_k_join
self.sort_by_score = sort_by_score

def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
results = [inp["documents"] for inp in inputs]

# Check if all results are non-empty
if all(not res for res in results):
return {"documents": [], "labels": inputs[0].get("labels", None)}, "output_1"

document_map = {doc.id: doc for result in results for doc in result}

if self.join_mode == "concatenate":
Expand Down Expand Up @@ -98,7 +103,7 @@ def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None):

return output, "output_1"

def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
# Join single document lists
if isinstance(inputs[0]["documents"][0], Document):
return self.run(inputs=inputs, top_k_join=top_k_join)
Expand All @@ -117,13 +122,13 @@ def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] =

return output, "output_1"

def _concatenate_results(self, results, document_map):
def _concatenate_results(self, results: List[List[Document]], document_map: Dict) -> Dict[str, float]:
"""
Concatenates multiple document result lists.
Return the documents with the higher score.
"""
list_id = list(document_map.keys())
scores_map = {}
scores_map: Dict[str, float] = {}
for idx in list_id:
tmp = []
for result in results:
Expand All @@ -134,11 +139,11 @@ def _concatenate_results(self, results, document_map):
scores_map.update({idx: item_best_score.score})
return scores_map

def _calculate_comb_sum(self, results):
def _calculate_comb_sum(self, results: List[List[Document]]) -> Dict[str, float]:
"""
Calculates a combination sum by multiplying each score by its weight.
"""
scores_map = defaultdict(int)
scores_map: Dict[str, float] = defaultdict(float)
weights = self.weights if self.weights else [1 / len(results)] * len(results)

for result, weight in zip(results, weights):
Expand All @@ -147,14 +152,14 @@ def _calculate_comb_sum(self, results):

return scores_map

def _calculate_rrf(self, results):
def _calculate_rrf(self, results: List[List[Document]]) -> Dict[str, float]:
"""
Calculates the reciprocal rank fusion. The constant K is set to 61 (60 was suggested by the original paper,
plus 1 as python lists are 0-based and the paper used 1-based ranking).
"""
K = 61

scores_map = defaultdict(int)
scores_map: Dict[str, float] = defaultdict(float)
weights = self.weights if self.weights else [1 / len(results)] * len(results)

# Calculate weighted reciprocal rank fusion score
Expand Down
8 changes: 6 additions & 2 deletions haystack/nodes/other/shaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,11 @@ def run( # type: ignore
meta: Optional[dict] = None,
invocation_context: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict, str]:
invocation_context = invocation_context or {}
if invocation_context is None:
invocation_context = {}
else:
invocation_context = invocation_context.copy()

if query and "query" not in invocation_context.keys():
invocation_context["query"] = query

Expand All @@ -755,7 +759,7 @@ def run( # type: ignore
if labels and "labels" not in invocation_context.keys():
invocation_context["labels"] = labels

if documents != None and "documents" not in invocation_context.keys():
if documents is not None and "documents" not in invocation_context.keys():
invocation_context["documents"] = documents

if meta and "meta" not in invocation_context.keys():
Expand Down
5 changes: 4 additions & 1 deletion haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def _prepare( # type: ignore
"""
Prepare prompt invocation.
"""
invocation_context = invocation_context or {}
if invocation_context is None:
invocation_context = {}
else:
invocation_context = invocation_context.copy()

if query and "query" not in invocation_context:
invocation_context["query"] = query
Expand Down
82 changes: 64 additions & 18 deletions haystack/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,19 +596,27 @@ def run( # type: ignore
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
if query:
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if query and "query" not in updated_input:
updated_input["query"] = query
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand All @@ -618,6 +626,22 @@ def run( # type: ignore

return node_output

def _combine_node_outputs(self, existing_input: Dict[str, Any], node_output: Dict[str, Any]) -> Dict[str, Any]:
"""
Combines the outputs of two nodes into a single input for a downstream node.
For matching keys first node's (existing_input) value is kept. This is used for join nodes.

:param existing_input: The output of the first node.
:param node_output: The output of the second node.
"""
additional_input = {}
combined = {**node_output, **existing_input}
for key in combined:
# Don't overwrite these keys since they are set in Pipeline.run
if key not in ["inputs", "params", "_debug"]:
additional_input[key] = combined[key]
return additional_input

async def _arun( # noqa: C901,PLR0912 type: ignore
self,
query: Optional[str] = None,
Expand Down Expand Up @@ -734,19 +758,27 @@ async def _arun( # noqa: C901,PLR0912 type: ignore
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
if query:
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if query and "query" not in updated_input:
updated_input["query"] = query
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
sjrl marked this conversation as resolved.
Show resolved Hide resolved
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand All @@ -756,6 +788,7 @@ async def _arun( # noqa: C901,PLR0912 type: ignore

return node_output

# pylint: disable=too-many-branches
def run_batch( # noqa: C901,PLR0912 type: ignore
self,
queries: Optional[List[str]] = None,
Expand Down Expand Up @@ -896,19 +929,32 @@ def run_batch( # noqa: C901,PLR0912 type: ignore
existing_input = queue[n]
if "inputs" not in existing_input.keys():
updated_input: Dict = {"inputs": [existing_input, node_output], "params": params}
if queries:
if "_debug" in existing_input.keys() or "_debug" in node_output.keys():
updated_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if queries and "queries" not in updated_input:
updated_input["queries"] = queries
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
---
fixes:
- |
When using a `Pipeline` with a `JoinNode` (e.g. `JoinDocuments`) all information from the previous nodes was lost
other than a few select fields (e.g. `documents`). This was due to the `JoinNode` not properly passing on
the information from the previous nodes. This has been fixed and now all information from the previous nodes is
passed on to the next node in the pipeline.

For example, this is a pipeline that rewrites the `query` during pipeline execution combined with a hybrid retrieval
setup that requires a `JoinDocuments` node. Specifically the first prompt node rewrites the `query` to fix all
spelling errors, and this new `query` is used for retrieval. And now the `JoinDocuments` node will now pass on the
rewritten `query` so it can be used by the `QAPromptNode` node whereas before it would pass on the original query.
```python
from haystack import Pipeline
from haystack.nodes import BM25Retriever, EmbeddingRetriever, PromptNode, Shaper, JoinDocuments, PromptTemplate
from haystack.document_stores import InMemoryDocumentStore

document_store = InMemoryDocumentStore(use_bm25=True)
dicts = [{"content": "The capital of Germany is Berlin."}, {"content": "The capital of France is Paris."}]
document_store.write_documents(dicts)

query_prompt_node = PromptNode(
model_name_or_path="gpt-3.5-turbo",
api_key="",
default_prompt_template=PromptTemplate("You are a spell checker. Given a user query return the same query with all spelling errors fixed.\nUser Query: {query}\nSpell Checked Query:")
)
shaper = Shaper(
func="join_strings",
inputs={"strings": "results"},
outputs=["query"],
)
qa_prompt_node = PromptNode(
model_name_or_path="gpt-3.5-turbo",
api_key="",
default_prompt_template=PromptTemplate("Answer the user query. Query: {query}")
)
sparse_retriever = BM25Retriever(
document_store=document_store,
top_k=2
)
dense_retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="intfloat/e5-base-v2",
model_format="sentence_transformers",
top_k=2
)
document_store.update_embeddings(dense_retriever)

pipeline = Pipeline()
pipeline.add_node(component=query_prompt_node, name="QueryPromptNode", inputs=["Query"])
pipeline.add_node(component=shaper, name="ListToString", inputs=["QueryPromptNode"])
pipeline.add_node(component=sparse_retriever, name="BM25", inputs=["ListToString"])
pipeline.add_node(component=dense_retriever, name="Embedding", inputs=["ListToString"])
pipeline.add_node(
component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["BM25", "Embedding"]
)
pipeline.add_node(component=qa_prompt_node, name="QAPromptNode", inputs=["Join"])

out = pipeline.run(query="What is the captial of Grmny?", debug=True)
print(out["invocation_context"])
# Before Fix
# {'query': 'What is the captial of Grmny?', <-- Original Query!!
# 'results': ['The capital of Germany is Berlin.'],
# 'prompts': ['Answer the user query. Query: What is the captial of Grmny?'], <-- Original Query!!
# After Fix
# {'query': 'What is the capital of Germany?', <-- Rewritten Query!!
# 'results': ['The capital of Germany is Berlin.'],
# 'prompts': ['Answer the user query. Query: What is the capital of Germany?'], <-- Rewritten Query!!
```
24 changes: 23 additions & 1 deletion test/nodes/test_join_documents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


from haystack import Document
from haystack import Document, Pipeline
from haystack.nodes.other.join_docs import JoinDocuments
from copy import deepcopy

Expand Down Expand Up @@ -149,3 +149,25 @@ def test_joindocuments_rrf_weights():
assert result_none["documents"] == result_even["documents"]
assert result_uneven["documents"] != result_none["documents"]
assert result_uneven["documents"][0].score > result_none["documents"][0].score


@pytest.mark.unit
def test_join_node_empty_documents():
pipe = Pipeline()
join_node = JoinDocuments(join_mode="concatenate")
pipe.add_node(component=join_node, name="Join", inputs=["Query"])

# Test single document lists
output = pipe.run(query="test", documents=[])
assert len(output["documents"]) == 0


@pytest.mark.unit
def test_join_node_none_documents():
pipe = Pipeline()
join_node = JoinDocuments(join_mode="concatenate")
pipe.add_node(component=join_node, name="Join", inputs=["Query"])

# Test single document lists
output = pipe.run(query="test", documents=None)
assert len(output["documents"]) == 0
Loading
Loading