diff --git a/docs/kineto_trace.md b/docs/kineto_trace.md
new file mode 100644
index 00000000..e38a2cde
--- /dev/null
+++ b/docs/kineto_trace.md
@@ -0,0 +1,28 @@
+# Kineto Trace Analysis with TritonBench
+
+TritonBench supports generating a Kineto trace file for each `` pair.
+For example, the following command will generate 6 Kineto traces, as it is running 2 inputs(`--num-inputs 2`) with 3 impls (`flash_v3,cudnn,triton_tutorial_flash_v2`).
+
+```
+$ python run.py --op flash_attention --num-inputs 2 --metrics kineto_trace --only flash_v3,cudnn,triton_tutorial_flash_v2
+
+ (Batch, Heads, SeqLen, Dhead) flash_v3-kineto_trace cudnn_90100-kineto_trace triton_tutorial_flash_v2-kineto_trace
+------------------------------- --------------------------------------------------------- ------------------------------------------------------ -------------------------------------------------------------------------
+ (4, 48, 128, 64) /tmp/tritonbench/flash_attention/kineto_traces/flash_v3_0 /tmp/tritonbench/flash_attention/kineto_traces/cudnn_0 /tmp/tritonbench/flash_attention/kineto_traces/triton_tutorial_flash_v2_0
+ (4, 48, 256, 64) /tmp/tritonbench/flash_attention/kineto_traces/flash_v3_1 /tmp/tritonbench/flash_attention/kineto_traces/cudnn_1 /tmp/tritonbench/flash_attention/kineto_traces/triton_tutorial_flash_v2_1
+```
+
+The output table shows the directory where the Kineto trace file is stored.
+
+## Example Kineto Trace Analysis
+
+Opening the trace file with Chrome Trace Viewer, we need to first separate the profiling iteration with the warm-up iterations.
+The profiling iteration runs after all warm-up iteraions and is labeled by `ProfilerStep#`.
+
+![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_1.png "Kineto Trace - Global View")
+
+Zooming into the profile iteration, we find two GPU kernels launched. The first one corresponds to the L2 Cache flush to clear the cache.
+The second one corresponds to the actual computation kernel, which is from CUDNN in this flash_attention operator.
+
+![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_2.png "Kineto Trace - Zoomed into Profile Iteration")
+
diff --git a/tritonbench/components/kineto/trace.py b/tritonbench/components/kineto/trace.py
index 25b0b34c..fb661c56 100644
--- a/tritonbench/components/kineto/trace.py
+++ b/tritonbench/components/kineto/trace.py
@@ -19,6 +19,54 @@
from .fb.run_utils import trace_handler
+def do_bench_kineto_cudagraph(
+ fn, warmup, grad_to_none, profile_opts, output_dir
+) -> str:
+ activity_groups = [
+ profiler.ProfilerActivity.CUDA,
+ profiler.ProfilerActivity.CPU,
+ ]
+ with torch.cuda.stream(torch.cuda.Stream()):
+ # step 1 - construct a cuda graph
+ g = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(g):
+ if grad_to_none is not None:
+ for x in grad_to_none:
+ x.grad = None
+ fn()
+ torch.cuda.synchronize()
+ prefix = f"tritonbench_cudagraph_{fn._name}"
+ name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json"
+ # step 2 - profile cuda graph launch with kineto
+ with profiler.profile(
+ schedule=profiler.schedule(wait=0, warmup=warmup, active=1, repeat=1),
+ activities=activity_groups,
+ record_shapes=profile_opts["record_shapes"],
+ profile_memory=profile_opts["profile_memory"],
+ with_stack=profile_opts["with_stack"],
+ with_flops=profile_opts["with_flops"],
+ with_modules=profile_opts["with_modules"],
+ on_trace_ready=(
+ partial(trace_handler, name)
+ if not hasattr(torch.version, "git_version")
+ else profiler.tensorboard_trace_handler(output_dir)
+ ),
+ ) as prof:
+ for _i in range(warmup + 1):
+ # we don't want `fn` to accumulate gradient values
+ # if it contains a backward pass. So we clear the
+ # provided gradients
+ if grad_to_none is not None:
+ for x in grad_to_none:
+ x.grad = None
+ g.replay()
+ prof.step()
+ if not hasattr(torch.version, "git_version"):
+ return f"https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/{name}.gz&bucket=pyper_traces"
+ else:
+ return output_dir
+
+
def do_bench_kineto(
fn: Callable,
warmup=25,
@@ -26,6 +74,7 @@ def do_bench_kineto(
fast_flush=True,
profile_opts=None,
output_dir=None,
+ use_cuda_graphs: bool = False,
) -> str:
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
@@ -44,6 +93,8 @@ def do_bench_kineto(
:param output_dir: Output directory to store the trace
:type output_dir: str, optional
"""
+ if profile_opts is None:
+ profile_opts = DEFAULT_PROFILE_OPTS
import torch
fn()
@@ -70,12 +121,15 @@ def do_bench_kineto(
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
+ if use_cuda_graphs:
+ return do_bench_kineto_cudagraph(
+ fn, n_warmup, grad_to_none, profile_opts, output_dir
+ )
+
activity_groups = [
profiler.ProfilerActivity.CUDA,
profiler.ProfilerActivity.CPU,
]
- if profile_opts is None:
- profile_opts = DEFAULT_PROFILE_OPTS
prefix = f"tritonbench_{fn._name}"
name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json"
with profiler.profile(
@@ -92,8 +146,6 @@ def do_bench_kineto(
else profiler.tensorboard_trace_handler(output_dir)
),
) as prof:
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
for i in range(n_warmup + 1):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
diff --git a/tritonbench/operators/cross_entropy/operator.py b/tritonbench/operators/cross_entropy/operator.py
index a32ab277..96192384 100644
--- a/tritonbench/operators/cross_entropy/operator.py
+++ b/tritonbench/operators/cross_entropy/operator.py
@@ -29,7 +29,6 @@ def __init__(
self.T = 2048
self.baseline_model = CrossEntropyLoss()
self.liger_model = LigerCrossEntropyLoss()
- self.use_cuda_graphs = False
def get_input_iter(self) -> Generator:
for V in [2**i for i in range(12, 18)]:
diff --git a/tritonbench/operators/embedding/operator.py b/tritonbench/operators/embedding/operator.py
index 7779c082..6ed8aec6 100644
--- a/tritonbench/operators/embedding/operator.py
+++ b/tritonbench/operators/embedding/operator.py
@@ -27,7 +27,6 @@ def __init__(
# they are generated later
self.baseline_op = None
self.liger_op = None
- self.use_cuda_graphs = False
def get_input_iter(self) -> Generator:
for B, T, D in [(32, 512, 768), (8, 2048, 4096)]:
diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py
index e92127a2..56ea8657 100644
--- a/tritonbench/operators/flash_attention/operator.py
+++ b/tritonbench/operators/flash_attention/operator.py
@@ -168,9 +168,7 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
- self.use_cuda_graphs = False
args = parse_op_args(self.extra_args)
- self.use_cuda_graphs = False
self.BATCH = args.batch
self.SEQ_LEN = args.seq_len
self.H = args.n_heads
diff --git a/tritonbench/operators/fused_linear_cross_entropy/operator.py b/tritonbench/operators/fused_linear_cross_entropy/operator.py
index 9f359345..4503d917 100644
--- a/tritonbench/operators/fused_linear_cross_entropy/operator.py
+++ b/tritonbench/operators/fused_linear_cross_entropy/operator.py
@@ -78,7 +78,6 @@ def __init__(
self.liger_model = LigerLMHeadCE(
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
).to(self.device)
- self.use_cuda_graphs = False
def get_input_iter(self) -> Generator:
for BT in [2**i for i in range(12, 16)]:
diff --git a/tritonbench/operators/fused_linear_jsd/operator.py b/tritonbench/operators/fused_linear_jsd/operator.py
index 7ebdcc3b..c26d3f0c 100644
--- a/tritonbench/operators/fused_linear_jsd/operator.py
+++ b/tritonbench/operators/fused_linear_jsd/operator.py
@@ -146,8 +146,6 @@ def __init__(
self.liger_op.teacher_lin.weight.data
) = torch.rand(self.V, self.H, device=self.device, dtype=self.dtype)
- self.use_cuda_graphs = False
-
def get_input_iter(self) -> Generator:
for BT in [2**i for i in range(10, 14)]:
student_input = torch.rand(
diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py
index 9613bc96..640e4185 100644
--- a/tritonbench/operators/geglu/operator.py
+++ b/tritonbench/operators/geglu/operator.py
@@ -40,7 +40,6 @@ def __init__(
self.liger_model = (
LigerGEGLUMLP(self.llama_config).to(self.device).to(self.dtype)
)
- self.use_cuda_graphs = False
def get_input_iter(self) -> Generator:
for T in [2**i for i in range(10, 14)]:
diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py
index 4082a814..d594c6e4 100644
--- a/tritonbench/operators/gemm/operator.py
+++ b/tritonbench/operators/gemm/operator.py
@@ -141,7 +141,6 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
- self.use_cuda_graphs = False
gemm_args = parse_args(self.extra_args)
self.layout = gemm_args.layout
if IS_FBCODE and tb_args.production_shapes:
diff --git a/tritonbench/operators/grouped_gemm/operator.py b/tritonbench/operators/grouped_gemm/operator.py
index 0ab2604c..0b695d45 100644
--- a/tritonbench/operators/grouped_gemm/operator.py
+++ b/tritonbench/operators/grouped_gemm/operator.py
@@ -18,7 +18,6 @@
class Operator(BenchmarkOperator):
DEFAULT_PRECISION = "fp16"
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
- use_cuda_graphs = False
@register_benchmark(baseline=True)
def torch(self, group_A, group_B):
diff --git a/tritonbench/operators/jagged_layer_norm/operator.py b/tritonbench/operators/jagged_layer_norm/operator.py
index d34f4896..36252c58 100644
--- a/tritonbench/operators/jagged_layer_norm/operator.py
+++ b/tritonbench/operators/jagged_layer_norm/operator.py
@@ -40,10 +40,6 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"
- use_cuda_graphs = (
- False # allows for a GPU/CPU sync, caused by methods like torch.unbind
- )
-
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
diff --git a/tritonbench/operators/jagged_mean/operator.py b/tritonbench/operators/jagged_mean/operator.py
index c9c975d2..cbcc8b74 100644
--- a/tritonbench/operators/jagged_mean/operator.py
+++ b/tritonbench/operators/jagged_mean/operator.py
@@ -96,10 +96,6 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"
- use_cuda_graphs = (
- False # enables GPU/CPU sync (for methods like NestedTensor unbind)
- )
-
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
diff --git a/tritonbench/operators/jagged_softmax/operator.py b/tritonbench/operators/jagged_softmax/operator.py
index cc3284c0..02299d75 100644
--- a/tritonbench/operators/jagged_softmax/operator.py
+++ b/tritonbench/operators/jagged_softmax/operator.py
@@ -75,10 +75,6 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy", "best_config"]
DEFAULT_PRECISION = "fp32"
- use_cuda_graphs = (
- False # enables GPU/CPU sync (for methods like NestedTensor unbind)
- )
-
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
diff --git a/tritonbench/operators/jagged_sum/operator.py b/tritonbench/operators/jagged_sum/operator.py
index 36d519db..97d9506f 100644
--- a/tritonbench/operators/jagged_sum/operator.py
+++ b/tritonbench/operators/jagged_sum/operator.py
@@ -95,9 +95,6 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer):
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy", "best_config"]
DEFAULT_PRECISION = "fp32"
- use_cuda_graphs = (
- False # enables GPU/CPU sync (for methods like NestedTensor unbind)
- )
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
diff --git a/tritonbench/operators/jsd/operator.py b/tritonbench/operators/jsd/operator.py
index 5a42f294..06b863b8 100644
--- a/tritonbench/operators/jsd/operator.py
+++ b/tritonbench/operators/jsd/operator.py
@@ -65,7 +65,6 @@ def __init__(
self.T = 2048
self.baseline_op = TorchJSD()
self.liger_op = LigerJSD()
- self.use_cuda_graphs = False
def get_input_iter(self) -> Generator:
for V in [2**i for i in range(12, 18)]:
diff --git a/tritonbench/operators/kl_div/operator.py b/tritonbench/operators/kl_div/operator.py
index 0d600cce..24fcef4e 100644
--- a/tritonbench/operators/kl_div/operator.py
+++ b/tritonbench/operators/kl_div/operator.py
@@ -27,7 +27,6 @@ def __init__(
self.T = 512
self.baseline_op = torch.nn.KLDivLoss(reduction="batchmean").to(self.device)
self.liger_op = LigerKLDIVLoss(reduction="batchmean").to(self.device)
- self.use_cuda_graphs = False
def get_input_iter(self) -> Generator:
for V in [2**i for i in range(12, 18)]:
diff --git a/tritonbench/operators/rms_norm/operator.py b/tritonbench/operators/rms_norm/operator.py
index 0c62d39d..adde1f48 100644
--- a/tritonbench/operators/rms_norm/operator.py
+++ b/tritonbench/operators/rms_norm/operator.py
@@ -45,7 +45,6 @@ def __init__(
# they are generated later
self.llama_rms_op = None
self.liger_rms_op = None
- self.use_cuda_graphs = False
def get_input_iter(self) -> Generator:
for H in [2**i for i in range(10, 16)]:
diff --git a/tritonbench/operators/rope/operator.py b/tritonbench/operators/rope/operator.py
index 174626ac..ac9073e6 100644
--- a/tritonbench/operators/rope/operator.py
+++ b/tritonbench/operators/rope/operator.py
@@ -30,7 +30,6 @@ def __init__(
# they are generated later
self.baseline_op = None
self.liger_op = None
- self.use_cuda_graphs = False
self.num_q_heads = 32
self.num_kv_heads = 8
diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py
index b21fede9..ab53ab76 100644
--- a/tritonbench/operators/swiglu/operator.py
+++ b/tritonbench/operators/swiglu/operator.py
@@ -39,7 +39,6 @@ def __init__(
self.liger_op = (
LigerSwiGLUMLP(config=llama_config).to(self.device).to(self.dtype)
)
- self.use_cuda_graphs = False
def get_input_iter(self) -> Generator:
for seq_len in [2**i for i in range(10, 14)]:
diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py
index 0a0aa962..fea2ed43 100644
--- a/tritonbench/utils/triton_op.py
+++ b/tritonbench/utils/triton_op.py
@@ -1470,6 +1470,7 @@ def kineto_trace(self, input_id: int, fn: Callable) -> str:
fn=fn,
grad_to_none=self.get_grad_to_none(self.example_inputs),
output_dir=kineto_output_dir,
+ use_cuda_graphs=self.use_cuda_graphs,
)
def compile_time(