diff --git a/nntrainer/tensor/blas_neon.cpp b/nntrainer/tensor/blas_neon.cpp index f442fd8cb6..576bc0e3e7 100644 --- a/nntrainer/tensor/blas_neon.cpp +++ b/nntrainer/tensor/blas_neon.cpp @@ -1588,7 +1588,11 @@ unsigned int isamax(const unsigned int N, const __fp16 *X) { void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N, uint32_t K, float alpha, float beta, bool TransA, bool TransB) { - + if (K == 1) { + unsigned int lda = (TransA) ? M : K; + unsigned int ldb = (TransB) ? K : N; + return hgemm_K1(M, N, K, A, lda, B, ldb, C, N, alpha, beta); + } // dynamic creation to avoid reaching stack limit(causes segmentation fault) float *C32 = (float *)malloc(M * N * sizeof(float)); diff --git a/nntrainer/tensor/hgemm/hgemm.cpp b/nntrainer/tensor/hgemm/hgemm.cpp index b8827d0bd6..52e089943f 100644 --- a/nntrainer/tensor/hgemm/hgemm.cpp +++ b/nntrainer/tensor/hgemm/hgemm.cpp @@ -76,6 +76,23 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, } } +void hgemm_K1(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, __fp16 *C, unsigned int ldc, + float alpha, float beta) { + float16x8_t a_vec; + unsigned int N8 = (N >> 3) << 3; + for (unsigned int m = 0; m < M; ++m) { + a_vec = vmovq_n_f16(A[m]); + for (unsigned int n = 0; n < N8; n += 8) { + vst1q_f16(&C[m * ldc + n], vmulq_f16(a_vec, vld1q_f16(&B[n]))); + } + for (unsigned int n = N8; n < N; ++n) { + C[m * ldc + n] = A[m] * B[n]; + } + } +} + void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K, const __fp16 *A, unsigned int lda, const __fp16 *B, unsigned int ldb, __fp16 *C, unsigned int ldc, diff --git a/nntrainer/tensor/hgemm/hgemm.h b/nntrainer/tensor/hgemm/hgemm.h index 8a071eadf3..085aef4e76 100644 --- a/nntrainer/tensor/hgemm/hgemm.h +++ b/nntrainer/tensor/hgemm/hgemm.h @@ -61,6 +61,25 @@ void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K, unsigned int ldb, float *C, unsigned int ldc, float alpha = 1.F, float beta = 0.F); +/** + * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C, + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param A input matrix A + * @param lda length of the col of matrix A + * @param B input matrix B + * @param ldb length of the col of matrix B + * @param C output matrix C + * @param ldc length of the col of matrix C + * @param[in] alpha float number + * @param[in] beta float number + */ +void hgemm_K1(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, __fp16 *C, unsigned int ldc, + float alpha = 1.F, float beta = 0.F); + /** * @brief hgemm noTrans computation with 1x4 kernel : C = A*B, *