Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic group blocks in Marlin MoE #11

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 79 additions & 75 deletions csrc/moe/marlin_kernels/marlin_moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool has_act_order // whether act_order is enabled
>
__device__ inline void MarlinMoESingle(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
Expand All @@ -261,6 +259,8 @@ __device__ inline void MarlinMoESingle(
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets,
int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
int num_groups, // number of scale groups per output channel
int expert_idx, // idx of current expert
int num_experts, // number of experts
Expand Down Expand Up @@ -289,8 +289,8 @@ __device__ inline void MarlinMoESingle(
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);

if constexpr (!has_act_order && group_blocks != -1) {
if (group_blocks >= thread_k_blocks) {
if constexpr (!has_act_order) {
if (group_blocks != -1 && group_blocks >= thread_k_blocks) {
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
Expand Down Expand Up @@ -384,11 +384,11 @@ __device__ inline void MarlinMoESingle(
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
// Scale size/strides with act_order
constexpr int tb_k = 16 * thread_k_blocks;
Expand Down Expand Up @@ -432,7 +432,7 @@ __device__ inline void MarlinMoESingle(
// No act_order
int s_gl_rd;
if constexpr (!has_act_order) {
if constexpr (group_blocks == -1) {
if (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
Expand All @@ -446,7 +446,7 @@ __device__ inline void MarlinMoESingle(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1)
if (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else
Expand Down Expand Up @@ -612,10 +612,10 @@ __device__ inline void MarlinMoESingle(
}
}
} else {
if constexpr (group_blocks != -1) {
if (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;

if constexpr (group_blocks >= thread_k_blocks) {
if (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
Expand Down Expand Up @@ -703,8 +703,8 @@ __device__ inline void MarlinMoESingle(

if constexpr (!has_act_order) {
// No act-order case
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
if (group_blocks != -1) {
if (group_blocks >= thread_k_blocks) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
Expand Down Expand Up @@ -824,7 +824,7 @@ __device__ inline void MarlinMoESingle(
scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
} else {
if constexpr (group_blocks != -1) {
if (group_blocks != -1) {
scale(frag_b0, frag_s[k % 2][j], 0);
}
}
Expand All @@ -835,7 +835,7 @@ __device__ inline void MarlinMoESingle(
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);

} else {
if constexpr (group_blocks != -1) {
if (group_blocks != -1) {
scale(frag_b1, frag_s[k % 2][j], 1);
}
}
Expand Down Expand Up @@ -1009,9 +1009,10 @@ __device__ inline void MarlinMoESingle(

// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4) {
res = __hmul2(res, s[0]);
if constexpr (!has_act_order && w_type.size_bits() == 4) {
if (group_blocks == -1) {
res = __hmul2(res, s[0]);
}
}

((half2*)sh)[idx] = res;
Expand Down Expand Up @@ -1140,52 +1141,59 @@ __device__ inline void MarlinMoESingle(
if (slice_iters == 0) {
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (!has_act_order) {
if constexpr (w_type.size_bits() == 8) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
} else {
// For 4-bit per-column scales, we only fetch them here in the
// final step before write-out
if (last) {
if (group_blocks == -1) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
}
} else {
if (group_blocks == -1) {
// For 4-bit per-column scales, we only fetch them here in the
// final step before write-out
if (last) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
}
}
}
}

thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (!has_act_order) {
if constexpr (w_type.size_bits() == 8) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}

} else {
if (last) {
if (group_blocks == -1) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
}

} else {
if (group_blocks == -1) {
if (last) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
}
}
}
}

// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
if constexpr (!has_act_order && w_type.size_bits() == 8) {
if (group_blocks == -1 && threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
Expand Down Expand Up @@ -1249,9 +1257,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool has_act_order // whether act_order is enabled
>
__global__ void MarlinMoE(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
Expand All @@ -1263,6 +1269,8 @@ __global__ void MarlinMoE(
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets,
int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
int num_groups, // number of scale groups per output channel
int expert_idx, // idx of current expert
int num_experts, // number of experts
Expand Down Expand Up @@ -1309,31 +1317,31 @@ __global__ void MarlinMoE(

if (max_block == 1) {
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk,
prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else if (max_block == 2) {
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk,
prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else if (max_block == 3) {
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk,
prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else {
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk,
prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
}
}
Expand All @@ -1346,9 +1354,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool has_act_order // whether act_order is enabled
>
__global__ void MarlinMoE(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
Expand All @@ -1360,6 +1366,8 @@ __global__ void MarlinMoE(
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets,
int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
int num_groups, // number of scale groups per output channel
int expert_idx, // idx of current expert
int num_experts, // number of experts
Expand Down Expand Up @@ -1396,30 +1404,26 @@ static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;

#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
GROUP_BLOCKS, NUM_THREADS) \
NUM_THREADS) \
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
has_act_order == HAS_ACT_ORDER && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute(MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
max_shared_mem); \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
STAGES, HAS_ACT_ORDER> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
g_idx_ptr, expert_offsets_ptr, group_blocks, num_groups, \
expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, \
locks, replicate_input, apply_weights, m_block, max_par, \
cfg_max_m_blocks); \
}

#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS)

} // namespace marlin_moe
Loading