Skip to content

Commit

Permalink
Update Neural GPU: allow to quantize activations, add a few tasks.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lukasz Kaiser committed Mar 14, 2016
1 parent c3589c2 commit 2e4f31a
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 42 deletions.
58 changes: 51 additions & 7 deletions neural_gpu/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Convolutional Gated Recurrent Networks for Algorithm Learning."""

import math
Expand All @@ -27,9 +26,10 @@

FLAGS = tf.app.flags.FLAGS

bins = [8, 16, 32, 64, 128]
all_tasks = ["sort", "id", "rev", "incr", "left", "right", "left-shift", "add",
"right-shift", "bmul", "dup", "badd", "qadd"]
bins = [8, 12, 16, 20, 24, 28, 32, 36, 40, 48, 64, 128]
all_tasks = ["sort", "kvsort", "id", "rev", "rev2", "incr", "add", "left",
"right", "left-shift", "right-shift", "bmul", "mul", "dup",
"badd", "qadd", "search"]
forward_max = 128
log_filename = ""

Expand Down Expand Up @@ -82,10 +82,13 @@ def rand_pair(l, task):
d2 = [np.random.randint(base) for _ in xrange(k)]
if task in ["add", "badd", "qadd"]:
res = add(d1, d2, base)
elif task in ["bmul"]:
elif task in ["mul", "bmul"]:
d1n = sum([d * (base ** i) for i, d in enumerate(d1)])
d2n = sum([d * (base ** i) for i, d in enumerate(d2)])
res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]]
if task == "bmul":
res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]]
else:
res = [int(x) for x in list(reversed(str(d1n * d2n)))]
else:
sys.exit()
sep = [12]
Expand All @@ -101,6 +104,32 @@ def rand_dup_pair(l):
res = x + x + [0 for _ in xrange(l - 2*k)]
return inp, res

def rand_rev2_pair(l):
"""Random data pair for reverse2 task. Total length should be <= l."""
inp = [(np.random.randint(nclass - 1) + 1,
np.random.randint(nclass - 1) + 1) for _ in xrange(l/2)]
res = [i for i in reversed(inp)]
return [x for p in inp for x in p], [x for p in res for x in p]

def rand_search_pair(l):
"""Random data pair for search task. Total length should be <= l."""
inp = [(np.random.randint(nclass - 1) + 1,
np.random.randint(nclass - 1) + 1) for _ in xrange(l-1/2)]
q = np.random.randint(nclass - 1) + 1
res = 0
for (k, v) in reversed(inp):
if k == q:
res = v
return [x for p in inp for x in p] + [q], [res]

def rand_kvsort_pair(l):
"""Random data pair for key-value sort. Total length should be <= l."""
keys = [(np.random.randint(nclass - 1) + 1, i) for i in xrange(l/2)]
vals = [np.random.randint(nclass - 1) + 1 for _ in xrange(l/2)]
kv = [(k, vals[i]) for (k, i) in keys]
sorted_kv = [(k, vals[i]) for (k, i) in sorted(keys)]
return [x for p in kv for x in p], [x for p in sorted_kv for x in p]

def spec(inp):
"""Return the target given the input for some tasks."""
if task == "sort":
Expand Down Expand Up @@ -140,7 +169,7 @@ def spec(inp):
cur_time = time.time()
if l > 10000 and case % 100 == 1:
print_out(" avg gen time %.4f s" % (total_time / float(case)))
if task in ["add", "badd", "qadd", "bmul"]:
if task in ["add", "badd", "qadd", "bmul", "mul"]:
i, t = rand_pair(l, task)
train_set[task][len(i)].append([i, t])
i, t = rand_pair(l, task)
Expand All @@ -150,6 +179,21 @@ def spec(inp):
train_set[task][len(i)].append([i, t])
i, t = rand_dup_pair(l)
test_set[task][len(i)].append([i, t])
elif task == "rev2":
i, t = rand_rev2_pair(l)
train_set[task][len(i)].append([i, t])
i, t = rand_rev2_pair(l)
test_set[task][len(i)].append([i, t])
elif task == "search":
i, t = rand_search_pair(l)
train_set[task][len(i)].append([i, t])
i, t = rand_search_pair(l)
test_set[task][len(i)].append([i, t])
elif task == "kvsort":
i, t = rand_kvsort_pair(l)
train_set[task][len(i)].append([i, t])
i, t = rand_kvsort_pair(l)
test_set[task][len(i)].append([i, t])
else:
inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
target = spec(inp)
Expand Down
59 changes: 48 additions & 11 deletions neural_gpu/neural_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""The Neural GPU Model."""

import time
Expand Down Expand Up @@ -47,17 +46,46 @@ def sigmoid_cutoff(x, cutoff):
return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d))


def tanh_cutoff(x, cutoff):
"""Tanh with cutoff, e.g., 1.1tanh(x) cut to [-1. 1]."""
y = tf.tanh(x)
if cutoff < 1.01: return y
d = (cutoff - 1.0) / 2.0
return tf.minimum(1.0, tf.maximum(-1.0, (1.0 + d) * y))


