From b339d089ed4a7ac7a8ed4ab7d57cccdc0602f0ed Mon Sep 17 00:00:00 2001 From: Levi Tamasi Date: Mon, 9 Dec 2024 18:56:27 -0800 Subject: [PATCH] Write-side support for FAISS IVF indices (#13197) Summary: Pull Request resolved: https://github.com/facebook/rocksdb/pull/13197 The patch adds initial support for backing FAISS's inverted file based indices with data stored in RocksDB. It introduces a `SecondaryIndex` implementation called `FaissIVFIndex` which takes ownership of a `faiss::IndexIVF` object. During indexing, `FaissIVFIndex` treats the original value of the specified primary column as an embedding vector, and passes it to the provided FAISS index object to perform quantization. It replaces the original embedding vector with the result of the coarse quantizer (i.e. the inverted list id), and puts the result of the fine quantizer (if any) into the secondary index value. Note that this patch is only one half of the equation; it provides a way of storing FAISS inverted lists in RocksDB but there is currently no retrieval/search support (this will be a follow-up change). Also, the integration currently works only with our internal Buck build. I plan to add support for `cmake` / `make` based builds similarly to how we handle Folly. Reviewed By: jowlyzhang Differential Revision: D66907065 fbshipit-source-id: 63fdf29895d5feeffc230254a7ddfb0aac050967 --- BUCK | 25 ++ Makefile | 2 +- buckifier/buckify_rocksdb.py | 50 +++- src.mk | 6 + utilities/secondary_index/faiss_ivf_index.cc | 214 ++++++++++++++++++ utilities/secondary_index/faiss_ivf_index.h | 60 +++++ .../secondary_index/faiss_ivf_index_test.cc | 124 ++++++++++ 7 files changed, 472 insertions(+), 9 deletions(-) create mode 100644 utilities/secondary_index/faiss_ivf_index.cc create mode 100644 utilities/secondary_index/faiss_ivf_index.h create mode 100644 utilities/secondary_index/faiss_ivf_index_test.cc diff --git a/BUCK b/BUCK index 7e12501725d..bdc7423b6b9 100644 --- a/BUCK +++ b/BUCK @@ -368,6 +368,11 @@ cpp_library_wrapper(name="rocksdb_lib", srcs=[ cpp_library_wrapper(name="rocksdb_whole_archive_lib", srcs=[], deps=[":rocksdb_lib"], headers=[], link_whole=True, extra_test_libs=False) +cpp_library_wrapper(name="rocksdb_with_faiss_lib", srcs=["utilities/secondary_index/faiss_ivf_index.cc"], deps=[ + "//faiss:faiss", + ":rocksdb_lib", + ], headers=[], link_whole=False, extra_test_libs=False) + cpp_library_wrapper(name="rocksdb_test_lib", srcs=[ "db/db_test_util.cc", "db/db_with_timestamp_test_util.cc", @@ -382,6 +387,20 @@ cpp_library_wrapper(name="rocksdb_test_lib", srcs=[ "utilities/cassandra/test_utils.cc", ], deps=[":rocksdb_lib"], headers=[], link_whole=False, extra_test_libs=True) +cpp_library_wrapper(name="rocksdb_with_faiss_test_lib", srcs=[ + "db/db_test_util.cc", + "db/db_with_timestamp_test_util.cc", + "table/mock_table.cc", + "test_util/mock_time_env.cc", + "test_util/secondary_cache_test_util.cc", + "test_util/testharness.cc", + "test_util/testutil.cc", + "tools/block_cache_analyzer/block_cache_trace_analyzer.cc", + "tools/trace_analyzer_tool.cc", + "utilities/agg_merge/test_agg_merge.cc", + "utilities/cassandra/test_utils.cc", + ], deps=[":rocksdb_with_faiss_lib"], headers=[], link_whole=False, extra_test_libs=True) + cpp_library_wrapper(name="rocksdb_tools_lib", srcs=[ "test_util/testutil.cc", "tools/block_cache_analyzer/block_cache_trace_analyzer.cc", @@ -5078,6 +5097,12 @@ cpp_unittest_wrapper(name="external_sst_file_test", extra_compiler_flags=[]) +cpp_unittest_wrapper(name="faiss_ivf_index_test", + srcs=["utilities/secondary_index/faiss_ivf_index_test.cc"], + deps=[":rocksdb_with_faiss_test_lib"], + extra_compiler_flags=[]) + + cpp_unittest_wrapper(name="fault_injection_test", srcs=["db/fault_injection_test.cc"], deps=[":rocksdb_test_lib"], diff --git a/Makefile b/Makefile index 0bedd667b37..e2dee455a4c 100644 --- a/Makefile +++ b/Makefile @@ -659,7 +659,7 @@ ifneq ($(filter check-headers, $(MAKECMDGOALS)),) # TODO: add/support JNI headers DEV_HEADER_DIRS := $(sort include/ $(dir $(ALL_SOURCES))) # Some headers like in port/ are platform-specific - DEV_HEADERS_TO_CHECK := $(shell $(FIND) $(DEV_HEADER_DIRS) -type f -name '*.h' | grep -E -v 'port/|plugin/|lua/|range_tree/') + DEV_HEADERS_TO_CHECK := $(shell $(FIND) $(DEV_HEADER_DIRS) -type f -name '*.h' | grep -E -v 'port/|plugin/|lua/|range_tree/|secondary_index/') PUBLIC_HEADERS_TO_CHECK := $(shell $(FIND) include/ -type f -name '*.h' | grep -E -v 'lua/') else DEV_HEADERS_TO_CHECK := diff --git a/buckifier/buckify_rocksdb.py b/buckifier/buckify_rocksdb.py index 92fcb8a7bb3..035254b5ad1 100755 --- a/buckifier/buckify_rocksdb.py +++ b/buckifier/buckify_rocksdb.py @@ -161,6 +161,15 @@ def generate_buck(repo_path, deps_map): extra_external_deps="", link_whole=True, ) + # rocksdb_with_faiss_lib + BUCK.add_library( + "rocksdb_with_faiss_lib", + src_mk.get("WITH_FAISS_LIB_SOURCES", []), + deps=[ + "//faiss:faiss", + ":rocksdb_lib", + ], + ) # rocksdb_test_lib BUCK.add_library( "rocksdb_test_lib", @@ -171,6 +180,18 @@ def generate_buck(repo_path, deps_map): [":rocksdb_lib"], extra_test_libs=True, ) + # rocksdb_with_faiss_test_lib + BUCK.add_library( + "rocksdb_with_faiss_test_lib", + src_mk.get("MOCK_LIB_SOURCES", []) + + src_mk.get("TEST_LIB_SOURCES", []) + + src_mk.get("EXP_LIB_SOURCES", []) + + src_mk.get("ANALYZER_LIB_SOURCES", []), + deps=[ + ":rocksdb_with_faiss_lib", + ], + extra_test_libs=True, + ) # rocksdb_tools_lib BUCK.add_library( "rocksdb_tools_lib", @@ -278,11 +299,16 @@ def generate_buck(repo_path, deps_map): for test_src in src_mk.get("TEST_MAIN_SOURCES", []): test = test_src.split(".c")[0].strip().split("/")[-1].strip() - test_source_map[test] = test_src + test_source_map[test] = (test_src, False) print("" + test + " " + test_src) + for test_src in src_mk.get("WITH_FAISS_TEST_MAIN_SOURCES", []): + test = test_src.split(".c")[0].strip().split("/")[-1].strip() + test_source_map[test] = (test_src, True) + print("" + test + " " + test_src + " [FAISS]") + for target_alias, deps in deps_map.items(): - for test, test_src in sorted(test_source_map.items()): + for test, (test_src, with_faiss) in sorted(test_source_map.items()): if len(test) == 0: print(ColorString.warning("Failed to get test name for %s" % test_src)) continue @@ -304,12 +330,20 @@ def generate_buck(repo_path, deps_map): extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]), ) else: - BUCK.register_test( - test_target_name, - test_src, - deps=json.dumps(deps["extra_deps"] + [":rocksdb_test_lib"]), - extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]), - ) + if with_faiss: + BUCK.register_test( + test_target_name, + test_src, + deps=json.dumps(deps["extra_deps"] + [":rocksdb_with_faiss_test_lib"]), + extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]), + ) + else: + BUCK.register_test( + test_target_name, + test_src, + deps=json.dumps(deps["extra_deps"] + [":rocksdb_test_lib"]), + extra_compiler_flags=json.dumps(deps["extra_compiler_flags"]), + ) BUCK.export_file("tools/db_crashtest.py") print(ColorString.info("Generated BUCK Summary:")) diff --git a/src.mk b/src.mk index fbe9ba1ea8e..121a08e928a 100644 --- a/src.mk +++ b/src.mk @@ -341,6 +341,9 @@ LIB_SOURCES_ASM = LIB_SOURCES_C = endif +WITH_FAISS_LIB_SOURCES = \ + utilities/secondary_index/faiss_ivf_index.cc \ + RANGE_TREE_SOURCES =\ utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.cc \ utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.cc \ @@ -651,6 +654,9 @@ TEST_MAIN_SOURCES = \ TEST_MAIN_SOURCES_C = \ db/c_test.c \ +WITH_FAISS_TEST_MAIN_SOURCES = \ + utilities/secondary_index/faiss_ivf_index_test.cc \ + MICROBENCH_SOURCES = \ microbench/ribbon_bench.cc \ microbench/db_basic_bench.cc \ diff --git a/utilities/secondary_index/faiss_ivf_index.cc b/utilities/secondary_index/faiss_ivf_index.cc new file mode 100644 index 00000000000..c419b98a2d1 --- /dev/null +++ b/utilities/secondary_index/faiss_ivf_index.cc @@ -0,0 +1,214 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "utilities/secondary_index/faiss_ivf_index.h" + +#include + +#include "faiss/invlists/InvertedLists.h" +#include "util/coding.h" + +namespace ROCKSDB_NAMESPACE { + +class FaissIVFIndex::Adapter : public faiss::InvertedLists { + public: + Adapter(size_t num_lists, size_t code_size) + : faiss::InvertedLists(num_lists, code_size) { + use_iterator = true; + } + + // Non-iterator-based read interface; not implemented/used since use_iterator + // is true + size_t list_size(size_t /* list_no */) const override { + assert(false); + return 0; + } + + const uint8_t* get_codes(size_t /* list_no */) const override { + assert(false); + return nullptr; + } + + const faiss::idx_t* get_ids(size_t /* list_no */) const override { + assert(false); + return nullptr; + } + + // Iterator-based read interface; not yet implemented + faiss::InvertedListsIterator* get_iterator( + size_t /* list_no */, + void* /* inverted_list_context */ = nullptr) const override { + // TODO: implement this + + assert(false); + return nullptr; + } + + // Write interface; only add_entry is implemented/required for now + size_t add_entry(size_t /* list_no */, faiss::idx_t /* id */, + const uint8_t* code, + void* inverted_list_context = nullptr) override { + std::string* const code_str = + static_cast(inverted_list_context); + assert(code_str); + + code_str->assign(reinterpret_cast(code), code_size); + + return 0; + } + + size_t add_entries(size_t /* list_no */, size_t /* num_entries */, + const faiss::idx_t* /* ids */, + const uint8_t* /* code */) override { + assert(false); + return 0; + } + + void update_entry(size_t /* list_no */, size_t /* offset */, + faiss::idx_t /* id */, const uint8_t* /* code */) override { + assert(false); + } + + void update_entries(size_t /* list_no */, size_t /* offset */, + size_t /* num_entries */, const faiss::idx_t* /* ids */, + const uint8_t* /* code */) override { + assert(false); + } + + void resize(size_t /* list_no */, size_t /* new_size */) override { + assert(false); + } +}; + +std::string FaissIVFIndex::SerializeLabel(faiss::idx_t label) { + std::string label_str; + PutVarsignedint64(&label_str, label); + + return label_str; +} + +faiss::idx_t FaissIVFIndex::DeserializeLabel(Slice label_slice) { + faiss::idx_t label = -1; + [[maybe_unused]] const bool ok = GetVarsignedint64(&label_slice, &label); + assert(ok); + + return label; +} + +FaissIVFIndex::FaissIVFIndex(std::unique_ptr&& index, + std::string primary_column_name) + : adapter_(std::make_unique(index->nlist, index->code_size)), + index_(std::move(index)), + primary_column_name_(std::move(primary_column_name)) { + assert(index_); + assert(index_->quantizer); + + index_->replace_invlists(adapter_.get()); +} + +FaissIVFIndex::~FaissIVFIndex() = default; + +void FaissIVFIndex::SetPrimaryColumnFamily(ColumnFamilyHandle* column_family) { + assert(column_family); + primary_column_family_ = column_family; +} + +void FaissIVFIndex::SetSecondaryColumnFamily( + ColumnFamilyHandle* column_family) { + assert(column_family); + secondary_column_family_ = column_family; +} + +ColumnFamilyHandle* FaissIVFIndex::GetPrimaryColumnFamily() const { + return primary_column_family_; +} + +ColumnFamilyHandle* FaissIVFIndex::GetSecondaryColumnFamily() const { + return secondary_column_family_; +} + +Slice FaissIVFIndex::GetPrimaryColumnName() const { + return primary_column_name_; +} + +Status FaissIVFIndex::UpdatePrimaryColumnValue( + const Slice& /* primary_key */, const Slice& primary_column_value, + std::optional>* updated_column_value) + const { + assert(updated_column_value); + + if (primary_column_value.size() != index_->d * sizeof(float)) { + return Status::InvalidArgument( + "Incorrectly sized vector passed to FaissIVFIndex"); + } + + constexpr faiss::idx_t n = 1; + faiss::idx_t label = -1; + + try { + index_->quantizer->assign( + n, reinterpret_cast(primary_column_value.data()), &label); + } catch (const std::exception& e) { + return Status::InvalidArgument(e.what()); + } + + if (label < 0 || label >= index_->nlist) { + return Status::InvalidArgument( + "Unexpected label returned by coarse quantizer"); + } + + updated_column_value->emplace(SerializeLabel(label)); + + return Status::OK(); +} + +Status FaissIVFIndex::GetSecondaryKeyPrefix( + const Slice& /* primary_key */, const Slice& primary_column_value, + std::variant* secondary_key_prefix) const { + assert(secondary_key_prefix); + + [[maybe_unused]] const faiss::idx_t label = + DeserializeLabel(primary_column_value); + assert(label >= 0); + assert(label < index_->nlist); + + *secondary_key_prefix = primary_column_value; + + return Status::OK(); +} + +Status FaissIVFIndex::GetSecondaryValue( + const Slice& /* primary_key */, const Slice& primary_column_value, + const Slice& original_column_value, + std::optional>* secondary_value) const { + assert(secondary_value); + + const faiss::idx_t label = DeserializeLabel(primary_column_value); + assert(label >= 0); + assert(label < index_->nlist); + + constexpr faiss::idx_t n = 1; + constexpr faiss::idx_t* xids = nullptr; + std::string code_str; + + try { + index_->add_core( + n, reinterpret_cast(original_column_value.data()), xids, + &label, &code_str); + } catch (const std::exception& e) { + return Status::InvalidArgument(e.what()); + } + + if (code_str.size() != index_->code_size) { + return Status::InvalidArgument( + "Unexpected code returned by fine quantizer"); + } + + secondary_value->emplace(std::move(code_str)); + + return Status::OK(); +} + +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/secondary_index/faiss_ivf_index.h b/utilities/secondary_index/faiss_ivf_index.h new file mode 100644 index 00000000000..956dba7762e --- /dev/null +++ b/utilities/secondary_index/faiss_ivf_index.h @@ -0,0 +1,60 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once + +#include +#include + +#include "faiss/IndexIVF.h" +#include "rocksdb/utilities/secondary_index.h" + +namespace ROCKSDB_NAMESPACE { + +// A SecondaryIndex implementation that wraps a FAISS inverted file index. +class FaissIVFIndex : public SecondaryIndex { + public: + explicit FaissIVFIndex(std::unique_ptr&& index, + std::string primary_column_name); + ~FaissIVFIndex() override; + + void SetPrimaryColumnFamily(ColumnFamilyHandle* column_family) override; + void SetSecondaryColumnFamily(ColumnFamilyHandle* column_family) override; + + ColumnFamilyHandle* GetPrimaryColumnFamily() const override; + ColumnFamilyHandle* GetSecondaryColumnFamily() const override; + + Slice GetPrimaryColumnName() const override; + + Status UpdatePrimaryColumnValue( + const Slice& primary_key, const Slice& primary_column_value, + std::optional>* updated_column_value) + const override; + + Status GetSecondaryKeyPrefix( + const Slice& primary_key, const Slice& primary_column_value, + std::variant* secondary_key_prefix) const override; + + Status GetSecondaryValue(const Slice& primary_key, + const Slice& primary_column_value, + const Slice& original_column_value, + std::optional>* + secondary_value) const override; + + private: + class Adapter; + + static std::string SerializeLabel(faiss::idx_t label); + static faiss::idx_t DeserializeLabel(Slice label_slice); + + std::unique_ptr adapter_; + std::unique_ptr index_; + std::string primary_column_name_; + ColumnFamilyHandle* primary_column_family_{}; + ColumnFamilyHandle* secondary_column_family_{}; +}; + +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/secondary_index/faiss_ivf_index_test.cc b/utilities/secondary_index/faiss_ivf_index_test.cc new file mode 100644 index 00000000000..5d2008a47a7 --- /dev/null +++ b/utilities/secondary_index/faiss_ivf_index_test.cc @@ -0,0 +1,124 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "utilities/secondary_index/faiss_ivf_index.h" + +#include +#include +#include +#include + +#include "faiss/IndexFlat.h" +#include "faiss/IndexIVFFlat.h" +#include "faiss/utils/random.h" +#include "rocksdb/utilities/transaction_db.h" +#include "test_util/testharness.h" +#include "util/coding.h" + +namespace ROCKSDB_NAMESPACE { + +TEST(FaissIVFIndexTest, Basic) { + constexpr size_t dim = 128; + auto quantizer = std::make_unique(dim); + + constexpr size_t num_lists = 16; + auto index = + std::make_unique(quantizer.get(), dim, num_lists); + + constexpr faiss::idx_t num_vectors = 1024; + std::vector embeddings(dim * num_vectors); + faiss::float_rand(embeddings.data(), dim * num_vectors, 42); + + index->train(num_vectors, embeddings.data()); + + index->nprobe = 2; + + const std::string db_name = test::PerThreadDBPath("faiss_ivf_index_test"); + EXPECT_OK(DestroyDB(db_name, Options())); + + Options options; + options.create_if_missing = true; + + TransactionDBOptions txn_db_options; + const std::string primary_column_name = "embedding"; + txn_db_options.secondary_indices.emplace_back( + std::make_shared(std::move(index), primary_column_name)); + + TransactionDB* db = nullptr; + ASSERT_OK(TransactionDB::Open(options, txn_db_options, db_name, &db)); + + std::unique_ptr db_guard(db); + + ColumnFamilyOptions cf1_opts; + ColumnFamilyHandle* cfh1 = nullptr; + ASSERT_OK(db->CreateColumnFamily(cf1_opts, "cf1", &cfh1)); + std::unique_ptr cfh1_guard(cfh1); + + ColumnFamilyOptions cf2_opts; + ColumnFamilyHandle* cfh2 = nullptr; + ASSERT_OK(db->CreateColumnFamily(cf2_opts, "cf2", &cfh2)); + std::unique_ptr cfh2_guard(cfh2); + + const auto& secondary_index = txn_db_options.secondary_indices.back(); + secondary_index->SetPrimaryColumnFamily(cfh1); + secondary_index->SetSecondaryColumnFamily(cfh2); + + { + std::unique_ptr txn(db->BeginTransaction(WriteOptions())); + + for (faiss::idx_t i = 0; i < num_vectors; ++i) { + const std::string primary_key = std::to_string(i); + + ASSERT_OK(txn->PutEntity( + cfh1, primary_key, + WideColumns{ + {primary_column_name, + Slice(reinterpret_cast(embeddings.data() + i * dim), + dim * sizeof(float))}})); + } + + ASSERT_OK(txn->Commit()); + } + + { + size_t num_found = 0; + + std::unique_ptr it(db->NewIterator(ReadOptions(), cfh2)); + + for (it->SeekToFirst(); it->Valid(); it->Next()) { + Slice key = it->key(); + faiss::idx_t label = -1; + ASSERT_TRUE(GetVarsignedint64(&key, &label)); + ASSERT_GE(label, 0); + ASSERT_LT(label, num_lists); + + faiss::idx_t id = -1; + ASSERT_EQ(std::from_chars(key.data(), key.data() + key.size(), id).ec, + std::errc()); + ASSERT_GE(id, 0); + ASSERT_LT(id, num_vectors); + + // Since we use IndexIVFFlat, there is no fine quantization, so the code + // is actually just the original embedding + ASSERT_EQ( + it->value(), + Slice(reinterpret_cast(embeddings.data() + id * dim), + dim * sizeof(float))); + + ++num_found; + } + + ASSERT_OK(it->status()); + ASSERT_EQ(num_found, num_vectors); + } +} + +} // namespace ROCKSDB_NAMESPACE + +int main(int argc, char** argv) { + ROCKSDB_NAMESPACE::port::InstallStackTraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}