Skip to content

Commit

Permalink
Enable cudagraph mode for kineto_trace (#106)
Browse files Browse the repository at this point in the history
Summary:
- Remove `use_cuda_graph = False` as we are now using non cudagraph as default. Many operators do not support cudagraph mode.
- Enable `--cudagraph --metrics kineto_trace` to turn on cuda graph mode for Kineto trace.

Pull Request resolved: #106

Reviewed By: FindHao

Differential Revision: D66923962

Pulled By: xuzhao9

fbshipit-source-id: 6942f1404aab1738a4262d517782cebd9e95e132
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Dec 9, 2024
1 parent 87dffcc commit 1fb46a3
Show file tree
Hide file tree
Showing 20 changed files with 85 additions and 34 deletions.
28 changes: 28 additions & 0 deletions docs/kineto_trace.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Kineto Trace Analysis with TritonBench

TritonBench supports generating a Kineto trace file for each `<input, impl>` pair.
For example, the following command will generate 6 Kineto traces, as it is running 2 inputs(`--num-inputs 2`) with 3 impls (`flash_v3,cudnn,triton_tutorial_flash_v2`).

```
$ python run.py --op flash_attention --num-inputs 2 --metrics kineto_trace --only flash_v3,cudnn,triton_tutorial_flash_v2
(Batch, Heads, SeqLen, Dhead) flash_v3-kineto_trace cudnn_90100-kineto_trace triton_tutorial_flash_v2-kineto_trace
------------------------------- --------------------------------------------------------- ------------------------------------------------------ -------------------------------------------------------------------------
(4, 48, 128, 64) /tmp/tritonbench/flash_attention/kineto_traces/flash_v3_0 /tmp/tritonbench/flash_attention/kineto_traces/cudnn_0 /tmp/tritonbench/flash_attention/kineto_traces/triton_tutorial_flash_v2_0
(4, 48, 256, 64) /tmp/tritonbench/flash_attention/kineto_traces/flash_v3_1 /tmp/tritonbench/flash_attention/kineto_traces/cudnn_1 /tmp/tritonbench/flash_attention/kineto_traces/triton_tutorial_flash_v2_1
```

The output table shows the directory where the Kineto trace file is stored.

## Example Kineto Trace Analysis

Opening the trace file with Chrome Trace Viewer, we need to first separate the profiling iteration with the warm-up iterations.
The profiling iteration runs after all warm-up iteraions and is labeled by `ProfilerStep#<number>`.

![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_1.png "Kineto Trace - Global View")

Zooming into the profile iteration, we find two GPU kernels launched. The first one corresponds to the L2 Cache flush to clear the cache.
The second one corresponds to the actual computation kernel, which is from CUDNN in this flash_attention operator.

![Kineto Trace](https://ossci-datasets.s3.us-east-1.amazonaws.com/tritonbench/docs/_static/img/kineto_trace_fig_2.png "Kineto Trace - Zoomed into Profile Iteration")

60 changes: 56 additions & 4 deletions tritonbench/components/kineto/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,62 @@
from .fb.run_utils import trace_handler


def do_bench_kineto_cudagraph(
fn, warmup, grad_to_none, profile_opts, output_dir
) -> str:
activity_groups = [
profiler.ProfilerActivity.CUDA,
profiler.ProfilerActivity.CPU,
]
with torch.cuda.stream(torch.cuda.Stream()):
# step 1 - construct a cuda graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
fn()
torch.cuda.synchronize()
prefix = f"tritonbench_cudagraph_{fn._name}"
name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json"
# step 2 - profile cuda graph launch with kineto
with profiler.profile(
schedule=profiler.schedule(wait=0, warmup=warmup, active=1, repeat=1),
activities=activity_groups,
record_shapes=profile_opts["record_shapes"],
profile_memory=profile_opts["profile_memory"],
with_stack=profile_opts["with_stack"],
with_flops=profile_opts["with_flops"],
with_modules=profile_opts["with_modules"],
on_trace_ready=(
partial(trace_handler, name)
if not hasattr(torch.version, "git_version")
else profiler.tensorboard_trace_handler(output_dir)
),
) as prof:
for _i in range(warmup + 1):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
g.replay()
prof.step()
if not hasattr(torch.version, "git_version"):
return f"https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/{name}.gz&bucket=pyper_traces"
else:
return output_dir


def do_bench_kineto(
fn: Callable,
warmup=25,
grad_to_none=None,
fast_flush=True,
profile_opts=None,
output_dir=None,
use_cuda_graphs: bool = False,
) -> str:
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
Expand All @@ -44,6 +93,8 @@ def do_bench_kineto(
:param output_dir: Output directory to store the trace
:type output_dir: str, optional
"""
if profile_opts is None:
profile_opts = DEFAULT_PROFILE_OPTS
import torch

fn()
Expand All @@ -70,12 +121,15 @@ def do_bench_kineto(

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
if use_cuda_graphs:
return do_bench_kineto_cudagraph(
fn, n_warmup, grad_to_none, profile_opts, output_dir
)

activity_groups = [
profiler.ProfilerActivity.CUDA,
profiler.ProfilerActivity.CPU,
]
if profile_opts is None:
profile_opts = DEFAULT_PROFILE_OPTS
prefix = f"tritonbench_{fn._name}"
name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json"
with profiler.profile(
Expand All @@ -92,8 +146,6 @@ def do_bench_kineto(
else profiler.tensorboard_trace_handler(output_dir)
),
) as prof:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
for i in range(n_warmup + 1):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/cross_entropy/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(
self.T = 2048
self.baseline_model = CrossEntropyLoss()
self.liger_model = LigerCrossEntropyLoss()
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for V in [2**i for i in range(12, 18)]:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/embedding/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
# they are generated later
self.baseline_op = None
self.liger_op = None
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for B, T, D in [(32, 512, 768), (8, 2048, 4096)]:
Expand Down
2 changes: 0 additions & 2 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
self.use_cuda_graphs = False
args = parse_op_args(self.extra_args)
self.use_cuda_graphs = False
self.BATCH = args.batch
self.SEQ_LEN = args.seq_len
self.H = args.n_heads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(
self.liger_model = LigerLMHeadCE(
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
).to(self.device)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for BT in [2**i for i in range(12, 16)]:
Expand Down
2 changes: 0 additions & 2 deletions tritonbench/operators/fused_linear_jsd/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ def __init__(
self.liger_op.teacher_lin.weight.data
) = torch.rand(self.V, self.H, device=self.device, dtype=self.dtype)

self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for BT in [2**i for i in range(10, 14)]:
student_input = torch.rand(
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/geglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
self.liger_model = (
LigerGEGLUMLP(self.llama_config).to(self.device).to(self.dtype)
)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for T in [2**i for i in range(10, 14)]:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
self.use_cuda_graphs = False
gemm_args = parse_args(self.extra_args)
self.layout = gemm_args.layout
if IS_FBCODE and tb_args.production_shapes:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/grouped_gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
class Operator(BenchmarkOperator):
DEFAULT_PRECISION = "fp16"
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
use_cuda_graphs = False

@register_benchmark(baseline=True)
def torch(self, group_A, group_B):
Expand Down
4 changes: 0 additions & 4 deletions tritonbench/operators/jagged_layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"

use_cuda_graphs = (
False # allows for a GPU/CPU sync, caused by methods like torch.unbind
)

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
Expand Down
4 changes: 0 additions & 4 deletions tritonbench/operators/jagged_mean/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"

use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
Expand Down
4 changes: 0 additions & 4 deletions tritonbench/operators/jagged_softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy", "best_config"]
DEFAULT_PRECISION = "fp32"

use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
Expand Down
3 changes: 0 additions & 3 deletions tritonbench/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,6 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer):
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "accuracy", "best_config"]
DEFAULT_PRECISION = "fp32"
use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/jsd/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def __init__(
self.T = 2048
self.baseline_op = TorchJSD()
self.liger_op = LigerJSD()
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for V in [2**i for i in range(12, 18)]:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/kl_div/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
self.T = 512
self.baseline_op = torch.nn.KLDivLoss(reduction="batchmean").to(self.device)
self.liger_op = LigerKLDIVLoss(reduction="batchmean").to(self.device)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for V in [2**i for i in range(12, 18)]:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/rms_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
# they are generated later
self.llama_rms_op = None
self.liger_rms_op = None
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for H in [2**i for i in range(10, 16)]:
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/rope/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
# they are generated later
self.baseline_op = None
self.liger_op = None
self.use_cuda_graphs = False
self.num_q_heads = 32
self.num_kv_heads = 8

Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/swiglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(
self.liger_op = (
LigerSwiGLUMLP(config=llama_config).to(self.device).to(self.dtype)
)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for seq_len in [2**i for i in range(10, 14)]:
Expand Down
1 change: 1 addition & 0 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,7 @@ def kineto_trace(self, input_id: int, fn: Callable) -> str:
fn=fn,
grad_to_none=self.get_grad_to_none(self.example_inputs),
output_dir=kineto_output_dir,
use_cuda_graphs=self.use_cuda_graphs,
)

def compile_time(
Expand Down

0 comments on commit 1fb46a3

Please sign in to comment.