From 9f6113fb4e40c52f8d243954e8d7b7e38cc560ec Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Thu, 9 May 2024 21:53:28 +0000 Subject: [PATCH 1/4] enable fused topK_softmax kernel for hip --- csrc/cuda_compat.h | 2 + csrc/moe/topk_softmax_kernels.cu | 27 ++++++---- setup.py | 8 ++- .../layers/fused_moe/fused_moe.py | 50 ++++++++----------- 4 files changed, 44 insertions(+), 43 deletions(-) diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index c711d8d1b24b9..fc3e217aef7a1 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -18,8 +18,10 @@ #ifndef USE_ROCM #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) __shfl_xor(var, lane_mask, width) #endif #ifndef USE_ROCM diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 8c65f40fe836a..e23f2ca353860 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -19,15 +19,22 @@ #include #include #include +#include "csrc/cuda_compat.h" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) namespace vllm { namespace moe { -static constexpr int WARP_SIZE = 32; - /// Aligned array type template < typename T, @@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); } // From this point, thread max in all the threads have the max within the row. @@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); } // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables @@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); - int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); // We want lower indices to "win" in every thread so we break ties this way if (other_max > max_val || (other_max == max_val && other_expert < expert)) @@ -383,7 +390,7 @@ struct TopkConstants { static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int THREADS_PER_ROW = EXPERTS / VPT; static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; @@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; - static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; diff --git a/setup.py b/setup.py index ad15f693875a6..a4f7a61e8ef71 100644 --- a/setup.py +++ b/setup.py @@ -362,15 +362,13 @@ def _read_requirements(filename: str) -> List[str]: ext_modules = [] -if _is_cuda(): - ext_modules.append(CMakeExtension(name="vllm._moe_C")) - - if _install_punica(): - ext_modules.append(CMakeExtension(name="vllm._punica_C")) +if _is_cuda() and _install_punica(): + ext_modules.append(CMakeExtension(name="vllm._punica_C")) if not _is_neuron(): ext_modules.append(CMakeExtension(name="vllm._C")) ext_modules.append(CMakeExtension(name="vllm._custom_C")) + ext_modules.append(CMakeExtension(name="vllm._moe_C")) package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1ec09f0cd4c28..c60f44c870d8f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -11,6 +11,8 @@ from vllm._C import ops from vllm.logger import init_logger from vllm.utils import is_hip +import vllm._moe_C as moe_kernels + logger = init_logger(__name__) @@ -323,34 +325,26 @@ def fused_moe( M, _ = hidden_states.shape E, N, _ = w1.shape - if is_hip(): - # The MoE kernels are not yet supported on ROCm. - routing_weights = torch.softmax(gating_output, - dim=-1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) - else: - import vllm._moe_C as moe_kernels - - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) From 6177c6c8e9e99b0092cf9ed2405825a09efd996b Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Thu, 9 May 2024 22:54:05 +0000 Subject: [PATCH 2/4] [nit] ruff: rm is_hip --- vllm/model_executor/layers/fused_moe/fused_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c60f44c870d8f..a7ea98d16c1b6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -10,7 +10,6 @@ from vllm._C import ops from vllm.logger import init_logger -from vllm.utils import is_hip import vllm._moe_C as moe_kernels From 97439cb8b07fc0aee19f3d3eae1a1da6daa468a2 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Fri, 10 May 2024 15:01:31 +0000 Subject: [PATCH 3/4] [nit] ruff isort imports --- .../layers/fused_moe/fused_moe.py | 226 +++++++++++------- 1 file changed, 137 insertions(+), 89 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a7ea98d16c1b6..836b8620355c7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,10 +8,9 @@ import triton import triton.language as tl +import vllm._moe_C as moe_kernels from vllm._C import ops from vllm.logger import init_logger -import vllm._moe_C as moe_kernels - logger = init_logger(__name__) @@ -105,12 +104,16 @@ def fused_moe_kernel( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + - offs_bn[None, :] * stride_bn) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -122,13 +125,14 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) # We accumulate along the K dimension. accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -136,24 +140,25 @@ def fused_moe_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load( + topk_weights_ptr + offs_token, mask=token_mask, other=0 + ) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = ( + c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + ) c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, - num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + topk_ids: torch.Tensor, block_size: int, num_experts: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. @@ -192,33 +197,50 @@ def moe_align_block_size( by block_size for proper block matrix operations. """ sorted_ids = torch.empty( - (topk_ids.numel() + num_experts * (block_size - 1), ), + (topk_ids.numel() + num_experts * (block_size - 1),), + dtype=torch.int32, + device=topk_ids.device, + ) + expert_ids = torch.empty( + (topk_ids.numel() + num_experts,), dtype=torch.int32, - device=topk_ids.device) - expert_ids = torch.empty((topk_ids.numel() + num_experts, ), - dtype=torch.int32, - device=topk_ids.device) + device=topk_ids.device, + ) sorted_ids.fill_(topk_ids.numel()) - num_tokens_post_pad = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + num_tokens_post_pad = torch.empty( + (1), dtype=torch.int32, device=topk_ids.device + ) + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, top_k: int, - config: Dict[str, Any]) -> None: +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], +) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + grid = lambda META: ( + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), + ) fused_moe_kernel[grid]( A, @@ -267,11 +289,13 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: json_file_name = get_config_file_name(E, N) config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) if os.path.exists(config_file_path): with open(config_file_path) as f: logger.info( - f"Using configuration from {config_file_path} for MoE layer.") + f"Using configuration from {config_file_path} for MoE layer." + ) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} @@ -311,31 +335,27 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") + assert ( + hidden_states.shape[0] == gating_output.shape[0] + ), "Number of tokens mismatch" assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] M, _ = hidden_states.shape E, N, _ = w1.shape - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) + token_expert_indicies = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) moe_kernels.topk_softmax( topk_weights, topk_ids, @@ -360,48 +380,76 @@ def fused_moe( else: # Else use the default config config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, } if M <= E: config = { - 'BLOCK_SIZE_M': 16, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 1 + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, } - intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config['BLOCK_SIZE_M'], E) + topk_ids, config["BLOCK_SIZE_M"], E + ) - invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, - topk_weights, topk_ids, sorted_token_ids, - expert_ids, num_tokens_post_padded, False, - topk_ids.shape[1], config) + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, - topk_weights, topk_ids, sorted_token_ids, - expert_ids, num_tokens_post_padded, True, 1, - config) + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + ) if inplace: - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), dim=1 + ) From 41e348af011b05d7efa9c51df16741fe70a025e7 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Fri, 10 May 2024 15:15:50 +0000 Subject: [PATCH 4/4] [nit] yapf formatting --- .../layers/fused_moe/fused_moe.py | 96 +++++++++---------- 1 file changed, 45 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 836b8620355c7..a48b4385ba917 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -104,16 +104,12 @@ def fused_moe_kernel( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak - ) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = ( - b_ptr - + off_experts * stride_be - + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - ) + b_ptrs = (b_ptr + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -127,12 +123,13 @@ def fused_moe_kernel( # K dimension. a = tl.load( a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0, ) - b = tl.load( - b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 - ) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) # We accumulate along the K dimension. accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -140,25 +137,24 @@ def fused_moe_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load( - topk_weights_ptr + offs_token, mask=token_mask, other=0 - ) + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = ( - c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] - ) + c_ptrs = (c_ptr + stride_cm * offs_token[:, None] + + stride_cn * offs_cn[None, :]) c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, num_experts: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. @@ -197,19 +193,19 @@ def moe_align_block_size( by block_size for proper block matrix operations. """ sorted_ids = torch.empty( - (topk_ids.numel() + num_experts * (block_size - 1),), + (topk_ids.numel() + num_experts * (block_size - 1), ), dtype=torch.int32, device=topk_ids.device, ) expert_ids = torch.empty( - (topk_ids.numel() + num_experts,), + (topk_ids.numel() + num_experts, ), dtype=torch.int32, device=topk_ids.device, ) sorted_ids.fill_(topk_ids.numel()) - num_tokens_post_pad = torch.empty( - (1), dtype=torch.int32, device=topk_ids.device - ) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) ops.moe_align_block_size( topk_ids, num_experts, @@ -237,10 +233,8 @@ def invoke_fused_moe_kernel( assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 - grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), - ) + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) fused_moe_kernel[grid]( A, @@ -289,13 +283,11 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: json_file_name = get_config_file_name(E, N) config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name - ) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) if os.path.exists(config_file_path): with open(config_file_path) as f: logger.info( - f"Using configuration from {config_file_path} for MoE layer." - ) + f"Using configuration from {config_file_path} for MoE layer.") # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} @@ -335,27 +327,31 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. - assert ( - hidden_states.shape[0] == gating_output.shape[0] - ), "Number of tokens mismatch" + assert (hidden_states.shape[0] == gating_output.shape[0] + ), "Number of tokens mismatch" assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] M, _ = hidden_states.shape E, N, _ = w1.shape - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) - token_expert_indicies = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) moe_kernels.topk_softmax( topk_weights, topk_ids, @@ -411,8 +407,7 @@ def fused_moe( ) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E - ) + topk_ids, config["BLOCK_SIZE_M"], E) invoke_fused_moe_kernel( hidden_states, @@ -450,6 +445,5 @@ def fused_moe( dim=1, out=hidden_states, ) - return torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), dim=1 - ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1)