From df1027f87c108eb5b4599fd27a77d45ae43c8f59 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 11 Dec 2024 07:41:34 -0800 Subject: [PATCH] Support flops metric in proton profiling (#111) Summary: Support tflops+proton metric on more operators. Now we require operator to use `def flops(): ` and the framework will add `tflops` metric automatically. The same interface will be used to support `flops` in Proton. Differential Revision: D67062546 --- tritonbench/operators/addmm/operator.py | 4 ++-- .../operators/bf16xint16_gemm/bf16xint16_gemm.py | 4 ++-- tritonbench/operators/flash_attention/operator.py | 7 ------- tritonbench/operators/fp8_attention/operator.py | 5 ++--- .../fp8_fused_quant_gemm_rowwise/operator.py | 4 ++-- tritonbench/operators/fp8_gemm/fp8_gemm.py | 4 ++-- .../operators/fp8_gemm_blockwise/operator.py | 4 ++-- tritonbench/operators/fp8_gemm_rowwise/operator.py | 4 ++-- tritonbench/operators/int4_gemm/int4_gemm.py | 4 ++-- tritonbench/operators/low_mem_dropout/operator.py | 4 ++-- tritonbench/operators/ragged_attention/operator.py | 12 ++++++------ tritonbench/utils/triton_op.py | 14 ++++++-------- 12 files changed, 30 insertions(+), 40 deletions(-) diff --git a/tritonbench/operators/addmm/operator.py b/tritonbench/operators/addmm/operator.py index 5abc1293..1f5fd43d 100644 --- a/tritonbench/operators/addmm/operator.py +++ b/tritonbench/operators/addmm/operator.py @@ -126,14 +126,14 @@ def gbps( return numel / metrics.latency * 1e3 @register_metric() - def tflops( + def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: _, mat1, mat2 = example_inputs m, k = mat1.size() k, n = mat2.size() flops = m * k * 2 * n - return flops / metrics.latency / 1e12 * 1e3 + return flops @register_x_val(label="(M, N, K)") def get_x_val(self, example_inputs) -> Tuple[int, int, int]: diff --git a/tritonbench/operators/bf16xint16_gemm/bf16xint16_gemm.py b/tritonbench/operators/bf16xint16_gemm/bf16xint16_gemm.py index c2c0935c..72908ac1 100644 --- a/tritonbench/operators/bf16xint16_gemm/bf16xint16_gemm.py +++ b/tritonbench/operators/bf16xint16_gemm/bf16xint16_gemm.py @@ -142,14 +142,14 @@ def nbytes(t): return gb / metrics.latency * 1e3 @register_metric() - def tflops( + def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: a, b = example_inputs m, k = a.size() _, n = b.size() flops = 2 * m * n * k - return flops / metrics.latency / 1e12 * 1e3 + return flops def plot(self): @triton.testing.perf_report( diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 51033b76..d400015a 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -478,13 +478,6 @@ def causal_mask(b, h, q_idx, kv_idx): return lambda: flex_attention(q, k, v, block_mask=block_mask) - @register_metric() - def tflops( - self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics - ) -> float: - analytic_flops = self.flops(fn_name, example_inputs, metrics) - return analytic_flops / metrics.latency * 1e-9 - @register_metric(x_only=True) def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics diff --git a/tritonbench/operators/fp8_attention/operator.py b/tritonbench/operators/fp8_attention/operator.py index 0131ca73..14264650 100644 --- a/tritonbench/operators/fp8_attention/operator.py +++ b/tritonbench/operators/fp8_attention/operator.py @@ -150,10 +150,9 @@ def get_input_iter(self) -> Generator: yield (q, k, v) @register_metric() - def tflops( + def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: H = self.embedding_dim // self.D_HEAD flops_per_matmul = 2.0 * self.BATCH * H * self.N_CTX * self.N_CTX * self.D_HEAD - tflops = 2 * flops_per_matmul - return tflops / metrics.latency * 1e-9 + return 2 * flops_per_matmul diff --git a/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py b/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py index 0f87a011..e55360de 100644 --- a/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py +++ b/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py @@ -132,14 +132,14 @@ def _impl(x1, x2, wq, w_scale, wd): return lambda: _impl(x1, x2, wq, w_scale, wd) @register_metric() - def tflops( + def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> List[float]: x1, _, wq, _, _ = example_inputs m, k = x1.size() n, k = wq.size() flops = m * k * 2 * n - return flops / metrics.latency / 1e12 * 1e3 + return flops @register_x_val(label="(M, N, K)") def get_x_val(self, example_inputs) -> Tuple[int, int, int]: diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 61b4e29f..960a789a 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -120,14 +120,14 @@ def nbytes(t): return gb / metrics.latency * 1e3 @register_metric() - def tflops( + def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: a, b = example_inputs m, k = a.size() _, n = b.size() flops = 2 * m * n * k - return flops / metrics.latency / 1e12 * 1e3 + return flops def plot(self): @triton.testing.perf_report( diff --git a/tritonbench/operators/fp8_gemm_blockwise/operator.py b/tritonbench/operators/fp8_gemm_blockwise/operator.py index d1f9bc73..b1a8c3aa 100644 --- a/tritonbench/operators/fp8_gemm_blockwise/operator.py +++ b/tritonbench/operators/fp8_gemm_blockwise/operator.py @@ -138,14 +138,14 @@ def _cutlass(self, xq, wq, x_scale, w_scale) -> Callable: return lambda: cutlass_fp8_block(xq, wq, x_scale, w_scale) @register_metric() - def tflops( + def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> List[float]: xq, wq, _, _ = example_inputs m, k = xq.size() n, k = wq.size() flops = m * k * 2 * n - return flops / metrics.latency / 1e12 * 1e3 + return flops @register_x_val(label="(M, N, K)") def get_x_val(self, example_inputs) -> Tuple[int, int, int]: diff --git a/tritonbench/operators/fp8_gemm_rowwise/operator.py b/tritonbench/operators/fp8_gemm_rowwise/operator.py index 53bf30b0..499a82ae 100644 --- a/tritonbench/operators/fp8_gemm_rowwise/operator.py +++ b/tritonbench/operators/fp8_gemm_rowwise/operator.py @@ -168,14 +168,14 @@ def _cublas(self, xq, wq, x_scale, w_scale) -> Callable: # return lambda: _cublass(xq, wq, x_scale, w_scale) @register_metric() - def tflops( + def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> List[float]: xq, wq, _, _ = example_inputs m, k = xq.size() n, k = wq.size() flops = m * k * 2 * n - return flops / metrics.latency / 1e12 * 1e3 + return flops @register_x_val(label="(M, N, K)") def get_x_val(self, example_inputs) -> Tuple[int, int, int]: diff --git a/tritonbench/operators/int4_gemm/int4_gemm.py b/tritonbench/operators/int4_gemm/int4_gemm.py index 8b611451..7664fc49 100644 --- a/tritonbench/operators/int4_gemm/int4_gemm.py +++ b/tritonbench/operators/int4_gemm/int4_gemm.py @@ -94,7 +94,7 @@ def nbytes(t): return gb / metrics.latency * 1e3 @register_metric() - def tflops( + def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: a, b = example_inputs @@ -102,7 +102,7 @@ def tflops( m = B * m _, n = b.size() flops = 2 * m * n * k - return flops / metrics.latency / 1e12 * 1e3 + return flops def plot(self): @triton.testing.perf_report( diff --git a/tritonbench/operators/low_mem_dropout/operator.py b/tritonbench/operators/low_mem_dropout/operator.py index 543ad41f..3abbf5cd 100644 --- a/tritonbench/operators/low_mem_dropout/operator.py +++ b/tritonbench/operators/low_mem_dropout/operator.py @@ -26,10 +26,10 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): ) @register_metric() - def tflops(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics): + def flops(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics): p, a = example_inputs flops = 2 * len(a) - return flops / metrics.latency + return flops @register_benchmark() def triton_dropout(self, p, x): diff --git a/tritonbench/operators/ragged_attention/operator.py b/tritonbench/operators/ragged_attention/operator.py index 2ab21e09..48bb51bb 100644 --- a/tritonbench/operators/ragged_attention/operator.py +++ b/tritonbench/operators/ragged_attention/operator.py @@ -129,14 +129,14 @@ def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]: return fn @register_metric() - def tflops( + def flops( self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics ) -> float: ratio = 2.0 # triangular masking f1 = 0.0 f2 = 0.0 jagged = True - q, k, v, seq_offsets, timestamps, num_targets = example_inputs + q, k, v, seq_offsets, timestamps, num_targets, seq_len = example_inputs _, nheads, attn_dim = q.shape _, _, hidden_dim = v.shape max_seqlen = timestamps.size(1) - 1 @@ -152,9 +152,9 @@ def tflops( # (QK^T)V, d(QK^T) = dOV^T, dV = (QK^T)^TdO, f2 += 2 * self.num_heads * hidden_dim * seq_len**2 // ratio if self.mode == Mode.FWD: - tflops = f1 + f2 # computes (QK^T) and (QK^T)V + flops = f1 + f2 # computes (QK^T) and (QK^T)V elif self.mode == Mode.BWD: - tflops = 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T) + flops = 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T) elif self.mode == Mode.FWD_BWD: - tflops = 4 * f1 + 3 * f2 - return tflops / metrics.latency * 1e-9 + flops = 4 * f1 + 3 * f2 + return flops diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index c7855cff..c500f227 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -651,14 +651,6 @@ def __init__( if self.tb_args.precision == "bypass": self.tb_args.precision = self.DEFAULT_PRECISION self.dtype = PRECISION_DTYPE_MAPPING.get(self.tb_args.precision, None) - self.DEFAULT_METRICS.extend( - [ - x - for x in REGISTERED_METRICS.get(self.name, []) - if x not in BUILTIN_METRICS - ] - ) - self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS)) if self.tb_args.baseline: BASELINE_BENCHMARKS[self.name] = self.tb_args.baseline self._only = _split_params_by_comma(self.tb_args.only) @@ -1607,6 +1599,10 @@ def _compile_time_in_task( def tflops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: + if self.has_metric("flops"): + flops = self.flops(fn_name, example_inputs, metrics) + return flops / metrics.latency / 1e12 * 1e3 + def _get_flops(self, func: Callable) -> float: """By default, use the torch.__dispatch__ based flops counter.""" from torch.utils.flop_counter import FlopCounterMode @@ -1678,4 +1674,6 @@ def has_bwd(cls) -> bool: @classmethod def has_metric(cls, metric_name: str) -> bool: + if metric_name == "tflops": + return bool(getattr(cls, "flops", None)) return bool(getattr(cls, metric_name, None))