-
Notifications
You must be signed in to change notification settings - Fork 0
/
freeze_graph_test.py
43 lines (39 loc) · 1.63 KB
/
freeze_graph_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph_test.py
# Not very usefull
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.tools import freeze_graph
checkpoint_prefix = "models/saved_checkpoint"
checkpoint_state_name = "checkpoint_state"
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"
# We'll create an input graph that has a single variable containing 1.0,
# and that then multiplies it by 2.
with ops.Graph().as_default():
variable_node = tf.Variable(1.0, name="variable_node")
output_node = tf.multiply(variable_node, 2.0, name="output_node")
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
output = sess.run(output_node)
# tf.assertNear(2.0, output, 0.00001)
saver = tf.train.Saver()
checkpoint_path = saver.save(
sess,
checkpoint_prefix,
global_step=0,
latest_filename=checkpoint_state_name)
input_graph_path = tf.train.write_graph(sess.graph, "models", input_graph_name)
# We save out the graph to disk, and then call the const conversion
# routine.
input_saver_def_path = ""
input_binary = False
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_node_names = "output_node"
output_graph_path = "models/output_graph.pb"
clear_devices = False
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_graph_path, clear_devices, "")