From b312ad59529108ac9f66441be9cb142af66fafd1 Mon Sep 17 00:00:00 2001 From: Eliza Wszola Date: Thu, 23 May 2024 13:13:47 +0000 Subject: [PATCH] Fix shared memory issue --- csrc/moe/marlin_moe_ops.cu | 160 ++++++++++++++++--------------------- 1 file changed, 71 insertions(+), 89 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 0d2ff16c920e3..5791dd832d089 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -551,6 +551,7 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; + // printf("B: %d / %d\n", group_blocks, thread_k_blocks); constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -613,7 +614,9 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk int sh_num_groups = -1; constexpr int sh_max_num_groups = 32; - int shs_size = stages * s_sh_stage > 0 ? stages * s_sh_stage : -stages * s_sh_stage; + // int shs_size = sh_max_num_groups * s_sh_stride + threads; + int shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + // printf("SHS size: %d\n", shs_size); extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. @@ -622,6 +625,8 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk int4* sh_s = sh_b + (stages * b_sh_stage); int* sh_sorted = (int*)(sh_s + shs_size); + // printf("sh arrays: %d %d %d\n", stages * a_sh_stage, stages * b_sh_stage, shs_size); + // Precompute which thread should not read memory in which iterations; this is needed if there are // more threads than required for a certain tilesize or when the batchsize is not a multiple // of 16. @@ -689,37 +694,6 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk reinterpret_cast(frag_c)[i] = 0; }; - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred( - &sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; - } - } - } - }; // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline // location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { @@ -731,7 +705,7 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk int row = a_idx / a_gl_stride; int sorted_row = sh_sorted[row]; int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - printf("LOAD ROW: %d -> %d\n", row, sorted_row); + // printf("LOAD ROW: %d -> %d\n", row, sorted_row); cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], a_sh_wr_pred[i]); @@ -773,10 +747,11 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk const int mpt = ceildiv(prob_m, threads); for (int i = 0; i < mpt; i++) { if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { - // printf("load %d -> %d to shared sorted, eid: %d\n", (i * sorted_gl_stride) + threadIdx.x, - // sorted_ids[(i * sorted_gl_stride) + threadIdx.x], expert_idx); sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; + // printf("load %d -> %d to shared sorted, eid: %d (%d / %d)\n", (i * sorted_gl_stride) + threadIdx.x, + // sorted_ids[(i * sorted_gl_stride) + threadIdx.x], expert_idx, (i * sorted_sh_stride) + threadIdx.x, + // sh_sorted[(i * sorted_sh_stride) + threadIdx.x]); } } }; @@ -1039,18 +1014,23 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { if (c_gl_wr < c_gl_wr_end) { - int row = sh_sorted[c_gl_wr / c_gl_stride]; - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - // HERE we read from sh, how is the access different? - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - - for (int j = 0; j < 8; ++j) { - __half old = ctrg[j]; - ctrg[j] = __float2half(__half2float(old) + __half2float(csrc[j])); + if (c_gl_wr / c_gl_stride < prob_m) { + int row = sh_sorted[c_gl_wr / c_gl_stride]; + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + // HERE we read from sh, how is the access different? + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + // printf("c offset: %d at row %d from %d (%d %d)\n", off, row, c_gl_wr / c_gl_stride, threadIdx.x, blockIdx.x); + for (int j = 0; j < 8; ++j) { + // printf("csrc %f\n", __half2float(csrc[j])); + // printf("ctrg %f\n", __half2float(ctrg[j])); + // printf("csrc %f, ctrg %f\n", __half2float(csrc[j]), __half2float(ctrg[j])); + __half old = ctrg[j]; + ctrg[j] = __float2half(__half2float(old) + __half2float(csrc[j])); + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; } } }; @@ -1115,6 +1095,7 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk if constexpr (group_blocks == -1) { if (last) { if (s_sh_wr_pred) { + // printf("COPY SCALES HERE\n"); cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); @@ -1132,7 +1113,6 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk } } } - if (slice_count > 1) { // only globally reduce if there is more than one block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); @@ -1200,11 +1180,10 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n +const int USER_THREADS = 256; // Note: This is only used with user-provided thread_k/n const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) +// const int SHARED_MEM = +// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; @@ -1222,9 +1201,10 @@ static constexpr int pack_factor_4bit = group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ cudaFuncSetAttribute(MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, SHARED_MEM); \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + printf("%d %d %d %d %d\n", THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS); \ MarlinMoE<<>>( \ + GROUP_BLOCKS><<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, s_ptr, red_tmp_ptr, \ num_groups, num_tokens_post_padded, expert_idx, \ prob_m, prob_n, prob_k, tot_m, locks); \ @@ -1308,36 +1288,36 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } + // #define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + // __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + // __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \ + // __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \ + // __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + // \ + // __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + // __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \ + // __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \ + // __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + // \ + // __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + // __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \ + // __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \ + // __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + // \ + // __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + // __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \ + // __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \ + // __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) + + #define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - - -// // #define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ -// // __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ -// // __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ -// // __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ -// // __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids, void* s, - void* a_tmp, void* red_tmp, int prob_m, int prob_n, int prob_k, void* workspace, + void* red_tmp, int prob_m, int prob_n, int prob_k, void* workspace, int num_groups, int group_size, int num_tokens_post_padded, int num_experts, int moe_block_size, int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par) { @@ -1402,6 +1382,11 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids " is not divisible by group_blocks = ", group_blocks); } + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + // ZeroOutput<<<1, num_threads>>>((int4*)C, tot_m, prob_n); // TODO get scales size accurately @@ -1416,7 +1401,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids const int4* s_ptr = (const int4*)s;// + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * prob_n / 8) * expert_idx; // const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; // const int* perm_ptr = (const int*)perm; - int4* a_tmp_ptr = (int4*)a_tmp; int4* red_tmp_ptr = (int4*)red_tmp; // const __half* shalf = reinterpret_cast(s_ptr); @@ -1475,9 +1459,9 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc int64_t size_m, int64_t size_n, int64_t size_k/*, bool is_k_full, int64_t num_tokens_post_padded, int64_t num_experts, int64_t moe_block_size*/) { - int version = 0; - cudaRuntimeGetVersion(&version); - printf("cuda v: %d\n", version); + int version = 0; + cudaRuntimeGetVersion(&version); + printf("cuda v: %d\n", version); int64_t num_tokens_post_padded = 64; int64_t num_experts = 1; @@ -1488,13 +1472,10 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc int dev = a.get_device(); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - // TODO make this torch.empty when we have a way to figure out how the kernel knows that we - // write to a C row for the first time torch::Tensor c = torch::zeros({size_m, size_n}, options); - torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); torch::Tensor red_tmp = torch::empty({size_m, size_n}, options); - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1) + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1) int thread_k = -1; // thread_n: `n` size of a thread_tile in `weights` (can usually be left as auto -1) int thread_n = -1; @@ -1510,6 +1491,7 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), " is not size_n = ", size_n); num_groups = b_scales.size(1); + // printf("NUM GROUPS: %d\n", num_groups); if (num_groups > 1) { TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, @@ -1520,7 +1502,7 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc } marlin::marlin_mm_moe_f16i4(a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), b_scales.data_ptr(), - a_tmp.data_ptr(), red_tmp.data_ptr(), size_m, size_n, size_k, + red_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_groups, group_size, num_tokens_post_padded, num_experts, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par);