Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681475496
  • Loading branch information
Orbax Authors committed Oct 2, 2024
1 parent 4279d83 commit 738a03b
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 0 deletions.
126 changes: 126 additions & 0 deletions checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "third_party/py/orbax/checkpoint/_src/serialization/tensorstore_utils.h"

#include <filesystem> // NOLINT(build/c++17)
#include <optional>
#include <regex> // NOLINT(build/c++11)
#include <string>
#include <utility>
#include <variant>

#include "third_party/absl/status/status.h"
#include "third_party/absl/status/statusor.h"
#include "third_party/json/include/nlohmann/json.hpp"
#include "third_party/json/include/nlohmann/json_fwd.hpp"

namespace orbax_checkpoint {

namespace {

const char kDefaultDriver[] = "file";
const char kProcessSubdirPrefix[] = "ocdbt.process_";
const char kOcdbtProcessIdRe[] = "[A-Za-z0-9]+";
const char kGcsPathRe[] = "^gs://([^/]*)/(.*)$";

} // namespace

absl::StatusOr<nlohmann::json> get_kvstore_for_gcs(
const std::string& ckpt_path) {
std::regex gcs_path_regex(kGcsPathRe);
std::smatch match;
if (!std::regex_match(ckpt_path, match, gcs_path_regex)) {
return absl::InvalidArgumentError(
"The ckpt_path should contain the bucket name and the "
"file path inside the bucket. Got: " +
ckpt_path);
}
std::string gcs_bucket = match[1];
std::string path_without_bucket = match[2];
nlohmann::json json_spec = {
{"driver", "gcs"}, {"bucket", gcs_bucket}, {"path", path_without_bucket}};
return absl::StatusOr<nlohmann::json>(json_spec);
}

absl::StatusOr<nlohmann::json> build_kvstore_tspec(
const std::string& directory, const std::optional<std::string>& name,
bool use_ocdbt,
const std::optional<std::variant<int, std::string>>& process_id) {
std::string default_driver = std::string(kDefaultDriver);

std::string normalized_directory =
std::filesystem::path(directory).lexically_normal().string();
normalized_directory =
std::regex_replace(normalized_directory, std::regex("gs:/"), "gs://");

bool is_gcs_path = normalized_directory.starts_with("gs://");

nlohmann::json kv_spec;

if (use_ocdbt) {
if (!is_gcs_path &&
!std::filesystem::path(normalized_directory).is_absolute()) {
return absl::InvalidArgumentError(
"Checkpoint path should be absolute. Got " + normalized_directory);
}

if (process_id.has_value()) {
std::string process_id_str;
if (std::holds_alternative<int>(*process_id)) {
process_id_str = std::to_string(std::get<int>(*process_id));
} else {
process_id_str = std::get<std::string>(*process_id);
}
std::regex process_id_regex(kOcdbtProcessIdRe);
if (!std::regex_match(process_id_str, process_id_regex)) {
return absl::InvalidArgumentError("process_id must conform to " +
std::string(kOcdbtProcessIdRe) +
" pattern, got " + process_id_str);
}
normalized_directory =
(std::filesystem::path(normalized_directory) /
(std::string(kProcessSubdirPrefix) + process_id_str))
.string();
}

nlohmann::json base_driver_spec;
if (is_gcs_path) {
base_driver_spec = normalized_directory;
} else {
base_driver_spec = nlohmann::json {
{"driver", default_driver}, { "path", normalized_directory }
};
}

kv_spec["driver"] = "ocdbt";
kv_spec["base"] = base_driver_spec;

if (name.has_value()) {
kv_spec["path"] = *name;
}

kv_spec["experimental_read_coalescing_threshold_bytes"] = 1000000;
kv_spec["experimental_read_coalescing_merged_bytes"] = 500000000000;
kv_spec["experimental_read_coalescing_interval"] = "1ms";
kv_spec["cache_pool"] = "cache_pool#ocdbt";

} else {
std::string path =
name.has_value()
? (std::filesystem::path(normalized_directory) / *name).string()
: normalized_directory;

if (is_gcs_path) {
absl::StatusOr<nlohmann::json> gcs_kvstore = get_kvstore_for_gcs(path);
if (!gcs_kvstore.ok()) {
return gcs_kvstore.status();
}
kv_spec = std::move(gcs_kvstore).value();
} else {
kv_spec["driver"] = default_driver;
kv_spec["path"] = path;
}
}

return kv_spec;
}

} // namespace orbax_checkpoint
40 changes: 40 additions & 0 deletions checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef THIRD_PARTY_PY_ORBAX_CHECKPOINT__SRC_SERIALIZATION_TENSORSTORE_UTILS_H_
#define THIRD_PARTY_PY_ORBAX_CHECKPOINT__SRC_SERIALIZATION_TENSORSTORE_UTILS_H_

