Skip to content

Commit

Permalink
Vector: Add software tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Jun 7, 2024
1 parent e861e2f commit 38e8eee
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 1 deletion.
1 change: 1 addition & 0 deletions DEVELOP.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ further commands.

Verify code by running all linters and software tests:

export CRATEDB_VERSION=latest
docker compose -f tests/docker-compose.yml up
poe check

Expand Down
29 changes: 29 additions & 0 deletions docs/working-with-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ from the CrateDB SQLAlchemy dialect. Currently, these are:

- Container types ``ObjectType`` and ``ObjectArray``.
- Geospatial types ``Geopoint`` and ``Geoshape``.
- Vector data type ``FloatVector``.


.. rubric:: Table of Contents
Expand All @@ -35,6 +36,7 @@ Import the relevant symbols:
>>> from uuid import uuid4
>>> from crate.client.sqlalchemy.types import ObjectType, ObjectArray
>>> from crate.client.sqlalchemy.types import Geopoint, Geoshape
>>> from sqlalchemy_cratedb.type import FloatVector

Establish a connection to the database, see also :ref:`sa:engines_toplevel`
and :ref:`connect`:
Expand Down Expand Up @@ -255,6 +257,33 @@ objects:
[('Tokyo', (139.75999999791384, 35.67999996710569), {"coordinates": [[[139.806, 35.515], [139.919, 35.703], [139.768, 35.817], [139.575, 35.76], [139.584, 35.619], [139.806, 35.515]]], "type": "Polygon"})]


Vector type
===========

CrateDB's vector data type, :ref:`crate-reference:type-float_vector`,
allows to store dense vectors of float values of fixed length.
``float_vector`` values are defined like float arrays.

>>> class SearchIndex(Base):
... __tablename__ = 'search'
... name = sa.Column(sa.String, primary_key=True)
... text = sa.Column(sa.Text)
... embedding = sa.Column(FloatVector(3))

Create an entity and store it into the database.

>>> foo_item = SearchIndex(name="foo", text="bar", embedding=[42.42, 43.43, 44.44])
>>> session.add(foo_item)
>>> session.commit()
>>> _ = connection.execute(text("REFRESH TABLE search"))

When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy array.

>>> query = session.query(SearchIndex.name, SearchIndex.text, SearchIndex.embedding)
>>> query.all()
[('foo', 'bar', array([42.42, 43.43, 44.44], dtype=float32))]


.. hidden: Disconnect from database
>>> session.close()
Expand Down
9 changes: 8 additions & 1 deletion tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,13 @@ def provision_database():
name STRING PRIMARY KEY,
coordinate GEO_POINT,
area GEO_SHAPE
)"""
)""",
"""
CREATE TABLE search (
name STRING PRIMARY KEY,
text STRING,
embedding FLOAT_VECTOR(3)
)""",
]
_execute_statements(ddl_statements)

Expand All @@ -120,6 +126,7 @@ def drop_tables():
"DROP TABLE IF EXISTS cities",
"DROP TABLE IF EXISTS locations",
"DROP BLOB TABLE IF EXISTS myfiles",
"DROP TABLE IF EXISTS search",
'DROP TABLE IF EXISTS "test-testdrive"',
"DROP TABLE IF EXISTS todos",
'DROP TABLE IF EXISTS "user"',
Expand Down
83 changes: 83 additions & 0 deletions tests/vector_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# -*- coding: utf-8; -*-
#
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
# license agreements. See the NOTICE file distributed with this work for
# additional information regarding copyright ownership. Crate licenses
# this file to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may
# obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.

from __future__ import absolute_import

from unittest import TestCase, skipIf
from unittest.mock import MagicMock, patch

import sqlalchemy as sa
from sqlalchemy.sql import select

from sqlalchemy_cratedb import SA_VERSION, SA_1_4
from sqlalchemy_cratedb.type import FloatVector

from crate.client.cursor import Cursor


fake_cursor = MagicMock(name="fake_cursor")
FakeCursor = MagicMock(name="FakeCursor", spec=Cursor)
FakeCursor.return_value = fake_cursor


@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 and earlier do not support the FloatVector type")
@patch("crate.client.connection.Cursor", FakeCursor)
class SqlAlchemyVectorTypeTest(TestCase):
def setUp(self):
self.engine = sa.create_engine("crate://")
metadata = sa.MetaData()
self.table = sa.Table(
"testdrive",
metadata,
sa.Column("name", sa.String),
sa.Column("data", FloatVector(3)),
)

def assertSQL(self, expected_str, selectable):
actual_expr = selectable.compile(bind=self.engine)
self.assertEqual(expected_str, str(actual_expr).replace("\n", ""))

def test_create(self):
self.table.create(self.engine)
fake_cursor.execute.assert_called_with(
(
"\nCREATE TABLE testdrive (\n\t"
"name STRING, \n\t"
"data FLOAT_VECTOR(3)\n)\n\n"
),
(),
)

def test_insert(self):
stmt = self.table.insert().values(
name="foo", data=[42.42, 43.43, 44.44]
)
with self.engine.connect() as conn:
conn.execute(stmt)
fake_cursor.execute.assert_called_with(
("INSERT INTO testdrive (name, data) VALUES (?, ?)"),
("foo", [42.42, 43.43, 44.44]),
)

def test_select(self):
self.assertSQL(
"SELECT testdrive.data FROM testdrive", select(self.table.c.data)
)

0 comments on commit 38e8eee

Please sign in to comment.