diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 013c0875f..63711ee8b 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -476,7 +476,7 @@ struct FaissHnswIteratorWorkspace { faiss::SearchParametersHNSW search_params; // the query - std::unique_ptr query; + std::unique_ptr query; // whether the initial search is done or not. // basically, upon initialization, we need to traverse to the largest @@ -493,7 +493,7 @@ struct FaissHnswIteratorWorkspace { // Contains an iterator logic class FaissHnswIterator : public IndexIterator { public: - FaissHnswIterator(const std::shared_ptr& index_in, std::unique_ptr&& query_in, + FaissHnswIterator(const std::shared_ptr& index_in, std::unique_ptr&& query_in, const BitsetView& bitset_in, const int32_t ef_in, bool larger_is_closer, const float refine_ratio = 0.5f) : IndexIterator(larger_is_closer, refine_ratio), index{index_in} { @@ -559,10 +559,10 @@ class FaissHnswIterator : public IndexIterator { } // set query - workspace.qdis->set_query(reinterpret_cast(query_in.get())); + workspace.qdis->set_query(query_in.get()); if (workspace.qdis_refine != nullptr) { - workspace.qdis_refine->set_query(reinterpret_cast(query_in.get())); + workspace.qdis_refine->set_query(query_in.get()); } // set up a buffer that tracks visited points @@ -1168,6 +1168,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { return expected>::Err(Status::empty_index, "index not loaded"); } + if (data_format != DataFormatEnum::fp32 && data_format != DataFormatEnum::fp16 && + data_format != DataFormatEnum::bf16) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return expected>::Err(Status::invalid_args, "unsupported data format"); + } + // parse parameters const auto dim = dataset->GetDim(); const auto n_queries = dataset->GetRows(); @@ -1187,24 +1193,14 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { for (int64_t i = 0; i < n_queries; i++) { futs.emplace_back(search_pool->push([&, idx = i] { // The query data is always cloned - std::unique_ptr cur_query; + std::unique_ptr cur_query = std::make_unique(dim); if (data_format == DataFormatEnum::fp32) { - cur_query = std::make_unique(dim * sizeof(float)); - std::copy_n(reinterpret_cast(reinterpret_cast(data) + idx * dim), - dim * sizeof(float), cur_query.get()); - } else if (data_format == DataFormatEnum::fp16) { - cur_query = std::make_unique(dim * sizeof(knowhere::fp16)); - std::copy_n( - reinterpret_cast(reinterpret_cast(data) + idx * dim), - dim * sizeof(knowhere::fp16), cur_query.get()); - } else if (data_format == DataFormatEnum::bf16) { - cur_query = std::make_unique(dim * sizeof(knowhere::bf16)); - std::copy_n( - reinterpret_cast(reinterpret_cast(data) + idx * dim), - dim * sizeof(knowhere::bf16), cur_query.get()); + std::copy_n(reinterpret_cast(data) + idx * dim, dim, cur_query.get()); + } else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) { + convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim); } else { - // invalid one + // invalid one. Should not be triggered, bcz input parameters are validated throw; } diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index d5ded26ad..4a35d63a4 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -146,7 +146,7 @@ class FaissHnswFlatConfig : public FaissHnswConfig { *err_msg = "refine is not supported for this index"; LOG_KNOWHERE_ERROR_ << *err_msg; } - return Status::invalid_value_in_json; + return Status::invalid_args; } } return Status::success; @@ -182,7 +182,7 @@ class FaissHnswSqConfig : public FaissHnswConfig { *err_msg = "invalid scalar quantizer type"; LOG_KNOWHERE_ERROR_ << *err_msg; } - return Status::invalid_value_in_json; + return Status::invalid_args; } // check refine @@ -192,7 +192,7 @@ class FaissHnswSqConfig : public FaissHnswConfig { *err_msg = "invalid refine type type"; LOG_KNOWHERE_ERROR_ << *err_msg; } - return Status::invalid_value_in_json; + return Status::invalid_args; } } } @@ -225,7 +225,30 @@ class FaissHnswPqConfig : public FaissHnswConfig { KNOHWERE_DECLARE_CONFIG(FaissHnswPqConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(m).description("m").set_default(32).for_train().set_range(1, 65536); - KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 16); + // FAISS rejects nbits > 24, because it is not practical + KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 24); + } + + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + switch (param_type) { + case PARAM_TYPE::TRAIN: { + if (dim.has_value() && m.has_value()) { + int vec_dim = dim.value(); + int param_m = m.value(); + if (vec_dim % param_m != 0) { + if (err_msg != nullptr) { + *err_msg = + "The dimension of the vector (dim) should be a multiple of the number of subquantizers " + "(m). Dimension: " + + std::to_string(vec_dim) + ", m: " + std::to_string(param_m); + } + return Status::invalid_args; + } + } + } + } + return Status::success; } }; @@ -239,8 +262,36 @@ class FaissHnswPrqConfig : public FaissHnswConfig { CFG_INT nbits; KNOHWERE_DECLARE_CONFIG(FaissHnswPrqConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(m).description("Number of splits").set_default(2).for_train().set_range(1, 65536); - KNOWHERE_CONFIG_DECLARE_FIELD(nrq).description("Number of residual subquantizers").for_train().set_range(1, 64); - KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 64); + // I'm not sure whether nrq > 16 is practical + KNOWHERE_CONFIG_DECLARE_FIELD(nrq) + .description("Number of residual subquantizers") + .set_default(2) + .for_train() + .set_range(1, 16); + // FAISS rejects nbits > 24, because it is not practical + KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 24); + } + + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + switch (param_type) { + case PARAM_TYPE::TRAIN: { + if (dim.has_value() && m.has_value()) { + int vec_dim = dim.value(); + int param_m = m.value(); + if (vec_dim % param_m != 0) { + if (err_msg != nullptr) { + *err_msg = + "The dimension of a vector (dim) should be a multiple of the number of subquantizers " + "(m). Dimension: " + + std::to_string(vec_dim) + ", m: " + std::to_string(param_m); + } + return Status::invalid_args; + } + } + } + } + return Status::success; } }; diff --git a/src/index/ivf/ivf_config.h b/src/index/ivf/ivf_config.h index a65b3a27e..2c7fecfe9 100644 --- a/src/index/ivf/ivf_config.h +++ b/src/index/ivf/ivf_config.h @@ -67,7 +67,8 @@ class IvfPqConfig : public IvfConfig { CFG_INT nbits; KNOHWERE_DECLARE_CONFIG(IvfPqConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(m).description("m").for_train().set_range(1, 65536); - KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 64); + // FAISS rejects nbits > 24, because it is not practical + KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 24); } Status @@ -80,8 +81,9 @@ class IvfPqConfig : public IvfConfig { if (vec_dim % param_m != 0) { if (err_msg) { *err_msg = - "dimension must be able to be divided by `m`, dimension: " + std::to_string(vec_dim) + - ", m: " + std::to_string(param_m); + "The dimension of a vector (dim) should be a multiple of the number of subquantizers " + "(m). Dimension: " + + std::to_string(vec_dim) + ", m: " + std::to_string(param_m); } return Status::invalid_args; } @@ -118,8 +120,8 @@ class ScannConfig : public IvfFlatConfig { int vec_dim = dim.value(); if (vec_dim % 2 != 0) { if (err_msg) { - *err_msg = - "dimension must be able to be divided by 2, dimension:" + std::to_string(vec_dim); + *err_msg = "The dimension of a vector (dim) should be a multiple of 2. Dimension:" + + std::to_string(vec_dim); } return Status::invalid_args; }