Skip to content

Commit

Permalink
Fixed calls to concat and convolution2d
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgilbert12 committed Jun 6, 2017
1 parent cfdbdf1 commit 514a10d
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions adversarial_crypto/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def model(self, collection, message, key=None):
"""

if key is not None:
combined_message = tf.concat(1, [message, key])
combined_message = tf.concat([message, key], 1)
else:
combined_message = message

# Ensure that all variables created are in the specified collection.
with tf.contrib.framework.arg_scope(
[tf.contrib.layers.fully_connected, tf.contrib.layers.convolution],
[tf.contrib.layers.fully_connected, tf.contrib.layers.convolution2d],
variables_collections=[collection]):

fc = tf.contrib.layers.fully_connected(
Expand All @@ -147,13 +147,13 @@ def model(self, collection, message, key=None):
# and then squeezing it back down).
fc = tf.expand_dims(fc, 2)
# 2,1 -> 1,2
conv = tf.contrib.layers.convolution(
conv = tf.contrib.layers.convolution2d(
fc, 2, 2, 2, 'SAME', activation_fn=tf.nn.sigmoid)
# 1,2 -> 1, 2
conv = tf.contrib.layers.convolution(
conv = tf.contrib.layers.convolution2d(
conv, 2, 1, 1, 'SAME', activation_fn=tf.nn.sigmoid)
# 1,2 -> 1, 1
conv = tf.contrib.layers.convolution(
conv = tf.contrib.layers.convolution2d(
conv, 1, 1, 1, 'SAME', activation_fn=tf.nn.tanh)
conv = tf.squeeze(conv, 2)
return conv
Expand Down

0 comments on commit 514a10d

Please sign in to comment.