diff --git a/.ci/tritonbench/test-gpu.sh b/.ci/tritonbench/test-gpu.sh index 0f786130..bf762c37 100644 --- a/.ci/tritonbench/test-gpu.sh +++ b/.ci/tritonbench/test-gpu.sh @@ -8,4 +8,7 @@ fi . "${SETUP_SCRIPT}" +# install deps +pip install psutil tabulate + python -m unittest test.test_gpu.main diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml new file mode 100644 index 00000000..ab69a1e5 --- /dev/null +++ b/.github/workflows/linter.yaml @@ -0,0 +1,29 @@ +name: Linter +on: + pull_request: + push: + branches: + - main + workflow_dispatch: + +jobs: + pylint: + permissions: + contents: read + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + path: tritonbench + - name: Install deps + run: | + pip install ruff-api==0.1.0 + - name: Check Formatting + uses: omnilib/ufmt@action-v1 + with: + path: tritonbench + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index e89e6fb8..22d525c9 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -5,6 +5,9 @@ on: - .ci/* - tritonbench/* - .github/workflows/pr.yaml + push: + branches: + - main jobs: h100-pytorch-test: @@ -27,6 +30,10 @@ jobs: sudo nvidia-smi -pm 1 sudo ldconfig nvidia-smi - - name: Test Tritonbench operators + - name: Test Tritonbench operators on H100 GPU run: | - bash ./.ci/tritonbench/test-operators.sh + bash ./.ci/tritonbench/test-gpu.sh + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..c8630b05 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[tool.ufmt] +formatter = "ruff-api" +excludes = ["submodules/"] + +[tool.black] +line-length = 88 +target-version = ["py312"] diff --git a/requirements.txt b/requirements.txt index ae8d8208..1cc03e35 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ packaging pynvml +psutil +tabulate transformers==4.46.1 diff --git a/test/test_cpu/main.py b/test/test_cpu/main.py index a4bfccf1..05e7940e 100644 --- a/test/test_cpu/main.py +++ b/test/test_cpu/main.py @@ -7,7 +7,6 @@ class TestTritonbenchCpu(unittest.TestCase): - def _get_test_op(self): parser = get_parser(["--device", "cpu", "--op", "test_op"]) tb_args, extra_args = parser.parse_known_args( diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index d4ffdfa6..9e952b2a 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -18,7 +18,11 @@ fbcode_skip_file_path = "fb/skip_tests_h100_fbcode.yaml" SKIP_FILE = importlib.resources.files(__package__).joinpath(fbcode_skip_file_path) else: - SKIP_FILE = "skip_tests_h100_pytorch.yaml" + import os + + SKIP_FILE = os.path.abspath( + os.path.join(os.path.dirname(__file__), "skip_tests_h100_pytorch.yaml") + ) with open(SKIP_FILE, "r") as f: skip_tests = yaml.safe_load(f) @@ -55,7 +59,7 @@ def _run_one_operator( ): if tb_args.op in skip_tests: # If the op itself is in the skip list, skip all tests - if skip_tests[tb_args.op] is None: + if not skip_tests[tb_args.op]: return tb_args.skip = ",".join(skip_tests[tb_args.op]) Operator = load_opbench_by_name(tb_args.op) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index 6119dc9d..6acc2094 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -1,11 +1,37 @@ # Tests we skip in OSS CI # This file is regarding to the Triton version bundled with pytorch -# Use to skip an entire operator -# Use to skip an impl -- test_op -- bf16xint16_gemm/bf16xint16 -- fp8_attention/colfax_fmha -- fp8_fused_quant_gemm_rowwise -- fp8_gemm/triton_persistent_fp8_gemm -- fp8_gemm/triton_tma_persistent_fp8_gemm -- fp8_gemm_rowwise +# Use to skip an entire operator +# Use to skip an impl +bf16xint16_gemm: + - bf16xint16 +# TODO: we have many buggy backends for flash_attention +# Need to fix them in the CI +flash_attention: +# - triton_tutorial_flash_v2_tma +# - triton_op_flash_v2 +# - xformers_splitk +# - colfax_cutlass +# - tk +# - sdpa +# - cudnn +# - flex_attention +fp8_attention: + - colfax_fmha +fp8_fused_quant_gemm_rowwise: +fp8_gemm: + - triton_persistent_fp8_gemm + - triton_tma_persistent_fp8_gemm +fp8_gemm_rowwise: +gemm: +grouped_gemm: +int4_gemm: +jagged_layer_norm: +jagged_mean: +jagged_softmax: +jagged_sum: +layer_norm: +low_mem_dropout: +rms_norm: +rope: +template_attention: +test_op: diff --git a/tools/cuda_utils.py b/tools/cuda_utils.py index 7baf70e8..91bcbbbe 100644 --- a/tools/cuda_utils.py +++ b/tools/cuda_utils.py @@ -238,7 +238,5 @@ def check_torch_nightly_version(force_date: Optional[str] = None): if args.install_torch_nightly: install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ) if args.check_torch_nightly_version: - assert ( - not args.install_torch_nightly - ), "Error: Can't run install torch nightly and check version in the same command." + assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command." check_torch_nightly_version(args.force_date) diff --git a/tritonbench/components/workers/subprocess_rpc.py b/tritonbench/components/workers/subprocess_rpc.py index aa9e1e3d..2fd15aca 100644 --- a/tritonbench/components/workers/subprocess_rpc.py +++ b/tritonbench/components/workers/subprocess_rpc.py @@ -274,8 +274,10 @@ def write(self, msg: bytes) -> None: def get_writer_pid(self) -> int: assert ( self._writer_pid is not None - ), "Writer pid is not specified. Maybe calling from child process or input pipe.\ + ), ( + "Writer pid is not specified. Maybe calling from child process or input pipe.\ Please report a bug." + ) return self._writer_pid def set_writer_pid(self, writer_pid: int) -> None: diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index e76342ca..3c0900e8 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -37,7 +37,6 @@ class TmaAutoTuneHelper: - # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 class KernelParamWrapper: def __init__(self, desc): @@ -734,7 +733,6 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, # HEAD_DIM: tl.constexpr, # STAGE: tl.constexpr, # ): - tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) off_hz = tl.program_id(1) @@ -848,7 +846,14 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, # @triton.jit def _attn_bwd_preprocess( - O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # # # # + O, + DO, + Delta, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, # # # # ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_hz = tl.program_id(1) @@ -1179,7 +1184,6 @@ def _attn_bwd( class _attention_ws(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints @@ -1232,7 +1236,7 @@ def forward(ctx, q, k, v, causal, sm_scale): N_CTX=q.shape[2], # HEAD_DIM=HEAD_DIM_K, # STAGE=stage, # - **extra_kern_args + **extra_kern_args, ) ctx.save_for_backward(q, k, v, o, M) @@ -1304,7 +1308,6 @@ def backward(ctx, do): class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints @@ -1355,7 +1358,7 @@ def forward(ctx, q, k, v, causal, sm_scale): N_CTX=q.shape[2], # HEAD_DIM=HEAD_DIM_K, # STAGE=stage, # - **extra_kern_args + **extra_kern_args, ) ctx.save_for_backward(q, k, v, o, M) @@ -1427,7 +1430,6 @@ def backward(ctx, do): class _attention_tma(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints @@ -1587,7 +1589,7 @@ def grid_tma(META): N_CTX=q.shape[2], # HEAD_DIM=HEAD_DIM_K, # STAGE=stage, # - **extra_kern_args + **extra_kern_args, ) ctx.save_for_backward(q, k, v, o, M) diff --git a/tritonbench/operators/gather_gemv/operator.py b/tritonbench/operators/gather_gemv/operator.py index 1ea4ab17..2702ad25 100644 --- a/tritonbench/operators/gather_gemv/operator.py +++ b/tritonbench/operators/gather_gemv/operator.py @@ -26,7 +26,6 @@ class Operator(BenchmarkOperator): - @register_metric() def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): arg0_1, arg1_1, arg2_1 = example_inputs diff --git a/tritonbench/operators/gather_gemv/triton_gather_gemv.py b/tritonbench/operators/gather_gemv/triton_gather_gemv.py index 94f9fbcd..550622c8 100644 --- a/tritonbench/operators/gather_gemv/triton_gather_gemv.py +++ b/tritonbench/operators/gather_gemv/triton_gather_gemv.py @@ -82,7 +82,7 @@ def triton_red_fused_mv_0( rbase = tl.arange(0, RBLOCK)[None, :].to(tl.int64) x0 = xindex # x0 // rnumel should have the same value of either 0 or 1 - tmp0 = tl.load(in_ptr0 + ((x0 // rnumel)), None, eviction_policy="evict_last") + tmp0 = tl.load(in_ptr0 + (x0 // rnumel), None, eviction_policy="evict_last") _tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase diff --git a/tritonbench/operators/jagged_layer_norm/operator.py b/tritonbench/operators/jagged_layer_norm/operator.py index c1742c75..63dc3a6e 100644 --- a/tritonbench/operators/jagged_layer_norm/operator.py +++ b/tritonbench/operators/jagged_layer_norm/operator.py @@ -36,7 +36,6 @@ def parse_op_args(args: List[str]): class Operator(BenchmarkOperator): - DEFAULT_METRICS = ["latency", "accuracy"] DEFAULT_PRECISION = "fp32" @@ -48,8 +47,8 @@ def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): super().__init__(tb_args, extra_args) - self.sizes = list(range(2, 12, 4)) + list( - range(12, 23, 3) + self.sizes = ( + list(range(2, 12, 4)) + list(range(12, 23, 3)) ) # bias towards larger sizes, which are more representative of real-world shapes args = parse_op_args(self.extra_args) @@ -105,8 +104,8 @@ def _inner(): ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm padded_normalized = ( - padded_values - mean - ) * padded_mask_values # mask elements outside of the ragged dimension size for correct variance calculation + (padded_values - mean) * padded_mask_values + ) # mask elements outside of the ragged dimension size for correct variance calculation variance = ( torch.sum( diff --git a/tritonbench/operators/jagged_mean/kernels.py b/tritonbench/operators/jagged_mean/kernels.py index 65af3712..66f9136a 100644 --- a/tritonbench/operators/jagged_mean/kernels.py +++ b/tritonbench/operators/jagged_mean/kernels.py @@ -53,8 +53,9 @@ def triton_jagged_mean_kernel_simple_fused_sum_then_buffer( offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) mask_m = offsets_m < M - ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load( - input_ptr_offsets + (pid_b + 1) + ragged_start, ragged_end = ( + tl.load(input_ptr_offsets + pid_b), + tl.load(input_ptr_offsets + (pid_b + 1)), ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] ragged_len = ragged_end - ragged_start @@ -133,8 +134,9 @@ def triton_jagged_mean_kernel_simple_fused_buffer_then_sum( offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) mask_m = offsets_m < M - ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load( - input_ptr_offsets + (pid_b + 1) + ragged_start, ragged_end = ( + tl.load(input_ptr_offsets + pid_b), + tl.load(input_ptr_offsets + (pid_b + 1)), ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] ragged_len = ragged_end - ragged_start @@ -212,8 +214,9 @@ def triton_jagged_mean_kernel_variable_length_loop_sum_then_buffer( offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) mask_m = offsets_m < M - ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load( - input_ptr_offsets + (pid_b + 1) + ragged_start, ragged_end = ( + tl.load(input_ptr_offsets + pid_b), + tl.load(input_ptr_offsets + (pid_b + 1)), ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] ragged_len = ragged_end - ragged_start @@ -288,8 +291,9 @@ def triton_jagged_mean_kernel_variable_length_loop_buffer_then_sum( offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) mask_m = offsets_m < M - ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load( - input_ptr_offsets + (pid_ragged + 1) + ragged_start, ragged_end = ( + tl.load(input_ptr_offsets + pid_ragged), + tl.load(input_ptr_offsets + (pid_ragged + 1)), ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] ragged_len = ragged_end - ragged_start diff --git a/tritonbench/operators/jagged_mean/operator.py b/tritonbench/operators/jagged_mean/operator.py index 487c1f9f..d5a82269 100644 --- a/tritonbench/operators/jagged_mean/operator.py +++ b/tritonbench/operators/jagged_mean/operator.py @@ -92,7 +92,6 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer): class Operator(BenchmarkOperator): - DEFAULT_METRICS = ["latency", "accuracy"] DEFAULT_PRECISION = "fp32" @@ -104,8 +103,8 @@ def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): super().__init__(tb_args, extra_args) - self.sizes = list(range(2, 12, 4)) + list( - range(12, 23, 3) + self.sizes = ( + list(range(2, 12, 4)) + list(range(12, 23, 3)) ) # bias towards larger sizes, which are more representative of real-world shapes args = parse_op_args(self.extra_args) @@ -130,28 +129,37 @@ def torch_jagged_mean_unbind_torch_mean( def torch_jagged_mean_torch_nanmean( self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float ): - return lambda: torch.nanmean( - torch.ops.aten._jagged_to_padded_dense_forward( - x.values(), - [x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. - max_lengths=[seqlen], # max length of ragged dimension - padding_value=float("nan"), - ), - dim=1, + return ( + lambda: torch.nanmean( + torch.ops.aten._jagged_to_padded_dense_forward( + x.values(), + [ + x.offsets() + ], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. + max_lengths=[seqlen], # max length of ragged dimension + padding_value=float("nan"), + ), + dim=1, + ) ) @register_benchmark() def torch_jagged_mean_torch_sum( self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float ): - return lambda: torch.sum( - torch.ops.aten._jagged_to_padded_dense_forward( - x.values(), - [x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. - max_lengths=[seqlen], # max length of ragged dimension - ), - dim=1, - ) / x.offsets().diff().unsqueeze(1) + return ( + lambda: torch.sum( + torch.ops.aten._jagged_to_padded_dense_forward( + x.values(), + [ + x.offsets() + ], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. + max_lengths=[seqlen], # max length of ragged dimension + ), + dim=1, + ) + / x.offsets().diff().unsqueeze(1) + ) @register_benchmark() def triton_jagged_mean_simple_fused( @@ -176,7 +184,9 @@ def torch_compile_nested_tensor_integration( self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float ): def _inner(x: torch.Tensor): # mean along ragged dimension (dim == 1) - return torch.mean(x, dim=x._ragged_idx, keepdim=True) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`. + return torch.mean( + x, dim=x._ragged_idx, keepdim=True + ) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`. torch_compile_func = torch.compile(_inner) return lambda: torch_compile_func(x) diff --git a/tritonbench/operators/jagged_softmax/kernels.py b/tritonbench/operators/jagged_softmax/kernels.py index a48a4d32..0a4cb512 100644 --- a/tritonbench/operators/jagged_softmax/kernels.py +++ b/tritonbench/operators/jagged_softmax/kernels.py @@ -54,8 +54,9 @@ def triton_jagged_softmax_kernel_simple_fused_buffer_then_sum( offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) mask_m = offsets_m < M - ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load( - input_ptr_offsets + (pid_b + 1) + ragged_start, ragged_end = ( + tl.load(input_ptr_offsets + pid_b), + tl.load(input_ptr_offsets + (pid_b + 1)), ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] buffer_max_all = tl.full( @@ -163,8 +164,9 @@ def triton_jagged_softmax_kernel_variable_length_loop_buffer_then_sum( offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) mask_m = offsets_m < M - ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load( - input_ptr_offsets + (pid_b + 1) + ragged_start, ragged_end = ( + tl.load(input_ptr_offsets + pid_b), + tl.load(input_ptr_offsets + (pid_b + 1)), ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] buffer_max_all = tl.full( diff --git a/tritonbench/operators/jagged_softmax/operator.py b/tritonbench/operators/jagged_softmax/operator.py index 657207f5..aad5a41c 100644 --- a/tritonbench/operators/jagged_softmax/operator.py +++ b/tritonbench/operators/jagged_softmax/operator.py @@ -71,7 +71,6 @@ def parse_op_args(args: List[str]): class Operator(BenchmarkOperator): - DEFAULT_METRICS = ["latency", "accuracy", "best_config"] DEFAULT_PRECISION = "fp32" @@ -83,8 +82,8 @@ def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): super().__init__(tb_args, extra_args) - self.sizes = list(range(2, 12, 4)) + list( - range(12, 23, 3) + self.sizes = ( + list(range(2, 12, 4)) + list(range(12, 23, 3)) ) # bias towards larger sizes, which are more representative of real-world shapes args = parse_op_args(self.extra_args) @@ -114,7 +113,9 @@ def torch_jagged_softmax_torch_sum( def _inner(): padded = torch.ops.aten._jagged_to_padded_dense_forward( x.values(), - [x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. + [ + x.offsets() + ], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. max_lengths=[seqlen], # max length of ragged dimension padding_value=float("-inf"), # e^-inf = 0 ) @@ -153,7 +154,9 @@ def torch_compile_nested_tensor_integration( self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float ): def _inner(x: torch.Tensor): # softmax along ragged dimension - return torch.softmax(x, dim=x._ragged_idx) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`. + return torch.softmax( + x, dim=x._ragged_idx + ) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`. torch_compile_func = torch.compile(_inner) return lambda: torch_compile_func( diff --git a/tritonbench/operators/jagged_sum/kernels.py b/tritonbench/operators/jagged_sum/kernels.py index 296f6a2f..cd43ac1d 100644 --- a/tritonbench/operators/jagged_sum/kernels.py +++ b/tritonbench/operators/jagged_sum/kernels.py @@ -52,8 +52,9 @@ def triton_jagged_sum_kernel_simple_fused_sum_then_buffer( offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) mask_m = offsets_m < M - ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load( - input_ptr_offsets + (pid_ragged + 1) + ragged_start, ragged_end = ( + tl.load(input_ptr_offsets + pid_ragged), + tl.load(input_ptr_offsets + (pid_ragged + 1)), ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] for block_pos in range( @@ -127,8 +128,9 @@ def triton_jagged_sum_kernel_simple_fused_buffer_then_sum( offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) mask_m = offsets_m < M - ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load( - input_ptr_offsets + (pid_ragged + 1) + ragged_start, ragged_end = ( + tl.load(input_ptr_offsets + pid_ragged), + tl.load(input_ptr_offsets + (pid_ragged + 1)), ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] for block_pos in range( @@ -201,8 +203,9 @@ def triton_jagged_sum_kernel_variable_length_loop_sum_then_buffer( offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) mask_m = offsets_m < M - ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load( - input_ptr_offsets + (pid_b + 1) + ragged_start, ragged_end = ( + tl.load(input_ptr_offsets + pid_b), + tl.load(input_ptr_offsets + (pid_b + 1)), ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] for block_start_ragged in range( @@ -272,8 +275,9 @@ def triton_jagged_sum_kernel_variable_length_loop_buffer_then_sum( offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) mask_m = offsets_m < M - ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load( - input_ptr_offsets + (pid_ragged + 1) + ragged_start, ragged_end = ( + tl.load(input_ptr_offsets + pid_ragged), + tl.load(input_ptr_offsets + (pid_ragged + 1)), ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] for block_start_ragged in range( diff --git a/tritonbench/operators/jagged_sum/operator.py b/tritonbench/operators/jagged_sum/operator.py index 6762d5d6..c531186c 100644 --- a/tritonbench/operators/jagged_sum/operator.py +++ b/tritonbench/operators/jagged_sum/operator.py @@ -92,7 +92,6 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer): class Operator(BenchmarkOperator): - DEFAULT_METRICS = ["latency", "accuracy", "best_config"] DEFAULT_PRECISION = "fp32" use_cuda_graphs = ( @@ -103,8 +102,8 @@ def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): super().__init__(tb_args, extra_args) - self.sizes = list(range(2, 12, 4)) + list( - range(12, 23, 3) + self.sizes = ( + list(range(2, 12, 4)) + list(range(12, 23, 3)) ) # bias towards larger sizes, which are more representative of real-world shapes args = parse_op_args(self.extra_args) @@ -133,13 +132,17 @@ def torch_jagged_sum_no_pad( def torch_jagged_sum_pad( self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float ): - return lambda: torch.sum( - torch.ops.aten._jagged_to_padded_dense_forward( - x.values(), - [x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. - max_lengths=[seqlen], # max length of ragged dimension - ), - dim=1, + return ( + lambda: torch.sum( + torch.ops.aten._jagged_to_padded_dense_forward( + x.values(), + [ + x.offsets() + ], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. + max_lengths=[seqlen], # max length of ragged dimension + ), + dim=1, + ) ) # sum along ragged dimension (dim == 1) @register_benchmark() @@ -165,7 +168,9 @@ def torch_compile_nested_tensor_integration( self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float ): def _inner(x: torch.Tensor): # sum along ragged dimension (dim == 1) - return torch.sum(x, dim=x._ragged_idx) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`. + return torch.sum( + x, dim=x._ragged_idx + ) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`. torch_compile_func = torch.compile(_inner) return lambda: torch_compile_func(x) diff --git a/tritonbench/operators/layer_norm/tutorial.py b/tritonbench/operators/layer_norm/tutorial.py index 2e49fd86..7d403416 100644 --- a/tritonbench/operators/layer_norm/tutorial.py +++ b/tritonbench/operators/layer_norm/tutorial.py @@ -233,7 +233,6 @@ def _layer_norm_bwd_dwdb( class LayerNorm(torch.autograd.Function): - @staticmethod def forward(ctx, x, normalized_shape, weight, bias, eps): # allocate output diff --git a/tritonbench/operators/op_task.py b/tritonbench/operators/op_task.py index 60317d4c..9c4d47b8 100644 --- a/tritonbench/operators/op_task.py +++ b/tritonbench/operators/op_task.py @@ -57,7 +57,6 @@ class OpDetails: class OpTask(base_task.TaskBase): - # The worker may (and often does) consume significant system resources. # In order to ensure that runs do not interfere with each other, we only # allow a single OpTask to exist at a time. diff --git a/tritonbench/operators/ragged_attention/operator.py b/tritonbench/operators/ragged_attention/operator.py index 770dc59d..72838636 100644 --- a/tritonbench/operators/ragged_attention/operator.py +++ b/tritonbench/operators/ragged_attention/operator.py @@ -29,7 +29,7 @@ def __init__( self.max_seq_len = 2**args.max_seq_len_log2 self.num_buckets = args.num_buckets # set a default number of inputs - self._num_inputs = 10 + self._num_inputs = 10 if self._num_inputs is None else self._num_inputs @register_benchmark() def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps): diff --git a/tritonbench/operators/rope/operator.py b/tritonbench/operators/rope/operator.py index e9398d9f..ab2b8476 100644 --- a/tritonbench/operators/rope/operator.py +++ b/tritonbench/operators/rope/operator.py @@ -58,9 +58,10 @@ def prepare_input(self, hidden_size, seq_length): requires_grad=True, dtype=self.dtype, ).transpose(1, 2) - dq, dk = torch.randn_like( - q, device=self.device, dtype=self.dtype - ), torch.randn_like(k, device=self.device) + dq, dk = ( + torch.randn_like(q, device=self.device, dtype=self.dtype), + torch.randn_like(k, device=self.device), + ) pos_ids = torch.arange( seq_length, device=self.device, dtype=torch.long ).unsqueeze(0) diff --git a/tritonbench/operators/softmax/operator.py b/tritonbench/operators/softmax/operator.py index 2644af48..af1e9dc8 100644 --- a/tritonbench/operators/softmax/operator.py +++ b/tritonbench/operators/softmax/operator.py @@ -13,7 +13,6 @@ class Operator(BenchmarkOperator): - @register_benchmark() def triton_softmax(self, x): n_rows, n_cols = x.shape diff --git a/tritonbench/operators/sum/kernels.py b/tritonbench/operators/sum/kernels.py index 5e1c8c63..68ef3bd3 100644 --- a/tritonbench/operators/sum/kernels.py +++ b/tritonbench/operators/sum/kernels.py @@ -16,8 +16,8 @@ def triton_sum_kernel_scalar_result( block_start = pid * BLOCK_SIZE_M # offsets have shape equal to input shape - offsets = block_start + tl.arange( - 0, BLOCK_SIZE_M + offsets = ( + block_start + tl.arange(0, BLOCK_SIZE_M) ) # create 1D vector (input shape) ranging from beginning to end of this program's block # mask has shape equal to input shape @@ -133,7 +133,8 @@ def triton_sum_kernel_1D_result_sum_then_buffer( num_warps=w, ) for b, w in itertools.product( - [2, 4, 8, 16], [2, 4, 8] # block sizes # number of warps + [2, 4, 8, 16], + [2, 4, 8], # block sizes # number of warps ) ], key=["M", "N"], @@ -206,7 +207,8 @@ def triton_sum_kernel_1D_result_buffer_then_sum( num_warps=w, ) for b, w in itertools.product( - [2, 4, 16, 32, 128, 256], [2, 4, 8] # block sizes, number of warps + [2, 4, 16, 32, 128, 256], + [2, 4, 8], # block sizes, number of warps ) ], key=["N"], diff --git a/tritonbench/operators/sum/operator.py b/tritonbench/operators/sum/operator.py index a8034916..ffdc790d 100644 --- a/tritonbench/operators/sum/operator.py +++ b/tritonbench/operators/sum/operator.py @@ -146,7 +146,6 @@ def execute_kernel_2D_result(x): class Operator(BenchmarkOperator): - DEFAULT_METRICS = ["latency", "accuracy", "best_config"] def __init__( diff --git a/tritonbench/operators/test_op/operator.py b/tritonbench/operators/test_op/operator.py index 0cd0bee9..341fe3cf 100644 --- a/tritonbench/operators/test_op/operator.py +++ b/tritonbench/operators/test_op/operator.py @@ -12,7 +12,6 @@ class Operator(BenchmarkOperator): - DEFAULT_METRICS = ["test_metric"] def __init__( diff --git a/tritonbench/operators/vector_add/operator.py b/tritonbench/operators/vector_add/operator.py index efb681e6..f1a63b82 100644 --- a/tritonbench/operators/vector_add/operator.py +++ b/tritonbench/operators/vector_add/operator.py @@ -15,7 +15,6 @@ class Operator(BenchmarkOperator): - @register_metric() def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): return ( @@ -28,7 +27,6 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): @register_benchmark() def triton_add(self, x: torch.Tensor, y: torch.Tensor): - # We need to preallocate the output. output = torch.empty_like(x) n_elements = output.numel() @@ -78,7 +76,6 @@ def plot(self): ) ) def _plot(size, provider): - gbps, max_gbps, min_gbps = self.output.get_y_vals(size, provider, "gbps") return gbps, max_gbps, min_gbps