-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an ArrayRecordDataSource for multiple shards. Adapted from the Tf…
…Grain DataSource PiperOrigin-RevId: 534384957
- Loading branch information
1 parent
48abb8c
commit cac26b2
Showing
5 changed files
with
350 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
#include "python/array_record_data_source_cpp.h" | ||
#include "absl/flags/flag.h" | ||
#include "absl/log/check.h" | ||
#include "absl/status/status.h" | ||
#include "absl/strings/string_view.h" | ||
#include "python/read_instructions_lib.h" | ||
#include "cpp/array_record_reader.h" | ||
#include "cpp/array_record_writer.h" | ||
|
||
|
||
namespace array_record { | ||
|
||
using ArrayRecordReaderOptions = ::array_record::ArrayRecordReaderBase::Options; | ||
using ArrayRecordWriterOptions = ::array_record::ArrayRecordWriterBase::Options; | ||
|
||
static uint64_t ArrayRecordGetNumRecords(const std::string& filename) { | ||
const array_record::ArrayRecordReader<riegeli::FileReader<>> reader( | ||
std::forward_as_tuple(filename)); | ||
return reader.NumRecords(); | ||
} | ||
|
||
ArrayRecordDataSource::ArrayRecordDataSource(absl::Span<std::string> paths_){ | ||
absl::StatusOr<std::vector<ReadInstruction>> read_instructions_or_failure = | ||
GetReadInstructions(paths_, ArrayRecordGetNumRecords); | ||
CHECK_OK(read_instructions_or_failure); | ||
read_instructions_ = *read_instructions_or_failure; | ||
|
||
total_num_records_ = 0; | ||
for (const auto& ri : read_instructions_) { | ||
total_num_records_ += ri.NumRecords(); | ||
} | ||
readers_.resize(read_instructions_.size()); | ||
} | ||
|
||
uint64_t ArrayRecordDataSource::NumRecords() const {return total_num_records_;} | ||
|
||
std::pair<int, uint64_t> ArrayRecordDataSource::GetReaderIndexAndPosition( | ||
uint64_t key) const { | ||
int reader_index = 0; | ||
CHECK(key < NumRecords()) << "Invalid key " << key; | ||
while (key >= read_instructions_[reader_index].NumRecords()) { | ||
key -= read_instructions_[reader_index].NumRecords(); | ||
reader_index++; | ||
} | ||
key += read_instructions_[reader_index].start; | ||
return {reader_index, key}; | ||
} | ||
|
||
absl::Status ArrayRecordDataSource::CheckGroupSize( | ||
const absl::string_view filename, | ||
const std::optional<std::string> options_string) { | ||
// Check that ArrayRecord files were created with group_size=1. Old files | ||
// (prior 2022-10) don't have this info. | ||
if (!options_string.has_value()) { | ||
return absl::OkStatus(); | ||
} | ||
auto maybe_options = ArrayRecordWriterOptions::FromString(*options_string); | ||
if (!maybe_options.ok()) { | ||
return maybe_options.status(); | ||
} | ||
const int group_size = maybe_options->group_size(); | ||
if (group_size != 1) { | ||
return absl::InvalidArgumentError(absl::StrCat( | ||
"File ", filename, " was created with group size ", group_size, | ||
". Grain requires group size 1 for good performance. Please " | ||
"re-generate your ArrayRecord files with 'group_size:1'.")); | ||
} | ||
return absl::OkStatus(); | ||
} | ||
|
||
void ArrayRecordDataSource::CreateReader(const int reader_index) { | ||
// See b/262550570 for the readahead buffer size. | ||
ArrayRecordReaderOptions array_record_reader_options; | ||
array_record_reader_options.set_max_parallelism(0); | ||
array_record_reader_options.set_readahead_buffer_size(0); | ||
riegeli::FileReaderBase::Options file_reader_options; | ||
file_reader_options.set_buffer_size(1 << 15); | ||
// Copy is on purpose. | ||
std::string filename = read_instructions_[reader_index].filename; | ||
auto reader = std::make_unique< | ||
array_record::ArrayRecordReader<riegeli::FileReader<>>>( | ||
std::forward_as_tuple(filename, file_reader_options), | ||
array_record_reader_options, array_record::ArrayRecordGlobalPool()); | ||
const auto status = CheckGroupSize(filename, reader->WriterOptionsString()); | ||
if (!status.ok()) { | ||
LOG(ERROR) << status; | ||
} | ||
{ | ||
const std::lock_guard<std::mutex> lock(create_reader_mutex_); | ||
if (readers_[reader_index] == nullptr) { | ||
readers_[reader_index] = std::move(reader); | ||
} | ||
} | ||
} | ||
|
||
absl::Status ArrayRecordDataSource::GetItem( | ||
uint64_t key, absl::string_view* record) { | ||
int reader_index; | ||
uint64_t position; | ||
std::tie(reader_index, position) = GetReaderIndexAndPosition(key); | ||
if (readers_[reader_index] == nullptr) { | ||
CreateReader(reader_index); | ||
} | ||
return readers_[reader_index]->ParallelReadRecordsWithIndices( | ||
{position}, | ||
[&](uint64_t read_idx, absl::string_view value) -> absl::Status { | ||
// TODO(amrahmed): Follow on this | ||
*record = value; | ||
return absl::OkStatus(); | ||
}); | ||
} | ||
} // namespace array_record |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#ifndef THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_ | ||
#define THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_ | ||
|
||
|
||
#include "python/read_instructions_lib.h" | ||
#include "cpp/array_record_reader.h" | ||
#include "riegeli/bytes/file_reader.h" | ||
|
||
|
||
namespace array_record { | ||
|
||
// A Datasource for multiple ArrayRecordFiles. It holds the file reader objects | ||
// and implements the lookup logic. The constructor constructs the global index | ||
// by reading the number of records per file. NumRecords() returns the total | ||
// number of records. GetItem() looks up a single key and returns the record. | ||
// If needed it will open file readers. | ||
class ArrayRecordDataSource { | ||
public: | ||
explicit ArrayRecordDataSource(absl::Span<std::string> paths_); | ||
|
||
uint64_t NumRecords() const; | ||
|
||
absl::Status GetItem(uint64_t key, absl::string_view* record); | ||
|
||
private: | ||
const std::vector<std::string> paths_; | ||
std::vector<ReadInstruction> read_instructions_; | ||
uint64_t total_num_records_; | ||
|
||
void CreateReader(int reader_index); | ||
|
||
using Reader = | ||
std::unique_ptr<array_record::ArrayRecordReader<riegeli::FileReader<>>>; | ||
std::vector<Reader> readers_; | ||
std::mutex create_reader_mutex_; | ||
|
||
std::pair<int, uint64_t> GetReaderIndexAndPosition(uint64_t key) const; | ||
|
||
absl::Status CheckGroupSize( | ||
absl::string_view filename, | ||
std::optional<std::string> options_string); | ||
}; | ||
} // namespace array_record | ||
|
||
|
||
#endif // THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#include "python/read_instructions_lib.h" | ||
|
||
#include "absl/strings/match.h" | ||
#include "absl/strings/str_format.h" | ||
#include "cpp/parallel_for.h" | ||
#include "third_party/re2/re2.h" | ||
#include "thread/threadpool.h" | ||
#include "absl/status/statusor.h" | ||
#include "iostream" | ||
|
||
namespace array_record { | ||
|
||
// Getting the read instructions is cheap but IO bound. We create a temporary | ||
// thread pool to get the number of records. | ||
constexpr int kNumThreadsForReadInstructions = 256; | ||
|
||
absl::StatusOr<ReadInstruction> ReadInstruction::Parse(absl::string_view path) { | ||
static const LazyRE2 kPattern = {R"((.+)\[(\d+):(\d+)\])"}; | ||
std::string filename; | ||
int64_t start, end; | ||
if (RE2::FullMatch(path, *kPattern, &filename, &start, &end)) { | ||
return ReadInstruction{filename, start, end}; | ||
} | ||
return absl::InvalidArgumentError( | ||
absl::StrFormat("Can't parse %s as ReadInstruction", path)); | ||
} | ||
|
||
// Get the read instructions for a list of paths where each path can be: | ||
// - A normal filename. | ||
// - A filename with read instructions: filename[start:end]. | ||
// Unless the filename is given with read instruction, the file will be opened | ||
// to get the total number of records. | ||
absl::StatusOr<std::vector<ReadInstruction>> GetReadInstructions( | ||
absl::Span<std::string> paths, | ||
const GetNumRecords& get_num_records) { | ||
std::vector<ReadInstruction> read_instructions; | ||
|
||
// Step 1: Parse potential read instructions. | ||
bool missing_num_records = false; | ||
for (const std::string& path : paths) { | ||
absl::StatusOr<ReadInstruction> read_instruction = | ||
ReadInstruction::Parse(path); | ||
if (read_instruction.ok()) { | ||
read_instructions.push_back(read_instruction.value()); | ||
} else { | ||
missing_num_records = true; | ||
const std::string pattern = path; | ||
read_instructions.push_back({pattern}); | ||
} | ||
} | ||
if (!missing_num_records) { | ||
return read_instructions; | ||
} | ||
|
||
ThreadPool* pool = new ThreadPool( | ||
"ReadInstructionsPool", kNumThreadsForReadInstructions); | ||
pool->StartWorkers(); | ||
|
||
std::vector<std::vector<ReadInstruction>> filled_instructions; | ||
filled_instructions.resize(read_instructions.size()); | ||
|
||
// Step 2: Match any patterns. | ||
auto match_pattern = [&](int i) { | ||
const std::string& pattern = read_instructions[i].filename; | ||
if (read_instructions[i].end >= 0 || !absl::StrContains(pattern, '?')) { | ||
filled_instructions[i].push_back(std::move(read_instructions[i])); | ||
return; | ||
} | ||
const auto status_or_filenames = file::Match(pattern, file::Defaults()); | ||
if (!status_or_filenames.ok() || status_or_filenames->empty()) { | ||
LOG(ERROR) << "Failed to find matching files for pattern " << pattern; | ||
return; | ||
} | ||
auto filenames = *status_or_filenames; | ||
// Make sure we always read files in the same order. | ||
absl::c_sort(filenames); | ||
filled_instructions[i].reserve(filenames.size()); | ||
for (const std::string& filename : filenames) { | ||
filled_instructions[i].push_back({filename, 0, -1}); | ||
} | ||
}; | ||
|
||
array_record::ParallelFor(Seq(read_instructions.size()), pool, match_pattern); | ||
|
||
// Flatten filled_instructions into read_instructions; | ||
read_instructions.clear(); | ||
for (const auto& instructions : filled_instructions) { | ||
read_instructions.insert(read_instructions.end(), instructions.begin(), | ||
instructions.end()); | ||
} | ||
|
||
// Step 3: Get number of records. | ||
auto add_num_records = [&](int i) { | ||
if (read_instructions[i].end >= 0) { | ||
return; | ||
} | ||
const std::string& filename = read_instructions[i].filename; | ||
|
||
std::cout << file::Exists(filename, file::Defaults()) << "\n"; | ||
if (!file::Exists(filename, file::Defaults()).ok()) { | ||
LOG(ERROR) << "File " << filename << " not found."; | ||
return; | ||
} | ||
|
||
read_instructions[i].end = | ||
static_cast<int64_t>(get_num_records(filename)); | ||
}; | ||
array_record::ParallelFor( | ||
Seq(read_instructions.size()), | ||
pool, | ||
add_num_records); | ||
return read_instructions; | ||
} | ||
|
||
} // namespace array_record |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#ifndef THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_ | ||
#define THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_ | ||
|
||
#include <cstdint> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "absl/status/statusor.h" | ||
|
||
namespace array_record { | ||
|
||
struct ReadInstruction { | ||
std::string filename; | ||
int64_t start = 0; // Always >= 0. | ||
// Must be >= start or -1. -1 indicates that the end of the file. | ||
int64_t end = -1; | ||
|
||
static absl::StatusOr<ReadInstruction> Parse(absl::string_view path); | ||
|
||
int64_t NumRecords() const { return end - start; } | ||
}; | ||
|
||
using GetNumRecords = std::function<uint64_t(const std::string&)>; | ||
|
||
// Get the read instructions for a list of paths where each path can be: | ||
// - A normal filename. | ||
// - A filename with read instructions: filename[start:end]. | ||
// Unless the filename is given with read instruction the file will be opened | ||
// to get the total number of records. | ||
absl::StatusOr<std::vector<ReadInstruction>> GetReadInstructions( | ||
absl::Span<std::string> paths, | ||
const GetNumRecords& get_num_records); | ||
|
||
} // namespace array_record | ||
|
||
|
||
#endif // THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_ |