Skip to content

Commit

Permalink
added functionality to set max length of predicted answer span and to…
Browse files Browse the repository at this point in the history
… force the predicted end to be greater than predicted start of span
  • Loading branch information
andrejonasson committed Feb 16, 2018
1 parent a1fd765 commit e99128b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
5 changes: 5 additions & 0 deletions question_answering/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
tf.app.flags.DEFINE_integer("pool_size", 4, "Number of units the maxout network pools.")
tf.app.flags.DEFINE_integer("max_iter", 4, "Maximum number of iterations of decoder.")
tf.app.flags.DEFINE_float("keep_prob", 0.80, "Decoder: Fraction of units randomly kept on non-recurrent connections.")
tf.app.flags.DEFINE_integer("force_end_gt_start", True, "Forces the predicted answer end to be greater than the start.")
tf.app.flags.DEFINE_integer("max_answer_length", 20, "Maximum length of model's predicted answer span.")

# Character embeddings (NOTE: INPUT PROCESSING NOT IMPLEMENTED YET)
tf.app.flags.DEFINE_boolean("use_char_cnn", False, "Whether to use character embeddings to build word vectors.")
Expand Down Expand Up @@ -229,6 +231,9 @@ def multibatch_prediction_truth(session, model, data, num_batches=None, random=F
begin_idx = i * FLAGS.batch_size
q, p, ql, pl, a = data[begin_idx:begin_idx+FLAGS.batch_size]
answer_start, answer_end = session.run(model.answer, model.fill_feed_dict(q, p, ql, pl))
# for i, s in enumerate(answer_start):
# if s > answer_end[i]:
# print('predicted: ', (s, answer_end[i], pl[i]), 'truth: ', (a[i][0], a[i][1]))
start.append(answer_start)
end.append(answer_end)
truth.extend(a)
Expand Down
10 changes: 8 additions & 2 deletions question_answering/networks/dcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import tensorflow as tf

from networks.modules import maybe_dropout, max_product_span, naive_decode, cell_factory, char_cnn_word_vectors
from networks.modules import maybe_dropout, max_product_span, naive_decode, cell_factory, char_cnn_word_vectors, _maybe_mask_to_start
from networks.dcn_plus import baseline_encode, dcn_encode, dcnplus_encode, dcn_decode, dcn_loss
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import _maybe_mask_score

class DCN:
""" Builds graph for DCN-type models
Expand Down Expand Up @@ -82,7 +83,12 @@ def __init__(self, pretrained_embeddings, hparams):
logits = dcn_decode(encoding, self.paragraph_length, hparams['state_size'], hparams['pool_size'], hparams['max_iter'], keep_prob=maybe_dropout(hparams['keep_prob'], self.is_training))
last_iter_logit = logits.read(hparams['max_iter']-1)
start_logit, end_logit = last_iter_logit[:,:,0], last_iter_logit[:,:,1]
self.answer = (tf.argmax(start_logit, axis=1, name='answer_start'), tf.argmax(end_logit, axis=1, name='answer_end'))
start = tf.argmax(start_logit, axis=1, name='answer_start')
if hparams['force_end_gt_start']:
end_logit = _maybe_mask_to_start(end_logit, start, -1e30)
if hparams['max_answer_length'] > 0:
end_logit = _maybe_mask_score(end_logit, start+hparams['max_answer_length'], -1e30)
self.answer = (start, tf.argmax(end_logit, axis=1, name='answer_end'))

with tf.variable_scope('loss'):
self.loss = dcn_loss(logits, self.answer_span, max_iter=hparams['max_iter'])
Expand Down
6 changes: 6 additions & 0 deletions question_answering/networks/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def maybe_mask_affinity(affinity, sequence_length, affinity_mask_value=float('-i
return tf.where(score_mask, affinity, affinity_mask_values)


def _maybe_mask_to_start(score, start, score_mask_value):
score_mask = tf.sequence_mask(start, maxlen=tf.shape(score)[1])
score_mask_values = score_mask_value * tf.ones_like(score)
return tf.where(~score_mask, score, score_mask_values)


def maybe_dropout(keep_prob, is_training):
return tf.cond(tf.convert_to_tensor(is_training), lambda: keep_prob, lambda: 1.0)

Expand Down

0 comments on commit e99128b

Please sign in to comment.