Skip to content

Commit

Permalink
Vector: Add wrapper for HNSW matching function KNN_MATCH
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Jun 9, 2024
1 parent 1abd9cc commit 84030bd
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 3 deletions.
10 changes: 9 additions & 1 deletion docs/working-with-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ CrateDB's vector data type, :ref:`crate-reference:type-float_vector`,
allows to store dense vectors of float values of fixed length.
``float_vector`` values are defined like float arrays.

>>> from sqlalchemy_cratedb.type.vector import FloatVector
>>> from sqlalchemy_cratedb.type.vector import knn_match, FloatVector

>>> class SearchIndex(Base):
... __tablename__ = 'search'
Expand All @@ -284,6 +284,14 @@ When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy arr
>>> query.all()
[('foo', 'bar', array([42.42, 43.43, 44.44], dtype=float32))]

In order to apply search, i.e. to match embeddings against each other, use the
``KNN_MATCH`` function like this.

>>> query = session.query(SearchIndex.name) \
... .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3))
>>> query.all()
[('foo',)]

.. hidden: Disconnect from database
>>> session.close()
Expand Down
3 changes: 3 additions & 0 deletions src/sqlalchemy_cratedb/type/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@
import numpy.typing as npt # pragma: no cover

import sqlalchemy as sa
from sqlalchemy.sql.expression import ColumnElement, literal
from sqlalchemy.ext.compiler import compiles


__all__ = [
"from_db",
"knn_match",
"to_db",
"FloatVector",
]
Expand Down
51 changes: 49 additions & 2 deletions tests/vector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,20 @@

import pytest
import sqlalchemy as sa
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.sql import select

from sqlalchemy_cratedb import SA_VERSION, SA_1_4
from sqlalchemy_cratedb.type import FloatVector

try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base

from crate.client.cursor import Cursor

from sqlalchemy_cratedb.type.vector import from_db, to_db
from sqlalchemy_cratedb.type.vector import from_db, knn_match, to_db

fake_cursor = MagicMock(name="fake_cursor")
FakeCursor = MagicMock(name="FakeCursor", spec=Cursor)
Expand Down Expand Up @@ -102,6 +107,14 @@ def test_sql_select(self):
"SELECT testdrive.data FROM testdrive", select(self.table.c.data)
)

def test_sql_match(self):
query = self.session.query(self.table.c.name) \
.filter(knn_match(self.table.c.data, [42.42, 43.43], 3))
self.assertSQL(
"SELECT testdrive.name AS testdrive_name FROM testdrive WHERE knn_match(testdrive.data, ?, ?)",
query
)


def test_from_db_success():
"""
Expand Down Expand Up @@ -201,3 +214,37 @@ def test_float_vector_as_generic():
fv = FloatVector(3)
assert isinstance(fv.as_generic(), sa.ARRAY)
assert fv.python_type is list


def test_float_vector_integration():
"""
An integration test for `FLOAT_VECTOR` and `KNN_SEARCH`.
"""
np = pytest.importorskip("numpy")

engine = sa.create_engine(f"crate://")
session = sessionmaker(bind=engine)()
Base = declarative_base()

# Define DDL.
class SearchIndex(Base):
__tablename__ = 'search'
name = sa.Column(sa.String, primary_key=True)
embedding = sa.Column(FloatVector(3))

Base.metadata.drop_all(engine, checkfirst=True)
Base.metadata.create_all(engine, checkfirst=True)

# Insert record.
foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44])
session.add(foo_item)
session.commit()
session.execute(sa.text("REFRESH TABLE search"))

# Query record.
query = session.query(SearchIndex.embedding) \
.filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3))
result = query.first()

# Compare outcome.
assert np.array_equal(result.embedding, np.array([42.42, 43.43, 44.44], np.float32))

0 comments on commit 84030bd

Please sign in to comment.