Skip to content

Commit

Permalink
Add CPU and GPU unit tests
Browse files Browse the repository at this point in the history
Summary:
Add CPU and GPU tests to pytorch/tritonbench/fb.

We can now remove the GPU tests from pytorch/benchmark/test_tritonbench.

Reviewed By: FindHao

Differential Revision: D64989354

fbshipit-source-id: 20acceecaceeba45e0518fb3683557a1b6327377
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 28, 2024
1 parent 7e60e23 commit 67ad80b
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 26 deletions.
2 changes: 2 additions & 0 deletions tritonbench/operators/jagged_layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def parse_op_args(args: List[str]):
class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"

use_cuda_graphs = (
False # allows for a GPU/CPU sync, caused by methods like torch.unbind
)
Expand Down
2 changes: 2 additions & 0 deletions tritonbench/operators/jagged_mean/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer):
class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"

use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)
Expand Down
2 changes: 2 additions & 0 deletions tritonbench/operators/jagged_softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def parse_op_args(args: List[str]):
class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy", "best_config"]
DEFAULT_PRECISION = "fp32"

use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators_collection/liger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"embedding",
"rms_norm",
"rope",
"jsd",
]


Expand Down
97 changes: 71 additions & 26 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,27 @@

logger = logging.getLogger(__name__)


@dataclass
class BenchmarkOperatorBackend:
# backend name
name: str
# backend label
label: str
# baseline
baseline: bool = False
# enabled
enabled: bool = True
# need to be tested in ci
# ci = False implies enabled = False
ci: bool = True


IS_FBCODE = not hasattr(torch.version, "git_version")
DEFAULT_WARMUP = 25
DEFAULT_RUN_ITERS = 100
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, str]] = {}
REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {}
ENABLED_BENCHMARKS: Dict[str, List[str]] = {}
REGISTERED_METRICS: Dict[str, List[str]] = {}
REGISTERED_X_VALS: Dict[str, str] = {}
Expand Down Expand Up @@ -220,6 +236,7 @@ class BenchmarkOperatorResult:
op_name: str
op_mode: str
metrics: List[str]
# Tuple: (x_val, Dict[impl_name, BenchmarkOperatorMetrics])
result: List[Tuple[Any, Dict[str, BenchmarkOperatorMetrics]]]
_result_dict: Optional[Dict[Number, Dict[str, BenchmarkOperatorMetrics]]] = None

