From 467d21ea41cfd7f09440f4d9f17610c64a11c35e Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 7 Jun 2024 15:12:16 +0200 Subject: [PATCH] Vector: Add software tests --- DEVELOP.md | 1 + docs/working-with-types.rst | 29 ++++++ tests/integration.py | 9 +- tests/vector_test.py | 175 ++++++++++++++++++++++++++++++++++++ 4 files changed, 213 insertions(+), 1 deletion(-) create mode 100644 tests/vector_test.py diff --git a/DEVELOP.md b/DEVELOP.md index bf5a29d0..62d3a221 100644 --- a/DEVELOP.md +++ b/DEVELOP.md @@ -16,6 +16,7 @@ further commands. Verify code by running all linters and software tests: + export CRATEDB_VERSION=latest docker compose -f tests/docker-compose.yml up poe check diff --git a/docs/working-with-types.rst b/docs/working-with-types.rst index 2505f831..3509bb89 100644 --- a/docs/working-with-types.rst +++ b/docs/working-with-types.rst @@ -9,6 +9,7 @@ from the CrateDB SQLAlchemy dialect. Currently, these are: - Container types ``ObjectType`` and ``ObjectArray``. - Geospatial types ``Geopoint`` and ``Geoshape``. +- Vector data type ``FloatVector``. .. rubric:: Table of Contents @@ -255,6 +256,34 @@ objects: [('Tokyo', (139.75999999791384, 35.67999996710569), {"coordinates": [[[139.806, 35.515], [139.919, 35.703], [139.768, 35.817], [139.575, 35.76], [139.584, 35.619], [139.806, 35.515]]], "type": "Polygon"})] +Vector type +=========== + +CrateDB's vector data type, :ref:`crate-reference:type-float_vector`, +allows to store dense vectors of float values of fixed length. +``float_vector`` values are defined like float arrays. + + >>> from sqlalchemy_cratedb.type.vector import FloatVector + + >>> class SearchIndex(Base): + ... __tablename__ = 'search' + ... name = sa.Column(sa.String, primary_key=True) + ... text = sa.Column(sa.Text) + ... embedding = sa.Column(FloatVector(3)) + +Create an entity and store it into the database. + + >>> foo_item = SearchIndex(name="foo", text="bar", embedding=[42.42, 43.43, 44.44]) + >>> session.add(foo_item) + >>> session.commit() + >>> _ = connection.execute(text("REFRESH TABLE search")) + +When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy array. + + >>> query = session.query(SearchIndex.name, SearchIndex.text, SearchIndex.embedding) + >>> query.all() + [('foo', 'bar', array([42.42, 43.43, 44.44], dtype=float32))] + .. hidden: Disconnect from database >>> session.close() diff --git a/tests/integration.py b/tests/integration.py index 968099f6..5e262fc8 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -105,7 +105,13 @@ def provision_database(): name STRING PRIMARY KEY, coordinate GEO_POINT, area GEO_SHAPE - )""" + )""", + """ + CREATE TABLE search ( + name STRING PRIMARY KEY, + text STRING, + embedding FLOAT_VECTOR(3) + )""", ] _execute_statements(ddl_statements) @@ -120,6 +126,7 @@ def drop_tables(): "DROP TABLE IF EXISTS cities", "DROP TABLE IF EXISTS locations", "DROP BLOB TABLE IF EXISTS myfiles", + "DROP TABLE IF EXISTS search", 'DROP TABLE IF EXISTS "test-testdrive"', "DROP TABLE IF EXISTS todos", 'DROP TABLE IF EXISTS "user"', diff --git a/tests/vector_test.py b/tests/vector_test.py new file mode 100644 index 00000000..756c93cf --- /dev/null +++ b/tests/vector_test.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8; -*- +# +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. Crate licenses +# this file to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# However, if you have executed another commercial license agreement +# with Crate these terms will supersede the license and you may use the +# software solely pursuant to the terms of the relevant commercial agreement. + +from __future__ import absolute_import + +import re +from unittest import TestCase, skipIf +from unittest.mock import MagicMock, patch + +import pytest +import sqlalchemy as sa +from sqlalchemy.orm import Session +from sqlalchemy.sql import select + +from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb.type import FloatVector + +from crate.client.cursor import Cursor + +from sqlalchemy_cratedb.type.vector import from_db, to_db + +fake_cursor = MagicMock(name="fake_cursor") +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) +FakeCursor.return_value = fake_cursor + + +@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 and earlier do not support the FloatVector type") +@patch("crate.client.connection.Cursor", FakeCursor) +class SqlAlchemyVectorTypeTest(TestCase): + """ + Verify compilation of SQL statements where the schema includes the `FloatVector` type. + """ + def setUp(self): + self.engine = sa.create_engine("crate://") + metadata = sa.MetaData() + self.table = sa.Table( + "testdrive", + metadata, + sa.Column("name", sa.String), + sa.Column("data", FloatVector(3)), + ) + self.session = Session(bind=self.engine) + + def assertSQL(self, expected_str, actual_expr): + self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) + + def test_create_invoke(self): + self.table.create(self.engine) + fake_cursor.execute.assert_called_with( + ( + "\nCREATE TABLE testdrive (\n\t" + "name STRING, \n\t" + "data FLOAT_VECTOR(3)\n)\n\n" + ), + (), + ) + + def test_insert_invoke(self): + stmt = self.table.insert().values( + name="foo", data=[42.42, 43.43, 44.44] + ) + with self.engine.connect() as conn: + conn.execute(stmt) + fake_cursor.execute.assert_called_with( + ("INSERT INTO testdrive (name, data) VALUES (?, ?)"), + ("foo", [42.42, 43.43, 44.44]), + ) + + def test_select_invoke(self): + stmt = select(self.table.c.data) + with self.engine.connect() as conn: + conn.execute(stmt) + fake_cursor.execute.assert_called_with( + ("SELECT testdrive.data \nFROM testdrive"), + (), + ) + + def test_sql_select(self): + self.assertSQL( + "SELECT testdrive.data FROM testdrive", select(self.table.c.data) + ) + + +def test_from_db_success(): + """ + Verify succeeding uses of `sqlalchemy_cratedb.type.vector.from_db`. + """ + np = pytest.importorskip("numpy") + assert from_db(None) is None + assert np.array_equal(from_db(False), np.array(0., dtype=np.float32)) + assert np.array_equal(from_db(True), np.array(1., dtype=np.float32)) + assert np.array_equal(from_db(42), np.array(42, dtype=np.float32)) + assert np.array_equal(from_db(42.42), np.array(42.42, dtype=np.float32)) + assert np.array_equal(from_db([42.42, 43.43]), np.array([42.42, 43.43], dtype=np.float32)) + assert np.array_equal(from_db("42.42"), np.array(42.42, dtype=np.float32)) + assert np.array_equal(from_db(["42.42", "43.43"]), np.array([42.42, 43.43], dtype=np.float32)) + + +def test_from_db_failure(): + """ + Verify failing uses of `sqlalchemy_cratedb.type.vector.from_db`. + """ + pytest.importorskip("numpy") + + with pytest.raises(ValueError) as ex: + from_db("foo") + assert ex.match("could not convert string to float: 'foo'") + + with pytest.raises(ValueError) as ex: + from_db(["foo"]) + assert ex.match("could not convert string to float: 'foo'") + + with pytest.raises(TypeError) as ex: + from_db({"foo": "bar"}) + assert ex.match(re.escape("float() argument must be a string or a real number, not 'dict'")) + + with pytest.raises(TypeError) as ex: + from_db(object()) + assert ex.match(re.escape("float() argument must be a string or a real number, not 'object'")) + + +def test_to_db_success(): + """ + Verify succeeding uses of `sqlalchemy_cratedb.type.vector.to_db`. + """ + np = pytest.importorskip("numpy") + assert to_db(None) is None + assert to_db(False) is False + assert to_db(True) is True + assert to_db(42) == 42 + assert to_db(42.42) == 42.42 + assert to_db([42.42, 43.43]) == [42.42, 43.43] + assert to_db(np.array([42.42, 43.43])) == [42.42, 43.43] + assert to_db("42.42") == "42.42" + assert to_db("foo") == "foo" + assert to_db(["foo"]) == ["foo"] + assert to_db({"foo": "bar"}) == {"foo": "bar"} + assert isinstance(to_db(object()), object) + + +def test_to_db_failure(): + """ + Verify failing uses of `sqlalchemy_cratedb.type.vector.to_db`. + """ + np = pytest.importorskip("numpy") + + with pytest.raises(ValueError) as ex: + to_db(np.array(["42.42", "43.43"])) + assert ex.match("dtype must be numeric") + + with pytest.raises(ValueError) as ex: + to_db(np.array([42.42, 43.43]), dim=33) + assert ex.match("expected 33 dimensions, not 2") + + with pytest.raises(ValueError) as ex: + to_db(np.array([[42.42, 43.43]])) + assert ex.match("expected ndim to be 1")