Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][transformer] bring llm component #2363

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Conversation

Mddct
Copy link
Collaborator

@Mddct Mddct commented Feb 22, 2024

这个pr 会将以下llm的组件引入wenet
(llama gemma等都用了)

TODO

  • benchmark multiquery vs multiheaded
  • @Mddct ROPE 目前实现没有问题,但是init model的时候漏掉了 需要修复下
  • [train_engine] support fsdp #2412
  • fix some comment
    btw:日后可方便加载llm模型

@Mddct Mddct force-pushed the Mddct-llm-component branch 4 times, most recently from 8c576bb to f261744 Compare February 22, 2024 15:46
@Mddct Mddct force-pushed the Mddct-llm-component branch from f261744 to a17fc45 Compare February 22, 2024 16:26
@xingchensong
Copy link
Member

xingchensong commented Feb 23, 2024

准备以后也是ckpt重命名的方式引入llm吗?(而不是import transformers)

@Mddct
Copy link
Collaborator Author

Mddct commented Feb 23, 2024

准备以后也是ckpt重命名的方式引入llm吗?(而不是import transformers)

是 后边会用fsdp/deepspeed 直接和transformers用有一堆奇奇怪怪的问题,而且也不方便做部署之类的工作

@xingchensong
Copy link
Member

同意!

@Mddct
Copy link
Collaborator Author

Mddct commented Feb 23, 2024

这里大致罗列下主流llm的一些情况,可能随版本变动 有些出入

模型名称 参数 隐藏层维度 层数 注意力头数 训练数据 位置编码 激活函数 归一化方法 注意力机制 词表大小 分词方法 最大长度 linear bias
LLAMA 6.7B 4096 32 32 1T RoPE SwiGLU RMSNorm(pre-norm) 多头注意力机制(MHA) 32000 BBPE 2048
LLAMA2 7B 4096 32 32 2.0T RoPE SwiGLU RMSNorm(pre-norm) 多头注意力机制(MHA) 32000 BBPE 4096 false
chatglm-6B 6.2B 4096 28 32 1T RoPE 2d位置编码 GELU layer norm(post-norm) 多头注意力机制(MHA) 130528 BBPE 2048
chatglm2-6B 6.2B 4096 28 32 1.4T RoPE 推理时,舍弃2d位置编码,回归decoder-only SwiGLU RMSNorm(post-norm) Multi-Query Attention (MQA) 65024 BBPE 32768
baichuan-7b 7B 4096 32 32 1.2T RoPE SwiGLU RMSNorm(pre-norm) 多头注意力机制(MHA) 64000 BBPE 4096 false
Qwen-7B 7B 4096 32 32 2.2T RoPE SwiGLU RMSNorm(pre-norm) 多头注意力机制(MHA) 151851 BBPE 2048 false
gemma rope gelu rmsnorm multi query false

@Mddct Mddct force-pushed the Mddct-llm-component branch from 427a033 to 8c64edb Compare February 23, 2024 09:37
@Mddct
Copy link
Collaborator Author

Mddct commented Feb 23, 2024

还是这个配置:#2333 (comment)

batch size data type 训练时间 att/rescore/ctc greedy/ctc beam wer
step 模式 avg 20 step 1000 save interval (no stage1 shuffle) bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 10h36min 5.63/5.30/5.84/5.85
step 模式 avg 20 step 1000 save interval (stage1 shuffle) + sdpa bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 5.70/5.38/5.89/5.90
step 模式 avg 20 step 1000 save interval (no stage1 shuffle) + sdpa bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 5.67/5.32/89/5.89
bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 14h11min 5.53/5.24/5.85/5.85
+gated mlp bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 15h36min1 5.34/5.06/5.51/5.51
+encoder no bias (mlp with bias) bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 14h4min 5.66/5.29/5.91/5.91
+encoder/decoder no bias (mlp with no bias) bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 13h32min 5.56/5.28/5.91/5.91
+encoder/decoder no bias + gated mlp bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 15h14min 5.32/5.11/5.58/5.59
+rms norm, eps=1e-5 bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 14h22min 5.65/5.26/5.82/5.82
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-5 bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 15h22min 5.46/5.17/5.60/5.60
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6 bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 15h23min 5.29/5.00/5.42/5.42
transformer encoder no pos directly through blocks bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 14h5min 6.46/5.46/6.13/6.13
+rope google bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 14h18min 5.65/5.28/5.87/5.87
+rope llama bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 14h4min 5.72/5.36/5.99/5.99
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6 + rope bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 15h35min 5.24/5.02/5.53/5.53
+multiquery bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 14h10min 5.54/5.18/5.86/5.86
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6 + rope + multiquery bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 15h18min 5.62/5.27/5.85/5.85

NOTE: 上述实验中的subsampling (conv2d) + ctc dense 都存在bias

截屏2024-02-26 14 57 31

train_conformer.yaml

batch size data type 训练时间 att/rescore/ctc greedy/ctc beam wer
static batch size = 18 raw / 5.18/4.61/4.94/4.94
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm: bn bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h42min 4.75/4.52/4.88/4.89
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm: rms norm bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h48min 4.73/4.5/4.85/4.85
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm: bn + rope bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h25min 4.64/4.49/4.73/4.74
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm: bn + rope sync bn bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h 4.65/4.43/4.70/4.70
avg30: 4.65/4.39/4.66/4.66
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm: bn + rope sync bn no final ln bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h 4.57/4.32/4.58/4.58
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm: bn + rope + tie word emb bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h29min 4.70/4.51/4.79/4.79
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm: bn + rope + tie word emb + dec no linear bias bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h50min 4.67/4.52/4.80/4.80
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm: rms norm + rope bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h55min 4.70/4.52/4.88/4.87
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm: rms norm conv no bias + rope bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h29min 4.70/4.50/4.86/4.87
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm:bn + conv norm no bias rope+ multiquery bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h22min 4.62/4.60/6.26/6.26
+encoder/decoder no bias + gated mlp + rms_norm, eps=1e-6, conv_norm: bn + conv norm bias rope+ multiquery bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 21h20min 4.71/4.49/4.74/4.76

conformer moe见:#2474 (comment)

@Mddct Mddct force-pushed the Mddct-llm-component branch 4 times, most recently from 2e99c5e to 0f3deaa Compare February 26, 2024 06:25
@Mddct Mddct force-pushed the Mddct-llm-component branch from 0f3deaa to 1afd9de Compare February 26, 2024 06:40
@Mddct Mddct force-pushed the Mddct-llm-component branch 4 times, most recently from 2867f5a to 9046622 Compare February 27, 2024 17:40
@Mddct Mddct force-pushed the Mddct-llm-component branch 4 times, most recently from 558fd14 to ea0888f Compare February 28, 2024 09:44
@Mddct Mddct force-pushed the Mddct-llm-component branch from ea0888f to 0dc48f1 Compare February 28, 2024 09:48
@Mddct Mddct changed the title [transformer] bring llm component [WIP][transformer] bring llm component Feb 28, 2024
@Mddct Mddct force-pushed the Mddct-llm-component branch 14 times, most recently from 9f61138 to 522a60a Compare March 1, 2024 02:44
@Mddct Mddct force-pushed the Mddct-llm-component branch from 522a60a to 6dec5bd Compare March 1, 2024 02:50
@Mddct Mddct force-pushed the Mddct-llm-component branch from 7888af5 to ddf648b Compare March 1, 2024 09:06
@Mddct
Copy link
Collaborator Author

Mddct commented Mar 8, 2024

该pr会拆成若干pr 完成

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants