diff --git a/haystack/components/retrievers/in_memory/bm25_retriever.py b/haystack/components/retrievers/in_memory/bm25_retriever.py index ac3f8486f8..e2edfeb578 100644 --- a/haystack/components/retrievers/in_memory/bm25_retriever.py +++ b/haystack/components/retrievers/in_memory/bm25_retriever.py @@ -38,7 +38,7 @@ class InMemoryBM25Retriever: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, document_store: InMemoryDocumentStore, filters: Optional[Dict[str, Any]] = None, diff --git a/haystack/components/retrievers/in_memory/embedding_retriever.py b/haystack/components/retrievers/in_memory/embedding_retriever.py index cddd2cef41..016b7a940b 100644 --- a/haystack/components/retrievers/in_memory/embedding_retriever.py +++ b/haystack/components/retrievers/in_memory/embedding_retriever.py @@ -50,7 +50,7 @@ class InMemoryEmbeddingRetriever: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, document_store: InMemoryDocumentStore, filters: Optional[Dict[str, Any]] = None, @@ -143,7 +143,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryEmbeddingRetriever": return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run( + def run( # pylint: disable=too-many-positional-arguments self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, diff --git a/haystack/components/retrievers/sentence_window_retriever.py b/haystack/components/retrievers/sentence_window_retriever.py index 979a462ea4..be1f9df100 100644 --- a/haystack/components/retrievers/sentence_window_retriever.py +++ b/haystack/components/retrievers/sentence_window_retriever.py @@ -151,7 +151,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceWindowRetriever": # deserialize the component return default_from_dict(cls, data) - @component.output_types(context_windows=List[str], context_documents=List[List[Document]]) + @component.output_types(context_windows=List[str], context_documents=List[Document]) def run(self, retrieved_documents: List[Document], window_size: Optional[int] = None): """ Based on the `source_id` and on the `doc.meta['split_id']` get surrounding documents from the document store. @@ -166,9 +166,9 @@ def run(self, retrieved_documents: List[Document], window_size: Optional[int] = A dictionary with the following keys: - `context_windows`: A list of strings, where each string represents the concatenated text from the context window of the corresponding document in `retrieved_documents`. - - `context_documents`: A list of lists of `Document` objects, where each inner list contains the - documents that come from the context window for the corresponding document in - `retrieved_documents`. + - `context_documents`: A list `Document` objects, containing the retrieved documents plus the context + document surrounding them. The documents are sorted by the `split_idx_start` + meta field. """ window_size = window_size or self.window_size @@ -200,6 +200,7 @@ def run(self, retrieved_documents: List[Document], window_size: Optional[int] = } ) context_text.append(self.merge_documents_text(context_docs)) - context_documents.append(context_docs) + context_docs_sorted = sorted(context_docs, key=lambda doc: doc.meta["split_idx_start"]) + context_documents.extend(context_docs_sorted) return {"context_windows": context_text, "context_documents": context_documents} diff --git a/releasenotes/notes/sentence-window-retriever-change-output-7beca98e9951039e.yaml b/releasenotes/notes/sentence-window-retriever-change-output-7beca98e9951039e.yaml new file mode 100644 index 0000000000..a2a72bbf2a --- /dev/null +++ b/releasenotes/notes/sentence-window-retriever-change-output-7beca98e9951039e.yaml @@ -0,0 +1,4 @@ +--- +upgrade: + - | + The SentenceWindowRetriever output key `context_documents` now outputs a List[Document] containing the retrieved documents together with the context windows ordered by `split_idx_start`. diff --git a/test/components/retrievers/test_sentence_window_retriever.py b/test/components/retrievers/test_sentence_window_retriever.py index 786d278dd6..4979fa334d 100644 --- a/test/components/retrievers/test_sentence_window_retriever.py +++ b/test/components/retrievers/test_sentence_window_retriever.py @@ -141,6 +141,39 @@ def test_constructor_parameter_does_not_change(self): retriever.run(retrieved_documents=[Document.from_dict(doc)], window_size=1) assert retriever.window_size == 5 + def test_context_documents_returned_are_ordered_by_split_idx_start(self): + docs = [] + accumulated_length = 0 + for sent in range(10): + content = f"Sentence {sent}." + docs.append( + Document( + content=content, + meta={ + "id": f"doc_{sent}", + "split_idx_start": accumulated_length, + "source_id": "source1", + "split_id": sent, + }, + ) + ) + accumulated_length += len(content) + + import random + + random.shuffle(docs) + + doc_store = InMemoryDocumentStore() + doc_store.write_documents(docs) + retriever = SentenceWindowRetriever(document_store=doc_store, window_size=3) + + # run the retriever with a document whose content = "Sentence 4." + result = retriever.run(retrieved_documents=[doc for doc in docs if doc.content == "Sentence 4."]) + + # assert that the context documents are in the correct order + assert len(result["context_documents"]) == 7 + assert [doc.meta["split_idx_start"] for doc in result["context_documents"]] == [11, 22, 33, 44, 55, 66, 77] + @pytest.mark.integration def test_run_with_pipeline(self): splitter = DocumentSplitter(split_length=1, split_overlap=0, split_by="sentence") @@ -165,13 +198,13 @@ def test_run_with_pipeline(self): "This is a text with some words. There is a second sentence. And there is also a third sentence. " "It also contains a fourth sentence. And a fifth sentence." ] - assert len(result["sentence_window_retriever"]["context_documents"][0]) == 5 + assert len(result["sentence_window_retriever"]["context_documents"]) == 5 result = pipe.run({"bm25_retriever": {"query": "third"}, "sentence_window_retriever": {"window_size": 1}}) assert result["sentence_window_retriever"]["context_windows"] == [ " There is a second sentence. And there is also a third sentence. It also contains a fourth sentence." ] - assert len(result["sentence_window_retriever"]["context_documents"][0]) == 3 + assert len(result["sentence_window_retriever"]["context_documents"]) == 3 @pytest.mark.integration def test_serialization_deserialization_in_pipeline(self):