This repository has been archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 10
Marlin moe integration #266
Closed
Closed
Changes from 21 commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
15d0f20
Start working on linking the elements together
ElizaWszola 60c097c
Runs to completion
ElizaWszola 42750bc
Commit before moving things around in the kernel
ElizaWszola b312ad5
Fix shared memory issue
ElizaWszola 289cc8c
Fused marlin moe in model executor
ElizaWszola f7b106e
Various small fixes
ElizaWszola 6765bde
quantize outside fused_marlin_moe function
ElizaWszola 0886f76
Test shapes
ElizaWszola 8ee85dd
Better test, some passing
ElizaWszola cb4dbaa
Working expert chunks
ElizaWszola cd4410f
working
ElizaWszola e9ee483
Combined kernel with fused op, scores
ElizaWszola c839319
Debugging
ElizaWszola 5ac5ba6
Rand for everything
ElizaWszola 7a9f453
Working A copy
ElizaWszola c412ab8
cleanup
ElizaWszola 2d4ac6a
Merge branch 'main' into marlin-moe-integration
ElizaWszola ffd5f64
Try to fix pybind11 builds after merge
ElizaWszola a22272a
Continue work on bindings
ElizaWszola d2cba06
impl
ElizaWszola 4b5a9f1
typo fix
ElizaWszola 0991c74
all links but maxdiff (must fix assertions)
ElizaWszola 9dd97b5
it works
ElizaWszola c62bc7f
Merge branch 'main' into marlin-moe-integration
ElizaWszola e65c195
lots of debugging
ElizaWszola 5f1ebdb
Fixed types, it's working!
ElizaWszola 836d627
Cleanup
ElizaWszola d3665fa
Tiny cleanups
ElizaWszola 4bcfde6
Some renaming
ElizaWszola 7701ee7
Renaming, Bill's feedback
ElizaWszola b9fdda3
Format
ElizaWszola 1b05843
add act_order to marlin moe
ElizaWszola 8adcb26
Tensor constness, factor common fused_moe parts into smaller functions
ElizaWszola 6a3ef46
Spelling
ElizaWszola c676bea
single marlin test should still be disabled
ElizaWszola 53c8ff7
Pass unit tests
ElizaWszola 0fd62c8
integrate with model
ElizaWszola 458c83f
cleanups
ElizaWszola 400e35a
Merge branch 'marlin-moe-act_order' into marlin-moe-integration
ElizaWszola a0d7f77
Merge branch 'main' into marlin-moe-integration
ElizaWszola f629593
format
ElizaWszola c8b79f3
Merge branch 'main' into marlin-moe-integration
ElizaWszola e405879
Start work on integrating with refactored mixtral code
ElizaWszola cda9a0f
Runs to completion, but produces garbage
ElizaWszola c469b74
it works!
ElizaWszola d8b455f
Cleanup, format, minor fixes
ElizaWszola 7504696
more efficient m blocking, a couple small fixes
ElizaWszola 3641692
Multi-GPU works, but could make it faster
ElizaWszola File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#pragma once | ||
|
||
#include <torch/all.h> | ||
|
||
torch::Tensor marlin_gemm_moe(torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids, torch::Tensor& topk_weights, | ||
torch::Tensor& b_scales, torch::Tensor& expert_offsets, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, | ||
int64_t num_tokens_post_padded, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// #include <Python.h> | ||
|
||
// #include "moe_ops.h" | ||
// #include "marlin_moe_ops.h" | ||
|
||
// #include <torch/extension.h> | ||
|
||
// #include <pybind11/numpy.h> | ||
|
||
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
// // m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); | ||
// m.def("marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, Tensor! topk_weights, | ||
// Tensor! b_scales, py::array_t<int>& expert_offsets, Tensor! workspace, size_m, size_n, size_k, | ||
// num_tokens_post_padded, num_experts, topk, moe_block_size, replicate_input, apply_weights)") | ||
// m.impl("marlin_gemm_moe", torch::kCUDA, [](torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids, torch::Tensor& topk_weights, | ||
// torch::Tensor& b_scales, py::array_t<int>& expert_offsets, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, | ||
// int64_t num_tokens_post_padded, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights){ | ||
// py::buffer_info expert_offsets_bo = expert_offsets.request(); | ||
// return marlin_gemm_moe(a, b_q_weights, sorted_ids, topk_weights, b_scales, | ||
// static_cast<int*>(expert_offsets_bo.ptr), | ||
// workspace, size_m, size_n, size_k, num_tokens_post_padded, | ||
// num_experts, topk, moe_block_size, replicate_input, apply_weights); | ||
// }, "Marlin gemm moe kernel."); | ||
// } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,29 @@ | ||
#include "registration.h" | ||
#include "moe_ops.h" | ||
#include "marlin_moe_ops.h" | ||
|
||
#include <torch/library.h> | ||
|
||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { | ||
// Apply topk softmax to the gating outputs. | ||
m.def( | ||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " | ||
"token_expert_indices, Tensor gating_output) -> ()"); | ||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); | ||
|
||
m.def("marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, Tensor! topk_weights, " | ||
"Tensor! b_scales, Tensor! expert_offsets, Tensor! workspace, size_m, size_n, size_k, " | ||
"num_tokens_post_padded, num_experts, topk, moe_block_size, replicate_input, apply_weights) -> Tensor"); | ||
ElizaWszola marked this conversation as resolved.
Show resolved
Hide resolved
|
||
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); | ||
// m.impl("marlin_gemm_moe", torch::kCUDA, [](torch::Tensor& a, torch::Tensor& b_q_weights, torch::Tensor& sorted_ids, torch::Tensor& topk_weights, | ||
// torch::Tensor& b_scales, py::array_t<int>& expert_offsets, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, | ||
// int64_t num_tokens_post_padded, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights){ | ||
// py::buffer_info expert_offsets_bo = expert_offsets.request(); | ||
// return marlin_gemm_moe(a, b_q_weights, sorted_ids, topk_weights, b_scales, | ||
// static_cast<int*>(expert_offsets_bo.ptr), | ||
// workspace, size_m, size_n, size_k, num_tokens_post_padded, | ||
// num_experts, topk, moe_block_size, replicate_input, apply_weights); | ||
// }, "Marlin gemm moe kernel."); | ||
} | ||
|
||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how much of this file is copy-pasted from the original marlin code? Could we factor out common functions? It will make it much easier to review if we can see what the new code is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is quite a bit of overlap, and many of changes boil down to adding one variable or an extra condition here and there. I don't really want to refactor into common functions until
act_order
is done, because there might be more of these tiny modifications (or is it better to do the refactor now?).In any case, running a comparison of this file against
csrc/quantization/gptq_marlin/gptq_marlin.cu
helps seeing what changed.Edit: fixed file name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That’s fair for things that may be changed by act_reorder but any functions that are copied over unmodified should be factored out IMO