From 2682c69ae7cf7265e46df45742b91fdf9be971b3 Mon Sep 17 00:00:00 2001 From: tstadel Date: Thu, 30 Nov 2023 15:45:42 +0100 Subject: [PATCH 1/2] feat: extend OpenSearch params support --- .../src/opensearch_haystack/bm25_retriever.py | 58 +++++++++++++++++-- .../src/opensearch_haystack/document_store.py | 8 ++- .../embedding_retriever.py | 13 ++++- .../opensearch/tests/test_bm25_retriever.py | 58 +++++++++++++++++++ .../opensearch/tests/test_document_store.py | 46 +++++++++++++++ .../tests/test_embedding_retriever.py | 32 ++++++++++ 6 files changed, 204 insertions(+), 11 deletions(-) diff --git a/integrations/opensearch/src/opensearch_haystack/bm25_retriever.py b/integrations/opensearch/src/opensearch_haystack/bm25_retriever.py index 9755d6253..788a40870 100644 --- a/integrations/opensearch/src/opensearch_haystack/bm25_retriever.py +++ b/integrations/opensearch/src/opensearch_haystack/bm25_retriever.py @@ -16,7 +16,22 @@ def __init__( fuzziness: str = "AUTO", top_k: int = 10, scale_score: bool = False, + all_terms_must_match: bool = False, ): + """ + Create the OpenSearchBM25Retriever component. + + :param document_store: An instance of OpenSearchDocumentStore. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + :param fuzziness: Fuzziness parameter for full-text queries. Defaults to "AUTO". + :param top_k: Maximum number of Documents to return, defaults to 10 + :param scale_score: Whether to scale the score of retrieved documents between 0 and 1. + This is useful when comparing documents across different indexes. + :param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents. + This is useful when searching for short text where even one term can make a difference. + :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. + + """ if not isinstance(document_store, OpenSearchDocumentStore): msg = "document_store must be an instance of OpenSearchDocumentStore" raise ValueError(msg) @@ -26,6 +41,7 @@ def __init__( self._fuzziness = fuzziness self._top_k = top_k self._scale_score = scale_score + self._all_terms_must_match = all_terms_must_match def to_dict(self) -> Dict[str, Any]: return default_to_dict( @@ -45,12 +61,44 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever": return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run(self, query: str): + def run( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + all_terms_must_match: Optional[bool] = None, + top_k: Optional[int] = None, + fuzziness: Optional[str] = None, + scale_score: Optional[bool] = None, + ): + """ + Retrieve documents using BM25 retrieval. + + :param query: The query string + :param filters: Optional filters to narrow down the search space. + :param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents. + :param top_k: Maximum number of Documents to return. + :param fuzziness: Fuzziness parameter for full-text queries. + :param scale_score: Whether to scale the score of retrieved documents between 0 and 1. + This is useful when comparing documents across different indexes. + :return: A dictionary containing the retrieved documents. + """ + if filters is None: + filters = self._filters + if all_terms_must_match is None: + all_terms_must_match = self._all_terms_must_match + if top_k is None: + top_k = self._top_k + if fuzziness is None: + fuzziness = self._fuzziness + if scale_score is None: + scale_score = self._scale_score + docs = self._document_store._bm25_retrieval( query=query, - filters=self._filters, - fuzziness=self._fuzziness, - top_k=self._top_k, - scale_score=self._scale_score, + filters=filters, + fuzziness=fuzziness, + top_k=top_k, + scale_score=scale_score, + all_terms_must_match=all_terms_must_match, ) return {"documents": docs} diff --git a/integrations/opensearch/src/opensearch_haystack/document_store.py b/integrations/opensearch/src/opensearch_haystack/document_store.py index fe8495fb0..e4167f777 100644 --- a/integrations/opensearch/src/opensearch_haystack/document_store.py +++ b/integrations/opensearch/src/opensearch_haystack/document_store.py @@ -221,6 +221,7 @@ def _bm25_retrieval( fuzziness: str = "AUTO", top_k: int = 10, scale_score: bool = False, + all_terms_must_match: bool = False, ) -> List[Document]: """ OpenSearch by defaults uses BM25 search algorithm. @@ -234,13 +235,13 @@ def _bm25_retrieval( `query` must be a non empty string, otherwise a `ValueError` will be raised. :param query: String to search in saved Documents' text. - :param filters: Filters applied to the retrieved Documents, for more info - see `OpenSearchDocumentStore.filter_documents`, defaults to None + :param filters: Optional filters to narrow down the search space. :param fuzziness: Fuzziness parameter passed to OpenSearch, defaults to "AUTO". see the official documentation for valid values: https://www.elastic.co/guide/en/OpenSearch/reference/current/common-options.html#fuzziness :param top_k: Maximum number of Documents to return, defaults to 10 :param scale_score: If `True` scales the Document`s scores between 0 and 1, defaults to False + :param all_terms_must_match: If `True` all terms in `query` must be present in the Document, defaults to False :raises ValueError: If `query` is an empty string :return: List of Document that match `query` """ @@ -249,6 +250,7 @@ def _bm25_retrieval( msg = "query must be a non empty string" raise ValueError(msg) + operator = "AND" if all_terms_must_match else "OR" body: Dict[str, Any] = { "size": top_k, "query": { @@ -259,7 +261,7 @@ def _bm25_retrieval( "query": query, "fuzziness": fuzziness, "type": "most_fields", - "operator": "AND", + "operator": operator, } } ] diff --git a/integrations/opensearch/src/opensearch_haystack/embedding_retriever.py b/integrations/opensearch/src/opensearch_haystack/embedding_retriever.py index 9bbc2a7a3..427920e8a 100644 --- a/integrations/opensearch/src/opensearch_haystack/embedding_retriever.py +++ b/integrations/opensearch/src/opensearch_haystack/embedding_retriever.py @@ -54,16 +54,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever": return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run(self, query_embedding: List[float]): + def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """ Retrieve documents using a vector similarity metric. :param query_embedding: Embedding of the query. + :param filters: Optional filters to narrow down the search space. + :param top_k: Maximum number of Documents to return. :return: List of Document similar to `query_embedding`. """ + if filters is None: + filters = self._filters + if top_k is None: + top_k = self._top_k + docs = self._document_store._embedding_retrieval( query_embedding=query_embedding, - filters=self._filters, - top_k=self._top_k, + filters=filters, + top_k=top_k, ) return {"documents": docs} diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index cfea2d767..c552113c9 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -72,6 +72,64 @@ def test_run(): fuzziness="AUTO", top_k=10, scale_score=False, + all_terms_must_match=False, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_init_params(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = OpenSearchBM25Retriever( + document_store=mock_store, + filters={"from": "init"}, + all_terms_must_match=True, + scale_score=True, + top_k=11, + fuzziness="1", + ) + res = retriever.run(query="some query") + mock_store._bm25_retrieval.assert_called_once_with( + query="some query", + filters={"from": "init"}, + fuzziness="1", + top_k=11, + scale_score=True, + all_terms_must_match=True, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_time_params(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = OpenSearchBM25Retriever( + document_store=mock_store, + filters={"from": "init"}, + all_terms_must_match=True, + scale_score=True, + top_k=11, + fuzziness="1", + ) + res = retriever.run( + query="some query", + filters={"from": "run"}, + all_terms_must_match=False, + scale_score=False, + top_k=9, + fuzziness="2", + ) + mock_store._bm25_retrieval.assert_called_once_with( + query="some query", + filters={"from": "run"}, + fuzziness="2", + top_k=9, + scale_score=False, + all_terms_must_match=False, ) assert len(res) == 1 assert len(res["documents"]) == 1 diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index 4678cca70..8f6e0a13c 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -167,6 +167,52 @@ def test_bm25_retrieval_pagination(self, document_store: OpenSearchDocumentStore assert len(res) == 11 assert all("programming" in doc.content for doc in res) + def test_bm25_retrieval_all_terms_must_match(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + + res = document_store._bm25_retrieval("functional Haskell", top_k=3, all_terms_must_match=True) + assert len(res) == 1 + assert "Haskell is a functional programming language" in res[0].content + + def test_bm25_retrieval_all_terms_must_match_false(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + + res = document_store._bm25_retrieval("functional Haskell", top_k=10, all_terms_must_match=False) + assert len(res) == 5 + assert "functional" in res[0].content + assert "functional" in res[1].content + assert "functional" in res[2].content + assert "functional" in res[3].content + assert "functional" in res[4].content + def test_bm25_retrieval_with_fuzziness(self, document_store: OpenSearchDocumentStore): document_store.write_documents( [ diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index 2bfe5761a..f97dd6e9a 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -68,3 +68,35 @@ def test_run(): assert len(res["documents"]) == 1 assert res["documents"][0].content == "Test doc" assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_init_params(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters={"from": "init"}, + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_time_params(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11) + res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters={"from": "run"}, + top_k=9, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] From 7ae718f2fdbf7e9dbf58f038a6e8766919dfc0bf Mon Sep 17 00:00:00 2001 From: tstadel Date: Thu, 30 Nov 2023 15:48:35 +0100 Subject: [PATCH 2/2] add defaults to docstrings --- .../opensearch/src/opensearch_haystack/bm25_retriever.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/opensearch/src/opensearch_haystack/bm25_retriever.py b/integrations/opensearch/src/opensearch_haystack/bm25_retriever.py index 788a40870..91a133345 100644 --- a/integrations/opensearch/src/opensearch_haystack/bm25_retriever.py +++ b/integrations/opensearch/src/opensearch_haystack/bm25_retriever.py @@ -26,9 +26,9 @@ def __init__( :param fuzziness: Fuzziness parameter for full-text queries. Defaults to "AUTO". :param top_k: Maximum number of Documents to return, defaults to 10 :param scale_score: Whether to scale the score of retrieved documents between 0 and 1. - This is useful when comparing documents across different indexes. + This is useful when comparing documents across different indexes. Defaults to False. :param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents. - This is useful when searching for short text where even one term can make a difference. + This is useful when searching for short text where even one term can make a difference. Defaults to False. :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. """