-
Notifications
You must be signed in to change notification settings - Fork 0
/
dequant.py
118 lines (95 loc) · 3.1 KB
/
dequant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import triton
from triton import language as tl
import itertools
def make_dequant_configs(block_sizes, num_warps):
configs = []
for bs, ws in itertools.product(block_sizes, num_warps):
configs.append(triton.Config({"X_BLOCK": bs}, num_warps=ws))
return configs
DEFAULT_DEQUANT_CONFIGS = make_dequant_configs([128, 256, 512, 1024], [4, 8])
# @triton.autotune(DEFAULT_DEQUANT_CONFIGS, key=["numels"])
@triton.jit
def dequant_kernel_248(
g_idx_ptr,
scales_ptr,
qweight_ptr,
qzeros_ptr,
out_ptr,
numels,
maxq: tl.constexpr,
bits: tl.constexpr,
outfeatures: tl.constexpr,
num_groups: tl.constexpr,
X_BLOCK: tl.constexpr,
):
# Block indexing
xoffset = tl.program_id(0) * X_BLOCK
x_index = xoffset + tl.arange(0, X_BLOCK)
xmask = x_index < numels
row_idx = x_index // outfeatures
col_idx = x_index % outfeatures
elements_per_feature: tl.constexpr = 32 // bits
# Load parameters
g_idx = tl.load(g_idx_ptr + (row_idx), None, eviction_policy="evict_last")
qweights = tl.load(
qweight_ptr + (col_idx + (outfeatures * (row_idx // elements_per_feature))),
None,
)
wf_weights = (row_idx % elements_per_feature) * bits
wf_zeros = (col_idx % elements_per_feature) * bits
tmp1 = g_idx + num_groups
tmp2 = g_idx < 0
tl.device_assert(g_idx >= 0, "index out of bounds: 0 <= tmp0 < 0")
groups = tl.where(tmp2, tmp1, g_idx) # tmp3 are g_idx
scales = tl.load(scales_ptr + (col_idx + (outfeatures * groups)), None).to(
tl.float32
)
# Unpack weights
weights = qweights >> wf_weights # bit shift qweight
weights = weights & maxq
# Unpack zeros
qzero_ncols: tl.constexpr = outfeatures // elements_per_feature
qzeros = tl.load(
qzeros_ptr + ((qzero_ncols * groups) + (col_idx // elements_per_feature)),
None,
eviction_policy="evict_last",
)
zeros = qzeros >> wf_zeros
zeros = zeros & maxq
# Dequantize
zeros = zeros + 1
weights = weights - zeros
weights = weights.to(tl.float32)
weights = scales * weights
tl.store(out_ptr + (x_index), weights, mask=xmask)
def dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None):
"""
Launcher for triton dequant kernel. Only valid for bits = 2, 4, 8
"""
num_groups = scales.shape[0]
outfeatures = scales.shape[1]
infeatures = g_idx.shape[0]
out = torch.empty((infeatures, outfeatures), device="cuda", dtype=torch.float16)
numels = out.numel()
maxq = 2**bits - 1 if maxq is None else maxq
grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731
dequant_kernel_248[grid](
g_idx,
scales,
qweight,
qzeros,
out,
numels,
maxq=maxq,
bits=bits,
outfeatures=outfeatures,
num_groups=num_groups,
)
return out
def quant_matmul_248(
input, qweight, scales, qzeros, g_idx, bits, maxq=None, transpose=False
):
W = dequant248(qweight, scales, qzeros, g_idx, bits, maxq=maxq)
if transpose:
return input @ W.t()
return input @ W