Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metrics] Enable cudagraph mode for kineto_trace #106

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/kineto_trace.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Kineto Trace Analysis with TritonBench

TritonBench supports generating a Kineto trace file for each `<input, impl>` pair.
For example, the following command will generate 6 Kineto traces, as it is running 2 inputs(`--num-inputs 2`) with 3 impls (`flash_v3,cudnn,triton_tutorial_flash_v2`).

```
$ python run.py --op flash_attention --num-inputs 2 --metrics kineto_trace --only flash_v3,cudnn,triton_tutorial_flash_v2

(Batch, Heads, SeqLen, Dhead) flash_v3-kineto_trace cudnn_90100-kineto_trace triton_tutorial_flash_v2-kineto_trace
------------------------------- --------------------------------------------------------- ------------------------------------------------------ -------------------------------------------------------------------------
(4, 48, 128, 64) /tmp/tritonbench/flash_attention/kineto_traces/flash_v3_0 /tmp/tritonbench/flash_attention/kineto_traces/cudnn_0 /tmp/tritonbench/flash_attention/kineto_traces/triton_tutorial_flash_v2_0
(4, 48, 256, 64) /tmp/tritonbench/flash_attention/kineto_traces/flash_v3_1 /tmp/tritonbench/flash_attention/kineto_traces/cudnn_1 /tmp/tritonbench/flash_attention/kineto_traces/triton_tutorial_flash_v2_1
```

The output table shows the directory where the Kineto trace file is stored.

## Example Kineto Trace Analysis

Opening the trace file with Chrome Trace Viewer, we need to first separate the profiling iteration with the warm-up iterations.
The profiling iteration runs after all warm-up iteraions and is labeled by `ProfilerStep#<number>`.

![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_1.png "Kineto Trace - Global View")

Zooming into the profile iteration, we find two GPU kernels launched. The first one corresponds to the L2 Cache flush to clear the cache.
The second one corresponds to the actual computation kernel, which is from CUDNN in this flash_attention operator.

![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_2.png "Kineto Trace - Zoomed into Profile Iteration")

60 changes: 56 additions & 4 deletions tritonbench/components/kineto/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,62 @@
from .fb.run_utils import trace_handler


def do_bench_kineto_cudagraph(
fn, warmup, grad_to_none, profile_opts, output_dir
) -> str:
activity_groups = [
profiler.ProfilerActivity.CUDA,
profiler.ProfilerActivity.CPU,
]
with torch.cuda.stream(torch.cuda.Stream()):
# step 1 - construct a cuda graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
fn()
torch.cuda.synchronize()
prefix = f"tritonbench_cudagraph_{fn._name}"
name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json"
# step 2 - profile cuda graph launch with kineto
with profiler.profile(
schedule=profiler.schedule(wait=0, warmup=warmup, active=1, repeat=1),
activities=activity_groups,
record_shapes=profile_opts["record_shapes"],
profile_memory=profile_opts["profile_memory"],
with_stack=profile_opts["with_stack"],
with_flops=profile_opts["with_flops"],
with_modules=profile_opts["with_modules"],
on_trace_ready=(
partial(trace_handler, name)
if not hasattr(torch.version, "git_version")
else profiler.tensorboard_trace_handler(output_dir)
),
) as prof:
for _i in range(warmup + 1):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
g.replay()
prof.step()
if not hasattr(torch.version, "git_version"):
return f"https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/{name}.gz&bucket=pyper_traces"
else:
return output_dir


