Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hacky nonmult8 for VNNI #90

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ if(INTGEMM_DONT_BUILD_TESTS)
return()
endif()

foreach(exe benchmark biasmultiply benchmark_quantizer)
foreach(exe benchmark biasmultiply benchmark_quantizer non_mult_8)
add_executable(${exe} benchmarks/${exe}.cc)
target_link_libraries(${exe} intgemm)
endforeach()
Expand Down
149 changes: 149 additions & 0 deletions benchmarks/non_mult_8.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#include "../intgemm/aligned.h"
#include "intgemm/intgemm_config.h"
#include "../intgemm/avx512_gemm.h"
#include "../intgemm/sse2_gemm.h"
#include "../intgemm/avx2_gemm.h"
#include "../intgemm/ssse3_gemm.h"
#include "../intgemm/intgemm.h"
#include "../intgemm/stats.h"
#include "../intgemm/callbacks.h"
#include <random>
#include <iostream>

/************************************************************************************ util ************************************************************************************/
template <class T>
int numDigits(T number) {
int digits = 0;
if (number <= 0) {
digits = 1; // count the minus and take care of the zero case
}
while (number) {
number /= 10;
digits++;
}
return digits;
}

template<class intType>
void printMat(intType * a, size_t rows, size_t cols, std::string name, int digits = 0) {
std::cerr << name << std::endl;
for (size_t i = 0; i < rows; i++) {
for (size_t j = 0; j < cols; j++) {
int numbah = (int)a[i*cols + j];
// Pad for nice printing
int mydigits = digits - numDigits(numbah);
for (int t = 0; t < mydigits; t++) {
std::cerr << ' ';
}
std::cerr << numbah << " ";
}
std::cerr << std::endl;
}
std::cerr << std::endl;
}

template<class intType>
void toColMajor(intType *in, intType * out, size_t rows, size_t cols) {
for (size_t i = 0; i < rows; i++) {
for (size_t j = 0; j < cols; j++) {
out[j*rows + i] = in[i*cols + j];
}
}
}

namespace intgemm {
template <class Routine>
void prepBtst(Index width, Index B_cols, float * in = nullptr) {
AlignedVector<float> B(width * B_cols);

//std::mt19937 gen;
//std::uniform_real_distribution<float> dist(-1.0f, 1.0f);

if (in != 0) {
for (Index i = 0; i<width*B_cols; i++) {
B[i] = in[i];
}
} else {
for (Index i = 0; i<width*B_cols; i++) {
B[i] = (float)(i%127);
}
}



float alpha = 127.0f;
float quant_mult = 127.0f / alpha;
//float unquant_mult = 1.0f / (quant_mult*quant_mult);

printMat(B.begin(), width, B_cols, "Raw Mat", 4);

AlignedVector<int8_t> B_prep(B.size());
//AlignedVector<int8_t> B_prep_print(B.size());
Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
printMat(B_prep.begin(), B_cols, width, "Prep Mat", 3);


//toColMajor(B_prep.begin(), B_prep_print.begin(), B_cols, width);
//printMat(B_prep_print.begin(), B_cols, width, "Prep Mat trans", 3);

}

void padMatrixTst(Index width, Index B_cols) {
AlignedVector<float> B(width * B_cols);
std::div_t results = std::div(B_cols, 8);

for (Index i = 0; i<width*B_cols; i++) {
B[i] = (float)(i%127);
}
auto padded = padMatrix(B.begin(), width, B_cols);
printMat(B.begin(), width, B_cols, "Raw Mat", 4);
printMat(padded.begin(), width, 8, "Padded", 4);

auto shrunk = shrinkMat(B.begin(), width, B_cols);
printMat(shrunk.begin(), width, results.quot*8, "Remainder", 4);
prepBtst<SSSE3::Kernels8>(width, 8, padded.begin());
}


template <class Routine>
void smallMultTst(Index A_rows, Index width, Index B_cols) {
AlignedVector<float> A(A_rows* width);
AlignedVector<float> B(width * B_cols);
AlignedVector<float> C(A_rows * B_cols);


for (Index i = 0; i<width*B_cols; i++) {
B[i] = (float)(i%127);
}

for (Index i = 0; i<A_rows*width; i++) {
A[i] = (float)(i%127);
}

float alpha = 127.0f;
float quant_mult = 127.0f / alpha;
float unquant_mult = 1.0f / (quant_mult*quant_mult);

printMat(A.begin(), A_rows, width, "Raw A", 3);
printMat(B.begin(), width, B_cols, "Raw B", 3);

AlignedVector<int8_t> A_prep(A.size());
AlignedVector<int8_t> B_prep(B.size());

Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); // A is strictly positive here
Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
printMat(B_prep.begin(), B_cols, width, "Prep Mat B", 3);

Routine::Multiply8Shift((uint8_t*)A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, C.begin()));
printMat(C.begin(), A_rows, B_cols, "Prep Mat C", 5);

}

} // namespace intgemm;
int main() {
using namespace intgemm;
//prepBtst<SSSE3::Kernels8>(32, 35);
//prepBtst<AVX512VNNI::Kernels8>(64, 9);
//padMatrixTst(32, 35);
smallMultTst<AVX512VNNI::Kernels8>(2, 64, 9);
}
2 changes: 2 additions & 0 deletions intgemm/aligned.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ template <class T> class AlignedVector {
}

