diff --git a/src/builtin-adapter/BUILD b/src/builtin-adapter/BUILD index b16c391d..38d986e1 100644 --- a/src/builtin-adapter/BUILD +++ b/src/builtin-adapter/BUILD @@ -145,6 +145,7 @@ cc_binary( ":models_to_json_lib", ":visualize_config", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/compiler/mlir:init_mlir", "@org_tensorflow//tensorflow/compiler/mlir/lite/tools:command_line_flags", diff --git a/src/builtin-adapter/models_to_json_main.cc b/src/builtin-adapter/models_to_json_main.cc index 4290486b..1ba9b034 100644 --- a/src/builtin-adapter/models_to_json_main.cc +++ b/src/builtin-adapter/models_to_json_main.cc @@ -17,6 +17,7 @@ #include #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "models_to_json_lib.h" @@ -47,9 +48,9 @@ int main(int argc, char* argv[]) { std::vector flag_list = { mlir::Flag::CreateFlag(kInputFileFlag, &input_file, "Input filename or directory", - mlir::Flag::kRequired), + mlir::Flag::kOptional), mlir::Flag::CreateFlag(kOutputFileFlag, &output_file, "Output filename", - mlir::Flag::kRequired), + mlir::Flag::kOptional), mlir::Flag::CreateFlag( kConstElementCountLimitFlag, &const_element_count_limit, "The maximum number of constant elements. If the number exceeds this " @@ -69,6 +70,11 @@ int main(int argc, char* argv[]) { return 1; } + if (output_file.empty()) { + LOG(ERROR) << "Output filename cannot be empty."; + return 1; + } + if (output_file.substr(output_file.size() - 4, 4) != "json") { LOG(WARNING) << "Please specify output format to be JSON."; }