-
Notifications
You must be signed in to change notification settings - Fork 10
Conversation
return torch.stack(tensors, dim=0).to(dev) | ||
|
||
|
||
def fused_marlin_moe( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function does not need to adhere to the exact same interface as fused_moe
This function will be called on the hotpath. It should receive INT4 weights and scales and just call the marlin moe kernel directly
qweights1 = [] | ||
scaless1 = [] | ||
|
||
for i in range(w1.shape[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will not be called on the hotpath.
Rather, the quantized weights should be an input to this function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see comments in code
So a couple things. vLLM is layed out in the following way Models --> llama.py, which uses linear_layers like ColumnParallelLinear Each layer has a LinearMethod which handles the representation of the weights and the forward pass Each LinearMethod exposes the following interface: So, we will eventually want to create a LinearMethod for MarlinMoE As a result, the fused_moe_marlin kernel should recieve already quantized weights and just execute the computation. We should not be quantizing the weights inside of that function For this PR, we should land the kernel + testing code for the kernel. we can work on adding the LinearMethod afterwards |
@@ -477,3 +478,342 @@ def fused_moe( | |||
out=hidden_states) | |||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), | |||
dim=1) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per my comment below, we will load the already compressed weights via create_weights
.
So none of this will need to be called on the hotpath
As a result, all of this should be moved into testing utilities
tests/kernels/test_moe.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to add a test to make sure this works on other GPUs as well (we do this in the cutlass unit tests, if you want to replicate that here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean testing on different devices? (@pytest.mark.parametrize("device", CUDA_DEVICES)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes exactly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing anything on cuda:1
results in memory erros (illegal access) in moe_align_block_size_kernel
which I rely on, but didn't modify - should I look into it or is it ok to leave it for now?
# Check constraints. | ||
assert hidden_states.shape[0] == gating_output.shape[0], ( | ||
"Number of tokens mismatch") | ||
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is 16 a hardcoded block size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is related to Marlin format which is hardcoded
w1_s = self.experts[i].w1.get_parameter("scales").half() | ||
w3_s = self.experts[i].w3.get_parameter("scales").half() | ||
w2_qw = self.experts[i].w2.get_parameter("qweight").int() | ||
w2_s = self.experts[i].w2.get_parameter("scales").half() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these guaranteed to be fp16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From Pytorch documentation: self.half() is equivalent to self.to(torch.float16)
Scales are not necessarily fp16 when loaded.
csrc/moe/marlin_moe_ops.cu
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how much of this file is copy-pasted from the original marlin code? Could we factor out common functions? It will make it much easier to review if we can see what the new code is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is quite a bit of overlap, and many of changes boil down to adding one variable or an extra condition here and there. I don't really want to refactor into common functions until act_order
is done, because there might be more of these tiny modifications (or is it better to do the refactor now?).
In any case, running a comparison of this file against csrc/quantization/gptq_marlin/gptq_marlin.cu
helps seeing what changed.
Edit: fixed file name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That’s fair for things that may be changed by act_reorder but any functions that are copied over unmodified should be factored out IMO
Moved to the other repo |
Unit testing:
(requires to uncomment @pytest.mark.skip in
test_moe.py
).End-to-end testing:
Run
offline_inference.py
withand