Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Sync itrex1.3 (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
airMeng authored Dec 22, 2023
1 parent 799f67c commit 6d8bb4a
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 59 deletions.
9 changes: 3 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@ option(NE_AVX512_VBMI "neural_engine: enable AVX512-VBMI"
option(NE_AVX512_VNNI "neural_engine: enable AVX512-VNNI" OFF)
option(NE_FMA "neural_engine: enable FMA" ON)
option(NE_AMX "neural_engine: enable AMX" OFF)

# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
option(NE_F16C "neural_engine: enable F16C" ON)
endif()
option(NE_F16C "neural_engine: enable F16C" ON)

# 3rd party libs
option(NE_ONEDNN "neural_engine: use oneDNN" ON)
Expand Down Expand Up @@ -93,6 +89,8 @@ if (NE_GELU_VEC)
endif()
option(NE_PYTHON_API "neural_engine: use python api" OFF)
option(NE_SIMD_VEC_DOT_F16 "neural_engine: enable vec_dot_fp16 SIMD optimization" ON)
option(BUILD_SHARED_LIBS "If build as shared libs" ON)

if (NE_SIMD_VEC_DOT_F16)
add_compile_definitions(NE_SIMD_VEC_DOT_F16)
endif()
Expand All @@ -103,7 +101,6 @@ endif()

if (MSVC)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS NOMINMAX)

if (BUILD_SHARED_LIBS)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
Expand Down
21 changes: 12 additions & 9 deletions bestla/jblas/jit_blas_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class SchedulerBase : public Scheduler2D {
mL2Use += static_cast<size_t>(mBlock[1]) * mBlock[2] * mEleSize[1];
mL2Use += static_cast<size_t>(mStep[0]) * mBlock[2] * mEleSize[0];
}
const float DensityThres = 32;
const float DensityThres = 16;
static size_t constexpr ReservedSize = 32ULL * 1024ULL;

virtual float calculate_score() {
Expand Down Expand Up @@ -364,7 +364,7 @@ class SchedulerKBlock : public Scheduler2D {
mL2Use += static_cast<size_t>(mBlock[1]) * mBlock[2] * mEleSize[1];
mL2Use += static_cast<size_t>(mStep[0]) * mBlock[2] * mEleSize[0];
}
const float DensityThres = 32;
const float DensityThres = 16;

float calculate_score() {
int tmpnstep = mThdSize[1] < _GemmCore_T::PREFERRED_N ? mThdSize[1] : _GemmCore_T::PREFERRED_N;
Expand Down Expand Up @@ -489,13 +489,14 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
this->mL2Use += static_cast<size_t>(blks) * (this->mBlock[1] + this->mStep[0]) *
(sizeof(float) + sizeof(int8_t) + sizeof(float)); // scale+zp+reduce
assert(this->mL2Use <= this->mL2Size - ReservedSize);
assert(this->mBlock[0]>0);
assert(this->mBlock[1]>0);
assert(this->mBlock[2]>0);
assert(this->mBlock[0] > 0);
assert(this->mBlock[1] > 0);
assert(this->mBlock[2] > 0);
assert(this->mBlock[2] % _GemmCore_T::KTILE == 0);
}

protected:
const float DensityThres = 32;
const float DensityThres = 16;
static size_t constexpr ReservedSize = 32ULL * 1024ULL;

void cache_blocking_compute() override {
Expand Down Expand Up @@ -529,6 +530,11 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
(this->mStep[0] * this->mEleSize[0] +
float(CorSize * (this->mStep[0] + this->mBlock[1])) / this->mKBlock +
this->mBlock[1] * this->mEleSize[1]));
if (rawk < this->mKBlock) {
rawk = static_cast<int>((valid_total - this->mBlock[0] * this->mBlock[1] * this->mEleSize[2] -
1 * CorSize * (this->mStep[0] + this->mBlock[1])) /
(this->mStep[0] * this->mEleSize[0] + this->mBlock[1] * this->mEleSize[1]));
}
rawk = std::min(rawk, this->mSizePadded[2]);
this->mBlock[2] = utils::padto_le(rawk, this->mStep[2]);
if (this->mBlock[2] > this->mKBlock) {
Expand Down Expand Up @@ -569,9 +575,6 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
this->mBlock[2] = static_cast<int>(getMaxK(this->mBlock[1]));
this->mBlock[2] = utils::padto_le(this->mBlock[2], this->mStep[2]);
this->mBlock[2] = std::min(mKBlock, this->mBlock[2]);
auto tmp = utils::updiv(mKBlock, this->mBlock[2]);
while (mKBlock % tmp != 0) tmp++; // TODO(Yu) optimize
this->mBlock[2] = utils::downdiv(mKBlock, tmp);
}
}

Expand Down
20 changes: 16 additions & 4 deletions bestla/jblas/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,14 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
for (; j < align_col; j += 8) quant();
for (; j < col; j++) {
auto fp_v = ref::f8_to_fp32(srcptr[i * ld_src + j], src_f8_type);
if constexpr (std::is_same_v<_S_T, utils::f8>) {
dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x);
} else if constexpr (std::is_same_v<_S_T, float>) {
dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW];
if constexpr (WITH_SCALE) {
if constexpr (std::is_same_v<_S_T, utils::f8>) {
dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x);
} else if constexpr (std::is_same_v<_S_T, float>) {
dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW];
}
} else {
dstptr[i * ld_dst + j] = fp_v;
}
}
}
Expand Down Expand Up @@ -636,6 +640,14 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(
vzps[iv] = _mm256_cvtepi8_epi32(tmp);
}
}
auto rowre = row - irow;
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
for (; irow < rowpad4; irow += UnrollRow) {
for (int iter16 = 0; iter16 < Loop16; iter16++)
pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8 * iter16));
for (int iterr = 0; iterr < UnrollRow; iterr++)
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps);
}
for (; irow < row; irow++) {
if constexpr (_NCOL == 24) {
pad_bit4_16(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2));
Expand Down
51 changes: 33 additions & 18 deletions bestla/jblas/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,28 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr,
vzps[iv] = _mm512_cvtepi8_epi32(tmp);
}
}
}
for (; irow < row; irow++) {
pad_bit4(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48);
if constexpr (_IS_SYM) {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr);
} else {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
auto rowre = row - irow;
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
for (; irow < rowpad4; irow += UnrollRow) {
for (int iter64 = 0; iter64 < Loop64; iter64++) {
pad_bit4(tmpbuf + iter64 * 64, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 32 * iter64), zmm_mask,
LoadMask64);
}
for (int iterr = 0; iterr < UnrollRow; iterr++) {
if constexpr (_IS_SYM) {
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, nullptr);
} else {
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, vzps);
}
}
}
for (; irow < row; irow++) {
pad_bit4(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48);
if constexpr (_IS_SYM) {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr);
} else {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
}
}
}
return JblasSuccess;
Expand Down Expand Up @@ -565,7 +580,7 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
auto quant = [&](__mmask16 mask) {
__m128i f8_src;
auto sign_revert =
_mm512_cvtepi8_epi32(_mm_mask_loadu_epi8(f8_src, mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)));
_mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)));
auto e_revert = sign_revert;
auto mantissa_revert = sign_revert;
sign_revert = _mm512_slli_epi32(sign_revert, 24);
Expand Down Expand Up @@ -888,10 +903,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
zmm2 = _mm512_add_ps(zmm2, zmm_zp);
zmm3 = _mm512_add_ps(zmm3, zmm_zp);
} else {
mask4 = _mm512_cmplt_ps_mask(zmm0, zmm_v0);
mask5 = _mm512_cmplt_ps_mask(zmm1, zmm_v0);
mask6 = _mm512_cmplt_ps_mask(zmm2, zmm_v0);
mask7 = _mm512_cmplt_ps_mask(zmm3, zmm_v0);
mask4 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1);
mask5 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 1);
mask6 = _mm512_cmp_ps_mask(zmm2, zmm_v0, 1);
mask7 = _mm512_cmp_ps_mask(zmm3, zmm_v0, 1);

