Skip to content

Commit

Permalink
Support flops metric in proton profiling (#111)
Browse files Browse the repository at this point in the history
Summary:

Support tflops+proton metric on more operators.

Now we require operator to use `def flops(): ` and the framework will add `tflops` metric automatically.
The same interface will be used to support `flops` in Proton.

Differential Revision: D67062546
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Dec 11, 2024
1 parent 7b67b0a commit 7670b39
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 32 deletions.
4 changes: 2 additions & 2 deletions tritonbench/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def gbps(
return numel / metrics.latency * 1e3

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
_, mat1, mat2 = example_inputs
m, k = mat1.size()
k, n = mat2.size()
flops = m * k * 2 * n
return flops / metrics.latency / 1e12 * 1e3
return flops

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/bf16xint16_gemm/bf16xint16_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@ def nbytes(t):
return gb / metrics.latency * 1e3

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
a, b = example_inputs
m, k = a.size()
_, n = b.size()
flops = 2 * m * n * k
return flops / metrics.latency / 1e12 * 1e3
return flops

def plot(self):
@triton.testing.perf_report(
Expand Down
7 changes: 0 additions & 7 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,6 @@ def causal_mask(b, h, q_idx, kv_idx):

return lambda: flex_attention(q, k, v, block_mask=block_mask)

@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
analytic_flops = self.flops(fn_name, example_inputs, metrics)
return analytic_flops / metrics.latency * 1e-9

@register_metric(x_only=True)
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
Expand Down
5 changes: 2 additions & 3 deletions tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,9 @@ def get_input_iter(self) -> Generator:
yield (q, k, v)

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
H = self.embedding_dim // self.D_HEAD
flops_per_matmul = 2.0 * self.BATCH * H * self.N_CTX * self.N_CTX * self.D_HEAD
tflops = 2 * flops_per_matmul
return tflops / metrics.latency * 1e-9
return 2 * flops_per_matmul
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ def _impl(x1, x2, wq, w_scale, wd):
return lambda: _impl(x1, x2, wq, w_scale, wd)

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> List[float]:
x1, _, wq, _, _ = example_inputs
m, k = x1.size()
n, k = wq.size()
flops = m * k * 2 * n
return flops / metrics.latency / 1e12 * 1e3
return flops

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def nbytes(t):
return gb / metrics.latency * 1e3

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
a, b = example_inputs
m, k = a.size()
_, n = b.size()
flops = 2 * m * n * k
return flops / metrics.latency / 1e12 * 1e3
return flops

def plot(self):
@triton.testing.perf_report(
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_gemm_blockwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ def _cutlass(self, xq, wq, x_scale, w_scale) -> Callable:
return lambda: cutlass_fp8_block(xq, wq, x_scale, w_scale)

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> List[float]:
xq, wq, _, _ = example_inputs
m, k = xq.size()
n, k = wq.size()
flops = m * k * 2 * n
return flops / metrics.latency / 1e12 * 1e3
return flops

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ def _cublas(self, xq, wq, x_scale, w_scale) -> Callable:
# return lambda: _cublass(xq, wq, x_scale, w_scale)

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> List[float]:
xq, wq, _, _ = example_inputs
m, k = xq.size()
n, k = wq.size()
flops = m * k * 2 * n
return flops / metrics.latency / 1e12 * 1e3
return flops

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/int4_gemm/int4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def nbytes(t):
return gb / metrics.latency * 1e3

@register_metric()
def tflops(
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
a, b = example_inputs
B, m, k = a.size()
m = B * m
_, n = b.size()
flops = 2 * m * n * k
return flops / metrics.latency / 1e12 * 1e3
return flops

def plot(self):
@triton.testing.perf_report(
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/low_mem_dropout/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
)

@register_metric()
def tflops(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
def flops(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
p, a = example_inputs
flops = 2 * len(a)
return flops / metrics.latency
return flops

@register_benchmark()
def triton_dropout(self, p, x):
Expand Down
12 changes: 6 additions & 6 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
return fn

@register_metric()
def tflops(
def flops(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> float:
ratio = 2.0 # triangular masking
f1 = 0.0
f2 = 0.0
jagged = True
q, k, v, seq_offsets, timestamps, num_targets = example_inputs
q, k, v, seq_offsets, timestamps, num_targets, seq_len = example_inputs
_, nheads, attn_dim = q.shape
_, _, hidden_dim = v.shape
max_seqlen = timestamps.size(1) - 1
Expand All @@ -152,9 +152,9 @@ def tflops(
# (QK^T)V, d(QK^T) = dOV^T, dV = (QK^T)^TdO,
f2 += 2 * self.num_heads * hidden_dim * seq_len**2 // ratio
if self.mode == Mode.FWD:
tflops = f1 + f2 # computes (QK^T) and (QK^T)V
flops = f1 + f2 # computes (QK^T) and (QK^T)V
elif self.mode == Mode.BWD:
tflops = 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T)
flops = 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T)
elif self.mode == Mode.FWD_BWD:
tflops = 4 * f1 + 3 * f2
return tflops / metrics.latency * 1e-9
flops = 4 * f1 + 3 * f2
return flops
4 changes: 4 additions & 0 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,10 @@ def _compile_time_in_task(
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
if self.has_metric("flops"):
flops = self.flops(fn_name, example_inputs, metrics)
return flops / metrics.latency / 1e12 * 1e3

def _get_flops(self, func: Callable) -> float:
"""By default, use the torch.__dispatch__ based flops counter."""
from torch.utils.flop_counter import FlopCounterMode
Expand Down

0 comments on commit 7670b39

Please sign in to comment.