Skip to content

Commit

Permalink
Merge branch 'main' into findhao/add-nsys-analyzer
Browse files Browse the repository at this point in the history
  • Loading branch information
FindHao committed Nov 26, 2024
2 parents e42929a + 0ca9f40 commit e7da4ac
Show file tree
Hide file tree
Showing 14 changed files with 26 additions and 22 deletions.
2 changes: 1 addition & 1 deletion tritonbench/operators/embedding/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def liger_embedding(self, V, D, input) -> Callable:
@register_benchmark()
def inductor_embedding(self, V, D, input) -> Callable:
self.baseline_op = Embedding(V, D).to(self.device).to(self.dtype)
compiled = torch.compile(self.baseline_op, dynamic=False)
compiled = torch.compile(self.baseline_op)
return lambda: compiled(input)

@register_x_val(label="(B, T, D, V)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:

class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops"]
DEFAULT_PRECISION = "fp32"
DEFAULT_PRECISION = "fp8"

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def parse_args(args):

class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
DEFAULT_PRECISION = "fp8"

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/fp8_gemm_blockwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def fp8_block_quantize(

class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
DEFAULT_PRECISION = "fp32"
DEFAULT_PRECISION = "fp8"

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
DEFAULT_PRECISION = "fp32"
DEFAULT_PRECISION = "fp8"

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
Expand All @@ -118,7 +118,7 @@ def __init__(
self.use_cuda_graphs = True
addmm_args = parse_args(self.extra_args)
if hasattr(tb_args, "production_shapes") and tb_args.production_shapes:
self.shapes = get_production_shapes(self.name, "fp8_gemm")
self.shapes = get_production_shapes(self.name, "fp32_gemm")
elif addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
elif addmm_args.llama:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def liger_lm_head_ce(self, input, target) -> Callable:

@register_benchmark()
def inductor_fused_linear_cross_entropy(self, input, target) -> Callable:
compiled = torch.compile(self.baseline_model, dynamic=False)
compiled = torch.compile(self.baseline_model)
return lambda: compiled(input, target)

@register_x_val(label="(B*T, H)")
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/fused_linear_jsd/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def liger_lm_head_jsd(self, student_input, teacher_input) -> Callable:

@register_benchmark()
def inductor_lm_head_jsd(self, student_input, teacher_input) -> Callable:
compiled = torch.compile(self.baseline_op, dynamic=False)
compiled = torch.compile(self.baseline_op)
return lambda: compiled(student_input, teacher_input)

@register_x_val(label="(B*T, H)")
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/geglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def liger_geglu(self, input) -> Callable:

@register_benchmark()
def inductor_geglu(self, input) -> Callable:
compiled = torch.compile(self.baseline_model, dynamic=False)
compiled = torch.compile(self.baseline_model)
return lambda: compiled(input)

@register_x_val(label="(B, T, H)")
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/jsd/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def liger_jsd(self, _input, target) -> Callable:

@register_benchmark()
def inductor_jsd(self, _input, target) -> Callable:
compiled = torch.compile(self.baseline_op, dynamic=False)
compiled = torch.compile(self.baseline_op)
return lambda: compiled(_input, target)

@register_x_val(label="(B, T, V)")
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/kl_div/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def liger_kl_div(self, input, target) -> Callable:

@register_benchmark()
def inductor_kl_div(self, input, target) -> Callable:
compiled = torch.compile(self.baseline_op, dynamic=False)
compiled = torch.compile(self.baseline_op)
return lambda: compiled(input, target)

@register_x_val(label="(B, T, V)")
Expand Down
6 changes: 5 additions & 1 deletion tritonbench/operators/rms_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def liger_rms(self, H, input) -> Callable:

@register_benchmark()
def inductor_rms(self, H, input) -> Callable:
compiled = torch.compile(self.llama_rms_op, dynamic=False)
if self.llama_rms_op is None:
self.llama_rms_op = LlamaRMSNorm(hidden_size=H, eps=self.eps).to(
self.device
)
compiled = torch.compile(self.llama_rms_op)
return lambda: compiled(input)

@register_x_val(label="(M, H)")
Expand Down
6 changes: 2 additions & 4 deletions tritonbench/operators/rope/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,9 @@ def liger_rotary_pos_emb(self, hidden_size, seq_length) -> Callable:
def inductor_rotary_pos_emb_full_op(self, hidden_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length)
head_dim = hidden_size // self.num_q_heads
compiled = torch.compile(
LlamaRotaryEmbedding(head_dim, device=self.device), dynamic=False
)
compiled = torch.compile(LlamaRotaryEmbedding(head_dim, device=self.device))
cos, sin = compiled(k, pos_ids)
compiled_func = torch.compile(apply_rotary_pos_emb, dynamic=False)
compiled_func = torch.compile(apply_rotary_pos_emb)
return lambda: compiled_func(q, k, cos, sin, pos_ids)

@register_x_val(label="(H, T)")
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/swiglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def liger_swiglu(self, input) -> Callable:

@register_benchmark()
def inductor_swiglu(self, input) -> Callable:
compiled = torch.compile(self.baseline_op, dynamic=False)
compiled = torch.compile(self.baseline_op)
return lambda: compiled(input)

@register_x_val(label="(B, T, H)")
Expand Down
13 changes: 7 additions & 6 deletions tritonbench/utils/gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
}

# NVIDIA H100 GPU Datasheet:
# https://www.nvidia.com/en-gb/data-center/h100
# "https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
NV_H100 = {
"fp32": 51,
"tf32": 756,
"bf16": 1513,
"fp16": 1513,
"fp8": 3026,
"fp32": 989 // 2,
"tf32": 989 // 2,
"bf16": 1979 // 2,
"fp16": 1979 // 2,
"fp8": 3958 // 2,
"int8": 3958 // 2,
}


Expand Down

0 comments on commit e7da4ac

Please sign in to comment.