Skip to content

Commit

Permalink
revert temp changes
Browse files Browse the repository at this point in the history
  • Loading branch information
andrejonasson committed Feb 8, 2018
1 parent c907459 commit 023c5ea
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
4 changes: 2 additions & 2 deletions question_answering/networks/dcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, pretrained_embeddings, hparams):

# Setup RNN Cells
cell = lambda: cell_factory(hparams['cell'], hparams['state_size'], self.is_training, hparams['input_keep_prob'], hparams['output_keep_prob'], hparams['state_keep_prob'])
final_cell = lambda: cell_factory(hparams['cell'], hparams['state_size'], self.is_training, 1.0, hparams['output_keep_prob'], hparams['state_keep_prob']) # TODO TEMP hparams['final_input_keep_prob']
final_cell = lambda: cell_factory(hparams['cell'], hparams['state_size'], self.is_training, hparams['final_input_keep_prob'], hparams['output_keep_prob'], hparams['state_keep_prob']) # TODO TEMP

# Setup Encoders
with tf.variable_scope('prediction'):
Expand All @@ -61,7 +61,7 @@ def __init__(self, pretrained_embeddings, hparams):
self.encode = dcn_encode
else:
self.encode = dcnplus_encode
encoding = self.encode(cell, final_cell, q_embeddings, self.question_length, p_embeddings, self.paragraph_length, keep_prob=maybe_dropout(hparams['keep_prob'], self.is_training), final_input_keep_prob=maybe_dropout(hparams['final_input_keep_prob'], self.is_training))
encoding = self.encode(cell, final_cell, q_embeddings, self.question_length, p_embeddings, self.paragraph_length, keep_prob=maybe_dropout(hparams['keep_prob'], self.is_training))
encoding = tf.nn.dropout(encoding, keep_prob=maybe_dropout(hparams['encoding_keep_prob'], self.is_training))

# Decoder, loss and prediction mechanism are different for baseline/mixed and dcn/dcn_plus
Expand Down
3 changes: 1 addition & 2 deletions question_answering/networks/dcn_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def baseline_encode(cell_factory, final_cell_factory, query, query_length, docum
return encoding # N x P x 2H


def dcn_encode(cell_factory, final_cell_factory, query, query_length, document, document_length, keep_prob=1.0, final_input_keep_prob=1.0):
def dcn_encode(cell_factory, final_cell_factory, query, query_length, document, document_length, keep_prob=1.0):
""" DCN Encoder that encodes questions and paragraphs into one representation.
It first encodes the question and paragraphs using a shared LSTM, then uses a
Expand Down Expand Up @@ -136,7 +136,6 @@ def dcn_encode(cell_factory, final_cell_factory, query, query_length, document,

with tf.variable_scope('final_encoder'):
document_representation = tf.concat(document_representations, 2)
document_representation = tf.nn.dropout(document_representation, final_input_keep_prob) # test if wanted
final = final_cell_factory()
outputs, _ = tf.nn.bidirectional_dynamic_rnn(
cell_fw = final,
Expand Down

0 comments on commit 023c5ea

Please sign in to comment.