Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a knn method to elasticsearch_dsl.search.Search #1691

Merged
merged 2 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/search_dsl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The ``Search`` object represents the entire search request:

* aggregations

* k-nearest neighbor searches

* sort

* pagination
Expand Down Expand Up @@ -352,6 +354,31 @@ As opposed to other methods on the ``Search`` objects, defining aggregations is
done in-place (does not return a copy).


K-Nearest Neighbor Searches
~~~~~~~~~~~~~~~~~~~~~~~~~~~

To issue a kNN search, use the ``.knn()`` method:

.. code:: python

s = Search()
vector = get_embedding("search text")

s = s.knn(
field="embedding",
k=5,
num_candidates=10,
query_vector=vector
)

The ``field``, ``k`` and ``num_candidates`` arguments can be given as
positional or keyword arguments and are required. In addition to these,
``query_vector`` or ``query_vector_builder`` must be given as well.

The ``.knn()`` method can be invoked multiple times to include multiple kNN
searches in the request.


Sorting
~~~~~~~

Expand Down
72 changes: 71 additions & 1 deletion elasticsearch_dsl/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .aggs import A, AggBase
from .connections import get_connection
from .exceptions import IllegalOperation
from .query import Bool, Q
from .query import Bool, Q, Query
from .response import Hit, Response
from .utils import AttrDict, DslBase, recursive_to_dict

Expand Down Expand Up @@ -319,6 +319,7 @@ def __init__(self, **kwargs):
self.aggs = AggsProxy(self)
self._sort = []
self._collapse = {}
self._knn = []
self._source = None
self._highlight = {}
self._highlight_opts = {}
Expand Down Expand Up @@ -406,6 +407,7 @@ def _clone(self):
s = super()._clone()

s._response_class = self._response_class
s._knn = [knn.copy() for knn in self._knn]
s._collapse = self._collapse.copy()
s._sort = self._sort[:]
s._source = copy.copy(self._source) if self._source is not None else None
Expand Down Expand Up @@ -445,6 +447,10 @@ def update_from_dict(self, d):
self.aggs._params = {
"aggs": {name: A(value) for (name, value) in aggs.items()}
}
if "knn" in d:
self._knn = d.pop("knn")
if isinstance(self._knn, dict):
self._knn = [self._knn]
if "collapse" in d:
self._collapse = d.pop("collapse")
if "sort" in d:
Expand Down Expand Up @@ -494,6 +500,64 @@ def script_fields(self, **kwargs):
s._script_fields.update(kwargs)
return s

def knn(
self,
field,
k,
num_candidates,
query_vector=None,
query_vector_builder=None,
boost=None,
filter=None,
similarity=None,
):
"""
Add a k-nearest neighbor (kNN) search.

:arg field: the name of the vector field to search against
:arg k: number of nearest neighbors to return as top hits
:arg num_candidates: number of nearest neighbor candidates to consider per shard
:arg query_vector: the vector to search for
:arg query_vector_builder: A dictionary indicating how to build a query vector
:arg boost: A floating-point boost factor for kNN scores
:arg filter: query to filter the documents that can match
:arg similarity: the minimum similarity required for a document to be considered a match, as a float value

Example::

s = Search()
s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector,
filter=Q('term', category='blog')))
"""
s = self._clone()
s._knn.append(
{
"field": field,
"k": k,
"num_candidates": num_candidates,
}
)
if query_vector is None and query_vector_builder is None:
raise ValueError("one of query_vector and query_vector_builder is required")
if query_vector is not None and query_vector_builder is not None:
raise ValueError(
"only one of query_vector and query_vector_builder must be given"
)
if query_vector is not None:
s._knn[-1]["query_vector"] = query_vector
if query_vector_builder is not None:
s._knn[-1]["query_vector_builder"] = query_vector_builder
if boost is not None:
s._knn[-1]["boost"] = boost
if filter is not None:
if isinstance(filter, Query):
s._knn[-1]["filter"] = filter.to_dict()
else:
s._knn[-1]["filter"] = filter
if similarity is not None:
s._knn[-1]["similarity"] = similarity
return s

def source(self, fields=None, **kwargs):
"""
Selectively control how the _source field is returned.
Expand Down Expand Up @@ -677,6 +741,12 @@ def to_dict(self, count=False, **kwargs):
if self.query:
d["query"] = self.query.to_dict()

if self._knn:
if len(self._knn) == 1:
d["knn"] = self._knn[0]
else:
d["knn"] = self._knn

# count request doesn't care for sorting and other things
if not count:
if self.post_filter:
Expand Down
54 changes: 54 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,60 @@ class MyDocument(Document):
assert s._doc_type_map == {}


def test_knn():
s = search.Search()

with raises(TypeError):
s.knn()
with raises(TypeError):
s.knn("field")
with raises(TypeError):
s.knn("field", 5)
with raises(ValueError):
s.knn("field", 5, 100)
with raises(ValueError):
s.knn("field", 5, 100, query_vector=[1, 2, 3], query_vector_builder={})

s = s.knn("field", 5, 100, query_vector=[1, 2, 3])
assert {
"knn": {
"field": "field",
"k": 5,
"num_candidates": 100,
"query_vector": [1, 2, 3],
}
} == s.to_dict()

s = s.knn(
k=4,
num_candidates=40,
boost=0.8,
field="name",
query_vector_builder={
"text_embedding": {"model_id": "foo", "model_text": "search text"}
},
)
assert {
"knn": [
{
"field": "field",
"k": 5,
"num_candidates": 100,
"query_vector": [1, 2, 3],
},
{
"field": "name",
"k": 4,
"num_candidates": 40,
"query_vector_builder": {
"text_embedding": {"model_id": "foo", "model_text": "search text"}
},
"boost": 0.8,
},
]
} == s.to_dict()


def test_sort():
s = search.Search()
s = s.sort("fielda", "-fieldb")
Expand Down
Loading