Skip to content

Commit

Permalink
Ability to train the translation model on arbitrary input sources.
Browse files Browse the repository at this point in the history
  • Loading branch information
Viacheslav Kovalevskyi committed Jan 4, 2017
1 parent 0d9a3ab commit c902a86
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 44 deletions.
68 changes: 50 additions & 18 deletions tutorials/rnn/translate/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def data_to_token_ids(data_path, target_path, vocabulary_path,
counter += 1
if counter % 100000 == 0:
print(" tokenizing line %d" % counter)
token_ids = sentence_to_token_ids(tf.compat.as_bytes(line), vocab,
tokenizer, normalize_digits)
token_ids = sentence_to_token_ids(line, vocab, tokenizer,
normalize_digits)
tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")


Expand All @@ -267,24 +267,56 @@ def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size, tokenizer
train_path = get_wmt_enfr_train_set(data_dir)
dev_path = get_wmt_enfr_dev_set(data_dir)

from_train_path = train_path + ".en"
to_train_path = train_path + ".fr"
from_dev_path = dev_path + ".en"
to_dev_path = dev_path + ".fr"
return prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, en_vocabulary_size,
fr_vocabulary_size, tokenizer)


def prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, from_vocabulary_size,
to_vocabulary_size, tokenizer=None):
"""Preapre all necessary files that are required for the training.
Args:
data_dir: directory in which the data sets will be stored.
from_train_path: path to the file that includes "from" training samples.
to_train_path: path to the file that includes "to" training samples.
from_dev_path: path to the file that includes "from" dev samples.
to_dev_path: path to the file that includes "to" dev samples.
from_vocabulary_size: size of the "from language" vocabulary to create and use.
to_vocabulary_size: size of the "to language" vocabulary to create and use.
tokenizer: a function to use to tokenize each data sentence;
if None, basic_tokenizer will be used.
Returns:
A tuple of 6 elements:
(1) path to the token-ids for "from language" training data-set,
(2) path to the token-ids for "to language" training data-set,
(3) path to the token-ids for "from language" development data-set,
(4) path to the token-ids for "to language" development data-set,
(5) path to the "from language" vocabulary file,
(6) path to the "to language" vocabulary file.
"""
# Create vocabularies of the appropriate sizes.
fr_vocab_path = os.path.join(data_dir, "vocab%d.fr" % fr_vocabulary_size)
en_vocab_path = os.path.join(data_dir, "vocab%d.en" % en_vocabulary_size)
create_vocabulary(fr_vocab_path, train_path + ".fr", fr_vocabulary_size, tokenizer)
create_vocabulary(en_vocab_path, train_path + ".en", en_vocabulary_size, tokenizer)
to_vocab_path = os.path.join(data_dir, "vocab%d" % to_vocabulary_size)
from_vocab_path = os.path.join(data_dir, "vocab%d" % from_vocabulary_size)
create_vocabulary(to_vocab_path, to_train_path , to_vocabulary_size, tokenizer)
create_vocabulary(from_vocab_path, from_train_path , from_vocabulary_size, tokenizer)

# Create token ids for the training data.
fr_train_ids_path = train_path + (".ids%d.fr" % fr_vocabulary_size)
en_train_ids_path = train_path + (".ids%d.en" % en_vocabulary_size)
data_to_token_ids(train_path + ".fr", fr_train_ids_path, fr_vocab_path, tokenizer)
data_to_token_ids(train_path + ".en", en_train_ids_path, en_vocab_path, tokenizer)
to_train_ids_path = to_train_path + (".ids%d" % to_vocabulary_size)
from_train_ids_path = from_train_path + (".ids%d" % from_vocabulary_size)
data_to_token_ids(to_train_path, to_train_ids_path, to_vocab_path, tokenizer)
data_to_token_ids(from_train_path, from_train_ids_path, from_vocab_path, tokenizer)

# Create token ids for the development data.
fr_dev_ids_path = dev_path + (".ids%d.fr" % fr_vocabulary_size)
en_dev_ids_path = dev_path + (".ids%d.en" % en_vocabulary_size)
data_to_token_ids(dev_path + ".fr", fr_dev_ids_path, fr_vocab_path, tokenizer)
data_to_token_ids(dev_path + ".en", en_dev_ids_path, en_vocab_path, tokenizer)

return (en_train_ids_path, fr_train_ids_path,
en_dev_ids_path, fr_dev_ids_path,
en_vocab_path, fr_vocab_path)
to_dev_ids_path = to_dev_path + (".ids%d" % to_vocabulary_size)
from_dev_ids_path = from_dev_path + (".ids%d" % from_vocabulary_size)
data_to_token_ids(to_dev_path, to_dev_ids_path, to_vocab_path, tokenizer)
data_to_token_ids(from_dev_path, from_dev_ids_path, from_vocab_path, tokenizer)

return (from_train_ids_path, to_train_ids_path,
from_dev_ids_path, to_dev_ids_path,
from_vocab_path, to_vocab_path)
21 changes: 8 additions & 13 deletions tutorials/rnn/translate/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,27 +108,22 @@ def sampled_loss(labels, inputs):
local_b = tf.cast(b, tf.float32)
local_inputs = tf.cast(inputs, tf.float32)
return tf.cast(
tf.nn.sampled_softmax_loss(
weights=local_w_t,
biases=local_b,
labels=labels,
inputs=local_inputs,
num_sampled=num_samples,
num_classes=self.target_vocab_size),
tf.nn.sampled_softmax_loss(local_w_t, local_b, local_inputs, labels,
num_samples, self.target_vocab_size),
dtype)
softmax_loss_function = sampled_loss