Expand All @@ -230,61 +247,62 @@ def _table(self):
if len(self.result) == 0:
return headers, table
y_val = self.result[0][1]
y_val_keys = list(y_val.keys())
backends = list(y_val.keys())
# move the baseline benchmark to the front of the list if exists
if (
self.op_name in BASELINE_BENCHMARKS
and BASELINE_BENCHMARKS[self.op_name] in y_val_keys
and BASELINE_BENCHMARKS[self.op_name] in backends
):
y_val_keys.insert(
0, y_val_keys.pop(y_val_keys.index(BASELINE_BENCHMARKS[self.op_name]))
backends.insert(
0, backends.pop(backends.index(BASELINE_BENCHMARKS[self.op_name]))
)
y_val_keys = [(x, REGISTERED_BENCHMARKS[self.op_name][x]) for x in y_val_keys]
key_metrics = {}
# Add header for x_only_metrics
x_only_metrics = sorted(
[metric for metric in self.metrics if metric in X_ONLY_METRICS]
)
headers.extend(x_only_metrics)
for k, label in y_val_keys:
for backend in backends:
label = REGISTERED_BENCHMARKS[self.op_name][backend].label

def select_metric(m):
def select_metric(backend, m):
if m in x_only_metrics:
return False
if (
m in BASELINE_SKIP_METRICS
and k == BASELINE_BENCHMARKS[self.op_name]
and backend == BASELINE_BENCHMARKS[self.op_name]
):
return False
return True

key_metrics[k] = sorted(filter(select_metric, self.metrics))
for metric in key_metrics[k]:
key_metrics[backend] = [
metric for metric in self.metrics if select_metric(backend, metric)
]
for metric in key_metrics[backend]:
# add extra metrics
headers.append(f"{label}-{metric}")
# generate rows
for x_val, y_val in self.result:
row = []
row.append(x_val)
# Append x_val_only metrics
# Append x_only metrics
for x_only_metric in x_only_metrics:
x_only_metric_dict = asdict(
y_val[y_val_keys[0][0]]
) # retrieve canonical name for metric function, where y_val_keys[0] = (canonical name, customized label name)
# retrieve x_only metrics from the first backend metrics
x_only_metric_dict = asdict(y_val[backends[0]])
if (
"extra_metrics" in x_only_metric_dict
and x_only_metric in x_only_metric_dict["extra_metrics"]
):
row.append(x_only_metric_dict["extra_metrics"][x_only_metric])
else:
row.append(x_only_metric_dict[x_only_metric])
for k, _label in y_val_keys:
metrics_dict = asdict(y_val[k])
for backend in backends:
metrics_dict = asdict(y_val[backend])
if metrics_dict["error_msg"]:
row.append(metrics_dict["error_msg"])
row.extend([None] * (len(key_metrics[k]) - 1))
row.extend([None] * (len(key_metrics[backend]) - 1))
continue
for metric in key_metrics[k]:
for metric in key_metrics[backend]:
_metrics_dict = (
metrics_dict["extra_metrics"]
if metric in metrics_dict["extra_metrics"]
Expand Down Expand Up @@ -333,13 +351,25 @@ def userbenchmark_dict(self) -> Dict[str, Any]:
# tritonbench_{op_name}_{op_mode}[{x_val}-{provider}-{metric}]
userbenchmark_metrics_dict = {}
headers, table = self._table()
num_rows = len(table)
agg_data = {}
for row in table:
x_val = row[0]

for ind, value in enumerate(row[1:]):
header = headers[ind + 1]
provider, _dash, metrics = header.partition("-")
metric_name = f"tritonbench_{self.op_name}_{self.op_mode}[x_{x_val}-{provider}]_{metrics}"
userbenchmark_metrics_dict[metric_name] = value
agg_metric_name = (
f"tritonbench_{self.op_name}_{self.op_mode}[{provider}]_{metrics}"
)
if value is None:
continue
agg_data[agg_metric_name] = agg_data.get(agg_metric_name, 0) + value
final_agg_data = {k: v / num_rows for k, v in agg_data.items()}
userbenchmark_metrics_dict.update(final_agg_data)

return userbenchmark_metrics_dict

def get_y_vals(self, x_val, provider, metric_name: str):
Expand Down Expand Up @@ -384,18 +414,26 @@ def _inner(self, *args, **kwargs):


def register_benchmark(
baseline: bool = False, enabled: bool = True, label: Optional[str] = None
baseline: bool = False,
enabled: bool = True,
ci: bool = True,
label: Optional[str] = None,
):
def decorator(function):
operator_name = _find_op_name_from_module_path(function.__module__)
backend_config = BenchmarkOperatorBackend(
name=function.__name__,
label=label if label else function.__name__,
baseline=baseline,
enabled=enabled if ci else False,
ci=ci,
)
if not operator_name in REGISTERED_BENCHMARKS:
REGISTERED_BENCHMARKS[operator_name] = OrderedDict()
REGISTERED_BENCHMARKS[operator_name][function.__name__] = (
function.__name__ if not label else label
)
if baseline:
REGISTERED_BENCHMARKS[operator_name][function.__name__] = backend_config
if backend_config.baseline:
BASELINE_BENCHMARKS[operator_name] = function.__name__
if enabled:
if backend_config.enabled:
if not operator_name in ENABLED_BENCHMARKS:
ENABLED_BENCHMARKS[operator_name] = []
ENABLED_BENCHMARKS[operator_name].append(function.__name__)
Expand All @@ -414,6 +452,7 @@ def register_benchmark_mannually(
baseline: bool = False,
enabled: bool = True,
label: Optional[str] = None,
ci: bool = True,
):
"""
Manually register a benchmark function for a given operator.
Expand All @@ -435,7 +474,13 @@ def register_benchmark_mannually(
"""
if not operator_name in REGISTERED_BENCHMARKS:
REGISTERED_BENCHMARKS[operator_name] = OrderedDict()
REGISTERED_BENCHMARKS[operator_name][func_name] = func_name if not label else label
REGISTERED_BENCHMARKS[operator_name][func_name] = BenchmarkOperatorBackend(
name=func_name,
label=label if label else func_name,
baseline=baseline,
enabled=enabled,
ci=ci,
)
if baseline:
BASELINE_BENCHMARKS[operator_name] = func_name
if enabled:
Expand Down

0 comments on commit 67ad80b

Please sign in to comment.