diff --git a/slim/BUILD b/slim/BUILD index 348ca75956..bc38704a36 100644 --- a/slim/BUILD +++ b/slim/BUILD @@ -390,3 +390,26 @@ py_binary( ":preprocessing_factory", ], ) + +py_binary( + name = "export_inference_graph", + srcs = ["export_inference_graph.py"], + deps = [ + ":dataset_factory", + ":nets_factory", + ], +) + +py_test( + name = "export_inference_graph_test", + size = "medium", + srcs = ["export_inference_graph_test.py"], + srcs_version = "PY2AND3", + tags = [ + "manual", + ], + deps = [ + ":export_inference_graph", + ":nets_factory", + ], +) diff --git a/slim/README.md b/slim/README.md index 6fe5a71836..179b806065 100644 --- a/slim/README.md +++ b/slim/README.md @@ -32,6 +32,8 @@ Maintainers of TF-slim: Training from scratch
Fine tuning to a new task
Evaluating performance
+Exporting Inference Graph
+Troubleshooting
# Installation @@ -204,7 +206,6 @@ Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy | [MobileNet_v1_1.0_224](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.py)|[mobilenet_v1_1.0_224_2017_06_14.tar.gz](http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz)|70.7|89.5| [MobileNet_v1_0.50_160](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.50_160_2017_06_14.tar.gz](http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz)|59.9|82.5| [MobileNet_v1_0.25_128](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.25_128_2017_06_14.tar.gz](http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz)|41.3|66.2| - ^ ResNet V2 models use Inception pre-processing and input image size of 299 (use `--preprocessing_name inception --eval_image_size 299` when using `eval_image_classifier.py`). Performance numbers for ResNet V2 models are @@ -327,8 +328,72 @@ $ python eval_image_classifier.py \ ``` +# Exporting the Inference Graph + + +Saves out a GraphDef containing the architecture of the model. + +To use it with a model name defined by slim, run: + +```shell +$ python export_inference_graph.py \ + --alsologtostderr \ + --model_name=inception_v3 \ + --output_file=/tmp/inception_v3_inf_graph.pb + +$ python export_inference_graph.py \ + --alsologtostderr \ + --model_name=mobilenet_v1 \ + --image_size=224 \ + --output_file=/tmp/mobilenet_v1_224.pb +``` + +## Freezing the exported Graph +If you then want to use the resulting model with your own or pretrained +checkpoints as part of a mobile model, you can run freeze_graph to get a graph +def with the variables inlined as constants using: + +```shell +bazel build tensorflow/python/tools:freeze_graph + +bazel-bin/tensorflow/python/tools/freeze_graph \ + --input_graph=/tmp/inception_v3_inf_graph.pb \ + --input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \ + --input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \ + --output_node_names=InceptionV3/Predictions/Reshape_1 +``` + +The output node names will vary depending on the model, but you can inspect and +estimate them using the summarize_graph tool: + +```shell +bazel build tensorflow/tools/graph_transforms:summarize_graph + +bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ + --in_graph=/tmp/inception_v3_inf_graph.pb +``` + +## Run label image in C++ + +To run the resulting graph in C++, you can look at the label_image sample code: + +```shell +bazel build tensorflow/examples/label_image:label_image + +bazel-bin/tensorflow/examples/label_image/label_image \ + --image=${HOME}/Pictures/flowers.jpg \ + --input_layer=input \ + --output_layer=InceptionV3/Predictions/Reshape_1 \ + --graph=/tmp/frozen_inception_v3.pb \ + --labels=/tmp/imagenet_slim_labels.txt \ + --input_mean=0 \ + --input_std=255 \ + --logtostderr +``` + # Troubleshooting + #### The model runs out of CPU memory. diff --git a/slim/export_inference_graph.py b/slim/export_inference_graph.py new file mode 100644 index 0000000000..13f10ce003 --- /dev/null +++ b/slim/export_inference_graph.py @@ -0,0 +1,122 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Saves out a GraphDef containing the architecture of the model. + +To use it, run something like this, with a model name defined by slim: + +bazel build tensorflow_models/slim:export_inference_graph +bazel-bin/tensorflow_models/slim/export_inference_graph \ +--model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb + +If you then want to use the resulting model with your own or pretrained +checkpoints as part of a mobile model, you can run freeze_graph to get a graph +def with the variables inlined as constants using: + +bazel build tensorflow/python/tools:freeze_graph +bazel-bin/tensorflow/python/tools/freeze_graph \ +--input_graph=/tmp/inception_v3_inf_graph.pb \ +--input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \ +--input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \ +--output_node_names=InceptionV3/Predictions/Reshape_1 + +The output node names will vary depending on the model, but you can inspect and +estimate them using the summarize_graph tool: + +bazel build tensorflow/tools/graph_transforms:summarize_graph +bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ +--in_graph=/tmp/inception_v3_inf_graph.pb + +To run the resulting graph in C++, you can look at the label_image sample code: + +bazel build tensorflow/examples/label_image:label_image +bazel-bin/tensorflow/examples/label_image/label_image \ +--image=${HOME}/Pictures/flowers.jpg \ +--input_layer=input \ +--output_layer=InceptionV3/Predictions/Reshape_1 \ +--graph=/tmp/frozen_inception_v3.pb \ +--labels=/tmp/imagenet_slim_labels.txt \ +--input_mean=0 \ +--input_std=255 \ +--logtostderr + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.python.platform import gfile +from datasets import dataset_factory +from nets import nets_factory + + +slim = tf.contrib.slim + +tf.app.flags.DEFINE_string( + 'model_name', 'inception_v3', 'The name of the architecture to save.') + +tf.app.flags.DEFINE_boolean( + 'is_training', False, + 'Whether to save out a training-focused version of the model.') + +tf.app.flags.DEFINE_integer( + 'default_image_size', 224, + 'The image size to use if the model does not define it.') + +tf.app.flags.DEFINE_string('dataset_name', 'imagenet', + 'The name of the dataset to use with the model.') + +tf.app.flags.DEFINE_integer( + 'labels_offset', 0, + 'An offset for the labels in the dataset. This flag is primarily used to ' + 'evaluate the VGG and ResNet architectures which do not use a background ' + 'class for the ImageNet dataset.') + +tf.app.flags.DEFINE_string( + 'output_file', '', 'Where to save the resulting file to.') + +tf.app.flags.DEFINE_string( + 'dataset_dir', '', 'Directory to save intermediate dataset files to') + +FLAGS = tf.app.flags.FLAGS + + +def main(_): + if not FLAGS.output_file: + raise ValueError('You must supply the path to save to with --output_file') + tf.logging.set_verbosity(tf.logging.INFO) + with tf.Graph().as_default() as graph: + dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'validation', + FLAGS.dataset_dir) + network_fn = nets_factory.get_network_fn( + FLAGS.model_name, + num_classes=(dataset.num_classes - FLAGS.labels_offset), + is_training=FLAGS.is_training) + if hasattr(network_fn, 'default_image_size'): + image_size = network_fn.default_image_size + else: + image_size = FLAGS.default_image_size + placeholder = tf.placeholder(name='input', dtype=tf.float32, + shape=[1, image_size, image_size, 3]) + network_fn(placeholder) + graph_def = graph.as_graph_def() + with gfile.GFile(FLAGS.output_file, 'wb') as f: + f.write(graph_def.SerializeToString()) + + +if __name__ == '__main__': + tf.app.run() diff --git a/slim/export_inference_graph_test.py b/slim/export_inference_graph_test.py new file mode 100644 index 0000000000..a730e67e58 --- /dev/null +++ b/slim/export_inference_graph_test.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for export_inference_graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + + +import tensorflow as tf + +from tensorflow.python.platform import gfile +from google3.third_party.tensorflow_models.slim import export_inference_graph + + +class ExportInferenceGraphTest(tf.test.TestCase): + + def testExportInferenceGraph(self): + tmpdir = self.get_temp_dir() + output_file = os.path.join(tmpdir, 'inception_v3.pb') + flags = tf.app.flags.FLAGS + flags.output_file = output_file + flags.model_name = 'inception_v3' + flags.dataset_dir = tmpdir + export_inference_graph.main(None) + self.assertTrue(gfile.Exists(output_file)) + +if __name__ == '__main__': + tf.test.main()