AlignedVector(const AlignedVector&) = delete;
AlignedVector(AlignedVector&) = delete;
AlignedVector& operator=(const AlignedVector&) = delete;
AlignedVector& operator=(AlignedVector&) = delete;

~AlignedVector() {
#ifdef _MSC_VER
Expand Down
7 changes: 5 additions & 2 deletions intgemm/avx512_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,22 @@ struct Kernels8 {
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */

INTGEMM_AVX512BW static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) {
assert(size % 16 == 0);
std::div_t result = std::div(size, 16);
assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
const __m512i pos127 = _mm512_set1_epi32(127);
const __m512i zero = _mm512_setzero_si512();
const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
const float *end = input + size;
const float *end = input + result.quot*16; // Do the majority using AVX512
for (; input < end; input += 16, output += 16) {
__m512i asint = QuantizerGrab(input, quant_mult_reg);
asint = _mm512_min_epi32(asint, pos127);
asint = _mm512_add_epi32(asint, pos127);
asint = _mm512_max_epi32(asint, zero);
_mm512_mask_cvtusepi32_storeu_epi8(output, 0xffff, asint);
}
for (int i = 0; i < result.rem; i++) { // Fill in the gaps linearly
output[i] = static_cast<uint8_t>(std::max(roundf(std::max(input[i]*quant_mult, 0.0f)), 255.0f));
}
}

// Tile size for B; B must be a multiple of this block size.
Expand Down
64 changes: 60 additions & 4 deletions intgemm/avx512vnni_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,21 @@ struct Kernels8 : public AVX512BW::Kernels8 {
template <typename Callback>
INTGEMM_AVX512VNNI static void Multiply8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
assert(width % sizeof(Register) == 0);
assert(B_cols % 8 == 0);
std::div_t results = std::div(B_cols, 8);
Index B_cols_trimmed = B_cols;
if (results.rem != 0) {
B_cols_trimmed = results.quot*8;
}
assert(B_cols_trimmed % 8 == 0);
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0);
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const Index simd_width = width / sizeof(Register);
Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
Index B0_colidx = 0; // OMP can't deal with this variable being asigned outside of the loop, hence we declare it once and asign to 0 twice
#pragma omp for
for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
for (B0_colidx = 0; B0_colidx < B_cols_trimmed; B0_colidx += 8) {
const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
Expand Down Expand Up @@ -119,20 +125,51 @@ struct Kernels8 : public AVX512BW::Kernels8 {
callback_impl.Run(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));
}
}
// Final bit, if we have a non-mult-of-eight matrix
if (results.rem != 0) {
const Register *B0_col = reinterpret_cast<const Register*>(B) + (B_cols_trimmed * width)/(sizeof(Register));
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width);
const Register *A_end = A_live + simd_width;
const Register *B_live = B0_col;
// TODO: separate first step.
Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
Register * sums[8] = {&sum0, &sum1, &sum2, &sum3, &sum4, &sum5, &sum6, &sum7};
for (; A_live != A_end; ++A_live, B_live += results.rem) {
Register a = *A_live;
//MultiplyAdd
for (int i = 0; i < results.rem; i++) {
VNNI8(*sums[i], a,*(B_live + i));
}
}
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
callback_impl.RunPartial(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols), (Index)results.rem);
}
}
}

