Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Fix shared memory issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed May 23, 2024
1 parent 42750bc commit b312ad5
Showing 1 changed file with 71 additions and 89 deletions.
160 changes: 71 additions & 89 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1;
// printf("B: %d / %d\n", group_blocks, thread_k_blocks);
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;

Expand Down Expand Up @@ -613,7 +614,9 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;

int shs_size = stages * s_sh_stage > 0 ? stages * s_sh_stage : -stages * s_sh_stage;
// int shs_size = sh_max_num_groups * s_sh_stride + threads;
int shs_size = group_blocks > 0 ? stages * s_sh_stage : threads;
// printf("SHS size: %d\n", shs_size);

extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines.
Expand All @@ -622,6 +625,8 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
int4* sh_s = sh_b + (stages * b_sh_stage);
int* sh_sorted = (int*)(sh_s + shs_size);

// printf("sh arrays: %d %d %d\n", stages * a_sh_stage, stages * b_sh_stage, shs_size);

// Precompute which thread should not read memory in which iterations; this is needed if there are
// more threads than required for a certain tilesize or when the batchsize is not a multiple
// of 16.
Expand Down Expand Up @@ -689,37 +694,6 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
reinterpret_cast<float*>(frag_c)[i] = 0;
};

auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) {
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;

if (sh_num_groups < sh_max_num_groups) {
sh_num_groups = sh_max_num_groups;
}

if (sh_first_group_id + sh_num_groups > num_groups) {
sh_num_groups = num_groups - sh_first_group_id;
}

int row_offset = first_group_id * s_gl_stride;

if (is_async) {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
cp_async4_pred(
&sh_s[(i * s_sh_stride) + threadIdx.x],
&scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]);
}
}
} else {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
sh_s[(i * s_sh_stride) + threadIdx.x] =
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline
// location.
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
Expand All @@ -731,7 +705,7 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
int row = a_idx / a_gl_stride;
int sorted_row = sh_sorted[row];
int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
printf("LOAD ROW: %d -> %d\n", row, sorted_row);
// printf("LOAD ROW: %d -> %d\n", row, sorted_row);
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]],
&A[new_idx],
a_sh_wr_pred[i]);
Expand Down Expand Up @@ -773,10 +747,11 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int mpt = ceildiv(prob_m, threads);
for (int i = 0; i < mpt; i++) {
if ((i * sorted_gl_stride) + threadIdx.x < prob_m) {
// printf("load %d -> %d to shared sorted, eid: %d\n", (i * sorted_gl_stride) + threadIdx.x,
// sorted_ids[(i * sorted_gl_stride) + threadIdx.x], expert_idx);
sh_sorted[(i * sorted_sh_stride) + threadIdx.x] =
sorted_ids[(i * sorted_gl_stride) + threadIdx.x];
// printf("load %d -> %d to shared sorted, eid: %d (%d / %d)\n", (i * sorted_gl_stride) + threadIdx.x,
// sorted_ids[(i * sorted_gl_stride) + threadIdx.x], expert_idx, (i * sorted_sh_stride) + threadIdx.x,
// sh_sorted[(i * sorted_sh_stride) + threadIdx.x]);
}
}
};
Expand Down Expand Up @@ -1039,18 +1014,23 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
#pragma unroll
for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
if (c_gl_wr < c_gl_wr_end) {
int row = sh_sorted[c_gl_wr / c_gl_stride];
int off = row * c_gl_stride + c_gl_wr % c_gl_stride;
__half* ctrg = reinterpret_cast<__half*>(&C[off]);
// HERE we read from sh, how is the access different?
__half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]);

for (int j = 0; j < 8; ++j) {
__half old = ctrg[j];
ctrg[j] = __float2half(__half2float(old) + __half2float(csrc[j]));
if (c_gl_wr / c_gl_stride < prob_m) {
int row = sh_sorted[c_gl_wr / c_gl_stride];
int off = row * c_gl_stride + c_gl_wr % c_gl_stride;
__half* ctrg = reinterpret_cast<__half*>(&C[off]);
// HERE we read from sh, how is the access different?
__half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]);
// printf("c offset: %d at row %d from %d (%d %d)\n", off, row, c_gl_wr / c_gl_stride, threadIdx.x, blockIdx.x);
for (int j = 0; j < 8; ++j) {
// printf("csrc %f\n", __half2float(csrc[j]));
// printf("ctrg %f\n", __half2float(ctrg[j]));
// printf("csrc %f, ctrg %f\n", __half2float(csrc[j]), __half2float(ctrg[j]));
__half old = ctrg[j];
ctrg[j] = __float2half(__half2float(old) + __half2float(csrc[j]));
}
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
}
};
Expand Down Expand Up @@ -1115,6 +1095,7 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
if constexpr (group_blocks == -1) {
if (last) {
if (s_sh_wr_pred) {
// printf("COPY SCALES HERE\n");
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
Expand All @@ -1132,7 +1113,6 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
}
}
}

