Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
remove SSE unpack 4bit
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Mar 14, 2024
1 parent 8ba7ee9 commit cc65433
Showing 1 changed file with 86 additions and 111 deletions.
197 changes: 86 additions & 111 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,42 @@ namespace avx2 {
#else
#endif

static uint8_t shuffle_map[] = {0x00, 0x01, 0x02, 0x03, 0xff, 0xff, 0xff, 0xff,
0x04, 0x05, 0x06, 0x07, 0xff, 0xff, 0xff, 0xff};

template <BTLA_DTYPE S4_T>
static inline __m128i unpack_4bits_sse(void* srcptr) {
auto shuffle_v = _mm_loadu_si128(reinterpret_cast<__m128i*>(shuffle_map));
auto raw_data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr));
auto xmm0 = _mm_shuffle_epi8(raw_data, shuffle_v);
auto xmm1 = _mm_srli_epi32(xmm0, 0x04);
auto and_helper = _mm_set1_epi8(0x0f);
xmm0 = _mm_and_si128(xmm0, and_helper);
xmm1 = _mm_and_si128(xmm1, and_helper);
auto xmm2 = _mm_unpacklo_epi8(xmm0, xmm1);
auto xmm3 = _mm_unpackhi_epi8(xmm0, xmm1);
xmm2 = _mm_unpacklo_epi64(xmm2, xmm3);
if constexpr (S4_T != BTLA_DTYPE::F4_NF4) xmm2 = _mm_slli_epi32(xmm2, 4);
return xmm2;
}

template <bool LowBits>
static inline __m256i unpack_4bits_avx2(void* srcptr, __m256i mask) {
auto raw_data = _mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr));
auto ymm0 = _mm256_cvtepu8_epi16(raw_data);
auto ymm1 = _mm256_slli_epi16(ymm0, 8);
ymm0 = _mm256_slli_epi16(ymm0, 4);
ymm0 = _mm256_or_si256(ymm0, ymm1);
ymm0 = _mm256_and_si256(ymm0, mask);
if constexpr (LowBits) {
ymm0 = _mm256_srli_epi16(ymm0, 4);
}
return ymm0;
}

static inline void convert_s4_s8_32_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) {
auto dst0 = unpack_4bits_avx2(srcptr, mask);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0);
template <int N, BTLA_DTYPE QT_T>
static inline void convert_s4_s8_N_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) {
static_assert(N % 2 == 0);
static_assert(N <= 64);
if constexpr (N == 32) {
auto dst0 = unpack_4bits_avx2<QT_T != BTLA_DTYPE::S4_CLIP>(srcptr, mask);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0);
} else if constexpr (N > 32) {
auto dst0 = unpack_4bits_avx2<QT_T != BTLA_DTYPE::S4_CLIP>(srcptr, mask);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0);
int8_t temp[32];
memcpy(temp, srcptr + 16, (N - 32) / 2);
dst0 = unpack_4bits_avx2<QT_T != BTLA_DTYPE::S4_CLIP>(temp, mask);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0);
memcpy(dstptr + 32, temp, (N - 32));
} else {
int8_t temp[32];
memcpy(temp, srcptr, N / 2);
auto dst0 = unpack_4bits_avx2<QT_T != BTLA_DTYPE::S4_CLIP>(temp, mask);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0);
memcpy(dstptr, temp, N);
}
}

inline __m256 ymm_cvt_bf16_fp32(__m128i vbf16) {
Expand All @@ -85,16 +89,6 @@ inline __m128i ymm_cvt_fp32_bf16(__m256 vfp32) {
return ymm_cvtepi32_epi16(_mm256_bsrli_epi128(_mm256_castps_si256(vfp32), 2));
}

template <BTLA_DTYPE S4_T>
static inline void convert_s4_s8_16_sse(int8_t* dstptr, int8_t* srcptr) {
auto dst0 = unpack_4bits_sse<S4_T>(srcptr);
if constexpr (S4_T == BTLA_DTYPE::F4_NF4) {
auto s8 = _mm_set1_epi8(8);
dst0 = _mm_sub_epi8(dst0, s8);
}
_mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr), dst0);
}

template <typename T>
static inline void convert_s8_fp_v8(T* dstptr, int8_t* srcptr) {
auto xmm = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr));
Expand All @@ -108,11 +102,6 @@ static inline void convert_s8_fp_v8(T* dstptr, int8_t* srcptr) {
}
}

