Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ring flash attention with BPT #24

Open
JiaoPL opened this issue Mar 22, 2024 · 3 comments
Open

ring flash attention with BPT #24

JiaoPL opened this issue Mar 22, 2024 · 3 comments

Comments

@JiaoPL
Copy link

JiaoPL commented Mar 22, 2024

Hi~ @zhuzilin
我正在尝试将BPT 接入ring flash attention,使用chunk_size切分qkv,在local进行更小chunk的attention计算。
参照ring_flash_attn.py的forward和backward,实现了 blockwise_flash_attn_forwardblockwise_flash_attn_backward,目前forward精度可以对齐,backward存在误差。我想问一下,backward的实现可能存在哪些问题?
下面是我的实现:

def blockwise_flash_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    q_chunk_size: int,
    k_chunk_size: int,
    softmax_scale,
    dropout_p=0,
    causal=True,
    return_softmax=True
):
    assert q.shape[1] % q_chunk_size == 0
    assert k.shape[1] % k_chunk_size == 0
    
    num_q_chunk = q.shape[1] // q_chunk_size
    num_k_chunk = k.shape[1] // k_chunk_size
    batch,seqlen,head_dim,num_head = q.shape
    
    block_out = torch.empty(q.shape, dtype=torch.float32, device=q.device)
    block_lse = torch.empty((batch,seqlen,head_dim,1), dtype=torch.float32, device=q.device)

    for i in range(num_q_chunk):
        q_i = q[:,i * q_chunk_size: (i + 1) * q_chunk_size]
        out_i = None
        lse_i = None
        
        for j in range(num_k_chunk-1,-1,-1):
            if j > i and causal:
                continue
            
            k_j = k[:,j * k_chunk_size: (j + 1) * k_chunk_size]
            v_j = v[:,j * k_chunk_size: (j + 1) * k_chunk_size]
            
            out_ij, _, _, _, _, lse_ij, _, _ = _flash_attn_forward(
                q_i,
                k_j,
                v_j,
                dropout_p,
                softmax_scale,
                causal=causal and j == i,
                return_softmax=return_softmax and dropout_p > 0
            )
            out_i, lse_i = update_out_and_lse(out_i, lse_i, out_ij, lse_ij)

        block_out[:, i * q_chunk_size: (i + 1) * q_chunk_size] = out_i
        block_lse[:, i * q_chunk_size: (i + 1) * q_chunk_size] = lse_i
        
    return block_out, block_lse.squeeze(dim=-1).transpose(-1,-2)


def blockwise_flash_attn_backward(
    dout,
    q,
    k,
    v,
    out,
    q_chunk_size,
    k_chunk_size,
    softmax_lse,
    dq,
    dk,
    dv,
    softmax_scale,
    dropout_p,
    causal=True,
    rng_state=None
):
    assert q.shape[1] % q_chunk_size == 0
    assert k.shape[1] % k_chunk_size == 0

    num_q_chunk = q.shape[1] // q_chunk_size
    num_k_chunk = k.shape[1] // k_chunk_size

    temp_dq_buffer = torch.empty(q[:,:q_chunk_size].shape, dtype=q.dtype, device=q.device)
    temp_dk_buffer = torch.empty(k[:,:k_chunk_size].shape, dtype=k.dtype, device=k.device)
    temp_dv_buffer = torch.empty(v[:,:k_chunk_size].shape, dtype=v.dtype, device=v.device)
    
    
    for i in range(num_q_chunk):
        q_i = q[:,i * q_chunk_size: (i + 1) * q_chunk_size]
        dout_i = dout[:,i * q_chunk_size: (i + 1) * q_chunk_size]
        out_i = out[:,i * q_chunk_size: (i + 1) * q_chunk_size]
        softmax_lse_i = softmax_lse[:,:,i * q_chunk_size: (i + 1) * q_chunk_size]
        q_i = q_i.contiguous()
        dout_i = dout_i.contiguous()
        out_i = out_i.contiguous()
        softmax_lse_i = softmax_lse_i.contiguous()

        for j in range(num_k_chunk):
            k_j = k[:,j * k_chunk_size: (j + 1) * k_chunk_size]
            v_j = v[:,j * k_chunk_size: (j + 1) * k_chunk_size]
            k_j = k_j.contiguous()
            v_j = v_j.contiguous()

            if j > i and causal:
                continue

            _flash_attn_backward(
                dout_i,
                q_i,
                k_j,
                v_j,
                out_i,
                softmax_lse_i,
                temp_dq_buffer,
                temp_dk_buffer,
                temp_dv_buffer,
                dropout_p,
                softmax_scale,
                causal = causal and j == i,
                rng_state=rng_state,
            )
            
            # update dq dk dv
            dq[:,i * q_chunk_size: (i + 1) * q_chunk_size] += temp_dq_buffer
            dk[:,j * k_chunk_size: (j + 1) * k_chunk_size] += temp_dk_buffer
            dv[:,j * k_chunk_size: (j + 1) * k_chunk_size] += temp_dv_buffer

分别替换ring_flash_attn_forward 中的_flash_attn_forward,和ring_flash_attn_backward中的_flash_attn_backward

