From d92d1d8c6a23f30b2f08c203e1236d2a814da257 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 21 Dec 2023 15:10:33 +0100 Subject: [PATCH 1/6] Vector: Add support for CrateDB's `FLOAT_VECTOR` data type: `FloatVector` https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector --- CHANGES.md | 6 + docs/data-types.rst | 2 + docs/index.rst | 5 +- pyproject.toml | 6 + src/sqlalchemy_cratedb/__init__.py | 2 + src/sqlalchemy_cratedb/compiler.py | 6 + src/sqlalchemy_cratedb/dialect.py | 5 +- src/sqlalchemy_cratedb/type/__init__.py | 1 + src/sqlalchemy_cratedb/type/vector.py | 173 ++++++++++++++++++++++++ 9 files changed, 202 insertions(+), 4 deletions(-) create mode 100644 src/sqlalchemy_cratedb/type/vector.py diff --git a/CHANGES.md b/CHANGES.md index b8975439..22d0726a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,12 @@ ## Unreleased +- Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying + [KNN_MATCH] function, for HNSW matches. For SQLAlchemy column definitions, + you can use it like `FloatVector(dimensions=1536)`. + +[FLOAT_VECTOR]: https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector +[KNN_MATCH]: https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match ## 2024/06/11 0.36.1 - Dependencies: Use `crate==1.0.0dev0` diff --git a/docs/data-types.rst b/docs/data-types.rst index 51a3e659..4b06aca5 100644 --- a/docs/data-types.rst +++ b/docs/data-types.rst @@ -45,6 +45,7 @@ CrateDB SQLAlchemy `integer`__ `Integer`__ `long`__ `NUMERIC`__ `float`__ `Float`__ +`float_vector`__ ``FloatVector`` `double`__ `DECIMAL`__ `timestamp`__ `TIMESTAMP`__ `string`__ `String`__ @@ -68,6 +69,7 @@ __ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#n __ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.NUMERIC __ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#numeric-data __ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.Float +__ https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector __ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#numeric-data __ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.DECIMAL __ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#dates-and-times diff --git a/docs/index.rst b/docs/index.rst index 864dd6e8..3f42ac4e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -50,8 +50,9 @@ Install package from PyPI. pip install sqlalchemy-cratedb The CrateDB dialect for `SQLAlchemy`_ offers convenient ORM access and supports -CrateDB's ``OBJECT``, ``ARRAY``, and geospatial data types using `GeoJSON`_, -supporting different kinds of `GeoJSON geometry objects`_. +CrateDB's container data types ``OBJECT`` and ``ARRAY``, its vector data type +``FLOAT_VECTOR``, and geospatial data types using `GeoJSON`_, supporting different +kinds of `GeoJSON geometry objects`_. .. toctree:: :maxdepth: 2 diff --git a/pyproject.toml b/pyproject.toml index cab44009..30c13633 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,9 @@ dependencies = [ "verlib2==0.2", ] [project.optional-dependencies] +all = [ + "sqlalchemy-cratedb[vector]", +] develop = [ "black<25", "mypy<1.11", @@ -116,6 +119,9 @@ test = [ "pytest-cov<6", "pytest-mock<4", ] +vector = [ + "numpy", +] [project.urls] changelog = "https://github.com/crate-workbench/sqlalchemy-cratedb/blob/main/CHANGES.md" documentation = "https://github.com/crate-workbench/sqlalchemy-cratedb" diff --git a/src/sqlalchemy_cratedb/__init__.py b/src/sqlalchemy_cratedb/__init__.py index 2ef5915e..d65bb395 100644 --- a/src/sqlalchemy_cratedb/__init__.py +++ b/src/sqlalchemy_cratedb/__init__.py @@ -27,6 +27,7 @@ from .type.array import ObjectArray from .type.geo import Geopoint, Geoshape from .type.object import ObjectType +from .type.vector import FloatVector if SA_VERSION < SA_1_4: import textwrap @@ -51,6 +52,7 @@ __all__ = [ dialect, + FloatVector, Geopoint, Geoshape, ObjectArray, diff --git a/src/sqlalchemy_cratedb/compiler.py b/src/sqlalchemy_cratedb/compiler.py index 07106b87..d3a66188 100644 --- a/src/sqlalchemy_cratedb/compiler.py +++ b/src/sqlalchemy_cratedb/compiler.py @@ -238,6 +238,12 @@ def visit_ARRAY(self, type_, **kw): def visit_OBJECT(self, type_, **kw): return "OBJECT" + def visit_FLOAT_VECTOR(self, type_, **kw): + dimensions = type_.dimensions + if dimensions is None: + raise ValueError("FloatVector must be initialized with dimension size") + return f"FLOAT_VECTOR({dimensions})" + class CrateCompiler(compiler.SQLCompiler): diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 216fb110..53fae734 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -33,7 +33,7 @@ ) from crate.client.exceptions import TimezoneUnawareException from .sa_version import SA_VERSION, SA_1_4, SA_2_0 -from .type import ObjectArray, ObjectType +from .type import FloatVector, ObjectArray, ObjectType TYPES_MAP = { "boolean": sqltypes.Boolean, @@ -51,7 +51,8 @@ "float": sqltypes.Float, "real": sqltypes.Float, "string": sqltypes.String, - "text": sqltypes.String + "text": sqltypes.String, + "float_vector": FloatVector, } try: # SQLAlchemy >= 1.1 diff --git a/src/sqlalchemy_cratedb/type/__init__.py b/src/sqlalchemy_cratedb/type/__init__.py index 8e78f7da..5bd871dc 100644 --- a/src/sqlalchemy_cratedb/type/__init__.py +++ b/src/sqlalchemy_cratedb/type/__init__.py @@ -1,3 +1,4 @@ from .array import ObjectArray from .geo import Geopoint, Geoshape from .object import ObjectType +from .vector import FloatVector diff --git a/src/sqlalchemy_cratedb/type/vector.py b/src/sqlalchemy_cratedb/type/vector.py new file mode 100644 index 00000000..01f55513 --- /dev/null +++ b/src/sqlalchemy_cratedb/type/vector.py @@ -0,0 +1,173 @@ +""" +## About +SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type. + +## References +- https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector +- https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match + +## Details +The implementation is based on SQLAlchemy's `TypeDecorator`, and also +offers compiler support. + +## Notes +CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`. +-- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55 + +On the other hand, pgvector use a comparator to apply different similarity +functions as operators, see `pgvector.sqlalchemy.Vector.comparator_factory`. + +<->: l2/euclidean_distance +<#>: max_inner_product +<=>: cosine_distance + +## Backlog +- The type implementation might want to be accompanied by corresponding support + for the `KNN_MATCH` function, similar to what the dialect already offers for + fulltext search through its `Match` predicate. + +## Origin +This module is based on the corresponding pgvector implementation +by Andrew Kane. Thank you. + +The MIT License (MIT) +Copyright (c) 2021-2023 Andrew Kane +https://github.com/pgvector/pgvector-python +""" +import typing as t + +if t.TYPE_CHECKING: + import numpy.typing as npt # pragma: no cover + +import sqlalchemy as sa + + +__all__ = [ + "from_db", + "to_db", + "FloatVector", +] + + +def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]: + import numpy as np + + # from `pgvector.utils` + # could be ndarray if already cast by lower-level driver + if value is None or isinstance(value, np.ndarray): + return value + + return np.array(value, dtype=np.float32) + + +def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: + import numpy as np + + # from `pgvector.utils` + if value is None: + return value + + if isinstance(value, np.ndarray): + if value.ndim != 1: + raise ValueError("expected ndim to be 1") + + if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(value.dtype, np.floating): + raise ValueError("dtype must be numeric") + + value = value.tolist() + + if dim is not None and len(value) != dim: + raise ValueError("expected %d dimensions, not %d" % (dim, len(value))) + + return value + + +class FloatVector(sa.TypeDecorator[t.Sequence[float]]): + + """ + SQLAlchemy `FloatVector` data type for CrateDB. + """ + + cache_ok = False + + __visit_name__ = "FLOAT_VECTOR" + + _is_array = True + + zero_indexes = False + + impl = sa.ARRAY + + def __init__(self, dimensions: int = None): + super().__init__(sa.FLOAT, dimensions=dimensions) + + def as_generic(self, allow_nulltype=False): + return sa.ARRAY(item_type=sa.FLOAT) + + @property + def python_type(self): + return list + + def bind_processor(self, dialect: sa.Dialect) -> t.Callable: + def process(value: t.Iterable) -> t.Optional[t.List]: + return to_db(value, self.dimensions) + + return process + + def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable: + def process(value: t.Any) -> t.Optional[npt.ArrayLike]: + return from_db(value) + + return process + + +class KnnMatch(ColumnElement): + """ + Wrap CrateDB's `KNN_MATCH` function into an SQLAlchemy function. + + https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match + """ + inherit_cache = True + + def __init__(self, column, term, k=None): + super(KnnMatch, self).__init__() + self.column = column + self.term = term + self.k = k + + def compile_column(self, compiler): + return compiler.process(self.column) + + def compile_term(self, compiler): + return compiler.process(literal(self.term)) + + def compile_k(self, compiler): + return compiler.process(literal(self.k)) + + +def knn_match(column, term, k): + """ + Generate a match predicate for vector search. + + :param column: A reference to a column or an index, or a subcolumn, or a + dictionary of subcolumns with boost values. + + :param term: The term to match against. This is an array of floating point + values, which is compared to other vectors using a HNSW index. + + :param k: The `k` argument determines the number of nearest neighbours to + search in the index. + """ + return KnnMatch(column, term, k) + + +@compiles(KnnMatch) +def compile_knn_match(knn_match, compiler, **kwargs): + """ + Clause compiler for `knn_match`. + """ + return "knn_match(%s, %s, %s)" % ( + knn_match.compile_column(compiler), + knn_match.compile_term(compiler), + knn_match.compile_k(compiler), + ) From 9919fd33c0314c17bc5a263a058b4b0eab0c0dd5 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 21 Dec 2023 15:14:53 +0100 Subject: [PATCH 2/6] CI: Adjust configuration to install `all` extra The `all` extra bundles all optional dependencies, like, in this case, `numpy`, added to support the `FLOAT_VECTOR` data type. --- .github/workflows/nightly.yml | 5 +---- .github/workflows/tests.yml | 5 +---- bootstrap.sh | 2 +- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index f585f3ee..295f8f51 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -54,16 +54,13 @@ jobs: cache-dependency-path: pyproject.toml - - name: Set up project + - name: Update setuptools run: | # `setuptools 0.64.0` adds support for editable install hooks (PEP 660). # https://github.com/pypa/setuptools/blob/main/CHANGES.rst#v6400 pip install "setuptools>=64" --upgrade - # Install package in editable mode. - pip install --use-pep517 --prefer-binary --editable='.[develop,test]' - - name: Invoke tests run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 19800659..04e306ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -64,16 +64,13 @@ jobs: cache-dependency-path: pyproject.toml - - name: Set up project + - name: Update setuptools run: | # `setuptools 0.64.0` adds support for editable install hooks (PEP 660). # https://github.com/pypa/setuptools/blob/main/CHANGES.rst#v6400 pip install "setuptools>=64" --upgrade - # Install package in editable mode. - pip install --use-pep517 --prefer-binary --editable='.[develop,test]' - - name: Invoke tests run: | diff --git a/bootstrap.sh b/bootstrap.sh index 380f70cf..a3e66b49 100644 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -68,7 +68,7 @@ function setup_package() { fi # Install package in editable mode. - pip install ${PIP_OPTIONS} --editable='.[develop,test]' + pip install ${PIP_OPTIONS} --use-pep517 --prefer-binary --editable='.[all,develop,test]' # Install designated SQLAlchemy version. if [ -n "${SQLALCHEMY_VERSION}" ]; then From 2c2da6897291ec7700bd100414ac809900ff3cd9 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 21 Dec 2023 15:28:50 +0100 Subject: [PATCH 3/6] Vector: Fix type checking and compatibility with SQLAlchemy 1.x --- src/sqlalchemy_cratedb/type/vector.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/sqlalchemy_cratedb/type/vector.py b/src/sqlalchemy_cratedb/type/vector.py index 01f55513..07f61d42 100644 --- a/src/sqlalchemy_cratedb/type/vector.py +++ b/src/sqlalchemy_cratedb/type/vector.py @@ -25,6 +25,8 @@ - The type implementation might want to be accompanied by corresponding support for the `KNN_MATCH` function, similar to what the dialect already offers for fulltext search through its `Match` predicate. +- After dropping support for SQLAlchemy 1.3, use + `class FloatVector(sa.TypeDecorator[t.Sequence[float]]):` ## Origin This module is based on the corresponding pgvector implementation @@ -49,7 +51,7 @@ ] -def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]: +def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]: import numpy as np # from `pgvector.utils` @@ -82,8 +84,7 @@ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: return value -class FloatVector(sa.TypeDecorator[t.Sequence[float]]): - +class FloatVector(sa.TypeDecorator): """ SQLAlchemy `FloatVector` data type for CrateDB. """ @@ -108,14 +109,14 @@ def as_generic(self, allow_nulltype=False): def python_type(self): return list - def bind_processor(self, dialect: sa.Dialect) -> t.Callable: + def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable: def process(value: t.Iterable) -> t.Optional[t.List]: return to_db(value, self.dimensions) return process - def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable: - def process(value: t.Any) -> t.Optional[npt.ArrayLike]: + def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable: + def process(value: t.Any) -> t.Optional["npt.ArrayLike"]: return from_db(value) return process From e5986f6446031d0ab9fabad727ac236b769cd6bd Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 7 Jun 2024 15:12:16 +0200 Subject: [PATCH 4/6] Vector: Add software tests --- DEVELOP.md | 1 + docs/working-with-types.rst | 28 ++++ src/sqlalchemy_cratedb/type/vector.py | 8 +- tests/integration.py | 9 +- tests/vector_test.py | 203 ++++++++++++++++++++++++++ 5 files changed, 244 insertions(+), 5 deletions(-) create mode 100644 tests/vector_test.py diff --git a/DEVELOP.md b/DEVELOP.md index 041c40fa..89c577f8 100644 --- a/DEVELOP.md +++ b/DEVELOP.md @@ -16,6 +16,7 @@ further commands. Verify code by running all linters and software tests: + export CRATEDB_VERSION=latest docker compose -f tests/docker-compose.yml up poe check diff --git a/docs/working-with-types.rst b/docs/working-with-types.rst index d5ccd3ab..b6283e37 100644 --- a/docs/working-with-types.rst +++ b/docs/working-with-types.rst @@ -9,6 +9,7 @@ from the CrateDB SQLAlchemy dialect. Currently, these are: - Container types ``ObjectType`` and ``ObjectArray``. - Geospatial types ``Geopoint`` and ``Geoshape``. +- Vector data type ``FloatVector``. .. rubric:: Table of Contents @@ -257,6 +258,33 @@ objects: [('Tokyo', (139.75999999791384, 35.67999996710569), {"coordinates": [[[139.806, 35.515], [139.919, 35.703], [139.768, 35.817], [139.575, 35.76], [139.584, 35.619], [139.806, 35.515]]], "type": "Polygon"})] +Vector type +=========== + +CrateDB's vector data type, :ref:`crate-reference:type-float_vector`, +allows to store dense vectors of float values of fixed length. + + >>> from sqlalchemy_cratedb.type.vector import FloatVector + + >>> class SearchIndex(Base): + ... __tablename__ = 'search' + ... name = sa.Column(sa.String, primary_key=True) + ... embedding = sa.Column(FloatVector(3)) + +Create an entity and store it into the database. ``float_vector`` values +can be defined by using arrays of floating point numbers. + + >>> foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44]) + >>> session.add(foo_item) + >>> session.commit() + >>> _ = connection.execute(sa.text("REFRESH TABLE search")) + +When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy array. + + >>> query = session.query(SearchIndex.name, SearchIndex.embedding) + >>> query.all() + [('foo', array([42.42, 43.43, 44.44], dtype=float32))] + .. hidden: Disconnect from database >>> session.close() diff --git a/src/sqlalchemy_cratedb/type/vector.py b/src/sqlalchemy_cratedb/type/vector.py index 07f61d42..01f62b2b 100644 --- a/src/sqlalchemy_cratedb/type/vector.py +++ b/src/sqlalchemy_cratedb/type/vector.py @@ -3,8 +3,8 @@ SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type. ## References -- https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector -- https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match +- https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector +- https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match ## Details The implementation is based on SQLAlchemy's `TypeDecorator`, and also @@ -14,8 +14,8 @@ CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`. -- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55 -On the other hand, pgvector use a comparator to apply different similarity -functions as operators, see `pgvector.sqlalchemy.Vector.comparator_factory`. +pgvector use a comparator to apply different similarity functions as operators, +see `pgvector.sqlalchemy.Vector.comparator_factory`. <->: l2/euclidean_distance <#>: max_inner_product diff --git a/tests/integration.py b/tests/integration.py index 968099f6..5e262fc8 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -105,7 +105,13 @@ def provision_database(): name STRING PRIMARY KEY, coordinate GEO_POINT, area GEO_SHAPE - )""" + )""", + """ + CREATE TABLE search ( + name STRING PRIMARY KEY, + text STRING, + embedding FLOAT_VECTOR(3) + )""", ] _execute_statements(ddl_statements) @@ -120,6 +126,7 @@ def drop_tables(): "DROP TABLE IF EXISTS cities", "DROP TABLE IF EXISTS locations", "DROP BLOB TABLE IF EXISTS myfiles", + "DROP TABLE IF EXISTS search", 'DROP TABLE IF EXISTS "test-testdrive"', "DROP TABLE IF EXISTS todos", 'DROP TABLE IF EXISTS "user"', diff --git a/tests/vector_test.py b/tests/vector_test.py new file mode 100644 index 00000000..38a52929 --- /dev/null +++ b/tests/vector_test.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8; -*- +# +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. Crate licenses +# this file to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# However, if you have executed another commercial license agreement +# with Crate these terms will supersede the license and you may use the +# software solely pursuant to the terms of the relevant commercial agreement. + +from __future__ import absolute_import + +import re +import sys +from unittest import TestCase +from unittest.mock import MagicMock, patch + +import pytest +import sqlalchemy as sa +from sqlalchemy.orm import Session +from sqlalchemy.sql import select + +from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb.type import FloatVector + +from crate.client.cursor import Cursor + +from sqlalchemy_cratedb.type.vector import from_db, to_db + +fake_cursor = MagicMock(name="fake_cursor") +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) +FakeCursor.return_value = fake_cursor + + +if SA_VERSION < SA_1_4: + pytest.skip(reason="The FloatVector type is not supported on SQLAlchemy 1.3 and earlier", allow_module_level=True) + + +@patch("crate.client.connection.Cursor", FakeCursor) +class SqlAlchemyVectorTypeTest(TestCase): + """ + Verify compilation of SQL statements where the schema includes the `FloatVector` type. + """ + def setUp(self): + self.engine = sa.create_engine("crate://") + metadata = sa.MetaData() + self.table = sa.Table( + "testdrive", + metadata, + sa.Column("name", sa.String), + sa.Column("data", FloatVector(3)), + ) + self.session = Session(bind=self.engine) + + def assertSQL(self, expected_str, actual_expr): + self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) + + def test_create_invoke(self): + self.table.create(self.engine) + fake_cursor.execute.assert_called_with( + ( + "\nCREATE TABLE testdrive (\n\t" + "name STRING, \n\t" + "data FLOAT_VECTOR(3)\n)\n\n" + ), + (), + ) + + def test_insert_invoke(self): + stmt = self.table.insert().values( + name="foo", data=[42.42, 43.43, 44.44] + ) + with self.engine.connect() as conn: + conn.execute(stmt) + fake_cursor.execute.assert_called_with( + ("INSERT INTO testdrive (name, data) VALUES (?, ?)"), + ("foo", [42.42, 43.43, 44.44]), + ) + + def test_select_invoke(self): + stmt = select(self.table.c.data) + with self.engine.connect() as conn: + conn.execute(stmt) + fake_cursor.execute.assert_called_with( + ("SELECT testdrive.data \nFROM testdrive"), + (), + ) + + def test_sql_select(self): + self.assertSQL( + "SELECT testdrive.data FROM testdrive", select(self.table.c.data) + ) + + +def test_from_db_success(): + """ + Verify succeeding uses of `sqlalchemy_cratedb.type.vector.from_db`. + """ + np = pytest.importorskip("numpy") + assert from_db(None) is None + assert np.array_equal(from_db(False), np.array(0., dtype=np.float32)) + assert np.array_equal(from_db(True), np.array(1., dtype=np.float32)) + assert np.array_equal(from_db(42), np.array(42, dtype=np.float32)) + assert np.array_equal(from_db(42.42), np.array(42.42, dtype=np.float32)) + assert np.array_equal(from_db([42.42, 43.43]), np.array([42.42, 43.43], dtype=np.float32)) + assert np.array_equal(from_db("42.42"), np.array(42.42, dtype=np.float32)) + assert np.array_equal(from_db(["42.42", "43.43"]), np.array([42.42, 43.43], dtype=np.float32)) + + +def test_from_db_failure(): + """ + Verify failing uses of `sqlalchemy_cratedb.type.vector.from_db`. + """ + pytest.importorskip("numpy") + + with pytest.raises(ValueError) as ex: + from_db("foo") + assert ex.match("could not convert string to float: 'foo'") + + with pytest.raises(ValueError) as ex: + from_db(["foo"]) + assert ex.match("could not convert string to float: 'foo'") + + with pytest.raises(TypeError) as ex: + from_db({"foo": "bar"}) + if sys.version_info < (3, 10): + assert ex.match(re.escape("float() argument must be a string or a number, not 'dict'")) + else: + assert ex.match(re.escape("float() argument must be a string or a real number, not 'dict'")) + + +def test_to_db_success(): + """ + Verify succeeding uses of `sqlalchemy_cratedb.type.vector.to_db`. + """ + np = pytest.importorskip("numpy") + assert to_db(None) is None + assert to_db(False) is False + assert to_db(True) is True + assert to_db(42) == 42 + assert to_db(42.42) == 42.42 + assert to_db([42.42, 43.43]) == [42.42, 43.43] + assert to_db(np.array([42.42, 43.43])) == [42.42, 43.43] + assert to_db("42.42") == "42.42" + assert to_db("foo") == "foo" + assert to_db(["foo"]) == ["foo"] + assert to_db({"foo": "bar"}) == {"foo": "bar"} + assert isinstance(to_db(object()), object) + + +def test_to_db_failure(): + """ + Verify failing uses of `sqlalchemy_cratedb.type.vector.to_db`. + """ + np = pytest.importorskip("numpy") + + with pytest.raises(ValueError) as ex: + to_db(np.array(["42.42", "43.43"])) + assert ex.match("dtype must be numeric") + + with pytest.raises(ValueError) as ex: + to_db(np.array([42.42, 43.43]), dim=33) + assert ex.match("expected 33 dimensions, not 2") + + with pytest.raises(ValueError) as ex: + to_db(np.array([[42.42, 43.43]])) + assert ex.match("expected ndim to be 1") + + +def test_float_vector_no_dimension_size(): + """ + Verify a FloatVector can not be initialized without a dimension size. + """ + engine = sa.create_engine("crate://") + metadata = sa.MetaData() + table = sa.Table( + "foo", + metadata, + sa.Column("data", FloatVector), + ) + with pytest.raises(ValueError) as ex: + table.create(engine) + ex.match("FloatVector must be initialized with dimension size") + + +def test_float_vector_as_generic(): + """ + Verify the `as_generic` and `python_type` method/property on the FloatVector type object. + """ + fv = FloatVector(3) + assert isinstance(fv.as_generic(), sa.ARRAY) + assert fv.python_type is list From de2aee42a46bf5786fe8085e2f4f4200fc3170a8 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 8 Jun 2024 14:14:58 +0200 Subject: [PATCH 5/6] Vector: Add wrapper for HNSW matching function `KNN_MATCH` --- docs/working-with-types.rst | 10 ++++- src/sqlalchemy_cratedb/__init__.py | 3 +- src/sqlalchemy_cratedb/type/__init__.py | 2 +- src/sqlalchemy_cratedb/type/vector.py | 17 ++++---- tests/vector_test.py | 52 +++++++++++++++++++++++-- 5 files changed, 69 insertions(+), 15 deletions(-) diff --git a/docs/working-with-types.rst b/docs/working-with-types.rst index b6283e37..9fc65168 100644 --- a/docs/working-with-types.rst +++ b/docs/working-with-types.rst @@ -264,7 +264,7 @@ Vector type CrateDB's vector data type, :ref:`crate-reference:type-float_vector`, allows to store dense vectors of float values of fixed length. - >>> from sqlalchemy_cratedb.type.vector import FloatVector + >>> from sqlalchemy_cratedb import FloatVector, knn_match >>> class SearchIndex(Base): ... __tablename__ = 'search' @@ -285,6 +285,14 @@ When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy arr >>> query.all() [('foo', array([42.42, 43.43, 44.44], dtype=float32))] +In order to apply search, i.e. to match embeddings against each other, use the +:ref:`crate-reference:scalar_knn_match` function like this. + + >>> query = session.query(SearchIndex.name) \ + ... .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3)) + >>> query.all() + [('foo',)] + .. hidden: Disconnect from database >>> session.close() diff --git a/src/sqlalchemy_cratedb/__init__.py b/src/sqlalchemy_cratedb/__init__.py index d65bb395..297e8fd9 100644 --- a/src/sqlalchemy_cratedb/__init__.py +++ b/src/sqlalchemy_cratedb/__init__.py @@ -27,7 +27,7 @@ from .type.array import ObjectArray from .type.geo import Geopoint, Geoshape from .type.object import ObjectType -from .type.vector import FloatVector +from .type.vector import FloatVector, knn_match if SA_VERSION < SA_1_4: import textwrap @@ -58,4 +58,5 @@ ObjectArray, ObjectType, match, + knn_match, ] diff --git a/src/sqlalchemy_cratedb/type/__init__.py b/src/sqlalchemy_cratedb/type/__init__.py index 5bd871dc..36ba8173 100644 --- a/src/sqlalchemy_cratedb/type/__init__.py +++ b/src/sqlalchemy_cratedb/type/__init__.py @@ -1,4 +1,4 @@ from .array import ObjectArray from .geo import Geopoint, Geoshape from .object import ObjectType -from .vector import FloatVector +from .vector import FloatVector, knn_match diff --git a/src/sqlalchemy_cratedb/type/vector.py b/src/sqlalchemy_cratedb/type/vector.py index 01f62b2b..56e1f50c 100644 --- a/src/sqlalchemy_cratedb/type/vector.py +++ b/src/sqlalchemy_cratedb/type/vector.py @@ -22,9 +22,6 @@ <=>: cosine_distance ## Backlog -- The type implementation might want to be accompanied by corresponding support - for the `KNN_MATCH` function, similar to what the dialect already offers for - fulltext search through its `Match` predicate. - After dropping support for SQLAlchemy 1.3, use `class FloatVector(sa.TypeDecorator[t.Sequence[float]]):` @@ -42,10 +39,13 @@ import numpy.typing as npt # pragma: no cover import sqlalchemy as sa +from sqlalchemy.sql.expression import ColumnElement, literal +from sqlalchemy.ext.compiler import compiles __all__ = [ "from_db", + "knn_match", "to_db", "FloatVector", ] @@ -131,7 +131,7 @@ class KnnMatch(ColumnElement): inherit_cache = True def __init__(self, column, term, k=None): - super(KnnMatch, self).__init__() + super().__init__() self.column = column self.term = term self.k = k @@ -150,11 +150,10 @@ def knn_match(column, term, k): """ Generate a match predicate for vector search. - :param column: A reference to a column or an index, or a subcolumn, or a - dictionary of subcolumns with boost values. + :param column: A reference to a column or an index. :param term: The term to match against. This is an array of floating point - values, which is compared to other vectors using a HNSW index. + values, which is compared to other vectors using a HNSW index search. :param k: The `k` argument determines the number of nearest neighbours to search in the index. @@ -165,9 +164,9 @@ def knn_match(column, term, k): @compiles(KnnMatch) def compile_knn_match(knn_match, compiler, **kwargs): """ - Clause compiler for `knn_match`. + Clause compiler for `KNN_MATCH`. """ - return "knn_match(%s, %s, %s)" % ( + return "KNN_MATCH(%s, %s, %s)" % ( knn_match.compile_column(compiler), knn_match.compile_term(compiler), knn_match.compile_k(compiler), diff --git a/tests/vector_test.py b/tests/vector_test.py index 38a52929..245ed308 100644 --- a/tests/vector_test.py +++ b/tests/vector_test.py @@ -28,14 +28,18 @@ import pytest import sqlalchemy as sa -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.sql import select -from sqlalchemy_cratedb import SA_VERSION, SA_1_4 -from sqlalchemy_cratedb.type import FloatVector +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from crate.client.cursor import Cursor +from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb import FloatVector, knn_match from sqlalchemy_cratedb.type.vector import from_db, to_db fake_cursor = MagicMock(name="fake_cursor") @@ -102,6 +106,14 @@ def test_sql_select(self): "SELECT testdrive.data FROM testdrive", select(self.table.c.data) ) + def test_sql_match(self): + query = self.session.query(self.table.c.name) \ + .filter(knn_match(self.table.c.data, [42.42, 43.43], 3)) + self.assertSQL( + "SELECT testdrive.name AS testdrive_name FROM testdrive WHERE KNN_MATCH(testdrive.data, ?, ?)", + query + ) + def test_from_db_success(): """ @@ -201,3 +213,37 @@ def test_float_vector_as_generic(): fv = FloatVector(3) assert isinstance(fv.as_generic(), sa.ARRAY) assert fv.python_type is list + + +def test_float_vector_integration(): + """ + An integration test for `FLOAT_VECTOR` and `KNN_SEARCH`. + """ + np = pytest.importorskip("numpy") + + engine = sa.create_engine(f"crate://") + session = sessionmaker(bind=engine)() + Base = declarative_base() + + # Define DDL. + class SearchIndex(Base): + __tablename__ = 'search' + name = sa.Column(sa.String, primary_key=True) + embedding = sa.Column(FloatVector(3)) + + Base.metadata.drop_all(engine, checkfirst=True) + Base.metadata.create_all(engine, checkfirst=True) + + # Insert record. + foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44]) + session.add(foo_item) + session.commit() + session.execute(sa.text("REFRESH TABLE search")) + + # Query record. + query = session.query(SearchIndex.embedding) \ + .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3)) + result = query.first() + + # Compare outcome. + assert np.array_equal(result.embedding, np.array([42.42, 43.43, 44.44], np.float32)) From 36d8b3d0cd57ea5f415c126be6e00f3ea12c6084 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 8 Jun 2024 21:17:50 +0200 Subject: [PATCH 6/6] Chore: Fix tests on Python 3.9 and SQLAlchemy 1.3, by skipping them We don't know which circumstances cause this problem. SQLAlchemy 1.3 is EOL anyway, so we don't care too much. sqlalchemy.exc.InvalidRequestError: When initializing mapper mapped class RootStore->root, expression 'ItemStore' failed to locate a name ('ItemStore'). If this is a class name, consider adding this relationship() to the .RootStore'> class after both dependent classes have been defined. --- tests/compiler_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/compiler_test.py b/tests/compiler_test.py index 8c1eccfb..6773b75e 100644 --- a/tests/compiler_test.py +++ b/tests/compiler_test.py @@ -18,6 +18,7 @@ # However, if you have executed another commercial license agreement # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. +import sys import warnings from textwrap import dedent from unittest import mock, skipIf, TestCase @@ -288,6 +289,8 @@ def test_for_update(self): FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +@skipIf(SA_VERSION < SA_1_4 and (3, 9) <= sys.version_info < (3, 10), + "SQLAlchemy 1.3 has problems with these test cases on Python 3.9") class CompilerTestCase(TestCase): """ A base class for providing mocking infrastructure to validate the DDL compiler.