diff --git a/examples/sparse_search.py b/examples/sparse_search.py index 7e786e3..6ce33e8 100644 --- a/examples/sparse_search.py +++ b/examples/sparse_search.py @@ -45,10 +45,10 @@ def fetch_embeddings(input): ] embeddings = fetch_embeddings(input) for content, embedding in zip(input, embeddings): - conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, SparseVector.from_dense(embedding))) + conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, SparseVector(embedding))) query = 'forest' query_embedding = fetch_embeddings([query])[0] -result = conn.execute('SELECT content FROM documents ORDER BY embedding <#> %s LIMIT 5', (SparseVector.from_dense(query_embedding),)).fetchall() +result = conn.execute('SELECT content FROM documents ORDER BY embedding <#> %s LIMIT 5', (SparseVector(query_embedding),)).fetchall() for row in result: print(row[0]) diff --git a/pgvector/utils/sparsevec.py b/pgvector/utils/sparsevec.py index 96bbcdf..e6d7f87 100644 --- a/pgvector/utils/sparsevec.py +++ b/pgvector/utils/sparsevec.py @@ -3,48 +3,54 @@ class SparseVector: - def __init__(self, dim, indices, values): - # TODO improve - self._dim = int(dim) - self._indices = [int(i) for i in indices] - self._values = [float(v) for v in values] + def __init__(self, value, dimensions=None): + if value.__class__.__module__ == 'scipy.sparse._arrays': + if dimensions is not None: + raise ValueError('dimensions not allowed') + + self._from_sparse(value) + elif isinstance(value, dict): + self._from_dict(value, dimensions) + else: + if dimensions is not None: + raise ValueError('dimensions not allowed') + + self._from_dense(value) def __repr__(self): - return f'SparseVector({self._dim}, {self._indices}, {self._values})' + return f'SparseVector({self.to_dict()}, {self.dim()})' + + def _from_dict(self, d, dim): + if dim is None: + raise ValueError('dimensions required') - @classmethod - def from_dict(cls, d, dim): elements = [(i, v) for i, v in d.items()] elements.sort() - indices = [int(v[0]) for v in elements] - values = [float(v[1]) for v in elements] - return cls(dim, indices, values) + self._dim = int(dim) + self._indices = [int(v[0]) for v in elements] + self._values = [float(v[1]) for v in elements] - @classmethod - def from_sparse(cls, value): + def _from_sparse(self, value): value = value.tocoo() if value.ndim == 1: - dim = value.shape[0] + self._dim = value.shape[0] elif value.ndim == 2 and value.shape[0] == 1: - dim = value.shape[1] + self._dim = value.shape[1] else: raise ValueError('expected ndim to be 1') if hasattr(value, 'coords'): # scipy 1.13+ - indices = value.coords[0].tolist() + self._indices = value.coords[0].tolist() else: - indices = value.col.tolist() - values = value.data.tolist() - return cls(dim, indices, values) + self._indices = value.col.tolist() + self._values = value.data.tolist() - @classmethod - def from_dense(cls, value): - dim = len(value) - indices = [i for i, v in enumerate(value) if v != 0] - values = [float(value[i]) for i in indices] - return cls(dim, indices, values) + def _from_dense(self, value): + self._dim = len(value) + self._indices = [i for i, v in enumerate(value) if v != 0] + self._values = [float(value[i]) for i in self._indices] def dim(self): return self._dim @@ -86,21 +92,30 @@ def from_text(cls, value): i, v = e.split(':', 2) indices.append(int(i) - 1) values.append(float(v)) - return cls(int(dim), indices, values) + return cls._from_parts(int(dim), indices, values) @classmethod def from_binary(cls, value): dim, nnz, unused = unpack_from('>iii', value) indices = unpack_from(f'>{nnz}i', value, 12) values = unpack_from(f'>{nnz}f', value, 12 + nnz * 4) - return cls(int(dim), indices, values) + return cls._from_parts(int(dim), indices, values) + + @classmethod + def _from_parts(cls, dim, indices, values): + vec = cls.__new__(cls) + vec._dim = dim + vec._indices = indices + vec._values = values + return vec @classmethod def _to_db(cls, value, dim=None): if value is None: return value - value = cls._to_db_value(value) + if not isinstance(value, cls): + value = cls(value) if dim is not None and value.dim() != dim: raise ValueError('expected %d dimensions, not %d' % (dim, value.dim())) @@ -112,19 +127,11 @@ def _to_db_binary(cls, value): if value is None: return value - value = cls._to_db_value(value) + if not isinstance(value, cls): + value = cls(value) return value.to_binary() - @classmethod - def _to_db_value(cls, value): - if isinstance(value, cls): - return value - elif isinstance(value, (list, np.ndarray)): - return cls.from_dense(value) - else: - raise ValueError('expected sparsevec') - @classmethod def _from_db(cls, value): if value is None or isinstance(value, cls): diff --git a/tests/test_asyncpg.py b/tests/test_asyncpg.py index 3bfc888..829883e 100644 --- a/tests/test_asyncpg.py +++ b/tests/test_asyncpg.py @@ -82,7 +82,7 @@ async def test_sparsevec(self): await register_vector(conn) - embedding = SparseVector.from_dense([1.5, 2, 3]) + embedding = SparseVector([1.5, 2, 3]) await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") diff --git a/tests/test_django.py b/tests/test_django.py index 186cf3a..421966f 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -88,9 +88,9 @@ class Migration(migrations.Migration): def create_items(): - Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector.from_dense([1, 1, 1])).save() - Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector.from_dense([2, 2, 2])).save() - Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector.from_dense([1, 1, 2])).save() + Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])).save() + Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])).save() + Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])).save() class VectorForm(ModelForm): @@ -208,34 +208,34 @@ def test_bit_jaccard_distance(self): # assert [v.distance for v in items] == [0, 1/3, 1] def test_sparsevec(self): - Item(id=1, sparse_embedding=SparseVector.from_dense([1, 2, 3])).save() + Item(id=1, sparse_embedding=SparseVector([1, 2, 3])).save() item = Item.objects.get(pk=1) assert item.sparse_embedding.to_list() == [1, 2, 3] def test_sparsevec_l2_distance(self): create_items() - distance = L2Distance('sparse_embedding', SparseVector.from_dense([1, 1, 1])) + distance = L2Distance('sparse_embedding', SparseVector([1, 1, 1])) items = Item.objects.annotate(distance=distance).order_by(distance) assert [v.id for v in items] == [1, 3, 2] assert [v.distance for v in items] == [0, 1, sqrt(3)] def test_sparsevec_max_inner_product(self): create_items() - distance = MaxInnerProduct('sparse_embedding', SparseVector.from_dense([1, 1, 1])) + distance = MaxInnerProduct('sparse_embedding', SparseVector([1, 1, 1])) items = Item.objects.annotate(distance=distance).order_by(distance) assert [v.id for v in items] == [2, 3, 1] assert [v.distance for v in items] == [-6, -4, -3] def test_sparsevec_cosine_distance(self): create_items() - distance = CosineDistance('sparse_embedding', SparseVector.from_dense([1, 1, 1])) + distance = CosineDistance('sparse_embedding', SparseVector([1, 1, 1])) items = Item.objects.annotate(distance=distance).order_by(distance) assert [v.id for v in items] == [1, 2, 3] assert [v.distance for v in items] == [0, 0, 0.05719095841793653] def test_sparsevec_l1_distance(self): create_items() - distance = L1Distance('sparse_embedding', SparseVector.from_dense([1, 1, 1])) + distance = L1Distance('sparse_embedding', SparseVector([1, 1, 1])) items = Item.objects.annotate(distance=distance).order_by(distance) assert [v.id for v in items] == [1, 3, 2] assert [v.distance for v in items] == [0, 1, 3] @@ -402,7 +402,7 @@ def test_sparesevec_form_save_missing(self): assert Item.objects.get(pk=1).sparse_embedding is None def test_clean(self): - item = Item(id=1, embedding=[1, 2, 3], half_embedding=[1, 2, 3], binary_embedding='101', sparse_embedding=SparseVector.from_dense([1, 2, 3])) + item = Item(id=1, embedding=[1, 2, 3], half_embedding=[1, 2, 3], binary_embedding='101', sparse_embedding=SparseVector([1, 2, 3])) item.full_clean() def test_get_or_create(self): diff --git a/tests/test_peewee.py b/tests/test_peewee.py index 1455303..0882890 100644 --- a/tests/test_peewee.py +++ b/tests/test_peewee.py @@ -30,9 +30,9 @@ class Meta: def create_items(): - Item.create(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector.from_dense([1, 1, 1])) - Item.create(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector.from_dense([2, 2, 2])) - Item.create(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector.from_dense([1, 1, 2])) + Item.create(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])) + Item.create(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])) + Item.create(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])) class TestPeewee: @@ -132,7 +132,7 @@ def test_sparsevec(self): def test_sparsevec_l2_distance(self): create_items() - distance = Item.sparse_embedding.l2_distance(SparseVector.from_dense([1, 1, 1])) + distance = Item.sparse_embedding.l2_distance(SparseVector([1, 1, 1])) items = Item.select(Item.id, distance.alias('distance')).order_by(distance).limit(5) assert [v.id for v in items] == [1, 3, 2] assert [v.distance for v in items] == [0, 1, sqrt(3)] diff --git a/tests/test_psycopg.py b/tests/test_psycopg.py index 2de1ec7..79ac190 100644 --- a/tests/test_psycopg.py +++ b/tests/test_psycopg.py @@ -100,20 +100,20 @@ def test_bit_text_format(self): assert repr(Bit(res)) == 'Bit(010100001)' def test_sparsevec(self): - embedding = SparseVector.from_dense([1.5, 2, 3]) + embedding = SparseVector([1.5, 2, 3]) conn.execute('INSERT INTO psycopg_items (sparse_embedding) VALUES (%s)', (embedding,)) res = conn.execute('SELECT sparse_embedding FROM psycopg_items ORDER BY id').fetchone()[0] assert res.to_list() == [1.5, 2, 3] def test_sparsevec_binary_format(self): - embedding = SparseVector.from_dense([1.5, 0, 2, 0, 3, 0]) + embedding = SparseVector([1.5, 0, 2, 0, 3, 0]) res = conn.execute('SELECT %b::sparsevec', (embedding,), binary=True).fetchone()[0] assert res.to_list() == [1.5, 0, 2, 0, 3, 0] assert np.array_equal(res.to_numpy(), np.array([1.5, 0, 2, 0, 3, 0])) def test_sparsevec_text_format(self): - embedding = SparseVector.from_dense([1.5, 0, 2, 0, 3, 0]) + embedding = SparseVector([1.5, 0, 2, 0, 3, 0]) res = conn.execute('SELECT %t::sparsevec', (embedding,)).fetchone()[0] assert res.to_list() == [1.5, 0, 2, 0, 3, 0] assert np.array_equal(res.to_numpy(), np.array([1.5, 0, 2, 0, 3, 0])) @@ -122,20 +122,20 @@ def test_text_copy(self): embedding = np.array([1.5, 2, 3]) cur = conn.cursor() with cur.copy("COPY psycopg_items (embedding, half_embedding, binary_embedding, sparse_embedding) FROM STDIN") as copy: - copy.write_row([embedding, HalfVector(embedding), '101', SparseVector.from_dense(embedding)]) + copy.write_row([embedding, HalfVector(embedding), '101', SparseVector(embedding)]) def test_binary_copy(self): embedding = np.array([1.5, 2, 3]) cur = conn.cursor() with cur.copy("COPY psycopg_items (embedding, half_embedding, binary_embedding, sparse_embedding) FROM STDIN WITH (FORMAT BINARY)") as copy: - copy.write_row([embedding, HalfVector(embedding), Bit('101'), SparseVector.from_dense(embedding)]) + copy.write_row([embedding, HalfVector(embedding), Bit('101'), SparseVector(embedding)]) def test_binary_copy_set_types(self): embedding = np.array([1.5, 2, 3]) cur = conn.cursor() with cur.copy("COPY psycopg_items (id, embedding, half_embedding, binary_embedding, sparse_embedding) FROM STDIN WITH (FORMAT BINARY)") as copy: copy.set_types(['int8', 'vector', 'halfvec', 'bit', 'sparsevec']) - copy.write_row([1, embedding, HalfVector(embedding), Bit('101'), SparseVector.from_dense(embedding)]) + copy.write_row([1, embedding, HalfVector(embedding), Bit('101'), SparseVector(embedding)]) @pytest.mark.asyncio async def test_async(self): diff --git a/tests/test_psycopg2.py b/tests/test_psycopg2.py index f18405f..54da6a7 100644 --- a/tests/test_psycopg2.py +++ b/tests/test_psycopg2.py @@ -46,7 +46,7 @@ def test_bit(self): assert res[1][0] is None def test_sparsevec(self): - embedding = SparseVector.from_dense([1.5, 2, 3]) + embedding = SparseVector([1.5, 2, 3]) cur.execute('INSERT INTO psycopg2_items (sparse_embedding) VALUES (%s), (NULL)', (embedding,)) cur.execute('SELECT sparse_embedding FROM psycopg2_items ORDER BY id') diff --git a/tests/test_sparse_vector.py b/tests/test_sparse_vector.py index 1286dbb..ae38e2a 100644 --- a/tests/test_sparse_vector.py +++ b/tests/test_sparse_vector.py @@ -6,27 +6,42 @@ class TestSparseVector: def test_from_dense(self): - assert SparseVector.from_dense([1, 0, 2, 0, 3, 0]).to_list() == [1, 0, 2, 0, 3, 0] - assert SparseVector.from_dense([1, 0, 2, 0, 3, 0]).to_numpy().tolist() == [1, 0, 2, 0, 3, 0] - assert SparseVector.from_dense(np.array([1, 0, 2, 0, 3, 0])).to_list() == [1, 0, 2, 0, 3, 0] + assert SparseVector([1, 0, 2, 0, 3, 0]).to_list() == [1, 0, 2, 0, 3, 0] + assert SparseVector([1, 0, 2, 0, 3, 0]).to_numpy().tolist() == [1, 0, 2, 0, 3, 0] + assert SparseVector(np.array([1, 0, 2, 0, 3, 0])).to_list() == [1, 0, 2, 0, 3, 0] + + def test_from_dense_dimensions(self): + with pytest.raises(ValueError) as error: + SparseVector([1, 0, 2, 0, 3, 0], 6) + assert str(error.value) == 'dimensions not allowed' def test_from_dict(self): - assert SparseVector.from_dict({0: 1, 2: 2, 4: 3}, 6).to_list() == [1, 0, 2, 0, 3, 0] + assert SparseVector({0: 1, 2: 2, 4: 3}, 6).to_list() == [1, 0, 2, 0, 3, 0] + + def test_from_dict_no_dimensions(self): + with pytest.raises(ValueError) as error: + SparseVector({0: 1, 2: 2, 4: 3}) + assert str(error.value) == 'dimensions required' def test_from_sparse(self): arr = coo_array(np.array([1, 0, 2, 0, 3, 0])) - assert SparseVector.from_sparse(arr).to_list() == [1, 0, 2, 0, 3, 0] - assert SparseVector.from_sparse(arr.todok()).to_list() == [1, 0, 2, 0, 3, 0] + assert SparseVector(arr).to_list() == [1, 0, 2, 0, 3, 0] + assert SparseVector(arr.todok()).to_list() == [1, 0, 2, 0, 3, 0] + + def test_from_sparse_dimensions(self): + with pytest.raises(ValueError) as error: + SparseVector(coo_array(np.array([1, 0, 2, 0, 3, 0])), 6) + assert str(error.value) == 'dimensions not allowed' def test_repr(self): - assert repr(SparseVector.from_dense([1, 0, 2, 0, 3, 0])) == 'SparseVector(6, [0, 2, 4], [1.0, 2.0, 3.0])' - assert str(SparseVector.from_dense([1, 0, 2, 0, 3, 0])) == 'SparseVector(6, [0, 2, 4], [1.0, 2.0, 3.0])' + assert repr(SparseVector([1, 0, 2, 0, 3, 0])) == 'SparseVector({0: 1.0, 2: 2.0, 4: 3.0}, 6)' + assert str(SparseVector([1, 0, 2, 0, 3, 0])) == 'SparseVector({0: 1.0, 2: 2.0, 4: 3.0}, 6)' def test_dim(self): - assert SparseVector.from_dense([1, 0, 2, 0, 3, 0]).dim() == 6 + assert SparseVector([1, 0, 2, 0, 3, 0]).dim() == 6 def test_to_dict(self): - assert SparseVector.from_dense([1, 0, 2, 0, 3, 0]).to_dict() == {0: 1, 2: 2, 4: 3} + assert SparseVector([1, 0, 2, 0, 3, 0]).to_dict() == {0: 1, 2: 2, 4: 3} def test_to_coo(self): - assert SparseVector.from_dense([1, 0, 2, 0, 3, 0]).to_coo().toarray().tolist() == [[1, 0, 2, 0, 3, 0]] + assert SparseVector([1, 0, 2, 0, 3, 0]).to_coo().toarray().tolist() == [[1, 0, 2, 0, 3, 0]] diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index 1c0fb80..edce3dc 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -41,9 +41,9 @@ class Item(Base): def create_items(): session = Session(engine) - session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector.from_dense([1, 1, 1]))) - session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector.from_dense([2, 2, 2]))) - session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector.from_dense([1, 1, 2]))) + session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1]))) + session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2]))) + session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2]))) session.commit() diff --git a/tests/test_sqlmodel.py b/tests/test_sqlmodel.py index 90f7e21..5685ce6 100644 --- a/tests/test_sqlmodel.py +++ b/tests/test_sqlmodel.py @@ -37,9 +37,9 @@ class Item(SQLModel, table=True): def create_items(): session = Session(engine) - session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector.from_dense([1, 1, 1]))) - session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector.from_dense([2, 2, 2]))) - session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector.from_dense([1, 1, 2]))) + session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1]))) + session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2]))) + session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2]))) session.commit()