diff --git a/DEVELOP.md b/DEVELOP.md index bf5a29d..62d3a22 100644 --- a/DEVELOP.md +++ b/DEVELOP.md @@ -16,6 +16,7 @@ further commands. Verify code by running all linters and software tests: + export CRATEDB_VERSION=latest docker compose -f tests/docker-compose.yml up poe check diff --git a/docs/working-with-types.rst b/docs/working-with-types.rst index 2505f83..f56ea29 100644 --- a/docs/working-with-types.rst +++ b/docs/working-with-types.rst @@ -9,6 +9,7 @@ from the CrateDB SQLAlchemy dialect. Currently, these are: - Container types ``ObjectType`` and ``ObjectArray``. - Geospatial types ``Geopoint`` and ``Geoshape``. +- Vector data type ``FloatVector``. .. rubric:: Table of Contents @@ -35,6 +36,7 @@ Import the relevant symbols: >>> from uuid import uuid4 >>> from crate.client.sqlalchemy.types import ObjectType, ObjectArray >>> from crate.client.sqlalchemy.types import Geopoint, Geoshape + >>> from sqlalchemy_cratedb.type import FloatVector Establish a connection to the database, see also :ref:`sa:engines_toplevel` and :ref:`connect`: @@ -255,6 +257,33 @@ objects: [('Tokyo', (139.75999999791384, 35.67999996710569), {"coordinates": [[[139.806, 35.515], [139.919, 35.703], [139.768, 35.817], [139.575, 35.76], [139.584, 35.619], [139.806, 35.515]]], "type": "Polygon"})] +Vector type +=========== + +CrateDB's vector data type, :ref:`crate-reference:type-float_vector`, +allows to store dense vectors of float values of fixed length. +``float_vector`` values are defined like float arrays. + + >>> class SearchIndex(Base): + ... __tablename__ = 'search' + ... name = sa.Column(sa.String, primary_key=True) + ... text = sa.Column(sa.Text) + ... embedding = sa.Column(FloatVector(3)) + +Create an entity and store it into the database. + + >>> foo_item = SearchIndex(name="foo", text="bar", embedding=[42.42, 43.43, 44.44]) + >>> session.add(foo_item) + >>> session.commit() + >>> _ = connection.execute(text("REFRESH TABLE search")) + +When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy array. + + >>> query = session.query(SearchIndex.name, SearchIndex.text, SearchIndex.embedding) + >>> query.all() + [('foo', 'bar', array([42.42, 43.43, 44.44], dtype=float32))] + + .. hidden: Disconnect from database >>> session.close() diff --git a/tests/integration.py b/tests/integration.py index 968099f..5e262fc 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -105,7 +105,13 @@ def provision_database(): name STRING PRIMARY KEY, coordinate GEO_POINT, area GEO_SHAPE - )""" + )""", + """ + CREATE TABLE search ( + name STRING PRIMARY KEY, + text STRING, + embedding FLOAT_VECTOR(3) + )""", ] _execute_statements(ddl_statements) @@ -120,6 +126,7 @@ def drop_tables(): "DROP TABLE IF EXISTS cities", "DROP TABLE IF EXISTS locations", "DROP BLOB TABLE IF EXISTS myfiles", + "DROP TABLE IF EXISTS search", 'DROP TABLE IF EXISTS "test-testdrive"', "DROP TABLE IF EXISTS todos", 'DROP TABLE IF EXISTS "user"', diff --git a/tests/vector_test.py b/tests/vector_test.py new file mode 100644 index 0000000..e756d8c --- /dev/null +++ b/tests/vector_test.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8; -*- +# +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. Crate licenses +# this file to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# However, if you have executed another commercial license agreement +# with Crate these terms will supersede the license and you may use the +# software solely pursuant to the terms of the relevant commercial agreement. + +from __future__ import absolute_import + +from unittest import TestCase, skipIf +from unittest.mock import MagicMock, patch + +import sqlalchemy as sa +from sqlalchemy.sql import select + +from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb.type import FloatVector + +from crate.client.cursor import Cursor + + +fake_cursor = MagicMock(name="fake_cursor") +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) +FakeCursor.return_value = fake_cursor + + +@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 and earlier do not support the FloatVector type") +@patch("crate.client.connection.Cursor", FakeCursor) +class SqlAlchemyVectorTypeTest(TestCase): + def setUp(self): + self.engine = sa.create_engine("crate://") + metadata = sa.MetaData() + self.table = sa.Table( + "testdrive", + metadata, + sa.Column("name", sa.String), + sa.Column("data", FloatVector(3)), + ) + + def assertSQL(self, expected_str, selectable): + actual_expr = selectable.compile(bind=self.engine) + self.assertEqual(expected_str, str(actual_expr).replace("\n", "")) + + def test_create(self): + self.table.create(self.engine) + fake_cursor.execute.assert_called_with( + ( + "\nCREATE TABLE testdrive (\n\t" + "name STRING, \n\t" + "data FLOAT_VECTOR(3)\n)\n\n" + ), + (), + ) + + def test_insert(self): + stmt = self.table.insert().values( + name="foo", data=[42.42, 43.43, 44.44] + ) + with self.engine.connect() as conn: + conn.execute(stmt) + fake_cursor.execute.assert_called_with( + ("INSERT INTO testdrive (name, data) VALUES (?, ?)"), + ("foo", [42.42, 43.43, 44.44]), + ) + + def test_select(self): + self.assertSQL( + "SELECT testdrive.data FROM testdrive", select(self.table.c.data) + )