if (slice_count > 1) { // only globally reduce if there is more than one block in a slice
barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last);
Expand Down Expand Up @@ -1200,11 +1180,10 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
const int USER_THREADS =
256; // Note: This is only used with user-provided thread_k/n
const int USER_THREADS = 256; // Note: This is only used with user-provided thread_k/n
const int STAGES = 4; // 4 pipeline stages fit into shared memory
const int SHARED_MEM =
96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
// const int SHARED_MEM =
// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)

static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
Expand All @@ -1222,9 +1201,10 @@ static constexpr int pack_factor_4bit =
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute(MarlinMoE<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, SHARED_MEM); \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
printf("%d %d %d %d %d\n", THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS); \
MarlinMoE<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, \
GROUP_BLOCKS><<<blocks, NUM_THREADS, SHARED_MEM, stream>>>( \
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, s_ptr, red_tmp_ptr, \
num_groups, num_tokens_post_padded, expert_idx, \
prob_m, prob_n, prob_k, tot_m, locks); \
Expand Down Expand Up @@ -1308,36 +1288,36 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
return thread_config_t{-1, -1, -1};
}

// #define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
// __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
// __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \
// __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \
// __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
// \
// __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
// __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \
// __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \
// __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
// \
// __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
// __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \
// __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \
// __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
// \
// __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
// __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \
// __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \
// __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)


#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
\
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
\
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
\
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)


// // #define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
// // __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
// // __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
// // __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
// // __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)

void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids, void* s,
void* a_tmp, void* red_tmp, int prob_m, int prob_n, int prob_k, void* workspace,
void* red_tmp, int prob_m, int prob_n, int prob_k, void* workspace,
int num_groups, int group_size,
int num_tokens_post_padded, int num_experts, int moe_block_size,
int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par) {
Expand Down Expand Up @@ -1402,6 +1382,11 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids
" is not divisible by group_blocks = ", group_blocks);
}

int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);

// ZeroOutput<<<1, num_threads>>>((int4*)C, tot_m, prob_n);

// TODO get scales size accurately
Expand All @@ -1416,7 +1401,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids
const int4* s_ptr = (const int4*)s;// + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * prob_n / 8) * expert_idx;
// const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
// const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp;
int4* red_tmp_ptr = (int4*)red_tmp;

// const __half* shalf = reinterpret_cast<const __half*>(s_ptr);
Expand Down Expand Up @@ -1475,9 +1459,9 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc
int64_t size_m, int64_t size_n, int64_t size_k/*, bool is_k_full,
int64_t num_tokens_post_padded, int64_t num_experts, int64_t moe_block_size*/)
{
int version = 0;
cudaRuntimeGetVersion(&version);
printf("cuda v: %d\n", version);
int version = 0;
cudaRuntimeGetVersion(&version);
printf("cuda v: %d\n", version);

int64_t num_tokens_post_padded = 64;
int64_t num_experts = 1;
Expand All @@ -1488,13 +1472,10 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc
int dev = a.get_device();

auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
// TODO make this torch.empty when we have a way to figure out how the kernel knows that we
// write to a C row for the first time
torch::Tensor c = torch::zeros({size_m, size_n}, options);
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
torch::Tensor red_tmp = torch::empty({size_m, size_n}, options);

// thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1)
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as auto -1)
int thread_n = -1;
Expand All @@ -1510,6 +1491,7 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
" is not size_n = ", size_n);
num_groups = b_scales.size(1);
// printf("NUM GROUPS: %d\n", num_groups);

if (num_groups > 1) {
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
Expand All @@ -1520,7 +1502,7 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc
}

marlin::marlin_mm_moe_f16i4(a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), b_scales.data_ptr(),
a_tmp.data_ptr(), red_tmp.data_ptr(), size_m, size_n, size_k,
red_tmp.data_ptr(), size_m, size_n, size_k,
workspace.data_ptr(), num_groups, group_size,
num_tokens_post_padded, num_experts, moe_block_size,
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par);
Expand Down

1 comment on commit b312ad5

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bigger_is_better

Benchmark suite Current: b312ad5 Previous: a10b831 Ratio
{"name": "request_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.2.0", "python_version": "3.10.12 (main, May 10 2024, 13:42:25) [GCC 9.4.0]", "torch_version": "2.3.0+cu121"} 3.839382272706773 prompts/s
{"name": "token_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.2.0", "python_version": "3.10.12 (main, May 10 2024, 13:42:25) [GCC 9.4.0]", "torch_version": "2.3.0+cu121"} 1474.3227927194007 tokens/s

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.