diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 9b2f48aac6c39..cb4c0e085f453 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -540,7 +540,7 @@ inplace : (out_grad -> x_grad) - backward_op : flash_attn_grad - forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset) + forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : @@ -550,15 +550,15 @@ func : flash_attn_grad data_type: q -- backward_op : flash_attn_raw_grad - forward : flash_attn_raw (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset) +- backward_op : flash_attn_unpadded_grad + forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : func : FlashAttnGradInferMeta param : [q, k, v] kernel : - func : flash_attn_raw_grad + func : flash_attn_unpadded_grad data_type: q - backward_op : flip_grad diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 63bbbf2170894..d0f450c7868f8 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -530,25 +530,27 @@ - op : flash_attn args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) - output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset) + output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) infer_meta : func : FlashAttnInferMeta param : [q, k, v] kernel : func : flash_attn data_type : q + intermediate : softmax_lse, seed_offset backward : flash_attn_grad -- op : flash_attn_raw +- op : flash_attn_unpadded args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) - output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset) + output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) infer_meta : func : FlashAttnInferMeta param : [q, k, v] kernel : - func : flash_attn_raw + func : flash_attn_unpadded data_type : q - backward : flash_attn_raw_grad + intermediate : softmax_lse, seed_offset + backward : flash_attn_unpadded_grad - op : flip args : (Tensor x, int[] axis) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index d2f1c78eb0b86..c1bbc0cff2dd4 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -259,8 +259,8 @@ void FlashAttnInferMeta(const MetaTensor& q, const MetaTensor& k, const MetaTensor& v, MetaTensor* out, - MetaTensor* softmax_lse, MetaTensor* softmax, + MetaTensor* softmax_lse, MetaTensor* seed_offset) { out->set_dims(q.dims()); out->set_dtype(q.dtype()); diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 338579930ae47..47c4b9826da4a 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -67,8 +67,8 @@ void FlashAttnInferMeta(const MetaTensor& q, const MetaTensor& k, const MetaTensor& v, MetaTensor* out, - MetaTensor* softmax_lse, MetaTensor* softmax, + MetaTensor* softmax_lse, MetaTensor* seed_offset); void InstanceNormInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/flash_attn_grad_kernel.h b/paddle/phi/kernels/flash_attn_grad_kernel.h index d22ddb0ef1840..ba3a6020e4545 100644 --- a/paddle/phi/kernels/flash_attn_grad_kernel.h +++ b/paddle/phi/kernels/flash_attn_grad_kernel.h @@ -20,24 +20,24 @@ namespace phi { template -void FlashAttnRawGradKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& cu_seqlens_q, - const DenseTensor& cu_seqlens_k, - const DenseTensor& out, - const DenseTensor& softmax_lse, - const DenseTensor& seed_offset, - const DenseTensor& dout, - int64_t max_seqlen_q, - int64_t max_seqlen_k, - float scale, - float dropout, - bool causal, - DenseTensor* dq, - DenseTensor* dk, - DenseTensor* dv); +void FlashAttnUnpaddedGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv); template void FlashAttnGradKernel(const Context& ctx, diff --git a/paddle/phi/kernels/flash_attn_kernel.h b/paddle/phi/kernels/flash_attn_kernel.h index dd6db04d45cd5..9027c0f6fa905 100644 --- a/paddle/phi/kernels/flash_attn_kernel.h +++ b/paddle/phi/kernels/flash_attn_kernel.h @@ -20,22 +20,22 @@ namespace phi { template -void FlashAttnRawKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& cu_seqlens_q, - const DenseTensor& cu_seqlens_k, - int64_t max_seqlen_q, - int64_t max_seqlen_k, - float scale, - float dropout, - bool causal, - bool return_softmax, - DenseTensor* out, - DenseTensor* softmax_lse, - DenseTensor* softmax, - DenseTensor* seed_offset); +void FlashAttnUnpaddedKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset); template void FlashAttnKernel(const Context& ctx, @@ -46,8 +46,8 @@ void FlashAttnKernel(const Context& ctx, bool causal, bool return_softmax, DenseTensor* out, - DenseTensor* softmax_lse, DenseTensor* softmax, + DenseTensor* softmax_lse, DenseTensor* seed_offset); } // namespace phi diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 038557d9feb21..049f89d7507ed 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -28,24 +28,24 @@ namespace phi { template -void FlashAttnRawGradKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& cu_seqlens_q, - const DenseTensor& cu_seqlens_k, - const DenseTensor& out, - const DenseTensor& softmax_lse, - const DenseTensor& seed_offset, - const DenseTensor& dout, - int64_t max_seqlen_q, - int64_t max_seqlen_k, - float scale, - float dropout, - bool causal, - DenseTensor* dq, - DenseTensor* dk, - DenseTensor* dv) { +void FlashAttnUnpaddedGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { #ifdef PADDLE_WITH_FLASHATTN ctx.template Alloc(dq); ctx.template Alloc(dk); @@ -202,34 +202,34 @@ void FlashAttnGradKernel(const Context& ctx, ArangeNullaryKernel( ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); - FlashAttnRawGradKernel(ctx, - q_t_s, - k_t_s, - v_t_s, - cu_seqlens_q, - cu_seqlens_k, - out, - softmax_lse, - seed_offset, - dout, - seq_len_q, - seq_len_k, - scale, - dropout, - causal, - dq, - dk, - dv); + FlashAttnUnpaddedGradKernel(ctx, + q_t_s, + k_t_s, + v_t_s, + cu_seqlens_q, + cu_seqlens_k, + out, + softmax_lse, + seed_offset, + dout, + seq_len_q, + seq_len_k, + scale, + dropout, + causal, + dq, + dk, + dv); #endif } } // namespace phi -PD_REGISTER_KERNEL(flash_attn_raw_grad, +PD_REGISTER_KERNEL(flash_attn_unpadded_grad, GPU, ALL_LAYOUT, - phi::FlashAttnRawGradKernel, + phi::FlashAttnUnpaddedGradKernel, phi::dtype::float16, phi::dtype::bfloat16) { kernel->InputAt(7).SetBackend(phi::Backend::CPU); // seed_offset diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index ef8bd2a98d15e..c77248122a8c8 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -31,22 +31,22 @@ namespace phi { template -void FlashAttnRawKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& cu_seqlens_q, - const DenseTensor& cu_seqlens_k, - int64_t max_seqlen_q, - int64_t max_seqlen_k, - float scale, - float dropout, - bool causal, - bool return_softmax, - DenseTensor* out, - DenseTensor* softmax_lse, - DenseTensor* softmax, - DenseTensor* seed_offset) { +void FlashAttnUnpaddedKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN ctx.template Alloc(out); @@ -185,8 +185,8 @@ void FlashAttnKernel(const Context& ctx, bool causal, bool return_softmax, DenseTensor* out, - DenseTensor* softmax_lse, DenseTensor* softmax, + DenseTensor* softmax_lse, DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN // q,k,v [batch_size, seq_len, num_heads, head_dim] @@ -224,32 +224,32 @@ void FlashAttnKernel(const Context& ctx, ArangeNullaryKernel( ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); - FlashAttnRawKernel(ctx, - q_t_s, - k_t_s, - v_t_s, - cu_seqlens_q, - cu_seqlens_k, - seq_len_q, - seq_len_k, - scale, - dropout, - causal, - return_softmax, - out, - softmax_lse, - softmax, - seed_offset); + FlashAttnUnpaddedKernel(ctx, + q_t_s, + k_t_s, + v_t_s, + cu_seqlens_q, + cu_seqlens_k, + seq_len_q, + seq_len_k, + scale, + dropout, + causal, + return_softmax, + out, + softmax, + softmax_lse, + seed_offset); #endif } } // namespace phi -PD_REGISTER_KERNEL(flash_attn_raw, +PD_REGISTER_KERNEL(flash_attn_unpadded, GPU, ALL_LAYOUT, - phi::FlashAttnRawKernel, + phi::FlashAttnUnpaddedKernel, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/tests/unittests/test_flash_attention.py b/python/paddle/fluid/tests/unittests/test_flash_attention.py index 223a17c797b72..034ccb22f8045 100644 --- a/python/paddle/fluid/tests/unittests/test_flash_attention.py +++ b/python/paddle/fluid/tests/unittests/test_flash_attention.py @@ -22,7 +22,10 @@ import paddle.fluid as fluid import paddle.fluid.core as core import paddle.nn.functional as F -from paddle.nn.functional.flash_attention import flash_attention +from paddle.nn.functional.flash_attention import ( + flash_attention, + flash_attn_unpadded, +) def get_cuda_version(): @@ -66,9 +69,9 @@ def setUp(self): self.causal = False self.return_softmax = False - def test_raw(self): + def test_unpadded(self): print( - f"Test Raw case shape {self.shape} dtype {self.dtype} causal {self.causal}" + f"Test unpadded case shape {self.shape} dtype {self.dtype} causal {self.causal}" ) paddle.disable_static() @@ -92,7 +95,7 @@ def test_raw(self): cu_q = paddle.arange(0, (bs + 1) * ms, ms, dtype='int32') qq = paddle.reshape(q, [bs * ms, nh, hd]) - out, _, _, _ = paddle._C_ops.flash_attn_raw( + out, _ = flash_attn_unpadded( qq, qq, qq, @@ -116,6 +119,45 @@ def test_raw(self): q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 ) + # test static + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + qs = paddle.static.data( + name="q", shape=self.shape, dtype=self.dtype + ) + + cu_q = paddle.arange(0, (bs + 1) * ms, ms, dtype='int32') + qs = paddle.reshape(qs, [bs * ms, nh, hd]) + + outs, softmax = flash_attn_unpadded( + qs, + qs, + qs, + cu_q, + cu_q, + ms, + ms, + scale, + self.dropout, + self.causal, + self.return_softmax, + ) + + exe = fluid.Executor(self.place) + fetches_result = exe.run( + feed={ + "q": query.astype('float16'), + "k": query.astype('float16'), + "v": query.astype('float16'), + }, + fetch_list=[outs], + ) + + np.testing.assert_allclose( + fetches_result[0], out_, rtol=5e-03, atol=1e-03 + ) + def test_all(self): print( f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}" diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 0bda34c2436de..f34041e18958f 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -78,12 +78,7 @@ def flash_attention( print(output) """ if in_dynamic_mode(): - ( - result_attention, - result_softmax_lse, - result_softmax, - seed_offset, - ) = _C_ops.flash_attn( + (result_attention, result_softmax,) = _C_ops.flash_attn( query, key, value, @@ -121,3 +116,126 @@ def flash_attention( }, ) return out, softmax + + +def flash_attn_unpadded( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + scale, + dropout=0.0, + causal=False, + return_softmax=False, + name=None, +): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API is only support inputs with dtype float16 and bfloat16. + + Args: + query(Tensor): The query tensor in the Attention module. + 3-D tensor with shape: + [total_seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + key(Tensor): The key tensor in the Attention module. + 3-D tensor with shape: + [total_seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + value(Tensor): The value tensor in the Attention module. + 3-D tensor with shape: + [total_seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch, + used to index query. + cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch, + used to index key and value. + max_seqlen_q(int): Maximum sequence length of query in the batch. + max_seqlen_k(int): Maximum sequence length of key/value in the batch. + scale(float): The scaling of QK^T before applying softmax. + dropout(float): The dropout ratio. + causal(bool): Wether enable causal mode. + return_softmax(bool): Wether to return softmax. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + out(Tensor): The attention tensor. + 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float16 or bfloat16. + softmax(Tensor): The softmax tensor. None if return_softmax is False. + + Examples: + .. code-block:: python + + # required: skiptest + import paddle + + q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) + + output = paddle.nn.functional.flash_attn_unpadded(q, q, q, 0.9, False, False) + print(output) + """ + if in_dynamic_mode(): + (result_attention, result_softmax,) = _C_ops.flash_attn_unpadded( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + ) + return result_attention, result_softmax + + helper = LayerHelper('flash_attn_unpadded', **locals()) + dtype = helper.input_dtype(input_param_name='q') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference(paddle.float32) + seed_offset = helper.create_variable_for_type_inference(paddle.int64) + inputs = { + 'q': query, + 'k': key, + 'v': value, + 'cu_seqlens_q': cu_seqlens_q, + 'cu_seqlens_k': cu_seqlens_k, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn_unpadded', + inputs=inputs, + outputs=outputs, + attrs={ + 'max_seqlen_q': max_seqlen_q, + 'max_seqlen_k': max_seqlen_k, + 'scale': scale, + 'dropout': dropout, + 'causal': causal, + 'return_softmax': return_softmax, + }, + ) + return out, softmax