Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

sync SYCL code #312

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ namespace bestla {
namespace kernel {
namespace avx2 {
#if CompileAVX2()
#if defined(__GNUC__)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute push(__attribute__((target("avx2,fma,f16c"))), apply_to = function)
#elif defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx2", "fma", "f16c")
#elif defined(ICX)
//#pragma clang attribute push(__attribute__((target("avx2,fma,f16c"))), apply_to = function)
#endif

static inline void zero_reg() { _mm256_zeroupper(); }
Expand Down Expand Up @@ -5373,10 +5373,11 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_fp32(int m_size, int

return BTLA_CODE::Success;
}
#ifdef __GNUC__
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute pop
#elif defined(__GNUC__)
#pragma GCC pop_options
#endif

#endif
} // namespace avx2
} // namespace kernel
Expand Down
11 changes: 7 additions & 4 deletions bestla/bestla/kernel_avx512_bf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ namespace kernel {
namespace avx512f {
namespace avx512_bf16 {
#if CompileBF16()
#if defined(__GNUC__)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute push(__attribute__((target("avx512f,avx512bf16,avx512vl,avx512bw"))), apply_to = function)
#elif defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx512bf16", "avx512vl", "avx512bw")
#elif defined(ICX)
#pragma clang attribute push(__attribute__((target("avx512bf16,avx512vl,avx512bw"))), apply_to = function)
#endif

static inline __m256i zmm_cvt_fp32_bf16(__m512 vfp32) { return (__m256i)_mm512_cvtneps_pbh(vfp32); }

static inline __m512 load_bf16_fp32(const utils::bf16* srcptr) {
Expand Down Expand Up @@ -175,7 +176,9 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_bf16(int m_size, int
}
return BTLA_CODE::Success;
}
#if defined(__GNUC__)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute pop
#elif defined(__GNUC__)
#pragma GCC pop_options
#endif
#endif
Expand Down
10 changes: 6 additions & 4 deletions bestla/bestla/kernel_avx512_fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ namespace kernel {
namespace avx512f {
namespace avx512_fp16 {
#if CompileFP16()
#if defined(__GNUC__)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute push(__attribute__((target("avx512f,avx512bf16,avx512bw,avx512fp16"))), apply_to = function)
#elif defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx512f", "avx512bf16", "avx512vl", "avx512bw", "avx512fp16")
#elif defined(ICX)
#pragma clang attribute push(__attribute__((target("avx512f,avx512bf16,avx512bw,avx512fp16"))), apply_to = function)
#endif

inline __m512 zmm_cvt_fp16_fp32(__m256i vfp16) { return _mm512_cvtxph_ps((__m256h)vfp16); }
Expand Down Expand Up @@ -465,7 +465,9 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_fp16(int m_size, int
}
return BTLA_CODE::Success;
}
#if defined(__GNUC__)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute pop
#elif defined(__GNUC__)
#pragma GCC pop_options
#endif
#endif
Expand Down
13 changes: 7 additions & 6 deletions bestla/bestla/kernel_avx512_vnni.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ namespace bestla {
namespace kernel {
namespace avx512f {
#if CompileAVX512VNNI()
#ifdef __GNUC__
#pragma GCC push_options
#pragma GCC target("avx512f", "avx512bw", "avx512vl", "avx512dq", "avx512vnni")
#elif defined(ICX)
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute push(__attribute__((target("avx512f,avx512bw,avx512vl,avx512dq,avx512vnni"))), \
apply_to = function)
#elif defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx512f", "avx512bw", "avx512vl", "avx512dq", "avx512vnni")
#endif

namespace vnni {
Expand Down Expand Up @@ -1517,9 +1517,10 @@ static inline BTLA_CODE gemv_7bit_s8s8_fp32(const utils::GemvParamA& A, const ut

} // namespace vnni

#ifdef __GNUC__
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute pop
#elif defined(__GNUC__)
#pragma GCC pop_options
#else
#endif
#endif
} // namespace avx512f
Expand Down
11 changes: 6 additions & 5 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ namespace bestla {
namespace kernel {
namespace avx512f {
#if CompileAVX512F()
#ifdef __GNUC__
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,avx512bw,avx512dq"))), apply_to = function)
#elif defined(__GNUC__)
#pragma GCC push_options
#pragma GCC target("avx512f", "avx512bw", "avx512vl", "avx512dq")
#elif defined(ICX)
#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,avx512bw,avx512dq"))), apply_to = function)
#endif

inline __m512 zmm_cvt_fp16_fp32(__m256i vfp16) { return _mm512_cvtph_ps(vfp16); }
Expand Down Expand Up @@ -6512,9 +6512,10 @@ static inline BTLA_CODE inplace_precompute_max_softmax_fp32_u8(int m_size, int n
}
return BTLA_CODE::Success;
}
#ifdef __GNUC__
#if defined(__INTEL_LLVM_COMPILER)
#pragma clang attribute pop
#elif defined(__GNUC__)
#pragma GCC pop_options
#else
#endif
#endif
} // namespace avx512f
Expand Down
10 changes: 6 additions & 4 deletions bestla/bestla/kernel_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -1925,6 +1925,7 @@ class ScaleTrackMax {
const int M_offset, const int N_offset, const int M, const int N, float scale,
int causal_offset, float alibi_slope, float tanh_scale, void* tmpcache,
size_t cachesize) {
#if CompileAVX2()
if (alibi_slope == 0 && tanh_scale == 0)
return avx2::scale_track_max_fp32_fp32<false, false>(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset, M,
N, scale, causal_offset, alibi_slope, tanh_scale, tmpcache,
Expand All @@ -1937,14 +1938,15 @@ class ScaleTrackMax {
return avx2::scale_track_max_fp32_fp32<true, false>(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset, M, N,
scale, causal_offset, alibi_slope, tanh_scale, tmpcache,
cachesize);
else
return BTLA_CODE::NotSupport;
#endif
return BTLA_CODE::NotSupport;
}

static BTLA_CODE forward_avx512(const SType* src, const int src_step, DType* dst, DType* dst_max, int ld_dst,
const int M_offset, const int N_offset, const int M, const int N, float scale,
int causal_offset, float alibi_slope, float tanh_scale, void* tmpcache,
size_t cachesize) {
#if CompileAVX512F()
if (alibi_slope == 0 && tanh_scale == 0)
return avx512f::scale_track_max_fp32_fp32<false, false>(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset,
M, N, scale, causal_offset, alibi_slope, tanh_scale,
Expand All @@ -1957,8 +1959,8 @@ class ScaleTrackMax {
return avx512f::scale_track_max_fp32_fp32<true, false>(src, src_step, dst, dst_max, ld_dst, M_offset, N_offset, M,
N, scale, causal_offset, alibi_slope, tanh_scale, tmpcache,
cachesize);
else
return BTLA_CODE::NotSupport;
#endif
return BTLA_CODE::NotSupport;
}
};

Expand Down
135 changes: 67 additions & 68 deletions bestla/bestla/sycl/sycl_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,76 +376,76 @@ class WeightS4Trans {
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;
auto ev = q->submit([&](sycl::handler& cgh) {
cgh.parallel_for(sycl::nd_range<1>(problem, group),
[=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size(
1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = B_scale + g_n * ldb;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
if constexpr (std::is_same_v<CType, sycl::half>) {
sycl::half2 tmpAcc = {0.f, 0.f};
for (int i = 0; i < k; i += GroupK * Unroll) {
cgh.parallel_for(
sycl::nd_range<1>(problem, group),
[=](sycl::nd_item<1> it) [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = B_scale + g_n * ldb;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
if constexpr (std::is_same_v<CType, sycl::half>) {
sycl::half2 tmpAcc = {0.f, 0.f};
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk];
sycl::half2 tmpB = {static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
tmpAcc += tmpA * tmpB * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
sycl::half2 sum = {0.f, 0.f};
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
}
} else {
CType tmpAcc = 0.f;
int constexpr Unroll = 2;
for (int i = 0; i < k; i += GroupK * Unroll) {
for (int ikk = 0; ikk < TileK; ikk += 2) {
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk];
sycl::half2 tmpB = {static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
tmpAcc += tmpA * tmpB * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
sycl::half2 sum = {0.f, 0.f};
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
}
} else {
CType tmpAcc = 0.f;
int constexpr Unroll = 2;
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
CType scale = *(sptr + sg_id * TileK / blocksize);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
tmpAcc += CType(aptr[sg_id * TileK + ikk]) *
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8) * scale;
tmpAcc += CType(aptr[sg_id * TileK + ikk + 1]) *
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8) * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
float sum = 0.f;
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum;
}
}
});
for (int ikk = 0; ikk < TileK; ikk += 2) {
tmpAcc +=
CType(aptr[sg_id * TileK + ikk]) * static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8) * scale;
tmpAcc +=
CType(aptr[sg_id * TileK + ikk + 1]) * static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8) * scale;
}
sptr += GroupK / blocksize;
aptr += GroupK;
bptr += GroupK / 2;
}
}
float sum = 0.f;
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum;
}
}
});
});
return ev;
} else {
Expand All @@ -458,8 +458,7 @@ class WeightS4Trans {
auto ev = q->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<1>(problem, group),
[=](sycl::nd_item<1> it) [[sycl::reqd_work_group_size(
1, 1, SgSize)]] [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
[=](sycl::nd_item<1> it) [[intel::kernel_args_restrict]] [[intel::reqd_sub_group_size(SgSize)]] {
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
Expand Down
Loading
Loading