forked from DefTruth/CUDA-Learn-Notes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedding.cu
131 lines (118 loc) · 5.5 KB
/
embedding.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <torch/types.h>
#include <torch/extension.h>
#define FLOAT4(value) (reinterpret_cast<float4 *>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])
__global__ void embedding_f32_kernel(const int *idx, float *weight, float *output, int n, int emb_size)
{
int tx = threadIdx.x;
int bx = blockIdx.x;
int tid = bx * blockDim.x + tx;
int offset = idx[bx] * emb_size;
output[bx * emb_size + tx] = weight[offset + tx];
}
__global__ void embedding_f32x4_kernel(const int *idx, float *weight, float *output, int n, int emb_size)
{
int tx = threadIdx.x;
int bx = blockIdx.x;
int tid = bx * blockDim.x + tx;
int offset = idx[bx] * emb_size;
output[bx * emb_size + tx] = weight[offset + tx];
output[bx * emb_size + tx + 1] = weight[offset + tx + 1];
output[bx * emb_size + tx + 2] = weight[offset + tx + 2];
output[bx * emb_size + tx + 3] = weight[offset + tx + 3];
}
__global__ void embedding_f32x4_pack_kernel(const int *idx, float *weight, float *output, int n, int emb_size)
{
int tx = threadIdx.x;
int bx = blockIdx.x;
int tid = bx * blockDim.x + tx;
int offset = idx[bx] * emb_size;
LDST128BITS(output[bx * emb_size + 4 * tx]) = LDST128BITS(weight[offset + 4 * tx]);
}
__global__ void embedding_f16_kernel(const int *idx, half *weight, half *output, int n, int emb_size)
{
int tx = threadIdx.x;
int bx = blockIdx.x;
int tid = bx * blockDim.x + tx;
int offset = idx[bx] * emb_size;
output[bx * emb_size + tx] = weight[offset + tx];
}
__global__ void embedding_f16x8_kernel(const int *idx, half *weight, half *output, int n, int emb_size)
{
int tx = threadIdx.x;
int bx = blockIdx.x;
int tid = bx * blockDim.x + tx;
int offset = idx[bx] * emb_size;
output[bx * emb_size + tx] = weight[offset + tx];
output[bx * emb_size + tx + 1] = weight[offset + tx + 1];
output[bx * emb_size + tx + 2] = weight[offset + tx + 2];
output[bx * emb_size + tx + 3] = weight[offset + tx + 3];
output[bx * emb_size + tx + 4] = weight[offset + tx + 4];
output[bx * emb_size + tx + 5] = weight[offset + tx + 5];
output[bx * emb_size + tx + 6] = weight[offset + tx + 6];
output[bx * emb_size + tx + 7] = weight[offset + tx + 7];
}
__global__ void embedding_f16x8_pack_kernel(const int *idx, half *weight, half *output, int n, int emb_size)
{
int tx = threadIdx.x;
int bx = blockIdx.x;
int tid = bx * blockDim.x + tx;
int offset = idx[bx] * emb_size;
LDST128BITS(output[bx * emb_size + 8 * tx]) = LDST128BITS(weight[offset + 8 * tx]);
}
// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
m.def(STRINGFY(func), &func, STRINGFY(func));
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
if (((T).options().dtype() != (th_type))) \
{ \
std::cout << "Tensor Info:" << (T).options() << std::endl; \
throw std::runtime_error("values must be " #th_type); \
}
#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \
if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) \
{ \
throw std::runtime_error("Tensor size mismatch!"); \
}
#define TORCH_BINDING_EMBEDDING(packed_type, th_type, element_type, n_elements) \
void embedding_##packed_type( \
torch::Tensor a, torch::Tensor weight, torch::Tensor o) \
{ \
CHECK_TORCH_TENSOR_DTYPE(a, (torch::kInt32)); \
CHECK_TORCH_TENSOR_DTYPE(weight, (th_type)); \
CHECK_TORCH_TENSOR_DTYPE(o, (th_type)); \
\
const int N = a.size(0); \
const int emb_size = weight.size(1); \
dim3 block(emb_size / n_elements); \
dim3 grid(N); \
embedding_##packed_type##_kernel<<<grid, block>>>( \
reinterpret_cast<int *>(a.data_ptr()), \
reinterpret_cast<element_type *>(weight.data_ptr()), \
reinterpret_cast<element_type *>(o.data_ptr()), N, emb_size); \
}
TORCH_BINDING_EMBEDDING(f32, torch::kFloat32, float, 1)
TORCH_BINDING_EMBEDDING(f32x4, torch::kFloat32, float, 4)
TORCH_BINDING_EMBEDDING(f32x4_pack, torch::kFloat32, float, 4)
TORCH_BINDING_EMBEDDING(f16, torch::kHalf, half, 1)
TORCH_BINDING_EMBEDDING(f16x8, torch::kHalf, half, 8)
TORCH_BINDING_EMBEDDING(f16x8_pack, torch::kHalf, half, 8)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
TORCH_BINDING_COMMON_EXTENSION(embedding_f32);
TORCH_BINDING_COMMON_EXTENSION(embedding_f32x4);
TORCH_BINDING_COMMON_EXTENSION(embedding_f32x4_pack);
TORCH_BINDING_COMMON_EXTENSION(embedding_f16);
TORCH_BINDING_COMMON_EXTENSION(embedding_f16x8);
TORCH_BINDING_COMMON_EXTENSION(embedding_f16x8_pack);
}