Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][transformer] bring llm component #2363

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 181 additions & 27 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
"""Multi-Head Attention layer definition."""

import math
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch import nn
from wenet.transformer.embedding import apply_rotary_emb

from wenet.utils.common import get_dtype_min


class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
if n_kv_head != None and n_kv_head != n_head
see: https://arxiv.org/pdf/1911.02150.pdf

Args:
n_head (int): The number of heads.
Expand All @@ -34,22 +37,40 @@ class MultiHeadedAttention(nn.Module):

"""

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False):
def __init__(
self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False,
bias: bool = True,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0

self.inner_dim = n_feat if head_dim is None else head_dim * n_head
if n_kv_head is not None:
assert head_dim is not None
self.inner_kv_dim = head_dim * n_kv_head
n_kv_head = n_kv_head
else:
self.inner_kv_dim = self.inner_dim
n_kv_head = n_head
if self.inner_dim == n_feat:
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.d_k = self.inner_dim // n_head
assert self.d_k == self.inner_kv_dim // n_kv_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.h_kv = n_head if n_kv_head is None else n_kv_head

self.linear_q = nn.Linear(n_feat, self.inner_dim, bias=bias)
self.linear_k = nn.Linear(n_feat, self.inner_kv_dim, bias=key_bias)
self.linear_v = nn.Linear(n_feat, self.inner_kv_dim, bias=bias)
self.linear_out = nn.Linear(self.inner_dim, n_feat, bias=bias)
self.dropout = nn.Dropout(p=dropout_rate)

self.use_sdpa = use_sdpa
Expand All @@ -69,18 +90,18 @@ def forward_qkv(
torch.Tensor: Transformed query tensor, size
(#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, n_head, time2, d_k).
(#batch, n_kv_head, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
(#batch, n_kv_head, time2, d_k).

"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h_kv, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h_kv, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
k = k.transpose(1, 2) # (batch, head_kv, time2, d_k)
v = v.transpose(1, 2) # (batch, head_kv, time2, d_k)

return q, k, v

Expand Down Expand Up @@ -197,6 +218,17 @@ def forward(
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=1,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=1,
)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
Expand Down Expand Up @@ -225,16 +257,28 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
dropout_rate (float): Dropout rate.
"""

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False):
def __init__(
self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False,
bias: bool = True,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa)
super().__init__(n_head,
n_feat,
dropout_rate,
n_kv_head=n_kv_head,
key_bias=key_bias,
use_sdpa=use_sdpa,
head_dim=head_dim,
bias=bias)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
self.linear_pos = nn.Linear(n_feat, self.inner_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
Expand Down Expand Up @@ -320,6 +364,18 @@ def forward(
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=1,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=1,
)

# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
Expand Down Expand Up @@ -369,3 +425,101 @@ def forward(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache


class RopeMultiHeadedAttention(MultiHeadedAttention):

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False,
bias: bool = True,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa,
bias, n_kv_head, head_dim)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute scaled dot product attention.

Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
Wenet.
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`


Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`

"""
q, k, v = self.forward_qkv(query, key, value)
# NOTE(Mddct): In order to make the code easier to read,
# these two lines are not placed in MultiHeadedAttention.
q = apply_rotary_emb(q, freqs_cis=pos_emb)
k = apply_rotary_emb(k, freqs_cis=pos_emb)

# see above
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1)

if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=1,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=1,
)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
else:
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask.unsqueeze(1),
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = (output.transpose(1, 2).contiguous().view(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache
12 changes: 8 additions & 4 deletions wenet/transformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
Expand All @@ -29,7 +31,8 @@ def __init__(self,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True):
bias: bool = True,
eps: float = 1e-5):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
Expand Down Expand Up @@ -68,13 +71,14 @@ def __init__(self,
bias=bias,
)

assert norm in ['batch_norm', 'layer_norm']
assert norm in ['batch_norm', 'layer_norm', 'rms_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
self.norm = WENET_NORM_CLASSES['batch_norm'](channels, eps=eps)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
# layer_norm or rms_norm
self.norm = WENET_NORM_CLASSES[norm](channels, eps=eps)

self.pointwise_conv2 = nn.Conv1d(
channels,
Expand Down
Loading
Loading