From 6514ec6424b5875006d21b3a7758be87a1756f0b Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Wed, 2 Oct 2024 09:07:16 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 681475496 --- .../_src/serialization/tensorstore_utils.cc | 126 ++++++++++++++++++ .../_src/serialization/tensorstore_utils.h | 40 ++++++ .../serialization/tensorstore_utils_test.cc | 82 ++++++++++++ 3 files changed, 248 insertions(+) create mode 100644 checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.cc create mode 100644 checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.h create mode 100644 checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.cc diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.cc b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.cc new file mode 100644 index 00000000..f7323801 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.cc @@ -0,0 +1,126 @@ +#include "third_party/py/orbax/checkpoint/_src/serialization/tensorstore_utils.h" + +#include // NOLINT(build/c++17) +#include +#include // NOLINT(build/c++11) +#include +#include +#include + +#include "third_party/absl/status/status.h" +#include "third_party/absl/status/statusor.h" +#include "third_party/json/include/nlohmann/json.hpp" +#include "third_party/json/include/nlohmann/json_fwd.hpp" + +namespace orbax_checkpoint { + +namespace { + +const char kDefaultDriver[] = "file"; +const char kProcessSubdirPrefix[] = "ocdbt.process_"; +const char kOcdbtProcessIdRe[] = "[A-Za-z0-9]+"; +const char kGcsPathRe[] = "^gs://([^/]*)/(.*)$"; + +} // namespace + +absl::StatusOr get_kvstore_for_gcs( + const std::string& ckpt_path) { + std::regex gcs_path_regex(kGcsPathRe); + std::smatch match; + if (!std::regex_match(ckpt_path, match, gcs_path_regex)) { + return absl::InvalidArgumentError( + "The ckpt_path should contain the bucket name and the " + "file path inside the bucket. Got: " + + ckpt_path); + } + std::string gcs_bucket = match[1]; + std::string path_without_bucket = match[2]; + nlohmann::json json_spec = { + {"driver", "gcs"}, {"bucket", gcs_bucket}, {"path", path_without_bucket}}; + return absl::StatusOr(json_spec); +} + +absl::StatusOr build_kvstore_tspec( + const std::string& directory, const std::optional& name, + bool use_ocdbt, + const std::optional>& process_id) { + std::string default_driver = std::string(kDefaultDriver); + + std::string normalized_directory = + std::filesystem::path(directory).lexically_normal().string(); + normalized_directory = + std::regex_replace(normalized_directory, std::regex("gs:/"), "gs://"); + + bool is_gcs_path = normalized_directory.starts_with("gs://"); + + nlohmann::json kv_spec; + + if (use_ocdbt) { + if (!is_gcs_path && + !std::filesystem::path(normalized_directory).is_absolute()) { + return absl::InvalidArgumentError( + "Checkpoint path should be absolute. Got " + normalized_directory); + } + + if (process_id.has_value()) { + std::string process_id_str; + if (std::holds_alternative(*process_id)) { + process_id_str = std::to_string(std::get(*process_id)); + } else { + process_id_str = std::get(*process_id); + } + std::regex process_id_regex(kOcdbtProcessIdRe); + if (!std::regex_match(process_id_str, process_id_regex)) { + return absl::InvalidArgumentError("process_id must conform to " + + std::string(kOcdbtProcessIdRe) + + " pattern, got " + process_id_str); + } + normalized_directory = + (std::filesystem::path(normalized_directory) / + (std::string(kProcessSubdirPrefix) + process_id_str)) + .string(); + } + + nlohmann::json base_driver_spec; + if (is_gcs_path) { + base_driver_spec = normalized_directory; + } else { + base_driver_spec = nlohmann::json { + {"driver", default_driver}, { "path", normalized_directory } + }; + } + + kv_spec["driver"] = "ocdbt"; + kv_spec["base"] = base_driver_spec; + + if (name.has_value()) { + kv_spec["path"] = *name; + } + + kv_spec["experimental_read_coalescing_threshold_bytes"] = 1000000; + kv_spec["experimental_read_coalescing_merged_bytes"] = 500000000000; + kv_spec["experimental_read_coalescing_interval"] = "1ms"; + kv_spec["cache_pool"] = "cache_pool#ocdbt"; + + } else { + std::string path = + name.has_value() + ? (std::filesystem::path(normalized_directory) / *name).string() + : normalized_directory; + + if (is_gcs_path) { + absl::StatusOr gcs_kvstore = get_kvstore_for_gcs(path); + if (!gcs_kvstore.ok()) { + return gcs_kvstore.status(); + } + kv_spec = std::move(gcs_kvstore).value(); + } else { + kv_spec["driver"] = default_driver; + kv_spec["path"] = path; + } + } + + return kv_spec; +} + +} // namespace orbax_checkpoint diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.h b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.h new file mode 100644 index 00000000..51555ead --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.h @@ -0,0 +1,40 @@ +#ifndef THIRD_PARTY_PY_ORBAX_CHECKPOINT__SRC_SERIALIZATION_TENSORSTORE_UTILS_H_ +#define THIRD_PARTY_PY_ORBAX_CHECKPOINT__SRC_SERIALIZATION_TENSORSTORE_UTILS_H_ + +#include +#include +#include +#include + +#include "third_party/absl/status/statusor.h" +#include "third_party/json/include/nlohmann/json_fwd.hpp" + +namespace orbax_checkpoint { + + +/* + * Constructs a spec for a Tensorstore KvStore. + * + * @param directory Base path (key prefix) of the KvStore, used by the underlying + * file driver. + * @param name Name (filename) of the parameter. + * @param use_ocdbt Whether to use OCDBT driver. + * @param process_id [only used with OCDBT driver] If provided, + * `{directory}/ocdbt.process_{process_id}` path is used as the base path. + * If a string, must conform to [A-Za-z0-9]+ pattern. + * + * @return A Tensorstore KvStore spec in dictionary form. + */ +absl::StatusOr build_kvstore_tspec( + const std::string& directory, + const std::optional& name = std::nullopt, + bool use_ocdbt = true, + const std::optional>& process_id = + std::nullopt); + +absl::StatusOr get_kvstore_for_gcs( + const std::string& ckpt_path); + +} // namespace orbax_checkpoint + +#endif // THIRD_PARTY_PY_ORBAX_CHECKPOINT__SRC_SERIALIZATION_TENSORSTORE_UTILS_H_ diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.cc b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.cc new file mode 100644 index 00000000..896b8d7e --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.cc @@ -0,0 +1,82 @@ +#include "third_party/py/orbax/checkpoint/_src/serialization/tensorstore_utils.h" + +#include +#include + +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" +#include "third_party/json/include/nlohmann/json.hpp" +#include "third_party/json/include/nlohmann/json_fwd.hpp" + +namespace orbax_checkpoint { +namespace { + +using ::testing::TestWithParam; +using ::testing::status::StatusIs; + +struct FormattingTestCase { + std::string test_name; + std::string directory; + std::string param_name; + bool use_ocdbt; + int process_id; + std::string expected_tspec_json_str; +}; + +using TensorStoreUtilTest = TestWithParam; + +TEST_P(TensorStoreUtilTest, BuildKvstoreTspec) { + const FormattingTestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN( + nlohmann::json json_kvstore_spec, + build_kvstore_tspec(test_case.directory, test_case.param_name, + test_case.use_ocdbt, test_case.process_id)); + nlohmann::json expected_json_kvstore_spec = + nlohmann::json::parse(test_case.expected_tspec_json_str); + EXPECT_TRUE(json_kvstore_spec == expected_json_kvstore_spec); +} + +INSTANTIATE_TEST_SUITE_P( + NumbersTestSuiteInstantiation, TensorStoreUtilTest, + testing::ValuesIn({ + {"local_fs_path", "/tmp/local_path", "params/a", false, 13, + R"({"driver":"gfile","path":"/tmp/local_path/params/a"})"}, + {"regular_gcs_path_with_ocdbt", "gs://gcs_bucket/object_path", + "params/a", true, 0, + R"({"driver": "ocdbt", + "base": "gs://gcs_bucket/object_path/ocdbt.process_0", + "path": "params/a", + "experimental_read_coalescing_threshold_bytes": 1000000, + "experimental_read_coalescing_merged_bytes": 500000000000, + "experimental_read_coalescing_interval": "1ms", + "cache_pool": "cache_pool#ocdbt"})"}, + }), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + +TEST(TensorStoreUtilSimpleTest, GetKvstoreForGcs) { + std::vector valid_paths = { + "gs://my-bucket/data/file.txt", + "gs://another-bucket/folder/", + }; + + std::vector expected_json{ + {{"bucket", "my-bucket"}, {"driver", "gcs"}, {"path", "data/file.txt"}}, + {{"bucket", "another-bucket"}, {"driver", "gcs"}, {"path", "folder/"}}}; + + for (int i = 0; i < valid_paths.size(); i++) { + ASSERT_OK_AND_ASSIGN(auto json_spec, get_kvstore_for_gcs(valid_paths[i])); + ASSERT_EQ(json_spec, expected_json[i]); + } + + std::vector invalid_paths = {"gs://invalid-path", + "https://www.example.com"}; + for (int i = 0; i < invalid_paths.size(); i++) { + EXPECT_THAT(get_kvstore_for_gcs(invalid_paths[i]), + StatusIs(util::error::INVALID_ARGUMENT)); + } +} + +} // namespace +} // namespace orbax_checkpoint