diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h index db817a0657ffcb..5089fb2e294ff6 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h @@ -540,7 +540,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in : pytorch_flash::convert_type_relu(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2) // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); // if (cute::thread0()) { print(tPaP); } diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h index 0386a07cc64fd6..9d97abb5eb90d9 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h @@ -339,7 +339,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } @@ -402,7 +402,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } @@ -895,7 +895,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = pytorch_flash::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); @@ -957,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = pytorch_flash::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); }