diff --git a/src/builtin-adapter/model_json_graph_convert.cc b/src/builtin-adapter/model_json_graph_convert.cc index 89724c02..898d7ee6 100644 --- a/src/builtin-adapter/model_json_graph_convert.cc +++ b/src/builtin-adapter/model_json_graph_convert.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" @@ -37,9 +38,11 @@ limitations under the License. #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/IR/AsmState.h" +#include "mlir/IR/Block.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" @@ -57,6 +60,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" @@ -71,6 +75,12 @@ namespace tooling { namespace visualization_client { namespace { +enum class MlirDialect { + kTf, + kTflite, + kStablehlo, +}; + // Referred logic from lite/python/flatbuffer_to_mlir.cc. static mlir::OwningOpRef FlatBufferFileToMlirTranslation( llvm::SourceMgr* source_mgr, mlir::MLIRContext* context) { @@ -181,42 +191,28 @@ absl::Status ConvertToStablehloModule(mlir::ModuleOp module_op) { return absl::OkStatus(); } -} // namespace - -absl::StatusOr ConvertStablehloMlirToJson( - const VisualizeConfig& config, absl::string_view model_path) { - mlir::DialectRegistry registry; - // Note: This is more dialects than is currently visualized, but does include - // what is commonly produced by different frameworks. So this would parse - // correctly but then fail in visualization. This should result in a better - // user experience than failing to parse here. - registry.insert(); - mlir::MLIRContext context(registry); - mlir::ParserConfig parser_config(&context); - std::string model_content; - RETURN_IF_ERROR(tsl::ReadFileToString( - tsl::Env::Default(), std::string(model_path), &model_content)); - auto module_op = - mlir::parseSourceString<::mlir::ModuleOp>(model_content, parser_config); - if (!module_op) return absl::InternalError("Unable to parse module"); - - // Converts StableHLO MLIR module to JSON string. - std::string json_output; - llvm::raw_string_ostream json_ost(json_output); - mlir::LogicalResult result = - JaxConvertedMlirToJsonTranslate(*module_op, json_ost); - if (mlir::failed(result)) { - return absl::InternalError( - "Failed to convert JAX converted MLIR module to JSON string."); +absl::StatusOr GetMlirDialect(mlir::ModuleOp module_op) { + auto fn_range = module_op.getOps(); + if (fn_range.empty()) { + return absl::InvalidArgumentError("Module is empty"); + } + mlir::func::FuncOp fn = *fn_range.begin(); + mlir::Block& first_block = fn.getBody().front(); + mlir::Operation& first_op = first_block.front(); + mlir::Dialect* dialect = first_op.getDialect(); + if (llvm::isa(dialect)) { + return MlirDialect::kTf; + } else if (llvm::isa(dialect)) { + return MlirDialect::kTflite; + } else if (llvm::isa(dialect)) { + return MlirDialect::kStablehlo; + } else { + return absl::InvalidArgumentError("Unsupported dialect"); } - - return json_output; } +} // namespace + absl::StatusOr ConvertSavedModelToJson( const VisualizeConfig& config, absl::string_view model_path) { tensorflow::SavedModel saved_model; @@ -334,5 +330,54 @@ absl::StatusOr ConvertFlatbufferToJson( return json_output; } +absl::StatusOr ConvertMlirToJson(const VisualizeConfig& config, + absl::string_view model_path) { + mlir::DialectRegistry registry; + // Note: This is more dialects than is currently visualized, but does include + // what is commonly produced by different frameworks. So this would parse + // correctly but then fail in visualization. This should result in a better + // user experience than failing to parse here. + registry.insert(); + mlir::MLIRContext context(registry); + mlir::ParserConfig parser_config(&context); + std::string model_content; + RETURN_IF_ERROR(tsl::ReadFileToString( + tsl::Env::Default(), std::string(model_path), &model_content)); + auto module_op = + mlir::parseSourceString<::mlir::ModuleOp>(model_content, parser_config); + if (!module_op) return absl::InternalError("Unable to parse module"); + + std::string json_output; + llvm::raw_string_ostream json_ost(json_output); + + ASSIGN_OR_RETURN(MlirDialect dialect, GetMlirDialect(*module_op)); + if (dialect == MlirDialect::kTf || dialect == MlirDialect::kStablehlo) { + if (HasXlaCallModule(*module_op)) { + RETURN_IF_ERROR(ConvertToStablehloModule(*module_op)); + } + mlir::LogicalResult result = + JaxConvertedMlirToJsonTranslate(*module_op, json_ost); + if (mlir::failed(result)) { + return absl::InternalError( + "Failed to convert TF or StableHLO MLIR module to JSON string."); + } + } else if (dialect == MlirDialect::kTflite) { + mlir::LogicalResult result = + TfliteMlirToJsonTranslateImpl(config, *module_op, json_ost); + if (mlir::failed(result)) { + return absl::InternalError( + "Failed to convert TFL MLIR module to JSON string."); + } + } else { + return absl::InvalidArgumentError("Unsupported dialect"); + } + + return json_output; +} + } // namespace visualization_client } // namespace tooling diff --git a/src/builtin-adapter/model_json_graph_convert.h b/src/builtin-adapter/model_json_graph_convert.h index 077c3395..82b34b63 100644 --- a/src/builtin-adapter/model_json_graph_convert.h +++ b/src/builtin-adapter/model_json_graph_convert.h @@ -35,9 +35,9 @@ absl::StatusOr ConvertFlatbufferToJson( bool is_modelpath); // Converts a MLIR textual/bytecode file to visualizer JSON string. -// Note: this expects StableHLO inside the bytecode file. -absl::StatusOr ConvertStablehloMlirToJson( - const VisualizeConfig& config, absl::string_view model_path); +// Note: now only supports tf, tfl, stablehlo dialects inside the file. +absl::StatusOr ConvertMlirToJson(const VisualizeConfig& config, + absl::string_view model_path); } // namespace visualization_client } // namespace tooling diff --git a/src/builtin-adapter/models_to_json_main.cc b/src/builtin-adapter/models_to_json_main.cc index 31463164..d1db4914 100644 --- a/src/builtin-adapter/models_to_json_main.cc +++ b/src/builtin-adapter/models_to_json_main.cc @@ -42,7 +42,7 @@ using ::tooling::visualization_client::ConvertSavedModelToJson; enum ModelFormat { kFlatbuffer, kSavedModel, - kStablehloMlir, + kMlir, kFlatbufferDirect, kSavedModelDirect, kGraphDefDirect, @@ -108,7 +108,7 @@ int main(int argc, char* argv[]) { } } else if (extension == ".mlirbc" || extension == ".mlir") { // StableHLO module represented using MLIR textual or bytecode format. - model_format = kStablehloMlir; + model_format = kMlir; } else if (extension == ".pb" || extension == ".pbtxt" || extension == ".graphdef") { model_format = kGraphDefDirect; @@ -129,8 +129,8 @@ int main(int argc, char* argv[]) { ConvertFlatbufferToJson(config, input_file, /*is_modelpath=*/true); break; } - case kStablehloMlir: { - json_output = ConvertStablehloMlirToJson(config, input_file); + case kMlir: { + json_output = ConvertMlirToJson(config, input_file); break; } case kSavedModel: { diff --git a/src/builtin-adapter/python/convert_wrapper/_pywrap_convert_wrapper.pyi b/src/builtin-adapter/python/convert_wrapper/_pywrap_convert_wrapper.pyi index 07680aed..b0878ad1 100644 --- a/src/builtin-adapter/python/convert_wrapper/_pywrap_convert_wrapper.pyi +++ b/src/builtin-adapter/python/convert_wrapper/_pywrap_convert_wrapper.pyi @@ -20,5 +20,6 @@ class VisualizeConfig: def ConvertFlatbufferDirectlyToJson(arg0: VisualizeConfig, arg1: str) -> str: ... def ConvertFlatbufferToJson(arg0: VisualizeConfig, arg1: str, arg2: bool) -> str: ... def ConvertGraphDefDirectlyToJson(arg0: VisualizeConfig, arg1: str) -> str: ... +def ConvertMlirToJson(arg0: VisualizeConfig, arg1: str) -> str: ... def ConvertSavedModelDirectlyToJson(arg0: VisualizeConfig, arg1: str) -> str: ... def ConvertSavedModelToJson(arg0: VisualizeConfig, arg1: str) -> str: ... diff --git a/src/builtin-adapter/python/convert_wrapper/convert_wrapper.cc b/src/builtin-adapter/python/convert_wrapper/convert_wrapper.cc index e93a6c4c..2edce5d5 100644 --- a/src/builtin-adapter/python/convert_wrapper/convert_wrapper.cc +++ b/src/builtin-adapter/python/convert_wrapper/convert_wrapper.cc @@ -118,6 +118,22 @@ PYBIND11_MODULE(_pywrap_convert_wrapper, m) { Converts a GraphDef directly to visualizer JSON string without MLIR or execution. Raises `RuntimeError` exception if failed. )pbdoc"); + m.def( + "ConvertMlirToJson", + [](const VisualizeConfig& config, + absl::string_view model_path) -> std::string { + const absl::StatusOr json_or_status = + ::tooling::visualization_client::ConvertMlirToJson(config, + model_path); + if (!json_or_status.ok()) { + throw std::runtime_error(json_or_status.status().ToString()); + } + return json_or_status.value(); + }, + R"pbdoc( + Converts a MLIR textual/bytecode file to visualizer JSON string. + Raises `RuntimeError` exception if failed. + )pbdoc"); } } // namespace pybind11