Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Test shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed May 27, 2024
1 parent 6765bde commit 0886f76
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 254 deletions.
256 changes: 35 additions & 221 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -195,204 +195,6 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
}
}


// // For a given "a" of size [M,K] performs a permutation of the K columns based on the given "perm"
// // indices.
// __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
// int const* __restrict__ perm_int_ptr,
// int4* __restrict__ out_int4_ptr, int size_m, int size_k,
// int block_rows) {

// int start_row = block_rows * blockIdx.x;
// int finish_row = start_row + block_rows;
// if (finish_row > size_m) {
// finish_row = size_m;
// }
// int cur_block_rows = finish_row - start_row;

// int row_stride = size_k * sizeof(half) / 16;

// auto permute_row = [&](int row) {
// int iters = size_k / default_threads;
// int rest = size_k % default_threads;

// int offset = row * row_stride;

// half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
// half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);

// int base_k = 0;

// for (int i = 0; i < iters; i++) {
// int cur_k = base_k + threadIdx.x;
// int src_pos = perm_int_ptr[cur_k];

// out_half[cur_k] = a_row_half[src_pos];

// base_k += default_threads;
// }

// if (rest) {
// if (threadIdx.x < rest) {
// int cur_k = base_k + threadIdx.x;
// int src_pos = perm_int_ptr[cur_k];

// out_half[cur_k] = a_row_half[src_pos];
// }
// }
// };

// for (int i = 0; i < cur_block_rows; i++) {
// int cur_row = start_row + i;
// if (cur_row < size_m) {
// permute_row(cur_row);
// }
// }
// }

// void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* g_idx, void* perm,
// void* a_tmp, int prob_m, int prob_n, int prob_k, void* workspace,
// bool has_act_order, bool is_k_full, int num_groups, int group_size, int dev,
// cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par) {
// TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ",
// prob_k, "]");

// int tot_m = prob_m;
// int tot_m_blocks = ceildiv(tot_m, 16);
// int pad = 16 * tot_m_blocks - tot_m;

// if (sms == -1) {
// cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
// }

// // Set thread config
// thread_config_t th_config;
// if (thread_k != -1 && thread_n != -1) {
// // User-defined config
// th_config = thread_config_t{thread_k, thread_n, default_threads};
// } else {
// // Auto config
// th_config = determine_thread_config(prob_m, prob_n, prob_k);
// }

// TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k),
// "Invalid thread config: thread_k = " + str(th_config.thread_k) + ", thread_n = " +
// str(th_config.thread_n) + ", num_threads = " + str(th_config.num_threads) +
// " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");

// int num_threads = th_config.num_threads;
// thread_k = th_config.thread_k;
// thread_n = th_config.thread_n;

// int thread_k_blocks = thread_k / 16;
// int thread_n_blocks = thread_n / 16;

// int blocks = sms;

// TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
// " is not divisible by thread_n = ", thread_n);
// TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
// " is not divisible by thread_k = ", thread_k);

// int group_blocks = 0;
// if (has_act_order) {
// if (is_k_full) {
// TORCH_CHECK(group_size != -1);
// group_blocks = group_size / 16;
// TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
// " is not divisible by group_blocks = ", group_blocks);
// } else {
// TORCH_CHECK(group_size == 0);
// group_blocks = 0;
// }

// } else {
// if (group_size == -1) {
// group_blocks = -1;
// } else {
// group_blocks = group_size / 16;
// TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
// " is not divisible by group_blocks = ", group_blocks);
// }
// }

// const int4* A_ptr = (const int4*)A;
// const int4* B_ptr = (const int4*)B;
// int4* C_ptr = (int4*)C;
// const int4* s_ptr = (const int4*)s;
// const int* g_idx_ptr = (const int*)g_idx;
// const int* perm_ptr = (const int*)perm;
// int4* a_tmp_ptr = (int4*)a_tmp;

// int* locks = (int*)workspace;

// if (has_act_order) {
// // Permute A columns
// int block_rows = ceildiv(prob_m, blocks);
// permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(A_ptr, perm_ptr, a_tmp_ptr, prob_m,
// prob_k, block_rows);
// A_ptr = a_tmp_ptr;
// }

// // If we have a full K, then we can run the non-act-order version of Marlin (since the weight rows
// // are reordered by increasing group ids, and by having a full K, we have full original groups)
// if (is_k_full) {
// has_act_order = false;
// }

