diff --git a/src/builtin-adapter/model_json_graph_convert.cc b/src/builtin-adapter/model_json_graph_convert.cc index ad68589b..fb73d2bd 100644 --- a/src/builtin-adapter/model_json_graph_convert.cc +++ b/src/builtin-adapter/model_json_graph_convert.cc @@ -88,6 +88,12 @@ enum class MlirDialect { kStablehlo, }; +struct TfMetadata { + int tf_version; + std::string saved_model_tags; + std::vector exported_names_vector; +}; + // Referred logic from lite/python/flatbuffer_to_mlir.cc. static mlir::OwningOpRef FlatBufferFileToMlirTranslation( llvm::SourceMgr* source_mgr, mlir::MLIRContext* context) { @@ -124,28 +130,36 @@ static int GetTfVersion(const tensorflow::SavedModel& saved_model) { return 2; } -// Obtains saved model tags and exported names from SavedModel proto. -absl::Status AssignTagsAndExportedNames( - const tensorflow::SavedModel& saved_model, const int tf_version, - std::string& tags_str, std::vector& exported_names) { +// Obtains TF metadata from the given SavedModel path. +// The metadata contains the tf version, saved model tags and exported names. +absl::StatusOr ObtainTfMetadata(absl::string_view model_path) { + TfMetadata tf_metadata; + tensorflow::SavedModel saved_model; + RETURN_IF_ERROR(tensorflow::ReadSavedModel(model_path, &saved_model)); + if (saved_model.meta_graphs_size() != 1) { + return absl::InvalidArgumentError(absl::StrCat( + "Only `SavedModel`s with 1 MetaGraph are supported. Instead, it has ", + saved_model.meta_graphs_size())); + } + tf_metadata.tf_version = GetTfVersion(saved_model); const tensorflow::MetaGraphDef::MetaInfoDef& meta_info_def = saved_model.meta_graphs()[0].meta_info_def(); - tags_str = absl::StrJoin(meta_info_def.tags(), ","); + tf_metadata.saved_model_tags = absl::StrJoin(meta_info_def.tags(), ","); // Only TF2 model needs it, TF1 model can use empty exported_names to apply // default values. - if (tf_version == 2) { + if (tf_metadata.tf_version == 2) { const tensorflow::SavedObjectGraph& object_graph_def = saved_model.meta_graphs()[0].object_graph_def(); const auto& saved_object_nodes = object_graph_def.nodes(); // According to saved_object_graph.proto, nodes[0] indicates root node. for (const auto& child : object_graph_def.nodes()[0].children()) { if (saved_object_nodes[child.node_id()].has_function()) { - exported_names.push_back(child.local_name()); + tf_metadata.exported_names_vector.push_back(child.local_name()); } } } - return absl::OkStatus(); + return tf_metadata; } absl::Status DeserializeVhloToStablehlo(mlir::ModuleOp module_op) { @@ -223,28 +237,17 @@ mlir::WalkResult PrintErrorAndInterupt(const absl::Status& status) { absl::StatusOr ConvertSavedModelToJson( const VisualizeConfig& config, absl::string_view model_path) { - tensorflow::SavedModel saved_model; - RETURN_IF_ERROR(tensorflow::ReadSavedModel(model_path, &saved_model)); - if (saved_model.meta_graphs_size() != 1) { - return absl::InvalidArgumentError(absl::StrCat( - "Only `SavedModel`s with 1 MetaGraph are supported. Instead, it has ", - saved_model.meta_graphs_size())); - } - const int tf_version = GetTfVersion(saved_model); - std::string saved_model_tags; - std::vector exported_names_vector; - - RETURN_IF_ERROR(AssignTagsAndExportedNames( - saved_model, tf_version, saved_model_tags, exported_names_vector)); - std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); - absl::Span exported_names(exported_names_vector); + ASSIGN_OR_RETURN(TfMetadata tf_metadata, ObtainTfMetadata(model_path)); + std::unordered_set tags = + absl::StrSplit(tf_metadata.saved_model_tags, ','); + absl::Span exported_names(tf_metadata.exported_names_vector); mlir::MLIRContext context; // Enable parsing of MLIR modules with unregistered dialects. This is safe as // Model Explorer does not execute operations, only visualizes them. context.allowUnregisteredDialects(true); mlir::OwningOpRef module_op; - if (tf_version == 1) { + if (tf_metadata.tf_version == 1) { LOG(INFO) << "Converting SavedModel V1 to MLIR module..."; tensorflow::MLIRImportOptions import_options; import_options.upgrade_legacy = true;