Skip to content

Commit

Permalink
rename flash_attn_raw to flash_attn_unpadded (#51704)
Browse files Browse the repository at this point in the history
* rename flash_attn_raw to flash_attn_unpadded

* fix static api

* fix static return
  • Loading branch information
kuizhiqing authored Mar 16, 2023
1 parent 86bf827 commit 0b778bd
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 129 deletions.
8 changes: 4 additions & 4 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@
inplace : (out_grad -> x_grad)

- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
Expand All @@ -550,15 +550,15 @@
func : flash_attn_grad
data_type: q

- backward_op : flash_attn_raw_grad
forward : flash_attn_raw (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_raw_grad
func : flash_attn_unpadded_grad
data_type: q

- backward_op : flip_grad
Expand Down
12 changes: 7 additions & 5 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -530,25 +530,27 @@

- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false)
output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
kernel :
func : flash_attn
data_type : q
intermediate : softmax_lse, seed_offset
backward : flash_attn_grad

- op : flash_attn_raw
- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false)
output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
kernel :
func : flash_attn_raw
func : flash_attn_unpadded
data_type : q
backward : flash_attn_raw_grad
intermediate : softmax_lse, seed_offset
backward : flash_attn_unpadded_grad

- op : flip
args : (Tensor x, int[] axis)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ void FlashAttnInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax_lse,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset) {
out->set_dims(q.dims());
out->set_dtype(q.dtype());
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ void FlashAttnInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax_lse,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);

void InstanceNormInferMeta(const MetaTensor& x,
Expand Down
36 changes: 18 additions & 18 deletions paddle/phi/kernels/flash_attn_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,24 @@
namespace phi {

template <typename T, typename Context>
void FlashAttnRawGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);

template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
Expand Down
34 changes: 17 additions & 17 deletions paddle/phi/kernels/flash_attn_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@
namespace phi {

template <typename T, typename Context>
void FlashAttnRawKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* seed_offset);
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);

template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
Expand All @@ -46,8 +46,8 @@ void FlashAttnKernel(const Context& ctx,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);

} // namespace phi
76 changes: 38 additions & 38 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,24 @@
namespace phi {

template <typename T, typename Context>
void FlashAttnRawGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
Expand Down Expand Up @@ -202,34 +202,34 @@ void FlashAttnGradKernel(const Context& ctx,
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);

FlashAttnRawGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
dq,
dk,
dv);
FlashAttnUnpaddedGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
dq,
dk,
dv);

#endif
}

} // namespace phi

PD_REGISTER_KERNEL(flash_attn_raw_grad,
PD_REGISTER_KERNEL(flash_attn_unpadded_grad,
GPU,
ALL_LAYOUT,
phi::FlashAttnRawGradKernel,
phi::FlashAttnUnpaddedGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(7).SetBackend(phi::Backend::CPU); // seed_offset
Expand Down
70 changes: 35 additions & 35 deletions paddle/phi/kernels/gpu/flash_attn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@
namespace phi {

template <typename T, typename Context>
void FlashAttnRawKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* seed_offset) {
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
ctx.template Alloc<T>(out);

Expand Down Expand Up @@ -185,8 +185,8 @@ void FlashAttnKernel(const Context& ctx,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
Expand Down Expand Up @@ -224,32 +224,32 @@ void FlashAttnKernel(const Context& ctx,
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);

FlashAttnRawKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
return_softmax,
out,
softmax_lse,
softmax,
seed_offset);
FlashAttnUnpaddedKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
return_softmax,
out,
softmax,
softmax_lse,
seed_offset);

#endif
}

} // namespace phi

PD_REGISTER_KERNEL(flash_attn_raw,
PD_REGISTER_KERNEL(flash_attn_unpadded,
GPU,
ALL_LAYOUT,
phi::FlashAttnRawKernel,
phi::FlashAttnUnpaddedKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}

Expand Down
Loading

0 comments on commit 0b778bd

Please sign in to comment.