diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 0c8ed8380a..adc5e65898 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -16,7 +16,7 @@ """Multi-Head Attention layer definition.""" import math -from typing import Tuple +from typing import Optional, Tuple import torch from torch import nn @@ -34,23 +34,40 @@ class MultiHeadedAttention(nn.Module): """ - def __init__(self, - n_head: int, - n_feat: int, - dropout_rate: float, - key_bias: bool = True, - use_sdpa: bool = False, - bias: bool = True): + def __init__( + self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True, + use_sdpa: bool = False, + bias: bool = True, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + ): """Construct an MultiHeadedAttention object.""" super().__init__() - assert n_feat % n_head == 0 + + self.inner_dim = n_feat if head_dim is None else head_dim * n_head + if n_kv_head is not None: + assert head_dim is not None + self.inner_kv_dim = head_dim * n_head + n_kv_head = n_kv_head + else: + self.inner_kv_dim = self.inner_dim + n_kv_head = n_head + if self.inner_dim == n_feat: + assert n_feat % n_head == 0 # We assume d_v always equals d_k - self.d_k = n_feat // n_head + self.d_k = self.inner_dim // n_head + assert self.d_k == self.inner_kv_dim // n_kv_head self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat, bias=bias) - self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) - self.linear_v = nn.Linear(n_feat, n_feat, bias=bias) - self.linear_out = nn.Linear(n_feat, n_feat, bias=bias) + self.h_kv = n_head if n_kv_head is None else n_kv_head + + self.linear_q = nn.Linear(n_feat, self.inner_dim, bias=bias) + self.linear_k = nn.Linear(n_feat, self.inner_kv_dim, bias=key_bias) + self.linear_v = nn.Linear(n_feat, self.inner_kv_dim, bias=bias) + self.linear_out = nn.Linear(self.inner_dim, n_feat, bias=bias) self.dropout = nn.Dropout(p=dropout_rate) self.use_sdpa = use_sdpa @@ -70,18 +87,18 @@ def forward_qkv( torch.Tensor: Transformed query tensor, size (#batch, n_head, time1, d_k). torch.Tensor: Transformed key tensor, size - (#batch, n_head, time2, d_k). + (#batch, n_kv_head, time2, d_k). torch.Tensor: Transformed value tensor, size - (#batch, n_head, time2, d_k). + (#batch, n_kv_head, time2, d_k). """ n_batch = query.size(0) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h_kv, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h_kv, self.d_k) q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) + k = k.transpose(1, 2) # (batch, head_kv, time2, d_k) + v = v.transpose(1, 2) # (batch, head_kv, time2, d_k) return q, k, v @@ -198,6 +215,17 @@ def forward( # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. new_cache = torch.cat((k, v), dim=-1) + if self.h_kv != self.h: + k = torch.repeat_interleave( + k, + self.h // self.h_kv, + dim=1, + ) + v = torch.repeat_interleave( + v, + self.h // self.h_kv, + dim=1, + ) if not self.use_sdpa: scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) @@ -226,22 +254,28 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): dropout_rate (float): Dropout rate. """ - def __init__(self, - n_head: int, - n_feat: int, - dropout_rate: float, - key_bias: bool = True, - use_sdpa: bool = False, - bias: bool = True): + def __init__( + self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True, + use_sdpa: bool = False, + bias: bool = True, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + ): """Construct an RelPositionMultiHeadedAttention object.""" super().__init__(n_head, n_feat, dropout_rate, - key_bias, - use_sdpa, + n_kv_head=n_kv_head, + key_bias=key_bias, + use_sdpa=use_sdpa, + head_dim=head_dim, bias=bias) # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + self.linear_pos = nn.Linear(n_feat, self.inner_dim, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) @@ -327,6 +361,18 @@ def forward( dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) + if self.h_kv != self.h: + k = torch.repeat_interleave( + k, + self.h // self.h_kv, + dim=1, + ) + v = torch.repeat_interleave( + v, + self.h // self.h_kv, + dim=1, + ) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. new_cache = torch.cat((k, v), dim=-1) @@ -376,14 +422,3 @@ def forward( query.size(0), -1, self.h * self.d_k)) # (batch, time1, d_model) return self.linear_out(output), new_cache - - -class MultiQueryAttention(MultiHeadedAttention): - - def __init__(self, - n_head: int, - n_feat: int, - dropout_rate: float, - key_bias: bool = True, - use_sdpa: bool = False): - super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa) diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index eb1aaf31ec..95584be22b 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -80,6 +80,8 @@ def __init__( bias: bool = True, layer_norm_type: str = 'layer_norm', eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, ): super().__init__() attention_dim = encoder_output_size @@ -114,6 +116,8 @@ def __init__( key_bias, use_sdpa, bias=bias, + n_kv_head=n_kv_head, + head_dim=head_dim, ), WENET_ATTENTION_CLASSES["selfattn"]( attention_heads, @@ -122,6 +126,8 @@ def __init__( key_bias, use_sdpa, bias=bias, + n_kv_head=n_kv_head, + head_dim=head_dim, ) if src_attention else None, mlp_class( attention_dim, @@ -328,6 +334,8 @@ def __init__( bias: bool = True, layer_norm_type: str = 'layer_norm', eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, ): super().__init__() @@ -353,6 +361,8 @@ def __init__( bias=bias, layer_norm_type=layer_norm_type, eps=eps, + n_kv_head=n_kv_head, + head_dim=head_dim, ) self.right_decoder = TransformerDecoder( @@ -376,6 +386,8 @@ def __init__( bias=bias, layer_norm_type=layer_norm_type, eps=eps, + n_kv_head=n_kv_head, + head_dim=head_dim, ) def forward( diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index cd454ef5a6..acbe179e32 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -14,7 +14,7 @@ # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder definition.""" -from typing import Tuple +from typing import Optional, Tuple import torch import torch.utils.checkpoint as ckpt @@ -371,6 +371,8 @@ def __init__( bias: bool = True, layer_norm_type: str = 'layer_norm', eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, ): """ Construct TransformerEncoder @@ -388,12 +390,16 @@ def __init__( self.encoders = torch.nn.ModuleList([ TransformerEncoderLayer( output_size, - WENET_ATTENTION_CLASSES["selfattn"](attention_heads, - output_size, - attention_dropout_rate, - key_bias, - use_sdpa, - bias=bias), + WENET_ATTENTION_CLASSES["selfattn"]( + attention_heads, + output_size, + attention_dropout_rate, + key_bias=key_bias, + use_sdpa=use_sdpa, + bias=bias, + n_kv_head=n_kv_head, + head_dim=head_dim, + ), mlp_class(output_size, linear_units, dropout_rate, @@ -442,6 +448,8 @@ def __init__( bias: bool = True, layer_norm_type: str = 'layer_norm', eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, ): """Construct ConformerEncoder @@ -491,6 +499,8 @@ def __init__( key_bias, use_sdpa, bias, + n_kv_head, + head_dim, ) # feed-forward module definition positionwise_layer_args = (