diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index d644b98f22..22b0c6fcd7 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -38,7 +38,9 @@ def __init__(self, n_head: int, n_feat: int, dropout_rate: float, + query_bias: bool = True, key_bias: bool = True, + value_bias: bool = True, use_sdpa: bool = False): """Construct an MultiHeadedAttention object.""" super().__init__() @@ -46,9 +48,9 @@ def __init__(self, # We assume d_v always equals d_k self.d_k = n_feat // n_head self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_q = nn.Linear(n_feat, n_feat, bias=query_bias) self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) - self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat, bias=value_bias) self.linear_out = nn.Linear(n_feat, n_feat) self.dropout = nn.Dropout(p=dropout_rate) @@ -239,10 +241,13 @@ def __init__(self, n_head: int, n_feat: int, dropout_rate: float, + query_bias: bool = True, key_bias: bool = True, + value_bias: bool = True, use_sdpa: bool = False): """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa) + super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, + value_bias, use_sdpa) # linear transformation for positional encoding self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) # these two learnable bias are used in matrix c and matrix d @@ -387,9 +392,12 @@ def __init__(self, n_head: int, n_feat: int, dropout_rate: float, + query_bias: bool = True, key_bias: bool = True, + value_bias: bool = True, use_sdpa: bool = False): - super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa) + super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, + value_bias, use_sdpa) def forward( self, diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 4d4542660b..2b3a5d61a6 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -48,7 +48,9 @@ class TransformerDecoder(torch.nn.Module): False: use layer_norm after each sub-block of a layer. src_attention: if false, encoder-decoder cross attention is not applied, such as CIF model + query_bias: whether use bias in attention.linear_q key_bias: whether use bias in attention.linear_k, False for whisper models. + value_bias: whether use bias in attention.linear_v gradient_checkpointing: rerunning a forward-pass segment for each checkpointed segment during backward. tie_word_embedding: Tie or clone module weights depending of whether we are @@ -70,7 +72,10 @@ def __init__( use_output_layer: bool = True, normalize_before: bool = True, src_attention: bool = True, + query_bias: bool = True, key_bias: bool = True, + value_bias: bool = True, + mlp_bias: bool = True, activation_type: str = "relu", gradient_checkpointing: bool = False, tie_word_embedding: bool = False, @@ -100,12 +105,14 @@ def __init__( attention_dim, WENET_ATTENTION_CLASSES["selfattn"]( attention_heads, attention_dim, - self_attention_dropout_rate, key_bias, use_sdpa), + self_attention_dropout_rate, query_bias, key_bias, + value_bias, use_sdpa), WENET_ATTENTION_CLASSES["crossattn"]( attention_heads, attention_dim, src_attention_dropout_rate, - key_bias, use_sdpa) if src_attention else None, + query_bias, key_bias, value_bias, use_sdpa) + if src_attention else None, PositionwiseFeedForward(attention_dim, linear_units, - dropout_rate, activation), + dropout_rate, activation, mlp_bias), dropout_rate, normalize_before, ) for _ in range(self.num_blocks) @@ -308,7 +315,10 @@ def __init__( input_layer: str = "embed", use_output_layer: bool = True, normalize_before: bool = True, + query_bias: bool = True, key_bias: bool = True, + value_bias: bool = True, + mlp_bias: bool = True, gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, @@ -330,7 +340,9 @@ def __init__( input_layer, use_output_layer, normalize_before, + query_bias=query_bias, key_bias=key_bias, + value_bias=value_bias, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, use_sdpa=use_sdpa) @@ -348,7 +360,9 @@ def __init__( input_layer, use_output_layer, normalize_before, + query_bias=query_bias, key_bias=key_bias, + value_bias=value_bias, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, use_sdpa=use_sdpa) diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index acf6546b65..fb4d81e981 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -83,7 +83,9 @@ def __init__( global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module use_dynamic_left_chunk (bool): whether use dynamic left chunk in dynamic chunk training + query_bias: whether use bias in attention.linear_q key_bias: whether use bias in attention.linear_k, False for whisper models. + value_bias: whether use bias in attention.linear_v gradient_checkpointing: rerunning a forward-pass segment for each checkpointed segment during backward. use_sdpa: whether to use SDPA, currently only support transformer for now @@ -358,7 +360,10 @@ def __init__( use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, + query_bias: bool = True, key_bias: bool = True, + value_bias: bool = True, + mlp_bias: bool = True, activation_type: str = "relu", gradient_checkpointing: bool = False, use_sdpa: bool = False, @@ -381,9 +386,10 @@ def __init__( WENET_ATTENTION_CLASSES["selfattn"](attention_heads, output_size, attention_dropout_rate, - key_bias, use_sdpa), + query_bias, key_bias, + value_bias, use_sdpa), PositionwiseFeedForward(output_size, linear_units, - dropout_rate, activation), + dropout_rate, activation, mlp_bias), dropout_rate, normalize_before) for _ in range(num_blocks) ]) @@ -416,7 +422,11 @@ def __init__( cnn_module_kernel: int = 15, causal: bool = False, cnn_module_norm: str = "batch_norm", + query_bias: bool = True, key_bias: bool = True, + value_bias: bool = True, + mlp_bias: bool = True, + conv_bias: bool = True, gradient_checkpointing: bool = False, use_sdpa: bool = False, ): @@ -451,7 +461,9 @@ def __init__( attention_heads, output_size, attention_dropout_rate, + query_bias, key_bias, + value_bias, use_sdpa, ) # feed-forward module definition @@ -460,10 +472,11 @@ def __init__( linear_units, dropout_rate, activation, + mlp_bias, ) # convolution module definition convolution_layer_args = (output_size, cnn_module_kernel, activation, - cnn_module_norm, causal) + cnn_module_norm, causal, conv_bias) self.encoders = torch.nn.ModuleList([ ConformerEncoderLayer( diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index b7a2cf6e73..68aeb0b04e 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -36,10 +36,11 @@ def __init__( hidden_units: int, dropout_rate: float, activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, ): """Construct a PositionwiseFeedForward object.""" super(PositionwiseFeedForward, self).__init__() - self.w_1 = torch.nn.Linear(idim, hidden_units) + self.w_1 = torch.nn.Linear(idim, hidden_units, bias=bias) self.activation = activation self.dropout = torch.nn.Dropout(dropout_rate) self.w_2 = torch.nn.Linear(hidden_units, idim) @@ -80,6 +81,7 @@ def __init__( hidden_units: int, dropout_rate: float, activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, ): super(MoEFFNLayer, self).__init__() self.gate = torch.nn.Linear(idim, n_expert, bias=False)