Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Types: Add support for CrateDB's FLOAT_VECTOR data type and KNN_MATCH function #9

Merged
merged 6 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,13 @@ jobs:
cache-dependency-path:
pyproject.toml

- name: Set up project
- name: Update setuptools
run: |

# `setuptools 0.64.0` adds support for editable install hooks (PEP 660).
# https://github.com/pypa/setuptools/blob/main/CHANGES.rst#v6400
pip install "setuptools>=64" --upgrade

# Install package in editable mode.
pip install --use-pep517 --prefer-binary --editable='.[develop,test]'

- name: Invoke tests
run: |

Expand Down
5 changes: 1 addition & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,13 @@ jobs:
cache-dependency-path:
pyproject.toml

- name: Set up project
- name: Update setuptools
run: |

# `setuptools 0.64.0` adds support for editable install hooks (PEP 660).
# https://github.com/pypa/setuptools/blob/main/CHANGES.rst#v6400
pip install "setuptools>=64" --upgrade

# Install package in editable mode.
pip install --use-pep517 --prefer-binary --editable='.[develop,test]'

- name: Invoke tests
run: |

Expand Down
6 changes: 6 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@


## Unreleased
- Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying
[KNN_MATCH] function, for HNSW matches. For SQLAlchemy column definitions,
you can use it like `FloatVector(dimensions=1536)`.

[FLOAT_VECTOR]: https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
[KNN_MATCH]: https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match

## 2024/06/11 0.36.1
- Dependencies: Use `crate==1.0.0dev0`
Expand Down
1 change: 1 addition & 0 deletions DEVELOP.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ further commands.

Verify code by running all linters and software tests:

export CRATEDB_VERSION=latest
docker compose -f tests/docker-compose.yml up
poe check

Expand Down
2 changes: 1 addition & 1 deletion bootstrap.sh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function setup_package() {
fi

# Install package in editable mode.
pip install ${PIP_OPTIONS} --editable='.[develop,test]'
pip install ${PIP_OPTIONS} --use-pep517 --prefer-binary --editable='.[all,develop,test]'

# Install designated SQLAlchemy version.
if [ -n "${SQLALCHEMY_VERSION}" ]; then
Expand Down
2 changes: 2 additions & 0 deletions docs/data-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ CrateDB SQLAlchemy
`integer`__ `Integer`__
`long`__ `NUMERIC`__
`float`__ `Float`__
`float_vector`__ ``FloatVector``
`double`__ `DECIMAL`__
`timestamp`__ `TIMESTAMP`__
`string`__ `String`__
Expand All @@ -68,6 +69,7 @@ __ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#n
__ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.NUMERIC
__ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#numeric-data
__ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.Float
__ https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
__ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#numeric-data
__ http://docs.sqlalchemy.org/en/latest/core/type_basics.html#sqlalchemy.types.DECIMAL
__ https://crate.io/docs/crate/reference/en/latest/general/ddl/data-types.html#dates-and-times
Expand Down
5 changes: 3 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ Install package from PyPI.
pip install sqlalchemy-cratedb

The CrateDB dialect for `SQLAlchemy`_ offers convenient ORM access and supports
CrateDB's ``OBJECT``, ``ARRAY``, and geospatial data types using `GeoJSON`_,
supporting different kinds of `GeoJSON geometry objects`_.
CrateDB's container data types ``OBJECT`` and ``ARRAY``, its vector data type
``FLOAT_VECTOR``, and geospatial data types using `GeoJSON`_, supporting different
kinds of `GeoJSON geometry objects`_.

.. toctree::
:maxdepth: 2
Expand Down
36 changes: 36 additions & 0 deletions docs/working-with-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ from the CrateDB SQLAlchemy dialect. Currently, these are:

- Container types ``ObjectType`` and ``ObjectArray``.
- Geospatial types ``Geopoint`` and ``Geoshape``.
- Vector data type ``FloatVector``.


.. rubric:: Table of Contents
Expand Down Expand Up @@ -257,6 +258,41 @@ objects:
[('Tokyo', (139.75999999791384, 35.67999996710569), {"coordinates": [[[139.806, 35.515], [139.919, 35.703], [139.768, 35.817], [139.575, 35.76], [139.584, 35.619], [139.806, 35.515]]], "type": "Polygon"})]


Vector type
===========

CrateDB's vector data type, :ref:`crate-reference:type-float_vector`,
allows to store dense vectors of float values of fixed length.

>>> from sqlalchemy_cratedb import FloatVector, knn_match

>>> class SearchIndex(Base):
... __tablename__ = 'search'
... name = sa.Column(sa.String, primary_key=True)
... embedding = sa.Column(FloatVector(3))

Create an entity and store it into the database. ``float_vector`` values
can be defined by using arrays of floating point numbers.

>>> foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44])
>>> session.add(foo_item)
>>> session.commit()
>>> _ = connection.execute(sa.text("REFRESH TABLE search"))

