Skip to content

Commit

Permalink
Vector: Add software tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Jun 8, 2024
1 parent f4bf5f6 commit 467d21e
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 1 deletion.
1 change: 1 addition & 0 deletions DEVELOP.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ further commands.

Verify code by running all linters and software tests:

export CRATEDB_VERSION=latest
docker compose -f tests/docker-compose.yml up
poe check

Expand Down
29 changes: 29 additions & 0 deletions docs/working-with-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ from the CrateDB SQLAlchemy dialect. Currently, these are:

- Container types ``ObjectType`` and ``ObjectArray``.
- Geospatial types ``Geopoint`` and ``Geoshape``.
- Vector data type ``FloatVector``.


.. rubric:: Table of Contents
Expand Down Expand Up @@ -255,6 +256,34 @@ objects:
[('Tokyo', (139.75999999791384, 35.67999996710569), {"coordinates": [[[139.806, 35.515], [139.919, 35.703], [139.768, 35.817], [139.575, 35.76], [139.584, 35.619], [139.806, 35.515]]], "type": "Polygon"})]


Vector type
===========

CrateDB's vector data type, :ref:`crate-reference:type-float_vector`,
allows to store dense vectors of float values of fixed length.
``float_vector`` values are defined like float arrays.

>>> from sqlalchemy_cratedb.type.vector import FloatVector

>>> class SearchIndex(Base):
... __tablename__ = 'search'
... name = sa.Column(sa.String, primary_key=True)
... text = sa.Column(sa.Text)
... embedding = sa.Column(FloatVector(3))

Create an entity and store it into the database.

>>> foo_item = SearchIndex(name="foo", text="bar", embedding=[42.42, 43.43, 44.44])
>>> session.add(foo_item)
>>> session.commit()
>>> _ = connection.execute(text("REFRESH TABLE search"))

When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy array.

>>> query = session.query(SearchIndex.name, SearchIndex.text, SearchIndex.embedding)
>>> query.all()
[('foo', 'bar', array([42.42, 43.43, 44.44], dtype=float32))]

.. hidden: Disconnect from database
>>> session.close()
Expand Down
9 changes: 8 additions & 1 deletion tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,13 @@ def provision_database():
name STRING PRIMARY KEY,
coordinate GEO_POINT,
area GEO_SHAPE
)"""
)""",
"""
CREATE TABLE search (
name STRING PRIMARY KEY,
text STRING,
embedding FLOAT_VECTOR(3)
)""",
]
_execute_statements(ddl_statements)

Expand All @@ -120,6 +126,7 @@ def drop_tables():
"DROP TABLE IF EXISTS cities",
"DROP TABLE IF EXISTS locations",
"DROP BLOB TABLE IF EXISTS myfiles",
"DROP TABLE IF EXISTS search",
'DROP TABLE IF EXISTS "test-testdrive"',
"DROP TABLE IF EXISTS todos",
'DROP TABLE IF EXISTS "user"',
Expand Down
175 changes: 175 additions & 0 deletions tests/vector_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# -*- coding: utf-8; -*-
#
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
# license agreements. See the NOTICE file distributed with this work for
# additional information regarding copyright ownership. Crate licenses
# this file to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may
# obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.

from __future__ import absolute_import

import re
from unittest import TestCase, skipIf
from unittest.mock import MagicMock, patch

import pytest
import sqlalchemy as sa
from sqlalchemy.orm import Session
from sqlalchemy.sql import select

from sqlalchemy_cratedb import SA_VERSION, SA_1_4
from sqlalchemy_cratedb.type import FloatVector

from crate.client.cursor import Cursor

from sqlalchemy_cratedb.type.vector import from_db, to_db

fake_cursor = MagicMock(name="fake_cursor")
FakeCursor = MagicMock(name="FakeCursor", spec=Cursor)
FakeCursor.return_value = fake_cursor


@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 and earlier do not support the FloatVector type")
@patch("crate.client.connection.Cursor", FakeCursor)
class SqlAlchemyVectorTypeTest(TestCase):
"""
Verify compilation of SQL statements where the schema includes the `FloatVector` type.
"""
def setUp(self):
self.engine = sa.create_engine("crate://")
metadata = sa.MetaData()
self.table = sa.Table(
"testdrive",
metadata,
sa.Column("name", sa.String),
sa.Column("data", FloatVector(3)),
)
self.session = Session(bind=self.engine)

