Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an ArrayRecordDataSource for multiple shards. Adapted from the TfGrain DataSource #56

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,43 @@ py_test(
"@com_google_absl_py//absl/testing:parameterized",
],
)

cc_library(
name = "array_record_data_source_cpp",
srcs = ["array_record_data_source_cpp.cc"],
hdrs = ["array_record_data_source_cpp.h"],
deps = [
":read_instructions_lib",
"//cpp:array_record_reader",
"//cpp:array_record_writer",
"//util/task:status",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_riegeli//riegeli/bytes:file_reader",
],
)

cc_library(
name = "read_instructions_lib",
srcs = ["read_instructions_lib.cc"],
hdrs = ["read_instructions_lib.h"],
deps = [
"//cpp:array_record_reader",
"//cpp:common",
"//cpp:parallel_for",
"//cpp:thread_pool",
"//file/base",
"//third_party/py/grain/google/placer",
"//third_party/re2",
"//thread",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_riegeli//riegeli/bytes:file_reader",
],
)
112 changes: 112 additions & 0 deletions python/array_record_data_source_cpp.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include "python/array_record_data_source_cpp.h"
#include "absl/flags/flag.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "python/read_instructions_lib.h"
#include "cpp/array_record_reader.h"
#include "cpp/array_record_writer.h"


namespace array_record {

using ArrayRecordReaderOptions = ::array_record::ArrayRecordReaderBase::Options;
using ArrayRecordWriterOptions = ::array_record::ArrayRecordWriterBase::Options;

static uint64_t ArrayRecordGetNumRecords(const std::string& filename) {
const array_record::ArrayRecordReader<riegeli::FileReader<>> reader(
std::forward_as_tuple(filename));
return reader.NumRecords();
}

ArrayRecordDataSource::ArrayRecordDataSource(absl::Span<std::string> paths_){
absl::StatusOr<std::vector<ReadInstruction>> read_instructions_or_failure =
GetReadInstructions(paths_, ArrayRecordGetNumRecords);
CHECK_OK(read_instructions_or_failure);
read_instructions_ = *read_instructions_or_failure;

total_num_records_ = 0;
for (const auto& ri : read_instructions_) {
total_num_records_ += ri.NumRecords();
}
readers_.resize(read_instructions_.size());
}

uint64_t ArrayRecordDataSource::NumRecords() const {return total_num_records_;}

std::pair<int, uint64_t> ArrayRecordDataSource::GetReaderIndexAndPosition(
uint64_t key) const {
int reader_index = 0;
CHECK(key < NumRecords()) << "Invalid key " << key;
while (key >= read_instructions_[reader_index].NumRecords()) {
key -= read_instructions_[reader_index].NumRecords();
reader_index++;
}
key += read_instructions_[reader_index].start;
return {reader_index, key};
}

absl::Status ArrayRecordDataSource::CheckGroupSize(
const absl::string_view filename,
const std::optional<std::string> options_string) {
// Check that ArrayRecord files were created with group_size=1. Old files
// (prior 2022-10) don't have this info.
if (!options_string.has_value()) {
return absl::OkStatus();
}
auto maybe_options = ArrayRecordWriterOptions::FromString(*options_string);
if (!maybe_options.ok()) {
return maybe_options.status();
}
const int group_size = maybe_options->group_size();
if (group_size != 1) {
return absl::InvalidArgumentError(absl::StrCat(
"File ", filename, " was created with group size ", group_size,
". Grain requires group size 1 for good performance. Please "
"re-generate your ArrayRecord files with 'group_size:1'."));
}
return absl::OkStatus();
}

void ArrayRecordDataSource::CreateReader(const int reader_index) {
// See b/262550570 for the readahead buffer size.
ArrayRecordReaderOptions array_record_reader_options;
array_record_reader_options.set_max_parallelism(0);
array_record_reader_options.set_readahead_buffer_size(0);
riegeli::FileReaderBase::Options file_reader_options;
file_reader_options.set_buffer_size(1 << 15);
// Copy is on purpose.
std::string filename = read_instructions_[reader_index].filename;
auto reader = std::make_unique<
array_record::ArrayRecordReader<riegeli::FileReader<>>>(
std::forward_as_tuple(filename, file_reader_options),
array_record_reader_options, array_record::ArrayRecordGlobalPool());
const auto status = CheckGroupSize(filename, reader->WriterOptionsString());
if (!status.ok()) {
LOG(ERROR) << status;
}
{
const std::lock_guard<std::mutex> lock(create_reader_mutex_);
if (readers_[reader_index] == nullptr) {
readers_[reader_index] = std::move(reader);
}
}
}

absl::Status ArrayRecordDataSource::GetItem(
uint64_t key, absl::string_view* record) {
int reader_index;
uint64_t position;
std::tie(reader_index, position) = GetReaderIndexAndPosition(key);
if (readers_[reader_index] == nullptr) {
CreateReader(reader_index);
}
return readers_[reader_index]->ParallelReadRecordsWithIndices(
{position},
[&](uint64_t read_idx, absl::string_view value) -> absl::Status {
// TODO(amrahmed): Follow on this
*record = value;
return absl::OkStatus();
});
}
} // namespace array_record
46 changes: 46 additions & 0 deletions python/array_record_data_source_cpp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#ifndef THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_
#define THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_


