diff --git a/python/BUILD b/python/BUILD index 12c1612..5575008 100644 --- a/python/BUILD +++ b/python/BUILD @@ -56,3 +56,49 @@ py_test( "@com_google_absl_py//absl/testing:parameterized", ], ) + +cc_library( + name = "array_record_data_source_cpp", + srcs = ["array_record_data_source_cpp.cc"], + hdrs = ["array_record_data_source_cpp.h"], + deps = [ + ":read_instructions_lib", + "//cpp:array_record_reader", + "//cpp:array_record_writer", + "//third_party/py/grain/google/placer", + "//util/task:status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_riegeli//riegeli/bytes:file_reader", + ], +) + +cc_library( + name = "read_instructions_lib", + srcs = ["read_instructions_lib.cc"], + hdrs = ["read_instructions_lib.h"], + deps = [ + "//cpp:array_record_reader", + "//cpp:common", + "//cpp:parallel_for", + "//cpp:thread_pool", + "//file/base", + "//third_party/py/grain/google/placer", + "//third_party/re2", + "//thread", + "//util/task:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_riegeli//riegeli/bytes:file_reader", + ], +) diff --git a/python/array_record_data_source_cpp.cc b/python/array_record_data_source_cpp.cc new file mode 100644 index 0000000..682a8f2 --- /dev/null +++ b/python/array_record_data_source_cpp.cc @@ -0,0 +1,112 @@ +#include "python/array_record_data_source_cpp.h" +#include "absl/flags/flag.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "python/read_instructions_lib.h" +#include "cpp/array_record_reader.h" +#include "cpp/array_record_writer.h" + + +namespace array_record { + +using ArrayRecordReaderOptions = ::array_record::ArrayRecordReaderBase::Options; +using ArrayRecordWriterOptions = ::array_record::ArrayRecordWriterBase::Options; + +static uint64_t ArrayRecordGetNumRecords(const std::string& filename) { + const array_record::ArrayRecordReader> reader( + std::forward_as_tuple(filename)); + return reader.NumRecords(); +} + +ArrayRecordDataSource::ArrayRecordDataSource(absl::Span paths_){ + absl::StatusOr> read_instructions_or_failure = + GetReadInstructions(paths_, ArrayRecordGetNumRecords); + CHECK_OK(read_instructions_or_failure); + read_instructions_ = *read_instructions_or_failure; + + total_num_records_ = 0; + for (const auto& ri : read_instructions_) { + total_num_records_ += ri.NumRecords(); + } + readers_.resize(read_instructions_.size()); +} + +uint64_t ArrayRecordDataSource::NumRecords() const {return total_num_records_;} + +std::pair ArrayRecordDataSource::GetReaderIndexAndPosition( + uint64_t key) const { + int reader_index = 0; + CHECK(key < NumRecords()) << "Invalid key " << key; + while (key >= read_instructions_[reader_index].NumRecords()) { + key -= read_instructions_[reader_index].NumRecords(); + reader_index++; + } + key += read_instructions_[reader_index].start; + return {reader_index, key}; +} + +absl::Status ArrayRecordDataSource::CheckGroupSize( + const absl::string_view filename, + const std::optional options_string) { + // Check that ArrayRecord files were created with group_size=1. Old files + // (prior 2022-10) don't have this info. + if (!options_string.has_value()) { + return absl::OkStatus(); + } + auto maybe_options = ArrayRecordWriterOptions::FromString(*options_string); + if (!maybe_options.ok()) { + return maybe_options.status(); + } + const int group_size = maybe_options->group_size(); + if (group_size != 1) { + return absl::InvalidArgumentError(absl::StrCat( + "File ", filename, " was created with group size ", group_size, + ". Grain requires group size 1 for good performance. Please " + "re-generate your ArrayRecord files with 'group_size:1'.")); + } + return absl::OkStatus(); +} + +void ArrayRecordDataSource::CreateReader(const int reader_index) { + // See b/262550570 for the readahead buffer size. + ArrayRecordReaderOptions array_record_reader_options; + array_record_reader_options.set_max_parallelism(0); + array_record_reader_options.set_readahead_buffer_size(0); + riegeli::FileReaderBase::Options file_reader_options; + file_reader_options.set_buffer_size(1 << 15); + // Copy is on purpose. + std::string filename = read_instructions_[reader_index].filename; + auto reader = std::make_unique< + array_record::ArrayRecordReader>>( + std::forward_as_tuple(filename, file_reader_options), + array_record_reader_options, array_record::ArrayRecordGlobalPool()); + const auto status = CheckGroupSize(filename, reader->WriterOptionsString()); + if (!status.ok()) { + LOG(ERROR) << status; + } + { + const std::lock_guard lock(create_reader_mutex_); + if (readers_[reader_index] == nullptr) { + readers_[reader_index] = std::move(reader); + } + } +} + +absl::Status ArrayRecordDataSource::GetItem( + uint64_t key, absl::string_view* record) { + int reader_index; + uint64_t position; + std::tie(reader_index, position) = GetReaderIndexAndPosition(key); + if (readers_[reader_index] == nullptr) { + CreateReader(reader_index); + } + return readers_[reader_index]->ParallelReadRecordsWithIndices( + {position}, + [&](uint64_t read_idx, absl::string_view value) -> absl::Status { + // TODO(amrahmed): Follow on this + *record = value; + return absl::OkStatus(); + }); +} +} // namespace array_record diff --git a/python/array_record_data_source_cpp.h b/python/array_record_data_source_cpp.h new file mode 100644 index 0000000..0ec11fc --- /dev/null +++ b/python/array_record_data_source_cpp.h @@ -0,0 +1,46 @@ +#ifndef THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_ +#define THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_ + + +#include "python/read_instructions_lib.h" +#include "cpp/array_record_reader.h" +#include "riegeli/bytes/file_reader.h" + + +namespace array_record { + +// A Datasource for multiple ArrayRecordFiles. It holds the file reader objects +// and implements the lookup logic. The constructor constructs the global index +// by reading the number of records per file. NumRecords() returns the total +// number of records. GetItem() looks up a single key and returns the record. +// If needed it will open file readers. +class ArrayRecordDataSource { + public: + explicit ArrayRecordDataSource(absl::Span paths_); + + uint64_t NumRecords() const; + + absl::Status GetItem(uint64_t key, absl::string_view* record); + + private: + const std::vector paths_; + std::vector read_instructions_; + uint64_t total_num_records_; + + void CreateReader(int reader_index); + + using Reader = + std::unique_ptr>>; + std::vector readers_; + std::mutex create_reader_mutex_; + + std::pair GetReaderIndexAndPosition(uint64_t key) const; + + absl::Status CheckGroupSize( + absl::string_view filename, + std::optional options_string); +}; +} // namespace array_record + + +#endif // THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_ diff --git a/python/read_instructions_lib.cc b/python/read_instructions_lib.cc new file mode 100644 index 0000000..92dabf5 --- /dev/null +++ b/python/read_instructions_lib.cc @@ -0,0 +1,115 @@ +#include "python/read_instructions_lib.h" + +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "cpp/parallel_for.h" +#include "third_party/re2/re2.h" +#include "thread/threadpool.h" +#include "absl/status/statusor.h" +#include "iostream" + +namespace array_record { + +// Getting the read instructions is cheap but IO bound. We create a temporary +// thread pool to get the number of records. +constexpr int kNumThreadsForReadInstructions = 256; + +absl::StatusOr ReadInstruction::Parse(absl::string_view path) { + static const LazyRE2 kPattern = {R"((.+)\[(\d+):(\d+)\])"}; + std::string filename; + int64_t start, end; + if (RE2::FullMatch(path, *kPattern, &filename, &start, &end)) { + return ReadInstruction{filename, start, end}; + } + return absl::InvalidArgumentError( + absl::StrFormat("Can't parse %s as ReadInstruction", path)); +} + +// Get the read instructions for a list of paths where each path can be: +// - A normal filename. +// - A filename with read instructions: filename[start:end]. +// Unless the filename is given with read instruction, the file will be opened +// to get the total number of records. +absl::StatusOr> GetReadInstructions( + absl::Span paths, + const GetNumRecords& get_num_records) { + std::vector read_instructions; + + // Step 1: Parse potential read instructions. + bool missing_num_records = false; + for (const std::string& path : paths) { + absl::StatusOr read_instruction = + ReadInstruction::Parse(path); + if (read_instruction.ok()) { + read_instructions.push_back(read_instruction.value()); + } else { + missing_num_records = true; + const std::string pattern = path; + read_instructions.push_back({pattern}); + } + } + if (!missing_num_records) { + return read_instructions; + } + + ThreadPool* pool = new ThreadPool( + "ReadInstructionsPool", kNumThreadsForReadInstructions); + pool->StartWorkers(); + + std::vector> filled_instructions; + filled_instructions.resize(read_instructions.size()); + + // Step 2: Match any patterns. + auto match_pattern = [&](int i) { + const std::string& pattern = read_instructions[i].filename; + if (read_instructions[i].end >= 0 || !absl::StrContains(pattern, '?')) { + filled_instructions[i].push_back(std::move(read_instructions[i])); + return; + } + const auto status_or_filenames = file::Match(pattern, file::Defaults()); + if (!status_or_filenames.ok() || status_or_filenames->empty()) { + LOG(ERROR) << "Failed to find matching files for pattern " << pattern; + return; + } + auto filenames = *status_or_filenames; + // Make sure we always read files in the same order. + absl::c_sort(filenames); + filled_instructions[i].reserve(filenames.size()); + for (const std::string& filename : filenames) { + filled_instructions[i].push_back({filename, 0, -1}); + } + }; + + array_record::ParallelFor(Seq(read_instructions.size()), pool, match_pattern); + + // Flatten filled_instructions into read_instructions; + read_instructions.clear(); + for (const auto& instructions : filled_instructions) { + read_instructions.insert(read_instructions.end(), instructions.begin(), + instructions.end()); + } + + // Step 3: Get number of records. + auto add_num_records = [&](int i) { + if (read_instructions[i].end >= 0) { + return; + } + const std::string& filename = read_instructions[i].filename; + + std::cout << file::Exists(filename, file::Defaults()) << "\n"; + if (!file::Exists(filename, file::Defaults()).ok()) { + LOG(ERROR) << "File " << filename << " not found."; + return; + } + + read_instructions[i].end = + static_cast(get_num_records(filename)); + }; + array_record::ParallelFor( + Seq(read_instructions.size()), + pool, + add_num_records); + return read_instructions; +} + +} // namespace array_record diff --git a/python/read_instructions_lib.h b/python/read_instructions_lib.h new file mode 100644 index 0000000..fd5d1a3 --- /dev/null +++ b/python/read_instructions_lib.h @@ -0,0 +1,37 @@ +#ifndef THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_ +#define THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" + +namespace array_record { + +struct ReadInstruction { + std::string filename; + int64_t start = 0; // Always >= 0. + // Must be >= start or -1. -1 indicates that the end of the file. + int64_t end = -1; + + static absl::StatusOr Parse(absl::string_view path); + + int64_t NumRecords() const { return end - start; } +}; + +using GetNumRecords = std::function; + +// Get the read instructions for a list of paths where each path can be: +// - A normal filename. +// - A filename with read instructions: filename[start:end]. +// Unless the filename is given with read instruction the file will be opened +// to get the total number of records. +absl::StatusOr> GetReadInstructions( + absl::Span paths, + const GetNumRecords& get_num_records); + +} // namespace array_record + + +#endif // THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_