diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index a6f7ad800..6518cbaf4 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -21,11 +21,11 @@ namespace bestla { namespace kernel { namespace avx2 { #if CompileAVX2() -#if defined(__GNUC__) +#if defined(__INTEL_LLVM_COMPILER) +#pragma clang attribute push(__attribute__((target("avx2,fma,f16c"))), apply_to = function) +#elif defined(__GNUC__) #pragma GCC push_options #pragma GCC target("avx2", "fma", "f16c") -#elif defined(ICX) -//#pragma clang attribute push(__attribute__((target("avx2,fma,f16c"))), apply_to = function) #endif static inline void zero_reg() { _mm256_zeroupper(); } @@ -5373,10 +5373,11 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_fp32(int m_size, int return BTLA_CODE::Success; } -#ifdef __GNUC__ +#if defined(__INTEL_LLVM_COMPILER) +#pragma clang attribute pop +#elif defined(__GNUC__) #pragma GCC pop_options #endif - #endif } // namespace avx2 } // namespace kernel diff --git a/bestla/bestla/kernel_avx512_bf16.h b/bestla/bestla/kernel_avx512_bf16.h index 5fdc9925e..607f8800a 100644 --- a/bestla/bestla/kernel_avx512_bf16.h +++ b/bestla/bestla/kernel_avx512_bf16.h @@ -19,12 +19,13 @@ namespace kernel { namespace avx512f { namespace avx512_bf16 { #if CompileBF16() -#if defined(__GNUC__) +#if defined(__INTEL_LLVM_COMPILER) +#pragma clang attribute push(__attribute__((target("avx512f,avx512bf16,avx512vl,avx512bw"))), apply_to = function) +#elif defined(__GNUC__) #pragma GCC push_options #pragma GCC target("avx512bf16", "avx512vl", "avx512bw") -#elif defined(ICX) -#pragma clang attribute push(__attribute__((target("avx512bf16,avx512vl,avx512bw"))), apply_to = function) #endif + static inline __m256i zmm_cvt_fp32_bf16(__m512 vfp32) { return (__m256i)_mm512_cvtneps_pbh(vfp32); } static inline __m512 load_bf16_fp32(const utils::bf16* srcptr) { @@ -175,7 +176,9 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_bf16(int m_size, int } return BTLA_CODE::Success; } -#if defined(__GNUC__) +#if defined(__INTEL_LLVM_COMPILER) +#pragma clang attribute pop +#elif defined(__GNUC__) #pragma GCC pop_options #endif #endif diff --git a/bestla/bestla/kernel_avx512_fp16.h b/bestla/bestla/kernel_avx512_fp16.h index 8ad426c5b..124ca0076 100644 --- a/bestla/bestla/kernel_avx512_fp16.h +++ b/bestla/bestla/kernel_avx512_fp16.h @@ -20,11 +20,11 @@ namespace kernel { namespace avx512f { namespace avx512_fp16 { #if CompileFP16() -#if defined(__GNUC__) +#if defined(__INTEL_LLVM_COMPILER) +#pragma clang attribute push(__attribute__((target("avx512f,avx512bf16,avx512bw,avx512fp16"))), apply_to = function) +#elif defined(__GNUC__) #pragma GCC push_options #pragma GCC target("avx512f", "avx512bf16", "avx512vl", "avx512bw", "avx512fp16") -#elif defined(ICX) -#pragma clang attribute push(__attribute__((target("avx512f,avx512bf16,avx512bw,avx512fp16"))), apply_to = function) #endif inline __m512 zmm_cvt_fp16_fp32(__m256i vfp16) { return _mm512_cvtxph_ps((__m256h)vfp16); } @@ -465,7 +465,9 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_fp16(int m_size, int } return BTLA_CODE::Success; } -#if defined(__GNUC__) +#if defined(__INTEL_LLVM_COMPILER) +#pragma clang attribute pop +#elif defined(__GNUC__) #pragma GCC pop_options #endif #endif diff --git a/bestla/bestla/kernel_avx512_vnni.h b/bestla/bestla/kernel_avx512_vnni.h index 23d4ccff0..53d12b2b0 100644 --- a/bestla/bestla/kernel_avx512_vnni.h +++ b/bestla/bestla/kernel_avx512_vnni.h @@ -18,12 +18,12 @@ namespace bestla { namespace kernel { namespace avx512f { #if CompileAVX512VNNI() -#ifdef __GNUC__ -#pragma GCC push_options -#pragma GCC target("avx512f", "avx512bw", "avx512vl", "avx512dq", "avx512vnni") -#elif defined(ICX) +#if defined(__INTEL_LLVM_COMPILER) #pragma clang attribute push(__attribute__((target("avx512f,avx512bw,avx512vl,avx512dq,avx512vnni"))), \ apply_to = function) +#elif defined(__GNUC__) +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512bw", "avx512vl", "avx512dq", "avx512vnni") #endif namespace vnni { @@ -1517,9 +1517,10 @@ static inline BTLA_CODE gemv_7bit_s8s8_fp32(const utils::GemvParamA& A, const ut } // namespace vnni -#ifdef __GNUC__ +#if defined(__INTEL_LLVM_COMPILER) +#pragma clang attribute pop +#elif defined(__GNUC__) #pragma GCC pop_options -#else #endif #endif } // namespace avx512f diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index ad55c2eb9..8c7eaa939 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -26,11 +26,11 @@ namespace bestla { namespace kernel { namespace avx512f { #if CompileAVX512F() -#ifdef __GNUC__ +#if defined(__INTEL_LLVM_COMPILER) +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,avx512bw,avx512dq"))), apply_to = function) +#elif defined(__GNUC__) #pragma GCC push_options #pragma GCC target("avx512f", "avx512bw", "avx512vl", "avx512dq") -#elif defined(ICX) -#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,avx512bw,avx512dq"))), apply_to = function) #endif inline __m512 zmm_cvt_fp16_fp32(__m256i vfp16) { return _mm512_cvtph_ps(vfp16); } @@ -6512,9 +6512,10 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_u8(int m_size, int n } return BTLA_CODE::Success; } -#ifdef __GNUC__ +#if defined(__INTEL_LLVM_COMPILER) +#pragma clang attribute pop +#elif defined(__GNUC__) #pragma GCC pop_options -#else #endif #endif } // namespace avx512f diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 56ed1efdb..4811c04c0 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -1925,6 +1925,7 @@ class ScaleTrackMax { const int M_offset, const int N_offset, const int M, const int N, float scale, int causal_offset, float alibi_slope, float tanh_scale, void* tmpcache, size_t cachesize) { +#if CompileAVX2() if (alibi_slope == 0 && tanh_scale == 0) return avx2::scale_track_max_fp32_fp32(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset, M, N, scale, causal_offset, alibi_slope, tanh_scale, tmpcache, @@ -1937,14 +1938,15 @@ class ScaleTrackMax { return avx2::scale_track_max_fp32_fp32(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset, M, N, scale, causal_offset, alibi_slope, tanh_scale, tmpcache, cachesize); - else - return BTLA_CODE::NotSupport; +#endif + return BTLA_CODE::NotSupport; } static BTLA_CODE forward_avx512(const SType* src, const int src_step, DType* dst, DType* dst_max, int ld_dst, const int M_offset, const int N_offset, const int M, const int N, float scale, int causal_offset, float alibi_slope, float tanh_scale, void* tmpcache, size_t cachesize) { +#if CompileAVX512F() if (alibi_slope == 0 && tanh_scale == 0) return avx512f::scale_track_max_fp32_fp32(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset, M, N, scale, causal_offset, alibi_slope, tanh_scale, @@ -1957,8 +1959,8 @@ class ScaleTrackMax { return avx512f::scale_track_max_fp32_fp32(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset, M, N, scale, causal_offset, alibi_slope, tanh_scale, tmpcache, cachesize); - else - return BTLA_CODE::NotSupport; +#endif + return BTLA_CODE::NotSupport; } }; diff --git a/bestla/bestla/sycl/sycl_prologue_b.h b/bestla/bestla/sycl/sycl_prologue_b.h index b6f845aa0..7c9810d8c 100644 --- a/bestla/bestla/sycl/sycl_prologue_b.h +++ b/bestla/bestla/sycl/sycl_prologue_b.h @@ -376,76 +376,76 @@ class WeightS4Trans { int constexpr TileK = 32; int constexpr GroupK = SgSize * TileK; auto ev = q->submit([&](sycl::handler& cgh) { - cgh.parallel_for(sycl::nd_range<1>(problem, group), - [=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size( - 1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { - int g_idx = it.get_group(0); - auto sg = it.get_sub_group(); - int sg_id = sg.get_local_id()[0]; - int g_n = g_idx; - auto sptr = B_scale + g_n * ldb; - auto bptr = B + g_n * k / 2; - auto aptr = A; - auto cptr = C + g_n; - if constexpr (std::is_same_v) { - sycl::half2 tmpAcc = {0.f, 0.f}; - for (int i = 0; i < k; i += GroupK * Unroll) { + cgh.parallel_for( + sycl::nd_range<1>(problem, group), + [=](sycl::nd_item<1> it) [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = B_scale + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { + sycl::half2 tmpAcc = {0.f, 0.f}; + for (int i = 0; i < k; i += GroupK * Unroll) { #pragma unroll - for (int iu = 0; iu < Unroll; iu++) { - uint8_t tmps8[TileK / 2]; - *(sycl::vec*)tmps8 = - *(sycl::vec*)(bptr + sg_id * TileK / 2); - CType scale = *(sptr + sg_id * TileK / blocksize); + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); #pragma unroll - for (int ikk = 0; ikk < TileK; ikk += 2) { - sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; - sycl::half2 tmpB = {static_cast((tmps8[ikk / 2] & 0x0f) - 8), - static_cast((tmps8[ikk / 2] >> 4) - 8)}; - tmpAcc += tmpA * tmpB * scale; - } - sptr += GroupK / blocksize; - aptr += GroupK; - bptr += GroupK / 2; - } - } - sycl::half2 sum = {0.f, 0.f}; - for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmpAcc, i); - } - if (sg_id == 0) { - *cptr = sum[0] + sum[1]; - } - } else { - CType tmpAcc = 0.f; - int constexpr Unroll = 2; - for (int i = 0; i < k; i += GroupK * Unroll) { + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = {static_cast((tmps8[ikk / 2] & 0x0f) - 8), + static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { + CType tmpAcc = 0.f; + int constexpr Unroll = 2; + for (int i = 0; i < k; i += GroupK * Unroll) { #pragma unroll - for (int iu = 0; iu < Unroll; iu++) { - uint8_t tmps8[TileK / 2]; - *(sycl::vec*)tmps8 = - *(sycl::vec*)(bptr + sg_id * TileK / 2); - CType scale = *(sptr + sg_id * TileK / blocksize); + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + CType scale = *(sptr + sg_id * TileK / blocksize); #pragma unroll - for (int ikk = 0; ikk < TileK; ikk += 2) { - tmpAcc += CType(aptr[sg_id * TileK + ikk]) * - static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; - tmpAcc += CType(aptr[sg_id * TileK + ikk + 1]) * - static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; - } - sptr += GroupK / blocksize; - aptr += GroupK; - bptr += GroupK / 2; - } - } - float sum = 0.f; - for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmpAcc, i); - } - if (sg_id == 0) { - *cptr = sum; - } - } - }); + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += + CType(aptr[sg_id * TileK + ikk]) * static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += + CType(aptr[sg_id * TileK + ikk + 1]) * static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; + } + } + }); }); return ev; } else { @@ -458,8 +458,7 @@ class WeightS4Trans { auto ev = q->submit([&](sycl::handler& cgh) { cgh.parallel_for( sycl::nd_range<1>(problem, group), - [=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size( - 1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { + [=](sycl::nd_item<1> it) [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] { int g_idx = it.get_group(0); auto sg = it.get_sub_group(); int sg_id = sg.get_local_id()[0]; diff --git a/bestla/bestla/sycl/sycl_wrapper.h b/bestla/bestla/sycl/sycl_wrapper.h index 49fb53de5..ea8d1b828 100644 --- a/bestla/bestla/sycl/sycl_wrapper.h +++ b/bestla/bestla/sycl/sycl_wrapper.h @@ -59,24 +59,22 @@ class Launcher { sycl::range<2> problem{static_cast(m_pad), static_cast(n) / GemmCore::TileN}; auto ev = q->submit([&](sycl::handler& cgh) { sycl::local_accessor slm_b(sycl::range(GemmCore::SLM_B_Size), cgh); - cgh.parallel_for( - sycl::nd_range<2>(problem, group), - [=](sycl::nd_item<2> it) [[sycl::reqd_work_group_size( - 1, GemmCore::WgM, - GemmCore::WgN)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] { - sycl_utils::nd_item_helper helper(it); - if constexpr (debug) { - compute_tile(k, B, ldb, slm_b, A, lda, C, ldc, it); - } else { - int m_tail = m - helper.sg_g_m(); - m_tail = m_tail > GemmCore::TileM ? GemmCore::TileM : m_tail; - if (m_tail == GemmCore::TileM) { - compute_tile(k, B, ldb, slm_b, A, lda, C, ldc, it); - } else { - compute_tail(k, B, ldb, slm_b, A, lda, C, ldc, m_tail, it); - } - } - }); + cgh.parallel_for(sycl::nd_range<2>(problem, group), + [=](sycl::nd_item<2> it) + [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] { + sycl_utils::nd_item_helper helper(it); + if constexpr (debug) { + compute_tile(k, B, ldb, slm_b, A, lda, C, ldc, it); + } else { + int m_tail = m - helper.sg_g_m(); + m_tail = m_tail > GemmCore::TileM ? GemmCore::TileM : m_tail; + if (m_tail == GemmCore::TileM) { + compute_tile(k, B, ldb, slm_b, A, lda, C, ldc, it); + } else { + compute_tail(k, B, ldb, slm_b, A, lda, C, ldc, m_tail, it); + } + } + }); }); return ev; } @@ -151,24 +149,22 @@ class LauncherWOQ { sycl::range<2> problem{static_cast(m_pad), static_cast(n) / GemmCore::TileN}; auto ev = q->submit([&](sycl::handler& cgh) { sycl::local_accessor slm_b(sycl::range(GemmCore::SLM_B_Size), cgh); - cgh.parallel_for( - sycl::nd_range<2>(problem, group), - [=](sycl::nd_item<2> it) [[sycl::reqd_work_group_size( - 1, GemmCore::WgM, - GemmCore::WgN)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] { - sycl_utils::nd_item_helper helper(it); - if constexpr (debug) { - compute_tile(k, blocksize, B, B_scale, ldb, slm_b, A, lda, C, ldc, it); - } else { - int m_tail = m - helper.sg_g_m(); - m_tail = m_tail > GemmCore::TileM ? GemmCore::TileM : m_tail; - if (m_tail == GemmCore::TileM) { - compute_tile(k, blocksize, B, B_scale, ldb, slm_b, A, lda, C, ldc, it); - } else { - compute_tail(k, blocksize, m_tail, B, B_scale, ldb, slm_b, A, lda, C, ldc, it); - } - } - }); + cgh.parallel_for(sycl::nd_range<2>(problem, group), + [=](sycl::nd_item<2> it) + [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(GemmCore::SgSize)]] { + sycl_utils::nd_item_helper helper(it); + if constexpr (debug) { + compute_tile(k, blocksize, B, B_scale, ldb, slm_b, A, lda, C, ldc, it); + } else { + int m_tail = m - helper.sg_g_m(); + m_tail = m_tail > GemmCore::TileM ? GemmCore::TileM : m_tail; + if (m_tail == GemmCore::TileM) { + compute_tile(k, blocksize, B, B_scale, ldb, slm_b, A, lda, C, ldc, it); + } else { + compute_tail(k, blocksize, m_tail, B, B_scale, ldb, slm_b, A, lda, C, ldc, it); + } + } + }); }); return ev; } diff --git a/bestla/bestla/ut/sycl_gemm.cpp b/bestla/bestla/ut/sycl_gemm.cpp index 19c58946a..87286e142 100644 --- a/bestla/bestla/ut/sycl_gemm.cpp +++ b/bestla/bestla/ut/sycl_gemm.cpp @@ -430,7 +430,7 @@ class UT_SyclS4Gemv { int constexpr TileK = 2; int constexpr GroupK = SgSize * TileK; sycl::range<1> group{SgSize}; - sycl::range<1> problem{n * SgSize}; + sycl::range<1> problem{(size_t)n * SgSize}; auto S_d = dS.data(); auto A_d = dA.data(); auto B_d = dB.data(); @@ -471,7 +471,7 @@ class UT_SyclS4Gemv { int constexpr TileK = 32; int constexpr GroupK = SgSize * TileK; sycl::range<1> group{SgSize}; - sycl::range<1> problem{n * SgSize}; + sycl::range<1> problem{(size_t)n * SgSize}; auto S_d = dS.data(); auto A_d = dA.data(); auto B_d = dB.data(); @@ -513,7 +513,7 @@ void mha_sref(float* Q, float* K, float* V, float* S, float* O, int batch, int s } float sums = 0.f; for (int jj = 0; jj < seqA; jj++) { - tmps[jj] = std::expf(tmps[jj] - maxs); + tmps[jj] = std::exp(tmps[jj] - maxs); sums += tmps[jj]; } sums = 1.f / sums; @@ -610,17 +610,17 @@ class UT_MHASgemm { int jj = wg_loc_id * 2; for (; jj < seq_acc_pad; jj += WgSize * 2) { auto s2 = *(TC*)&slm[jj]; - s2[0] = std::expf(s2[0] - fmax); - s2[1] = std::expf(s2[1] - fmax); + s2[0] = std::exp(s2[0] - fmax); + s2[1] = std::exp(s2[1] - fmax); fsums += s2[0]; fsums += s2[1]; *(TC*)&slm[jj] = s2; } if (jj < seq_acc) { - slm[jj] = std::expf(float(slm[jj]) - fmax); + slm[jj] = std::exp(float(slm[jj]) - fmax); fsums += slm[jj]; if (jj + 1 < seq_acc) { - slm[jj + 1] = std::expf(float(slm[jj + 1]) - fmax); + slm[jj + 1] = std::exp(float(slm[jj + 1]) - fmax); fsums += slm[jj + 1]; } } @@ -694,7 +694,7 @@ class UT_MHASgemm { auto Sptr = dS.data(); auto Optr = dO.data(); int nf = hnum * hsize; - sycl::range<1> num_items{batch * seq * hnum}; + sycl::range<1> num_items{(size_t)batch * seq * hnum}; int n_past = seqA - seq; const float attn_scale = 1.0f / sqrtf(static_cast(hsize)); if (seq > 1) { @@ -729,7 +729,7 @@ class UT_MHASgemm { // } // float sums = 0.f; // for (int jj = 0; jj < seqA; jj++) { - // tmps[jj] = std::expf(tmps[jj] - maxs); + // tmps[jj] = std::exp(tmps[jj] - maxs); // sums += tmps[jj]; // } // sums = 1.f / sums; diff --git a/bestla/bestla/ut/sycl_misc.cpp b/bestla/bestla/ut/sycl_misc.cpp index 2a41bfaa9..cc4b2866e 100644 --- a/bestla/bestla/ut/sycl_misc.cpp +++ b/bestla/bestla/ut/sycl_misc.cpp @@ -102,7 +102,7 @@ class UT_BlockQunatize_S3S4 { using ProB = sycl_prologue_b::WeightS4Trans; sycl_utils::sycl_vector dequantB(n * k, q); int blks = updiv(k, blocksize); - auto evt = ProB::dequant_s4( + auto evt = ProB::template dequant_s4( n, k, blocksize, {(uint8_t*)sycl_stor.mQBuf, (float*)sycl_stor.mSBuf, blks}, dequantB.data(), q); evt.wait(); q->memcpy(dequant.data(), dequantB.data(), dequantB.size() * 4).wait(); diff --git a/neural_speed/core/CMakeLists.txt b/neural_speed/core/CMakeLists.txt index 487273c4c..6a974dfae 100644 --- a/neural_speed/core/CMakeLists.txt +++ b/neural_speed/core/CMakeLists.txt @@ -16,12 +16,12 @@ find_package(Threads REQUIRED) file(GLOB layers_srcs "layers/*.cpp") file(GLOB test_srcs "layers/*test*.cpp") list(REMOVE_ITEM layers_srcs ${test_srcs}) -set(sources ne_layers.c ${layers_srcs}) +set(sources ne_layers.cpp ${layers_srcs}) add_shareable_library_w_warning(ne_layers "${sources}") target_include_directories(ne_layers PUBLIC .) -target_compile_features(ne_layers PUBLIC c_std_11) # don't bump +target_compile_features(ne_layers PUBLIC cxx_std_17) set_target_properties(ne_layers PROPERTIES POSITION_INDEPENDENT_CODE ON) if (NS_TP) find_package(oneCCL REQUIRED) diff --git a/neural_speed/core/layers/ne_bestla_sycl.cpp b/neural_speed/core/layers/ne_bestla_sycl.cpp index 170f5c503..442847cdd 100644 --- a/neural_speed/core/layers/ne_bestla_sycl.cpp +++ b/neural_speed/core/layers/ne_bestla_sycl.cpp @@ -202,12 +202,12 @@ void bestla_device_mul_f32(const struct ne_compute_params* params, const struct const size_t nb1 = dst->nb[1]; const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; - sycl::range<1> num_items{ne00 * ne01 * ne02 * ne03}; + sycl::range<1> num_items{static_cast(ne00) * ne01 * ne02 * ne03}; auto src0ptr = (float*)src0->data; auto src1ptr = (float*)src1->data; auto dstptr = (float*)dst->data; auto ev = q->submit([&](sycl::handler& cgh) { - cgh.parallel_for(num_items, [=](auto it) { + cgh.parallel_for(num_items, [=](auto it) [[intel::reqd_sub_group_size(16)]] { int i = it; int i00 = i % ne00; i /= ne00; @@ -264,12 +264,12 @@ void bestla_device_add_f32(const struct ne_compute_params* params, const struct const size_t nb1 = dst->nb[1]; const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; - sycl::range<1> num_items{ne00 * ne01 * ne02 * ne03}; + sycl::range<1> num_items{static_cast(ne00) * ne01 * ne02 * ne03}; auto src0ptr = (float*)src0->data; auto src1ptr = (float*)src1->data; auto dstptr = (float*)dst->data; auto ev = q->submit([&](sycl::handler& cgh) { - cgh.parallel_for(num_items, [=](auto it) { + cgh.parallel_for(num_items, [=](auto it) [[intel::reqd_sub_group_size(16)]] { int i = it; int i00 = i % ne00; i /= ne00; @@ -309,9 +309,9 @@ void bestla_device_elewise_f32(const struct ne_compute_params* params, const str auto srcptr = (float*)src0->data; auto dstptr = (float*)dst->data; - sycl::range<1> num_items{ne00 * ne01 * ne02 * ne03}; + sycl::range<1> num_items{static_cast(ne00) * ne01 * ne02 * ne03}; auto ev = q->submit([&](sycl::handler& cgh) { - cgh.parallel_for(num_items, [=](auto it) { + cgh.parallel_for(num_items, [=](auto it) [[intel::reqd_sub_group_size(16)]] { int i = it; float srcval = srcptr[i]; if (op == NE_OP_SILU) { @@ -347,6 +347,7 @@ void bestla_device_rms_norm_f32(const struct ne_compute_params* params, const st const size_t nb3 = dst->nb[3]; int64_t constexpr WgSize = 1024; int constexpr SgSize = 16; + int constexpr SgNum = WgSize / SgSize; int64_t ne00_ = bestla::utils::padto_le(ne00, WgSize); auto src0ptr = (float*)src0->data; auto dstptr = (float*)dst->data; @@ -370,34 +371,24 @@ void bestla_device_rms_norm_f32(const struct ne_compute_params* params, const st float* src0_ptr = (float*)((char*)src0ptr + i03 * nb03 + i02 * nb02 + i01 * nb01); float sum = 0.0; int64_t i00 = wg_loc_id; - for (; i00 < ne00_; i00 += WgSize) { + for (; i00 < ne00; i00 += WgSize) { sum += (src0_ptr[i00] * src0_ptr[i00]); } - if (i00 < ne00) { - sum += (src0_ptr[i00] * src0_ptr[i00]); + auto gsum = sycl::reduce_over_group(sg, sum, sycl::plus()); + if (lane_id == 0) { + slm[sg_idx] = gsum; } - slm[wg_loc_id] = sum; it.barrier(sycl::access::fence_space::local_space); - if (sg_idx == 0) { - for (size_t i = wg_loc_id; i < WgSize - SgSize; i += SgSize) { - sum += slm[i + SgSize]; - } - float gsum = 0.f; - for (int i = 0; i < SgSize; i += 1) { - gsum += sg.shuffle(sum, i); - } - float mean = gsum / ne00; - const float scale = 1.0f / sqrtf(mean + eps); - slm[0] = scale; + sum = 0; + for (size_t i = lane_id; i < SgNum; i += SgSize) { + sum += slm[i + SgSize]; } - it.barrier(sycl::access::fence_space::local_space); + gsum = sycl::reduce_over_group(sg, sum, sycl::plus()); - float scale = slm[0]; + float mean = gsum / ne00; + const float scale = 1.0f / sqrtf(mean + eps); i00 = wg_loc_id; - for (; i00 < ne00_; i00 += WgSize) { - dst_ptr[i00] = src0_ptr[i00] * scale; - } - if (i00 < ne00) { + for (; i00 < ne00; i00 += WgSize) { dst_ptr[i00] = src0_ptr[i00] * scale; } }); @@ -484,7 +475,7 @@ void bestla_device_rope_f32(const struct ne_compute_params* params, const struct const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; - const int nr = ne1 * ne2 * ne3; + const size_t nr = ne1 * ne2 * ne3; const float theta_scale = powf(freq_base, -2.0f / n_dims); const float inv_ndims = -1.f / n_dims; @@ -495,15 +486,14 @@ void bestla_device_rope_f32(const struct ne_compute_params* params, const struct int constexpr SgSize = 16; auto src0ptr = (float*)src0->data; auto dstptr = (float*)dst->data; + int ne00_h = ne00 >> 1; + assert(ne00 % 2 == 0); auto ev = q->submit([&](sycl::handler& cgh) { // sycl::local_accessor slm(sycl::range(WgSize), cgh); - cgh.parallel_for(sycl::nd_range<1>(nr * SgSize, SgSize), [=](auto it) [[intel::reqd_sub_group_size(SgSize)]] { - auto sg = it.get_sub_group(); - auto sg_idx = sg.get_group_id()[0]; - auto wg_idx = it.get_group(0); - auto wg_loc_id = it.get_local_id(); - auto lane_id = sg.get_local_id()[0]; - int i = wg_idx; + cgh.parallel_for(nr * ne00_h, [=](auto it) [[intel::reqd_sub_group_size(SgSize)]] { + int i = it; + int i0 = i % ne00_h; + i /= ne00_h; int i1 = i % ne1; i /= ne1; int i2 = i % ne2; @@ -511,22 +501,19 @@ void bestla_device_rope_f32(const struct ne_compute_params* params, const struct int i3 = i % ne3; const int64_t p = n_past + i2; - float theta_base = (float)p; - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims0, corr_dims1, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - - theta_base *= theta_scale; + float theta_base = (float)p * sycl::pow(theta_scale, (float)i0); + i0 *= 2; + float cos_theta, sin_theta; + rope_yarn(theta_base, freq_scale, corr_dims0, corr_dims1, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - const float* const src = (float*)((char*)src0ptr + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00); - float* dst_data = (float*)((char*)dstptr + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); + const float* const src = (float*)((char*)src0ptr + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00); + float* dst_data = (float*)((char*)dstptr + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); - const float x0 = src[0]; - const float x1 = src[1]; + const float x0 = src[0]; + const float x1 = src[1]; - dst_data[0] = x0 * cos_theta - x1 * sin_theta; - dst_data[1] = x0 * sin_theta + x1 * cos_theta; - } + dst_data[0] = x0 * cos_theta - x1 * sin_theta; + dst_data[1] = x0 * sin_theta + x1 * cos_theta; }); }); if (sycl_device::SyclDevice::is_cpu(q)) { @@ -564,9 +551,9 @@ void bestla_device_dup_f32(const struct ne_compute_params* params, const struct auto srcptr = (float*)src0->data; auto dstptr = (float*)dst->data; auto dtype = dst->type; - sycl::range<1> num_items{ne0 * ne1 * ne2 * ne3}; + sycl::range<1> num_items{static_cast(ne0) * ne1 * ne2 * ne3}; auto ev = q->submit([&](sycl::handler& cgh) { - cgh.parallel_for(num_items, [=](auto it) { + cgh.parallel_for(num_items, [=](auto it) [[intel::reqd_sub_group_size(16)]] { int i = it; int i0 = i % ne0; i /= ne0; @@ -644,10 +631,7 @@ class MHA { tmp *= attn_scale; } T tmp_sum = tmp[0] + tmp[1]; - T sum = 0; - for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmp_sum, i); - } + T sum = sycl::reduce_over_group(sg, tmp_sum, sycl::plus()); slm[jj] = sum; maxs = std::max(maxs, sum); } @@ -656,24 +640,21 @@ class MHA { int jj = wg_loc_id * 2; for (; jj < seq_acc_pad; jj += WgSize * 2) { auto s2 = *(TC*)&slm[jj]; - s2[0] = std::expf(s2[0] - fmax); - s2[1] = std::expf(s2[1] - fmax); + s2[0] = std::exp(s2[0] - fmax); + s2[1] = std::exp(s2[1] - fmax); fsums += s2[0]; fsums += s2[1]; *(TC*)&slm[jj] = s2; } if (jj < seq_acc) { - slm[jj] = std::expf(float(slm[jj]) - fmax); + slm[jj] = std::exp(float(slm[jj]) - fmax); fsums += slm[jj]; if (jj + 1 < seq_acc) { - slm[jj + 1] = std::expf(float(slm[jj + 1]) - fmax); + slm[jj + 1] = std::exp(float(slm[jj + 1]) - fmax); fsums += slm[jj + 1]; } } - float gsum = 0; - for (int i = 0; i < SgSize; i += 1) { - gsum += sg.shuffle(fsums, i); - } + float gsum = sycl::reduce_over_group(sg, fsums, sycl::plus()); T scale = 1.f / gsum; jj = wg_loc_id * 2; for (; jj < seq_acc_pad; jj += WgSize * 2) { @@ -703,10 +684,7 @@ class MHA { } } T tmp_sum = tmp[0] + tmp[1]; - T sum = 0; - for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmp_sum, i); - } + T sum = sycl::reduce_over_group(sg, tmp_sum, sycl::plus()); O[O_off + kk] = sum; } }); @@ -774,11 +752,7 @@ class MHA { } tmp *= attn_scale; } - T sum = 0; -#pragma unroll - for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmp, i); - } + T sum = sycl::reduce_over_group(sg, tmp, sycl::plus()); slm[jj] = sum; maxs = std::max(maxs, sum); } @@ -787,20 +761,16 @@ class MHA { int jj = wg_loc_id; for (; jj < seq_acc_pad; jj += SgSize) { auto s = slm[jj]; - s = std::expf(s - fmax); + s = std::exp(s - fmax); fsums += s; slm[jj] = s; } if (jj < seq_acc) { - auto s = std::expf(float(slm[jj]) - fmax); + auto s = std::exp(float(slm[jj]) - fmax); fsums += s; slm[jj] = s; } - float gsum = 0; -#pragma unroll - for (int i = 0; i < SgSize; i += 1) { - gsum += sg.shuffle(fsums, i); - } + auto gsum = sycl::reduce_over_group(sg, fsums, sycl::plus()); T scale = 1.f / gsum; jj = wg_loc_id; for (; jj < seq_acc_pad; jj += WgSize) { diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.cpp similarity index 96% rename from neural_speed/core/ne_layers.c rename to neural_speed/core/ne_layers.cpp index 278ec50c7..40052c48d 100644 --- a/neural_speed/core/ne_layers.c +++ b/neural_speed/core/ne_layers.cpp @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. // Defines CLOCK_MONOTONIC on Linux -#define _GNU_SOURCE - #include "ne_layers.h" #if defined(_MSC_VER) || defined(__MINGW32__) @@ -30,7 +28,14 @@ #include #include #include - +#include +#ifndef PRId64 +#if defined(_MSC_VER) || defined(__MINGW32__) +#define PRId64 "lld" +#else +#define PRId64 "ld" +#endif +#endif #if defined(_MSC_VER) || defined(__MINGW32__) #include #else @@ -50,24 +55,9 @@ #include "ne.h" #include "ne_bestla.h" -// if C99 - static_assert is noop -// ref: https://stackoverflow.com/a/53923785/4039976 -#ifndef static_assert -#define static_assert(cond, msg) struct global_scope_noop_trick -#endif - #if defined(_WIN32) #include - -typedef volatile LONG atomic_int; -typedef atomic_int atomic_bool; - -static void atomic_store(atomic_int* ptr, LONG val) { InterlockedExchange(ptr, val); } -static LONG atomic_load(atomic_int* ptr) { return InterlockedCompareExchange(ptr, 0, 0); } -static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { return InterlockedExchangeAdd(ptr, inc); } -static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { return atomic_fetch_add(ptr, -(dec)); } - typedef HANDLE pthread_t; typedef DWORD thread_ret_t; @@ -91,8 +81,6 @@ static int sched_yield(void) { } #else #include -#include - typedef void* thread_ret_t; #endif @@ -254,78 +242,76 @@ int64_t ne_cycles_per_ms(void) { return CLOCKS_PER_SEC / 1000; } // cache line // -#if defined(__cpp_lib_hardware_interference_size) -#define CACHE_LINE_SIZE hardware_destructive_interference_size -#else #define CACHE_LINE_SIZE 64 -#endif static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE / sizeof(float); static const quantize_fns_t quantize_fns[NE_TYPE_COUNT] = { - [NE_TYPE_Q4_0] = - { - .dequantize_row_q = (dequantize_row_q_t)dequantize_row_q4_0, - .quantize_row_q = quantize_row_q4_0, - .quantize_row_q_reference = (quantize_row_q_t)quantize_row_q4_0_reference, - .quantize_row_q_dot = quantize_row_q8_0, - .vec_dot_q = ne_vec_dot_q4_0_q8_0, - .vec_dot_type = NE_TYPE_Q8_0, - }, - [NE_TYPE_Q4_1] = - { - .dequantize_row_q = (dequantize_row_q_t)dequantize_row_q4_1, - .quantize_row_q = quantize_row_q4_1, - .quantize_row_q_reference = (quantize_row_q_t)quantize_row_q4_1_reference, - .quantize_row_q_dot = quantize_row_q8_1, - .vec_dot_q = ne_vec_dot_q4_1_q8_1, - .vec_dot_type = NE_TYPE_Q8_1, - }, - [NE_TYPE_Q5_0] = - { - .dequantize_row_q = (dequantize_row_q_t)dequantize_row_q5_0, - .quantize_row_q = quantize_row_q5_0, - .quantize_row_q_reference = (quantize_row_q_t)quantize_row_q5_0_reference, - .quantize_row_q_dot = quantize_row_q8_0, - .vec_dot_q = ne_vec_dot_q5_0_q8_0, - .vec_dot_type = NE_TYPE_Q8_0, - }, - [NE_TYPE_Q5_1] = - { - .dequantize_row_q = (dequantize_row_q_t)dequantize_row_q5_1, - .quantize_row_q = quantize_row_q5_1, - .quantize_row_q_reference = (quantize_row_q_t)quantize_row_q5_1_reference, - .quantize_row_q_dot = quantize_row_q8_1, - .vec_dot_q = ne_vec_dot_q5_1_q8_1, - .vec_dot_type = NE_TYPE_Q8_1, - }, - [NE_TYPE_Q8_0] = - { - .dequantize_row_q = dequantize_row_q8_0, - .quantize_row_q = quantize_row_q8_0, - .quantize_row_q_reference = (quantize_row_q_t)quantize_row_q8_0_reference, - .quantize_row_q_dot = quantize_row_q8_0, - .vec_dot_q = ne_vec_dot_q8_0_q8_0, - .vec_dot_type = NE_TYPE_Q8_0, - }, - [NE_TYPE_Q8_1] = - { - .dequantize_row_q = NULL, // TODO - .quantize_row_q = quantize_row_q8_1, - .quantize_row_q_reference = (quantize_row_q_t)quantize_row_q8_1_reference, - .quantize_row_q_dot = quantize_row_q8_1, - .vec_dot_q = NULL, // TODO - .vec_dot_type = NE_TYPE_Q8_1, - }, - [NE_TYPE_Q6_K] = - { - .dequantize_row_q = (dequantize_row_q_t)dequantize_row_q6_K, // TODO - .quantize_row_q = quantize_row_q6_K, - .quantize_row_q_reference = (quantize_row_q_t)quantize_row_q6_K_reference, - .quantize_row_q_dot = quantize_row_q8_K, - .vec_dot_q = ggml_vec_dot_q6_K_q8_K, // TODO - .vec_dot_type = NE_TYPE_Q8_K, - }, + {}, // NE_TYPE_F32 = 0 + {}, // NE_TYPE_F16 = 1 + { + (dequantize_row_q_t)dequantize_row_q4_0, //.dequantize_row_q = + quantize_row_q4_0, //.quantize_row_q = + (quantize_row_q_t)quantize_row_q4_0_reference, //.quantize_row_q_reference = + quantize_row_q8_0, //.quantize_row_q_dot = + ne_vec_dot_q4_0_q8_0, //.vec_dot_q = + NE_TYPE_Q8_0, //.vec_dot_type = + }, // NE_TYPE_Q4_0 = 2, + { + (dequantize_row_q_t)dequantize_row_q4_1, //.dequantize_row_q = + quantize_row_q4_1, //.quantize_row_q = + (quantize_row_q_t)quantize_row_q4_1_reference, //.quantize_row_q_reference = + quantize_row_q8_1, //.quantize_row_q_dot = + ne_vec_dot_q4_1_q8_1, //.vec_dot_q = + NE_TYPE_Q8_1, //.vec_dot_type = + }, // NE_TYPE_Q4_1 =3 + {}, // NE_TYPE_Q4_2 = 4, support has been removed + {}, // NE_TYPE_Q4_3 (5) support has been removed + { + (dequantize_row_q_t)dequantize_row_q5_0, //.dequantize_row_q = + quantize_row_q5_0, //.quantize_row_q = + (quantize_row_q_t)quantize_row_q5_0_reference, //.quantize_row_q_reference = + quantize_row_q8_0, //.quantize_row_q_dot = + ne_vec_dot_q5_0_q8_0, //.vec_dot_q = + NE_TYPE_Q8_0, //.vec_dot_type = + }, // NE_TYPE_Q5_0 = 6, + + { + (dequantize_row_q_t)dequantize_row_q5_1, //.dequantize_row_q = + quantize_row_q5_1, //.quantize_row_q = + (quantize_row_q_t)quantize_row_q5_1_reference, //.quantize_row_q_reference = + quantize_row_q8_1, //.quantize_row_q_dot = + ne_vec_dot_q5_1_q8_1, //.vec_dot_q = + NE_TYPE_Q8_1, //.vec_dot_type = + }, // NE_TYPE_Q5_1 = 7, + { + dequantize_row_q8_0, //.dequantize_row_q = + quantize_row_q8_0, //.quantize_row_q = + (quantize_row_q_t)quantize_row_q8_0_reference, //.quantize_row_q_reference = + quantize_row_q8_0, //.quantize_row_q_dot = + ne_vec_dot_q8_0_q8_0, //.vec_dot_q = + NE_TYPE_Q8_0, //.vec_dot_type = + }, // NE_TYPE_Q8_0 = 8, + { + NULL, // .dequantize_row_q = + quantize_row_q8_1, //.quantize_row_q = + (quantize_row_q_t)quantize_row_q8_1_reference, //.quantize_row_q_reference = + quantize_row_q8_1, //.quantize_row_q_dot = + NULL, // .vec_dot_q = + NE_TYPE_Q8_1, //.vec_dot_type = + }, // NE_TYPE_Q8_1 = 9, + {}, + {}, + {}, + {}, + { + (dequantize_row_q_t)dequantize_row_q6_K, // .dequantize_row_q = + quantize_row_q6_K, //.quantize_row_q = + (quantize_row_q_t)quantize_row_q6_K_reference, //.quantize_row_q_reference = + quantize_row_q8_K, //.quantize_row_q_dot = + ggml_vec_dot_q6_K_q8_K, // .vec_dot_q = + NE_TYPE_Q8_K, //.vec_dot_type = + }, // NE_TYPE_Q6_K = 14, }; // For internal test use @@ -338,35 +324,81 @@ quantize_fns_t ne_internal_get_quantize_fn(size_t i) { // data types // -static const int NE_BLCK_SIZE[NE_TYPE_COUNT] = { - [NE_TYPE_F32] = 1, [NE_TYPE_F16] = 1, [NE_TYPE_Q4_0] = QK4_0, [NE_TYPE_Q4_1] = QK4_1, - [NE_TYPE_Q5_0] = QK5_0, [NE_TYPE_Q5_1] = QK5_1, [NE_TYPE_Q8_0] = QK8_0, [NE_TYPE_Q8_1] = QK8_1, - [NE_TYPE_Q6_K] = QK_K, [NE_TYPE_Q8_K] = QK_K, [NE_TYPE_I8] = 1, [NE_TYPE_I16] = 1, - [NE_TYPE_I32] = 1, -}; +static const int NE_BLCK_SIZE[NE_TYPE_COUNT] = {1, // NE_TYPE_F32=0 + 1, // NE_TYPE_F16=1 + QK4_0, // NE_TYPE_Q4_0=2 + QK4_1, // NE_TYPE_Q4_1=3 + 0, 0, + QK5_0, //[NE_TYPE_Q5_0] = + QK5_1, //[NE_TYPE_Q5_1] = + QK8_0, //[NE_TYPE_Q8_0] = + QK8_1, //[NE_TYPE_Q8_1] = + 0, 0, 0, 0, + QK_K, //[NE_TYPE_Q6_K] = + QK_K, //[NE_TYPE_Q8_K] = + 1, //[NE_TYPE_I8] = + 1, //[NE_TYPE_I16] = + 1, //[NE_TYPE_I32] = + -1}; static_assert(NE_TYPE_COUNT == 20, "NE_BLCK_SIZE is outdated"); -static const size_t NE_TYPE_SIZE[NE_TYPE_COUNT] = { - [NE_TYPE_F32] = sizeof(float), [NE_TYPE_F16] = sizeof(ne_fp16_t), [NE_TYPE_Q4_0] = sizeof(block_q4_0), - [NE_TYPE_Q4_1] = sizeof(block_q4_1), [NE_TYPE_Q5_0] = sizeof(block_q5_0), [NE_TYPE_Q5_1] = sizeof(block_q5_1), - [NE_TYPE_Q8_0] = sizeof(block_q8_0), [NE_TYPE_Q8_1] = sizeof(block_q8_1), [NE_TYPE_Q6_K] = sizeof(block_q6_K), - [NE_TYPE_Q8_K] = sizeof(block_q8_K), [NE_TYPE_I8] = sizeof(int8_t), [NE_TYPE_I16] = sizeof(int16_t), - [NE_TYPE_I32] = sizeof(int32_t), -}; +static const size_t NE_TYPE_SIZE[NE_TYPE_COUNT] = {sizeof(float), //[NE_TYPE_F32] = + sizeof(ne_fp16_t), //[NE_TYPE_F16] = + sizeof(block_q4_0), //[NE_TYPE_Q4_0] = + sizeof(block_q4_1), //[NE_TYPE_Q4_1] = + 0, + 0, + sizeof(block_q5_0), //[NE_TYPE_Q5_0] = + sizeof(block_q5_1), //[NE_TYPE_Q5_1] = + sizeof(block_q8_0), //[NE_TYPE_Q8_0] = + sizeof(block_q8_1), //[NE_TYPE_Q8_1] = + 0, + 0, + 0, + 0, + sizeof(block_q6_K), //[NE_TYPE_Q6_K] = + sizeof(block_q8_K), //[NE_TYPE_Q8_K] = + sizeof(int8_t), //[NE_TYPE_I8] = + sizeof(int16_t), //[NE_TYPE_I16] = + sizeof(int32_t), //[NE_TYPE_I32] = + 1}; static_assert(NE_TYPE_COUNT == 20, "NE_TYPE_SIZE is outdated"); -static const char* NE_TYPE_NAME[NE_TYPE_COUNT] = { - [NE_TYPE_F32] = "f32", [NE_TYPE_F16] = "f16", [NE_TYPE_Q4_0] = "q4_0", [NE_TYPE_Q4_1] = "q4_1", - [NE_TYPE_Q5_0] = "q5_0", [NE_TYPE_Q5_1] = "q5_1", [NE_TYPE_Q8_0] = "q8_0", [NE_TYPE_Q8_1] = "q8_1", - [NE_TYPE_Q6_K] = "q6_k", [NE_TYPE_Q8_K] = "q8_k", [NE_TYPE_I8] = "i8", [NE_TYPE_I16] = "i16", - [NE_TYPE_I32] = "i32", -}; +static const char* NE_TYPE_NAME[NE_TYPE_COUNT] = {"f32", //[NE_TYPE_F32] = + "f16", //[NE_TYPE_F16] = + "q4_0", //[NE_TYPE_Q4_0] = + "q4_1", //[NE_TYPE_Q4_1] = + "", "", + "q5_0", //[NE_TYPE_Q5_0] = + "q5_1", //[NE_TYPE_Q5_1] = + "q8_0", //[NE_TYPE_Q8_0] = + "q8_1", //[NE_TYPE_Q8_1] = + "", "", "", "", + "q6_k", //[NE_TYPE_Q6_K] = + "q8_k", //[NE_TYPE_Q8_K] = + "i8", //[NE_TYPE_I8] = + "i16", //[NE_TYPE_I16] = + "i32", //[NE_TYPE_I32] = + "bestla"}; static_assert(NE_TYPE_COUNT == 20, "NE_TYPE_NAME is outdated"); static bool NE_IS_QUANTIZED[NE_TYPE_COUNT] = { - [NE_TYPE_F32] = false, [NE_TYPE_F16] = false, [NE_TYPE_Q4_0] = true, [NE_TYPE_Q4_1] = true, [NE_TYPE_Q5_0] = true, - [NE_TYPE_Q5_1] = true, [NE_TYPE_Q8_0] = true, [NE_TYPE_Q8_1] = true, [NE_TYPE_Q6_K] = true, [NE_TYPE_Q6_K] = true, - [NE_TYPE_I8] = false, [NE_TYPE_I16] = false, [NE_TYPE_I32] = false, [NE_TYPE_BTLA] = true, + false, // [NE_TYPE_F32] = + false, //[NE_TYPE_F16] = + true, //[NE_TYPE_Q4_0] = + true, //[NE_TYPE_Q4_1] = + false, false, + true, //[NE_TYPE_Q5_0] = + true, //[NE_TYPE_Q5_1] = + true, //[NE_TYPE_Q8_0] = + true, //[NE_TYPE_Q8_1] = + false, false, false, false, + true, //[NE_TYPE_Q6_K] = + true, //[NE_TYPE_Q6_K] = + false, //[NE_TYPE_I8] = + false, //[NE_TYPE_I16] = + false, //[NE_TYPE_I32] = + true, //[NE_TYPE_BTLA] = }; static_assert(NE_TYPE_COUNT == 20, "NE_IS_QUANTIZED is outdated"); @@ -530,23 +562,23 @@ struct ne_state { // global state static struct ne_state g_state; -static atomic_int g_state_barrier = 0; +static std::atomic g_state_barrier = 0; // barrier via spin lock inline static void ne_critical_section_start(void) { - int processing = atomic_fetch_add(&g_state_barrier, 1); + int processing = g_state_barrier.fetch_add(1); while (processing > 0) { // wait for other threads to finish - atomic_fetch_sub(&g_state_barrier, 1); + g_state_barrier.fetch_sub(1); sched_yield(); // TODO: reconsider this - processing = atomic_fetch_add(&g_state_barrier, 1); + processing = g_state_barrier.fetch_add(1); } } // TODO: make this somehow automatically executed // some sort of "sentry" mechanism -inline static void ne_critical_section_end(void) { atomic_fetch_sub(&g_state_barrier, 1); } +inline static void ne_critical_section_end(void) { g_state_barrier.fetch_sub(1); } //////////////////////////////////////////////////////////////////////////////// @@ -756,7 +788,7 @@ struct ne_context* ne_init(struct ne_init_params params) { const uint64_t t_start = ne_time_us(); UNUSED(t_start); - g_state = (struct ne_state){ + g_state = ne_state{ /*.contexts =*/{{0}}, }; @@ -796,26 +828,25 @@ struct ne_context* ne_init(struct ne_init_params params) { const size_t mem_size = (params.mem_size + NE_MEM_ALIGN - 1) & ~(NE_MEM_ALIGN - 1); - *ctx = - (struct ne_context){/*.mem_size =*/mem_size, - /*.mem_buffer =*/params.mem_buffer ? params.mem_buffer : NE_ALIGNED_MALLOC(mem_size), - /*.mem_buffer_owned =*/params.mem_buffer ? false : true, - /*.no_alloc =*/params.no_alloc, - /*.n_objects =*/0, - /*.objects_begin =*/NULL, - /*.objects_end =*/NULL, - /*.scratch =*/ - { - 0, - 0, - NULL, - }, - /*.scratch_save =*/ - { - 0, - 0, - NULL, - }}; + *ctx = ne_context{/*.mem_size =*/mem_size, + /*.mem_buffer =*/params.mem_buffer ? params.mem_buffer : NE_ALIGNED_MALLOC(mem_size), + /*.mem_buffer_owned =*/params.mem_buffer ? false : true, + /*.no_alloc =*/params.no_alloc, + /*.n_objects =*/0, + /*.objects_begin =*/NULL, + /*.objects_end =*/NULL, + /*.scratch =*/ + { + 0, + 0, + NULL, + }, + /*.scratch_save =*/ + { + 0, + 0, + NULL, + }}; NE_ASSERT(ctx->mem_buffer != NULL); @@ -957,10 +988,10 @@ struct ne_tensor* ne_new_device_tensor_impl(struct ne_context* ctx, enum ne_type return NULL; } - *obj_new = (struct ne_object){ - .offs = cur_end + NE_OBJECT_SIZE, - .size = obj_size, - .next = NULL, + *obj_new = ne_object{ + cur_end + NE_OBJECT_SIZE, + obj_size, + NULL, }; } else { if (cur_end + sizeof(struct ne_tensor) + NE_OBJECT_SIZE > ctx->mem_size) { @@ -969,10 +1000,10 @@ struct ne_tensor* ne_new_device_tensor_impl(struct ne_context* ctx, enum ne_type assert(false); return NULL; } - *obj_new = (struct ne_object){ - .offs = cur_end + NE_OBJECT_SIZE, - .size = sizeof(struct ne_tensor), - .next = NULL, + *obj_new = ne_object{ + cur_end + NE_OBJECT_SIZE, + sizeof(struct ne_tensor), + NULL, }; if (type == NE_TYPE_BTLA) { if (ctx->scratch.offs + bestla_device_storage_size() > ctx->scratch.size) { @@ -997,27 +1028,27 @@ struct ne_tensor* ne_new_device_tensor_impl(struct ne_context* ctx, enum ne_type struct ne_tensor* const result = (struct ne_tensor*)(mem_buffer + obj_new->offs); - *result = (struct ne_tensor){ - .type = type, - .backend = NE_BACKEND_SYCL, - .n_dims = n_dims, - .ne = {1, 1, 1, 1}, - .nb = {0, 0, 0, 0}, - .op = NE_OP_NONE, - .is_param = false, - .op_params = {0}, - .grad = NULL, - .src0 = NULL, - .src1 = NULL, - .opt = {NULL}, - .n_tasks = 0, - .perf_runs = 0, - .perf_cycles = 0, - .perf_time_us = 0, - .data = NULL, - .size = size_needed, - .name = {0}, - .padding = {0}, + *result = ne_tensor{ + type, + NE_BACKEND_SYCL, + n_dims, + {1, 1, 1, 1}, + {0, 0, 0, 0}, + NE_OP_NONE, + false, + {0}, + NULL, + NULL, + NULL, + {NULL}, + 0, + 0, + 0, + 0, + NULL, + size_needed, + {0}, + {0}, }; if (type == NE_TYPE_BTLA) { result->data = (void*)(result + 1); @@ -1100,10 +1131,10 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, return NULL; } - *obj_new = (struct ne_object){ - .offs = cur_end + NE_OBJECT_SIZE, - .size = size_needed, - .next = NULL, + *obj_new = ne_object{ + cur_end + NE_OBJECT_SIZE, + size_needed, + NULL, }; } else { if (ctx->scratch.offs + size_needed > ctx->scratch.size) { @@ -1124,10 +1155,10 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, data = (char* const)ctx->scratch.data + ctx->scratch.offs; - *obj_new = (struct ne_object){ - .offs = cur_end + NE_OBJECT_SIZE, - .size = sizeof(struct ne_tensor), - .next = NULL, + *obj_new = ne_object{ + cur_end + NE_OBJECT_SIZE, + sizeof(struct ne_tensor), + NULL, }; ctx->scratch.offs += size_needed; @@ -1144,27 +1175,27 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, struct ne_tensor* const result = (struct ne_tensor*)(mem_buffer + obj_new->offs); - *result = (struct ne_tensor){ - .type = type, - .backend = NE_BACKEND_CPU, - .n_dims = n_dims, - .ne = {1, 1, 1, 1}, - .nb = {0, 0, 0, 0}, - .op = NE_OP_NONE, - .is_param = false, - .op_params = {0}, - .grad = NULL, - .src0 = NULL, - .src1 = NULL, - .opt = {NULL}, - .n_tasks = 0, - .perf_runs = 0, - .perf_cycles = 0, - .perf_time_us = 0, - .data = (data == NULL && !ctx->no_alloc) ? (void*)(result + 1) : data, - .size = data ? size : size_needed, - .name = {0}, - .padding = {0}, + *result = ne_tensor{ + type, + NE_BACKEND_CPU, + n_dims, + {1, 1, 1, 1}, + {0, 0, 0, 0}, + NE_OP_NONE, + false, + {0}, + NULL, + NULL, + NULL, + {NULL}, + 0, + 0, + 0, + 0, + (data == NULL && !ctx->no_alloc) ? (void*)(result + 1) : data, + data ? size : size_needed, + {0}, + {0}, }; for (int i = 0; i < n_dims; i++) { @@ -1395,7 +1426,7 @@ void ne_set_i32_1d(const struct ne_tensor* tensor, int i, int32_t value) { } break; case NE_TYPE_F16: { NE_ASSERT(tensor->nb[0] == sizeof(ne_fp16_t)); - ((ne_fp16_t*)(tensor->data))[i] = NE_FP32_TO_FP16(value); + ((ne_fp16_t*)(tensor->data))[i] = NE_FP32_TO_FP16(float(value)); } break; case NE_TYPE_F32: { NE_ASSERT(tensor->nb[0] == sizeof(float)); @@ -1547,7 +1578,7 @@ struct ne_tensor* ne_debug_op(struct ne_context* ctx, struct ne_tensor* a, ne_de result->op = NE_OP_DEBUG; result->src0 = a; static_assert(sizeof(void*) <= sizeof(result->padding), "No enough space for function ptr!"); - *((void**)(result->padding)) = cb; + *(ne_debug_callback_t*)(result->padding) = cb; return result; } @@ -2340,7 +2371,6 @@ struct ne_tensor* ne_rms_norm_back(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* ne_mul_mat(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* b) { NE_ASSERT(ne_can_mul_mat(a, b)); - NE_ASSERT(!ne_is_transposed(a)); bool is_node = false; @@ -2362,7 +2392,6 @@ struct ne_tensor* ne_mul_mat(struct ne_context* ctx, struct ne_tensor* a, struct struct ne_tensor* ne_mul_mat_with_bias(struct ne_context* ctx, struct ne_tensor* w, struct ne_tensor* b, struct ne_tensor* a) { NE_ASSERT(ne_can_mul_mat(w, a)); - NE_ASSERT(!ne_is_transposed(w)); bool is_node = false; @@ -2410,7 +2439,6 @@ struct ne_tensor* ne_mul_mat_id(struct ne_context* ctx, struct ne_tensor* const struct ne_tensor* a = as[i]; NE_ASSERT(ne_are_same_shape(as[0], a)); NE_ASSERT(ne_can_mul_mat(a, b)); - NE_ASSERT(!ne_is_transposed(a)); result->opt[i] = a; } return result; @@ -10174,41 +10202,41 @@ static void ne_compute_forward_flash_attn_f32_f16_f16(const struct ne_compute_pa float scale = *(float*)dst->padding; ne_attn_flags_t flags = *(bool*)&dst->padding[sizeof(scale)]; attn_fp32_fp16_fp16_fp32_fwd_args_t args = { - .Q = (float*)q->data, - .K = (ne_fp16_t*)k->data, - .V = (ne_fp16_t*)v->data, - .dst = (float*)dst->data, - .Q_sc = 1.f, - .K_sc = 1.f, - .V_sc = 1.f, - .dst_sc = 1.f, - .tmp = (char*)tmp->data, - .QK_scale = scale, - .attn_flags = flags, - .batch_size = batch, - .head_num = headnum, - .heads_kv = heads_kv, - .head_size = headsize, - .sl_q = seq_cur, - .sl_kv = seq_all, - .Q_layout = ATTN_FWD_LAYOUT_PLAIN, - .K_layout = ATTN_FWD_LAYOUT_PLAIN, - .V_layout = ATTN_FWD_LAYOUT_PLAIN, - .dst_layout = ATTN_FWD_LAYOUT_PLAIN, - .step_q_bs = seq_cur * embedsize, - .step_q_head_num = headsize, - .step_q_sl = embedsize, - .step_k_bs = step_k_bs, - .step_k_head_num = step_k_head_num, - .step_k_sl = step_k_sl, - .step_k_head_size = step_k_head_size, // TODO - .step_v_bs = step_v_bs, - .step_v_head_num = step_v_head_num, - .step_v_sl = step_v_sl, - .step_v_head_size = 1, - .step_dst_bs = seq_cur * embedsize, - .step_dst_head_num = headsize, - .step_dst_sl = embedsize, + (float*)q->data, + (ne_fp16_t*)k->data, + (ne_fp16_t*)v->data, + (float*)dst->data, + 1.f, + 1.f, + 1.f, + 1.f, + (char*)tmp->data, + scale, + flags, + static_cast(batch), + static_cast(headnum), + static_cast(heads_kv), + static_cast(headsize), + static_cast(seq_cur), + static_cast(seq_all), + ATTN_FWD_LAYOUT_PLAIN, + ATTN_FWD_LAYOUT_PLAIN, + ATTN_FWD_LAYOUT_PLAIN, + ATTN_FWD_LAYOUT_PLAIN, + static_cast(seq_cur * embedsize), + static_cast(headsize), + static_cast(embedsize), + step_k_bs, + step_k_head_num, + step_k_sl, + step_k_head_size, // TODO + step_v_bs, + step_v_head_num, + step_v_sl, + 1, + static_cast(seq_cur * embedsize), + static_cast(headsize), + static_cast(embedsize), }; bestla_fusion_attn_fp32_fp16_fp16_fp32_forward(&args); } @@ -10236,45 +10264,45 @@ static void ne_compute_forward_flash_attn_reordered(const struct ne_compute_para ATTN_FWD_LAYOUT V_layout = *(ATTN_FWD_LAYOUT*)(&v->nb[0]); bestla_reordered_attn_fp32_fp32_fwd_args_t args = { - .Q = (float*)q->data, - .K = (char*)k->data, - .V = (char*)v->data, - .dst = (float*)dst->data, - .Q_sc = 1.f, - .K_sc = 1.f, - .V_sc = 1.f, - .dst_sc = 1.f, - .tmp = (char*)tmp->data, - .QK_scale = scale, - .attn_flags = flags, - .batch_size = batch, - .head_num = headnum, - .heads_kv = heads_kv, - .head_size = headsize, - .sl_q = seq_cur, - .sl_kv = seq_all, - .Q_layout = ATTN_FWD_LAYOUT_PLAIN, - .K_layout = K_layout, - .V_layout = V_layout, - .dst_layout = ATTN_FWD_LAYOUT_PLAIN, - .step_q_bs = q->nb[3] / q_ele_size, - .step_q_head_num = q->nb[2] / q_ele_size, - .step_q_sl = q->nb[1] / q_ele_size, - - .stride_k_bs = k->nb[3], - .stride_k_head_num = k->nb[2], - .stride_k_sl = k->nb[1], - .stride_k_head_size = 0, - - .stride_v_bs = v->nb[3], - .stride_v_head_num = v->nb[2], - .stride_v_sl = 0, - .stride_v_head_size = v->nb[1], + (float*)q->data, + (char*)k->data, + (char*)v->data, + (float*)dst->data, + 1.f, + 1.f, + 1.f, + 1.f, + (char*)tmp->data, + scale, + flags, + static_cast(batch), + static_cast(headnum), + static_cast(heads_kv), + static_cast(headsize), + static_cast(seq_cur), + static_cast(seq_all), + ATTN_FWD_LAYOUT_PLAIN, + K_layout, + V_layout, + ATTN_FWD_LAYOUT_PLAIN, + static_cast(q->nb[3] / q_ele_size), + static_cast(q->nb[2] / q_ele_size), + static_cast(q->nb[1] / q_ele_size), + + static_cast(k->nb[3]), + static_cast(k->nb[2]), + static_cast(k->nb[1]), + 0, + + static_cast(v->nb[3]), + static_cast(v->nb[2]), + 0, + static_cast(v->nb[1]), // dst in (head_size, n_head, seq, bs) - .step_dst_bs = dst->nb[3] / dst_ele_size, - .step_dst_head_num = dst->nb[1] / dst_ele_size, - .step_dst_sl = dst->nb[2] / dst_ele_size, + static_cast(dst->nb[3] / dst_ele_size), + static_cast(dst->nb[1] / dst_ele_size), + static_cast(dst->nb[2] / dst_ele_size), }; bestla_reordered_attn_fp32_forward(&args); } @@ -10537,19 +10565,19 @@ static void ne_compute_forward_flash_attn_kv_update(const struct ne_compute_para const bool no_zeroing = (bool)p_data[2]; NE_ASSERT(cur->type == NE_TYPE_F32); bestla_fusion_attn_fp32_update_kv_args_t args = { - .src = (float*)cur->data, - .cache = (char*)cache->data, - .batch_size = cur->ne[3], - .heads_kv = cur->ne[1], - .head_size = cur->ne[0], - .seq_off = n_past, - .seq_size = cur->ne[2], - .seq_max = cache->ne[1], - .step_bs = cur->nb[3] / NE_TYPE_SIZE[cur->type], - .step_head_num = cur->nb[1] / NE_TYPE_SIZE[cur->type], - .step_seq = cur->nb[2] / NE_TYPE_SIZE[cur->type], - .step_head_size = cur->nb[0] / NE_TYPE_SIZE[cur->type], - .no_zeroing = no_zeroing, + (float*)cur->data, + (char*)cache->data, + static_cast(cur->ne[3]), + (int)cur->ne[1], + (int)cur->ne[0], + n_past, + (int)cur->ne[2], + (int)cache->ne[1], + (int)(cur->nb[3] / NE_TYPE_SIZE[cur->type]), + (int)(cur->nb[1] / NE_TYPE_SIZE[cur->type]), + (int)(cur->nb[2] / NE_TYPE_SIZE[cur->type]), + (int)(cur->nb[0] / NE_TYPE_SIZE[cur->type]), + no_zeroing, }; if (is_v) bestla_reordered_attn_fp32_update_v(&args); @@ -12806,55 +12834,53 @@ struct ne_opt_params ne_opt_default_params(enum ne_opt_type type) { switch (type) { case NE_OPT_ADAM: { - result = (struct ne_opt_params){ - .type = NE_OPT_ADAM, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 100, - - .print_forward_graph = true, - .print_backward_graph = true, - - .adam = - { - .n_iter = 10000, - .alpha = 0.001f, - .beta1 = 0.9f, - .beta2 = 0.999f, - .eps = 1e-8f, - .eps_f = 1e-5f, - .eps_g = 1e-3f, - }, - }; + result = ne_opt_params{NE_OPT_ADAM, //.type = + 1, //.n_threads = + 0, //.past = + 1e-5f, //.delta = + + 100, //.max_no_improvement = + + true, //.print_forward_graph = + true, //.print_backward_graph = + + { + 10000, //.n_iter = + 0.001f, //.alpha = + 0.9f, //.beta1 = + 0.999f, //.beta2 = + 1e-8f, //.eps = + 1e-5f, //.eps_f = + 1e-3f, //.eps_g = + }, //.adam = + {}}; } break; case NE_OPT_LBFGS: { - result = (struct ne_opt_params){ - .type = NE_OPT_LBFGS, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 0, - - .print_forward_graph = true, - .print_backward_graph = true, - - .lbfgs = - { - .m = 6, - .n_iter = 100, - .max_linesearch = 20, - - .eps = 1e-5f, - .ftol = 1e-4f, - .wolfe = 0.9f, - .min_step = 1e-20f, - .max_step = 1e+20f, - - .linesearch = NE_LINESEARCH_DEFAULT, - }, + result = ne_opt_params{ + NE_OPT_LBFGS, //.type = + 1, //.n_threads = + 0, //.past = + 1e-5f, //.delta = + + 0, //.max_no_improvement = + + true, //.print_forward_graph = + true, //.print_backward_graph = + {}, + + { + 6, //.m = + 100, //.n_iter = + 20, //.max_linesearch = + + 1e-5f, //.eps = + 1e-4f, //.ftol = + 0.9f, //.wolfe = + 1e-20f, //.min_step = + 1e+20f, //.max_step = + + NE_LINESEARCH_DEFAULT, //.linesearch = + }, //.lbfgs = }; } break; } @@ -12866,9 +12892,9 @@ enum ne_opt_result ne_opt(struct ne_context* ctx, struct ne_opt_params params, s bool free_ctx = false; if (ctx == NULL) { struct ne_init_params params_ctx = { - .mem_size = 16 * 1024 * 1024, - .mem_buffer = NULL, - .no_alloc = false, + 16 * 1024 * 1024, + NULL, + false, }; ctx = ne_init(params_ctx); diff --git a/neural_speed/models/llama/llama.cpp b/neural_speed/models/llama/llama.cpp index 486c3760e..9ae2643df 100644 --- a/neural_speed/models/llama/llama.cpp +++ b/neural_speed/models/llama/llama.cpp @@ -352,7 +352,6 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp hparams.freq_scale); Kcur = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale); - // Vcur = ne_transpose(ctx0, ne_reshape_2d(ctx0, Vcur, head_size * n_head_kv, N)); } ne_set_name(Qcur, "Qcur"); ne_set_name(Kcur, "Kcur");