Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Fused marlin moe in model executor
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed May 24, 2024
1 parent b312ad5 commit 289cc8c
Show file tree
Hide file tree
Showing 4 changed files with 417 additions and 42 deletions.
34 changes: 16 additions & 18 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1014,23 +1014,21 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
#pragma unroll
for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
if (c_gl_wr < c_gl_wr_end) {
if (c_gl_wr / c_gl_stride < prob_m) {
int row = sh_sorted[c_gl_wr / c_gl_stride];
int off = row * c_gl_stride + c_gl_wr % c_gl_stride;
__half* ctrg = reinterpret_cast<__half*>(&C[off]);
// HERE we read from sh, how is the access different?
__half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]);
// printf("c offset: %d at row %d from %d (%d %d)\n", off, row, c_gl_wr / c_gl_stride, threadIdx.x, blockIdx.x);
for (int j = 0; j < 8; ++j) {
// printf("csrc %f\n", __half2float(csrc[j]));
// printf("ctrg %f\n", __half2float(ctrg[j]));
// printf("csrc %f, ctrg %f\n", __half2float(csrc[j]), __half2float(ctrg[j]));
__half old = ctrg[j];
ctrg[j] = __float2half(__half2float(old) + __half2float(csrc[j]));
}
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
int row = sh_sorted[c_gl_wr / c_gl_stride];
int off = row * c_gl_stride + c_gl_wr % c_gl_stride;
__half* ctrg = reinterpret_cast<__half*>(&C[off]);
// HERE we read from sh, how is the access different?
__half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]);
// printf("c offset: %d at row %d from %d (%d %d)\n", off, row, c_gl_wr / c_gl_stride, threadIdx.x, blockIdx.x);
for (int j = 0; j < 8; ++j) {
// printf("csrc %f\n", __half2float(csrc[j]));
// printf("ctrg %f\n", __half2float(ctrg[j]));
// printf("csrc %f, ctrg %f\n", __half2float(csrc[j]), __half2float(ctrg[j]));
__half old = ctrg[j];
ctrg[j] = __float2half(__half2float(old) + __half2float(csrc[j]));
}
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
}
};
Expand Down Expand Up @@ -1395,10 +1393,10 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, void* sorted_ids
for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
// printf("init ptrs for expert %d and gs %d\n", expert_idx, group_size);
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;// + (prob_n * prob_k / 32) * expert_idx;
const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx;
int4* C_ptr = (int4*)C;
int* sorted_ids_ptr = (int*)sorted_ids;// + moe_block_size * expert_idx;
const int4* s_ptr = (const int4*)s;// + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * prob_n / 8) * expert_idx;
const int4* s_ptr = (const int4*)s + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * prob_n / 8) * expert_idx;
// const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
// const int* perm_ptr = (const int*)perm;
int4* red_tmp_ptr = (int4*)red_tmp;
Expand Down
88 changes: 67 additions & 21 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import pytest
import torch
import numpy
from typing import List
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import fused_moe, fused_marlin_moe
from vllm.model_executor.models.mixtral import MixtralMoE

"""
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
Expand All @@ -24,13 +24,13 @@ def torch_moe(a, w1, w2, score, topk):
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
# print(mask)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)

@pytest.mark.parametrize("m", [512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
Expand Down Expand Up @@ -59,8 +59,8 @@ def test_fused_moe(
[torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype):
"Make sure our Mixtral MoE implementation agrees with the one from
huggingface."
"Make sure our Mixtral MoE implementation agrees with the one from"
"huggingface."

# Instantiate our and huggingface's MoE blocks
config = MixtralConfig()
Expand Down Expand Up @@ -102,8 +102,6 @@ def test_mixtral_moe(dtype: torch.dtype):
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])

"""

def get_marlin_perms():
perm = []
for i in range(32):
Expand Down Expand Up @@ -162,6 +160,8 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
dtype=numpy.uint32)

