Skip to content

Commit

Permalink
Reproduce reported virtual adversarial text results
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan Sepassi committed Jun 15, 2017
1 parent fc7342b commit b5afddb
Show file tree
Hide file tree
Showing 17 changed files with 208 additions and 116 deletions.
21 changes: 21 additions & 0 deletions adversarial_text/BUILD
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
licenses(["notice"]) # Apache 2.0

# Binaries
# ==============================================================================
py_binary(
name = "evaluate",
srcs = ["evaluate.py"],
deps = [
":graphs",
# google3 file dep,
# tensorflow dep,
],
)

Expand All @@ -14,6 +18,8 @@ py_binary(
deps = [
":graphs",
":train_utils",
# google3 file dep,
# tensorflow dep,
],
)

Expand All @@ -25,6 +31,8 @@ py_binary(
deps = [
":graphs",
":train_utils",
# google3 file dep,
# tensorflow dep,
],
)

Expand All @@ -37,30 +45,42 @@ py_library(
":adversarial_losses",
":inputs",
":layers",
# tensorflow dep,
],
)

py_library(
name = "adversarial_losses",
srcs = ["adversarial_losses.py"],
deps = [
# tensorflow dep,
],
)

py_library(
name = "inputs",
srcs = ["inputs.py"],
deps = [
# tensorflow dep,
"//adversarial_text/data:data_utils",
],
)

py_library(
name = "layers",
srcs = ["layers.py"],
deps = [
# tensorflow dep,
],
)

py_library(
name = "train_utils",
srcs = ["train_utils.py"],
deps = [
# numpy dep,
# tensorflow dep,
],
)

# Tests
Expand All @@ -71,6 +91,7 @@ py_test(
srcs = ["graphs_test.py"],
deps = [
":graphs",
# tensorflow dep,
"//adversarial_text/data:data_utils",
],
)
19 changes: 9 additions & 10 deletions adversarial_text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ $ bazel run :pretrain -- \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--num_candidate_samples=1024 \
--optimizer=adam \
--batch_size=256 \
--learning_rate=0.001 \
--learning_rate_decay_factor=0.9999 \
Expand Down Expand Up @@ -87,7 +86,6 @@ $ bazel run :train_classifier -- \
--rnn_cell_size=1024 \
--cl_num_layers=1 \
--cl_hidden_size=30 \
--optimizer=adam \
--batch_size=64 \
--learning_rate=0.0005 \
--learning_rate_decay_factor=0.9998 \
Expand All @@ -96,7 +94,8 @@ $ bazel run :train_classifier -- \
--num_timesteps=400 \
--keep_prob_emb=0.5 \
--normalize_embeddings \
--adv_training_method=vat
--adv_training_method=vat \
--perturb_norm_length=5.0
```

### Evaluate on test data
Expand Down Expand Up @@ -136,21 +135,21 @@ adversarial training losses). The training loop itself is defined in
### Command-Line Flags

Flags related to distributed training and the training loop itself are defined
in `train_utils.py`.
in [`train_utils.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/train_utils.py).

Flags related to model hyperparameters are defined in `graphs.py`.
Flags related to model hyperparameters are defined in [`graphs.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/graphs.py).

Flags related to adversarial training are defined in `adversarial_losses.py`.
Flags related to adversarial training are defined in [`adversarial_losses.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/adversarial_losses.py).

Flags particular to each job are defined in the main binary files.

### Data Generation

* Vocabulary generation: `gen_vocab.py`
* Data generation: `gen_data.py`
* Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/gen_vocab.py)
* Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/gen_data.py)

Command-line flags defined in `document_generators.py` control which dataset is
processed and how.
Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/document_generators.py)
control which dataset is processed and how.

## Contact for Issues

Expand Down
68 changes: 27 additions & 41 deletions adversarial_text/adversarial_losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,25 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Adversarial losses for text models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

import tensorflow as tf

flags = tf.app.flags
FLAGS = flags.FLAGS

# Adversarial and virtual adversarial training parameters.
flags.DEFINE_float('perturb_norm_length', 0.1,
flags.DEFINE_float('perturb_norm_length', 5.0,
'Norm length of adversarial perturbation to be '
'optimized with validation')
'optimized with validation. '
'5.0 is optimal on IMDB with virtual adversarial training. ')

# Virtual adversarial training parameters
flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration')
flags.DEFINE_float('small_constant_for_finite_diff', 1e-3,
flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
'Small constant for finite difference method')

# Parameters for building the graph
Expand Down Expand Up @@ -83,19 +85,22 @@ def virtual_adversarial_loss(logits, embedded, inputs,
"""
# Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
logits = tf.stop_gradient(logits)

# Only care about the KL divergence on the final timestep.
weights = _end_of_seq_mask(inputs.labels)
weights = inputs.eos_weights
assert weights is not None

# Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
d = _mask_by_length(tf.random_normal(shape=tf.shape(embedded)), inputs.length)
d = tf.random_normal(shape=tf.shape(embedded))

# Perform finite difference method and power iteration.
# See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf,
# Adding small noise to input and taking gradient with respect to the noise
# corresponds to 1 power iteration.
for _ in xrange(FLAGS.num_power_iteration):
d = _scale_l2(d, FLAGS.small_constant_for_finite_diff)
d = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)
d_logits = logits_from_embedding_fn(embedded + d)
kl = _kl_divergence_with_logits(logits, d_logits, weights)
d, = tf.gradients(
Expand All @@ -104,8 +109,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
d = tf.stop_gradient(d)

perturb = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.perturb_norm_length)
perturb = _scale_l2(d, FLAGS.perturb_norm_length)
vadv_logits = logits_from_embedding_fn(embedded + perturb)
return _kl_divergence_with_logits(logits, vadv_logits, weights)

Expand Down Expand Up @@ -136,7 +140,8 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
"""Virtual adversarial loss for bidirectional models."""
logits = tf.stop_gradient(logits)
f_inputs, _ = inputs
weights = _end_of_seq_mask(f_inputs.labels)
weights = f_inputs.eos_weights
assert weights is not None

perturbs = [
_mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length)
Expand All @@ -155,10 +160,7 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
perturbs = [tf.stop_gradient(d) for d in perturbs]

perturbs = [
_scale_l2(_mask_by_length(d, f_inputs.length), FLAGS.perturb_norm_length)
for d in perturbs
]
perturbs = [_scale_l2(d, FLAGS.perturb_norm_length) for d in perturbs]
vadv_logits = logits_from_embedding_fn(
[emb + d for (emb, d) in zip(embedded, perturbs)])
return _kl_divergence_with_logits(logits, vadv_logits, weights)
Expand All @@ -167,40 +169,26 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
def _mask_by_length(t, length):
"""Mask t, 3-D [batch, time, dim], by length, 1-D [batch,]."""
maxlen = t.get_shape().as_list()[1]
mask = tf.sequence_mask(length, maxlen=maxlen)

# Subtract 1 from length to prevent the perturbation from going on 'eos'
mask = tf.sequence_mask(length - 1, maxlen=maxlen)
mask = tf.expand_dims(tf.cast(mask, tf.float32), -1)
# shape(mask) = (batch, num_timesteps, 1)
return t * mask


def _scale_l2(x, norm_length):
# shape(x) = (batch, num_timesteps, d)

# Divide x by max(abs(x)) for a numerically stable L2 norm.
# 2norm(x) = a * 2norm(x/a)
# Scale over the full sequence, dims (1, 2)
alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12
l2_norm = alpha * tf.sqrt(tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2),
keep_dims=True) + 1e-6)
l2_norm = alpha * tf.sqrt(
tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2), keep_dims=True) + 1e-6)
x_unit = x / l2_norm
return norm_length * x_unit


def _end_of_seq_mask(tokens):
"""Generate a mask for the EOS token (1.0 on EOS, 0.0 otherwise).
Args:
tokens: 1-D integer tensor [num_timesteps*batch_size]. Each element is an
id from the vocab.
Returns:
Float tensor same shape as tokens, whose values are 1.0 on the end of
sequence and 0.0 on the others.
"""
eos_id = FLAGS.vocab_size - 1
return tf.cast(tf.equal(tokens, eos_id), tf.float32)


def _kl_divergence_with_logits(q_logits, p_logits, weights):
"""Returns weighted KL divergence between distributions q and p.
Expand All @@ -218,21 +206,19 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
# For logistic regression
if FLAGS.num_classes == 2:
q = tf.nn.sigmoid(q_logits)
p = tf.nn.sigmoid(p_logits)
kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q))
kl = tf.squeeze(kl)

# For softmax regression
else:
q = tf.nn.softmax(q_logits)
p = tf.nn.softmax(p_logits)
kl = tf.reduce_sum(q * (tf.log(q) - tf.log(p)), 1)
kl = tf.reduce_sum(
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), 1)

num_labels = tf.reduce_sum(weights)
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)

kl.get_shape().assert_has_rank(2)
kl.get_shape().assert_has_rank(1)
weights.get_shape().assert_has_rank(1)
loss = tf.identity(tf.reduce_sum(tf.expand_dims(weights, -1) * kl) /
num_labels, name='kl')
loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl')
return loss
11 changes: 11 additions & 0 deletions adversarial_text/data/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
licenses(["notice"]) # Apache 2.0

package(
default_visibility = [
"//adversarial_text:__subpackages__",
Expand All @@ -10,6 +12,7 @@ py_binary(
deps = [
":data_utils",
":document_generators",
# tensorflow dep,
],
)

Expand All @@ -19,23 +22,31 @@ py_binary(
deps = [
":data_utils",
":document_generators",
# tensorflow dep,
],
)

py_library(
name = "document_generators",
srcs = ["document_generators.py"],
deps = [
# tensorflow dep,
],
)

py_library(
name = "data_utils",
srcs = ["data_utils.py"],
deps = [
# tensorflow dep,
],
)

py_test(
name = "data_utils_test",
srcs = ["data_utils_test.py"],
deps = [
":data_utils",
# tensorflow dep,
],
)
20 changes: 13 additions & 7 deletions adversarial_text/data/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utilities for generating/preprocessing data for adversarial text models."""

import operator
import os
import random
import re

# Dependency imports

import tensorflow as tf

EOS_TOKEN = '</s>'
Expand Down Expand Up @@ -215,13 +217,17 @@ def build_lm_sequence(seq):
Returns:
SequenceWrapper with `seq` tokens copied over to output sequence tokens and
labels (offset by 1, i.e. predict next token) with weights set to 1.0.
labels (offset by 1, i.e. predict next token) with weights set to 1.0,
except for <eos> token.
"""
lm_seq = SequenceWrapper()
for i, timestep in enumerate(seq[:-1]):
lm_seq.add_timestep().set_token(timestep.token).set_label(
seq[i + 1].token).set_weight(1.0)

for i, timestep in enumerate(seq):
if i == len(seq) - 1:
lm_seq.add_timestep().set_token(timestep.token).set_label(
seq[i].token).set_weight(0.0)
else:
lm_seq.add_timestep().set_token(timestep.token).set_label(
seq[i + 1].token).set_weight(1.0)
return lm_seq


Expand Down
Loading

0 comments on commit b5afddb

Please sign in to comment.