Skip to content

Commit

Permalink
[CUDA] Workaround register spilling issue in mem-efficient SDP kernel…
Browse files Browse the repository at this point in the history
…s on `sm60` (pytorch#120445)

We're seeing that a newer version of CUDA introduces register spilling behavior for a few kernels on Pascal---this PR works around them for this specific version.

CC @ptrblck

Pull Request resolved: pytorch#120445
Approved by: https://github.com/Skylion007, https://github.com/drisspg
  • Loading branch information
eqy authored and pytorchmergebot committed Feb 23, 2024
1 parent edf1c4e commit 9e9eaf0
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,16 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 65536>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 65536>::kMinBlocksPerSm)
fmha_cutlassB_f32_aligned_64x64_k65536_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 65536>::Params p);
#if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM)
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>::kMinBlocksPerSm)
fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>::Params p);
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>::kMinBlocksPerSm)
fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>::Params p);
#else
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 32>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 32>::kMinBlocksPerSm)
Expand All @@ -444,6 +454,7 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 64>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 64>::kMinBlocksPerSm)
fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 64>::Params p);
#endif
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 128>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 128>::kMinBlocksPerSm)
Expand Down Expand Up @@ -490,8 +501,13 @@ template <typename T> void dispatch_cutlassB_f32_sm50(T cb, int cc) {
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_sm50);
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_sm50);
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 65536>(), fmha_cutlassB_f32_aligned_64x64_k65536_sm50);
#if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM)
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>(), fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50);
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>(), fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50);
#else
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm50);
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50);
#endif
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm50);
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 65536>(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm50);
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 32>(), fmha_cutlassB_f32_notaligned_64x64_k32_sm50);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,27 @@
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
using namespace PyTorchMemEffAttention;
#if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM)
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>::kMinBlocksPerSm)
fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>::Params p) {
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ >= 500
#if __CUDA_ARCH__ < 700
if (!p.advance_to_block()) {
return;
}
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>::attention_kernel(p);
return;
#endif
#endif
printf(
"FATAL: kernel `fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50` is for sm50-sm70, but was built for sm%d\n",
int(__CUDA_ARCH__ + 0) / 10);
#endif
}
#else
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 32>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 32>::kMinBlocksPerSm)
Expand All @@ -27,6 +48,7 @@ fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm50(typename AttentionBackwardKerne
int(__CUDA_ARCH__ + 0) / 10);
#endif
}
#endif
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 32>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 32>::kMinBlocksPerSm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,27 @@
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
using namespace PyTorchMemEffAttention;
#if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM)
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>::kMinBlocksPerSm)
fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>::Params p) {
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ >= 500
#if __CUDA_ARCH__ < 700
if (!p.advance_to_block()) {
return;
}
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>::attention_kernel(p);
return;
#endif
#endif
printf(
"FATAL: kernel `fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50` is for sm50-sm70, but was built for sm%d\n",
int(__CUDA_ARCH__ + 0) / 10);
#endif
}
#else
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 64>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 64>::kMinBlocksPerSm)
Expand All @@ -27,6 +48,7 @@ fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50(typename AttentionBackwardKerne
int(__CUDA_ARCH__ + 0) / 10);
#endif
}
#endif
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 64>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 64>::kMinBlocksPerSm)
Expand Down

0 comments on commit 9e9eaf0

Please sign in to comment.