zmm0 = _mm512_abs_ps(zmm0);
zmm1 = _mm512_abs_ps(zmm1);
Expand All @@ -908,10 +923,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
zmm5 = _mm512_sub_ps(zmm1, sub_v);
zmm6 = _mm512_sub_ps(zmm2, sub_v);
zmm7 = _mm512_sub_ps(zmm3, sub_v);
mask0 = _mm512_cmple_ps_mask(zmm4, zmm_v0);
mask1 = _mm512_cmple_ps_mask(zmm5, zmm_v0);
mask2 = _mm512_cmple_ps_mask(zmm6, zmm_v0);
mask3 = _mm512_cmple_ps_mask(zmm7, zmm_v0);
mask0 = _mm512_cmp_ps_mask(zmm4, zmm_v0, 2);
mask1 = _mm512_cmp_ps_mask(zmm5, zmm_v0, 2);
mask2 = _mm512_cmp_ps_mask(zmm6, zmm_v0, 2);
mask3 = _mm512_cmp_ps_mask(zmm7, zmm_v0, 2);
xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
xmm1 = _mm_mask_blend_epi8(mask1, xmm1, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
xmm2 = _mm_mask_blend_epi8(mask2, xmm2, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
Expand Down Expand Up @@ -949,7 +964,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
auto zp = _mm512_set1_ps(0.8480964004993439f);
zmm0 = _mm512_add_ps(zmm0, zp);
} else {
mask1 = _mm512_cmplt_ps_mask(zmm0, zmm_v0);
mask1 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1);
zmm0 = _mm512_abs_ps(zmm0);
}
constexpr int loop_num = F4_T == JBLAS_DTYPE::F4_NF4 ? 16 : 8;
Expand All @@ -959,7 +974,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) sub_v = _mm512_set1_ps(F4_BNB_quant_sub_helper[i]);
if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps(F4_E2M1_quant_sub_helper[i]);
zmm1 = _mm512_sub_ps(zmm0, sub_v);
mask0 = _mm512_cmple_ps_mask(zmm1, zmm_v0);
mask0 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 2);
xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
zmm0 = _mm512_mask_add_ps(zmm0, mask0, zmm0, avoid_double_cmp);
}
Expand Down
56 changes: 39 additions & 17 deletions bestla/jblas/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,25 +230,47 @@ inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) {
dstptr[7] = tmp;
}