// // Main loop
// for (int i = 0; i < tot_m_blocks; i += 4) {
// int thread_m_blocks = tot_m_blocks - i;
// prob_m = tot_m - 16 * i;
// int par = 1;
// if (thread_m_blocks > 4) {
// // Note that parallel > 1 currently only works for inputs without any padding
// par = (16 * thread_m_blocks - pad) / 64;
// if (par > max_par)
// par = max_par;
// prob_m = 64 * par;
// i += 4 * (par - 1);
// thread_m_blocks = 4;
// }

// // Define kernel configurations
// if (false) {
// }
// CALL_IF(16, 4, 256)
// CALL_IF(8, 8, 256)
// CALL_IF(8, 4, 128)
// CALL_IF(4, 8, 128)
// else {
// TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " +
// str(prob_k) + "]" + ", has_act_order = " + str(has_act_order) +
// ", num_groups = " + str(num_groups) + ", group_size = " +
// str(group_size) + ", thread_m_blocks = " + str(thread_m_blocks) +
// ", thread_n_blocks = " + str(thread_n_blocks) +
// ", thread_k_blocks = " + str(thread_k_blocks));
// }

// A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
// C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
// }
// }

// // TODO make this batched and multiblock
// __global__ void
// ZeroOutput(int4* __restrict__ C, // fp16 output buffer of shape mxn
// int prob_m, // output dimension m
// int prob_n // output dimension n
// ) {
// int stride = blockDim.x;
// int size = prob_m * prob_n / 8;
// #pragma unroll
// for (int i = threadIdx.x; i < size; i += stride) {
// __half* ctrg = reinterpret_cast<__half*>(&C[i]);
// for (int j = 0; j < 8; ++j) {
// ctrg[j] = __float2half(0.0f);
// }
// }
// }

template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the
// threadblock
Expand All @@ -407,11 +209,14 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn /// TODO offset B to the beginning of right expert and use this as the func argument
int4* __restrict__ C, // fp16 output buffer of shape mxn
int* __restrict__ sorted_ids, // int32 sorted ids of experts
int* __restrict__ topk_ids, // int32 topk ids
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape (k/groupsize)xn
int4* __restrict__ red_tmp, // extra tmp buffer for computing reductions of shape moe_block_sizexn
int num_groups, // number of scale groups per output channel
int num_tokens_post_padded, // scales_ptrs size with padding
int expert_idx, // idx of current expert
int num_experts, // number of experts
int topk, // topk parameter of moe
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
Expand Down Expand Up @@ -703,12 +508,17 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
for (int i = 0; i < a_sh_wr_iters; i++) {
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
int row = a_idx / a_gl_stride;
int sorted_row = sh_sorted[row];
int sorted_row = sh_sorted[row] / topk;
int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
// printf("LOAD ROW: %d -> %d\n", row, sorted_row);
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]],
&A[new_idx],
a_sh_wr_pred[i]);
if (sorted_row >= 0 && sorted_row < num_experts) {
// printf("LOAD ROW: %d -> %d -> %d\n", row, sh_sorted[row], sorted_row);
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]],
&A[new_idx],
a_sh_wr_pred[i]);
}
// else {
// printf("CAN'T LOAD ROW: %d -> %d -> %d\n", row, sh_sorted[row], sorted_row);
// }
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
Expand Down Expand Up @@ -1157,11 +967,14 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn /// TODO offset B to the beginning of right expert and use this as the func argument
int4* __restrict__ C, // fp16 output buffer of shape mxn
int* __restrict__ sorted_ids, // int32 sorted ids of experts
int* __restrict__ topk_ids, // int32 topk ids
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape (k/groupsize)xn
int4* __restrict__ red_tmp, // extra tmp buffer for computing reductions of shape moe_block_sizexn
int num_groups, // number of scale groups per output channel
int num_tokens_post_padded, // scales_ptrs size with padding
int expert_idx, // idx of current expert
int num_experts, // number of experts
int topk, // topk parameter of moe
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
Expand Down Expand Up @@ -1203,8 +1016,8 @@ static constexpr int pack_factor_4bit =
printf("%d %d %d %d %d\n", THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS); \
MarlinMoE<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, \
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, s_ptr, red_tmp_ptr, \
num_groups, num_tokens_post_padded, expert_idx, \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_ids_ptr, s_ptr, red_tmp_ptr, \
num_groups, num_tokens_post_padded, expert_idx, num_experts, topk, \
prob_m, prob_n, prob_k, tot_m, locks); \
}

