Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Marlin moe integration #266

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
15d0f20
Start working on linking the elements together
ElizaWszola May 17, 2024
60c097c
Runs to completion
ElizaWszola May 21, 2024
42750bc
Commit before moving things around in the kernel
ElizaWszola May 22, 2024
b312ad5
Fix shared memory issue
ElizaWszola May 23, 2024
289cc8c
Fused marlin moe in model executor
ElizaWszola May 24, 2024
f7b106e
Various small fixes
ElizaWszola May 24, 2024
6765bde
quantize outside fused_marlin_moe function
ElizaWszola May 24, 2024
0886f76
Test shapes
ElizaWszola May 27, 2024
8ee85dd
Better test, some passing
ElizaWszola May 28, 2024
cb4dbaa
Working expert chunks
ElizaWszola Jun 3, 2024
cd4410f
working
ElizaWszola Jun 3, 2024
e9ee483
Combined kernel with fused op, scores
ElizaWszola Jun 11, 2024
c839319
Debugging
ElizaWszola Jun 12, 2024
5ac5ba6
Rand for everything
ElizaWszola Jun 12, 2024
7a9f453
Working A copy
ElizaWszola Jun 13, 2024
c412ab8
cleanup
ElizaWszola Jun 14, 2024
2d4ac6a
Merge branch 'main' into marlin-moe-integration
ElizaWszola Jun 14, 2024
ffd5f64
Try to fix pybind11 builds after merge
ElizaWszola Jun 14, 2024
a22272a
Continue work on bindings
ElizaWszola Jun 17, 2024
d2cba06
impl
ElizaWszola Jun 17, 2024
4b5a9f1
typo fix
ElizaWszola Jun 17, 2024
0991c74
all links but maxdiff (must fix assertions)
ElizaWszola Jun 17, 2024
9dd97b5
it works
ElizaWszola Jun 17, 2024
c62bc7f
Merge branch 'main' into marlin-moe-integration
ElizaWszola Jun 17, 2024
e65c195
lots of debugging
ElizaWszola Jun 24, 2024
5f1ebdb
Fixed types, it's working!
ElizaWszola Jun 24, 2024
836d627
Cleanup
ElizaWszola Jun 25, 2024
d3665fa
Tiny cleanups
ElizaWszola Jun 25, 2024
4bcfde6
Some renaming
ElizaWszola Jun 25, 2024
7701ee7
Renaming, Bill's feedback
ElizaWszola Jun 25, 2024
b9fdda3
Format
ElizaWszola Jun 25, 2024
1b05843
add act_order to marlin moe
ElizaWszola Jun 26, 2024
8adcb26
Tensor constness, factor common fused_moe parts into smaller functions
ElizaWszola Jun 26, 2024
6a3ef46
Spelling
ElizaWszola Jun 26, 2024
c676bea
single marlin test should still be disabled
ElizaWszola Jun 27, 2024
53c8ff7
Pass unit tests
ElizaWszola Jul 1, 2024
0fd62c8
integrate with model
ElizaWszola Jul 2, 2024
458c83f
cleanups
ElizaWszola Jul 2, 2024
400e35a
Merge branch 'marlin-moe-act_order' into marlin-moe-integration
ElizaWszola Jul 2, 2024
a0d7f77
Merge branch 'main' into marlin-moe-integration
ElizaWszola Jul 4, 2024
f629593
format
ElizaWszola Jul 4, 2024
c8b79f3
Merge branch 'main' into marlin-moe-integration
ElizaWszola Jul 11, 2024
e405879
Start work on integrating with refactored mixtral code
ElizaWszola Jul 11, 2024
cda9a0f
Runs to completion, but produces garbage
ElizaWszola Jul 12, 2024
c469b74
it works!
ElizaWszola Jul 12, 2024
d8b455f
Cleanup, format, minor fixes
ElizaWszola Jul 12, 2024
7504696
more efficient m blocking, a couple small fixes
ElizaWszola Jul 23, 2024
3641692
Multi-GPU works, but could make it faster
ElizaWszola Jul 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/torch_bindings.cpp")
"csrc/torch_bindings.cpp"
)