def do_bench_kineto(
fn: Callable,
warmup=25,
grad_to_none=None,
fast_flush=True,
profile_opts=None,
output_dir=None,
use_cuda_graphs: bool = False,
) -> str:
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
Expand All @@ -44,6 +93,8 @@ def do_bench_kineto(
:param output_dir: Output directory to store the trace
:type output_dir: str, optional
"""
if profile_opts is None:
profile_opts = DEFAULT_PROFILE_OPTS
import torch

fn()
Expand All @@ -70,12 +121,15 @@ def do_bench_kineto(

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
if use_cuda_graphs:
return do_bench_kineto_cudagraph(
fn, n_warmup, grad_to_none, profile_opts, output_dir
)

activity_groups = [
profiler.ProfilerActivity.CUDA,
profiler.ProfilerActivity.CPU,
]
if profile_opts is None:
profile_opts = DEFAULT_PROFILE_OPTS
prefix = f"tritonbench_{fn._name}"
name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json"
with profiler.profile(
Expand All @@ -92,8 +146,6 @@ def do_bench_kineto(
else profiler.tensorboard_trace_handler(output_dir)
),
) as prof:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
for i in range(n_warmup + 1):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/cross_entropy/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(
self.T = 2048
self.baseline_model = CrossEntropyLoss()
self.liger_model = LigerCrossEntropyLoss()
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for V in [2**i for i in range(12, 18)]:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/embedding/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
# they are generated later
self.baseline_op = None
self.liger_op = None
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for B, T, D in [(32, 512, 768), (8, 2048, 4096)]:
Expand Down
2 changes: 0 additions & 2 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
self.use_cuda_graphs = False
args = parse_op_args(self.extra_args)
self.use_cuda_graphs = False
self.BATCH = args.batch
self.SEQ_LEN = args.seq_len
self.H = args.n_heads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(
self.liger_model = LigerLMHeadCE(
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
).to(self.device)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for BT in [2**i for i in range(12, 16)]:
Expand Down
2 changes: 0 additions & 2 deletions tritonbench/operators/fused_linear_jsd/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ def __init__(
self.liger_op.teacher_lin.weight.data
) = torch.rand(self.V, self.H, device=self.device, dtype=self.dtype)

self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for BT in [2**i for i in range(10, 14)]:
student_input = torch.rand(
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/geglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
self.liger_model = (
LigerGEGLUMLP(self.llama_config).to(self.device).to(self.dtype)
)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for T in [2**i for i in range(10, 14)]:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
self.use_cuda_graphs = False
gemm_args = parse_args(self.extra_args)
self.layout = gemm_args.layout
if IS_FBCODE and tb_args.production_shapes:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/grouped_gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
class Operator(BenchmarkOperator):
DEFAULT_PRECISION = "fp16"
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
use_cuda_graphs = False

@register_benchmark(baseline=True)
def torch(self, group_A, group_B):
Expand Down
4 changes: 0 additions & 4 deletions tritonbench/operators/jagged_layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"

use_cuda_graphs = (
False # allows for a GPU/CPU sync, caused by methods like torch.unbind
)

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
Expand Down
4 changes: 0 additions & 4 deletions tritonbench/operators/jagged_mean/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"

use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
Expand Down
4 changes: 0 additions & 4 deletions tritonbench/operators/jagged_softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy", "best_config"]
DEFAULT_PRECISION = "fp32"

use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
Expand Down
3 changes: 0 additions & 3 deletions tritonbench/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,6 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer):
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy", "best_config"]
DEFAULT_PRECISION = "fp32"
use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/jsd/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def __init__(
self.T = 2048
self.baseline_op = TorchJSD()
self.liger_op = LigerJSD()
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for V in [2**i for i in range(12, 18)]:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/kl_div/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
self.T = 512
self.baseline_op = torch.nn.KLDivLoss(reduction="batchmean").to(self.device)
self.liger_op = LigerKLDIVLoss(reduction="batchmean").to(self.device)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for V in [2**i for i in range(12, 18)]:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/rms_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
# they are generated later
self.llama_rms_op = None
self.liger_rms_op = None
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for H in [2**i for i in range(10, 16)]:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/rope/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
# they are generated later
self.baseline_op = None
self.liger_op = None
self.use_cuda_graphs = False
self.num_q_heads = 32
self.num_kv_heads = 8

Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/swiglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(
self.liger_op = (
LigerSwiGLUMLP(config=llama_config).to(self.device).to(self.dtype)
)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for seq_len in [2**i for i in range(10, 14)]:
Expand Down
1 change: 1 addition & 0 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,7 @@ def kineto_trace(self, input_id: int, fn: Callable) -> str:
fn=fn,
grad_to_none=self.get_grad_to_none(self.example_inputs),
output_dir=kineto_output_dir,
use_cuda_graphs=self.use_cuda_graphs,
)

def compile_time(
Expand Down
Loading