Skip to content

Commit

Permalink
[ hgemm ] Generalize redundant micro hgemm kernel implementation
Browse files Browse the repository at this point in the history
- Previous implementation naively used fixed-sized ukernels for the K-direction accumulation.
- Such kernels were excessively long, but had better performance than looping through single K-iteration.
- However, recent test results have shown that justing stacking 4 K iters, and looping through such ukernel preserved the performance with better code readability.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <[email protected]>
  • Loading branch information
skykongkong8 committed Aug 9, 2024
1 parent 23c0983 commit 097d83e
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 1,650 deletions.
239 changes: 36 additions & 203 deletions nntrainer/tensor/hgemm/hgemm_kernel/hgemm_kernel_4x8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,205 +25,41 @@
v9 = vdupq_n_f16(0.F); \
} while (0)

// 1. Partial sum 256 digits
#define KERNEL_4x8_ACC16() \
do { \
dv0 = vld1_f16(a); \
v24 = vld1q_f16(b); \
v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
dv1 = vld1_f16(a + 4); \
v25 = vld1q_f16(b + 8); \
v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
dv2 = vld1_f16(a + 4 * 2); \
v26 = vld1q_f16(b + 8 * 2); \
v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
dv3 = vld1_f16(a + 4 * 3); \
v27 = vld1q_f16(b + 8 * 3); \
v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
dv4 = vld1_f16(a + 4 * 4); \
v28 = vld1q_f16(b + 8 * 4); \
v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \
v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \
v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \
v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \
dv5 = vld1_f16(a + 4 * 5); \
v29 = vld1q_f16(b + 8 * 5); \
v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \
v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \
v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \
v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \
dv6 = vld1_f16(a + 4 * 6); \
v30 = vld1q_f16(b + 8 * 6); \
v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \
v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \
v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \
v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \
dv7 = vld1_f16(a + 4 * 7); \
v31 = vld1q_f16(b + 8 * 7); \
v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
dv7 = vld1_f16(a + 4 * 8); \
v31 = vld1q_f16(b + 8 * 8); \
v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
dv7 = vld1_f16(a + 4 * 9); \
v31 = vld1q_f16(b + 8 * 9); \
v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
dv7 = vld1_f16(a + 4 * 10); \
v31 = vld1q_f16(b + 8 * 10); \
v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
dv7 = vld1_f16(a + 4 * 11); \
v31 = vld1q_f16(b + 8 * 11); \
v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
dv7 = vld1_f16(a + 4 * 12); \
v31 = vld1q_f16(b + 8 * 12); \
v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
dv7 = vld1_f16(a + 4 * 13); \
v31 = vld1q_f16(b + 8 * 13); \
v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
dv7 = vld1_f16(a + 4 * 14); \
v31 = vld1q_f16(b + 8 * 14); \
v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
dv7 = vld1_f16(a + 4 * 15); \
v31 = vld1q_f16(b + 8 * 15); \
v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
l += 16; \
__builtin_prefetch(b + 128, 0, 3); \
__builtin_prefetch(a + 64, 0, 3); \
b += 8 * 16; \
a += 4 * 16; \
#define KERNEL_4x8_ACC_N4(N) \
do { \
for (int i = 0; i < N; i += 4) { \
dv0 = vld1_f16(a + 4 * i); \
v24 = vld1q_f16(b + 8 * i); \
v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
dv1 = vld1_f16(a + 4 * i + 4); \
v25 = vld1q_f16(b + 8 * i + 8); \
v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
dv2 = vld1_f16(a + 4 * i + 8); \
v26 = vld1q_f16(b + 8 * i + 16); \
v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
dv3 = vld1_f16(a + 4 * i + 12); \
v27 = vld1q_f16(b + 8 * i + 24); \
v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
} \
l += N; \
__builtin_prefetch(b + 8 * N, 0, 3); \
__builtin_prefetch(a + 4 * N, 0, 3); \
b += 8 * N; \
a += 4 * N; \
} while (0)

// 1. Partial sum 256 digits
#define KERNEL_4x8_ACC8() \
do { \
dv0 = vld1_f16(a); \
v24 = vld1q_f16(b); \
v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
dv1 = vld1_f16(a + 4); \
v25 = vld1q_f16(b + 8); \
v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
dv2 = vld1_f16(a + 8); \
v26 = vld1q_f16(b + 16); \
v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
dv3 = vld1_f16(a + 12); \
v27 = vld1q_f16(b + 24); \
v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
dv4 = vld1_f16(a + 16); \
v28 = vld1q_f16(b + 32); \
v0 = vfmaq_lane_f16(v0, v28, dv4, 0); \
v3 = vfmaq_lane_f16(v3, v28, dv4, 1); \
v6 = vfmaq_lane_f16(v6, v28, dv4, 2); \
v9 = vfmaq_lane_f16(v9, v28, dv4, 3); \
dv5 = vld1_f16(a + 20); \
v29 = vld1q_f16(b + 40); \
v0 = vfmaq_lane_f16(v0, v29, dv5, 0); \
v3 = vfmaq_lane_f16(v3, v29, dv5, 1); \
v6 = vfmaq_lane_f16(v6, v29, dv5, 2); \
v9 = vfmaq_lane_f16(v9, v29, dv5, 3); \
dv6 = vld1_f16(a + 24); \
v30 = vld1q_f16(b + 48); \
v0 = vfmaq_lane_f16(v0, v30, dv6, 0); \
v3 = vfmaq_lane_f16(v3, v30, dv6, 1); \
v6 = vfmaq_lane_f16(v6, v30, dv6, 2); \
v9 = vfmaq_lane_f16(v9, v30, dv6, 3); \
dv7 = vld1_f16(a + 28); \
v31 = vld1q_f16(b + 56); \
v0 = vfmaq_lane_f16(v0, v31, dv7, 0); \
v3 = vfmaq_lane_f16(v3, v31, dv7, 1); \
v6 = vfmaq_lane_f16(v6, v31, dv7, 2); \
v9 = vfmaq_lane_f16(v9, v31, dv7, 3); \
l += 8; \
__builtin_prefetch(b + 64, 0, 3); \
__builtin_prefetch(a + 32, 0, 3); \
b += 8 * 8; \
a += 4 * 8; \
} while (0)

// 2. Partial sum 128 digits
#define KERNEL_4x8_ACC4() \
do { \
dv0 = vld1_f16(a); \
v24 = vld1q_f16(b); \
v0 = vfmaq_lane_f16(v0, v24, dv0, 0); \
v3 = vfmaq_lane_f16(v3, v24, dv0, 1); \
v6 = vfmaq_lane_f16(v6, v24, dv0, 2); \
v9 = vfmaq_lane_f16(v9, v24, dv0, 3); \
dv1 = vld1_f16(a + 4); \
v25 = vld1q_f16(b + 8); \
v0 = vfmaq_lane_f16(v0, v25, dv1, 0); \
v3 = vfmaq_lane_f16(v3, v25, dv1, 1); \
v6 = vfmaq_lane_f16(v6, v25, dv1, 2); \
v9 = vfmaq_lane_f16(v9, v25, dv1, 3); \
dv2 = vld1_f16(a + 8); \
v26 = vld1q_f16(b + 16); \
v0 = vfmaq_lane_f16(v0, v26, dv2, 0); \
v3 = vfmaq_lane_f16(v3, v26, dv2, 1); \
v6 = vfmaq_lane_f16(v6, v26, dv2, 2); \
v9 = vfmaq_lane_f16(v9, v26, dv2, 3); \
dv3 = vld1_f16(a + 12); \
v27 = vld1q_f16(b + 24); \
v0 = vfmaq_lane_f16(v0, v27, dv3, 0); \
v3 = vfmaq_lane_f16(v3, v27, dv3, 1); \
v6 = vfmaq_lane_f16(v6, v27, dv3, 2); \
v9 = vfmaq_lane_f16(v9, v27, dv3, 3); \
l += 4; \
__builtin_prefetch(b + 32, 0, 3); \
__builtin_prefetch(a + 16, 0, 3); \
b += 8 * 4; \
a += 4 * 4; \
} while (0)

// 3. Partial sum 32 digits
#define KERNEL_4x8_ACC1() \
do { \
dv0 = vld1_f16(a); \
Expand Down Expand Up @@ -277,9 +113,6 @@ void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
INIT_KERNEL_4X8();
l = 0;
for (; l < K8;) {
KERNEL_4x8_ACC8();
}
for (; l < K;) {
KERNEL_4x8_ACC1();
}
Expand Down Expand Up @@ -319,17 +152,17 @@ void hgemm_kernel_4x8(unsigned int M, unsigned int N, unsigned int K,
l = 0;
for (; l < K16;) {
INIT_KERNEL_4X8();
KERNEL_4x8_ACC16();
KERNEL_4x8_ACC_N4(16);
SAVE_KERNEL_4X8_F16_F32();
}
for (; l < K8;) {
INIT_KERNEL_4X8();
KERNEL_4x8_ACC8();
KERNEL_4x8_ACC_N4(8);
SAVE_KERNEL_4X8_F16_F32();
}
for (; l < K4;) {
INIT_KERNEL_4X8();
KERNEL_4x8_ACC4();
KERNEL_4x8_ACC_N4(4);
SAVE_KERNEL_4X8_F16_F32();
}
for (; l < K;) {
Expand Down
Loading

0 comments on commit 097d83e

Please sign in to comment.