inline void convert_s4_s8_8_lowbits(int8_t* dstptr, int8_t* srcptr) {
auto src32 = *reinterpret_cast<uint32_t*>(srcptr);
auto tmp = static_cast<int>(src32 & 0xf);
dstptr[0] = static_cast<int8_t>(tmp);
tmp = static_cast<int>(src32 & 0xf0) >> 4;
dstptr[1] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf00) >> 8);
dstptr[2] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf000) >> 12);
dstptr[3] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf0000) >> 16);
dstptr[4] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf00000) >> 20);
dstptr[5] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf000000) >> 24);
dstptr[6] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf0000000) >> 28);
dstptr[7] = static_cast<int8_t>(tmp);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::S4_FULLRANGE>(int8_t* dstptr, int8_t* srcptr) {
auto src32 = *reinterpret_cast<uint32_t*>(srcptr);
auto tmp = static_cast<int8_t>(src32 & 0xf);
dstptr[0] = tmp - 8;
tmp = static_cast<int8_t>(src32 & 0xf0) >> 4;
dstptr[1] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf00) >> 8);
dstptr[2] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf000) >> 12);
dstptr[3] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf0000) >> 16);
dstptr[4] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf00000) >> 20);
dstptr[5] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf000000) >> 24);
dstptr[6] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf0000000) >> 28);
dstptr[7] = tmp - 8;
convert_s4_s8_8_lowbits(dstptr, srcptr);
for (size_t i = 0; i < 8; i++) {
dstptr[i] -= 8;
}
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_BNB>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_NF4>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_E2M1>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <JBLAS_DTYPE S4_T>
Expand Down
20 changes: 18 additions & 2 deletions neural_speed/cmake/Common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,25 @@ function(add_executable_w_warning TARGET)
warning_check(${TARGET})
endfunction()

function(add_library_w_warning TARGET)
add_library(${TARGET} STATIC ${ARGN})
function(add_library_w_warning_ TARGET)
add_library(${TARGET} ${ARGN})
set_target_properties(${TARGET} PROPERTIES C_STANDARD 11 C_STANDARD_REQUIRED ON C_EXTENSIONS OFF)
set_target_properties(${TARGET} PROPERTIES CXX_STANDARD 11 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF)
warning_check(${TARGET})
endfunction()

function(add_library_w_warning TARGET)
add_library_w_warning_(${TARGET} STATIC ${ARGN})
endfunction()

function(add_shared_library_w_warning TARGET)
add_library_w_warning_(${TARGET} SHARED ${ARGN})
endfunction()

function(add_shareable_library_w_warning TARGET)
if (BUILD_SHARED_LIBS)
add_library_w_warning_(${TARGET} SHARED ${ARGN})
else()
add_library_w_warning_(${TARGET} STATIC ${ARGN})
endif()
endfunction()
3 changes: 3 additions & 0 deletions neural_speed/cmake/ISA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

if (MSVC)
if(NE_F16C)
add_compile_definitions(__F16C__)
endif()
if (NE_AVX512)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
Expand Down
2 changes: 1 addition & 1 deletion neural_speed/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ find_package(Threads REQUIRED)
file(GLOB layers_srcs "layers/*.cpp")
set(sources ne_layers.c ${layers_srcs})

add_library_w_warning(ne_layers "${sources}")
add_shareable_library_w_warning(ne_layers "${sources}")

target_include_directories(ne_layers PUBLIC .)
target_compile_features(ne_layers PUBLIC c_std_11) # don't bump
Expand Down
4 changes: 2 additions & 2 deletions neural_speed/scripts/convert_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,8 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)


SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {'F16': DT_F16, 'F32': DT_F32, 'I32': DT_I32, 'BOOL': DT_BOOL}

SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {'F16': DT_F16, 'F32': DT_F32, 'I32': DT_I32, 'BOOL': DT_BOOL,
'BF16': DT_BF16}

def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
header_size, = struct.unpack('<Q', fp.read(8))
Expand Down

0 comments on commit 6d8bb4a

Please sign in to comment.