From ebdb921a6d4fdb343f3e3eebc1de3f35acb6aa1c Mon Sep 17 00:00:00 2001 From: FindHao Date: Mon, 25 Nov 2024 13:35:52 -0800 Subject: [PATCH 1/3] Ops bug fix and args clean (#76) Summary: Fix rope's bug for specific nsys profile. Clean torch.compile args. Pull Request resolved: https://github.com/pytorch-labs/tritonbench/pull/76 Reviewed By: adamomainz Differential Revision: D66461602 Pulled By: FindHao fbshipit-source-id: 8f56e3e60826a6d712ee3ea338e3be5dda65b6ab --- tritonbench/operators/embedding/operator.py | 2 +- .../operators/fused_linear_cross_entropy/operator.py | 2 +- tritonbench/operators/fused_linear_jsd/operator.py | 2 +- tritonbench/operators/geglu/operator.py | 2 +- tritonbench/operators/jsd/operator.py | 2 +- tritonbench/operators/kl_div/operator.py | 2 +- tritonbench/operators/rms_norm/operator.py | 6 +++++- tritonbench/operators/rope/operator.py | 6 ++---- tritonbench/operators/swiglu/operator.py | 2 +- 9 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tritonbench/operators/embedding/operator.py b/tritonbench/operators/embedding/operator.py index 8c7ff41b..00fd4696 100644 --- a/tritonbench/operators/embedding/operator.py +++ b/tritonbench/operators/embedding/operator.py @@ -48,7 +48,7 @@ def liger_embedding(self, V, D, input) -> Callable: @register_benchmark() def inductor_embedding(self, V, D, input) -> Callable: self.baseline_op = Embedding(V, D).to(self.device).to(self.dtype) - compiled = torch.compile(self.baseline_op, dynamic=False) + compiled = torch.compile(self.baseline_op) return lambda: compiled(input) @register_x_val(label="(B, T, D, V)") diff --git a/tritonbench/operators/fused_linear_cross_entropy/operator.py b/tritonbench/operators/fused_linear_cross_entropy/operator.py index 2d384783..9f359345 100644 --- a/tritonbench/operators/fused_linear_cross_entropy/operator.py +++ b/tritonbench/operators/fused_linear_cross_entropy/operator.py @@ -104,7 +104,7 @@ def liger_lm_head_ce(self, input, target) -> Callable: @register_benchmark() def inductor_fused_linear_cross_entropy(self, input, target) -> Callable: - compiled = torch.compile(self.baseline_model, dynamic=False) + compiled = torch.compile(self.baseline_model) return lambda: compiled(input, target) @register_x_val(label="(B*T, H)") diff --git a/tritonbench/operators/fused_linear_jsd/operator.py b/tritonbench/operators/fused_linear_jsd/operator.py index 758c0c98..7ebdcc3b 100644 --- a/tritonbench/operators/fused_linear_jsd/operator.py +++ b/tritonbench/operators/fused_linear_jsd/operator.py @@ -166,7 +166,7 @@ def liger_lm_head_jsd(self, student_input, teacher_input) -> Callable: @register_benchmark() def inductor_lm_head_jsd(self, student_input, teacher_input) -> Callable: - compiled = torch.compile(self.baseline_op, dynamic=False) + compiled = torch.compile(self.baseline_op) return lambda: compiled(student_input, teacher_input) @register_x_val(label="(B*T, H)") diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index fa2b2f04..237f850d 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -59,7 +59,7 @@ def liger_geglu(self, input) -> Callable: @register_benchmark() def inductor_geglu(self, input) -> Callable: - compiled = torch.compile(self.baseline_model, dynamic=False) + compiled = torch.compile(self.baseline_model) return lambda: compiled(input) @register_x_val(label="(B, T, H)") diff --git a/tritonbench/operators/jsd/operator.py b/tritonbench/operators/jsd/operator.py index 881a8d5d..5a42f294 100644 --- a/tritonbench/operators/jsd/operator.py +++ b/tritonbench/operators/jsd/operator.py @@ -87,7 +87,7 @@ def liger_jsd(self, _input, target) -> Callable: @register_benchmark() def inductor_jsd(self, _input, target) -> Callable: - compiled = torch.compile(self.baseline_op, dynamic=False) + compiled = torch.compile(self.baseline_op) return lambda: compiled(_input, target) @register_x_val(label="(B, T, V)") diff --git a/tritonbench/operators/kl_div/operator.py b/tritonbench/operators/kl_div/operator.py index 129ac5c6..0d600cce 100644 --- a/tritonbench/operators/kl_div/operator.py +++ b/tritonbench/operators/kl_div/operator.py @@ -47,7 +47,7 @@ def liger_kl_div(self, input, target) -> Callable: @register_benchmark() def inductor_kl_div(self, input, target) -> Callable: - compiled = torch.compile(self.baseline_op, dynamic=False) + compiled = torch.compile(self.baseline_op) return lambda: compiled(input, target) @register_x_val(label="(B, T, V)") diff --git a/tritonbench/operators/rms_norm/operator.py b/tritonbench/operators/rms_norm/operator.py index 492f7de6..0c62d39d 100644 --- a/tritonbench/operators/rms_norm/operator.py +++ b/tritonbench/operators/rms_norm/operator.py @@ -65,7 +65,11 @@ def liger_rms(self, H, input) -> Callable: @register_benchmark() def inductor_rms(self, H, input) -> Callable: - compiled = torch.compile(self.llama_rms_op, dynamic=False) + if self.llama_rms_op is None: + self.llama_rms_op = LlamaRMSNorm(hidden_size=H, eps=self.eps).to( + self.device + ) + compiled = torch.compile(self.llama_rms_op) return lambda: compiled(input) @register_x_val(label="(M, H)") diff --git a/tritonbench/operators/rope/operator.py b/tritonbench/operators/rope/operator.py index ab2b8476..174626ac 100644 --- a/tritonbench/operators/rope/operator.py +++ b/tritonbench/operators/rope/operator.py @@ -88,11 +88,9 @@ def liger_rotary_pos_emb(self, hidden_size, seq_length) -> Callable: def inductor_rotary_pos_emb_full_op(self, hidden_size, seq_length) -> Callable: q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length) head_dim = hidden_size // self.num_q_heads - compiled = torch.compile( - LlamaRotaryEmbedding(head_dim, device=self.device), dynamic=False - ) + compiled = torch.compile(LlamaRotaryEmbedding(head_dim, device=self.device)) cos, sin = compiled(k, pos_ids) - compiled_func = torch.compile(apply_rotary_pos_emb, dynamic=False) + compiled_func = torch.compile(apply_rotary_pos_emb) return lambda: compiled_func(q, k, cos, sin, pos_ids) @register_x_val(label="(H, T)") diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py index 7808da8d..b21fede9 100644 --- a/tritonbench/operators/swiglu/operator.py +++ b/tritonbench/operators/swiglu/operator.py @@ -60,7 +60,7 @@ def liger_swiglu(self, input) -> Callable: @register_benchmark() def inductor_swiglu(self, input) -> Callable: - compiled = torch.compile(self.baseline_op, dynamic=False) + compiled = torch.compile(self.baseline_op) return lambda: compiled(input) @register_x_val(label="(B, T, H)") From 66816daabd3647f256802100eec0ed0790eae409 Mon Sep 17 00:00:00 2001 From: Adam Mainz Date: Mon, 25 Nov 2024 17:11:32 -0800 Subject: [PATCH 2/3] fp8 gemm was using fp32 by default Summary: why was fp8 defaulting to fp32 before? Reviewed By: xuzhao9 Differential Revision: D66474486 fbshipit-source-id: 560ae17c93ce225e74b6a91bb6147536535089c1 --- .../operators/fp8_fused_quant_gemm_rowwise/operator.py | 2 +- tritonbench/operators/fp8_gemm/fp8_gemm.py | 1 + tritonbench/operators/fp8_gemm_blockwise/operator.py | 2 +- tritonbench/operators/fp8_gemm_rowwise/operator.py | 4 ++-- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py b/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py index 58d210b8..0f87a011 100644 --- a/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py +++ b/tritonbench/operators/fp8_fused_quant_gemm_rowwise/operator.py @@ -72,7 +72,7 @@ def parse_args(args: List[str]) -> argparse.Namespace: class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops"] - DEFAULT_PRECISION = "fp32" + DEFAULT_PRECISION = "fp8" def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 556c56ca..d277528c 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -43,6 +43,7 @@ def parse_args(args): class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "gbps", "latency"] + DEFAULT_PRECISION = "fp8" def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None diff --git a/tritonbench/operators/fp8_gemm_blockwise/operator.py b/tritonbench/operators/fp8_gemm_blockwise/operator.py index 0bcf507a..d1f9bc73 100644 --- a/tritonbench/operators/fp8_gemm_blockwise/operator.py +++ b/tritonbench/operators/fp8_gemm_blockwise/operator.py @@ -115,7 +115,7 @@ def fp8_block_quantize( class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "speedup", "accuracy"] - DEFAULT_PRECISION = "fp32" + DEFAULT_PRECISION = "fp8" def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None diff --git a/tritonbench/operators/fp8_gemm_rowwise/operator.py b/tritonbench/operators/fp8_gemm_rowwise/operator.py index f65383e4..cae6fd34 100644 --- a/tritonbench/operators/fp8_gemm_rowwise/operator.py +++ b/tritonbench/operators/fp8_gemm_rowwise/operator.py @@ -109,7 +109,7 @@ def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "speedup", "accuracy"] - DEFAULT_PRECISION = "fp32" + DEFAULT_PRECISION = "fp8" def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None @@ -118,7 +118,7 @@ def __init__( self.use_cuda_graphs = True addmm_args = parse_args(self.extra_args) if hasattr(tb_args, "production_shapes") and tb_args.production_shapes: - self.shapes = get_production_shapes(self.name, "fp8_gemm") + self.shapes = get_production_shapes(self.name, "fp32_gemm") elif addmm_args.m and addmm_args.n and addmm_args.k: self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)] elif addmm_args.llama: From 0ca9f4022178996cf68185339628f2560b87069e Mon Sep 17 00:00:00 2001 From: Adam Mainz Date: Tue, 26 Nov 2024 11:38:06 -0800 Subject: [PATCH 3/3] changing hw rooflines to match xformers (#80) Summary: Pull Request resolved: https://github.com/pytorch-labs/tritonbench/pull/80 TSIA see https://github.com/facebookresearch/xformers/blob/6e10bd21ac6fc878657b24684723ccd05e41d385/xformers/profiler/device_limits.py#L25 Reviewed By: xuzhao9 Differential Revision: D66502297 fbshipit-source-id: 19964c3e240df40f552c4599a05a1dc7b799b9b8 --- tritonbench/utils/gpu_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tritonbench/utils/gpu_utils.py b/tritonbench/utils/gpu_utils.py index e50422ec..d0c28ddf 100644 --- a/tritonbench/utils/gpu_utils.py +++ b/tritonbench/utils/gpu_utils.py @@ -14,13 +14,14 @@ } # NVIDIA H100 GPU Datasheet: -# https://www.nvidia.com/en-gb/data-center/h100 +# "https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet NV_H100 = { - "fp32": 51, - "tf32": 756, - "bf16": 1513, - "fp16": 1513, - "fp8": 3026, + "fp32": 989 // 2, + "tf32": 989 // 2, + "bf16": 1979 // 2, + "fp16": 1979 // 2, + "fp8": 3958 // 2, + "int8": 3958 // 2, }