From bec6d9f031d401f34b70331651cfce70e1b6372f Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 27 Nov 2024 10:39:14 -0800 Subject: [PATCH] Install tritonbench as a library (#81) Summary: Now users can import Tritonbench as a library: ``` $ pip install -e . $ python -c "import tritonbench; op = tritonbench.load_opbench_by_name('addmm');" ``` Clean up the init file so that dependencies like `os` and `importlib` will not pollute the namespace. Pull Request resolved: https://github.com/pytorch-labs/tritonbench/pull/81 Reviewed By: FindHao Differential Revision: D66544679 Pulled By: xuzhao9 fbshipit-source-id: 0339adef58d28f2c7f207a2869c21d4b55575386 --- .gitignore | 2 +- README.md | 16 +++- pyproject.toml | 15 ++++ tritonbench/__init__.py | 5 ++ tritonbench/operators/__init__.py | 79 +------------------ tritonbench/operators/addmm/hstu.py | 1 + .../operators/fp8_gemm_rowwise/operator.py | 1 + tritonbench/operators/geglu/operator.py | 1 + tritonbench/operators/gemm/operator.py | 1 + .../operators/jagged_layer_norm/operator.py | 1 + tritonbench/operators/jagged_mean/operator.py | 1 + .../operators/jagged_softmax/operator.py | 1 + tritonbench/operators/jagged_sum/operator.py | 1 + tritonbench/operators/layer_norm/operator.py | 1 + tritonbench/operators/op.py | 78 ++++++++++++++++++ tritonbench/operators/op_task.py | 1 + .../operators/ragged_attention/hstu.py | 1 + .../operators/ragged_attention/operator.py | 1 + tritonbench/operators/softmax/operator.py | 1 + tritonbench/operators/sum/operator.py | 1 + tritonbench/operators_collection/__init__.py | 72 +---------------- .../operators_collection/op_collection.py | 71 +++++++++++++++++ tritonbench/utils/__init__.py | 3 + tritonbench/utils/triton_op.py | 1 + 24 files changed, 205 insertions(+), 151 deletions(-) create mode 100644 tritonbench/__init__.py create mode 100644 tritonbench/operators/op.py create mode 100644 tritonbench/operators_collection/op_collection.py create mode 100644 tritonbench/utils/__init__.py diff --git a/.gitignore b/.gitignore index 8ca07fae..db8e0b86 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ __pycache__/ .DS_Store .ipynb_checkpoints/ .idea -torchbench.egg-info/ +*.egg-info/ diff --git a/README.md b/README.md index 160ec7dd..ecc816f6 100644 --- a/README.md +++ b/README.md @@ -26,12 +26,26 @@ By default, it will install the latest PyTorch nightly release and use the Trito ## Basic Usage -To benchmark an operator, use the following command: +To benchmark an operator, run the following command: ``` $ python run.py --op gemm ``` +## Install as a library + +To install as a library: + +``` +$ pip install -e . +# in your own benchmark script +import tritonbench +from tritonbench.utils import parser +op_args = parser.parse_args() +addmm_bench = tritonbench.load_opbench_by_name("addmm")(op_args) +addmm_bench.run() +``` + ## Submodules We depend on the following projects as a source of customized Triton or CUTLASS kernels: diff --git a/pyproject.toml b/pyproject.toml index c8630b05..d8cab09c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,18 @@ +[build-system] +requires = ["setuptools>=40.8.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "tritonbench" +version = "0.0.1" +dependencies = [ + "torch", + "triton", +] + +[tool.setuptools.packages.find] +include = ["tritonbench*"] + [tool.ufmt] formatter = "ruff-api" excludes = ["submodules/"] diff --git a/tritonbench/__init__.py b/tritonbench/__init__.py new file mode 100644 index 00000000..858fb44e --- /dev/null +++ b/tritonbench/__init__.py @@ -0,0 +1,5 @@ +from .operators import list_operators, load_opbench_by_name +from .operators_collection import ( + list_operator_collections, + list_operators_by_collection, +) diff --git a/tritonbench/operators/__init__.py b/tritonbench/operators/__init__.py index 58a7b359..38c91993 100644 --- a/tritonbench/operators/__init__.py +++ b/tritonbench/operators/__init__.py @@ -1,78 +1 @@ -import importlib -import os -import pathlib -from typing import List - -OPBENCH_DIR = "operators" -INTERNAL_OPBENCH_DIR = "fb" - - -def _dir_contains_file(dir, file_name) -> bool: - names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir())) - return file_name in names - - -def _is_internal_operator(op_name: str) -> bool: - p = ( - pathlib.Path(__file__) - .parent.parent.joinpath(OPBENCH_DIR) - .joinpath(INTERNAL_OPBENCH_DIR) - .joinpath(op_name) - ) - if p.exists() and p.joinpath("__init__.py").exists(): - return True - return False - - -def _list_opbench_paths() -> List[str]: - p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR) - # Only load the model directories that contain a "__init.py__" file - opbench = sorted( - str(child.absolute()) - for child in p.iterdir() - if child.is_dir() and _dir_contains_file(child, "__init__.py") - ) - p = p.joinpath(INTERNAL_OPBENCH_DIR) - if p.exists(): - o = sorted( - str(child.absolute()) - for child in p.iterdir() - if child.is_dir() and _dir_contains_file(child, "__init__.py") - ) - opbench.extend(o) - return opbench - - -def list_operators() -> List[str]: - operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths())) - if INTERNAL_OPBENCH_DIR in operators: - operators.remove(INTERNAL_OPBENCH_DIR) - return operators - - -def load_opbench_by_name(op_name: str): - opbench_list = filter( - lambda x: op_name.lower() == x.lower(), - map(lambda y: os.path.basename(y), _list_opbench_paths()), - ) - opbench_list = list(opbench_list) - if not opbench_list: - raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.") - assert ( - len(opbench_list) == 1 - ), f"Found more than one operators {opbench_list} matching the required name: {op_name}" - op_name = opbench_list[0] - op_pkg = ( - op_name - if not _is_internal_operator(op_name) - else f"{INTERNAL_OPBENCH_DIR}.{op_name}" - ) - module = importlib.import_module(f".{op_pkg}", package=__name__) - - Operator = getattr(module, "Operator", None) - if Operator is None: - print(f"Warning: {module} does not define attribute Operator, skip it") - return None - if not hasattr(Operator, "name"): - Operator.name = op_name - return Operator +from .op import list_operators, load_opbench_by_name diff --git a/tritonbench/operators/addmm/hstu.py b/tritonbench/operators/addmm/hstu.py index ddd08073..2aeed9f8 100644 --- a/tritonbench/operators/addmm/hstu.py +++ b/tritonbench/operators/addmm/hstu.py @@ -4,6 +4,7 @@ import torch import triton + from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))): diff --git a/tritonbench/operators/fp8_gemm_rowwise/operator.py b/tritonbench/operators/fp8_gemm_rowwise/operator.py index cae6fd34..53bf30b0 100644 --- a/tritonbench/operators/fp8_gemm_rowwise/operator.py +++ b/tritonbench/operators/fp8_gemm_rowwise/operator.py @@ -6,6 +6,7 @@ import torch import triton + from tritonbench.utils.data_utils import get_production_shapes from tritonbench.utils.triton_op import ( diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index 237f850d..9613bc96 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -5,6 +5,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP + from tritonbench.utils.triton_op import ( BenchmarkOperator, register_benchmark, diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index b7893659..4082a814 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -6,6 +6,7 @@ import torch import torch._inductor.config as inductor_config import triton + from tritonbench.utils.data_utils import get_production_shapes from tritonbench.utils.path_utils import REPO_PATH diff --git a/tritonbench/operators/jagged_layer_norm/operator.py b/tritonbench/operators/jagged_layer_norm/operator.py index 63dc3a6e..d34f4896 100644 --- a/tritonbench/operators/jagged_layer_norm/operator.py +++ b/tritonbench/operators/jagged_layer_norm/operator.py @@ -7,6 +7,7 @@ import torch import triton + from tritonbench.utils.jagged_utils import ( ABSOLUTE_TOLERANCE, EPSILON, diff --git a/tritonbench/operators/jagged_mean/operator.py b/tritonbench/operators/jagged_mean/operator.py index d5a82269..c9c975d2 100644 --- a/tritonbench/operators/jagged_mean/operator.py +++ b/tritonbench/operators/jagged_mean/operator.py @@ -7,6 +7,7 @@ import torch import triton + from tritonbench.utils.jagged_utils import ( ABSOLUTE_TOLERANCE, generate_input_vals, diff --git a/tritonbench/operators/jagged_softmax/operator.py b/tritonbench/operators/jagged_softmax/operator.py index aad5a41c..cc3284c0 100644 --- a/tritonbench/operators/jagged_softmax/operator.py +++ b/tritonbench/operators/jagged_softmax/operator.py @@ -7,6 +7,7 @@ import torch import triton + from tritonbench.utils.jagged_utils import ( ABSOLUTE_TOLERANCE, generate_input_vals, diff --git a/tritonbench/operators/jagged_sum/operator.py b/tritonbench/operators/jagged_sum/operator.py index c531186c..36d519db 100644 --- a/tritonbench/operators/jagged_sum/operator.py +++ b/tritonbench/operators/jagged_sum/operator.py @@ -7,6 +7,7 @@ import torch import triton + from tritonbench.utils.jagged_utils import ( ABSOLUTE_TOLERANCE, generate_input_vals, diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index d2e1df29..6627697c 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F import triton + from tritonbench.utils.triton_op import ( BenchmarkOperator, BenchmarkOperatorMetrics, diff --git a/tritonbench/operators/op.py b/tritonbench/operators/op.py new file mode 100644 index 00000000..edcc3c45 --- /dev/null +++ b/tritonbench/operators/op.py @@ -0,0 +1,78 @@ +import importlib +import os +import pathlib +from typing import List + +OPBENCH_DIR = "operators" +INTERNAL_OPBENCH_DIR = "fb" + + +def _dir_contains_file(dir, file_name) -> bool: + names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir())) + return file_name in names + + +def _is_internal_operator(op_name: str) -> bool: + p = ( + pathlib.Path(__file__) + .parent.parent.joinpath(OPBENCH_DIR) + .joinpath(INTERNAL_OPBENCH_DIR) + .joinpath(op_name) + ) + if p.exists() and p.joinpath("__init__.py").exists(): + return True + return False + + +def _list_opbench_paths() -> List[str]: + p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR) + # Only load the model directories that contain a "__init.py__" file + opbench = sorted( + str(child.absolute()) + for child in p.iterdir() + if child.is_dir() and _dir_contains_file(child, "__init__.py") + ) + p = p.joinpath(INTERNAL_OPBENCH_DIR) + if p.exists(): + o = sorted( + str(child.absolute()) + for child in p.iterdir() + if child.is_dir() and _dir_contains_file(child, "__init__.py") + ) + opbench.extend(o) + return opbench + + +def list_operators() -> List[str]: + operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths())) + if INTERNAL_OPBENCH_DIR in operators: + operators.remove(INTERNAL_OPBENCH_DIR) + return operators + + +def load_opbench_by_name(op_name: str): + opbench_list = filter( + lambda x: op_name.lower() == x.lower(), + map(lambda y: os.path.basename(y), _list_opbench_paths()), + ) + opbench_list = list(opbench_list) + if not opbench_list: + raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.") + assert ( + len(opbench_list) == 1 + ), f"Found more than one operators {opbench_list} matching the required name: {op_name}" + op_name = opbench_list[0] + op_pkg = ( + op_name + if not _is_internal_operator(op_name) + else f"{INTERNAL_OPBENCH_DIR}.{op_name}" + ) + module = importlib.import_module(f"..{op_pkg}", package=__name__) + + Operator = getattr(module, "Operator", None) + if Operator is None: + print(f"Warning: {module} does not define attribute Operator, skip it") + return None + if not hasattr(Operator, "name"): + Operator.name = op_name + return Operator diff --git a/tritonbench/operators/op_task.py b/tritonbench/operators/op_task.py index fc358ceb..413025b6 100644 --- a/tritonbench/operators/op_task.py +++ b/tritonbench/operators/op_task.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional import torch + from tritonbench.components.tasks import base as base_task from tritonbench.components.workers import subprocess_worker diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index c6dd4010..025bad0d 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -1,5 +1,6 @@ import torch import triton + from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH from tritonbench.utils.triton_op import IS_FBCODE diff --git a/tritonbench/operators/ragged_attention/operator.py b/tritonbench/operators/ragged_attention/operator.py index 1beb351c..52db5482 100644 --- a/tritonbench/operators/ragged_attention/operator.py +++ b/tritonbench/operators/ragged_attention/operator.py @@ -3,6 +3,7 @@ from typing import Any, Callable, List, Optional import torch + from tritonbench.utils.input import input_filter from tritonbench.utils.triton_op import ( diff --git a/tritonbench/operators/softmax/operator.py b/tritonbench/operators/softmax/operator.py index 58b56aac..ce81768e 100644 --- a/tritonbench/operators/softmax/operator.py +++ b/tritonbench/operators/softmax/operator.py @@ -3,6 +3,7 @@ import torch import triton import triton.language as tl + from tritonbench.utils.data_utils import get_production_shapes from tritonbench.utils.triton_op import ( diff --git a/tritonbench/operators/sum/operator.py b/tritonbench/operators/sum/operator.py index ffdc790d..7f60c244 100644 --- a/tritonbench/operators/sum/operator.py +++ b/tritonbench/operators/sum/operator.py @@ -7,6 +7,7 @@ import torch import triton import triton.language as tl + from tritonbench.utils.triton_op import ( BenchmarkOperator, BenchmarkOperatorMetrics, diff --git a/tritonbench/operators_collection/__init__.py b/tritonbench/operators_collection/__init__.py index 9a60deb7..5e6cc3cd 100644 --- a/tritonbench/operators_collection/__init__.py +++ b/tritonbench/operators_collection/__init__.py @@ -1,71 +1 @@ -import importlib -import pathlib -from typing import List - -OP_COLLECTION_PATH = "operators_collection" - - -def list_operator_collections() -> List[str]: - """ - List the available operator collections. - - This function retrieves the list of available operator collections by scanning the directories - in the current path that contain an "__init__.py" file. - - Returns: - List[str]: A list of names of the available operator collections. - """ - p = pathlib.Path(__file__).parent - # only load the directories that contain a "__init__.py" file - collection_paths = sorted( - str(child.absolute()) - for child in p.iterdir() - if child.is_dir() and child.joinpath("__init__.py").exists() - ) - filtered_collections = [pathlib.Path(path).name for path in collection_paths] - return filtered_collections - - -def list_operators_by_collection(op_collection: str = "default") -> List[str]: - """ - List the operators from the specified operator collections. - - This function retrieves the list of operators from the specified operator collections. - If the collection name is "all", it retrieves operators from all available collections. - If the collection name is not specified, it defaults to the "default" collection. - - Args: - op_collection (str): Names of the operator collections to list operators from. - It can be a single collection name or a comma-separated list of names. - Special value "all" retrieves operators from all collections. - - Returns: - List[str]: A list of operator names from the specified collection(s). - - Raises: - ModuleNotFoundError: If the specified collection module is not found. - AttributeError: If the specified collection module does not have a 'get_operators' function. - """ - - def _list_all_operators(collection_name: str): - try: - module_name = f".{collection_name}" - module = importlib.import_module(module_name, package=__name__) - if hasattr(module, "get_operators"): - return module.get_operators() - else: - raise AttributeError( - f"Module '{module_name}' does not have a 'get_operators' function" - ) - except ModuleNotFoundError: - raise ModuleNotFoundError(f"Module '{module_name}' not found") - - if op_collection == "all": - collection_names = list_operator_collections() - else: - collection_names = op_collection.split(",") - - all_operators = [] - for collection_name in collection_names: - all_operators.extend(_list_all_operators(collection_name)) - return all_operators +from .op_collection import list_operator_collections, list_operators_by_collection diff --git a/tritonbench/operators_collection/op_collection.py b/tritonbench/operators_collection/op_collection.py new file mode 100644 index 00000000..65e67750 --- /dev/null +++ b/tritonbench/operators_collection/op_collection.py @@ -0,0 +1,71 @@ +import importlib +import pathlib +from typing import List + +OP_COLLECTION_PATH = "operators_collection" + + +def list_operator_collections() -> List[str]: + """ + List the available operator collections. + + This function retrieves the list of available operator collections by scanning the directories + in the current path that contain an "__init__.py" file. + + Returns: + List[str]: A list of names of the available operator collections. + """ + p = pathlib.Path(__file__).parent + # only load the directories that contain a "__init__.py" file + collection_paths = sorted( + str(child.absolute()) + for child in p.iterdir() + if child.is_dir() and child.joinpath("__init__.py").exists() + ) + filtered_collections = [pathlib.Path(path).name for path in collection_paths] + return filtered_collections + + +def list_operators_by_collection(op_collection: str = "default") -> List[str]: + """ + List the operators from the specified operator collections. + + This function retrieves the list of operators from the specified operator collections. + If the collection name is "all", it retrieves operators from all available collections. + If the collection name is not specified, it defaults to the "default" collection. + + Args: + op_collection (str): Names of the operator collections to list operators from. + It can be a single collection name or a comma-separated list of names. + Special value "all" retrieves operators from all collections. + + Returns: + List[str]: A list of operator names from the specified collection(s). + + Raises: + ModuleNotFoundError: If the specified collection module is not found. + AttributeError: If the specified collection module does not have a 'get_operators' function. + """ + + def _list_all_operators(collection_name: str): + try: + module_name = f"..{collection_name}" + module = importlib.import_module(module_name, package=__name__) + if hasattr(module, "get_operators"): + return module.get_operators() + else: + raise AttributeError( + f"Module '{module_name}' does not have a 'get_operators' function" + ) + except ModuleNotFoundError: + raise ModuleNotFoundError(f"Module '{module_name}' not found") + + if op_collection == "all": + collection_names = list_operator_collections() + else: + collection_names = op_collection.split(",") + + all_operators = [] + for collection_name in collection_names: + all_operators.extend(_list_all_operators(collection_name)) + return all_operators diff --git a/tritonbench/utils/__init__.py b/tritonbench/utils/__init__.py new file mode 100644 index 00000000..3b76e765 --- /dev/null +++ b/tritonbench/utils/__init__.py @@ -0,0 +1,3 @@ +from .parser import get_parser + +parser = get_parser() diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 38c6c037..d5cfb277 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -22,6 +22,7 @@ import tabulate import torch import triton + from tritonbench.components.ncu import analyzer as ncu_analyzer from tritonbench.utils.env_utils import ( apply_precision,