From d33e62925739ee78f9ee049658841717b131d1de Mon Sep 17 00:00:00 2001 From: "xianliang.li" Date: Wed, 30 Oct 2024 11:11:30 +0800 Subject: [PATCH] add range check Signed-off-by: xianliang.li --- include/knowhere/config.h | 54 +++++++++++++------ .../sparse/sparse_inverted_index_config.h | 4 +- tests/ut/test_config.cc | 26 +++++++++ 3 files changed, 65 insertions(+), 19 deletions(-) diff --git a/include/knowhere/config.h b/include/knowhere/config.h index 2c30a7d2b..204f757ad 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -58,6 +58,32 @@ typedef nlohmann::json Json; #define CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE std::optional #endif +template +struct Range { + T left; + T right; + bool include_left; + bool include_right; + + Range(T left, T right, bool includeLeft, bool includeRight) + : left(left), right(right), include_left(includeLeft), include_right(includeRight) { + } + + bool + within(T val) { + bool left_range_check = left < val || (include_left && left <= val); + bool right_range_check = val < right || (include_right && val <= right); + return left_range_check && right_range_check; + } + + std::string + to_string() { + std::string left_mark = include_left ? "[" : "("; + std::string right_mark = include_right ? "]" : ")"; + return left_mark + std::to_string(left) + ", " + std::to_string(right) + right_mark; + } +}; + template struct Entry {}; @@ -114,7 +140,7 @@ struct Entry { CFG_FLOAT* val; std::optional default_val; uint32_t type; - std::optional> range; + std::optional> range; std::optional desc; bool allow_empty_without_default = false; }; @@ -139,7 +165,7 @@ struct Entry { CFG_INT* val; std::optional default_val; uint32_t type; - std::optional> range; + std::optional> range; std::optional desc; bool allow_empty_without_default = false; }; @@ -164,7 +190,7 @@ struct Entry { CFG_INT64* val; std::optional default_val; uint32_t type; - std::optional> range; + std::optional> range; std::optional desc; bool allow_empty_without_default = false; }; @@ -228,8 +254,8 @@ class EntryAccess { } EntryAccess& - set_range(typename T::value_type a, typename T::value_type b) { - entry->range = std::make_pair(a, b); + set_range(typename T::value_type a, typename T::value_type b, bool include_left = true, bool include_right = true) { + entry->range = Range(a, b, include_left, include_right); return *this; } @@ -360,13 +386,11 @@ class Config { } CFG_INT::value_type v = json[it.first]; auto range_val = ptr->range.value(); - if (range_val.first <= v && v <= range_val.second) { + if (range_val.within(v)) { *ptr->val = v; } else { std::string msg = "Out of range in json: param '" + it.first + "' (" + - to_string(json[it.first]) + ") should be in range [" + - std::to_string(range_val.first) + ", " + std::to_string(range_val.second) + - "]"; + to_string(json[it.first]) + ") should be in range " + range_val.to_string(); show_err_msg(msg); return Status::out_of_range_in_json; } @@ -408,13 +432,11 @@ class Config { } CFG_INT64::value_type v = json[it.first]; auto range_val = ptr->range.value(); - if (range_val.first <= v && v <= range_val.second) { + if (range_val.within(v)) { *ptr->val = v; } else { std::string msg = "Out of range in json: param '" + it.first + "' (" + - to_string(json[it.first]) + ") should be in range [" + - std::to_string(range_val.first) + ", " + std::to_string(range_val.second) + - "]"; + to_string(json[it.first]) + ") should be in range " + range_val.to_string(); show_err_msg(msg); return Status::out_of_range_in_json; } @@ -456,13 +478,11 @@ class Config { } CFG_FLOAT::value_type v = json[it.first]; auto range_val = ptr->range.value(); - if (range_val.first <= v && v <= range_val.second) { + if (range_val.within(v)) { *ptr->val = v; } else { std::string msg = "Out of range in json: param '" + it.first + "' (" + - to_string(json[it.first]) + ") should be in range [" + - std::to_string(range_val.first) + ", " + std::to_string(range_val.second) + - "]"; + to_string(json[it.first]) + ") should be in range " + range_val.to_string(); show_err_msg(msg); return Status::out_of_range_in_json; } diff --git a/src/index/sparse/sparse_inverted_index_config.h b/src/index/sparse/sparse_inverted_index_config.h index 6f492925b..2d416e43a 100644 --- a/src/index/sparse/sparse_inverted_index_config.h +++ b/src/index/sparse/sparse_inverted_index_config.h @@ -27,12 +27,12 @@ class SparseInvertedIndexConfig : public BaseConfig { KNOWHERE_CONFIG_DECLARE_FIELD(drop_ratio_build) .description("drop ratio for build") .set_default(0.0f) - .set_range(0.0f, 1.0f) + .set_range(0.0f, 1.0f, true, false) .for_train(); KNOWHERE_CONFIG_DECLARE_FIELD(drop_ratio_search) .description("drop ratio for search") .set_default(0.0f) - .set_range(0.0f, 1.0f) + .set_range(0.0f, 1.0f, true, false) .for_search() .for_range_search() .for_iterator(); diff --git a/tests/ut/test_config.cc b/tests/ut/test_config.cc index 33f11ba92..cbc07f98f 100644 --- a/tests/ut/test_config.cc +++ b/tests/ut/test_config.cc @@ -19,6 +19,7 @@ #include "knowhere/version.h" #ifdef KNOWHERE_WITH_DISKANN #include "index/diskann/diskann_config.h" +#include "index/sparse/sparse_inverted_index_config.h" #endif #ifdef KNOWHERE_WITH_RAFT #include "index/gpu_raft/gpu_raft_cagra_config.h" @@ -110,6 +111,31 @@ TEST_CASE("Test config json parse", "[config]") { CHECK(test_config.dim.value() == 10000000000L); } + SECTION("check range data values") { + auto sparse_valid = GENERATE(as{}, + R"({ + "drop_ratio_build": 0.0 + })"); + knowhere::BaseConfig test_config; + knowhere::Json test_json = knowhere::Json::parse(sparse_valid); + s = knowhere::Config::FormatAndCheck(test_config, test_json); + CHECK(s == knowhere::Status::success); + s = knowhere::Config::Load(test_config, test_json, knowhere::TRAIN); + CHECK(s == knowhere::Status::success); + + auto sparse_invalid = GENERATE(as{}, + R"({ + "drop_ratio_build": 1.0 + })"); + + knowhere::SparseInvertedIndexConfig test_invalid_config; + knowhere::Json test_invalid_json = knowhere::Json::parse(sparse_invalid); + s = knowhere::Config::FormatAndCheck(test_invalid_config, test_invalid_json); + CHECK(s == knowhere::Status::success); + s = knowhere::Config::Load(test_invalid_config, test_invalid_json, knowhere::TRAIN); + CHECK(s == knowhere::Status::out_of_range_in_json); + } + SECTION("check invalid json values") { auto invalid_json_str = GENERATE(as{}, R"({