diff --git a/real_nvp/README.md b/real_nvp/README.md new file mode 100644 index 00000000000..20eda29ec3d --- /dev/null +++ b/real_nvp/README.md @@ -0,0 +1,278 @@ +# Real NVP in TensorFlow + +*A Tensorflow implementation of the training procedure of* +[*Density estimation using Real NVP*](https://arxiv.org/abs/1605.08803)*, by +Laurent Dinh, Jascha Sohl-Dickstein and Samy Bengio, for Imagenet +(32x32 and 64x64), CelebA and LSUN Including the scripts to +put the datasets in `.tfrecords` format.* + +We are happy to open source the code for *Real NVP*, a novel approach to +density estimation using deep neural networks that enables tractable density +estimation and efficient one-pass inference and sampling. This model +successfully decomposes images into hierarchical features ranging from +high-level concepts to low-resolution details. Visualizations are available +[here](http://goo.gl/yco14s). + +## Installation +* python 2.7: + * python 3 support is not available yet +* pip (python package manager) + * `apt-get install python-pip` on Ubuntu + * `brew` installs pip along with python on OSX +* Install the dependencies for [LSUN](https://github.com/fyu/lsun.git) + * Install [OpenCV](http://opencv.org/) + * `pip install numpy lmdb` +* Install the python dependencies + * `pip install scipy scikit-image Pillow` +* Install the +[latest Tensorflow Pip package](https://www.tensorflow.org/get_started/os_setup.html#using-pip) +for Python 2.7 + +## Getting Started +Once you have successfully installed the dependencies, you can start by +downloading the repository: +```shell +git clone --recursive https://github.com/tensorflow/models.git +``` +Afterward, you can use the utilities in this folder prepare the datasets. + +## Preparing datasets +### CelebA +For [*CelebA*](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html), download +`img_align_celeba.zip` from the Dropbox link on this +[page](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) under the +link *Align&Cropped Images* in the *Img* directory and `list_eval_partition.txt` +under the link *Train/Val/Test Partitions* in the *Eval* directory. Then do: + +```shell +mkdir celeba +cd celeba +unzip img_align_celeba.zip +``` + +We'll format the training subset: +```shell +python2.7 ../models/real_nvp/celeba_formatting.py \ + --partition_fn list_eval_partition.txt \ + --file_out celeba_train \ + --fn_root img_align_celeba \ + --set 0 +``` + +Then the validation subset: +```shell +python2.7 ../models/real_nvp/celeba_formatting.py \ + --partition_fn list_eval_partition.txt \ + --file_out celeba_valid \ + --fn_root img_align_celeba \ + --set 1 +``` + +And finally the test subset: +```shell +python2.7 ../models/real_nvp/celeba_formatting.py \ + --partition_fn list_eval_partition.txt \ + --file_out celeba_test \ + --fn_root img_align_celeba \ + --set 2 +``` + +Afterward: +```shell +cd .. +``` + +### Small Imagenet +Downloading the [*small Imagenet*](http://image-net.org/small/download.php) +dataset is more straightforward and can be done +entirely in Shell: +```shell +mkdir small_imnet +cd small_imnet +for FILENAME in train_32x32.tar valid_32x32.tar train_64x64.tar valid_64x64.tar +do + curl -O http://image-net.org/small/$FILENAME + tar -xvf $FILENAME +done +``` + +Then, you can format the datasets as follow: +```shell +for DIRNAME in train_32x32 valid_32x32 train_64x64 valid_64x64 +do + python2.7 ../models/real_nvp/imnet_formatting.py \ + --file_out $DIRNAME \ + --fn_root $DIRNAME +done +cd .. +``` + +### LSUN +To prepare the [*LSUN*](http://lsun.cs.princeton.edu/2016/) dataset, we will +need to use the code associated: +```shell +git clone https://github.com/fyu/lsun.git +cd lsun +``` +Then we'll download the db files: +```shell +for CATEGORY in bedroom church_outdoor tower +do + python2.7 download.py -c $CATEGORY + unzip "$CATEGORY"_train_lmdb.zip + unzip "$CATEGORY"_val_lmdb.zip + python2.7 data.py export "$CATEGORY"_train_lmdb \ + --out_dir "$CATEGORY"_train --flat + python2.7 data.py export "$CATEGORY"_val_lmdb \ + --out_dir "$CATEGORY"_val --flat +done +``` + +Finally, we then format the dataset into `.tfrecords`: +```shell +for CATEGORY in bedroom church_outdoor tower +do + python2.7 ../models/real_nvp/lsun_formatting.py \ + --file_out "$CATEGORY"_train \ + --fn_root "$CATEGORY"_train + python2.7 ../models/real_nvp/lsun_formatting.py \ + --file_out "$CATEGORY"_val \ + --fn_root "$CATEGORY"_val +done +cd .. +``` + + +## Training +We'll give an example on how to train a model on the small Imagenet +dataset (32x32): +```shell +cd models/real_nvp/ +python2.7 real_nvp_multiscale_dataset.py \ +--image_size 32 \ +--hpconfig=n_scale=4,base_dim=32,clip_gradient=100,residual_blocks=4 \ +--dataset imnet \ +--traindir /tmp/real_nvp_imnet32/train \ +--logdir /tmp/real_nvp_imnet32/train \ +--data_path ../../small_imnet/train_32x32_?????.tfrecords +``` +In parallel, you can run the script to generate visualization from the model: +```shell +python2.7 real_nvp_multiscale_dataset.py \ +--image_size 32 \ +--hpconfig=n_scale=4,base_dim=32,clip_gradient=100,residual_blocks=4 \ +--dataset imnet \ +--traindir /tmp/real_nvp_imnet32/train \ +--logdir /tmp/real_nvp_imnet32/sample \ +--data_path ../../small_imnet/valid_32x32_?????.tfrecords \ +--mode sample +``` +Additionally, you can also run in the script to evaluate the model on the +validation set: +```shell +python2.7 real_nvp_multiscale_dataset.py \ +--image_size 32 \ +--hpconfig=n_scale=4,base_dim=32,clip_gradient=100,residual_blocks=4 \ +--dataset imnet \ +--traindir /tmp/real_nvp_imnet32/train \ +--logdir /tmp/real_nvp_imnet32/eval \ +--data_path ../../small_imnet/valid_32x32_?????.tfrecords \ +--eval_set_size 50000 +--mode eval +``` +The visualizations and validation set evaluation can be seen through +[Tensorboard](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/README.md). + +Another example would be how to run the model on LSUN (bedroom category): +```shell +# train the model +python2.7 real_nvp_multiscale_dataset.py \ +--image_size 64 \ +--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \ +--dataset lsun \ +--traindir /tmp/real_nvp_church_outdoor/train \ +--logdir /tmp/real_nvp_church_outdoor/train \ +--data_path ../../lsun/church_outdoor_train_?????.tfrecords +``` + +```shell +# sample from the model +python2.7 real_nvp_multiscale_dataset.py \ +--image_size 64 \ +--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \ +--dataset lsun \ +--traindir /tmp/real_nvp_church_outdoor/train \ +--logdir /tmp/real_nvp_church_outdoor/sample \ +--data_path ../../lsun/church_outdoor_val_?????.tfrecords \ +--mode sample +``` + +```shell +# evaluate the model +python2.7 real_nvp_multiscale_dataset.py \ +--image_size 64 \ +--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \ +--dataset lsun \ +--traindir /tmp/real_nvp_church_outdoor/train \ +--logdir /tmp/real_nvp_church_outdoor/eval \ +--data_path ../../lsun/church_outdoor_val_?????.tfrecords \ +--eval_set_size 300 +--mode eval +``` + +Finally, we'll give the commands to run the model on the CelebA dataset: +```shell +# train the model +python2.7 real_nvp_multiscale_dataset.py \ +--image_size 64 \ +--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \ +--dataset lsun \ +--traindir /tmp/real_nvp_celeba/train \ +--logdir /tmp/real_nvp_celeba/train \ +--data_path ../../celeba/celeba_train.tfrecords +``` + +```shell +# sample from the model +python2.7 real_nvp_multiscale_dataset.py \ +--image_size 64 \ +--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \ +--dataset celeba \ +--traindir /tmp/real_nvp_celeba/train \ +--logdir /tmp/real_nvp_celeba/sample \ +--data_path ../../celeba/celeba_valid.tfrecords \ +--mode sample +``` + +```shell +# evaluate the model on validation set +python2.7 real_nvp_multiscale_dataset.py \ +--image_size 64 \ +--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \ +--dataset celeba \ +--traindir /tmp/real_nvp_celeba/train \ +--logdir /tmp/real_nvp_celeba/eval_valid \ +--data_path ../../celeba/celeba_valid.tfrecords \ +--eval_set_size 19867 +--mode eval + +# evaluate the model on test set +python2.7 real_nvp_multiscale_dataset.py \ +--image_size 64 \ +--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \ +--dataset celeba \ +--traindir /tmp/real_nvp_celeba/train \ +--logdir /tmp/real_nvp_celeba/eval_test \ +--data_path ../../celeba/celeba_test.tfrecords \ +--eval_set_size 19962 +--mode eval +``` + +## Credits +This code was written by Laurent Dinh +([@laurent-dinh](https://github.com/laurent-dinh)) with +the help of +Jascha Sohl-Dickstein ([@Sohl-Dickstein](https://github.com/Sohl-Dickstein) +and [jaschasd@google.com](mailto:jaschasd@google.com)), +Samy Bengio, Jon Shlens, Sherry Moore and +David Andersen. diff --git a/real_nvp/__init__.py b/real_nvp/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/real_nvp/celeba_formatting.py b/real_nvp/celeba_formatting.py new file mode 100644 index 00000000000..f1520083fac --- /dev/null +++ b/real_nvp/celeba_formatting.py @@ -0,0 +1,94 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +r"""CelebA dataset formating. + +Download img_align_celeba.zip from +http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html under the +link "Align&Cropped Images" in the "Img" directory and list_eval_partition.txt +under the link "Train/Val/Test Partitions" in the "Eval" directory. Then do: +unzip img_align_celeba.zip + +Use the script as follow: +python celeba_formatting.py \ + --partition_fn [PARTITION_FILE_PATH] \ + --file_out [OUTPUT_FILE_PATH_PREFIX] \ + --fn_root [CELEBA_FOLDER] \ + --set [SUBSET_INDEX] + +""" + +import os +import os.path + +import scipy.io +import scipy.io.wavfile +import scipy.ndimage +import tensorflow as tf + + +tf.flags.DEFINE_string("file_out", "", + "Filename of the output .tfrecords file.") +tf.flags.DEFINE_string("fn_root", "", "Name of root file path.") +tf.flags.DEFINE_string("partition_fn", "", "Partition file path.") +tf.flags.DEFINE_string("set", "", "Name of subset.") + +FLAGS = tf.flags.FLAGS + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def main(): + """Main converter function.""" + # Celeb A + with open(FLAGS.partition_fn, "r") as infile: + img_fn_list = infile.readlines() + img_fn_list = [elem.strip().split() for elem in img_fn_list] + img_fn_list = [elem[0] for elem in img_fn_list if elem[1] == FLAGS.set] + fn_root = FLAGS.fn_root + num_examples = len(img_fn_list) + + file_out = "%s.tfrecords" % FLAGS.file_out + writer = tf.python_io.TFRecordWriter(file_out) + for example_idx, img_fn in enumerate(img_fn_list): + if example_idx % 1000 == 0: + print example_idx, "/", num_examples + image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn)) + rows = image_raw.shape[0] + cols = image_raw.shape[1] + depth = image_raw.shape[2] + image_raw = image_raw.tostring() + example = tf.train.Example( + features=tf.train.Features( + feature={ + "height": _int64_feature(rows), + "width": _int64_feature(cols), + "depth": _int64_feature(depth), + "image_raw": _bytes_feature(image_raw) + } + ) + ) + writer.write(example.SerializeToString()) + writer.close() + + +if __name__ == "__main__": + main() diff --git a/real_nvp/imnet_formatting.py b/real_nvp/imnet_formatting.py new file mode 100644 index 00000000000..954e1cd2624 --- /dev/null +++ b/real_nvp/imnet_formatting.py @@ -0,0 +1,103 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +r"""LSUN dataset formatting. + +Download and format the Imagenet dataset as follow: +mkdir [IMAGENET_PATH] +cd [IMAGENET_PATH] +for FILENAME in train_32x32.tar valid_32x32.tar train_64x64.tar valid_64x64.tar +do + curl -O http://image-net.org/small/$FILENAME + tar -xvf $FILENAME +done + +Then use the script as follow: +for DIRNAME in train_32x32 valid_32x32 train_64x64 valid_64x64 +do + python imnet_formatting.py \ + --file_out $DIRNAME \ + --fn_root $DIRNAME +done + +""" + +import os +import os.path + +import scipy.io +import scipy.io.wavfile +import scipy.ndimage +import tensorflow as tf + + +tf.flags.DEFINE_string("file_out", "", + "Filename of the output .tfrecords file.") +tf.flags.DEFINE_string("fn_root", "", "Name of root file path.") + +FLAGS = tf.flags.FLAGS + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def main(): + """Main converter function.""" + # LSUN + fn_root = FLAGS.fn_root + img_fn_list = os.listdir(fn_root) + img_fn_list = [img_fn for img_fn in img_fn_list + if img_fn.endswith('.png')] + num_examples = len(img_fn_list) + + n_examples_per_file = 10000 + for example_idx, img_fn in enumerate(img_fn_list): + if example_idx % n_examples_per_file == 0: + file_out = "%s_%05d.tfrecords" + file_out = file_out % (FLAGS.file_out, + example_idx // n_examples_per_file) + print "Writing on:", file_out + writer = tf.python_io.TFRecordWriter(file_out) + if example_idx % 1000 == 0: + print example_idx, "/", num_examples + image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn)) + rows = image_raw.shape[0] + cols = image_raw.shape[1] + depth = image_raw.shape[2] + image_raw = image_raw.astype("uint8") + image_raw = image_raw.tostring() + example = tf.train.Example( + features=tf.train.Features( + feature={ + "height": _int64_feature(rows), + "width": _int64_feature(cols), + "depth": _int64_feature(depth), + "image_raw": _bytes_feature(image_raw) + } + ) + ) + writer.write(example.SerializeToString()) + if example_idx % n_examples_per_file == (n_examples_per_file - 1): + writer.close() + writer.close() + + +if __name__ == "__main__": + main() diff --git a/real_nvp/lsun_formatting.py b/real_nvp/lsun_formatting.py new file mode 100644 index 00000000000..715c283df89 --- /dev/null +++ b/real_nvp/lsun_formatting.py @@ -0,0 +1,104 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +r"""LSUN dataset formatting. + +Download and format the LSUN dataset as follow: +git clone https://github.com/fyu/lsun.git +cd lsun +python2.7 download.py -c [CATEGORY] + +Then unzip the downloaded .zip files before executing: +python2.7 data.py export [IMAGE_DB_PATH] --out_dir [LSUN_FOLDER] --flat + +Then use the script as follow: +python lsun_formatting.py \ + --file_out [OUTPUT_FILE_PATH_PREFIX] \ + --fn_root [LSUN_FOLDER] + +""" + +import os +import os.path + +import numpy +import skimage.transform +from PIL import Image +import tensorflow as tf + + +tf.flags.DEFINE_string("file_out", "", + "Filename of the output .tfrecords file.") +tf.flags.DEFINE_string("fn_root", "", "Name of root file path.") + +FLAGS = tf.flags.FLAGS + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def main(): + """Main converter function.""" + fn_root = FLAGS.fn_root + img_fn_list = os.listdir(fn_root) + img_fn_list = [img_fn for img_fn in img_fn_list + if img_fn.endswith('.webp')] + num_examples = len(img_fn_list) + + n_examples_per_file = 10000 + for example_idx, img_fn in enumerate(img_fn_list): + if example_idx % n_examples_per_file == 0: + file_out = "%s_%05d.tfrecords" + file_out = file_out % (FLAGS.file_out, + example_idx // n_examples_per_file) + print "Writing on:", file_out + writer = tf.python_io.TFRecordWriter(file_out) + if example_idx % 1000 == 0: + print example_idx, "/", num_examples + image_raw = numpy.array(Image.open(os.path.join(fn_root, img_fn))) + rows = image_raw.shape[0] + cols = image_raw.shape[1] + depth = image_raw.shape[2] + downscale = min(rows / 96., cols / 96.) + image_raw = skimage.transform.pyramid_reduce(image_raw, downscale) + image_raw *= 255. + image_raw = image_raw.astype("uint8") + rows = image_raw.shape[0] + cols = image_raw.shape[1] + depth = image_raw.shape[2] + image_raw = image_raw.tostring() + example = tf.train.Example( + features=tf.train.Features( + feature={ + "height": _int64_feature(rows), + "width": _int64_feature(cols), + "depth": _int64_feature(depth), + "image_raw": _bytes_feature(image_raw) + } + ) + ) + writer.write(example.SerializeToString()) + if example_idx % n_examples_per_file == (n_examples_per_file - 1): + writer.close() + writer.close() + + +if __name__ == "__main__": + main() diff --git a/real_nvp/real_nvp_multiscale_dataset.py b/real_nvp/real_nvp_multiscale_dataset.py new file mode 100644 index 00000000000..8587261f9b5 --- /dev/null +++ b/real_nvp/real_nvp_multiscale_dataset.py @@ -0,0 +1,1636 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +r"""Script for training, evaluation and sampling for Real NVP. + +$ python real_nvp_multiscale_dataset.py \ +--alsologtostderr \ +--image_size 64 \ +--hpconfig=n_scale=5,base_dim=8 \ +--dataset imnet \ +--data_path [DATA_PATH] +""" + +import time +from datetime import datetime +import os + +import numpy +import tensorflow as tf + +from tensorflow import gfile + +from real_nvp_utils import ( + batch_norm, batch_norm_log_diff, conv_layer, + squeeze_2x2, squeeze_2x2_ordered, standard_normal_ll, + standard_normal_sample, unsqueeze_2x2, variable_on_cpu) + + +tf.flags.DEFINE_string("master", "local", + "BNS name of the TensorFlow master, or local.") + +tf.flags.DEFINE_string("logdir", "/tmp/real_nvp_multiscale", + "Directory to which writes logs.") + +tf.flags.DEFINE_string("traindir", "/tmp/real_nvp_multiscale", + "Directory to which writes logs.") + +tf.flags.DEFINE_integer("train_steps", 1000000000000000000, + "Number of steps to train for.") + +tf.flags.DEFINE_string("data_path", "", "Path to the data.") + +tf.flags.DEFINE_string("mode", "train", + "Mode of execution. Must be 'train', " + "'sample' or 'eval'.") + +tf.flags.DEFINE_string("dataset", "imnet", + "Dataset used. Must be 'imnet', " + "'celeba' or 'lsun'.") + +tf.flags.DEFINE_integer("recursion_type", 2, + "Type of the recursion.") + +tf.flags.DEFINE_integer("image_size", 64, + "Size of the input image.") + +tf.flags.DEFINE_integer("eval_set_size", 0, + "Size of evaluation dataset.") + +tf.flags.DEFINE_string( + "hpconfig", "", + "A comma separated list of hyperparameters for the model. Format is " + "hp1=value1,hp2=value2,etc. If this FLAG is set, the model will be trained " + "with the specified hyperparameters, filling in missing hyperparameters " + "from the default_values in |hyper_params|.") + +FLAGS = tf.flags.FLAGS + +class HParams(object): + """Dictionary of hyperparameters.""" + def __init__(self, **kwargs): + self.dict_ = kwargs + self.__dict__.update(self.dict_) + + def update_config(self, in_string): + """Update the dictionary with a comma separated list.""" + pairs = in_string.split(",") + pairs = [pair.split("=") for pair in pairs] + for key, val in pairs: + self.dict_[key] = type(self.dict_[key])(val) + self.__dict__.update(self.dict_) + return self + + def __getitem__(self, key): + return self.dict_[key] + + def __setitem__(self, key, val): + self.dict_[key] = val + self.__dict__.update(self.dict_) + + +def get_default_hparams(): + """Get the default hyperparameters.""" + return HParams( + batch_size=64, + residual_blocks=2, + n_couplings=2, + n_scale=4, + learning_rate=0.001, + momentum=1e-1, + decay=1e-3, + l2_coeff=0.00005, + clip_gradient=100., + optimizer="adam", + dropout_mask=0, + base_dim=32, + bottleneck=0, + use_batch_norm=1, + alternate=1, + use_aff=1, + skip=1, + data_constraint=.9, + n_opt=0) + + +# RESNET UTILS +def residual_block(input_, dim, name, use_batch_norm=True, + train=True, weight_norm=True, bottleneck=False): + """Residual convolutional block.""" + with tf.variable_scope(name): + res = input_ + if use_batch_norm: + res = batch_norm( + input_=res, dim=dim, name="bn_in", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res = tf.nn.relu(res) + if bottleneck: + res = conv_layer( + input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, + name="h_0", stddev=numpy.sqrt(2. / (dim)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, bias=(not use_batch_norm), + weight_norm=weight_norm, scale=False) + if use_batch_norm: + res = batch_norm( + input_=res, dim=dim, + name="bn_0", scale=False, train=train, + epsilon=1e-4, axes=[0, 1, 2]) + res = tf.nn.relu(res) + res = conv_layer( + input_=res, filter_size=[3, 3], dim_in=dim, + dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)), + strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None, + bias=(not use_batch_norm), + weight_norm=weight_norm, scale=False) + if use_batch_norm: + res = batch_norm( + input_=res, dim=dim, name="bn_1", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res = tf.nn.relu(res) + res = conv_layer( + input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, + name="out", stddev=numpy.sqrt(2. / (1. * dim)), + strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None, + bias=True, weight_norm=weight_norm, scale=True) + else: + res = conv_layer( + input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim, + name="h_0", stddev=numpy.sqrt(2. / (dim)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, bias=(not use_batch_norm), + weight_norm=weight_norm, scale=False) + if use_batch_norm: + res = batch_norm( + input_=res, dim=dim, name="bn_0", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res = tf.nn.relu(res) + res = conv_layer( + input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim, + name="out", stddev=numpy.sqrt(2. / (1. * dim)), + strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None, + bias=True, weight_norm=weight_norm, scale=True) + res += input_ + + return res + + +def resnet(input_, dim_in, dim, dim_out, name, use_batch_norm=True, + train=True, weight_norm=True, residual_blocks=5, + bottleneck=False, skip=True): + """Residual convolutional network.""" + with tf.variable_scope(name): + res = input_ + if residual_blocks != 0: + res = conv_layer( + input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim, + name="h_in", stddev=numpy.sqrt(2. / (dim_in)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, bias=True, + weight_norm=weight_norm, scale=False) + if skip: + out = conv_layer( + input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, + name="skip_in", stddev=numpy.sqrt(2. / (dim)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, bias=True, + weight_norm=weight_norm, scale=True) + + # residual blocks + for idx_block in xrange(residual_blocks): + res = residual_block(res, dim, "block_%d" % idx_block, + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + bottleneck=bottleneck) + if skip: + out += conv_layer( + input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, + name="skip_%d" % idx_block, stddev=numpy.sqrt(2. / (dim)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, bias=True, + weight_norm=weight_norm, scale=True) + # outputs + if skip: + res = out + if use_batch_norm: + res = batch_norm( + input_=res, dim=dim, name="bn_pre_out", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res = tf.nn.relu(res) + res = conv_layer( + input_=res, filter_size=[1, 1], dim_in=dim, + dim_out=dim_out, + name="out", stddev=numpy.sqrt(2. / (1. * dim)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, bias=True, + weight_norm=weight_norm, scale=True) + else: + if bottleneck: + res = conv_layer( + input_=res, filter_size=[1, 1], dim_in=dim_in, dim_out=dim, + name="h_0", stddev=numpy.sqrt(2. / (dim_in)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, bias=(not use_batch_norm), + weight_norm=weight_norm, scale=False) + if use_batch_norm: + res = batch_norm( + input_=res, dim=dim, name="bn_0", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res = tf.nn.relu(res) + res = conv_layer( + input_=res, filter_size=[3, 3], dim_in=dim, + dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, + bias=(not use_batch_norm), + weight_norm=weight_norm, scale=False) + if use_batch_norm: + res = batch_norm( + input_=res, dim=dim, name="bn_1", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res = tf.nn.relu(res) + res = conv_layer( + input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim_out, + name="out", stddev=numpy.sqrt(2. / (1. * dim)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, bias=True, + weight_norm=weight_norm, scale=True) + else: + res = conv_layer( + input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim, + name="h_0", stddev=numpy.sqrt(2. / (dim_in)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, bias=(not use_batch_norm), + weight_norm=weight_norm, scale=False) + if use_batch_norm: + res = batch_norm( + input_=res, dim=dim, name="bn_0", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res = tf.nn.relu(res) + res = conv_layer( + input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim_out, + name="out", stddev=numpy.sqrt(2. / (1. * dim)), + strides=[1, 1, 1, 1], padding="SAME", + nonlinearity=None, bias=True, + weight_norm=weight_norm, scale=True) + return res + + +# COUPLING LAYERS +# masked convolution implementations +def masked_conv_aff_coupling(input_, mask_in, dim, name, + use_batch_norm=True, train=True, weight_norm=True, + reverse=False, residual_blocks=5, + bottleneck=False, use_width=1., use_height=1., + mask_channel=0., skip=True): + """Affine coupling with masked convolution.""" + with tf.variable_scope(name) as scope: + if reverse or (not train): + scope.reuse_variables() + shape = input_.get_shape().as_list() + batch_size = shape[0] + height = shape[1] + width = shape[2] + channels = shape[3] + + # build mask + mask = use_width * numpy.arange(width) + mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask + mask = mask.astype("float32") + mask = tf.mod(mask_in + mask, 2) + mask = tf.reshape(mask, [-1, height, width, 1]) + if mask.get_shape().as_list()[0] == 1: + mask = tf.tile(mask, [batch_size, 1, 1, 1]) + res = input_ * tf.mod(mask_channel + mask, 2) + + # initial input + if use_batch_norm: + res = batch_norm( + input_=res, dim=channels, name="bn_in", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res *= 2. + res = tf.concat_v2([res, -res], 3) + res = tf.concat_v2([res, mask], 3) + dim_in = 2. * channels + 1 + res = tf.nn.relu(res) + res = resnet(input_=res, dim_in=dim_in, dim=dim, + dim_out=2 * channels, + name="resnet", use_batch_norm=use_batch_norm, + train=train, weight_norm=weight_norm, + residual_blocks=residual_blocks, + bottleneck=bottleneck, skip=skip) + mask = tf.mod(mask_channel + mask, 2) + res = tf.split(res, 2, 3) + shift, log_rescaling = res[-2], res[-1] + scale = variable_on_cpu( + "rescaling_scale", [], + tf.constant_initializer(0.)) + shift = tf.reshape( + shift, [batch_size, height, width, channels]) + log_rescaling = tf.reshape( + log_rescaling, [batch_size, height, width, channels]) + log_rescaling = scale * tf.tanh(log_rescaling) + if not use_batch_norm: + scale_shift = variable_on_cpu( + "scale_shift", [], + tf.constant_initializer(0.)) + log_rescaling += scale_shift + shift *= (1. - mask) + log_rescaling *= (1. - mask) + if reverse: + res = input_ + if use_batch_norm: + mean, var = batch_norm_log_diff( + input_=res * (1. - mask), dim=channels, name="bn_out", + train=False, epsilon=1e-4, axes=[0, 1, 2]) + log_var = tf.log(var) + res *= tf.exp(.5 * log_var * (1. - mask)) + res += mean * (1. - mask) + res *= tf.exp(-log_rescaling) + res -= shift + log_diff = -log_rescaling + if use_batch_norm: + log_diff += .5 * log_var * (1. - mask) + else: + res = input_ + res += shift + res *= tf.exp(log_rescaling) + log_diff = log_rescaling + if use_batch_norm: + mean, var = batch_norm_log_diff( + input_=res * (1. - mask), dim=channels, name="bn_out", + train=train, epsilon=1e-4, axes=[0, 1, 2]) + log_var = tf.log(var) + res -= mean * (1. - mask) + res *= tf.exp(-.5 * log_var * (1. - mask)) + log_diff -= .5 * log_var * (1. - mask) + + return res, log_diff + + +def masked_conv_add_coupling(input_, mask_in, dim, name, + use_batch_norm=True, train=True, weight_norm=True, + reverse=False, residual_blocks=5, + bottleneck=False, use_width=1., use_height=1., + mask_channel=0., skip=True): + """Additive coupling with masked convolution.""" + with tf.variable_scope(name) as scope: + if reverse or (not train): + scope.reuse_variables() + shape = input_.get_shape().as_list() + batch_size = shape[0] + height = shape[1] + width = shape[2] + channels = shape[3] + + # build mask + mask = use_width * numpy.arange(width) + mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask + mask = mask.astype("float32") + mask = tf.mod(mask_in + mask, 2) + mask = tf.reshape(mask, [-1, height, width, 1]) + if mask.get_shape().as_list()[0] == 1: + mask = tf.tile(mask, [batch_size, 1, 1, 1]) + res = input_ * tf.mod(mask_channel + mask, 2) + + # initial input + if use_batch_norm: + res = batch_norm( + input_=res, dim=channels, name="bn_in", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res *= 2. + res = tf.concat_v2([res, -res], 3) + res = tf.concat_v2([res, mask], 3) + dim_in = 2. * channels + 1 + res = tf.nn.relu(res) + shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels, + name="resnet", use_batch_norm=use_batch_norm, + train=train, weight_norm=weight_norm, + residual_blocks=residual_blocks, + bottleneck=bottleneck, skip=skip) + mask = tf.mod(mask_channel + mask, 2) + shift *= (1. - mask) + # use_batch_norm = False + if reverse: + res = input_ + if use_batch_norm: + mean, var = batch_norm_log_diff( + input_=res * (1. - mask), + dim=channels, name="bn_out", train=False, epsilon=1e-4) + log_var = tf.log(var) + res *= tf.exp(.5 * log_var * (1. - mask)) + res += mean * (1. - mask) + res -= shift + log_diff = tf.zeros_like(res) + if use_batch_norm: + log_diff += .5 * log_var * (1. - mask) + else: + res = input_ + res += shift + log_diff = tf.zeros_like(res) + if use_batch_norm: + mean, var = batch_norm_log_diff( + input_=res * (1. - mask), dim=channels, + name="bn_out", train=train, epsilon=1e-4, axes=[0, 1, 2]) + log_var = tf.log(var) + res -= mean * (1. - mask) + res *= tf.exp(-.5 * log_var * (1. - mask)) + log_diff -= .5 * log_var * (1. - mask) + + return res, log_diff + + +def masked_conv_coupling(input_, mask_in, dim, name, + use_batch_norm=True, train=True, weight_norm=True, + reverse=False, residual_blocks=5, + bottleneck=False, use_aff=True, + use_width=1., use_height=1., + mask_channel=0., skip=True): + """Coupling with masked convolution.""" + if use_aff: + return masked_conv_aff_coupling( + input_=input_, mask_in=mask_in, dim=dim, name=name, + use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, + reverse=reverse, residual_blocks=residual_blocks, + bottleneck=bottleneck, use_width=use_width, use_height=use_height, + mask_channel=mask_channel, skip=skip) + else: + return masked_conv_add_coupling( + input_=input_, mask_in=mask_in, dim=dim, name=name, + use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, + reverse=reverse, residual_blocks=residual_blocks, + bottleneck=bottleneck, use_width=use_width, use_height=use_height, + mask_channel=mask_channel, skip=skip) + + +# channel-axis splitting implementations +def conv_ch_aff_coupling(input_, dim, name, + use_batch_norm=True, train=True, weight_norm=True, + reverse=False, residual_blocks=5, + bottleneck=False, change_bottom=True, skip=True): + """Affine coupling with channel-wise splitting.""" + with tf.variable_scope(name) as scope: + if reverse or (not train): + scope.reuse_variables() + + if change_bottom: + input_, canvas = tf.split(input_, 2, 3) + else: + canvas, input_ = tf.split(input_, 2, 3) + shape = input_.get_shape().as_list() + batch_size = shape[0] + height = shape[1] + width = shape[2] + channels = shape[3] + res = input_ + + # initial input + if use_batch_norm: + res = batch_norm( + input_=res, dim=channels, name="bn_in", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res = tf.concat_v2([res, -res], 3) + dim_in = 2. * channels + res = tf.nn.relu(res) + res = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=2 * channels, + name="resnet", use_batch_norm=use_batch_norm, + train=train, weight_norm=weight_norm, + residual_blocks=residual_blocks, + bottleneck=bottleneck, skip=skip) + shift, log_rescaling = tf.split(res, 2, 3) + scale = variable_on_cpu( + "scale", [], + tf.constant_initializer(1.)) + shift = tf.reshape( + shift, [batch_size, height, width, channels]) + log_rescaling = tf.reshape( + log_rescaling, [batch_size, height, width, channels]) + log_rescaling = scale * tf.tanh(log_rescaling) + if not use_batch_norm: + scale_shift = variable_on_cpu( + "scale_shift", [], + tf.constant_initializer(0.)) + log_rescaling += scale_shift + if reverse: + res = canvas + if use_batch_norm: + mean, var = batch_norm_log_diff( + input_=res, dim=channels, name="bn_out", train=False, + epsilon=1e-4, axes=[0, 1, 2]) + log_var = tf.log(var) + res *= tf.exp(.5 * log_var) + res += mean + res *= tf.exp(-log_rescaling) + res -= shift + log_diff = -log_rescaling + if use_batch_norm: + log_diff += .5 * log_var + else: + res = canvas + res += shift + res *= tf.exp(log_rescaling) + log_diff = log_rescaling + if use_batch_norm: + mean, var = batch_norm_log_diff( + input_=res, dim=channels, name="bn_out", train=train, + epsilon=1e-4, axes=[0, 1, 2]) + log_var = tf.log(var) + res -= mean + res *= tf.exp(-.5 * log_var) + log_diff -= .5 * log_var + if change_bottom: + res = tf.concat_v2([input_, res], 3) + log_diff = tf.concat_v2([tf.zeros_like(log_diff), log_diff], 3) + else: + res = tf.concat_v2([res, input_], 3) + log_diff = tf.concat_v2([log_diff, tf.zeros_like(log_diff)], 3) + + return res, log_diff + + +def conv_ch_add_coupling(input_, dim, name, + use_batch_norm=True, train=True, weight_norm=True, + reverse=False, residual_blocks=5, + bottleneck=False, change_bottom=True, skip=True): + """Additive coupling with channel-wise splitting.""" + with tf.variable_scope(name) as scope: + if reverse or (not train): + scope.reuse_variables() + + if change_bottom: + input_, canvas = tf.split(input_, 2, 3) + else: + canvas, input_ = tf.split(input_, 2, 3) + shape = input_.get_shape().as_list() + channels = shape[3] + res = input_ + + # initial input + if use_batch_norm: + res = batch_norm( + input_=res, dim=channels, name="bn_in", scale=False, + train=train, epsilon=1e-4, axes=[0, 1, 2]) + res = tf.concat_v2([res, -res], 3) + dim_in = 2. * channels + res = tf.nn.relu(res) + shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels, + name="resnet", use_batch_norm=use_batch_norm, + train=train, weight_norm=weight_norm, + residual_blocks=residual_blocks, + bottleneck=bottleneck, skip=skip) + if reverse: + res = canvas + if use_batch_norm: + mean, var = batch_norm_log_diff( + input_=res, dim=channels, name="bn_out", train=False, + epsilon=1e-4, axes=[0, 1, 2]) + log_var = tf.log(var) + res *= tf.exp(.5 * log_var) + res += mean + res -= shift + log_diff = tf.zeros_like(res) + if use_batch_norm: + log_diff += .5 * log_var + else: + res = canvas + res += shift + log_diff = tf.zeros_like(res) + if use_batch_norm: + mean, var = batch_norm_log_diff( + input_=res, dim=channels, name="bn_out", train=train, + epsilon=1e-4, axes=[0, 1, 2]) + log_var = tf.log(var) + res -= mean + res *= tf.exp(-.5 * log_var) + log_diff -= .5 * log_var + if change_bottom: + res = tf.concat_v2([input_, res], 3) + log_diff = tf.concat_v2([tf.zeros_like(log_diff), log_diff], 3) + else: + res = tf.concat_v2([res, input_], 3) + log_diff = tf.concat_v2([log_diff, tf.zeros_like(log_diff)], 3) + + return res, log_diff + + +def conv_ch_coupling(input_, dim, name, + use_batch_norm=True, train=True, weight_norm=True, + reverse=False, residual_blocks=5, + bottleneck=False, use_aff=True, change_bottom=True, + skip=True): + """Coupling with channel-wise splitting.""" + if use_aff: + return conv_ch_aff_coupling( + input_=input_, dim=dim, name=name, + use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, + reverse=reverse, residual_blocks=residual_blocks, + bottleneck=bottleneck, change_bottom=change_bottom, skip=skip) + else: + return conv_ch_add_coupling( + input_=input_, dim=dim, name=name, + use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, + reverse=reverse, residual_blocks=residual_blocks, + bottleneck=bottleneck, change_bottom=change_bottom, skip=skip) + + +# RECURSIVE USE OF COUPLING LAYERS +def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale, + use_batch_norm=True, weight_norm=True, + train=True): + """Recursion on coupling layers.""" + shape = input_.get_shape().as_list() + channels = shape[3] + residual_blocks = hps.residual_blocks + base_dim = hps.base_dim + mask = 1. + use_aff = hps.use_aff + res = input_ + skip = hps.skip + log_diff = tf.zeros_like(input_) + dim = base_dim + if FLAGS.recursion_type < 4: + dim *= 2 ** scale_idx + with tf.variable_scope("scale_%d" % scale_idx): + # initial coupling layers + res, inc_log_diff = masked_conv_coupling( + input_=res, + mask_in=mask, dim=dim, + name="coupling_0", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=False, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=use_aff, + use_width=1., use_height=1., skip=skip) + log_diff += inc_log_diff + res, inc_log_diff = masked_conv_coupling( + input_=res, + mask_in=1. - mask, dim=dim, + name="coupling_1", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=False, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=use_aff, + use_width=1., use_height=1., skip=skip) + log_diff += inc_log_diff + res, inc_log_diff = masked_conv_coupling( + input_=res, + mask_in=mask, dim=dim, + name="coupling_2", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=False, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=True, + use_width=1., use_height=1., skip=skip) + log_diff += inc_log_diff + if scale_idx < (n_scale - 1): + with tf.variable_scope("scale_%d" % scale_idx): + res = squeeze_2x2(res) + log_diff = squeeze_2x2(log_diff) + res, inc_log_diff = conv_ch_coupling( + input_=res, + change_bottom=True, dim=2 * dim, + name="coupling_4", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=False, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) + log_diff += inc_log_diff + res, inc_log_diff = conv_ch_coupling( + input_=res, + change_bottom=False, dim=2 * dim, + name="coupling_5", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=False, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) + log_diff += inc_log_diff + res, inc_log_diff = conv_ch_coupling( + input_=res, + change_bottom=True, dim=2 * dim, + name="coupling_6", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=False, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=True, skip=skip) + log_diff += inc_log_diff + res = unsqueeze_2x2(res) + log_diff = unsqueeze_2x2(log_diff) + if FLAGS.recursion_type > 1: + res = squeeze_2x2_ordered(res) + log_diff = squeeze_2x2_ordered(log_diff) + if FLAGS.recursion_type > 2: + res_1 = res[:, :, :, :channels] + res_2 = res[:, :, :, channels:] + log_diff_1 = log_diff[:, :, :, :channels] + log_diff_2 = log_diff[:, :, :, channels:] + else: + res_1, res_2 = tf.split(res, 2, 3) + log_diff_1, log_diff_2 = tf.split(log_diff, 2, 3) + res_1, inc_log_diff = rec_masked_conv_coupling( + input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale, + use_batch_norm=use_batch_norm, weight_norm=weight_norm, + train=train) + res = tf.concat_v2([res_1, res_2], 3) + log_diff_1 += inc_log_diff + log_diff = tf.concat_v2([log_diff_1, log_diff_2], 3) + res = squeeze_2x2_ordered(res, reverse=True) + log_diff = squeeze_2x2_ordered(log_diff, reverse=True) + else: + res = squeeze_2x2_ordered(res) + log_diff = squeeze_2x2_ordered(log_diff) + res, inc_log_diff = rec_masked_conv_coupling( + input_=res, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale, + use_batch_norm=use_batch_norm, weight_norm=weight_norm, + train=train) + log_diff += inc_log_diff + res = squeeze_2x2_ordered(res, reverse=True) + log_diff = squeeze_2x2_ordered(log_diff, reverse=True) + else: + with tf.variable_scope("scale_%d" % scale_idx): + res, inc_log_diff = masked_conv_coupling( + input_=res, + mask_in=1. - mask, dim=dim, + name="coupling_3", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=False, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=True, + use_width=1., use_height=1., skip=skip) + log_diff += inc_log_diff + return res, log_diff + + +def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale, + use_batch_norm=True, weight_norm=True, + train=True): + """Recursion on inverting coupling layers.""" + shape = input_.get_shape().as_list() + channels = shape[3] + residual_blocks = hps.residual_blocks + base_dim = hps.base_dim + mask = 1. + use_aff = hps.use_aff + res = input_ + log_diff = tf.zeros_like(input_) + skip = hps.skip + dim = base_dim + if FLAGS.recursion_type < 4: + dim *= 2 ** scale_idx + if scale_idx < (n_scale - 1): + if FLAGS.recursion_type > 1: + res = squeeze_2x2_ordered(res) + log_diff = squeeze_2x2_ordered(log_diff) + if FLAGS.recursion_type > 2: + res_1 = res[:, :, :, :channels] + res_2 = res[:, :, :, channels:] + log_diff_1 = log_diff[:, :, :, :channels] + log_diff_2 = log_diff[:, :, :, channels:] + else: + res_1, res_2 = tf.split(res, 2, 3) + log_diff_1, log_diff_2 = tf.split(log_diff, 2, 3) + res_1, log_diff_1 = rec_masked_deconv_coupling( + input_=res_1, hps=hps, + scale_idx=scale_idx + 1, n_scale=n_scale, + use_batch_norm=use_batch_norm, weight_norm=weight_norm, + train=train) + res = tf.concat_v2([res_1, res_2], 3) + log_diff = tf.concat_v2([log_diff_1, log_diff_2], 3) + res = squeeze_2x2_ordered(res, reverse=True) + log_diff = squeeze_2x2_ordered(log_diff, reverse=True) + else: + res = squeeze_2x2_ordered(res) + log_diff = squeeze_2x2_ordered(log_diff) + res, log_diff = rec_masked_deconv_coupling( + input_=res, hps=hps, + scale_idx=scale_idx + 1, n_scale=n_scale, + use_batch_norm=use_batch_norm, weight_norm=weight_norm, + train=train) + res = squeeze_2x2_ordered(res, reverse=True) + log_diff = squeeze_2x2_ordered(log_diff, reverse=True) + with tf.variable_scope("scale_%d" % scale_idx): + res = squeeze_2x2(res) + log_diff = squeeze_2x2(log_diff) + res, inc_log_diff = conv_ch_coupling( + input_=res, + change_bottom=True, dim=2 * dim, + name="coupling_6", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=True, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=True, skip=skip) + log_diff += inc_log_diff + res, inc_log_diff = conv_ch_coupling( + input_=res, + change_bottom=False, dim=2 * dim, + name="coupling_5", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=True, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) + log_diff += inc_log_diff + res, inc_log_diff = conv_ch_coupling( + input_=res, + change_bottom=True, dim=2 * dim, + name="coupling_4", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=True, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) + log_diff += inc_log_diff + res = unsqueeze_2x2(res) + log_diff = unsqueeze_2x2(log_diff) + else: + with tf.variable_scope("scale_%d" % scale_idx): + res, inc_log_diff = masked_conv_coupling( + input_=res, + mask_in=1. - mask, dim=dim, + name="coupling_3", + use_batch_norm=use_batch_norm, train=train, + weight_norm=weight_norm, + reverse=True, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=True, + use_width=1., use_height=1., skip=skip) + log_diff += inc_log_diff + + with tf.variable_scope("scale_%d" % scale_idx): + res, inc_log_diff = masked_conv_coupling( + input_=res, + mask_in=mask, dim=dim, + name="coupling_2", + use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, + reverse=True, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=True, + use_width=1., use_height=1., skip=skip) + log_diff += inc_log_diff + res, inc_log_diff = masked_conv_coupling( + input_=res, + mask_in=1. - mask, dim=dim, + name="coupling_1", + use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, + reverse=True, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=use_aff, + use_width=1., use_height=1., skip=skip) + log_diff += inc_log_diff + res, inc_log_diff = masked_conv_coupling( + input_=res, + mask_in=mask, dim=dim, + name="coupling_0", + use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, + reverse=True, residual_blocks=residual_blocks, + bottleneck=hps.bottleneck, use_aff=use_aff, + use_width=1., use_height=1., skip=skip) + log_diff += inc_log_diff + + return res, log_diff + + +# ENCODER AND DECODER IMPLEMENTATIONS +# start the recursions +def encoder(input_, hps, n_scale, use_batch_norm=True, + weight_norm=True, train=True): + """Encoding/gaussianization function.""" + res = input_ + log_diff = tf.zeros_like(input_) + res, inc_log_diff = rec_masked_conv_coupling( + input_=res, hps=hps, scale_idx=0, n_scale=n_scale, + use_batch_norm=use_batch_norm, weight_norm=weight_norm, + train=train) + log_diff += inc_log_diff + + return res, log_diff + + +def decoder(input_, hps, n_scale, use_batch_norm=True, + weight_norm=True, train=True): + """Decoding/generator function.""" + res, log_diff = rec_masked_deconv_coupling( + input_=input_, hps=hps, scale_idx=0, n_scale=n_scale, + use_batch_norm=use_batch_norm, weight_norm=weight_norm, + train=train) + + return res, log_diff + + +class RealNVP(object): + """Real NVP model.""" + + def __init__(self, hps, sampling=False): + # DATA TENSOR INSTANTIATION + device = "/cpu:0" + if FLAGS.dataset == "imnet": + with tf.device( + tf.train.replica_device_setter(0, worker_device=device)): + filename_queue = tf.train.string_input_producer( + gfile.Glob(FLAGS.data_path), num_epochs=None) + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + features = tf.parse_single_example( + serialized_example, + features={ + "image_raw": tf.FixedLenFeature([], tf.string), + }) + image = tf.decode_raw(features["image_raw"], tf.uint8) + image.set_shape([FLAGS.image_size * FLAGS.image_size * 3]) + image = tf.cast(image, tf.float32) + if FLAGS.mode == "train": + images = tf.train.shuffle_batch( + [image], batch_size=hps.batch_size, num_threads=1, + capacity=1000 + 3 * hps.batch_size, + # Ensures a minimum amount of shuffling of examples. + min_after_dequeue=1000) + else: + images = tf.train.batch( + [image], batch_size=hps.batch_size, num_threads=1, + capacity=1000 + 3 * hps.batch_size) + self.x_orig = x_orig = images + image_size = FLAGS.image_size + x_in = tf.reshape( + x_orig, + [hps.batch_size, FLAGS.image_size, FLAGS.image_size, 3]) + x_in = tf.clip_by_value(x_in, 0, 255) + x_in = (tf.cast(x_in, tf.float32) + + tf.random_uniform(tf.shape(x_in))) / 256. + elif FLAGS.dataset == "celeba": + with tf.device( + tf.train.replica_device_setter(0, worker_device=device)): + filename_queue = tf.train.string_input_producer( + gfile.Glob(FLAGS.data_path), num_epochs=None) + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + features = tf.parse_single_example( + serialized_example, + features={ + "image_raw": tf.FixedLenFeature([], tf.string), + }) + image = tf.decode_raw(features["image_raw"], tf.uint8) + image.set_shape([218 * 178 * 3]) # 218, 178 + image = tf.cast(image, tf.float32) + image = tf.reshape(image, [218, 178, 3]) + image = image[40:188, 15:163, :] + if FLAGS.mode == "train": + image = tf.image.random_flip_left_right(image) + images = tf.train.shuffle_batch( + [image], batch_size=hps.batch_size, num_threads=1, + capacity=1000 + 3 * hps.batch_size, + min_after_dequeue=1000) + else: + images = tf.train.batch( + [image], batch_size=hps.batch_size, num_threads=1, + capacity=1000 + 3 * hps.batch_size) + self.x_orig = x_orig = images + image_size = 64 + x_in = tf.reshape(x_orig, [hps.batch_size, 148, 148, 3]) + x_in = tf.image.resize_images( + x_in, [64, 64], method=0, align_corners=False) + x_in = (tf.cast(x_in, tf.float32) + + tf.random_uniform(tf.shape(x_in))) / 256. + elif FLAGS.dataset == "lsun": + with tf.device( + tf.train.replica_device_setter(0, worker_device=device)): + filename_queue = tf.train.string_input_producer( + gfile.Glob(FLAGS.data_path), num_epochs=None) + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + features = tf.parse_single_example( + serialized_example, + features={ + "image_raw": tf.FixedLenFeature([], tf.string), + "height": tf.FixedLenFeature([], tf.int64), + "width": tf.FixedLenFeature([], tf.int64), + "depth": tf.FixedLenFeature([], tf.int64) + }) + image = tf.decode_raw(features["image_raw"], tf.uint8) + height = tf.reshape((features["height"], tf.int64)[0], [1]) + height = tf.cast(height, tf.int32) + width = tf.reshape((features["width"], tf.int64)[0], [1]) + width = tf.cast(width, tf.int32) + depth = tf.reshape((features["depth"], tf.int64)[0], [1]) + depth = tf.cast(depth, tf.int32) + image = tf.reshape(image, tf.concat_v2([height, width, depth], 0)) + image = tf.random_crop(image, [64, 64, 3]) + if FLAGS.mode == "train": + image = tf.image.random_flip_left_right(image) + images = tf.train.shuffle_batch( + [image], batch_size=hps.batch_size, num_threads=1, + capacity=1000 + 3 * hps.batch_size, + # Ensures a minimum amount of shuffling of examples. + min_after_dequeue=1000) + else: + images = tf.train.batch( + [image], batch_size=hps.batch_size, num_threads=1, + capacity=1000 + 3 * hps.batch_size) + self.x_orig = x_orig = images + image_size = 64 + x_in = tf.reshape(x_orig, [hps.batch_size, 64, 64, 3]) + x_in = (tf.cast(x_in, tf.float32) + + tf.random_uniform(tf.shape(x_in))) / 256. + else: + raise ValueError("Unknown dataset.") + x_in = tf.reshape(x_in, [hps.batch_size, image_size, image_size, 3]) + side_shown = int(numpy.sqrt(hps.batch_size)) + shown_x = tf.transpose( + tf.reshape( + x_in[:(side_shown * side_shown), :, :, :], + [side_shown, image_size * side_shown, image_size, 3]), + [0, 2, 1, 3]) + shown_x = tf.transpose( + tf.reshape( + shown_x, + [1, image_size * side_shown, image_size * side_shown, 3]), + [0, 2, 1, 3]) * 255. + tf.summary.image( + "inputs", + tf.cast(shown_x, tf.uint8), + max_outputs=1) + + # restrict the data + FLAGS.image_size = image_size + data_constraint = hps.data_constraint + pre_logit_scale = numpy.log(data_constraint) + pre_logit_scale -= numpy.log(1. - data_constraint) + pre_logit_scale = tf.cast(pre_logit_scale, tf.float32) + logit_x_in = 2. * x_in # [0, 2] + logit_x_in -= 1. # [-1, 1] + logit_x_in *= data_constraint # [-.9, .9] + logit_x_in += 1. # [.1, 1.9] + logit_x_in /= 2. # [.05, .95] + # logit the data + logit_x_in = tf.log(logit_x_in) - tf.log(1. - logit_x_in) + transform_cost = tf.reduce_sum( + tf.nn.softplus(logit_x_in) + tf.nn.softplus(-logit_x_in) + - tf.nn.softplus(-pre_logit_scale), + [1, 2, 3]) + + # INFERENCE AND COSTS + z_out, log_diff = encoder( + input_=logit_x_in, hps=hps, n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=True) + if FLAGS.mode != "train": + z_out, log_diff = encoder( + input_=logit_x_in, hps=hps, n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=False) + final_shape = [image_size, image_size, 3] + prior_ll = standard_normal_ll(z_out) + prior_ll = tf.reduce_sum(prior_ll, [1, 2, 3]) + log_diff = tf.reduce_sum(log_diff, [1, 2, 3]) + log_diff += transform_cost + cost = -(prior_ll + log_diff) + + self.x_in = x_in + self.z_out = z_out + self.cost = cost = tf.reduce_mean(cost) + + l2_reg = sum( + [tf.reduce_sum(tf.square(v)) for v in tf.trainable_variables() + if ("magnitude" in v.name) or ("rescaling_scale" in v.name)]) + + bit_per_dim = ((cost + numpy.log(256.) * image_size * image_size * 3.) + / (image_size * image_size * 3. * numpy.log(2.))) + self.bit_per_dim = bit_per_dim + + # OPTIMIZATION + momentum = 1. - hps.momentum + decay = 1. - hps.decay + if hps.optimizer == "adam": + optimizer = tf.train.AdamOptimizer( + learning_rate=hps.learning_rate, + beta1=momentum, beta2=decay, epsilon=1e-08, + use_locking=False, name="Adam") + elif hps.optimizer == "rmsprop": + optimizer = tf.train.RMSPropOptimizer( + learning_rate=hps.learning_rate, decay=decay, + momentum=momentum, epsilon=1e-04, + use_locking=False, name="RMSProp") + else: + optimizer = tf.train.MomentumOptimizer(hps.learning_rate, + momentum=momentum) + + step = tf.get_variable( + "global_step", [], tf.int64, + tf.zeros_initializer(), + trainable=False) + self.step = step + grads_and_vars = optimizer.compute_gradients( + cost + hps.l2_coeff * l2_reg, + tf.trainable_variables()) + grads, vars_ = zip(*grads_and_vars) + capped_grads, gradient_norm = tf.clip_by_global_norm( + grads, clip_norm=hps.clip_gradient) + gradient_norm = tf.check_numerics(gradient_norm, + "Gradient norm is NaN or Inf.") + + l2_z = tf.reduce_sum(tf.square(z_out), [1, 2, 3]) + if not sampling: + tf.summary.scalar("negative_log_likelihood", tf.reshape(cost, [])) + tf.summary.scalar("gradient_norm", tf.reshape(gradient_norm, [])) + tf.summary.scalar("bit_per_dim", tf.reshape(bit_per_dim, [])) + tf.summary.scalar("log_diff", tf.reshape(tf.reduce_mean(log_diff), [])) + tf.summary.scalar("prior_ll", tf.reshape(tf.reduce_mean(prior_ll), [])) + tf.summary.scalar( + "log_diff_var", + tf.reshape(tf.reduce_mean(tf.square(log_diff)) + - tf.square(tf.reduce_mean(log_diff)), [])) + tf.summary.scalar( + "prior_ll_var", + tf.reshape(tf.reduce_mean(tf.square(prior_ll)) + - tf.square(tf.reduce_mean(prior_ll)), [])) + tf.summary.scalar("l2_z_mean", tf.reshape(tf.reduce_mean(l2_z), [])) + tf.summary.scalar( + "l2_z_var", + tf.reshape(tf.reduce_mean(tf.square(l2_z)) + - tf.square(tf.reduce_mean(l2_z)), [])) + + + capped_grads_and_vars = zip(capped_grads, vars_) + self.train_step = optimizer.apply_gradients( + capped_grads_and_vars, global_step=step) + + # SAMPLING AND VISUALIZATION + if sampling: + # SAMPLES + sample = standard_normal_sample([100] + final_shape) + sample, _ = decoder( + input_=sample, hps=hps, n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=True) + sample = tf.nn.sigmoid(sample) + + sample = tf.clip_by_value(sample, 0, 1) * 255. + sample = tf.reshape(sample, [100, image_size, image_size, 3]) + sample = tf.transpose( + tf.reshape(sample, [10, image_size * 10, image_size, 3]), + [0, 2, 1, 3]) + sample = tf.transpose( + tf.reshape(sample, [1, image_size * 10, image_size * 10, 3]), + [0, 2, 1, 3]) + tf.summary.image( + "samples", + tf.cast(sample, tf.uint8), + max_outputs=1) + + # CONCATENATION + concatenation, _ = encoder( + input_=logit_x_in, hps=hps, + n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=False) + concatenation = tf.reshape( + concatenation, + [(side_shown * side_shown), image_size, image_size, 3]) + concatenation = tf.transpose( + tf.reshape( + concatenation, + [side_shown, image_size * side_shown, image_size, 3]), + [0, 2, 1, 3]) + concatenation = tf.transpose( + tf.reshape( + concatenation, + [1, image_size * side_shown, image_size * side_shown, 3]), + [0, 2, 1, 3]) + concatenation, _ = decoder( + input_=concatenation, hps=hps, n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=False) + concatenation = tf.nn.sigmoid(concatenation) * 255. + tf.summary.image( + "concatenation", + tf.cast(concatenation, tf.uint8), + max_outputs=1) + + # MANIFOLD + + # Data basis + z_u, _ = encoder( + input_=logit_x_in[:8, :, :, :], hps=hps, + n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=False) + u_1 = tf.reshape(z_u[0, :, :, :], [-1]) + u_2 = tf.reshape(z_u[1, :, :, :], [-1]) + u_3 = tf.reshape(z_u[2, :, :, :], [-1]) + u_4 = tf.reshape(z_u[3, :, :, :], [-1]) + u_5 = tf.reshape(z_u[4, :, :, :], [-1]) + u_6 = tf.reshape(z_u[5, :, :, :], [-1]) + u_7 = tf.reshape(z_u[6, :, :, :], [-1]) + u_8 = tf.reshape(z_u[7, :, :, :], [-1]) + + # 3D dome + manifold_side = 8 + angle_1 = numpy.arange(manifold_side) * 1. / manifold_side + angle_2 = numpy.arange(manifold_side) * 1. / manifold_side + angle_1 *= 2. * numpy.pi + angle_2 *= 2. * numpy.pi + angle_1 = angle_1.astype("float32") + angle_2 = angle_2.astype("float32") + angle_1 = tf.reshape(angle_1, [1, -1, 1]) + angle_1 += tf.zeros([manifold_side, manifold_side, 1]) + angle_2 = tf.reshape(angle_2, [-1, 1, 1]) + angle_2 += tf.zeros([manifold_side, manifold_side, 1]) + n_angle_3 = 40 + angle_3 = numpy.arange(n_angle_3) * 1. / n_angle_3 + angle_3 *= 2 * numpy.pi + angle_3 = angle_3.astype("float32") + angle_3 = tf.reshape(angle_3, [-1, 1, 1, 1]) + angle_3 += tf.zeros([n_angle_3, manifold_side, manifold_side, 1]) + manifold = tf.cos(angle_1) * ( + tf.cos(angle_2) * ( + tf.cos(angle_3) * u_1 + tf.sin(angle_3) * u_2) + + tf.sin(angle_2) * ( + tf.cos(angle_3) * u_3 + tf.sin(angle_3) * u_4)) + manifold += tf.sin(angle_1) * ( + tf.cos(angle_2) * ( + tf.cos(angle_3) * u_5 + tf.sin(angle_3) * u_6) + + tf.sin(angle_2) * ( + tf.cos(angle_3) * u_7 + tf.sin(angle_3) * u_8)) + manifold = tf.reshape( + manifold, + [n_angle_3 * manifold_side * manifold_side] + final_shape) + manifold, _ = decoder( + input_=manifold, hps=hps, n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=False) + manifold = tf.nn.sigmoid(manifold) + + manifold = tf.clip_by_value(manifold, 0, 1) * 255. + manifold = tf.reshape( + manifold, + [n_angle_3, + manifold_side * manifold_side, + image_size, + image_size, + 3]) + manifold = tf.transpose( + tf.reshape( + manifold, + [n_angle_3, manifold_side, + image_size * manifold_side, image_size, 3]), [0, 1, 3, 2, 4]) + manifold = tf.transpose( + tf.reshape( + manifold, + [n_angle_3, image_size * manifold_side, + image_size * manifold_side, 3]), + [0, 2, 1, 3]) + manifold = tf.transpose(manifold, [1, 2, 0, 3]) + manifold = tf.reshape( + manifold, + [1, image_size * manifold_side, + image_size * manifold_side, 3 * n_angle_3]) + tf.summary.image( + "manifold", + tf.cast(manifold[:, :, :, :3], tf.uint8), + max_outputs=1) + + # COMPRESSION + z_complete, _ = encoder( + input_=logit_x_in[:hps.n_scale, :, :, :], hps=hps, + n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=False) + z_compressed_list = [z_complete] + z_noisy_list = [z_complete] + z_lost = z_complete + for scale_idx in xrange(hps.n_scale - 1): + z_lost = squeeze_2x2_ordered(z_lost) + z_lost, _ = tf.split(z_lost, 2, 3) + z_compressed = z_lost + z_noisy = z_lost + for _ in xrange(scale_idx + 1): + z_compressed = tf.concat_v2( + [z_compressed, tf.zeros_like(z_compressed)], 3) + z_compressed = squeeze_2x2_ordered( + z_compressed, reverse=True) + z_noisy = tf.concat_v2( + [z_noisy, tf.random_normal( + z_noisy.get_shape().as_list())], 3) + z_noisy = squeeze_2x2_ordered(z_noisy, reverse=True) + z_compressed_list.append(z_compressed) + z_noisy_list.append(z_noisy) + self.z_reduced = z_lost + z_compressed = tf.concat_v2(z_compressed_list, 0) + z_noisy = tf.concat_v2(z_noisy_list, 0) + noisy_images, _ = decoder( + input_=z_noisy, hps=hps, n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=False) + compressed_images, _ = decoder( + input_=z_compressed, hps=hps, n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=False) + noisy_images = tf.nn.sigmoid(noisy_images) + compressed_images = tf.nn.sigmoid(compressed_images) + + noisy_images = tf.clip_by_value(noisy_images, 0, 1) * 255. + noisy_images = tf.reshape( + noisy_images, + [(hps.n_scale * hps.n_scale), image_size, image_size, 3]) + noisy_images = tf.transpose( + tf.reshape( + noisy_images, + [hps.n_scale, image_size * hps.n_scale, image_size, 3]), + [0, 2, 1, 3]) + noisy_images = tf.transpose( + tf.reshape( + noisy_images, + [1, image_size * hps.n_scale, image_size * hps.n_scale, 3]), + [0, 2, 1, 3]) + tf.summary.image( + "noise", + tf.cast(noisy_images, tf.uint8), + max_outputs=1) + compressed_images = tf.clip_by_value(compressed_images, 0, 1) * 255. + compressed_images = tf.reshape( + compressed_images, + [(hps.n_scale * hps.n_scale), image_size, image_size, 3]) + compressed_images = tf.transpose( + tf.reshape( + compressed_images, + [hps.n_scale, image_size * hps.n_scale, image_size, 3]), + [0, 2, 1, 3]) + compressed_images = tf.transpose( + tf.reshape( + compressed_images, + [1, image_size * hps.n_scale, image_size * hps.n_scale, 3]), + [0, 2, 1, 3]) + tf.summary.image( + "compression", + tf.cast(compressed_images, tf.uint8), + max_outputs=1) + + # SAMPLES x2 + final_shape[0] *= 2 + final_shape[1] *= 2 + big_sample = standard_normal_sample([25] + final_shape) + big_sample, _ = decoder( + input_=big_sample, hps=hps, n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=True) + big_sample = tf.nn.sigmoid(big_sample) + + big_sample = tf.clip_by_value(big_sample, 0, 1) * 255. + big_sample = tf.reshape( + big_sample, + [25, image_size * 2, image_size * 2, 3]) + big_sample = tf.transpose( + tf.reshape( + big_sample, + [5, image_size * 10, image_size * 2, 3]), [0, 2, 1, 3]) + big_sample = tf.transpose( + tf.reshape( + big_sample, + [1, image_size * 10, image_size * 10, 3]), + [0, 2, 1, 3]) + tf.summary.image( + "big_sample", + tf.cast(big_sample, tf.uint8), + max_outputs=1) + + # SAMPLES x10 + final_shape[0] *= 5 + final_shape[1] *= 5 + extra_large = standard_normal_sample([1] + final_shape) + extra_large, _ = decoder( + input_=extra_large, hps=hps, n_scale=hps.n_scale, + use_batch_norm=hps.use_batch_norm, weight_norm=True, + train=True) + extra_large = tf.nn.sigmoid(extra_large) + + extra_large = tf.clip_by_value(extra_large, 0, 1) * 255. + tf.summary.image( + "extra_large", + tf.cast(extra_large, tf.uint8), + max_outputs=1) + + def eval_epoch(self, hps): + """Evaluate bits/dim.""" + n_eval_dict = { + "imnet": 50000, + "lsun": 300, + "celeba": 19962, + "svhn": 26032, + } + if FLAGS.eval_set_size == 0: + num_examples_eval = n_eval_dict[FLAGS.dataset] + else: + num_examples_eval = FLAGS.eval_set_size + n_epoch = num_examples_eval / hps.batch_size + eval_costs = [] + bar_len = 70 + for epoch_idx in xrange(n_epoch): + n_equal = epoch_idx * bar_len * 1. / n_epoch + n_equal = numpy.ceil(n_equal) + n_equal = int(n_equal) + n_dash = bar_len - n_equal + progress_bar = "[" + "=" * n_equal + "-" * n_dash + "]\r" + print progress_bar, + cost = self.bit_per_dim.eval() + eval_costs.append(cost) + print "" + return float(numpy.mean(eval_costs)) + + +def train_model(hps, logdir): + """Training.""" + with tf.Graph().as_default(): + with tf.device(tf.train.replica_device_setter(0)): + with tf.variable_scope("model"): + model = RealNVP(hps) + + saver = tf.train.Saver(tf.global_variables()) + + # Build the summary operation from the last tower summaries. + summary_op = tf.summary.merge_all() + + # Build an initialization operation to run below. + init = tf.global_variables_initializer() + + # Start running operations on the Graph. allow_soft_placement must be set to + # True to build towers on GPU, as some of the ops do not have GPU + # implementations. + sess = tf.Session(config=tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=True)) + sess.run(init) + + ckpt_state = tf.train.get_checkpoint_state(logdir) + if ckpt_state and ckpt_state.model_checkpoint_path: + print "Loading file %s" % ckpt_state.model_checkpoint_path + saver.restore(sess, ckpt_state.model_checkpoint_path) + + # Start the queue runners. + tf.train.start_queue_runners(sess=sess) + + summary_writer = tf.summary.FileWriter( + logdir, + graph=sess.graph) + + local_step = 0 + while True: + fetches = [model.step, model.bit_per_dim, model.train_step] + # The chief worker evaluates the summaries every 10 steps. + should_eval_summaries = local_step % 100 == 0 + if should_eval_summaries: + fetches += [summary_op] + + + start_time = time.time() + outputs = sess.run(fetches) + global_step_val = outputs[0] + loss = outputs[1] + duration = time.time() - start_time + assert not numpy.isnan( + loss), 'Model diverged with loss = NaN' + + if local_step % 10 == 0: + examples_per_sec = hps.batch_size / float(duration) + format_str = ('%s: step %d, loss = %.2f ' + '(%.1f examples/sec; %.3f ' + 'sec/batch)') + print format_str % (datetime.now(), global_step_val, loss, + examples_per_sec, duration) + + if should_eval_summaries: + summary_str = outputs[-1] + summary_writer.add_summary(summary_str, global_step_val) + + # Save the model checkpoint periodically. + if local_step % 1000 == 0 or (local_step + 1) == FLAGS.train_steps: + checkpoint_path = os.path.join(logdir, 'model.ckpt') + saver.save( + sess, + checkpoint_path, + global_step=global_step_val) + + if outputs[0] >= FLAGS.train_steps: + break + + local_step += 1 + + +def evaluate(hps, logdir, traindir, subset="valid", return_val=False): + """Evaluation.""" + hps.batch_size = 100 + with tf.Graph().as_default(): + with tf.device("/cpu:0"): + with tf.variable_scope("model") as var_scope: + eval_model = RealNVP(hps) + summary_writer = tf.summary.FileWriter(logdir) + var_scope.reuse_variables() + + saver = tf.train.Saver() + sess = tf.Session(config=tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=True)) + tf.train.start_queue_runners(sess) + + previous_global_step = 0 # don"t run eval for step = 0 + + with sess.as_default(): + while True: + ckpt_state = tf.train.get_checkpoint_state(traindir) + if not (ckpt_state and ckpt_state.model_checkpoint_path): + print "No model to eval yet at %s" % traindir + time.sleep(30) + continue + print "Loading file %s" % ckpt_state.model_checkpoint_path + saver.restore(sess, ckpt_state.model_checkpoint_path) + + current_step = tf.train.global_step(sess, eval_model.step) + if current_step == previous_global_step: + print "Waiting for the checkpoint to be updated." + time.sleep(30) + continue + previous_global_step = current_step + + print "Evaluating..." + bit_per_dim = eval_model.eval_epoch(hps) + print ("Epoch: %d, %s -> %.3f bits/dim" + % (current_step, subset, bit_per_dim)) + print "Writing summary..." + summary = tf.Summary() + summary.value.extend( + [tf.Summary.Value( + tag="bit_per_dim", + simple_value=bit_per_dim)]) + summary_writer.add_summary(summary, current_step) + + if return_val: + return current_step, bit_per_dim + + +def sample_from_model(hps, logdir, traindir): + """Sampling.""" + hps.batch_size = 100 + with tf.Graph().as_default(): + with tf.device("/cpu:0"): + with tf.variable_scope("model") as var_scope: + eval_model = RealNVP(hps, sampling=True) + summary_writer = tf.summary.FileWriter(logdir) + var_scope.reuse_variables() + + summary_op = tf.summary.merge_all() + saver = tf.train.Saver() + sess = tf.Session(config=tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=True)) + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(sess=sess, coord=coord) + + previous_global_step = 0 # don"t run eval for step = 0 + + initialized = False + with sess.as_default(): + while True: + ckpt_state = tf.train.get_checkpoint_state(traindir) + if not (ckpt_state and ckpt_state.model_checkpoint_path): + if not initialized: + print "No model to eval yet at %s" % traindir + time.sleep(30) + continue + else: + print ("Loading file %s" + % ckpt_state.model_checkpoint_path) + saver.restore(sess, ckpt_state.model_checkpoint_path) + + current_step = tf.train.global_step(sess, eval_model.step) + if current_step == previous_global_step: + print "Waiting for the checkpoint to be updated." + time.sleep(30) + continue + previous_global_step = current_step + + fetches = [summary_op] + + outputs = sess.run(fetches) + summary_writer.add_summary(outputs[0], current_step) + coord.request_stop() + coord.join(threads) + + +def main(unused_argv): + hps = get_default_hparams().update_config(FLAGS.hpconfig) + if FLAGS.mode == "train": + train_model(hps=hps, logdir=FLAGS.logdir) + elif FLAGS.mode == "sample": + sample_from_model(hps=hps, logdir=FLAGS.logdir, + traindir=FLAGS.traindir) + else: + hps.batch_size = 100 + evaluate(hps=hps, logdir=FLAGS.logdir, + traindir=FLAGS.traindir, subset=FLAGS.mode) + +if __name__ == "__main__": + tf.app.run() diff --git a/real_nvp/real_nvp_utils.py b/real_nvp/real_nvp_utils.py new file mode 100644 index 00000000000..203ca35ec4a --- /dev/null +++ b/real_nvp/real_nvp_utils.py @@ -0,0 +1,474 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +r"""Utility functions for Real NVP. +""" + +# pylint: disable=dangerous-default-value + +import numpy +import tensorflow as tf +from tensorflow.python.framework import ops + +DEFAULT_BN_LAG = .0 + + +def stable_var(input_, mean=None, axes=[0]): + """Numerically more stable variance computation.""" + if mean is None: + mean = tf.reduce_mean(input_, axes) + res = tf.square(input_ - mean) + max_sqr = tf.reduce_max(res, axes) + res /= max_sqr + res = tf.reduce_mean(res, axes) + res *= max_sqr + + return res + + +def variable_on_cpu(name, shape, initializer, trainable=True): + """Helper to create a Variable stored on CPU memory. + + Args: + name: name of the variable + shape: list of ints + initializer: initializer for Variable + trainable: boolean defining if the variable is for training + Returns: + Variable Tensor + """ + var = tf.get_variable( + name, shape, initializer=initializer, trainable=trainable) + return var + + +# layers +def conv_layer(input_, + filter_size, + dim_in, + dim_out, + name, + stddev=1e-2, + strides=[1, 1, 1, 1], + padding="SAME", + nonlinearity=None, + bias=False, + weight_norm=False, + scale=False): + """Convolutional layer.""" + with tf.variable_scope(name) as scope: + weights = variable_on_cpu( + "weights", + filter_size + [dim_in, dim_out], + tf.random_uniform_initializer( + minval=-stddev, maxval=stddev)) + # weight normalization + if weight_norm: + weights /= tf.sqrt(tf.reduce_sum(tf.square(weights), [0, 1, 2])) + if scale: + magnitude = variable_on_cpu( + "magnitude", [dim_out], + tf.constant_initializer( + stddev * numpy.sqrt(dim_in * numpy.prod(filter_size) / 12.))) + weights *= magnitude + res = input_ + # handling filter size bigger than image size + if hasattr(input_, "shape"): + if input_.get_shape().as_list()[1] < filter_size[0]: + pad_1 = tf.zeros([ + input_.get_shape().as_list()[0], + filter_size[0] - input_.get_shape().as_list()[1], + input_.get_shape().as_list()[2], + input_.get_shape().as_list()[3] + ]) + pad_2 = tf.zeros([ + input_.get_shape().as_list[0], + filter_size[0], + filter_size[1] - input_.get_shape().as_list()[2], + input_.get_shape().as_list()[3] + ]) + res = tf.concat(1, [pad_1, res]) + res = tf.concat(2, [pad_2, res]) + res = tf.nn.conv2d( + input=res, + filter=weights, + strides=strides, + padding=padding, + name=scope.name) + + if hasattr(input_, "shape"): + if input_.get_shape().as_list()[1] < filter_size[0]: + res = tf.slice(res, [ + 0, filter_size[0] - input_.get_shape().as_list()[1], + filter_size[1] - input_.get_shape().as_list()[2], 0 + ], [-1, -1, -1, -1]) + + if bias: + biases = variable_on_cpu("biases", [dim_out], tf.constant_initializer(0.)) + res = tf.nn.bias_add(res, biases) + if nonlinearity is not None: + res = nonlinearity(res) + + return res + + +def max_pool_2x2(input_): + """Max pooling.""" + return tf.nn.max_pool( + input_, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") + + +def depool_2x2(input_, stride=2): + """Depooling.""" + shape = input_.get_shape().as_list() + batch_size = shape[0] + height = shape[1] + width = shape[2] + channels = shape[3] + res = tf.reshape(input_, [batch_size, height, 1, width, 1, channels]) + res = tf.concat( + 2, [res, tf.zeros([batch_size, height, stride - 1, width, 1, channels])]) + res = tf.concat(4, [ + res, tf.zeros([batch_size, height, stride, width, stride - 1, channels]) + ]) + res = tf.reshape(res, [batch_size, stride * height, stride * width, channels]) + + return res + + +# random flip on a batch of images +def batch_random_flip(input_): + """Simultaneous horizontal random flip.""" + if isinstance(input_, (float, int)): + return input_ + shape = input_.get_shape().as_list() + batch_size = shape[0] + height = shape[1] + width = shape[2] + channels = shape[3] + res = tf.split(0, batch_size, input_) + res = [elem[0, :, :, :] for elem in res] + res = [tf.image.random_flip_left_right(elem) for elem in res] + res = [tf.reshape(elem, [1, height, width, channels]) for elem in res] + res = tf.concat(0, res) + + return res + + +# build a one hot representation corresponding to the integer tensor +# the one-hot dimension is appended to the integer tensor shape +def as_one_hot(input_, n_indices): + """Convert indices to one-hot.""" + shape = input_.get_shape().as_list() + n_elem = numpy.prod(shape) + indices = tf.range(n_elem) + indices = tf.cast(indices, tf.int64) + indices_input = tf.concat(0, [indices, tf.reshape(input_, [-1])]) + indices_input = tf.reshape(indices_input, [2, -1]) + indices_input = tf.transpose(indices_input) + res = tf.sparse_to_dense( + indices_input, [n_elem, n_indices], 1., 0., name="flat_one_hot") + res = tf.reshape(res, [elem for elem in shape] + [n_indices]) + + return res + + +def squeeze_2x2(input_): + """Squeezing operation: reshape to convert space to channels.""" + return squeeze_nxn(input_, n_factor=2) + + +def squeeze_nxn(input_, n_factor=2): + """Squeezing operation: reshape to convert space to channels.""" + if isinstance(input_, (float, int)): + return input_ + shape = input_.get_shape().as_list() + batch_size = shape[0] + height = shape[1] + width = shape[2] + channels = shape[3] + if height % n_factor != 0: + raise ValueError("Height not divisible by %d." % n_factor) + if width % n_factor != 0: + raise ValueError("Width not divisible by %d." % n_factor) + res = tf.reshape( + input_, + [batch_size, + height // n_factor, + n_factor, width // n_factor, + n_factor, channels]) + res = tf.transpose(res, [0, 1, 3, 5, 2, 4]) + res = tf.reshape( + res, + [batch_size, + height // n_factor, + width // n_factor, + channels * n_factor * n_factor]) + + return res + + +def unsqueeze_2x2(input_): + """Unsqueezing operation: reshape to convert channels into space.""" + if isinstance(input_, (float, int)): + return input_ + shape = input_.get_shape().as_list() + batch_size = shape[0] + height = shape[1] + width = shape[2] + channels = shape[3] + if channels % 4 != 0: + raise ValueError("Number of channels not divisible by 4.") + res = tf.reshape(input_, [batch_size, height, width, channels // 4, 2, 2]) + res = tf.transpose(res, [0, 1, 4, 2, 5, 3]) + res = tf.reshape(res, [batch_size, 2 * height, 2 * width, channels // 4]) + + return res + + +# batch norm +def batch_norm(input_, + dim, + name, + scale=True, + train=True, + epsilon=1e-8, + decay=.1, + axes=[0], + bn_lag=DEFAULT_BN_LAG): + """Batch normalization.""" + # create variables + with tf.variable_scope(name): + var = variable_on_cpu( + "var", [dim], tf.constant_initializer(1.), trainable=False) + mean = variable_on_cpu( + "mean", [dim], tf.constant_initializer(0.), trainable=False) + step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False) + if scale: + gamma = variable_on_cpu("gamma", [dim], tf.constant_initializer(1.)) + beta = variable_on_cpu("beta", [dim], tf.constant_initializer(0.)) + # choose the appropriate moments + if train: + used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm") + cur_mean, cur_var = used_mean, used_var + if bn_lag > 0.: + used_mean -= (1. - bn_lag) * (used_mean - tf.stop_gradient(mean)) + used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var)) + used_mean /= (1. - bn_lag**(step + 1)) + used_var /= (1. - bn_lag**(step + 1)) + else: + used_mean, used_var = mean, var + cur_mean, cur_var = used_mean, used_var + + # normalize + res = (input_ - used_mean) / tf.sqrt(used_var + epsilon) + # de-normalize + if scale: + res *= gamma + res += beta + + # update variables + if train: + with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]): + with ops.colocate_with(mean): + new_mean = tf.assign_sub( + mean, + tf.check_numerics(decay * (mean - cur_mean), "NaN in moving mean.")) + with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]): + with ops.colocate_with(var): + new_var = tf.assign_sub( + var, + tf.check_numerics(decay * (var - cur_var), + "NaN in moving variance.")) + with tf.name_scope(name, "IncrementTime", [step]): + with ops.colocate_with(step): + new_step = tf.assign_add(step, 1.) + res += 0. * new_mean * new_var * new_step + + return res + + +# batch normalization taking into account the volume transformation +def batch_norm_log_diff(input_, + dim, + name, + train=True, + epsilon=1e-8, + decay=.1, + axes=[0], + reuse=None, + bn_lag=DEFAULT_BN_LAG): + """Batch normalization with corresponding log determinant Jacobian.""" + if reuse is None: + reuse = not train + # create variables + with tf.variable_scope(name) as scope: + if reuse: + scope.reuse_variables() + var = variable_on_cpu( + "var", [dim], tf.constant_initializer(1.), trainable=False) + mean = variable_on_cpu( + "mean", [dim], tf.constant_initializer(0.), trainable=False) + step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False) + # choose the appropriate moments + if train: + used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm") + cur_mean, cur_var = used_mean, used_var + if bn_lag > 0.: + used_var = stable_var(input_=input_, mean=used_mean, axes=axes) + cur_var = used_var + used_mean -= (1 - bn_lag) * (used_mean - tf.stop_gradient(mean)) + used_mean /= (1. - bn_lag**(step + 1)) + used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var)) + used_var /= (1. - bn_lag**(step + 1)) + else: + used_mean, used_var = mean, var + cur_mean, cur_var = used_mean, used_var + + # update variables + if train: + with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]): + with ops.colocate_with(mean): + new_mean = tf.assign_sub( + mean, + tf.check_numerics( + decay * (mean - cur_mean), "NaN in moving mean.")) + with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]): + with ops.colocate_with(var): + new_var = tf.assign_sub( + var, + tf.check_numerics(decay * (var - cur_var), + "NaN in moving variance.")) + with tf.name_scope(name, "IncrementTime", [step]): + with ops.colocate_with(step): + new_step = tf.assign_add(step, 1.) + used_var += 0. * new_mean * new_var * new_step + used_var += epsilon + + return used_mean, used_var + + +def convnet(input_, + dim_in, + dim_hid, + filter_sizes, + dim_out, + name, + use_batch_norm=True, + train=True, + nonlinearity=tf.nn.relu): + """Chaining of convolutional layers.""" + dims_in = [dim_in] + dim_hid[:-1] + dims_out = dim_hid + res = input_ + + bias = (not use_batch_norm) + with tf.variable_scope(name): + for layer_idx in xrange(len(dim_hid)): + res = conv_layer( + input_=res, + filter_size=filter_sizes[layer_idx], + dim_in=dims_in[layer_idx], + dim_out=dims_out[layer_idx], + name="h_%d" % layer_idx, + stddev=1e-2, + nonlinearity=None, + bias=bias) + if use_batch_norm: + res = batch_norm( + input_=res, + dim=dims_out[layer_idx], + name="bn_%d" % layer_idx, + scale=(nonlinearity == tf.nn.relu), + train=train, + epsilon=1e-8, + axes=[0, 1, 2]) + if nonlinearity is not None: + res = nonlinearity(res) + + res = conv_layer( + input_=res, + filter_size=filter_sizes[-1], + dim_in=dims_out[-1], + dim_out=dim_out, + name="out", + stddev=1e-2, + nonlinearity=None) + + return res + + +# distributions +# log-likelihood estimation +def standard_normal_ll(input_): + """Log-likelihood of standard Gaussian distribution.""" + res = -.5 * (tf.square(input_) + numpy.log(2. * numpy.pi)) + + return res + + +def standard_normal_sample(shape): + """Samples from standard Gaussian distribution.""" + return tf.random_normal(shape) + + +SQUEEZE_MATRIX = numpy.array([[[[1., 0., 0., 0.]], [[0., 0., 1., 0.]]], + [[[0., 0., 0., 1.]], [[0., 1., 0., 0.]]]]) + + +def squeeze_2x2_ordered(input_, reverse=False): + """Squeezing operation with a controlled ordering.""" + shape = input_.get_shape().as_list() + batch_size = shape[0] + height = shape[1] + width = shape[2] + channels = shape[3] + if reverse: + if channels % 4 != 0: + raise ValueError("Number of channels not divisible by 4.") + channels /= 4 + else: + if height % 2 != 0: + raise ValueError("Height not divisible by 2.") + if width % 2 != 0: + raise ValueError("Width not divisible by 2.") + weights = numpy.zeros((2, 2, channels, 4 * channels)) + for idx_ch in xrange(channels): + slice_2 = slice(idx_ch, (idx_ch + 1)) + slice_3 = slice((idx_ch * 4), ((idx_ch + 1) * 4)) + weights[:, :, slice_2, slice_3] = SQUEEZE_MATRIX + shuffle_channels = [idx_ch * 4 for idx_ch in xrange(channels)] + shuffle_channels += [idx_ch * 4 + 1 for idx_ch in xrange(channels)] + shuffle_channels += [idx_ch * 4 + 2 for idx_ch in xrange(channels)] + shuffle_channels += [idx_ch * 4 + 3 for idx_ch in xrange(channels)] + shuffle_channels = numpy.array(shuffle_channels) + weights = weights[:, :, :, shuffle_channels].astype("float32") + if reverse: + res = tf.nn.conv2d_transpose( + value=input_, + filter=weights, + output_shape=[batch_size, height * 2, width * 2, channels], + strides=[1, 2, 2, 1], + padding="SAME", + name="unsqueeze_2x2") + else: + res = tf.nn.conv2d( + input=input_, + filter=weights, + strides=[1, 2, 2, 1], + padding="SAME", + name="squeeze_2x2") + + return res