下面是我的测试结果:

##############################
# forward:
##############################
out: max 2.896484375, mean 0.0203094482421875
lse: max 10.417832374572754, mean 9.204237937927246
out diff:
[0] max 0.00048828125, mean 8.881092071533203e-06
[1] max 0.0001220703125, mean 7.450580596923828e-06
[2] max 0.0001220703125, mean 5.9604644775390625e-06
[3] max 6.103515625e-05, mean 5.066394805908203e-06
[4] max 6.103515625e-05, mean 4.5299530029296875e-06
[5] max 6.103515625e-05, mean 4.112720489501953e-06
[6] max 6.103515625e-05, mean 3.814697265625e-06
[7] max 6.103515625e-05, mean 3.516674041748047e-06
lse diff:
[0] max 9.5367431640625e-07, mean 1.645181413323371e-07
[1] max 9.5367431640625e-07, mean 2.641230878452916e-07
[2] max 1.9073486328125e-06, mean 3.0044466825529526e-07
[3] max 1.9073486328125e-06, mean 3.3890827921823075e-07
[4] max 1.9073486328125e-06, mean 3.8137659430503845e-07
[5] max 1.9073486328125e-06, mean 4.0913002408160537e-07
[6] max 1.9073486328125e-06, mean 4.272908142866072e-07
[7] max 1.9073486328125e-06, mean 4.6798959374427795e-07
##############################
# backward:
##############################
load_dq:
[0] max 2.783203125, mean 0.052520751953125
[1] max 0.3310546875, mean 0.02398681640625
[2] max 0.2083740234375, mean 0.0184478759765625
[3] max 0.1162109375, mean 0.0155792236328125
[4] max 0.13330078125, mean 0.01374053955078125
[5] max 0.1204833984375, mean 0.01241302490234375
[6] max 0.11260986328125, mean 0.0114288330078125
[7] max 0.0775146484375, mean 0.01064300537109375
dq diff:
[0] max 0.005859375, mean 7.49826431274414e-05
[1] max 0.186279296875, mean 0.01239776611328125
[2] max 0.1973876953125, mean 0.01953125
[3] max 0.235107421875, mean 0.0253143310546875
[4] max 0.30615234375, mean 0.0301361083984375
[5] max 0.52392578125, mean 0.03436279296875
[6] max 0.56689453125, mean 0.038177490234375
[7] max 0.3955078125, mean 0.041748046875
load_dk:
[0] max 2.654296875, mean 0.05340576171875
[1] max 0.256591796875, mean 0.021697998046875
[2] max 0.169921875, mean 0.01535797119140625
[3] max 0.13330078125, mean 0.0116729736328125
[4] max 0.09124755859375, mean 0.0090484619140625
[5] max 0.1158447265625, mean 0.006908416748046875
[6] max 0.050384521484375, mean 0.00492095947265625
[7] max 0.03936767578125, mean 0.002498626708984375
dk diff:
[0] max 0.253173828125, mean 0.03192138671875
[1] max 0.16845703125, mean 0.0232696533203125
[2] max 0.130126953125, mean 0.017364501953125
[3] max 0.1097412109375, mean 0.012786865234375
[4] max 0.10797119140625, mean 0.00893402099609375
[5] max 0.049530029296875, mean 0.005580902099609375
[6] max 0.039337158203125, mean 0.002498626708984375
[7] max 1.52587890625e-05, mean 3.5762786865234375e-07
load_dv:
[0] max 5.89453125, mean 0.05450439453125
[1] max 0.1951904296875, mean 0.021484375
[2] max 0.11883544921875, mean 0.01525115966796875
[3] max 0.10003662109375, mean 0.01158905029296875
[4] max 0.07550048828125, mean 0.00901031494140625
[5] max 0.06658935546875, mean 0.006816864013671875
[6] max 0.041015625, mean 0.00492095947265625
[7] max 0.041961669921875, mean 0.002475738525390625
dv diff:
[0] max 0.3232421875, mean 0.042572021484375
[1] max 0.21240234375, mean 0.03094482421875
[2] max 0.1527099609375, mean 0.0223236083984375
[3] max 0.1075439453125, mean 0.015625
[4] max 0.08245849609375, mean 0.010223388671875
[5] max 0.0447998046875, mean 0.005950927734375
[6] max 0.0419921875, mean 0.002475738525390625
[7] max 3.0517578125e-05, mean 3.5762786865234375e-07
@GeneZC
Copy link

GeneZC commented Mar 24, 2024

Do we even still need the BPT if we have the ring attention implemented in this repo? @zhuzilin

I personally think BPT is a single-GPU version of ring attention, right?

@Edenzzzz
Copy link

Do we even still need the BPT if we have the ring attention implemented in this repo? @zhuzilin

I personally think BPT is a single-GPU version of ring attention, right?

That's right, BPT is inherently supported by ring attention. We do not need another implementation.
image

@zhuzilin
Copy link
Owner

zhuzilin commented Apr 18, 2024

I'm not sure spliting the sequence length on each device into blocks could save memory (because we still need save buffers and flash_attn itself seems to use linear size memory w.r.t. sequence length), or speed up (because it will call smaller kernels).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants