-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ring flash attention with BPT #24
Comments
Do we even still need the BPT if we have the ring attention implemented in this repo? @zhuzilin I personally think BPT is a single-GPU version of ring attention, right? |
That's right, BPT is inherently supported by ring attention. We do not need another implementation. |
I'm not sure spliting the sequence length on each device into blocks could save memory (because we still need save buffers and flash_attn itself seems to use linear size memory w.r.t. sequence length), or speed up (because it will call smaller kernels). |
Hi~ @zhuzilin
我正在尝试将BPT 接入ring flash attention,使用chunk_size切分qkv,在local进行更小chunk的attention计算。
参照ring_flash_attn.py的forward和backward,实现了
blockwise_flash_attn_forward
和blockwise_flash_attn_backward
,目前forward精度可以对齐,backward存在误差。我想问一下,backward的实现可能存在哪些问题?下面是我的实现:
分别替换ring_flash_attn_forward 中的_flash_attn_forward,和ring_flash_attn_backward中的_flash_attn_backward
下面是我的测试结果:
The text was updated successfully, but these errors were encountered: