Skip to content

Commit

Permalink
Added support for sparsevec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jul 2, 2024
1 parent 478966e commit efb990a
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## 0.2.0 (unreleased)

- Added support for `halfvec` type
- Added support for `halfvec` and `sparsevec` types
- Fixed error with MSVC

## 0.1.1 (2022-11-13)
Expand Down
69 changes: 69 additions & 0 deletions include/pqxx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once

#include "halfvec.hpp"
#include "sparsevec.hpp"
#include "vector.hpp"
#include <pqxx/pqxx>
#include <sstream>
Expand Down Expand Up @@ -101,4 +102,72 @@ template <> struct string_traits<pgvector::HalfVector> {
static_cast<std::vector<float>>(value));
}
};

template <> std::string const type_name<pgvector::SparseVector>{"sparsevec"};

template <> struct nullness<pgvector::SparseVector> : pqxx::no_null<pgvector::SparseVector> {};

template <> struct string_traits<pgvector::SparseVector> {
static constexpr bool converts_to_string{true};

// TODO add from_string
static constexpr bool converts_from_string{false};

static zview to_buf(char* begin, char* end, const pgvector::SparseVector& value) {
char *const next = into_buf(begin, end, value);
return zview{begin, next - begin - 1};
}

static char* into_buf(char* begin, char* end, const pgvector::SparseVector& value) {
int dimensions = value.dimensions();
auto indices = value.indices();
auto values = value.values();
size_t nnz = indices.size();

// important! size_buffer cannot throw an exception on overflow
// so perform this check before writing any data
if (nnz > 16000) {
throw conversion_overrun{"sparsevec cannot have more than 16000 dimensions"};
}

char *here = begin;
*here++ = '{';

for (size_t i = 0; i < nnz; i++) {
if (i != 0) {
*here++ = ',';
}

here = string_traits<int>::into_buf(here, end, indices[i] + 1) - 1;
*here++ = ':';
here = string_traits<float>::into_buf(here, end, values[i]) - 1;
}

*here++ = '}';
*here++ = '/';
here = string_traits<int>::into_buf(here, end, dimensions) - 1;
*here++ = '\0';

return here;
}

static size_t size_buffer(const pgvector::SparseVector& value) noexcept {
int dimensions = value.dimensions();
auto indices = value.indices();
auto values = value.values();
size_t nnz = indices.size();

// cannot throw an exception here on overflow
// so throw in into_buf

size_t size = 4; // {, }, /, and \0
size += string_traits<int>::size_buffer(dimensions);
for (size_t i = 0; i < nnz; i++) {
size += 2; // : and ,
size += string_traits<int>::size_buffer(indices[i]);
size += string_traits<float>::size_buffer(values[i]);
}
return size;
}
};
} // namespace pqxx
73 changes: 73 additions & 0 deletions include/sparsevec.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*!
* pgvector-cpp v0.1.1
* https://github.com/pgvector/pgvector-cpp
* MIT License
*/

#pragma once

#include <ostream>
#include <vector>

namespace pgvector {
class SparseVector {
public:
SparseVector() = default;

SparseVector(int dimensions, const std::vector<int>& indices, const std::vector<float>& values) {
if (values.size() != indices.size()) {
throw std::invalid_argument("indices and values must be the same length");
}
dimensions_ = dimensions;
indices_ = indices;
values_ = values;
}

SparseVector(const std::vector<float>& value) {
dimensions_ = value.size();
for (size_t i = 0; i < value.size(); i++) {
float v = value[i];
if (v != 0) {
indices_.push_back(i);
values_.push_back(v);
}
}
}

int dimensions() const {
return dimensions_;
}

const std::vector<int>& indices() const {
return indices_;
}

const std::vector<float>& values() const {
return values_;
}

friend bool operator==(const SparseVector& lhs, const SparseVector& rhs) {
return lhs.dimensions_ == rhs.dimensions_ && lhs.indices_ == rhs.indices_ && lhs.values_ == rhs.values_;
}

friend std::ostream& operator<<(std::ostream& os, const SparseVector& value) {
os << "{";
for (size_t i = 0; i < value.indices_.size(); i++) {
if (i > 0) {
os << ",";
}
os << value.indices_[i] + 1;
os << ":";
os << value.values_[i];
}
os << "}/";
os << value.dimensions_;
return os;
}

private:
int dimensions_;
std::vector<int> indices_;
std::vector<float> values_;
};
} // namespace pgvector
8 changes: 4 additions & 4 deletions test/pqxx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ void test_sparsevec(pqxx::connection &conn) {
before_each(conn);

pqxx::work tx{conn};
auto embedding = "{1:1,2:2,3:3}/3";
auto embedding2 = "{1:4,2:5,3:6}/3";
auto embedding = pgvector::SparseVector({1, 2, 3});
auto embedding2 = pgvector::SparseVector({4, 5, 6});
tx.exec_params("INSERT INTO items (sparse_embedding) VALUES ($1), ($2), ($3)",
embedding, embedding2, std::nullopt);

pqxx::result res{tx.exec_params(
"SELECT sparse_embedding FROM items ORDER BY sparse_embedding <-> $1", embedding2)};
assert(res.size() == 3);
assert(res[0][0].as<std::string>() == embedding2);
assert(res[1][0].as<std::string>() == embedding);
assert(res[0][0].as<std::string>() == "{1:4,2:5,3:6}/3");
assert(res[1][0].as<std::string>() == "{1:1,2:2,3:3}/3");
assert(!res[2][0].as<std::optional<std::string>>().has_value());
tx.commit();
}
Expand Down

0 comments on commit efb990a

Please sign in to comment.