static inline void fp4_pad_4bit(int8_t* dstptr, int8_t* srcptr) {
auto dst0 = unpack_4bits_sse<BTLA_DTYPE::F4_NF4>(srcptr);
_mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr), dst0);
}

template <int N, bool _IS_SYM>
static inline void dequant_s8_N_avx2(float* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps = nullptr) {
static_assert(N % 8 == 0);
Expand Down Expand Up @@ -383,10 +372,9 @@ static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr,
if (col == ld_src) {
size_t elesize = static_cast<size_t>(row) * col;
size_t velt = utils::padto_le(elesize, 32);
size_t velt2 = utils::padto_le(elesize, 16);
size_t i = 0;
for (; i < velt; i += 32) {
convert_s4_s8_32_avx2(dstptr + i, reinterpret_cast<int8_t*>(srcptr + i / 2), vmask);
convert_s4_s8_N_avx2<32, S4_T>(dstptr + i, reinterpret_cast<int8_t*>(srcptr + i / 2), vmask);
}
for (; i < elesize; i += 2) {
auto tmp = srcptr[i / 2];
Expand All @@ -405,27 +393,17 @@ inline BTLA_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr
auto vmask = _mm256_set1_epi32(*reinterpret_cast<int*>(&mask));
if (col == ld_src) {
size_t elesize = static_cast<size_t>(row) * col;
#if 0
size_t ele16 = utils::padto_le(elesize, 16);
size_t i = 0;
assert(tmpsize >= 16);
for (; i < ele16; i += 16) {
convert_s4_s8_16_sse<S4_T>(tmp, reinterpret_cast<int8_t*>(srcptr + i / 2));
convert_s8_fp_v8(dstptr + i, tmp);
convert_s8_fp_v8(dstptr + i + 8, tmp + 8);
}
#else

size_t velt = utils::padto_le(elesize, 32);
size_t i = 0;
assert(tmpsize >= 32);
for (; i < velt; i += 32) {
convert_s4_s8_32_avx2(tmp, reinterpret_cast<int8_t*>(srcptr + i / 2), vmask);
convert_s4_s8_N_avx2<32, S4_T>(tmp, reinterpret_cast<int8_t*>(srcptr + i / 2), vmask);
convert_s8_fp_v8(dstptr + i, tmp);
convert_s8_fp_v8(dstptr + i + 8, tmp + 8);
convert_s8_fp_v8(dstptr + i + 16, tmp + 16);
convert_s8_fp_v8(dstptr + i + 24, tmp + 24);
}
#endif
for (; i < elesize; i += 2) {
auto tmp = srcptr[i / 2];
dstptr[i + 0] = static_cast<_DST_T>(static_cast<float>(ref::get_s8<S4_T>(tmp.x)));
Expand Down Expand Up @@ -622,12 +600,12 @@ inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dst
auto vmask = _mm256_set1_epi32(*reinterpret_cast<int*>(&mask));
if (col == ld_src) {
size_t elesize = static_cast<size_t>(row) * col;
size_t ele16 = utils::padto_le(elesize, 16);
size_t velt = utils::padto_le(elesize, 32);
size_t i = 0;
assert(tmpsize >= 16);
for (; i < ele16; i += 16) {
fp4_pad_4bit(tmp, reinterpret_cast<int8_t*>(srcptr + i / 2));
unpack_f4_N<16, DST_T, F4_T>(dstptr + i, tmp);
assert(tmpsize >= 32);
for (; i < velt; i += 32) {
convert_s4_s8_N_avx2<32, F4_T>(tmp, reinterpret_cast<int8_t*>(srcptr + i / 2), vmask);
unpack_f4_N<32, DST_T, F4_T>(dstptr + i, tmp);
}
for (; i < elesize; i += 2) {
auto tmp = srcptr[i / 2];
Expand All @@ -639,11 +617,11 @@ inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dst
return BTLA_CODE::Success;
}

template <bool _IS_SYM, int _NCOL, typename _ST, typename _DST_T>
static inline BTLA_CODE decompress_kblock_bit4_packrow1(
utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int8_t* zero_points,
int k_offset, int kblock, int NPad, void (*dequantize)(_DST_T*, int8_t*, __m256*, __m256i*),
void (*pad_bit4_16)(int8_t*, int8_t*), void (*pad_bit4_8)(int8_t*, int8_t*), int8_t* tmpbuf, size_t tmpsize) {
template <BTLA_DTYPE QT_T, bool _IS_SYM, int _NCOL, typename _ST, typename _DST_T>
static inline BTLA_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col,
int ld_src, int ld_dst, _ST* scales, int8_t* zero_points,
int k_offset, int kblock, int NPad, int8_t* tmpbuf,
size_t tmpsize) {
uint32_t mask = 0xf0f0f0f0;
auto vmask = _mm256_set1_epi32(*reinterpret_cast<int*>(&mask));
int constexpr NReg = _NCOL / 8;
Expand All @@ -654,13 +632,21 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1(
__m256i vzps[NReg];
int constexpr UnrollRow = 4;
assert(kblock % UnrollRow == 0);
int constexpr Loop16 = _NCOL * UnrollRow / 16;
int constexpr NTile = 32;
int constexpr Loop32 = _NCOL * UnrollRow / NTile;
assert(tmpsize >= (_NCOL * UnrollRow));
int row0 = kblock - k_offset % kblock;
row0 = row0 == kblock ? 0 : row0;
row0 = row0 > row ? row : row0;
int row1 = row - row0;
int irow = 0;
auto dequantize = [&](_DST_T* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps = nullptr) {
if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) {
dequant_s8_N_avx2<_NCOL, _IS_SYM>(dstptr, srcptr, vscales, vzps);
} else {
dequant_f4_N<_NCOL, _DST_T, QT_T>(dstptr, srcptr, vscales, vzps);
}
};
if (row0) {
int rowpad4 = utils::padto_le(row0, UnrollRow);
for (int iv = 0; iv < NReg; iv++) {
Expand All @@ -672,19 +658,15 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1(
}
}
for (; irow < rowpad4; irow += UnrollRow) {
for (int iter16 = 0; iter16 < Loop16; iter16++)
pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8 * iter16));
for (int iter16 = 0; iter16 < Loop32; iter16++)
convert_s4_s8_N_avx2<NTile, QT_T>(
tmpbuf + iter16 * NTile, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + NTile / 2 * iter16), vmask);
for (int iterr = 0; iterr < UnrollRow; iterr++)
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps);
}
for (; irow < row0; irow++) {
if constexpr (_NCOL == 24) {
pad_bit4_16(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2));
pad_bit4_8(tmpbuf + 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8));
} else {
for (int iter16 = 0; iter16 < 3; iter16++)
pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8 * iter16));
}
convert_s4_s8_N_avx2<_NCOL, QT_T>(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), vmask);

dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
}
}
Expand All @@ -700,8 +682,10 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1(
}
}
for (int irr = 0; irr < kblock; irr += UnrollRow) {
for (int iter16 = 0; iter16 < Loop16; iter16++)
pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast<int8_t*>(srcptr + (irow + irr) * ld_src / 2 + 8 * iter16));
for (int iter16 = 0; iter16 < Loop32; iter16++)
convert_s4_s8_N_avx2<NTile, QT_T>(
tmpbuf + iter16 * NTile, reinterpret_cast<int8_t*>(srcptr + (irow + irr) * ld_src / 2 + NTile / 2 * iter16),
vmask);
for (int iterr = 0; iterr < UnrollRow; iterr++)
dequantize(dstptr + (irow + irr + iterr) * ld_src, tmpbuf + iterr * _NCOL, vscales, vzps);
}
Expand All @@ -718,31 +702,24 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1(
auto rowre = row - irow;
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
for (; irow < rowpad4; irow += UnrollRow) {
for (int iter16 = 0; iter16 < Loop16; iter16++)
pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8 * iter16));
for (int iter16 = 0; iter16 < Loop32; iter16++)
convert_s4_s8_N_avx2<NTile, QT_T>(
tmpbuf + iter16 * NTile, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + NTile / 2 * iter16), vmask);
for (int iterr = 0; iterr < UnrollRow; iterr++)
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps);
}
for (; irow < row; irow++) {
if constexpr (_NCOL == 24) {
pad_bit4_16(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2));
pad_bit4_8(tmpbuf + 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8));
} else {
for (int iter16 = 0; iter16 < 3; iter16++)
pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8 * iter16));
}
convert_s4_s8_N_avx2<_NCOL, QT_T>(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), vmask);
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
}
}
return BTLA_CODE::Success;
}

