Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support flops metric in proton profiling #111

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tritonbench/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def gbps(
return numel / metrics.latency * 1e3

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
_, mat1, mat2 = example_inputs
m, k = mat1.size()
k, n = mat2.size()
flops = m * k * 2 * n
return flops / metrics.latency / 1e12 * 1e3
return flops

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/bf16xint16_gemm/bf16xint16_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@ def nbytes(t):
return gb / metrics.latency * 1e3

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
a, b = example_inputs
m, k = a.size()
_, n = b.size()
flops = 2 * m * n * k
return flops / metrics.latency / 1e12 * 1e3
return flops

def plot(self):
@triton.testing.perf_report(
Expand Down
7 changes: 0 additions & 7 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,6 @@ def causal_mask(b, h, q_idx, kv_idx):

return lambda: flex_attention(q, k, v, block_mask=block_mask)

@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
analytic_flops = self.flops(fn_name, example_inputs, metrics)
return analytic_flops / metrics.latency * 1e-9

@register_metric(x_only=True)
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
Expand Down
5 changes: 2 additions & 3 deletions tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,9 @@ def get_input_iter(self) -> Generator:
yield (q, k, v)

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
H = self.embedding_dim // self.D_HEAD
flops_per_matmul = 2.0 * self.BATCH * H * self.N_CTX * self.N_CTX * self.D_HEAD
tflops = 2 * flops_per_matmul
return tflops / metrics.latency * 1e-9
return 2 * flops_per_matmul
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ def _impl(x1, x2, wq, w_scale, wd):
return lambda: _impl(x1, x2, wq, w_scale, wd)

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> List[float]:
x1, _, wq, _, _ = example_inputs
m, k = x1.size()
n, k = wq.size()
flops = m * k * 2 * n
return flops / metrics.latency / 1e12 * 1e3
return flops

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def nbytes(t):
return gb / metrics.latency * 1e3

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
a, b = example_inputs
m, k = a.size()
_, n = b.size()
flops = 2 * m * n * k
return flops / metrics.latency / 1e12 * 1e3
return flops

def plot(self):
@triton.testing.perf_report(
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_gemm_blockwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ def _cutlass(self, xq, wq, x_scale, w_scale) -> Callable:
return lambda: cutlass_fp8_block(xq, wq, x_scale, w_scale)

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> List[float]:
xq, wq, _, _ = example_inputs
m, k = xq.size()
n, k = wq.size()
flops = m * k * 2 * n
return flops / metrics.latency / 1e12 * 1e3
return flops

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ def _cublas(self, xq, wq, x_scale, w_scale) -> Callable:
# return lambda: _cublass(xq, wq, x_scale, w_scale)

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> List[float]:
xq, wq, _, _ = example_inputs
m, k = xq.size()
n, k = wq.size()
flops = m * k * 2 * n
return flops / metrics.latency / 1e12 * 1e3
return flops

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/int4_gemm/int4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def nbytes(t):
return gb / metrics.latency * 1e3

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
a, b = example_inputs
B, m, k = a.size()
m = B * m
_, n = b.size()
flops = 2 * m * n * k
return flops / metrics.latency / 1e12 * 1e3
return flops

def plot(self):
@triton.testing.perf_report(
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/low_mem_dropout/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
)

@register_metric()
def tflops(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
def flops(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
p, a = example_inputs
flops = 2 * len(a)
return flops / metrics.latency
return flops

@register_benchmark()
def triton_dropout(self, p, x):
Expand Down
12 changes: 6 additions & 6 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
return fn

@register_metric()
def tflops(
def flops(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> float:
ratio = 2.0 # triangular masking
f1 = 0.0
f2 = 0.0
jagged = True
q, k, v, seq_offsets, timestamps, num_targets = example_inputs
q, k, v, seq_offsets, timestamps, num_targets, seq_len = example_inputs
_, nheads, attn_dim = q.shape
_, _, hidden_dim = v.shape
max_seqlen = timestamps.size(1) - 1
Expand All @@ -152,9 +152,9 @@ def tflops(
# (QK^T)V, d(QK^T) = dOV^T, dV = (QK^T)^TdO,
f2 += 2 * self.num_heads * hidden_dim * seq_len**2 // ratio
if self.mode == Mode.FWD:
tflops = f1 + f2 # computes (QK^T) and (QK^T)V
flops = f1 + f2 # computes (QK^T) and (QK^T)V
elif self.mode == Mode.BWD:
tflops = 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T)
flops = 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T)
elif self.mode == Mode.FWD_BWD:
tflops = 4 * f1 + 3 * f2
return tflops / metrics.latency * 1e-9
flops = 4 * f1 + 3 * f2
return flops
14 changes: 6 additions & 8 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,14 +651,6 @@ def __init__(
if self.tb_args.precision == "bypass":
self.tb_args.precision = self.DEFAULT_PRECISION
self.dtype = PRECISION_DTYPE_MAPPING.get(self.tb_args.precision, None)
self.DEFAULT_METRICS.extend(
[
x
for x in REGISTERED_METRICS.get(self.name, [])
if x not in BUILTIN_METRICS
]
)
self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS))
if self.tb_args.baseline:
BASELINE_BENCHMARKS[self.name] = self.tb_args.baseline
self._only = _split_params_by_comma(self.tb_args.only)
Expand Down Expand Up @@ -1607,6 +1599,10 @@ def _compile_time_in_task(
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
if self.has_metric("flops"):
flops = self.flops(fn_name, example_inputs, metrics)
return flops / metrics.latency / 1e12 * 1e3

def _get_flops(self, func: Callable) -> float:
"""By default, use the torch.__dispatch__ based flops counter."""
from torch.utils.flop_counter import FlopCounterMode
Expand Down Expand Up @@ -1678,4 +1674,6 @@ def has_bwd(cls) -> bool:

@classmethod
def has_metric(cls, metric_name: str) -> bool:
if metric_name == "tflops":
return bool(getattr(cls, "flops", None))
return bool(getattr(cls, metric_name, None))
Loading