Skip to content

Commit

Permalink
skip tests
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Nov 26, 2024
1 parent 18e6b46 commit 9ba5656
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ flash_attention:
- triton_tutorial_flash_v2_tma
- triton_tutorial_flash_v2_ws
- triton_tutorial_flash_v2_tma_ws
- triton_tutorial_flash_v2_tma_ws_persistent
fp8_attention:
- colfax_fmha
# triton_flash_v2 requires triton-main
Expand Down
1 change: 1 addition & 0 deletions test/test_gpu/skip_tests_h100_triton_main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ flash_attention:
# _ws kernels require Triton with warp specialization
- triton_tutorial_flash_v2_ws
- triton_tutorial_flash_v2_tma_ws
- triton_tutorial_flash_v2_tma_ws_persistent
fp8_attention:
# fb-only kernel
- colfax_fmha
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _attn_fwd_inner(
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# loop over k, v and update accumulator
for start_n in tl.range(lo, hi, BLOCK_N, loop_schedule=LOOP_SCHEDULE):
for start_n in tl.range(lo, hi, BLOCK_N): #, loop_schedule=LOOP_SCHEDULE):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if ENABLE_TMA:
Expand Down Expand Up @@ -259,7 +259,7 @@ def _attn_fwd_inner_ws(
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# loop over k, v and update accumulator
for start_n in tl.range(lo, hi, BLOCK_N, loop_schedule=LOOP_SCHEDULE):
for start_n in tl.range(lo, hi, BLOCK_N): #, loop_schedule=LOOP_SCHEDULE):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
with tl.async_task([0]):
Expand Down

0 comments on commit 9ba5656

Please sign in to comment.