From e59b3429bbf142cb414b183ce398a6c26091dd48 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 1 Nov 2024 15:13:32 -0700 Subject: [PATCH] Use skiplist to manage CI skips instead of flagging in the code Summary: It turns out using `ci=` flag is a bad idea because we are now dealing with three types of CIs: - OSS in triton main - OSS in triton-pytorch - fbcode triton It is much better to organize them as yaml files so that we can enable/disable them more easily. Reviewed By: FindHao Differential Revision: D65316542 fbshipit-source-id: 1c1cc0c00a1fe9d31f51e54e597415b822772a8f --- test/test_gpu/main.py | 25 ++++++++++++++++--- ...orch.yaml => skip_tests_h100_pytorch.yaml} | 0 .../operators/flash_attention/operator.py | 12 ++++----- tritonbench/operators/gemm/operator.py | 6 ++--- tritonbench/operators/layer_norm/operator.py | 2 +- tritonbench/utils/parser.py | 7 +++++- tritonbench/utils/triton_op.py | 7 ++++++ 7 files changed, 44 insertions(+), 15 deletions(-) rename test/test_gpu/{skip_tests_pytorch.yaml => skip_tests_h100_pytorch.yaml} (100%) diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index e27bc942..32885988 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -4,10 +4,24 @@ from typing import List, Optional +import yaml + from tritonbench.operators import load_opbench_by_name from tritonbench.operators_collection import list_operators_by_collection from tritonbench.utils.parser import get_parser +from tritonbench.utils.triton_op import IS_FBCODE + +if IS_FBCODE: + import importlib + + fbcode_skip_file_path = "fb/skip_tests_h100_fbcode.yaml" + SKIP_FILE = importlib.resources.files(__package__).joinpath(fbcode_skip_file_path) +else: + SKIP_FILE = "skip_tests_h100_pytorch.yaml" + +with open(SKIP_FILE, "r") as f: + skip_tests = yaml.safe_load(f) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -25,10 +39,9 @@ def check_ci_output(op): output = op.output output_impls = output.result[0][1].keys() + skiped_impls = op.tb_args.skip ci_enabled_impls = [ - x - for x in REGISTERED_BENCHMARKS[output.op_name].keys() - if REGISTERED_BENCHMARKS[output.op_name][x].ci + x for x in REGISTERED_BENCHMARKS[output.op_name].keys() if x not in skiped_impls ] # Make sure that all the ci_enabled impls are in the output logger.info(f"output impls: {output_impls}, ci_enabled impls: {ci_enabled_impls}") @@ -41,7 +54,13 @@ def _run_one_operator( tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None, ): + if tb_args.op in skip_tests: + # If the op itself is in the skip list, skip all tests + if skip_tests[tb_args.op] is None: + return + tb_args.skip = ",".join(skip_tests[tb_args.op]) Operator = load_opbench_by_name(tb_args.op) + op = Operator(tb_args=tb_args, extra_args=extra_args) op.run() check_ci_output(op) diff --git a/test/test_gpu/skip_tests_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml similarity index 100% rename from test/test_gpu/skip_tests_pytorch.yaml rename to test/test_gpu/skip_tests_h100_pytorch.yaml diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index b7d4a608..245c1ecf 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -290,7 +290,7 @@ def xformers_preprocess( ) return fhma_input - @register_benchmark(enabled=False, ci=False) + @register_benchmark(enabled=False) def xformers( self, q: torch.Tensor, @@ -301,7 +301,7 @@ def xformers( xformers_cutlass_fhma = xformers.ops.fmha.cutlass.FwOp return lambda: xformers_cutlass_fhma().apply(fhma_input, needs_gradient=False) - @register_benchmark(enabled=False, ci=False) + @register_benchmark(enabled=False) def xformers_splitk( self, q: torch.Tensor, @@ -319,7 +319,7 @@ def colfax_cutlass_preprocess(self, q, k, v): torch.transpose(v, 1, 2), ) - @register_benchmark(enabled=False, ci=False) + @register_benchmark(enabled=False) def colfax_cutlass(self, q, k, v): default_scale = 1.0 / math.sqrt(float(self.D_HEAD)) colfax_q, colfax_k, colfax_v = self.colfax_cutlass_preprocess(q, k, v) @@ -333,7 +333,7 @@ def colfax_cutlass(self, q, k, v): default_scale, ) - @register_benchmark(enabled=False, ci=False) + @register_benchmark(enabled=False) def tk(self, q, k, v): o = torch.zeros_like(v) @@ -346,9 +346,7 @@ def tk_dispatcher(): return tk_dispatcher - @register_benchmark( - enabled=False, label=f"cudnn_{torch.backends.cudnn.version()}", ci=False - ) + @register_benchmark(enabled=False, label=f"cudnn_{torch.backends.cudnn.version()}") def cudnn(self, q, k, v): os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1" diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 87e8cfdb..fb2afa0b 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -162,7 +162,7 @@ def triton_persistent_matmul(self, a, b, bias) -> Callable: else: return lambda: matmul_persistent(a, b) - @register_benchmark(enabled=not IS_FBCODE, ci=False) + @register_benchmark(enabled=not IS_FBCODE) def triton_tma_persistent_matmul(self, a, b, bias) -> Callable: b = b.T.contiguous() if not bias == None: @@ -170,7 +170,7 @@ def triton_tma_persistent_matmul(self, a, b, bias) -> Callable: else: return lambda: matmul_tma_persistent(a, b) - @register_benchmark(enabled=not IS_FBCODE, ci=False) + @register_benchmark(enabled=not IS_FBCODE) def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable: b = b.T.contiguous() if not bias == None: @@ -198,7 +198,7 @@ def hstu_triton_matmul(self, a, b, bias) -> Callable: else: return lambda: hstu_triton_matmul(a, b) - @register_benchmark(enabled=bool(colfax_gemm), ci=False) + @register_benchmark(enabled=bool(colfax_gemm)) def colfax_cutlass_matmul(self, a, b, bias) -> Callable: assert colfax_gemm, f"colfax_gemm operator is not available." if not bias == None: diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index 17e4efda..73a0344c 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -39,7 +39,7 @@ def inner(*args): return lambda: inner(*args) - @register_benchmark(ci=HAS_LIGER_KERNEL) + @register_benchmark() def liger_layer_norm(self, *args): (x, w_shape, weight, bias, eps) = args return lambda: LigerLayerNormFunction.apply(x, weight, bias, eps) diff --git a/tritonbench/utils/parser.py b/tritonbench/utils/parser.py index 520e49be..26c7aac1 100644 --- a/tritonbench/utils/parser.py +++ b/tritonbench/utils/parser.py @@ -104,7 +104,12 @@ def get_parser(args=None): parser.add_argument( "--only", default=None, - help="Specify one or multiple operator implementations to run.", + help="Specify one or multiple kernel implementations to run.", + ) + parser.add_argument( + "--skip", + default=None, + help="Specify one or multiple kernel implementations to skip.", ) parser.add_argument( "--baseline", type=str, default=None, help="Override default baseline." diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index b85be025..6a07ec01 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -590,6 +590,7 @@ def __init__( if self.tb_args.baseline: BASELINE_BENCHMARKS[self.name] = self.tb_args.baseline self._only = _split_params_by_comma(self.tb_args.only) + self._skip = _split_params_by_comma(self.tb_args.skip) self._input_id = self.tb_args.input_id self._num_inputs = self.tb_args.num_inputs @@ -673,6 +674,12 @@ def run( x_val = self.get_x_val(self.example_inputs) if self._only: benchmarks = self._only + elif self._skip: + benchmarks = [ + bm + for bm in REGISTERED_BENCHMARKS[self.name].keys() + if bm not in self._skip + ] else: benchmarks = ( [bm for bm in ENABLED_BENCHMARKS[self.name]]