Skip to content

Commit

Permalink
[transformer] support bias
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Mar 8, 2024
1 parent 1a6dcfe commit b0e9d22
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 11 deletions.
16 changes: 12 additions & 4 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_q = nn.Linear(n_feat, n_feat, bias=query_bias)
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat, bias=value_bias)
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)

Expand Down Expand Up @@ -239,10 +241,13 @@ def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa)
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
Expand Down Expand Up @@ -387,9 +392,12 @@ def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False):
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa)
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa)

def forward(
self,
Expand Down
20 changes: 17 additions & 3 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class TransformerDecoder(torch.nn.Module):
False: use layer_norm after each sub-block of a layer.
src_attention: if false, encoder-decoder cross attention is not
applied, such as CIF model
query_bias: whether use bias in attention.linear_q
key_bias: whether use bias in attention.linear_k, False for whisper models.
value_bias: whether use bias in attention.linear_v
gradient_checkpointing: rerunning a forward-pass segment for each
checkpointed segment during backward.
tie_word_embedding: Tie or clone module weights depending of whether we are
Expand All @@ -70,7 +72,10 @@ def __init__(
use_output_layer: bool = True,
normalize_before: bool = True,
src_attention: bool = True,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
Expand Down Expand Up @@ -100,12 +105,14 @@ def __init__(
attention_dim,
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim,
self_attention_dropout_rate, key_bias, use_sdpa),
self_attention_dropout_rate, query_bias, key_bias,
value_bias, use_sdpa),
WENET_ATTENTION_CLASSES["crossattn"](
attention_heads, attention_dim, src_attention_dropout_rate,
key_bias, use_sdpa) if src_attention else None,
query_bias, key_bias, value_bias, use_sdpa)
if src_attention else None,
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate, activation),
dropout_rate, activation, mlp_bias),
dropout_rate,
normalize_before,
) for _ in range(self.num_blocks)
Expand Down Expand Up @@ -308,7 +315,10 @@ def __init__(
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
Expand All @@ -330,7 +340,9 @@ def __init__(
input_layer,
use_output_layer,
normalize_before,
query_bias=query_bias,
key_bias=key_bias,
value_bias=value_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa)
Expand All @@ -348,7 +360,9 @@ def __init__(
input_layer,
use_output_layer,
normalize_before,
query_bias=query_bias,
key_bias=key_bias,
value_bias=value_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa)
Expand Down
19 changes: 16 additions & 3 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def __init__(
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
dynamic chunk training
query_bias: whether use bias in attention.linear_q
key_bias: whether use bias in attention.linear_k, False for whisper models.
value_bias: whether use bias in attention.linear_v
gradient_checkpointing: rerunning a forward-pass segment for each
checkpointed segment during backward.
use_sdpa: whether to use SDPA, currently only support transformer for now
Expand Down Expand Up @@ -358,7 +360,10 @@ def __init__(
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
Expand All @@ -381,9 +386,10 @@ def __init__(
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
output_size,
attention_dropout_rate,
key_bias, use_sdpa),
query_bias, key_bias,
value_bias, use_sdpa),
PositionwiseFeedForward(output_size, linear_units,
dropout_rate, activation),
dropout_rate, activation, mlp_bias),
dropout_rate, normalize_before) for _ in range(num_blocks)
])

Expand Down Expand Up @@ -416,7 +422,11 @@ def __init__(
cnn_module_kernel: int = 15,
causal: bool = False,
cnn_module_norm: str = "batch_norm",
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
conv_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
):
Expand Down Expand Up @@ -451,7 +461,9 @@ def __init__(
attention_heads,
output_size,
attention_dropout_rate,
query_bias,
key_bias,
value_bias,
use_sdpa,
)
# feed-forward module definition
Expand All @@ -460,10 +472,11 @@ def __init__(
linear_units,
dropout_rate,
activation,
mlp_bias,
)
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)
cnn_module_norm, causal, conv_bias)

self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
Expand Down
4 changes: 3 additions & 1 deletion wenet/transformer/positionwise_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ def __init__(
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_1 = torch.nn.Linear(idim, hidden_units, bias=bias)
self.activation = activation
self.dropout = torch.nn.Dropout(dropout_rate)
self.w_2 = torch.nn.Linear(hidden_units, idim)
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
):
super(MoEFFNLayer, self).__init__()
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
Expand Down

0 comments on commit b0e9d22

Please sign in to comment.