Skip to content

Commit

Permalink
fix Index parameters handling and anniterator (zilliztech#913)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <[email protected]>
  • Loading branch information
alexanderguzhva authored and foxspy committed Nov 18, 2024
1 parent 50c87bc commit c611ed9
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 30 deletions.
34 changes: 15 additions & 19 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ struct FaissHnswIteratorWorkspace {
faiss::SearchParametersHNSW search_params;

// the query
std::unique_ptr<uint8_t[]> query;
std::unique_ptr<float[]> query;

// whether the initial search is done or not.
// basically, upon initialization, we need to traverse to the largest
Expand All @@ -493,7 +493,7 @@ struct FaissHnswIteratorWorkspace {
// Contains an iterator logic
class FaissHnswIterator : public IndexIterator {
public:
FaissHnswIterator(const std::shared_ptr<faiss::Index>& index_in, std::unique_ptr<uint8_t[]>&& query_in,
FaissHnswIterator(const std::shared_ptr<faiss::Index>& index_in, std::unique_ptr<float[]>&& query_in,
const BitsetView& bitset_in, const int32_t ef_in, bool larger_is_closer,
const float refine_ratio = 0.5f)
: IndexIterator(larger_is_closer, refine_ratio), index{index_in} {
Expand Down Expand Up @@ -559,10 +559,10 @@ class FaissHnswIterator : public IndexIterator {
}

// set query
workspace.qdis->set_query(reinterpret_cast<const float*>(query_in.get()));
workspace.qdis->set_query(query_in.get());

if (workspace.qdis_refine != nullptr) {
workspace.qdis_refine->set_query(reinterpret_cast<const float*>(query_in.get()));
workspace.qdis_refine->set_query(query_in.get());
}

// set up a buffer that tracks visited points
Expand Down Expand Up @@ -1168,6 +1168,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::empty_index, "index not loaded");
}

if (data_format != DataFormatEnum::fp32 && data_format != DataFormatEnum::fp16 &&
data_format != DataFormatEnum::bf16) {
LOG_KNOWHERE_ERROR_ << "Unsupported data format";
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::invalid_args, "unsupported data format");
}

// parse parameters
const auto dim = dataset->GetDim();
const auto n_queries = dataset->GetRows();
Expand All @@ -1187,24 +1193,14 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
for (int64_t i = 0; i < n_queries; i++) {
futs.emplace_back(search_pool->push([&, idx = i] {
// The query data is always cloned
std::unique_ptr<uint8_t[]> cur_query;
std::unique_ptr<float[]> cur_query = std::make_unique<float[]>(dim);

if (data_format == DataFormatEnum::fp32) {
cur_query = std::make_unique<uint8_t[]>(dim * sizeof(float));
std::copy_n(reinterpret_cast<const uint8_t*>(reinterpret_cast<const float*>(data) + idx * dim),
dim * sizeof(float), cur_query.get());
} else if (data_format == DataFormatEnum::fp16) {
cur_query = std::make_unique<uint8_t[]>(dim * sizeof(knowhere::fp16));
std::copy_n(
reinterpret_cast<const uint8_t*>(reinterpret_cast<const knowhere::fp16*>(data) + idx * dim),
dim * sizeof(knowhere::fp16), cur_query.get());
} else if (data_format == DataFormatEnum::bf16) {
cur_query = std::make_unique<uint8_t[]>(dim * sizeof(knowhere::bf16));
std::copy_n(
reinterpret_cast<const uint8_t*>(reinterpret_cast<const knowhere::bf16*>(data) + idx * dim),
dim * sizeof(knowhere::bf16), cur_query.get());
std::copy_n(reinterpret_cast<const float*>(data) + idx * dim, dim, cur_query.get());
} else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) {
convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim);
} else {
// invalid one
// invalid one. Should not be triggered, bcz input parameters are validated
throw;
}

Expand Down
63 changes: 57 additions & 6 deletions src/index/hnsw/faiss_hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class FaissHnswFlatConfig : public FaissHnswConfig {
*err_msg = "refine is not supported for this index";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::invalid_value_in_json;
return Status::invalid_args;
}
}
return Status::success;
Expand Down Expand Up @@ -182,7 +182,7 @@ class FaissHnswSqConfig : public FaissHnswConfig {
*err_msg = "invalid scalar quantizer type";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::invalid_value_in_json;
return Status::invalid_args;
}

// check refine
Expand All @@ -192,7 +192,7 @@ class FaissHnswSqConfig : public FaissHnswConfig {
*err_msg = "invalid refine type type";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::invalid_value_in_json;
return Status::invalid_args;
}
}
}
Expand Down Expand Up @@ -225,7 +225,30 @@ class FaissHnswPqConfig : public FaissHnswConfig {

KNOHWERE_DECLARE_CONFIG(FaissHnswPqConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(m).description("m").set_default(32).for_train().set_range(1, 65536);
KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 16);
// FAISS rejects nbits > 24, because it is not practical
KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 24);
}

Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
switch (param_type) {
case PARAM_TYPE::TRAIN: {
if (dim.has_value() && m.has_value()) {
int vec_dim = dim.value();
int param_m = m.value();
if (vec_dim % param_m != 0) {
if (err_msg != nullptr) {
*err_msg =
"The dimension of the vector (dim) should be a multiple of the number of subquantizers "
"(m). Dimension: " +
std::to_string(vec_dim) + ", m: " + std::to_string(param_m);
}
return Status::invalid_args;
}
}
}
}
return Status::success;
}
};

