-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 681475496
- Loading branch information
Orbax Authors
committed
Oct 2, 2024
1 parent
4279d83
commit 738a03b
Showing
3 changed files
with
248 additions
and
0 deletions.
There are no files selected for viewing
126 changes: 126 additions & 0 deletions
126
checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#include "third_party/py/orbax/checkpoint/_src/serialization/tensorstore_utils.h" | ||
|
||
#include <filesystem> // NOLINT(build/c++17) | ||
#include <optional> | ||
#include <regex> // NOLINT(build/c++11) | ||
#include <string> | ||
#include <utility> | ||
#include <variant> | ||
|
||
#include "third_party/absl/status/status.h" | ||
#include "third_party/absl/status/statusor.h" | ||
#include "third_party/json/include/nlohmann/json.hpp" | ||
#include "third_party/json/include/nlohmann/json_fwd.hpp" | ||
|
||
namespace orbax_checkpoint { | ||
|
||
namespace { | ||
|
||
const char kDefaultDriver[] = "file"; | ||
const char kProcessSubdirPrefix[] = "ocdbt.process_"; | ||
const char kOcdbtProcessIdRe[] = "[A-Za-z0-9]+"; | ||
const char kGcsPathRe[] = "^gs://([^/]*)/(.*)$"; | ||
|
||
} // namespace | ||
|
||
absl::StatusOr<nlohmann::json> get_kvstore_for_gcs( | ||
const std::string& ckpt_path) { | ||
std::regex gcs_path_regex(kGcsPathRe); | ||
std::smatch match; | ||
if (!std::regex_match(ckpt_path, match, gcs_path_regex)) { | ||
return absl::InvalidArgumentError( | ||
"The ckpt_path should contain the bucket name and the " | ||
"file path inside the bucket. Got: " + | ||
ckpt_path); | ||
} | ||
std::string gcs_bucket = match[1]; | ||
std::string path_without_bucket = match[2]; | ||
nlohmann::json json_spec = { | ||
{"driver", "gcs"}, {"bucket", gcs_bucket}, {"path", path_without_bucket}}; | ||
return absl::StatusOr<nlohmann::json>(json_spec); | ||
} | ||
|
||
absl::StatusOr<nlohmann::json> build_kvstore_tspec( | ||
const std::string& directory, const std::optional<std::string>& name, | ||
bool use_ocdbt, | ||
const std::optional<std::variant<int, std::string>>& process_id) { | ||
std::string default_driver = std::string(kDefaultDriver); | ||
|
||
std::string normalized_directory = | ||
std::filesystem::path(directory).lexically_normal().string(); | ||
normalized_directory = | ||
std::regex_replace(normalized_directory, std::regex("gs:/"), "gs://"); | ||
|
||
bool is_gcs_path = normalized_directory.starts_with("gs://"); | ||
|
||
nlohmann::json kv_spec; | ||
|
||
if (use_ocdbt) { | ||
if (!is_gcs_path && | ||
!std::filesystem::path(normalized_directory).is_absolute()) { | ||
return absl::InvalidArgumentError( | ||
"Checkpoint path should be absolute. Got " + normalized_directory); | ||
} | ||
|
||
if (process_id.has_value()) { | ||
std::string process_id_str; | ||
if (std::holds_alternative<int>(*process_id)) { | ||
process_id_str = std::to_string(std::get<int>(*process_id)); | ||
} else { | ||
process_id_str = std::get<std::string>(*process_id); | ||
} | ||
std::regex process_id_regex(kOcdbtProcessIdRe); | ||
if (!std::regex_match(process_id_str, process_id_regex)) { | ||
return absl::InvalidArgumentError("process_id must conform to " + | ||
std::string(kOcdbtProcessIdRe) + | ||
" pattern, got " + process_id_str); | ||
} | ||
normalized_directory = | ||
(std::filesystem::path(normalized_directory) / | ||
(std::string(kProcessSubdirPrefix) + process_id_str)) | ||
.string(); | ||
} | ||
|
||
nlohmann::json base_driver_spec; | ||
if (is_gcs_path) { | ||
base_driver_spec = normalized_directory; | ||
} else { | ||
base_driver_spec = nlohmann::json { | ||
{"driver", default_driver}, { "path", normalized_directory } | ||
}; | ||
} | ||
|
||
kv_spec["driver"] = "ocdbt"; | ||
kv_spec["base"] = base_driver_spec; | ||
|
||
if (name.has_value()) { | ||
kv_spec["path"] = *name; | ||
} | ||
|
||
kv_spec["experimental_read_coalescing_threshold_bytes"] = 1000000; | ||
kv_spec["experimental_read_coalescing_merged_bytes"] = 500000000000; | ||
kv_spec["experimental_read_coalescing_interval"] = "1ms"; | ||
kv_spec["cache_pool"] = "cache_pool#ocdbt"; | ||
|
||
} else { | ||
std::string path = | ||
name.has_value() | ||
? (std::filesystem::path(normalized_directory) / *name).string() | ||
: normalized_directory; | ||
|
||
if (is_gcs_path) { | ||
absl::StatusOr<nlohmann::json> gcs_kvstore = get_kvstore_for_gcs(path); | ||
if (!gcs_kvstore.ok()) { | ||
return gcs_kvstore.status(); | ||
} | ||
kv_spec = std::move(gcs_kvstore).value(); | ||
} else { | ||
kv_spec["driver"] = default_driver; | ||
kv_spec["path"] = path; | ||
} | ||
} | ||
|
||
return kv_spec; | ||
} | ||
|
||
} // namespace orbax_checkpoint |
40 changes: 40 additions & 0 deletions
40
checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef THIRD_PARTY_PY_ORBAX_CHECKPOINT__SRC_SERIALIZATION_TENSORSTORE_UTILS_H_ | ||
#define THIRD_PARTY_PY_ORBAX_CHECKPOINT__SRC_SERIALIZATION_TENSORSTORE_UTILS_H_ | ||
|
||
#include <optional> | ||
#include <string> | ||
#include <variant> | ||
#include <vector> | ||
|
||
#include "third_party/absl/status/statusor.h" | ||
#include "third_party/json/include/nlohmann/json_fwd.hpp" | ||
|
||
namespace orbax_checkpoint { | ||
|
||
|
||
/* | ||
* Constructs a spec for a Tensorstore KvStore. | ||
* | ||
* @param directory Base path (key prefix) of the KvStore, used by the underlying | ||
* file driver. | ||
* @param name Name (filename) of the parameter. | ||
* @param use_ocdbt Whether to use OCDBT driver. | ||
* @param process_id [only used with OCDBT driver] If provided, | ||
* `{directory}/ocdbt.process_{process_id}` path is used as the base path. | ||
* If a string, must conform to [A-Za-z0-9]+ pattern. | ||
* | ||
* @return A Tensorstore KvStore spec in dictionary form. | ||
*/ | ||
absl::StatusOr<nlohmann::json> build_kvstore_tspec( | ||
const std::string& directory, | ||
const std::optional<std::string>& name = std::nullopt, | ||
bool use_ocdbt = true, | ||
const std::optional<std::variant<int, std::string>>& process_id = | ||
std::nullopt); | ||
|
||
absl::StatusOr<nlohmann::json> get_kvstore_for_gcs( | ||
const std::string& ckpt_path); | ||
|
||
} // namespace orbax_checkpoint | ||
|
||
#endif // THIRD_PARTY_PY_ORBAX_CHECKPOINT__SRC_SERIALIZATION_TENSORSTORE_UTILS_H_ |
82 changes: 82 additions & 0 deletions
82
checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#include "third_party/py/orbax/checkpoint/_src/serialization/tensorstore_utils.h" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "testing/base/public/gmock.h" | ||
#include "testing/base/public/gunit.h" | ||
#include "third_party/json/include/nlohmann/json.hpp" | ||
#include "third_party/json/include/nlohmann/json_fwd.hpp" | ||
|
||
namespace orbax_checkpoint { | ||
namespace { | ||
|
||
using ::testing::TestWithParam; | ||
using ::testing::status::StatusIs; | ||
|
||
struct FormattingTestCase { | ||
std::string test_name; | ||
std::string directory; | ||
std::string param_name; | ||
bool use_ocdbt; | ||
int process_id; | ||
std::string expected_tspec_json_str; | ||
}; | ||
|
||
using TensorStoreUtilTest = TestWithParam<FormattingTestCase>; | ||
|
||
TEST_P(TensorStoreUtilTest, BuildKvstoreTspec) { | ||
const FormattingTestCase& test_case = GetParam(); | ||
ASSERT_OK_AND_ASSIGN( | ||
nlohmann::json json_kvstore_spec, | ||
build_kvstore_tspec(test_case.directory, test_case.param_name, | ||
test_case.use_ocdbt, test_case.process_id)); | ||
nlohmann::json expected_json_kvstore_spec = | ||
nlohmann::json::parse(test_case.expected_tspec_json_str); | ||
EXPECT_TRUE(json_kvstore_spec == expected_json_kvstore_spec); | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P( | ||
NumbersTestSuiteInstantiation, TensorStoreUtilTest, | ||
testing::ValuesIn<FormattingTestCase>({ | ||
{"local_fs_path", "/tmp/local_path", "params/a", false, 13, | ||
R"({"driver":"gfile","path":"/tmp/local_path/params/a"})"}, | ||
{"regular_gcs_path_with_ocdbt", "gs://gcs_bucket/object_path", | ||
"params/a", true, 0, | ||
R"({"driver": "ocdbt", | ||
"base": "gs://gcs_bucket/object_path/ocdbt.process_0", | ||
"path": "params/a", | ||
"experimental_read_coalescing_threshold_bytes": 1000000, | ||
"experimental_read_coalescing_merged_bytes": 500000000000, | ||
"experimental_read_coalescing_interval": "1ms", | ||
"cache_pool": "cache_pool#ocdbt"})"}, | ||
}), | ||
[](const testing::TestParamInfo<TensorStoreUtilTest::ParamType>& info) { | ||
return info.param.test_name; | ||
}); | ||
|
||
TEST(TensorStoreUtilSimpleTest, GetKvstoreForGcs) { | ||
std::vector<std::string> valid_paths = { | ||
"gs://my-bucket/data/file.txt", | ||
"gs://another-bucket/folder/", | ||
}; | ||
|
||
std::vector<nlohmann::json> expected_json{ | ||
{{"bucket", "my-bucket"}, {"driver", "gcs"}, {"path", "data/file.txt"}}, | ||
{{"bucket", "another-bucket"}, {"driver", "gcs"}, {"path", "folder/"}}}; | ||
|
||
for (int i = 0; i < valid_paths.size(); i++) { | ||
ASSERT_OK_AND_ASSIGN(auto json_spec, get_kvstore_for_gcs(valid_paths[i])); | ||
ASSERT_EQ(json_spec, expected_json[i]); | ||
} | ||
|
||
std::vector<std::string> invalid_paths = {"gs://invalid-path", | ||
"https://www.example.com"}; | ||
for (int i = 0; i < invalid_paths.size(); i++) { | ||
EXPECT_THAT(get_kvstore_for_gcs(invalid_paths[i]), | ||
StatusIs(util::error::INVALID_ARGUMENT)); | ||
} | ||
} | ||
|
||
} // namespace | ||
} // namespace orbax_checkpoint |