From 69c29a95d82a7f2bbd53949257a926351521879e Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue, 25 Jun 2024 17:24:10 +0200 Subject: [PATCH] feat: add custom_query param to OpenSearch retrievers (#841) * feat: add custom_query param to OpenSearch retrievers * feat: add custom_query to OpenSearch retrievers * add as run param * fix lint * switch to jinja2 templates * Revert "switch to jinja2 templates" This reverts commit f36ed13fa25abc5d17df7e087841a9ecf839c75f. * support custom_query as dict * remove unneccessary comments * remove str * fix lint --- .../retrievers/opensearch/bm25_retriever.py | 56 +++++++ .../opensearch/embedding_retriever.py | 75 ++++++++- .../opensearch/document_store.py | 152 +++++++++++++----- .../opensearch/tests/test_bm25_retriever.py | 4 + .../opensearch/tests/test_document_store.py | 106 ++++++++++++ .../tests/test_embedding_retriever.py | 7 +- 6 files changed, 357 insertions(+), 43 deletions(-) diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py index 0ad257b42..949826230 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py @@ -19,6 +19,7 @@ def __init__( top_k: int = 10, scale_score: bool = False, all_terms_must_match: bool = False, + custom_query: Optional[Dict[str, Any]] = None, ): """ Create the OpenSearchBM25Retriever component. @@ -31,6 +32,31 @@ def __init__( This is useful when comparing documents across different indexes. Defaults to False. :param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents. This is useful when searching for short text where even one term can make a difference. Defaults to False. + :param custom_query: The query containing a mandatory `$query` and an optional `$filters` placeholder + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "should": [{"multi_match": { + "query": "$query", // mandatory query placeholder + "type": "most_fields", + "fields": ["content", "title"]}}], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + + **For this custom_query, a sample `run()` could be:** + + ```python + retriever.run(query="Why did the revenue increase?", + filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + ``` + :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. """ @@ -44,6 +70,7 @@ def __init__( self._top_k = top_k self._scale_score = scale_score self._all_terms_must_match = all_terms_must_match + self._custom_query = custom_query def to_dict(self) -> Dict[str, Any]: """ @@ -86,6 +113,7 @@ def run( top_k: Optional[int] = None, fuzziness: Optional[str] = None, scale_score: Optional[bool] = None, + custom_query: Optional[Dict[str, Any]] = None, ): """ Retrieve documents using BM25 retrieval. @@ -97,6 +125,31 @@ def run( :param fuzziness: Fuzziness parameter for full-text queries. :param scale_score: Whether to scale the score of retrieved documents between 0 and 1. This is useful when comparing documents across different indexes. + :param custom_query: The query containing a mandatory `$query` and an optional `$filters` placeholder + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "should": [{"multi_match": { + "query": "$query", // mandatory query placeholder + "type": "most_fields", + "fields": ["content", "title"]}}], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + + **For this custom_query, a sample `run()` could be:** + + ```python + retriever.run(query="Why did the revenue increase?", + filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + ``` + :returns: A dictionary containing the retrieved documents with the following structure: @@ -113,6 +166,8 @@ def run( fuzziness = self._fuzziness if scale_score is None: scale_score = self._scale_score + if custom_query is None: + custom_query = self._custom_query docs = self._document_store._bm25_retrieval( query=query, @@ -121,5 +176,6 @@ def run( top_k=top_k, scale_score=scale_score, all_terms_must_match=all_terms_must_match, + custom_query=custom_query, ) return {"documents": docs} diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index 50b30d7f1..81688601f 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -22,6 +22,7 @@ def __init__( document_store: OpenSearchDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, + custom_query: Optional[Dict[str, Any]] = None, ): """ Create the OpenSearchEmbeddingRetriever component. @@ -30,6 +31,37 @@ def __init__( :param filters: Filters applied to the retrieved Documents. Defaults to None. Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. :param top_k: Maximum number of Documents to return, defaults to 10 + :param custom_query: The query containing a mandatory `$query_embedding` and an optional `$filters` placeholder + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "must": [ + { + "knn": { + "embedding": { + "vector": "$query_embedding", // mandatory query placeholder + "k": 10000, + } + } + } + ], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + + **For this custom_query, a sample `run()` could be:** + + ```python + retriever.run(query_embedding=embedding, + filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + ``` + :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. """ if not isinstance(document_store, OpenSearchDocumentStore): @@ -39,6 +71,7 @@ def __init__( self._document_store = document_store self._filters = filters or {} self._top_k = top_k + self._custom_query = custom_query def to_dict(self) -> Dict[str, Any]: """ @@ -71,13 +104,50 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever": return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + custom_query: Optional[Dict[str, Any]] = None, + ): """ Retrieve documents using a vector similarity metric. :param query_embedding: Embedding of the query. :param filters: Optional filters to narrow down the search space. :param top_k: Maximum number of Documents to return. + :param custom_query: The query containing a mandatory `$query_embedding` and an optional `$filters` placeholder + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "must": [ + { + "knn": { + "embedding": { + "vector": "$query_embedding", // mandatory query placeholder + "k": 10000, + } + } + } + ], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + + **For this custom_query, a sample `run()` could be:** + + ```python + retriever.run(query_embedding=embedding, + filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + ``` + :returns: Dictionary with key "documents" containing the retrieved Documents. - documents: List of Document similar to `query_embedding`. @@ -86,10 +156,13 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = filters = self._filters if top_k is None: top_k = self._top_k + if custom_query is None: + custom_query = self._custom_query docs = self._document_store._embedding_retrieval( query_embedding=query_embedding, filters=filters, top_k=top_k, + custom_query=custom_query, ) return {"documents": docs} diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 6f7ef3f93..d382e3d94 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -312,6 +312,7 @@ def _bm25_retrieval( top_k: int = 10, scale_score: bool = False, all_terms_must_match: bool = False, + custom_query: Optional[Dict[str, Any]] = None, ) -> List[Document]: """ OpenSearch by defaults uses BM25 search algorithm. @@ -322,8 +323,6 @@ def _bm25_retrieval( `OpenSearchDocumentStore` nor called directly. `OpenSearchBM25Retriever` uses this method directly and is the public interface for it. - `query` must be a non empty string, otherwise a `ValueError` will be raised. - :param query: String to search in saved Documents' text. :param filters: Optional filters to narrow down the search space. :param fuzziness: Fuzziness parameter passed to OpenSearch, defaults to "AUTO". see the official documentation @@ -331,35 +330,58 @@ def _bm25_retrieval( :param top_k: Maximum number of Documents to return, defaults to 10 :param scale_score: If `True` scales the Document`s scores between 0 and 1, defaults to False :param all_terms_must_match: If `True` all terms in `query` must be present in the Document, defaults to False - :raises ValueError: If `query` is an empty string + :param custom_query: The query containing a mandatory `$query` and an optional `$filters` placeholder + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "should": [{"multi_match": { + "query": "$query", // mandatory query placeholder + "type": "most_fields", + "fields": ["content", "title"]}}], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + :returns: List of Document that match `query` """ if not query: - msg = "query must be a non empty string" - raise ValueError(msg) + body: Dict[str, Any] = {"query": {"bool": {"must": {"match_all": {}}}}} + if filters: + body["query"]["bool"]["filter"] = normalize_filters(filters) + + if isinstance(custom_query, dict): + body = self._render_custom_query(custom_query, {"$query": query, "$filters": normalize_filters(filters)}) - operator = "AND" if all_terms_must_match else "OR" - body: Dict[str, Any] = { - "size": top_k, - "query": { - "bool": { - "must": [ - { - "multi_match": { - "query": query, - "fuzziness": fuzziness, - "type": "most_fields", - "operator": operator, + else: + operator = "AND" if all_terms_must_match else "OR" + body = { + "query": { + "bool": { + "must": [ + { + "multi_match": { + "query": query, + "fuzziness": fuzziness, + "type": "most_fields", + "operator": operator, + } } - } - ] - } - }, - } + ] + } + }, + } - if filters: - body["query"]["bool"]["filter"] = normalize_filters(filters) + if filters: + body["query"]["bool"]["filter"] = normalize_filters(filters) + + body["size"] = top_k # For some applications not returning the embedding can save a lot of bandwidth # if you don't need this data not retrieving it can be a good idea @@ -380,6 +402,7 @@ def _embedding_retrieval( *, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, + custom_query: Optional[Dict[str, Any]] = None, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. @@ -393,6 +416,29 @@ def _embedding_retrieval( :param filters: Filters applied to the retrieved Documents. Defaults to None. Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. :param top_k: Maximum number of Documents to return, defaults to 10 + :param custom_query: The query containing a mandatory `$query_embedding` and an optional `$filters` placeholder + + **An example custom_query:** + ```python + { + "query": { + "bool": { + "must": [ + { + "knn": { + "embedding": { + "vector": "$query_embedding", // mandatory query placeholder + "k": 10000, + } + } + } + ], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + :raises ValueError: If `query_embedding` is an empty list :returns: List of Document that are most similar to `query_embedding` """ @@ -401,26 +447,33 @@ def _embedding_retrieval( msg = "query_embedding must be a non-empty list of floats" raise ValueError(msg) - body: Dict[str, Any] = { - "query": { - "bool": { - "must": [ - { - "knn": { - "embedding": { - "vector": query_embedding, - "k": top_k, + if isinstance(custom_query, dict): + body = self._render_custom_query( + custom_query, {"$query_embedding": query_embedding, "$filters": normalize_filters(filters)} + ) + + else: + body = { + "query": { + "bool": { + "must": [ + { + "knn": { + "embedding": { + "vector": query_embedding, + "k": top_k, + } } } - } - ], - } - }, - "size": top_k, - } + ], + } + }, + } - if filters: - body["query"]["bool"]["filter"] = normalize_filters(filters) + if filters: + body["query"]["bool"]["filter"] = normalize_filters(filters) + + body["size"] = top_k # For some applications not returning the embedding can save a lot of bandwidth # if you don't need this data not retrieving it can be a good idea @@ -429,3 +482,20 @@ def _embedding_retrieval( docs = self._search_documents(**body) return docs + + def _render_custom_query(self, custom_query: Any, substitutions: Dict[str, Any]) -> Any: + """ + Recursively replaces the placeholders in the custom_query with the actual values. + + :param custom_query: The custom query to replace the placeholders in. + :param substitutions: The dictionary containing the actual values to replace the placeholders with. + :returns: The custom query with the placeholders replaced. + """ + if isinstance(custom_query, dict): + return {key: self._render_custom_query(value, substitutions) for key, value in custom_query.items()} + elif isinstance(custom_query, list): + return [self._render_custom_query(entry, substitutions) for entry in custom_query] + elif isinstance(custom_query, str): + return substitutions.get(custom_query, custom_query) + + return custom_query diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index 4242386f0..a3e3e3acd 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -91,6 +91,7 @@ def test_run(): top_k=10, scale_score=False, all_terms_must_match=False, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -107,6 +108,7 @@ def test_run_init_params(): scale_score=True, top_k=11, fuzziness="1", + custom_query={"some": "custom query"}, ) res = retriever.run(query="some query") mock_store._bm25_retrieval.assert_called_once_with( @@ -116,6 +118,7 @@ def test_run_init_params(): top_k=11, scale_score=True, all_terms_must_match=True, + custom_query={"some": "custom query"}, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -148,6 +151,7 @@ def test_run_time_params(): top_k=9, scale_score=False, all_terms_must_match=False, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1 diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index 4b7e242f2..369782ef3 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -336,6 +336,87 @@ def test_bm25_retrieval_with_fuzziness(self, document_store: OpenSearchDocumentS assert "functional" in res[1].content assert "functional" in res[2].content + def test_bm25_retrieval_with_custom_query(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document( + content="Haskell is a functional programming language", + meta={"likes": 100000, "language_type": "functional"}, + id="1", + ), + Document( + content="Lisp is a functional programming language", + meta={"likes": 10000, "language_type": "functional"}, + id="2", + ), + Document( + content="Exilir is a functional programming language", + meta={"likes": 1000, "language_type": "functional"}, + id="3", + ), + Document( + content="F# is a functional programming language", + meta={"likes": 100, "language_type": "functional"}, + id="4", + ), + Document( + content="C# is a functional programming language", + meta={"likes": 10, "language_type": "functional"}, + id="5", + ), + Document( + content="C++ is an object oriented programming language", + meta={"likes": 100000, "language_type": "object_oriented"}, + id="6", + ), + Document( + content="Dart is an object oriented programming language", + meta={"likes": 10000, "language_type": "object_oriented"}, + id="7", + ), + Document( + content="Go is an object oriented programming language", + meta={"likes": 1000, "language_type": "object_oriented"}, + id="8", + ), + Document( + content="Python is a object oriented programming language", + meta={"likes": 100, "language_type": "object_oriented"}, + id="9", + ), + Document( + content="Ruby is a object oriented programming language", + meta={"likes": 10, "language_type": "object_oriented"}, + id="10", + ), + Document( + content="PHP is a object oriented programming language", + meta={"likes": 1, "language_type": "object_oriented"}, + id="11", + ), + ] + ) + + custom_query = { + "query": { + "function_score": { + "query": {"bool": {"must": {"match": {"content": "$query"}}, "filter": "$filters"}}, + "field_value_factor": {"field": "likes", "factor": 0.1, "modifier": "log1p", "missing": 0}, + } + } + } + + res = document_store._bm25_retrieval( + "functional", + top_k=3, + custom_query=custom_query, + filters={"field": "language_type", "operator": "==", "value": "functional"}, + ) + assert len(res) == 3 + assert "1" == res[0].id + assert "2" == res[1].id + assert "3" == res[2].id + def test_embedding_retrieval(self, document_store_embedding_dim_4: OpenSearchDocumentStore): docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), @@ -387,6 +468,31 @@ def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: Op ) assert len(results) == 11 + def test_embedding_retrieval_with_custom_query(self, document_store_embedding_dim_4: OpenSearchDocumentStore): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4.write_documents(docs) + + custom_query = { + "query": { + "bool": {"must": [{"knn": {"embedding": {"vector": "$query_embedding", "k": 3}}}], "filter": "$filters"} + } + } + + filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} + results = document_store_embedding_dim_4._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, filters=filters, custom_query=custom_query + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_query_documents_different_embedding_sizes( self, document_store_embedding_dim_4: OpenSearchDocumentStore ): diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index 7bf6c09eb..325ff16d3 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -98,6 +98,7 @@ def test_run(): query_embedding=[0.5, 0.7], filters={}, top_k=10, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -108,12 +109,15 @@ def test_run(): def test_run_init_params(): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] - retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11) + retriever = OpenSearchEmbeddingRetriever( + document_store=mock_store, filters={"from": "init"}, top_k=11, custom_query="custom_query" + ) res = retriever.run(query_embedding=[0.5, 0.7]) mock_store._embedding_retrieval.assert_called_once_with( query_embedding=[0.5, 0.7], filters={"from": "init"}, top_k=11, + custom_query="custom_query", ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -130,6 +134,7 @@ def test_run_time_params(): query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1