From de2aee42a46bf5786fe8085e2f4f4200fc3170a8 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 8 Jun 2024 14:14:58 +0200 Subject: [PATCH] Vector: Add wrapper for HNSW matching function `KNN_MATCH` --- docs/working-with-types.rst | 10 ++++- src/sqlalchemy_cratedb/__init__.py | 3 +- src/sqlalchemy_cratedb/type/__init__.py | 2 +- src/sqlalchemy_cratedb/type/vector.py | 17 ++++---- tests/vector_test.py | 52 +++++++++++++++++++++++-- 5 files changed, 69 insertions(+), 15 deletions(-) diff --git a/docs/working-with-types.rst b/docs/working-with-types.rst index b6283e37..9fc65168 100644 --- a/docs/working-with-types.rst +++ b/docs/working-with-types.rst @@ -264,7 +264,7 @@ Vector type CrateDB's vector data type, :ref:`crate-reference:type-float_vector`, allows to store dense vectors of float values of fixed length. - >>> from sqlalchemy_cratedb.type.vector import FloatVector + >>> from sqlalchemy_cratedb import FloatVector, knn_match >>> class SearchIndex(Base): ... __tablename__ = 'search' @@ -285,6 +285,14 @@ When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy arr >>> query.all() [('foo', array([42.42, 43.43, 44.44], dtype=float32))] +In order to apply search, i.e. to match embeddings against each other, use the +:ref:`crate-reference:scalar_knn_match` function like this. + + >>> query = session.query(SearchIndex.name) \ + ... .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3)) + >>> query.all() + [('foo',)] + .. hidden: Disconnect from database >>> session.close() diff --git a/src/sqlalchemy_cratedb/__init__.py b/src/sqlalchemy_cratedb/__init__.py index d65bb395..297e8fd9 100644 --- a/src/sqlalchemy_cratedb/__init__.py +++ b/src/sqlalchemy_cratedb/__init__.py @@ -27,7 +27,7 @@ from .type.array import ObjectArray from .type.geo import Geopoint, Geoshape from .type.object import ObjectType -from .type.vector import FloatVector +from .type.vector import FloatVector, knn_match if SA_VERSION < SA_1_4: import textwrap @@ -58,4 +58,5 @@ ObjectArray, ObjectType, match, + knn_match, ] diff --git a/src/sqlalchemy_cratedb/type/__init__.py b/src/sqlalchemy_cratedb/type/__init__.py index 5bd871dc..36ba8173 100644 --- a/src/sqlalchemy_cratedb/type/__init__.py +++ b/src/sqlalchemy_cratedb/type/__init__.py @@ -1,4 +1,4 @@ from .array import ObjectArray from .geo import Geopoint, Geoshape from .object import ObjectType -from .vector import FloatVector +from .vector import FloatVector, knn_match diff --git a/src/sqlalchemy_cratedb/type/vector.py b/src/sqlalchemy_cratedb/type/vector.py index 01f62b2b..56e1f50c 100644 --- a/src/sqlalchemy_cratedb/type/vector.py +++ b/src/sqlalchemy_cratedb/type/vector.py @@ -22,9 +22,6 @@ <=>: cosine_distance ## Backlog -- The type implementation might want to be accompanied by corresponding support - for the `KNN_MATCH` function, similar to what the dialect already offers for - fulltext search through its `Match` predicate. - After dropping support for SQLAlchemy 1.3, use `class FloatVector(sa.TypeDecorator[t.Sequence[float]]):` @@ -42,10 +39,13 @@ import numpy.typing as npt # pragma: no cover import sqlalchemy as sa +from sqlalchemy.sql.expression import ColumnElement, literal +from sqlalchemy.ext.compiler import compiles __all__ = [ "from_db", + "knn_match", "to_db", "FloatVector", ] @@ -131,7 +131,7 @@ class KnnMatch(ColumnElement): inherit_cache = True def __init__(self, column, term, k=None): - super(KnnMatch, self).__init__() + super().__init__() self.column = column self.term = term self.k = k @@ -150,11 +150,10 @@ def knn_match(column, term, k): """ Generate a match predicate for vector search. - :param column: A reference to a column or an index, or a subcolumn, or a - dictionary of subcolumns with boost values. + :param column: A reference to a column or an index. :param term: The term to match against. This is an array of floating point - values, which is compared to other vectors using a HNSW index. + values, which is compared to other vectors using a HNSW index search. :param k: The `k` argument determines the number of nearest neighbours to search in the index. @@ -165,9 +164,9 @@ def knn_match(column, term, k): @compiles(KnnMatch) def compile_knn_match(knn_match, compiler, **kwargs): """ - Clause compiler for `knn_match`. + Clause compiler for `KNN_MATCH`. """ - return "knn_match(%s, %s, %s)" % ( + return "KNN_MATCH(%s, %s, %s)" % ( knn_match.compile_column(compiler), knn_match.compile_term(compiler), knn_match.compile_k(compiler), diff --git a/tests/vector_test.py b/tests/vector_test.py index 38a52929..245ed308 100644 --- a/tests/vector_test.py +++ b/tests/vector_test.py @@ -28,14 +28,18 @@ import pytest import sqlalchemy as sa -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.sql import select -from sqlalchemy_cratedb import SA_VERSION, SA_1_4 -from sqlalchemy_cratedb.type import FloatVector +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from crate.client.cursor import Cursor +from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb import FloatVector, knn_match from sqlalchemy_cratedb.type.vector import from_db, to_db fake_cursor = MagicMock(name="fake_cursor") @@ -102,6 +106,14 @@ def test_sql_select(self): "SELECT testdrive.data FROM testdrive", select(self.table.c.data) ) + def test_sql_match(self): + query = self.session.query(self.table.c.name) \ + .filter(knn_match(self.table.c.data, [42.42, 43.43], 3)) + self.assertSQL( + "SELECT testdrive.name AS testdrive_name FROM testdrive WHERE KNN_MATCH(testdrive.data, ?, ?)", + query + ) + def test_from_db_success(): """ @@ -201,3 +213,37 @@ def test_float_vector_as_generic(): fv = FloatVector(3) assert isinstance(fv.as_generic(), sa.ARRAY) assert fv.python_type is list + + +def test_float_vector_integration(): + """ + An integration test for `FLOAT_VECTOR` and `KNN_SEARCH`. + """ + np = pytest.importorskip("numpy") + + engine = sa.create_engine(f"crate://") + session = sessionmaker(bind=engine)() + Base = declarative_base() + + # Define DDL. + class SearchIndex(Base): + __tablename__ = 'search' + name = sa.Column(sa.String, primary_key=True) + embedding = sa.Column(FloatVector(3)) + + Base.metadata.drop_all(engine, checkfirst=True) + Base.metadata.create_all(engine, checkfirst=True) + + # Insert record. + foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44]) + session.add(foo_item) + session.commit() + session.execute(sa.text("REFRESH TABLE search")) + + # Query record. + query = session.query(SearchIndex.embedding) \ + .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3)) + result = query.first() + + # Compare outcome. + assert np.array_equal(result.embedding, np.array([42.42, 43.43, 44.44], np.float32))