From 9ba565626b99f74a00d625675768223061363bad Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Tue, 26 Nov 2024 12:35:17 -0800 Subject: [PATCH] skip tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/test_gpu/skip_tests_h100_pytorch.yaml | 1 + test/test_gpu/skip_tests_h100_triton_main.yaml | 1 + tritonbench/kernels/triton_fused_attention.py | 4 ++-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index 9f6be6d7..f484f8c0 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -16,6 +16,7 @@ flash_attention: - triton_tutorial_flash_v2_tma - triton_tutorial_flash_v2_ws - triton_tutorial_flash_v2_tma_ws + - triton_tutorial_flash_v2_tma_ws_persistent fp8_attention: - colfax_fmha # triton_flash_v2 requires triton-main diff --git a/test/test_gpu/skip_tests_h100_triton_main.yaml b/test/test_gpu/skip_tests_h100_triton_main.yaml index 6843eaaa..58ac80d9 100644 --- a/test/test_gpu/skip_tests_h100_triton_main.yaml +++ b/test/test_gpu/skip_tests_h100_triton_main.yaml @@ -10,6 +10,7 @@ flash_attention: # _ws kernels require Triton with warp specialization - triton_tutorial_flash_v2_ws - triton_tutorial_flash_v2_tma_ws + - triton_tutorial_flash_v2_tma_ws_persistent fp8_attention: # fb-only kernel - colfax_fmha diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 3e20fb03..d50e547c 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -155,7 +155,7 @@ def _attn_fwd_inner( K_block_ptr = tl.advance(K_block_ptr, (0, lo)) V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # loop over k, v and update accumulator - for start_n in tl.range(lo, hi, BLOCK_N, loop_schedule=LOOP_SCHEDULE): + for start_n in tl.range(lo, hi, BLOCK_N): #, loop_schedule=LOOP_SCHEDULE): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- if ENABLE_TMA: @@ -259,7 +259,7 @@ def _attn_fwd_inner_ws( K_block_ptr = tl.advance(K_block_ptr, (0, lo)) V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # loop over k, v and update accumulator - for start_n in tl.range(lo, hi, BLOCK_N, loop_schedule=LOOP_SCHEDULE): + for start_n in tl.range(lo, hi, BLOCK_N): #, loop_schedule=LOOP_SCHEDULE): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- with tl.async_task([0]):