From 850ec97bcb8917c5f5004acc25c27b4aa8e05d36 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 18 Nov 2024 19:24:45 -0500 Subject: [PATCH] Fix fp8_attention --- tritonbench/operators/fp8_attention/operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tritonbench/operators/fp8_attention/operator.py b/tritonbench/operators/fp8_attention/operator.py index 63d1e33b..0131ca73 100644 --- a/tritonbench/operators/fp8_attention/operator.py +++ b/tritonbench/operators/fp8_attention/operator.py @@ -10,7 +10,7 @@ import torch -from tritonbench.kernels.triton_fused_attention import attention as triton_attention +from tritonbench.kernels.triton_fused_attention import attention_opt as triton_attention from tritonbench.utils.triton_op import ( BenchmarkOperator, BenchmarkOperatorMetrics, @@ -110,7 +110,7 @@ def triton_flash_v2( triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v) # full fp8 will be enabled if type of q,k,v is fp8 return lambda: triton_attention( - triton_q, triton_k, triton_v, False, self.sm_scale + triton_q, triton_k, triton_v, False, self.sm_scale, "base" ) def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]: