-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: Fix pipeline with join node (#7510)
* Update Pipeline.run method to pass on more info when using join node * Add release notes * Update docs and update _arun method as well * Added support for unique keys * Simplify the method * fix mypy error * ignore mypy in image converter * Ignore mypy warning * Update _debug as well when there are more than two streams being joined. * Add copy to invocation_context instead of overwriting * Fix the copy * Add copy to Shaper as well * Add another test to make sure query stays changed * Updated test to make sure that it does only pass with the new changes * Add to run_batch as well * Fix type annotations * Make join nodes work when no docs are provided * Ignore pylint error * Update haystack/nodes/other/join_docs.py Co-authored-by: tstadel <[email protected]> * Expand on release note * Fix test --------- Co-authored-by: Stefano Fiorucci <[email protected]> Co-authored-by: tstadel <[email protected]>
- Loading branch information
1 parent
c1bbd21
commit 08007df
Showing
8 changed files
with
287 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 69 additions & 0 deletions
69
releasenotes/notes/fix-pipeline-with-join-node-5f23a426cd4d88d9.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
--- | ||
fixes: | ||
- | | ||
When using a `Pipeline` with a `JoinNode` (e.g. `JoinDocuments`) all information from the previous nodes was lost | ||
other than a few select fields (e.g. `documents`). This was due to the `JoinNode` not properly passing on | ||
the information from the previous nodes. This has been fixed and now all information from the previous nodes is | ||
passed on to the next node in the pipeline. | ||
For example, this is a pipeline that rewrites the `query` during pipeline execution combined with a hybrid retrieval | ||
setup that requires a `JoinDocuments` node. Specifically the first prompt node rewrites the `query` to fix all | ||
spelling errors, and this new `query` is used for retrieval. And now the `JoinDocuments` node will now pass on the | ||
rewritten `query` so it can be used by the `QAPromptNode` node whereas before it would pass on the original query. | ||
```python | ||
from haystack import Pipeline | ||
from haystack.nodes import BM25Retriever, EmbeddingRetriever, PromptNode, Shaper, JoinDocuments, PromptTemplate | ||
from haystack.document_stores import InMemoryDocumentStore | ||
document_store = InMemoryDocumentStore(use_bm25=True) | ||
dicts = [{"content": "The capital of Germany is Berlin."}, {"content": "The capital of France is Paris."}] | ||
document_store.write_documents(dicts) | ||
query_prompt_node = PromptNode( | ||
model_name_or_path="gpt-3.5-turbo", | ||
api_key="", | ||
default_prompt_template=PromptTemplate("You are a spell checker. Given a user query return the same query with all spelling errors fixed.\nUser Query: {query}\nSpell Checked Query:") | ||
) | ||
shaper = Shaper( | ||
func="join_strings", | ||
inputs={"strings": "results"}, | ||
outputs=["query"], | ||
) | ||
qa_prompt_node = PromptNode( | ||
model_name_or_path="gpt-3.5-turbo", | ||
api_key="", | ||
default_prompt_template=PromptTemplate("Answer the user query. Query: {query}") | ||
) | ||
sparse_retriever = BM25Retriever( | ||
document_store=document_store, | ||
top_k=2 | ||
) | ||
dense_retriever = EmbeddingRetriever( | ||
document_store=document_store, | ||
embedding_model="intfloat/e5-base-v2", | ||
model_format="sentence_transformers", | ||
top_k=2 | ||
) | ||
document_store.update_embeddings(dense_retriever) | ||
pipeline = Pipeline() | ||
pipeline.add_node(component=query_prompt_node, name="QueryPromptNode", inputs=["Query"]) | ||
pipeline.add_node(component=shaper, name="ListToString", inputs=["QueryPromptNode"]) | ||
pipeline.add_node(component=sparse_retriever, name="BM25", inputs=["ListToString"]) | ||
pipeline.add_node(component=dense_retriever, name="Embedding", inputs=["ListToString"]) | ||
pipeline.add_node( | ||
component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["BM25", "Embedding"] | ||
) | ||
pipeline.add_node(component=qa_prompt_node, name="QAPromptNode", inputs=["Join"]) | ||
out = pipeline.run(query="What is the captial of Grmny?", debug=True) | ||
print(out["invocation_context"]) | ||
# Before Fix | ||
# {'query': 'What is the captial of Grmny?', <-- Original Query!! | ||
# 'results': ['The capital of Germany is Berlin.'], | ||
# 'prompts': ['Answer the user query. Query: What is the captial of Grmny?'], <-- Original Query!! | ||
# After Fix | ||
# {'query': 'What is the capital of Germany?', <-- Rewritten Query!! | ||
# 'results': ['The capital of Germany is Berlin.'], | ||
# 'prompts': ['Answer the user query. Query: What is the capital of Germany?'], <-- Rewritten Query!! | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.