template <typename Callback>
INTGEMM_AVX512VNNI static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) {
assert(width % sizeof(Register) == 0);
assert(B_cols % 8 == 0);
std::div_t results = std::div(B_cols, 8);
Index B_cols_trimmed = B_cols;
if (results.rem != 0) {
B_cols_trimmed = results.quot*8;
}
assert(B_cols_trimmed % 8 == 0);
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
Index simd_width = width / sizeof(Register);
Register zeros = setzero_si<Register>();
const Register a = set1_epi8<Register>(1);
// Go over 8 columns of B at a time.
Index B0_colidx = 0; // OMP can't deal with this variable being asigned outside of the loop, hence we declare it once and asign to 0 twice
#pragma omp for
for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
for (B0_colidx = 0; B0_colidx < B_cols_trimmed; B0_colidx += 8) {
const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function
const Register *B_end = B_live + simd_width*8;
Expand All @@ -155,6 +192,25 @@ struct Kernels8 : public AVX512BW::Kernels8 {
auto total = PermuteSummer(pack0123, pack4567);
callback_impl.Run(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols));
}
// Final bit, if we have a non-mult-of-eight matrix
if (results.rem != 0) {
const Register *B0_col = reinterpret_cast<const Register*>(B) + (B_cols_trimmed * width)/(sizeof(Register));
const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function
const Register *B_end = B_live + simd_width*results.rem;

// TODO: separate first step.
Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
Register * sums[8] = {&sum0, &sum1, &sum2, &sum3, &sum4, &sum5, &sum6, &sum7};
for (; B_live != B_end; B_live += results.rem) {
for (int i = 0; i < results.rem; i++) {
VNNI8(*sums[i], a,*(B_live + i));
}
}
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
callback_impl.RunPartial(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols), (Index)results.rem);
}
}

constexpr static const char *const kName = "8-bit AVX512VNNI";
Expand Down
53 changes: 53 additions & 0 deletions intgemm/callbacks/implementations.inl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@ public:
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}

INTGEMM_TARGET void RunPartial(vi input, const OutputBufferInfo& info, Index partial) {
// Workaround gcc 5 internal compiler error that can't read register members in debug.
vf mult_reg;
#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
#else
mult_reg = unquant_mult;
#endif
auto result = kernels::unquantize(input, mult_reg);
kernels::write_partial(result, config.output_addr, info.row_idx * info.cols + info.col_idx, partial);
}

private:
vf unquant_mult;
UnquantizeAndWrite config;
Expand All @@ -172,6 +184,17 @@ public:
auto result = kernels::relu<float>(kernels::unquantize(input, mult_reg));
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}
INTGEMM_TARGET void RunPartial(vi input, const OutputBufferInfo& info, Index partial) {
// Workaround gcc 5 internal compiler error that can't read register members in debug.
vf mult_reg;
#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
#else
mult_reg = unquant_mult;
#endif
auto result = kernels::relu<float>(kernels::unquantize(input, mult_reg));
kernels::write_partial(result, config.output_addr, info.row_idx * info.cols + info.col_idx, partial);
}

private:
vf unquant_mult;
Expand All @@ -191,6 +214,11 @@ public:
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}

INTGEMM_TARGET void RunPartial(vi input, const OutputBufferInfo& info, Index partial) {
auto result = kernels::add_bias_partial(input, config.bias_addr, info.col_idx, partial);
kernels::write_partial(result, config.output_addr, info.row_idx * info.cols + info.col_idx, partial);
}

private:
AddBiasAndWrite config;
};
Expand All @@ -216,6 +244,18 @@ public:
result = kernels::add_bias(result, config.bias_addr, info.col_idx);
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}
INTGEMM_TARGET void RunPartial(vi input, const OutputBufferInfo& info, Index partial) {
// Workaround gcc 5 internal compiler error that can't read register members in debug.
vf mult_reg;
#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
#else
mult_reg = unquant_mult;
#endif
auto result = kernels::unquantize(input, mult_reg);
result = kernels::add_bias_partial(result, config.bias_addr, info.col_idx, partial);
kernels::write_partial(result, config.output_addr, info.row_idx * info.cols + info.col_idx, partial);
}
private:
vf unquant_mult;
UnquantizeAndAddBiasAndWrite config;
Expand Down Expand Up @@ -243,6 +283,19 @@ public:
result = kernels::relu<float>(result);
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}
INTGEMM_TARGET void RunPartial(vi input, const OutputBufferInfo& info, Index partial) {
// Workaround gcc 5 internal compiler error that can't read register members in debug.
vf mult_reg;
#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
#else
mult_reg = unquant_mult;
#endif
auto result = kernels::unquantize(input, mult_reg);
result = kernels::add_bias_partial(result, config.bias_addr, info.col_idx, partial);
result = kernels::relu<float>(result);
kernels::write_partial(result, config.output_addr, info.row_idx * info.cols + info.col_idx, partial);
}
private:
vf unquant_mult;
UnquantizeAndAddBiasAndWriteRelu config;
Expand Down
Loading