From 5f6d10c14c17122e6d711a4829ee0ca672e07f6f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 22 May 2024 03:18:41 -0400 Subject: [PATCH] [CI/Build] Enforce style for C++ and CUDA code with `clang-format` (#4722) --- .clang-format | 26 + .github/workflows/clang-format.yml | 42 + csrc/activation_kernels.cu | 139 +- csrc/attention/attention_generic.cuh | 19 +- csrc/attention/attention_kernels.cu | 636 ++- csrc/attention/attention_utils.cuh | 11 +- csrc/attention/dtype_bfloat16.cuh | 74 +- csrc/attention/dtype_float16.cuh | 92 +- csrc/attention/dtype_float32.cuh | 88 +- csrc/attention/dtype_fp8.cuh | 32 +- csrc/cache.h | 44 +- csrc/cache_kernels.cu | 288 +- csrc/cpu/activation.cpp | 60 +- csrc/cpu/attention.cpp | 411 +- csrc/cpu/cache.cpp | 53 +- csrc/cpu/layernorm.cpp | 32 +- csrc/cpu/pos_encoding.cpp | 66 +- csrc/cpu/pybind.cpp | 75 +- csrc/cuda_compat.h | 9 +- csrc/cuda_utils.h | 7 +- csrc/cuda_utils_kernels.cu | 40 +- csrc/custom_all_reduce.cu | 55 +- csrc/custom_all_reduce.cuh | 105 +- csrc/custom_all_reduce_test.cu | 38 +- csrc/dispatch_utils.h | 42 +- csrc/layernorm_kernels.cu | 242 +- csrc/moe/moe_ops.cpp | 3 +- csrc/moe/moe_ops.h | 8 +- csrc/moe_align_block_size_kernels.cu | 211 +- csrc/ops.h | 330 +- csrc/pos_encoding_kernels.cu | 229 +- csrc/pybind.cpp | 142 +- csrc/quantization/aqlm/gemm_kernels.cu | 536 +-- csrc/quantization/awq/dequantize.cuh | 138 +- csrc/quantization/awq/gemm_kernels.cu | 611 +-- .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 38 +- .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 22 +- .../cutlass_w8a8/scaled_mm_dq_entry.cu | 47 +- csrc/quantization/fp8/amd/hip_float8.h | 216 +- csrc/quantization/fp8/amd/hip_float8_impl.h | 520 +-- csrc/quantization/fp8/amd/quant_utils.cuh | 711 ++-- csrc/quantization/fp8/common.cu | 87 +- csrc/quantization/fp8/nvidia/quant_utils.cuh | 138 +- csrc/quantization/gptq/compat.cuh | 70 +- csrc/quantization/gptq/matrix_view.cuh | 503 +-- csrc/quantization/gptq/q_gemm.cu | 3441 ++++++++--------- csrc/quantization/gptq/qdq_2.cuh | 107 +- csrc/quantization/gptq/qdq_3.cuh | 246 +- csrc/quantization/gptq/qdq_4.cuh | 203 +- csrc/quantization/gptq/qdq_8.cuh | 34 +- csrc/quantization/gptq/qdq_util.cuh | 58 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 696 ++-- csrc/quantization/gptq_marlin/gptq_marlin.cuh | 50 +- .../gptq_marlin/gptq_marlin_dtypes.cuh | 89 +- .../gptq_marlin/gptq_marlin_repack.cu | 94 +- .../marlin/dense/marlin_cuda_kernel.cu | 460 ++- csrc/quantization/marlin/sparse/common/base.h | 12 +- csrc/quantization/marlin/sparse/common/mem.h | 64 +- csrc/quantization/marlin/sparse/common/mma.h | 107 +- .../marlin/sparse/marlin_24_cuda_kernel.cu | 446 ++- .../squeezellm/quant_cuda_kernel.cu | 63 +- csrc/reduction_utils.cuh | 20 +- format.sh | 57 +- requirements-dev.txt | 1 + 64 files changed, 6571 insertions(+), 6963 deletions(-) create mode 100644 .clang-format create mode 100644 .github/workflows/clang-format.yml diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000..7f9e6d720fae5 --- /dev/null +++ b/.clang-format @@ -0,0 +1,26 @@ +BasedOnStyle: Google +UseTab: Never +IndentWidth: 2 +ColumnLimit: 80 + +# Force pointers to the type for C++. +DerivePointerAlignment: false +PointerAlignment: Left + +# Reordering #include statements can (and currently will) introduce errors +SortIncludes: false + +# Style choices +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +IndentPPDirectives: BeforeHash + +IncludeCategories: + - Regex: '^<' + Priority: 4 + - Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/' + Priority: 3 + - Regex: '^"(qoda|\.\.)/' + Priority: 2 + - Regex: '.*' + Priority: 1 diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml new file mode 100644 index 0000000000000..e9b6e28fa6bcb --- /dev/null +++ b/.github/workflows/clang-format.yml @@ -0,0 +1,42 @@ +name: clang-format + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + clang-format: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install clang-format==18.1.5 + - name: Running clang-format + run: | + EXCLUDES=( + 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' + 'csrc/punica/bgmv/bgmv_config.h' + 'csrc/punica/bgmv/bgmv_impl.cuh' + 'csrc/punica/bgmv/vec_dtypes.cuh' + 'csrc/punica/punica_ops.cu' + 'csrc/punica/type_convert.h' + ) + find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ + | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ + | xargs clang-format --dry-run --Werror \ No newline at end of file diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 24d972702c858..867f63f12de4b 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -10,11 +10,11 @@ namespace vllm { // Activation and gating kernel template. -template +template __global__ void act_and_mul_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); @@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel( } } -template +template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); + return (T)(((float)x) / (1.0f + expf((float)-x))); } -template +template __device__ __forceinline__ T gelu_kernel(const T& x) { // Equivalent to PyTorch GELU with 'none' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 - const float f = (float) x; + const float f = (float)x; constexpr float ALPHA = M_SQRT1_2; - return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); + return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); } -template +template __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { // Equivalent to PyTorch GELU with 'tanh' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 - const float f = (float) x; + const float f = (float)x; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float KAPPA = 0.044715; float x_cube = f * f * f; float inner = BETA * (f + KAPPA * x_cube); - return (T) (0.5f * f * (1.0f + ::tanhf(inner))); + return (T)(0.5f * f * (1.0f + ::tanhf(inner))); } -} // namespace vllm +} // namespace vllm // Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "act_and_mul_kernel", \ - [&] { \ - vllm::act_and_mul_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); - -void silu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + +void silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } -void gelu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); } -void gelu_tanh_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); } @@ -96,11 +90,11 @@ void gelu_tanh_and_mul( namespace vllm { // Element-wise activation kernel template. -template +template __global__ void activation_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); @@ -108,54 +102,49 @@ __global__ void activation_kernel( } } -} // namespace vllm +} // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / d; \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "activation_kernel", \ - [&] { \ - vllm::activation_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + vllm::activation_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); namespace vllm { -template +template __device__ __forceinline__ T gelu_new_kernel(const T& x) { - const float x3 = (float) (x * x * x); - const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); - return ((T) 0.5) * x * (((T) 1.0) + t); + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); } -template +template __device__ __forceinline__ T gelu_fast_kernel(const T& x) { - const float f = (float) x; - const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); - return ((T) 0.5) * x * (((T) 1.0) + t); + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); } -} // namespace vllm +} // namespace vllm -void gelu_new( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_new(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_fast(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh index 31fb401cbe2c1..62409c0cce93e 100644 --- a/csrc/attention/attention_generic.cuh +++ b/csrc/attention/attention_generic.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -22,31 +23,31 @@ namespace vllm { // A vector type to store Q, K, V elements. -template +template struct Vec {}; // A vector type to store FP32 accumulators. -template +template struct FloatVec {}; // Template vector operations. -template +template inline __device__ Acc mul(A a, B b); -template +template inline __device__ float sum(T v); -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { @@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { dst = tmp.raw; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 41b337dd91d36..d6203174e7275 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -27,15 +28,15 @@ #ifdef USE_ROCM #include #include "../quantization/fp8/amd/quant_utils.cuh" - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16 __nv_bfloat16; #else #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif #ifndef USE_ROCM -#define WARP_SIZE 32 + #define WARP_SIZE 32 #else -#define WARP_SIZE warpSize + #define WARP_SIZE warpSize #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -45,7 +46,7 @@ namespace vllm { // Utility function for attention softmax. -template +template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; @@ -82,31 +83,28 @@ inline __device__ float block_sum(float* red_smem, float sum) { // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - vllm::Fp8KVCacheDataType KV_DTYPE, - int PARTITION_SIZE = 0> // Zero means no partitioning. +template // Zero means no partitioning. __device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -118,22 +116,29 @@ __device__ void paged_attention_kernel( } const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -143,13 +148,14 @@ __device__ void paged_attention_kernel( const int num_heads = gridDim.x; const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; - const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or compute 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; @@ -163,18 +169,21 @@ __device__ void paged_attention_kernel( // Load the query to registers. // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; #pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; @@ -193,44 +202,50 @@ __device__ void paged_attention_kernel( // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; K_vec k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); } else { // Vector conversion from Quant_vec to K_vec. Quant_vec k_vec_quant = *reinterpret_cast( - k_ptr + offset1 * BLOCK_SIZE * x + offset2); - k_vecs[j] = fp8::scaled_convert(k_vec_quant, kv_scale); + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, kv_scale); } } // Compute dot product. // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; @@ -285,13 +300,12 @@ __device__ void paged_attention_kernel( // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *exp_sums_ptr = exp_sum; } @@ -304,7 +318,8 @@ __device__ void paged_attention_kernel( constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -315,18 +330,21 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); - const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -337,14 +355,17 @@ __device__ void paged_attention_kernel( if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { v_vec = *reinterpret_cast(v_ptr + offset); } else { - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8::scaled_convert(v_quant_vec, kv_scale); + v_vec = fp8::scaled_convert(v_quant_vec, + kv_scale); } if (block_idx == num_seq_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the context, - // we should explicitly zero out the values since they may contain NaNs. - // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { @@ -367,8 +388,8 @@ __device__ void paged_attention_kernel( accs[i] = acc; } - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. __syncthreads(); // Perform reduction across warps. @@ -405,9 +426,9 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -419,79 +440,75 @@ __device__ void paged_attention_kernel( } // Grid: (num_heads, num_seqs, 1). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - vllm::Fp8KVCacheDataType KV_DTYPE> +template __global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - vllm::Fp8KVCacheDataType KV_DTYPE, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride, kv_scale); + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_partitions) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; @@ -499,9 +516,11 @@ __global__ void paged_attention_v2_reduce_kernel( const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { out_ptr[i] = tmp_out_ptr[i]; } @@ -520,8 +539,9 @@ __global__ void paged_attention_v2_reduce_kernel( // Load max logits to shared memory. float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const float l = max_logits_ptr[i]; @@ -550,9 +570,11 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. - float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float l = shared_max_logits[i]; @@ -565,61 +587,45 @@ __global__ void paged_attention_v2_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { float acc = 0.0f; for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; } from_float(out_ptr[i], acc); } } -} // namespace vllm - -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ - vllm::paged_attention_v1_kernel<<>>( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); +} // namespace vllm + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel< \ + T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + kv_scale); // TODO(woosuk): Tune NUM_THREADS. -template< - typename T, - typename CACHE_T, - int BLOCK_SIZE, - vllm::Fp8KVCacheDataType KV_DTYPE, - int NUM_THREADS = 128> +template void paged_attention_v1_launcher( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int max_seq_len, - const c10::optional& alibi_slopes, - float kv_scale) { + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -632,9 +638,10 @@ void paged_attention_v1_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -644,7 +651,8 @@ void paged_attention_v1_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_seq_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len @@ -683,19 +691,10 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ - kv_scale); +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -716,74 +715,45 @@ void paged_attention_v1_launcher( } void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale) { - - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE) -} - -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - seq_lens_ptr, \ - max_num_partitions); - -template< - typename T, - typename CACHE_T, - int BLOCK_SIZE, - vllm::Fp8KVCacheDataType KV_DTYPE, - int NUM_THREADS = 128, - int PARTITION_SIZE = 512> + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale){ + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE)} +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, kv_scale); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template void paged_attention_v2_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int max_seq_len, - const c10::optional& alibi_slopes, - float kv_scale) { + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -796,9 +766,10 @@ void paged_attention_v2_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -853,59 +824,50 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ - kv_scale); +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale) { - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale) { + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) } #undef WARP_SIZE diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index ff64c4bd8f80c..cdcee42748998 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -26,7 +27,7 @@ namespace vllm { // Q*K^T operation. -template +template inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { using A_vec = typename FloatVec::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). @@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { return qk; } -template +template struct Qk_dot { - template + template static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { return qk_dot_(q, k); } }; -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 31e0cee01d2e1..3cdcb95e08099 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -28,8 +30,8 @@ #include #include - typedef __hip_bfloat162 __nv_bfloat162; - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; #endif #include @@ -50,37 +52,37 @@ struct bf16_8_t { }; // BF16 vector types for Q, K, V. -template<> +template <> struct Vec<__nv_bfloat16, 1> { using Type = __nv_bfloat16; }; -template<> +template <> struct Vec<__nv_bfloat16, 2> { using Type = __nv_bfloat162; }; -template<> +template <> struct Vec<__nv_bfloat16, 4> { using Type = bf16_4_t; }; -template<> +template <> struct Vec<__nv_bfloat16, 8> { using Type = bf16_8_t; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec<__nv_bfloat16> { using Type = float; }; -template<> +template <> struct FloatVec<__nv_bfloat162> { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { assert(false); #else #ifndef USE_ROCM - return a + b; + return a + b; #else - return __hadd(a, b); + return __hadd(a, b); #endif #endif } @@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #endif } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #endif } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); } -template<> +template <> inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { bf16_4_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_4_t c; @@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { bf16_8_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_8_t c; @@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { float fa = __bfloat162float(a); float fb = __bfloat162float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { float2 fa = bf1622float2(a); float2 fb = bf1622float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul(bf162bf162(a), b); } -template<> +template <> inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); Float4_ fc; @@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); Float8_ fc; @@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { } // Vector fused multiply-add. -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf #endif } -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(__nv_bfloat16 v) { return __bfloat162float(v); } -template<> +template <> inline __device__ float sum(__nv_bfloat162 v) { float2 vf = bf1622float2(v); return vf.x + vf.y; } -template<> +template <> inline __device__ float sum(bf16_4_t v) { return sum(v.x) + sum(v.y); } -template<> +template <> inline __device__ float sum(bf16_8_t v) { return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); } @@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { #endif } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index d3271e69cd69d..3a1815f0ed4fc 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -30,37 +32,37 @@ namespace vllm { // FP16 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = uint16_t; }; -template<> +template <> struct Vec { using Type = uint32_t; }; -template<> +template <> struct Vec { using Type = uint2; }; -template<> +template <> struct Vec { using Type = uint4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { return b; #else union { - uint32_t u32; - uint16_t u16[2]; + uint32_t u32; + uint16_t u16[2]; } tmp; tmp.u16[0] = a; tmp.u16[1] = a; @@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { } tmp; #ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif #else tmp.u16[0] = float_to_half(f.x); @@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { uint16_t c; #ifndef USE_ROCM @@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { uint32_t c; #ifndef USE_ROCM @@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ uint2 mul(uint2 a, uint2 b) { uint2 c; c.x = mul(a.x, b.x); @@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { return c; } -template<> +template <> inline __device__ uint2 mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); uint2 c; @@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint4 a, uint4 b) { uint4 c; c.x = mul(a.x, b.x); @@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); uint4 c; @@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { return c; } -template<> +template <> inline __device__ float mul(uint16_t a, uint16_t b) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(uint32_t a, uint32_t b) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ Float4_ mul(uint2 a, uint2 b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); Float4_ fc; @@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint4 a, uint4 b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); Float8_ fc; @@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #ifndef USE_ROCM - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); #else - asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); #endif return d; } @@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(uint16_t v) { return half_to_float(v); } -template<> +template <> inline __device__ float sum(uint32_t v) { float2 tmp = half2_to_float2(v); return tmp.x + tmp.y; } -template<> +template <> inline __device__ float sum(uint2 v) { uint32_t c = add(v.x, v.y); return sum(c); } -template<> +template <> inline __device__ float sum(uint4 v) { uint32_t c = add(v.x, v.y); c = add(c, v.z); @@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { } // From float16 to float32. -inline __device__ float to_float(uint16_t u) { - return half_to_float(u); -} +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } -inline __device__ float2 to_float(uint32_t u) { - return half2_to_float2(u); -} +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } inline __device__ Float4_ to_float(uint2 u) { Float4_ tmp; @@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { } // Zero-out a variable. -inline __device__ void zero(uint16_t& dst) { - dst = uint16_t(0); -} +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index b200d2d226eb0..7c6a686db3ba9 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -38,37 +40,35 @@ struct Float8_ { }; // FP32 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = float; }; -template<> +template <> struct Vec { using Type = float2; }; -template<> +template <> struct Vec { using Type = float4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = float4; }; // Vector addition. -inline __device__ float add(float a, float b) { - return a + b; -} +inline __device__ float add(float a, float b) { return a + b; } inline __device__ float2 add(float2 a, float2 b) { float2 c; @@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { } // Vector multiplication. -template<> +template <> inline __device__ float mul(float a, float b) { return a * b; } -template<> +template <> inline __device__ float2 mul(float2 a, float2 b) { float2 c; c.x = a.x * b.x; @@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { return c; } -template<> +template <> inline __device__ float2 mul(float a, float2 b) { float2 c; c.x = a * b.x; @@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { return c; } -template<> +template <> inline __device__ float4 mul(float4 a, float4 b) { float4 c; c.x = a.x * b.x; @@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { return c; } -template<> +template <> inline __device__ float4 mul(float a, float4 b) { float4 c; c.x = a * b.x; @@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { } // Vector fused multiply-add. -inline __device__ float fma(float a, float b, float c) { - return a * b + c; -} +inline __device__ float fma(float a, float b, float c) { return a * b + c; } inline __device__ float2 fma(float2 a, float2 b, float2 c) { float2 d; @@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { } // Vector sum. -template<> +template <> inline __device__ float sum(float v) { return v; } -template<> +template <> inline __device__ float sum(float2 v) { return v.x + v.y; } -template<> +template <> inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -template<> +template <> inline __device__ float sum(Float4_ v) { return v.x.x + v.x.y + v.y.x + v.y.y; } -template<> +template <> inline __device__ float sum(Float8_ v) { return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; } // Vector dot product. -inline __device__ float dot(float a, float b) { - return a * b; -} +inline __device__ float dot(float a, float b) { return a * b; } inline __device__ float dot(float2 a, float2 b) { float2 c = mul(a, b); @@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { } // From float to float. -inline __device__ void from_float(float& dst, float src) { - dst = src; -} +inline __device__ void from_float(float& dst, float src) { dst = src; } -inline __device__ void from_float(float2& dst, float2 src) { - dst = src; -} +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } -inline __device__ void from_float(float4& dst, float4 src) { - dst = src; -} +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } // From float to float. -inline __device__ float to_float(float u) { - return u; -} +inline __device__ float to_float(float u) { return u; } -inline __device__ float2 to_float(float2 u) { - return u; -} +inline __device__ float2 to_float(float2 u) { return u; } -inline __device__ float4 to_float(float4 u) { - return u; -} +inline __device__ float4 to_float(float4 u) { return u; } -inline __device__ Float4_ to_float(Float4_ u) { - return u; -} +inline __device__ Float4_ to_float(Float4_ u) { return u; } -inline __device__ Float8_ to_float(Float8_ u) { - return u; -} +inline __device__ Float8_ to_float(Float8_ u) { return u; } // Zero-out a variable. -inline __device__ void zero(float& dst) { - dst = 0.f; -} +inline __device__ void zero(float& dst) { dst = 0.f; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index 2b32ce372a64f..e714e321b0beb 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -4,38 +4,38 @@ #include #ifdef ENABLE_FP8 -#ifndef USE_ROCM -#include -#endif // USE_ROCM -#endif // ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 namespace vllm { enum class Fp8KVCacheDataType { - kAuto = 0, - kFp8E4M3 = 1, - kFp8E5M2 = 2, + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, }; // fp8 vector types for quantization of kv cache -template<> +template <> struct Vec { - using Type = uint8_t; + using Type = uint8_t; }; -template<> +template <> struct Vec { - using Type = uint16_t; + using Type = uint16_t; }; -template<> +template <> struct Vec { - using Type = uint32_t; + using Type = uint32_t; }; -template<> +template <> struct Vec { - using Type = uint2; + using Type = uint2; }; -} // namespace vllm +} // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index 8c176c452425e..435ae3e57f555 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -5,36 +5,24 @@ #include #include -void swap_blocks( - torch::Tensor& src, - torch::Tensor& dst, - const torch::Tensor& block_mapping); +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping); -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - const torch::Tensor& block_mapping); +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& block_mapping); -void reshape_and_cache( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const float kv_scale); +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, const float kv_scale); -void reshape_and_cache_flash( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); +void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); // Just for unittest -void convert_fp8( - torch::Tensor& dst_cache, - torch::Tensor& src_cache, - const float scale, - const std::string& kv_cache_dtype); +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const float scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index e5b74da6ad068..d924ac39b89ca 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -6,9 +6,9 @@ #include "dispatch_utils.h" #ifdef USE_ROCM -#include "quantization/fp8/amd/quant_utils.cuh" + #include "quantization/fp8/amd/quant_utils.cuh" #else -#include "quantization/fp8/nvidia/quant_utils.cuh" + #include "quantization/fp8/nvidia/quant_utils.cuh" #endif #include @@ -18,20 +18,17 @@ #ifdef USE_ROCM #include - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16 __nv_bfloat16; #endif -void swap_blocks( - torch::Tensor& src, - torch::Tensor& dst, - const torch::Tensor& block_mapping) { +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; if (src_device.is_cuda() && dst_device.is_cuda()) { - TORCH_CHECK( - src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); memcpy_type = cudaMemcpyDeviceToDevice; } else if (src_device.is_cuda() && dst_device.is_cpu()) { memcpy_type = cudaMemcpyDeviceToHost; @@ -41,16 +38,17 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - // NOTE(youkaichao): keep in mind that `block_mapping` should be + // NOTE(youkaichao): keep in mind that `block_mapping` should be // a cpu tensor, otherwise every `item` call will require a gpu-cpu // synchronization. TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); - char *src_ptr = static_cast(src.data_ptr()); - char *dst_ptr = static_cast(dst.data_ptr()); + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); + const at::cuda::OptionalCUDAGuard device_guard( + src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. const int64_t num_blocks = block_mapping.size(0); @@ -59,29 +57,25 @@ void swap_blocks( int64_t dst_block_number = block_mapping[i][1].item(); int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; - cudaMemcpyAsync( - dst_ptr + dst_offset, - src_ptr + src_offset, - block_size_in_bytes, - memcpy_type, - stream); + cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, + block_size_in_bytes, memcpy_type, stream); } } namespace vllm { // Grid: (num_layers, num_pairs) -template -__global__ void copy_blocks_kernel( - int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { +template +__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { const int layer_idx = blockIdx.x; const int pair_idx = blockIdx.y; scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); - scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); + scalar_t* value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; @@ -99,12 +93,11 @@ __global__ void copy_blocks_kernel( } } -} // namespace vllm +} // namespace vllm -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - const torch::Tensor& block_mapping) { +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -118,8 +111,10 @@ void copy_blocks( int64_t key_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers]; for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); - value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); + key_cache_ptrs[layer_idx] = + reinterpret_cast(key_caches[layer_idx].data_ptr()); + value_cache_ptrs[layer_idx] = + reinterpret_cast(value_caches[layer_idx].data_ptr()); } // block_mapping is a 2D tensor with shape (num_pairs, 2). @@ -127,10 +122,12 @@ void copy_blocks( // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. - torch::Tensor key_cache_ptrs_tensor = torch::from_blob( - key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor value_cache_ptrs_tensor = torch::from_blob( - value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); + torch::Tensor key_cache_ptrs_tensor = + torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + torch::Tensor value_cache_ptrs_tensor = + torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -139,31 +136,28 @@ void copy_blocks( const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { - vllm::copy_blocks_kernel<<>>( - key_cache_ptrs_tensor.data_ptr(), - value_cache_ptrs_tensor.data_ptr(), - block_mapping.data_ptr(), - numel_per_block); - })); + key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { + vllm::copy_blocks_kernel<<>>( + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); + })); } namespace vllm { -template +template __global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x, - const float kv_scale) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, + // block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, + // block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x, + const float kv_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -184,40 +178,39 @@ __global__ void reshape_and_cache_kernel( const int x_idx = head_offset / x; const int x_offset = head_offset % x; - const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; } else { - key_cache[tgt_key_idx] = fp8::scaled_convert(tgt_key, kv_scale); - value_cache[tgt_value_idx] = fp8::scaled_convert(tgt_value, kv_scale); + key_cache[tgt_key_idx] = + fp8::scaled_convert(tgt_key, kv_scale); + value_cache[tgt_value_idx] = + fp8::scaled_convert(tgt_value, kv_scale); } } } -template +template __global__ void reshape_and_cache_flash_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] - scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, + // head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, + // head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, const int key_stride, const int value_stride, + const int num_heads, const int head_size, const int block_size) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -232,43 +225,37 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_value_idx = block_idx * block_stride - + block_offset * num_heads * head_size - + head_idx * head_size - + head_offset; + const int64_t tgt_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; k_cache[tgt_value_idx] = key[src_key_idx]; v_cache[tgt_value_idx] = value[src_value_idx]; } } -} // namespace vllm +} // namespace vllm // KV_T is the stored data type of kv-cache. // CACHE_T is the data type of key and value tensors. // KV_DTYPE is the real data type of kv-cache. -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ - vllm::reshape_and_cache_kernel<<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - x, \ - kv_scale); +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), key_stride, value_stride, \ + num_heads, head_size, block_size, x, kv_scale); void reshape_and_cache( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, - const float kv_scale) -{ + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, const float kv_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -283,17 +270,17 @@ void reshape_and_cache( const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE) + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE) } void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype) -{ + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) { // FIXME: only support auto datatype, does not support fp8 if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); @@ -313,62 +300,47 @@ void reshape_and_cache_flash( const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), - "reshape_and_cache_flash", - [&] { - vllm::reshape_and_cache_flash_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - k_cache.data_ptr(), - v_cache.data_ptr(), - slot_mapping.data_ptr(), - block_stride, - key_stride, - value_stride, - num_heads, - head_size, - block_size); - }); + key.scalar_type(), "reshape_and_cache_flash", [&] { + vllm::reshape_and_cache_flash_kernel + <<>>( + key.data_ptr(), value.data_ptr(), + k_cache.data_ptr(), v_cache.data_ptr(), + slot_mapping.data_ptr(), block_stride, key_stride, + value_stride, num_heads, head_size, block_size); + }); } namespace vllm { -template -__global__ void convert_fp8_kernel( - const Tin* __restrict__ src_cache, - Tout* __restrict__ dst_cache, - const float kv_scale, - const int64_t block_stride) { +template +__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const float kv_scale, + const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; - dst_cache[idx] = fp8::scaled_convert(src_cache[idx], kv_scale); + dst_cache[idx] = + fp8::scaled_convert(src_cache[idx], kv_scale); } } -} // namespace vllm +} // namespace vllm -#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ - kv_scale, \ - block_stride); +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), kv_scale, block_stride); // Only for testing. -void convert_fp8( - torch::Tensor& dst_cache, - torch::Tensor& src_cache, - const float kv_scale, - const std::string& kv_cache_dtype) -{ +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const float kv_scale, const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") - TORCH_CHECK( - src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); at::cuda::OptionalCUDAGuard device_guard(src_device); int64_t num_blocks = src_cache.size(0); @@ -398,13 +370,15 @@ void convert_fp8( } else if (src_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3); + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, + vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (dst_cache.dtype() == at::ScalarType::Float) { CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (dst_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); } } else { TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp index 1bd24eb79d129..becd2ac42f17a 100644 --- a/csrc/cpu/activation.cpp +++ b/csrc/cpu/activation.cpp @@ -1,10 +1,10 @@ #include "cpu_types.hpp" namespace { -template -void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, - scalar_t *__restrict__ output) { +void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, + scalar_t* __restrict__ output) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); @@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, } } -FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 ones(1.0); return x / (ones + (zeros - x).exp()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w2(0.5); return x * w2 * (ones + (x * w1).er()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w2(0.5); @@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); return x * w2 * (ones + inner.tanh()); } -}; // namespace +}; // namespace -void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { +void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "silu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(silu_and_mul_impl) - activation_kernel(num_tokens, d, - input.data_ptr(), - out.data_ptr()); - CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); } -void gelu_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "gelu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) - activation_kernel(num_tokens, d, - input.data_ptr(), - out.data_ptr()); - CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); } -void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; @@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] }); } -void gelu_new(torch::Tensor &out, torch::Tensor &input) { +void gelu_new(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); @@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) { }); } -void gelu_fast(torch::Tensor &out, torch::Tensor &input) { +void gelu_fast(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index c1d765be05598..54df69b7379d6 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -2,7 +2,8 @@ namespace { -template struct KernelVecType { +template +struct KernelVecType { using q_load_vec_type = void; using q_vec_type = void; using k_load_vec_type = void; @@ -11,7 +12,8 @@ template struct KernelVecType { using v_load_vec_type = void; }; -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::FP32Vec4; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16; @@ -21,7 +23,8 @@ template <> struct KernelVecType { }; #ifdef __AVX512BF16__ -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32; @@ -30,7 +33,8 @@ template <> struct KernelVecType { using v_load_vec_type = vec_op::BF16Vec16; }; #else -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::BF16Vec16; @@ -41,7 +45,7 @@ template <> struct KernelVecType { #endif template -FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, +FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, const int capacity) { T max = data[0]; for (int i = 1; i < size; ++i) { @@ -67,10 +71,11 @@ FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, } template -FORCE_INLINE std::pair -reduceSoftmaxAlibi(T *data, const int size, const int capacity, - const float alibi_slope, const int start_index, - const int seq_len) { +FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, + const int capacity, + const float alibi_slope, + const int start_index, + const int seq_len) { data[0] += alibi_slope * (start_index - seq_len + 1); T max = data[0]; for (int i = 1; i < size; ++i) { @@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity, } template -FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, +FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, const int size) { T max = max_data[0]; for (int i = 1; i < size; ++i) { @@ -132,9 +137,9 @@ struct reduceQKBlockKernel { static_assert(k_load_vec_type::get_elem_num() % x == 0); static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - FORCE_INLINE static void call(const scalar_t *__restrict__ q, - const scalar_t *__restrict__ k_block, - float *__restrict__ logits, float scale, + FORCE_INLINE static void call(const scalar_t* __restrict__ q, + const scalar_t* __restrict__ k_block, + float* __restrict__ logits, float scale, const int token_num) { const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; @@ -196,8 +201,8 @@ struct reduceQKBlockKernel { template -FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, - acc_t &&acc) { +FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, + acc_t&& acc) { using v_load_vec_type = typename KernelVecType::v_load_vec_type; constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); static_assert(BLOCK_SIZE == ELEM_NUM); @@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; }); } -}; // namespace +}; // namespace // Paged attention v1 namespace { template struct paged_attention_v1_impl { - static void - call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + static void call( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] - const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size, block_size] - const int num_kv_heads, const float scale, - const int - *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float *__restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads) { + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { constexpr int x = 16 / sizeof(scalar_t); const int num_queries_per_kv = num_heads / num_kv_heads; @@ -243,32 +248,31 @@ struct paged_attention_v1_impl { size_t logits_bytes = parallel_work_item_num * max_seq_len_padded * sizeof(float); - float *logits = (float *)std::aligned_alloc( - 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_seq_len_padded] + float* logits = (float*)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_seq_len_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { int seq_len = seq_lens[seq_idx]; - const int *seq_block_table = + const int* seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t *__restrict__ q_vec_ptr = + const scalar_t* __restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - const int last_block_token_num = - seq_len - (block_num - 1) * BLOCK_SIZE; - float *__restrict__ thread_block_logits = + const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; + float* __restrict__ thread_block_logits = logits + omp_get_thread_num() * max_seq_len_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t *__restrict__ k_block_cache_ptr = + const scalar_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float *__restrict__ head_block_logits = + float* __restrict__ head_block_logits = thread_block_logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -282,8 +286,7 @@ struct paged_attention_v1_impl { block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, seq_len); } else { - reduceSoftmax(thread_block_logits, seq_len, - block_num * BLOCK_SIZE); + reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); } // Compute value @@ -293,14 +296,14 @@ struct paged_attention_v1_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t *__restrict__ out_ptr = + scalar_t* __restrict__ out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float *__restrict__ prob_vec_ptr = + const float* __restrict__ prob_vec_ptr = thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t *__restrict__ v_block_cache_ptr = + const scalar_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -311,7 +314,7 @@ struct paged_attention_v1_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t *__restrict__ next_v_block_cache_ptr = + const scalar_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -340,16 +343,16 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); template void paged_attention_v1_impl_launcher( - torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &seq_lens, - int max_seq_len, const c10::optional &alibi_slopes) { + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -359,67 +362,66 @@ void paged_attention_v1_impl_launcher( int kv_head_stride = key_cache.stride(1); // NOTE: alibi_slopes is optional. - const float *alibi_slopes_ptr = + const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T *out_ptr = reinterpret_cast(out.data_ptr()); - T *query_ptr = reinterpret_cast(query.data_ptr()); - T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int *block_tables_ptr = block_tables.data_ptr(); - int *seq_lens_ptr = seq_lens.data_ptr(); + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 256: - LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ +#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ seq_lens, max_seq_len, alibi_slopes); -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace +} // namespace -void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, - torch::Tensor &key_cache, torch::Tensor &value_cache, +void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, + torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, - torch::Tensor &seq_lens, int block_size, - int max_seq_len, - const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype, float kv_scale) { + torch::Tensor& block_tables, torch::Tensor& seq_lens, + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { @@ -434,23 +436,24 @@ namespace { template struct paged_attention_v2_impl { static void call( - scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] - float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float - *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, - const int - *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ seq_lens, // [num_seqs] + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float *__restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const int num_seqs, const int num_heads, const int max_num_partitions) { constexpr int x = 16 / sizeof(scalar_t); @@ -468,8 +471,7 @@ struct paged_attention_v2_impl { const int seq_len = seq_lens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= seq_len) - continue; + if (start_token_idx >= seq_len) continue; const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; @@ -477,15 +479,14 @@ struct paged_attention_v2_impl { const int token_num = (std::min(seq_len, start_token_idx + PARTITION_SIZE) - start_token_idx); - const int block_num = - (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_token_num = token_num - (block_num - 1) * BLOCK_SIZE; - const int *seq_block_table = block_tables + + const int* seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx + start_token_idx / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t *__restrict__ q_vec_ptr = + const scalar_t* __restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; @@ -493,10 +494,10 @@ struct paged_attention_v2_impl { // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t *__restrict__ k_block_cache_ptr = + const scalar_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float *__restrict__ head_block_logits = + float* __restrict__ head_block_logits = logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -510,13 +511,13 @@ struct paged_attention_v2_impl { logits, token_num, block_num * BLOCK_SIZE, alibi_slopes[head_idx], start_token_idx, seq_len); } else { - max_and_sum = reduceSoftmax(logits, token_num, - block_num * BLOCK_SIZE); + max_and_sum = + reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); } - auto &&[max_logit, exp_sum] = max_and_sum; + auto&& [max_logit, exp_sum] = max_and_sum; - scalar_t *__restrict__ output_buffer = nullptr; + scalar_t* __restrict__ output_buffer = nullptr; if (!no_reduce) { auto idx = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; @@ -538,13 +539,13 @@ struct paged_attention_v2_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t *__restrict__ out_ptr = + scalar_t* __restrict__ out_ptr = output_buffer + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float *__restrict__ prob_vec_ptr = + const float* __restrict__ prob_vec_ptr = logits + block_idx * BLOCK_SIZE; - const scalar_t *__restrict__ v_block_cache_ptr = + const scalar_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -555,7 +556,7 @@ struct paged_attention_v2_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t *__restrict__ next_v_block_cache_ptr = + const scalar_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -587,8 +588,7 @@ struct paged_attention_v2_impl { const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) - continue; + if (partition_num == 1) continue; reducePartitonSoftmax( max_logits + seq_idx * num_heads * max_num_partitions + @@ -603,11 +603,11 @@ struct paged_attention_v2_impl { using v_load_vec_type = typename KernelVecType::v_load_vec_type; static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); constexpr int head_elem_num_per_group = - 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE - // didn't align with 64 bytes + 16; // Note: didn't align with the cacheline size, due to some + // HEAD_SIZE didn't align with 64 bytes static_assert(HEAD_SIZE % head_elem_num_per_group == 0); constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; - const float *__restrict__ rescale_factors = exp_sums; + const float* __restrict__ rescale_factors = exp_sums; #pragma omp parallel for collapse(3) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { @@ -616,17 +616,16 @@ struct paged_attention_v2_impl { const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) - continue; + if (partition_num == 1) continue; - const float *__restrict__ seq_head_rescale_factors = + const float* __restrict__ seq_head_rescale_factors = rescale_factors + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - const scalar_t *__restrict__ seq_head_tmp_out = + const scalar_t* __restrict__ seq_head_tmp_out = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE + group_idx * head_elem_num_per_group; - scalar_t *__restrict__ seq_head_output = + scalar_t* __restrict__ seq_head_output = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + group_idx * head_elem_num_per_group; @@ -645,21 +644,21 @@ struct paged_attention_v2_impl { } }; -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ +#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ + key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); template void paged_attention_v2_impl_launcher( - torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, - torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, - int max_seq_len, const c10::optional &alibi_slopes) { + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -670,72 +669,72 @@ void paged_attention_v2_impl_launcher( int max_num_partitions = exp_sums.size(-1); // NOTE: alibi_slopes is optional. - const float *alibi_slopes_ptr = + const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T *out_ptr = reinterpret_cast(out.data_ptr()); - float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T *query_ptr = reinterpret_cast(query.data_ptr()); - T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int *block_tables_ptr = block_tables.data_ptr(); - int *seq_lens_ptr = seq_lens.data_ptr(); + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, block_size, \ - max_seq_len, alibi_slopes); - -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ + alibi_slopes); + +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V2_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace - -void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, - torch::Tensor &max_logits, torch::Tensor &tmp_out, - torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, - float scale, torch::Tensor &block_tables, - torch::Tensor &seq_lens, int block_size, +} // namespace + +void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, + float scale, torch::Tensor& block_tables, + torch::Tensor& seq_lens, int block_size, int max_seq_len, - const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype, float kv_scale) { + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 26e81685d623e..2890ba6e2bb32 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -5,25 +5,26 @@ namespace { template -void copy_blocks_cpu_impl( - std::vector &key_caches, - std::vector &value_caches, - const torch::Tensor& mapping_pairs, - const int element_num_per_block, const int layer_num) { +void copy_blocks_cpu_impl(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& mapping_pairs, + const int element_num_per_block, + const int layer_num) { const size_t pair_num = mapping_pairs.size(0); const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; #pragma omp parallel for collapse(2) for (int layer = 0; layer < layer_num; ++layer) { for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item(); + int64_t source_offset = + element_num_per_block * mapping_pairs[pair][0].item(); int64_t target_offset = element_num_per_block * mapping_pairs[pair][1].item(); - scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); - scalar_t *source_ptr = key_cache_ptr + source_offset; - scalar_t *target_ptr = key_cache_ptr + target_offset; + scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t* source_ptr = key_cache_ptr + source_offset; + scalar_t* target_ptr = key_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); - scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); + scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); source_ptr = value_cache_ptr + source_offset; target_ptr = value_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); @@ -33,9 +34,9 @@ void copy_blocks_cpu_impl( template void reshape_and_cache_cpu_impl( - const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, - scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, - const int64_t *__restrict__ slot_mapping, const int num_tokens, + const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, + const int64_t* __restrict__ slot_mapping, const int num_tokens, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { const int block_elem_num = num_heads * head_size * block_size; @@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl( int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_value_head_idx = token_idx * value_stride + head_idx * head_size; - const scalar_t *src_key_head_ptr = key + src_key_head_idx; - const scalar_t *src_value_head_ptr = value + src_value_head_idx; + const scalar_t* src_key_head_ptr = key + src_key_head_idx; + const scalar_t* src_value_head_ptr = value + src_value_head_idx; const int64_t block_index = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; - scalar_t *target_key_head_ptr = key_cache + + scalar_t* target_key_head_ptr = key_cache + block_elem_num * block_index + head_idx * block_size * head_size; - scalar_t *target_value_head_ptr = value_cache + + scalar_t* target_value_head_ptr = value_cache + block_elem_num * block_index + head_idx * block_size * head_size; @@ -79,10 +80,10 @@ void reshape_and_cache_cpu_impl( } } } -}; // namespace +}; // namespace -void copy_blocks(std::vector &key_caches, - std::vector &value_caches, +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, const torch::Tensor& block_mapping) { unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); @@ -100,10 +101,10 @@ void copy_blocks(std::vector &key_caches, }); } -void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, - torch::Tensor &key_cache, torch::Tensor &value_cache, - torch::Tensor &slot_mapping, - const std::string &kv_cache_dtype, float kv_scale) { +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); int num_tokens = key.size(0); @@ -127,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, }); } -void swap_blocks(torch::Tensor &src, torch::Tensor &dst, - const torch::Tensor&block_mapping) { +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") } diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp index 467f0dc84982c..65d3ddcec5709 100644 --- a/csrc/cpu/layernorm.cpp +++ b/csrc/cpu/layernorm.cpp @@ -2,10 +2,10 @@ namespace { template -void rms_norm_impl(scalar_t *__restrict__ out, - const scalar_t *__restrict__ input, - const scalar_t *__restrict__ weight, const float epsilon, - const int num_tokens, const int hidden_size) { +void rms_norm_impl(scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out, } template -void fused_add_rms_norm_impl(scalar_t *__restrict__ input, - scalar_t *__restrict__ residual, - const scalar_t *__restrict__ weight, - const float epsilon, const int num_tokens, - const int hidden_size) { +void fused_add_rms_norm_impl(scalar_t* __restrict__ input, + scalar_t* __restrict__ residual, + const scalar_t* __restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input, } } } -} // namespace +} // namespace -void rms_norm(torch::Tensor &out, torch::Tensor &input, - torch::Tensor &weight, float epsilon) { +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { CPU_KERNEL_GUARD_IN(rms_norm_impl) rms_norm_impl(out.data_ptr(), input.data_ptr(), - weight.data_ptr(), epsilon, num_tokens, - hidden_size); + weight.data_ptr(), epsilon, num_tokens, + hidden_size); CPU_KERNEL_GUARD_OUT(rms_norm_impl) }); } -void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, - torch::Tensor &weight, float epsilon) { +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 5dc1bde45ac5f..73bf77e46f538 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -4,16 +4,16 @@ namespace { template void rotary_embedding_impl( - const int64_t - *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t - *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or - /// [num_tokens, num_heads, head_size] - scalar_t - *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or - // [num_tokens, num_kv_heads, head_size] - const scalar_t - *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { @@ -26,7 +26,7 @@ void rotary_embedding_impl( #pragma omp parallel for for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; for (int i = 0; i < num_heads; ++i) { const int head_idx = i; @@ -94,16 +94,16 @@ void rotary_embedding_impl( template void rotary_embedding_gptj_impl( - const int64_t - *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t - *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or - /// [num_tokens, num_heads, head_size] - scalar_t - *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or - // [num_tokens, num_kv_heads, head_size] - const scalar_t - *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { @@ -113,13 +113,13 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t *cos_cache_ptr = cache_ptr; - const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * query_stride + head_idx * head_size; - scalar_t *head_query = token_head + query; + scalar_t* head_query = token_head + query; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -141,12 +141,12 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_kv_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t *cos_cache_ptr = cache_ptr; - const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * key_stride + head_idx * head_size; - scalar_t *head_key = key + token_head; + scalar_t* head_key = key + token_head; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -164,11 +164,11 @@ void rotary_embedding_gptj_impl( } } } -}; // namespace +}; // namespace -void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, - torch::Tensor &key, int head_size, - torch::Tensor &cos_sin_cache, bool is_neox) { +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index bba044087f37c..63082393c8102 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); + ops.def("paged_attention_v1", &paged_attention_v1, + "Compute the attention between an input query and the cached " + "keys/values using PagedAttention."); + ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); // Activation ops - ops.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - ops.def( - "gelu_and_mul", - &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def( - "gelu_tanh_and_mul", - &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - ops.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); + ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + ops.def("gelu_and_mul", &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); + ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); // Layernorm - ops.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + ops.def("rms_norm", &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); - ops.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); + ops.def("fused_add_rms_norm", &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); // Rotary embedding - ops.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + ops.def("rotary_embedding", &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); + cache_ops.def("swap_blocks", &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def("copy_blocks", ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def("reshape_and_cache", &reshape_and_cache, + "Reshape the key and value tensors and cache them"); } diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 1ebb2e74a82fc..5909e5eaf5e60 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -1,7 +1,7 @@ #pragma once #ifdef USE_ROCM -#include + #include #endif #ifndef USE_ROCM @@ -17,7 +17,8 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #endif @@ -29,7 +30,8 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) #else #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) #endif @@ -41,4 +43,3 @@ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif - diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 1483484faeb4a..2ba49b339e148 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -2,9 +2,6 @@ #include -int get_device_attribute( - int attribute, - int device_id); +int get_device_attribute(int attribute, int device_id); -int get_max_shared_memory_per_block_device_attribute( - int device_id); +int get_max_shared_memory_per_block_device_attribute(int device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 1a443ef3620cc..7d8e2e19720fa 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -2,34 +2,28 @@ #include #include #endif -int get_device_attribute( - int attribute, - int device_id) -{ - int device, value; - if (device_id < 0) { - cudaGetDevice(&device); - } - else { - device = device_id; - } - cudaDeviceGetAttribute(&value, static_cast(attribute), device); - return value; +int get_device_attribute(int attribute, int device_id) { + int device, value; + if (device_id < 0) { + cudaGetDevice(&device); + } else { + device = device_id; + } + cudaDeviceGetAttribute(&value, static_cast(attribute), + device); + return value; } - -int get_max_shared_memory_per_block_device_attribute( - int device_id) -{ -int attribute; -// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html -// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 +int get_max_shared_memory_per_block_device_attribute(int device_id) { + int attribute; + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html + // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 #ifdef USE_ROCM - attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; + attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; #else - attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; + attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; #endif - return get_device_attribute(attribute, device_id); + return get_device_attribute(attribute, device_id); } diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 3906dcfc80dbf..0b1d95848525a 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -7,11 +7,11 @@ // fake pointer type using fptr_t = uint64_t; -static_assert(sizeof(void *) == sizeof(fptr_t)); +static_assert(sizeof(void*) == sizeof(fptr_t)); -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int rank, bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) @@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } @@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ -bool _is_weak_contiguous(torch::Tensor &t) { +bool _is_weak_contiguous(torch::Tensor& t) { return t.is_contiguous() || (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, bool full_nvlink) { auto inp_size = inp.numel() * inp.element_size(); // custom allreduce requires input byte size to be multiples of 16 @@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, return false; } -void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, +void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, cudaStream_t stream) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); switch (out.scalar_type()) { case at::ScalarType::Float: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), - out.numel()); + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( - stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #endif @@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, } } -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); @@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { _all_reduce(_fa, inp, out, stream); } -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out) { +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); @@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, } void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); delete fa; } int meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets) { - auto fa = reinterpret_cast(_fa); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets) { + auto fa = reinterpret_cast(_fa); fa->register_buffer(handles, offsets, t.data_ptr()); } std::pair, std::vector> get_graph_buffer_ipc_meta( fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); return fa->get_graph_buffer_ipc_meta(); } -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector> &offsets) { - auto fa = reinterpret_cast(_fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); fa->register_graph_buffers(handles, offsets); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 750e68d42f6c6..1ed49b8aa9cae 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -31,9 +31,9 @@ struct Signal { alignas(128) uint32_t end[kMaxBlocks][8]; }; -struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; +struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { volatile Signal *signals[8]; }; +struct __align__(16) RankSignals { volatile Signal* signals[8]; }; // like std::array, but aligned template @@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) { // scalar add functions // for some reason when compiling with Pytorch, the + operator for half and // bfloat is disabled so we call the intrinsics directly -DINLINE half &assign_add(half &a, half b) { +DINLINE half& assign_add(half& a, half b) { a = __hadd(a, b); return a; } -DINLINE float &assign_add(float &a, float b) { return a += b; } +DINLINE float& assign_add(float& a, float b) { return a += b; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } @@ -80,14 +80,14 @@ template <> DINLINE nv_bfloat16 downcast_s(float val) { return __float2bfloat16(val); } -DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { +DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { a = __hadd(a, b); return a; } #endif template -DINLINE array_t &packed_assign_add(array_t &a, array_t b) { +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { #pragma unroll for (int i = 0; i < N; i++) { assign_add(a.data[i], b.data[i]); @@ -128,7 +128,7 @@ DINLINE O downcast(array_t val) { // prior memory accesses. Note: volatile writes will not be reordered against // other volatile writes. template -DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, +DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { if (threadIdx.x < ngpus) { // reset flag for next time @@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]) - ; + while (!self_sg->start[blockIdx.x][threadIdx.x]); } __syncthreads(); } @@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template -DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, +DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. + // the memory model. if constexpr (!final_sync) __threadfence_system(); if (threadIdx.x < ngpus) { // reset flag for next time @@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]) - ; + while (!self_sg->end[blockIdx.x][threadIdx.x]); } if constexpr (!final_sync) __syncthreads(); } template -DINLINE P packed_reduce(const P *ptrs[], int idx) { +DINLINE P packed_reduce(const P* ptrs[], int idx) { A tmp = upcast(ptrs[0][idx]); #pragma unroll for (int i = 1; i < ngpus; i++) { @@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData *_dp, RankSignals sg, - volatile Signal *self_sg, T *__restrict__ result, + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; @@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1) // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { - ((P *)result)[idx] = - packed_reduce((const P **)&dp.ptrs[0], idx); + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } end_sync(sg, self_sg, rank); } template -DINLINE P *get_tmp_buf(volatile Signal *sg) { - return (P *)(((Signal *)sg) + 1); +DINLINE P* get_tmp_buf(volatile Signal* sg) { + return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData *_dp, RankSignals sg, - volatile Signal *self_sg, T *__restrict__ result, + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; @@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1) int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; - const P *ptrs[ngpus]; - P *tmps[ngpus]; + const P* ptrs[ngpus]; + P* tmps[ngpus]; #pragma unroll for (int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; - ptrs[i] = (const P *)_dp->ptrs[target]; + ptrs[i] = (const P*)_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; @@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1) int gather_from_rank = ((rank + i) % ngpus); if (gather_from_rank == ngpus - 1 || idx < part) { int dst_idx = gather_from_rank * part + idx; - ((P *)result)[dst_idx] = tmps[i][idx]; + ((P*)result)[dst_idx] = tmps[i][idx]; } } } @@ -261,14 +258,14 @@ class CustomAllreduce { // below are device pointers RankSignals sg_; - std::unordered_map buffers_; - Signal *self_sg_; + std::unordered_map buffers_; + Signal* self_sg_; // stores the registered device pointers from all ranks RankData *d_rank_data_base_, *d_rank_data_end_; - std::vector graph_unreg_buffers_; + std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers - std::map ipc_handles_; + std::map ipc_handles_; /** * meta is a pointer to device metadata and temporary buffer for allreduce. @@ -279,22 +276,22 @@ class CustomAllreduce { * note: this class does not own any device memory. Any required buffers * are passed in from the constructor */ - CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t *handles, - const std::vector &offsets, int rank, + CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t* handles, + const std::vector& offsets, int rank, bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), full_nvlink_(full_nvlink), self_sg_(meta), - d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - Signal *rank_sg; + Signal* rank_sg; if (i != rank_) { - char *handle = open_ipc_handle(&handles[i]); + char* handle = open_ipc_handle(&handles[i]); handle += offsets[i]; - rank_sg = (Signal *)handle; + rank_sg = (Signal*)handle; } else { rank_sg = self_sg_; } @@ -302,13 +299,13 @@ class CustomAllreduce { } } - char *open_ipc_handle(const void *ipc_handle) { + char* open_ipc_handle(const void* ipc_handle) { auto [it, new_handle] = - ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); + ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); if (new_handle) { - char *ipc_ptr; - CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, - *((const cudaIpcMemHandle_t *)ipc_handle), + char* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } @@ -323,7 +320,7 @@ class CustomAllreduce { std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; - void *base_ptr; + void* base_ptr; // note: must share the base address of each allocation, or we get wrong // address if (cuPointerGetAttribute(&base_ptr, @@ -331,8 +328,8 @@ class CustomAllreduce { (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(cudaIpcGetMemHandle( - (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); - offsets[i] = ((char *)ptr) - ((char *)base_ptr); + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); } return std::make_pair(handles, offsets); } @@ -344,13 +341,13 @@ class CustomAllreduce { std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } - void register_buffer(const std::vector &handles, - const std::vector &offsets, void *self) { + void register_buffer(const std::vector& handles, + const std::vector& offsets, void* self) { check_rank_data_capacity(); RankData data; for (int i = 0; i < world_size_; i++) { if (i != rank_) { - char *handle = open_ipc_handle(handles[i].data()); + char* handle = open_ipc_handle(handles[i].data()); handle += offsets[i]; data.ptrs[i] = handle; } else { @@ -371,17 +368,17 @@ class CustomAllreduce { // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. void register_graph_buffers( - const std::vector &handles, - const std::vector> &offsets) { + const std::vector& handles, + const std::vector>& offsets) { auto num_buffers = graph_unreg_buffers_.size(); check_rank_data_capacity(num_buffers); std::vector rank_data(num_buffers); for (int i = 0; i < num_buffers; i++) { auto self_ptr = graph_unreg_buffers_[i]; - auto &rd = rank_data[i]; + auto& rd = rank_data[i]; for (int j = 0; j < world_size_; j++) { if (j != rank_) { - char *handle = + char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); handle += offsets[j][i]; rd.ptrs[j] = handle; @@ -405,7 +402,7 @@ class CustomAllreduce { * will cause contention on NVLink bus. */ template - void allreduce(cudaStream_t stream, T *input, T *output, int size, + void allreduce(cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = 36) { auto d = packed_t::P::size; if (size % d != 0) @@ -418,7 +415,7 @@ class CustomAllreduce { std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit)); - RankData *ptrs; + RankData* ptrs; cudaStreamCaptureStatus status; CUDACHECK(cudaStreamIsCapturing(stream, &status)); if (status == cudaStreamCaptureStatusActive) { diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index c34a50389c21c..f7868233076cd 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -48,7 +48,7 @@ __global__ void dummy_kernel() { } template -__global__ void set_data(T *data, int size, int myRank) { +__global__ void set_data(T* data, int size, int myRank) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { data[idx] = myRank * 0.11f; @@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) { } template -__global__ void convert_data(const T *data1, const T *data2, double *fdata1, - double *fdata2, int size) { +__global__ void convert_data(const T* data1, const T* data2, double* fdata1, + double* fdata2, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { fdata1[idx] = data1[idx]; @@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1, } } -__global__ void init_rand(curandState_t *state, int size, int nRanks) { +__global__ void init_rand(curandState_t* state, int size, int nRanks) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { for (int i = 0; i < nRanks; i++) { @@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) { } template -__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, +__global__ void gen_data(curandState_t* state, T* data, double* ground_truth, int myRank, int nRanks, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { @@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, } template -void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, +void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, int data_size, bool performance_test) { - T *result; + T* result; cudaStream_t stream; CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); @@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t data_handles[8]; - vllm::Signal *buffer; - T *self_data_copy; + vllm::Signal* buffer; + T* self_data_copy; /** * Allocate IPC buffer * @@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, MPI_COMM_WORLD)); - void *rank_data; + void* rank_data; size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); std::vector offsets(nRanks, 0); vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, offsets, myRank); - auto *self_data = - reinterpret_cast(reinterpret_cast(buffer) + - sizeof(vllm::Signal) + data_size * sizeof(T)); + auto* self_data = + reinterpret_cast(reinterpret_cast(buffer) + + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { std::vector handles; handles.reserve(nRanks); for (int i = 0; i < nRanks; i++) { - char *begin = (char *)&data_handles[i]; - char *end = (char *)&data_handles[i + 1]; + char* begin = (char*)&data_handles[i]; + char* end = (char*)&data_handles[i + 1]; handles.emplace_back(begin, end); } std::vector offsets(nRanks, @@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, fa.register_buffer(handles, offsets, self_data); } - double *ground_truth; + double* ground_truth; CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); - curandState_t *states; + curandState_t* states; CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, @@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, CUDACHECK(cudaStreamDestroy(stream)); } -int main(int argc, char **argv) { +int main(int argc, char** argv) { int nRanks, myRank; MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); @@ -296,7 +296,7 @@ int main(int argc, char **argv) { ncclUniqueId id; ncclComm_t comm; if (myRank == 0) ncclGetUniqueId(&id); - MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, + MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 91abd9e85b4bb..3ecea03242f06 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -6,32 +6,30 @@ #include -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) - -#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) -#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index e56b4d2204005..70a2b3b0a07b1 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -11,26 +11,24 @@ #include #include - using __nv_bfloat16 = __hip_bfloat16; - using __nv_bfloat162 = __hip_bfloat162; +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; #endif namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float) input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); @@ -40,12 +38,12 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; } } - /* Converter structs for the conversion from torch types to HIP/CUDA types, and the associated type conversions within HIP/CUDA. These helpers need to be implemented for now because the relevant type conversion @@ -54,51 +52,68 @@ __global__ void rms_norm_kernel( Each struct should have the member static constexpr bool `exists`: If false, the optimized kernel is not used for the corresponding torch type. - If true, the struct should be fully defined as shown in the examples below. + If true, the struct should be fully defined as shown in the examples below. */ -template -struct _typeConvert { static constexpr bool exists = false; }; +template +struct _typeConvert { + static constexpr bool exists = false; +}; #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion -template<> +template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __half; using packed_hip_type = __half2; __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } - __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } - __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } }; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // CUDA_ARCH < 800 does not have BF16 support // TODO: Add in ROCm support once public headers handle bf16 maturely -template<> +template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; - __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } - __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } - __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } - __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } }; -#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) + #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) /* Vector POD struct to generate vectorized and packed FP16/BF16 ops for appropriate specializations of fused_add_rms_norm_kernel. Only functions that are necessary in that kernel are implemented. Alignment to 16 bytes is required to use 128-bit global memory ops. */ -template +template struct alignas(16) _f16Vec { - /* Not theoretically necessary that width is a power of 2 but should - almost always be the case for optimization purposes */ + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ static_assert(width > 0 && (width & (width - 1)) == 0, "Width is not a positive power of 2!"); using Converter = _typeConvert; @@ -108,51 +123,49 @@ struct alignas(16) _f16Vec { __device__ _f16Vec& operator+=(const _f16Vec& other) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i+1]}; - temp += T2{other.data[i], other.data[i+1]}; + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] += other.data[i]; +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const _f16Vec& other) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i+1]}; - temp *= T2{other.data[i], other.data[i+1]}; + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] *= other.data[i]; +#pragma unroll + for (int i = 0; i < width; ++i) data[i] *= other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const float scale) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); temp_f.x *= scale; temp_f.y *= scale; T2 temp = Converter::convert(temp_f); data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < width; ++i) { float temp = Converter::convert(data[i]) * scale; data[i] = Converter::convert(temp); @@ -164,13 +177,13 @@ struct alignas(16) _f16Vec { __device__ float sum_squares() const { float result = 0.0f; if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i+1]}); + float2 z = Converter::convert(T2{data[i], data[i + 1]}); result += z.x * z.x + z.y * z.y; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < width; ++i) { float x = Converter::convert(data[i]); result += x * x; @@ -184,15 +197,13 @@ struct alignas(16) _f16Vec { Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ -template -__global__ std::enable_if_t< - (width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { // Sanity checks on our vector struct and type-punned pointer arithmetic static_assert(std::is_pod_v<_f16Vec>); static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); @@ -203,9 +214,12 @@ __global__ std::enable_if_t< /* These and the argument pointers are all declared `restrict` as they are not aliased in practice. Argument pointers should not be dereferenced in this kernel as that would be undefined behavior */ - auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); - auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); - auto* __restrict__ weight_v = reinterpret_cast*>(weight); + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; @@ -215,10 +229,11 @@ __global__ std::enable_if_t< residual_v[id] = temp; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else variance = blockReduceSum(variance); + } else + variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -233,52 +248,50 @@ __global__ std::enable_if_t< } } - /* Generic fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ -template -__global__ std::enable_if_t< - (width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { scalar_t z = input[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx]; - float x = (float) z; + float x = (float)z; variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else variance = blockReduceSum(variance); + } else + variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) residual[blockIdx.x * hidden_size + idx]; - input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)residual[blockIdx.x * hidden_size + idx]; + input[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; } } -} // namespace vllm +} // namespace vllm -void rms_norm( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +void rms_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -286,40 +299,27 @@ void rms_norm( dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "rms_norm_kernel", - [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "fused_add_rms_norm_kernel", \ - [&] { \ - vllm::fused_add_rms_norm_kernel \ - <<>>( \ - input.data_ptr(), \ - residual.data_ptr(), \ - weight.data_ptr(), \ - epsilon, \ - num_tokens, \ - hidden_size); \ - }); - -void fused_add_rms_norm( - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& residual, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>(input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), epsilon, \ + num_tokens, hidden_size); \ + }); + +void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -342,8 +342,8 @@ void fused_add_rms_norm( auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); - bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ - && wt_ptr % 16 == 0; + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; if (ptrs_are_aligned && hidden_size % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index 35c328499a22d..4122f7630d7c7 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -3,5 +3,6 @@ #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); + m.def("topk_softmax", &topk_softmax, + "Apply topk softmax to the gating outputs."); } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index a01be3e426d72..93e7844ac1993 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -2,8 +2,6 @@ #include -void topk_softmax( - torch::Tensor& topk_weights, - torch::Tensor& topk_indices, - torch::Tensor& token_expert_indices, - torch::Tensor& gating_output); +void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output); diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index e01b23685ef4e..edc441d121029 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -7,119 +7,128 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace vllm { namespace { -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { - // don't worry about overflow because num_experts is relatively small - return row * total_col + col; -} +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; } +} // namespace template -__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, - int32_t *sorted_token_ids, - int32_t *expert_ids, - int32_t *total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, - size_t numel) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - extern __shared__ int32_t shared_mem[]; - - int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) - int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are assigned - * to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding blocks - * and stores the corresponding expert_id for each block. - */ - for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { - expert_ids[i / block_size] = threadIdx.x; +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = + shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = + shared_mem + (num_experts + 1) * + num_experts; // 1d tensor with shape (num_experts + 1) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; } - - /** - * Each thread processes a token shard, calculating the index of each token after - * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and - * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], - * where * represents a padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] - * stores the indices of the tokens processed by the expert with expert_id within - * the current thread's token shard. - */ - int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } -} + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } } - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors - const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); +} // namespace vllm + +void moe_align_block_size(torch::Tensor topk_ids, int num_experts, + int block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t shared_mem = + ((num_experts + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); // set dynamic shared mem auto kernel = vllm::moe_align_block_size_kernel; - AT_CUDA_CHECK( - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); kernel<<<1, num_experts, shared_mem, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - num_experts, - block_size, + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); - }); + }); } diff --git a/csrc/ops.h b/csrc/ops.h index 8c2c2ae6e1f5a..f5e0e423bb65d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -2,224 +2,136 @@ #include -void paged_attention_v1( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale); - -void paged_attention_v2( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale); - -void rms_norm( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& weight, - float epsilon); - -void fused_add_rms_norm( - torch::Tensor& input, - torch::Tensor& residual, - torch::Tensor& weight, - float epsilon); - -void rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox); - -void batched_rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox, - int rot_dim, - torch::Tensor& cos_sin_cache_offsets); - -void silu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_tanh_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_new( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_fast( - torch::Tensor& out, - torch::Tensor& input); +void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, + torch::Tensor& key_cache, torch::Tensor& value_cache, + int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale); + +void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, + float scale, torch::Tensor& block_tables, + torch::Tensor& seq_lens, int block_size, + int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale); + +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + float epsilon); + +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, float epsilon); + +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox); + +void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox, + int rot_dim, + torch::Tensor& cos_sin_cache_offsets); + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_new(torch::Tensor& out, torch::Tensor& input); + +void gelu_fast(torch::Tensor& out, torch::Tensor& input); #ifndef USE_ROCM -torch::Tensor aqlm_gemm( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias -); - -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes -); - -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy); - -torch::Tensor marlin_gemm( - torch::Tensor& a, - torch::Tensor& b_q_weight, - torch::Tensor& b_scales, - torch::Tensor& workspace, - int64_t size_m, - int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_24_gemm( - torch::Tensor &a, - torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, - int64_t num_bits, - int64_t size_m, - int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_gemm( - torch::Tensor &a, - torch::Tensor &b_q_weight, - torch::Tensor &b_scales, - torch::Tensor &g_idx, - torch::Tensor &perm, - torch::Tensor &workspace, - int64_t num_bits, - int64_t size_m, - int64_t size_n, - int64_t size_k, - bool is_k_full); - -torch::Tensor gptq_marlin_repack( - torch::Tensor &b_q_weight, - torch::Tensor &perm, - int64_t size_k, - int64_t size_n, - int64_t num_bits); - -int cutlass_scaled_mm_dq( - torch::Tensor& out, - torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias); + +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes); + +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters); + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int split_k_iters, int thx, + int thy); + +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t size_m, int64_t size_n, int64_t size_k); + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k); + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full); + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits); + +int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales); #endif -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table); - -torch::Tensor gptq_gemm( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama, - int bit); - -void gptq_shuffle( - torch::Tensor q_weight, - torch::Tensor q_perm, - int bit); - -void static_scaled_fp8_quant( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& scale); - -void dynamic_scaled_fp8_quant( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& scale); - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table); + +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int bit); + +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit); + +void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + +void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + +void moe_align_block_size(torch::Tensor topk_ids, int num_experts, + int block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM using fptr_t = uint64_t; -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, - bool full_nvlink); -bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int rank, + bool full_nvlink); +bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, bool full_nvlink); -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out); +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out); void dispose(fptr_t _fa); int meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets); -std::pair, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector> &offsets); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets); +std::pair, std::vector> get_graph_buffer_ipc_meta( + fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets); #endif diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index d80cb6973fad6..69d6dae1c26bc 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -7,14 +7,10 @@ namespace vllm { -template +template inline __device__ void apply_token_rotary_embedding( - scalar_t* __restrict__ arr, - const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, - int rot_offset, - int embed_dim) -{ + scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { int x_index, y_index; scalar_t cos, sin; if (IS_NEOX) { @@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding( arr[y_index] = y * cos + x * sin; } -template +template inline __device__ void apply_rotary_embedding( - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* cache_ptr, - const int head_size, - const int num_heads, - const int num_kv_heads, - const int rot_dim, - const int token_idx, - const int64_t query_stride, - const int64_t key_stride) -{ + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, const int head_size, const int num_heads, + const int num_kv_heads, const int rot_dim, const int token_idx, + const int64_t query_stride, const int64_t key_stride) { const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(query + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding( + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } const int nk = num_kv_heads * embed_dim; @@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(key + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } } -template +template __global__ void rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); } -template +template __global__ void batched_rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] + // or [num_tokens] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + const scalar_t* cache_ptr = + cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); } -} // namespace vllm +} // namespace vllm void rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] - int head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox) { + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { int64_t num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; @@ -135,36 +141,21 @@ void rotary_embedding( dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), - "rotary_embedding", - [&] { - if (is_neox) { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } else { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size); + } + }); } /* @@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together and process in batched manner. */ void batched_rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] - int head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, - int rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, int rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] ) { int64_t num_tokens = cos_sin_cache_offsets.size(0); int num_heads = query.size(-1) / head_size; @@ -191,36 +183,21 @@ void batched_rotary_embedding( dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), - "rotary_embedding", - [&] { - if (is_neox) { - vllm::batched_rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } else { - vllm::batched_rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } + }); } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index f5b4865506568..cba07f0ae9f2a 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -8,116 +8,87 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); + ops.def("paged_attention_v1", &paged_attention_v1, + "Compute the attention between an input query and the cached " + "keys/values using PagedAttention."); + ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); // Activation ops - ops.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - ops.def( - "gelu_and_mul", - &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def( - "gelu_tanh_and_mul", - &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - ops.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); + ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + ops.def("gelu_and_mul", &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); + ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); // Layernorm - ops.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + ops.def("rms_norm", &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); - ops.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); + ops.def("fused_add_rms_norm", &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); // Rotary embedding - ops.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + ops.def("rotary_embedding", &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); - ops.def( - "batched_rotary_embedding", - &batched_rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"); + ops.def("batched_rotary_embedding", &batched_rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key " + "(supports multiple loras)"); // Quantization ops #ifndef USE_ROCM ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); + ops.def("marlin_gemm", &marlin_gemm, + "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, + "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, + "gptq_marlin Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_repack", &gptq_marlin_repack, + "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); - ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."); + ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, + "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or " + "per-row/column quantization."); #endif - + ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); - ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); - ops.def( - "moe_align_block_size", - &moe_align_block_size, - "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); + ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, + "Compute FP8 quantized tensor for given scaling factor"); + ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, + "Compute FP8 quantized tensor and scaling factor"); + ops.def("moe_align_block_size", &moe_align_block_size, + "Aligning the number of tokens to be processed by each expert such " + "that it is divisible by the block size."); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); - cache_ops.def( - "reshape_and_cache_flash", - &reshape_and_cache_flash, - "Reshape the key and value tensors and cache them"); - cache_ops.def( - "convert_fp8", - &convert_fp8, - "Convert the key and value cache to fp8 data type"); + cache_ops.def("swap_blocks", &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def("copy_blocks", ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def("reshape_and_cache", &reshape_and_cache, + "Reshape the key and value tensors and cache them"); + cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash, + "Reshape the key and value tensors and cache them"); + cache_ops.def("convert_fp8", &convert_fp8, + "Convert the key and value cache to fp8 data type"); // Cuda utils - pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); - cuda_utils.def( - "get_device_attribute", - &get_device_attribute, - "Gets the specified device attribute."); + pybind11::module cuda_utils = + m.def_submodule("cuda_utils", "vLLM cuda utils"); + cuda_utils.def("get_device_attribute", &get_device_attribute, + "Gets the specified device attribute."); - cuda_utils.def( - "get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute, - "Gets the maximum shared memory per block device attribute."); + cuda_utils.def("get_max_shared_memory_per_block_device_attribute", + &get_max_shared_memory_per_block_device_attribute, + "Gets the maximum shared memory per block device attribute."); #ifndef USE_ROCM // Custom all-reduce kernels @@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { custom_ar.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers"); #endif - } diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 4415316e1e8cd..255844eec56d4 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -25,32 +25,28 @@ #include #include - namespace vllm { namespace aqlm { __global__ void Code1x16MatVec( - const int4* __restrict__ A, - const int4* __restrict__ B, - int4* __restrict__ C, - const int4* __restrict__ codebook, - const int prob_m, - const int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, + const int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -67,8 +63,7 @@ __global__ void Code1x16MatVec( // We pad shared memory to avoid bank conflicts during reads __syncthreads(); for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) - sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; } __syncthreads(); b_gl_rd += 32 * 8; @@ -76,22 +71,19 @@ __global__ void Code1x16MatVec( int b_sh_rd = 9 * (threadIdx.x % 32); if (pred && a_gl_rd < a_gl_end) { const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { uint32_t dec[4]; - // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't - // actually help us; this brings > 2x speedup. - asm volatile ( - "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*) &codebook[enc[i]]) - ); + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); half2* a = reinterpret_cast(&dec); half2* b = reinterpret_cast(&sh_b[b_sh_rd]); half2 res2 = {}; - #pragma unroll - for (int j = 0; j < 4; j++) - res2 = __hfma2(a[j], b[j], res2); +#pragma unroll + for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); res += __half2float(res2.x) + __half2float(res2.y); b_sh_rd++; } @@ -100,37 +92,33 @@ __global__ void Code1x16MatVec( } if (pred) { - #pragma unroll - for (int i = 16; i > 0; i /= 2) - res += __shfl_down_sync(0xffffffff, res, i); +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); if (threadIdx.x % 32 == 0) reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); } } __global__ void Code2x8MatVec( - const int4* __restrict__ A, - const int4* __restrict__ B, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -148,9 +136,8 @@ __global__ void Code2x8MatVec( for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { int4 dec = codebook[i]; - #pragma unroll - for (int j = 0; j < 8; j++) - sh_code[8 * i + (j + lane) % 8] = dec; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; } __syncthreads(); @@ -161,8 +148,7 @@ __global__ void Code2x8MatVec( // We pad shared memory to avoid bank conflicts during reads __syncthreads(); for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) - sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; } __syncthreads(); b_gl_rd += 32 * 8; @@ -170,13 +156,15 @@ __global__ void Code2x8MatVec( int b_sh_rd = 9 * (threadIdx.x % 32); if (pred && a_gl_rd < a_gl_end) { const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { - half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); half2 res2 = {}; - #pragma unroll +#pragma unroll for (int j = 0; j < 4; j++) res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); res += __half2float(res2.x) + __half2float(res2.y); @@ -187,36 +175,31 @@ __global__ void Code2x8MatVec( } if (pred) { - #pragma unroll - for (int i = 16; i > 0; i /= 2) - res += __shfl_down_sync(0xffffffff, res, i); +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); if (threadIdx.x % 32 == 0) reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); } } - __global__ void Code1x16Dequant( - const int4* __restrict__ A, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m. - const int codebook_stride // as int4 + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long, sums to m. + const int codebook_stride // as int4 ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -231,17 +214,15 @@ __global__ void Code1x16Dequant( while (iters--) { if (pred && a_gl_rd < a_gl_end) { const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { int4 chunk; auto dec = reinterpret_cast(&chunk); - // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't - // actually help us; this brings > 2x speedup. - asm volatile ( - "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*) &codebook[enc[i]]) - ); + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); C[a_gl_rd * 8 + i] = chunk; } @@ -250,28 +231,25 @@ __global__ void Code1x16Dequant( } } - __global__ void Code2x8Dequant( - const int4* __restrict__ A, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. - const int codebook_stride // as int4 + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -290,9 +268,8 @@ __global__ void Code2x8Dequant( for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { int4 dec = codebook[i]; - #pragma unroll - for (int j = 0; j < 8; j++) - sh_code[8 * i + (j + lane) % 8] = dec; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; } __syncthreads(); @@ -302,12 +279,14 @@ __global__ void Code2x8Dequant( while (iters--) { if (pred && a_gl_rd < a_gl_end) { const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { int4 chunk; - half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - #pragma unroll + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); +#pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); C[a_gl_rd * 8 + i] = chunk; @@ -317,22 +296,15 @@ __global__ void Code2x8Dequant( } } -inline int ceildiv(int a, int b) { - return (a + b - 1) / b; -} +inline int ceildiv(int a, int b) { return (a + b - 1) / b; } const int THREAD_M = 16; -void code1x16_matvec_cuda( - const void* __restrict__ A, - const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, - const int codebook_stride -) { +void code1x16_matvec_cuda(const void* __restrict__ A, + const void* __restrict__ B, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); int waves = 0; @@ -345,28 +317,16 @@ void code1x16_matvec_cuda( int blocks = ceildiv(prob_m, thread_m); int threads = 32 * thread_m; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code1x16MatVec<<>>( - (const int4*) A, - (const int4*) B, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + Code1x16MatVec<<>>( + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); } -void code2x8_matvec_cuda( - const void* __restrict__ A, - const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, - const int codebook_stride -) { +void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); int waves = 0; @@ -379,30 +339,20 @@ void code2x8_matvec_cuda( int blocks = ceildiv(prob_m, thread_m); int threads = 32 * thread_m; int shared = 16 * (2 * 256 * 8 + 32 * 9); - cudaFuncSetAttribute( - Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared - ); + cudaFuncSetAttribute(Code2x8MatVec, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); Code2x8MatVec<<>>( - (const int4*) A, - (const int4*) B, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); } void code1x16_dequant_cuda( - const void* __restrict__ A, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); @@ -417,25 +367,21 @@ void code1x16_dequant_cuda( int threads = 32 * thread_m; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); Code1x16Dequant<<>>( - (const int4*) A, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - codebook_stride // as int4. + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long. + codebook_stride // as int4. ); } // Dequantizes the code and codebook into weights. -void code2x8_dequant_cuda( - const void* __restrict__ A, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. - const int codebook_stride // as int4 +void code2x8_dequant_cuda( + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 ) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); @@ -451,74 +397,50 @@ void code2x8_dequant_cuda( int shared = 16 * (2 * 256 * 8 + 32 * 9); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cudaFuncSetAttribute( - Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared - ); + cudaFuncSetAttribute(Code2x8Dequant, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); Code2x8Dequant<<>>( - (const int4*) A, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, codebook_stride); } -int codebook_stride(const torch::Tensor& codebooks) -{ +int codebook_stride(const torch::Tensor& codebooks) { return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); } void code1x16_matvec( - const torch::Tensor& A, - const torch::Tensor& B, - torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long. + const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes // cumulative sizes of A spanning each + // codebook, at most 3 long. ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); int prob_m = C.size(0); int prob_k = B.size(0); - code1x16_matvec_cuda( - A.data_ptr(), - B.data_ptr(), - C.data_ptr(), - codebook.data_ptr(), - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride(codebook) - ); + code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + codebook_stride(codebook)); } -torch::Tensor code1x16_matmat( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias) { +torch::Tensor code1x16_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { auto input_sizes = input.sizes(); auto out_features = codes.size(0) * codebooks.size(2); auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty({flat_input.size(0), out_features}, - torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()) - ); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); for (int i = 0; i < flat_input.size(0); ++i) { auto input_vec = flat_input.index({i}); auto output_vec = flat_output.index({i}); - code1x16_matvec( - codes.squeeze(2), - input_vec, - output_vec, - codebooks, - codebook_a_sizes - ); + code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); } flat_output *= scales.flatten().unsqueeze(0); @@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat( return output; } -void code2x8_matvec( - const torch::Tensor& A, - const torch::Tensor& B, - torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes -) { +void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, + torch::Tensor& C, const torch::Tensor& codebook, + const int4 codebook_a_sizes) { const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); int prob_m = C.size(0); int prob_k = B.size(0); - code2x8_matvec_cuda( - A.data_ptr(), - B.data_ptr(), - C.data_ptr(), - codebook.data_ptr(), - prob_m, - prob_k, - codebook_a_sizes, - 2 * codebook_stride(codebook) - ); + code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + 2 * codebook_stride(codebook)); } -torch::Tensor code2x8_matmat( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias -) { +torch::Tensor code2x8_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { auto input_sizes = input.sizes(); auto out_features = codes.size(0) * codebooks.size(2); auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty({flat_input.size(0), out_features}, - torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()) - ); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); for (int i = 0; i < flat_input.size(0); ++i) { auto input_vec = flat_input.index({i}); auto output_vec = flat_output.index({i}); - code2x8_matvec( - codes.squeeze(2), - input_vec, - output_vec, - codebooks, - codebook_a_sizes - ); + code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); } flat_output *= scales.flatten().unsqueeze(0); if (bias.has_value()) { @@ -596,64 +498,56 @@ torch::Tensor code2x8_matmat( } // Accumulate the partition sizes. -int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) -{ +int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { int4 cumulative_sizes; auto cumulative_size = &cumulative_sizes.x; int i = 0; int last = 0; assert(codebook_partition_sizes.size(0) <= 4); - for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) - { + for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { *cumulative_size = codebook_partition_sizes[i].item() + last; last = *cumulative_size; } // fill in the rest with unreachable. - for (; i < 4; ++i, ++cumulative_size) - { - *cumulative_size = last*10; + for (; i < 4; ++i, ++cumulative_size) { + *cumulative_size = last * 10; } return cumulative_sizes; } -} // namespace aqlm -} // namespace vllm - +} // namespace aqlm +} // namespace vllm -torch::Tensor aqlm_gemm( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias -) -{ - int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const entries = codebooks.size(1); - if (nbooks == 1 && entries == (1 << 16)) - { - return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + if (nbooks == 1 && entries == (1 << 16)) { + return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); } - if (nbooks == 2 && entries == (1 << 8)) - { - return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + if (nbooks == 2 && entries == (1 << 8)) { + return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); } - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") return {}; } -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes -) -{ - int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const entries = codebooks.size(1); @@ -668,45 +562,37 @@ torch::Tensor aqlm_dequant( assert(out_features = codebook_partition_sizes.sum().item()); auto weights = torch::empty({out_features, in_features}, - torch::TensorOptions() - .dtype(codebooks.dtype()) - .device(codebooks.device()) - ); + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device())); + + if (nbooks == 1 && entries == (1 << 16)) { + vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); - if (nbooks == 1 && entries == (1 << 16)) - { - vllm::aqlm::code1x16_dequant_cuda( - codes.data_ptr(), - weights.data_ptr(), - codebooks.data_ptr(), - out_features, - in_features, - cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) - // weights *= scales.index({"...", 0, 0}); - - return weights; + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation.) weights *= + // scales.index({"...", 0, 0}); + + return weights; } - if (nbooks == 2 && entries == (1 << 8)) - { - vllm::aqlm::code2x8_dequant_cuda( - codes.data_ptr(), - weights.data_ptr(), - codebooks.data_ptr(), - out_features, - in_features, - cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) - // weights *= scales.index({"...", 0, 0}); - - return weights; + if (nbooks == 2 && entries == (1 << 8)) { + vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation) weights *= + // scales.index({"...", 0, 0}); + + return weights; } - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") return {}; } diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh index d1d926de18d78..813ec6716cf54 100644 --- a/csrc/quantization/awq/dequantize.cuh +++ b/csrc/quantization/awq/dequantize.cuh @@ -1,11 +1,11 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq -Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +Modified from NVIDIA FasterTransformer: +https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ @@ -14,74 +14,88 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor namespace vllm { namespace awq { -__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) -{ +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); #else - uint4 result; + uint4 result; - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. In + // addition, I exploit the fact that sub and fma have the same throughput in + // order to convert elt_23 and elt_67 to fp16 without having to shift them to + // the bottom bits before hand. - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW + // dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. - // This is the half2 {1032, 1032} represented as an integer. - // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - // static constexpr uint32_t NEG_72 = 0xd480d480; - // Haotian: Let's use {-64, -64}. - static constexpr uint32_t NEG_64 = 0xd400d400; + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - return result; + return result; #endif } -} // namespace awq -} // namespace vllm +} // namespace awq +} // namespace vllm diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 5aefb0bd16aef..bb8e5bbb23d7f 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -1,14 +1,12 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ - #include #include @@ -20,26 +18,20 @@ namespace vllm { namespace awq { // Pack two half values. -static inline __device__ __host__ unsigned -__pack_half2(const half x, const half y) { - unsigned v0 = *((unsigned short *)&x); - unsigned v1 = *((unsigned short *)&y); +static inline __device__ __host__ unsigned __pack_half2(const half x, + const half y) { + unsigned v0 = *((unsigned short*)&x); + unsigned v1 = *((unsigned short*)&y); return (v1 << 16) | v0; } -template -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( - int G, - int split_k_iters, - half* __restrict__ A, - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - int M, - int IC, - int OC, - half* __restrict__ C) -{ +template +__global__ void __launch_bounds__(64) + gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, + half* __restrict__ A, int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, int M, int IC, + int OC, half* __restrict__ C) { // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 @@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( static constexpr int row_stride = 2 * 32 * 8 / N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + bool ld_A_flag = + (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * (256 / N) - + (((int)threadIdx.x) / (N / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + (((int)threadIdx.x) % (N / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) - + (((int)threadIdx.x) / (N / 8)) * (N + 8) - + (((int)threadIdx.x) % (N / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + ((int)threadIdx.x) % (N / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * N - + (((int)threadIdx.x) % (N / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * N - + ((int)threadIdx.y) * (N / 2) - + (((int)threadIdx.x) % 4) * 2; + half* A_ptr = + A + + (((int)blockIdx_y) / j_factors1 * 16 + + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * + IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8)) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = + C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; @@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { + if (ld_A_flag) { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { + } else { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + uint4 B_loaded_scale = + *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && + threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, + B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, + B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); } */ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { - // B: 32 x 136 (128+8) float16 // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus + // zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * + // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) + // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * + // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * + // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * + // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = + *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / + // 8)) * 8); - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x + // % (cta_N / 8)) * 8); // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = + // q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == + 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", + B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); } */ // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = + B_loaded_fp16; } __syncthreads(); @@ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + + (((((int)threadIdx.x) & 15) * 40) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(A_shared_warp + 0))[0]), + "=r"(((unsigned*)(A_shared_warp + 0))[1]), + "=r"(((unsigned*)(A_shared_warp + 0))[2]), + "=r"(((unsigned*)(A_shared_warp + 0))[3]) + : "r"(addr)); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) - ); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + + (((int)threadIdx.y) * (N / 2))) + + (ax1_0 * 16))])) + + (((((int)threadIdx.x) & 15) * (N + 8)) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr)); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#else + #else { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#endif + #endif } } } -// TODO: Shang: Hoist loop invariance. + // TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); } } } #endif } -__global__ void __launch_bounds__(64) dequantize_weights( - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - half* __restrict__ C, - int G -) -{ +__global__ void __launch_bounds__(64) + dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, + int* __restrict__ zeros, half* __restrict__ C, int G) { int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; @@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights( uint32_t B_loaded = *(uint32_t*)B_ptr2; uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); *(uint4*)B_shared_ptr2 = B_loaded_fp16; @@ -326,58 +430,57 @@ __global__ void __launch_bounds__(64) dequantize_weights( } } -} // namespace awq -} // namespace vllm - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy) -{ - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); - int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); - - int x_thread = thx; - int y_thread = thy; - - int x_blocks = 1; - int y_blocks = 1; - if (thx==0) { - x_thread = qout_c; - } - if (thy==0) { - y_thread = in_c; - } - if (thx==0 && thy==0) { - x_thread = 8; - y_thread = 8; - x_blocks = (int)(qout_c / 8); - y_blocks = (int)(in_c / 8); - } +} // namespace awq +} // namespace vllm + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int split_k_iters, int thx, + int thy) { + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + int G = in_c / _scaling_factors.size(0); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx == 0) { + x_thread = qout_c; + } + if (thy == 0) { + y_thread = in_c; + } + if (thx == 0 && thy == 0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + auto options = torch::TensorOptions() + .dtype(_scaling_factors.dtype()) + .device(_scaling_factors.device()); + at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks); - dim3 threads_per_block(x_thread, y_thread); + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_thread, y_thread); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::awq::dequantize_weights<<>>( - kernel, scaling_factors, zeros, de_kernel, G); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G); - return _de_kernel; + return _de_kernel; } // in_feats: M, IC [float16] @@ -386,61 +489,61 @@ torch::Tensor awq_dequantize( // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 128 == 0) - { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - else if (num_out_channels % 64 == 0) - { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - return _out_feats.sum(0); +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters) { + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); + at::Tensor _out_feats = + torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } else if (num_out_channels % 64 == 0) { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * + split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index 3ec454f78c654..e62fe731a98d3 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -117,10 +117,10 @@ struct cutlass_2x_gemm { }; template -void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, using StrideC = Stride, Int<0>>; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto c_ptr = static_cast(out.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); auto a_scales_ptr = a_scales.data_ptr(); auto b_scales_ptr = b_scales.data_ptr(); @@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, } // namespace -void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); @@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, } } -void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); @@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, } } -void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 37b096de23e3b..12efcac7bb919 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -120,10 +120,10 @@ struct cutlass_3x_gemm { }; template -void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -146,12 +146,12 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, using GemmKernel = typename Gemm::GemmKernel; typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, b_stride}; - auto c_ptr = static_cast(out.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; @@ -183,10 +183,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, } } // namespace -void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu index a4e696d4a3322..dab73ac6c831e 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -2,29 +2,29 @@ #include #include -void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { int32_t major_capability; int32_t minor_capability; cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, @@ -36,14 +36,15 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, // Checks for conformality TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); + b.size(1) == c.size(1)); TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); diff --git a/csrc/quantization/fp8/amd/hip_float8.h b/csrc/quantization/fp8/amd/hip_float8.h index 87c7c9ce66100..f9c80fcdec576 100644 --- a/csrc/quantization/fp8/amd/hip_float8.h +++ b/csrc/quantization/fp8/amd/hip_float8.h @@ -1,167 +1,137 @@ #pragma once #ifdef __HIPCC__ -#include + #include #else -#include -#include -#include -#include + #include + #include + #include + #include #endif #include "hip_float8_impl.h" -struct alignas(1) hip_fp8 -{ - struct from_bits_t - { - }; - HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } - uint8_t data; - - hip_fp8() = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; - explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) - : data(v) - { - } +struct alignas(1) hip_fp8 { + struct from_bits_t {}; + HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + uint8_t data; + + hip_fp8() = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; + explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) + : data(v) {} #ifdef __HIP__MI300__ - // NOTE: ON-DEVICE... always optimal bias - explicit HIP_FP8_DEVICE hip_fp8(float v) - : data(hip_fp8_impl::to_fp8_from_fp32(v)) - { - } - - explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) - : hip_fp8(static_cast(v)) - { - } - - // Host only implementation using s/w simulation - explicit HIP_FP8_HOST -#else // __HIP__MI300__ - // both Host and DEVICE for non-MI300 using s/w simulation - explicit HIP_FP8_HOST_DEVICE -#endif // __HIP__MI300__ - hip_fp8(float v) - { - data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v); - } - - explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) - : hip_fp8(static_cast(v)) - { - } + // NOTE: ON-DEVICE... always optimal bias + explicit HIP_FP8_DEVICE hip_fp8(float v) + : data(hip_fp8_impl::to_fp8_from_fp32(v)) {} + + explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) + : hip_fp8(static_cast(v)) {} + + // Host only implementation using s/w simulation + explicit HIP_FP8_HOST +#else // __HIP__MI300__ + // both Host and DEVICE for non-MI300 using s/w simulation + explicit HIP_FP8_HOST_DEVICE +#endif // __HIP__MI300__ + hip_fp8(float v) { + data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, + true /*clip*/>(v); + } + + explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) + : hip_fp8(static_cast(v)) {} #ifdef __HIP__MI300__ - // upcast using device specific intrinsic - explicit inline HIP_FP8_DEVICE operator float() const - { - float fval; - uint32_t i32val = static_cast(data); - - // upcast - asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); - - return fval; - } - - explicit inline HIP_FP8_HOST operator float() const -#else // __HIP__MI300__ - explicit inline HIP_FP8_HOST_DEVICE operator float() const -#endif // __HIP__MI300__ - { - return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data); - } + // upcast using device specific intrinsic + explicit inline HIP_FP8_DEVICE operator float() const { + float fval; + uint32_t i32val = static_cast(data); + + // upcast + asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" + : "=v"(fval) + : "v"(i32val)); + + return fval; + } + + explicit inline HIP_FP8_HOST operator float() const +#else // __HIP__MI300__ + explicit inline HIP_FP8_HOST_DEVICE operator float() const +#endif // __HIP__MI300__ + { + return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>( + data); + } }; -namespace std -{ -inline hip_fp8 sin(hip_fp8 a) -{ - return hip_fp8(sinf(float(a))); -} -inline hip_fp8 cos(hip_fp8 a) -{ - return hip_fp8(cosf(float(a))); -} -HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) -{ - return a; -} -} // namespace std +namespace std { +inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } +inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } +HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } +} // namespace std // Special operator overloading -inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) -{ - return os << float(f8); +inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { + return os << float(f8); } // all + operator overloading with mixed types -// mixed types, always converts to f32, does computation in f32, and returns float -inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) -{ - return (fa + float(b)); +// mixed types, always converts to f32, does computation in f32, and returns +// float +inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { + return (fa + float(b)); } -inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) -{ - return (float(a) + fb); +inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { + return (float(a) + fb); } -inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) -{ - return hip_fp8(float(a) + float(b)); +inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { + return hip_fp8(float(a) + float(b)); } -inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) -{ - return a = hip_fp8(float(a) + float(b)); +inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { + return a = hip_fp8(float(a) + float(b)); } // overloading multiplication, always returns float, -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) -{ - return float(a) * float(b); +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { + return float(a) * float(b); } -inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) -{ - return (a * float(b)); +inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { + return (a * float(b)); } -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) -{ - return (float(a) * b); +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { + return (float(a) * b); } -inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) -{ - return ((float)a * float(b)); +inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { + return ((float)a * float(b)); } -inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) -{ - return ((float)a * float(b)); +inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { + return ((float)a * float(b)); } // overloading for compare -inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) -{ - return (a.data == b.data); +inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { + return (a.data == b.data); } -inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) -{ - return (a.data != b.data); +inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { + return (a.data != b.data); } -inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) -{ - return static_cast(a) >= static_cast(b); +inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { + return static_cast(a) >= static_cast(b); } -inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) -{ - return static_cast(a) > static_cast(b); +inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { + return static_cast(a) > static_cast(b); } diff --git a/csrc/quantization/fp8/amd/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h index e05905b4e49e8..90251c3539534 100644 --- a/csrc/quantization/fp8/amd/hip_float8_impl.h +++ b/csrc/quantization/fp8/amd/hip_float8_impl.h @@ -1,316 +1,316 @@ #pragma once -#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) -#define __HIP__MI300__ +#if defined(__HIPCC__) && \ + (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ #endif #ifdef __HIPCC__ -#define HIP_FP8_HOST_DEVICE __host__ __device__ -#define HIP_FP8_HOST __host__ -#define HIP_FP8_DEVICE __device__ + #define HIP_FP8_HOST_DEVICE __host__ __device__ + #define HIP_FP8_HOST __host__ + #define HIP_FP8_DEVICE __device__ #else -#define HIP_FP8_HOST_DEVICE -#define HIP_FP8_HOST -#define HIP_FP8_DEVICE + #define HIP_FP8_HOST_DEVICE + #define HIP_FP8_HOST + #define HIP_FP8_DEVICE #endif -namespace hip_fp8_impl -{ +namespace hip_fp8_impl { #ifdef __HIP__MI300__ -HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) -{ - uint8_t i8data; - union { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // NOTE: not endian independent - } val; - - uint32_t ival = 0; - val.fval = v; - - if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping - val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); - } - - ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, - false); // false -> WORD0 - val.i32val = ival; - i8data = val.i8val[0]; - - return i8data; +HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) { + uint8_t i8data; + union { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // NOTE: not endian independent + } val; + + uint32_t ival = 0; + val.fval = v; + + if ((val.i32val & 0x7F800000) != + 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + + return i8data; } -#endif // __HIP__MI300__ +#endif // __HIP__MI300__ -HIP_FP8_HOST inline int clz(uint32_t x) -{ - return __builtin_clz(x); -} +HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } #if defined(__HIPCC__) || defined(__CUDA_ARCH__) -HIP_FP8_DEVICE inline int clz(uint32_t x) -{ - return __clz(x); -} +HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); } #endif template -HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0) -{ +HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, + uint32_t rng = 0) { #ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; + constexpr bool is_half = std::is_same::value; #else - constexpr bool is_half = false; + constexpr bool is_half = false; #endif - constexpr bool is_float = std::is_same::value; - static_assert(wm + we == 7, "wm+we==7"); - static_assert(is_half || is_float, "Only half and float can be cast to f8"); - - const int mfmt = (sizeof(T) == 4) ? 23 : 10; - uint32_t x; + constexpr bool is_float = std::is_same::value; + static_assert(wm + we == 7, "wm+we==7"); + static_assert(is_half || is_float, "Only half and float can be cast to f8"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + uint32_t x; + if (sizeof(T) == 4) { + x = reinterpret_cast(_x); + } else { + x = reinterpret_cast(_x); + } + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { if (sizeof(T) == 4) { - x = reinterpret_cast(_x); + if ((x & 0x7F800000) == 0x7F800000) { + return 0x80; + } } else { - x = reinterpret_cast(_x); + // if(__hisinf(x) || __hisnan(x)) + if ((x & 0x7C00) == 0x7C00) { + return 0x80; + } } - - uint32_t head, mantissa; - int exponent, bias; - uint32_t sign; - + } else { if (sizeof(T) == 4) { - head = x & 0xFF800000; - mantissa = x & 0x7FFFFF; - exponent = (head >> 23) & 0xFF; - sign = head >> 31; - bias = 127; + if ((x & 0x7F800000) == 0x7F800000) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } } else { - head = x & 0xFC00; - mantissa = x & 0x3FF; - exponent = (head >> 10) & 0x1F; - sign = head >> 15; - bias = 15; + if ((x & 0x7C00) == 0x7C00) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } } - - uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); - - // Deal with inf and NaNs - if (negative_zero_nan) { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return 0x80; - } - } else { - // if(__hisinf(x) || __hisnan(x)) - if ((x & 0x7C00) == 0x7C00) { - return 0x80; - } - } - } else { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } else { - if ((x & 0x7C00) == 0x7C00) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } - } - if (x == 0) { - return 0; - } - - // First need to check if it is normal or denorm as there is a difference of - // implicit 1 Then need to adjust the exponent to align with the F8 exponent, - // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng - // to mantissa and truncate. And for RNE, no need to add rng. Then probably - // need to check whether there is carry and adjust exponent and mantissa again - - // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent - // bits - const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); - const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal - // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) - // f8_exponent is the converted f8 exponent with bias encoding - // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, - // the difference needs to be adjusted and mantissa shifted - int act_exponent, f8_exponent, exponent_diff; - - if (exponent == 0) { // fp32/fp16 is in denormal. - /* fp32 denormal is below 2^-127 so it is usually not a concern here, we + } + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = + 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ - act_exponent = exponent - bias + 1; - exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal - } else { // fp32/fp16 is normal with implicit 1 - act_exponent = exponent - bias; - if (act_exponent <= f8_denormal_act_exponent) { - /* This is the case where fp32/fp16 is normal but it is in f8 denormal - range. For example fp8 nanoo mode, denormal exponent is -7, but if the - fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, - Therefore it needs to be adjust to -6 and mantissa shift right by 1. - So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ - exponent_diff = f8_denormal_act_exponent - act_exponent; - } else { // both fp32/fp16 and f8 are in normal range - exponent_diff = 0; // exponent_diff=0 does not mean there is no difference - // for this case, - // act_exponent could be larger. Just that it does not need shift mantissa - } - mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + act_exponent = exponent - bias + 1; + exponent_diff = + f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal +range. For example fp8 nanoo mode, denormal exponent is -7, but if the +fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, +Therefore it needs to be adjust to -6 and mantissa shift right by 1. +So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no + // difference for this case, act_exponent could be + // larger. Just that it does not need shift mantissa } - - bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == - static_cast(1 << (mfmt - wm + exponent_diff - 1)); - /* This part is a bit tricky. The judgment of whether it is a tie needs to be - done before we shift right as shift right could rip off some residual part - and make something not midpoint look like midpoint. For example, the fp16 - number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after - shift right by 4 bits, it would look like midpoint. + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + static_cast(1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part + and make something not midpoint look like midpoint. For example, the fp16 + number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after + shift right by 4 bits, it would look like midpoint. */ - if (exponent_diff > 0) { - mantissa >>= exponent_diff; - } else if (exponent_diff == -1) { - mantissa <<= -exponent_diff; + if (exponent_diff > 0) { + mantissa >>= exponent_diff; + } else if (exponent_diff == -1) { + mantissa <<= -exponent_diff; + } + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit + // that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & + drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1 << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent } - bool implicit_one = mantissa & (1 << mfmt); - // if there is no implicit 1, it means the f8 is denormal and need to adjust - // to denorm exponent - f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); - - // Now we have the exponent and mantissa adjusted - uint32_t drop_mask = (1 << (mfmt - wm)) - 1; - bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that - // is not truncated is 1 - mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; - - // Now we deal with overflow - if (f8_exponent == 0) { - if ((1 << mfmt) & mantissa) { - f8_exponent = 1; // denormal overflow to become normal, promote exponent - } - } else { - if ((1 << (mfmt + 1)) & mantissa) { - mantissa >>= 1; - f8_exponent++; - } + } else { + if ((1 << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; } + } - mantissa >>= (mfmt - wm); - - // above range: quantize to maximum possible float of the same sign - const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); - if (f8_exponent > max_exp) { - if (clip) { - mantissa = (1 << wm) - 1; - f8_exponent = max_exp; - } else { - return signed_inf; - } - } + mantissa >>= (mfmt - wm); - if (f8_exponent == 0 && mantissa == 0) { - return negative_zero_nan ? 0 : (sign << 7); + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; } - mantissa &= (1 << wm) - 1; - return (sign << 7) | (f8_exponent << wm) | mantissa; + } + + if (f8_exponent == 0 && mantissa == 0) { + return negative_zero_nan ? 0 : (sign << 7); + } + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; } template -inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) -{ +inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) { #ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; + constexpr bool is_half = std::is_same::value; #else - constexpr bool is_half = false; + constexpr bool is_half = false; #endif - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "only half and float are supported"); + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported"); - constexpr int weo = is_half ? 5 : 8; - constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); - T fInf, fNegInf, fNaN, fNeg0; + T fInf, fNegInf, fNaN, fNeg0; #ifdef __HIPCC__ - if (is_half) { - const uint16_t ihInf = 0x7C00; - const uint16_t ihNegInf = 0xFC00; - const uint16_t ihNaN = 0x7C01; - const uint16_t ihNeg0 = 0x8000; - fInf = reinterpret_cast(ihInf); - fNegInf = reinterpret_cast(ihNegInf); - fNaN = reinterpret_cast(ihNaN); - fNeg0 = reinterpret_cast(ihNeg0); - } else + if (is_half) { + const uint16_t ihInf = 0x7C00; + const uint16_t ihNegInf = 0xFC00; + const uint16_t ihNaN = 0x7C01; + const uint16_t ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else #endif - if (is_float) { - const uint32_t ifInf = 0x7F800000; - const uint32_t ifNegInf = 0xFF800000; - const uint32_t ifNaN = 0x7F800001; - const uint32_t ifNeg0 = 0x80000000; - fInf = reinterpret_cast(ifInf); - fNegInf = reinterpret_cast(ifNegInf); - fNaN = reinterpret_cast(ifNaN); - fNeg0 = reinterpret_cast(ifNeg0); - } - - if (x == 0) { - return 0; - } - - uint32_t sign = x >> 7; - uint32_t mantissa = x & ((1 << wm) - 1); - int exponent = (x & 0x7F) >> wm; - if (negative_zero_nan) { - if (x == 0x80) { - return fNaN; - } - } else { - if (x == 0x80) { - return fNeg0; - } - if (exponent == ((1 << we) - 1)) { - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; - } - } - typename std::conditional::type retval; - if (we == 5 && is_half && !negative_zero_nan) { - retval = x << 8; - return reinterpret_cast(retval); + if (is_float) { + const uint32_t ifInf = 0x7F800000; + const uint32_t ifNegInf = 0xFF800000; + const uint32_t ifNaN = 0x7F800001; + const uint32_t ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) { + return fNaN; } - - const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); - - // subnormal input - if (exponent == 0) { - // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + clz(mantissa) - (32 - wm); - mantissa <<= sh; - exponent += 1 - sh; - mantissa &= ((1 << wm) - 1); + } else { + if (x == 0x80) { + return fNeg0; } - exponent += exp_low_cutoff - 1; - mantissa <<= wmo - wm; - - // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) - if (exponent <= 0) { - mantissa |= 1 << wmo; - mantissa >>= 1 - exponent; - exponent = 0; - } - - if (sizeof(T) == 2) { - retval = (sign << 15) | (exponent << 10) | mantissa; - } else { - retval = (sign << 31) | (exponent << 23) | mantissa; + if (exponent == ((1 << we) - 1)) { + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; } + } + typename std::conditional::type retval; + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; return reinterpret_cast(retval); + } + + const int exp_low_cutoff = + (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) { + retval = (sign << 15) | (exponent << 10) | mantissa; + } else { + retval = (sign << 31) | (exponent << 23) | mantissa; + } + return reinterpret_cast(retval); } -} // namespace hip_fp8_impl +} // namespace hip_fp8_impl diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index df0329f79d361..35123d7fc65d4 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -9,566 +9,567 @@ #include "../../../attention/dtype_float32.cuh" #include "../../../attention/dtype_bfloat16.cuh" -namespace vllm -{ +namespace vllm { #ifdef USE_ROCM namespace fp8 { -#ifdef ENABLE_FP8 + #ifdef ENABLE_FP8 template -__inline__ __device__ Tout vec_conversion(const Tin& x) -{ - return x; +__inline__ __device__ Tout vec_conversion(const Tin& x) { + return x; } template -__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale) -{ - return x; +__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, + const float scale) { + return x; } // fp8 -> half template <> -__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8); - return res.x; +__inline__ __device__ uint16_t +vec_conversion(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8); + return res.x; } // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0]; - tmp.h2r.y.data = f2[1]; - return tmp.ui32; -#else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = vec_conversion(static_cast(a)); - tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); - return tmp.u32; -#endif +__inline__ __device__ uint32_t +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0]; + tmp.h2r.y.data = f2[1]; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = vec_conversion(static_cast(a)); + tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); + return tmp.u32; + #endif } // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 vec_conversion(const uint32_t& a) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = vec_conversion((uint16_t)a); - tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); - return tmp.u32x2; +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; } // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 vec_conversion(const uint2& a) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = vec_conversion(a.x); - tmp.u64[1] = vec_conversion(a.y); - return tmp.u64x2; +__inline__ __device__ uint4 vec_conversion(const uint2& a) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; } using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f); +__inline__ __device__ __nv_bfloat16 +vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f); } using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) -{ - __nv_bfloat162 res; - res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); - res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); - return res; +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; } // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) -{ - bf16_4_t res; - res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); - res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); - return res; +__inline__ __device__ bf16_4_t +vec_conversion(const uint32_t& a) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; } // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) -{ - bf16_4_t tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // fp8 -> float template <> -__inline__ __device__ float vec_conversion(const uint8_t& a) -{ - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8); +__inline__ __device__ float vec_conversion(const uint8_t& a) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8); } // fp8x2 -> float2 template <> -__inline__ __device__ float2 vec_conversion(const uint16_t& a) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0]; - res.y = f2[1]; - return res; -#else - float2 res; - res.x = vec_conversion(static_cast(a)); - res.y = vec_conversion(static_cast(a >> 8U)); - return res; -#endif +__inline__ __device__ float2 +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0]; + res.y = f2[1]; + return res; + #else + float2 res; + res.x = vec_conversion(static_cast(a)); + res.y = vec_conversion(static_cast(a >> 8U)); + return res; + #endif } // fp8x4 -> float4 template <> -__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) -{ - Float4_ res; - res.x = vec_conversion((uint16_t)a); - res.y = vec_conversion((uint16_t)(a >> 16U)); - return res; +__inline__ __device__ Float4_ +vec_conversion(const uint32_t& a) { + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; } // fp8x8 -> float8 template <> -__inline__ __device__ Float8_ vec_conversion(const uint2& a) -{ - Float4_ tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ Float8_ vec_conversion(const uint2& a) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // half -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) -{ - __half_raw tmp; - tmp.x = a; +__inline__ __device__ uint8_t +vec_conversion(const uint16_t& a) { + __half_raw tmp; + tmp.x = a; - hip_fp8 f8{static_cast(tmp.data)}; - return f8.data; + hip_fp8 f8{static_cast(tmp.data)}; + return f8.data; } // bf16 -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) -{ - hip_fp8 res{__bfloat162float(a)}; - return res.data; +__inline__ __device__ uint8_t +vec_conversion(const __nv_bfloat16& a) { + hip_fp8 res{__bfloat162float(a)}; + return res.data; } // float -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const float& a) -{ - hip_fp8 f8(a); - return f8.data; +__inline__ __device__ uint8_t vec_conversion(const float& a) { + hip_fp8 f8(a); + return f8.data; } // fp8x4 -> float4 template <> -__inline__ __device__ float4 vec_conversion(const uint32_t& a) -{ - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +__inline__ __device__ float4 +vec_conversion(const uint32_t& a) { + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; } // float2 -> half2 template <> -__inline__ __device__ uint32_t vec_conversion(const float2& a) -{ - union { - half2 float16; - uint32_t uint32; - }; +__inline__ __device__ uint32_t +vec_conversion(const float2& a) { + union { + half2 float16; + uint32_t uint32; + }; - float16 = __float22half2_rn(a); - return uint32; + float16 = __float22half2_rn(a); + return uint32; } // Float4 -> half2x2 template <> -__inline__ __device__ uint2 vec_conversion(const Float4_& a) -{ - uint2 b; - float2 val; - val.x = a.x.x; - val.y = a.x.y; - b.x = vec_conversion(val); +__inline__ __device__ uint2 vec_conversion(const Float4_& a) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); - val.x = a.y.x; - val.y = a.y.y; - b.y = vec_conversion(val); - return b; + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + return b; } // Float4 -> float4 template <> -__inline__ __device__ float4 vec_conversion(const Float4_& a) -{ - float4 b; - b.x = a.x.x; - b.y = a.x.y; - b.z = a.y.x; - b.w = a.y.y; - return b; +__inline__ __device__ float4 vec_conversion(const Float4_& a) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; } // Float8 -> half2x4 template <> -__inline__ __device__ uint4 vec_conversion(const Float8_& a) -{ - uint4 b; - b.x = vec_conversion(a.x); - b.y = vec_conversion(a.y); - b.z = vec_conversion(a.z); - b.w = vec_conversion(a.w); - return b; +__inline__ __device__ uint4 vec_conversion(const Float8_& a) { + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; } // float2 -> bfloat162 template <> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) -{ - __nv_bfloat162 b = __float22bfloat162_rn(a); - return b; +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, float2>(const float2& a) { + __nv_bfloat162 b = __float22bfloat162_rn(a); + return b; } // Float4 -> bfloat162x2 template <> -__inline__ __device__ bf16_4_t vec_conversion(const Float4_& a) -{ - bf16_4_t b; - b.x = __float22bfloat162_rn(a.x); - b.y = __float22bfloat162_rn(a.y); - return b; +__inline__ __device__ bf16_4_t +vec_conversion(const Float4_& a) { + bf16_4_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + return b; } // Float8 -> bfloat162x4 template <> -__inline__ __device__ bf16_8_t vec_conversion(const Float8_& a) -{ - bf16_8_t b; - b.x = __float22bfloat162_rn(a.x); - b.y = __float22bfloat162_rn(a.y); - b.z = __float22bfloat162_rn(a.z); - b.w = __float22bfloat162_rn(a.w); - return b; +__inline__ __device__ bf16_8_t +vec_conversion(const Float8_& a) { + bf16_8_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + b.z = __float22bfloat162_rn(a.z); + b.w = __float22bfloat162_rn(a.w); + return b; } +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains -/* Scaled and vectorized conversions, for data exchange between high and low precision domains - - Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) - s.t. - Quantize(HP / scale) => FP8 - Dequant(FP8) * scale => HP + Convention of the scale in API, e.g: FP8_data = Quantization( + High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) * + scale => HP */ // fp8 -> half template <> -__inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, const float scale) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8) * scale; - return res.x; +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint8_t& a, const float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8) * scale; + return res.x; } // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, const float scale) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0] * scale; - tmp.h2r.y.data = f2[1] * scale; - return tmp.ui32; -#else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = scaled_vec_conversion(static_cast(a), scale); - tmp.u16[1] = scaled_vec_conversion(static_cast(a >> 8U), scale); - return tmp.u32; -#endif +__inline__ __device__ uint32_t scaled_vec_conversion( + const uint16_t& a, const float scale) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0] * scale; + tmp.h2r.y.data = f2[1] * scale; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = + scaled_vec_conversion(static_cast(a), scale); + tmp.u16[1] = scaled_vec_conversion( + static_cast(a >> 8U), scale); + return tmp.u32; + #endif } // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, const float scale) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); - tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return tmp.u32x2; +__inline__ __device__ uint2 +scaled_vec_conversion(const uint32_t& a, const float scale) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = + scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; } // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, const float scale) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = scaled_vec_conversion(a.x, scale); - tmp.u64[1] = scaled_vec_conversion(a.y, scale); - return tmp.u64x2; +__inline__ __device__ uint4 +scaled_vec_conversion(const uint2& a, const float scale) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; } using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f * scale); +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, + const float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f * scale); } using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale) -{ - __nv_bfloat162 res; - res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); - return res; +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, + const float scale) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = + scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; } // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, const float scale) -{ - bf16_4_t res; - res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ bf16_4_t scaled_vec_conversion( + const uint32_t& a, const float scale) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale); + return res; } // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, const float scale) -{ - bf16_4_t tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ bf16_8_t +scaled_vec_conversion(const uint2& a, const float scale) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // fp8 -> float template <> -__inline__ __device__ float scaled_vec_conversion(const uint8_t& a, const float scale) -{ - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8) * scale; +__inline__ __device__ float scaled_vec_conversion( + const uint8_t& a, const float scale) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8) * scale; } // fp8x2 -> float2 template <> -__inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, const float scale) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0] * scale; - res.y = f2[1] * scale; - return res; -#else - float2 res; - res.x = scaled_vec_conversion(static_cast(a), scale); - res.y = scaled_vec_conversion(static_cast(a >> 8U), scale); - return res; -#endif +__inline__ __device__ float2 +scaled_vec_conversion(const uint16_t& a, const float scale) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0] * scale; + res.y = f2[1] * scale; + return res; + #else + float2 res; + res.x = scaled_vec_conversion(static_cast(a), scale); + res.y = scaled_vec_conversion(static_cast(a >> 8U), + scale); + return res; + #endif } // fp8x4 -> float4 template <> -__inline__ __device__ Float4_ scaled_vec_conversion(const uint32_t& a, const float scale) -{ - Float4_ res; - res.x = scaled_vec_conversion((uint16_t)a, scale); - res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ Float4_ +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; } // fp8x8 -> float8 template <> -__inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, const float scale) -{ - Float4_ tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ Float8_ +scaled_vec_conversion(const uint2& a, const float scale) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } - /* Quantize(HP / scale) => FP8 */ // TODO(Hai): vectorized to add // half -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, const float scale) -{ - __half_raw tmp; - tmp.x = a; +__inline__ __device__ uint8_t +scaled_vec_conversion(const uint16_t& a, const float scale) { + __half_raw tmp; + tmp.x = a; - hip_fp8 f8{static_cast(tmp.data)/scale}; - return f8.data; + hip_fp8 f8{static_cast(tmp.data) / scale}; + return f8.data; } // bf16 -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const __nv_bfloat16& a, const float scale) -{ - hip_fp8 res{__bfloat162float(a)/scale}; - return res.data; +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16& a, const float scale) { + hip_fp8 res{__bfloat162float(a) / scale}; + return res.data; } // float -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const float& a, const float scale) -{ - hip_fp8 f8(a/scale); - return f8.data; +__inline__ __device__ uint8_t +scaled_vec_conversion(const float& a, const float scale) { + hip_fp8 f8(a / scale); + return f8.data; } // fp8x4 -> float4 template <> -__inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, const float scale) -{ - Float4_ tmp = scaled_vec_conversion(a, scale); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +__inline__ __device__ float4 +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ tmp = scaled_vec_conversion(a, scale); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; } -#endif // ENABLE_FP8 + #endif // ENABLE_FP8 template -__inline__ __device__ Tout convert(const Tin &x) { -#ifdef ENABLE_FP8 +__inline__ __device__ Tout convert(const Tin& x) { + #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return vec_conversion(x); } -#endif + #endif assert(false); } template -__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { -#ifdef ENABLE_FP8 +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return scaled_vec_conversion(x, scale); } -#endif + #endif assert(false); } -// The following macro is used to dispatch the conversion function based on the -// data type of the key and value cache. The FN is a macro that calls a function -// with template. -#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ - if (KV_DTYPE == "auto") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ - } else { \ - TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ - } \ - } else { \ - if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ } else { \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ } else { \ - TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ - } \ - } + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } -} // fp8 -#endif // USE_ROCM -} // namespace vllm +} // namespace fp8 +#endif // USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index b9c5d39277ca5..55be3305a9b8c 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -10,17 +10,20 @@ namespace vllm { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { - float old; - old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : - __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + float old; + old = (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(value))); - return old; + return old; } #define FP8_E4M3_MAX std::numeric_limits::max() -template -__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) { +template +__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( + const scalar_t val, const float scale) { float x = static_cast(val) / scale; float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); return static_cast(r); @@ -32,11 +35,10 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar // So to get the right answer, *scale needs to be initialized to // a value <= 0.0 and we need to wait for all thread blocks to // finish before consuming *scale. -template -__global__ void segmented_max_reduction( - float* __restrict__ scale, - const scalar_t* __restrict__ input, - int64_t num_elems) { +template +__global__ void segmented_max_reduction(float* __restrict__ scale, + const scalar_t* __restrict__ input, + int64_t num_elems) { __shared__ float cache[1024]; int i = blockDim.x * blockIdx.x + threadIdx.x; @@ -56,7 +58,7 @@ __global__ void segmented_max_reduction( int ib = blockDim.x / 2; while (ib != 0) { if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { - cache[threadIdx.x] = cache[threadIdx.x + ib]; + cache[threadIdx.x] = cache[threadIdx.x + ib]; } __syncthreads(); ib /= 2; @@ -64,16 +66,16 @@ __global__ void segmented_max_reduction( // Finally, since cache[0] contains the maximum for this thread block, // atomically write the max to the target location if (threadIdx.x == 0) { - atomicMaxFloat(scale, cache[0] / std::numeric_limits::max()); + atomicMaxFloat(scale, + cache[0] / std::numeric_limits::max()); } } -template -__global__ void scaled_fp8_quant_kernel( - c10::Float8_e4m3fn* __restrict__ out, - const scalar_t* __restrict__ input, - const float* __restrict__ scale, - int64_t num_elems) { +template +__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { int i = blockDim.x * blockIdx.x + threadIdx.x; while (i < num_elems) { out[i] = scaled_fp8_conversion(input[i], *scale); @@ -81,12 +83,11 @@ __global__ void scaled_fp8_quant_kernel( } } -} // namespace vllm +} // namespace vllm -void static_scaled_fp8_quant( - torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] +void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -95,21 +96,16 @@ void static_scaled_fp8_quant( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "scaled_fp8_quant_kernel", - [&] { - vllm::scaled_fp8_quant_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - scale.data_ptr(), - num_elems); + input.scalar_type(), "scaled_fp8_quant_kernel", [&] { + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); }); } -void dynamic_scaled_fp8_quant( - torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] +void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -118,18 +114,11 @@ void dynamic_scaled_fp8_quant( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "scaled_fp8_quant_kernel", - [&] { - vllm::segmented_max_reduction<<>>( - scale.data_ptr(), - input.data_ptr(), - num_elems); - vllm::scaled_fp8_quant_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - scale.data_ptr(), - num_elems); + input.scalar_type(), "scaled_fp8_quant_kernel", [&] { + vllm::segmented_max_reduction<<>>( + scale.data_ptr(), input.data_ptr(), num_elems); + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); }); } - diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index 4eeacf7a6f9d9..cde26dbda18cf 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -10,9 +10,9 @@ namespace vllm { #ifndef USE_ROCM namespace fp8 { -#ifdef ENABLE_FP8 + #ifdef ENABLE_FP8 -#if 0 // Disable the following code to reduce the binary size. + #if 0 // Disable the following code to reduce the binary size. template __inline__ __device__ Tout vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { @@ -177,13 +177,13 @@ __inline__ __device__ uint8_t vec_conversion( template <> __inline__ __device__ uint8_t vec_conversion( const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); -#else + #else __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); return (uint8_t)res; -#endif + #endif } // float -> fp8 @@ -276,7 +276,7 @@ __inline__ __device__ bf16_8_t vec_conversion( from_float(b, a); return b; } -#endif + #endif /* Scaled and vectorized conversions, for data exchange between high and low precision domains Convention of the scale in API, e.g: FP8_data = @@ -286,14 +286,14 @@ __inline__ __device__ bf16_8_t vec_conversion( template __inline__ __device__ Tout scaled_vec_conversion( - const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) { + const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) { return x; } // fp8 -> half template <> __inline__ __device__ uint16_t scaled_vec_conversion( - const uint8_t &a, const float scale, + const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); return float_to_half(half_to_float(tmp.x) * scale); @@ -302,7 +302,7 @@ __inline__ __device__ uint16_t scaled_vec_conversion( // fp8x2 -> half2 template <> __inline__ __device__ uint32_t scaled_vec_conversion( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint16_t u16[2]; @@ -317,7 +317,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion( // fp8x4 -> half2x2 template <> __inline__ __device__ uint2 scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint2 u32x2; @@ -333,7 +333,7 @@ __inline__ __device__ uint2 scaled_vec_conversion( // fp8x8 -> half2x4 template <> __inline__ __device__ uint4 -scaled_vec_conversion(const uint2 &a, const float scale, +scaled_vec_conversion(const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint4 u64x2; @@ -348,7 +348,7 @@ scaled_vec_conversion(const uint2 &a, const float scale, template <> __inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>( - const uint8_t &a, const float scale, + const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // Note there is no direct convert function from fp8 to bf16. // fp8 -> half @@ -362,7 +362,7 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>( template <> __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_bfloat162 res; res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, @@ -375,7 +375,7 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>( // fp8x4 -> bf16_4_t template <> __inline__ __device__ bf16_4_t scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t res; res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, @@ -388,7 +388,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion( // fp8x8 -> bf16_8_t template <> __inline__ __device__ bf16_8_t scaled_vec_conversion( - const uint2 &a, const float scale, + const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); @@ -404,9 +404,8 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion( // fp8 -> float template <> __inline__ __device__ float scaled_vec_conversion( - const uint8_t &a, const float scale, + const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { - // fp8 -> half __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); uint16_t tmp = res.x; @@ -418,7 +417,7 @@ __inline__ __device__ float scaled_vec_conversion( // fp8x2 -> float2 template <> __inline__ __device__ float2 scaled_vec_conversion( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // fp8x2 -> half2 uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); @@ -429,7 +428,7 @@ __inline__ __device__ float2 scaled_vec_conversion( // fp8x4 -> float4 template <> __inline__ __device__ Float4_ scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ res; res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); @@ -441,7 +440,7 @@ __inline__ __device__ Float4_ scaled_vec_conversion( // fp8x8 -> float8 template <> __inline__ __device__ Float8_ scaled_vec_conversion( - const uint2 &a, const float scale, + const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); @@ -457,7 +456,7 @@ __inline__ __device__ Float8_ scaled_vec_conversion( // half -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); @@ -467,21 +466,21 @@ __inline__ __device__ uint8_t scaled_vec_conversion( // bf16 -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const __nv_bfloat16 &a, const float scale, + const __nv_bfloat16& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); -#else + #else __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, __NV_SATFINITE, fp8_type); return (uint8_t)res; -#endif + #endif } // float -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const float &a, const float scale, + const float& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); @@ -491,78 +490,81 @@ __inline__ __device__ uint8_t scaled_vec_conversion( // fp8x4 -> float4 template <> __inline__ __device__ float4 scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); return res; } -#endif // ENABLE_FP8 + #endif // ENABLE_FP8 template -__inline__ __device__ Tout convert(const Tin &x) { -#if 0 // Disable the following code to reduce the binary size. +__inline__ __device__ Tout convert(const Tin& x) { + #if 0 // Disable the following code to reduce the binary size. if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return vec_conversion(x, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return vec_conversion(x, __NV_E5M2); } -#endif + #endif assert(false); } template -__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { -#ifdef ENABLE_FP8 +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return scaled_vec_conversion(x, scale, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return scaled_vec_conversion(x, scale, __NV_E5M2); } -#endif + #endif assert(false); } -// The following macro is used to dispatch the conversion function based on the -// data type of the key and value cache. The FN is a macro that calls a function -// with template. -#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ - if (KV_DTYPE == "auto") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ - } else { \ - TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ - } \ - } else { \ - if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ } else { \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ - } else if (KV_DTYPE == "fp8_e5m2") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else if (KV_DTYPE == "fp8_e5m2") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ } else { \ - TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ } \ - } else { \ - TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ - } \ - } + } -} // namespace fp8 -#endif // not USE_ROCM -} // namespace vllm +} // namespace fp8 +#endif // not USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/gptq/compat.cuh b/csrc/quantization/gptq/compat.cuh index 4da0bc6e2df38..1b3fb3d39103f 100644 --- a/csrc/quantization/gptq/compat.cuh +++ b/csrc/quantization/gptq/compat.cuh @@ -9,54 +9,54 @@ namespace vllm { namespace gptq { // atomicAdd for half types, to support CC < 7.x -__device__ __forceinline__ void atomicAdd_half(half* address, half val) -{ - unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; +__device__ __forceinline__ void atomicAdd_half(half* address, half val) { + unsigned int* address_as_ui = + (unsigned int*)((char*)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; - do - { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } - while (assumed != old); + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) + : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); } // atomicAdd for half2 types -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) -{ - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do - { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } - while (assumed != old); +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } while (assumed != old); } // #if defined(__CUDA_ARCH__) || defined(USE_ROCM) -#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } +__device__ __forceinline__ void atomicAdd(half* address, half val) { + atomicAdd_half(address, val); +} -#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } -#endif + #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { + atomicAdd_half2(address, val); +} + #endif -#endif + #endif #endif } // namespace gptq diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh index eda3436eb5375..2b6719fbdc1bc 100644 --- a/csrc/quantization/gptq/matrix_view.cuh +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -1,5 +1,6 @@ /* -Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/turboderp/exllama */ #ifndef _matrix_view_cuh @@ -13,260 +14,280 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo namespace vllm { namespace gptq { -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } - - __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const - { - half2* ptr = (half2*) item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __low2half(i01); - items[1] = __high2half(i01); - items[2] = __low2half(i23); - items[3] = __high2half(i23); - } - __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const - { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2float(__low2half(i01)); - items[1] = __half2float(__high2half(i01)); - items[2] = __half2float(__low2half(i23)); - items[3] = __half2float(__high2half(i23)); - } - - __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const - { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2half2(__low2half(i01)); - items[1] = __half2half2(__high2half(i01)); - items[2] = __half2half2(__low2half(i23)); - items[3] = __half2half2(__high2half(i23)); - } +class MatrixView_half { + public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { + return __half2half2(data[row * width + column]); + } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { + return &data[row * width + column]; + } + + __device__ __forceinline__ void item4(half (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } }; -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } - - __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) - { - half2 v01 = __halves2half2(v0, v1); - half2 v23 = __halves2half2(v2, v3); - half2* ptr = (half2*) item_ptr(row, column); - ptr[0] = v01; - ptr[1] = v23; - } +class MatrixView_half_rw { + public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half* item_ptr(int row, int column) { + return &data[row * width + column]; + } + __device__ __forceinline__ void set(int row, int column, half value) { + data[row * width + column] = value; + } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { + ((half2*)data)[(row * width + column) / 2] = value; + } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, + half v2, half v3) { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*)item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } }; -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - items[2] = (d >> 8) & 0x0f; - items[3] = (d >> 12) & 0x0f; - } +class MatrixView_q4_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } }; -class MatrixView_q4_column -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +class MatrixView_q4_column { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { + return data[row / 8 * width + column]; + } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, + int column) { + return &data[row / 8 * width + column]; + } }; -class MatrixView_q2_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x0f) * 2; - return (data[row * width / 16 + column / 16] >> shift) & 0x03; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x0f) * 2; - uint32_t d = data[row * width / 16 + column / 16] >> shift; - items[0] = d & 0x03; - items[1] = (d >> 2) & 0x03; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x0f) * 2; - uint32_t d = data[row * width / 16 + column / 16] >> shift; - items[0] = d & 0x03; - items[1] = (d >> 2) & 0x03; - items[2] = (d >> 4) & 0x03; - items[3] = (d >> 6) & 0x03; - } +class MatrixView_q2_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x0f) * 2; + return (data[row * width / 16 + column / 16] >> shift) & 0x03; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + items[2] = (d >> 4) & 0x03; + items[3] = (d >> 6) & 0x03; + } }; -class MatrixView_q3_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int z_w = column * 3 / 32; - int z_mod = column & 0x1f; - - if (z_mod == 10) { - return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); - } else if (z_mod == 21) { - return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); - } else if (z_mod < 10) { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; - } else if (z_mod < 21) { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; - } else { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; - } +class MatrixView_q3_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int z_w = column * 3 / 32; + int z_mod = column & 0x1f; + + if (z_mod == 10) { + return (data[row * width * 3 / 32 + z_w] >> 30) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); + } else if (z_mod == 21) { + return (data[row * width * 3 / 32 + z_w] >> 31) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); + } else if (z_mod < 10) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; + } else if (z_mod < 21) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; + } else { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x1f); - uint32_t d; - if (shift <= 4) { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); - } else if (shift == 8) { - d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); - } else if (shift <= 16) { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); - } else if (shift == 20) { - d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); - } else { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); - } - items[0] = d & 0x07; - items[1] = (d >> 3) & 0x07; - items[2] = (d >> 6) & 0x07; - items[3] = (d >> 9) & 0x07; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x1f); + uint32_t d; + if (shift <= 4) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); + } else if (shift == 8) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); + } else if (shift <= 16) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); + } else if (shift == 20) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); + } else { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); } + items[0] = d & 0x07; + items[1] = (d >> 3) & 0x07; + items[2] = (d >> 6) & 0x07; + items[3] = (d >> 9) & 0x07; + } }; -class MatrixView_q8_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x03) * 8; - return (data[row * width / 4 + column / 4] >> shift) & 0xff; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x03) * 8; - uint32_t d = data[row * width / 4 + column / 4] >> shift; - items[0] = d & 0xff; - items[1] = (d >> 8) & 0xff; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x03) * 2; - uint32_t d = data[row * width / 4 + column / 4] >> shift; - items[0] = d & 0xff; - items[1] = (d >> 8) & 0xff; - items[2] = (d >> 16) & 0xff; - items[3] = (d >> 24) & 0xff; - } +class MatrixView_q8_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x03) * 8; + return (data[row * width / 4 + column / 4] >> shift) & 0xff; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x03) * 8; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x03) * 2; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + items[2] = (d >> 16) & 0xff; + items[3] = (d >> 24) & 0xff; + } }; } // namespace gptq diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index cc56649917a8a..480c4986c3821 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -1,5 +1,6 @@ /* -Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/qwopqwop200/GPTQ-for-LLaMa */ #include @@ -32,2044 +33,1824 @@ namespace gptq { #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #if defined(USE_ROCM) -#include -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); + #include +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half* alpha, + const half* AP, int lda, const half* BP, int ldb, const half* beta, + half* CP, int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); } -#define hipblasHgemm __compat_hipblasHgemm + #define hipblasHgemm __compat_hipblasHgemm -// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_hgemm __compat_hipblasHgemm + // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. + #define rocblas_operation_none HIPBLAS_OP_N + #define rocblas_hgemm __compat_hipblasHgemm #endif -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hadd2(result, g_result); +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); } -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __half2float(__low2half(result)) + __half2float(__high2half(result)); +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); } -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) -{ - // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 - - float result = {}; - #pragma unroll - for (int i = 0; i < 4; i++) - { - half2 w01 = dq[i]; - float w0 = __low2float(w01); - float w1 = __high2float(w01); - float x0 = __half2float(*a_ptr++); - float x1 = __half2float(*a_ptr++); - result = fma(w0, x0, result); - result = fma(w1, x1, result); - } - float qs = __half2float(qs_h); - result *= qs; - half result_h = __float2half_rn(result); - return __hadd(result_h, g_result); +__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half* a_ptr, + const half g_result, + const half qs_h) { + // Use FP32 accumulator to avoid potential overflow since unscaled weights are + // in the range -128..127 + + float result = {}; +#pragma unroll + for (int i = 0; i < 4; i++) { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); } -__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - half result_h = __hadd(__low2half(result), __high2half(result)); - return __hfma(result_h, qs_h, g_result); +__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); } -__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - half result_h = __hadd(__low2half(result), __high2half(result)); - return __hfma(result_h, qs_h, g_result); +__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); } - -typedef void (*fp_gemm_half_q_half_gptq_kernel) -( - const half*, - const uint32_t*, - const uint32_t*, - const half*, - half*, - const int, - const int, - const int, - const int, - const int* -); - +typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*, + const uint32_t*, const half*, + half*, const int, const int, + const int, const int, + const int*); template -__global__ void gemm_half_q_half_gptq_4bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_4bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - float scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - - // Column result - float block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - } - - #pragma unroll - for (int j = 0; j < 4; j++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][4]; - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); - block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); - block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); - block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); - } - - b_ptr += size_n; - a_ptr += 8; - } - - k += 32; +#pragma unroll + for (int j = 0; j < 4; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], + block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], + block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], + block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], + block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); - half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), + __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), + __float2half_rn(block_c[m][3])); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_2bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_2bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 2); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 2); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - - b_ptr += size_n; - a_ptr += 16; - } - - k += 16; +#pragma unroll + for (int j = 0; j < 1; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + + b_ptr += size_n; + a_ptr += 16; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 16; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_3bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_3bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / 32 * 3; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / 32 * 3; - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - int4 load_int4[3]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][16]; - dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); - dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); - dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); - dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - a_ptr += 32; - } - - k += 32; +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 32; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_8bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_8bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } - } - - // Zero output - if (n >= size_n) return; - - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 8); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } + // Zero output + if (n >= size_n) return; - #pragma unroll - for (int j = 0; j < 4; j++) - { - int4 load_int4[2]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][4]; - dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); - dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); - dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); - dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); - - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - a_ptr += 8; - } - k += 32; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 8); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); +#pragma unroll + for (int j = 0; j < 4; j++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 8; } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( - bool first_block, const int m_count, const int bit) -{ - #define SELECT_KERNEL(M_COUNT) \ - if (m_count == M_COUNT) { \ - if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ - if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ - if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ - if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ - } - #if BLOCK_M_SIZE_MAX >= 1 - SELECT_KERNEL(1); - #endif - #if BLOCK_M_SIZE_MAX >= 2 - SELECT_KERNEL(2); - #endif - #if BLOCK_M_SIZE_MAX >= 3 - SELECT_KERNEL(3); - #endif - #if BLOCK_M_SIZE_MAX >= 4 - SELECT_KERNEL(4); - #endif - #if BLOCK_M_SIZE_MAX >= 5 - SELECT_KERNEL(5); - #endif - #if BLOCK_M_SIZE_MAX >= 6 - SELECT_KERNEL(6); - #endif - #if BLOCK_M_SIZE_MAX >= 7 - SELECT_KERNEL(7); - #endif - #if BLOCK_M_SIZE_MAX >= 8 - SELECT_KERNEL(8); - #endif - return NULL; + bool first_block, const int m_count, const int bit) { +#define SELECT_KERNEL(M_COUNT) \ + if (m_count == M_COUNT) { \ + if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ + if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ + if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ + } +#if BLOCK_M_SIZE_MAX >= 1 + SELECT_KERNEL(1); +#endif +#if BLOCK_M_SIZE_MAX >= 2 + SELECT_KERNEL(2); +#endif +#if BLOCK_M_SIZE_MAX >= 3 + SELECT_KERNEL(3); +#endif +#if BLOCK_M_SIZE_MAX >= 4 + SELECT_KERNEL(4); +#endif +#if BLOCK_M_SIZE_MAX >= 5 + SELECT_KERNEL(5); +#endif +#if BLOCK_M_SIZE_MAX >= 6 + SELECT_KERNEL(6); +#endif +#if BLOCK_M_SIZE_MAX >= 7 + SELECT_KERNEL(7); +#endif +#if BLOCK_M_SIZE_MAX >= 8 + SELECT_KERNEL(8); +#endif + return NULL; } +void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_q_perm, + half* c, int size_m, int size_n, int size_k, + int m_count, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = + pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(a, b_q_weight, b_gptq_qzeros, + b_gptq_scales, c, size_m, size_n, + size_k, groups, b_q_perm); +} -void gemm_half_q_half_cuda_part -( - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_q_perm, - half* c, - int size_m, - int size_n, - int size_k, - int m_count, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); - gridDim.y = DIVIDE(size_m, m_count); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); +__global__ void reconstruct_exllama_8bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - a, - b_q_weight, - b_gptq_qzeros, - b_gptq_scales, - c, - size_m, - size_n, - size_k, - groups, - b_q_perm - ); -} + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; -__global__ void reconstruct_exllama_8bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - // b offset - int qk = offset_k / (32 / 8); + // b offset + int qk = offset_k / (32 / 8); - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); - __syncthreads(); + __syncthreads(); - int k = offset_k; - int lk = 0; + int k = offset_k; + int lk = 0; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } - for (int p = 0; p < 4; p++) - { - int4 load_int4[2]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][4]; - dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); - dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); - dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); - dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); - - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + for (int p = 0; p < 4; p++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); } - k += 32; + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } } + k += 32; + } } -__global__ void reconstruct_exllama_4bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_4bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); } - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // b offset - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - - __syncthreads(); - - int k = offset_k; - int lk = 0; - - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + for (int p = 0; p < 4; p++) { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); } - - for (int p = 0; p < 4; p++) - { - half2 dq[4][4]; - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - b_ptr += size_n; - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); } - k += 32; + } } + k += 32; + } } -__global__ void reconstruct_exllama_3bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_3bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - // b offset - int qk = offset_k / 32* 3; + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - __syncthreads(); + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - int k = offset_k; - int lk = 0; + // b offset + int qk = offset_k / 32 * 3; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); - for (int p = 0; p < 1; p++) - { - int4 load_int4[3]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][16]; - dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); - dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); - dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); - dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); - - if (b_q_perm) - { - for (int j = 0; j < 16; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 16; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 1; p++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + + if (b_q_perm) { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); } - k += 32; + } } + k += 32; + } } -__global__ void reconstruct_exllama_2bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_2bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - // b offset - int qk = offset_k / (32 / 2); + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - __syncthreads(); + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - int k = offset_k; - int lk = 0; + // b offset + int qk = offset_k / (32 / 2); - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - for (int p = 0; p < 2; p++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); - - b_ptr += size_n; - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 8; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 8; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - } - k += 32; - } -} + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); -void reconstruct_exllama -( - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_q_perm, - half* out, - int height, - int width, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + __syncthreads(); - auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; - if (bit == 2) { - reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; - } else if (bit == 3) { - reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; - } else if (bit == 8) { - reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - reconstruct_exllama_kernel<<>> - ( - b_q_weight, - b_q_perm, - b_gptq_qzeros, - b_gptq_scales, - height, - width, - groups, - out - ); + for (int p = 0; p < 2; p++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } } +void reconstruct_exllama(const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_q_perm, + half* out, int height, int width, int groups, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; + if (bit == 2) { + reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; + } else if (bit == 3) { + reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; + } else if (bit == 8) { + reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_exllama_kernel<<>>( + b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, + out); +} __global__ void gemm_half_q_half_alt_4bit_kernel( - const half2* __restrict__ vec, - const uint32_t* __restrict__ mat, - half* __restrict__ mul, - const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, - const int* __restrict__ g_idx, - int batch, - int height, - int width -) -{ - int zero_width = width / 8; - int vec_height = height * 4; - const int blockwidth2 = BLOCK_KN_SIZE / 2; - int b = blockIdx.y * BLOCK_M_SIZE_MAX; - int b_end = min(BLOCK_M_SIZE_MAX, batch - b); - int h = BLOCK_KN_SIZE * blockIdx.z / 8; - int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; - int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - - __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; - if (threadIdx.x < h_end) { - for (int m = 0; m < b_end; ++m) { - blockvec[m][threadIdx.x] = - vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + - threadIdx.x]; - } + const half2* __restrict__ vec, const uint32_t* __restrict__ mat, + half* __restrict__ mul, const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; } - - __shared__ half2 deq2[256][8]; - int val = threadIdx.x / 8; - int off = threadIdx.x % 8; - for (; val < 256; val += BLOCK_KN_SIZE / 8) { - deq2[val][off] = __halves2half2( - __int2half_rn(val & 0xF), __int2half_rn(val >> 4) - ); + } + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = + __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4)); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - + 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; } - - if (blockIdx.z == 0) - { - for (int m = 0; m < b_end; m++) - mul[(b + m) * width + w] = __int2half_rn(0); - } - __syncthreads(); - - int i = width * h + w; - int g_h = h * 8; - int k = 0; - int z_w = w / 8; - int z_mod = (w % 8) * 4; - half2 res2; - half res[BLOCK_M_SIZE_MAX] = {}; - - unsigned int tmp; - while (k < h_end) { - tmp = mat[i]; - half2 scales_tmp[4]; - half2 zeros_tmp[4]; - for (int tmp_k = 0; tmp_k < 4; tmp_k++) { - int g = g_idx[g_h + (k + tmp_k) * 2]; - int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; - half scale_f = scales[g * width + w]; - half scale_f2 = scales[g2 * width + w]; - half2 scale = __halves2half2(scale_f, scale_f2); - half2 zero = __halves2half2( - __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), - __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)) - ); - scales_tmp[tmp_k] = scale; - zeros_tmp[tmp_k] = zero; - } - for (int m = 0; m < b_end; m++) { + for (int m = 0; m < b_end; m++) { #ifndef USE_ROCM - res2 = {}; + res2 = {}; #else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); #endif - res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), + blockvec[m][k + 2], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), + blockvec[m][k + 3], res2); #ifndef USE_ROCM - res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); #else - res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); #endif - } - i += width; - k += 4; - } - for (int m = 0; m < b_end; m++) { - atomicAdd(&mul[(b + m) * width + w], res[m]); } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } } - __global__ void gemm_half_q_half_alt_8bit_kernel( - const half2* __restrict__ vec, - const uint32_t* __restrict__ mat, - half* __restrict__ mul, - const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, - const int* __restrict__ g_idx, - int batch, - int height, - int width -) -{ - int zero_width = width / 4; - int vec_height = height * 2; - const int blockwidth2 = BLOCK_KN_SIZE / 2; - int b = blockIdx.y * BLOCK_M_SIZE_MAX; - int b_end = min(BLOCK_M_SIZE_MAX, batch - b); - int h = BLOCK_KN_SIZE * blockIdx.z / 4; - int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; - int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - - __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; - if (threadIdx.x < h_end) { - for (int m = 0; m < b_end; ++m) { - blockvec[m][threadIdx.x] = - vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + - threadIdx.x]; - } + const half2* __restrict__ vec, const uint32_t* __restrict__ mat, + half* __restrict__ mul, const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 4; + int vec_height = height * 2; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 4; + int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; } - - - if (blockIdx.z == 0) - { - for (int m = 0; m < b_end; m++) - mul[(b + m) * width + w] = __int2half_rn(0); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 4; + int k = 0; + int z_w = w / 4; + int z_mod = (w % 4) * 8; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[2]; + half2 zeros_tmp[2]; + for (int tmp_k = 0; tmp_k < 2; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn( + -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; } - __syncthreads(); - - int i = width * h + w; - int g_h = h * 4; - int k = 0; - int z_w = w / 4; - int z_mod = (w % 4) * 8; - half2 res2; - half res[BLOCK_M_SIZE_MAX] = {}; - - unsigned int tmp; - while (k < h_end) { - tmp = mat[i]; - half2 scales_tmp[2]; - half2 zeros_tmp[2]; - for (int tmp_k = 0; tmp_k < 2; tmp_k++) { - int g = g_idx[g_h + (k + tmp_k) * 2]; - int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; - half scale_f = scales[g * width + w]; - half scale_f2 = scales[g2 * width + w]; - half2 scale = __halves2half2(scale_f, scale_f2); - half2 zero = __halves2half2( - __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), - __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1)) - ); - scales_tmp[tmp_k] = scale; - zeros_tmp[tmp_k] = zero; - } - for (int m = 0; m < b_end; m++) { + for (int m = 0; m < b_end; m++) { #ifndef USE_ROCM - res2 = {}; + res2 = {}; #else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); #endif - half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), __int2half_rn((tmp >> 8) & 0xFF)); - res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); - half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), __int2half_rn((tmp >> 24) & 0xFF)); - res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), + __int2half_rn((tmp >> 8) & 0xFF)); + res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), + __int2half_rn((tmp >> 24) & 0xFF)); + res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); #ifndef USE_ROCM - res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); #else - res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); #endif - } - i += width; - k += 2; - } - for (int m = 0; m < b_end; m++) { - atomicAdd(&mul[(b + m) * width + w], res[m]); } + i += width; + k += 2; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } } -void gemm_half_q_half_alt -( - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* c, - int size_m, - int size_n, - int size_k, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); - gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); - - auto kernel = gemm_half_q_half_alt_4bit_kernel; - if (bit == 8) { - kernel = gemm_half_q_half_alt_8bit_kernel; - } - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - (const half2*) a, - b_q_weight, - c, - b_gptq_scales, - b_gptq_qzeros, - b_g_idx, - size_m, - size_k / 32 * bit, - size_n - ); +void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, + half* c, int size_m, int size_n, int size_k, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + auto kernel = gemm_half_q_half_alt_4bit_kernel; + if (bit == 8) { + kernel = gemm_half_q_half_alt_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, + size_m, size_k / 32 * bit, size_n); } -template -__global__ void reconstruct_gptq_kernel -( - const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, - const int height, - const int width, - const int group, - half* __restrict__ out -) -{ - // Start of block - - int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - int row = blockIdx.y * 32 / bit; - if (column >= width) return; - - // Views - - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, group, width); - T w_zeros_(w_zeros, group, width); - - uint32_t w_read = w[blockIdx.y * width + column]; - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int s = 0; s < 32; s += bit) - { - int group = g_idx[row + s / bit]; - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale); - *out_ptr = w_item; out_ptr += out_.width; - } +template +__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, const int width, + const int group, + half* __restrict__ out) { + // Start of block + + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32 / bit; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + T w_zeros_(w_zeros, group, width); + + uint32_t w_read = w[blockIdx.y * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int s = 0; s < 32; s += bit) { + int group = g_idx[row + s / bit]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = + __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), + w_scale); + *out_ptr = w_item; + out_ptr += out_.width; + } } -__global__ void reconstruct_gptq_3bit_kernel -( - const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, - const int height, - const int width, - const int group, - half* __restrict__ out -) -{ - // Start of block - int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - int row = blockIdx.y * 32; - if (column >= width) return; - - // Views - - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, group, width); - MatrixView_q3_row w_zeros_(w_zeros, group, width); - - uint32_t w1 = w[(blockIdx.y * 3) * width + column]; - uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; - uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int i = 0; i < 32; i += 1) - { - int group = g_idx[row + i]; - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - int w_item; - if (i == 10) { - w_item = (w1 >> 30) | ((w2 << 2) & 0x4); - } else if (i == 21) { - w_item = (w2 >> 31) | ((w3 << 1) & 0x6); - } else if (i < 10) { - w_item = ((w1 >> (i * 3)) & 0x7); - } else if (i < 21) { - w_item = ((w2 >> (i * 3 - 32)) & 0x7); - } else { - w_item = ((w3 >> (i * 3 - 64)) & 0x7); - } - *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); - out_ptr += out_.width; +__global__ void reconstruct_gptq_3bit_kernel( + const uint32_t* __restrict__ w, const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, + const int height, const int width, const int group, + half* __restrict__ out) { + // Start of block + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q3_row w_zeros_(w_zeros, group, width); + + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; + uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; + uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int i = 0; i < 32; i += 1) { + int group = g_idx[row + i]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + int w_item; + if (i == 10) { + w_item = (w1 >> 30) | ((w2 << 2) & 0x4); + } else if (i == 21) { + w_item = (w2 >> 31) | ((w3 << 1) & 0x6); + } else if (i < 10) { + w_item = ((w1 >> (i * 3)) & 0x7); + } else if (i < 21) { + w_item = ((w2 >> (i * 3 - 32)) & 0x7); + } else { + w_item = ((w3 >> (i * 3 - 64)) & 0x7); } + *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); + out_ptr += out_.width; + } } -void reconstruct_gptq -( - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* out, - int height, - int width, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - gridDim.y = DIVIDE(height, 32 / bit); - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - - auto kernel = reconstruct_gptq_kernel; - if (bit == 2) { - kernel = reconstruct_gptq_kernel; - } else if (bit == 8) { - kernel = reconstruct_gptq_kernel; - } else if (bit == 3) { - kernel = reconstruct_gptq_3bit_kernel; - gridDim.y = DIVIDE(height, 32); - } - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - b_q_weight, - b_gptq_scales, - b_gptq_qzeros, - b_g_idx, - height, - width, - groups, - out - ); +void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, half* out, + int height, int width, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 32 / bit); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto kernel = reconstruct_gptq_kernel; + if (bit == 2) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 8) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 3) { + kernel = reconstruct_gptq_3bit_kernel; + gridDim.y = DIVIDE(height, 32); + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(b_q_weight, b_gptq_scales, + b_gptq_qzeros, b_g_idx, height, + width, groups, out); } - -void gemm_half_q_half_cuda -( - cublasHandle_t cublas_handle, - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* c, - half* temp_dq, - int size_m, - int size_n, - int size_k, - int groups, - bool use_exllama, - int bit -) -{ - bool use_reconstruct; +void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, + half* c, half* temp_dq, int size_m, int size_n, + int size_k, int groups, bool use_exllama, int bit) { + bool use_reconstruct; + if (use_exllama) { + use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || + (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + } else { + // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so + // we disabled them for now. + use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + } + if (use_reconstruct) { + // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) { - use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); } else { - // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so we disabled them for now. - use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); } - if (use_reconstruct) { - // Reconstruct FP16 matrix, then cuBLAS - if (use_exllama) { - reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, - size_k, size_n, groups, bit); - } - else - { - reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); - } - const half alpha = __float2half(1.0f); - const half beta = __float2half(0.0f); - cublasHgemm(cublas_handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - size_n, size_m, size_k, - &alpha, temp_dq, size_n, - a, size_k, - &beta, c, size_n); + const half alpha = __float2half(1.0f); + const half beta = __float2half(0.0f); + cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k, + &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); + } else if (use_exllama) { + // Quantized matmul + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) { + gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, c, last_chunk, size_n, size_k, + BLOCK_M_SIZE_MAX, groups, bit); } - else if (use_exllama) - { - // Quantized matmul - int max_chunks = size_m / BLOCK_M_SIZE_MAX; - int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; - int last_chunk_size = size_m - last_chunk; - - if (max_chunks) - { - gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, - groups, bit); - } - if (last_chunk_size) - { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, - b_gptq_scales, b_g_idx, c + last_chunk * size_n, - last_chunk_size, size_n, size_k, last_chunk_size, - groups, bit); - } - } - else - { - gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, size_m, size_n, size_k, bit); + if (last_chunk_size) { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, + b_gptq_qzeros, b_gptq_scales, b_g_idx, + c + last_chunk * size_n, last_chunk_size, + size_n, size_k, last_chunk_size, groups, bit); } + } else { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, size_m, size_n, size_k, bit); + } } -__global__ void shuffle_4bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } +__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } } -__global__ void shuffle_8bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; } +__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } } -__global__ void shuffle_2bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; } +__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } } -__global__ void shuffle_3bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; } +__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } } -__global__ void make_sequential_4bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 3; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } -__global__ void make_sequential_2bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 4; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 16; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 4; - int w2_subrow = source_row & 0x0f; - int w2_row_shift = w2_subrow << 1; - int wnew2_row_shift = i << 1; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000300000003; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 4; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 16; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 4; + int w2_subrow = source_row & 0x0f; + int w2_row_shift = w2_subrow << 1; + int wnew2_row_shift = i << 1; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000300000003; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } -__global__ void make_sequential_3bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - int w_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w_column >= w_width) return; - int w_new_row = blockIdx.y * 3; - int q_perm_idx = blockIdx.y << 5; - uint32_t dst[3] = {0, 0, 0}; - - #pragma unroll - for (int i = 0; i < 32; i++) - { - int source_row = q_perm[q_perm_idx++]; - int z_w = (source_row / 32) * 3; - int z_mod = source_row % 32; - int z_bit; - - if (z_mod != 10){ - if (z_mod != 21){ - z_bit = z_mod; - if (z_bit > 21){ - z_bit *= 3; - z_bit -= 64; - z_w += 2; - } else if (z_bit > 10){ - z_bit *= 3; - z_bit -= 32; - z_w += 1; - } else { - z_bit *= 3; - } - } else { - z_w += 1; - } - } - - uint64_t src; - if (z_mod == 10) { - src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); - } else if (z_mod == 21){ - src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); +__global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w_column >= w_width) return; + int w_new_row = blockIdx.y * 3; + int q_perm_idx = blockIdx.y << 5; + uint32_t dst[3] = {0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 32; i++) { + int source_row = q_perm[q_perm_idx++]; + int z_w = (source_row / 32) * 3; + int z_mod = source_row % 32; + int z_bit; + + if (z_mod != 10) { + if (z_mod != 21) { + z_bit = z_mod; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; } else { - src = w[z_w * w_width + w_column]; - src >>= z_bit; - src &= 0x07; + z_bit *= 3; } + } else { + z_w += 1; + } + } - z_w = 0; - if (i != 10){ - if (i != 21){ - z_bit = i; - if (z_bit > 21){ - z_bit *= 3; - z_bit -= 64; - z_w += 2; - } else if (z_bit > 10){ - z_bit *= 3; - z_bit -= 32; - z_w += 1; - } else { - z_bit *= 3; - } - } else { - z_w += 1; - } - } - if (i == 10) { - dst[z_w] |= (src & 0x03) << 30; - dst[z_w + 1] |= ((src & 0x4) >> 2); - } else if (i == 21) { - dst[z_w] |= (src & 0x01) << 31; - dst[z_w + 1] |= ((src & 0x6) >> 1); + uint64_t src; + if (z_mod == 10) { + src = (w[z_w * w_width + w_column] >> 30) | + ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); + } else if (z_mod == 21) { + src = (w[z_w * w_width + w_column] >> 31) | + ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); + } else { + src = w[z_w * w_width + w_column]; + src >>= z_bit; + src &= 0x07; + } + + z_w = 0; + if (i != 10) { + if (i != 21) { + z_bit = i; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; } else { - dst[z_w] |= (src << z_bit); + z_bit *= 3; } + } else { + z_w += 1; + } + } + if (i == 10) { + dst[z_w] |= (src & 0x03) << 30; + dst[z_w + 1] |= ((src & 0x4) >> 2); + } else if (i == 21) { + dst[z_w] |= (src & 0x01) << 31; + dst[z_w + 1] |= ((src & 0x6) >> 1); + } else { + dst[z_w] |= (src << z_bit); } - w_new[w_new_row * w_width + w_column] = dst[0]; - w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; - w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; + } + w_new[w_new_row * w_width + w_column] = dst[0]; + w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; + w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; } -__global__ void make_sequential_8bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 2; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 4; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 2; - int w2_subrow = source_row & 0x03; - int w2_row_shift = w2_subrow << 3; - int wnew2_row_shift = i << 3; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x000000ff000000ff; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 2; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 2; + int w2_subrow = source_row & 0x03; + int w2_row_shift = w2_subrow << 3; + int wnew2_row_shift = i << 3; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x000000ff000000ff; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } +void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, + int width, int bit) { + if (q_perm) { + uint32_t* new_qweight = NULL; + cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); -void shuffle_exllama_weight -( - uint32_t* q_weight, - int* q_perm, - int height, - int width, - int bit -) -{ - if (q_perm) - { - uint32_t* new_qweight = NULL; - cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); - - dim3 blockDim, gridDim; - blockDim.x = THREADS_X; - blockDim.y = 1; - gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = height / 32 * bit; - - auto kernel = make_sequential_4bit_kernel; - if (bit == 2) { - kernel = make_sequential_2bit_kernel; - } else if (bit == 3) { - kernel = make_sequential_3bit_kernel; - gridDim.y = height / 32; - } else if (bit == 8) { - kernel = make_sequential_8bit_kernel; - } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - q_weight, - new_qweight, - q_perm, - width - ); - // Replace qweights - cudaMemcpyAsync(q_weight, new_qweight, height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - // Cleanup - cudaDeviceSynchronize(); - cudaFree(new_qweight); - } dim3 blockDim, gridDim; blockDim.x = THREADS_X; blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = 1; - auto shuffle_kernel = shuffle_4bit_kernel; + gridDim.y = height / 32 * bit; + + auto kernel = make_sequential_4bit_kernel; if (bit == 2) { - shuffle_kernel = shuffle_2bit_kernel; + kernel = make_sequential_2bit_kernel; } else if (bit == 3) { - shuffle_kernel = shuffle_3bit_kernel; + kernel = make_sequential_3bit_kernel; + gridDim.y = height / 32; } else if (bit == 8) { - shuffle_kernel = shuffle_8bit_kernel; + kernel = make_sequential_8bit_kernel; } const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - shuffle_kernel<<>>(q_weight, height, width); + kernel<<>>(q_weight, new_qweight, q_perm, + width); + // Replace qweights + cudaMemcpyAsync(q_weight, new_qweight, + height / 32 * bit * width * sizeof(uint32_t), + cudaMemcpyDeviceToDevice); + // Cleanup + cudaDeviceSynchronize(); + cudaFree(new_qweight); + } + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + auto shuffle_kernel = shuffle_4bit_kernel; + if (bit == 2) { + shuffle_kernel = shuffle_2bit_kernel; + } else if (bit == 3) { + shuffle_kernel = shuffle_3bit_kernel; + } else if (bit == 8) { + shuffle_kernel = shuffle_8bit_kernel; + } + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + shuffle_kernel<<>>(q_weight, height, width); } } // namespace gptq } // namespace vllm -torch::Tensor gptq_gemm -( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama, - int bit -) -{ - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); - at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); - - vllm::gptq::gemm_half_q_half_cuda - ( - at::cuda::getCurrentCUDABlasHandle(), - (const half*) a.data_ptr(), - (const uint32_t*) b_q_weight.data_ptr(), - (const uint32_t*)b_gptq_qzeros.data_ptr(), - (const half*) b_gptq_scales.data_ptr(), - b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), - (half*) c.data_ptr(), - (half*) temp_dq.data_ptr(), - c.size(0), // m - c.size(1), // n - a.size(1), // k - b_gptq_qzeros.size(0), // group number - use_exllama, - bit - ); - return c; +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); + at::Tensor temp_dq = torch::empty( + {b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); + + vllm::gptq::gemm_half_q_half_cuda( + at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(), + (const uint32_t*)b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*)b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(), + (half*)c.data_ptr(), (half*)temp_dq.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + b_gptq_qzeros.size(0), // group number + use_exllama, bit); + return c; } -void gptq_shuffle -( - torch::Tensor q_weight, - torch::Tensor q_perm, - int bit -) -{ - const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); - vllm::gptq::shuffle_exllama_weight( - (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(), - q_weight.size(0) * 32 / bit, - q_weight.size(1), - bit - ); +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); + vllm::gptq::shuffle_exllama_weight( + (uint32_t*)q_weight.data_ptr(), + q_perm.device().is_meta() || q_perm.numel() == 0 + ? NULL + : (int*)q_perm.data_ptr(), + q_weight.size(0) * 32 / bit, q_weight.size(1), bit); } diff --git a/csrc/quantization/gptq/qdq_2.cuh b/csrc/quantization/gptq/qdq_2.cuh index 295872a91de37..ca0f810608d1b 100644 --- a/csrc/quantization/gptq/qdq_2.cuh +++ b/csrc/quantization/gptq/qdq_2.cuh @@ -14,71 +14,60 @@ namespace gptq { // // ffddbb99 77553311 eeccaa88 66442200 -__forceinline__ __device__ void shuffle_2bit_16 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0]; - uint32_t qb = 0; +__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; - #pragma unroll - for (int i = 0; i < 8; i++) - { - uint32_t qa0 = qa & 0x03; - uint32_t qa1 = (qa & 0x0c) >> 2; - qa >>= 4; - qb |= (qa1 << (i * 2 + 16)); - qb |= (qa0 << (i * 2)); - } - q[0] = qb; +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; } -__forceinline__ __device__ void dequant_2bit_16 -( - const uint32_t q_0, - half2 (&dq)[8], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y4_ = __float2half_rn(1.0f / 4.0f); - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y4 = __halves2half2(y4_, y4_); - const half2 y16 = __halves2half2(y16_, y16_); - const half2 y64 = __halves2half2(y64_, y64_); +__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, + half2 (&dq)[8], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); - const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); - const half2 z1 = __half2half2(z1_.as_half); - const half2 z4 = __half2half2(z4_); - const half2 z16 = __half2half2(z16_); - const half2 z64 = __half2half2(z64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z4 = __half2half2(z4_); + const half2 z16 = __half2half2(z16_); + const half2 z64 = __half2half2(z64_); - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 - half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 - half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 - qa >>= 8; - half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 - half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 - half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 - half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y4, z4); - dq[2] = __hfma2(q2.as_half2, y16, z16); - dq[3] = __hfma2(q3.as_half2, y64, z64); - dq[4] = __hadd2(q4.as_half2, z1); - dq[5] = __hfma2(q5.as_half2, y4, z4); - dq[6] = __hfma2(q6.as_half2, y16, z16); - dq[7] = __hfma2(q7.as_half2, y64, z64); + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_3.cuh b/csrc/quantization/gptq/qdq_3.cuh index 3e7ecde752ba3..0d5c2adf5dbbe 100644 --- a/csrc/quantization/gptq/qdq_3.cuh +++ b/csrc/quantization/gptq/qdq_3.cuh @@ -11,128 +11,136 @@ namespace gptq { // vjjjhhhf ffdddbbb uiiiggge eecccaaa // vtttrrrp ppnnnlll usssqqqo oommmkkk -__forceinline__ __device__ void shuffle_3bit_32 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0 * stride]; - uint32_t qb = q[1 * stride]; - uint32_t qc = q[2 * stride]; - - // qa: aa999888 77766655 54443332 22111000 - // qb: lkkkjjji iihhhggg fffeeedd dcccbbba - // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll - - uint32_t qd = qc >> 26; - qc <<= 4; - qc |= qb >> 28; - qb <<= 2; - qb |= qa >> 30; - - // qa: ..999888 77766655 54443332 22111000 - // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa - // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk - // qd: vvvuuu - - uint32_t za = 0; - uint32_t zb = 0; - uint32_t zc = 0; - - for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } - for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } - for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } - - // za: 9997775 55333111 8886664 44222000 - // zb: jjjhhhf ffdddbbb iiiggge eecccaaa - // zc: tttrrrp ppnnnlll sssqqqo oommmkkk - // qd: vvvuuu - - za |= ((qd & 0x01) >> 0) << 15; - zb |= ((qd & 0x02) >> 1) << 15; - zc |= ((qd & 0x04) >> 2) << 15; - za |= ((qd & 0x08) >> 3) << 31; - zb |= ((qd & 0x10) >> 4) << 31; - zc |= ((qd & 0x20) >> 5) << 31; - - // za: v9997775 55333111 u8886664 44222000 (u, v lsb) - // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa - // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk - - q[0 * stride] = za; - q[1 * stride] = zb; - q[2 * stride] = zc; -} - -__forceinline__ __device__ void dequant_3bit_32 -( - const uint32_t q_0, - const uint32_t q_1, - const uint32_t q_2, - half2 (&dq)[16], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y8_ = __float2half_rn(1.0f / 8.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y8 = __halves2half2(y8_, y8_); - const half2 y64 = __halves2half2(y64_, y64_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); - const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); - const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); - const half2 z8 = __halves2half2(z8_, z8_); - const half2 z64 = __halves2half2(z64_, z64_); - - uint32_t qa = q_0; - uint32_t qb = q_1; - uint32_t qc = q_2; - - half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 +__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { + uint32_t t0 = qa & 0x07; + uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; - half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 - half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 - qa >>= 9; - qa &= 0x00010001; - half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 - half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + za |= (t0 << (i * 3)); + za |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qb & 0x07; + uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; - half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 - half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 - half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 - qb >>= 8; - qb &= 0x00020002; - half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 - half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + zb |= (t0 << (i * 3)); + zb |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qc & 0x07; + uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; - half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 - half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 - half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 - qc >>= 7; - qc &= 0x00040004; - half2_uint32 q15((qa | qb | qc) | c0); - - dq[ 0] = __hadd2( q0.as_half2, z1); - dq[ 1] = __hfma2( q1.as_half2, y8, z8); - dq[ 2] = __hadd2( q2.as_half2, z1); - dq[ 3] = __hfma2( q3.as_half2, y8, z8); - dq[ 4] = __hfma2( q4.as_half2, y64, z64); - dq[ 5] = __hadd2( q5.as_half2, z1); - dq[ 6] = __hfma2( q6.as_half2, y8, z8); - dq[ 7] = __hadd2( q7.as_half2, z1); - dq[ 8] = __hfma2( q8.as_half2, y8, z8); - dq[ 9] = __hfma2( q9.as_half2, y64, z64); - dq[10] = __hadd2(q10.as_half2, z1); - dq[11] = __hfma2(q11.as_half2, y8, z8); - dq[12] = __hadd2(q12.as_half2, z1); - dq[13] = __hfma2(q13.as_half2, y8, z8); - dq[14] = __hfma2(q14.as_half2, y64, z64); - dq[15] = __hadd2(q15.as_half2, z1); + zc |= (t0 << (i * 3)); + zc |= (t1 << (i * 3 + 16)); + } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y8, z8); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y8, z8); + dq[4] = __hfma2(q4.as_half2, y64, z64); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hfma2(q6.as_half2, y8, z8); + dq[7] = __hadd2(q7.as_half2, z1); + dq[8] = __hfma2(q8.as_half2, y8, z8); + dq[9] = __hfma2(q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/quantization/gptq/qdq_4.cuh index 881f353f6564d..7f65d2d2819b1 100644 --- a/csrc/quantization/gptq/qdq_4.cuh +++ b/csrc/quantization/gptq/qdq_4.cuh @@ -13,133 +13,112 @@ namespace gptq { // // 77775555 33331111 66664444 22220000 -__forceinline__ __device__ void shuffle_4bit_8 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0]; - uint32_t qb = 0; - - #pragma unroll - for (int i = 0; i < 4; i++) - { - uint32_t qa0 = qa & 0x0f; - uint32_t qa1 = (qa & 0xf0) >> 4; - qa >>= 8; - qb |= (qa1 << (i * 4 + 16)); - qb |= (qa0 << (i * 4)); - } - q[0] = qb; -} - -__forceinline__ __device__ void dequant_4bit_8 -( - const uint32_t q_0, - half2 (&dq)[4], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half2 y16 = __halves2half2(y16_, y16_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - const half2 z1 = __half2half2(z1_.as_half); - const half2 z16 = __half2half2(z16_); - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 +__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y16, z16); - dq[2] = __hadd2(q2.as_half2, z1); - dq[3] = __hfma2(q3.as_half2, y16, z16); +__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, + half2 (&dq)[4], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z16 = __half2half2(z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); } -__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale -( - const uint32_t zero, - const half scale, - half2 (&z1z16)[2], - half2 (&y1y16)[2] -) -{ - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale( + const uint32_t zero, const half scale, half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - half2 scale2 = __half2half2(scale); + half2 scale2 = __half2half2(scale); - z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); - z1z16[1] = __hmul2(scale2, __half2half2(z16)); + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); - y1y16[0] = __hmul2(scale2, __half2half2(y1)); - y1y16[1] = __hmul2(scale2, __half2half2(y16)); + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); } -__forceinline__ __device__ void dequant_4bit_8_prep_zero -( - const uint32_t zero, - half2(&z1z16)[2], - half2(&y1y16)[2] -) -{ - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); +__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, + half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - z1z16[0] = __half2half2(z1.as_half); - z1z16[1] = __half2half2(z16); + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); - y1y16[0] = __half2half2(y1); - y1y16[1] = __half2half2(y16); + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); } - -__forceinline__ __device__ void dequant_4bit_8_gptq -( - const uint32_t q_0, - half2 (&dq)[4], - half2 (&z1z16)[2], - half2 (&y1y16)[2], - int stride, - bool scaled -) -{ - const uint32_t c0 = 0x64006400; - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) - qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) - - if (scaled) - { - dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) - dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) - dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); - } - else - { - dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) - dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) - dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) - } +__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, bool scaled) { + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | + c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | + c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | + c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | + c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) { + dq[0] = __hfma2(q0.as_half2, y1y16[0], + z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } else { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], + z1z16[1]); // half2( q[6] - z, q[7] - z ) + } } } // namespace gptq } // namespace vllm diff --git a/csrc/quantization/gptq/qdq_8.cuh b/csrc/quantization/gptq/qdq_8.cuh index 0c7ad7876140b..feb5d220424b0 100644 --- a/csrc/quantization/gptq/qdq_8.cuh +++ b/csrc/quantization/gptq/qdq_8.cuh @@ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2 namespace vllm { namespace gptq { -__forceinline__ __device__ void shuffle_8bit_4 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_8bit_8 -( - const uint32_t q_0, - const uint32_t q_1, - half2 (&dq)[4], - int stride, - const uint32_t zero -) -{ - half dqh[8]; - for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero); - for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); - - for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {} + +__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], int stride, + const uint32_t zero) { + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); + + for (int i = 0; i < 4; i++) + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_util.cuh b/csrc/quantization/gptq/qdq_util.cuh index 1722a9aa6cb34..9426408fec502 100644 --- a/csrc/quantization/gptq/qdq_util.cuh +++ b/csrc/quantization/gptq/qdq_util.cuh @@ -8,51 +8,47 @@ Copied from https://github.com/turboderp/exllamav2 namespace vllm { namespace gptq { -union half2_uint32 -{ - uint32_t as_uint32; - half2 as_half2; - __device__ half2_uint32(uint32_t val) : as_uint32(val) {} - __device__ half2_uint32(half2 val) : as_half2(val) {} +union half2_uint32 { + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} }; -union half_uint16 -{ - uint16_t as_uint16; - half as_half; - __device__ half_uint16(uint16_t val) : as_uint16(val) {} - __device__ half_uint16(half val) : as_half(val) {} +union half_uint16 { + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} }; // Max_scale premultiplied by 1/256 -__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) -{ - int qs_i = qs + 1; - half qs_h = __int2half_rn(qs_i * qs_i); - qs_h = __hmul(qs_h, max_scale); - return qs_h; +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; } -__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) -{ - return __hmul(__int2half_rn(q - qzero), scale); +__forceinline__ __device__ half dq(const int q, const int qzero, + const half scale) { + return __hmul(__int2half_rn(q - qzero), scale); } -__forceinline__ __device__ half dq_ns(const int q, const int qzero) -{ - //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); - return __int2half_rn(q - qzero); +__forceinline__ __device__ half dq_ns(const int q, const int qzero) { + // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); } -__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) -{ - return (int)((q >> shift) & mask); +__forceinline__ __device__ int exb(const uint32_t q, const int shift, + const int mask) { + return (int)((q >> shift) & mask); } -__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) -{ - return (int)(__funnelshift_rc(q0, q1, shift) & mask); +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, + const int shift, const int mask) { + return (int)(__funnelshift_rc(q0, q1, shift) & mask); } } // namespace gptq diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 34950a5d13cf5..c573b9041065b 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -22,53 +22,58 @@ #include "gptq_marlin.cuh" #include "gptq_marlin_dtypes.cuh" -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\ - std::is_same::value || std::is_same::value, \ - "only float16 and bfloat16 is supported"); - -template inline std::string str(T x) { return std::to_string(x); } +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template +inline std::string str(T x) { + return std::to_string(x); +} namespace gptq_marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, - int const *__restrict__ perm_int_ptr, - int4 *__restrict__ out_int4_ptr, int size_m, +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int *__restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) {} -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &g_idx, - torch::Tensor &perm, torch::Tensor &workspace, +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full) { TORCH_CHECK_NOT_IMPLEMENTED(false, @@ -81,24 +86,26 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. template -__device__ inline void mma(const typename ScalarType::FragA &a_frag, - const typename ScalarType::FragB &frag_b, - typename ScalarType::FragC &frag_c) { - const uint32_t *a = reinterpret_cast(&a_frag); - const uint32_t *b = reinterpret_cast(&frag_b); - float *c = reinterpret_cast(&frag_c); +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); if constexpr (std::is_same::value) { - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else if constexpr (std::is_same::value) { - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } @@ -107,8 +114,9 @@ __device__ inline void mma(const typename ScalarType::FragA &a_frag, // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template -__device__ inline void ldsm4(typename ScalarType::FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -118,7 +126,8 @@ __device__ inline void ldsm4(typename ScalarType::FragA &frag_a, const // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template __device__ inline int lop3(int a, int b, int c) { +template +__device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -140,8 +149,10 @@ __device__ inline uint32_t prmt(uint32_t a) { // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: -// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 template __device__ inline typename ScalarType::FragB dequant_4bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); @@ -161,16 +172,17 @@ __device__ inline typename ScalarType::FragB dequant_4bit(int q) { const int MUL = 0x2c002c00; const int ADD = 0xd480d480; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } template <> -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { +__device__ inline typename ScalarType::FragB +dequant_4bit(int q) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; @@ -184,7 +196,7 @@ __device__ inline typename ScalarType::FragB dequant_4bit(&lo), + frag_b[0] = __hfma2(*reinterpret_cast(&lo), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), @@ -193,10 +205,12 @@ __device__ inline typename ScalarType::FragB dequant_4bit __device__ inline typename ScalarType::FragB dequant_8bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); @@ -214,24 +228,26 @@ __device__ inline typename ScalarType::FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); return frag_b; } template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { +__device__ inline typename ScalarType::FragB +dequant_8bit(int q) { typename ScalarType::FragB frag_b; float fp32_intermediates[4]; - uint32_t * fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); fp32_intermediates[0] -= 8388736.f; @@ -240,8 +256,10 @@ __device__ inline typename ScalarType::FragB dequant_8bit(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); return frag_b; } @@ -249,30 +267,32 @@ __device__ inline typename ScalarType::FragB dequant_8bit -__device__ inline void scale(typename ScalarType::FragB &frag_b, - typename ScalarType::FragS &frag_s, int i) { +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Same as above, but for act_order (each K is multiplied individually) template -__device__ inline void scale4(typename ScalarType::FragB &frag_b, - typename ScalarType::FragS &frag_s_1, - typename ScalarType::FragS &frag_s_2, - typename ScalarType::FragS &frag_s_3, - typename ScalarType::FragS &frag_s_4, +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, int i) { - using scalar_t2 = typename ScalarType::scalar_t2; + using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s_val_1_2; - s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; scalar_t2 s_val_3_4; - s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); @@ -280,14 +300,15 @@ __device__ inline void scale4(typename ScalarType::FragB &frag_b, // Given 2 floats multiply by 2 scales (halves) template -__device__ inline void scale_float(float *c, typename ScalarType::FragS &s) { - scalar_t *s_ptr = reinterpret_cast(&s); +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int *lock, int count) { +__device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -302,7 +323,7 @@ __device__ inline void barrier_acquire(int *lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int *lock, bool reset = false) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -321,11 +342,10 @@ __device__ inline void barrier_release(int *lock, bool reset = false) { // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, - int const *__restrict__ perm_int_ptr, - int4 *__restrict__ out_int4_ptr, int size_m, +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) { - int start_row = block_rows * blockIdx.x; int finish_row = start_row + block_rows; if (finish_row > size_m) { @@ -341,9 +361,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, int offset = row * row_stride; - half const *a_row_half = - reinterpret_cast(a_int4_ptr + offset); - half *out_half = reinterpret_cast(out_int4_ptr + offset); + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); int base_k = 0; @@ -374,31 +393,32 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, } } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int *__restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -445,11 +465,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice + int slice_iters; // number of threadblock tiles in the current slice int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers @@ -465,27 +485,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; + if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { @@ -605,7 +620,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; @@ -623,13 +638,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -639,30 +654,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_g_idx = sh_b + (stages * b_sh_stage); - int4 *sh_s = sh_g_idx + (stages * g_idx_stage); + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order // Zero accumulators. auto zero_accums = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; int sh_first_group_id = -1; @@ -706,18 +721,18 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } @@ -730,10 +745,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int full_pipe = a_off; int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; if (cur_k < prob_k && cur_k < slice_k_finish) { - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int4 const *cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); if (threadIdx.x < g_idx_stage) { cp_async4_pred(&sh_g_idx_stage[threadIdx.x], @@ -742,7 +757,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } else { if constexpr (group_blocks != -1) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + int4* sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (group_blocks >= thread_k_blocks) { // Only fetch scales if this tile starts a new group @@ -782,15 +797,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. auto fetch_to_registers = [&](int k, int pipe) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( + frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; @@ -805,8 +821,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk return; } - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); int group_id_1 = sh_g_idx_int_ptr[0]; int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; @@ -822,10 +838,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // No act-order case if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - int4 *sh_s_stage = + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } else { int warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; @@ -838,9 +854,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int k_blocks = cur_k / 16; int cur_group_id = k_blocks / group_blocks; - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } } @@ -867,7 +883,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // thread-id) int warp_id = threadIdx.x / 32; int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N int warp_row = warp_id / n_warps; int warp_col = warp_id % n_warps; @@ -875,7 +891,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cur_k += warp_row * 16; int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix int s_col_shift = /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + @@ -883,45 +899,44 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk if (is_same_group[pipe]) { if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); } for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); } return; } - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread + 9}; // Tensor core offsets per thread -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; int group_id = sh_g_idx_int_ptr[actual_k]; int rel_group_id = group_id - sh_first_group_id; - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; } }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; @@ -933,7 +948,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk frag_b1 = dequant_4bit(b_quant_shift); } else { - int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; @@ -943,8 +958,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Apply scale to frag_b0 if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + scale4(frag_b0, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 0); } else { if constexpr (group_blocks != -1) { scale(frag_b0, frag_s[k % 2][j], 0); @@ -953,8 +969,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Apply scale to frag_b1 if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + scale4(frag_b1, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 1); } else { if constexpr (group_blocks != -1) { @@ -962,7 +979,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); @@ -987,38 +1004,38 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. -#pragma unroll + #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll + #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -1049,39 +1066,39 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int row = (threadIdx.x % 32) / 4; if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + row < prob_m); + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -1115,8 +1132,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS &s) { - scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) @@ -1124,13 +1142,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk res = __hmul2(res, s[0]); } - ((scalar_t2 *)sh)[idx] = res; + ((scalar_t2*)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], @@ -1147,7 +1165,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -1162,7 +1180,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Start global fetch and register load pipelines. auto start_pipes = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < stages - 1; i++) { if (has_act_order && i == 0) { int last_g_idx = slice_k_start + stages * tb_k * 2; @@ -1193,9 +1211,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // have even length meaning that the next iteration will always start at // index 0. -#pragma unroll + #pragma unroll for (int pipe = 0; pipe < stages;) { -#pragma unroll + #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); @@ -1261,8 +1279,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else { @@ -1270,8 +1288,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } } @@ -1282,31 +1300,35 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); } } } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; @@ -1315,13 +1337,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } // Update slice k/n for scales loading @@ -1341,23 +1362,24 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } -#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ - prob_k, locks); \ - } + #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ + prob_k, locks); \ + } typedef struct { int thread_k; @@ -1389,7 +1411,7 @@ thread_config_t large_batch_thread_configs[] = { }; -int get_scales_cache_size(thread_config_t const &th_config, int prob_m, +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full) { bool cache_scales_chunk = has_act_order && !is_k_full; @@ -1402,15 +1424,15 @@ int get_scales_cache_size(thread_config_t const &th_config, int prob_m, if (group_size == -1) { tb_groups = 1; } else if (group_size == 0) { - tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size } else { tb_groups = div_ceil(tb_k, group_size); } if (cache_scales_chunk) { int load_groups = - tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; } else { @@ -1420,7 +1442,7 @@ int get_scales_cache_size(thread_config_t const &th_config, int prob_m, } } -bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int scales_cache_size, int max_shared_mem) { int pack_factor = 32 / num_bits; @@ -1451,12 +1473,12 @@ bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, float pipe_size = (a_size + b_size) * pipe_stages; - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); } -bool is_valid_config(thread_config_t const &th_config, int max_m_blocks, +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int max_shared_mem) { @@ -1519,43 +1541,43 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } } - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, - void *g_idx, void *perm, void *a_tmp, int prob_m, - int prob_n, int prob_k, void *workspace, int num_bits, +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, + void* g_idx, void* perm, void* a_tmp, int prob_m, + int prob_n, int prob_k, void* workspace, int num_bits, bool has_act_order, bool is_k_full, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par) { @@ -1639,15 +1661,15 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, } } - const int4 *A_ptr = (const int4 *)A; - const int4 *B_ptr = (const int4 *)B; - int4 *C_ptr = (int4 *)C; - const int4 *s_ptr = (const int4 *)s; - const int *g_idx_ptr = (const int *)g_idx; - const int *perm_ptr = (const int *)perm; - int4 *a_tmp_ptr = (int4 *)a_tmp; + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; - int *locks = (int *)workspace; + int* locks = (int*)workspace; if (has_act_order) { // Permute A columns @@ -1673,8 +1695,7 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) - par = max_par; + if (par > max_par) par = max_par; prob_m = (16 * exec_cfg.max_m_blocks) * par; i += exec_cfg.max_m_blocks * (par - 1); thread_m_blocks = exec_cfg.max_m_blocks; @@ -1709,11 +1730,11 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, } } -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &g_idx, - torch::Tensor &perm, torch::Tensor &workspace, +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full) { // Verify num_bits @@ -1824,18 +1845,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, gptq_marlin::max_par); + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, + group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, gptq_marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, gptq_marlin::max_par); + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, + is_k_full, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh index 35ea48aaba310..ba5368ea8835f 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -11,22 +11,23 @@ namespace gptq_marlin { -// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per -// schedule allows some more latency hiding. At the same time, we want relatively few warps to have -// many registers per warp and small tiles. +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. static constexpr int default_threads = 256; -static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory +static constexpr int pipe_stages = + 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; -static constexpr int max_par = 16; +static constexpr int max_par = 16; template struct Vec { - T elems[n]; + T elems[n]; __device__ T& operator[](int i) { return elems[i]; } }; @@ -35,30 +36,35 @@ using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - // No support for async +// No support for async #else -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); } -__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} template __device__ inline void cp_async_wait() { @@ -67,4 +73,4 @@ __device__ inline void cp_async_wait() { #endif -} // namespace gptq_marlin +} // namespace gptq_marlin diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh index 7881abbe4cbbf..ca1b7099d6ec7 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh @@ -5,58 +5,73 @@ #include #include - namespace gptq_marlin { template -class ScalarType { -}; +class ScalarType {}; template <> class ScalarType { -public: - using scalar_t = half; - using scalar_t2 = half2; - - // Matrix fragments for tensor core instructions; their precise layout is - // documented here: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; - - static __device__ float inline num2float(const half x) { return __half2float(x); } - - static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } - - static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } - - static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } + public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } }; template <> class ScalarType { -public: - using scalar_t = nv_bfloat16; - using scalar_t2 = nv_bfloat162; + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } - - static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } - - static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } - - static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } #endif }; -} +} // namespace gptq_marlin #endif diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index 0d3da6240dbca..4adc158eb14ea 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -12,14 +12,14 @@ static constexpr int tile_n_size = tile_k_size * 4; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 template -__global__ void -marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, - uint32_t const *__restrict__ perm_ptr, - uint32_t *__restrict__ out_ptr, int size_k, int size_n) {} +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) {} -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -30,10 +30,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, #else template -__global__ void -marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, - uint32_t const *__restrict__ perm_ptr, - uint32_t *__restrict__ out_ptr, int size_k, int size_n) { +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; int k_tiles = size_k / tile_k_size; @@ -61,8 +61,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, constexpr int perm_size = tile_k_size / 4; - int4 *sh_perm_ptr = sh; - int4 *sh_pipe_ptr = sh_perm_ptr; + int4* sh_perm_ptr = sh; + int4* sh_pipe_ptr = sh_perm_ptr; if constexpr (has_perm) { sh_pipe_ptr += perm_size; } @@ -76,7 +76,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, auto load_perm_to_shared = [&](int k_tile_id) { int first_k_int4 = (k_tile_id * tile_k_size) / 4; - int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); + int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); if (threadIdx.x < perm_size) { sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; @@ -92,22 +92,22 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int first_n = n_tile_id * tile_n_size; - int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; + int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; if constexpr (has_perm) { if (threadIdx.x < stage_size) { int k_id = threadIdx.x / stage_n_threads; int n_id = threadIdx.x % stage_n_threads; - uint32_t const *sh_perm_int_ptr = - reinterpret_cast(sh_perm_ptr); + uint32_t const* sh_perm_int_ptr = + reinterpret_cast(sh_perm_ptr); int src_k = sh_perm_int_ptr[k_id]; int src_k_packed = src_k / pack_factor; cp_async4( &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&( + reinterpret_cast(&( b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); } @@ -120,7 +120,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int first_k_packed = first_k / pack_factor; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( + reinterpret_cast( &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); } @@ -151,10 +151,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, constexpr int sh_stride = 64; constexpr uint32_t mask = (1 << num_bits) - 1; - int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; - uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); uint32_t vals[8]; @@ -176,17 +176,16 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } } else { - uint32_t b1_vals[tile_ints]; uint32_t b2_vals[tile_ints]; -#pragma unroll + #pragma unroll for (int i = 0; i < tile_ints; i++) { b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; } -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { int cur_elem = tc_row + tc_offsets[i]; int cur_int = cur_elem / pack_factor; @@ -206,7 +205,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; uint32_t res = 0; -#pragma unroll + #pragma unroll for (int i = 0; i < 8; i++) { res |= vals[pack_idx[i]] << (i * 4); } @@ -218,7 +217,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t res1 = 0; uint32_t res2 = 0; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { res1 |= vals[pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8); @@ -230,14 +229,14 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { -#pragma unroll + #pragma unroll for (int pipe = 0; pipe < repack_stages - 1; pipe++) { fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); } wait_for_stage(); }; -#pragma unroll + #pragma unroll for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { int n_tile_id = 0; @@ -248,7 +247,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, start_pipes(k_tile_id, n_tile_id); while (n_tile_id < n_tiles) { -#pragma unroll + #pragma unroll for (int pipe = 0; pipe < repack_stages; pipe++) { fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); @@ -260,21 +259,21 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } } -} // namespace gptq_marlin - -#define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - gptq_marlin::marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - gptq_marlin::marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ - } +} // namespace gptq_marlin + + #define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin::marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin::marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } -torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 @@ -318,11 +317,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, bool has_perm = perm.size(0) != 0; // Get ptrs - uint32_t const *b_q_weight_ptr = - reinterpret_cast(b_q_weight.data_ptr()); - uint32_t const *perm_ptr = - reinterpret_cast(perm.data_ptr()); - uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); + uint32_t const* b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); // Get dev info int dev = b_q_weight.get_device(); diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index 002a70001885d..03d66cecedf1f 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -25,7 +25,10 @@ #include -template inline std::string str(T x) { return std::to_string(x); } +template +inline std::string str(T x) { + return std::to_string(x); +} namespace marlin { @@ -38,9 +41,10 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } // corresponding index accesses must be compile-time constants, which is why we // extensively use `#pragma unroll` throughout the kernel code to guarantee // this. -template struct Vec { +template +struct Vec { T elems[n]; - __device__ T &operator[](int i) { return elems[i]; } + __device__ T& operator[](int i) { return elems[i]; } }; using I4 = Vec; @@ -51,29 +55,32 @@ using I4 = Vec; using FragA = Vec; using FragB = Vec; using FragC = Vec; -using FragS = Vec; // quantization scales +using FragS = Vec; // quantization scales // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } // Asynchronous global->shared copy -__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. @@ -82,28 +89,30 @@ __device__ inline void cp_async_fence() { } // Wait until at most `n` async copy stages are still pending. -template __device__ inline void cp_async_wait() { +template +__device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, - FragC &frag_c) { - const uint32_t *a = reinterpret_cast(&a_frag); - const uint32_t *b = reinterpret_cast(&frag_b); - float *c = reinterpret_cast(&frag_c); - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -113,7 +122,8 @@ __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template __device__ inline int lop3(int a, int b, int c) { +template +__device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -138,24 +148,24 @@ __device__ inline FragB dequant(int q) { const int MUL = 0x2c002c00; const int ADD = 0xd480d480; FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int *lock, int count) { +__device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -170,7 +180,7 @@ __device__ inline void barrier_acquire(int *lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int *lock, bool reset = false) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -187,26 +197,27 @@ __device__ inline void barrier_release(int *lock, bool reset = false) { } } -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -241,11 +252,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice + int slice_iters; // number of threadblock tiles in the current slice int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers @@ -261,27 +272,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; + if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { @@ -293,29 +299,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; init_slice(); - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory // We typically use `constexpr` to indicate that this value is a compile-time // constant constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory + 8; // delta between subsequent A tiles in global memory int a_gl_rd_delta_i = a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile constexpr int a_sh_wr_delta = - a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + a_sh_stride * + (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads + (thread_n_blocks / 4)); // between shared memory tile reads constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile + a_sh_stride * 16; // within a shared memory tile constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile + a_sh_wr_delta); // number of shared write iterations for a tile int b_gl_stride = 16 * prob_n / 32; constexpr int b_sh_stride = 32 * thread_n_blocks / 4; @@ -368,7 +375,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; @@ -387,13 +394,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -403,16 +410,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_s = sh_b + (stages * b_sh_stage); + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2]; @@ -421,34 +428,33 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Zero accumulators. auto zero_accums = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); B_ptr[i] += b_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) - cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -475,37 +481,35 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // theoretically better attempts have lead to bad instruction ordering by // the compiler and correspondingly a noticeable drop in performance. if (group_blocks != -1) { - int4 *sh_s_stage = + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll for (int j = 0; j < 4; j++) { int b_quant = frag_b_quant[k % 2][j]; int b_quant_shift = b_quant >> 8; FragB frag_b0 = dequant(b_quant); // If there are no groups, we can just scale the final output once and can // avoid doing so for each weight. - if (group_blocks != -1) - scale(frag_b0, frag_s[k % 2][j], 0); + if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) - scale(frag_b1, frag_s[k % 2][j], 1); -#pragma unroll + if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); @@ -530,38 +534,38 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. -#pragma unroll + #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll + #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -571,9 +575,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -592,39 +596,39 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int row = (threadIdx.x % 32) / 4; if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + row < prob_m); + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half *>(&c_red)[j]); + __half2float(reinterpret_cast<__half*>(&c_red)[j]); } } if (!last) { int4 c; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half *>(&c)[j] = - __float2half(reinterpret_cast( + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -658,17 +662,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS &s) { + auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); if (group_blocks == - -1) // for per-column quantization we finally apply the scale here + -1) // for per-column quantization we finally apply the scale here res = __hmul2(res, s[0]); - ((half2 *)sh)[idx] = res; + ((half2*)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], @@ -685,7 +689,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -699,9 +703,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Start global fetch and register load pipelines. auto start_pipes = [&]() { -#pragma unroll - for (int i = 0; i < stages - 1; i++) - fetch_to_shared(i, i, i < slice_iters); + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); zero_accums(); wait_for_stage(); fetch_to_registers(0, 0); @@ -711,12 +714,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Main loop. while (slice_iters) { -// We unroll over both the global fetch and the register load pipeline to ensure -// all shared memory accesses are static. Note that both pipelines have even -// length meaning that the next iteration will always start at index 0. -#pragma unroll + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { -#pragma unroll + #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); if (k == b_sh_wr_iters - 2) { @@ -728,8 +731,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk matmul(k); } slice_iters--; - if (slice_iters == 0) - break; + if (slice_iters == 0) break; } a_gl_rd += a_gl_rd_delta_o * stages; @@ -742,8 +744,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // For per-column scales, we only fetch them here in the final step before // write-out if (group_blocks == -1 && last) { - if (s_sh_wr_pred) - cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } thread_block_reduce(); @@ -751,17 +752,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; @@ -770,13 +771,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); @@ -787,26 +787,27 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk #else -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Marlin is not implemented yet for SM < 8.0 assert(false); @@ -819,10 +820,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; @@ -831,7 +832,7 @@ static constexpr int tile_size = 16; static constexpr int max_par = 16; static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit + 8; // We have 8 4-bit vals inside a 32 bit #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ GROUP_BLOCKS, NUM_THREADS) \ @@ -858,23 +859,23 @@ thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N }; thread_config_t large_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X }; -bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || @@ -907,7 +908,6 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, } thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { for (auto th_config : small_batch_thread_configs) { if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { @@ -926,20 +926,20 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) -void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, - int prob_n, int prob_k, void *workspace, int groupsize = -1, +void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, + int prob_n, int prob_k, void* workspace, int groupsize = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, int sms = -1, int max_par = 16) { int tot_m = prob_m; @@ -996,12 +996,12 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, " is not divisible by group_blocks = ", group_blocks); } - const int4 *A_ptr = (const int4 *)A; - const int4 *B_ptr = (const int4 *)B; - int4 *C_ptr = (int4 *)C; - const int4 *s_ptr = (const int4 *)s; + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; - int *locks = (int *)workspace; + int* locks = (int*)workspace; for (int i = 0; i < tot_m_blocks; i += 4) { int thread_m_blocks = tot_m_blocks - i; @@ -1011,8 +1011,7 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) - par = max_par; + if (par > max_par) par = max_par; prob_m = 64 * par; i += 4 * (par - 1); thread_m_blocks = 4; @@ -1041,12 +1040,11 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, } } -} // namespace marlin +} // namespace marlin -torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &workspace, +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M TORCH_CHECK(size_m == a.size(0), "Shape mismatch: a.size(0) = " + str(a.size(0)) + @@ -1074,9 +1072,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); // Verify A device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); diff --git a/csrc/quantization/marlin/sparse/common/base.h b/csrc/quantization/marlin/sparse/common/base.h index 929b39d7642f1..16018d331bec2 100644 --- a/csrc/quantization/marlin/sparse/common/base.h +++ b/csrc/quantization/marlin/sparse/common/base.h @@ -26,12 +26,14 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } // corresponding index accesses must be compile-time constants, which is why we // extensively use `#pragma unroll` throughout the kernel code to guarantee // this. -template struct Vec { +template +struct Vec { T elems[n]; - __device__ T &operator[](int i) { return elems[i]; } + __device__ T& operator[](int i) { return elems[i]; } }; -template struct ShapeBase { +template +struct ShapeBase { static constexpr int M = M_, N = N_, K = K_; }; @@ -44,6 +46,6 @@ using FragA = Vec; using FragB = Vec; using FragM = Vec; using FragC = Vec; -using FragS = Vec; // quantization scales +using FragS = Vec; // quantization scales -} // namespace marlin_24 +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mem.h b/csrc/quantization/marlin/sparse/common/mem.h index a49d15ca544eb..83e3578d2f511 100644 --- a/csrc/quantization/marlin/sparse/common/mem.h +++ b/csrc/quantization/marlin/sparse/common/mem.h @@ -21,41 +21,44 @@ namespace marlin_24 { // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred_zfill(void *smem_ptr, - const void *glob_ptr, +__device__ inline void cp_async4_pred_zfill(void* smem_ptr, + const void* glob_ptr, bool pred = true, const bool zfill = false) { const int BYTES = 16; int src_in_bytes = (zfill ? 0 : BYTES); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); } -__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } // Asynchronous global->shared copy -__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. @@ -64,22 +67,23 @@ __device__ inline void cp_async_fence() { } // Wait until at most `n` async copy stages are still pending. -template __device__ inline void cp_async_wait() { +template +__device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); } -__device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_m); +__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_m); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) @@ -88,8 +92,8 @@ __device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) { // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" @@ -98,7 +102,7 @@ __device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) { } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int *lock, int count) { +__device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -113,7 +117,7 @@ __device__ inline void barrier_acquire(int *lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int *lock, bool reset = false) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -129,4 +133,4 @@ __device__ inline void barrier_release(int *lock, bool reset = false) { : "l"(lock), "r"(val)); } } -} // namespace marlin_24 +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h index 9319456677d36..45ab67a78a1de 100644 --- a/csrc/quantization/marlin/sparse/common/mma.h +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -22,51 +22,56 @@ namespace marlin_24 { // m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1, - const FragA &frag_b, FragC &frag_c, FragM &frag_m, +__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, + const FragA& frag_b, FragC& frag_c, FragM& frag_m, const int psel) { - const uint32_t *a0 = reinterpret_cast(&a_frag0); - const uint32_t *a1 = reinterpret_cast(&a_frag1); - const uint32_t *b = reinterpret_cast(&frag_b); - const uint32_t *e = reinterpret_cast(&frag_m); - float *c = reinterpret_cast(&frag_c); + const uint32_t* a0 = reinterpret_cast(&a_frag0); + const uint32_t* a1 = reinterpret_cast(&a_frag1); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* e = reinterpret_cast(&frag_m); + float* c = reinterpret_cast(&frag_c); if (psel == 0) { - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]), + "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), + "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), + "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), + "r"(e[0])); } else { - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]), + "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), + "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), + "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), + "r"(e[0])); } } // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template __device__ inline int lop3(int a, int b, int c) { +template +__device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -120,11 +125,11 @@ __device__ inline FragB dequant_4bit(int q) { const int ADD = 0xd480d480; FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } @@ -143,24 +148,24 @@ __device__ inline FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } -__device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3, - FragS &s0, float *c4, float *c5, float *c6, - float *c7, FragS &s1) { +__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, + FragS& s0, float* c4, float* c5, float* c6, + float* c7, FragS& s1) { *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); @@ -172,4 +177,4 @@ __device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3, *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); } -} // namespace marlin_24 +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 42b0566183a8d..54ad27676e207 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -32,12 +32,15 @@ #else -#include "common/mem.h" -#include "common/mma.h" + #include "common/mem.h" + #include "common/mma.h" #endif -template inline std::string str(T x) { return std::to_string(x); } +template +inline std::string str(T x) { + return std::to_string(x); +} namespace marlin_24 { @@ -45,7 +48,7 @@ namespace marlin_24 { // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. static constexpr int THREADS = 256; -static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory +static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 128; @@ -54,35 +57,36 @@ static constexpr int max_par = 16; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void Marlin_24( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4 - *__restrict__ meta, // 2bit metadata information about 2:4 format on B - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) {} -torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, int64_t num_bits, +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -92,29 +96,30 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, #else -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void Marlin_24( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4 - *__restrict__ meta, // 2bit metadata information about 2:4 format on B - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -174,27 +179,22 @@ __global__ void Marlin_24( auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; + if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { @@ -207,7 +207,7 @@ __global__ void Marlin_24( init_slice(); // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory // stride of an A matrix tile in shared memory constexpr int a_sh_stride = 32 * thread_k_blocks / 8; @@ -239,9 +239,9 @@ __global__ void Marlin_24( constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 constexpr int m_sh_stride = - (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); constexpr int m_sh_wr_delta = threads / 2; @@ -305,7 +305,7 @@ __global__ void Marlin_24( // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; } @@ -325,13 +325,13 @@ __global__ void Marlin_24( // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < thread_m_blocks; j++) { a_sh_rd_trans[0][i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -344,23 +344,23 @@ __global__ void Marlin_24( // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; - const int4 *meta_ptr[m_sh_iters]; -#pragma unroll + const int4* meta_ptr[m_sh_iters]; + #pragma unroll for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_s = sh_b + (stages * b_sh_stage); - int4 *sh_m = sh_s + (stages * s_sh_stage); + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + int4* sh_m = sh_s + (stages * s_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks][2]; I4 frag_b_quant[2][b_thread_vecs]; @@ -370,46 +370,43 @@ __global__ void Marlin_24( // Zero accumulators. auto zero_accums = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], - B_ptr[i] + j); + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } B_ptr[i] += b_gl_rd_delta_o; } - int4 *sh_meta_stage = sh_m + m_sh_stage * pipe; -#pragma unroll + int4* sh_meta_stage = sh_m + m_sh_stage * pipe; + #pragma unroll for (int i = 0; i < m_sh_iters; i++) { if (m_sh_wr_pred) - cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], - meta_ptr[i]); + cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); meta_ptr[i] += m_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) - cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -436,13 +433,13 @@ __global__ void Marlin_24( // theoretically better attempts have lead to bad instruction ordering by // the compiler and correspondingly a noticeable drop in performance. if (group_blocks != -1) { - int4 *sh_s_stage = + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { ldsm4(frag_a[k % 2][i][0], &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); @@ -450,24 +447,24 @@ __global__ void Marlin_24( &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( + frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } // Load meta with ldsm4 - int4 *sh_m_stage = sh_m + m_sh_stage * pipe; + int4* sh_m_stage = sh_m + m_sh_stage * pipe; ldsm4_m(frag_m[k % 2][0], &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; @@ -480,7 +477,7 @@ __global__ void Marlin_24( frag_b1 = dequant_4bit(b_quant_shift); } else { - int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; @@ -497,7 +494,7 @@ __global__ void Marlin_24( scale(frag_b1, frag_s[k % 2][j], 1); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], frag_m[k % 2][j / 2], j % 2); @@ -518,41 +515,41 @@ __global__ void Marlin_24( int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); -// Parallel logarithmic shared memory reduction. We make sure to avoid any -// unnecessary read or write iterations, e.g., for two warps we write only once -// by warp 1 and read only once by warp 0. -#pragma unroll + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll + #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -562,9 +559,9 @@ __global__ void Marlin_24( }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -574,7 +571,7 @@ __global__ void Marlin_24( int c_gl_stride = prob_n / 8; int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; int c_gl_wr_delta_i = - c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; @@ -584,10 +581,10 @@ __global__ void Marlin_24( int col = 2 * ((threadIdx.x % 32) % 4); if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + @@ -599,32 +596,32 @@ __global__ void Marlin_24( cp_async_wait<0>(); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + col + (i % 2) < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll + #pragma unroll for (int j2 = 0; j2 < 2; j2++) { -#pragma unroll + #pragma unroll for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + 4 * ((i % 4) / 2) + i % 2] += __half2float( - reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]); + reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); } } } if (!last) { int4 c; -#pragma unroll + #pragma unroll for (int j2 = 0; j2 < 2; j2++) { -#pragma unroll + #pragma unroll for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] = - __float2half(reinterpret_cast( + reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = + __float2half(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + 4 * ((i % 4) / 2) + i % 2]); } @@ -643,9 +640,9 @@ __global__ void Marlin_24( auto write_result = [&]() { int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: - constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: - constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); @@ -654,22 +651,22 @@ __global__ void Marlin_24( c_gl_wr += (2 * thread_n_blocks) * slice_col; int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + - ((threadIdx.x % 32) / 4); // RLC: - c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) constexpr int c_sh_rd_delta = - c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + (threadIdx.x % (2 * 2 * thread_n_blocks)); int c_gl_wr_end = c_gl_stride * prob_m; - auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0, - float c4, float c5, float c6, float c7, FragS &s1) { + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, + float c4, float c5, float c6, float c7, FragS& s1) { uint2 res[2]; res[0] = to_half4(c0, c1, c2, c3); res[1] = to_half4(c4, c5, c6, c7); - half2 *tmp = (half2 *)&res; + half2* tmp = (half2*)&res; // for per-column quantization we finally apply the scale here if constexpr (group_blocks == -1 && num_bits == 4) { tmp[0] = __hmul2(tmp[0], s0[0]); @@ -677,12 +674,12 @@ __global__ void Marlin_24( tmp[2] = __hmul2(tmp[2], s1[0]); tmp[3] = __hmul2(tmp[3], s1[1]); } - ((int4 *)sh)[idx] = *((int4 *)&res[0]); + ((int4*)sh)[idx] = *((int4*)&res[0]); }; // RLC: only warp 0 and 1 baseline example if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { int wr = c_sh_wr; write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], @@ -707,7 +704,7 @@ __global__ void Marlin_24( } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -721,9 +718,8 @@ __global__ void Marlin_24( // Start global fetch and register load pipelines. auto start_pipes = [&]() { -#pragma unroll - for (int i = 0; i < stages - 1; i++) - fetch_to_shared(i, i, i < slice_iters); + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); zero_accums(); wait_for_stage(); fetch_to_registers(0, 0); @@ -733,10 +729,10 @@ __global__ void Marlin_24( // Main loop. while (slice_iters) { -// We unroll over both the global fetch and the register load pipeline to ensure -// all shared memory accesses are static. Note that both pipelines have even -// length meaning that the next iteration will always start at index 0. -#pragma unroll + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); @@ -747,8 +743,7 @@ __global__ void Marlin_24( pipe++; slice_iters--; - if (slice_iters == 0) - break; + if (slice_iters == 0) break; } a_gl_rd += a_gl_rd_delta_o * stages; @@ -762,13 +757,11 @@ __global__ void Marlin_24( // write-out if constexpr (group_blocks == -1) { if constexpr (num_bits == 8) { - if (s_sh_wr_pred) - cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } else { if (last) { - if (s_sh_wr_pred) - cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } } @@ -780,14 +773,14 @@ __global__ void Marlin_24( cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); } } else { if (last) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); } } } @@ -798,7 +791,7 @@ __global__ void Marlin_24( // overflow in fp16) if constexpr (group_blocks == -1 && num_bits == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], @@ -827,13 +820,13 @@ __global__ void Marlin_24( } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; @@ -843,19 +836,17 @@ __global__ void Marlin_24( if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; -#pragma unroll + #pragma unroll for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; -#pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] -= m_gl_stride; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); @@ -866,26 +857,26 @@ __global__ void Marlin_24( #endif -#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS) { \ - cudaFuncSetAttribute( \ - Marlin_24, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin_24 \ - <<>>(A_ptr, B_ptr, meta_ptr, \ - C_ptr, s_ptr, prob_n, \ - prob_m, prob_k, locks); \ +#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + cudaFuncSetAttribute( \ + Marlin_24, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin_24 \ + <<>>(A_ptr, B_ptr, meta_ptr, \ + C_ptr, s_ptr, prob_n, \ + prob_m, prob_k, locks); \ } -void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, - void *s, int prob_m, int prob_n, int prob_k, - void *workspace, int num_bits, int groupsize = -1, +void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, + void* s, int prob_m, int prob_n, int prob_k, + void* workspace, int num_bits, int groupsize = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int thread_m = -1, int sms = -1, int max_par = 16) { int tot_n = prob_n; @@ -904,8 +895,8 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, if (thread_k == -1 || thread_m == -1) { if (prob_n <= 16) { - // For small batchizes, better partitioningif is slightly more important than - // better compute utilization + // For small batchizes, better partitioningif is slightly more important + // than better compute utilization thread_k = 128; thread_m = 128; } else { @@ -914,7 +905,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, } } - int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction + int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction int thread_m_blocks = thread_m / 16; int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; int blocks = sms; @@ -931,13 +922,13 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); - const int4 *A_ptr = (const int4 *)A; - const int4 *B_ptr = (const int4 *)B; - const int4 *meta_ptr = (const int4 *)meta; - int4 *C_ptr = (int4 *)C; - const int4 *s_ptr = (const int4 *)s; + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + const int4* meta_ptr = (const int4*)meta; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; - int *locks = (int *)workspace; + int* locks = (int*)workspace; for (int i = 0; i < tot_n_blocks; i += 4) { int thread_n_blocks = tot_n_blocks - i; prob_n = tot_n - 16 * i; @@ -946,8 +937,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_n_blocks - pad) / 64; - if (par > max_par) - par = max_par; + if (par > max_par) par = max_par; prob_n = 64 * par; i += 4 * (par - 1); thread_n_blocks = 4; @@ -956,16 +946,16 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, // For compilation speed, we only define the kernel configurations that have // seemed useful (in terms of performance) in our testing, however many more // are, in principle, possible. - + // the false is start of the CALL_IF macros - if (false) { - } // BMxBNxBK, group + if (false) { + } // BMxBNxBK, group // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 - CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(4, 16, 2, 2, 4) CALL_IF_2_4(4, 16, 3, 2, -1) CALL_IF_2_4(4, 16, 3, 2, 4) @@ -973,11 +963,11 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, CALL_IF_2_4(4, 16, 4, 2, 4) // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 - CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(8, 16, 2, 2, 4) CALL_IF_2_4(8, 16, 3, 2, -1) CALL_IF_2_4(8, 16, 3, 2, 4) @@ -997,12 +987,12 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, } } -} // namespace marlin_24 +} // namespace marlin_24 -torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, int64_t num_bits, +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k) { // Verify num_bits @@ -1037,9 +1027,9 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, " is not divisible by tile_size = " + str(marlin_24::tile_size)); int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; - TORCH_CHECK(size_n == actual_size_n, - "size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); // Verify meta TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, @@ -1081,7 +1071,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, ", is not divisible by b_scales.size(0) = " + str(b_scales.size(0))); groupsize = size_k / b_scales.size(0); - groupsize /= 2; // Because of 24 + groupsize /= 2; // Because of 24 } // Verify groupsize diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 09964903622b4..1b339fa4b392b 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) { // 4-bit matvec kernel (LUT-based) __global__ void NUQ4MatMulKernel( #ifndef USE_ROCM - const half2* __restrict__ vec, + const half2* __restrict__ vec, #else - const __half2* __restrict__ vec, + const __half2* __restrict__ vec, #endif - const int* __restrict__ mat, + const int* __restrict__ mat, #ifndef USE_ROCM - half2* __restrict__ mul, + half2* __restrict__ mul, #else - float2* __restrict__ mul, + float2* __restrict__ mul, #endif - const __half* __restrict__ lookup_table, - int height, - int width, - int batch, - int vec_height -) { + const __half* __restrict__ lookup_table, int height, int width, int batch, + int vec_height) { const int blockwidth2 = BLOCKWIDTH / 2; int row = BLOCKHEIGHT4 * blockIdx.x; - int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; #ifndef USE_ROCM __shared__ half2 blockvec[blockwidth2]; @@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel( unsigned int tmp1; unsigned int lut_index1, lut_index2; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b) { i = width * row + col; res = __int2half_rd(0); k = 0; __syncthreads(); if (threadIdx.x < blockwidth2) - blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x]; + blockvec[threadIdx.x] = + vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + + threadIdx.x]; __syncthreads(); while (k < blockwidth2) { @@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel( #ifndef USE_ROCM res = __hadd(__hadd(res2.x, res2.y), res); #else - res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); + res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), + res); #endif i += width; @@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel( } } -} // namespace squeezellm -} // namespace vllm +} // namespace squeezellm +} // namespace vllm // 4-bit matvec kernel (LUT-based) -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table -) { +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table) { int height = mat.size(0); int width = mat.size(1); int batch = vec.size(0); int vec_height = vec.size(1); - dim3 blocks( - (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, - (width + BLOCKWIDTH - 1) / BLOCKWIDTH - ); + dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH); dim3 threads(BLOCKWIDTH); const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); vllm::squeezellm::NUQ4MatMulKernel<<>>( #ifndef USE_ROCM - (half2*) vec.data(), + (half2*)vec.data(), #else - (__half2*) vec.data_ptr(), + (__half2*)vec.data_ptr(), #endif - mat.data_ptr(), + mat.data_ptr(), #ifndef USE_ROCM - (half2*) mul.data(), - (__half*) lookup_table.data(), + (half2*)mul.data(), (__half*)lookup_table.data(), #else - (float2*) mul.data_ptr(), - (__half*) lookup_table.data_ptr(), + (float2*)mul.data_ptr(), + (__half*)lookup_table.data_ptr(), #endif - height, width, batch, vec_height - ); + height, width, batch, vec_height); } #undef BLOCKWIDTH diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bb5171f854d55..9af4aae516151 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -20,12 +21,12 @@ #include "cuda_compat.h" namespace vllm { -template +template __inline__ __device__ T warpReduceSum(T val) { static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, "numLanes is not a positive power of 2!"); static_assert(numLanes <= WARP_SIZE); - #pragma unroll +#pragma unroll for (int mask = numLanes >> 1; mask > 0; mask >>= 1) val += VLLM_SHFL_XOR_SYNC(val, mask); return val; @@ -38,22 +39,23 @@ static constexpr int _nextPow2(unsigned int num) { } /* Calculate the sum of all elements in a block */ -template +template __inline__ __device__ T blockReduceSum(T val) { static_assert(maxBlockSize <= 1024); if constexpr (maxBlockSize > WARP_SIZE) { val = warpReduceSum(val); - // Calculates max number of lanes that need to participate in the last warpReduce + // Calculates max number of lanes that need to participate in the last + // warpReduce constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; static __shared__ T shared[maxActiveLanes]; int lane = threadIdx.x % WARP_SIZE; int wid = threadIdx.x / WARP_SIZE; - if (lane == 0) - shared[wid] = val; + if (lane == 0) shared[wid] = val; __syncthreads(); - val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f); + val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] + : (T)(0.0f); val = warpReduceSum(val); } else { // A single warpReduce is equal to blockReduce @@ -62,4 +64,4 @@ __inline__ __device__ T blockReduceSum(T val) { return val; } -} // namespace vllm +} // namespace vllm diff --git a/format.sh b/format.sh index 5f6e20256d404..aaec25a8aa0dc 100755 --- a/format.sh +++ b/format.sh @@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}') MYPY_VERSION=$(mypy --version | awk '{print $2}') CODESPELL_VERSION=$(codespell --version) ISORT_VERSION=$(isort --vn) +CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') # # params: tool name, tool version, required version tool_version_check() { @@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)" tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)" tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)" YAPF_FLAGS=( '--recursive' @@ -179,7 +181,6 @@ lint_changed() { } # Run Ruff -echo 'vLLM ruff:' ### This flag lints individual files. --files *must* be the first command line ### arg to use this option. if [[ "$1" == '--files' ]]; then @@ -192,6 +193,7 @@ else # Format only the files that changed in last commit. lint_changed fi +echo 'vLLM ruff: Done' # check spelling of specified files isort_check() { @@ -233,6 +235,59 @@ else fi echo 'vLLM isort: Done' +# Clang-format section +# Exclude some files for formatting because they are vendored +# NOTE: Keep up to date with .github/workflows/clang-format.yml +CLANG_FORMAT_EXCLUDES=( + 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' + 'csrc/punica/bgmv/bgmv_config.h' + 'csrc/punica/bgmv/bgmv_impl.cuh' + 'csrc/punica/bgmv/vec_dtypes.cuh' + 'csrc/punica/punica_ops.cu' + 'csrc/punica/type_convert.h' +) + +# Format specified files with clang-format +clang_format() { + clang-format -i "$@" +} + +# Format files that differ from main branch with clang-format. +clang_format_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause clang-format to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + # Get the list of changed files, excluding the specified ones + changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}")) + if [ -n "$changed_files" ]; then + echo "$changed_files" | xargs -P 5 clang-format -i + fi +} + +# Format all files with clang-format +clang_format_all() { + find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ + | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \ + | xargs clang-format -i +} + +# Run clang-format +if [[ "$1" == '--files' ]]; then + clang_format "${@:2}" +elif [[ "$1" == '--all' ]]; then + clang_format_all +else + clang_format_changed +fi +echo 'vLLM clang-format: Done' + + if ! git diff --quiet &>/dev/null; then echo 'Reformatted files. Please review and stage the changes.' echo 'Changes not staged for commit:' diff --git a/requirements-dev.txt b/requirements-dev.txt index 4f6c27d95fe6a..cf2bb9bef22d9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,6 +5,7 @@ tomli==2.0.1 ruff==0.1.5 codespell==2.2.6 isort==5.13.2 +clang-format==18.1.5 # type checking mypy==1.9.0