Skip to content

Commit

Permalink
Add export inference_graph (tensorflow#1702)
Browse files Browse the repository at this point in the history
* Add export inference_graph

* Update Readme.md to include export_inference_graph
  • Loading branch information
sguada authored Jun 20, 2017
1 parent f305d2d commit 434c277
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 1 deletion.
23 changes: 23 additions & 0 deletions slim/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,26 @@ py_binary(
":preprocessing_factory",
],
)

py_binary(
name = "export_inference_graph",
srcs = ["export_inference_graph.py"],
deps = [
":dataset_factory",
":nets_factory",
],
)

py_test(
name = "export_inference_graph_test",
size = "medium",
srcs = ["export_inference_graph_test.py"],
srcs_version = "PY2AND3",
tags = [
"manual",
],
deps = [
":export_inference_graph",
":nets_factory",
],
)
67 changes: 66 additions & 1 deletion slim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Maintainers of TF-slim:
<a href='#Training'>Training from scratch</a><br>
<a href='#Tuning'>Fine tuning to a new task</a><br>
<a href='#Eval'>Evaluating performance</a><br>
<a href='#Export'>Exporting Inference Graph</a><br>
<a href='#Troubleshooting'>Troubleshooting</a><br>

# Installation
<a id='Install'></a>
Expand Down Expand Up @@ -204,7 +206,6 @@ Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy |
[MobileNet_v1_1.0_224](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.py)|[mobilenet_v1_1.0_224_2017_06_14.tar.gz](http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz)|70.7|89.5|
[MobileNet_v1_0.50_160](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.50_160_2017_06_14.tar.gz](http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz)|59.9|82.5|
[MobileNet_v1_0.25_128](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.25_128_2017_06_14.tar.gz](http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz)|41.3|66.2|

^ ResNet V2 models use Inception pre-processing and input image size of 299 (use
`--preprocessing_name inception --eval_image_size 299` when using
`eval_image_classifier.py`). Performance numbers for ResNet V2 models are
Expand Down Expand Up @@ -327,8 +328,72 @@ $ python eval_image_classifier.py \
```


# Exporting the Inference Graph
<a id='Export'></a>

Saves out a GraphDef containing the architecture of the model.

To use it with a model name defined by slim, run:

```shell
$ python export_inference_graph.py \
--alsologtostderr \
--model_name=inception_v3 \
--output_file=/tmp/inception_v3_inf_graph.pb

$ python export_inference_graph.py \
--alsologtostderr \
--model_name=mobilenet_v1 \
--image_size=224 \
--output_file=/tmp/mobilenet_v1_224.pb
```

## Freezing the exported Graph
If you then want to use the resulting model with your own or pretrained
checkpoints as part of a mobile model, you can run freeze_graph to get a graph
def with the variables inlined as constants using:

```shell
bazel build tensorflow/python/tools:freeze_graph

bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/tmp/inception_v3_inf_graph.pb \
--input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \
--input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \
--output_node_names=InceptionV3/Predictions/Reshape_1
```

The output node names will vary depending on the model, but you can inspect and
estimate them using the summarize_graph tool:

```shell
bazel build tensorflow/tools/graph_transforms:summarize_graph

bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/tmp/inception_v3_inf_graph.pb
```

## Run label image in C++

To run the resulting graph in C++, you can look at the label_image sample code:

```shell
bazel build tensorflow/examples/label_image:label_image

bazel-bin/tensorflow/examples/label_image/label_image \
--image=${HOME}/Pictures/flowers.jpg \
--input_layer=input \
--output_layer=InceptionV3/Predictions/Reshape_1 \
--graph=/tmp/frozen_inception_v3.pb \
--labels=/tmp/imagenet_slim_labels.txt \
--input_mean=0 \
--input_std=255 \
--logtostderr
```


# Troubleshooting
<a id='Troubleshooting'></a>

#### The model runs out of CPU memory.

Expand Down
122 changes: 122 additions & 0 deletions slim/export_inference_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Saves out a GraphDef containing the architecture of the model.
To use it, run something like this, with a model name defined by slim:
bazel build tensorflow_models/slim:export_inference_graph
bazel-bin/tensorflow_models/slim/export_inference_graph \
--model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb
If you then want to use the resulting model with your own or pretrained
checkpoints as part of a mobile model, you can run freeze_graph to get a graph
def with the variables inlined as constants using:
bazel build tensorflow/python/tools:freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/tmp/inception_v3_inf_graph.pb \
--input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \
--input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \
--output_node_names=InceptionV3/Predictions/Reshape_1
The output node names will vary depending on the model, but you can inspect and
estimate them using the summarize_graph tool:
bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/tmp/inception_v3_inf_graph.pb
To run the resulting graph in C++, you can look at the label_image sample code:
bazel build tensorflow/examples/label_image:label_image
bazel-bin/tensorflow/examples/label_image/label_image \
--image=${HOME}/Pictures/flowers.jpg \
--input_layer=input \
--output_layer=InceptionV3/Predictions/Reshape_1 \
--graph=/tmp/frozen_inception_v3.pb \
--labels=/tmp/imagenet_slim_labels.txt \
--input_mean=0 \
--input_std=255 \
--logtostderr
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.python.platform import gfile
from datasets import dataset_factory
from nets import nets_factory


slim = tf.contrib.slim

tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to save.')

tf.app.flags.DEFINE_boolean(
'is_training', False,
'Whether to save out a training-focused version of the model.')

tf.app.flags.DEFINE_integer(
'default_image_size', 224,
'The image size to use if the model does not define it.')

tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
'The name of the dataset to use with the model.')

tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')

tf.app.flags.DEFINE_string(
'output_file', '', 'Where to save the resulting file to.')

tf.app.flags.DEFINE_string(
'dataset_dir', '', 'Directory to save intermediate dataset files to')

FLAGS = tf.app.flags.FLAGS


def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'validation',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
if hasattr(network_fn, 'default_image_size'):
image_size = network_fn.default_image_size
else:
image_size = FLAGS.default_image_size
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[1, image_size, image_size, 3])
network_fn(placeholder)
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())


if __name__ == '__main__':
tf.app.run()
44 changes: 44 additions & 0 deletions slim/export_inference_graph_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for export_inference_graph."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os


import tensorflow as tf

from tensorflow.python.platform import gfile
from google3.third_party.tensorflow_models.slim import export_inference_graph


class ExportInferenceGraphTest(tf.test.TestCase):

def testExportInferenceGraph(self):
tmpdir = self.get_temp_dir()
output_file = os.path.join(tmpdir, 'inception_v3.pb')
flags = tf.app.flags.FLAGS
flags.output_file = output_file
flags.model_name = 'inception_v3'
flags.dataset_dir = tmpdir
export_inference_graph.main(None)
self.assertTrue(gfile.Exists(output_file))

if __name__ == '__main__':
tf.test.main()

0 comments on commit 434c277

Please sign in to comment.