From 10bad7deafba7341aee5d8ea96c1df24bc76a326 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sun, 7 Apr 2024 18:55:44 +0000 Subject: [PATCH 1/3] Optimize the fp-dequantizer to get high memory-BW utilization --- csrc/fp_quantizer/quantize.cu | 197 ++++++++++++++++------------------ 1 file changed, 91 insertions(+), 106 deletions(-) diff --git a/csrc/fp_quantizer/quantize.cu b/csrc/fp_quantizer/quantize.cu index 37be6cc0657c..7bdd9c1ed486 100644 --- a/csrc/fp_quantizer/quantize.cu +++ b/csrc/fp_quantizer/quantize.cu @@ -219,119 +219,114 @@ __global__ void apply_quantization(T* val, } template -__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size) +__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements) { - int tidx = threadIdx.x; - int wid = tidx >> 5; - int lane = tidx & 0x1f; - int gid = blockIdx.x * quantization::warps + wid; + + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size; + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; - constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); - - constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); - constexpr uint32_t load_stride = vector_size * hw_warp_size; - const uint32_t thread_offset = lane * vector_size; - const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8; - const uint32_t base_load_offset = - gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset - const uint32_t base_store_offset = gid * group_size + thread_offset; - const uint8_t* load_base_ptr = val + base_load_offset; - + const uint32_t g_index = (tidx / group_size); + const uint32_t group_size_bytes = (group_size * quantized_bits / 8); + const uint8_t* load_base_ptr = val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8; + int mantisa_mask = ((1 << q_mantisa_bits) - 1); mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); - T* store_base_ptr = q_val + base_store_offset; - float scale; //= q_scale[gid]; + T* store_base_ptr = q_val + tidx; + float scale; - uint8_t* scale_as_int8 = reinterpret_cast(&scale); + uint8_t *scale_as_int8 = reinterpret_cast(&scale); if (quantized_bits == 6) { mem_access::load_global( scale_as_int8, - val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8)); + val + g_index * (group_size_bytes + 4) + group_size_bytes + ); mem_access::load_global( scale_as_int8 + quantization::quanitzed_access_granularity_6bits, - val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) + - quantization::quanitzed_access_granularity_6bits); - } else + val + g_index * (group_size_bytes + 4) + group_size_bytes + quantization::quanitzed_access_granularity_6bits + ); + } + else mem_access::load_global( scale_as_int8, - val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8)); - -#pragma unroll - for (int i = 0; i < unroll; i++) { - if (i * load_stride + thread_offset < group_size) { - uint64_t q_buf_in; - uint64_t q_buf_in1; - uint8_t* int8_data = reinterpret_cast(&q_buf_in); - uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); - uint32_t loading_offset = i * load_stride * quantized_bits / 8; - if (quantized_bits == 6) { - mem_access::load_global( - int8_data, load_base_ptr + loading_offset); - mem_access::load_global( - int8_data + quantization::quanitzed_access_granularity_6bits, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity_6bits); - mem_access::load_global( - int8_data + quantization::quanitzed_access_granularity_6bits * 2, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity_6bits * 2); - } else { + val + g_index * (group_size_bytes + 4) + group_size_bytes + ); + + if (tidx < total_num_elements) + { + uint64_t q_buf_in; + uint64_t q_buf_in1; + uint8_t *int8_data = reinterpret_cast(&q_buf_in); + uint8_t *int8_data1 = reinterpret_cast(&q_buf_in1); + if (quantized_bits == 6) + { + mem_access::load_global( + int8_data, + load_base_ptr); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits, + load_base_ptr + quantization::quanitzed_access_granularity_6bits); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits * 2, + load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + } + else { + mem_access::load_global( + int8_data, + load_base_ptr); + if (quantized_bits > 4) { mem_access::load_global( - int8_data, load_base_ptr + loading_offset); - if (quantized_bits > 4) { + int8_data + quantization::quanitzed_access_granularity, + load_base_ptr + quantization::quanitzed_access_granularity); + if (quantized_bits == 12) + { mem_access::load_global( - int8_data + quantization::quanitzed_access_granularity, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity); - if (quantized_bits == 12) { - mem_access::load_global( - int8_data1, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity * 2); - } + int8_data1, + load_base_ptr + quantization::quanitzed_access_granularity * 2); } } - T store_buf[vector_size]; - uint16_t* q_buf = reinterpret_cast(store_buf); + } + T store_buf[vector_size]; + uint16_t* q_buf = reinterpret_cast(store_buf); #pragma unroll - for (int j = 0; j < vector_size; j++) { - uint16_t new_data; - if (j < 5 || quantized_bits != 12) { - new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); - } else { - if (j == 5) { - new_data = (uint16_t)(q_buf_in1); - new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); - } else - new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); - } - - uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); - uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; - uint16_t dst_mantisa = (new_data & _mantisa_mask); - - if (dst_exponent != (1 << q_exponent_bits) - 1) - dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + - (1 << (q_exponent_bits - 1)) - 1; - - q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) | - (dst_exponent << q_mantisa_bits) | - (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); - float up_cast = conversion::to(store_buf[j]); - store_buf[j] = conversion::to(up_cast * scale); + for (int j = 0; j < vector_size; j++) + { + uint16_t new_data; + if (j < 5 || quantized_bits != 12) { + new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); } - mem_access::store_global( - store_base_ptr + i * load_stride, store_buf); + else { + if (j == 5) { + new_data = (uint16_t)(q_buf_in1); + new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); + } else + new_data = (uint16_t)(q_buf_in1 >> ((j-6) * quantized_bits + 8)); + } + + uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; + uint16_t dst_mantisa = (new_data & _mantisa_mask); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + (1 << (q_exponent_bits - 1)) - 1; + + q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) | + (dst_exponent << q_mantisa_bits) | + (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + float up_cast = conversion::to(store_buf[j]); + store_buf[j] = conversion::to(up_cast * scale); } + mem_access::store_global( + store_base_ptr, store_buf); } } @@ -386,11 +381,6 @@ INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8); #endif INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8); -#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \ - case COUNT: \ - apply_dequantization \ - <<>>(val, q_val, group_size); \ - break; template void launch_dequantization(uint8_t* val, @@ -401,22 +391,17 @@ void launch_dequantization(uint8_t* val, int q_exponent_bits, cudaStream_t stream) { - const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps); + int blocks = ((num_groups * group_size) - 1) / (quantization::threads * (quantization::access_granularity / sizeof(T))) + 1; + const dim3 grid(blocks); const dim3 block(quantization::threads); - - constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T); - const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll; - - DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { - switch (copy_unroll) { - LAUNCH_FOR_DEQUANTIZATION_UNROLL(1) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(2) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(3) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(4) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(5) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(6) - } - }); + DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { + apply_dequantization<<>>( + val, + q_val, + group_size, + (num_groups * group_size) + ); + }); } #define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \ template void launch_dequantization(uint8_t*, T*, int, int, int, int, cudaStream_t); From c5ba68e3ac32a0fa93d2be07b7ddcd9a40b3d122 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 10 Apr 2024 06:33:36 +0000 Subject: [PATCH 2/3] fix formating --- csrc/fp_quantizer/quantize.cu | 95 ++++++++++++++--------------------- 1 file changed, 39 insertions(+), 56 deletions(-) diff --git a/csrc/fp_quantizer/quantize.cu b/csrc/fp_quantizer/quantize.cu index 7bdd9c1ed486..5f0b58f124f0 100644 --- a/csrc/fp_quantizer/quantize.cu +++ b/csrc/fp_quantizer/quantize.cu @@ -225,108 +225,94 @@ template __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements) { - constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size; - + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; - constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); const uint32_t g_index = (tidx / group_size); const uint32_t group_size_bytes = (group_size * quantized_bits / 8); - const uint8_t* load_base_ptr = val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8; - + const uint8_t* load_base_ptr = + val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8; + int mantisa_mask = ((1 << q_mantisa_bits) - 1); mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); T* store_base_ptr = q_val + tidx; float scale; - uint8_t *scale_as_int8 = reinterpret_cast(&scale); + uint8_t* scale_as_int8 = reinterpret_cast(&scale); if (quantized_bits == 6) { mem_access::load_global( - scale_as_int8, - val + g_index * (group_size_bytes + 4) + group_size_bytes - ); + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); mem_access::load_global( scale_as_int8 + quantization::quanitzed_access_granularity_6bits, - val + g_index * (group_size_bytes + 4) + group_size_bytes + quantization::quanitzed_access_granularity_6bits - ); - } - else + val + g_index * (group_size_bytes + 4) + group_size_bytes + + quantization::quanitzed_access_granularity_6bits); + } else mem_access::load_global( - scale_as_int8, - val + g_index * (group_size_bytes + 4) + group_size_bytes - ); + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); - if (tidx < total_num_elements) - { + if (tidx < total_num_elements) { uint64_t q_buf_in; uint64_t q_buf_in1; - uint8_t *int8_data = reinterpret_cast(&q_buf_in); - uint8_t *int8_data1 = reinterpret_cast(&q_buf_in1); - if (quantized_bits == 6) - { + uint8_t* int8_data = reinterpret_cast(&q_buf_in); + uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); + if (quantized_bits == 6) { mem_access::load_global( - int8_data, - load_base_ptr); + int8_data, load_base_ptr); mem_access::load_global( int8_data + quantization::quanitzed_access_granularity_6bits, load_base_ptr + quantization::quanitzed_access_granularity_6bits); mem_access::load_global( int8_data + quantization::quanitzed_access_granularity_6bits * 2, load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); - } - else { - mem_access::load_global( - int8_data, - load_base_ptr); + } else { + mem_access::load_global(int8_data, + load_base_ptr); if (quantized_bits > 4) { mem_access::load_global( int8_data + quantization::quanitzed_access_granularity, load_base_ptr + quantization::quanitzed_access_granularity); - if (quantized_bits == 12) - { + if (quantized_bits == 12) { mem_access::load_global( - int8_data1, - load_base_ptr + quantization::quanitzed_access_granularity * 2); + int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2); } } } T store_buf[vector_size]; uint16_t* q_buf = reinterpret_cast(store_buf); #pragma unroll - for (int j = 0; j < vector_size; j++) - { + for (int j = 0; j < vector_size; j++) { uint16_t new_data; if (j < 5 || quantized_bits != 12) { new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); - } - else { + } else { if (j == 5) { new_data = (uint16_t)(q_buf_in1); new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); } else - new_data = (uint16_t)(q_buf_in1 >> ((j-6) * quantized_bits + 8)); + new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); } - + uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; uint16_t dst_mantisa = (new_data & _mantisa_mask); if (dst_exponent != (1 << q_exponent_bits) - 1) - dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + (1 << (q_exponent_bits - 1)) - 1; - - q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) | - (dst_exponent << q_mantisa_bits) | - (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + + q_buf[j] = + ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) | + (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); float up_cast = conversion::to(store_buf[j]); store_buf[j] = conversion::to(up_cast * scale); } - mem_access::store_global( - store_base_ptr, store_buf); + mem_access::store_global(store_base_ptr, store_buf); } } @@ -381,7 +367,6 @@ INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8); #endif INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8); - template void launch_dequantization(uint8_t* val, T* q_val, @@ -391,17 +376,15 @@ void launch_dequantization(uint8_t* val, int q_exponent_bits, cudaStream_t stream) { - int blocks = ((num_groups * group_size) - 1) / (quantization::threads * (quantization::access_granularity / sizeof(T))) + 1; + int blocks = ((num_groups * group_size) - 1) / + (quantization::threads * (quantization::access_granularity / sizeof(T))) + + 1; const dim3 grid(blocks); const dim3 block(quantization::threads); - DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { - apply_dequantization<<>>( - val, - q_val, - group_size, - (num_groups * group_size) - ); - }); + DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { + apply_dequantization + <<>>(val, q_val, group_size, (num_groups * group_size)); + }); } #define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \ template void launch_dequantization(uint8_t*, T*, int, int, int, int, cudaStream_t); From 53e91ecbb584162424678756426df13d562082f4 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 8 Nov 2024 23:14:15 +0000 Subject: [PATCH 3/3] Add several kernels to speed up training llama-based MoE model architectures --- deepspeed/tops/__init__.py | 12 + deepspeed/tops/includes/moe_gating.h | 62 + deepspeed/tops/includes/reduction_utils.h | 818 ++++++++ deepspeed/tops/includes/rope.h | 7 + deepspeed/tops/includes/swiglu.h | 6 + deepspeed/tops/includes/tops_context.h | 44 + deepspeed/tops/includes/utils.h | 1803 ++++++++++++++++++ deepspeed/tops/moe_gating/__init__.py | 2 + deepspeed/tops/moe_gating/moe_gather.py | 89 + deepspeed/tops/moe_gating/moe_gating.cpp | 285 +++ deepspeed/tops/moe_gating/moe_gating.py | 321 ++++ deepspeed/tops/moe_gating/test_moe_gating.py | 125 ++ deepspeed/tops/moe_gating/top1_moe_gating.cu | 781 ++++++++ deepspeed/tops/moe_gating/top2_moe_gating.cu | 930 +++++++++ deepspeed/tops/rope/__init__.py | 1 + deepspeed/tops/rope/rope-test.py | 16 + deepspeed/tops/rope/rope.cpp | 63 + deepspeed/tops/rope/rope.cu | 353 ++++ deepspeed/tops/rope/rope.py | 52 + deepspeed/tops/swiglu/__init__.py | 1 + deepspeed/tops/swiglu/swiglu.cpp | 55 + deepspeed/tops/swiglu/swiglu.cu | 214 +++ deepspeed/tops/swiglu/swiglu.py | 53 + deepspeed/tops/swiglu/test_swiglu.py | 54 + deepspeed/tops/tops.cpp | 18 + op_builder/tops.py | 83 + 26 files changed, 6248 insertions(+) create mode 100644 deepspeed/tops/__init__.py create mode 100644 deepspeed/tops/includes/moe_gating.h create mode 100644 deepspeed/tops/includes/reduction_utils.h create mode 100644 deepspeed/tops/includes/rope.h create mode 100644 deepspeed/tops/includes/swiglu.h create mode 100644 deepspeed/tops/includes/tops_context.h create mode 100644 deepspeed/tops/includes/utils.h create mode 100644 deepspeed/tops/moe_gating/__init__.py create mode 100644 deepspeed/tops/moe_gating/moe_gather.py create mode 100644 deepspeed/tops/moe_gating/moe_gating.cpp create mode 100644 deepspeed/tops/moe_gating/moe_gating.py create mode 100644 deepspeed/tops/moe_gating/test_moe_gating.py create mode 100644 deepspeed/tops/moe_gating/top1_moe_gating.cu create mode 100644 deepspeed/tops/moe_gating/top2_moe_gating.cu create mode 100644 deepspeed/tops/rope/__init__.py create mode 100644 deepspeed/tops/rope/rope-test.py create mode 100644 deepspeed/tops/rope/rope.cpp create mode 100644 deepspeed/tops/rope/rope.cu create mode 100644 deepspeed/tops/rope/rope.py create mode 100644 deepspeed/tops/swiglu/__init__.py create mode 100644 deepspeed/tops/swiglu/swiglu.cpp create mode 100644 deepspeed/tops/swiglu/swiglu.cu create mode 100644 deepspeed/tops/swiglu/swiglu.py create mode 100644 deepspeed/tops/swiglu/test_swiglu.py create mode 100644 deepspeed/tops/tops.cpp create mode 100644 op_builder/tops.py diff --git a/deepspeed/tops/__init__.py b/deepspeed/tops/__init__.py new file mode 100644 index 000000000000..c31656310f98 --- /dev/null +++ b/deepspeed/tops/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +############################# +######## Training Ops ####### +############################# + +from .moe_gating import * +from .swiglu import * +from .rope import * \ No newline at end of file diff --git a/deepspeed/tops/includes/moe_gating.h b/deepspeed/tops/includes/moe_gating.h new file mode 100644 index 000000000000..9b8fcdf467ea --- /dev/null +++ b/deepspeed/tops/includes/moe_gating.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include "moe_gating.cuh" + +void gate_scatter(torch::Tensor& moe_input, + torch::Tensor& expert_count_cumsums, + torch::Tensor& mapped_slots, + torch::Tensor& activations, + torch::Tensor& expert_counts, + torch::Tensor& mapped_expert_counts, + torch::Tensor& scores, + torch::Tensor& assignments, + torch::Tensor& offsets, + torch::Tensor& backup_offsets, + int top_k, + int capacity, + bool use_rts); + +void gate_fwd(torch::Tensor& moe_input, + torch::Tensor& expert_count_cumsums, + torch::Tensor& mapped_slots, + torch::Tensor& activations, + torch::Tensor& expert_counts, + torch::Tensor& mapped_expert_counts, + torch::Tensor& scores, + torch::Tensor& assignments, + torch::Tensor& offsets, + torch::Tensor& backup_offsets, + torch::Tensor& logits, + torch::Tensor& logits_out, + int top_k, + int capacity, + bool use_rts); + +void gate_bwd(torch::Tensor& moe_input_grad, + torch::Tensor& scores_grad, + torch::Tensor& activations_grad, + torch::Tensor& logits_grad, + torch::Tensor& logits, + torch::Tensor& assignments, + torch::Tensor& offsets, + torch::Tensor& mapped_slots, + int top_k, + int capacity, + bool use_rts); + + +void gather_fwd(torch::Tensor& layer_output, + torch::Tensor& moe_output, + torch::Tensor& scores, + torch::Tensor& mapped_slots, + int top_k); + +void gather_bwd(torch::Tensor& layer_output_grad, + torch::Tensor& scores_grad, + torch::Tensor& moe_output_grad, + torch::Tensor& moe_output, + torch::Tensor& scores, + torch::Tensor& mapped_slots, + int top_k); diff --git a/deepspeed/tops/includes/reduction_utils.h b/deepspeed/tops/includes/reduction_utils.h new file mode 100644 index 000000000000..3469bea2a22a --- /dev/null +++ b/deepspeed/tops/includes/reduction_utils.h @@ -0,0 +1,818 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "utils.h" + +namespace cg = cooperative_groups; + +namespace reduce { + +enum class ROpType { + // Addition + Add, + + // Maximum reduction + Max, + + // Minimum reduction + Min, +}; + +constexpr int max_threads = 1024; +constexpr int max_warps = max_threads / hw_warp_size; + +/* +High level API. The API takes in a set of operations and variables +and performs that reduction operation on that variable. The reductions +of each of the arguments are completely independent of each other ( +i.e., the val1-op1 combination has no impact on val2-op2). + +Example usage: +``` cpp +float max_val; +float min_val; +reduce::block(tb, warp, max_val, min_val); +``` + +TODO(cmikeh2): In theory, we might be able to do this sequentially with +device functions and rely on the assembler correctly behaving. My initial +instinct is this won't work, but if it does it would reduce implementation +cost significantly. + +TODO(cmikeh2): We need to support sub-block reductions. The warp intrinsic +currently supports this (more incidentally than anything else). It is not +uncommon in something like softmax or a fused attention kernel to map multiple +reductions to a thread block, but each reduction itself is only scoped +to part of the threads (i.e block size = 512, 128 threads per reduction). +*/ +template +DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4); + +/* +The partitioned block is a special case of the above where in the warps of a threadblock are +partitioned into separate independent reductions. For example, I might have an 8 warp thread block +in which each pair of warps is processing an independent piece of data. I would then reduce that +data with the something like the following: +``` cpp +float max_val; +reduce::partitioned_block(tb, warp, max_val); +``` +After which, each pair of warps would have coherent data with each other. Note, this API will not +provide correct results if the number of warps per partition is not a power of 2. +*/ +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4); + +/* +Single element reduction primitives. Used inside serial collection +loops. + +Example usage: +using rop = reduce::OpType; +float min = init(); +for (int i = 0; i < 4; i++) { + min = reduce::element(min, data[i]); +} +*/ + +template +DS_D_INLINE T element(const T lhs, const T rhs); + +template +DS_D_INLINE T init(); + +/********************** Internal reduction APIs **********************/ + +/* +Single element "reductions". TODO(cmikeh2): this sort of "op" concept +should be refactored into its own implementation at some point. This interface +may be easily expanded for new types/operations, but the typical reductions +we need are covered with min/max/add on float. + +NOTE: there is no mean reduction because that relies on knowledge of how +many values were already reduced into each scalar. Implementing this on top +of reduce should be straightforward (can just wrap the sum reduction) and +would be a good extension of the header. +*/ + +DS_D_INLINE int _warp_rank() +{ + const int thread_rank = + threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; + return thread_rank / hw_warp_size; +} + +/* Float element reduce implementations */ +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE double element(const double lhs, const double rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return fmaxf(lhs, rhs); +} + +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return fminf(lhs, rhs); +} + +/* __half element reduce implementation */ +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmax(lhs, rhs); +#else + return (lhs > rhs) ? lhs : rhs; +#endif +} + +template <> +DS_D_INLINE __nv_bfloat16 element(const __nv_bfloat16 lhs, const __nv_bfloat16 rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmax(lhs, rhs); +#else + return (lhs > rhs) ? lhs : rhs; +#endif +} + +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmin(lhs, rhs); +#else + return (lhs < rhs) ? lhs : rhs; +#endif +} + +/* __half2 element reduce implementation */ +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmax2(lhs, rhs); +#else + __half2 ret_val; + ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} + +template <> +DS_D_INLINE __nv_bfloat162 element(const __nv_bfloat162 lhs, const __nv_bfloat162 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmax2(lhs, rhs); +#else + __nv_bfloat162 ret_val; + ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} + +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmin2(lhs, rhs); +#else + __half2 ret_val; + ret_val.x = (lhs.x < rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y < rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +/* +Reduction initialization primitives +*/ +template <> +DS_D_INLINE float init() +{ + return 0.0f; +} +template <> +DS_D_INLINE double init() +{ + return (double)0.0f; +} + +template <> +DS_D_INLINE float init() +{ + // Positive infinity + return INFINITY; +} + +template <> +DS_D_INLINE float init() +{ + // Negative infinity + return -INFINITY; +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw zero = {0x0000}; + return __half(zero); +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw inf = {0x7C00}; + return __half(inf); +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw neg_inf = {0xFC00}; + return __half(neg_inf); +} + +template <> +DS_D_INLINE __nv_bfloat16 init() +{ + constexpr __nv_bfloat16_raw neg_inf = {0xFF80}; + return __nv_bfloat16(neg_inf); +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0x0000, 0x0000}}; +#else + constexpr __half2_raw zero = {0x0000, 0x0000}; + return __half2(zero); +#endif +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0x7C00, 0x7C00}}; +#else + constexpr __half2_raw inf = {0x7C00, 0x7C00}; + return __half2(inf); +#endif +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0xFC00, 0xFC00}}; +#else + constexpr __half2_raw neg_inf = {0xFC00, 0xFC00}; + return __half2(neg_inf); +#endif +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x7FFFFFFF; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x80000000; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0xFFFFFFFF; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x7FFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x8000000000000000; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0xFFFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); + data[2] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); + data[2] = init(); + data[3] = init(); +} + +/* +Warp reduction primitives + +`reduction_width` is an unsafe template parameter, that is that +when using `reduction_width` < hw_warp_size the warp is partitioned +into `hw_warp_size` / `reduction_width` groups of partial sums. + +If someone can figure out how to use variadic templates in a reasonable way +here (fold is C++17 only and I don't think helps and recursion feels like +huge overkill that harms readability) that would be wonderful. +*/ + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[2] = element(data[2], warp.shfl_xor(data[2], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[2] = element(data[2], warp.shfl_xor(data[2], i)); + data[3] = element(data[3], warp.shfl_xor(data[3], i)); + } +} + +/* +Implementation for primary block reduction that serves both `block` and +`partitioned_block`. + +Total warps refers to the reduction width of the reduction, not +the number of warps in the block (which may exceed that +if the block is partitioned or if we do a conservative bound at +compile time). +*/ +template +DS_D_INLINE void _block(cg::thread_block& tb, + cg::thread_block_tile& warp_arg, + T* data) +{ + constexpr int elems = sizeof...(Ops); + constexpr int bytes = sizeof(T); + // Unused when `partition_size == 1` or total_warps == 1 + __shared__ T reduce_buffer[max_warps * elems]; + +#ifdef __HIP_PLATFORM_AMD__ + const int total_threads = blockDim.x * blockDim.y * blockDim.z; + const int running_warps = total_threads / hw_warp_size; +#else + const int running_warps = warp_arg.meta_group_size(); +#endif + + // Always perform warp-scope reduction + _warp(warp_arg, data); + + // If max_warps == 1 let's skip the runtime check + if (total_warps != 1) { + if (warp_arg.thread_rank() == 0) { +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::store_shared(reduce_buffer + elems * _warp_rank() + i, data + i); + } + } + + // Synchronization inside block-uniform conditional is safe + tb.sync(); + + if (_warp_rank() == 0) { + if (warp_arg.thread_rank() < running_warps) { +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::load_shared( + data + i, reduce_buffer + elems * warp_arg.thread_rank() + i); + } + } else { + init(data); + } + + _warp(warp_arg, data); + +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::store_shared(reduce_buffer + elems * warp_arg.thread_rank() + i, + data + i); + } + } + + // Synchronization inside block-uniform conditional is safe + tb.sync(); + +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::load_shared(data + i, reduce_buffer + _warp_rank() * elems + i); + } + } +} + +/* +Main API implementations. For the most part, they just convert the individual +variables into arrays, which makes working with them easier with a single +implementation. In theory, we could use the `_block` implementation as another +option, but the nature of using a pointer is a little less safe and this allows +us to obfuscate the details of the partitioned implementation. +*/ +template +DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val) +{ + _block(tb, warp, &val); +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2) +{ + float data[2] = {val1, val2}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3) +{ + float data[3] = {val1, val2, val3}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4) +{ + float data[4] = {val1, val2, val3, val4}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; + val4 = data[3]; +} + +/* +Note: for the partitioned blocks, the implementation does not support non-power of 2 blocks in order +to shorten block scale reduction length. +*/ +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val) +{ + if (num_threads <= hw_warp_size) { + _warp(warp, &val); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, &val); + } +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2) +{ + float data[2] = {val1, val2}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3) +{ + float data[3] = {val1, val2, val3}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4) +{ + float data[4] = {val1, val2, val3, val4}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; + val4 = data[3]; +} + +/* +Arg-reduce is a specialization of the above. We only support this with a single reduction +parameter. This only works for max/min reductions. +*/ + +__align__(8) struct IdxReduceResult { + /* + NOTE: ORDERING MATTERS HERE! The idx is the least significant set of bits + and the val is the most significant. Changing the order of this declaration + will break the code. + */ + int idx; + float val; +}; + +template +DS_D_INLINE IdxReduceResult +idx_reduce(cg::thread_block& tb, cg::thread_block_tile& warp, float val, int idx) +{ + IdxReduceResult res = {idx, val}; + + // Clear out the nan. This shouldn't be an issue for our initial applications + if (isnan(val)) res.val = init(); + + // Can do float compares as integers. By packing the index into the lower bits + // we can just do a single int64 rather than a branch, compare, and select. + // One side benefit of this is that it is by nature a stable algorithm and + // will always bias ties to the higher index. + int64_t* res_as_int = reinterpret_cast(&res); + + // The way floating point compare works is normally to perform a sign comparison + // and if they match, then do a comparison of the rest of the bits as unsigned + // integers. Since we are bundling these, that means for negative values we need + // to reverse the sort order, which we can do with an XOR. + if (val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + _block(tb, warp, res_as_int); + + // Sign bit is preserved, so we can check if we need to invert the mantissa back + if (res.val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + return res; +} + +} // namespace reduce diff --git a/deepspeed/tops/includes/rope.h b/deepspeed/tops/includes/rope.h new file mode 100644 index 000000000000..f3005c3021af --- /dev/null +++ b/deepspeed/tops/includes/rope.h @@ -0,0 +1,7 @@ + +#include +#include +#include "rope.cuh" + +void rope_fwd(torch::Tensor& query, torch::Tensor& key, int rotary_dim, float rope_theta); +void rope_bwd(torch::Tensor& query_grad, torch::Tensor& key_grad, int rotary_dim, float rope_theta); diff --git a/deepspeed/tops/includes/swiglu.h b/deepspeed/tops/includes/swiglu.h new file mode 100644 index 000000000000..76a80f17f30c --- /dev/null +++ b/deepspeed/tops/includes/swiglu.h @@ -0,0 +1,6 @@ +#include +#include +#include "swiglu.cuh" + +void swiglu_fwd(torch::Tensor& inp, torch::Tensor& out); +void swiglu_bwd(torch::Tensor& inp, torch::Tensor& out_grad, torch::Tensor& inp_grad); diff --git a/deepspeed/tops/includes/tops_context.h b/deepspeed/tops/includes/tops_context.h new file mode 100644 index 000000000000..7a5b466f42c6 --- /dev/null +++ b/deepspeed/tops/includes/tops_context.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include "cuda.h" +#include "curand.h" +#include + +class TOPSContext { +public: + TOPSContext() : _seed(42), _curr_offset(0) + { + curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(_gen, 123); + } + + virtual ~TOPSContext() + { + } + + static TOPSContext& Instance() + { + static TOPSContext _ctx; + return _ctx; + } + + curandGenerator_t& GetRandGenerator() { return _gen; } + + std::pair IncrementOffset(uint64_t offset_inc) + { + uint64_t offset = _curr_offset; + _curr_offset += offset_inc; + return std::pair(_seed, offset); + } + + void SetSeed(uint64_t new_seed) { _seed = new_seed; } + + +private: + curandGenerator_t _gen; + uint64_t _seed; + uint64_t _curr_offset; +}; \ No newline at end of file diff --git a/deepspeed/tops/includes/utils.h b/deepspeed/tops/includes/utils.h new file mode 100644 index 000000000000..a8d65df2c9f8 --- /dev/null +++ b/deepspeed/tops/includes/utils.h @@ -0,0 +1,1803 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Centralized header file for preprocessor macros and constants +used throughout the codebase. +*/ + +#pragma once + +#include +#include + +#include + +#define DS_HD_INLINE __host__ __device__ __forceinline__ +#define DS_D_INLINE __device__ __forceinline__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = 32; + +#define HALF_PRECISION_AVAILABLE = 1 +#define PTX_AVAILABLE + +#define ASYNC_COPY_AVAILABLE + +#include + +#include + +inline int next_pow2(const int val) +{ + int rounded_val = val - 1; + rounded_val |= rounded_val >> 1; + rounded_val |= rounded_val >> 2; + rounded_val |= rounded_val >> 4; + rounded_val |= rounded_val >> 8; + return rounded_val + 1; +} + + + +namespace conversion { + +// Basic primitive for constructing conversions +template +DS_D_INLINE TO to(FROM val) +{ + return to(val); +} + +// Specializations + +/********************* Identity Conversions *********************/ +/* +Identity conversions are useful in templated functions where we might have +a fixed destination type. For example, I might have a kernel that accepts +__half, __nv_bfloat16, and float but always want to do the core computation +at floating point: + +T mem_value = input[idx]; +float compute_value = conversion::to(mem_value); + +In practice, we should be able to elide the second template parameter: +float compute_val = conversion::to(mem_value); + +In this case, we need an implementation to handle the T = float case + +NOTE: The type inferencing system appears to be unable to handle inferring the first +template parameter, even in the trivial case. +*/ + +// Floating point types +template <> +DS_D_INLINE double to(double val) +{ + return val; +} +template <> +DS_D_INLINE float to(float val) +{ + return val; +} +template <> +DS_D_INLINE __half to(__half val) +{ + return val; +} +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val) +{ + return val; +} +#endif + +// Integer types +template <> +DS_D_INLINE int8_t to(int8_t val) +{ + return val; +} +template <> +DS_D_INLINE uint8_t to(uint8_t val) +{ + return val; +} +template <> +DS_D_INLINE int16_t to(int16_t val) +{ + return val; +} +template <> +DS_D_INLINE uint16_t to(uint16_t val) +{ + return val; +} +template <> +DS_D_INLINE int32_t to(int32_t val) +{ + return val; +} +template <> +DS_D_INLINE uint32_t to(uint32_t val) +{ + return val; +} +template <> +DS_D_INLINE int64_t to(int64_t val) +{ + return val; +} +template <> +DS_D_INLINE uint64_t to(uint64_t val) +{ + return val; +} + +// TODO: evaluate if we want bools + +/********************* To Double Conversions *********************/ + +// * to double variants + +// Would normally like to not use C cast, but this is an important enough conversion +// to keep +template <> +DS_D_INLINE double to(float val) +{ +#ifdef PTX_AVAILABLE + double ret_val; + asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val)); + return ret_val; +#else + return double(val); +#endif +} +// Note: there is a CVT instruction for __half -> double, but there's no inline interface +// for passing a single half value +template <> +DS_D_INLINE double to(__half val) +{ + return to(__half2float(val)); +} +//template <> +//DS_D_INLINE double to(int64_t val) +//{ +// return __ll2double_rn(val); +//} +template <> +DS_D_INLINE double to(int32_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int16_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int8_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(uint64_t val) +{ + return __ull2double_rn(val); +} +template <> +DS_D_INLINE double to(uint32_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint16_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint8_t val) +{ + return __uint2double_rn(val); +} + +// Same applies here +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE double to(__nv_bfloat16 val) +{ + return to(__bfloat162float(val)); +} +#endif + +/********************* To Float Conversions *********************/ + +template <> +DS_D_INLINE float to(double val) +{ + return __double2float_rn(val); +} +template <> +DS_D_INLINE float to(__half val) +{ + return __half2float(val); +} +template <> +DS_D_INLINE float to(int64_t val) +{ + return __ll2float_rn(val); +} +template <> +DS_D_INLINE float to(int32_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int16_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int8_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(uint64_t val) +{ + return __ull2float_rn(val); +} +template <> +DS_D_INLINE float to(uint32_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint16_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint8_t val) +{ + return __uint2float_rn(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float to(__nv_bfloat16 val) +{ + return __bfloat162float(val); +} +#endif + +/********************* To Float2 Conversions *********************/ +template <> +DS_D_INLINE float2 to(__half2 val) +{ + return __half22float2(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float2 to(__nv_bfloat162 val) +{ + return __bfloat1622float2(val); +} +#endif + +/********************* To Half Conversions *********************/ +template <> +DS_D_INLINE __half to(double val) +{ +#ifdef __HIP_PLATFORM_AMD__ + float val_f = __double2float_rn(val); + return __float2half(val_f); +#else + return __double2half(val); +#endif +} +template <> +DS_D_INLINE __half to(float val) +{ + return __float2half(val); +} +template <> +DS_D_INLINE __half to(int64_t val) +{ + return __ll2half_rn(val); +} +template <> +DS_D_INLINE __half to(int32_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(int16_t val) +{ + return __short2half_rn(val); +} +template <> +DS_D_INLINE __half to(int8_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint64_t val) +{ + return __ull2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint32_t val) +{ + return __uint2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint16_t val) +{ + return __ushort2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint8_t val) +{ + return __uint2half_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half to(__nv_bfloat16 val) +{ + return to<__half>(to(val)); +} +#endif + +/********************* To Half2 Conversions *********************/ +template <> +DS_D_INLINE __half2 to(float2 val) +{ + return __float22half2_rn(val); +} +template <> +DS_D_INLINE __half2 to(float val) +{ + return __float2half2_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half2 to(__nv_bfloat162 val) +{ + return to<__half2>(to(val)); +} +#endif + +/********************* To BF16 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(double val) +{ + return __double2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(float val) +{ + return __float2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int64_t val) +{ + return __ll2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int32_t val) +{ + return __int2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int16_t val) +{ + return __short2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int8_t val) +{ + return __int2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint64_t val) +{ + return __ull2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint32_t val) +{ + return __uint2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint16_t val) +{ + return __ushort2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint8_t val) +{ + return __uint2bfloat16_rn(val); +} +#endif + +/********************* To BF162 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat162 to(float2 val) +{ + return __float22bfloat162_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat162 to(float val) +{ + return __float2bfloat162_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat162 to(__half2 val) +{ + return to<__nv_bfloat162>(to(val)); +} +#endif + +/********************* To INT64_T Conversions *********************/ +template <> +DS_D_INLINE int64_t to(double val) +{ + return __double2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(float val) +{ + return __float2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(__half val) +{ + return __half2ll_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int64_t to(__nv_bfloat16 val) +{ + return __bfloat162ll_rn(val); +} +#endif + +/********************* To INT32_T Conversions *********************/ +template <> +DS_D_INLINE int32_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int32_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To INT16_T Conversions *********************/ +template <> +DS_D_INLINE int16_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int16_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To INT8_T Conversions *********************/ +template <> +DS_D_INLINE int8_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int8_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To UINT64_T Conversions *********************/ +template <> +DS_D_INLINE uint64_t to(double val) +{ + return __double2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(float val) +{ + return __float2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(__half val) +{ + return __half2ull_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint64_t to(__nv_bfloat16 val) +{ + return __bfloat162ull_rn(val); +} +#endif + +/********************* To UINT32_T Conversions *********************/ +template <> +DS_D_INLINE uint32_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint32_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +/********************* To UINT16_T Conversions *********************/ +template <> +DS_D_INLINE uint16_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint16_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +/********************* To UINT8_T Conversions *********************/ +template <> +DS_D_INLINE uint8_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint8_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +} // namespace conversion + +/////////////////////////////// Memory Access Utils /////////////////////////////// +namespace mem_access { + +enum class LoadPolicy { + CacheAll, // Cache at all levels + CacheGlobal, // Cache at L2 only + CacheStreaming // Cache with evict first policy +}; + +enum class StorePolicy { + Writeback, // Cache in L1, write-back on eviction + CacheGlobal, // Bypass L1, write-back on eviction + CacheStreaming // Allocate cache line with evict first policy +}; + +template +__device__ __forceinline__ void load_global(void* dst, const void* src); + +template +__device__ __forceinline__ void load_global(void* dst, const void* src, bool do_access); + +// Shared accesses have no cache policy +template +__device__ __forceinline__ void load_shared(void* dst, const void* src); + +template +__device__ __forceinline__ void load_shared(void* dst, const void* src, bool do_access); + +template +__device__ __forceinline__ void store_global(void* dst, const void* src); + +// Shared accesses have no cache policy +template +__device__ __forceinline__ void store_shared(void* dst, const void* src); + +#ifdef ASYNC_COPY_AVAILABLE +template +__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl); + +template +__device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate); + +template +__device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate); + +__device__ __forceinline__ void memcpy_async_fence(); + +template +__device__ __forceinline__ void memcpy_async_wait(); + +template +__device__ __forceinline__ void tail_complete_wait(int remaining_stages); +#endif + +// Util for tracking pipeline buffers +// TODO: Evaluate whether this should also be guarded by ASYNC_COPY_AVAILABLE +template +class BufferTracker { +public: + int current_state; + + __device__ __forceinline__ BufferTracker() : current_state(0) {} + + __device__ __forceinline__ int get() + { + int return_val = current_state++; + current_state = (current_state == max ? 0 : current_state); + return return_val; + } +}; + +__device__ __forceinline__ uint32_t lane_id() +{ +#ifdef PTX_AVAILABLE + unsigned int lane_id; + asm volatile("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +#else + return threadIdx.x & (warpSize - 1); // Portable +#endif +} + +/////////// Load Global /////////// +template <> +__device__ __forceinline__ void load_global<16>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16>(void* dst, const void* src, bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<8>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8>(void* dst, const void* src, bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.cg.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<4>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4>(void* dst, const void* src, bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.cg.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.cs.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<2>(void* dst, const void* src) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src)); +#else + const int16_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<2>(void* dst, const void* src, bool do_access) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.u16 %0, 0;\n" + "\t@p ld.global.u16 {%0}, [%1];\n" + "}\n" + : "=h"(*data) + : "l"(src), "r"((int)do_access)); +#else + const int16_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src)); +#else + const int16_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.u16 %0, 0;\n" + "\t@p ld.global.cg.u16 {%0}, [%1];\n" + "}\n" + : "=h"(*data) + : "l"(src), "r"((int)do_access)); +#else + const int16_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src)); +#else + const int16_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.u16 %0, 0;\n" + "\t@p ld.global.cs.u16 {%0}, [%1];\n" + "}\n" + : "=h"(*data) + : "l"(src), "r"((int)do_access)); +#else + const int16_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +/////////// Load Shared /////////// +namespace internal { + +#ifdef PTX_AVAILABLE +__device__ __forceinline__ unsigned convert_to_shared(const void* ptr) +{ +#if __CUDACC_VER_MAJOR__ >= 11 + // In CUDA 11 we have a builtin intrinsic + return __cvta_generic_to_shared(ptr); +#else + unsigned ret_val; + asm volatile( + "{\n" + "\t.reg .u64 p1;\n" + "\tcvta.to.shared.u64 p1, %1\n" + "\tcvt.u32.u64 %0, p1;\n" + "}\n" + : "=r"(ret_val) + : "l"(ptr)); + return ret_val; +#endif +} +#endif + +} // namespace internal + +template <> +__device__ __forceinline__ void load_shared<16>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "r"(src_shr)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<16>(void* dst, const void* src, bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "r"(src_shr), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_shared<8>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "r"(src_shr)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<8>(void* dst, const void* src, bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.shared.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "r"(src_shr), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_shared<4>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.u32 {%0}, [%1];\n" : "=r"(*data) : "r"(src_shr)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<4>(void* dst, const void* src, bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.shared.u32 %0, [%1];\n" + "}\n" + : "=r"(data[0]) + : "r"(src_shr), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +/////////// Store Global /////////// + +template <> +__device__ __forceinline__ void store_global<16>(void* dst, const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<16, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<16, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8>(void* dst, const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4>(void* dst, const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<2>(void* dst, const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + +template <> +__device__ __forceinline__ void store_global<2, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + +template <> +__device__ __forceinline__ void store_global<2, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} +/////////// Store Shared /////////// + +template <> +__device__ __forceinline__ void store_shared<16>(void* dst, const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "r"(dst_int), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w)); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_shared<8>(void* dst, const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" + : + : "r"(dst_int), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_shared<4>(void* dst, const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.u32 [%0], %1;\n" : : "r"(dst_int), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +/////////// Asynchronous Memory Copy /////////// + +#ifdef ASYNC_COPY_AVAILABLE +template +__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" + : + : "r"(shr_int), "l"(gbl), "n"(AccessSize)); +} + +template +__device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" + : + : "r"((int)predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize)); +} + +template +__device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (predicate ? AccessSize : 0); + + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" + : + : "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy)); +} + +template +__device__ __forceinline__ void memcpy_async_zero_nop(void* shr, + const void* gbl, + bool zero_predicate, + bool nop_predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (zero_predicate ? AccessSize : 0); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3, %4;\n" + "}\n" + : + : "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy)); +} + +// Cache global variants. Separate interface to require deliberate use of them. +__device__ __forceinline__ void memcpy_async_cg(void* shr, const void* gbl) +{ + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" : : "r"(shr_int), "l"(gbl)); +} + +__device__ __forceinline__ void memcpy_async_nop_cg(void* shr, const void* gbl, bool predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], 16;\n" + "}\n" + : + : "r"((int)predicate), "r"(shr_int), "l"(gbl)); +} + +__device__ __forceinline__ void memcpy_async_zero_cg(void* shr, const void* gbl, bool predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (predicate ? 16 : 0); + + asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n" + : + : "r"(shr_int), "l"(gbl), "r"(bytes_to_copy)); +} + +__device__ __forceinline__ void memcpy_async_zero_nop_cg(void* shr, + const void* gbl, + bool zero_predicate, + bool nop_predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (zero_predicate ? 16 : 0); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], 16, %3;\n" + "}\n" + : + : "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "r"(bytes_to_copy)); +} + +__device__ __forceinline__ void memcpy_async_fence() { asm volatile("cp.async.commit_group;\n"); } + +template +__device__ __forceinline__ void memcpy_async_wait() +{ + static_assert(stages <= 8); + + asm volatile("cp.async.wait_group %0;\n" : : "n"(stages)); +} + +// TODO: The tail complete should be a known compile time artifact, should try and induce this +// without all of the branches from the call-site. This is a hacky solution. +template <> +__device__ __forceinline__ void tail_complete_wait<1>(int remaining_stages) +{ + if (remaining_stages == 0) memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<2>(int remaining_stages) +{ + if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<3>(int remaining_stages) +{ + if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<4>(int remaining_stages) +{ + if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<5>(int remaining_stages) +{ + if (remaining_stages == 4) + memcpy_async_wait<4>(); + else if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<6>(int remaining_stages) +{ + if (remaining_stages == 5) + memcpy_async_wait<5>(); + else if (remaining_stages == 4) + memcpy_async_wait<4>(); + else if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} +#endif + +} // namespace mem_access diff --git a/deepspeed/tops/moe_gating/__init__.py b/deepspeed/tops/moe_gating/__init__.py new file mode 100644 index 000000000000..8839fb65a4c2 --- /dev/null +++ b/deepspeed/tops/moe_gating/__init__.py @@ -0,0 +1,2 @@ +from .moe_gating import MoEGating +from .moe_gather import MoEGather \ No newline at end of file diff --git a/deepspeed/tops/moe_gating/moe_gather.py b/deepspeed/tops/moe_gating/moe_gather.py new file mode 100644 index 000000000000..005a990a4ff4 --- /dev/null +++ b/deepspeed/tops/moe_gating/moe_gather.py @@ -0,0 +1,89 @@ + +import torch + +from typing import Tuple + +from deepspeed.ops.op_builder import TopsBuilder + +inf_module = None + +class MoEGatherFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, + moe_output, + scores, + mapped_slots, + is_grad_enabled, + top_k): + kernel = inf_module.moe_gather_fwd + + ctx.inp_shape = moe_output.shape + moe_output = moe_output.reshape(-1, moe_output.shape[-1]).contiguous() + top_k_tokens = scores.shape[0] + _, hidden_size = moe_output.shape + n_tokens = top_k_tokens // top_k + layer_output = torch.empty(n_tokens, hidden_size, dtype=moe_output.dtype, device=torch.cuda.current_device()) + kernel( + layer_output, + moe_output, + scores, + mapped_slots, + top_k + ) + ctx.top_k = top_k + if is_grad_enabled: + ctx.save_for_backward( + scores, + mapped_slots, + moe_output + ) + + return layer_output + + @staticmethod + def backward(ctx, layer_output_grad): + (scores, + mapped_slots, + moe_output) = ctx.saved_tensors + layer_output_grad = layer_output_grad.contiguous() + top_k = ctx.top_k + n_tokens, hidden_size = layer_output_grad.shape + kernel = inf_module.moe_gather_bwd + + moe_output_grad = torch.zeros(moe_output.shape, dtype=layer_output_grad.dtype, device=torch.cuda.current_device()) + scores_grad = torch.empty(n_tokens * top_k, dtype=scores.dtype, device=torch.cuda.current_device()) + + kernel( + layer_output_grad, + scores_grad, + moe_output_grad, + moe_output, + scores, + mapped_slots, + top_k, + ) + return moe_output_grad.reshape(ctx.inp_shape), scores_grad, None, None, None + +class MoEGather(torch.nn.Module): + + def __init__(self, logit_dtype=None, top_k=1, use_act_ckpting=False) -> None: + super(MoEGather, self).__init__() + global inf_module + if inf_module is None: + inf_module = TopsBuilder().load() + self.top_k = top_k + self.use_act_ckpting = use_act_ckpting + def forward(self, + moe_output: torch.Tensor, + scores: torch.Tensor, + mapped_slots: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + is_grad_enabled = self.use_act_ckpting and torch.is_grad_enabled() + return MoEGatherFunction.apply( + moe_output, + scores, + mapped_slots, + is_grad_enabled, + self.top_k + ) \ No newline at end of file diff --git a/deepspeed/tops/moe_gating/moe_gating.cpp b/deepspeed/tops/moe_gating/moe_gating.cpp new file mode 100644 index 000000000000..c99dd2ce98a1 --- /dev/null +++ b/deepspeed/tops/moe_gating/moe_gating.cpp @@ -0,0 +1,285 @@ +#include "moe_gating.h" + +#include + +#define DISPATCH_MOE_GATING(T_TYPE, C_TYPE) \ + if (activations.options().dtype() == torch::T_TYPE) { \ + if (top_k == 1) \ + launch_moe_gating((C_TYPE*)moe_input.data_ptr(), \ + (int32_t*)expert_count_cumsums.data_ptr(), \ + (int32_t*)mapped_slots.data_ptr(), \ + (const C_TYPE*)activations.data_ptr(), \ + (int32_t*)expert_counts.data_ptr(), \ + (int32_t*)mapped_expert_counts.data_ptr(), \ + (float*)scores.data_ptr(), \ + (int32_t*)assignments.data_ptr(), \ + (int32_t*)offsets.data_ptr(), \ + (int32_t*)backup_offsets.data_ptr(), \ + (float*)logits.data_ptr(), \ + (float*)logits_out.data_ptr(), \ + capacity, \ + n_tokens, \ + n_channels, \ + n_experts, \ + at::cuda::getCurrentCUDAStream()); \ + else \ + launch_top2_moe_gating((C_TYPE*)moe_input.data_ptr(), \ + (int32_t*)expert_count_cumsums.data_ptr(), \ + (int32_t*)mapped_slots.data_ptr(), \ + (const C_TYPE*)activations.data_ptr(), \ + (int32_t*)expert_counts.data_ptr(), \ + (int32_t*)mapped_expert_counts.data_ptr(), \ + (float*)scores.data_ptr(), \ + (int32_t*)assignments.data_ptr(), \ + (int32_t*)offsets.data_ptr(), \ + (int32_t*)backup_offsets.data_ptr(), \ + (float*)logits.data_ptr(), \ + (float*)logits_out.data_ptr(), \ + capacity, \ + n_tokens, \ + n_channels, \ + n_experts, \ + top_k, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + + +void gate_fwd(torch::Tensor& moe_input, + torch::Tensor& expert_count_cumsums, + torch::Tensor& mapped_slots, + torch::Tensor& activations, + torch::Tensor& expert_counts, + torch::Tensor& mapped_expert_counts, + torch::Tensor& scores, + torch::Tensor& assignments, + torch::Tensor& offsets, + torch::Tensor& backup_offsets, + torch::Tensor& logits, + torch::Tensor& logits_out, + int top_k, + int capacity, + bool use_rts) +{ + const int32_t n_tokens = activations.size(0); + const int32_t n_channels = activations.size(1); + + const int32_t n_experts = expert_count_cumsums.size(0) / top_k; + + DISPATCH_MOE_GATING(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_MOE_GATING(kBFloat16, __nv_bfloat16); +#endif + +} + + +#define DISPATCH_MOE_SCATTER(T_TYPE, C_TYPE) \ + if (activations.options().dtype() == torch::T_TYPE) { \ + if (top_k == 1) \ + launch_moe_scatter((C_TYPE*)moe_input.data_ptr(), \ + (int32_t*)expert_count_cumsums.data_ptr(), \ + (int32_t*)mapped_slots.data_ptr(), \ + (const C_TYPE*)activations.data_ptr(), \ + (int32_t*)expert_counts.data_ptr(), \ + (int32_t*)mapped_expert_counts.data_ptr(), \ + (float*)scores.data_ptr(), \ + (int32_t*)assignments.data_ptr(), \ + (int32_t*)offsets.data_ptr(), \ + (int32_t*)backup_offsets.data_ptr(), \ + capacity, \ + n_tokens, \ + n_channels, \ + n_experts, \ + at::cuda::getCurrentCUDAStream()); \ + else \ + launch_top2_moe_scatter((C_TYPE*)moe_input.data_ptr(), \ + (int32_t*)expert_count_cumsums.data_ptr(), \ + (int32_t*)mapped_slots.data_ptr(), \ + (const C_TYPE*)activations.data_ptr(), \ + (int32_t*)expert_counts.data_ptr(), \ + (int32_t*)mapped_expert_counts.data_ptr(), \ + (float*)scores.data_ptr(), \ + (int32_t*)assignments.data_ptr(), \ + (int32_t*)offsets.data_ptr(), \ + (int32_t*)backup_offsets.data_ptr(), \ + capacity, \ + n_tokens, \ + n_channels, \ + n_experts, \ + top_k, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + + +void gate_scatter(torch::Tensor& moe_input, + torch::Tensor& expert_count_cumsums, + torch::Tensor& mapped_slots, + torch::Tensor& activations, + torch::Tensor& expert_counts, + torch::Tensor& mapped_expert_counts, + torch::Tensor& scores, + torch::Tensor& assignments, + torch::Tensor& offsets, + torch::Tensor& backup_offsets, + int top_k, + int capacity, + bool use_rts) +{ + const int32_t n_tokens = activations.size(0); + const int32_t n_channels = activations.size(1); + + const int32_t n_experts = expert_count_cumsums.size(0) / top_k; + + DISPATCH_MOE_SCATTER(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_MOE_SCATTER(kBFloat16, __nv_bfloat16); +#endif + +} + +#define DISPATCH_MOE_GATING_BWD(T_TYPE, C_TYPE) \ + if (moe_input_grad.options().dtype() == torch::T_TYPE) { \ + if (top_k == 1) \ + launch_moe_gating_bwd((C_TYPE*)moe_input_grad.data_ptr(), \ + (float*)scores_grad.data_ptr(), \ + (C_TYPE*)activations_grad.data_ptr(), \ + (float*)logits_grad.data_ptr(), \ + (float*)logits.data_ptr(), \ + (int32_t*)assignments.data_ptr(), \ + (int32_t*)offsets.data_ptr(), \ + n_channels, \ + n_experts, \ + n_tokens, \ + capacity, \ + at::cuda::getCurrentCUDAStream()); \ + else \ + launch_top2_moe_gating_bwd((C_TYPE*)moe_input_grad.data_ptr(), \ + (float*)scores_grad.data_ptr(), \ + (C_TYPE*)activations_grad.data_ptr(), \ + (float*)logits_grad.data_ptr(), \ + (float*)logits.data_ptr(), \ + (int32_t*)assignments.data_ptr(), \ + (int32_t*)mapped_slots.data_ptr(), \ + n_channels, \ + n_experts, \ + n_tokens, \ + capacity, \ + top_k, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +void gate_bwd(torch::Tensor& moe_input_grad, + torch::Tensor& scores_grad, + torch::Tensor& activations_grad, + torch::Tensor& logits_grad, + torch::Tensor& logits, + torch::Tensor& assignments, + torch::Tensor& offsets, + torch::Tensor& mapped_slots, + int top_k, + int capacity, + bool use_rts) +{ + const int32_t n_tokens = scores_grad.size(0) / top_k; + const int32_t n_channels = moe_input_grad.size(1); + + const int32_t n_experts = logits.size(1); + DISPATCH_MOE_GATING_BWD(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_MOE_GATING_BWD(kBFloat16, __nv_bfloat16); +#endif + +} + + + +#define DISPATCH_GATHER(T_TYPE, C_TYPE) \ + if (layer_output.options().dtype() == torch::T_TYPE) { \ + if (top_k == 1) \ + launch_moe_gather((C_TYPE*)layer_output.data_ptr(), \ + (const C_TYPE*)moe_output.data_ptr(), \ + (const float*)scores.data_ptr(), \ + (const int32_t*)mapped_slots.data_ptr(), \ + n_channels, \ + n_tokens, \ + at::cuda::getCurrentCUDAStream()); \ + else \ + launch_top2_moe_gather((C_TYPE*)layer_output.data_ptr(), \ + (const C_TYPE*)moe_output.data_ptr(), \ + (const float*)scores.data_ptr(), \ + (const int32_t*)mapped_slots.data_ptr(), \ + n_channels, \ + n_tokens, \ + top_k, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +void gather_fwd(torch::Tensor& layer_output, + torch::Tensor& moe_output, + torch::Tensor& scores, + torch::Tensor& mapped_slots, + int top_k) +{ + const int32_t n_tokens = layer_output.size(0); + const int32_t n_channels = layer_output.size(1); + + DISPATCH_GATHER(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_GATHER(kBFloat16, __nv_bfloat16); +#endif + +} + +#define DISPATCH_GATHER_BWD(T_TYPE, C_TYPE) \ + if (layer_output_grad.options().dtype() == torch::T_TYPE) { \ + if (top_k == 1) \ + launch_moe_gather_bwd((C_TYPE*)layer_output_grad.data_ptr(), \ + (float*)scores_grad.data_ptr(), \ + (C_TYPE*)moe_output_grad.data_ptr(), \ + (C_TYPE*)moe_output.data_ptr(), \ + (const float*)scores.data_ptr(), \ + (const int32_t*)mapped_slots.data_ptr(), \ + n_channels, \ + n_tokens, \ + at::cuda::getCurrentCUDAStream()); \ + else \ + launch_top2_moe_gather_bwd((C_TYPE*)layer_output_grad.data_ptr(), \ + (float*)scores_grad.data_ptr(), \ + (C_TYPE*)moe_output_grad.data_ptr(), \ + (C_TYPE*)moe_output.data_ptr(), \ + (const float*)scores.data_ptr(), \ + (const int32_t*)mapped_slots.data_ptr(), \ + n_channels, \ + n_tokens, \ + top_k, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + + +void gather_bwd(torch::Tensor& layer_output_grad, + torch::Tensor& scores_grad, + torch::Tensor& moe_output_grad, + torch::Tensor& moe_output, + torch::Tensor& scores, + torch::Tensor& mapped_slots, + int top_k) +{ + const int32_t n_tokens = layer_output_grad.size(0); + const int32_t n_channels = layer_output_grad.size(1); + + DISPATCH_GATHER_BWD(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_GATHER_BWD(kBFloat16, __nv_bfloat16); +#endif + +} diff --git a/deepspeed/tops/moe_gating/moe_gating.py b/deepspeed/tops/moe_gating/moe_gating.py new file mode 100644 index 000000000000..15b694a719f3 --- /dev/null +++ b/deepspeed/tops/moe_gating/moe_gating.py @@ -0,0 +1,321 @@ + +import torch + +from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple + +from deepspeed.ops.op_builder import TopsBuilder +import math +import torch.nn.functional as F +from torch import Tensor + +inf_module = None + +exp_selection_uniform_map: Dict[torch.device, Callable] = {} + +@torch.jit.script +def _top_idx(source, k): + return torch.topk(source, k=k, dim=0)[1] + +@torch.jit.script +def _capacity(gates: Tensor, capacity_factor: Tensor) -> Tensor: + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64) + return capacity + +@torch.jit.script +def _one_hot_to_float(x, num_classes): + return F.one_hot(x, num_classes=num_classes).float() + +class MoEGatingFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, + activations, + logits, + logits_out, + capacity, + use_rst, + is_grad_enabled, + scores , + expert_assignment , + mapped_slots , + expert_offset , + expert_backup_offset, + expert_counts , + mapped_expert_counts, + expert_cumsum , + top_k , + ): + kernel = inf_module.moe_gating_fwd + + n_tokens, hidden_size = activations.shape + _, n_experts = logits.shape + + moe_input_size = n_experts * capacity * top_k + # always cap the size to 256-divisible buffer-size! + if moe_input_size % 256 != 0: + moe_input_size = (256 - moe_input_size % 256) + moe_input_size + + moe_input = torch.zeros( + moe_input_size, + hidden_size, + dtype=activations.dtype, + device=activations.device + ) + if not is_grad_enabled: + expert_counts.zero_() + mapped_expert_counts.zero_() + expert_cumsum.zero_() + + if top_k > 1: + torch_capacity = _capacity(logits, torch.tensor(top_k)) + # Create a mask for 1st's expert per token + indices1_s = torch.argmax(logits, dim=1) + num_experts = int(logits.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + logits_w_noise = logits # + gumbel_rsample(logits.shape, device=logits.device) + # Replace top-expert with min value + logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) + indices2_s = torch.argmax(logits_except1, dim=1) + mask2 = F.one_hot(indices2_s, num_classes=num_experts) + + if top_k > 2: + # do the same for 3 and 4 + logits_w_noise_2 = logits_except1 + logits_except1_2 = logits_w_noise_2.masked_fill(mask2.bool(), float("-inf")) + indices3_s = torch.argmax(logits_except1_2, dim=1) + mask3 = torch.nn.functional.one_hot(indices3_s, num_classes=n_experts) + + logits_w_noise_3 = logits_except1_2 + logits_except1_2_3 = logits_w_noise_3.masked_fill(mask3.bool(), float("-inf")) + indices4_s = torch.argmax(logits_except1_2_3, dim=1) + mask4 = torch.nn.functional.one_hot(indices4_s, num_classes=n_experts) + + # Random Token Selection + uniform = exp_selection_uniform_map.get(logits.device) + if uniform is None: + uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device), + high=torch.tensor(1.0, device=logits.device)).rsample + exp_selection_uniform_map[logits.device] = uniform + mask1_rand = mask1 * uniform(mask1.shape) + mask2_rand = mask2 * uniform(mask2.shape) + if top_k > 2: + mask3_rand = mask3 * uniform(mask3.shape) + mask4_rand = mask4 * uniform(mask4.shape) + + top_idx1 = _top_idx(mask1_rand, torch_capacity) + top_idx2 = _top_idx(mask2_rand, torch_capacity) + if top_k > 2: + top_idx3 = _top_idx(mask3_rand, torch_capacity) + top_idx4 = _top_idx(mask4_rand, torch_capacity) + + mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx1, 1) + mask2 = mask2 * torch.zeros_like(mask2).scatter_(0, top_idx2, 1) + + if top_k > 2: + mask3 = mask3 * torch.zeros_like(mask3).scatter_(0, top_idx3, 1) + mask4 = mask4 * torch.zeros_like(mask4).scatter_(0, top_idx4, 1) + + # Compute locations in capacity buffer + locations1 = torch.cumsum(mask1, dim=0) - 1 + locations2 = torch.cumsum(mask2, dim=0) - 1 + + # Update 2nd's location by accounting for locations of 1st + locations2 += torch.sum(mask1, dim=0, keepdim=True) + if top_k > 2: + locations3 = torch.cumsum(mask3, dim=0) - 1 + locations4 = torch.cumsum(mask4, dim=0) - 1 + locations3 += torch.sum(mask1, dim=0, keepdim=True) + torch.sum(mask2, dim=0, keepdim=True) + locations4 += torch.sum(mask1, dim=0, keepdim=True) + torch.sum(mask2, dim=0, keepdim=True) + torch.sum(mask3, dim=0, keepdim=True) + + # Remove locations outside capacity from mask + mask1 *= torch.lt(locations1, torch_capacity) + mask2 *= torch.lt(locations2, torch_capacity) + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + locations2_s = torch.sum(locations2 * mask2, dim=1) + if top_k > 2: + mask3 *= torch.lt(locations3, torch_capacity) + mask4 *= torch.lt(locations4, torch_capacity) + locations3_s = torch.sum(locations3 * mask3, dim=1) + locations4_s = torch.sum(locations4 * mask4, dim=1) + + if top_k > 2: + expert_offset = torch.cat([locations1_s.to(torch.int32), + locations2_s.to(torch.int32), + locations3_s.to(torch.int32), + locations4_s.to(torch.int32)]).contiguous() + expert_backup_offset = torch.cat([mask1.sum(dim=1).to(torch.int32), + mask2.sum(dim=1).to(torch.int32), + mask3.sum(dim=1).to(torch.int32), + mask4.sum(dim=1).to(torch.int32)]).contiguous() + else: + expert_offset = torch.cat([locations1_s.to(torch.int32), + locations2_s.to(torch.int32)]).contiguous() + expert_backup_offset = torch.cat([mask1.sum(dim=1).to(torch.int32), + mask2.sum(dim=1).to(torch.int32)]).contiguous() + kernel( + moe_input, + expert_cumsum, + mapped_slots, + activations, + expert_counts, + mapped_expert_counts, + scores, + expert_assignment, + expert_offset, + expert_backup_offset, + logits, + logits_out, + top_k, + capacity, + use_rst, + ) + else: + inf_module.moe_gating_scatter( + moe_input, + expert_cumsum, + mapped_slots, + activations, + expert_counts, + mapped_expert_counts, + scores, + expert_assignment, + expert_offset, + expert_backup_offset, + top_k, + capacity, + use_rst, + ) + + ctx.top_k = top_k + ctx.capacity = capacity + ctx.use_rst = use_rst + + if is_grad_enabled: + ctx.save_for_backward(expert_assignment, expert_offset, logits_out, mapped_slots) + + return moe_input, scores, logits_out, expert_counts, mapped_slots, expert_assignment, expert_offset, expert_backup_offset + + @staticmethod + def backward(ctx, moe_inp_grad, scores_grad, logits_grad, expert_counts_grad, mapped_slots_grad, expert_assignment_grad, expert_offset_grad, expert_backup_offset_grad): + (expert_assignment, + expert_offset, + logits, + mapped_slots) = ctx.saved_tensors + + + moe_inp_grad = moe_inp_grad.contiguous() + scores_grad = scores_grad.contiguous() + logits_grad = logits_grad.contiguous() + + _, hidden_size = moe_inp_grad.shape + top_k_tokens = scores_grad.shape[0] + n_tokens = top_k_tokens // ctx.top_k + kernel = inf_module.moe_gating_bwd + + + activations_grad = torch.zeros(n_tokens, hidden_size, dtype=moe_inp_grad.dtype, device=torch.cuda.current_device()) + kernel( + moe_inp_grad, + scores_grad, + activations_grad, + logits_grad, + logits, + expert_assignment, + expert_offset, + mapped_slots, + ctx.top_k, + ctx.capacity, + ctx.use_rst + ) + return activations_grad, logits_grad, logits_grad, None, None, None, scores_grad, expert_assignment_grad, mapped_slots_grad, None, None, expert_counts_grad, None, None, None + +class MoEGating(torch.nn.Module): + """ + CUDA implementation of top-1 gating. This will perform a softmax on the logits, + and return the scale as well as its idx within that expert's allocation. + """ + + + def __init__(self, + logit_dtype=torch.bfloat16, + n_tokens=16384, + hidden_size=3072, + n_experts=64, + top_k=1, + use_floored_capacity=True, + compute_aux_loss=False, + use_act_ckpting=False) -> None: + super(MoEGating, self).__init__() + global inf_module + if inf_module is None: + inf_module = TopsBuilder().load() + self.scores = torch.empty(n_tokens * top_k, dtype=torch.float32, device=torch.cuda.current_device()) + self.expert_assignment = torch.empty(n_tokens * top_k, dtype=torch.int32, device=torch.cuda.current_device()) + + self.mapped_slots = torch.empty(n_tokens * top_k, dtype=torch.int32, device=torch.cuda.current_device()) + self.expert_offset = torch.empty(n_tokens * top_k, dtype=torch.int32, device=torch.cuda.current_device()) + self.expert_backup_offset = torch.empty(n_tokens * top_k, dtype=torch.int32, device=torch.cuda.current_device()) + + self.expert_counts = torch.empty(n_experts * top_k, dtype=torch.int32, device=torch.cuda.current_device()) + self.mapped_expert_counts = torch.empty(n_experts * top_k, dtype=torch.int32, device=torch.cuda.current_device()) + self.expert_cumsum = torch.empty(n_experts * top_k, dtype=torch.int32, device=torch.cuda.current_device()) + self.logits = torch.empty(n_tokens, n_experts, dtype=torch.float32, device=torch.cuda.current_device()) + self.top_k = top_k + self.use_floored_capacity = use_floored_capacity + self.compute_aux_loss = compute_aux_loss + self.use_act_ckpting = use_act_ckpting + + def forward(self, + activations: torch.Tensor, + logits: torch.Tensor, + capacity_factor: float, + use_rst: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Perform top_1_gating and token scatter. + """ + + n_tokens = activations.shape[0] + n_experts = logits.shape[-1] + + capacity = int(capacity_factor * (n_tokens / n_experts)) if self.use_floored_capacity else math.ceil(capacity_factor * (n_tokens / n_experts)) + + is_grad_enabled = self.use_act_ckpting and torch.is_grad_enabled() + + (moe_input, + self.scores, + self.logits, + self.expert_counts, + self.mapped_slots, + self.expert_assignment, + self.expert_offset, + self.expert_backup_offset) = MoEGatingFunction.apply( + activations, + logits, + self.logits, + capacity, + use_rst, + is_grad_enabled, + self.scores , + self.expert_assignment , + self.mapped_slots , + self.expert_offset , + self.expert_backup_offset, + self.expert_counts , + self.mapped_expert_counts, + self.expert_cumsum , + self.top_k + ) + if self.compute_aux_loss: + if self.top_k == 1: + l_aux = (torch.mean(self.logits, dim=0) * self.expert_counts[: n_experts] / n_tokens).sum() * n_experts + else: + l_aux = torch.mean((torch.mean(self.logits, dim=0) * self.expert_counts[: n_experts] / n_tokens)) * n_experts * n_experts + + return l_aux, moe_input, self.scores, self.mapped_slots + else: + return self.logits, moe_input, self.scores, self.mapped_slots \ No newline at end of file diff --git a/deepspeed/tops/moe_gating/test_moe_gating.py b/deepspeed/tops/moe_gating/test_moe_gating.py new file mode 100644 index 000000000000..d1d33bf521fe --- /dev/null +++ b/deepspeed/tops/moe_gating/test_moe_gating.py @@ -0,0 +1,125 @@ +import deepspeed +from deepspeed.tops import MoEGating, MoEGather +import torch + +import time + +import torch.nn.functional as F +from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple +from torch import Tensor + + +exp_selection_uniform_map: Dict[torch.device, Callable] = {} + +@torch.jit.script +def _top_idx(source, k): + return torch.topk(source, k=k, dim=0)[1] + +@torch.jit.script +def _one_hot_to_float(x, num_classes): + return F.one_hot(x, num_classes=num_classes).float() + +def top1gating(logits: Tensor, + capacity: int, + use_rts: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + + # everything is in fp32 in this function + gates = F.softmax(logits, dim=1) + + capacity = torch.tensor(capacity).to(torch.int64) + + # Create a mask for 1st's expert per token + # noisy gating + indices1_s = torch.argmax(gates, dim=1) + num_experts = int(gates.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + # gating decisions + exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + + # Compute l_aux + me = torch.mean(gates, dim=0) + ce = torch.mean(mask1.float(), dim=0) + l_aux = me.mul_(ce).sum() * num_experts + + # Random Token Selection + if use_rts: + uniform = exp_selection_uniform_map.get(logits.device) + if uniform is None: + uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device), + high=torch.tensor(1.0, device=logits.device)).rsample + exp_selection_uniform_map[logits.device] = uniform + + mask1_rand = mask1 #* uniform(mask1.shape) + else: + mask1_rand = mask1 + + top_idx = _top_idx(mask1_rand, capacity) + + new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) + mask1 = new_mask1 + + locations1 = torch.cumsum(mask1, dim=0) - 1 + + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + + # Normalize gate probabilities + mask1_float = mask1.float() + gates = gates * mask1_float + + locations1_sc = _one_hot_to_float(locations1_s, capacity) + combine_weights = torch.einsum("se,sc->sec", gates, locations1_sc) + + dispatch_mask = combine_weights.bool() + + return l_aux, combine_weights, dispatch_mask, exp_counts + + +n_tokens = 4096*4 +n_experts = 64 +hidden_size = 3072 +capacity = 256 + +hidden_states = torch.randn(n_tokens, hidden_size, requires_grad=True).bfloat16().cuda() +hidden_states1 = hidden_states.clone() #torch.ones(n_tokens, hidden_size, requires_grad=True).bfloat16().cuda() +logits = torch.randn(n_tokens, n_experts, requires_grad=True).cuda() +logits_f = logits.clone() #torch.ones(n_tokens, n_experts, requires_grad=True).cuda() +# logits = logits.bfloat16() +weight = torch.randn(hidden_size, hidden_size, requires_grad=True).cuda().bfloat16() +weight1 = weight.clone() #torch.ones(hidden_size, hidden_size, requires_grad=True).cuda().bfloat16() +moe_gating = MoEGating(top_k=1) +moe_gather = MoEGather(top_k=1) + +def run_baseline(logits, hidden_states): + gate_out = top1gating(logits, n_tokens // n_experts) + dispatched_input = torch.einsum("sec,sm->ecm", gate_out[2].type_as(hidden_states), hidden_states) + out = torch.matmul(dispatched_input, weight1) + out = torch.einsum("sec,ecm->sm", gate_out[1].type_as(hidden_states), out) + return gate_out[0], out + +def run_deepspeed(hidden_states, logits): + l_aux, moe_inp, scores, mapped_slots = moe_gating(hidden_states, logits, 1.0) + out = torch.matmul(moe_inp, weight) + out = moe_gather(out, scores, mapped_slots,) + return l_aux, out, scores, mapped_slots + + +logits_f.retain_grad() +hidden_states1.retain_grad() + +logits.retain_grad() +hidden_states.retain_grad() + +weight.retain_grad() +weight1.retain_grad() + +l_aux, moe_input, scores, mapped_slots = run_deepspeed(hidden_states, logits) +print(l_aux.item(), moe_input.norm().item(), hidden_states.norm().item()) +loss = l_aux + moe_input.sum() +loss.backward() + +l_aux1, moe_input1 = run_baseline(logits_f, hidden_states1) +print(l_aux1.item(), moe_input1.norm().item(), hidden_states1.norm().item()) +loss1 = moe_input1.sum() + l_aux1 +loss1.backward() diff --git a/deepspeed/tops/moe_gating/top1_moe_gating.cu b/deepspeed/tops/moe_gating/top1_moe_gating.cu new file mode 100644 index 000000000000..7cfea2272f9e --- /dev/null +++ b/deepspeed/tops/moe_gating/top1_moe_gating.cu @@ -0,0 +1,781 @@ +#include "moe_gating.cuh" +#include "reduction_utils.h" +#include "stdio.h" +#include "tops_context.h" + +using ROp = reduce::ROpType; + + + +__global__ void top_1_gating_kernel(int32_t* expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + float* logits, + float* logits_out, + const int capacity, + const int32_t n_experts, + const int32_t n_tokens) +{ + const int32_t token_idx = blockIdx.x; + const int32_t expert_idx = threadIdx.x; + const int32_t max_warps = 1024 / hw_warp_size; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + float* token_logits = logits + token_idx * n_experts; + float* token_logits_out = logits_out + token_idx * n_experts; + + float logit_val; + if (expert_idx < n_experts) + logit_val = token_logits[expert_idx]; + else { + reduce::init(&logit_val); + } + + int32_t inverted_expert = n_experts - expert_idx - 1; + // Perform softmax + const reduce::IdxReduceResult res = + reduce::idx_reduce(tb, warp, logit_val, inverted_expert); + // Recover the original expert index + const int32_t assigned_expert = n_experts - res.idx - 1; + const float max_logit = res.val; + + float softlogit = __expf(logit_val - max_logit); + float softmax_sum = softlogit; + reduce::block(tb, warp, softmax_sum); + + // Compute the score + const float score = 1.0 / softmax_sum; + if (expert_idx < n_experts) token_logits_out[expert_idx] = softlogit / softmax_sum; + + if (threadIdx.x == 0) + { + atomicAdd(expert_counts + assigned_expert, 1); + scores[token_idx] = score; + assignments[token_idx] = assigned_expert; + } +} + +template +__global__ void refine_expert_mapping(int32_t* expert_counts, + int32_t* mapped_expert_counts, + int32_t* expert_count_cumsums, + float* scores, + int32_t* assignments, + int32_t* offsets, + int32_t* backup_offsets, + const int capacity, + const int32_t n_tokens, + const int32_t n_experts, + std::pair seed){ + + const int32_t bidx = blockIdx.x; + const int32_t tidx = threadIdx.x; + int32_t token_idx = bidx * blockDim.x + tidx; + if (token_idx >= n_tokens) { + return; + } + + int32_t assignment = assignments[token_idx]; + + int idx = token_idx << 2; + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + int32_t total_ec = expert_counts[assignment]; + float ratio = 1.0 - (float)(capacity) / (float)total_ec; + + float4 rand_val = curand_uniform4(&state); + + + { + if ((rand_val.z > ratio && rand_val.x > ratio && rand_val.y > ratio) || + (rand_val.z <= ratio && (rand_val.x > ratio || rand_val.y > ratio))) { + // if (true) { + offsets[token_idx] = atomicAdd(mapped_expert_counts + assignment, 1); + //backup_offsets[token_idx] = atomicAdd(expert_count_cumsums + assignment, 1);// gating::unassigned; + + } + else{ + offsets[token_idx] = gating::unassigned; + backup_offsets[token_idx] = atomicAdd(expert_count_cumsums + assignment, 1);// gating::unassigned; + + // assignments[token_idx] = n_experts; //gating::unassigned; + // scores[token_idx] = 0.f; + } // need to set these tokens to Zero! + } +} + + +__global__ void gate_logits_bwd_kernel(float* logits_grad, + float* scores_grad, + const int32_t* assignment, + float* logits, + const int32_t n_experts, + const int32_t n_tokens) +{ + const int32_t token_idx = blockIdx.x; + const int32_t expert_idx = threadIdx.x; + int32_t assigned_expert = assignment[token_idx]; + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // Padding tokens do not require + if (token_idx >= n_tokens) { + return; + } + + float* token_logits = logits + token_idx * n_experts; + float* token_logits_grad = logits_grad + token_idx * n_experts; + + float logit_val; + float logit_grad_val; + + if (expert_idx < n_experts) { + logit_val = token_logits[expert_idx]; + logit_grad_val = token_logits_grad[expert_idx]; + } else { + reduce::init(&logit_val); + reduce::init(&logit_grad_val); + } + + if (assigned_expert == expert_idx) { + logit_grad_val += scores_grad[token_idx]; + } + float softmax_grad_sum = logit_val * logit_grad_val; + reduce::block(tb, warp, softmax_grad_sum); + logit_grad_val = logit_val * (logit_grad_val - softmax_grad_sum); + + if (expert_idx < n_experts) + token_logits_grad[expert_idx] = logit_grad_val; +} + +template +__global__ void moe_gather_kernel(T* layer_output, + const T* moe_output, + const float* scores, + const int32_t* mapped_slots, + const int32_t n_channels, + const int32_t num_tokens) +{ + constexpr int32_t vector_size = scatter_gather::access_granularity / sizeof(T); + constexpr int32_t stride = vector_size * scatter_gather::threads; + + const int32_t token_idx = blockIdx.x; + const int32_t mapped_slot = mapped_slots[token_idx]; + + const float score = scores[token_idx]; + const int32_t channel_offset = threadIdx.x * vector_size; + + const T* moe_output_base = moe_output + mapped_slot * n_channels + channel_offset; + T* layer_output_base = layer_output + token_idx * n_channels + channel_offset; + +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + T reg_buffer[vector_size]; + + if (i * stride + channel_offset < n_channels) { + if (mapped_slot != gating::unassigned && mapped_slot < num_tokens) + { + mem_access::load_global( + reg_buffer, + moe_output_base + i * stride + ); +#pragma unroll + for (int j = 0; j < vector_size; j++) { + float up_cast = conversion::to(reg_buffer[j]); + reg_buffer[j] = conversion::to(up_cast * score); + } + } + else{ +#pragma unroll + for (int j = 0; j < vector_size; j++) { + reg_buffer[j] = conversion::to(0.f); + } + } + + mem_access::store_global( + layer_output_base + i * stride, + reg_buffer + ); + } + } +} + +template +__global__ void moe_gather_bwd_kernel(T* layer_output_grad, + float* scores_grad, + T* moe_output_grad, + T* moe_output, + const float* scores, + const int32_t* mapped_slots, + const int32_t n_channels, + const int32_t num_tokens) +{ + constexpr int32_t vector_size = scatter_gather::access_granularity / sizeof(T); + constexpr int32_t stride = vector_size * scatter_gather::threads; + + const int32_t token_idx = blockIdx.x; + const int32_t mapped_slot = mapped_slots[token_idx]; + + + const float score = scores[token_idx]; + const int32_t channel_offset = threadIdx.x * vector_size; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + T* layer_output_grad_base = layer_output_grad + token_idx * n_channels + channel_offset; + T* moe_output_grad_base = moe_output_grad + mapped_slot * n_channels + channel_offset; + T* moe_output_base = moe_output + mapped_slot * n_channels + channel_offset; + float score_grad = 0.f; +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + T reg_buffer[vector_size]; + T out_buffer[vector_size]; + + if (i * stride + channel_offset < n_channels) { + if (mapped_slot != gating::unassigned && mapped_slot < num_tokens) + { + mem_access::load_global(reg_buffer, + layer_output_grad_base + i * stride); + + mem_access::load_global(out_buffer, + moe_output_base + i * stride); + +#pragma unroll + for (int j = 0; j < vector_size; j++) { + float up_cast = conversion::to(reg_buffer[j]); + float out_up_cast = conversion::to(out_buffer[j]); + reg_buffer[j] = conversion::to(up_cast * score); + score_grad += (up_cast * out_up_cast); + } + mem_access::store_global(moe_output_grad_base + i * stride, + reg_buffer); + } + + } + } + + reduce::_block(tb, warp, &score_grad); + if (threadIdx.x == 0) scores_grad[token_idx] = score_grad; +} + +template +__global__ void moe_scatter_bwd_kernel(T* moe_input_grad, + T* activations_grad, + const int32_t* assignments, + const int32_t* offsets, + const int32_t n_channels, + const int capacity, + const int32_t num_tokens, + const int32_t n_experts) +{ + constexpr int32_t vector_size = scatter_gather::access_granularity / sizeof(T); + constexpr int32_t load_stride = vector_size * scatter_gather::threads; + + const int32_t token_idx = blockIdx.x; + const int32_t tidx = threadIdx.x; + const int32_t warp_rank = tidx / hw_warp_size; + + // Bank aligned and sufficient + __shared__ int32_t red_buffer[32]; + __shared__ int32_t token_0_row; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + int assigned_expert = assignments[token_idx]; + + // For the different codepaths, we'll converge on this variable for doing + // the token copy. + int32_t token_base_row; + + + token_base_row = capacity * assigned_expert; + + // Data copy to appropriate location + const int32_t thread_offset = tidx * vector_size; + + const int32_t base_load_offset = token_idx * n_channels + thread_offset; + T* load_base_ptr = activations_grad + base_load_offset; + int32_t offset = offsets[token_idx]; + const int32_t store_row = token_base_row + offset; + const int32_t base_store_offset = store_row * n_channels + thread_offset; + T* store_base_ptr = moe_input_grad + base_store_offset; + +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + T tmp_buf[vector_size]; + + if (i * load_stride + thread_offset < n_channels) { + if (assigned_expert < n_experts && offset != gating::unassigned && offset < capacity && store_row < num_tokens) + mem_access::load_global(tmp_buf, store_base_ptr + i * load_stride); + else + { + #pragma unroll + for (int k = 0; k < vector_size; k++) + tmp_buf[k] = conversion::to(0.f); + } + + mem_access::store_global(load_base_ptr + i * load_stride, tmp_buf); + } + } +} + +template +__global__ void moe_scatter_kernel(T* moe_input, + int32_t* expert_count_cumsums, + int32_t* mapped_slots, + float* scores, + const T* activations, + int32_t* assignments, + const int32_t* expert_counts, + const int32_t* mapped_expert_counts, + int32_t* offsets, + const int32_t* backup_offsets, + const int32_t n_channels, + const int32_t n_experts, + const int capacity, + const int32_t num_tokens) +{ + constexpr int32_t vector_size = scatter_gather::access_granularity / sizeof(T); + constexpr int32_t load_stride = vector_size * scatter_gather::threads; + + const int32_t token_idx = blockIdx.x; + const int32_t tidx = threadIdx.x; + const int32_t warp_rank = tidx / hw_warp_size; + + // Bank aligned and sufficient + __shared__ int32_t red_buffer[32]; + __shared__ int32_t token_0_row; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + int assigned_expert = assignments[token_idx]; + if (assigned_expert >= n_experts || assigned_expert == gating::unassigned) { + // For whatever reason, don't need to perform the copy, so we'll early return + // and signal this wasn't mapped with a negative 1. + + if (tidx == 0) mapped_slots[token_idx] = gating::unassigned; + return; + } + int expert_count = expert_counts[assigned_expert]; + int mapped_expert_count = mapped_expert_counts[assigned_expert]; + + // For the different codepaths, we'll converge on this variable for doing + // the token copy. + int32_t token_base_row; + int32_t offset = offsets[token_idx]; + int32_t other_offset = backup_offsets[token_idx]; + //if (offset == gating::unassigned && expert_count <= capacity) { + // if (tidx == 0) { + // assignments[token_idx] = n_experts; + // mapped_slots[token_idx] = gating::unassigned; + // } + // return; + //} + //else if (mapped_expert_count != capacity){ + // offset = backup_offsets[token_idx] + mapped_expert_count; + // if (tidx == 0) printf("Coming here: E(%d), T(%d), MEC(%d), O(%d)\n", assigned_expert, token_idx, mapped_expert_count, offset); + //} +// + if (offset == gating::unassigned && expert_count > capacity && mapped_expert_count < capacity) + { + offset = backup_offsets[token_idx] + mapped_expert_count; + if (tidx == 0) offsets[token_idx] = offset; + } + // if (other_offset == 0)//>= capacity || offset == gating::unassigned) + if (offset >= capacity || offset == gating::unassigned) + { + if (tidx == 0) { + mapped_slots[token_idx] = gating::unassigned; + assignments[token_idx] = n_experts; //gating::unassigned; + scores[token_idx] = 0.f; + } + return; + } + //else{ + // if (tidx == 0) offsets[token_idx] = offset; + //} + + + token_base_row = capacity * assigned_expert; + + // Data copy to appropriate location + const int32_t thread_offset = tidx * vector_size; + + const int32_t base_load_offset = token_idx * n_channels + thread_offset; + const T* load_base_ptr = activations + base_load_offset; + const int32_t store_row = token_base_row + offset; + const int32_t base_store_offset = store_row * n_channels + thread_offset; + T* store_base_ptr = moe_input + base_store_offset; +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + T tmp_buf[vector_size]; + + if (i * load_stride + thread_offset < n_channels && store_row < num_tokens) { + mem_access::load_global(tmp_buf, + load_base_ptr + i * load_stride); + mem_access::store_global(store_base_ptr + i * load_stride, + tmp_buf); + } + } + + if (threadIdx.x == 0) { + mapped_slots[token_idx] = (store_row < num_tokens) ? store_row : gating::unassigned; + } +} + +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_scatter_kernel<<>>(moe_input, \ + expert_count_cumsums, \ + mapped_slots, \ + scores, \ + activations, \ + assignments, \ + expert_counts, \ + mapped_expert_counts, \ + offsets, \ + backup_offsets, \ + n_channels, \ + n_experts, \ + capacity, \ + n_tokens); \ + break; + + +template +void launch_moe_gating(T* moe_input, + int32_t* expert_count_cumsums, + int32_t* mapped_slots, + const T* activations, + int32_t* expert_counts, + int32_t* mapped_expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + int32_t* backup_offsets, + float* logits, + float* logits_out, + const int capacity, + const int32_t n_tokens, + const int32_t n_channels, + const int32_t n_experts, + cudaStream_t stream) +{ + const dim3 grid(n_tokens); + const dim3 block(((n_experts - 1) / hw_warp_size + 1) * hw_warp_size); + + std::pair seed = TOPSContext::Instance().IncrementOffset(16); + + top_1_gating_kernel<<>>( + expert_counts, + scores, + assignments, + offsets, + logits, + logits_out, + capacity, + n_experts, + n_tokens + ); + const dim3 block2(scatter_gather::threads); + const dim3 grid2((n_tokens - 1) / scatter_gather::threads + 1); + + refine_expert_mapping<<>>( + expert_counts, + mapped_expert_counts, + expert_count_cumsums, + scores, + assignments, + offsets, + backup_offsets, + capacity, + n_tokens, + n_experts, + seed + ); + + constexpr int vals_per_unroll = scatter_gather::threads * scatter_gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block1(scatter_gather::threads); + const dim3 grid1(n_tokens); + + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1); + LAUNCH_FOR_UNROLL(2); + LAUNCH_FOR_UNROLL(3); + LAUNCH_FOR_UNROLL(4); + LAUNCH_FOR_UNROLL(5); + LAUNCH_FOR_UNROLL(6); + } +} + +#define INSTANTIATE_MoE_Gating_KERNEL(T) \ + template void launch_moe_gating(T* moe_input, \ + int32_t* expert_count_cumsums, \ + int32_t* mapped_slots, \ + const T* activations, \ + int32_t * expert_counts, \ + int32_t * mapped_expert_counts, \ + float* scores, \ + int32_t* assignments, \ + int32_t* offsets, \ + int32_t* backup_offsets, \ + float* logits, \ + float* logits_out, \ + const int capacity, \ + const int32_t n_tokens, \ + const int32_t n_channels, \ + const int32_t n_experts, \ + cudaStream_t stream); + +INSTANTIATE_MoE_Gating_KERNEL(float) +INSTANTIATE_MoE_Gating_KERNEL(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_MoE_Gating_KERNEL(__nv_bfloat16) +#endif + + +template +void launch_moe_scatter(T* moe_input, + int32_t* expert_count_cumsums, + int32_t* mapped_slots, + const T* activations, + int32_t* expert_counts, + int32_t* mapped_expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + int32_t* backup_offsets, + const int capacity, + const int32_t n_tokens, + const int32_t n_channels, + const int32_t n_experts, + cudaStream_t stream) +{ + constexpr int vals_per_unroll = scatter_gather::threads * scatter_gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block1(scatter_gather::threads); + const dim3 grid1(n_tokens); + + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1); + LAUNCH_FOR_UNROLL(2); + LAUNCH_FOR_UNROLL(3); + LAUNCH_FOR_UNROLL(4); + LAUNCH_FOR_UNROLL(5); + LAUNCH_FOR_UNROLL(6); + } +} + +#define INSTANTIATE_MoE_SCATTER_KERNEL(T) \ + template void launch_moe_scatter(T* moe_input, \ + int32_t* expert_count_cumsums, \ + int32_t* mapped_slots, \ + const T* activations, \ + int32_t * expert_counts, \ + int32_t * mapped_expert_counts, \ + float* scores, \ + int32_t* assignments, \ + int32_t* offsets, \ + int32_t* backup_offsets, \ + const int capacity, \ + const int32_t n_tokens, \ + const int32_t n_channels, \ + const int32_t n_experts, \ + cudaStream_t stream); + +INSTANTIATE_MoE_SCATTER_KERNEL(float) +INSTANTIATE_MoE_SCATTER_KERNEL(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_MoE_SCATTER_KERNEL(__nv_bfloat16) +#endif + + +#define LAUNCH_FOR_UNROLL_GATHER(COUNT) \ + case COUNT: \ + moe_gather_kernel<<>>(layer_output, \ + moe_output, \ + scores, \ + mapped_slots, \ + n_channels, \ + n_tokens); \ + break; + +template +void launch_moe_gather(T* layer_output, + const T* moe_output, + const float* scores, + const int32_t* mapped_slots, + const int32_t n_channels, + const int32_t n_tokens, + cudaStream_t stream) +{ + constexpr int vals_per_unroll = scatter_gather::threads * scatter_gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block(scatter_gather::threads); + const dim3 grid(n_tokens); + + switch (copy_unroll) { + LAUNCH_FOR_UNROLL_GATHER(1) + LAUNCH_FOR_UNROLL_GATHER(2) + LAUNCH_FOR_UNROLL_GATHER(3) + LAUNCH_FOR_UNROLL_GATHER(4) + LAUNCH_FOR_UNROLL_GATHER(5) + LAUNCH_FOR_UNROLL_GATHER(6) + } +} + +#define INSTANTIATE_GATHER_FOR_TYPE(TYPE) \ + template void launch_moe_gather(TYPE * layer_output, \ + const TYPE* moe_output, \ + const float* scores, \ + const int32_t* mapped_slots, \ + const int32_t n_channels, \ + const int32_t n_tokens, \ + cudaStream_t stream); \ + +INSTANTIATE_GATHER_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_GATHER_FOR_TYPE(__nv_bfloat16) +#endif + + +#define LAUNCH_FOR_UNROLL_GATHER_BWD(COUNT) \ + case COUNT: \ + moe_gather_bwd_kernel<<>>(layer_output_grad, \ + scores_grad, \ + moe_output_grad, \ + moe_output, \ + scores, \ + mapped_slots, \ + n_channels, \ + n_tokens); \ + break; + +template +void launch_moe_gather_bwd(T* layer_output_grad, + float* scores_grad, + T* moe_output_grad, + T* moe_output, + const float* scores, + const int32_t* mapped_slots, + const int32_t n_channels, + const int32_t n_tokens, + cudaStream_t stream) +{ + constexpr int vals_per_unroll = scatter_gather::threads * scatter_gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block(scatter_gather::threads); + const dim3 grid(n_tokens); + + switch (copy_unroll) { + LAUNCH_FOR_UNROLL_GATHER_BWD(1) + LAUNCH_FOR_UNROLL_GATHER_BWD(2) + LAUNCH_FOR_UNROLL_GATHER_BWD(3) + LAUNCH_FOR_UNROLL_GATHER_BWD(4) + LAUNCH_FOR_UNROLL_GATHER_BWD(5) + LAUNCH_FOR_UNROLL_GATHER_BWD(6) + } +} + +#define INSTANTIATE_GATHER_BWD_FOR_TYPE(TYPE) \ + template void launch_moe_gather_bwd(TYPE * layer_output_grad, \ + float* scores_grad, \ + TYPE* moe_output_grad, \ + TYPE* moe_output, \ + const float* scores, \ + const int32_t* mapped_slots, \ + const int32_t n_channels, \ + const int32_t n_tokens, \ + cudaStream_t stream); + +INSTANTIATE_GATHER_BWD_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_GATHER_BWD_FOR_TYPE(__nv_bfloat16) +#endif + + +#define LAUNCH_FOR_UNROLL_MOE_BWD(COUNT) \ + case COUNT: \ + moe_scatter_bwd_kernel<<>>(moe_input_grad, \ + activations_grad, \ + assignments, \ + offsets, \ + n_channels, \ + capacity, \ + n_tokens, \ + n_experts); \ + break; + +template +void launch_moe_gating_bwd(T* moe_input_grad, + float* scores_grad, + T* activations_grad, + float* logits_grad, + float* logits, + const int32_t* assignments, + const int32_t* offsets, + const int32_t n_channels, + const int32_t n_experts, + const int32_t n_tokens, + int capacity, + cudaStream_t stream) +{ + const dim3 grid(n_tokens); + const dim3 block(((n_experts - 1) / hw_warp_size + 1) * hw_warp_size); + + gate_logits_bwd_kernel<<>> (logits_grad, scores_grad, assignments, logits, n_experts, n_tokens); + + constexpr int vals_per_unroll = scatter_gather::threads * scatter_gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block1(scatter_gather::threads); + const dim3 grid1(n_tokens); + + switch (copy_unroll) { + LAUNCH_FOR_UNROLL_MOE_BWD(1); + LAUNCH_FOR_UNROLL_MOE_BWD(2); + LAUNCH_FOR_UNROLL_MOE_BWD(3); + LAUNCH_FOR_UNROLL_MOE_BWD(4); + LAUNCH_FOR_UNROLL_MOE_BWD(5); + LAUNCH_FOR_UNROLL_MOE_BWD(6); + } +} + +#define INSTANTIATE_MOE_GATING_BWD_FOR_TYPE(TYPE) \ + template void launch_moe_gating_bwd(TYPE * moe_input_grad, \ + float* scores_grad, \ + TYPE* activations_grad, \ + float* logits_grad, \ + float* logits, \ + const int32_t* assignments, \ + const int32_t* offsets, \ + const int32_t n_channels, \ + const int32_t n_experts, \ + const int32_t n_tokens, \ + int capacity, \ + cudaStream_t stream); \ + +INSTANTIATE_MOE_GATING_BWD_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_MOE_GATING_BWD_FOR_TYPE(__nv_bfloat16) +#endif diff --git a/deepspeed/tops/moe_gating/top2_moe_gating.cu b/deepspeed/tops/moe_gating/top2_moe_gating.cu new file mode 100644 index 000000000000..779893a43589 --- /dev/null +++ b/deepspeed/tops/moe_gating/top2_moe_gating.cu @@ -0,0 +1,930 @@ +#include "moe_gating.cuh" +#include "reduction_utils.h" +#include "stdio.h" +#include "tops_context.h" + +using ROp = reduce::ROpType; + + +template +__global__ void top_2_gating_kernel(int32_t* expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + const float* logits, + float* logits_out, + const int32_t n_experts, + const int32_t n_tokens) +{ + const int32_t token_idx = blockIdx.x; + const int32_t expert_idx = threadIdx.x; + const int32_t max_warps = 1024 / hw_warp_size; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + const float* token_logits = logits + token_idx * n_experts; + float* token_logits_out = logits_out + token_idx * n_experts; + + float logit_val; + if (expert_idx < n_experts) { + logit_val = token_logits[expert_idx]; + } else { + reduce::init(&logit_val); + } + float reduce_val = logit_val; + + int32_t local_assigned_experts[TOP_K]; + float local_assigned_logits[TOP_K]; + + // Training code tends to use ``torch.argmax`` to select the expert, which + // which has ties broken by the lower index. Since our fused comparison algorithm + // breaks ties by the higher index (since it's the lower 32-bits of the 64-bit + // comparison), we invert the expert index to break ties by the lower index. + int32_t inverted_expert = n_experts - expert_idx - 1; + + // Find the top k logits + for (int i = 0; i < TOP_K; ++i) { + const reduce::IdxReduceResult res = + reduce::idx_reduce(tb, warp, reduce_val, inverted_expert); + local_assigned_experts[i] = n_experts - res.idx - 1; + local_assigned_logits[i] = res.val; + + // Set the max logit to -inf so that it is not selected again + if (threadIdx.x == n_experts - res.idx - 1) { reduce::init(&reduce_val); } + } + + const float max_logit = local_assigned_logits[0]; + float softlogit = __expf(logit_val - max_logit); + float softmax_sum = softlogit; + + reduce::block(tb, warp, softmax_sum); + if (expert_idx < n_experts) + token_logits_out[expert_idx] = softlogit / softmax_sum; + + if (threadIdx.x == 0) { + #pragma unroll + for (int i = 0; i < TOP_K; ++i) { + scores[token_idx * TOP_K + i] = __expf(local_assigned_logits[i] - max_logit) / softmax_sum; + assignments[token_idx * TOP_K + i] = local_assigned_experts[i]; + atomicAdd(expert_counts + n_experts * i + local_assigned_experts[i], 1); + } + } +} + +template +__global__ void refine_expert_mapping_for_top2(int32_t* expert_counts, + int32_t* mapped_expert_counts, + int32_t* expert_count_cumsums, + float* scores, + int32_t* assignments, + int32_t* offsets, + int32_t* backup_offsets, + const int capacity, + const int32_t n_tokens, + const int32_t n_experts, + std::pair seed){ + + const int32_t bidx = blockIdx.x; + const int32_t tidx = threadIdx.x; + int32_t token_idx = bidx * blockDim.x + tidx; + if (token_idx >= n_tokens) { + return; + } + + int32_t assignment[TOP_K]; + #pragma unroll + for (int i = 0; i < TOP_K; ++i) + assignment[i] = assignments[(token_idx * TOP_K) + i]; + + int idx = token_idx << (1 + TOP_K); + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + +#pragma unroll + for (int i = 0; i < TOP_K; i++) { + float4 rand_val = curand_uniform4(&state); + + int32_t total_ec = expert_counts[assignment[i] + n_experts * i]; + float ratio = 1.0 - (float)(capacity * (TOP_K - i)) / (float)total_ec; + { + if ((rand_val.z > ratio && rand_val.x > ratio && rand_val.y > ratio) || + (rand_val.z <= ratio && (rand_val.x > ratio || rand_val.y > ratio))) { + offsets[token_idx * TOP_K + i] = atomicAdd(mapped_expert_counts + n_experts * i + assignment[i], 1); + } + else{ + offsets[token_idx * TOP_K + i] = gating::unassigned; + backup_offsets[token_idx * TOP_K + i] = atomicAdd(expert_count_cumsums + n_experts * i + assignment[i], 1); + } + } + } +} + +template +__global__ void top2_gate_logits_bwd_kernel(float* logits_grad, + float* scores_grad, + const int32_t* assignment, + float* logits, + const int32_t n_experts, + const int32_t n_tokens) +{ + const int32_t token_idx = blockIdx.x; + const int32_t expert_idx = threadIdx.x; + + int32_t assigned_expert[TOP_K]; + +#pragma unroll + for (int i = 0; i < TOP_K; i++) + assigned_expert[i] = assignment[token_idx * TOP_K + i]; + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // Padding tokens do not require + if (token_idx >= n_tokens) { + return; + } + + float* token_logits = logits + token_idx * n_experts; + float* token_logits_grad = logits_grad + token_idx * n_experts; + + float logit_val; + float logit_grad_val; + if (expert_idx < n_experts) { + logit_val = token_logits[expert_idx]; + logit_grad_val = token_logits_grad[expert_idx]; + } else { + reduce::init(&logit_val); + reduce::init(&logit_grad_val); + } + + +#pragma unroll + for (int i = 0; i < TOP_K; i++) + { + if (assigned_expert[i] == expert_idx) { + logit_grad_val += scores_grad[token_idx * TOP_K + i]; + } + } + float softmax_grad_sum = logit_val * logit_grad_val; + reduce::block(tb, warp, softmax_grad_sum); + logit_grad_val = logit_val * (logit_grad_val - softmax_grad_sum); + if (expert_idx < n_experts) + token_logits_grad[expert_idx] = logit_grad_val; +} + +template +__global__ void moe_top2_gather_kernel(T* layer_output, + const T* moe_output, + const float* scores, + const int32_t* mapped_slots, + const int32_t n_channels, + const int32_t num_tokens) +{ + constexpr int32_t vector_size = scatter_gather::access_granularity / sizeof(T); + constexpr int32_t stride = vector_size * scatter_gather::threads; + + const int32_t token_idx = blockIdx.x; + + int32_t mapped_slot[TOP_K]; + float score[TOP_K]; + + + float sum = 0.0f; + +#pragma unroll + for (int i = 0; i < TOP_K; i++) mapped_slot[i] = mapped_slots[token_idx * TOP_K + i]; + +#pragma unroll + for (int i = 0; i < TOP_K; i++) { + if (mapped_slot[i] != gating::unassigned && mapped_slot[i] < (num_tokens * TOP_K)) + { + score[i] = scores[token_idx * TOP_K + i]; + sum += score[i]; + } + } + sum += 1.192092895e-07; + + const int32_t channel_offset = threadIdx.x * vector_size; + + T* layer_output_base = layer_output + token_idx * n_channels + channel_offset; + +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + if (i * stride + channel_offset < n_channels) + { + float accumulator[vector_size]; + if (mapped_slot[0] != gating::unassigned && mapped_slot[0] < (num_tokens * TOP_K)) + { + T read_buf[vector_size]; + const T* moe_output_base = moe_output + mapped_slot[0] * n_channels + channel_offset; + mem_access::load_global( + read_buf, + moe_output_base + i * stride + ); +#pragma unroll + for (int j = 0; j < vector_size; j++) { + float up_cast = conversion::to(read_buf[j]); + accumulator[j] = up_cast * (score[0] / sum); + } + } + else{ +#pragma unroll + for (int j = 0; j < vector_size; j++) { + accumulator[j] = 0.f; + } + } +#pragma unroll + for (int k = 1; k < TOP_K; k++) { + const T* moe_output_base = moe_output + mapped_slot[k] * n_channels + channel_offset; + + if (mapped_slot[k] != gating::unassigned && mapped_slot[k] < (num_tokens * TOP_K)) + { + T read_buf[vector_size]; + mem_access::load_global( + read_buf, + moe_output_base + i * stride + ); +#pragma unroll + for (int j = 0; j < vector_size; j++) { + float up_cast = conversion::to(read_buf[j]); + accumulator[j] += up_cast * (score[k] / sum); + } + } + } + + T reg_buffer[vector_size]; +#pragma unroll + for (int j = 0; j < vector_size; j++) reg_buffer[j] = conversion::to(accumulator[j]); + + mem_access::store_global( + layer_output_base + i * stride, + reg_buffer + ); + } + } +} + +template +__global__ void moe_top2_gather_bwd_kernel(T* layer_output_grad, + float* scores_grad, + T* moe_output_grad, + T* moe_output, + const float* scores, + const int32_t* mapped_slots, + const int32_t n_channels, + const int32_t num_tokens) +{ + constexpr int32_t vector_size = scatter_gather::access_granularity / sizeof(T); + constexpr int32_t stride = vector_size * scatter_gather::threads; + + const int32_t token_idx = blockIdx.x; + + int32_t mapped_slot[TOP_K]; + float score[TOP_K]; + + + float sum = 0.0f; + +#pragma unroll + for (int i = 0; i < TOP_K; i++) mapped_slot[i] = mapped_slots[token_idx * TOP_K + i]; + +#pragma unroll + for (int i = 0; i < TOP_K; i++) { + score[i] = (mapped_slot[i] != gating::unassigned && mapped_slot[i] < (num_tokens * TOP_K)) ? scores[token_idx * TOP_K + i] : 0.f; + sum += score[i]; + } + sum += 1.192092895e-07; + + const int32_t channel_offset = threadIdx.x * vector_size; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + T* layer_output_grad_base = layer_output_grad + token_idx * n_channels + channel_offset; + // float score_grad[TOP_K]; + float score_out_grad[TOP_K]; + +#pragma unroll + for (int j = 0; j < TOP_K; j++) { + // score_grad[j] = 0.f; + score_out_grad[j] = 0.f; + } + +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + + if (i * stride + channel_offset < n_channels) + { + float reg_buffer[vector_size]; + { + T read_buf[vector_size]; + mem_access::load_global( + read_buf, layer_output_grad_base + i * stride); +#pragma unroll + for (int j = 0; j < vector_size; j++) reg_buffer[j] = conversion::to(read_buf[j]); + } + +#pragma unroll + for (int k = 0; k < TOP_K; k++) { + T store_buffer[vector_size]; + if (mapped_slot[k] != gating::unassigned && mapped_slot[k] < (num_tokens * TOP_K)) + { + T out_buffer[vector_size]; + T* moe_output_base = moe_output + mapped_slot[k] * n_channels + channel_offset; + T* moe_output_grad_base = moe_output_grad + mapped_slot[k] * n_channels + channel_offset; + mem_access::load_global( + out_buffer, moe_output_base + i * stride + ); + +#pragma unroll + for (int j = 0; j < vector_size; j++) { + float out_up_cast = conversion::to(out_buffer[j]); + store_buffer[j] = conversion::to(reg_buffer[j] * (score[k] / sum)); + for (int m = 0;m < TOP_K;m++) + score_out_grad[m] += (float)((double)(reg_buffer[j] * out_up_cast * + (m == k ? (sum - score[k]) : (-score[m]))) / (double)(sum * sum)); + } + mem_access::store_global( + moe_output_grad_base + i * stride, store_buffer + ); + } + } + } + } + + for (int j = 0; j < TOP_K; j++) + reduce::_block(tb, warp, score_out_grad + j); + + if (threadIdx.x == 0) { +#pragma unroll + for (int j = 0; j < TOP_K; j++) + { + scores_grad[token_idx * TOP_K + j] = (float)score_out_grad[j]; + } + } +} + +template +__global__ void moe_top2_scatter_bwd_kernel(T* moe_input_grad, + T* activations_grad, + const int32_t* assignments, + const int32_t* mapped_slots, + const int32_t n_channels, + const int capacity, + const int32_t num_tokens, + const int32_t n_experts) +{ + constexpr int32_t vector_size = scatter_gather::access_granularity / sizeof(T); + constexpr int32_t load_stride = vector_size * scatter_gather::threads; + + const int32_t token_idx = blockIdx.x; + const int32_t tidx = threadIdx.x; + const int32_t warp_rank = tidx / hw_warp_size; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + int assigned_expert[TOP_K]; + +#pragma unroll + for (int i = 0; i < TOP_K; i++) + assigned_expert[i] = assignments[token_idx * TOP_K + i]; + + + // Data copy to appropriate location + const int32_t thread_offset = tidx * vector_size; + + const int32_t base_load_offset = token_idx * n_channels + thread_offset; + T* load_base_ptr = activations_grad + base_load_offset; + int32_t store_row[TOP_K]; + +#pragma unroll + for (int i = 0; i < TOP_K; i++) + store_row[i] = mapped_slots[token_idx * TOP_K + i]; + +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + + float tmp_buf[vector_size]; + + if (i * load_stride + thread_offset < n_channels) { + if (assigned_expert[0] < n_experts && store_row[0] != gating::unassigned) + { + int32_t base_store_offset = store_row[0] * n_channels + thread_offset; + T* store_base_ptr = moe_input_grad + base_store_offset; + T reg_buffer[vector_size]; + mem_access::load_global(reg_buffer, store_base_ptr + i * load_stride); +#pragma unroll + for (int k = 0; k < vector_size; k++) + tmp_buf[k] = conversion::to(reg_buffer[k]); + } + else + { +#pragma unroll + for (int k = 0; k < vector_size; k++) + tmp_buf[k] = 0.f; + } +#pragma unroll + for (int k = 1; k < TOP_K; k++) + { + if (assigned_expert[k] < n_experts && store_row[k] != gating::unassigned){ + T reg_buffer[vector_size]; + const int32_t base_store_offset = store_row[k] * n_channels + thread_offset; + T* store_base_ptr = moe_input_grad + base_store_offset; + mem_access::load_global(reg_buffer, store_base_ptr + i * load_stride); + + #pragma unroll + for (int j = 0; j < vector_size; j++) + tmp_buf[j] += conversion::to(reg_buffer[j]); + } + } + T store_buf[vector_size]; + +#pragma unroll + for (int k = 0; k < vector_size; k++) + store_buf[k] = conversion::to(tmp_buf[k]); + mem_access::store_global(load_base_ptr + i * load_stride, store_buf); + } + } +} + +template +__global__ void moe_top2_scatter_kernel(T* moe_input, + int32_t* expert_count_cumsums, + int32_t* mapped_slots, + float* scores, + const T* activations, + int32_t* assignments, + const int32_t* expert_counts, + const int32_t* mapped_expert_counts, + int32_t* offsets, + const int32_t* backup_offsets, + const int32_t n_channels, + const int32_t n_experts, + const int capacity, + const int32_t num_tokens) +{ + constexpr int32_t vector_size = scatter_gather::access_granularity / sizeof(T); + constexpr int32_t load_stride = vector_size * scatter_gather::threads; + + const int32_t token_idx = blockIdx.x; + const int32_t tidx = threadIdx.x; + const int32_t warp_rank = tidx / hw_warp_size; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + int assigned_expert[TOP_K]; + int32_t offset[TOP_K]; + int mapped_expert_count[TOP_K]; + int expert_count[TOP_K]; +#pragma unroll + for (int i = 0; i < TOP_K; i++) { + + assigned_expert[i] = assignments[token_idx * TOP_K + i]; + + // expert_count[i] = expert_counts[assigned_expert[i] + n_experts * i]; + // mapped_expert_count[i] = mapped_expert_counts[assigned_expert[i] + n_experts * i]; + + offset[i] = offsets[token_idx + i * num_tokens]; + + // if (offset[i] == gating::unassigned && expert_count[i] > (capacity * TOP_K) && mapped_expert_count[i] < (capacity * TOP_K)) + // { + // offset[i] = backup_offsets[token_idx * TOP_K + i] + mapped_expert_count[i]; + // // if (tidx == 0) offsets[token_idx * TOP_K + i] = offset[i] + offsets_offset; + // } + + // if (offset[i] != gating::unassigned) { + // int32_t offsets_offset = 0; + // for (int j = 0; j < i; j++) offsets_offset += expert_count[j]; + // offset[i] += offsets_offset; + // } + int32_t other_offset = backup_offsets[token_idx + i * num_tokens]; + + if (other_offset == 0) //if (offset[i] >= (capacity * TOP_K) || offset[i] == gating::unassigned) + { + if (tidx == 0) { + mapped_slots[token_idx * TOP_K + i] = gating::unassigned; + assignments[token_idx * TOP_K + i] = n_experts; //gating::unassigned; + offset[i] = gating::unassigned; + scores[token_idx * TOP_K + i] = 0.f; + } + assigned_expert[i] = n_experts; + } + } + + // Data copy to appropriate location + const int32_t thread_offset = tidx * vector_size; + + const int32_t base_load_offset = token_idx * n_channels + thread_offset; + const T* load_base_ptr = activations + base_load_offset; + T* store_base_ptr[TOP_K]; + +#pragma unroll + for (int j = 0; j < TOP_K; j++) { + if (assigned_expert[j] < n_experts && assigned_expert[j] != gating::unassigned && offset[j] != gating::unassigned && offset[j] < (capacity * TOP_K)) + { + int32_t store_row = capacity * TOP_K * assigned_expert[j] + offset[j]; + int32_t base_store_offset = store_row * n_channels + thread_offset; + store_base_ptr[j] = moe_input + base_store_offset; + if (threadIdx.x == 0) { + mapped_slots[token_idx * TOP_K + j] = store_row; + } + } + else{ + store_base_ptr[j] = nullptr; + } + } +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + T tmp_buf[vector_size]; + if ((i * load_stride + thread_offset) < n_channels) + { + mem_access::load_global( + tmp_buf, load_base_ptr + i * load_stride); +#pragma unroll + for (int j = 0; j < TOP_K; j++) { + if (store_base_ptr[j] != nullptr) + mem_access::store_global( + store_base_ptr[j] + i * load_stride, tmp_buf); + + } + } + } +} + +#define LAUNCH_TOP2_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_top2_scatter_kernel<<>>(moe_input, \ + expert_count_cumsums, \ + mapped_slots, \ + scores, \ + activations, \ + assignments, \ + expert_counts, \ + mapped_expert_counts, \ + offsets, \ + backup_offsets, \ + n_channels, \ + n_experts, \ + capacity, \ + n_tokens); \ + break; + + +template +void launch_top2_moe_gating(T* moe_input, + int32_t* expert_count_cumsums, + int32_t* mapped_slots, + const T* activations, + int32_t* expert_counts, + int32_t* mapped_expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + int32_t* backup_offsets, + float* logits, + float* logits_out, + const int capacity, + const int32_t n_tokens, + const int32_t n_channels, + const int32_t n_experts, + const int32_t n_top_k, + cudaStream_t stream) +{ + const dim3 grid(n_tokens); + const dim3 block(((n_experts - 1) / hw_warp_size + 1) * hw_warp_size); + + std::pair seed = TOPSContext::Instance().IncrementOffset(16); + + TOP_K_SWITCH(n_top_k, [&] { + top_2_gating_kernel<<>>( + expert_counts, + scores, + assignments, + offsets, + logits, + logits_out, + n_experts, + n_tokens + ); + }); + const dim3 block2(scatter_gather::threads); + const dim3 grid2((n_tokens - 1) / scatter_gather::threads + 1); + + // TOP_K_SWITCH(n_top_k, [&] { + // refine_expert_mapping_for_top2<<>>( + // expert_counts, + // mapped_expert_counts, + // expert_count_cumsums, + // scores, + // assignments, + // offsets, + // backup_offsets, + // capacity, + // n_tokens, + // n_experts, + // seed + // ); + // }); + constexpr int vals_per_unroll = scatter_gather::threads * scatter_gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block1(scatter_gather::threads); + const dim3 grid1(n_tokens); + + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_TOP2_FOR_UNROLL(1); + LAUNCH_TOP2_FOR_UNROLL(2); + LAUNCH_TOP2_FOR_UNROLL(3); + LAUNCH_TOP2_FOR_UNROLL(4); + LAUNCH_TOP2_FOR_UNROLL(5); + LAUNCH_TOP2_FOR_UNROLL(6); + } + }); +} + +#define INSTANTIATE_TOP2_MoE_Gating_KERNEL(T) \ + template void launch_top2_moe_gating(T* moe_input, \ + int32_t* expert_count_cumsums, \ + int32_t* mapped_slots, \ + const T* activations, \ + int32_t * expert_counts, \ + int32_t * mapped_expert_counts, \ + float* scores, \ + int32_t* assignments, \ + int32_t* offsets, \ + int32_t* backup_offsets, \ + float* logits, \ + float* logits_out, \ + const int capacity, \ + const int32_t n_tokens, \ + const int32_t n_channels, \ + const int32_t n_experts, \ + const int32_t n_top_k, \ + cudaStream_t stream); + +INSTANTIATE_TOP2_MoE_Gating_KERNEL(float) +INSTANTIATE_TOP2_MoE_Gating_KERNEL(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_TOP2_MoE_Gating_KERNEL(__nv_bfloat16) +#endif + + +template +void launch_top2_moe_scatter(T* moe_input, + int32_t* expert_count_cumsums, + int32_t* mapped_slots, + const T* activations, + int32_t* expert_counts, + int32_t* mapped_expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + int32_t* backup_offsets, + const int capacity, + const int32_t n_tokens, + const int32_t n_channels, + const int32_t n_experts, + const int32_t n_top_k, + cudaStream_t stream) +{ + constexpr int vals_per_unroll = scatter_gather::threads * scatter_gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block1(scatter_gather::threads); + const dim3 grid1(n_tokens); + + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_TOP2_FOR_UNROLL(1); + LAUNCH_TOP2_FOR_UNROLL(2); + LAUNCH_TOP2_FOR_UNROLL(3); + LAUNCH_TOP2_FOR_UNROLL(4); + LAUNCH_TOP2_FOR_UNROLL(5); + LAUNCH_TOP2_FOR_UNROLL(6); + } + }); +} + +#define INSTANTIATE_TOP2_MoE_SCATTER_KERNEL(T) \ + template void launch_top2_moe_scatter(T* moe_input, \ + int32_t* expert_count_cumsums, \ + int32_t* mapped_slots, \ + const T* activations, \ + int32_t * expert_counts, \ + int32_t * mapped_expert_counts, \ + float* scores, \ + int32_t* assignments, \ + int32_t* offsets, \ + int32_t* backup_offsets, \ + const int capacity, \ + const int32_t n_tokens, \ + const int32_t n_channels, \ + const int32_t n_experts, \ + const int32_t n_top_k, \ + cudaStream_t stream); + +INSTANTIATE_TOP2_MoE_SCATTER_KERNEL(float) +INSTANTIATE_TOP2_MoE_SCATTER_KERNEL(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_TOP2_MoE_SCATTER_KERNEL(__nv_bfloat16) +#endif + + + +#define LAUNCH_FOR_UNROLL_GATHER_TOP2(COUNT) \ + case COUNT: \ + moe_top2_gather_kernel<<>>(layer_output, \ + moe_output, \ + scores, \ + mapped_slots, \ + n_channels, \ + n_tokens); \ + break; + +template +void launch_top2_moe_gather(T* layer_output, + const T* moe_output, + const float* scores, + const int32_t* mapped_slots, + const int32_t n_channels, + const int32_t n_tokens, + const int32_t n_top_k, + cudaStream_t stream) +{ + constexpr int vals_per_unroll = scatter_gather::threads * scatter_gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block(scatter_gather::threads); + const dim3 grid(n_tokens); + + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL_GATHER_TOP2(1) + LAUNCH_FOR_UNROLL_GATHER_TOP2(2) + LAUNCH_FOR_UNROLL_GATHER_TOP2(3) + LAUNCH_FOR_UNROLL_GATHER_TOP2(4) + LAUNCH_FOR_UNROLL_GATHER_TOP2(5) + LAUNCH_FOR_UNROLL_GATHER_TOP2(6) + } + }); +} + +#define INSTANTIATE_TOP2_GATHER_FOR_TYPE(TYPE) \ + template void launch_top2_moe_gather(TYPE * layer_output, \ + const TYPE* moe_output, \ + const float* scores, \ + const int32_t* mapped_slots, \ + const int32_t n_channels, \ + const int32_t n_tokens, \ + const int32_t n_top_k, \ + cudaStream_t stream); \ + +INSTANTIATE_TOP2_GATHER_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_TOP2_GATHER_FOR_TYPE(__nv_bfloat16) +#endif + + +#define LAUNCH_FOR_UNROLL_GATHER_TOP2_BWD(COUNT) \ + case COUNT: \ + moe_top2_gather_bwd_kernel<<>>(layer_output_grad, \ + scores_grad, \ + moe_output_grad, \ + moe_output, \ + scores, \ + mapped_slots, \ + n_channels, \ + n_tokens); \ + break; + +template +void launch_top2_moe_gather_bwd(T* layer_output_grad, + float* scores_grad, + T* moe_output_grad, + T* moe_output, + const float* scores, + const int32_t* mapped_slots, + const int32_t n_channels, + const int32_t n_tokens, + const int32_t n_top_k, + cudaStream_t stream) +{ + constexpr int vals_per_unroll = scatter_gather::threads * scatter_gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block(scatter_gather::threads); + const dim3 grid(n_tokens); + + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL_GATHER_TOP2_BWD(1) + LAUNCH_FOR_UNROLL_GATHER_TOP2_BWD(2) + LAUNCH_FOR_UNROLL_GATHER_TOP2_BWD(3) + LAUNCH_FOR_UNROLL_GATHER_TOP2_BWD(4) + LAUNCH_FOR_UNROLL_GATHER_TOP2_BWD(5) + LAUNCH_FOR_UNROLL_GATHER_TOP2_BWD(6) + } + }); +} + +#define INSTANTIATE_TOP2_GATHER_BWD_FOR_TYPE(TYPE) \ + template void launch_top2_moe_gather_bwd(TYPE * layer_output_grad, \ + float* scores_grad, \ + TYPE* moe_output_grad, \ + TYPE* moe_output, \ + const float* scores, \ + const int32_t* mapped_slots, \ + const int32_t n_channels, \ + const int32_t n_tokens, \ + const int32_t n_top_k, \ + cudaStream_t stream); + +INSTANTIATE_TOP2_GATHER_BWD_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_TOP2_GATHER_BWD_FOR_TYPE(__nv_bfloat16) +#endif + + +#define LAUNCH_FOR_UNROLL_TOP2_MOE_BWD(COUNT) \ + case COUNT: \ + moe_top2_scatter_bwd_kernel<<>>(moe_input_grad, \ + activations_grad, \ + assignments, \ + mapped_slots, \ + n_channels, \ + capacity, \ + n_tokens, \ + n_experts); \ + break; + +template +void launch_top2_moe_gating_bwd(T* moe_input_grad, + float* scores_grad, + T* activations_grad, + float* logits_grad, + float* logits, + const int32_t* assignments, + const int32_t* mapped_slots, + const int32_t n_channels, + const int32_t n_experts, + const int32_t n_tokens, + int capacity, + const int32_t n_top_k, + cudaStream_t stream) +{ + const dim3 grid(n_tokens); + const dim3 block(((n_experts - 1) / hw_warp_size + 1) * hw_warp_size); + + TOP_K_SWITCH(n_top_k, [&] { + top2_gate_logits_bwd_kernel<<>> + (logits_grad, scores_grad, assignments, logits, n_experts, n_tokens); + }); + constexpr int vals_per_unroll = scatter_gather::threads * scatter_gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block1(scatter_gather::threads); + const dim3 grid1(n_tokens); + + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL_TOP2_MOE_BWD(1); + LAUNCH_FOR_UNROLL_TOP2_MOE_BWD(2); + LAUNCH_FOR_UNROLL_TOP2_MOE_BWD(3); + LAUNCH_FOR_UNROLL_TOP2_MOE_BWD(4); + LAUNCH_FOR_UNROLL_TOP2_MOE_BWD(5); + LAUNCH_FOR_UNROLL_TOP2_MOE_BWD(6); + } + }); +} + +#define INSTANTIATE_TOP2_MOE_GATING_BWD_FOR_TYPE(TYPE) \ + template void launch_top2_moe_gating_bwd(TYPE * moe_input_grad, \ + float* scores_grad, \ + TYPE* activations_grad, \ + float* logits_grad, \ + float* logits, \ + const int32_t* assignments, \ + const int32_t* mapped_slots, \ + const int32_t n_channels, \ + const int32_t n_experts, \ + const int32_t n_tokens, \ + int capacity, \ + const int32_t n_top_k, \ + cudaStream_t stream); \ + +INSTANTIATE_TOP2_MOE_GATING_BWD_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_TOP2_MOE_GATING_BWD_FOR_TYPE(__nv_bfloat16) +#endif diff --git a/deepspeed/tops/rope/__init__.py b/deepspeed/tops/rope/__init__.py new file mode 100644 index 000000000000..550f91316842 --- /dev/null +++ b/deepspeed/tops/rope/__init__.py @@ -0,0 +1 @@ +from .rope import RoPE \ No newline at end of file diff --git a/deepspeed/tops/rope/rope-test.py b/deepspeed/tops/rope/rope-test.py new file mode 100644 index 000000000000..d5c9893d7b3a --- /dev/null +++ b/deepspeed/tops/rope/rope-test.py @@ -0,0 +1,16 @@ +import torch + +import deepspeed +from deepspeed.tops import RoPE +from megatron.model.rotary_pos_embedding import RotaryEmbedding + + +rotary_pos_emb = RotaryEmbedding(128) +rotary_pos_emb = self.rotary_pos_emb(rotary_pos_emb_len) + +rope = RoPE() +query_layer = torch.randn(4096, 1, 1, 128, device=torch.cuda.current_device(), dtype=torch.bfloat16) +key_layer = torch.randn(4096, 1, 1, 128, device=torch.cuda.current_device(), dtype=torch.bfloat16) + +query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) +key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) \ No newline at end of file diff --git a/deepspeed/tops/rope/rope.cpp b/deepspeed/tops/rope/rope.cpp new file mode 100644 index 000000000000..13d9b8ea3b4f --- /dev/null +++ b/deepspeed/tops/rope/rope.cpp @@ -0,0 +1,63 @@ +#include "rope.h" +#include + +#define DISPATCH_ROPE(T_TYPE, C_TYPE) \ + if (query.options().dtype() == torch::T_TYPE) { \ + launch_apply_rotary_pos_emb((C_TYPE*)query.data_ptr(), \ + (C_TYPE*)key.data_ptr(), \ + head_size, \ + seq_len, \ + rotary_dim, \ + offset, \ + num_heads, \ + batch, \ + rope_theta, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +void rope_fwd(torch::Tensor& query, torch::Tensor& key, int rotary_dim, float rope_theta) +{ + int seq_len = query.size(0); + int batch = query.size(1); + int num_heads = query.size(2); + int head_size = query.size(3); + int offset = 0; + + DISPATCH_ROPE(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_ROPE(kBFloat16, __nv_bfloat16); +#endif + +} + +#define DISPATCH_ROPE_BWD(T_TYPE, C_TYPE) \ + if (query_grad.options().dtype() == torch::T_TYPE) { \ + launch_apply_rotary_pos_bwd_emb((C_TYPE*)query_grad.data_ptr(), \ + (C_TYPE*)key_grad.data_ptr(), \ + head_size, \ + seq_len, \ + rotary_dim, \ + offset, \ + num_heads, \ + batch, \ + rope_theta, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +void rope_bwd(torch::Tensor& query_grad, torch::Tensor& key_grad, int rotary_dim, float rope_theta) +{ + + int seq_len = query_grad.size(0); + int batch = query_grad.size(1); + int num_heads = query_grad.size(2); + int head_size = query_grad.size(3); + int offset = 0; + + DISPATCH_ROPE_BWD(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_ROPE_BWD(kBFloat16, __nv_bfloat16); +#endif + +} diff --git a/deepspeed/tops/rope/rope.cu b/deepspeed/tops/rope/rope.cu new file mode 100644 index 000000000000..781180b4875e --- /dev/null +++ b/deepspeed/tops/rope/rope.cu @@ -0,0 +1,353 @@ +#include "rope.cuh" +#include "utils.h" + +#include "reduction_utils.h" +#include + +namespace rot_half { +constexpr int threads = 256; +} // namespace rot_half + +template +__global__ void apply_rotary_pos_half(T* mixed_query, + T* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count, + float rope_theta) +{ + constexpr int T_per_thread = granularity / sizeof(T); + constexpr int heads_per_block = rot_half::threads / threadsPerHead; + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile head_group = cg::tiled_partition(tb); + + const int head_idx = blockIdx.x * heads_per_block + threadIdx.x / threadsPerHead; + const int cur_seq_idx = head_idx / (total_count / seq_len); + const int offset = head_idx * head_size; + + const int seq_idx = cur_seq_idx + seq_offset; + const int half_dim = rotary_dim >> 1; + const int half_dim_threads = half_dim / T_per_thread; + + if (head_idx < total_count) { + const int base_neuron_idx = head_group.thread_rank() * T_per_thread; + + T q[T_per_thread], k[T_per_thread]; + mem_access::load_global(q, mixed_query + offset + base_neuron_idx); + mem_access::load_global(k, key_layer + offset + base_neuron_idx); + +#pragma unroll + for (int i = 0; i < T_per_thread; i++) { + const int neuron_idx = base_neuron_idx + i; + if (neuron_idx < rotary_dim) { + float inv_freq = (float)((neuron_idx % half_dim) * 2) / (float)rotary_dim; + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_idx; + + float rotary_sign = (neuron_idx > (half_dim - 1) ? -1.0 : 1.0); + float q_rot = conversion::to(q[i]) * rotary_sign; + float k_rot = conversion::to(k[i]) * rotary_sign; + + const int target_lane = (neuron_idx < half_dim) + ? head_group.thread_rank() + half_dim_threads + : head_group.thread_rank() - half_dim_threads; + + const float q_rot_temp = head_group.shfl(q_rot, target_lane); + const float k_rot_temp = head_group.shfl(k_rot, target_lane); + + q[i] = conversion::to(conversion::to(q[i]) * cosf(inv_freq) + + q_rot_temp * sinf(inv_freq)); + k[i] = conversion::to(conversion::to(k[i]) * cosf(inv_freq) + + k_rot_temp * sinf(inv_freq)); + } + } + + mem_access::store_global(mixed_query + offset + base_neuron_idx, q); + mem_access::store_global(key_layer + offset + base_neuron_idx, k); + } +} + + +template +__global__ void apply_rotary_pos_bwd_half(T* mixed_query_grad, + T* key_layer_grad, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count, + float rope_theta + ) +{ + constexpr int T_per_thread = granularity / sizeof(T); + constexpr int heads_per_block = rot_half::threads / threadsPerHead; + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile head_group = cg::tiled_partition(tb); + + const int head_idx = blockIdx.x * heads_per_block + threadIdx.x / threadsPerHead; + const int cur_seq_idx = head_idx / (total_count / seq_len); + const int offset = head_idx * head_size; + + const int seq_idx = cur_seq_idx + seq_offset; + const int half_dim = rotary_dim >> 1; + const int half_dim_threads = half_dim / T_per_thread; + + if (head_idx < total_count) { + const int base_neuron_idx = head_group.thread_rank() * T_per_thread; + + T q[T_per_thread], k[T_per_thread]; + mem_access::load_global(q, mixed_query_grad + offset + base_neuron_idx); + mem_access::load_global(k, key_layer_grad + offset + base_neuron_idx); + +#pragma unroll + for (int i = 0; i < T_per_thread; i++) { + const int neuron_idx = base_neuron_idx + i; + if (neuron_idx < rotary_dim) { + float inv_freq = (float)((neuron_idx % half_dim) * 2) / (float)rotary_dim; + inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_idx; + + float rotary_sign = (neuron_idx > (half_dim - 1) ? 1.0 : -1.0); + float q_rot = conversion::to(q[i]) * rotary_sign; + float k_rot = conversion::to(k[i]) * rotary_sign; + + const int target_lane = (neuron_idx < half_dim) + ? head_group.thread_rank() + half_dim_threads + : head_group.thread_rank() - half_dim_threads; + + const float q_rot_temp = head_group.shfl(q_rot, target_lane); + const float k_rot_temp = head_group.shfl(k_rot, target_lane); + + q[i] = conversion::to(conversion::to(q[i]) * cosf(inv_freq) + + q_rot_temp * sinf(inv_freq)); + k[i] = conversion::to(conversion::to(k[i]) * cosf(inv_freq) + + k_rot_temp * sinf(inv_freq)); + } + } + + mem_access::store_global(mixed_query_grad + offset + base_neuron_idx, q); + mem_access::store_global(key_layer_grad + offset + base_neuron_idx, k); + } +} + +#define LAUNCH_ROT_POS_EMB_BWD_HALF(HEAD_THREADS, ALIGNMENT) \ + apply_rotary_pos_bwd_half<<>>(mixed_query_grad, \ + key_layer_grad, \ + rotary_dim, \ + seq_len, \ + offset, \ + num_heads, \ + head_size, \ + total_count, \ + rope_theta); + + +#ifdef __HIP_PLATFORM_AMD__ +#define LAUNCH_BWD_FOR_ALIGNMENT(ALIGNMENT) \ + if (threads_per_head == 4) { \ + LAUNCH_ROT_POS_EMB_BWD_HALF(4, ALIGNMENT); \ + } else if (threads_per_head == 8) { \ + LAUNCH_ROT_POS_EMB_BWD_HALF(8, ALIGNMENT); \ + } else if (threads_per_head == 16) { \ + LAUNCH_ROT_POS_EMB_BWD_HALF(16, ALIGNMENT); \ + } else if (threads_per_head == 32) { \ + LAUNCH_ROT_POS_EMB_BWD_HALF(32, ALIGNMENT); \ + } else if (threads_per_head == 64) { \ + LAUNCH_ROT_POS_EMB_BWD_HALF(64, ALIGNMENT); \ + } else { \ + assert(false); \ + } +#else +#define LAUNCH_BWD_FOR_ALIGNMENT(ALIGNMENT) \ + if (threads_per_head == 4) { \ + LAUNCH_ROT_POS_EMB_BWD_HALF(4, ALIGNMENT); \ + } else if (threads_per_head == 8) { \ + LAUNCH_ROT_POS_EMB_BWD_HALF(8, ALIGNMENT); \ + } else if (threads_per_head == 16) { \ + LAUNCH_ROT_POS_EMB_BWD_HALF(16, ALIGNMENT); \ + } else if (threads_per_head == 32) { \ + LAUNCH_ROT_POS_EMB_BWD_HALF(32, ALIGNMENT); \ + } else { \ + assert(false); \ + } +#endif + +#define LAUNCH_ROT_POS_EMB_HALF(HEAD_THREADS, ALIGNMENT) \ + apply_rotary_pos_half<<>>(mixed_query, \ + key_layer, \ + rotary_dim, \ + seq_len, \ + offset, \ + num_heads, \ + head_size, \ + total_count, \ + rope_theta); + +#ifdef __HIP_PLATFORM_AMD__ +#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ + if (threads_per_head == 4) { \ + LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ + } else if (threads_per_head == 8) { \ + LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ + } else if (threads_per_head == 16) { \ + LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ + } else if (threads_per_head == 32) { \ + LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ + } else if (threads_per_head == 64) { \ + LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \ + } else { \ + assert(false); \ + } +#else +#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ + if (threads_per_head == 4) { \ + LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ + } else if (threads_per_head == 8) { \ + LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ + } else if (threads_per_head == 16) { \ + LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ + } else if (threads_per_head == 32) { \ + LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ + } else { \ + assert(false); \ + } +#endif + +template +void launch_apply_rotary_pos_emb(T* mixed_query, + T* key_layer, + unsigned head_size, + unsigned seq_len, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads, + unsigned batch, + float rope_theta, + cudaStream_t stream) +{ + const int half_dim = rotary_dim >> 1; + + int alignment = sizeof(T); + if (half_dim % (16 / sizeof(T)) == 0) { + alignment = 16; + } else if (half_dim % (8 / sizeof(T)) == 0) { + alignment = 8; + } else if (half_dim % (4 / sizeof(T)) == 0) { + alignment = 4; + } else { + assert(false); + } + const int T_per_elem = alignment / sizeof(T); + + int total_count = batch * num_heads * seq_len; + + const int padded_head_size = next_pow2(head_size); + + assert(padded_head_size <= hw_warp_size * T_per_elem); + + const int threads_per_head = padded_head_size / T_per_elem; + const int heads_per_block = rot_half::threads / threads_per_head; + + dim3 block(rot_half::threads); + dim3 grid((total_count + heads_per_block - 1) / heads_per_block); + + if (alignment == 4) { + LAUNCH_FOR_ALIGNMENT(4); + } else if (alignment == 8) { + LAUNCH_FOR_ALIGNMENT(8); + } else if (alignment == 16) { + LAUNCH_FOR_ALIGNMENT(16); + } else { + assert(false); + } +} + + +#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \ + template void launch_apply_rotary_pos_emb(T*, \ + T*, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + float, \ + cudaStream_t); + +INSTANTIATE_LAUNCH_ROTARY_POS_EMB(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_ROTARY_POS_EMB(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_ROTARY_POS_EMB(__half); + +template +void launch_apply_rotary_pos_bwd_emb(T* mixed_query_grad, + T* key_layer_grad, + unsigned head_size, + unsigned seq_len, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads, + unsigned batch, + float rope_theta, + cudaStream_t stream) +{ + const int half_dim = rotary_dim >> 1; + + int alignment = sizeof(T); + if (half_dim % (16 / sizeof(T)) == 0) { + alignment = 16; + } else if (half_dim % (8 / sizeof(T)) == 0) { + alignment = 8; + } else if (half_dim % (4 / sizeof(T)) == 0) { + alignment = 4; + } else { + assert(false); + } + const int T_per_elem = alignment / sizeof(T); + + int total_count = batch * num_heads * seq_len; + + const int padded_head_size = next_pow2(head_size); + + assert(padded_head_size <= hw_warp_size * T_per_elem); + + const int threads_per_head = padded_head_size / T_per_elem; + const int heads_per_block = rot_half::threads / threads_per_head; + + dim3 block(rot_half::threads); + dim3 grid((total_count + heads_per_block - 1) / heads_per_block); + + if (alignment == 4) { + LAUNCH_BWD_FOR_ALIGNMENT(4); + } else if (alignment == 8) { + LAUNCH_BWD_FOR_ALIGNMENT(8); + } else if (alignment == 16) { + LAUNCH_BWD_FOR_ALIGNMENT(16); + } else { + assert(false); + } +} + +#define INSTANTIATE_LAUNCH_ROTARY_POS_BWD_EMB(T) \ + template void launch_apply_rotary_pos_bwd_emb(T*, \ + T*, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + float, \ + cudaStream_t); + +INSTANTIATE_LAUNCH_ROTARY_POS_BWD_EMB(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_ROTARY_POS_BWD_EMB(__nv_bfloat16); +#endif +INSTANTIATE_LAUNCH_ROTARY_POS_BWD_EMB(__half); \ No newline at end of file diff --git a/deepspeed/tops/rope/rope.py b/deepspeed/tops/rope/rope.py new file mode 100644 index 000000000000..4a649904cfdd --- /dev/null +++ b/deepspeed/tops/rope/rope.py @@ -0,0 +1,52 @@ + +import torch + +from typing import Tuple + +from deepspeed.ops.op_builder import TopsBuilder + +inf_module = None + +class RoPEFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, dim, theta): + q = q.contiguous() + k = k.contiguous() + inf_module.rope_fwd(q, k, dim, theta) + + ctx.dim = dim + ctx.theta = theta + + return q, k + + @staticmethod + def backward(ctx, q_grad, k_grad): + q_grad = q_grad.contiguous() + k_grad = k_grad.contiguous() + + inf_module.rope_bwd(q_grad, k_grad, ctx.dim, ctx.theta) + + return q_grad, k_grad, None, None + +class RoPE(torch.nn.Module): + + def __init__(self, rotary_dim=None, rope_theta=10000.0) -> None: + super(RoPE, self).__init__() + global inf_module + if inf_module is None: + inf_module = TopsBuilder().load() + self.rotary_dim = rotary_dim + self.rope_theta = rope_theta + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + ) -> torch.Tensor: + + if self.rotary_dim is None: + self.rotary_dim = query.shape[-1] + + return RoPEFunction.apply( + query, key, self.rotary_dim, self.rope_theta + ) diff --git a/deepspeed/tops/swiglu/__init__.py b/deepspeed/tops/swiglu/__init__.py new file mode 100644 index 000000000000..44687d8ffb1a --- /dev/null +++ b/deepspeed/tops/swiglu/__init__.py @@ -0,0 +1 @@ +from .swiglu import SwiGlu \ No newline at end of file diff --git a/deepspeed/tops/swiglu/swiglu.cpp b/deepspeed/tops/swiglu/swiglu.cpp new file mode 100644 index 000000000000..a350d8e457c1 --- /dev/null +++ b/deepspeed/tops/swiglu/swiglu.cpp @@ -0,0 +1,55 @@ +#include "swiglu.h" +#include + +#define DISPATCH_SWIGLU(T_TYPE, C_TYPE) \ + if (inp.options().dtype() == torch::T_TYPE) { \ + launch_swiglu((C_TYPE*)out.data_ptr(), \ + (C_TYPE*)inp.data_ptr(), \ + bsz, \ + hidden_size, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + + +void swiglu_fwd(torch::Tensor& inp, torch::Tensor& out) +{ + int inp_dims = inp.sizes().size(); + int hidden_size = inp.size(inp_dims - 1); + int bsz = inp.size(0); + for (int i = 1;i < inp_dims - 1;i++) + bsz *= inp.size(i); + // printf("bsz = %d, hidden_size = %d \n", bsz, hidden_size); + + DISPATCH_SWIGLU(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_SWIGLU(kBFloat16, __nv_bfloat16); +#endif + +} + +#define DISPATCH_SWIGLU_BWD(T_TYPE, C_TYPE) \ + if (inp.options().dtype() == torch::T_TYPE) { \ + launch_swiglu_bwd((C_TYPE*)inp_grad.data_ptr(), \ + (C_TYPE*)out_grad.data_ptr(), \ + (C_TYPE*)inp.data_ptr(), \ + bsz, \ + hidden_size, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +void swiglu_bwd(torch::Tensor& inp, torch::Tensor& out_grad, torch::Tensor& inp_grad) +{ + int inp_dims = inp.sizes().size(); + int hidden_size = inp.size(inp_dims - 1); + int bsz = inp.size(0); + for (int i = 1;i < inp_dims-1;i++) + bsz *= inp.size(i); + + DISPATCH_SWIGLU_BWD(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_SWIGLU_BWD(kBFloat16, __nv_bfloat16); +#endif + +} diff --git a/deepspeed/tops/swiglu/swiglu.cu b/deepspeed/tops/swiglu/swiglu.cu new file mode 100644 index 000000000000..82c132b5e8e5 --- /dev/null +++ b/deepspeed/tops/swiglu/swiglu.cu @@ -0,0 +1,214 @@ +#include "swiglu.cuh" +#include "utils.h" + +#include + +DS_D_INLINE float gated_act_fn(float x, float y) +{ + return y * (x / (1.0f + expf(-x))); +} + +template +__global__ void swiglu_kernel(T* out, T* inp, int hidden_size) +{ + + constexpr int read_vector = 16 / sizeof(T); + constexpr int write_vector = read_vector; // / 2; + + const int row = blockIdx.x; + const int col = threadIdx.x * read_vector; + + + + T* input_row = inp + row * hidden_size; + T* output_row = out + row * (hidden_size >> 1); + +#pragma unroll + for (int i = 0; i < loopUnroll; i++) { + T read1[read_vector]; + T read2[read_vector]; + T store[write_vector]; + + const int read_offset = col + ((read_vector * i) << 10); + const int write_offset = col + ((write_vector * i) << 10); + + if (i != loopUnroll - 1 || read_offset < (hidden_size >> 1)) { + mem_access::load_global<16>(read1, input_row + read_offset); + mem_access::load_global<16>(read2, input_row + read_offset + (hidden_size >> 1)); + + for (int j = 0; j < write_vector; j++) { + float g_val = conversion::to(read1[j]); + float a_val = conversion::to(read2[j]) ; + + float act_val = gated_act_fn(g_val, a_val); + store[j] = conversion::to(act_val); + // if (threadIdx.x == 0 && blockIdx.x == 0) printf("I am here! %f %p %p %d\n", act_val, out, output_row, write_offset); + } + + mem_access::store_global<16>(output_row + write_offset, store); + } + } +} + + +#define DISPATCH_UNROLL(unroll_val) \ + swiglu_kernel \ + <<>>(out, inp, hidden_size); + + +template +void launch_swiglu(T* out, + T* inp, + int bsz, + int hidden_size, + cudaStream_t stream) +{ + const int threads = 1024; + const dim3 grid(bsz); + const dim3 block(threads); + constexpr int cols_per_unroll = threads * 16 / sizeof(T); + const int unroll = ((hidden_size >> 1) - 1) / cols_per_unroll + 1; + // printf("bsz = %d, cols_per_unroll = %d, unroll = %d, hidden_size = %d \n", bsz, cols_per_unroll, unroll, hidden_size); + + if (unroll == 1) { + DISPATCH_UNROLL(1); + } else if (unroll == 2) { + DISPATCH_UNROLL(2); + } else if (unroll == 3) { + DISPATCH_UNROLL(3); + } else if (unroll == 4) { + DISPATCH_UNROLL(4); + } else if (unroll == 5) { + DISPATCH_UNROLL(5); + } else if (unroll == 6) { + DISPATCH_UNROLL(6); + } else { + throw std::runtime_error( + "[RuntimeError]: SwiGlu kernel limit surpassed"); + } +} + +#define INSTANTIATE_FOR_TYPE(T) \ + template void launch_swiglu(T * out, \ + T* inp, \ + int bsz, \ + int hidden_size, \ + cudaStream_t stream); + +INSTANTIATE_FOR_TYPE(float) +INSTANTIATE_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_FOR_TYPE(__nv_bfloat16) +#endif + + + +DS_D_INLINE float gated_act_bwd_fn(float &x, float &y, float &grad) +{ + float sigmoid = 1.0 / (1.0 + expf(-x)); + return y * grad * sigmoid * (1.0 + x * (1.0 - sigmoid)); +} + +DS_D_INLINE float sig_fn(float x, float grad) +{ + return grad * (x / (1.0f + expf(-x))); +} + + +template +__global__ void swiglu_bwd_kernel(T* inp_grad, T* out_grad, T* inp, int hidden_size) +{ + + constexpr int read_vector = 16 / sizeof(T); + constexpr int write_vector = read_vector; /// 2; + + const int row = blockIdx.x; + const int col = threadIdx.x * read_vector; + + T* input_row = inp + row * hidden_size; + T* inp_grad_row = inp_grad + row * hidden_size; + T* out_grad_row = out_grad + row * (hidden_size >> 1); + +#pragma unroll + for (int i = 0; i < loopUnroll; i++) { + T read1[read_vector]; + T read2[read_vector]; + T read_grad[write_vector]; + T store1[read_vector]; + T store2[read_vector]; + + const int read_offset = col + ((read_vector * i) << 10); + const int write_offset = col + ((write_vector * i) << 10); + + if (i != loopUnroll - 1 || read_offset < (hidden_size >> 1)) { + mem_access::load_global<16>(read1, input_row + read_offset); + mem_access::load_global<16>(read2, input_row + read_offset + (hidden_size >> 1)); + mem_access::load_global<16>(read_grad, out_grad_row + write_offset); + + for (int j = 0; j < write_vector; j++) { + float g_val = conversion::to(read1[j]); + float a_val = conversion::to(read2[j]) ; + float grad_val = conversion::to(read_grad[j]) ; + + float grad_y = sig_fn(g_val, grad_val); + float grad_x = gated_act_bwd_fn(g_val, a_val, grad_val); + + store1[j] = conversion::to(grad_x); + store2[j] = conversion::to(grad_y); + } + + mem_access::store_global<16>(inp_grad_row + read_offset, store1); + mem_access::store_global<16>(inp_grad_row + read_offset + (hidden_size >> 1), store2); + } + } +} + + +#define BWD_DISPATCH_UNROLL(unroll_val) \ + swiglu_bwd_kernel \ + <<>>(inp_grad, out_grad, inp, hidden_size); + + +template +void launch_swiglu_bwd(T* inp_grad, T* out_grad, T* inp, + int bsz, int hidden_size, + cudaStream_t stream) +{ + const int threads = 1024; + const dim3 grid(bsz); + const dim3 block(threads); + constexpr int cols_per_unroll = threads * 16 / sizeof(T); + const int unroll = ((hidden_size >> 1) - 1) / cols_per_unroll + 1; + if (unroll == 1) { + BWD_DISPATCH_UNROLL(1); + } else if (unroll == 2) { + BWD_DISPATCH_UNROLL(2); + } else if (unroll == 3) { + BWD_DISPATCH_UNROLL(3); + } else if (unroll == 4) { + BWD_DISPATCH_UNROLL(4); + } else if (unroll == 5) { + BWD_DISPATCH_UNROLL(5); + } else if (unroll == 6) { + BWD_DISPATCH_UNROLL(6); + } else { + throw std::runtime_error( + "[RuntimeError]: SwiGlu BWD kernel limit surpassed"); + } +} + +#define INSTANTIATE_BWD_FOR_TYPE(T) \ + template void launch_swiglu_bwd(T * inp_grad, \ + T * out_grad, \ + T* inp, \ + int bsz, \ + int hidden_size, \ + cudaStream_t stream); + +INSTANTIATE_BWD_FOR_TYPE(float) +INSTANTIATE_BWD_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_BWD_FOR_TYPE(__nv_bfloat16) +#endif \ No newline at end of file diff --git a/deepspeed/tops/swiglu/swiglu.py b/deepspeed/tops/swiglu/swiglu.py new file mode 100644 index 000000000000..a6d7424f01eb --- /dev/null +++ b/deepspeed/tops/swiglu/swiglu.py @@ -0,0 +1,53 @@ + +import torch + +from typing import Tuple + +from deepspeed.ops.op_builder import TopsBuilder + +inf_module = None + +def bwd(x, y, grad): + x_float = x.float() + y_float = y.float() + g_float = grad.float() + return (g_float * y_float * \ + torch.nn.functional.sigmoid(x_float) * (1.0 + x_float * (1.0 - torch.nn.functional.sigmoid(x_float)))).to(x.dtype) + +class SwiGluFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, inp, is_grad_enabled): + out = torch.empty((inp.shape[:-1] + (inp.shape[-1] // 2,)), + dtype=inp.dtype, device=inp.device) + + inp = inp.contiguous() + inf_module.swiglu_fwd(inp, out) + + if is_grad_enabled: + ctx.save_for_backward(inp) + return out + + @staticmethod + def backward(ctx, grad_out): + (inp,) = ctx.saved_tensors + grad_out = grad_out.contiguous() + inp_grad = torch.empty_like(inp) + inf_module.swiglu_bwd(inp, grad_out, inp_grad) + return inp_grad, None + +class SwiGlu(torch.nn.Module): + + def __init__(self, ) -> None: + super(SwiGlu, self).__init__() + global inf_module + if inf_module is None: + inf_module = TopsBuilder().load() + + def forward(self, + inp: torch.Tensor, + ) -> torch.Tensor: + is_grad_enabled = torch.is_grad_enabled() + return SwiGluFunction.apply( + inp, is_grad_enabled + ) diff --git a/deepspeed/tops/swiglu/test_swiglu.py b/deepspeed/tops/swiglu/test_swiglu.py new file mode 100644 index 000000000000..b3b511e0aabe --- /dev/null +++ b/deepspeed/tops/swiglu/test_swiglu.py @@ -0,0 +1,54 @@ +import torch + +import deepspeed +from deepspeed.tops import SwiGlu +import torch.nn.functional as F +import time + +def calc_error(ds_out, pt_out): + error = (ds_out - pt_out).abs().float().sum() / pt_out.numel() + rel_error = ((pt_out - ds_out).abs() / (pt_out + 1e-5).abs()).float().sum() / pt_out.numel() + return error, rel_error + +def pt_swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + +a = torch.ones(4 * 4096, 16384, dtype=torch.bfloat16, device=torch.cuda.current_device(), requires_grad=True) +aa = torch.ones(4 * 4096, 16384, dtype=torch.bfloat16, device=torch.cuda.current_device(), requires_grad=True) + + +aa.retain_grad() + +swiglu = SwiGlu() +for _ in range(10): + ds_out = swiglu(a) + error = ds_out.sum() + error.backward() +torch.cuda.synchronize() +t0 = time.time() +for _ in range(100): + ds_out = swiglu(a) + error = ds_out.sum() + error.backward() +torch.cuda.synchronize() +t1 = time.time() +ds_time = t1 - t0 +print(ds_time, (a.numel() + ds_out.numel()) * 2 / ds_time / 1000000) + +pt_out = pt_swiglu(aa) +for _ in range(10): + pt_out = pt_swiglu(aa) + error1 = pt_out.sum() + error1.backward() +torch.cuda.synchronize() +t0 = time.time() +for _ in range(100): + pt_out = pt_swiglu(aa) + error1 = pt_out.sum() + error1.backward() +torch.cuda.synchronize() +t1 = time.time() +pt_time = t1 - t0 +print(pt_time, (a.numel() + ds_out.numel()) * 2 / pt_time / 1000000) +print(f'speedup: {pt_time/ds_time:.2f}x') \ No newline at end of file diff --git a/deepspeed/tops/tops.cpp b/deepspeed/tops/tops.cpp new file mode 100644 index 000000000000..a98f8f23f398 --- /dev/null +++ b/deepspeed/tops/tops.cpp @@ -0,0 +1,18 @@ +#include "moe_gating.h" +#include "rope.h" +#include "swiglu.h" + +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("swiglu_fwd", &swiglu_fwd, "swiglu_fwd function (fwd)"); + m.def("swiglu_bwd", &swiglu_bwd, "swiglu_bwd function (bwd)"); + m.def("rope_fwd", &rope_fwd, "rope_fwd function (fwd)"); + m.def("rope_bwd", &rope_bwd, "rope_bwd function (bwd)"); + m.def("moe_gating_fwd", &gate_fwd, "MoE gating function (fwd)"); + m.def("moe_gating_scatter", &gate_scatter, "MoE gating scatter function (fwd)"); + m.def("moe_gating_bwd", &gate_bwd, "MoE gating function (bwd)"); + m.def("moe_gather_fwd", &gather_fwd, "MoE gather function (fwd)"); + m.def("moe_gather_bwd", &gather_bwd, "MoE gather function (bwd)"); +} \ No newline at end of file diff --git a/op_builder/tops.py b/op_builder/tops.py new file mode 100644 index 000000000000..d6e17cc27f4a --- /dev/null +++ b/op_builder/tops.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class TopsBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_TOPS_OPS" + NAME = "tops_ops" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.tops.{self.NAME}' + + def is_compatible(self, verbose=True): + try: + import torch + except ImportError: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 6: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in ccs: + if int(cc[0]) >= 6: + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def get_prefix(self): + ds_path = self.deepspeed_src_path("deepspeed") + return "deepspeed" if os.path.isdir(ds_path) else ".." + + def sources(self): + sources = [ + "tops/swiglu/swiglu.cu", + "tops/rope/rope.cu", + "tops/moe_gating/top1_moe_gating.cu", + "tops/moe_gating/top2_moe_gating.cu", + "tops/swiglu/swiglu.cpp", + "tops/rope/rope.cpp", + "tops/moe_gating/moe_gating.cpp", + "tops/tops.cpp", + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def extra_ldflags(self): + return [] + + def include_paths(self): + include_dirs = ['tops/includes'] + prefix = self.get_prefix() + includes = [os.path.join(prefix, include_dir) for include_dir in include_dirs] + + return includes