Skip to content

Commit

Permalink
Use StableHLO adapter implementation as default in MLIR
Browse files Browse the repository at this point in the history
This should support dialects like `mhlo` and other `hlo` family of dialects.

PiperOrigin-RevId: 662596997
  • Loading branch information
yijie-yang authored and copybara-github committed Aug 13, 2024
1 parent b701e79 commit b29ce2d
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions src/builtin-adapter/model_json_graph_convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,21 +391,16 @@ absl::StatusOr<std::string> ConvertMlirToJson(const VisualizeConfig& config,
mlir::Block& block = fop.getBody().front();
mlir::Operation& first_op = block.front();
absl::StatusOr<Subgraph> subgraph;
if (llvm::isa<mlir::stablehlo::StablehloDialect>(
first_op.getDialect())) {
subgraph = StablehloFunctionToSubgraph(config, fop);
} else if (llvm::isa<mlir::TF::TensorFlowDialect>(
first_op.getDialect())) {
if (llvm::isa<mlir::TF::TensorFlowDialect>(first_op.getDialect())) {
subgraph = TfFunctionToSubgraph(config, fop);
} else if (llvm::isa<mlir::TFL::TensorFlowLiteDialect>(
first_op.getDialect())) {
subgraph = TfliteFunctionToSubgraph(config, fop);
} else {
llvm::errs() << "Unknown dialect: "
<< first_op.getDialect()->getNamespace()
<< " in function: " << fop.getSymName()
<< ", we skip serializing this function.\n";
return mlir::WalkResult::skip();
// Use StableHLO adapter as default for all other dialects. It will do
// best effort for "hlo" family of dialects, but no guarantees for the
// others.
subgraph = StablehloFunctionToSubgraph(config, fop);
}
if (!subgraph.ok()) {
return PrintErrorAndInterupt(subgraph.status());
Expand Down

0 comments on commit b29ce2d

Please sign in to comment.