diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index cbd5db18b..9cb6a09d8 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -28,25 +28,7 @@ namespace avx2 { #else #endif -static uint8_t shuffle_map[] = {0x00, 0x01, 0x02, 0x03, 0xff, 0xff, 0xff, 0xff, - 0x04, 0x05, 0x06, 0x07, 0xff, 0xff, 0xff, 0xff}; - -template -static inline __m128i unpack_4bits_sse(void* srcptr) { - auto shuffle_v = _mm_loadu_si128(reinterpret_cast<__m128i*>(shuffle_map)); - auto raw_data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr)); - auto xmm0 = _mm_shuffle_epi8(raw_data, shuffle_v); - auto xmm1 = _mm_srli_epi32(xmm0, 0x04); - auto and_helper = _mm_set1_epi8(0x0f); - xmm0 = _mm_and_si128(xmm0, and_helper); - xmm1 = _mm_and_si128(xmm1, and_helper); - auto xmm2 = _mm_unpacklo_epi8(xmm0, xmm1); - auto xmm3 = _mm_unpackhi_epi8(xmm0, xmm1); - xmm2 = _mm_unpacklo_epi64(xmm2, xmm3); - if constexpr (S4_T != BTLA_DTYPE::F4_NF4) xmm2 = _mm_slli_epi32(xmm2, 4); - return xmm2; -} - +template static inline __m256i unpack_4bits_avx2(void* srcptr, __m256i mask) { auto raw_data = _mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr)); auto ymm0 = _mm256_cvtepu8_epi16(raw_data); @@ -54,12 +36,34 @@ static inline __m256i unpack_4bits_avx2(void* srcptr, __m256i mask) { ymm0 = _mm256_slli_epi16(ymm0, 4); ymm0 = _mm256_or_si256(ymm0, ymm1); ymm0 = _mm256_and_si256(ymm0, mask); + if constexpr (LowBits) { + ymm0 = _mm256_srli_epi16(ymm0, 4); + } return ymm0; } -static inline void convert_s4_s8_32_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) { - auto dst0 = unpack_4bits_avx2(srcptr, mask); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); +template +static inline void convert_s4_s8_N_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) { + static_assert(N % 2 == 0); + static_assert(N <= 64); + if constexpr (N == 32) { + auto dst0 = unpack_4bits_avx2(srcptr, mask); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); + } else if constexpr (N > 32) { + auto dst0 = unpack_4bits_avx2(srcptr, mask); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); + int8_t temp[32]; + memcpy(temp, srcptr + 16, (N - 32) / 2); + dst0 = unpack_4bits_avx2(temp, mask); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); + memcpy(dstptr + 32, temp, (N - 32)); + } else { + int8_t temp[32]; + memcpy(temp, srcptr, N / 2); + auto dst0 = unpack_4bits_avx2(temp, mask); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); + memcpy(dstptr, temp, N); + } } inline __m256 ymm_cvt_bf16_fp32(__m128i vbf16) { @@ -85,16 +89,6 @@ inline __m128i ymm_cvt_fp32_bf16(__m256 vfp32) { return ymm_cvtepi32_epi16(_mm256_bsrli_epi128(_mm256_castps_si256(vfp32), 2)); } -template -static inline void convert_s4_s8_16_sse(int8_t* dstptr, int8_t* srcptr) { - auto dst0 = unpack_4bits_sse(srcptr); - if constexpr (S4_T == BTLA_DTYPE::F4_NF4) { - auto s8 = _mm_set1_epi8(8); - dst0 = _mm_sub_epi8(dst0, s8); - } - _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr), dst0); -} - template static inline void convert_s8_fp_v8(T* dstptr, int8_t* srcptr) { auto xmm = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr)); @@ -108,11 +102,6 @@ static inline void convert_s8_fp_v8(T* dstptr, int8_t* srcptr) { } } -static inline void fp4_pad_4bit(int8_t* dstptr, int8_t* srcptr) { - auto dst0 = unpack_4bits_sse(srcptr); - _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr), dst0); -} - template static inline void dequant_s8_N_avx2(float* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps = nullptr) { static_assert(N % 8 == 0); @@ -383,10 +372,9 @@ static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, if (col == ld_src) { size_t elesize = static_cast(row) * col; size_t velt = utils::padto_le(elesize, 32); - size_t velt2 = utils::padto_le(elesize, 16); size_t i = 0; for (; i < velt; i += 32) { - convert_s4_s8_32_avx2(dstptr + i, reinterpret_cast(srcptr + i / 2), vmask); + convert_s4_s8_N_avx2<32, S4_T>(dstptr + i, reinterpret_cast(srcptr + i / 2), vmask); } for (; i < elesize; i += 2) { auto tmp = srcptr[i / 2]; @@ -405,27 +393,17 @@ inline BTLA_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); if (col == ld_src) { size_t elesize = static_cast(row) * col; -#if 0 - size_t ele16 = utils::padto_le(elesize, 16); - size_t i = 0; - assert(tmpsize >= 16); - for (; i < ele16; i += 16) { - convert_s4_s8_16_sse(tmp, reinterpret_cast(srcptr + i / 2)); - convert_s8_fp_v8(dstptr + i, tmp); - convert_s8_fp_v8(dstptr + i + 8, tmp + 8); - } -#else + size_t velt = utils::padto_le(elesize, 32); size_t i = 0; assert(tmpsize >= 32); for (; i < velt; i += 32) { - convert_s4_s8_32_avx2(tmp, reinterpret_cast(srcptr + i / 2), vmask); + convert_s4_s8_N_avx2<32, S4_T>(tmp, reinterpret_cast(srcptr + i / 2), vmask); convert_s8_fp_v8(dstptr + i, tmp); convert_s8_fp_v8(dstptr + i + 8, tmp + 8); convert_s8_fp_v8(dstptr + i + 16, tmp + 16); convert_s8_fp_v8(dstptr + i + 24, tmp + 24); } -#endif for (; i < elesize; i += 2) { auto tmp = srcptr[i / 2]; dstptr[i + 0] = static_cast<_DST_T>(static_cast(ref::get_s8(tmp.x))); @@ -622,12 +600,12 @@ inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dst auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); if (col == ld_src) { size_t elesize = static_cast(row) * col; - size_t ele16 = utils::padto_le(elesize, 16); + size_t velt = utils::padto_le(elesize, 32); size_t i = 0; - assert(tmpsize >= 16); - for (; i < ele16; i += 16) { - fp4_pad_4bit(tmp, reinterpret_cast(srcptr + i / 2)); - unpack_f4_N<16, DST_T, F4_T>(dstptr + i, tmp); + assert(tmpsize >= 32); + for (; i < velt; i += 32) { + convert_s4_s8_N_avx2<32, F4_T>(tmp, reinterpret_cast(srcptr + i / 2), vmask); + unpack_f4_N<32, DST_T, F4_T>(dstptr + i, tmp); } for (; i < elesize; i += 2) { auto tmp = srcptr[i / 2]; @@ -639,11 +617,11 @@ inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dst return BTLA_CODE::Success; } -template -static inline BTLA_CODE decompress_kblock_bit4_packrow1( - utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, void (*dequantize)(_DST_T*, int8_t*, __m256*, __m256i*), - void (*pad_bit4_16)(int8_t*, int8_t*), void (*pad_bit4_8)(int8_t*, int8_t*), int8_t* tmpbuf, size_t tmpsize) { +template +static inline BTLA_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, int8_t* tmpbuf, + size_t tmpsize) { uint32_t mask = 0xf0f0f0f0; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); int constexpr NReg = _NCOL / 8; @@ -654,13 +632,21 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1( __m256i vzps[NReg]; int constexpr UnrollRow = 4; assert(kblock % UnrollRow == 0); - int constexpr Loop16 = _NCOL * UnrollRow / 16; + int constexpr NTile = 32; + int constexpr Loop32 = _NCOL * UnrollRow / NTile; assert(tmpsize >= (_NCOL * UnrollRow)); int row0 = kblock - k_offset % kblock; row0 = row0 == kblock ? 0 : row0; row0 = row0 > row ? row : row0; int row1 = row - row0; int irow = 0; + auto dequantize = [&](_DST_T* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps = nullptr) { + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dequant_s8_N_avx2<_NCOL, _IS_SYM>(dstptr, srcptr, vscales, vzps); + } else { + dequant_f4_N<_NCOL, _DST_T, QT_T>(dstptr, srcptr, vscales, vzps); + } + }; if (row0) { int rowpad4 = utils::padto_le(row0, UnrollRow); for (int iv = 0; iv < NReg; iv++) { @@ -672,19 +658,15 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1( } } for (; irow < rowpad4; irow += UnrollRow) { - for (int iter16 = 0; iter16 < Loop16; iter16++) - pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8 * iter16)); + for (int iter16 = 0; iter16 < Loop32; iter16++) + convert_s4_s8_N_avx2( + tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + irow * ld_src / 2 + NTile / 2 * iter16), vmask); for (int iterr = 0; iterr < UnrollRow; iterr++) dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps); } for (; irow < row0; irow++) { - if constexpr (_NCOL == 24) { - pad_bit4_16(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2)); - pad_bit4_8(tmpbuf + 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8)); - } else { - for (int iter16 = 0; iter16 < 3; iter16++) - pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8 * iter16)); - } + convert_s4_s8_N_avx2<_NCOL, QT_T>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2), vmask); + dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); } } @@ -700,8 +682,10 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1( } } for (int irr = 0; irr < kblock; irr += UnrollRow) { - for (int iter16 = 0; iter16 < Loop16; iter16++) - pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + 8 * iter16)); + for (int iter16 = 0; iter16 < Loop32; iter16++) + convert_s4_s8_N_avx2( + tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + NTile / 2 * iter16), + vmask); for (int iterr = 0; iterr < UnrollRow; iterr++) dequantize(dstptr + (irow + irr + iterr) * ld_src, tmpbuf + iterr * _NCOL, vscales, vzps); } @@ -718,31 +702,24 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1( auto rowre = row - irow; int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow; for (; irow < rowpad4; irow += UnrollRow) { - for (int iter16 = 0; iter16 < Loop16; iter16++) - pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8 * iter16)); + for (int iter16 = 0; iter16 < Loop32; iter16++) + convert_s4_s8_N_avx2( + tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + irow * ld_src / 2 + NTile / 2 * iter16), vmask); for (int iterr = 0; iterr < UnrollRow; iterr++) dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps); } for (; irow < row; irow++) { - if constexpr (_NCOL == 24) { - pad_bit4_16(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2)); - pad_bit4_8(tmpbuf + 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8)); - } else { - for (int iter16 = 0; iter16 < 3; iter16++) - pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8 * iter16)); - } + convert_s4_s8_N_avx2<_NCOL, QT_T>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2), vmask); dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); } } return BTLA_CODE::Success; } -template +template static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, - void (*dequantize)(_DST_T*, int8_t*, __m256*, __m256i*), - void (*pad_bit4)(int8_t*, int8_t*), int8_t* tmp, + int k_offset, int kblock, int NPad, int8_t* tmp, size_t tmpsize) { return BTLA_CODE::NotSupport; } @@ -755,28 +732,28 @@ static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* d if constexpr (_PACK_ROW == 1 && std::is_same_v<_DST_T, float> && std::is_same_v<_ST, float>) { if (zero_points == nullptr) { if (col == 24) { - ret = avx2::decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - &avx2::dequant_s8_N_avx2<24, true>, &avx2::convert_s4_s8_16_sse, &ref::convert_s4_s8_8, - reinterpret_cast(tmp), tmpsize); + ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); } else if (col == 48) { - ret = avx2::decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - &avx2::dequant_s8_N_avx2<48, true>, &avx2::convert_s4_s8_16_sse, &ref::convert_s4_s8_8, - reinterpret_cast(tmp), tmpsize); + ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } else { + assert(0); } } else { if (col == 24) { - ret = avx2::decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - &avx2::dequant_s8_N_avx2<24, false>, &avx2::convert_s4_s8_16_sse, &ref::convert_s4_s8_8, - reinterpret_cast(tmp), tmpsize); + ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); } else if (col == 48) { - ret = avx2::decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - &avx2::dequant_s8_N_avx2<48, false>, &avx2::convert_s4_s8_16_sse, &ref::convert_s4_s8_8, - reinterpret_cast(tmp), tmpsize); + ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } else { + assert(0); } } } @@ -789,20 +766,18 @@ static inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dst int8_t* tmp, size_t tmpsize) { if constexpr (_PACK_ROW == 1) { if (col == 24) { - return decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, - &dequant_f4_N<24, _DST_T, _F4_T>, fp4_pad_4bit, &ref::convert_s4_s8_8<_F4_T>, tmp, tmpsize); + return decompress_kblock_bit4_packrow1<_F4_T, true, 24, _ST, _DST_T>( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize); } if (col == 48) { - return decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, - &dequant_f4_N<48, _DST_T, _F4_T>, fp4_pad_4bit, &ref::convert_s4_s8_8<_F4_T>, tmp, tmpsize); + return decompress_kblock_bit4_packrow1<_F4_T, true, 48, _ST, _DST_T>( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize); } } else if constexpr (_PACK_ROW == 2) { - return decompress_kblock_bit4_packrow2(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, - k_offset, kblock, NPad, &dequant_f4_N<64, _DST_T, _F4_T>, - fp4_pad_4bit, tmp, tmpsize); + return decompress_kblock_bit4_packrow2<_F4_T, true, _ST, _DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + nullptr, k_offset, kblock, NPad, tmp, tmpsize); } + assert(0); return BTLA_CODE::NotSupport; }