# Create the internal multi-layer cell for our RNN.
single_cell = tf.contrib.rnn.GRUCell(size)
single_cell = tf.nn.rnn_cell.GRUCell(size)
if use_lstm:
single_cell = tf.contrib.rnn.BasicLSTMCell(size)
single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
cell = single_cell
if num_layers > 1:
cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)
cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)

# The seq2seq function: we use embedding for the input and attention.
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
return tf.nn.seq2seq.embedding_attention_seq2seq(
encoder_inputs,
decoder_inputs,
cell,
Expand Down Expand Up @@ -158,7 +153,7 @@ def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):

# Training outputs and losses.
if forward_only:
self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
self.outputs, self.losses = tf.nn.seq2seq.model_with_buckets(
self.encoder_inputs, self.decoder_inputs, targets,
self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True),
softmax_loss_function=softmax_loss_function)
Expand All @@ -170,7 +165,7 @@ def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
for output in self.outputs[b]
]
else:
self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
self.outputs, self.losses = tf.nn.seq2seq.model_with_buckets(
self.encoder_inputs, self.decoder_inputs, targets,
self.target_weights, buckets,
lambda x, y: seq2seq_f(x, y, False),
Expand Down
51 changes: 38 additions & 13 deletions tutorials/rnn/translate/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,14 @@
"Batch size to use during training.")
tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.")
tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
tf.app.flags.DEFINE_integer("en_vocab_size", 40000, "English vocabulary size.")
tf.app.flags.DEFINE_integer("fr_vocab_size", 40000, "French vocabulary size.")
tf.app.flags.DEFINE_integer("from_vocab_size", 40000, "English vocabulary size.")
tf.app.flags.DEFINE_integer("to_vocab_size", 40000, "French vocabulary size.")
tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory")
tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.")
tf.app.flags.DEFINE_string("from_train_data", None, "Training data.")
tf.app.flags.DEFINE_string("to_train_data", None, "Training data.")
tf.app.flags.DEFINE_string("from_dev_data", None, "Training data.")
tf.app.flags.DEFINE_string("to_dev_data", None, "Training data.")
tf.app.flags.DEFINE_integer("max_train_data_size", 0,
"Limit on the size of training data (0: no limit).")
tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200,
Expand Down Expand Up @@ -119,8 +123,8 @@ def create_model(session, forward_only):
"""Create translation model and initialize or load parameters in session."""
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
model = seq2seq_model.Seq2SeqModel(
FLAGS.en_vocab_size,
FLAGS.fr_vocab_size,
FLAGS.from_vocab_size,
FLAGS.to_vocab_size,
_buckets,
FLAGS.size,
FLAGS.num_layers,
Expand All @@ -142,10 +146,31 @@ def create_model(session, forward_only):

def train():
"""Train a en->fr translation model using WMT data."""
# Prepare WMT data.
print("Preparing WMT data in %s" % FLAGS.data_dir)
en_train, fr_train, en_dev, fr_dev, _, _ = data_utils.prepare_wmt_data(
FLAGS.data_dir, FLAGS.en_vocab_size, FLAGS.fr_vocab_size)
from_train = None
to_train = None
from_dev = None
to_dev = None
if FLAGS.from_train_data and FLAGS.to_train_data:
from_train_data = FLAGS.from_train_data
to_train_data = FLAGS.to_train_data
from_dev_data = from_train_data
to_dev_data = to_train_data
if FLAGS.from_dev_data and FLAGS.to_dev_data:
from_dev_data = FLAGS.from_dev_data
to_dev_data = FLAGS.to_dev_data
from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_data(
FLAGS.data_dir,
from_train_data,
to_train_data,
from_dev_data,
to_dev_data,
FLAGS.from_vocab_size,
FLAGS.to_vocab_size)
else:
# Prepare WMT data.
print("Preparing WMT data in %s" % FLAGS.data_dir)
from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_wmt_data(
FLAGS.data_dir, FLAGS.from_vocab_size, FLAGS.to_vocab_size)

with tf.Session() as sess:
# Create model.
Expand All @@ -155,8 +180,8 @@ def train():
# Read data into buckets and compute their sizes.
print ("Reading development and training data (limit: %d)."
% FLAGS.max_train_data_size)
dev_set = read_data(en_dev, fr_dev)
train_set = read_data(en_train, fr_train, FLAGS.max_train_data_size)
dev_set = read_data(from_dev, to_dev)
train_set = read_data(from_train, to_train, FLAGS.max_train_data_size)
train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
train_total_size = float(sum(train_bucket_sizes))

Expand Down Expand Up @@ -225,9 +250,9 @@ def decode():

# Load vocabularies.
en_vocab_path = os.path.join(FLAGS.data_dir,
"vocab%d.en" % FLAGS.en_vocab_size)
"vocab%d.from" % FLAGS.from_vocab_size)
fr_vocab_path = os.path.join(FLAGS.data_dir,
"vocab%d.fr" % FLAGS.fr_vocab_size)
"vocab%d.to" % FLAGS.to_vocab_size)
en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
_, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)

Expand All @@ -245,7 +270,7 @@ def decode():
bucket_id = i
break
else:
logging.warning("Sentence truncated: %s", sentence)
logging.warning("Sentence truncated: %s", sentence)

# Get a 1-element batch to feed the sentence to the model.
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
Expand Down

0 comments on commit c902a86

Please sign in to comment.