Skip to content

Commit

Permalink
Add an ArrayRecordDataSource for multiple shards. Adapted from the Tf…
Browse files Browse the repository at this point in the history
…Grain DataSource

PiperOrigin-RevId: 534384957
  • Loading branch information
ArrayRecord Team authored and copybara-github committed May 23, 2023
1 parent 48abb8c commit 079903e
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 0 deletions.
46 changes: 46 additions & 0 deletions python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,49 @@ py_test(
"@com_google_absl_py//absl/testing:parameterized",
],
)

cc_library(
name = "array_record_data_source_cpp",
srcs = ["array_record_data_source_cpp.cc"],
hdrs = ["array_record_data_source_cpp.h"],
deps = [
":read_instructions_lib",
"//cpp:array_record_reader",
"//cpp:array_record_writer",
"//third_party/py/grain/google/placer",
"//util/task:status",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_riegeli//riegeli/bytes:file_reader",
],
)

cc_library(
name = "read_instructions_lib",
srcs = ["read_instructions_lib.cc"],
hdrs = ["read_instructions_lib.h"],
deps = [
"//cpp:array_record_reader",
"//cpp:common",
"//cpp:parallel_for",
"//cpp:thread_pool",
"//file/base",
"//third_party/py/grain/google/placer",
"//third_party/re2",
"//thread",
"//util/task:status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_riegeli//riegeli/bytes:file_reader",
],
)
112 changes: 112 additions & 0 deletions python/array_record_data_source_cpp.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include "python/array_record_data_source_cpp.h"
#include "absl/flags/flag.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "python/read_instructions_lib.h"
#include "cpp/array_record_reader.h"
#include "cpp/array_record_writer.h"


namespace array_record {

using ArrayRecordReaderOptions = ::array_record::ArrayRecordReaderBase::Options;
using ArrayRecordWriterOptions = ::array_record::ArrayRecordWriterBase::Options;

static uint64_t ArrayRecordGetNumRecords(const std::string& filename) {
const array_record::ArrayRecordReader<riegeli::FileReader<>> reader(
std::forward_as_tuple(filename));
return reader.NumRecords();
}

ArrayRecordDataSource::ArrayRecordDataSource(absl::Span<std::string> paths_){
absl::StatusOr<std::vector<ReadInstruction>> read_instructions_or_failure =
GetReadInstructions(paths_, ArrayRecordGetNumRecords);
CHECK_OK(read_instructions_or_failure);
read_instructions_ = *read_instructions_or_failure;

total_num_records_ = 0;
for (const auto& ri : read_instructions_) {
total_num_records_ += ri.NumRecords();
}
readers_.resize(read_instructions_.size());
}

uint64_t ArrayRecordDataSource::NumRecords() const {return total_num_records_;}

std::pair<int, uint64_t> ArrayRecordDataSource::GetReaderIndexAndPosition(
uint64_t key) const {
int reader_index = 0;
CHECK(key < NumRecords()) << "Invalid key " << key;
while (key >= read_instructions_[reader_index].NumRecords()) {
key -= read_instructions_[reader_index].NumRecords();
reader_index++;
}
key += read_instructions_[reader_index].start;
return {reader_index, key};
}

absl::Status ArrayRecordDataSource::CheckGroupSize(
const absl::string_view filename,
const std::optional<std::string> options_string) {
// Check that ArrayRecord files were created with group_size=1. Old files
// (prior 2022-10) don't have this info.
if (!options_string.has_value()) {
return absl::OkStatus();
}
auto maybe_options = ArrayRecordWriterOptions::FromString(*options_string);
if (!maybe_options.ok()) {
return maybe_options.status();
}
const int group_size = maybe_options->group_size();
if (group_size != 1) {
return absl::InvalidArgumentError(absl::StrCat(
"File ", filename, " was created with group size ", group_size,
". Grain requires group size 1 for good performance. Please "
"re-generate your ArrayRecord files with 'group_size:1'."));
}
return absl::OkStatus();
}

void ArrayRecordDataSource::CreateReader(const int reader_index) {
// See b/262550570 for the readahead buffer size.
ArrayRecordReaderOptions array_record_reader_options;
array_record_reader_options.set_max_parallelism(0);
array_record_reader_options.set_readahead_buffer_size(0);
riegeli::FileReaderBase::Options file_reader_options;
file_reader_options.set_buffer_size(1 << 15);
// Copy is on purpose.
std::string filename = read_instructions_[reader_index].filename;
auto reader = std::make_unique<
array_record::ArrayRecordReader<riegeli::FileReader<>>>(
std::forward_as_tuple(filename, file_reader_options),
array_record_reader_options, array_record::ArrayRecordGlobalPool());
const auto status = CheckGroupSize(filename, reader->WriterOptionsString());
if (!status.ok()) {
LOG(ERROR) << status;
}
{
const std::lock_guard<std::mutex> lock(create_reader_mutex_);
if (readers_[reader_index] == nullptr) {
readers_[reader_index] = std::move(reader);
}
}
}

absl::Status ArrayRecordDataSource::GetItem(
uint64_t key, absl::string_view* record) {
int reader_index;
uint64_t position;
std::tie(reader_index, position) = GetReaderIndexAndPosition(key);
if (readers_[reader_index] == nullptr) {
CreateReader(reader_index);
}
return readers_[reader_index]->ParallelReadRecordsWithIndices(
{position},
[&](uint64_t read_idx, absl::string_view value) -> absl::Status {
// TODO(amrahmed): Follow on this
*record = value;
return absl::OkStatus();
});
}
} // namespace array_record
46 changes: 46 additions & 0 deletions python/array_record_data_source_cpp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#ifndef THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_
#define THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_