print("PACKED:", q_w.shape, ">>", q_packed.shape)

for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i

Expand Down Expand Up @@ -236,6 +236,7 @@ def marlin_quantize(
num_bits: int,
group_size: int,
):
print("START:", w.size(), num_bits, group_size)
perm, scale_perm, scale_perm_single = get_marlin_perms()

print("SHAPE:", w.shape)
Expand All @@ -249,6 +250,7 @@ def marlin_quantize(

# Quantize
w_ref, q_w, s = quantize_weights(w, num_bits, group_size)
print("interm:", w_ref.size(), q_w.size(), s.size())

#TODO experts
# Reformat to marlin
Expand All @@ -268,35 +270,79 @@ def marlin_quantize(

import vllm._moe_C as moe_kernels

def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)

def test_fused_marlin_moe():
m = 16
n = 128
k = 128
e = 1
m = 256
n = 512
k = 512
e = 8
topk = 2
dtype = torch.float16
moe_block_size = 16
e_m = 8

a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((k, n), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
w_m = torch.randn((e_m, k, n), device='cuda', dtype=dtype) / 10

a_2d = a.view(-1, a.shape[-1])

w_ref, qweight, scales = marlin_quantize(w1, 4, -1)
qweight = qweight.unsqueeze(0)
scales = scales.unsqueeze(0)
w_refs = []
qweights = []
scaless = []

for i in range(w_m.shape[0]):
w_ref, qweight, scales = marlin_quantize(w_m[i], 4, -1)
w_refs.append(w_ref)
qweights.append(qweight)
scaless.append(scales)

w_ref = stack_and_dev(w_refs)
qweight = stack_and_dev(qweights)
scales = stack_and_dev(scaless)

print("w_ref size:", w_ref.size())
print("qweight size:", qweight.size())
print("scales size:", scales.size())

# Allocate marlin workspace
max_workspace_size = (n // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
shuffles = torch.range(0, m - 1, dtype=torch.int)
sorted_ids = torch.full([m + (moe_block_size - 1)], m, dtype=torch.int).cuda()
sorted_ids[:m] = shuffles

# score = torch.randn((m, e), device='cuda', dtype=dtype)
moe_kernels.marlin_gemm_moe(a_2d, qweight, sorted_ids, scales, workspace, m, n, k)
# triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
# torch_output = torch_moe(a, w1, w2, score, topk)
# assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
score = torch.randn((m, e), device='cuda', dtype=dtype)
marlin_output = moe_kernels.marlin_gemm_moe(a, qweight, sorted_ids, scales, workspace, m, n, k)
torch_output = torch_moe(a, w1, w2, score, topk)
# assert torch.allclose(marlin_output, torch_output, atol=1e-2, rtol=0)

@pytest.mark.parametrize("m", [512]) #, 222, 33, 1])
@pytest.mark.parametrize("n", [2048]) #, 256, 1024])
@pytest.mark.parametrize("k", [128]) #, 511, 1024])
@pytest.mark.parametrize("e", [8]) #, 64])
@pytest.mark.parametrize("topk", [2]) #, 6])
@pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16])
def test_fused_marlin_moe_2(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
triton_output = fused_marlin_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe, get_config_file_name)
fused_moe, fused_marlin_moe, get_config_file_name)

__all__ = [
"fused_moe",
"fused_marlin_moe",
"get_config_file_name",
]
Loading

1 comment on commit 289cc8c

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bigger_is_better

Benchmark suite Current: 289cc8c Previous: a10b831 Ratio
{"name": "request_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.2.0", "python_version": "3.10.12 (main, May 10 2024, 13:42:25) [GCC 9.4.0]", "torch_version": "2.3.0+cu121"} 3.841071779379211 prompts/s
{"name": "token_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.2.0", "python_version": "3.10.12 (main, May 10 2024, 13:42:25) [GCC 9.4.0]", "torch_version": "2.3.0+cu121"} 1474.971563281617 tokens/s

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.