When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy array.

>>> query = session.query(SearchIndex.name, SearchIndex.embedding)
>>> query.all()
[('foo', array([42.42, 43.43, 44.44], dtype=float32))]

In order to apply search, i.e. to match embeddings against each other, use the
:ref:`crate-reference:scalar_knn_match` function like this.

>>> query = session.query(SearchIndex.name) \
... .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3))
>>> query.all()
[('foo',)]

.. hidden: Disconnect from database

>>> session.close()
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ dependencies = [
"verlib2==0.2",
]
[project.optional-dependencies]
all = [
"sqlalchemy-cratedb[vector]",
]
develop = [
"black<25",
"mypy<1.11",
Expand All @@ -116,6 +119,9 @@ test = [
"pytest-cov<6",
"pytest-mock<4",
]
vector = [
"numpy",
]
[project.urls]
changelog = "https://github.com/crate-workbench/sqlalchemy-cratedb/blob/main/CHANGES.md"
documentation = "https://github.com/crate-workbench/sqlalchemy-cratedb"
Expand Down
3 changes: 3 additions & 0 deletions src/sqlalchemy_cratedb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .type.array import ObjectArray
from .type.geo import Geopoint, Geoshape
from .type.object import ObjectType
from .type.vector import FloatVector, knn_match

if SA_VERSION < SA_1_4:
import textwrap
Expand All @@ -51,9 +52,11 @@

__all__ = [
dialect,
FloatVector,
Geopoint,
Geoshape,
ObjectArray,
ObjectType,
match,
knn_match,
]
6 changes: 6 additions & 0 deletions src/sqlalchemy_cratedb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def visit_ARRAY(self, type_, **kw):
def visit_OBJECT(self, type_, **kw):
return "OBJECT"

def visit_FLOAT_VECTOR(self, type_, **kw):
dimensions = type_.dimensions
if dimensions is None:
raise ValueError("FloatVector must be initialized with dimension size")
return f"FLOAT_VECTOR({dimensions})"


class CrateCompiler(compiler.SQLCompiler):

Expand Down
5 changes: 3 additions & 2 deletions src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from crate.client.exceptions import TimezoneUnawareException
from .sa_version import SA_VERSION, SA_1_4, SA_2_0
from .type import ObjectArray, ObjectType
from .type import FloatVector, ObjectArray, ObjectType

TYPES_MAP = {
"boolean": sqltypes.Boolean,
Expand All @@ -51,7 +51,8 @@
"float": sqltypes.Float,
"real": sqltypes.Float,
"string": sqltypes.String,
"text": sqltypes.String
"text": sqltypes.String,
"float_vector": FloatVector,
}
try:
# SQLAlchemy >= 1.1
Expand Down
1 change: 1 addition & 0 deletions src/sqlalchemy_cratedb/type/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .array import ObjectArray
from .geo import Geopoint, Geoshape
from .object import ObjectType
from .vector import FloatVector, knn_match
173 changes: 173 additions & 0 deletions src/sqlalchemy_cratedb/type/vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""
## About
SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type.

## References
- https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
- https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match

## Details
The implementation is based on SQLAlchemy's `TypeDecorator`, and also
offers compiler support.

## Notes
CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`.
-- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55

pgvector use a comparator to apply different similarity functions as operators,
see `pgvector.sqlalchemy.Vector.comparator_factory`.

<->: l2/euclidean_distance
<#>: max_inner_product
<=>: cosine_distance

## Backlog
- After dropping support for SQLAlchemy 1.3, use
`class FloatVector(sa.TypeDecorator[t.Sequence[float]]):`

## Origin
This module is based on the corresponding pgvector implementation
by Andrew Kane. Thank you.

The MIT License (MIT)
Copyright (c) 2021-2023 Andrew Kane
https://github.com/pgvector/pgvector-python
"""
import typing as t

if t.TYPE_CHECKING:
import numpy.typing as npt # pragma: no cover

import sqlalchemy as sa
from sqlalchemy.sql.expression import ColumnElement, literal
from sqlalchemy.ext.compiler import compiles


__all__ = [
"from_db",
"knn_match",
"to_db",
"FloatVector",
]


def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]:
import numpy as np

# from `pgvector.utils`
# could be ndarray if already cast by lower-level driver
if value is None or isinstance(value, np.ndarray):
return value

return np.array(value, dtype=np.float32)


def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
import numpy as np

# from `pgvector.utils`
if value is None:
return value

if isinstance(value, np.ndarray):
if value.ndim != 1:
raise ValueError("expected ndim to be 1")

if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(value.dtype, np.floating):
raise ValueError("dtype must be numeric")

value = value.tolist()

if dim is not None and len(value) != dim:
raise ValueError("expected %d dimensions, not %d" % (dim, len(value)))

return value


class FloatVector(sa.TypeDecorator):
"""
SQLAlchemy `FloatVector` data type for CrateDB.
"""

cache_ok = False

__visit_name__ = "FLOAT_VECTOR"

_is_array = True

zero_indexes = False

impl = sa.ARRAY

def __init__(self, dimensions: int = None):
super().__init__(sa.FLOAT, dimensions=dimensions)

def as_generic(self, allow_nulltype=False):
return sa.ARRAY(item_type=sa.FLOAT)

@property
def python_type(self):
return list

def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable:
def process(value: t.Iterable) -> t.Optional[t.List]:
return to_db(value, self.dimensions)

return process

def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable:
def process(value: t.Any) -> t.Optional["npt.ArrayLike"]:
return from_db(value)

return process


class KnnMatch(ColumnElement):
"""
Wrap CrateDB's `KNN_MATCH` function into an SQLAlchemy function.

https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
"""
inherit_cache = True

def __init__(self, column, term, k=None):
super().__init__()
self.column = column
self.term = term
self.k = k

def compile_column(self, compiler):
return compiler.process(self.column)

def compile_term(self, compiler):
return compiler.process(literal(self.term))

def compile_k(self, compiler):
return compiler.process(literal(self.k))


def knn_match(column, term, k):
"""
Generate a match predicate for vector search.

:param column: A reference to a column or an index.

:param term: The term to match against. This is an array of floating point
values, which is compared to other vectors using a HNSW index search.

:param k: The `k` argument determines the number of nearest neighbours to
search in the index.
"""
return KnnMatch(column, term, k)


@compiles(KnnMatch)
def compile_knn_match(knn_match, compiler, **kwargs):
"""
Clause compiler for `KNN_MATCH`.
"""
return "KNN_MATCH(%s, %s, %s)" % (
knn_match.compile_column(compiler),
knn_match.compile_term(compiler),
knn_match.compile_k(compiler),
)
Loading