diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index dbd47249c..64d375f80 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -21,8 +21,10 @@ env: ENV_NAME_MMCV: 'pt1.11v1' GPU_REQUESTS: 1 SLURM_PAR_SH1988: ${{ vars.SLURM_PAR_SH1984 != '' && vars.SLURM_PAR_SH1984 || 'pat_rd -x SH-IDC1-10-198-8-58,SH-IDC1-10-198-8-87' }} + SLURM_PAR_SH1424: ${{ vars.SLURM_PAR_SH1424 != '' && vars.SLURM_PAR_SH1424 || 'pat_rd' }} SLURM_PAR_CAMB: ${{ vars.SLURM_PAR_CAMB != '' && vars.SLURM_PAR_CAMB || 'camb_mlu370_m8 --exclude HOST-10-142-11-120,HOST-10-142-11-126' }} CLUSTER_1988: SH1988 + CLUSTER_1424: SH1424 CLUSTER_CAMB: CAMB CLUSTER_ASCEND: ASCEND CLUSTER_TOPSRIDER: TOPSRIDER @@ -82,6 +84,9 @@ jobs: && rsync -a --delete ${GITHUB_WORKSPACE}/DIOPI/ ${CLUSTER_TOPSRIDER}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to topsrider" ssh ${CLUSTER_SUPA} "mkdir -p ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source" \ && rsync -a --delete ${GITHUB_WORKSPACE}/DIOPI/ ${CLUSTER_SUPA}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to supa" + ssh ${CLUSTER_1424} "mkdir -p ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source" \ + && rsync -a --delete ${GITHUB_WORKSPACE}/DIOPI/ ${CLUSTER_1424}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to sh1424" + lint: name: lint @@ -117,16 +122,24 @@ jobs: export CI=true srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1988} --time=10 bash -c 'cd impl && bash scripts/build_impl.sh torch' || ( cd ${NFS_PATH}/${GITHUB_RUN_NUMBER}/ && rm -rf ${BUILD_TEST1} && exit 1 ) """ - - name: build-dyload + + Build-Nvidia-A100: + name: Build-Nvidia-A100 + runs-on: tps-diopi-ci + needs: [Rsync] + steps: + - name: build run: | - ssh ${CLUSTER_1988} """ - set -e - source ${ENV_PATH}/github_bashrc && source /mnt/cache/share/platform/env/${ENV_NAME} - cd ${NFS_PATH}/${GITHUB_RUN_NUMBER} && rm -rf ${BUILD_TEST2} && cp -R source ${BUILD_TEST2} && cd ${BUILD_TEST2} + ssh ${CLUSTER_1424} """ + set -ex + export USE_COVERAGE=ON + source /mnt/cache/share/platform/env/${ENV_NAME} + cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && rm -rf ${BUILD_TEST1} && cp -R source ${BUILD_TEST1} && cd ${BUILD_TEST1} export CI=true - srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1988} --time=10 bash -c 'cd impl && bash scripts/build_impl.sh torch_dyload' || ( cd ${NFS_PATH}/${GITHUB_RUN_NUMBER}/ && rm -rf ${BUILD_TEST2} && exit 1 ) + srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1424} --time=10 bash -c 'cd impl && bash scripts/build_impl.sh torch' || ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/ && rm -rf ${BUILD_TEST1} && exit 1 ) """ + Build-Camb: name: Build-Camb runs-on: tps-diopi-ci @@ -203,6 +216,32 @@ jobs: || ( cd ${NFS_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 ) """ + Gen-Data-Op-Test-A100: + name: Gen-Data-Op-Test-A100 + runs-on: tps-diopi-ci + needs: [Build-Nvidia-A100] + steps: + - name: gen-test-data + run: | + ssh ${CLUSTER_1424} """ + set -e + export CI=true + source /mnt/cache/share/platform/env/${ENV_NAME} + cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST1} && cp -f scripts/ci/diopi_config_A100.py diopi_test/python/configs/diopi_configs.py && cd diopi_test/python && ls && + srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1424} --time=10 --gres=gpu:${GPU_REQUESTS} bash -c 'python main.py --mode gen_data' \ + || ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 ) + """ + - name: test-op + run: | + ssh ${CLUSTER_1424} """ + set -e + export CI=true + source /mnt/cache/share/platform/env/${ENV_NAME} && cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST1} + export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/impl/lib + echo \$LD_LIBRARY_PATH + srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1424} --time=20 --gres=gpu:${GPU_REQUESTS} bash -c 'cd diopi_test/python && python main.py --mode gen_case && python main.py --mode run_test' + """ + Op-Test-Nvidia: name: Op-Test-Nvidia runs-on: tps-diopi-ci @@ -224,25 +263,24 @@ jobs: bash /mnt/cache/share/platform/dep/sonar/coverage_DIOPI_nv.sh ${NFS_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} ${GITHUB_RUN_NUMBER} || echo "get coverage fail" fi """ - - name: increment coverage check - if: ${{ contains( github.event_name, 'pull_request' ) && contains( github.base_ref, 'main' ) }} + - name: test run: | ssh ${CLUSTER_1988} """ set -e + export CI=true source ${ENV_PATH}/github_bashrc && source /mnt/cache/share/platform/env/${ENV_NAME} && cd ${NFS_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST1} - bash scripts/increment_coverage.sh ${REQUIRE_COVERAGE} + export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:${NFS_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/impl/lib + echo \$LD_LIBRARY_PATH + srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1988} --time=20 --gres=gpu:${GPU_REQUESTS} bash -c 'cd diopi_test/python && python main.py --mode gen_case && + python main.py --mode run_test' """ - - name: dyload-test + - name: increment coverage check + if: ${{ contains( github.event_name, 'pull_request' ) && contains( github.base_ref, 'main' ) }} run: | ssh ${CLUSTER_1988} """ set -e - export CI=true - source ${ENV_PATH}/github_bashrc && source /mnt/cache/share/platform/env/${ENV_NAME} && cd ${NFS_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST2} - rm -rf ${GEN_DATA} && ln -s ${NFS_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/${GEN_DATA} ${GEN_DATA} - export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:${NFS_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST2}/impl/lib - echo \$LD_LIBRARY_PATH - srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1988} --time=20 --gres=gpu:${GPU_REQUESTS} bash -c 'cd diopi_test/python && python main.py --mode gen_case && - python main.py --mode run_test' + source ${ENV_PATH}/github_bashrc && source /mnt/cache/share/platform/env/${ENV_NAME} && cd ${NFS_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST1} + bash scripts/increment_coverage.sh ${REQUIRE_COVERAGE} """ Rt-Test-Nvidia: diff --git a/diopi_test/diopi_stub/codegen/gen.py b/diopi_test/diopi_stub/codegen/gen.py index 9a7e9018c..83c57b209 100644 --- a/diopi_test/diopi_stub/codegen/gen.py +++ b/diopi_test/diopi_stub/codegen/gen.py @@ -100,16 +100,7 @@ def get_func_info(content): return type_change, args, attr_types, paras_can_be_none, ins_vector, outs_vector, out_ptr -def gen_functions(options, functions_fm): - _cur_dir = os.path.dirname(os.path.abspath(__file__)) - with open(os.path.join(_cur_dir, options.get('source_dir'), 'functions.h'), 'r', encoding='utf8')as f: - content = f.readlines() - exports = [] - device = options.get('device') - if device == 'ascend': - ft = OT.function_ascend_template - else: - ft = OT.function_template +def get_export(content, ft, exports): for idx, row in enumerate(content): if row.startswith("DIOPI_API"): row = row[10:] @@ -167,6 +158,23 @@ def gen_functions(options, functions_fm): else: exports.append(ft.substitute(env=dict(func_name=func_name, attrs=', '.join(arg_def), convert='', out_copy='', call_func=call_func))) + return exports + + +def gen_functions(options, functions_fm): + _cur_dir = os.path.dirname(os.path.abspath(__file__)) + with open(os.path.join(_cur_dir, options.get('source_dir'), 'functions.h'), 'r', encoding='utf8')as f: + content = f.readlines() + exports = [] + device = options.get('device') + if device == 'ascend': + ft = OT.function_ascend_template + else: + ft = OT.function_template + exports = get_export(content, ft, exports) + with open(os.path.join(_cur_dir, options.get('source_dir'), 'functions_ext.h'), 'r', encoding='utf8')as f: + content_ext = f.readlines() + exports = get_export(content_ext, ft, exports) functions_fm.write("export_functions.cpp", OT.operators_template, env=dict(export_functions=exports)) diff --git a/diopi_test/diopi_stub/codegen/op_template.py b/diopi_test/diopi_stub/codegen/op_template.py index b4e85a055..062b48a40 100644 --- a/diopi_test/diopi_stub/codegen/op_template.py +++ b/diopi_test/diopi_stub/codegen/op_template.py @@ -15,6 +15,7 @@ class OpTemplate(object): #include "litert.hpp" #include #include +#include namespace py = pybind11; diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index 5ca6ffa70..367f8798d 100644 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -8018,4 +8018,94 @@ ], ), ), + + 'rotary_emb': dict( + name=['rotary_emb'], + interface=['CustomizedTest'], + dtype=[np.float64, np.float32, np.float16], + para=dict( + conj=[False, True, False, True], + ), + tensor_para=dict( + gen_fn='Genfunc.randn', + args=[ + { + "ins": ['input'], + "shape": ((1, 125, 16, 32), (1, 125, 16, 32), (2, 64, 16, 32), (3, 100, 8, 64)), + }, + { + "ins": ['cos'], + "shape": ((125, 1, 16), (125, 1, 16), (64, 1, 16), (100, 1, 32)), + }, + { + "ins": ['sin'], + "shape": ((125, 1, 16), (125, 1, 16), (64, 1, 16), (100, 1, 32)), + }, + ], + ), + ), + + 'rms_norm': dict( + name=['rms_norm'], + interface=['CustomizedTest'], + dtype=[np.float32], + para=dict( + eps=[1e-6, 1e-6, 1e-6, 1e-6], + normalized_shape=[(5, ), (32, ), (64, ), (8, )], + ), + tensor_para=dict( + gen_fn='Genfunc.randn', + args=[ + { + "ins": ['input'], + "shape": ((5, 5), (35, 125, 32), (16, 64, 64), (1, 32, 32, 8)), + }, + { + "ins": ['weight'], + "shape": ((5, ), (32, ), (64, ), (8, )), + }, + { + "ins": ['bias'], + "shape": ((5, ), (32, ), (64, ), (8, )), + }, + ], + ), + ), + + # 'multihead_attention_forward': dict( + # name=['multihead_attention_forward'], + # interface=['CustomizedTest'], + # dtype=[np.float16], + # atol=1e-3, + # rtol=1e-4, + # para=dict( + # dropout_p=[0, 0], + # is_causal=[False, False], + # return_debug_mask=[False, False], + # scale=[None, None] + # ), + # tensor_para=dict( + # gen_fn='Genfunc.randn', + # args=[ + # { + # "ins": ['q'], + # "shape": ((2, 2, 2, 8), (2, 5, 7, 8)), + # "dtype": [np..float16], + # "gen_fn": Genfunc.randn, + # }, + # { + # "ins": ['k'], + # "shape": ((2, 2, 2, 8), (2, 5, 7, 8)), + # "dtype": [np.float16], + # "gen_fn": Genfunc.randn, + # }, + # { + # "ins": ['v'], + # "shape": ((2, 2, 2, 8), (2, 5, 7, 8)), + # "dtype": [np.float16], + # "gen_fn": Genfunc.randn, + # }, + # ], + # ), + # ), } diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index cc6be9137..1ad82c6ae 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -3974,3 +3974,40 @@ def linalgqr(input, mode): out = [q, r] check_returncode(ret) return out + + +def rotary_emb(input, cos, sin, conj): + call = "diopiRotaryEmbedding" + func = check_function(call) + size = list(input.size().data) + out = Tensor(size, input.get_dtype()) + ret = func(input.context(), out, input, cos, sin, conj, False) + check_returncode(ret) + return out + + +def rms_norm(input, normalized_shape, weight, bias, eps): + call = "diopiRMSNorm" + func = check_function(call) + size = list(input.size().data) + out = Tensor(size, input.get_dtype()) + inv_rms = Tensor(size, input.get_dtype()) + normalized_shape = Sizes(list(normalized_shape)) + ret = func(input.context(), out, inv_rms, input, normalized_shape, weight, bias, eps) + check_returncode(ret) + return out + + +def multihead_attention_forward(q, k, v, dropout_p, is_causal, return_debug_mask, scale): + call = "diopiMultiHeadAttention" + func = check_function(call) + q_size = list(q.size().data) + k_size = list(k.size().data) + out = Tensor(q_size, q.get_dtype()) + softmax_lse = Tensor([q_size[0], q_size[2], q_size[1]], q.get_dtype()) + gen = None + debug_attn_mask = Tensor([0], q.get_dtype()) + softmax_scale = 1.0 / math.sqrt(q.shape().data[-1]) if not scale else scale + ret = func(q.context(), q, k, v, dropout_p, is_causal, return_debug_mask, softmax_scale, out, softmax_lse, gen, debug_attn_mask) + check_returncode(ret) + return out diff --git a/diopi_test/python/conformance/gen_output.py b/diopi_test/python/conformance/gen_output.py index efc5b5c7a..9100491e4 100644 --- a/diopi_test/python/conformance/gen_output.py +++ b/diopi_test/python/conformance/gen_output.py @@ -186,6 +186,47 @@ def batch_norm_elemt(input, weight, bias, mean, invstd, eps): out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) return out + def rotary_emb(input, cos, sin, conj): + x1, x2 = input.chunk(2, dim=-1) + data_type = input.dtype + x1 = x1.to(torch.float32) + x2 = x2.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + if not conj: + out1 = x1 * cos - x2 * sin + out2 = x1 * sin + x2 * cos + else: + out1 = x1 * cos + x2 * sin + out2 = -x1 * sin + x2 * cos + out1 = out1.to(data_type) + out2 = out2.to(data_type) + out = torch.cat((out1, out2), dim=-1) + return out + + def rms_norm(input, normalized_shape, weight, bias, eps): + variance = input.to(torch.float32).pow(2).mean(-1, keepdim=True) + input = input * torch.rsqrt(variance + eps) + out = weight * input + return out + + def multihead_attention_forward(q, k, v, dropout_p, is_causal, return_debug_mask, scale): + # 为了保证精度,因此在test的时候不使用dropout + from einops import rearrange + import math + + _, seqlen = q.shape[0], q.shape[1] + softmax_scale = 1.0 / math.sqrt(q.shape[-1]) if not scale else scale + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if is_causal: + causal_mask = torch.triu( + torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 + ) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + output = torch.einsum("bhts,bshd->bthd", attention, v) + return output + class GenOutputData(object): r''' diff --git a/impl/ascend/device_configs.py b/impl/ascend/device_configs.py index ac3c7ec05..6cc7362ed 100755 --- a/impl/ascend/device_configs.py +++ b/impl/ascend/device_configs.py @@ -2440,5 +2440,29 @@ interface=['torch'], atol=1e-4, rtol=1e-4, + ), + + 'rotary_emb': dict( + name=["rotary_emb"], + tensor_para=dict( + args=[ + { + "ins": ['input'], + "dtype": [Skip(np.float64), Skip(np.float32), Skip(np.float16)], + }, + ], + ), + ), + + 'rms_norm': dict( + name=["rms_norm"], + tensor_para=dict( + args=[ + { + "ins": ['input'], + "dtype": [Skip(np.float32)], + }, + ], + ), ) } diff --git a/impl/camb/common/dtype_cast.cpp b/impl/camb/common/dtype_cast.cpp index 172a725d2..e3fdffdb6 100644 --- a/impl/camb/common/dtype_cast.cpp +++ b/impl/camb/common/dtype_cast.cpp @@ -22,6 +22,7 @@ bool isComplexDtype(diopiDtype_t dtype) { return (dtype == diopi_dtype_complex32 inline bool canCastByInt32(uint64_t castType) { // special convert (cnnl doesn't support) + constexpr uint64_t boolInt8 = MAKE_KEY(diopi_dtype_bool, diopi_dtype_int8); constexpr uint64_t boolInt64 = MAKE_KEY(diopi_dtype_bool, diopi_dtype_int64); constexpr uint64_t int16Int64 = MAKE_KEY(diopi_dtype_int16, diopi_dtype_int64); constexpr uint64_t uint8Bool = MAKE_KEY(diopi_dtype_uint8, diopi_dtype_bool); @@ -32,7 +33,7 @@ inline bool canCastByInt32(uint64_t castType) { constexpr uint64_t int64Int8 = MAKE_KEY(diopi_dtype_int64, diopi_dtype_int8); return boolInt64 == castType || int16Int64 == castType || uint8Bool == castType || int16Bool == castType || int64Bool == castType || int8Bool == castType || - int8Int64 == castType || int64Int8 == castType; + int8Int64 == castType || int64Int8 == castType || boolInt8 == castType; } inline bool canCastByFloat32(uint64_t castType) { diff --git a/impl/camb/device_configs.py b/impl/camb/device_configs.py index 63f297d28..6699618a5 100644 --- a/impl/camb/device_configs.py +++ b/impl/camb/device_configs.py @@ -1956,4 +1956,28 @@ atol=1e-2, rtol=1e-3, ), + + 'rotary_emb': dict( + name=["rotary_emb"], + tensor_para=dict( + args=[ + { + "ins": ['input'], + "dtype": [Skip(np.float64), Skip(np.float32), Skip(np.float16)], + }, + ], + ), + ), + + 'rms_norm': dict( + name=["rms_norm"], + tensor_para=dict( + args=[ + { + "ins": ['input'], + "dtype": [Skip(np.float32)], + }, + ], + ), + ), } diff --git a/impl/torch/CMakeLists.txt b/impl/torch/CMakeLists.txt index a68c1aa5c..5ef27894a 100644 --- a/impl/torch/CMakeLists.txt +++ b/impl/torch/CMakeLists.txt @@ -27,12 +27,13 @@ if (DYLOAD) set(IMPL_SRC error.cpp wrap_func.cpp) endif() -file(GLOB_RECURSE REAL_IMPL_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} functions_mmcv/*.cu) +file(GLOB REAL_IMPL_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} functions_mmcv/*.cu functions_ext/*.cu) set(REAL_IMPL_SRC ${REAL_IMPL_SRC} nms_kernel.cu roi_align_kernel.cu functions.cpp functions_mmcv.cpp + functions_ext.cpp error.cpp) if(HIP) @@ -74,7 +75,10 @@ else () target_link_libraries(${DEVICEIMPL} ${HIP_LIBRARIES} ${TORCH_LIBRARIES}) else() cuda_add_library(${DEVICEIMPL} SHARED ${REAL_IMPL_SRC}) + # target_compile_definitions(${DEVICEIMPL} PRIVATE __CUDA_NO_HALF_OPERATORS__) target_link_libraries(${DEVICEIMPL} ${CUDA_LIBRARIES} ${TORCH_LIBRARIES}) + add_subdirectory(functions_ext/flash-attention) + target_link_libraries(${DEVICEIMPL} diopi_torch_ext_flash_attn) endif() endif() diff --git a/impl/torch/code_gen.py b/impl/torch/code_gen.py index ff7260ace..dbe3df60c 100644 --- a/impl/torch/code_gen.py +++ b/impl/torch/code_gen.py @@ -21,6 +21,7 @@ */\n\ #include \n\ #include \n\ +#include \n\ #include \n\ #include \n\ \n\ @@ -129,6 +130,10 @@ def gen_wrapper_func(content): content_mmcv = f.readlines() print("generate for functions_mmcv.h") gen_wrapper_func(content_mmcv) + with open(os.path.join(_cur_dir, '../proto/include/diopi/functions_ext.h'), 'r') as f: + content_ext = f.readlines() + print("generate for functions_ext.h") + gen_wrapper_func(content_ext) os.system("rm -f wrap_func.cpp") print("generate wrap_func.cpp") with open('wrap_func.cpp', 'w') as f: diff --git a/impl/torch/cuda_helpers.h b/impl/torch/cuda_helpers.h index ba7ac86ca..ac277b0d1 100644 --- a/impl/torch/cuda_helpers.h +++ b/impl/torch/cuda_helpers.h @@ -7,27 +7,22 @@ #ifndef IMPL_TORCH_CUDA_HELPERS_H_ #define IMPL_TORCH_CUDA_HELPERS_H_ -#include -#include -#include -#include +#include // IWYU pragma: export +#include -#include -#include +#include // IWYU pragma: export +#include // IWYU pragma: export #include - -using at::Half; -using at::Tensor; -using phalf = at::Half; - -#define __PHALF(x) (x) -#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) +#include namespace cuda { namespace helper { -template -constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { +using at::Tensor; + +template +constexpr __host__ __device__ inline IntType ceil_div(IntType n, IntType m) { + static_assert(std::is_integral::value, "ceil_div only accept integral types"); return (n + m - 1) / m; } @@ -41,12 +36,12 @@ constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { for (size_t i = blockIdx.x; i < (n); i += gridDim.x) \ for (size_t j = blockIdx.y; j < (m); j += gridDim.y) -#define THREADS_PER_BLOCK 512 +constexpr int THREADS_PER_BLOCK = 512; inline int GET_BLOCKS(const int N, const int num_threads = THREADS_PER_BLOCK) { int optimal_block_num = (N + num_threads - 1) / num_threads; int max_block_num = 4096; - return min(optimal_block_num, max_block_num); + return std::min(optimal_block_num, max_block_num); } template diff --git a/impl/torch/ext_kernel.h b/impl/torch/ext_kernel.h new file mode 100644 index 000000000..3c347a740 --- /dev/null +++ b/impl/torch/ext_kernel.h @@ -0,0 +1,26 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2023, DeepLink. + */ + +#ifndef IMPL_TORCH_EXT_KERNEL_H_ +#define IMPL_TORCH_EXT_KERNEL_H_ + +#include +#include + +namespace ext { +namespace ops { + +void apply_rotary_cuda(const at::Tensor& x1, const at::Tensor& x2, const at::Tensor& cos, const at::Tensor& sin, at::Tensor out1, at::Tensor out2, + const bool conj); + +void rms_norm_forward(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, double epsilon, at::Tensor output, at::Tensor invvar); + +void rms_norm_backward(at::Tensor dout, at::Tensor invvar, at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, double epsilon, + at::Tensor grad_input, at::Tensor grad_gamma); + +} // namespace ops +} // namespace ext +#endif // IMPL_TORCH_EXT_KERNEL_H_ diff --git a/impl/torch/functions_ext.cpp b/impl/torch/functions_ext.cpp new file mode 100644 index 000000000..12624ae8f --- /dev/null +++ b/impl/torch/functions_ext.cpp @@ -0,0 +1,253 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2023, DeepLink. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// TODO(lljbash): the dependency on context.h makes no sense, check and refactor +#include "context.h" // IWYU pragma: keep +#include "ext_kernel.h" +#include "helper.hpp" + +namespace { + +c10::optional buildGeneratorForMha(diopiContextHandle_t ctx, diopiGeneratorHandle_t gen, double dropoutP) { + if (gen == nullptr) { + if (dropoutP != 0) { + throw std::runtime_error("dropout option requires a generator to be set"); + } + return c10::nullopt; + } + return impl::aten::buildGenerator(ctx, gen); +} + +} // namespace + +extern "C" { + +diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t x, diopiConstTensorHandle_t cos, + diopiConstTensorHandle_t sin, const bool conj, const bool interleaved) { + impl::aten::setCurCtx(ctx); + auto atX = impl::aten::buildATen(x); + auto atCos = impl::aten::buildATen(cos); + auto atSin = impl::aten::buildATen(sin); + auto atOut = impl::aten::buildATen(out); + int lastDim = atX.dim() - 1; // 确定最后一个维度的索引 + auto chunks = atX.chunk(2, lastDim); // 将 atX 切分为两个部分 + auto x1 = chunks[0]; + auto x2 = chunks[1]; + auto chunksOut = atOut.chunk(2, lastDim); + auto out1 = chunksOut[0]; + auto out2 = chunksOut[1]; + ext::ops::apply_rotary_cuda(x1, x2, atCos, atSin, out1, out2, conj); + impl::aten::unsetCurCtx(); + return diopiSuccess; +} + +diopiError_t diopiRMSNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t invRMS, diopiConstTensorHandle_t input, + diopiSize_t normalized_shape, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, double eps) { + impl::aten::setCurCtx(ctx); + auto atOut = impl::aten::buildATen(out); + auto atInvRMS = impl::aten::buildATen(invRMS); + auto atInput = impl::aten::buildATen(input); + auto atNormalizedShape = impl::aten::buildAtIntArray(normalized_shape); + auto atWeight = impl::aten::buildATen(weight); + auto atBias = impl::aten::buildATen(bias); // bias在这里实际上没有使用 + ext::ops::rms_norm_forward(atInput, atNormalizedShape, atWeight, eps, atOut, atInvRMS); + impl::aten::unsetCurCtx(); + return diopiSuccess; +} + +diopiError_t diopiRMSNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t gradWeight, diopiTensorHandle_t gradBias, + diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t bias, diopiConstTensorHandle_t invRMS, diopiSize_t normalized_shape, double eps) { + impl::aten::setCurCtx(ctx); + auto atGradInput = impl::aten::buildATen(gradInput); + auto atGradWeight = impl::aten::buildATen(gradWeight); + auto atGradBias = impl::aten::buildATen(gradBias); + auto atGradOutput = impl::aten::buildATen(gradOutput); + auto atInvRMS = impl::aten::buildATen(invRMS); + auto atInput = impl::aten::buildATen(input); + auto atNormalizedShape = impl::aten::buildAtIntArray(normalized_shape); + auto atWeight = impl::aten::buildATen(weight); + auto atBias = impl::aten::buildATen(bias); // bias在这里实际上没有使用 + ext::ops::rms_norm_backward(atGradOutput, atInvRMS, atInput, atNormalizedShape, atWeight, eps, atGradInput, atGradWeight); + impl::aten::unsetCurCtx(); + return diopiSuccess; +} + +diopiError_t diopiMultiHeadAttention(diopiContextHandle_t ctx, diopiTensorHandle_t q, diopiTensorHandle_t k, diopiTensorHandle_t v, double dropout_p, + bool is_causal, bool return_debug_mask, double scale, diopiTensorHandle_t out, diopiTensorHandle_t softmax_lse, + diopiGeneratorHandle_t gen, diopiTensorHandle_t debug_attn_mask) { + impl::aten::setCurCtx(ctx); + + auto atQ = impl::aten::buildATen(q).contiguous(); + auto atK = impl::aten::buildATen(k).contiguous(); + auto atV = impl::aten::buildATen(v).contiguous(); + auto atGen = buildGeneratorForMha(ctx, gen, dropout_p); + + c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + std::vector result = DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, nullOpt, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen); + // const auto& atOutput = result[0]; + const auto& atQPaded = result[1]; + const auto& atKPaded = result[2]; + const auto& atVPaded = result[3]; + const auto& atOutPaded = result[4]; + const auto& atLogSumexp = result[5]; + const auto& atDebugAttnMask = result[6]; + // const auto& atRngState = result[7]; + + // TODO(lljbash): support auto padding + auto headSize = atQ.sizes()[3]; + TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8"); + + // TODO(gqw): check why this is needed + impl::aten::updateATen2Tensor(ctx, atQPaded, q); + impl::aten::updateATen2Tensor(ctx, atKPaded, k); + impl::aten::updateATen2Tensor(ctx, atVPaded, v); + impl::aten::updateATen2Tensor(ctx, atOutPaded, out); + impl::aten::updateATen2Tensor(ctx, atLogSumexp, softmax_lse); + if (return_debug_mask) { + impl::aten::updateATen2Tensor(ctx, atDebugAttnMask, debug_attn_mask); + } + + impl::aten::unsetCurCtx(); + return diopiSuccess; +} + +diopiError_t diopiMultiHeadAttentionBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_out, diopiConstTensorHandle_t q, + diopiConstTensorHandle_t k, diopiConstTensorHandle_t v, diopiConstTensorHandle_t out, + diopiConstTensorHandle_t softmax_lse, double dropout_p, bool is_causal, diopiGeneratorHandle_t gen, double scale, + diopiTensorHandle_t grad_q, diopiTensorHandle_t grad_k, diopiTensorHandle_t grad_v) { + impl::aten::setCurCtx(ctx); + + auto atQ = impl::aten::buildATen(q).contiguous(); + auto atK = impl::aten::buildATen(k).contiguous(); + auto atV = impl::aten::buildATen(v).contiguous(); + auto atGen = buildGeneratorForMha(ctx, gen, dropout_p); + auto atGradOut = impl::aten::buildATen(grad_out).contiguous(); + auto atOut = impl::aten::buildATen(out).contiguous(); + auto atLogsumexp = impl::aten::buildATen(softmax_lse); + + c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + std::vector result = DIOPI_EXT_CALL_FLASH( + mha_bwd, atGradOut, atQ, atK, atV, atOut, atLogsumexp, nullOpt, nullOpt, nullOpt, dropout_p, scale, is_causal, -1, -1, atGen, nullOpt); + const auto& atGradQ = result[0]; + const auto& atGradK = result[1]; + const auto& atGradV = result[2]; + + impl::aten::updateATen2Tensor(ctx, atGradQ, grad_q); + impl::aten::updateATen2Tensor(ctx, atGradK, grad_k); + impl::aten::updateATen2Tensor(ctx, atGradV, grad_v); + + impl::aten::unsetCurCtx(); + return diopiSuccess; +} + +diopiError_t diopiMultiHeadAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHandle_t q, diopiTensorHandle_t k, diopiTensorHandle_t v, + diopiConstTensorHandle_t cum_seq_q, diopiConstTensorHandle_t cum_seq_k, int64_t max_q, int64_t max_k, + double dropout_p, bool is_causal, bool return_debug_mask, double scale, diopiTensorHandle_t out, + diopiTensorHandle_t softmax_lse, diopiGeneratorHandle_t gen, diopiTensorHandle_t debug_attn_mask) { + impl::aten::setCurCtx(ctx); + + auto atQ = impl::aten::buildATen(q).clone(); + auto atK = impl::aten::buildATen(k).clone(); + auto atV = impl::aten::buildATen(v).clone(); + auto atCumSeqQ = impl::aten::buildATen(cum_seq_q); + auto atCumSeqK = impl::aten::buildATen(cum_seq_k); + auto atGen = buildGeneratorForMha(ctx, gen, dropout_p); + + c10::optional outputNull; + std::vector result = DIOPI_EXT_CALL_FLASH( + mha_varlen_fwd, atQ, atK, atV, outputNull, atCumSeqQ, atCumSeqK, max_q, max_k, dropout_p, scale, false, is_causal, -1, -1, return_debug_mask, atGen); + // auto atOutput = result[0]; + auto atQPadded = result[1]; + auto atKPadded = result[2]; + auto atVPadded = result[3]; + auto atOutPadded = result[4]; + auto atLogSumexp = result[5]; + auto atDebugAttnMask = result[6]; + // auto atRngState = result[7]; + + // TODO(lljbash): support auto padding + auto headSize = atQ.sizes()[3]; + TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8"); + + // TODO(gqw): check why this is needed + impl::aten::updateATen2Tensor(ctx, atQPadded, q); + impl::aten::updateATen2Tensor(ctx, atKPadded, k); + impl::aten::updateATen2Tensor(ctx, atVPadded, v); + impl::aten::updateATen2Tensor(ctx, atOutPadded, out); + impl::aten::updateATen2Tensor(ctx, atLogSumexp, softmax_lse); + if (return_debug_mask) { + impl::aten::updateATen2Tensor(ctx, atDebugAttnMask, debug_attn_mask); + } + + impl::aten::unsetCurCtx(); + return diopiSuccess; +} + +diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_out, diopiConstTensorHandle_t q, + diopiConstTensorHandle_t k, diopiConstTensorHandle_t v, diopiConstTensorHandle_t out, + diopiConstTensorHandle_t softmax_lse, diopiConstTensorHandle_t cum_seq_q, diopiConstTensorHandle_t cum_seq_k, + int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, diopiGeneratorHandle_t gen, double scale, + diopiTensorHandle_t grad_q, diopiTensorHandle_t grad_k, diopiTensorHandle_t grad_v) { + impl::aten::setCurCtx(ctx); + + auto atQ = impl::aten::buildATen(q).contiguous(); + auto atK = impl::aten::buildATen(k).contiguous(); + auto atV = impl::aten::buildATen(v).contiguous(); + auto atGen = buildGeneratorForMha(ctx, gen, dropout_p); + auto atGradOut = impl::aten::buildATen(grad_out).contiguous(); + auto atOut = impl::aten::buildATen(out).contiguous(); + auto atLogsumexp = impl::aten::buildATen(softmax_lse); + auto atCumSeqQ = impl::aten::buildATen(cum_seq_q); + auto atCumSeqK = impl::aten::buildATen(cum_seq_k); + + auto nullOpt = c10::optional(); // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_bwd, + atGradOut, + atQ, + atK, + atV, + atOut, + atLogsumexp, + nullOpt, + nullOpt, + nullOpt, + atCumSeqQ, + atCumSeqK, + max_q, + max_k, + dropout_p, + scale, + false, + is_causal, + -1, + -1, + atGen, + nullOpt); + const auto& atGradQ = result[0]; + const auto& atGradK = result[1]; + const auto& atGradV = result[2]; + + impl::aten::updateATen2Tensor(ctx, atGradQ, grad_q); + impl::aten::updateATen2Tensor(ctx, atGradK, grad_k); + impl::aten::updateATen2Tensor(ctx, atGradV, grad_v); + + impl::aten::unsetCurCtx(); + return diopiSuccess; +} + +} // extern "C" diff --git a/impl/torch/functions_ext/flash-attention/CMakeLists.txt b/impl/torch/functions_ext/flash-attention/CMakeLists.txt new file mode 100644 index 000000000..d86fe77a0 --- /dev/null +++ b/impl/torch/functions_ext/flash-attention/CMakeLists.txt @@ -0,0 +1,26 @@ +find_library( + DIOPI_TORCH_EXT_FLASH_ATTN_LIB + NAMES # Set env FLASH_ATTN_LIB_NAME if your lib has a different name + ENV FLASH_ATTN_LIB_NAME + # this is the default name of the flash-attention library + flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so + HINTS + # Set environment variable FLASH_ATTN_LIB_DIR to the path of the library + ENV FLASH_ATTN_LIB_DIR + # This is the path on cluster 1424 A100 + /mnt/cache/shenliancheng/workspace/flash-attention/build/lib.linux-x86_64-cpython-38 +) + +if(NOT DIOPI_TORCH_EXT_FLASH_ATTN_LIB) + message(WARNING "flash-attention NOT FOUND, will build without mha support") + add_library(diopi_torch_ext_flash_attn INTERFACE) +else() + message(STATUS "FOUND flash-attention: ${DIOPI_TORCH_EXT_FLASH_ATTN_LIB}") + add_library(diopi_torch_ext_flash_attn SHARED IMPORTED GLOBAL) + set_target_properties( + diopi_torch_ext_flash_attn PROPERTIES IMPORTED_LOCATION + ${DIOPI_TORCH_EXT_FLASH_ATTN_LIB}) + target_link_options(diopi_torch_ext_flash_attn INTERFACE "LINKER:-no-as-needed") +endif() + +target_include_directories(diopi_torch_ext_flash_attn INTERFACE include) diff --git a/impl/torch/functions_ext/flash-attention/include/flash_attn/flash_api.h b/impl/torch/functions_ext/flash-attention/include/flash_attn/flash_api.h new file mode 100644 index 000000000..72f2e3f37 --- /dev/null +++ b/impl/torch/functions_ext/flash-attention/include/flash_attn/flash_api.h @@ -0,0 +1,72 @@ +#ifndef IMPL_TORCH_FUNCTIONS_EXT_FLASH_ATTENTION_INCLUDE_FLASH_ATTN_FLASH_API_H_ +#define IMPL_TORCH_FUNCTIONS_EXT_FLASH_ATTENTION_INCLUDE_FLASH_ATTN_FLASH_API_H_ + +#include +#include +#include + +#include // IWYU pragma: keep +#include + +#define DIOPI_EXT_CALL_FLASH(func, ...) \ + [&] { \ + if (func == nullptr) { \ + throw std::runtime_error("unable to call flash " #func ": DIOPI is built without flash-attention"); \ + } \ + return func(__VA_ARGS__); \ + }() + +std::vector __attribute__((weak)) mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float p_dropout, const float softmax_scale, bool is_causal, const int window_size_left, + int window_size_right, const bool return_softmax, c10::optional gen_); + +std::vector __attribute__((weak)) +mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, const bool is_causal, + const int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); + +std::vector __attribute__((weak)) mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + const float p_dropout, // probability to drop + const float softmax_scale, const bool is_causal, const int window_size_left, int window_size_right, + c10::optional gen_, c10::optional &rng_state); + +std::vector __attribute__((weak)) +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, const bool zero_tensors, const bool is_causal, const int window_size_left, int window_size_right, + c10::optional gen_, c10::optional &rng_state); + +std::vector __attribute__((weak)) +mha_fwd_kvcache(at::Tensor &q, const at::Tensor &kcache, const at::Tensor &vcache, const at::Tensor &k, const at::Tensor &v, const at::Tensor &seqlens_k, + const at::Tensor &rotary_cos, const at::Tensor &rotary_sin, const at::Tensor &cache_batch_idx, float softmax_scale, bool is_causal, + int window_size_left, int window_size_right, bool is_rotary_interleaved, int num_splits); + +#endif // IMPL_TORCH_FUNCTIONS_EXT_FLASH_ATTENTION_INCLUDE_FLASH_ATTN_FLASH_API_H_ diff --git a/impl/torch/functions_ext/layer_norm_cuda_kernel.cu b/impl/torch/functions_ext/layer_norm_cuda_kernel.cu new file mode 100644 index 000000000..2badb1159 --- /dev/null +++ b/impl/torch/functions_ext/layer_norm_cuda_kernel.cu @@ -0,0 +1,1028 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +template +__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template +__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, U& mu, U& sigma2, U& count) { + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA * mu + nB * muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template +__device__ void cuRMSOnlineSum(const U curr, U& sigma2) { + sigma2 = sigma2 + curr * curr; +} + +template +__device__ void cuChanRMSOnlineSum(const U sigma2B, U& sigma2) { + sigma2 = sigma2 + sigma2B; +} + +template +__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1, const int n2, const int i1, U& mu, U& sigma2, U* buf, bool rms_only) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu = U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1 * n2; + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l + k]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + U sigma2B = WARP_SHFL(sigma2, srcLaneB); + if (!rms_only) { + U muB = WARP_SHFL(mu, srcLaneB); + U countB = WARP_SHFL(count, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = reinterpret_cast(buf); + U* ibuf = reinterpret_cast(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + if (!rms_only) { + ubuf[2 * wrt_y] = mu; + ibuf[wrt_y] = count; + } + ubuf[2 * wrt_y + 1] = sigma2; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U sigma2B = ubuf[2 * threadIdx.y + 1]; + if (!rms_only) { + U muB = ubuf[2 * threadIdx.y]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + ubuf[0] = mu; + } + ubuf[1] = sigma2; + } + __syncthreads(); + if (!rms_only) { + mu = ubuf[0]; + } + sigma2 = ubuf[1] / U(n2); + // don't care about final value of count, we know count == n2 + } else { + if (!rms_only) { + mu = WARP_SHFL(mu, 0); + } + sigma2 = WARP_SHFL(sigma2 / U(n2), 0); + } + } +} + +template <> +__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals, const int n1, const int n2, const int i1, float& mu, float& sigma2, float* buf, + bool rms_only) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu = static_cast(0); + sigma2 = static_cast(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + at::Half* lvals = const_cast(vals) + i1 * n2; + int l = 8 * thrx; + if ((((size_t)lvals) & 3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l + 7 < n2; l += 8 * numx) { + for (int k = 0; k < 8; k += 2) { + at::Half* offset = lvals + l + k; + float2 curr = __half22float2(*reinterpret_cast<__half2*>(offset)); + if (!rms_only) { + cuWelfordOnlineSum(curr.x, mu, sigma2, count); + cuWelfordOnlineSum(curr.y, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr.x, sigma2); + cuRMSOnlineSum(curr.y, sigma2); + } + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + float sigma2B = WARP_SHFL(sigma2, srcLaneB); + if (!rms_only) { + float muB = WARP_SHFL(mu, srcLaneB); + float countB = WARP_SHFL(count, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float* ubuf = reinterpret_cast(buf); + float* ibuf = reinterpret_cast(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y + 1] = sigma2; + if (!rms_only) { + ubuf[2 * wrt_y] = mu; + ibuf[wrt_y] = count; + } + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float sigma2B = ubuf[2 * threadIdx.y + 1]; + if (!rms_only) { + float muB = ubuf[2 * threadIdx.y]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + ubuf[0] = mu; + } + ubuf[1] = sigma2; + } + __syncthreads(); + if (!rms_only) { + mu = ubuf[0]; + } + sigma2 = ubuf[1] / static_cast(n2); + // don't care about final value of count, we know count == n2 + } else { + if (!rms_only) { + mu = WARP_SHFL(mu, 0); + } + sigma2 = WARP_SHFL(sigma2 / static_cast(n2), 0); + } + } +} + +} // namespace + +template +static U rsqrt(U v) { + return U(1) / sqrt(v); +} +template <> +float rsqrt(float v) { + return rsqrtf(v); +} +template <> +double rsqrt(double v) { + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of this +// struct by putting an undefined symbol in the function body so it won't compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory { + __device__ float* getPointer() { + extern __shared__ float s_float[]; + return s_float; + } +}; + +template <> +struct SharedMemory { + __device__ double* getPointer() { + extern __shared__ double s_double[]; + return s_double; + } +}; +} // namespace + +namespace { + +template +__device__ void cuApplyLayerNorm_(V* __restrict__ output_vals, U* __restrict__ mean, U* __restrict__ invvar, const T* __restrict__ vals, const int n1, + const int n2, const U epsilon, const V* __restrict__ gamma, const V* __restrict__ beta, bool rms_only) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu, sigma2; + cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, rms_only); + + const T* lvals = vals + i1 * n2; + V* ovals = output_vals + i1 * n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && (beta != NULL || rms_only)) { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + if (!rms_only) { + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } else { + ovals[i] = gamma[i] * static_cast(c_invvar * curr); + } + } + } else { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + if (!rms_only) { + ovals[i] = static_cast(c_invvar * (curr - mu)); + } else { + ovals[i] = static_cast(c_invvar * curr); + } + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + mean[i1] = mu; + } + invvar[i1] = c_invvar; + } + __syncthreads(); + } +} + +template +__global__ void cuApplyLayerNorm(V* __restrict__ output_vals, U* __restrict__ mean, U* __restrict__ invvar, const T* __restrict__ vals, const int n1, + const int n2, const U epsilon, const V* __restrict__ gamma, const V* __restrict__ beta) { + cuApplyLayerNorm_(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false); +} + +template +__global__ void cuApplyRMSNorm(V* __restrict__ output_vals, U* __restrict__ invvar, const T* __restrict__ vals, const int n1, const int n2, const U epsilon, + const V* __restrict__ gamma) { + cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); +} + +template +__device__ void cuLoadWriteStridedInputs(const int i1_block, const int thr_load_row_off, const int thr_load_col_off, const int i2_off, const int row_stride, + U* warp_buf1, U* warp_buf2, const T* input, const V* dout, const int i1_end, const int n2, const U* __restrict__ mean, + const U* __restrict__ invvar, bool rms_only) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean; + if (!rms_only) { + curr_mean = mean[i1]; + } + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + if (!rms_only) { + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf2[write_idx] = curr_dout * (curr_input)*curr_invvar; + } + } else { + if (!rms_only) { + warp_buf1[write_idx] = U(0); + } + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (!rms_only) { + warp_buf1[write_idx] = U(0); + } + warp_buf2[write_idx] = U(0); + } + } +} + +template +__device__ void cuLoadAddStridedInputs(const int i1_block, const int thr_load_row_off, const int thr_load_col_off, const int i2_off, const int row_stride, + U* warp_buf1, U* warp_buf2, const T* input, const V* dout, const int i1_end, const int n2, const U* __restrict__ mean, + const U* __restrict__ invvar, bool rms_only) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean; + if (!rms_only) { + curr_mean = mean[i1]; + } + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + if (!rms_only) { + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf2[write_idx] += curr_dout * (curr_input)*curr_invvar; + } + } + } + } +} + +template +__global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, const T* __restrict__ input, const int n1, const int n2, const U* __restrict__ mean, + const U* __restrict__ invvar, U epsilon, U* part_grad_gamma, U* part_grad_beta, bool rms_only) { + const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; + const int i1_beg_plus_one = (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x + 1; + const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); + const int thr_load_row_off = (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = reinterpret_cast(buf); + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs( + i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar, rms_only); + for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; i1_block += blockDim.y * blockDim.y) { + cuLoadAddStridedInputs( + i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar, rms_only); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k * blockDim.y; + int idx1 = row1 * row_stride + threadIdx.x; + if (!rms_only) { + acc1 += warp_buf1[idx1]; + } + acc2 += warp_buf2[idx1]; + } + if (!rms_only) { + warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; + } + warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + if (!rms_only) { + warp_buf1[idx1] += warp_buf1[idx2]; + } + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + if (!rms_only) { + part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; + } + part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template +__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const U* part_grad_beta, const int part_size, const int n1, const int n2, V* grad_gamma, + V* grad_beta, bool rms_only) { + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; + if (!rms_only) { + sum_beta += part_grad_beta_ptr[warp_offset * n2]; + } + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + if (!rms_only) { + buf[write_idx + nbsize3] = sum_beta; + } + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + if (!rms_only) { + sum_beta += buf[read_idx + nbsize3]; + } + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + if (!rms_only) { + grad_beta[i2] = sum_beta; + } + } + } +} + +template +__global__ void cuComputeGradInput(const V* __restrict__ dout, const T* __restrict__ input, const int n1, const int n2, const U* __restrict__ mean, + const U* __restrict__ invvar, U epsilon, const V* gamma, T* grad_input, bool rms_only) { + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + U c_mean; + if (!rms_only) { + c_mean = mean[i1]; + } + const U c_invvar = invvar[i1]; + const T* k_input = input + i1 * n2; + const V* k_dout = dout + i1 * n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + if (!rms_only) { + sum_loss1 += c_loss * gamma[l + k]; + sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * gamma[l + k] * (c_h)*c_invvar; + } + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + if (!rms_only) { + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * gamma[l] * (c_h)*c_invvar; + } + } + } else { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + if (!rms_only) { + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * (c_h)*c_invvar; + } + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + if (!rms_only) { + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * (c_h)*c_invvar; + } + } + } + // intra-warp reductions + for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { + if (!rms_only) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + } + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + if (!rms_only) { + buf[2 * wrt_i] = sum_loss1; + } + buf[2 * wrt_i + 1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + if (!rms_only) { + sum_loss1 += buf[2 * read_i]; + } + sum_loss2 += buf[2 * read_i + 1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + if (!rms_only) { + buf[2 * threadIdx.x] = sum_loss1; + } + buf[2 * threadIdx.x + 1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y != 0) { + if (!rms_only) { + sum_loss1 = buf[2 * threadIdx.x]; + } + sum_loss2 = buf[2 * threadIdx.x + 1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1 * n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + if (!rms_only) { + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + f_grad_input -= (c_h)*c_invvar * sum_loss2; + } + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + if (!rms_only) { + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + f_grad_input -= (c_h)*c_invvar * sum_loss2; + } + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + // prevent race where buf is written again before reads are done + __syncthreads(); + } +} + +} // namespace + +namespace ext { +namespace ops { +namespace { + +template +void HostApplyLayerNorm(V* output, U* mean, U* invvar, const T* input, int n1, int n2, double epsilon, const V* gamma, const V* beta) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32, 4, 1); + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyLayerNorm<<>>(output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); +} + +template +void HostApplyRMSNorm(V* output, U* invvar, const T* input, int n1, int n2, double epsilon, const V* gamma) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32, 4, 1); + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyRMSNorm<<>>(output, invvar, input, n1, n2, U(epsilon), gamma); +} + +template +void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, at::Tensor* input, int n1, int n2, const V* gamma, const V* beta, double epsilon, + T* grad_input, V* grad_gamma, V* grad_beta) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32, 4, 1); + const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that + // the `cuda_layer_norm_gradient` doesn't support double. + const auto part_grad_dtype = + (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size, n2}, input->options().dtype(part_grad_dtype)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, input->data_ptr(), n1, n2, mean, invvar, U(epsilon), part_grad_gamma.data_ptr(), part_grad_beta.data_ptr(), false); + + const dim3 threads3(32, 8, 1); + const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.data_ptr(), part_grad_beta.data_ptr(), part_size, n1, n2, grad_gamma, grad_beta, false); + } + + // compute grad_input + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32, 4, 1); + int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; + cuComputeGradInput<<>>(dout, input->data_ptr(), n1, n2, mean, invvar, U(epsilon), gamma, grad_input, false); +} + +template +void HostRMSNormGradient(const V* dout, const U* invvar, at::Tensor* input, int n1, int n2, const V* gamma, double epsilon, T* grad_input, V* grad_gamma) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL) { + const int part_size = 16; + const dim3 threads2(32, 4, 1); + const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that + // the `cuda_layer_norm_gradient` doesn't support double. + const auto part_grad_dtype = + (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size, n2}, input->options().dtype(part_grad_dtype)); + cuComputePartGradGammaBeta<<>>(dout, + input->data_ptr(), + n1, + n2, + invvar, // unused + invvar, + U(epsilon), + part_grad_gamma.data_ptr(), + part_grad_gamma.data_ptr(), /* unused */ + true); + + const dim3 threads3(32, 8, 1); + const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>(part_grad_gamma.data_ptr(), + part_grad_gamma.data_ptr(), /* unused */ + part_size, + n1, + n2, + grad_gamma, + grad_gamma, /* unused */ + true); + } + + // compute grad_input + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32, 4, 1); + int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; + cuComputeGradInput<<>>(dout, + input->data_ptr(), + n1, + n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + grad_input, + true); +} + +} // namespace + +#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) { \ + case at::ScalarType::Double: { \ + using scalar_t_in = double; \ + switch (TYPEOUT) { \ + case at::ScalarType::Double: { \ + using scalar_t_out = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Float: { \ + using scalar_t_in = float; \ + switch (TYPEOUT) { \ + case at::ScalarType::Float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) { \ + case at::ScalarType::Float: { \ + using scalar_t_in = float; \ + switch (TYPEOUT) { \ + case at::ScalarType::Float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + +void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape, + at::Tensor* gamma, at::Tensor* beta, double epsilon) { + using namespace at; + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "layer_norm_cuda_kernel", using accscalar_t = at::acc_type; + HostApplyLayerNorm(output->data_ptr(), + mean->data_ptr(), + invvar->data_ptr(), + input->data_ptr(), + n1, + n2, + epsilon, + gamma != NULL ? gamma->data_ptr() : NULL, + beta != NULL ? beta->data_ptr() : NULL);) +} + +void cuda_rms_norm(at::Tensor* output, at::Tensor* invvar, at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma, + double epsilon) { + using namespace at; + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "rms_norm_cuda_kernel", using accscalar_t = at::acc_type; + HostApplyRMSNorm(output->data_ptr(), + invvar->data_ptr(), + input->data_ptr(), + n1, + n2, + epsilon, + gamma != NULL ? gamma->data_ptr() : NULL);) +} + +void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape, + at::Tensor* gamma, at::Tensor* beta, double epsilon, at::Tensor* grad_input, at::Tensor* grad_gamma, at::Tensor* grad_beta) { + using namespace at; + // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(input->scalar_type(), + gamma == NULL ? input->scalar_type() : gamma->scalar_type(), + "cuComputeGradInput", + using accscalar_t = at::acc_type; + HostLayerNormGradient(dout->data_ptr(), + mean->data_ptr(), + invvar->data_ptr(), + input, + n1, + n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->data_ptr() : NULL, + gamma != NULL ? beta->data_ptr() : NULL, + epsilon, + grad_input->data_ptr(), + gamma != NULL ? grad_gamma->data_ptr() : NULL, + gamma != NULL ? grad_beta->data_ptr() : NULL);) +} + +void cuda_rms_norm_gradient(at::Tensor* dout, at::Tensor* invvar, at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma, + double epsilon, at::Tensor* grad_input, at::Tensor* grad_gamma) { + using namespace at; + // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 + // DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(input->scalar_type(), + gamma == NULL ? input->scalar_type() : gamma->scalar_type(), + "cuComputeGradInputRMS", + using accscalar_t = at::acc_type; + HostRMSNormGradient(dout->data_ptr(), + invvar->data_ptr(), + input, + n1, + n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->data_ptr() : NULL, + epsilon, + grad_input->data_ptr(), + gamma != NULL ? grad_gamma->data_ptr() : NULL);) +} + +} // namespace ops +} // namespace ext diff --git a/impl/torch/functions_ext/rmsnorm.cu b/impl/torch/functions_ext/rmsnorm.cu new file mode 100644 index 000000000..9983f6e81 --- /dev/null +++ b/impl/torch/functions_ext/rmsnorm.cu @@ -0,0 +1,111 @@ +#include +#include +#include + +namespace ext { +namespace ops { +namespace { + +#define DIOPI_TORCH_EXT_CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") + +#define DIOPI_TORCH_EXT_CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define DIOPI_TORCH_EXT_CHECK_INPUT(x) \ + DIOPI_TORCH_EXT_CHECK_CUDA(x); \ + DIOPI_TORCH_EXT_CHECK_CONTIGUOUS(x) + +void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < static_cast(normalized_shape.size()); ++i) { + assert(input.sizes()[i + idiff] == normalized_shape[i]); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta) { + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma) { TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); } + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) { + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input, normalized_shape, n1, n2); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, int& n1, int& n2) { + check_args(input, normalized_shape, n1, n2); + check_args(normalized_shape, gamma, beta); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, int& n1, int& n2) { + check_args(input, normalized_shape, n1, n2); + check_args(normalized_shape, gamma); +} + +} // namespace + +// implemented in layer_norm_cuda_kernel.cu +void cuda_rms_norm(at::Tensor* output, at::Tensor* invvar, at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma, + double epsilon); + +// implemented in layer_norm_cuda_kernel.cu +void cuda_rms_norm_gradient(at::Tensor* dout, at::Tensor* invvar, at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma, + double epsilon, at::Tensor* grad_input, at::Tensor* grad_gamma); + +// 供cpp层调用函数,向外暴露的函数 + +void rms_norm_forward(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, double epsilon, at::Tensor output, at::Tensor invvar) { + DIOPI_TORCH_EXT_CHECK_INPUT(input); + DIOPI_TORCH_EXT_CHECK_INPUT(gamma); + int n1; + int n2; + check_args(input, normalized_shape, gamma, n1, n2); + const auto stats_dtype = + (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type(); + cuda_rms_norm(&output, &invvar, &input, n1, n2, normalized_shape, &gamma, epsilon); + return; +} + +void rms_norm_backward(at::Tensor dout, at::Tensor invvar, at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, double epsilon, + at::Tensor grad_input, at::Tensor grad_gamma) { + DIOPI_TORCH_EXT_CHECK_INPUT(dout); + DIOPI_TORCH_EXT_CHECK_INPUT(invvar); + DIOPI_TORCH_EXT_CHECK_INPUT(input); + DIOPI_TORCH_EXT_CHECK_INPUT(gamma); + int n1; + int n2; + check_args(input, normalized_shape, gamma, n1, n2); + cuda_rms_norm_gradient(&dout, &invvar, &input, n1, n2, normalized_shape, &gamma, epsilon, &grad_input, &grad_gamma); + return; +} + +} // namespace ops +} // namespace ext diff --git a/impl/torch/functions_ext/rotary.cu b/impl/torch/functions_ext/rotary.cu new file mode 100644 index 000000000..80f737ce2 --- /dev/null +++ b/impl/torch/functions_ext/rotary.cu @@ -0,0 +1,48 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + +#include + +namespace ext { +namespace ops { + +void apply_rotary_cuda(const at::Tensor& x1, const at::Tensor& x2, const at::Tensor& cos, const at::Tensor& sin, at::Tensor out1, at::Tensor out2, + const bool conj) { + auto iter = at::TensorIteratorConfig() + .add_output(out1) + .add_output(out2) + .add_input(x1) + .add_input(x2) + .add_input(cos) + .add_input(sin) + .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(false) + .build(); + + if (!conj) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { + at::native::gpu_kernel_multiple_outputs( + iter, [] GPU_LAMBDA(scalar_t x1, scalar_t x2, scalar_t cos, scalar_t sin) -> thrust::tuple { + scalar_t out1 = static_cast(x1) * static_cast(cos) - static_cast(x2) * static_cast(sin); + scalar_t out2 = static_cast(x1) * static_cast(sin) + static_cast(x2) * static_cast(cos); + return {out1, out2}; + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { + at::native::gpu_kernel_multiple_outputs( + iter, [] GPU_LAMBDA(scalar_t x1, scalar_t x2, scalar_t cos, scalar_t sin) -> thrust::tuple { + scalar_t out1 = static_cast(x1) * static_cast(cos) + static_cast(x2) * static_cast(sin); + scalar_t out2 = -static_cast(x1) * static_cast(sin) + static_cast(x2) * static_cast(cos); + return {out1, out2}; + }); + }); + } +} +} // namespace ops +} // namespace ext diff --git a/proto/include/diopi/functions_ext.h b/proto/include/diopi/functions_ext.h index 677797196..20971fb9f 100644 --- a/proto/include/diopi/functions_ext.h +++ b/proto/include/diopi/functions_ext.h @@ -66,8 +66,14 @@ DIOPI_API diopiError_t diopiRMSNormBackward(diopiContextHandle_t ctx, diopiTenso * @brief Compute the forward pass for MultiheadAttention. * @param[in] ctx The diopi context. * @param[in] q Query tensor. shape = [batch_size, q_seq_len, head_num, head_dim]. type = [float32, float16, float64]. + * - For the implementation of flash-attention in CUDA, it is necessary to pad for 'q' to ensure that the 'head_dim' of the output from 'q' is + * divisible by 8. Therefore, it is required to perform in-place modifications on 'q' by setting its data type to 'diopiTensorHandle_t'. * @param[in] k Key tensor. shape = [batch_size, k_seq_len, head_num, head_dim]. type = [float32, float16, float64]. + * - For the implementation of flash-attention in CUDA, it is necessary to pad for 'k' to ensure that the 'head_dim' of the output from 'k' is + * divisible by 8. Therefore, it is required to perform in-place modifications on 'k' by setting its data type to 'diopiTensorHandle_t'. * @param[in] v Value tensor. shape = [batch_size, v_seq_len, head_num, head_dim]. type = [float32, float16, float64]. + * - For the implementation of flash-attention in CUDA, it is necessary to pad for 'v' to ensure that the 'head_dim' of the output from 'v' is + * divisible by 8. Therefore, it is required to perform in-place modifications on 'v' by setting its data type to 'diopiTensorHandle_t'. * @param[in] dropout_p Dropout probability. type = [float32, float16, float64]. * @param[in] is_causal Flag to determine if the attention should be causal, masking future tokens. type = [bool] * @param[in] return_debug_mask Flag indicating if the attention debug mask should be returned. type = [bool]. @@ -80,9 +86,9 @@ DIOPI_API diopiError_t diopiRMSNormBackward(diopiContextHandle_t ctx, diopiTenso * @param[out] debug_attn_mask Debugging tensor for the attention mask (returned if return_debug_mask is true). shape = [batch_size, num_heads, q_seq_len, * k_seq_len]. type = [bool]. */ -DIOPI_API diopiError_t diopiMultiHeadAttention(diopiContextHandle_t ctx, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k, diopiConstTensorHandle_t v, - double dropout_p, bool is_causal, bool return_debug_mask, double scale, diopiTensorHandle_t out, - diopiTensorHandle_t softmax_lse, diopiGeneratorHandle_t gen, diopiTensorHandle_t debug_attn_mask); +DIOPI_API diopiError_t diopiMultiHeadAttention(diopiContextHandle_t ctx, diopiTensorHandle_t q, diopiTensorHandle_t k, diopiTensorHandle_t v, double dropout_p, + bool is_causal, bool return_debug_mask, double scale, diopiTensorHandle_t out, diopiTensorHandle_t softmax_lse, + diopiGeneratorHandle_t gen, diopiTensorHandle_t debug_attn_mask); /** * @brief Compute the forward pass for MultiheadAttention. @@ -111,8 +117,14 @@ DIOPI_API diopiError_t diopiMultiHeadAttentionBackward(diopiContextHandle_t ctx, * @brief Compute the forward pass for MultiheadAttentionVarLen. * @param[in] ctx The diopi context. * @param[in] q Query tensor. shape = [q_nums, head_num, head_dim]. type = [float32, float16, float64]. + * - For the implementation of flash-attention in CUDA, it is necessary to pad for 'q' to ensure that the 'head_dim' of the output from 'q' is + * divisible by 8. Therefore, it is required to perform in-place modifications on 'q' by setting its data type to 'diopiTensorHandle_t'. * @param[in] k Key tensor. shape = [k_nums, head_num, head_dim]. type = [float32, float16, float64]. + * - For the implementation of flash-attention in CUDA, it is necessary to pad for 'k' to ensure that the 'head_dim' of the output from 'k' is + * divisible by 8. Therefore, it is required to perform in-place modifications on 'k' by setting its data type to 'diopiTensorHandle_t'. * @param[in] v Value tensor. shape = [v_nums, head_num, head_dim]. type = [float32, float16, float64]. + * - For the implementation of flash-attention in CUDA, it is necessary to pad for 'v' to ensure that the 'head_dim' of the output from 'v' is + * divisible by 8. Therefore, it is required to perform in-place modifications on 'v' by setting its data type to 'diopiTensorHandle_t'. * @param[in] cum_seq_q Cumulative sequence length for the query. shape = [batch_size+1, ]. type = [int64, int32]. * @param[in] cum_seq_k Cumulative sequence length for the key. shape = [batch_size+1, ]. type = [int64, int32]. * @param[in] max_q Maximum sequence length for the query. type = [int64, int32]. @@ -127,11 +139,10 @@ DIOPI_API diopiError_t diopiMultiHeadAttentionBackward(diopiContextHandle_t ctx, * @param[out] debug_attn_mask Debugging tensor for the attention mask (returned if return_debug_mask is true). shape = [batch_size, num_heads, max_q, max_k]. * type = [bool]. */ -DIOPI_API diopiError_t diopiMultiHeadAttentionVarLen(diopiContextHandle_t ctx, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k, - diopiConstTensorHandle_t v, diopiConstTensorHandle_t cum_seq_q, diopiConstTensorHandle_t cum_seq_k, - int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, bool return_debug_mask, double scale, - diopiTensorHandle_t out, diopiTensorHandle_t softmax_lse, diopiGeneratorHandle_t gen, - diopiTensorHandle_t debug_attn_mask); +DIOPI_API diopiError_t diopiMultiHeadAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHandle_t q, diopiTensorHandle_t k, diopiTensorHandle_t v, + diopiConstTensorHandle_t cum_seq_q, diopiConstTensorHandle_t cum_seq_k, int64_t max_q, int64_t max_k, + double dropout_p, bool is_causal, bool return_debug_mask, double scale, diopiTensorHandle_t out, + diopiTensorHandle_t softmax_lse, diopiGeneratorHandle_t gen, diopiTensorHandle_t debug_attn_mask); /** * @brief Compute the forward pass for MultiheadAttentionVarLen. diff --git a/scripts/ci/diopi_config_A100.py b/scripts/ci/diopi_config_A100.py new file mode 100644 index 000000000..8bf58ed6f --- /dev/null +++ b/scripts/ci/diopi_config_A100.py @@ -0,0 +1,43 @@ +# Copyright (c) 2023, DeepLink. +# This is a temporary solution for testing the flash attention operator on the A100. It is intended for use +# only in the CI (Continuous Integration) environment. To use this file on the A100, replace the original +# DIOPI/diopi_test/python/conformance/diopi_configs.py with it. +import numpy as np + +ops_with_states = {} + +diopi_configs = { + 'multihead_attention_forward': dict( + name=['multihead_attention_forward'], + interface=['CustomizedTest'], + dtype=[np.float16], + atol=1e-3, + rtol=1e-4, + para=dict( + dropout_p=[0, 0, 0, 0], + is_causal=[False, False, True, False], + return_debug_mask=[False, False, False, False], + scale=[None, None, None, 0.1334] + ), + tensor_para=dict( + gen_fn='Genfunc.randn', + args=[ + { + "ins": ['q'], + "shape": ((2, 2, 2, 8), (2, 5, 7, 8), (4, 103, 8, 32), (8, 256, 16, 256)), + "dtype": [np.float16], + }, + { + "ins": ['k'], + "shape": ((2, 2, 2, 8), (2, 5, 7, 8), (4, 103, 8, 32), (8, 256, 16, 256)), + "dtype": [np.float16], + }, + { + "ins": ['v'], + "shape": ((2, 2, 2, 8), (2, 5, 7, 8), (4, 103, 8, 32), (8, 256, 16, 256)), + "dtype": [np.float16], + }, + ], + ), + ), +} \ No newline at end of file