diff --git a/docs/kineto_trace.md b/docs/kineto_trace.md new file mode 100644 index 00000000..e38a2cde --- /dev/null +++ b/docs/kineto_trace.md @@ -0,0 +1,28 @@ +# Kineto Trace Analysis with TritonBench + +TritonBench supports generating a Kineto trace file for each `` pair. +For example, the following command will generate 6 Kineto traces, as it is running 2 inputs(`--num-inputs 2`) with 3 impls (`flash_v3,cudnn,triton_tutorial_flash_v2`). + +``` +$ python run.py --op flash_attention --num-inputs 2 --metrics kineto_trace --only flash_v3,cudnn,triton_tutorial_flash_v2 + + (Batch, Heads, SeqLen, Dhead) flash_v3-kineto_trace cudnn_90100-kineto_trace triton_tutorial_flash_v2-kineto_trace +------------------------------- --------------------------------------------------------- ------------------------------------------------------ ------------------------------------------------------------------------- + (4, 48, 128, 64) /tmp/tritonbench/flash_attention/kineto_traces/flash_v3_0 /tmp/tritonbench/flash_attention/kineto_traces/cudnn_0 /tmp/tritonbench/flash_attention/kineto_traces/triton_tutorial_flash_v2_0 + (4, 48, 256, 64) /tmp/tritonbench/flash_attention/kineto_traces/flash_v3_1 /tmp/tritonbench/flash_attention/kineto_traces/cudnn_1 /tmp/tritonbench/flash_attention/kineto_traces/triton_tutorial_flash_v2_1 +``` + +The output table shows the directory where the Kineto trace file is stored. + +## Example Kineto Trace Analysis + +Opening the trace file with Chrome Trace Viewer, we need to first separate the profiling iteration with the warm-up iterations. +The profiling iteration runs after all warm-up iteraions and is labeled by `ProfilerStep#`. + +![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_1.png "Kineto Trace - Global View") + +Zooming into the profile iteration, we find two GPU kernels launched. The first one corresponds to the L2 Cache flush to clear the cache. +The second one corresponds to the actual computation kernel, which is from CUDNN in this flash_attention operator. + +![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_2.png "Kineto Trace - Zoomed into Profile Iteration") + diff --git a/tritonbench/components/kineto/trace.py b/tritonbench/components/kineto/trace.py index 25b0b34c..fb661c56 100644 --- a/tritonbench/components/kineto/trace.py +++ b/tritonbench/components/kineto/trace.py @@ -19,6 +19,54 @@ from .fb.run_utils import trace_handler +def do_bench_kineto_cudagraph( + fn, warmup, grad_to_none, profile_opts, output_dir +) -> str: + activity_groups = [ + profiler.ProfilerActivity.CUDA, + profiler.ProfilerActivity.CPU, + ] + with torch.cuda.stream(torch.cuda.Stream()): + # step 1 - construct a cuda graph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + prefix = f"tritonbench_cudagraph_{fn._name}" + name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json" + # step 2 - profile cuda graph launch with kineto + with profiler.profile( + schedule=profiler.schedule(wait=0, warmup=warmup, active=1, repeat=1), + activities=activity_groups, + record_shapes=profile_opts["record_shapes"], + profile_memory=profile_opts["profile_memory"], + with_stack=profile_opts["with_stack"], + with_flops=profile_opts["with_flops"], + with_modules=profile_opts["with_modules"], + on_trace_ready=( + partial(trace_handler, name) + if not hasattr(torch.version, "git_version") + else profiler.tensorboard_trace_handler(output_dir) + ), + ) as prof: + for _i in range(warmup + 1): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + g.replay() + prof.step() + if not hasattr(torch.version, "git_version"): + return f"https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/{name}.gz&bucket=pyper_traces" + else: + return output_dir + + def do_bench_kineto( fn: Callable, warmup=25, @@ -26,6 +74,7 @@ def do_bench_kineto( fast_flush=True, profile_opts=None, output_dir=None, + use_cuda_graphs: bool = False, ) -> str: """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with @@ -44,6 +93,8 @@ def do_bench_kineto( :param output_dir: Output directory to store the trace :type output_dir: str, optional """ + if profile_opts is None: + profile_opts = DEFAULT_PROFILE_OPTS import torch fn() @@ -70,12 +121,15 @@ def do_bench_kineto( # compute number of warmup and repeat n_warmup = max(1, int(warmup / estimate_ms)) + if use_cuda_graphs: + return do_bench_kineto_cudagraph( + fn, n_warmup, grad_to_none, profile_opts, output_dir + ) + activity_groups = [ profiler.ProfilerActivity.CUDA, profiler.ProfilerActivity.CPU, ] - if profile_opts is None: - profile_opts = DEFAULT_PROFILE_OPTS prefix = f"tritonbench_{fn._name}" name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json" with profiler.profile( @@ -92,8 +146,6 @@ def do_bench_kineto( else profiler.tensorboard_trace_handler(output_dir) ), ) as prof: - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) for i in range(n_warmup + 1): # we don't want `fn` to accumulate gradient values # if it contains a backward pass. So we clear the diff --git a/tritonbench/operators/cross_entropy/operator.py b/tritonbench/operators/cross_entropy/operator.py index a32ab277..96192384 100644 --- a/tritonbench/operators/cross_entropy/operator.py +++ b/tritonbench/operators/cross_entropy/operator.py @@ -29,7 +29,6 @@ def __init__( self.T = 2048 self.baseline_model = CrossEntropyLoss() self.liger_model = LigerCrossEntropyLoss() - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for V in [2**i for i in range(12, 18)]: diff --git a/tritonbench/operators/embedding/operator.py b/tritonbench/operators/embedding/operator.py index 7779c082..6ed8aec6 100644 --- a/tritonbench/operators/embedding/operator.py +++ b/tritonbench/operators/embedding/operator.py @@ -27,7 +27,6 @@ def __init__( # they are generated later self.baseline_op = None self.liger_op = None - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for B, T, D in [(32, 512, 768), (8, 2048, 4096)]: diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index e92127a2..56ea8657 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -168,9 +168,7 @@ def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): super().__init__(tb_args, extra_args) - self.use_cuda_graphs = False args = parse_op_args(self.extra_args) - self.use_cuda_graphs = False self.BATCH = args.batch self.SEQ_LEN = args.seq_len self.H = args.n_heads diff --git a/tritonbench/operators/fused_linear_cross_entropy/operator.py b/tritonbench/operators/fused_linear_cross_entropy/operator.py index 9f359345..4503d917 100644 --- a/tritonbench/operators/fused_linear_cross_entropy/operator.py +++ b/tritonbench/operators/fused_linear_cross_entropy/operator.py @@ -78,7 +78,6 @@ def __init__( self.liger_model = LigerLMHeadCE( H=self.hidden_size, V=self.vocab_size, dtype=self.dtype ).to(self.device) - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for BT in [2**i for i in range(12, 16)]: diff --git a/tritonbench/operators/fused_linear_jsd/operator.py b/tritonbench/operators/fused_linear_jsd/operator.py index 7ebdcc3b..c26d3f0c 100644 --- a/tritonbench/operators/fused_linear_jsd/operator.py +++ b/tritonbench/operators/fused_linear_jsd/operator.py @@ -146,8 +146,6 @@ def __init__( self.liger_op.teacher_lin.weight.data ) = torch.rand(self.V, self.H, device=self.device, dtype=self.dtype) - self.use_cuda_graphs = False - def get_input_iter(self) -> Generator: for BT in [2**i for i in range(10, 14)]: student_input = torch.rand( diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index 9613bc96..640e4185 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -40,7 +40,6 @@ def __init__( self.liger_model = ( LigerGEGLUMLP(self.llama_config).to(self.device).to(self.dtype) ) - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for T in [2**i for i in range(10, 14)]: diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 4082a814..d594c6e4 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -141,7 +141,6 @@ def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): super().__init__(tb_args, extra_args) - self.use_cuda_graphs = False gemm_args = parse_args(self.extra_args) self.layout = gemm_args.layout if IS_FBCODE and tb_args.production_shapes: diff --git a/tritonbench/operators/grouped_gemm/operator.py b/tritonbench/operators/grouped_gemm/operator.py index 0ab2604c..0b695d45 100644 --- a/tritonbench/operators/grouped_gemm/operator.py +++ b/tritonbench/operators/grouped_gemm/operator.py @@ -18,7 +18,6 @@ class Operator(BenchmarkOperator): DEFAULT_PRECISION = "fp16" DEFAULT_METRICS = ["latency", "speedup", "accuracy"] - use_cuda_graphs = False @register_benchmark(baseline=True) def torch(self, group_A, group_B): diff --git a/tritonbench/operators/jagged_layer_norm/operator.py b/tritonbench/operators/jagged_layer_norm/operator.py index d34f4896..36252c58 100644 --- a/tritonbench/operators/jagged_layer_norm/operator.py +++ b/tritonbench/operators/jagged_layer_norm/operator.py @@ -40,10 +40,6 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy"] DEFAULT_PRECISION = "fp32" - use_cuda_graphs = ( - False # allows for a GPU/CPU sync, caused by methods like torch.unbind - ) - def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): diff --git a/tritonbench/operators/jagged_mean/operator.py b/tritonbench/operators/jagged_mean/operator.py index c9c975d2..cbcc8b74 100644 --- a/tritonbench/operators/jagged_mean/operator.py +++ b/tritonbench/operators/jagged_mean/operator.py @@ -96,10 +96,6 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy"] DEFAULT_PRECISION = "fp32" - use_cuda_graphs = ( - False # enables GPU/CPU sync (for methods like NestedTensor unbind) - ) - def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): diff --git a/tritonbench/operators/jagged_softmax/operator.py b/tritonbench/operators/jagged_softmax/operator.py index cc3284c0..02299d75 100644 --- a/tritonbench/operators/jagged_softmax/operator.py +++ b/tritonbench/operators/jagged_softmax/operator.py @@ -75,10 +75,6 @@ class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy", "best_config"] DEFAULT_PRECISION = "fp32" - use_cuda_graphs = ( - False # enables GPU/CPU sync (for methods like NestedTensor unbind) - ) - def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): diff --git a/tritonbench/operators/jagged_sum/operator.py b/tritonbench/operators/jagged_sum/operator.py index 36d519db..97d9506f 100644 --- a/tritonbench/operators/jagged_sum/operator.py +++ b/tritonbench/operators/jagged_sum/operator.py @@ -95,9 +95,6 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer): class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy", "best_config"] DEFAULT_PRECISION = "fp32" - use_cuda_graphs = ( - False # enables GPU/CPU sync (for methods like NestedTensor unbind) - ) def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None diff --git a/tritonbench/operators/jsd/operator.py b/tritonbench/operators/jsd/operator.py index 5a42f294..06b863b8 100644 --- a/tritonbench/operators/jsd/operator.py +++ b/tritonbench/operators/jsd/operator.py @@ -65,7 +65,6 @@ def __init__( self.T = 2048 self.baseline_op = TorchJSD() self.liger_op = LigerJSD() - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for V in [2**i for i in range(12, 18)]: diff --git a/tritonbench/operators/kl_div/operator.py b/tritonbench/operators/kl_div/operator.py index 0d600cce..24fcef4e 100644 --- a/tritonbench/operators/kl_div/operator.py +++ b/tritonbench/operators/kl_div/operator.py @@ -27,7 +27,6 @@ def __init__( self.T = 512 self.baseline_op = torch.nn.KLDivLoss(reduction="batchmean").to(self.device) self.liger_op = LigerKLDIVLoss(reduction="batchmean").to(self.device) - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for V in [2**i for i in range(12, 18)]: diff --git a/tritonbench/operators/rms_norm/operator.py b/tritonbench/operators/rms_norm/operator.py index 0c62d39d..adde1f48 100644 --- a/tritonbench/operators/rms_norm/operator.py +++ b/tritonbench/operators/rms_norm/operator.py @@ -45,7 +45,6 @@ def __init__( # they are generated later self.llama_rms_op = None self.liger_rms_op = None - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for H in [2**i for i in range(10, 16)]: diff --git a/tritonbench/operators/rope/operator.py b/tritonbench/operators/rope/operator.py index 174626ac..ac9073e6 100644 --- a/tritonbench/operators/rope/operator.py +++ b/tritonbench/operators/rope/operator.py @@ -30,7 +30,6 @@ def __init__( # they are generated later self.baseline_op = None self.liger_op = None - self.use_cuda_graphs = False self.num_q_heads = 32 self.num_kv_heads = 8 diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py index b21fede9..ab53ab76 100644 --- a/tritonbench/operators/swiglu/operator.py +++ b/tritonbench/operators/swiglu/operator.py @@ -39,7 +39,6 @@ def __init__( self.liger_op = ( LigerSwiGLUMLP(config=llama_config).to(self.device).to(self.dtype) ) - self.use_cuda_graphs = False def get_input_iter(self) -> Generator: for seq_len in [2**i for i in range(10, 14)]: diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 0a0aa962..fea2ed43 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -1470,6 +1470,7 @@ def kineto_trace(self, input_id: int, fn: Callable) -> str: fn=fn, grad_to_none=self.get_grad_to_none(self.example_inputs), output_dir=kineto_output_dir, + use_cuda_graphs=self.use_cuda_graphs, ) def compile_time(