def conv_gru(inpts, mem, kw, kh, nmaps, cutoff, prefix):
"""Convolutional GRU."""
def conv_lin(args, suffix, bias_start):
return conv_linear(args, kw, kh, len(args) * nmaps, nmaps, True, bias_start,
prefix + "/" + suffix)
reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
# candidate = tanh_cutoff(conv_lin(inpts + [reset * mem], "c", 0.0), cutoff)
candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
return gate * mem + (1 - gate) * candidate


@tf.RegisterGradient("CustomIdG")
def _custom_id_grad(_, grads):
return grads


def quantize(t, quant_scale, max_value=1.0):
"""Quantize a tensor t with each element in [-max_value, max_value]."""
t = tf.minimum(max_value, tf.maximum(t, -max_value))
big = quant_scale * (t + max_value) + 0.5
with tf.get_default_graph().gradient_override_map({"Floor": "CustomIdG"}):
res = (tf.floor(big) / quant_scale) - max_value
return res


def quantize_weights_op(quant_scale, max_value):
ops = [v.assign(quantize(v, quant_scale, float(max_value)))
for v in tf.trainable_variables()]
return tf.group(*ops)


def relaxed_average(var_name_suffix, rx_step):
"""Calculate the average of relaxed variables having var_name_suffix."""
relaxed_vars = []
Expand Down Expand Up @@ -117,7 +145,7 @@ class NeuralGPU(object):

def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
max_grad_norm, cutoff, nconvs, kw, kh, height, mode,
learning_rate, pull, pull_incr, min_length):
learning_rate, pull, pull_incr, min_length, act_noise=0.0):
# Feeds for parameters and ops to update them.
self.global_step = tf.Variable(0, trainable=False)
self.cur_length = tf.Variable(min_length, trainable=False)
Expand Down Expand Up @@ -195,7 +223,9 @@ def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
first = tf.concat(2, first)

# Computation steps.
step = [tf.nn.dropout(first, 1.0 - self.do_training * dropout) * mask]
keep_prob = 1.0 - self.do_training * (dropout * 8.0 / float(length))
step = [tf.nn.dropout(first, keep_prob) * mask]
act_noise_scale = act_noise * self.do_training * self.pull
outputs = []
for it in xrange(length):
with tf.variable_scope("RX%d" % (it % rx_step)) as vs:
Expand All @@ -205,9 +235,12 @@ def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
# Do nconvs-many CGRU steps.
for layer in xrange(nconvs):
cur = conv_gru([], cur, kw, kh, nmaps, cutoff, "cgru_%d" % layer)
cur = tf.nn.dropout(cur, 1.0 - self.do_training * dropout)
cur *= mask
outputs.append(tf.slice(cur, [0, 0, 0, 0], [-1, -1, 1, -1]))
cur = tf.nn.dropout(cur, keep_prob)
if act_noise > 0.00001:
cur += tf.truncated_normal(tf.shape(cur)) * act_noise_scale
step.append(cur * mask)
outputs.append(tf.slice(step[-1], [0, 0, 0, 0], [-1, -1, 1, -1]))

self.steps.append([tf.reshape(s, [-1, length, height * nmaps])
for s in step])
Expand All @@ -216,8 +249,10 @@ def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
# Final convolution to get logits, list outputs.
output = conv_linear(output, 1, 1, nmaps, noclass, True, 0.0, "output")
output = tf.reshape(output, [-1, length, noclass])
self.outputs.append([tf.reshape(o, [-1, noclass])
for o in list(tf.split(1, length, output))])
external_output = [tf.reshape(o, [-1, noclass])
for o in list(tf.split(1, length, output))]
external_output = [tf.nn.softmax(o) for o in external_output]
self.outputs.append(external_output)

# Calculate cross-entropy loss and normalize it.
targets = tf.concat(1, [make_dense(self.target[l], noclass)
Expand Down Expand Up @@ -252,7 +287,8 @@ def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
" %.2f s." % (length, time.time() - start_time))
self.saver = tf.train.Saver(tf.all_variables())

def step(self, sess, inp, target, do_backward, noise_param=None):
def step(self, sess, inp, target, do_backward, noise_param=None,
get_steps=False):
"""Run a step of the network."""
assert len(inp) == len(target)
length = len(target)
Expand All @@ -272,14 +308,15 @@ def step(self, sess, inp, target, do_backward, noise_param=None):
for l in xrange(length):
feed_in[self.target[l].name] = target[l]
feed_out.append(self.outputs[index][l])
for l in xrange(length+1):
feed_out.append(self.steps[index][l])
if get_steps:
for l in xrange(length+1):
feed_out.append(self.steps[index][l])
res = sess.run(feed_out, feed_in)
offset = 0
norm = None
if do_backward:
offset = 2
norm = res[1]
outputs = res[offset + 1:offset + 1 + length]
steps = res[offset + 1 + length:]
steps = res[offset + 1 + length:] if get_steps else None
return res[offset], outputs, norm, steps
Loading

0 comments on commit 2e4f31a

Please sign in to comment.