From 3c83e0b9be62a8983edb1e1bdd799439a5e3de2d Mon Sep 17 00:00:00 2001 From: Sam Ginzburg Date: Tue, 12 Nov 2024 08:05:35 -0800 Subject: [PATCH] AMD changed the default pipeliner behavior in OSS Triton so num_stages=0 triggers an assert. Reviewed By: bertmaher Differential Revision: D65695544 fbshipit-source-id: 95988b90debba743f90bfca3a63fa9f34df5e492 --- tritonbench/operators/bf16xint16_gemm/kernel.py | 10 +++++----- tritonbench/operators/fp8_gemm/tutorial.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tritonbench/operators/bf16xint16_gemm/kernel.py b/tritonbench/operators/bf16xint16_gemm/kernel.py index c3590ad2..78155fa6 100644 --- a/tritonbench/operators/bf16xint16_gemm/kernel.py +++ b/tritonbench/operators/bf16xint16_gemm/kernel.py @@ -194,7 +194,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=4, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -205,7 +205,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=8, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -216,7 +216,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=8, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -227,7 +227,7 @@ def get_hip_autotune_config(): "waves_per_eu": 3, }, num_warps=4, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -238,7 +238,7 @@ def get_hip_autotune_config(): "waves_per_eu": 8, }, num_warps=4, - num_stages=0, + num_stages=2, ), ] diff --git a/tritonbench/operators/fp8_gemm/tutorial.py b/tritonbench/operators/fp8_gemm/tutorial.py index ed312deb..99ae23d0 100644 --- a/tritonbench/operators/fp8_gemm/tutorial.py +++ b/tritonbench/operators/fp8_gemm/tutorial.py @@ -341,7 +341,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=4, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -352,7 +352,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=8, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -363,7 +363,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=8, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -374,7 +374,7 @@ def get_hip_autotune_config(): "waves_per_eu": 3, }, num_warps=4, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -385,7 +385,7 @@ def get_hip_autotune_config(): "waves_per_eu": 8, }, num_warps=4, - num_stages=0, + num_stages=2, ), ]