Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: efficient knn filtering support for OpenSearch #1134

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
custom_query: Optional[Dict[str, Any]] = None,
raise_on_failure: bool = True,
efficient_filtering: bool = False,
):
"""
Create the OpenSearchEmbeddingRetriever component.
Expand Down Expand Up @@ -85,6 +86,8 @@ def __init__(
:param raise_on_failure:
If `True`, raises an exception if the API call fails.
If `False`, logs a warning and returns an empty list.
:param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search.
This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib".

:raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore.
"""
Expand All @@ -100,6 +103,7 @@ def __init__(
)
self._custom_query = custom_query
self._raise_on_failure = raise_on_failure
self._efficient_filtering = efficient_filtering

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -116,6 +120,7 @@ def to_dict(self) -> Dict[str, Any]:
filter_policy=self._filter_policy.value,
custom_query=self._custom_query,
raise_on_failure=self._raise_on_failure,
efficient_filtering=self._efficient_filtering,
)

@classmethod
Expand Down Expand Up @@ -146,6 +151,7 @@ def run(
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
custom_query: Optional[Dict[str, Any]] = None,
efficient_filtering: Optional[bool] = None,
):
"""
Retrieve documents using a vector similarity metric.
Expand Down Expand Up @@ -196,6 +202,9 @@ def run(
)
```

:param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search.
This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib".

:returns:
Dictionary with key "documents" containing the retrieved Documents.
- documents: List of Document similar to `query_embedding`.
Expand All @@ -208,6 +217,8 @@ def run(
top_k = self._top_k
if custom_query is None:
custom_query = self._custom_query
if efficient_filtering is None:
efficient_filtering = self._efficient_filtering

docs: List[Document] = []

Expand All @@ -217,6 +228,7 @@ def run(
filters=filters,
top_k=top_k,
custom_query=custom_query,
efficient_filtering=efficient_filtering,
)
except Exception as e:
if self._raise_on_failure:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def _embedding_retrieval(
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
custom_query: Optional[Dict[str, Any]] = None,
efficient_filtering: bool = False,
) -> List[Document]:
"""
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
Expand Down Expand Up @@ -474,6 +475,8 @@ def _embedding_retrieval(
}
```

:param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search.
This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib".
:raises ValueError: If `query_embedding` is an empty list
:returns: List of Document that are most similar to `query_embedding`
"""
Expand Down Expand Up @@ -509,7 +512,10 @@ def _embedding_retrieval(
}

if filters:
body["query"]["bool"]["filter"] = normalize_filters(filters)
if efficient_filtering:
body["query"]["bool"]["must"][0]["knn"]["embedding"]["filter"] = normalize_filters(filters)
else:
body["query"]["bool"]["filter"] = normalize_filters(filters)

body["size"] = top_k

Expand Down
47 changes: 47 additions & 0 deletions integrations/opensearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,27 @@ def document_store_embedding_dim_4(self, request):
yield store
store.client.indices.delete(index=index, params={"ignore": [400, 404]})

@pytest.fixture
def document_store_embedding_dim_4_faiss(self, request):
"""
This is the most basic requirement for the child class: provide
an instance of this document store so the base class can use it.
"""
hosts = ["https://localhost:9200"]
# Use a different index for each test so we can run them in parallel
index = f"{request.node.name}"

store = OpenSearchDocumentStore(
hosts=hosts,
index=index,
http_auth=("admin", "admin"),
verify_certs=False,
embedding_dim=4,
method={"space_type": "innerproduct", "engine": "faiss", "name": "hnsw"},
)
yield store
store.client.indices.delete(index=index, params={"ignore": [400, 404]})

def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
"""
The OpenSearchDocumentStore.filter_documents() method returns a Documents with their score set.
Expand Down Expand Up @@ -690,6 +711,32 @@ def test_embedding_retrieval_with_filters(self, document_store_embedding_dim_4:
assert len(results) == 1
assert results[0].content == "Not very similar document with meta field"

def test_embedding_retrieval_with_filters_efficient_filtering(
self, document_store_embedding_dim_4_faiss: OpenSearchDocumentStore
):
docs = [
Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]),
Document(
content="Not very similar document with meta field",
embedding=[0.0, 0.8, 0.3, 0.9],
meta={"meta_field": "custom_value"},
),
]
document_store_embedding_dim_4_faiss.write_documents(docs)

filters = {"field": "meta_field", "operator": "==", "value": "custom_value"}
# we set top_k=3, to make the test pass as we are not sure whether efficient filtering is supported for nmslib
# TODO: remove top_k=3, when efficient filtering is supported for nmslib
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem necessary for this test already, I tried running it locally without setting top_k and it passed.

results = document_store_embedding_dim_4_faiss._embedding_retrieval(
query_embedding=[0.1, 0.1, 0.1, 0.1],
top_k=3,
filters=filters,
efficient_filtering=True,
)
tstadel marked this conversation as resolved.
Show resolved Hide resolved
assert len(results) == 1
assert results[0].content == "Not very similar document with meta field"

def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: OpenSearchDocumentStore):
"""
Test that handling of pagination works as expected, when the matching documents are > 10.
Expand Down
15 changes: 13 additions & 2 deletions integrations/opensearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_init_default():
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._filter_policy == FilterPolicy.REPLACE
assert retriever._efficient_filtering is False

retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace")
assert retriever._filter_policy == FilterPolicy.REPLACE
Expand Down Expand Up @@ -82,6 +83,7 @@ def test_to_dict(_mock_opensearch_client):
"filter_policy": "replace",
"custom_query": {"some": "custom query"},
"raise_on_failure": True,
"efficient_filtering": False,
},
}

Expand All @@ -101,6 +103,7 @@ def test_from_dict(_mock_opensearch_client):
"filter_policy": "replace",
"custom_query": {"some": "custom query"},
"raise_on_failure": False,
"efficient_filtering": True,
},
}
retriever = OpenSearchEmbeddingRetriever.from_dict(data)
Expand All @@ -110,6 +113,7 @@ def test_from_dict(_mock_opensearch_client):
assert retriever._custom_query == {"some": "custom query"}
assert retriever._raise_on_failure is False
assert retriever._filter_policy == FilterPolicy.REPLACE
assert retriever._efficient_filtering is True

# For backwards compatibility with older versions of the retriever without a filter policy
data = {
Expand Down Expand Up @@ -139,6 +143,7 @@ def test_run():
filters={},
top_k=10,
custom_query=None,
efficient_filtering=False,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand All @@ -150,14 +155,19 @@ def test_run_init_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = OpenSearchEmbeddingRetriever(
document_store=mock_store, filters={"from": "init"}, top_k=11, custom_query="custom_query"
document_store=mock_store,
filters={"from": "init"},
top_k=11,
custom_query="custom_query",
efficient_filtering=True,
)
res = retriever.run(query_embedding=[0.5, 0.7])
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={"from": "init"},
top_k=11,
custom_query="custom_query",
efficient_filtering=True,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand All @@ -169,12 +179,13 @@ def test_run_time_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11)
res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9)
res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, efficient_filtering=True)
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={"from": "run"},
top_k=9,
custom_query=None,
efficient_filtering=True,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand Down