#include "python/read_instructions_lib.h"
#include "cpp/array_record_reader.h"
#include "riegeli/bytes/file_reader.h"


namespace array_record {

// A Datasource for multiple ArrayRecordFiles. It holds the file reader objects
// and implements the lookup logic. The constructor constructs the global index
// by reading the number of records per file. NumRecords() returns the total
// number of records. GetItem() looks up a single key and returns the record.
// If needed it will open file readers.
class ArrayRecordDataSource {
public:
explicit ArrayRecordDataSource(absl::Span<std::string> paths_);

uint64_t NumRecords() const;

absl::Status GetItem(uint64_t key, absl::string_view* record);

private:
const std::vector<std::string> paths_;
std::vector<ReadInstruction> read_instructions_;
uint64_t total_num_records_;

void CreateReader(int reader_index);

using Reader =
std::unique_ptr<array_record::ArrayRecordReader<riegeli::FileReader<>>>;
std::vector<Reader> readers_;
std::mutex create_reader_mutex_;

std::pair<int, uint64_t> GetReaderIndexAndPosition(uint64_t key) const;

absl::Status CheckGroupSize(
absl::string_view filename,
std::optional<std::string> options_string);
};
} // namespace array_record


#endif // THIRD_PARTY_ARRAY_RECORD_PYTHON_ARRAY_RECORD_DATA_SOURCE_CPP_H_
115 changes: 115 additions & 0 deletions python/read_instructions_lib.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include "python/read_instructions_lib.h"

#include "absl/strings/match.h"
#include "absl/strings/str_format.h"
#include "cpp/parallel_for.h"
#include "third_party/re2/re2.h"
#include "thread/threadpool.h"
#include "absl/status/statusor.h"
#include "iostream"

