Skip to content

Commit

Permalink
Merge pull request opendatahub-io#17 from ROCm/triton-config-fix
Browse files Browse the repository at this point in the history
[ROCm] adding a missing triton autotune config
  • Loading branch information
hongxiayang authored May 16, 2024
2 parents 4b39609 + dfef216 commit bebcbe6
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,16 @@ def _attn_fwd_inner(
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
Expand Down

0 comments on commit bebcbe6

Please sign in to comment.