-
Notifications
You must be signed in to change notification settings - Fork 48
/
benchmark_varlen_kvpacked_func.py
243 lines (225 loc) · 6.98 KB
/
benchmark_varlen_kvpacked_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from flash_attn import flash_attn_varlen_kvpacked_func
import os
import torch
import torch.distributed as dist
from ring_flash_attn import (
ring_flash_attn_varlen_kvpacked_func,
zigzag_ring_flash_attn_varlen_kvpacked_func,
llama3_flash_attn_varlen_kvpacked_func,
llama3_flash_attn_prepare_cu_seqlens,
)
def benchmark(
f,
use_double_cu_seqlens,
use_llama3=False,
num_iter=100,
forward_only=True,
log=True,
profile=False,
):
dtype = torch.bfloat16
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
deterministic = False
# config of llama3 8B
seqlen = 1024 * 8
num_heads = 32
num_kv_heads = 8
head_dim = 128
causal = True
assert seqlen % (2 * world_size) == 0
assert head_dim % 8 == 0
q = torch.randn(
seqlen, num_heads, head_dim, device=device, dtype=dtype, requires_grad=True
)
kv = torch.randn(
seqlen,
2,
num_kv_heads,
head_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
dout = torch.randn(seqlen, num_heads, head_dim, device=device, dtype=dtype)
cu_seqlens_list = [
torch.tensor([0, 8192], device=device, dtype=torch.int32),
torch.tensor([0, 256, 7648, 8192], device=device, dtype=torch.int32),
torch.tensor([0, 4096, 8192], device=device, dtype=torch.int32),
torch.tensor(
[0, 3104, 6304, 7904, 8064, 8192], device=device, dtype=torch.int32
),
]
if use_llama3:
cu_seqlens_q_list = []
cu_seqlens_k_list = []
max_seqlen_q_list = []
max_seqlen_k_list = []
local_k_slice_list = []
for cu_seqlens in cu_seqlens_list:
(
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
local_k_slice,
) = llama3_flash_attn_prepare_cu_seqlens(
cu_seqlens * world_size,
causal=causal,
rank=rank,
world_size=world_size,
)
cu_seqlens_q_list.append(cu_seqlens_q)
cu_seqlens_k_list.append(cu_seqlens_k)
max_seqlen_q_list.append(max_seqlen_q)
max_seqlen_k_list.append(max_seqlen_k)
local_k_slice_list.append(local_k_slice)
else:
max_seqlen_list = [
(cu_seqlens[1:] - cu_seqlens[:1]).max().item()
for cu_seqlens in cu_seqlens_list
]
if profile:
torch.backends.cudnn.benchmark = True
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=5,
warmup=5,
active=5,
),
record_shapes=True,
profile_memory=True,
with_flops=True,
with_modules=True,
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
os.path.join(
f"./benchmark/logs/{f.__name__}", f"rank_{dist.get_rank()}"
)
),
)
if profile:
profiler.start()
begin = torch.cuda.Event(enable_timing=True)
begin.record()
def wrapper(i: int):
if use_llama3:
return f(
q,
kv,
cu_seqlens_q_list[i % len(cu_seqlens_list)],
cu_seqlens_k_list[i % len(cu_seqlens_list)],
max_seqlen_q_list[i % len(cu_seqlens_list)],
max_seqlen_k_list[i % len(cu_seqlens_list)],
heads_k_stride=4,
local_k_slice=local_k_slice_list[i % len(cu_seqlens_list)],
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
elif use_double_cu_seqlens:
return f(
q,
kv,
cu_seqlens_list[i % len(cu_seqlens_list)],
cu_seqlens_list[i % len(cu_seqlens_list)],
max_seqlen_list[i % len(cu_seqlens_list)],
max_seqlen_list[i % len(cu_seqlens_list)],
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
else:
return f(
q,
kv,
cu_seqlens_list[i % len(cu_seqlens_list)],
max_seqlen_list[i % len(cu_seqlens_list)],
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
if forward_only:
with torch.no_grad():
for i in range(num_iter):
_ = wrapper(i)
else:
for i in range(num_iter):
q.grad = None
kv.grad = None
out = wrapper(i)
out.backward(dout)
if profile:
profiler.step()
end = torch.cuda.Event(enable_timing=True)
end.record()
torch.cuda.synchronize(device=device)
time = begin.elapsed_time(end) / 1000.0
if profile:
profiler.stop()
if rank == 0 and log:
print(f"{num_iter / time} iter/s, {time} sec")
if __name__ == "__main__":
dist.init_process_group("nccl")
rank = dist.get_rank()
forward_only = False
profile = False
num_iter = 500 if forward_only else 100
for f, use_double_cu_seqlens in [
(flash_attn_varlen_kvpacked_func, True),
(ring_flash_attn_varlen_kvpacked_func, False),
(zigzag_ring_flash_attn_varlen_kvpacked_func, False),
]:
torch.cuda.empty_cache()
if rank == 0:
print(f"# {f.__name__}")
benchmark(
f,
use_double_cu_seqlens,
forward_only=forward_only,
num_iter=num_iter,
log=False,
)
benchmark(
f,
use_double_cu_seqlens,
forward_only=forward_only,
num_iter=num_iter,
log=True,
profile=profile,
)
for f, use_double_cu_seqlens in [
(llama3_flash_attn_varlen_kvpacked_func, True),
]:
torch.cuda.empty_cache()
if rank == 0:
print(f"# {f.__name__}")
benchmark(
f,
use_double_cu_seqlens,
use_llama3=True,
forward_only=forward_only,
num_iter=num_iter,
log=False,
)
benchmark(
f,
use_double_cu_seqlens,
use_llama3=True,
forward_only=forward_only,
num_iter=num_iter,
log=True,
profile=profile,
)