diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index f585f3e..295f8f5 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -54,16 +54,13 @@ jobs: cache-dependency-path: pyproject.toml - - name: Set up project + - name: Update setuptools run: | # `setuptools 0.64.0` adds support for editable install hooks (PEP 660). # https://github.com/pypa/setuptools/blob/main/CHANGES.rst#v6400 pip install "setuptools>=64" --upgrade - # Install package in editable mode. - pip install --use-pep517 --prefer-binary --editable='.[develop,test]' - - name: Invoke tests run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1980065..04e306e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -64,16 +64,13 @@ jobs: cache-dependency-path: pyproject.toml - - name: Set up project + - name: Update setuptools run: | # `setuptools 0.64.0` adds support for editable install hooks (PEP 660). # https://github.com/pypa/setuptools/blob/main/CHANGES.rst#v6400 pip install "setuptools>=64" --upgrade - # Install package in editable mode. - pip install --use-pep517 --prefer-binary --editable='.[develop,test]' - - name: Invoke tests run: | diff --git a/CHANGES.md b/CHANGES.md index b897543..22d0726 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,12 @@ ## Unreleased +- Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying + [KNN_MATCH] function, for HNSW matches. For SQLAlchemy column definitions, + you can use it like `FloatVector(dimensions=1536)`. + +[FLOAT_VECTOR]: https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector +[KNN_MATCH]: https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match ## 2024/06/11 0.36.1 - Dependencies: Use `crate==1.0.0dev0` diff --git a/DEVELOP.md b/DEVELOP.md index 041c40f..89c577f 100644 --- a/DEVELOP.md +++ b/DEVELOP.md @@ -16,6 +16,7 @@ further commands. Verify code by running all linters and software tests: + export CRATEDB_VERSION=latest docker compose -f tests/docker-compose.yml up poe check diff --git a/bootstrap.sh b/bootstrap.sh index 380f70c..a3e66b4 100644 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -68,7 +68,7 @@ function setup_package() { fi # Install package in editable mode. - pip install ${PIP_OPTIONS} --editable='.[develop,test]' + pip install ${PIP_OPTIONS} --use-pep517 --prefer-binary --editable='.[all,develop,test]' # Install designated SQLAlchemy version. if [ -n "${SQLALCHEMY_VERSION}" ]; then diff --git a/docs/data-types.rst b/docs/data-types.rst index 51a3e65..4b06aca 100644 --- a/docs/data-types.rst +++ b/docs/data-types.rst @@ -45,6 +45,7 @@ CrateDB SQLAlchemy `integer`__ `Integer`__ `long`__ `NUMERIC`__ `float`__ `Float`__ +`float_vector`__ ``FloatVector`` `double`__ `DECIMAL`__ `timestamp`__ `TIMESTAMP`__ `string`__ `String`__ @@ -68,6 +69,7 @@ __ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#n __ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.NUMERIC __ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#numeric-data __ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.Float +__ https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector __ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#numeric-data __ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.DECIMAL __ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#dates-and-times diff --git a/docs/index.rst b/docs/index.rst index 864dd6e..3f42ac4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -50,8 +50,9 @@ Install package from PyPI. pip install sqlalchemy-cratedb The CrateDB dialect for `SQLAlchemy`_ offers convenient ORM access and supports -CrateDB's ``OBJECT``, ``ARRAY``, and geospatial data types using `GeoJSON`_, -supporting different kinds of `GeoJSON geometry objects`_. +CrateDB's container data types ``OBJECT`` and ``ARRAY``, its vector data type +``FLOAT_VECTOR``, and geospatial data types using `GeoJSON`_, supporting different +kinds of `GeoJSON geometry objects`_. .. toctree:: :maxdepth: 2 diff --git a/docs/working-with-types.rst b/docs/working-with-types.rst index d5ccd3a..9fc6516 100644 --- a/docs/working-with-types.rst +++ b/docs/working-with-types.rst @@ -9,6 +9,7 @@ from the CrateDB SQLAlchemy dialect. Currently, these are: - Container types ``ObjectType`` and ``ObjectArray``. - Geospatial types ``Geopoint`` and ``Geoshape``. +- Vector data type ``FloatVector``. .. rubric:: Table of Contents @@ -257,6 +258,41 @@ objects: [('Tokyo', (139.75999999791384, 35.67999996710569), {"coordinates": [[[139.806, 35.515], [139.919, 35.703], [139.768, 35.817], [139.575, 35.76], [139.584, 35.619], [139.806, 35.515]]], "type": "Polygon"})] +Vector type +=========== + +CrateDB's vector data type, :ref:`crate-reference:type-float_vector`, +allows to store dense vectors of float values of fixed length. + + >>> from sqlalchemy_cratedb import FloatVector, knn_match + + >>> class SearchIndex(Base): + ... __tablename__ = 'search' + ... name = sa.Column(sa.String, primary_key=True) + ... embedding = sa.Column(FloatVector(3)) + +Create an entity and store it into the database. ``float_vector`` values +can be defined by using arrays of floating point numbers. + + >>> foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44]) + >>> session.add(foo_item) + >>> session.commit() + >>> _ = connection.execute(sa.text("REFRESH TABLE search")) + +When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy array. + + >>> query = session.query(SearchIndex.name, SearchIndex.embedding) + >>> query.all() + [('foo', array([42.42, 43.43, 44.44], dtype=float32))] + +In order to apply search, i.e. to match embeddings against each other, use the +:ref:`crate-reference:scalar_knn_match` function like this. + + >>> query = session.query(SearchIndex.name) \ + ... .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3)) + >>> query.all() + [('foo',)] + .. hidden: Disconnect from database >>> session.close() diff --git a/pyproject.toml b/pyproject.toml index cab4400..30c1363 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,9 @@ dependencies = [ "verlib2==0.2", ] [project.optional-dependencies] +all = [ + "sqlalchemy-cratedb[vector]", +] develop = [ "black<25", "mypy<1.11", @@ -116,6 +119,9 @@ test = [ "pytest-cov<6", "pytest-mock<4", ] +vector = [ + "numpy", +] [project.urls] changelog = "https://github.com/crate-workbench/sqlalchemy-cratedb/blob/main/CHANGES.md" documentation = "https://github.com/crate-workbench/sqlalchemy-cratedb" diff --git a/src/sqlalchemy_cratedb/__init__.py b/src/sqlalchemy_cratedb/__init__.py index 2ef5915..297e8fd 100644 --- a/src/sqlalchemy_cratedb/__init__.py +++ b/src/sqlalchemy_cratedb/__init__.py @@ -27,6 +27,7 @@ from .type.array import ObjectArray from .type.geo import Geopoint, Geoshape from .type.object import ObjectType +from .type.vector import FloatVector, knn_match if SA_VERSION < SA_1_4: import textwrap @@ -51,9 +52,11 @@ __all__ = [ dialect, + FloatVector, Geopoint, Geoshape, ObjectArray, ObjectType, match, + knn_match, ] diff --git a/src/sqlalchemy_cratedb/compiler.py b/src/sqlalchemy_cratedb/compiler.py index 07106b8..d3a6618 100644 --- a/src/sqlalchemy_cratedb/compiler.py +++ b/src/sqlalchemy_cratedb/compiler.py @@ -238,6 +238,12 @@ def visit_ARRAY(self, type_, **kw): def visit_OBJECT(self, type_, **kw): return "OBJECT" + def visit_FLOAT_VECTOR(self, type_, **kw): + dimensions = type_.dimensions + if dimensions is None: + raise ValueError("FloatVector must be initialized with dimension size") + return f"FLOAT_VECTOR({dimensions})" + class CrateCompiler(compiler.SQLCompiler): diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 216fb11..53fae73 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -33,7 +33,7 @@ ) from crate.client.exceptions import TimezoneUnawareException from .sa_version import SA_VERSION, SA_1_4, SA_2_0 -from .type import ObjectArray, ObjectType +from .type import FloatVector, ObjectArray, ObjectType TYPES_MAP = { "boolean": sqltypes.Boolean, @@ -51,7 +51,8 @@ "float": sqltypes.Float, "real": sqltypes.Float, "string": sqltypes.String, - "text": sqltypes.String + "text": sqltypes.String, + "float_vector": FloatVector, } try: # SQLAlchemy >= 1.1 diff --git a/src/sqlalchemy_cratedb/type/__init__.py b/src/sqlalchemy_cratedb/type/__init__.py index 8e78f7d..36ba817 100644 --- a/src/sqlalchemy_cratedb/type/__init__.py +++ b/src/sqlalchemy_cratedb/type/__init__.py @@ -1,3 +1,4 @@ from .array import ObjectArray from .geo import Geopoint, Geoshape from .object import ObjectType +from .vector import FloatVector, knn_match diff --git a/src/sqlalchemy_cratedb/type/vector.py b/src/sqlalchemy_cratedb/type/vector.py new file mode 100644 index 0000000..56e1f50 --- /dev/null +++ b/src/sqlalchemy_cratedb/type/vector.py @@ -0,0 +1,173 @@ +""" +## About +SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type. + +## References +- https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector +- https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match + +## Details +The implementation is based on SQLAlchemy's `TypeDecorator`, and also +offers compiler support. + +## Notes +CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`. +-- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55 + +pgvector use a comparator to apply different similarity functions as operators, +see `pgvector.sqlalchemy.Vector.comparator_factory`. + +<->: l2/euclidean_distance +<#>: max_inner_product +<=>: cosine_distance + +## Backlog +- After dropping support for SQLAlchemy 1.3, use + `class FloatVector(sa.TypeDecorator[t.Sequence[float]]):` + +## Origin +This module is based on the corresponding pgvector implementation +by Andrew Kane. Thank you. + +The MIT License (MIT) +Copyright (c) 2021-2023 Andrew Kane +https://github.com/pgvector/pgvector-python +""" +import typing as t + +if t.TYPE_CHECKING: + import numpy.typing as npt # pragma: no cover + +import sqlalchemy as sa +from sqlalchemy.sql.expression import ColumnElement, literal +from sqlalchemy.ext.compiler import compiles + + +__all__ = [ + "from_db", + "knn_match", + "to_db", + "FloatVector", +] + + +def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]: + import numpy as np + + # from `pgvector.utils` + # could be ndarray if already cast by lower-level driver + if value is None or isinstance(value, np.ndarray): + return value + + return np.array(value, dtype=np.float32) + + +def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: + import numpy as np + + # from `pgvector.utils` + if value is None: + return value + + if isinstance(value, np.ndarray): + if value.ndim != 1: + raise ValueError("expected ndim to be 1") + + if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(value.dtype, np.floating): + raise ValueError("dtype must be numeric") + + value = value.tolist() + + if dim is not None and len(value) != dim: + raise ValueError("expected %d dimensions, not %d" % (dim, len(value))) + + return value + + +class FloatVector(sa.TypeDecorator): + """ + SQLAlchemy `FloatVector` data type for CrateDB. + """ + + cache_ok = False + + __visit_name__ = "FLOAT_VECTOR" + + _is_array = True + + zero_indexes = False + + impl = sa.ARRAY + + def __init__(self, dimensions: int = None): + super().__init__(sa.FLOAT, dimensions=dimensions) + + def as_generic(self, allow_nulltype=False): + return sa.ARRAY(item_type=sa.FLOAT) + + @property + def python_type(self): + return list + + def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable: + def process(value: t.Iterable) -> t.Optional[t.List]: + return to_db(value, self.dimensions) + + return process + + def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable: + def process(value: t.Any) -> t.Optional["npt.ArrayLike"]: + return from_db(value) + + return process + + +class KnnMatch(ColumnElement): + """ + Wrap CrateDB's `KNN_MATCH` function into an SQLAlchemy function. + + https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match + """ + inherit_cache = True + + def __init__(self, column, term, k=None): + super().__init__() + self.column = column + self.term = term + self.k = k + + def compile_column(self, compiler): + return compiler.process(self.column) + + def compile_term(self, compiler): + return compiler.process(literal(self.term)) + + def compile_k(self, compiler): + return compiler.process(literal(self.k)) + + +def knn_match(column, term, k): + """ + Generate a match predicate for vector search. + + :param column: A reference to a column or an index. + + :param term: The term to match against. This is an array of floating point + values, which is compared to other vectors using a HNSW index search. + + :param k: The `k` argument determines the number of nearest neighbours to + search in the index. + """ + return KnnMatch(column, term, k) + + +@compiles(KnnMatch) +def compile_knn_match(knn_match, compiler, **kwargs): + """ + Clause compiler for `KNN_MATCH`. + """ + return "KNN_MATCH(%s, %s, %s)" % ( + knn_match.compile_column(compiler), + knn_match.compile_term(compiler), + knn_match.compile_k(compiler), + ) diff --git a/tests/compiler_test.py b/tests/compiler_test.py index 8c1eccf..6773b75 100644 --- a/tests/compiler_test.py +++ b/tests/compiler_test.py @@ -18,6 +18,7 @@ # However, if you have executed another commercial license agreement # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. +import sys import warnings from textwrap import dedent from unittest import mock, skipIf, TestCase @@ -288,6 +289,8 @@ def test_for_update(self): FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +@skipIf(SA_VERSION < SA_1_4 and (3, 9) <= sys.version_info < (3, 10), + "SQLAlchemy 1.3 has problems with these test cases on Python 3.9") class CompilerTestCase(TestCase): """ A base class for providing mocking infrastructure to validate the DDL compiler. diff --git a/tests/integration.py b/tests/integration.py index 968099f..5e262fc 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -105,7 +105,13 @@ def provision_database(): name STRING PRIMARY KEY, coordinate GEO_POINT, area GEO_SHAPE - )""" + )""", + """ + CREATE TABLE search ( + name STRING PRIMARY KEY, + text STRING, + embedding FLOAT_VECTOR(3) + )""", ] _execute_statements(ddl_statements) @@ -120,6 +126,7 @@ def drop_tables(): "DROP TABLE IF EXISTS cities", "DROP TABLE IF EXISTS locations", "DROP BLOB TABLE IF EXISTS myfiles", + "DROP TABLE IF EXISTS search", 'DROP TABLE IF EXISTS "test-testdrive"', "DROP TABLE IF EXISTS todos", 'DROP TABLE IF EXISTS "user"', diff --git a/tests/vector_test.py b/tests/vector_test.py new file mode 100644 index 0000000..245ed30 --- /dev/null +++ b/tests/vector_test.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8; -*- +# +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. Crate licenses +# this file to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# However, if you have executed another commercial license agreement +# with Crate these terms will supersede the license and you may use the +# software solely pursuant to the terms of the relevant commercial agreement. + +from __future__ import absolute_import + +import re +import sys +from unittest import TestCase +from unittest.mock import MagicMock, patch + +import pytest +import sqlalchemy as sa +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.sql import select + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base + +from crate.client.cursor import Cursor + +from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb import FloatVector, knn_match +from sqlalchemy_cratedb.type.vector import from_db, to_db + +fake_cursor = MagicMock(name="fake_cursor") +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) +FakeCursor.return_value = fake_cursor + + +if SA_VERSION < SA_1_4: + pytest.skip(reason="The FloatVector type is not supported on SQLAlchemy 1.3 and earlier", allow_module_level=True) + + +@patch("crate.client.connection.Cursor", FakeCursor) +class SqlAlchemyVectorTypeTest(TestCase): + """ + Verify compilation of SQL statements where the schema includes the `FloatVector` type. + """ + def setUp(self): + self.engine = sa.create_engine("crate://") + metadata = sa.MetaData() + self.table = sa.Table( + "testdrive", + metadata, + sa.Column("name", sa.String), + sa.Column("data", FloatVector(3)), + ) + self.session = Session(bind=self.engine) + + def assertSQL(self, expected_str, actual_expr): + self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) + + def test_create_invoke(self): + self.table.create(self.engine) + fake_cursor.execute.assert_called_with( + ( + "\nCREATE TABLE testdrive (\n\t" + "name STRING, \n\t" + "data FLOAT_VECTOR(3)\n)\n\n" + ), + (), + ) + + def test_insert_invoke(self): + stmt = self.table.insert().values( + name="foo", data=[42.42, 43.43, 44.44] + ) + with self.engine.connect() as conn: + conn.execute(stmt) + fake_cursor.execute.assert_called_with( + ("INSERT INTO testdrive (name, data) VALUES (?, ?)"), + ("foo", [42.42, 43.43, 44.44]), + ) + + def test_select_invoke(self): + stmt = select(self.table.c.data) + with self.engine.connect() as conn: + conn.execute(stmt) + fake_cursor.execute.assert_called_with( + ("SELECT testdrive.data \nFROM testdrive"), + (), + ) + + def test_sql_select(self): + self.assertSQL( + "SELECT testdrive.data FROM testdrive", select(self.table.c.data) + ) + + def test_sql_match(self): + query = self.session.query(self.table.c.name) \ + .filter(knn_match(self.table.c.data, [42.42, 43.43], 3)) + self.assertSQL( + "SELECT testdrive.name AS testdrive_name FROM testdrive WHERE KNN_MATCH(testdrive.data, ?, ?)", + query + ) + + +def test_from_db_success(): + """ + Verify succeeding uses of `sqlalchemy_cratedb.type.vector.from_db`. + """ + np = pytest.importorskip("numpy") + assert from_db(None) is None + assert np.array_equal(from_db(False), np.array(0., dtype=np.float32)) + assert np.array_equal(from_db(True), np.array(1., dtype=np.float32)) + assert np.array_equal(from_db(42), np.array(42, dtype=np.float32)) + assert np.array_equal(from_db(42.42), np.array(42.42, dtype=np.float32)) + assert np.array_equal(from_db([42.42, 43.43]), np.array([42.42, 43.43], dtype=np.float32)) + assert np.array_equal(from_db("42.42"), np.array(42.42, dtype=np.float32)) + assert np.array_equal(from_db(["42.42", "43.43"]), np.array([42.42, 43.43], dtype=np.float32)) + + +def test_from_db_failure(): + """ + Verify failing uses of `sqlalchemy_cratedb.type.vector.from_db`. + """ + pytest.importorskip("numpy") + + with pytest.raises(ValueError) as ex: + from_db("foo") + assert ex.match("could not convert string to float: 'foo'") + + with pytest.raises(ValueError) as ex: + from_db(["foo"]) + assert ex.match("could not convert string to float: 'foo'") + + with pytest.raises(TypeError) as ex: + from_db({"foo": "bar"}) + if sys.version_info < (3, 10): + assert ex.match(re.escape("float() argument must be a string or a number, not 'dict'")) + else: + assert ex.match(re.escape("float() argument must be a string or a real number, not 'dict'")) + + +def test_to_db_success(): + """ + Verify succeeding uses of `sqlalchemy_cratedb.type.vector.to_db`. + """ + np = pytest.importorskip("numpy") + assert to_db(None) is None + assert to_db(False) is False + assert to_db(True) is True + assert to_db(42) == 42 + assert to_db(42.42) == 42.42 + assert to_db([42.42, 43.43]) == [42.42, 43.43] + assert to_db(np.array([42.42, 43.43])) == [42.42, 43.43] + assert to_db("42.42") == "42.42" + assert to_db("foo") == "foo" + assert to_db(["foo"]) == ["foo"] + assert to_db({"foo": "bar"}) == {"foo": "bar"} + assert isinstance(to_db(object()), object) + + +def test_to_db_failure(): + """ + Verify failing uses of `sqlalchemy_cratedb.type.vector.to_db`. + """ + np = pytest.importorskip("numpy") + + with pytest.raises(ValueError) as ex: + to_db(np.array(["42.42", "43.43"])) + assert ex.match("dtype must be numeric") + + with pytest.raises(ValueError) as ex: + to_db(np.array([42.42, 43.43]), dim=33) + assert ex.match("expected 33 dimensions, not 2") + + with pytest.raises(ValueError) as ex: + to_db(np.array([[42.42, 43.43]])) + assert ex.match("expected ndim to be 1") + + +def test_float_vector_no_dimension_size(): + """ + Verify a FloatVector can not be initialized without a dimension size. + """ + engine = sa.create_engine("crate://") + metadata = sa.MetaData() + table = sa.Table( + "foo", + metadata, + sa.Column("data", FloatVector), + ) + with pytest.raises(ValueError) as ex: + table.create(engine) + ex.match("FloatVector must be initialized with dimension size") + + +def test_float_vector_as_generic(): + """ + Verify the `as_generic` and `python_type` method/property on the FloatVector type object. + """ + fv = FloatVector(3) + assert isinstance(fv.as_generic(), sa.ARRAY) + assert fv.python_type is list + + +def test_float_vector_integration(): + """ + An integration test for `FLOAT_VECTOR` and `KNN_SEARCH`. + """ + np = pytest.importorskip("numpy") + + engine = sa.create_engine(f"crate://") + session = sessionmaker(bind=engine)() + Base = declarative_base() + + # Define DDL. + class SearchIndex(Base): + __tablename__ = 'search' + name = sa.Column(sa.String, primary_key=True) + embedding = sa.Column(FloatVector(3)) + + Base.metadata.drop_all(engine, checkfirst=True) + Base.metadata.create_all(engine, checkfirst=True) + + # Insert record. + foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44]) + session.add(foo_item) + session.commit() + session.execute(sa.text("REFRESH TABLE search")) + + # Query record. + query = session.query(SearchIndex.embedding) \ + .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3)) + result = query.first() + + # Compare outcome. + assert np.array_equal(result.embedding, np.array([42.42, 43.43, 44.44], np.float32))