diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index ea87c2ca19..f257abf373 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -20,6 +20,7 @@ import torch from torch import nn +from wenet.transformer.embedding import apply_rotary_emb from wenet.utils.common import get_dtype_min @@ -424,3 +425,101 @@ def forward( query.size(0), -1, self.h * self.d_k)) # (batch, time1, d_model) return self.linear_out(output), new_cache + + +class RopeMultiHeadedAttention(MultiHeadedAttention): + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True, + use_sdpa: bool = False, + bias: bool = True, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None): + super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa, + bias, n_kv_head, head_dim) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + Wenet. + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + """ + q, k, v = self.forward_qkv(query, key, value) + # see above + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + + # NOTE(Mddct): In order to make the code easier to read, + # these two lines are not placed in MultiHeadedAttention. + q = apply_rotary_emb(q, freqs_cis=pos_emb) + k = apply_rotary_emb(k, freqs_cis=pos_emb) + + new_cache = torch.cat((k, v), dim=-1) + if self.h_kv != self.h: + k = torch.repeat_interleave( + k, + self.h // self.h_kv, + dim=1, + ) + v = torch.repeat_interleave( + v, + self.h // self.h_kv, + dim=1, + ) + + if not self.use_sdpa: + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache + else: + output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask.unsqueeze(1), + dropout_p=self.dropout_rate, + scale=1 / math.sqrt(self.d_k), + ) + output = (output.transpose(1, 2).contiguous().view( + query.size(0), -1, + self.h * self.d_k)) # (batch, time1, d_model) + return self.linear_out(output), new_cache diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index 17d8810ffd..75bbf477f3 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -178,7 +178,7 @@ class NoPositionalEncoding(torch.nn.Module): """ No position encoding """ - def __init__(self, d_model: int, dropout_rate: float): + def __init__(self, d_model: int, dropout_rate: float, *args): super().__init__() self.d_model = d_model self.dropout = torch.nn.Dropout(p=dropout_rate) @@ -195,3 +195,63 @@ def forward(self, def position_encoding(self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor: return torch.zeros(1, size, self.d_model) + + +# copy from:https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L84 +def precompute_freqs_cis(dim: int, + end: int, + theta: float = 10000.0) -> torch.Tensor: + """Precomputes the frequency cis.""" + freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +# copy from:https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L95 +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Applies the rotary embedding to the query and key tensors.""" + x_ = torch.view_as_complex( + torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1)) + x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) + x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) + x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], + -1).transpose(1, 2) + return x_out + + +class RopePositionalEncoding(PositionalEncoding): + + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 1500, + rope_theta=10000.0): + super().__init__(d_model, dropout_rate=dropout_rate, max_len=max_len) + delattr(self, 'pe') + self.pe = precompute_freqs_cis(d_model, max_len * 2, rope_theta) + del self.dropout + self.dropout_rate = dropout_rate + + def forward( + self, + x: torch.Tensor, + offset: Union[int, + torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]: + + self.pe = self.pe.to(x.device) + pos_emb = self.position_encoding(offset, x.size(1), False) + # NOTE(Mddct): some model don't scale + # TODO(Mddct): fix + x = x * self.xscale + # NOTE(Mddct) dropout don't suuport complex float for pos_emb + return self.dropout(x), self.dropout_complex(pos_emb) + + def dropout_complex(self, x): + mask = torch.nn.functional.dropout( + torch.ones_like(x.real), + self.training, + p=self.dropout_rate, + ) + return x * mask diff --git a/wenet/utils/class_utils.py b/wenet/utils/class_utils.py index e06b2e7b47..4ad32d2e00 100644 --- a/wenet/utils/class_utils.py +++ b/wenet/utils/class_utils.py @@ -21,11 +21,13 @@ from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4 from wenet.transformer.embedding import (PositionalEncoding, RelPositionalEncoding, + RopePositionalEncoding, WhisperPositionalEncoding, LearnablePositionalEncoding, NoPositionalEncoding) from wenet.transformer.attention import (MultiHeadedAttention, - RelPositionMultiHeadedAttention) + RelPositionMultiHeadedAttention, + RopeMultiHeadedAttention) from wenet.efficient_conformer.attention import GroupedRelPositionMultiHeadedAttention WENET_ACTIVATION_CLASSES = { @@ -63,12 +65,14 @@ "abs_pos_whisper": WhisperPositionalEncoding, "embed_learnable_pe": LearnablePositionalEncoding, "abs_pos_paraformer": ParaformerPositinoalEncoding, + "rope": RopePositionalEncoding, } WENET_ATTENTION_CLASSES = { "selfattn": MultiHeadedAttention, "rel_selfattn": RelPositionMultiHeadedAttention, "grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention, + "rope_selfattn": RopeMultiHeadedAttention, } WENET_MLP_CLASSES = {