From dfef216024ce7848092808de6b1c634686f6fa9c Mon Sep 17 00:00:00 2001 From: Hongxia Yang Date: Thu, 16 May 2024 16:29:05 +0000 Subject: [PATCH] [ROCm] adding a missing triton autotune config --- vllm/attention/ops/triton_flash_attention.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index b86e845020b07..77390f2d0d696 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -239,6 +239,16 @@ def _attn_fwd_inner( num_stages=1, num_warps=8, ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), triton.Config( { "BLOCK_M": 128,