forked from DefTruth/CUDA-Learn-Notes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hgemm_wmma.cu
755 lines (655 loc) · 29.6 KB
/
hgemm_wmma.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <mma.h>
#include <torch/types.h>
#include <torch/extension.h>
using namespace nvcuda;
#define WARP_SIZE 32
#define DEVICE_INLINE __device__ inline
#define HOST_DEVICE_INLINE __device__ __host__ inline
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST32BITS(value) (reinterpret_cast<half2*>(&(value))[0])
#define LDST64BITS(value) (reinterpret_cast<float2*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)
#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::)
#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n))
// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes.
#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
// Support A and B matrix with row-major inorder to compare with the kernels using CUDA Cores in
// hgemm.cu and hgemm_async.cu.
HOST_DEVICE_INLINE
int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); }
// only 1 warp per block(32 threads), m16n16k16. A, B, C: all row_major.
template<const int WMMA_M=16, const int WMMA_N=16, const int WMMA_K=16>
__global__ void hgemm_wmma_m16n16k16_naive_kernel(half* A, half* B, half* C,
int M, int N, int K) {
const int NUM_K_TILES = div_ceil(K, WMMA_K);
const int load_gmem_a_m = blockIdx.y * WMMA_M;
const int load_gmem_b_n = blockIdx.x * WMMA_N;
if (load_gmem_a_m >= M && load_gmem_b_n >= N) return;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> C_frag;
wmma::fill_fragment(C_frag, 0.0);
#pragma unroll
for (int k = 0; k < NUM_K_TILES; ++k) {
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag;
wmma::load_matrix_sync(A_frag, A + load_gmem_a_m * K + k * WMMA_K, K);
wmma::load_matrix_sync(B_frag, B + (k * WMMA_K) * N + load_gmem_b_n, N);
wmma::mma_sync(C_frag, A_frag, B_frag, C_frag);
__syncthreads();
}
wmma::store_matrix_sync(C + load_gmem_a_m * N + load_gmem_b_n, C_frag, N,
wmma::mem_row_major);
}
// m16n16k16 wmma + tile MMA with smem, A, B, C: all row_major.
template<const int WMMA_M=16, const int WMMA_N=16, const int WMMA_K=16,
const int WMMA_TILE_M=4, const int WMMA_TILE_N=2>
__global__ void hgemm_wmma_m16n16k16_mma4x2_kernel(
half* A, half* B, half* C, int M, int N, int K) {
// 256 threads(8 warps) per block.
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, WMMA_K);
constexpr int BM = WMMA_M * WMMA_TILE_M; // 16x4=64
constexpr int BN = WMMA_N * WMMA_TILE_N; // 16x2=32
constexpr int BK = WMMA_K; // 16
__shared__ half s_a[BM][BK], s_b[WMMA_K][BN]; // 64x16x2=2KB, 16x32x2=1KB
// 要保证相同的warp下thread执行相同的指令
// warp_id 0 -> warp_m 0, warp_n 0
// warp_id 1 -> warp_m 0, warp_n 1
// warp_id 2 -> warp_m 1, warp_n 0
// warp_id 3 -> warp_m 1, warp_n 1
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
const int lane_id = tid % WARP_SIZE; // 0~31
const int warp_m = warp_id / 2; // 0,1,2,3
const int warp_n = warp_id % 2; // 0,1
// 256线程分别load s_a=64x16, s_b=16x32
// 64*16/256=4, half4, 16x32/256=2, half2
// s_a, 64*16, 每个线程load 4 half, 每行需要4线程,64行,共256线程
const int load_smem_a_m = tid / 4; // 0~63
const int load_smem_a_k = (tid % 4) * 4; // 0,4,12,...
// s_b, 16x32, 每个线程load 2 half, 每行需要8线程,32行,共256线程
const int load_smem_b_k = tid / 16; // 0~16
const int load_smem_b_n = (tid % 16) * 2; // 0,2,4,...,32
const int load_gmem_a_m = by * BM + load_smem_a_m; // global m
const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n
if (load_gmem_a_m >= M && load_gmem_b_n >= N) return;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> C_frag;
wmma::fill_fragment(C_frag, 0.0);
#pragma unroll
for (int k = 0; k < NUM_K_TILES; ++k) {
int load_gmem_a_k = k * WMMA_K + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = k * WMMA_K + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
// 64 bits sync memory issues gmem_a -> smem_a.
LDST64BITS(s_a[load_smem_a_m][load_smem_a_k]) = (
LDST64BITS(A[load_gmem_a_addr]));
// 32 bits sync memory issues gmem_b -> smem_b.
LDST32BITS(s_b[load_smem_b_k][load_smem_b_n]) = (
LDST32BITS(B[load_gmem_b_addr]));
__syncthreads();
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag;
wmma::load_matrix_sync(A_frag, &s_a[warp_m * WMMA_M][0], BK); // BM*BK, BK=WMMA_K
wmma::load_matrix_sync(B_frag, &s_b[0][warp_n * WMMA_N], BN); // BK=BN, BK=WMMA_K
wmma::mma_sync(C_frag, A_frag, B_frag, C_frag);
__syncthreads();
}
const int store_gmem_a_m = by * BM + warp_m * WMMA_M;
const int store_gmem_a_n = bx * BN + warp_n * WMMA_N;
wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag, N,
wmma::mem_row_major);
}
// m16n16k16 wmma + tile MMA with smem, A, B, C: all row_major.
template<const int WMMA_M=16, const int WMMA_N=16, const int WMMA_K=16,
const int WMMA_TILE_M=4, const int WMMA_TILE_N=2,
const int WARP_TILE_M=2, const int WARP_TILE_N=4>
__global__ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_kernel(
half* A, half* B, half* C, int M, int N, int K) {
// 256 threads(8 warps) per block.
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, WMMA_K);
constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 16x4*2=128
constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 16x2*4=128
constexpr int BK = WMMA_K; // 16
__shared__ half s_a[BM][BK], s_b[BK][BN]; // 16x128x2=4KB
// 要保证相同的warp下thread执行相同的指令
// warp_id 0 -> warp_m 0, warp_n 0
// warp_id 1 -> warp_m 0, warp_n 1
// warp_id 2 -> warp_m 1, warp_n 0
// warp_id 3 -> warp_m 1, warp_n 1
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
const int lane_id = tid % WARP_SIZE; // 0~31
const int warp_m = warp_id / 2; // 0,1,2,3
const int warp_n = warp_id % 2; // 0,1
// 0. 先计算shared memory中的索引
// tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=8 按行读取 A行主序
// 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程
int load_smem_a_m = tid / 2; // row 0~127
int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8
// tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序
// 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程
int load_smem_b_k = tid / 16; // row 0~15
int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120
// 1. 再计算全局内存中的索引
// 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
wmma::fragment<wmma::accumulator,
WMMA_M, WMMA_N, WMMA_K,
half> C_frag[WARP_TILE_M][WARP_TILE_N];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
wmma::fill_fragment(C_frag[i][j], 0.0);
}
}
#pragma unroll
for (int k = 0; k < NUM_K_TILES; ++k) {
int load_gmem_a_k = k * WMMA_K + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = k * WMMA_K + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = (
LDST128BITS(B[load_gmem_b_addr]));
LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = (
LDST128BITS(A[load_gmem_a_addr]));
__syncthreads();
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half,
wmma::row_major> A_frag[WARP_TILE_M];
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half,
wmma::row_major> B_frag[WARP_TILE_N];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3
const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;
wmma::load_matrix_sync(A_frag[i], &s_a[warp_smem_a_m][0], BK); // BM*BK, BK=WMMA_K
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2
const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;
wmma::load_matrix_sync(B_frag[j], &s_b[0][warp_smem_b_n], BN); // BM*BK, BK=WMMA_K
}
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);
}
}
__syncthreads();
}
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
const int store_gmem_a_m = by * BM + warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;
const int store_gmem_a_n = bx * BN + warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;
wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag[i][j], N,
wmma::mem_row_major);
}
}
}
// Double buffers
template<const int WMMA_M=16, const int WMMA_N=16, const int WMMA_K=16,
const int WMMA_TILE_M=4, const int WMMA_TILE_N=2,
const int WARP_TILE_M=2, const int WARP_TILE_N=4,
const int OFFSET=0>
__global__ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_kernel(
half* A, half* B, half* C, int M, int N, int K) {
// 256 threads(8 warps) per block.
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, WMMA_K);
constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 16x4*2=128
constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 16x2*4=128
constexpr int BK = WMMA_K; // 16
// 16x128x2=4KB, 4+4=8KB, padding to reduce bank conflicts.
__shared__ half s_a[2][BM][BK+OFFSET], s_b[2][BK][BN+OFFSET];
// 要保证相同的warp下thread执行相同的指令
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
const int lane_id = tid % WARP_SIZE; // 0~31
const int warp_m = warp_id / 2; // 0,1,2,3
const int warp_n = warp_id % 2; // 0,1
// 0. 先计算shared memory中的索引
// tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=8 按行读取 A行主序
// 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程
int load_smem_a_m = tid / 2; // row 0~127
int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8
// tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序
// 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程
int load_smem_b_k = tid / 16; // row 0~15
int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120
// 1. 再计算全局内存中的索引
// 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
wmma::fragment<wmma::accumulator,
WMMA_M, WMMA_N, WMMA_K,
half> C_frag[WARP_TILE_M][WARP_TILE_N];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
wmma::fill_fragment(C_frag[i][j], 0.0);
}
}
// k = 0 is loading here, buffer 0
{
int load_gmem_a_k = load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
uint32_t load_smem_a_ptr = __cvta_generic_to_shared(
&s_a[0][load_smem_a_m][load_smem_a_k]);
CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);
uint32_t load_smem_b_ptr = __cvta_generic_to_shared(
&s_b[0][load_smem_b_k][load_smem_b_n]);
CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
CP_ASYNC_COMMIT_GROUP();
CP_ASYNC_WAIT_GROUP(0);
}
__syncthreads();
#pragma unroll
for (int k = 1; k < NUM_K_TILES; ++k) { // start from 1
int smem_sel = (k - 1) & 1; // k 1->0, k 2->1, k 3->0, ...
int smem_sel_next = k & 1; // k 1->1, k 2->0, k 3->1, ...
int load_gmem_a_k = k * WMMA_K + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = k * WMMA_K + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
uint32_t load_smem_a_ptr = __cvta_generic_to_shared(
&s_a[smem_sel_next][load_smem_a_m][load_smem_a_k]);
CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);
uint32_t load_smem_b_ptr = __cvta_generic_to_shared(
&s_b[smem_sel_next][load_smem_b_k][load_smem_b_n]);
CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half,
wmma::row_major> A_frag[WARP_TILE_M];
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half,
wmma::row_major> B_frag[WARP_TILE_N];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3
const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;
wmma::load_matrix_sync(A_frag[i], &s_a[smem_sel][warp_smem_a_m][0], BK+OFFSET);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2
const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;
wmma::load_matrix_sync(B_frag[j], &s_b[smem_sel][0][warp_smem_b_n], BN+OFFSET);
}
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);
}
}
CP_ASYNC_COMMIT_GROUP();
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
// processing last k tile
{
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half,
wmma::row_major> A_frag[WARP_TILE_M];
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half,
wmma::row_major> B_frag[WARP_TILE_N];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3
const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;
wmma::load_matrix_sync(A_frag[i], &s_a[1][warp_smem_a_m][0], BK+OFFSET);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2
const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;
wmma::load_matrix_sync(B_frag[j], &s_b[1][0][warp_smem_b_n], BN+OFFSET);
}
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);
}
}
}
// finally, store back to C matrix.
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
const int store_gmem_a_m = by * BM + warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;
const int store_gmem_a_n = bx * BN + warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;
wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag[i][j], N,
wmma::mem_row_major);
}
}
}
// m32n8k16/m8n32k16 kernel
template<const int WMMA_M=32, const int WMMA_N=8, const int WMMA_K=16,
const int WMMA_TILE_M=2, const int WMMA_TILE_N=4,
const int WARP_TILE_M=2, const int WARP_TILE_N=4,
const int OFFSET=0>
__global__ void hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async_kernel(
half* A, half* B, half* C, int M, int N, int K) {
// 256 threads(8 warps) per block.
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, WMMA_K);
constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M; // 32x2*2=128
constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N; // 8x4*4=128
constexpr int BK = WMMA_K; // 16
// 16x128x2=4KB, 4+4=8KB, padding to reduce bank conflicts.
__shared__ half s_a[2][BM][BK+OFFSET], s_b[2][BK][BN+OFFSET];
// 要保证相同的warp下thread执行相同的指令
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
const int lane_id = tid % WARP_SIZE; // 0~31
const int warp_m = warp_id / 4; // 0,1
const int warp_n = warp_id % 4; // 0,1,2,3
// 0. 先计算shared memory中的索引
// tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=8 按行读取 A行主序
// 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程
int load_smem_a_m = tid / 2; // row 0~127
int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8
// tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序
// 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程
int load_smem_b_k = tid / 16; // row 0~15
int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120
// 1. 再计算全局内存中的索引
// 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
wmma::fragment<wmma::accumulator,
WMMA_M, WMMA_N, WMMA_K,
half> C_frag[WARP_TILE_M][WARP_TILE_N];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
wmma::fill_fragment(C_frag[i][j], 0.0);
}
}
// k = 0 is loading here, buffer 0
{
int load_gmem_a_k = load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
uint32_t load_smem_a_ptr = __cvta_generic_to_shared(
&s_a[0][load_smem_a_m][load_smem_a_k]);
CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);
uint32_t load_smem_b_ptr = __cvta_generic_to_shared(
&s_b[0][load_smem_b_k][load_smem_b_n]);
CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
CP_ASYNC_COMMIT_GROUP();
CP_ASYNC_WAIT_GROUP(0);
}
__syncthreads();
#pragma unroll
for (int k = 1; k < NUM_K_TILES; ++k) { // start from 1
int smem_sel = (k - 1) & 1; // k 1->0, k 2->1, k 3->0, ...
int smem_sel_next = k & 1; // k 1->1, k 2->0, k 3->1, ...
int load_gmem_a_k = k * WMMA_K + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
int load_gmem_b_k = k * WMMA_K + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
uint32_t load_smem_a_ptr = __cvta_generic_to_shared(
&s_a[smem_sel_next][load_smem_a_m][load_smem_a_k]);
CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);
uint32_t load_smem_b_ptr = __cvta_generic_to_shared(
&s_b[smem_sel_next][load_smem_b_k][load_smem_b_n]);
CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half,
wmma::row_major> A_frag[WARP_TILE_M];
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half,
wmma::row_major> B_frag[WARP_TILE_N];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3
const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;
wmma::load_matrix_sync(A_frag[i], &s_a[smem_sel][warp_smem_a_m][0], BK+OFFSET);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2
const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;
wmma::load_matrix_sync(B_frag[j], &s_b[smem_sel][0][warp_smem_b_n], BN+OFFSET);
}
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);
}
}
CP_ASYNC_COMMIT_GROUP();
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
// processing last k tile
{
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half,
wmma::row_major> A_frag[WARP_TILE_M];
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half,
wmma::row_major> B_frag[WARP_TILE_N];
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3
const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;
wmma::load_matrix_sync(A_frag[i], &s_a[1][warp_smem_a_m][0], BK+OFFSET);
}
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2
const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;
wmma::load_matrix_sync(B_frag[j], &s_b[1][0][warp_smem_b_n], BN+OFFSET);
}
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);
}
}
}
// finally, store back to C matrix.
#pragma unroll
for (int i = 0; i < WARP_TILE_M; ++i) {
#pragma unroll
for (int j = 0; j < WARP_TILE_N; ++j) {
const int store_gmem_a_m = by * BM + warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;
const int store_gmem_a_n = bx * BN + warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;
wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag[i][j], N,
wmma::mem_row_major);
}
}
}
// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
m.def(STRINGFY(func), &func, STRINGFY(func));
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
if(((T).options().dtype() != (th_type))) { \
std::cout << "Tensor Info:" << (T).options() << std::endl; \
throw std::runtime_error("values must be "#th_type); \
}
#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \
if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
throw std::runtime_error("Tensor size mismatch!"); \
}
// 1 warp per block(32 threads), m16n16k16. A, B, C: all row_major.
void hgemm_wmma_m16n16k16_naive(
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
const int M = a.size(0);
const int K = a.size(1);
const int N = b.size(1);
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
constexpr int WMMA_M = 16;
constexpr int WMMA_N = 16;
constexpr int WMMA_K = 16;
dim3 block(WARP_SIZE);
dim3 grid(div_ceil(N, WMMA_N), div_ceil(M, WMMA_M));
hgemm_wmma_m16n16k16_naive_kernel<
WMMA_M, WMMA_N, WMMA_K><<<grid, block>>>(
reinterpret_cast<half*>(a.data_ptr()),
reinterpret_cast<half*>(b.data_ptr()),
reinterpret_cast<half*>(c.data_ptr()),
M, N, K
);
}
void hgemm_wmma_m16n16k16_mma4x2(
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
const int M = a.size(0);
const int K = a.size(1);
const int N = b.size(1);
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
constexpr int WMMA_M = 16;
constexpr int WMMA_N = 16;
constexpr int WMMA_K = 16;
constexpr int WMMA_TILE_M = 4;
constexpr int WMMA_TILE_N = 2;
constexpr int NUM_THREADS= (
WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 4 * 2 * 32 = 256
dim3 block(NUM_THREADS);
dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N),
div_ceil(M, WMMA_M * WMMA_TILE_M));
hgemm_wmma_m16n16k16_mma4x2_kernel<
WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N><<<grid, block>>>(
reinterpret_cast<half*>(a.data_ptr()),
reinterpret_cast<half*>(b.data_ptr()),
reinterpret_cast<half*>(c.data_ptr()),
M, N, K
);
}
void hgemm_wmma_m16n16k16_mma4x2_warp2x4(
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
const int M = a.size(0);
const int K = a.size(1);
const int N = b.size(1);
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
constexpr int WMMA_M = 16;
constexpr int WMMA_N = 16;
constexpr int WMMA_K = 16;
constexpr int WMMA_TILE_M = 4;
constexpr int WMMA_TILE_N = 2;
constexpr int WARP_TILE_M = 2;
constexpr int WARP_TILE_N = 4;
constexpr int NUM_THREADS= (
WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 4 * 2 * 32 = 256
dim3 block(NUM_THREADS);
dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N * WARP_TILE_N),
div_ceil(M, WMMA_M * WMMA_TILE_M * WARP_TILE_M));
hgemm_wmma_m16n16k16_mma4x2_warp2x4_kernel<
WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N,
WARP_TILE_M, WARP_TILE_N><<<grid, block>>>(
reinterpret_cast<half*>(a.data_ptr()),
reinterpret_cast<half*>(b.data_ptr()),
reinterpret_cast<half*>(c.data_ptr()),
M, N, K
);
}
// double buffer, padding
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async(
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
const int M = a.size(0);
const int K = a.size(1);
const int N = b.size(1);
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
constexpr int WMMA_M = 16;
constexpr int WMMA_N = 16;
constexpr int WMMA_K = 16;
constexpr int WMMA_TILE_M = 4;
constexpr int WMMA_TILE_N = 2;
constexpr int WARP_TILE_M = 2;
constexpr int WARP_TILE_N = 4;
constexpr int NUM_THREADS= (
WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 4 * 2 * 32 = 256
dim3 block(NUM_THREADS);
dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N * WARP_TILE_N),
div_ceil(M, WMMA_M * WMMA_TILE_M * WARP_TILE_M));
hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_kernel<
WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N,
WARP_TILE_M, WARP_TILE_N, 8><<<grid, block>>>(
reinterpret_cast<half*>(a.data_ptr()),
reinterpret_cast<half*>(b.data_ptr()),
reinterpret_cast<half*>(c.data_ptr()),
M, N, K
);
}
// m32n8k16
void hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
const int M = a.size(0);
const int K = a.size(1);
const int N = b.size(1);
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
constexpr int WMMA_M = 32;
constexpr int WMMA_N = 8;
constexpr int WMMA_K = 16;
constexpr int WMMA_TILE_M = 2;
constexpr int WMMA_TILE_N = 4;
constexpr int WARP_TILE_M = 2;
constexpr int WARP_TILE_N = 4;
constexpr int NUM_THREADS= (
WMMA_TILE_M * WMMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
dim3 block(NUM_THREADS);
dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N * WARP_TILE_N),
div_ceil(M, WMMA_M * WMMA_TILE_M * WARP_TILE_M));
hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async_kernel<
WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N,
WARP_TILE_M, WARP_TILE_N, 8><<<grid, block>>>(
reinterpret_cast<half*>(a.data_ptr()),
reinterpret_cast<half*>(b.data_ptr()),
reinterpret_cast<half*>(c.data_ptr()),
M, N, K
);
}