From 5623deb79cfcd28f8f8c5463b58b5bd76a81fd0d Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Fri, 5 Mar 2021 00:14:45 -0800 Subject: [PATCH] [Mesh-TF] Add is_training as an arg to mtf.dropout PiperOrigin-RevId: 361088273 --- tensor2tensor/models/mtf_image_transformer.py | 14 +++++++++++++- tensor2tensor/models/mtf_transformer.py | 18 +++++++++++++----- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/tensor2tensor/models/mtf_image_transformer.py b/tensor2tensor/models/mtf_image_transformer.py index aabce6a47..fdcee7a23 100644 --- a/tensor2tensor/models/mtf_image_transformer.py +++ b/tensor2tensor/models/mtf_image_transformer.py @@ -243,8 +243,10 @@ def import_to_batch_by_length(x, name): def layer_prepostprocess_dropout(x, hparams): batch_dim = x.shape.dims[0] model_dim = x.shape.dims[-1] + mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) + is_training = mode == tf.estimator.ModeKeys.TRAIN return mtf.dropout( - x, + x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([batch_dim, model_dim])) @@ -259,6 +261,8 @@ def local_attention1d_spatial_decoder(x, kv_dim, heads_dim, x = mtf.reshape( x, mtf.Shape([batch_dim, num_w_blocks_dim, blocks_w_dim, model_dim])) # [ self attention - ffn - residual + dropout] x n + mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) + is_training = mode == tf.estimator.ModeKeys.TRAIN for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): @@ -268,6 +272,7 @@ def local_attention1d_spatial_decoder(x, kv_dim, heads_dim, mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), kv_dim, heads_dim, + is_training, memory_w_dim=blocks_w_dim, mask_right=True, name="self_att"), hparams) @@ -276,6 +281,7 @@ def local_attention1d_spatial_decoder(x, kv_dim, heads_dim, mtf.layers.dense_relu_dense( mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"), feedforward_dim, + is_training, hparams.dropout, dropout_broadcast_dims=[length_dim]), hparams) @@ -305,6 +311,8 @@ def local_attention2d_spatial_decoder(x, kv_dim, heads_dim, batch_dim, num_h_blocks_dim, num_w_blocks_dim, blocks_h_dim, blocks_w_dim, model_dim ])) + mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) + is_training = mode == tf.estimator.ModeKeys.TRAIN # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): @@ -316,6 +324,7 @@ def local_attention2d_spatial_decoder(x, kv_dim, heads_dim, mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), kv_dim, heads_dim, + is_training, memory_h_dim=num_h_blocks_dim, memory_w_dim=num_w_blocks_dim, name="self_att"), hparams) @@ -336,6 +345,8 @@ def local_attention1d_masked_decoder(x, kv_dim, heads_dim, """Image Transformer decoder with local1D masked layers.""" print(x) _, length_dim, model_dim = x.shape.dims + mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) + is_training = mode == tf.estimator.ModeKeys.TRAIN for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): @@ -347,6 +358,7 @@ def local_attention1d_masked_decoder(x, kv_dim, heads_dim, mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), kv_dim, heads_dim, + is_training, window_size=hparams.block_length, length_per_split=length_per_split, name="self_att"), hparams) diff --git a/tensor2tensor/models/mtf_transformer.py b/tensor2tensor/models/mtf_transformer.py index 044170ef9..5ac5e091a 100644 --- a/tensor2tensor/models/mtf_transformer.py +++ b/tensor2tensor/models/mtf_transformer.py @@ -242,6 +242,8 @@ def _mtf_model_fn(self, features, mesh): hparams = self._hparams extra_losses = [] targets = tf.to_int32(features["targets"]) + mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) + is_training = mode == tf.estimator.ModeKeys.TRAIN if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) @@ -289,7 +291,7 @@ def pad_to_max_length(x): def layer_prepostprocess_dropout(x): return mtf.dropout( - x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, + x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) (inputs_embedding_var, @@ -426,10 +428,11 @@ def _feedforward_layer(self, x, layer_type, losses=None): ValueError: if hparams make no sense """ hparams = self._hparams - + mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) + is_training = mode == tf.estimator.ModeKeys.TRAIN if layer_type == "drd": return mtf.layers.dense_relu_dense( - x, self.feedforward_dim, dropout=hparams.relu_dropout, + x, self.feedforward_dim, is_training, dropout=hparams.relu_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype) @@ -493,11 +496,13 @@ def _layer_stack(self, """ hparams = self._hparams is_incremental = (step_num is not None) + mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) + is_training = mode == tf.estimator.ModeKeys.TRAIN def layer_prepostprocess_dropout(x): if is_incremental: return x return mtf.dropout( - x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, + x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) num_layers = len(layers) num_layer_norms = num_layers + 1 @@ -540,6 +545,7 @@ def normalize(x): mtf.layers.multihead_attention( normalize(x), None, self_attention_mask, self.kv_dim, self.heads_dim, + is_training, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, @@ -560,6 +566,7 @@ def normalize(x): mtf.layers.multihead_attention( normalize(x), encoder_output, encdec_attention_mask, self.kv_dim, self.heads_dim, + is_training, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, @@ -582,7 +589,7 @@ def normalize(x): x += layer_prepostprocess_dropout( mtf.layers.masked_local_attention_1d( normalize(x), - self.kv_dim, self.heads_dim, + self.kv_dim, self.heads_dim, is_training, window_size=hparams.local_attention_window_size, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, @@ -601,6 +608,7 @@ def normalize(x): compression_factor=hparams.compression_factor, kv_channels=self.kv_dim, heads=self.heads_dim, + is_training=is_training, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype,