From 8c576bb7dfe33f92ea4901349c3471550ec70636 Mon Sep 17 00:00:00 2001 From: Mddct Date: Thu, 22 Feb 2024 23:27:42 +0800 Subject: [PATCH] gated mlp works --- wenet/transformer/decoder.py | 15 ++++++--- wenet/transformer/encoder.py | 59 ++++++++++++++++++------------------ wenet/utils/class_utils.py | 8 +++++ 3 files changed, 48 insertions(+), 34 deletions(-) diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index ec467ee434..263023508e 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -20,11 +20,11 @@ import logging from wenet.transformer.decoder_layer import DecoderLayer -from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward from wenet.utils.class_utils import ( WENET_EMB_CLASSES, WENET_ATTENTION_CLASSES, WENET_ACTIVATION_CLASSES, + WENET_MLP_CLASSES, ) from wenet.utils.common import mask_to_bias from wenet.utils.mask import (subsequent_mask, make_pad_mask) @@ -75,6 +75,7 @@ def __init__( gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward', ): super().__init__() attention_dim = encoder_output_size @@ -95,6 +96,7 @@ def __init__( else: self.output_layer = torch.nn.Identity() self.num_blocks = num_blocks + mlp_class = WENET_MLP_CLASSES[mlp_type] self.decoders = torch.nn.ModuleList([ DecoderLayer( attention_dim, @@ -104,8 +106,8 @@ def __init__( WENET_ATTENTION_CLASSES["selfattn"]( attention_heads, attention_dim, src_attention_dropout_rate, key_bias, use_sdpa) if src_attention else None, - PositionwiseFeedForward(attention_dim, linear_units, - dropout_rate, activation), + mlp_class(attention_dim, linear_units, dropout_rate, + activation), dropout_rate, normalize_before, ) for _ in range(self.num_blocks) @@ -298,6 +300,7 @@ def __init__( gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward', ): super().__init__() @@ -318,7 +321,8 @@ def __init__( key_bias=key_bias, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, - use_sdpa=use_sdpa) + use_sdpa=use_sdpa, + mlp_type=mlp_type) self.right_decoder = TransformerDecoder( vocab_size, @@ -336,7 +340,8 @@ def __init__( key_bias=key_bias, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, - use_sdpa=use_sdpa) + use_sdpa=use_sdpa, + mlp_type=mlp_type) def forward( self, diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index acf6546b65..b3b624f00d 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -22,9 +22,9 @@ from wenet.transformer.convolution import ConvolutionModule from wenet.transformer.encoder_layer import TransformerEncoderLayer from wenet.transformer.encoder_layer import ConformerEncoderLayer -from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward from wenet.utils.class_utils import ( WENET_EMB_CLASSES, + WENET_MLP_CLASSES, WENET_SUBSAMPLE_CLASSES, WENET_ATTENTION_CLASSES, WENET_ACTIVATION_CLASSES, @@ -341,28 +341,27 @@ def forward_chunk_by_chunk( class TransformerEncoder(BaseEncoder): """Transformer encoder module.""" - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - input_layer: str = "conv2d", - pos_enc_layer_type: str = "abs_pos", - normalize_before: bool = True, - static_chunk_size: int = 0, - use_dynamic_chunk: bool = False, - global_cmvn: torch.nn.Module = None, - use_dynamic_left_chunk: bool = False, - key_bias: bool = True, - activation_type: str = "relu", - gradient_checkpointing: bool = False, - use_sdpa: bool = False, - ): + def __init__(self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + key_bias: bool = True, + activation_type: str = "relu", + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward'): """ Construct TransformerEncoder See Encoder for the meaning of each parameter. @@ -375,6 +374,7 @@ def __init__( use_dynamic_left_chunk, gradient_checkpointing, use_sdpa) activation = WENET_ACTIVATION_CLASSES[activation_type]() + mlp_class = WENET_MLP_CLASSES[mlp_type] self.encoders = torch.nn.ModuleList([ TransformerEncoderLayer( output_size, @@ -382,9 +382,9 @@ def __init__( output_size, attention_dropout_rate, key_bias, use_sdpa), - PositionwiseFeedForward(output_size, linear_units, - dropout_rate, activation), - dropout_rate, normalize_before) for _ in range(num_blocks) + mlp_class(output_size, linear_units, dropout_rate, + activation), dropout_rate, normalize_before) + for _ in range(num_blocks) ]) @@ -419,6 +419,7 @@ def __init__( key_bias: bool = True, gradient_checkpointing: bool = False, use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward', ): """Construct ConformerEncoder @@ -465,14 +466,14 @@ def __init__( convolution_layer_args = (output_size, cnn_module_kernel, activation, cnn_module_norm, causal) + mlp_class = WENET_MLP_CLASSES[mlp_type] self.encoders = torch.nn.ModuleList([ ConformerEncoderLayer( output_size, WENET_ATTENTION_CLASSES[selfattention_layer_type]( *encoder_selfattn_layer_args), - PositionwiseFeedForward(*positionwise_layer_args), - PositionwiseFeedForward( - *positionwise_layer_args) if macaron_style else None, + mlp_class(*positionwise_layer_args), + mlp_class(*positionwise_layer_args) if macaron_style else None, ConvolutionModule( *convolution_layer_args) if use_cnn_module else None, dropout_rate, diff --git a/wenet/utils/class_utils.py b/wenet/utils/class_utils.py index a136359b5e..5c0800f72b 100644 --- a/wenet/utils/class_utils.py +++ b/wenet/utils/class_utils.py @@ -3,6 +3,8 @@ # Copyright [2023-11-28] import torch from wenet.paraformer.embedding import ParaformerPositinoalEncoding +from wenet.transformer.positionwise_feed_forward import ( + GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward) from wenet.transformer.swish import Swish from wenet.transformer.subsampling import ( @@ -66,3 +68,9 @@ "rel_selfattn": RelPositionMultiHeadedAttention, "grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention, } + +WENET_MLP_CLASSES = { + 'position_wise_feed_forward': PositionwiseFeedForward, + 'moe': MoEFFNLayer, + 'gated': GatedVariantsMLP +}