Skip to content

Commit

Permalink
Added test for arrays with SQLAlchemy - #96
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 13, 2024
1 parent 43b809f commit 0852a1f
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum
import pytest
from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer
from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY
from sqlalchemy.exc import StatementError
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import declarative_base, Session
Expand Down Expand Up @@ -31,6 +31,7 @@ class Item(Base):
half_embedding = mapped_column(HALFVEC(3))
binary_embedding = mapped_column(BIT(3))
sparse_embedding = mapped_column(SPARSEVEC(3))
embeddings = mapped_column(ARRAY(VECTOR(3)))


Base.metadata.drop_all(engine)
Expand Down Expand Up @@ -70,7 +71,8 @@ def test_core(self):
Column('embedding', VECTOR(3)),
Column('half_embedding', HALFVEC(3)),
Column('binary_embedding', BIT(3)),
Column('sparse_embedding', SPARSEVEC(3))
Column('sparse_embedding', SPARSEVEC(3)),
Column('embeddings', ARRAY(VECTOR(3)))
)

metadata.drop_all(engine)
Expand Down Expand Up @@ -422,6 +424,14 @@ def test_automap(self):
item = session.query(AutoItem).first()
assert item.embedding.tolist() == [1, 2, 3]

def test_vector_array(self):
session = Session(engine)
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
session.commit()

# this fails if the driver does not cast arrays
# item = session.get(Item, 1)

@pytest.mark.asyncio
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
async def test_async(self):
Expand Down

0 comments on commit 0852a1f

Please sign in to comment.