Skip to content

Commit

Permalink
Reduce memory usage in ConvertSavedModelToJson
Browse files Browse the repository at this point in the history
Improved memory management in `ConvertSavedModelToJson` for handling large SavedModel. Refactored the initial proto loading into a dedicated function. This function extracts the required information and then immediately releases the proto from memory, preventing unnecessary memory occupation during subsequent operations.

PiperOrigin-RevId: 655218370
  • Loading branch information
yijie-yang authored and copybara-github committed Jul 23, 2024
1 parent e0b60aa commit 1fa883a
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions src/builtin-adapter/model_json_graph_convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ enum class MlirDialect {
kStablehlo,
};

struct TfMetadata {
int tf_version;
std::string saved_model_tags;
std::vector<std::string> exported_names_vector;
};

// Referred logic from lite/python/flatbuffer_to_mlir.cc.
static mlir::OwningOpRef<mlir::ModuleOp> FlatBufferFileToMlirTranslation(
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context) {
Expand Down Expand Up @@ -124,28 +130,36 @@ static int GetTfVersion(const tensorflow::SavedModel& saved_model) {
return 2;
}

// Obtains saved model tags and exported names from SavedModel proto.
absl::Status AssignTagsAndExportedNames(
const tensorflow::SavedModel& saved_model, const int tf_version,
std::string& tags_str, std::vector<std::string>& exported_names) {
// Obtains TF metadata from the given SavedModel path.
// The metadata contains the tf version, saved model tags and exported names.
absl::StatusOr<TfMetadata> ObtainTfMetadata(absl::string_view model_path) {
TfMetadata tf_metadata;
tensorflow::SavedModel saved_model;
RETURN_IF_ERROR(tensorflow::ReadSavedModel(model_path, &saved_model));
if (saved_model.meta_graphs_size() != 1) {
return absl::InvalidArgumentError(absl::StrCat(
"Only `SavedModel`s with 1 MetaGraph are supported. Instead, it has ",
saved_model.meta_graphs_size()));
}
tf_metadata.tf_version = GetTfVersion(saved_model);
const tensorflow::MetaGraphDef::MetaInfoDef& meta_info_def =
saved_model.meta_graphs()[0].meta_info_def();
tags_str = absl::StrJoin(meta_info_def.tags(), ",");
tf_metadata.saved_model_tags = absl::StrJoin(meta_info_def.tags(), ",");

// Only TF2 model needs it, TF1 model can use empty exported_names to apply
// default values.
if (tf_version == 2) {
if (tf_metadata.tf_version == 2) {
const tensorflow::SavedObjectGraph& object_graph_def =
saved_model.meta_graphs()[0].object_graph_def();
const auto& saved_object_nodes = object_graph_def.nodes();
// According to saved_object_graph.proto, nodes[0] indicates root node.
for (const auto& child : object_graph_def.nodes()[0].children()) {
if (saved_object_nodes[child.node_id()].has_function()) {
exported_names.push_back(child.local_name());
tf_metadata.exported_names_vector.push_back(child.local_name());
}
}
}
return absl::OkStatus();
return tf_metadata;
}

absl::Status DeserializeVhloToStablehlo(mlir::ModuleOp module_op) {
Expand Down Expand Up @@ -223,28 +237,17 @@ mlir::WalkResult PrintErrorAndInterupt(const absl::Status& status) {

absl::StatusOr<std::string> ConvertSavedModelToJson(
const VisualizeConfig& config, absl::string_view model_path) {
tensorflow::SavedModel saved_model;
RETURN_IF_ERROR(tensorflow::ReadSavedModel(model_path, &saved_model));
if (saved_model.meta_graphs_size() != 1) {
return absl::InvalidArgumentError(absl::StrCat(
"Only `SavedModel`s with 1 MetaGraph are supported. Instead, it has ",
saved_model.meta_graphs_size()));
}
const int tf_version = GetTfVersion(saved_model);
std::string saved_model_tags;
std::vector<std::string> exported_names_vector;

RETURN_IF_ERROR(AssignTagsAndExportedNames(
saved_model, tf_version, saved_model_tags, exported_names_vector));
std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
absl::Span<std::string> exported_names(exported_names_vector);
ASSIGN_OR_RETURN(TfMetadata tf_metadata, ObtainTfMetadata(model_path));
std::unordered_set<std::string> tags =
absl::StrSplit(tf_metadata.saved_model_tags, ',');
absl::Span<std::string> exported_names(tf_metadata.exported_names_vector);

mlir::MLIRContext context;
// Enable parsing of MLIR modules with unregistered dialects. This is safe as
// Model Explorer does not execute operations, only visualizes them.
context.allowUnregisteredDialects(true);
mlir::OwningOpRef<mlir::ModuleOp> module_op;
if (tf_version == 1) {
if (tf_metadata.tf_version == 1) {
LOG(INFO) << "Converting SavedModel V1 to MLIR module...";
tensorflow::MLIRImportOptions import_options;
import_options.upgrade_legacy = true;
Expand Down

0 comments on commit 1fa883a

Please sign in to comment.