Expand All @@ -239,8 +262,36 @@ class FaissHnswPrqConfig : public FaissHnswConfig {
CFG_INT nbits;
KNOHWERE_DECLARE_CONFIG(FaissHnswPrqConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(m).description("Number of splits").set_default(2).for_train().set_range(1, 65536);
KNOWHERE_CONFIG_DECLARE_FIELD(nrq).description("Number of residual subquantizers").for_train().set_range(1, 64);
KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 64);
// I'm not sure whether nrq > 16 is practical
KNOWHERE_CONFIG_DECLARE_FIELD(nrq)
.description("Number of residual subquantizers")
.set_default(2)
.for_train()
.set_range(1, 16);
// FAISS rejects nbits > 24, because it is not practical
KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 24);
}

Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
switch (param_type) {
case PARAM_TYPE::TRAIN: {
if (dim.has_value() && m.has_value()) {
int vec_dim = dim.value();
int param_m = m.value();
if (vec_dim % param_m != 0) {
if (err_msg != nullptr) {
*err_msg =
"The dimension of a vector (dim) should be a multiple of the number of subquantizers "
"(m). Dimension: " +
std::to_string(vec_dim) + ", m: " + std::to_string(param_m);
}
return Status::invalid_args;
}
}
}
}
return Status::success;
}
};

Expand Down
12 changes: 7 additions & 5 deletions src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class IvfPqConfig : public IvfConfig {
CFG_INT nbits;
KNOHWERE_DECLARE_CONFIG(IvfPqConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(m).description("m").for_train().set_range(1, 65536);
KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 64);
// FAISS rejects nbits > 24, because it is not practical
KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 24);
}

Status
Expand All @@ -80,8 +81,9 @@ class IvfPqConfig : public IvfConfig {
if (vec_dim % param_m != 0) {
if (err_msg) {
*err_msg =
"dimension must be able to be divided by `m`, dimension: " + std::to_string(vec_dim) +
", m: " + std::to_string(param_m);
"The dimension of a vector (dim) should be a multiple of the number of subquantizers "
"(m). Dimension: " +
std::to_string(vec_dim) + ", m: " + std::to_string(param_m);
}
return Status::invalid_args;
}
Expand Down Expand Up @@ -118,8 +120,8 @@ class ScannConfig : public IvfFlatConfig {
int vec_dim = dim.value();
if (vec_dim % 2 != 0) {
if (err_msg) {
*err_msg =
"dimension must be able to be divided by 2, dimension:" + std::to_string(vec_dim);
*err_msg = "The dimension of a vector (dim) should be a multiple of 2. Dimension:" +
std::to_string(vec_dim);
}
return Status::invalid_args;
}
Expand Down

0 comments on commit c611ed9

Please sign in to comment.