#include "python/read_instructions_lib.h"
#include "cpp/array_record_reader.h"
#include "riegeli/bytes/file_reader.h"


namespace array_record {

// A Datasource for multiple ArrayRecordFiles. It holds the file reader objects
// and implements the lookup logic. The constructor constructs the global index
// by reading the number of records per file. NumRecords() returns the total
// number of records. GetItem() looks up a single key and returns the record.
// If needed it will open file readers.
class ArrayRecordDataSource {
public:
explicit ArrayRecordDataSource(absl::Span<std::string> paths_);

uint64_t NumRecords() const;

absl::Status GetItem(uint64_t key, absl::string_view* record);

private:
const std::vector<std::string> paths_;
std::vector<ReadInstruction> read_instructions_;
uint64_t total_num_records_;

void CreateReader(int reader_index);

using Reader =
std::unique_ptr<array_record::ArrayRecordReader<riegeli::FileReader<>>>;
std::vector<Reader> readers_;
std::mutex create_reader_mutex_;

std::pair<int, uint64_t> GetReaderIndexAndPosition(uint64_t key) const;

absl::Status CheckGroupSize(
absl::string_view filename,
std::optional<std::string> options_string);
};
} // namespace array_record


#endif // THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_
115 changes: 115 additions & 0 deletions python/read_instructions_lib.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include "python/read_instructions_lib.h"

#include "absl/strings/match.h"
#include "absl/strings/str_format.h"
#include "cpp/parallel_for.h"
#include "third_party/re2/re2.h"
#include "thread/threadpool.h"
#include "absl/status/statusor.h"
#include "iostream"

