Skip to content

Commit

Permalink
Use skiplist to manage CI skips instead of flagging in the code
Browse files Browse the repository at this point in the history
Summary:
It turns out using `ci=` flag is a bad idea because we are now dealing with three types of CIs:
- OSS in triton main
- OSS in triton-pytorch
- fbcode triton

It is much better to organize them as yaml files so that we can enable/disable them more easily.

Reviewed By: FindHao

Differential Revision: D65316542

fbshipit-source-id: 1c1cc0c00a1fe9d31f51e54e597415b822772a8f
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 1, 2024
1 parent 107261c commit e59b342
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 15 deletions.
25 changes: 22 additions & 3 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,24 @@

from typing import List, Optional

import yaml

from tritonbench.operators import load_opbench_by_name
from tritonbench.operators_collection import list_operators_by_collection

from tritonbench.utils.parser import get_parser
from tritonbench.utils.triton_op import IS_FBCODE

if IS_FBCODE:
import importlib

fbcode_skip_file_path = "fb/skip_tests_h100_fbcode.yaml"
SKIP_FILE = importlib.resources.files(__package__).joinpath(fbcode_skip_file_path)
else:
SKIP_FILE = "skip_tests_h100_pytorch.yaml"

with open(SKIP_FILE, "r") as f:
skip_tests = yaml.safe_load(f)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -25,10 +39,9 @@ def check_ci_output(op):

output = op.output
output_impls = output.result[0][1].keys()
skiped_impls = op.tb_args.skip
ci_enabled_impls = [
x
for x in REGISTERED_BENCHMARKS[output.op_name].keys()
if REGISTERED_BENCHMARKS[output.op_name][x].ci
x for x in REGISTERED_BENCHMARKS[output.op_name].keys() if x not in skiped_impls
]
# Make sure that all the ci_enabled impls are in the output
logger.info(f"output impls: {output_impls}, ci_enabled impls: {ci_enabled_impls}")
Expand All @@ -41,7 +54,13 @@ def _run_one_operator(
tb_args: argparse.Namespace,
extra_args: Optional[List[str]] = None,
):
if tb_args.op in skip_tests:
# If the op itself is in the skip list, skip all tests
if skip_tests[tb_args.op] is None:
return
tb_args.skip = ",".join(skip_tests[tb_args.op])
Operator = load_opbench_by_name(tb_args.op)

op = Operator(tb_args=tb_args, extra_args=extra_args)
op.run()
check_ci_output(op)
Expand Down
File renamed without changes.
12 changes: 5 additions & 7 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def xformers_preprocess(
)
return fhma_input

@register_benchmark(enabled=False, ci=False)
@register_benchmark(enabled=False)
def xformers(
self,
q: torch.Tensor,
Expand All @@ -301,7 +301,7 @@ def xformers(
xformers_cutlass_fhma = xformers.ops.fmha.cutlass.FwOp
return lambda: xformers_cutlass_fhma().apply(fhma_input, needs_gradient=False)

@register_benchmark(enabled=False, ci=False)
@register_benchmark(enabled=False)
def xformers_splitk(
self,
q: torch.Tensor,
Expand All @@ -319,7 +319,7 @@ def colfax_cutlass_preprocess(self, q, k, v):
torch.transpose(v, 1, 2),
)

@register_benchmark(enabled=False, ci=False)
@register_benchmark(enabled=False)
def colfax_cutlass(self, q, k, v):
default_scale = 1.0 / math.sqrt(float(self.D_HEAD))
colfax_q, colfax_k, colfax_v = self.colfax_cutlass_preprocess(q, k, v)
Expand All @@ -333,7 +333,7 @@ def colfax_cutlass(self, q, k, v):
default_scale,
)

@register_benchmark(enabled=False, ci=False)
@register_benchmark(enabled=False)
def tk(self, q, k, v):
o = torch.zeros_like(v)

Expand All @@ -346,9 +346,7 @@ def tk_dispatcher():

return tk_dispatcher

@register_benchmark(
enabled=False, label=f"cudnn_{torch.backends.cudnn.version()}", ci=False
)
@register_benchmark(enabled=False, label=f"cudnn_{torch.backends.cudnn.version()}")
def cudnn(self, q, k, v):
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"

Expand Down
6 changes: 3 additions & 3 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,15 @@ def triton_persistent_matmul(self, a, b, bias) -> Callable:
else:
return lambda: matmul_persistent(a, b)

@register_benchmark(enabled=not IS_FBCODE, ci=False)
@register_benchmark(enabled=not IS_FBCODE)
def triton_tma_persistent_matmul(self, a, b, bias) -> Callable:
b = b.T.contiguous()
if not bias == None:
return lambda: matmul_tma_persistent(a, b) + bias
else:
return lambda: matmul_tma_persistent(a, b)

@register_benchmark(enabled=not IS_FBCODE, ci=False)
@register_benchmark(enabled=not IS_FBCODE)
def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable:
b = b.T.contiguous()
if not bias == None:
Expand Down Expand Up @@ -198,7 +198,7 @@ def hstu_triton_matmul(self, a, b, bias) -> Callable:
else:
return lambda: hstu_triton_matmul(a, b)

@register_benchmark(enabled=bool(colfax_gemm), ci=False)
@register_benchmark(enabled=bool(colfax_gemm))
def colfax_cutlass_matmul(self, a, b, bias) -> Callable:
assert colfax_gemm, f"colfax_gemm operator is not available."
if not bias == None:
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def inner(*args):

return lambda: inner(*args)

@register_benchmark(ci=HAS_LIGER_KERNEL)
@register_benchmark()
def liger_layer_norm(self, *args):
(x, w_shape, weight, bias, eps) = args
return lambda: LigerLayerNormFunction.apply(x, weight, bias, eps)
Expand Down
7 changes: 6 additions & 1 deletion tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,12 @@ def get_parser(args=None):
parser.add_argument(
"--only",
default=None,
help="Specify one or multiple operator implementations to run.",
help="Specify one or multiple kernel implementations to run.",
)
parser.add_argument(
"--skip",
default=None,
help="Specify one or multiple kernel implementations to skip.",
)
parser.add_argument(
"--baseline", type=str, default=None, help="Override default baseline."
Expand Down
7 changes: 7 additions & 0 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def __init__(
if self.tb_args.baseline:
BASELINE_BENCHMARKS[self.name] = self.tb_args.baseline
self._only = _split_params_by_comma(self.tb_args.only)
self._skip = _split_params_by_comma(self.tb_args.skip)
self._input_id = self.tb_args.input_id
self._num_inputs = self.tb_args.num_inputs

Expand Down Expand Up @@ -673,6 +674,12 @@ def run(
x_val = self.get_x_val(self.example_inputs)
if self._only:
benchmarks = self._only
elif self._skip:
benchmarks = [
bm
for bm in REGISTERED_BENCHMARKS[self.name].keys()
if bm not in self._skip
]
else:
benchmarks = (
[bm for bm in ENABLED_BENCHMARKS[self.name]]
Expand Down

0 comments on commit e59b342

Please sign in to comment.