diff --git a/nntrainer/tensor/matrix_transpose_neon/matrix_transpose_kernels_neon.h b/nntrainer/tensor/matrix_transpose_neon/matrix_transpose_kernels_neon.h index fd8290eca3..ad9125155c 100644 --- a/nntrainer/tensor/matrix_transpose_neon/matrix_transpose_kernels_neon.h +++ b/nntrainer/tensor/matrix_transpose_neon/matrix_transpose_kernels_neon.h @@ -80,7 +80,11 @@ static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src, unsigned i; for (i = 0; i < M; ++i) { - input[i] = vbsl_f16(bitmask_v8, vld1_f16(&src[i * ld_src]), ZEROS); + float16x4_t tmp = ZEROS; + for (unsigned int n = 0; n < N; ++n) { + tmp[n] = src[i * ld_src + n]; + } + input[i] = tmp; } for (; i < 4; ++i) { input[i] = vmov_n_f16(0.F); @@ -95,7 +99,6 @@ static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src, temp[i] = vmov_n_f16(0.F); } - bitmask_v8 = vld1_u16(reinterpret_cast(masks[M])); for (i = 0; i < N; ++i) { if (i % 2 == 0) { input[i] = @@ -106,10 +109,12 @@ static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src, vcombine_f32(vget_high_f32(vcvt_f32_f16(temp[i / 2])), vget_high_f32(vcvt_f32_f16(temp[2 + i / 2])))); } - vst1_f16(&dst[i * ld_dst], - vbsl_f16(bitmask_v8, input[i], vld1_f16(&dst[i * ld_dst]))); + for (unsigned int m = 0; m < M; ++m) { + dst[i * ld_dst + m] = input[i][m]; + } } } + /** * @brief 8x8 sized kernel for matrix transpose in NEON * @@ -182,6 +187,7 @@ static inline void transpose_kernel_8x8_neon(const __fp16 *src, vst1q_f16(&dst[6 * ld_dst], g); vst1q_f16(&dst[7 * ld_dst], h); } + /** * @brief general case mxn sized matrix transpose kernel with 256 bit SIMD * register @@ -203,7 +209,15 @@ static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src, float16x8_t input[8]; unsigned i; for (i = 0; i < M; ++i) { - input[i] = vbslq_f16(bitmask_v8, vld1q_f16(&src[i * ld_src]), ZEROS); + if (N == 8) { + input[i] = vld1q_f16(&src[i * ld_src]); + } else { + float16x8_t tmp = ZEROS; + for (unsigned int n = 0; n < N; ++n) { + tmp[n] = src[i * ld_src + n]; + } + input[i] = tmp; + } } for (; i < 8; ++i) { input[i] = ZEROS; @@ -235,8 +249,6 @@ static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src, vbslq_f16(shuffle_mask, vextq_f16(temp[4 * i + 1], temp[4 * i + 1], 2), temp[4 * i + 3]); } - bitmask_v8 = - vld1q_u16(reinterpret_cast(neon_16bit_masks[M])); for (i = 0; i < N; ++i) { if (i < 4) { temp[i] = @@ -245,7 +257,14 @@ static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src, temp[i] = vcombine_f16(vget_high_f16(input[i - 4]), vget_high_f16(input[i])); } - vst1q_f16(&dst[i * ld_dst], - vbslq_f16(bitmask_v8, temp[i], vld1q_f16(&dst[i * ld_dst]))); + bitmask_v8 = + vld1q_u16(reinterpret_cast(neon_16bit_masks[M])); + if (M == 8) { + vst1q_f16(&dst[i * ld_dst], temp[i]); + } else { + for (unsigned int m = 0; m < M; ++m) { + dst[i * ld_dst + m] = temp[i][m]; + } + } } }