namespace array_record {

// Getting the read instructions is cheap but IO bound. We create a temporary
// thread pool to get the number of records.
constexpr int kNumThreadsForReadInstructions = 256;

absl::StatusOr<ReadInstruction> ReadInstruction::Parse(absl::string_view path) {
static const LazyRE2 kPattern = {R"((.+)\[(\d+):(\d+)\])"};
std::string filename;
int64_t start, end;
if (RE2::FullMatch(path, *kPattern, &filename, &start, &end)) {
return ReadInstruction{filename, start, end};
}
return absl::InvalidArgumentError(
absl::StrFormat("Can't parse %s as ReadInstruction", path));
}

// Get the read instructions for a list of paths where each path can be:
// - A normal filename.
// - A filename with read instructions: filename[start:end].
// Unless the filename is given with read instruction, the file will be opened
// to get the total number of records.
absl::StatusOr<std::vector<ReadInstruction>> GetReadInstructions(
absl::Span<std::string> paths,
const GetNumRecords& get_num_records) {
std::vector<ReadInstruction> read_instructions;

// Step 1: Parse potential read instructions.
bool missing_num_records = false;
for (const std::string& path : paths) {
absl::StatusOr<ReadInstruction> read_instruction =
ReadInstruction::Parse(path);
if (read_instruction.ok()) {
read_instructions.push_back(read_instruction.value());
} else {
missing_num_records = true;
const std::string pattern = path;
read_instructions.push_back({pattern});
}
}
if (!missing_num_records) {
return read_instructions;
}

ThreadPool* pool = new ThreadPool(
"ReadInstructionsPool", kNumThreadsForReadInstructions);
pool->StartWorkers();

std::vector<std::vector<ReadInstruction>> filled_instructions;
filled_instructions.resize(read_instructions.size());

// Step 2: Match any patterns.
auto match_pattern = [&](int i) {
const std::string& pattern = read_instructions[i].filename;
if (read_instructions[i].end >= 0 || !absl::StrContains(pattern, '?')) {
filled_instructions[i].push_back(std::move(read_instructions[i]));
return;
}
const auto status_or_filenames = file::Match(pattern, file::Defaults());
if (!status_or_filenames.ok() || status_or_filenames->empty()) {
LOG(ERROR) << "Failed to find matching files for pattern " << pattern;
return;
}
auto filenames = *status_or_filenames;
// Make sure we always read files in the same order.
absl::c_sort(filenames);
filled_instructions[i].reserve(filenames.size());
for (const std::string& filename : filenames) {
filled_instructions[i].push_back({filename, 0, -1});
}
};

array_record::ParallelFor(Seq(read_instructions.size()), pool, match_pattern);

// Flatten filled_instructions into read_instructions;
read_instructions.clear();
for (const auto& instructions : filled_instructions) {
read_instructions.insert(read_instructions.end(), instructions.begin(),
instructions.end());
}

// Step 3: Get number of records.
auto add_num_records = [&](int i) {
if (read_instructions[i].end >= 0) {
return;
}
const std::string& filename = read_instructions[i].filename;

std::cout << file::Exists(filename, file::Defaults()) << "\n";
if (!file::Exists(filename, file::Defaults()).ok()) {
LOG(ERROR) << "File " << filename << " not found.";
return;
}

read_instructions[i].end =
static_cast<int64_t>(get_num_records(filename));
};
array_record::ParallelFor(
Seq(read_instructions.size()),
pool,
add_num_records);
return read_instructions;
}

} // namespace array_record
37 changes: 37 additions & 0 deletions python/read_instructions_lib.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_
#define THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_

#include <cstdint>
#include <string>
#include <vector>

#include "absl/status/statusor.h"

namespace array_record {

struct ReadInstruction {
std::string filename;
int64_t start = 0; // Always >= 0.
// Must be >= start or -1. -1 indicates that the end of the file.
int64_t end = -1;

static absl::StatusOr<ReadInstruction> Parse(absl::string_view path);

int64_t NumRecords() const { return end - start; }
};

using GetNumRecords = std::function<uint64_t(const std::string&)>;

// Get the read instructions for a list of paths where each path can be:
// - A normal filename.
// - A filename with read instructions: filename[start:end].
// Unless the filename is given with read instruction the file will be opened
// to get the total number of records.
absl::StatusOr<std::vector<ReadInstruction>> GetReadInstructions(
absl::Span<std::string> paths,
const GetNumRecords& get_num_records);

} // namespace array_record


#endif // THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_

0 comments on commit 079903e

Please sign in to comment.