diff --git a/src/builtin-adapter/WORKSPACE b/src/builtin-adapter/WORKSPACE index 43842c7b..0998b23a 100644 --- a/src/builtin-adapter/WORKSPACE +++ b/src/builtin-adapter/WORKSPACE @@ -13,9 +13,9 @@ http_archive( ], ) -TENSORFLOW_COMMIT = "8b5370df5655da95b113362e18a9cd850ded7973" +TENSORFLOW_COMMIT = "7b26bb0c266f8b122932904b5f1216818429709d" -TENSORFLOW_SHA256 = "b0366cb1eef7bdc18f9e733c082115894f0147f87bf764d0474bd97f686d5d12" +TENSORFLOW_SHA256 = "418bd874023857039ffbadd7ecd95dd191a750aaa9da1170ae71e03946c4a4c0" http_archive( name = "org_tensorflow", diff --git a/src/builtin-adapter/python/convert_wrapper/BUILD b/src/builtin-adapter/python/convert_wrapper/BUILD index 4f623dff..64e1eff4 100644 --- a/src/builtin-adapter/python/convert_wrapper/BUILD +++ b/src/builtin-adapter/python/convert_wrapper/BUILD @@ -16,10 +16,9 @@ pybind_extension( "_pywrap_convert_wrapper.pyi", ], deps = [ + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@pybind11", - "@pybind11_abseil//pybind11_abseil:import_status_module", - "@pybind11_abseil//pybind11_abseil:status_casters", "//:direct_flatbuffer_to_json_graph_convert", "//:direct_saved_model_to_json_graph_convert", "//:model_json_graph_convert", diff --git a/src/builtin-adapter/python/convert_wrapper/convert_wrapper.cc b/src/builtin-adapter/python/convert_wrapper/convert_wrapper.cc index ac57d628..e93a6c4c 100644 --- a/src/builtin-adapter/python/convert_wrapper/convert_wrapper.cc +++ b/src/builtin-adapter/python/convert_wrapper/convert_wrapper.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <stdexcept> +#include <string> + +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "pybind11/pybind11.h" -#include "pybind11_abseil/import_status_module.h" -#include "pybind11_abseil/status_casters.h" // IWYU pragma : keep #include "direct_flatbuffer_to_json_graph_convert.h" #include "direct_saved_model_to_json_graph_convert.h" #include "model_json_graph_convert.h" @@ -27,8 +29,6 @@ using tooling::visualization_client::VisualizeConfig; namespace pybind11 { PYBIND11_MODULE(_pywrap_convert_wrapper, m) { - pybind11::google::ImportStatusModule(); - class_<VisualizeConfig>(m, "VisualizeConfig") .def(init<>()) .def_readwrite("const_element_count_limit", @@ -36,58 +36,87 @@ PYBIND11_MODULE(_pywrap_convert_wrapper, m) { m.def( "ConvertSavedModelToJson", - [](const VisualizeConfig& config, absl::string_view model_path) { - return ::tooling::visualization_client::ConvertSavedModelToJson( - config, model_path); + [](const VisualizeConfig& config, + absl::string_view model_path) -> std::string { + const absl::StatusOr<std::string> json_or_status = + ::tooling::visualization_client::ConvertSavedModelToJson( + config, model_path); + if (!json_or_status.ok()) { + throw std::runtime_error(json_or_status.status().ToString()); + } + return json_or_status.value(); }, R"pbdoc( Converts a SavedModel to visualizer JSON string through tf dialect MLIR - module if succeeded, otherwise raises `StatusNotOk` exception. + module if succeeded, otherwise raises `RuntimeError` exception. )pbdoc"); m.def( "ConvertFlatbufferToJson", [](const VisualizeConfig& config, absl::string_view model_path, - bool is_modelpath) { - return ::tooling::visualization_client::ConvertFlatbufferToJson( - config, model_path, is_modelpath); + bool is_modelpath) -> std::string { + const absl::StatusOr<std::string> json_or_status = + ::tooling::visualization_client::ConvertFlatbufferToJson( + config, model_path, is_modelpath); + if (!json_or_status.ok()) { + throw std::runtime_error(json_or_status.status().ToString()); + } + return json_or_status.value(); }, R"pbdoc( Converts a Flatbuffer to visualizer JSON string through tfl dialect MLIR - module if succeeded, otherwise raises `StatusNotOk` exception. + module if succeeded, otherwise raises `RuntimeError` exception. )pbdoc"); m.def( "ConvertFlatbufferDirectlyToJson", - [](const VisualizeConfig& config, absl::string_view model_path) { - return ::tooling::visualization_client::ConvertFlatbufferDirectlyToJson( - config, model_path); + [](const VisualizeConfig& config, + absl::string_view model_path) -> std::string { + const absl::StatusOr<std::string> json_or_status = + ::tooling::visualization_client::ConvertFlatbufferDirectlyToJson( + config, model_path); + if (!json_or_status.ok()) { + throw std::runtime_error(json_or_status.status().ToString()); + } + return json_or_status.value(); }, R"pbdoc( Converts a Flatbuffer directly to visualizer JSON string without MLIR or - execution. Raises `StatusNotOk` exception if failed. + execution. Raises `RuntimeError` exception if failed. )pbdoc"); m.def( "ConvertSavedModelDirectlyToJson", - [](const VisualizeConfig& config, absl::string_view model_path) { - return ::tooling::visualization_client::ConvertSavedModelDirectlyToJson( - config, model_path); + [](const VisualizeConfig& config, + absl::string_view model_path) -> std::string { + const absl::StatusOr<std::string> json_or_status = + ::tooling::visualization_client::ConvertSavedModelDirectlyToJson( + config, model_path); + if (!json_or_status.ok()) { + throw std::runtime_error(json_or_status.status().ToString()); + } + return json_or_status.value(); }, R"pbdoc( Converts a SavedModel directly to visualizer JSON string without MLIR or - execution. Raises `StatusNotOk` exception if failed. + execution. Raises `RuntimeError` exception if failed. )pbdoc"); m.def( "ConvertGraphDefDirectlyToJson", - [](const VisualizeConfig& config, absl::string_view model_path) { - return ::tooling::visualization_client::ConvertGraphDefDirectlyToJson( - config, model_path); + [](const VisualizeConfig& config, + absl::string_view model_path) -> std::string { + const absl::StatusOr<std::string> json_or_status = + ::tooling::visualization_client::ConvertGraphDefDirectlyToJson( + config, model_path); + if (!json_or_status.ok()) { + throw std::runtime_error(json_or_status.status().ToString()); + } + return json_or_status.value(); }, R"pbdoc( Converts a GraphDef directly to visualizer JSON string without MLIR or - execution. Raises `StatusNotOk` exception if failed. + execution. Raises `RuntimeError` exception if failed. )pbdoc"); } diff --git a/src/builtin-adapter/tools/load_opdefs.h b/src/builtin-adapter/tools/load_opdefs.h index 6f326543..4fd1f496 100644 --- a/src/builtin-adapter/tools/load_opdefs.h +++ b/src/builtin-adapter/tools/load_opdefs.h @@ -12,6 +12,10 @@ namespace visualization_client { struct OpMetadata { std::vector<std::string> arguments; std::vector<std::string> results; + + OpMetadata(const std::vector<std::string>& arguments, + const std::vector<std::string>& results) + : arguments(arguments), results(results) {} }; absl::flat_hash_map<std::string, OpMetadata> LoadTfliteOpdefs();