diff --git a/wenet/transformer/convolution.py b/wenet/transformer/convolution.py index 071f25aac..ad5f7f15e 100644 --- a/wenet/transformer/convolution.py +++ b/wenet/transformer/convolution.py @@ -19,6 +19,8 @@ import torch from torch import nn +from wenet.utils.class_utils import WENET_NORM_CLASSES + class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model.""" @@ -68,13 +70,13 @@ def __init__(self, bias=bias, ) - assert norm in ['batch_norm', 'layer_norm'] + assert norm in ['batch_norm', 'layer_norm', 'rms_norm'] if norm == "batch_norm": self.use_layer_norm = False - self.norm = nn.BatchNorm1d(channels) + self.norm = WENET_NORM_CLASSES['batch_norm'](channels) else: self.use_layer_norm = True - self.norm = nn.LayerNorm(channels) + self.norm = WENET_NORM_CLASSES[norm](channels) self.pointwise_conv2 = nn.Conv1d( channels, diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index ff4d932f3..c315eb39f 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -25,6 +25,7 @@ WENET_ATTENTION_CLASSES, WENET_ACTIVATION_CLASSES, WENET_MLP_CLASSES, + WENET_NORM_CLASSES, ) from wenet.utils.common import mask_to_bias from wenet.utils.mask import (subsequent_mask, make_pad_mask) @@ -81,6 +82,7 @@ def __init__( tie_word_embedding: bool = False, use_sdpa: bool = False, mlp_type: str = 'position_wise_feed_forward', + layer_norm_type: str = 'layer_norm', ): super().__init__() attention_dim = encoder_output_size @@ -93,8 +95,10 @@ def __init__( positional_dropout_rate), ) + assert layer_norm_type in ['layer_norm', 'rms_norm'] self.normalize_before = normalize_before - self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5) + self.after_norm = WENET_NORM_CLASSES[layer_norm_type](attention_dim, + eps=1e-5) self.use_output_layer = use_output_layer if use_output_layer: self.output_layer = torch.nn.Linear(attention_dim, vocab_size) diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py index 4b9270fbb..d28da1dc2 100644 --- a/wenet/transformer/decoder_layer.py +++ b/wenet/transformer/decoder_layer.py @@ -18,6 +18,8 @@ import torch from torch import nn +from wenet.utils.class_utils import WENET_NORM_CLASSES + class DecoderLayer(nn.Module): """Single decoder layer module. @@ -46,6 +48,7 @@ def __init__( feed_forward: nn.Module, dropout_rate: float, normalize_before: bool = True, + layer_norm_type: str = 'layer_norm', ): """Construct an DecoderLayer object.""" super().__init__() @@ -53,9 +56,10 @@ def __init__( self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward - self.norm1 = nn.LayerNorm(size, eps=1e-5) - self.norm2 = nn.LayerNorm(size, eps=1e-5) - self.norm3 = nn.LayerNorm(size, eps=1e-5) + assert layer_norm_type in ['layer_norm', 'rms_norm'] + self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5) + self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5) + self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5) self.dropout = nn.Dropout(dropout_rate) self.normalize_before = normalize_before diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 705609b4a..4d1dddc2a 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -25,6 +25,7 @@ from wenet.utils.class_utils import ( WENET_EMB_CLASSES, WENET_MLP_CLASSES, + WENET_NORM_CLASSES, WENET_SUBSAMPLE_CLASSES, WENET_ATTENTION_CLASSES, WENET_ACTIVATION_CLASSES, @@ -55,6 +56,7 @@ def __init__( use_dynamic_left_chunk: bool = False, gradient_checkpointing: bool = False, use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', ): """ Args: @@ -102,8 +104,10 @@ def __init__( positional_dropout_rate), ) + assert layer_norm_type in ['layer_norm', 'rms_norm'] self.normalize_before = normalize_before - self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) + self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size, + eps=1e-5) self.static_chunk_size = static_chunk_size self.use_dynamic_chunk = use_dynamic_chunk self.use_dynamic_left_chunk = use_dynamic_left_chunk @@ -368,6 +372,7 @@ def __init__( gradient_checkpointing: bool = False, use_sdpa: bool = False, mlp_type: str = 'position_wise_feed_forward', + layer_norm_type: str = 'layer_norm', ): """ Construct TransformerEncoder @@ -379,19 +384,21 @@ def __init__( input_layer, pos_enc_layer_type, normalize_before, static_chunk_size, use_dynamic_chunk, global_cmvn, use_dynamic_left_chunk, gradient_checkpointing, - use_sdpa) + use_sdpa, layer_norm_type) activation = WENET_ACTIVATION_CLASSES[activation_type]() mlp_class = WENET_MLP_CLASSES[mlp_type] self.encoders = torch.nn.ModuleList([ - TransformerEncoderLayer( - output_size, - WENET_ATTENTION_CLASSES["selfattn"](attention_heads, - output_size, - attention_dropout_rate, - query_bias, key_bias, - value_bias, use_sdpa), - mlp_class(output_size, linear_units, dropout_rate, activation, - mlp_bias), dropout_rate, normalize_before) + TransformerEncoderLayer(output_size, + WENET_ATTENTION_CLASSES["selfattn"]( + attention_heads, output_size, + attention_dropout_rate, query_bias, + key_bias, value_bias, use_sdpa), + mlp_class(output_size, linear_units, + dropout_rate, activation, + mlp_bias), + dropout_rate, + normalize_before, + layer_norm_type=layer_norm_type) for _ in range(num_blocks) ]) diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index aafcec412..31cf59291 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -20,6 +20,8 @@ import torch from torch import nn +from wenet.utils.class_utils import WENET_NORM_CLASSES + class TransformerEncoderLayer(nn.Module): """Encoder layer module. @@ -44,13 +46,15 @@ def __init__( feed_forward: torch.nn.Module, dropout_rate: float, normalize_before: bool = True, + layer_norm_type: str = 'layer_norm', ): """Construct an EncoderLayer object.""" super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward - self.norm1 = nn.LayerNorm(size, eps=1e-5) - self.norm2 = nn.LayerNorm(size, eps=1e-5) + assert layer_norm_type in ['layer_norm', 'rms_norm'] + self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5) + self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before @@ -135,23 +139,29 @@ def __init__( conv_module: Optional[nn.Module] = None, dropout_rate: float = 0.1, normalize_before: bool = True, + layer_norm_type: str = 'layer_norm', ): """Construct an EncoderLayer object.""" super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward + assert layer_norm_type in ['layer_norm', 'rms_norm'] self.feed_forward_macaron = feed_forward_macaron self.conv_module = conv_module - self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module - self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module + self.norm_ff = WENET_NORM_CLASSES[layer_norm_type]( + size, eps=1e-5) # for the FNN module + self.norm_mha = WENET_NORM_CLASSES[layer_norm_type]( + size, eps=1e-5) # for the MHA module if feed_forward_macaron is not None: - self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) + self.norm_ff_macaron = WENET_NORM_CLASSES[layer_norm_type]( + size, eps=1e-5) self.ff_scale = 0.5 else: self.ff_scale = 1.0 if self.conv_module is not None: - self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module - self.norm_final = nn.LayerNorm( + self.norm_conv = WENET_NORM_CLASSES[layer_norm_type]( + size, eps=1e-5) # for the CNN module + self.norm_final = WENET_NORM_CLASSES[layer_norm_type]( size, eps=1e-5) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size diff --git a/wenet/transformer/norm.py b/wenet/transformer/norm.py new file mode 100644 index 000000000..2c3756f13 --- /dev/null +++ b/wenet/transformer/norm.py @@ -0,0 +1,22 @@ +import torch + + +class RMSNorm(torch.nn.Module): + """ https://arxiv.org/pdf/1910.07467.pdf + """ + + def __init__( + self, + dim: int, + eps: float = 1e-6, + ): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + x = self._norm(x.float()).type_as(x) + return x * self.weight diff --git a/wenet/utils/class_utils.py b/wenet/utils/class_utils.py index 314e1b826..b25e58fad 100644 --- a/wenet/utils/class_utils.py +++ b/wenet/utils/class_utils.py @@ -2,7 +2,9 @@ # -*- coding: utf-8 -*- # Copyright [2023-11-28] import torch +from torch.nn import BatchNorm1d, LayerNorm from wenet.paraformer.embedding import ParaformerPositinoalEncoding +from wenet.transformer.norm import RMSNorm from wenet.transformer.positionwise_feed_forward import ( GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward) @@ -77,3 +79,9 @@ 'moe': MoEFFNLayer, 'gated': GatedVariantsMLP } + +WENET_NORM_CLASSES = { + 'layer_norm': LayerNorm, + 'batch_norm': BatchNorm1d, + 'rms_norm': RMSNorm +}