Skip to content

Commit

Permalink
fix: window_size set during run instead of construction (#8463)
Browse files Browse the repository at this point in the history
* window_size set during runtime

* revert init and update run with window_size

* improved doc, removed print

* adding release notes

* updating tests

* reverting docstring example

* Update haystack/components/retrievers/sentence_window_retriever.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update haystack/components/retrievers/sentence_window_retriever.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update haystack/components/retrievers/sentence_window_retriever.py

Co-authored-by: Silvano Cerza <[email protected]>

---------

Co-authored-by: David S. Batista <[email protected]>
Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
3 people authored Oct 22, 2024
1 parent 0157459 commit a556e11
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
16 changes: 11 additions & 5 deletions haystack/components/retrievers/sentence_window_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import DocumentStore
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self, document_store: DocumentStore, window_size: int = 3):
:param document_store: The Document Store to retrieve the surrounding documents from.
:param window_size: The number of documents to retrieve before and after the relevant one.
For example, `window_size: 2` fetches 2 preceding and 2 following documents.
For example, `window_size: 2` fetches 2 preceding and 2 following documents.
"""
if window_size < 1:
raise ValueError("The window_size parameter must be greater than 0.")
Expand Down Expand Up @@ -144,14 +144,16 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceWindowRetriever":
return default_from_dict(cls, data)

@component.output_types(context_windows=List[str], context_documents=List[List[Document]])
def run(self, retrieved_documents: List[Document]):
def run(self, retrieved_documents: List[Document], window_size: Optional[int] = None):
"""
Based on the `source_id` and on the `doc.meta['split_id']` get surrounding documents from the document store.
Implements the logic behind the sentence-window technique, retrieving the surrounding documents of a given
document from the document store.
:param retrieved_documents: List of retrieved documents from the previous retriever.
:param window_size: The number of documents to retrieve before and after the relevant one. This will overwrite
the `window_size` parameter set in the constructor.
:returns:
A dictionary with the following keys:
- `context_windows`: A list of strings, where each string represents the concatenated text from the
Expand All @@ -161,6 +163,10 @@ def run(self, retrieved_documents: List[Document]):
`retrieved_documents`.
"""
window_size = window_size or self.window_size

if window_size < 1:
raise ValueError("The window_size parameter must be greater than 0.")

if not all("split_id" in doc.meta for doc in retrieved_documents):
raise ValueError("The retrieved documents must have 'split_id' in the metadata.")
Expand All @@ -173,8 +179,8 @@ def run(self, retrieved_documents: List[Document]):
for doc in retrieved_documents:
source_id = doc.meta["source_id"]
split_id = doc.meta["split_id"]
min_before = min(list(range(split_id - 1, split_id - self.window_size - 1, -1)))
max_after = max(list(range(split_id + 1, split_id + self.window_size + 1, 1)))
min_before = min(list(range(split_id - 1, split_id - window_size - 1, -1)))
max_after = max(list(range(split_id + 1, split_id + window_size + 1, 1)))
context_docs = self.document_store.filter_documents(
{
"operator": "AND",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
The SentenceWindowRetriever now supports the `window_size` parameter at run time, overwriting the value set in the constructor.
39 changes: 34 additions & 5 deletions test/components/retrievers/test_sentence_window_retriever.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest

from haystack import Document, DeserializationError, Pipeline
from haystack import DeserializationError, Document, Pipeline
from haystack.components.preprocessors import DocumentSplitter
from haystack.components.retrievers import InMemoryBM25Retriever
from haystack.components.retrievers.sentence_window_retriever import SentenceWindowRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.preprocessors import DocumentSplitter


class TestSentenceWindowRetriever:
Expand Down Expand Up @@ -118,9 +118,32 @@ def test_document_without_source_id(self):
retriever = SentenceWindowRetriever(document_store=InMemoryDocumentStore(), window_size=3)
retriever.run(retrieved_documents=docs)

def test_run_invalid_window_size(self):
docs = [Document(content="This is a text with some words. There is a ", meta={"id": "doc_0", "split_id": 0})]
with pytest.raises(ValueError):
retriever = SentenceWindowRetriever(document_store=InMemoryDocumentStore(), window_size=0)
retriever.run(retrieved_documents=docs)

def test_constructor_parameter_does_not_change(self):
retriever = SentenceWindowRetriever(InMemoryDocumentStore(), window_size=5)
assert retriever.window_size == 5

doc = {
"id": "doc_0",
"content": "This is a text with some words. There is a ",
"source_id": "c5d7c632affc486d0cfe7b3c0f4dc1d3896ea720da2b538d6d10b104a3df5f99",
"page_number": 1,
"split_id": 0,
"split_idx_start": 0,
"_split_overlap": [{"doc_id": "doc_1", "range": (0, 23)}],
}

retriever.run(retrieved_documents=[Document.from_dict(doc)], window_size=1)
assert retriever.window_size == 5

@pytest.mark.integration
def test_run_with_pipeline(self):
splitter = DocumentSplitter(split_length=10, split_overlap=5, split_by="word")
splitter = DocumentSplitter(split_length=1, split_overlap=0, split_by="sentence")
text = (
"This is a text with some words. There is a second sentence. And there is also a third sentence. "
"It also contains a fourth sentence. And a fifth sentence. And a sixth sentence. And a seventh sentence"
Expand All @@ -139,11 +162,17 @@ def test_run_with_pipeline(self):
result = pipe.run({"bm25_retriever": {"query": "third"}})

assert result["sentence_window_retriever"]["context_windows"] == [
"some words. There is a second sentence. And there is also a third sentence. It also "
"contains a fourth sentence. And a fifth sentence. And a sixth sentence. And a "
"This is a text with some words. There is a second sentence. And there is also a third sentence. "
"It also contains a fourth sentence. And a fifth sentence."
]
assert len(result["sentence_window_retriever"]["context_documents"][0]) == 5

result = pipe.run({"bm25_retriever": {"query": "third"}, "sentence_window_retriever": {"window_size": 1}})
assert result["sentence_window_retriever"]["context_windows"] == [
" There is a second sentence. And there is also a third sentence. It also contains a fourth sentence."
]
assert len(result["sentence_window_retriever"]["context_documents"][0]) == 3

@pytest.mark.integration
def test_serialization_deserialization_in_pipeline(self):
doc_store = InMemoryDocumentStore()
Expand Down

0 comments on commit a556e11

Please sign in to comment.