diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index 02d9bae29..b3b16becb 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -555,7 +555,15 @@ static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales, for (int iv = 0; iv < VLoop; iv++) { auto idx = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr + iv * 8)); auto pad_idx = _mm256_cvtepu8_epi32(idx); - auto fp32_dq_v = _mm256_i32gather_ps(LUT, pad_idx, 4); +#if 0 + auto fp32_dq_v = _mm256_i32gather_ps(LUT, pad_idx, 4); +#else + auto mskgt8 = _mm256_cmpgt_epi32(pad_idx, v8); + auto fp32_dq_v0 = _mm256_permutevar8x32_ps(vLutL, pad_idx); + pad_idx = _mm256_sub_epi32(pad_idx, v8); + auto fp32_dq_v1 = _mm256_permutevar8x32_ps(vLutH, pad_idx); + auto fp32_dq_v = _mm256_blendv_ps(vLutL, vLutH, _mm256_castsi256_ps(mskgt8)); +#endif fp32_dq_v = _mm256_mul_ps(fp32_dq_v, vscales[iv]); if constexpr (std::is_same_v<_DST_T, float>) { _mm256_storeu_ps(dstptr + iv * 8, fp32_dq_v);