def assertSQL(self, expected_str, actual_expr):
self.assertEqual(expected_str, str(actual_expr).replace('\n', ''))

def test_create_invoke(self):
self.table.create(self.engine)
fake_cursor.execute.assert_called_with(
(
"\nCREATE TABLE testdrive (\n\t"
"name STRING, \n\t"
"data FLOAT_VECTOR(3)\n)\n\n"
),
(),
)

def test_insert_invoke(self):
stmt = self.table.insert().values(
name="foo", data=[42.42, 43.43, 44.44]
)
with self.engine.connect() as conn:
conn.execute(stmt)
fake_cursor.execute.assert_called_with(
("INSERT INTO testdrive (name, data) VALUES (?, ?)"),
("foo", [42.42, 43.43, 44.44]),
)

def test_select_invoke(self):
stmt = select(self.table.c.data)
with self.engine.connect() as conn:
conn.execute(stmt)
fake_cursor.execute.assert_called_with(
("SELECT testdrive.data \nFROM testdrive"),
(),
)

def test_sql_select(self):
self.assertSQL(
"SELECT testdrive.data FROM testdrive", select(self.table.c.data)
)


def test_from_db_success():
"""
Verify succeeding uses of `sqlalchemy_cratedb.type.vector.from_db`.
"""
np = pytest.importorskip("numpy")
assert from_db(None) is None
assert np.array_equal(from_db(False), np.array(0., dtype=np.float32))
assert np.array_equal(from_db(True), np.array(1., dtype=np.float32))
assert np.array_equal(from_db(42), np.array(42, dtype=np.float32))
assert np.array_equal(from_db(42.42), np.array(42.42, dtype=np.float32))
assert np.array_equal(from_db([42.42, 43.43]), np.array([42.42, 43.43], dtype=np.float32))
assert np.array_equal(from_db("42.42"), np.array(42.42, dtype=np.float32))
assert np.array_equal(from_db(["42.42", "43.43"]), np.array([42.42, 43.43], dtype=np.float32))


def test_from_db_failure():
"""
Verify failing uses of `sqlalchemy_cratedb.type.vector.from_db`.
"""
pytest.importorskip("numpy")

with pytest.raises(ValueError) as ex:
from_db("foo")
assert ex.match("could not convert string to float: 'foo'")

with pytest.raises(ValueError) as ex:
from_db(["foo"])
assert ex.match("could not convert string to float: 'foo'")

with pytest.raises(TypeError) as ex:
from_db({"foo": "bar"})
assert ex.match(re.escape("float() argument must be a string or a real number, not 'dict'"))

with pytest.raises(TypeError) as ex:
from_db(object())
assert ex.match(re.escape("float() argument must be a string or a real number, not 'object'"))


def test_to_db_success():
"""
Verify succeeding uses of `sqlalchemy_cratedb.type.vector.to_db`.
"""
np = pytest.importorskip("numpy")
assert to_db(None) is None
assert to_db(False) is False
assert to_db(True) is True
assert to_db(42) == 42
assert to_db(42.42) == 42.42
assert to_db([42.42, 43.43]) == [42.42, 43.43]
assert to_db(np.array([42.42, 43.43])) == [42.42, 43.43]
assert to_db("42.42") == "42.42"
assert to_db("foo") == "foo"
assert to_db(["foo"]) == ["foo"]
assert to_db({"foo": "bar"}) == {"foo": "bar"}
assert isinstance(to_db(object()), object)


def test_to_db_failure():
"""
Verify failing uses of `sqlalchemy_cratedb.type.vector.to_db`.
"""
np = pytest.importorskip("numpy")

with pytest.raises(ValueError) as ex:
to_db(np.array(["42.42", "43.43"]))
assert ex.match("dtype must be numeric")

with pytest.raises(ValueError) as ex:
to_db(np.array([42.42, 43.43]), dim=33)
assert ex.match("expected 33 dimensions, not 2")

with pytest.raises(ValueError) as ex:
to_db(np.array([[42.42, 43.43]]))
assert ex.match("expected ndim to be 1")

0 comments on commit 467d21e

Please sign in to comment.