namespace array_record {

// Getting the read instructions is cheap but IO bound. We create a temporary
// thread pool to get the number of records.
constexpr int kNumThreadsForReadInstructions = 256;

absl::StatusOr<ReadInstruction> ReadInstruction::Parse(absl::string_view path) {
static const LazyRE2 kPattern = {R"((.+)\[(\d+):(\d+)\])"};
std::string filename;
int64_t start, end;
if (RE2::FullMatch(path, *kPattern, &filename, &start, &end)) {
return ReadInstruction{filename, start, end};
}
return absl::InvalidArgumentError(
absl::StrFormat("Can't parse %s as ReadInstruction", path));
}

// Get the read instructions for a list of paths where each path can be:
// - A normal filename.
// - A filename with read instructions: filename[start:end].
// Unless the filename is given with read instruction, the file will be opened
// to get the total number of records.
absl::StatusOr<std::vector<ReadInstruction>> GetReadInstructions(
absl::Span<std::string> paths,
const GetNumRecords& get_num_records) {
std::vector<ReadInstruction> read_instructions;

// Step 1: Parse potential read instructions.
bool missing_num_records = false;
for (const std::string& path : paths) {
absl::StatusOr<ReadInstruction> read_instruction =
ReadInstruction::Parse(path);
if (read_instruction.ok()) {
read_instructions.push_back(read_instruction.value());
} else {
missing_num_records = true;
const std::string pattern = path;
read_instructions.push_back({pattern});
}
}
if (!missing_num_records) {
return read_instructions;
}

ThreadPool* pool = new ThreadPool(
"ReadInstructionsPool", kNumThreadsForReadInstructions);
pool->StartWorkers();

std::vector<std::vector<ReadInstruction>> filled_instructions;
filled_instructions.resize(read_instructions.size());

// Step 2: Match any patterns.
auto match_pattern = [&](int i) {
const std::string& pattern = read_instructions[i].filename;
if (read_instructions[i].end >= 0 || !absl::StrContains(pattern, '?')) {
filled_instructions[i].push_back(std::move(read_instructions[i]));
return;
}
const auto status_or_filenames = file::Match(pattern, file::Defaults());
if (!status_or_filenames.ok() || status_or_filenames->empty()) {
LOG(ERROR) << "Failed to find matching files for pattern " << pattern;
return;
}
auto filenames = *status_or_filenames;
// Make sure we always read files in the same order.
absl::c_sort(filenames);
filled_instructions[i].reserve(filenames.size());
for (const std::string& filename : filenames) {
filled_instructions[i].push_back({filename, 0, -1});
}
};

array_record::ParallelFor(Seq(read_instructions.size()), pool, match_pattern);

// Flatten filled_instructions into read_instructions;
read_instructions.clear();
for (const auto& instructions : filled_instructions) {
read_instructions.insert(read_instructions.end(), instructions.begin(),
instructions.end());
}

// Step 3: Get number of records.
auto add_num_records = [&](int i) {
if (read_instructions[i].end >= 0) {
return;
}
const std::string& filename = read_instructions[i].filename;

std::cout << file::Exists(filename, file::Defaults()) << "\n";
if (!file::Exists(filename, file::Defaults()).ok()) {
LOG(ERROR) << "File " << filename << " not found.";
return;
}

read_instructions[i].end =
static_cast<int64_t>(get_num_records(filename));
};
array_record::ParallelFor(
Seq(read_instructions.size()),
pool,
add_num_records);
return read_instructions;
}

} // namespace array_record
37 changes: 37 additions & 0 deletions python/read_instructions_lib.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_
#define THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_

#include <cstdint>
#include <string>
#include <vector>

#include "absl/status/statusor.h"

namespace array_record {

struct ReadInstruction {
std::string filename;
int64_t start = 0; // Always >= 0.
// Must be >= start or -1. -1 indicates that the end of the file.
int64_t end = -1;

static absl::StatusOr<ReadInstruction> Parse(absl::string_view path);

int64_t NumRecords() const { return end - start; }
};

using GetNumRecords = std::function<uint64_t(const std::string&)>;

// Get the read instructions for a list of paths where each path can be:
// - A normal filename.
// - A filename with read instructions: filename[start:end].
// Unless the filename is given with read instruction the file will be opened
// to get the total number of records.
absl::StatusOr<std::vector<ReadInstruction>> GetReadInstructions(
absl::Span<std::string> paths,
const GetNumRecords& get_num_records);

} // namespace array_record


#endif // THIRD_PARTY_ARRAY_RECORD_PYTHON_READ_INTRUCTIONS_LIB_H_