if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
Expand Down Expand Up @@ -214,7 +215,9 @@ define_gpu_extension_target(

set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/topk_softmax_kernels.cu")
"csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/marlin_moe_ops.cu"
)

define_gpu_extension_target(
_moe_C
Expand Down
1,320 changes: 1,320 additions & 0 deletions csrc/moe/marlin_moe_ops.cu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how much of this file is copy-pasted from the original marlin code? Could we factor out common functions? It will make it much easier to review if we can see what the new code is

Copy link
Author

@ElizaWszola ElizaWszola Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is quite a bit of overlap, and many of changes boil down to adding one variable or an extra condition here and there. I don't really want to refactor into common functions until act_order is done, because there might be more of these tiny modifications (or is it better to do the refactor now?).

In any case, running a comparison of this file against csrc/quantization/gptq_marlin/gptq_marlin.cu helps seeing what changed.

Edit: fixed file name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s fair for things that may be changed by act_reorder but any functions that are copied over unmodified should be factored out IMO

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

#include <torch/all.h>

torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids, torch::Tensor& topk_weights,
torch::Tensor& b_scales, torch::Tensor& expert_offsets, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
int64_t num_tokens_post_padded, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights);
24 changes: 24 additions & 0 deletions csrc/moe/moe_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// #include <Python.h>

// #include "moe_ops.h"
// #include "marlin_moe_ops.h"

// #include <torch/extension.h>

// #include <pybind11/numpy.h>

// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// // m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
// m.def("marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, Tensor! topk_weights,
// Tensor! b_scales, py::array_t<int>& expert_offsets, Tensor! workspace, size_m, size_n, size_k,
// num_tokens_post_padded, num_experts, topk, moe_block_size, replicate_input, apply_weights)")
// m.impl("marlin_gemm_moe", torch::kCUDA, [](torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids, torch::Tensor& topk_weights,
// torch::Tensor& b_scales, py::array_t<int>& expert_offsets, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
// int64_t num_tokens_post_padded, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights){
// py::buffer_info expert_offsets_bo = expert_offsets.request();
// return marlin_gemm_moe(a, b_q_weights, sorted_ids, topk_weights, b_scales,
// static_cast<int*>(expert_offsets_bo.ptr),
// workspace, size_m, size_n, size_k, num_tokens_post_padded,
// num_experts, topk, moe_block_size, replicate_input, apply_weights);
// }, "Marlin gemm moe kernel.");
// }
17 changes: 17 additions & 0 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
#include "registration.h"
#include "moe_ops.h"
#include "marlin_moe_ops.h"

#include <torch/library.h>

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);

m.def("marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, Tensor! topk_weights, "
"Tensor! b_scales, Tensor! expert_offsets, Tensor! workspace, size_m, size_n, size_k, "
"num_tokens_post_padded, num_experts, topk, moe_block_size, replicate_input, apply_weights) -> Tensor");
ElizaWszola marked this conversation as resolved.
Show resolved Hide resolved
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
// m.impl("marlin_gemm_moe", torch::kCUDA, [](torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids, torch::Tensor& topk_weights,
// torch::Tensor& b_scales, py::array_t<int>& expert_offsets, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
// int64_t num_tokens_post_padded, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights){
// py::buffer_info expert_offsets_bo = expert_offsets.request();
// return marlin_gemm_moe(a, b_q_weights, sorted_ids, topk_weights, b_scales,
// static_cast<int*>(expert_offsets_bo.ptr),
// workspace, size_m, size_n, size_k, num_tokens_post_padded,
// num_experts, topk, moe_block_size, replicate_input, apply_weights);
// }, "Marlin gemm moe kernel.");
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
4 changes: 4 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);

// torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids, torch::Tensor& topk_weights,
// torch::Tensor& b_scales, torch::Tensor& expert_offsets, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
// int64_t num_tokens_post_padded, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights);

torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
Expand Down
3 changes: 3 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("marlin_gemm", &marlin_gemm);
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);

// ops.def("marlin_gemm_moe", &marlin_gemm_moe);
// ops.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);

// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
Expand Down
Loading
Loading