Skip to content

Commit

Permalink
Merge pull request opendatahub-io#14 from ROCm/fused_topK_softmax
Browse files Browse the repository at this point in the history
enable fused topK_softmax kernel for hip path
  • Loading branch information
hthangirala authored May 16, 2024
2 parents 83ce7b2 + 41e348a commit e3ae076
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 95 deletions.
2 changes: 2 additions & 0 deletions csrc/cuda_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor(var, lane_mask, width)
#endif

#ifndef USE_ROCM
Expand Down
27 changes: 17 additions & 10 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "csrc/cuda_compat.h"

#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))

namespace vllm {
namespace moe {

static constexpr int WARP_SIZE = 32;

/// Aligned array type
template <
typename T,
Expand Down Expand Up @@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
}

// From this point, thread max in all the threads have the max within the row.
Expand All @@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
}

// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
Expand Down Expand Up @@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);

// We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert))
Expand Down Expand Up @@ -383,7 +390,7 @@ struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
Expand All @@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;

static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
Expand Down
8 changes: 3 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,13 @@ def _read_requirements(filename: str) -> List[str]:

ext_modules = []

if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))

if _install_punica():
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
if _is_cuda() and _install_punica():
ext_modules.append(CMakeExtension(name="vllm._punica_C"))

if not _is_neuron():
ext_modules.append(CMakeExtension(name="vllm._C"))
ext_modules.append(CMakeExtension(name="vllm._custom_C"))
ext_modules.append(CMakeExtension(name="vllm._moe_C"))

package_data = {
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
Expand Down
195 changes: 115 additions & 80 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import triton
import triton.language as tl

import vllm._moe_C as moe_kernels
from vllm._C import ops
from vllm.logger import init_logger
from vllm.utils import is_hip

logger = init_logger(__name__)

Expand Down Expand Up @@ -108,8 +108,8 @@ def fused_moe_kernel(
offs_k[None, :] * stride_ak)

off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
b_ptrs = (b_ptr + off_experts * stride_be +
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
Expand All @@ -121,10 +121,12 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
a = tl.load(
a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
Expand All @@ -144,8 +146,8 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_ptrs = (c_ptr + stride_cm * offs_token[:, None] +
stride_cn * offs_cn[None, :])
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)

Expand Down Expand Up @@ -193,31 +195,46 @@ def moe_align_block_size(
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
dtype=torch.int32,
device=topk_ids.device)
device=topk_ids.device,
)
expert_ids = torch.empty(
(topk_ids.numel() + num_experts, ),
dtype=torch.int32,
device=topk_ids.device,
)
sorted_ids.fill_(topk_ids.numel())
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
return sorted_ids, expert_ids, num_tokens_post_pad


def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None:
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: Dict[str, Any],
) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
"BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), )

fused_moe_kernel[grid](
A,
Expand Down Expand Up @@ -310,8 +327,8 @@ def fused_moe(
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert (hidden_states.shape[0] == gating_output.shape[0]
), "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
Expand All @@ -323,34 +340,26 @@ def fused_moe(
M, _ = hidden_states.shape
E, N, _ = w1.shape

if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels

topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk_weights = torch.empty(M,
topk,
dtype=torch.int32,
dtype=torch.float32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

Expand All @@ -367,48 +376,74 @@ def fused_moe(
else:
# Else use the default config
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}

if M <= E:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}

intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
topk_ids, config["BLOCK_SIZE_M"], E)

invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, False,
topk_ids.shape[1], config)
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))

invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, True, 1,
config)
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
)

if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states)
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)

0 comments on commit e3ae076

Please sign in to comment.