Skip to content

Commit

Permalink
[ BLAS ] Implement transpose case functions for K=1 GEMM
Browse files Browse the repository at this point in the history
- To cover transpose cases like, (1,M).T * (1,N) and all other transpose combinations, transpose with SIMD, and apply the original kernel

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <[email protected]>
  • Loading branch information
skykongkong8 committed Jun 28, 2024
1 parent 7abcbd3 commit a2d0536
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 9 deletions.
20 changes: 17 additions & 3 deletions nntrainer/tensor/blas_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1589,9 +1589,7 @@ unsigned int isamax(const unsigned int N, const __fp16 *X) {
void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
uint32_t K, float alpha, float beta, bool TransA, bool TransB) {
if (K == 1) {
unsigned int lda = (TransA) ? M : K;
unsigned int ldb = (TransB) ? K : N;
return hgemm_K1(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB);
}
// dynamic creation to avoid reaching stack limit(causes segmentation fault)
float *C32 = (float *)malloc(M * N * sizeof(float));
Expand Down Expand Up @@ -1644,6 +1642,22 @@ void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
free(C32);
}

void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
bool TransB) {
unsigned int lda = (TransA) ? M : K;
unsigned int ldb = (TransB) ? K : N;
if (!TransA && TransB) {
hgemm_K1_transB(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
} else if (TransA && !TransB) {
hgemm_K1_transA(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
} else if (!TransA && !TransB) {
hgemm_K1_noTrans(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
} else { // TransA && TransB
hgemm_K1_transAB(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
}
}

void ele_mul(const unsigned int N, const __fp16 *X, const __fp16 *Y, __fp16 *Z,
float alpha, float beta) {
unsigned int i = 0;
Expand Down
16 changes: 16 additions & 0 deletions nntrainer/tensor/blas_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,22 @@ unsigned int isamax(const unsigned int N, const __fp16 *X);
void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
uint32_t K, float alpha, float beta, bool TransA, bool TransB);

/**
* @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
* where op(X) is one of X or X**T
* @param[in] A __fp16 * for Matrix A
* @param[in] B __fp16 * for Matrix B
* @param[in] C __fp16 * for Matrix C
* @param[in] M number of op(A)'s and C's row
* @param[in] N number of op(B)'s and C's columns
* @param[in] K number of op(A)'s and columns and op(B)'s rows
* @param[in] alpha float number
* @param[in] beta float number
*/
void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
bool TransB);

/**
* @brief squared root transformation with neon : X = sqrt(X)
*
Expand Down
61 changes: 56 additions & 5 deletions nntrainer/tensor/hgemm/hgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,74 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
}
}

void hgemm_K1(unsigned int M, unsigned int N, unsigned int K,
void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
float alpha, float beta) {
const float eps = std::numeric_limits<float>::epsilon();
float16x8_t a_vec;
unsigned int N8 = (N >> 3) << 3;
for (unsigned int m = 0; m < M; ++m) {
a_vec = vmovq_n_f16(A[m]);
for (unsigned int n = 0; n < N8; n += 8) {
vst1q_f16(&C[m * ldc + n], vmulq_f16(a_vec, vld1q_f16(&B[n])));
a_vec = vmovq_n_f16(alpha * A[m]);
if (std::fpclassify(beta) != FP_ZERO) {
for (unsigned int n = 0; n < N8; n += 8) {
vst1q_f16(&C[m * ldc + n],
vaddq_f16(vmulq_f16(a_vec, vld1q_f16(&B[n])),
vmulq_n_f16(vld1q_f16(&C[m * ldc + n]), beta)));
}
} else {
for (unsigned int n = 0; n < N8; n += 8) {
vst1q_f16(&C[m * ldc + n], vmulq_f16(a_vec, vld1q_f16(&B[n])));
}
}
for (unsigned int n = N8; n < N; ++n) {
C[m * ldc + n] = A[m] * B[n];
C[m * ldc + n] = alpha * A[m] * B[n] + beta * C[m * ldc + n];
}
}
}

void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
float beta) {
__fp16 *A_T = new __fp16[M * K];

transpose_neon<__fp16>(K, M, A, M, A_T, K);

hgemm_K1_noTrans(M, N, K, A_T, lda, B, ldb, C, ldc, alpha, beta);

free(A_T);
}

void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
float beta) {
__fp16 *B_T = new __fp16[K * N];

transpose_neon<__fp16>(N, K, B, K, B_T, N);

hgemm_K1_noTrans(M, N, K, A, lda, B_T, ldb, C, ldc, alpha, beta);

free(B_T);
}

void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
float alpha, float beta) {
__fp16 *A_T = new __fp16[M * K];
__fp16 *B_T = new __fp16[K * N];

transpose_neon<__fp16>(K, M, A, M, A_T, K);
transpose_neon<__fp16>(N, K, B, K, B_T, N);

hgemm_K1_noTrans(M, N, K, A_T, lda, B_T, ldb, C, ldc, alpha, beta);

free(A_T);
free(B_T);
}

void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
Expand Down
56 changes: 55 additions & 1 deletion nntrainer/tensor/hgemm/hgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,61 @@ void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
* @param[in] alpha float number
* @param[in] beta float number
*/
void hgemm_K1(unsigned int M, unsigned int N, unsigned int K,
void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
float alpha = 1.F, float beta = 0.F);
/**
* @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
* @param M length of the row of matrix A
* @param N length of the col of matrix B
* @param K length of the col of matrix A
* @param A input matrix A
* @param lda length of the col of matrix A
* @param B input matrix B
* @param ldb length of the col of matrix B
* @param C output matrix C
* @param ldc length of the col of matrix C
* @param[in] alpha float number
* @param[in] beta float number
*/
void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
float alpha = 1.F, float beta = 0.F);
/**
* @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
* @param M length of the row of matrix A
* @param N length of the col of matrix B
* @param K length of the col of matrix A
* @param A input matrix A
* @param lda length of the col of matrix A
* @param B input matrix B
* @param ldb length of the col of matrix B
* @param C output matrix C
* @param ldc length of the col of matrix C
* @param[in] alpha float number
* @param[in] beta float number
*/
void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
float alpha = 1.F, float beta = 0.F);
/**
* @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
* @param M length of the row of matrix A
* @param N length of the col of matrix B
* @param K length of the col of matrix A
* @param A input matrix A
* @param lda length of the col of matrix A
* @param B input matrix B
* @param ldb length of the col of matrix B
* @param C output matrix C
* @param ldc length of the col of matrix C
* @param[in] alpha float number
* @param[in] beta float number
*/
void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
float alpha = 1.F, float beta = 0.F);
Expand Down

0 comments on commit a2d0536

Please sign in to comment.