#include <optional>
#include <string>
#include <variant>
#include <vector>

#include "third_party/absl/status/statusor.h"
#include "third_party/json/include/nlohmann/json_fwd.hpp"

namespace orbax_checkpoint {


/*
* Constructs a spec for a Tensorstore KvStore.
*
* @param directory Base path (key prefix) of the KvStore, used by the underlying
* file driver.
* @param name Name (filename) of the parameter.
* @param use_ocdbt Whether to use OCDBT driver.
* @param process_id [only used with OCDBT driver] If provided,
* `{directory}/ocdbt.process_{process_id}` path is used as the base path.
* If a string, must conform to [A-Za-z0-9]+ pattern.
*
* @return A Tensorstore KvStore spec in dictionary form.
*/
absl::StatusOr<nlohmann::json> build_kvstore_tspec(
const std::string& directory,
const std::optional<std::string>& name = std::nullopt,
bool use_ocdbt = true,
const std::optional<std::variant<int, std::string>>& process_id =
std::nullopt);

absl::StatusOr<nlohmann::json> get_kvstore_for_gcs(
const std::string& ckpt_path);

} // namespace orbax_checkpoint

#endif // THIRD_PARTY_PY_ORBAX_CHECKPOINT__SRC_SERIALIZATION_TENSORSTORE_UTILS_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "third_party/py/orbax/checkpoint/_src/serialization/tensorstore_utils.h"

#include <string>
#include <vector>

#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#include "third_party/json/include/nlohmann/json.hpp"
#include "third_party/json/include/nlohmann/json_fwd.hpp"

namespace orbax_checkpoint {
namespace {

using ::testing::TestWithParam;
using ::testing::status::StatusIs;

struct FormattingTestCase {
std::string test_name;
std::string directory;
std::string param_name;
bool use_ocdbt;
int process_id;
std::string expected_tspec_json_str;
};

using TensorStoreUtilTest = TestWithParam<FormattingTestCase>;

TEST_P(TensorStoreUtilTest, BuildKvstoreTspec) {
const FormattingTestCase& test_case = GetParam();
ASSERT_OK_AND_ASSIGN(
nlohmann::json json_kvstore_spec,
build_kvstore_tspec(test_case.directory, test_case.param_name,
test_case.use_ocdbt, test_case.process_id));
nlohmann::json expected_json_kvstore_spec =
nlohmann::json::parse(test_case.expected_tspec_json_str);
EXPECT_TRUE(json_kvstore_spec == expected_json_kvstore_spec);
}

INSTANTIATE_TEST_SUITE_P(
NumbersTestSuiteInstantiation, TensorStoreUtilTest,
testing::ValuesIn<FormattingTestCase>({
{"local_fs_path", "/tmp/local_path", "params/a", false, 13,
R"({"driver":"gfile","path":"/tmp/local_path/params/a"})"},
{"regular_gcs_path_with_ocdbt", "gs://gcs_bucket/object_path",
"params/a", true, 0,
R"({"driver": "ocdbt",
"base": "gs://gcs_bucket/object_path/ocdbt.process_0",
"path": "params/a",
"experimental_read_coalescing_threshold_bytes": 1000000,
"experimental_read_coalescing_merged_bytes": 500000000000,
"experimental_read_coalescing_interval": "1ms",
"cache_pool": "cache_pool#ocdbt"})"},
}),
[](const testing::TestParamInfo<TensorStoreUtilTest::ParamType>& info) {
return info.param.test_name;
});

TEST(TensorStoreUtilSimpleTest, GetKvstoreForGcs) {
std::vector<std::string> valid_paths = {
"gs://my-bucket/data/file.txt",
"gs://another-bucket/folder/",
};

std::vector<nlohmann::json> expected_json{
{{"bucket", "my-bucket"}, {"driver", "gcs"}, {"path", "data/file.txt"}},
{{"bucket", "another-bucket"}, {"driver", "gcs"}, {"path", "folder/"}}};

for (int i = 0; i < valid_paths.size(); i++) {
ASSERT_OK_AND_ASSIGN(auto json_spec, get_kvstore_for_gcs(valid_paths[i]));
ASSERT_EQ(json_spec, expected_json[i]);
}

std::vector<std::string> invalid_paths = {"gs://invalid-path",
"https://www.example.com"};
for (int i = 0; i < invalid_paths.size(); i++) {
EXPECT_THAT(get_kvstore_for_gcs(invalid_paths[i]),
StatusIs(util::error::INVALID_ARGUMENT));
}
}

} // namespace
} // namespace orbax_checkpoint

0 comments on commit 738a03b

Please sign in to comment.