Skip to content

Commit

Permalink
Support tf, tfl and stablehlo dialect MLIR in Model Explorer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631968620
  • Loading branch information
yijie-yang authored and copybara-github committed May 8, 2024
1 parent ae9ee55 commit 7e64ce1
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 40 deletions.
111 changes: 78 additions & 33 deletions src/builtin-adapter/model_json_graph_convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
Expand All @@ -37,9 +38,11 @@ limitations under the License.
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
Expand All @@ -57,6 +60,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
Expand All @@ -71,6 +75,12 @@ namespace tooling {
namespace visualization_client {
namespace {

enum class MlirDialect {
kTf,
kTflite,
kStablehlo,
};

// Referred logic from lite/python/flatbuffer_to_mlir.cc.
static mlir::OwningOpRef<mlir::ModuleOp> FlatBufferFileToMlirTranslation(
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context) {
Expand Down Expand Up @@ -181,42 +191,28 @@ absl::Status ConvertToStablehloModule(mlir::ModuleOp module_op) {
return absl::OkStatus();
}

} // namespace

absl::StatusOr<std::string> ConvertStablehloMlirToJson(
const VisualizeConfig& config, absl::string_view model_path) {
mlir::DialectRegistry registry;
// Note: This is more dialects than is currently visualized, but does include
// what is commonly produced by different frameworks. So this would parse
// correctly but then fail in visualization. This should result in a better
// user experience than failing to parse here.
registry.insert<mlir::stablehlo::StablehloDialect, mlir::chlo::ChloDialect,
mlir::mhlo::MhloDialect, mlir::vhlo::VhloDialect,
mlir::func::FuncDialect, mlir::arith::ArithDialect,
mlir::shape::ShapeDialect, mlir::TFL::TensorFlowLiteDialect,
mlir::scf::SCFDialect>();
mlir::MLIRContext context(registry);
mlir::ParserConfig parser_config(&context);
std::string model_content;
RETURN_IF_ERROR(tsl::ReadFileToString(
tsl::Env::Default(), std::string(model_path), &model_content));
auto module_op =
mlir::parseSourceString<::mlir::ModuleOp>(model_content, parser_config);
if (!module_op) return absl::InternalError("Unable to parse module");

// Converts StableHLO MLIR module to JSON string.
std::string json_output;
llvm::raw_string_ostream json_ost(json_output);
mlir::LogicalResult result =
JaxConvertedMlirToJsonTranslate(*module_op, json_ost);
if (mlir::failed(result)) {
return absl::InternalError(
"Failed to convert JAX converted MLIR module to JSON string.");
absl::StatusOr<MlirDialect> GetMlirDialect(mlir::ModuleOp module_op) {
auto fn_range = module_op.getOps<mlir::func::FuncOp>();
if (fn_range.empty()) {
return absl::InvalidArgumentError("Module is empty");
}
mlir::func::FuncOp fn = *fn_range.begin();
mlir::Block& first_block = fn.getBody().front();
mlir::Operation& first_op = first_block.front();
mlir::Dialect* dialect = first_op.getDialect();
if (llvm::isa<mlir::TF::TensorFlowDialect>(dialect)) {
return MlirDialect::kTf;
} else if (llvm::isa<mlir::TFL::TensorFlowLiteDialect>(dialect)) {
return MlirDialect::kTflite;
} else if (llvm::isa<mlir::stablehlo::StablehloDialect>(dialect)) {
return MlirDialect::kStablehlo;
} else {
return absl::InvalidArgumentError("Unsupported dialect");
}

return json_output;
}

} // namespace

absl::StatusOr<std::string> ConvertSavedModelToJson(
const VisualizeConfig& config, absl::string_view model_path) {
tensorflow::SavedModel saved_model;
Expand Down Expand Up @@ -334,5 +330,54 @@ absl::StatusOr<std::string> ConvertFlatbufferToJson(
return json_output;
}

absl::StatusOr<std::string> ConvertMlirToJson(const VisualizeConfig& config,
absl::string_view model_path) {
mlir::DialectRegistry registry;
// Note: This is more dialects than is currently visualized, but does include
// what is commonly produced by different frameworks. So this would parse
// correctly but then fail in visualization. This should result in a better
// user experience than failing to parse here.
registry.insert<mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect,
mlir::stablehlo::StablehloDialect, mlir::chlo::ChloDialect,
mlir::mhlo::MhloDialect, mlir::vhlo::VhloDialect,
mlir::func::FuncDialect, mlir::arith::ArithDialect,
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
mlir::MLIRContext context(registry);
mlir::ParserConfig parser_config(&context);
std::string model_content;
RETURN_IF_ERROR(tsl::ReadFileToString(
tsl::Env::Default(), std::string(model_path), &model_content));
auto module_op =
mlir::parseSourceString<::mlir::ModuleOp>(model_content, parser_config);
if (!module_op) return absl::InternalError("Unable to parse module");

std::string json_output;
llvm::raw_string_ostream json_ost(json_output);

ASSIGN_OR_RETURN(MlirDialect dialect, GetMlirDialect(*module_op));
if (dialect == MlirDialect::kTf || dialect == MlirDialect::kStablehlo) {
if (HasXlaCallModule(*module_op)) {
RETURN_IF_ERROR(ConvertToStablehloModule(*module_op));
}
mlir::LogicalResult result =
JaxConvertedMlirToJsonTranslate(*module_op, json_ost);
if (mlir::failed(result)) {
return absl::InternalError(
"Failed to convert TF or StableHLO MLIR module to JSON string.");
}
} else if (dialect == MlirDialect::kTflite) {
mlir::LogicalResult result =
TfliteMlirToJsonTranslateImpl(config, *module_op, json_ost);
if (mlir::failed(result)) {
return absl::InternalError(
"Failed to convert TFL MLIR module to JSON string.");
}
} else {
return absl::InvalidArgumentError("Unsupported dialect");
}

return json_output;
}

} // namespace visualization_client
} // namespace tooling
6 changes: 3 additions & 3 deletions src/builtin-adapter/model_json_graph_convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ absl::StatusOr<std::string> ConvertFlatbufferToJson(
bool is_modelpath);

// Converts a MLIR textual/bytecode file to visualizer JSON string.
// Note: this expects StableHLO inside the bytecode file.
absl::StatusOr<std::string> ConvertStablehloMlirToJson(
const VisualizeConfig& config, absl::string_view model_path);
// Note: now only supports tf, tfl, stablehlo dialects inside the file.
absl::StatusOr<std::string> ConvertMlirToJson(const VisualizeConfig& config,
absl::string_view model_path);

} // namespace visualization_client
} // namespace tooling
Expand Down
8 changes: 4 additions & 4 deletions src/builtin-adapter/models_to_json_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ using ::tooling::visualization_client::ConvertSavedModelToJson;
enum ModelFormat {
kFlatbuffer,
kSavedModel,
kStablehloMlir,
kMlir,
kFlatbufferDirect,
kSavedModelDirect,
kGraphDefDirect,
Expand Down Expand Up @@ -108,7 +108,7 @@ int main(int argc, char* argv[]) {
}
} else if (extension == ".mlirbc" || extension == ".mlir") {
// StableHLO module represented using MLIR textual or bytecode format.
model_format = kStablehloMlir;
model_format = kMlir;
} else if (extension == ".pb" || extension == ".pbtxt" ||
extension == ".graphdef") {
model_format = kGraphDefDirect;
Expand All @@ -129,8 +129,8 @@ int main(int argc, char* argv[]) {
ConvertFlatbufferToJson(config, input_file, /*is_modelpath=*/true);
break;
}
case kStablehloMlir: {
json_output = ConvertStablehloMlirToJson(config, input_file);
case kMlir: {
json_output = ConvertMlirToJson(config, input_file);
break;
}
case kSavedModel: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ class VisualizeConfig:
def ConvertFlatbufferDirectlyToJson(arg0: VisualizeConfig, arg1: str) -> str: ...
def ConvertFlatbufferToJson(arg0: VisualizeConfig, arg1: str, arg2: bool) -> str: ...
def ConvertGraphDefDirectlyToJson(arg0: VisualizeConfig, arg1: str) -> str: ...
def ConvertMlirToJson(arg0: VisualizeConfig, arg1: str) -> str: ...
def ConvertSavedModelDirectlyToJson(arg0: VisualizeConfig, arg1: str) -> str: ...
def ConvertSavedModelToJson(arg0: VisualizeConfig, arg1: str) -> str: ...
16 changes: 16 additions & 0 deletions src/builtin-adapter/python/convert_wrapper/convert_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ PYBIND11_MODULE(_pywrap_convert_wrapper, m) {
Converts a GraphDef directly to visualizer JSON string without MLIR or
execution. Raises `RuntimeError` exception if failed.
)pbdoc");
m.def(
"ConvertMlirToJson",
[](const VisualizeConfig& config,
absl::string_view model_path) -> std::string {
const absl::StatusOr<std::string> json_or_status =
::tooling::visualization_client::ConvertMlirToJson(config,
model_path);
if (!json_or_status.ok()) {
throw std::runtime_error(json_or_status.status().ToString());
}
return json_or_status.value();
},
R"pbdoc(
Converts a MLIR textual/bytecode file to visualizer JSON string.
Raises `RuntimeError` exception if failed.
)pbdoc");
}

} // namespace pybind11

0 comments on commit 7e64ce1

Please sign in to comment.