Skip to content

Commit

Permalink
Fix fp8_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 19, 2024
1 parent 34aeb9e commit 850ec97
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from tritonbench.kernels.triton_fused_attention import attention as triton_attention
from tritonbench.kernels.triton_fused_attention import attention_opt as triton_attention
from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down Expand Up @@ -110,7 +110,7 @@ def triton_flash_v2(
triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v)
# full fp8 will be enabled if type of q,k,v is fp8
return lambda: triton_attention(
triton_q, triton_k, triton_v, False, self.sm_scale
triton_q, triton_k, triton_v, False, self.sm_scale, "base"
)

def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]:
Expand Down

0 comments on commit 850ec97

Please sign in to comment.