Expand Down Expand Up @@ -1314,10 +1127,10 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)

void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids, void* s,
void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids, void* topk_ids, void* s,
void* red_tmp, int prob_m, int prob_n, int prob_k, void* workspace,
int num_groups, int group_size,
int num_tokens_post_padded, int num_experts, int moe_block_size,
int num_tokens_post_padded, int num_experts, int topk, int moe_block_size,
int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par) {

// #if defined(__CUDA_ARCH__)
Expand All @@ -1329,8 +1142,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids
// void* g_idx = (void*)nullptr;
// void* perm = (void*)nullptr;

// MarlinMoE<<<1, 1>>>();

TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ",
prob_k, "]");

Expand Down Expand Up @@ -1398,6 +1209,7 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids
const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx;
int4* C_ptr = (int4*)C;
int* sorted_ids_ptr = (int*)sorted_ids + moe_stride * expert_idx;
int* topk_ids_ptr = (int*)topk_ids; // TODO adjust
const int4* s_ptr = (const int4*)s + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * prob_n / 8) * expert_idx;
// const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
// const int* perm_ptr = (const int*)perm;
Expand Down Expand Up @@ -1454,26 +1266,27 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids

} // namespace marlin

torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids,
torch::Tensor& b_scales, /*torch::Tensor& g_idx, torch::Tensor& perm, */torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k/*, bool is_k_full,
int64_t num_tokens_post_padded, int64_t num_experts, int64_t moe_block_size*/)
torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids, torch::Tensor& topk_ids,
torch::Tensor& b_scales, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
int64_t num_tokens_post_padded, int64_t num_experts, int64_t topk, int64_t moe_block_size)
{
int version = 0;
cudaRuntimeGetVersion(&version);
printf("cuda v: %d\n", version);

int64_t num_tokens_post_padded = 64;
int64_t num_experts = 1;
int64_t moe_block_size = 16;
// int64_t num_tokens_post_padded = 64;
// int64_t num_experts = 1;
// int64_t moe_block_size = 16;

// topk = 1; // TODO temporary

int max_par = 4;

int dev = a.get_device();

auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c = torch::zeros({size_m, size_n}, options);
torch::Tensor red_tmp = torch::empty({size_m, size_n}, options);
torch::Tensor c = torch::zeros({size_m, topk, size_n}, options);
torch::Tensor red_tmp = torch::empty({size_m, topk, size_n}, options);

// thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1)
int thread_k = -1;
Expand All @@ -1487,7 +1300,7 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc
int group_size = -1;

int b_rank = b_scales.sizes().size();
TORCH_CHECK(b_rank == 2 || b_rank == 3, "b_scales rank = ", b_rank, " is not 2");
TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3");
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
" is not size_n = ", size_n);
num_groups = b_scales.size(1);
Expand All @@ -1501,10 +1314,11 @@ torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torc
group_size = -1;
}

marlin::marlin_mm_moe_f16i4(a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), b_scales.data_ptr(),
marlin::marlin_mm_moe_f16i4(a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(),
sorted_ids.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
red_tmp.data_ptr(), size_m, size_n, size_k,
workspace.data_ptr(), num_groups, group_size,
num_tokens_post_padded, num_experts, moe_block_size,
num_tokens_post_padded, num_experts, topk, moe_block_size,
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par);
return c;
}
7 changes: 3 additions & 4 deletions csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k/*,
int64_t num_tokens_post_padded, int64_t num_experts, int64_t moe_block_size*/);
torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids, torch::Tensor& topk_ids,
torch::Tensor& b_scales, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
int64_t num_tokens_post_padded, int64_t num_experts, int64_t topk, int64_t moe_block_size);
Loading

1 comment on commit 0886f76

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bigger_is_better

Benchmark suite Current: 0886f76 Previous: a6b9443 Ratio
{"name": "request_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.2.0", "python_version": "3.10.12 (main, May 10 2024, 13:42:25) [GCC 9.4.0]", "torch_version": "2.3.0+cu121"} 3.838075436131264 prompts/s
{"name": "token_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.2.0", "python_version": "3.10.12 (main, May 10 2024, 13:42:25) [GCC 9.4.0]", "torch_version": "2.3.0+cu121"} 1473.8209674744055 tokens/s

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.