template <bool _IS_SYM, typename _ST, typename _DST_T>
template <BTLA_DTYPE S4_T, bool _IS_SYM, typename _ST, typename _DST_T>
static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col,
int ld_src, int ld_dst, _ST* scales, int8_t* zero_points,
int k_offset, int kblock, int NPad,
void (*dequantize)(_DST_T*, int8_t*, __m256*, __m256i*),
void (*pad_bit4)(int8_t*, int8_t*), int8_t* tmp,
int k_offset, int kblock, int NPad, int8_t* tmp,
size_t tmpsize) {
return BTLA_CODE::NotSupport;
}
Expand All @@ -755,28 +732,28 @@ static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* d
if constexpr (_PACK_ROW == 1 && std::is_same_v<_DST_T, float> && std::is_same_v<_ST, float>) {
if (zero_points == nullptr) {
if (col == 24) {
ret = avx2::decompress_kblock_bit4_packrow1<true, 24>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad,
&avx2::dequant_s8_N_avx2<24, true>, &avx2::convert_s4_s8_16_sse<S4_T>, &ref::convert_s4_s8_8<S4_T>,
reinterpret_cast<int8_t*>(tmp), tmpsize);
ret = decompress_kblock_bit4_packrow1<S4_T, true, 24>(srcptr, dstptr, row, col, ld_src, ld_dst, scales,
zero_points, k_offset, kblock, NPad,
reinterpret_cast<int8_t*>(tmp), tmpsize);
} else if (col == 48) {
ret = avx2::decompress_kblock_bit4_packrow1<true, 48>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad,
&avx2::dequant_s8_N_avx2<48, true>, &avx2::convert_s4_s8_16_sse<S4_T>, &ref::convert_s4_s8_8<S4_T>,
reinterpret_cast<int8_t*>(tmp), tmpsize);
ret = decompress_kblock_bit4_packrow1<S4_T, true, 48>(srcptr, dstptr, row, col, ld_src, ld_dst, scales,
zero_points, k_offset, kblock, NPad,
reinterpret_cast<int8_t*>(tmp), tmpsize);
} else {
assert(0);
}

} else {
if (col == 24) {
ret = avx2::decompress_kblock_bit4_packrow1<false, 24>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad,
&avx2::dequant_s8_N_avx2<24, false>, &avx2::convert_s4_s8_16_sse<S4_T>, &ref::convert_s4_s8_8<S4_T>,
reinterpret_cast<int8_t*>(tmp), tmpsize);
ret = decompress_kblock_bit4_packrow1<S4_T, false, 24>(srcptr, dstptr, row, col, ld_src, ld_dst, scales,
zero_points, k_offset, kblock, NPad,
reinterpret_cast<int8_t*>(tmp), tmpsize);
} else if (col == 48) {
ret = avx2::decompress_kblock_bit4_packrow1<false, 48>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad,
&avx2::dequant_s8_N_avx2<48, false>, &avx2::convert_s4_s8_16_sse<S4_T>, &ref::convert_s4_s8_8<S4_T>,
reinterpret_cast<int8_t*>(tmp), tmpsize);
ret = decompress_kblock_bit4_packrow1<S4_T, false, 48>(srcptr, dstptr, row, col, ld_src, ld_dst, scales,
zero_points, k_offset, kblock, NPad,
reinterpret_cast<int8_t*>(tmp), tmpsize);
} else {
assert(0);
}
}
}
Expand All @@ -789,20 +766,18 @@ static inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dst
int8_t* tmp, size_t tmpsize) {
if constexpr (_PACK_ROW == 1) {
if (col == 24) {
return decompress_kblock_bit4_packrow1<true, 24, _ST, _DST_T>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad,
&dequant_f4_N<24, _DST_T, _F4_T>, fp4_pad_4bit, &ref::convert_s4_s8_8<_F4_T>, tmp, tmpsize);
return decompress_kblock_bit4_packrow1<_F4_T, true, 24, _ST, _DST_T>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize);
}
if (col == 48) {
return decompress_kblock_bit4_packrow1<true, 48, _ST, _DST_T>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad,
&dequant_f4_N<48, _DST_T, _F4_T>, fp4_pad_4bit, &ref::convert_s4_s8_8<_F4_T>, tmp, tmpsize);
return decompress_kblock_bit4_packrow1<_F4_T, true, 48, _ST, _DST_T>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize);
}
} else if constexpr (_PACK_ROW == 2) {
return decompress_kblock_bit4_packrow2<true, _ST, _DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr,
k_offset, kblock, NPad, &dequant_f4_N<64, _DST_T, _F4_T>,
fp4_pad_4bit, tmp, tmpsize);
return decompress_kblock_bit4_packrow2<_F4_T, true, _ST, _DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales,
nullptr, k_offset, kblock, NPad, tmp, tmpsize);
}
assert(0);
return BTLA_CODE::NotSupport;
}

Expand Down

0 comments on commit cc65433

Please sign in to comment.