Skip to content

Commit

Permalink
fix: Fix pipeline with join node (#7510)
Browse files Browse the repository at this point in the history
* Update Pipeline.run method to pass on more info when using join node

* Add release notes

* Update docs and update _arun method as well

* Added support for unique keys

* Simplify the method

* fix mypy error

* ignore mypy in image converter

* Ignore mypy warning

* Update _debug as well when there are more than two streams being joined.

* Add copy to invocation_context instead of overwriting

* Fix the copy

* Add copy to Shaper as well

* Add another test to make sure query stays changed

* Updated test to make sure that it does only pass with the new changes

* Add to run_batch as well

* Fix type annotations

* Make join nodes work when no docs are provided

* Ignore pylint error

* Update haystack/nodes/other/join_docs.py

Co-authored-by: tstadel <[email protected]>

* Expand on release note

* Fix test

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
Co-authored-by: tstadel <[email protected]>
  • Loading branch information
3 people authored and vblagoje committed Apr 23, 2024
1 parent c55267a commit 23827c0
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 33 deletions.
2 changes: 1 addition & 1 deletion haystack/nodes/file_converter/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def convert(

file_path = Path(file_path)
image = Image.open(file_path)
pages = self._image_to_text(image)
pages = self._image_to_text(image) # type: ignore
if remove_numeric_tables is None:
remove_numeric_tables = self.remove_numeric_tables
if valid_languages is None:
Expand Down
23 changes: 14 additions & 9 deletions haystack/nodes/other/join_docs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections import defaultdict
from math import inf
from typing import List, Optional
from typing import List, Optional, Dict, Tuple

from haystack.nodes.other.join import JoinNode
from haystack.schema import Document
Expand Down Expand Up @@ -58,8 +58,13 @@ def __init__(
self.top_k_join = top_k_join
self.sort_by_score = sort_by_score

def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
results = [inp["documents"] for inp in inputs]

# Check if all results are non-empty
if all(not res for res in results):
return {"documents": [], "labels": inputs[0].get("labels", None)}, "output_1"

document_map = {doc.id: doc for result in results for doc in result}

if self.join_mode == "concatenate":
Expand Down Expand Up @@ -98,7 +103,7 @@ def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None):

return output, "output_1"

def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
# Join single document lists
if isinstance(inputs[0]["documents"][0], Document):
return self.run(inputs=inputs, top_k_join=top_k_join)
Expand All @@ -117,13 +122,13 @@ def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] =

return output, "output_1"

def _concatenate_results(self, results, document_map):
def _concatenate_results(self, results: List[List[Document]], document_map: Dict) -> Dict[str, float]:
"""
Concatenates multiple document result lists.
Return the documents with the higher score.
"""
list_id = list(document_map.keys())
scores_map = {}
scores_map: Dict[str, float] = {}
for idx in list_id:
tmp = []
for result in results:
Expand All @@ -134,11 +139,11 @@ def _concatenate_results(self, results, document_map):
scores_map.update({idx: item_best_score.score})
return scores_map

def _calculate_comb_sum(self, results):
def _calculate_comb_sum(self, results: List[List[Document]]) -> Dict[str, float]:
"""
Calculates a combination sum by multiplying each score by its weight.
"""
scores_map = defaultdict(int)
scores_map: Dict[str, float] = defaultdict(float)
weights = self.weights if self.weights else [1 / len(results)] * len(results)

for result, weight in zip(results, weights):
Expand All @@ -147,14 +152,14 @@ def _calculate_comb_sum(self, results):

return scores_map

def _calculate_rrf(self, results):
def _calculate_rrf(self, results: List[List[Document]]) -> Dict[str, float]:
"""
Calculates the reciprocal rank fusion. The constant K is set to 61 (60 was suggested by the original paper,
plus 1 as python lists are 0-based and the paper used 1-based ranking).
"""
K = 61

scores_map = defaultdict(int)
scores_map: Dict[str, float] = defaultdict(float)
weights = self.weights if self.weights else [1 / len(results)] * len(results)

# Calculate weighted reciprocal rank fusion score
Expand Down
8 changes: 6 additions & 2 deletions haystack/nodes/other/shaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,11 @@ def run( # type: ignore
meta: Optional[dict] = None,
invocation_context: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict, str]:
invocation_context = invocation_context or {}
if invocation_context is None:
invocation_context = {}
else:
invocation_context = invocation_context.copy()

if query and "query" not in invocation_context.keys():
invocation_context["query"] = query

Expand All @@ -755,7 +759,7 @@ def run( # type: ignore
if labels and "labels" not in invocation_context.keys():
invocation_context["labels"] = labels

if documents != None and "documents" not in invocation_context.keys():
if documents is not None and "documents" not in invocation_context.keys():
invocation_context["documents"] = documents

if meta and "meta" not in invocation_context.keys():
Expand Down
5 changes: 4 additions & 1 deletion haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def _prepare( # type: ignore
"""
Prepare prompt invocation.
"""
invocation_context = invocation_context or {}
if invocation_context is None:
invocation_context = {}
else:
invocation_context = invocation_context.copy()

if query and "query" not in invocation_context:
invocation_context["query"] = query
Expand Down
82 changes: 64 additions & 18 deletions haystack/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,19 +596,27 @@ def run( # type: ignore
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
if query:
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if query and "query" not in updated_input:
updated_input["query"] = query
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand All @@ -618,6 +626,22 @@ def run( # type: ignore

return node_output

def _combine_node_outputs(self, existing_input: Dict[str, Any], node_output: Dict[str, Any]) -> Dict[str, Any]:
"""
Combines the outputs of two nodes into a single input for a downstream node.
For matching keys first node's (existing_input) value is kept. This is used for join nodes.
:param existing_input: The output of the first node.
:param node_output: The output of the second node.
"""
additional_input = {}
combined = {**node_output, **existing_input}
for key in combined:
# Don't overwrite these keys since they are set in Pipeline.run
if key not in ["inputs", "params", "_debug"]:
additional_input[key] = combined[key]
return additional_input

async def _arun( # noqa: C901,PLR0912 type: ignore
self,
query: Optional[str] = None,
Expand Down Expand Up @@ -734,19 +758,27 @@ async def _arun( # noqa: C901,PLR0912 type: ignore
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
if query:
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if query and "query" not in updated_input:
updated_input["query"] = query
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand All @@ -756,6 +788,7 @@ async def _arun( # noqa: C901,PLR0912 type: ignore

return node_output

# pylint: disable=too-many-branches
def run_batch( # noqa: C901,PLR0912 type: ignore
self,
queries: Optional[List[str]] = None,
Expand Down Expand Up @@ -896,19 +929,32 @@ def run_batch( # noqa: C901,PLR0912 type: ignore
existing_input = queue[n]
if "inputs" not in existing_input.keys():
updated_input: Dict = {"inputs": [existing_input, node_output], "params": params}
if queries:
if "_debug" in existing_input.keys() or "_debug" in node_output.keys():
updated_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if queries and "queries" not in updated_input:
updated_input["queries"] = queries
if file_paths:
if file_paths and "file_paths" not in updated_input:
updated_input["file_paths"] = file_paths
if labels:
if labels and "labels" not in updated_input:
updated_input["labels"] = labels
if documents:
if documents and "documents" not in updated_input:
updated_input["documents"] = documents
if meta:
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
updated_input = existing_input
if "_debug" in node_output.keys():
existing_input["_debug"] = {
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
---
fixes:
- |
When using a `Pipeline` with a `JoinNode` (e.g. `JoinDocuments`) all information from the previous nodes was lost
other than a few select fields (e.g. `documents`). This was due to the `JoinNode` not properly passing on
the information from the previous nodes. This has been fixed and now all information from the previous nodes is
passed on to the next node in the pipeline.
For example, this is a pipeline that rewrites the `query` during pipeline execution combined with a hybrid retrieval
setup that requires a `JoinDocuments` node. Specifically the first prompt node rewrites the `query` to fix all
spelling errors, and this new `query` is used for retrieval. And now the `JoinDocuments` node will now pass on the
rewritten `query` so it can be used by the `QAPromptNode` node whereas before it would pass on the original query.
```python
from haystack import Pipeline
from haystack.nodes import BM25Retriever, EmbeddingRetriever, PromptNode, Shaper, JoinDocuments, PromptTemplate
from haystack.document_stores import InMemoryDocumentStore
document_store = InMemoryDocumentStore(use_bm25=True)
dicts = [{"content": "The capital of Germany is Berlin."}, {"content": "The capital of France is Paris."}]
document_store.write_documents(dicts)
query_prompt_node = PromptNode(
model_name_or_path="gpt-3.5-turbo",
api_key="",
default_prompt_template=PromptTemplate("You are a spell checker. Given a user query return the same query with all spelling errors fixed.\nUser Query: {query}\nSpell Checked Query:")
)
shaper = Shaper(
func="join_strings",
inputs={"strings": "results"},
outputs=["query"],
)
qa_prompt_node = PromptNode(
model_name_or_path="gpt-3.5-turbo",
api_key="",
default_prompt_template=PromptTemplate("Answer the user query. Query: {query}")
)
sparse_retriever = BM25Retriever(
document_store=document_store,
top_k=2
)
dense_retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="intfloat/e5-base-v2",
model_format="sentence_transformers",
top_k=2
)
document_store.update_embeddings(dense_retriever)
pipeline = Pipeline()
pipeline.add_node(component=query_prompt_node, name="QueryPromptNode", inputs=["Query"])
pipeline.add_node(component=shaper, name="ListToString", inputs=["QueryPromptNode"])
pipeline.add_node(component=sparse_retriever, name="BM25", inputs=["ListToString"])
pipeline.add_node(component=dense_retriever, name="Embedding", inputs=["ListToString"])
pipeline.add_node(
component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["BM25", "Embedding"]
)
pipeline.add_node(component=qa_prompt_node, name="QAPromptNode", inputs=["Join"])
out = pipeline.run(query="What is the captial of Grmny?", debug=True)
print(out["invocation_context"])
# Before Fix
# {'query': 'What is the captial of Grmny?', <-- Original Query!!
# 'results': ['The capital of Germany is Berlin.'],
# 'prompts': ['Answer the user query. Query: What is the captial of Grmny?'], <-- Original Query!!
# After Fix
# {'query': 'What is the capital of Germany?', <-- Rewritten Query!!
# 'results': ['The capital of Germany is Berlin.'],
# 'prompts': ['Answer the user query. Query: What is the capital of Germany?'], <-- Rewritten Query!!
```
24 changes: 23 additions & 1 deletion test/nodes/test_join_documents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


from haystack import Document
from haystack import Document, Pipeline
from haystack.nodes.other.join_docs import JoinDocuments
from copy import deepcopy

Expand Down Expand Up @@ -149,3 +149,25 @@ def test_joindocuments_rrf_weights():
assert result_none["documents"] == result_even["documents"]
assert result_uneven["documents"] != result_none["documents"]
assert result_uneven["documents"][0].score > result_none["documents"][0].score


@pytest.mark.unit
def test_join_node_empty_documents():
pipe = Pipeline()
join_node = JoinDocuments(join_mode="concatenate")
pipe.add_node(component=join_node, name="Join", inputs=["Query"])

# Test single document lists
output = pipe.run(query="test", documents=[])
assert len(output["documents"]) == 0


@pytest.mark.unit
def test_join_node_none_documents():
pipe = Pipeline()
join_node = JoinDocuments(join_mode="concatenate")
pipe.add_node(component=join_node, name="Join", inputs=["Query"])

# Test single document lists
output = pipe.run(query="test", documents=None)
assert len(output["documents"]) == 0
Loading

0 comments on commit 23827c0

Please sign in to comment.