Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Rand for everything
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Jun 12, 2024
1 parent c839319 commit 5ac5ba6
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
for (int i = 0; i < a_sh_wr_iters; i++) {
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
int row = a_idx / a_gl_stride;
int sorted_row = sorted_ids[row] / (replicate_input ? topk : 1);
int sorted_row = replicate_input ? sorted_ids[row] / topk : sorted_ids[row];
int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
// if (threadIdx.x < 8 && blockIdx.x == 80) {
// // int mcols = replicate_input ? 1 : topk;
Expand All @@ -490,7 +490,7 @@ MarlinMoE(const int4* __restrict__ A, // fp16 input matrix of shape mxk
// // printf("row: %d -> %d, sh: %d -> %d ? %d // %d, %d, %d\n", row, sorted_row, i,
// // a_sh_wr_trans[i], a_sh_wr_pred[i], tot_m * (replicate_input ? 1 : topk), a_sh_wr_iters, stages * a_sh_stage);
// }
if (sorted_row < tot_m * (replicate_input ? 1 : topk)) {
if (sorted_row < tot_m * (replicate_input ? 1 : topk) && new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) {
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]],
&A[new_idx],
a_sh_wr_pred[i]);
Expand Down
2 changes: 2 additions & 0 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch
import numpy
import random
from typing import List
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
Expand Down Expand Up @@ -355,6 +356,7 @@ def test_fused_marlin_moe(
topk: int,
group_size: int,
):
random.seed(4000)
torch.manual_seed(4000)
if topk > e:
return
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def fused_marlin_moe(

# sorted_token_ids[0:M] = torch.range(0, M - 1) * 2

# print("sorted_token_ids", sorted_token_ids)
print("sorted_token_ids", sorted_token_ids)

max_workspace_size = (N // 64) * 16
workspace = torch.zeros(max_workspace_size,
Expand Down Expand Up @@ -817,7 +817,7 @@ def fused_marlin_moe(

# intermediate_cache2 = intermediate_cache2.view(M, -1, N).sum(dim=1)

# print("intermediate op:", intermediate_cache2.size(), w2.size())
print("intermediate op:", intermediate_cache2.size(), w2.size(), M, N, K, topk)

intermediate_cache3 = moe_kernels.marlin_gemm_moe(intermediate_cache2, w2,
sorted_token_ids, topk_weights, w2_scale, expert_offsets_np, workspace,
Expand Down

1 comment on commit 5ac5ba6

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bigger_is_better

Benchmark suite Current: 5ac5ba6 Previous: c9d8a5d Ratio
{"name": "request_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.2.0", "python_version": "3.10.12 (main, Jun 10 2024, 18:47:50) [GCC 9.4.0]", "torch_version": "2.3.0+cu121"} 3.9169633996661437 prompts/s
{"name": "token_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.2.0", "python_version": "3.10.12 (main, Jun 10 2024, 18:47:50) [GCC 9.4.0]", "torch_version": "2.3.0+cu121"} 1504.1139454717993 tokens/s

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.