Skip to content

Commit

Permalink
[aten_cuda/flash_attn] Add typename to template argument Kernel_trait…
Browse files Browse the repository at this point in the history
…s::TiledMma(SdP) when calling the template device function pytorch_flash::convert_layout_acc_Aregs
  • Loading branch information
kiukchung committed May 31, 2024
1 parent 4a0d96e commit 9ae6929
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
: pytorch_flash::convert_type_relu<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2)
// if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8.
Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout()));
Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaSdP>(rP.layout()));
Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
// if (cute::thread0()) { print(tPaP); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
// if (cute::thread0()) { print(tOrP); }
pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }
Expand Down Expand Up @@ -402,7 +402,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}

Expand Down Expand Up @@ -895,7 +895,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));

pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);

Expand Down Expand Up @@ -957,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));

pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
Expand Down

0 comments on commit 9ae6929

Please sign in to comment.