From 4ddf056dfac41498b8a64173c11a281be0300079 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Sat, 7 Dec 2024 10:10:40 -0500 Subject: [PATCH 1/7] Add doc to kineto trace metric --- docs/kineto_trace.md | 28 ++++++++++++++++++++++++++ tritonbench/components/kineto/trace.py | 2 -- 2 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 docs/kineto_trace.md diff --git a/docs/kineto_trace.md b/docs/kineto_trace.md new file mode 100644 index 00000000..162a48bc --- /dev/null +++ b/docs/kineto_trace.md @@ -0,0 +1,28 @@ +# Kineto Trace Analysis with TritonBench + +TritonBench supports generating a Kineto trace file for each `` pair. +For example, the following command will generate 6 Kineto traces, as it is running 2 inputs(`--num-inputs 2`) with 3 impls (`flash_v3,cudnn,triton_tutorial_flash_v2`). + +``` +$ python run.py --op flash_attention --num-inputs 2 --metrics kineto_trace --only flash_v3,cudnn,triton_tutorial_flash_v2 + + (Batch, Heads, SeqLen, Dhead) flash_v3-kineto_trace cudnn_90100-kineto_trace triton_tutorial_flash_v2-kineto_trace +------------------------------- --------------------------------------------------------- ------------------------------------------------------ ------------------------------------------------------------------------- + (4, 48, 128, 64) /tmp/tritonbench/flash_attention/kineto_traces/flash_v3_0 /tmp/tritonbench/flash_attention/kineto_traces/cudnn_0 /tmp/tritonbench/flash_attention/kineto_traces/triton_tutorial_flash_v2_0 + (4, 48, 256, 64) /tmp/tritonbench/flash_attention/kineto_traces/flash_v3_1 /tmp/tritonbench/flash_attention/kineto_traces/cudnn_1 /tmp/tritonbench/flash_attention/kineto_traces/triton_tutorial_flash_v2_1 +``` + +The output table shows the directory where the Kineto trace file is stored. + +## Example Kineto Trace Analysis + +Opening the trace file with Chrome Trace Viewer, we need to first separate the profiling iteration with the warm-up iterations. +The profiling iteration runs after all warm-up iteraions and is labeled by `ProfilerStep#`. + +![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_1.png "Kineto Trace - Global View") + +Zooming into the profile iteration, we find two GPU kernels launched. The first one corresponds to the L2 Cache clearance. +The second one corresponds to the actual computation kernel, which is from CUDNN in this flash_attention operator. + +![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_2.png "Kineto Trace - Zoomed into Profile Iteration") + diff --git a/tritonbench/components/kineto/trace.py b/tritonbench/components/kineto/trace.py index 25b0b34c..51056547 100644 --- a/tritonbench/components/kineto/trace.py +++ b/tritonbench/components/kineto/trace.py @@ -92,8 +92,6 @@ def do_bench_kineto( else profiler.tensorboard_trace_handler(output_dir) ), ) as prof: - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) for i in range(n_warmup + 1): # we don't want `fn` to accumulate gradient values # if it contains a backward pass. So we clear the From fb7fa6e01d52082205054bc2e1e565677cbbb7b8 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Sat, 7 Dec 2024 10:12:56 -0500 Subject: [PATCH 2/7] Fix typo --- docs/kineto_trace.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/kineto_trace.md b/docs/kineto_trace.md index 162a48bc..e38a2cde 100644 --- a/docs/kineto_trace.md +++ b/docs/kineto_trace.md @@ -21,7 +21,7 @@ The profiling iteration runs after all warm-up iteraions and is labeled by `Prof ![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_1.png "Kineto Trace - Global View") -Zooming into the profile iteration, we find two GPU kernels launched. The first one corresponds to the L2 Cache clearance. +Zooming into the profile iteration, we find two GPU kernels launched. The first one corresponds to the L2 Cache flush to clear the cache. The second one corresponds to the actual computation kernel, which is from CUDNN in this flash_attention operator. ![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_2.png "Kineto Trace - Zoomed into Profile Iteration") From 06f62d8d199db2d958054a24eebcd9046458ca2a Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Sat, 7 Dec 2024 10:31:59 -0500 Subject: [PATCH 3/7] Add kineto profiling with cudagraph --- tritonbench/components/kineto/trace.py | 54 ++++++++++++++++++++++++++ tritonbench/utils/triton_op.py | 1 + 2 files changed, 55 insertions(+) diff --git a/tritonbench/components/kineto/trace.py b/tritonbench/components/kineto/trace.py index 51056547..7c2910f4 100644 --- a/tritonbench/components/kineto/trace.py +++ b/tritonbench/components/kineto/trace.py @@ -19,6 +19,57 @@ from .fb.run_utils import trace_handler +def do_bench_kineto_cudagraph(fn, warmup, grad_to_none, profile_opts, output_dir) -> str: + activity_groups = [ + profiler.ProfilerActivity.CUDA, + profiler.ProfilerActivity.CPU, + ] + with torch.cuda.stream(torch.cuda.Stream()): + # step 1 - warmup + fn() + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + # step 2 - construct a cuda graph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + prefix = f"tritonbench_cudagraph_{fn._name}" + name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json" + # step 3 - profile cuda graph with kineto + with profiler.profile( + schedule=profiler.schedule(wait=0, warmup=warmup, active=1, repeat=1), + activities=activity_groups, + record_shapes=profile_opts["record_shapes"], + profile_memory=profile_opts["profile_memory"], + with_stack=profile_opts["with_stack"], + with_flops=profile_opts["with_flops"], + with_modules=profile_opts["with_modules"], + on_trace_ready=( + partial(trace_handler, name) + if not hasattr(torch.version, "git_version") + else profiler.tensorboard_trace_handler(output_dir) + ), + ) as prof: + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + g.replay() + prof.step() + if not hasattr(torch.version, "git_version"): + return f"https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/{name}.gz&bucket=pyper_traces" + else: + return output_dir + def do_bench_kineto( fn: Callable, warmup=25, @@ -26,6 +77,7 @@ def do_bench_kineto( fast_flush=True, profile_opts=None, output_dir=None, + use_cuda_graphs: bool = False, ) -> str: """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with @@ -44,6 +96,8 @@ def do_bench_kineto( :param output_dir: Output directory to store the trace :type output_dir: str, optional """ + if use_cuda_graphs: + return do_bench_kineto_cudagraph(fn, warmup, grad_to_none, profile_opts, output_dir) import torch fn() diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 0a0aa962..fea2ed43 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -1470,6 +1470,7 @@ def kineto_trace(self, input_id: int, fn: Callable) -> str: fn=fn, grad_to_none=self.get_grad_to_none(self.example_inputs), output_dir=kineto_output_dir, + use_cuda_graphs=self.use_cuda_graphs, ) def compile_time( From 5be21f0419b39f4112d74a4cbf210dfebbc28e06 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Sat, 7 Dec 2024 10:37:28 -0500 Subject: [PATCH 4/7] Add profiler opts --- tritonbench/components/kineto/trace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tritonbench/components/kineto/trace.py b/tritonbench/components/kineto/trace.py index 7c2910f4..8b60b6e6 100644 --- a/tritonbench/components/kineto/trace.py +++ b/tritonbench/components/kineto/trace.py @@ -96,6 +96,8 @@ def do_bench_kineto( :param output_dir: Output directory to store the trace :type output_dir: str, optional """ + if profile_opts is None: + profile_opts = DEFAULT_PROFILE_OPTS if use_cuda_graphs: return do_bench_kineto_cudagraph(fn, warmup, grad_to_none, profile_opts, output_dir) import torch @@ -128,8 +130,6 @@ def do_bench_kineto( profiler.ProfilerActivity.CUDA, profiler.ProfilerActivity.CPU, ] - if profile_opts is None: - profile_opts = DEFAULT_PROFILE_OPTS prefix = f"tritonbench_{fn._name}" name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json" with profiler.profile( From 48aaba4306523416b26387c28a1a6f8474e68c22 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Sat, 7 Dec 2024 12:45:47 -0500 Subject: [PATCH 5/7] Remove use cuda graph false --- tritonbench/operators/cross_entropy/operator.py | 1 - tritonbench/operators/embedding/operator.py | 1 - tritonbench/operators/flash_attention/operator.py | 2 -- tritonbench/operators/fused_linear_cross_entropy/operator.py | 1 - tritonbench/operators/fused_linear_jsd/operator.py | 2 -- tritonbench/operators/geglu/operator.py | 1 - tritonbench/operators/gemm/operator.py | 1 - tritonbench/operators/grouped_gemm/operator.py | 1 - tritonbench/operators/jagged_layer_norm/operator.py | 4 ---- tritonbench/operators/jagged_mean/operator.py | 4 ---- tritonbench/operators/jagged_softmax/operator.py | 4 ---- tritonbench/operators/jagged_sum/operator.py | 3 --- tritonbench/operators/jsd/operator.py | 1 - tritonbench/operators/kl_div/operator.py | 1 - tritonbench/operators/rms_norm/operator.py | 1 - tritonbench/operators/rope/operator.py | 1 - tritonbench/operators/swiglu/operator.py | 1 - 17 files changed, 30 deletions(-) diff --git a/tritonbench/operators/cross_entropy/operator.py b/tritonbench/operators/cross_entropy/operator.py index a32ab277..96192384 100644 --- a/tritonbench/operators/cross_entropy/operator.py +++ b/tritonbench/operators/cross_entropy/operator.py @@ -29,7 +29,6 @@ def __init__( self.T = 2048 self.baseline_model = CrossEntropyLoss() self.liger_model = LigerCrossEntropyLoss() - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for V in [2**i for i in range(12, 18)]: diff --git a/tritonbench/operators/embedding/operator.py b/tritonbench/operators/embedding/operator.py index 7779c082..6ed8aec6 100644 --- a/tritonbench/operators/embedding/operator.py +++ b/tritonbench/operators/embedding/operator.py @@ -27,7 +27,6 @@ def __init__( # they are generated later self.baseline_op = None self.liger_op = None - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for B, T, D in [(32, 512, 768), (8, 2048, 4096)]: diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index e92127a2..56ea8657 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -168,9 +168,7 @@ def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): super().__init__(tb_args, extra_args) - self.use_cuda_graphs = False args = parse_op_args(self.extra_args) - self.use_cuda_graphs = False self.BATCH = args.batch self.SEQ_LEN = args.seq_len self.H = args.n_heads diff --git a/tritonbench/operators/fused_linear_cross_entropy/operator.py b/tritonbench/operators/fused_linear_cross_entropy/operator.py index 9f359345..4503d917 100644 --- a/tritonbench/operators/fused_linear_cross_entropy/operator.py +++ b/tritonbench/operators/fused_linear_cross_entropy/operator.py @@ -78,7 +78,6 @@ def __init__( self.liger_model = LigerLMHeadCE( H=self.hidden_size, V=self.vocab_size, dtype=self.dtype ).to(self.device) - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for BT in [2**i for i in range(12, 16)]: diff --git a/tritonbench/operators/fused_linear_jsd/operator.py b/tritonbench/operators/fused_linear_jsd/operator.py index 7ebdcc3b..c26d3f0c 100644 --- a/tritonbench/operators/fused_linear_jsd/operator.py +++ b/tritonbench/operators/fused_linear_jsd/operator.py @@ -146,8 +146,6 @@ def __init__( self.liger_op.teacher_lin.weight.data ) = torch.rand(self.V, self.H, device=self.device, dtype=self.dtype) - self.use_cuda_graphs = False - def get_input_iter(self) -> Generator: for BT in [2**i for i in range(10, 14)]: student_input = torch.rand( diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index 9613bc96..640e4185 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -40,7 +40,6 @@ def __init__( self.liger_model = ( LigerGEGLUMLP(self.llama_config).to(self.device).to(self.dtype) ) - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for T in [2**i for i in range(10, 14)]: diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 4082a814..d594c6e4 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -141,7 +141,6 @@ def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): super().__init__(tb_args, extra_args) - self.use_cuda_graphs = False gemm_args = parse_args(self.extra_args) self.layout = gemm_args.layout if IS_FBCODE and tb_args.production_shapes: diff --git a/tritonbench/operators/grouped_gemm/operator.py b/tritonbench/operators/grouped_gemm/operator.py index 0ab2604c..0b695d45 100644 --- a/tritonbench/operators/grouped_gemm/operator.py +++ b/tritonbench/operators/grouped_gemm/operator.py @@ -18,7 +18,6 @@ class Operator(BenchmarkOperator): DEFAULT_PRECISION = "fp16" DEFAULT_METRICS = ["latency", "speedup", "accuracy"] - use_cuda_graphs = False @register_benchmark(baseline=True) def torch(self, group_A, group_B): diff --git a/tritonbench/operators/jagged_layer_norm/operator.py b/tritonbench/operators/jagged_layer_norm/operator.py index d34f4896..36252c58 100644 --- a/tritonbench/operators/jagged_layer_norm/operator.py +++ b/tritonbench/operators/jagged_layer_norm/operator.py @@ -40,10 +40,6 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy"] DEFAULT_PRECISION = "fp32" - use_cuda_graphs = ( - False # allows for a GPU/CPU sync, caused by methods like torch.unbind - ) - def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): diff --git a/tritonbench/operators/jagged_mean/operator.py b/tritonbench/operators/jagged_mean/operator.py index c9c975d2..cbcc8b74 100644 --- a/tritonbench/operators/jagged_mean/operator.py +++ b/tritonbench/operators/jagged_mean/operator.py @@ -96,10 +96,6 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy"] DEFAULT_PRECISION = "fp32" - use_cuda_graphs = ( - False # enables GPU/CPU sync (for methods like NestedTensor unbind) - ) - def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): diff --git a/tritonbench/operators/jagged_softmax/operator.py b/tritonbench/operators/jagged_softmax/operator.py index cc3284c0..02299d75 100644 --- a/tritonbench/operators/jagged_softmax/operator.py +++ b/tritonbench/operators/jagged_softmax/operator.py @@ -75,10 +75,6 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy", "best_config"] DEFAULT_PRECISION = "fp32" - use_cuda_graphs = ( - False # enables GPU/CPU sync (for methods like NestedTensor unbind) - ) - def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): diff --git a/tritonbench/operators/jagged_sum/operator.py b/tritonbench/operators/jagged_sum/operator.py index 36d519db..97d9506f 100644 --- a/tritonbench/operators/jagged_sum/operator.py +++ b/tritonbench/operators/jagged_sum/operator.py @@ -95,9 +95,6 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer): class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy", "best_config"] DEFAULT_PRECISION = "fp32" - use_cuda_graphs = ( - False # enables GPU/CPU sync (for methods like NestedTensor unbind) - ) def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None diff --git a/tritonbench/operators/jsd/operator.py b/tritonbench/operators/jsd/operator.py index 5a42f294..06b863b8 100644 --- a/tritonbench/operators/jsd/operator.py +++ b/tritonbench/operators/jsd/operator.py @@ -65,7 +65,6 @@ def __init__( self.T = 2048 self.baseline_op = TorchJSD() self.liger_op = LigerJSD() - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for V in [2**i for i in range(12, 18)]: diff --git a/tritonbench/operators/kl_div/operator.py b/tritonbench/operators/kl_div/operator.py index 0d600cce..24fcef4e 100644 --- a/tritonbench/operators/kl_div/operator.py +++ b/tritonbench/operators/kl_div/operator.py @@ -27,7 +27,6 @@ def __init__( self.T = 512 self.baseline_op = torch.nn.KLDivLoss(reduction="batchmean").to(self.device) self.liger_op = LigerKLDIVLoss(reduction="batchmean").to(self.device) - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for V in [2**i for i in range(12, 18)]: diff --git a/tritonbench/operators/rms_norm/operator.py b/tritonbench/operators/rms_norm/operator.py index 0c62d39d..adde1f48 100644 --- a/tritonbench/operators/rms_norm/operator.py +++ b/tritonbench/operators/rms_norm/operator.py @@ -45,7 +45,6 @@ def __init__( # they are generated later self.llama_rms_op = None self.liger_rms_op = None - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for H in [2**i for i in range(10, 16)]: diff --git a/tritonbench/operators/rope/operator.py b/tritonbench/operators/rope/operator.py index 174626ac..ac9073e6 100644 --- a/tritonbench/operators/rope/operator.py +++ b/tritonbench/operators/rope/operator.py @@ -30,7 +30,6 @@ def __init__( # they are generated later self.baseline_op = None self.liger_op = None - self.use_cuda_graphs = False self.num_q_heads = 32 self.num_kv_heads = 8 diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py index b21fede9..ab53ab76 100644 --- a/tritonbench/operators/swiglu/operator.py +++ b/tritonbench/operators/swiglu/operator.py @@ -39,7 +39,6 @@ def __init__( self.liger_op = ( LigerSwiGLUMLP(config=llama_config).to(self.device).to(self.dtype) ) - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for seq_len in [2**i for i in range(10, 14)]: From 93653a51359d9bd795a89c89dbcea278e628cd49 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Sat, 7 Dec 2024 13:46:09 -0500 Subject: [PATCH 6/7] Enable cuda graph mode tracing for Kineto --- tritonbench/components/kineto/trace.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tritonbench/components/kineto/trace.py b/tritonbench/components/kineto/trace.py index 8b60b6e6..a2fa6586 100644 --- a/tritonbench/components/kineto/trace.py +++ b/tritonbench/components/kineto/trace.py @@ -57,14 +57,15 @@ def do_bench_kineto_cudagraph(fn, warmup, grad_to_none, profile_opts, output_dir else profiler.tensorboard_trace_handler(output_dir) ), ) as prof: - # we don't want `fn` to accumulate gradient values - # if it contains a backward pass. So we clear the - # provided gradients - if grad_to_none is not None: - for x in grad_to_none: - x.grad = None - g.replay() - prof.step() + for _i in range(warmup + 1): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + g.replay() + prof.step() if not hasattr(torch.version, "git_version"): return f"https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/{name}.gz&bucket=pyper_traces" else: @@ -98,8 +99,6 @@ def do_bench_kineto( """ if profile_opts is None: profile_opts = DEFAULT_PROFILE_OPTS - if use_cuda_graphs: - return do_bench_kineto_cudagraph(fn, warmup, grad_to_none, profile_opts, output_dir) import torch fn() @@ -126,6 +125,9 @@ def do_bench_kineto( # compute number of warmup and repeat n_warmup = max(1, int(warmup / estimate_ms)) + if use_cuda_graphs: + return do_bench_kineto_cudagraph(fn, n_warmup, grad_to_none, profile_opts, output_dir) + activity_groups = [ profiler.ProfilerActivity.CUDA, profiler.ProfilerActivity.CPU, From 344b2632bae80c9bfe79a4d07a3e8f5f758e2378 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Sat, 7 Dec 2024 13:48:48 -0500 Subject: [PATCH 7/7] Bugfix --- tritonbench/components/kineto/trace.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tritonbench/components/kineto/trace.py b/tritonbench/components/kineto/trace.py index a2fa6586..fb661c56 100644 --- a/tritonbench/components/kineto/trace.py +++ b/tritonbench/components/kineto/trace.py @@ -19,20 +19,15 @@ from .fb.run_utils import trace_handler -def do_bench_kineto_cudagraph(fn, warmup, grad_to_none, profile_opts, output_dir) -> str: +def do_bench_kineto_cudagraph( + fn, warmup, grad_to_none, profile_opts, output_dir +) -> str: activity_groups = [ profiler.ProfilerActivity.CUDA, profiler.ProfilerActivity.CPU, ] with torch.cuda.stream(torch.cuda.Stream()): - # step 1 - warmup - fn() - if grad_to_none is not None: - for x in grad_to_none: - x.detach_() - x.requires_grad_(True) - x.grad = None - # step 2 - construct a cuda graph + # step 1 - construct a cuda graph g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): if grad_to_none is not None: @@ -42,7 +37,7 @@ def do_bench_kineto_cudagraph(fn, warmup, grad_to_none, profile_opts, output_dir torch.cuda.synchronize() prefix = f"tritonbench_cudagraph_{fn._name}" name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json" - # step 3 - profile cuda graph with kineto + # step 2 - profile cuda graph launch with kineto with profiler.profile( schedule=profiler.schedule(wait=0, warmup=warmup, active=1, repeat=1), activities=activity_groups, @@ -71,6 +66,7 @@ def do_bench_kineto_cudagraph(fn, warmup, grad_to_none, profile_opts, output_dir else: return output_dir + def do_bench_kineto( fn: Callable, warmup=25, @@ -126,7 +122,9 @@ def do_bench_kineto( # compute number of warmup and repeat n_warmup = max(1, int(warmup / estimate_ms)) if use_cuda_graphs: - return do_bench_kineto_cudagraph(fn, n_warmup, grad_to_none, profile_opts, output_dir) + return do_bench_kineto_cudagraph( + fn, n_warmup, grad_to_none, profile_opts, output_dir + ) activity_groups = [ profiler.ProfilerActivity.CUDA,