Skip to content

Commit

Permalink
Added support for halfvec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 19, 2024
1 parent 0d7966b commit 44ede65
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 0.2.0 (unreleased)

- Added support for `halfvec` type
- Fixed error with MSVC

## 0.1.1 (2022-11-13)
Expand Down
57 changes: 57 additions & 0 deletions include/halfvec.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*!
* pgvector-cpp v0.1.1
* https://github.com/pgvector/pgvector-cpp
* MIT License
*/

#pragma once

#include <ostream>
#include <vector>

namespace pgvector {
class HalfVector {
public:
HalfVector() = default;

HalfVector(const std::vector<float> &value) {
value_ = value;
}

HalfVector(std::vector<float>&& value) {
value_ = std::move(value);
}

HalfVector(const float *value, size_t n) {
value_ = std::vector<float>{value, value + n};
}

size_t dimensions() const {
return value_.size();
}

operator const std::vector<float>() const {
return value_;
}

friend bool operator==(const HalfVector &lhs, const HalfVector &rhs) {
return lhs.value_ == rhs.value_;
}

friend std::ostream &operator<<(std::ostream &os, const HalfVector &value) {
os << "[";
for (size_t i = 0; i < value.value_.size(); i++) {
if (i > 0) {
os << ",";
}
os << value.value_[i];
}
os << "]";
return os;
}

private:
// TODO use std::float16_t for C++23
std::vector<float> value_;
};
} // namespace pgvector
46 changes: 46 additions & 0 deletions include/pqxx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#pragma once

#include "halfvec.hpp"
#include "vector.hpp"
#include <pqxx/pqxx>
#include <sstream>
Expand Down Expand Up @@ -55,4 +56,49 @@ template <> struct string_traits<pgvector::Vector> {
static_cast<std::vector<float>>(value));
}
};

template <> std::string const type_name<pgvector::HalfVector>{"halfvec"};

template <> struct nullness<pgvector::HalfVector> : pqxx::no_null<pgvector::HalfVector> {};

template <> struct string_traits<pgvector::HalfVector> {
static constexpr bool converts_to_string{true};

static constexpr bool converts_from_string{true};

static pgvector::HalfVector from_string(std::string_view text) {
if (text.front() != '[' || text.back() != ']') {
throw conversion_error("Malformed halfvec literal");
}

// TODO don't copy string
std::vector<float> result;
std::stringstream ss(std::string(text.substr(1, -2)));
while (ss.good()) {
std::string substr;
getline(ss, substr, ',');
result.push_back(std::stod(substr));
}
return pgvector::HalfVector(result);
}

static zview to_buf(char *begin, char *end, pgvector::HalfVector const &value) {
char *const next = into_buf(begin, end, value);
return zview{begin, next - begin - 1};
}

static char *into_buf(char *begin, char *end, pgvector::HalfVector const &value) {
auto ret = string_traits<std::vector<float>>::into_buf(
begin, end, static_cast<std::vector<float>>(value));
// replace array brackets
*begin = '[';
*(ret - 2) = ']';
return ret;
}

static size_t size_buffer(pgvector::HalfVector const &value) noexcept {
return string_traits<std::vector<float>>::size_buffer(
static_cast<std::vector<float>>(value));
}
};
} // namespace pqxx
20 changes: 19 additions & 1 deletion test/pqxx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ void setup(pqxx::connection &conn) {
pqxx::work tx{conn};
tx.exec0("CREATE EXTENSION IF NOT EXISTS vector");
tx.exec0("DROP TABLE IF EXISTS items");
tx.exec0("CREATE TABLE items (id serial PRIMARY KEY, embedding vector(3))");
tx.exec0("CREATE TABLE items (id serial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3))");
tx.commit();
}

Expand All @@ -29,6 +29,24 @@ void test_works(pqxx::connection &conn) {
tx.commit();
}

void test_halfvec(pqxx::connection &conn) {
pqxx::work tx{conn};
auto embedding = pgvector::HalfVector({1, 2, 3});
assert(embedding.dimensions() == 3);
float arr[] = {4, 5, 6};
auto embedding2 = pgvector::HalfVector(arr, 3);
tx.exec_params("INSERT INTO items (half_embedding) VALUES ($1), ($2), ($3)",
embedding, embedding2, std::nullopt);

pqxx::result res{tx.exec_params(
"SELECT embedding FROM items ORDER BY half_embedding <-> $1", embedding2)};
assert(res.size() == 3);
assert(res[0][0].as<pgvector::HalfVector>() == embedding2);
assert(res[1][0].as<pgvector::HalfVector>() == embedding);
assert(!res[2][0].as<std::optional<pgvector::HalfVector>>().has_value());
tx.commit();
}

void test_stream(pqxx::connection &conn) {
pqxx::work tx{conn};
int count = 0;
Expand Down

0 comments on commit 44ede65

Please sign in to comment.