-
Notifications
You must be signed in to change notification settings - Fork 0
/
topk_kernel.cu
151 lines (106 loc) · 4.54 KB
/
topk_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include <iostream>
#include <random>
#include <chrono>
#include <algorithm>
#include <cmath>
#define timestamp(__var__) auto __var__ = std::chrono::system_clock::now();
inline double getDuration(std::chrono::time_point<std::chrono::system_clock> a,
std::chrono::time_point<std::chrono::system_clock> b)
{
return std::chrono::duration<double>(b - a).count();
}
using namespace std;
const int WARPS_PER_BLOCK = 16;
const int N = 232960 >> 8 << 8;
//const int N = 4096;
const int dim_in = 256, dim_out = 64;
__global__ void topk(float *, float *, unsigned int *);
int main() {
cout << "Test TopK kernel" << endl;
cout << "N = "<< N << ", dim_in = " << dim_in << ", dim_out = " << dim_out << ", preparing data..." << endl;
float *data, *value;
unsigned int *indices;
cudaMallocManaged(&data, N * dim_in * sizeof(float));
cudaMallocManaged(&value, N * dim_out * sizeof(float));
cudaMallocManaged(&indices, N * dim_out * sizeof(unsigned int));
default_random_engine engine;
engine.seed(123);
uniform_real_distribution<float> rd(0, 1);
generate(data, data + N * dim_in, [&](){ return rd(engine); });
unsigned int shared_mem_size = WARPS_PER_BLOCK * dim_in * (sizeof(float) + sizeof(unsigned int));
cout<<"Config GridDim = "<< N / WARPS_PER_BLOCK << ", BlockDim = " << WARPS_PER_BLOCK * 32 << ", shared_mem_size = " << shared_mem_size << endl;
dim3 grid(N / WARPS_PER_BLOCK, 1, 1);
dim3 block(WARPS_PER_BLOCK * 32, 1, 1);
int times = 10;
for (int i = 0; i < times; i++) {
topk <<< grid, block, shared_mem_size >>> (data, value, indices);
}
cudaDeviceSynchronize();
double measured_time = 0;
for (int i = 0; i < times; i++) {
timestamp(t0);
topk <<< grid, block, shared_mem_size >>> (data, value, indices);
cudaDeviceSynchronize();
timestamp(t1);
measured_time += getDuration(t0, t1);
}
cout << "top-k time = " << measured_time / times * 1000 << " ms" <<endl;
for (int i = 0; i < 64; i += 1) {
cout << "value[" << i << "] = " << *(value + i) << endl;
}
for (int i = 0; i < 64; i += 1) {
cout << "indices[" << i << "] = " << *(indices + i) << endl;
}
cudaFree(data);
cudaFree(value);
cudaFree(indices);
return 0;
}
__global__ void topk(float *data, float *value, unsigned int *indices) {
extern __shared__ float buffer[];
unsigned int *track = (unsigned int*) &buffer[WARPS_PER_BLOCK * dim_in];
const int warp_id = threadIdx.x / 32;
const int local_tid = threadIdx.x % 32;
const int warp_offset = WARPS_PER_BLOCK * dim_in;
const int feature_per_warp = dim_in / 32;
float v_holder;
unsigned int idx_holder;
#pragma unroll
for (unsigned int i = 0; i < feature_per_warp; i += 1) {
buffer[warp_id * dim_in + feature_per_warp * local_tid + i] = data[blockIdx.x * warp_offset + warp_id * dim_in + feature_per_warp * local_tid + i];
track[warp_id * dim_in + feature_per_warp * local_tid + i] = local_tid * feature_per_warp + i;
}
__syncwarp();
#pragma unroll
for (int iter = 0; iter < dim_in / 2; iter += 1) {
for (int i = 0; i < dim_in; i += 64) {
int curr_idx = warp_id * dim_in + 2 * local_tid + i;
if (buffer[curr_idx] < buffer[curr_idx + 1]) {
v_holder = buffer[curr_idx];
buffer[curr_idx] = buffer[curr_idx + 1];
buffer[curr_idx + 1] = v_holder;
idx_holder = track[curr_idx];
track[curr_idx] = track[curr_idx + 1];
track[curr_idx + 1] = idx_holder;
}
}
__syncwarp();
for (int i = 0; i < dim_in; i += 64) {
int curr_idx = warp_id * dim_in + 2 * local_tid + i + 1;
if (curr_idx < dim_in - 1 && buffer[curr_idx] < buffer[curr_idx + 1]) {
v_holder = buffer[curr_idx];
buffer[curr_idx] = buffer[curr_idx + 1];
buffer[curr_idx + 1] = v_holder;
idx_holder = track[curr_idx];
track[curr_idx] = track[curr_idx + 1];
track[curr_idx + 1] = idx_holder;
}
}
__syncwarp();
}
__syncwarp();
for (int i = 0; i < dim_out; i += 32) {
value[blockIdx.x * WARPS_PER_BLOCK * dim_out + warp_id * dim_out + local_tid + i] = buffer[warp_id * dim_in + local_tid + i];
indices[blockIdx.x * WARPS_PER_BLOCK * dim_out + warp_id * dim_out + local_tid + i] = track[warp_id * dim_in + local_tid + i];
}
}