diff --git a/tools/compile_tf_graph.py b/tools/compile_tf_graph.py index 569e9ff36e..e3c6652efb 100755 --- a/tools/compile_tf_graph.py +++ b/tools/compile_tf_graph.py @@ -72,7 +72,7 @@ def main(argv): argparser.add_argument('--search', type=int, default=0, help='beam search. 0 disable (default), 1 enable') argparser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") argparser.add_argument("--summaries_tensor_name") - argparser.add_argument("--output_file", help='output pb or pbtxt file') + argparser.add_argument("--output_file", help='output pb, pbtxt or meta, metatxt file') argparser.add_argument("--output_file_model_params_list", help="line-based, names of model params") argparser.add_argument("--output_file_state_vars_list", help="line-based, name of state vars") args = argparser.parse_args(argv[1:]) @@ -107,21 +107,28 @@ def main(argv): assert isinstance(summaries_tensor, tf.Tensor), "no summaries in the graph?" tf.identity(summaries_tensor, name=args.summaries_tensor_name) + if args.output_file and os.path.splitext(args.output_file)[1] in [".meta", ".metatxt"]: + # https://www.tensorflow.org/api_guides/python/meta_graph + saver = tf.train.Saver( + var_list=network.get_saveable_params_list(), max_to_keep=2 ** 31 - 1) + graph_def = saver.export_meta_graph() + else: + graph_def = graph.as_graph_def(add_shapes=True) + print("Graph collection keys:", graph.get_all_collection_keys()) print("Graph num operations:", len(graph.get_operations())) - graph_def = graph.as_graph_def(add_shapes=True) print("Graph def size:", Util.human_bytes_size(graph_def.ByteSize())) if args.output_file: filename = args.output_file _, ext = os.path.splitext(filename) - assert ext in [".pb", ".pbtxt"], 'filename %r extension should be pb or pbtxt' % filename + assert ext in [".pb", ".pbtxt", ".meta", ".metatxt"], 'filename %r extension invalid' % filename print("Write graph to file:", filename) graph_io.write_graph( graph_def, logdir=os.path.dirname(filename), name=os.path.basename(filename), - as_text=(ext == ".pbtxt")) + as_text=ext.endswith("txt")) else: print("Use --output_file if you want to store the graph.")