You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We didn't claim that more dilation is better (thinking about an extreme case that the segment length starts from 1). We suggest the segment length not less than 2048 in language tasks, for a good balance of efficiency and accuracy. In the paper, we use segment lengths of {2048, 4096, 8192, 16384, 32768}.
As shown in Figure 5 in our paper, the speedup is significant only when the sequence length is greater than 8K. This is because the flash-attn has an excellent optimization for dense attention, so the percentage of the attention cost is pretty small when the sequence length is relatively short. In this case, any further optimization on the attention becomes invisible.
torchrun --nproc_per_node=8 --nnodes=1 train.py ../../../fairseq/data-bin/wikitext-103/ --num-workers 0 --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 4096 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --flash-attention --segment-length [1024,2048,4096] --dilated-ratio [1,2,4]
the best ppl on the val dataset is 29.11
torchrun --nproc_per_node=8 --nnodes=1 train.py ../../../fairseq/data-bin/wikitext-103/ --num-workers 0 --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 4096 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --flash-attention --segment-length [2048,4096] --dilated-ratio [1,2]
the best ppl on the val dataset is 28.08
torchrun --nproc_per_node=8 --nnodes=1 train.py ../../../fairseq/data-bin/wikitext-103/ --num-workers 0 --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 4096 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --flash-attention --segment-length [4096] --dilated-ratio [1]
the best ppl on the val dataset is 27.92
more dilation more bad ppl,that is not same as the paper.
the speeds on the above are almost same, that is also not same as the paper.
The text was updated successfully, but these errors were encountered: