From e26bc150c24aa52b144296c39c9dc08a1479ac7c Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 11:58:29 -0500 Subject: [PATCH 01/13] Add pyproject to install tritonbench --- tritonbench/pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 tritonbench/pyproject.toml diff --git a/tritonbench/pyproject.toml b/tritonbench/pyproject.toml new file mode 100644 index 00000000..d890fa04 --- /dev/null +++ b/tritonbench/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["torch", "triton"] + From 241ff1a71ad27233b154d3771ed8526b1b14d40e Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 12:00:09 -0500 Subject: [PATCH 02/13] No build requires --- tritonbench/pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tritonbench/pyproject.toml b/tritonbench/pyproject.toml index d890fa04..024e9e6f 100644 --- a/tritonbench/pyproject.toml +++ b/tritonbench/pyproject.toml @@ -1,3 +1,2 @@ [build-system] -requires = ["torch", "triton"] - +requires = [] From 8d2ea1d49d6a430c3019483783c2df78709c935e Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 12:01:57 -0500 Subject: [PATCH 03/13] Add pyproject --- tritonbench/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tritonbench/pyproject.toml b/tritonbench/pyproject.toml index 024e9e6f..186965cf 100644 --- a/tritonbench/pyproject.toml +++ b/tritonbench/pyproject.toml @@ -1,2 +1,2 @@ [build-system] -requires = [] +requires = ["setuptools>=40.8.0", "wheel"] From 2173b9536e80d04e85c4b43f3eceebf86a7bb386 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 13:11:36 -0500 Subject: [PATCH 04/13] Install tritonbench package --- tritonbench/setup.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 tritonbench/setup.py diff --git a/tritonbench/setup.py b/tritonbench/setup.py new file mode 100644 index 00000000..1580f874 --- /dev/null +++ b/tritonbench/setup.py @@ -0,0 +1,18 @@ +import subprocess +from setuptools import setup + +def get_git_commit_hash(length=8): + try: + cmd = ['git', 'rev-parse', f'--short={length}', 'HEAD'] + return "+git{}".format(subprocess.check_output(cmd).strip().decode('utf-8')) + except Exception: + return "" + +setup( + name="tritonbench", + version="0.0.1" + get_git_commit_hash(), + author="Xu Zhao", + author_email="xzhao9@meta.com", + description="A benchmark suite for OpenAI Triton and TorchInductor", + long_description="", +) \ No newline at end of file From c8075f5f6a19200945e7d54d90eb6642c8168908 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 13:33:27 -0500 Subject: [PATCH 05/13] Add package to install pyproject --- pyproject.toml | 15 +++++++++++++++ tritonbench/pyproject.toml | 2 -- tritonbench/setup.py | 18 ------------------ 3 files changed, 15 insertions(+), 20 deletions(-) delete mode 100644 tritonbench/pyproject.toml delete mode 100644 tritonbench/setup.py diff --git a/pyproject.toml b/pyproject.toml index c8630b05..d8cab09c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,18 @@ +[build-system] +requires = ["setuptools>=40.8.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "tritonbench" +version = "0.0.1" +dependencies = [ + "torch", + "triton", +] + +[tool.setuptools.packages.find] +include = ["tritonbench*"] + [tool.ufmt] formatter = "ruff-api" excludes = ["submodules/"] diff --git a/tritonbench/pyproject.toml b/tritonbench/pyproject.toml deleted file mode 100644 index 186965cf..00000000 --- a/tritonbench/pyproject.toml +++ /dev/null @@ -1,2 +0,0 @@ -[build-system] -requires = ["setuptools>=40.8.0", "wheel"] diff --git a/tritonbench/setup.py b/tritonbench/setup.py deleted file mode 100644 index 1580f874..00000000 --- a/tritonbench/setup.py +++ /dev/null @@ -1,18 +0,0 @@ -import subprocess -from setuptools import setup - -def get_git_commit_hash(length=8): - try: - cmd = ['git', 'rev-parse', f'--short={length}', 'HEAD'] - return "+git{}".format(subprocess.check_output(cmd).strip().decode('utf-8')) - except Exception: - return "" - -setup( - name="tritonbench", - version="0.0.1" + get_git_commit_hash(), - author="Xu Zhao", - author_email="xzhao9@meta.com", - description="A benchmark suite for OpenAI Triton and TorchInductor", - long_description="", -) \ No newline at end of file From 5b1ec59dce7ec249b10c496b4872b18a72ea96c8 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 13:38:27 -0500 Subject: [PATCH 06/13] Add modules --- tritonbench/__init__.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 tritonbench/__init__.py diff --git a/tritonbench/__init__.py b/tritonbench/__init__.py new file mode 100644 index 00000000..a7ce9df4 --- /dev/null +++ b/tritonbench/__init__.py @@ -0,0 +1,3 @@ +from . import operator_loader +from . import operators +from . import operators_collection \ No newline at end of file From 38d470f821464af046b2327cbc585216db684c2d Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 13:41:33 -0500 Subject: [PATCH 07/13] Add operators and loaders --- .gitignore | 2 +- tritonbench/operators/__init__.py | 79 +------------------------------ tritonbench/operators/op.py | 78 ++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 79 deletions(-) create mode 100644 tritonbench/operators/op.py diff --git a/.gitignore b/.gitignore index 8ca07fae..db8e0b86 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ __pycache__/ .DS_Store .ipynb_checkpoints/ .idea -torchbench.egg-info/ +*.egg-info/ diff --git a/tritonbench/operators/__init__.py b/tritonbench/operators/__init__.py index 58a7b359..497b8ba0 100644 --- a/tritonbench/operators/__init__.py +++ b/tritonbench/operators/__init__.py @@ -1,78 +1 @@ -import importlib -import os -import pathlib -from typing import List - -OPBENCH_DIR = "operators" -INTERNAL_OPBENCH_DIR = "fb" - - -def _dir_contains_file(dir, file_name) -> bool: - names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir())) - return file_name in names - - -def _is_internal_operator(op_name: str) -> bool: - p = ( - pathlib.Path(__file__) - .parent.parent.joinpath(OPBENCH_DIR) - .joinpath(INTERNAL_OPBENCH_DIR) - .joinpath(op_name) - ) - if p.exists() and p.joinpath("__init__.py").exists(): - return True - return False - - -def _list_opbench_paths() -> List[str]: - p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR) - # Only load the model directories that contain a "__init.py__" file - opbench = sorted( - str(child.absolute()) - for child in p.iterdir() - if child.is_dir() and _dir_contains_file(child, "__init__.py") - ) - p = p.joinpath(INTERNAL_OPBENCH_DIR) - if p.exists(): - o = sorted( - str(child.absolute()) - for child in p.iterdir() - if child.is_dir() and _dir_contains_file(child, "__init__.py") - ) - opbench.extend(o) - return opbench - - -def list_operators() -> List[str]: - operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths())) - if INTERNAL_OPBENCH_DIR in operators: - operators.remove(INTERNAL_OPBENCH_DIR) - return operators - - -def load_opbench_by_name(op_name: str): - opbench_list = filter( - lambda x: op_name.lower() == x.lower(), - map(lambda y: os.path.basename(y), _list_opbench_paths()), - ) - opbench_list = list(opbench_list) - if not opbench_list: - raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.") - assert ( - len(opbench_list) == 1 - ), f"Found more than one operators {opbench_list} matching the required name: {op_name}" - op_name = opbench_list[0] - op_pkg = ( - op_name - if not _is_internal_operator(op_name) - else f"{INTERNAL_OPBENCH_DIR}.{op_name}" - ) - module = importlib.import_module(f".{op_pkg}", package=__name__) - - Operator = getattr(module, "Operator", None) - if Operator is None: - print(f"Warning: {module} does not define attribute Operator, skip it") - return None - if not hasattr(Operator, "name"): - Operator.name = op_name - return Operator +from .op import load_opbench_by_name, list_operators \ No newline at end of file diff --git a/tritonbench/operators/op.py b/tritonbench/operators/op.py new file mode 100644 index 00000000..58a7b359 --- /dev/null +++ b/tritonbench/operators/op.py @@ -0,0 +1,78 @@ +import importlib +import os +import pathlib +from typing import List + +OPBENCH_DIR = "operators" +INTERNAL_OPBENCH_DIR = "fb" + + +def _dir_contains_file(dir, file_name) -> bool: + names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir())) + return file_name in names + + +def _is_internal_operator(op_name: str) -> bool: + p = ( + pathlib.Path(__file__) + .parent.parent.joinpath(OPBENCH_DIR) + .joinpath(INTERNAL_OPBENCH_DIR) + .joinpath(op_name) + ) + if p.exists() and p.joinpath("__init__.py").exists(): + return True + return False + + +def _list_opbench_paths() -> List[str]: + p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR) + # Only load the model directories that contain a "__init.py__" file + opbench = sorted( + str(child.absolute()) + for child in p.iterdir() + if child.is_dir() and _dir_contains_file(child, "__init__.py") + ) + p = p.joinpath(INTERNAL_OPBENCH_DIR) + if p.exists(): + o = sorted( + str(child.absolute()) + for child in p.iterdir() + if child.is_dir() and _dir_contains_file(child, "__init__.py") + ) + opbench.extend(o) + return opbench + + +def list_operators() -> List[str]: + operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths())) + if INTERNAL_OPBENCH_DIR in operators: + operators.remove(INTERNAL_OPBENCH_DIR) + return operators + + +def load_opbench_by_name(op_name: str): + opbench_list = filter( + lambda x: op_name.lower() == x.lower(), + map(lambda y: os.path.basename(y), _list_opbench_paths()), + ) + opbench_list = list(opbench_list) + if not opbench_list: + raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.") + assert ( + len(opbench_list) == 1 + ), f"Found more than one operators {opbench_list} matching the required name: {op_name}" + op_name = opbench_list[0] + op_pkg = ( + op_name + if not _is_internal_operator(op_name) + else f"{INTERNAL_OPBENCH_DIR}.{op_name}" + ) + module = importlib.import_module(f".{op_pkg}", package=__name__) + + Operator = getattr(module, "Operator", None) + if Operator is None: + print(f"Warning: {module} does not define attribute Operator, skip it") + return None + if not hasattr(Operator, "name"): + Operator.name = op_name + return Operator From 3bbc161e097dc1e4f008a33b90f62ed662b5583b Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 13:44:19 -0500 Subject: [PATCH 08/13] Add op collection and op to library --- tritonbench/__init__.py | 3 +- tritonbench/operators/__init__.py | 2 +- tritonbench/operators_collection/__init__.py | 72 +------------------ .../operators_collection/op_collection.py | 71 ++++++++++++++++++ 4 files changed, 74 insertions(+), 74 deletions(-) create mode 100644 tritonbench/operators_collection/op_collection.py diff --git a/tritonbench/__init__.py b/tritonbench/__init__.py index a7ce9df4..3186c98e 100644 --- a/tritonbench/__init__.py +++ b/tritonbench/__init__.py @@ -1,3 +1,2 @@ -from . import operator_loader from . import operators -from . import operators_collection \ No newline at end of file +from . import operators_collection diff --git a/tritonbench/operators/__init__.py b/tritonbench/operators/__init__.py index 497b8ba0..c2e499f4 100644 --- a/tritonbench/operators/__init__.py +++ b/tritonbench/operators/__init__.py @@ -1 +1 @@ -from .op import load_opbench_by_name, list_operators \ No newline at end of file +from .op import load_opbench_by_name, list_operators diff --git a/tritonbench/operators_collection/__init__.py b/tritonbench/operators_collection/__init__.py index 9a60deb7..5e6cc3cd 100644 --- a/tritonbench/operators_collection/__init__.py +++ b/tritonbench/operators_collection/__init__.py @@ -1,71 +1 @@ -import importlib -import pathlib -from typing import List - -OP_COLLECTION_PATH = "operators_collection" - - -def list_operator_collections() -> List[str]: - """ - List the available operator collections. - - This function retrieves the list of available operator collections by scanning the directories - in the current path that contain an "__init__.py" file. - - Returns: - List[str]: A list of names of the available operator collections. - """ - p = pathlib.Path(__file__).parent - # only load the directories that contain a "__init__.py" file - collection_paths = sorted( - str(child.absolute()) - for child in p.iterdir() - if child.is_dir() and child.joinpath("__init__.py").exists() - ) - filtered_collections = [pathlib.Path(path).name for path in collection_paths] - return filtered_collections - - -def list_operators_by_collection(op_collection: str = "default") -> List[str]: - """ - List the operators from the specified operator collections. - - This function retrieves the list of operators from the specified operator collections. - If the collection name is "all", it retrieves operators from all available collections. - If the collection name is not specified, it defaults to the "default" collection. - - Args: - op_collection (str): Names of the operator collections to list operators from. - It can be a single collection name or a comma-separated list of names. - Special value "all" retrieves operators from all collections. - - Returns: - List[str]: A list of operator names from the specified collection(s). - - Raises: - ModuleNotFoundError: If the specified collection module is not found. - AttributeError: If the specified collection module does not have a 'get_operators' function. - """ - - def _list_all_operators(collection_name: str): - try: - module_name = f".{collection_name}" - module = importlib.import_module(module_name, package=__name__) - if hasattr(module, "get_operators"): - return module.get_operators() - else: - raise AttributeError( - f"Module '{module_name}' does not have a 'get_operators' function" - ) - except ModuleNotFoundError: - raise ModuleNotFoundError(f"Module '{module_name}' not found") - - if op_collection == "all": - collection_names = list_operator_collections() - else: - collection_names = op_collection.split(",") - - all_operators = [] - for collection_name in collection_names: - all_operators.extend(_list_all_operators(collection_name)) - return all_operators +from .op_collection import list_operator_collections, list_operators_by_collection diff --git a/tritonbench/operators_collection/op_collection.py b/tritonbench/operators_collection/op_collection.py new file mode 100644 index 00000000..9a60deb7 --- /dev/null +++ b/tritonbench/operators_collection/op_collection.py @@ -0,0 +1,71 @@ +import importlib +import pathlib +from typing import List + +OP_COLLECTION_PATH = "operators_collection" + + +def list_operator_collections() -> List[str]: + """ + List the available operator collections. + + This function retrieves the list of available operator collections by scanning the directories + in the current path that contain an "__init__.py" file. + + Returns: + List[str]: A list of names of the available operator collections. + """ + p = pathlib.Path(__file__).parent + # only load the directories that contain a "__init__.py" file + collection_paths = sorted( + str(child.absolute()) + for child in p.iterdir() + if child.is_dir() and child.joinpath("__init__.py").exists() + ) + filtered_collections = [pathlib.Path(path).name for path in collection_paths] + return filtered_collections + + +def list_operators_by_collection(op_collection: str = "default") -> List[str]: + """ + List the operators from the specified operator collections. + + This function retrieves the list of operators from the specified operator collections. + If the collection name is "all", it retrieves operators from all available collections. + If the collection name is not specified, it defaults to the "default" collection. + + Args: + op_collection (str): Names of the operator collections to list operators from. + It can be a single collection name or a comma-separated list of names. + Special value "all" retrieves operators from all collections. + + Returns: + List[str]: A list of operator names from the specified collection(s). + + Raises: + ModuleNotFoundError: If the specified collection module is not found. + AttributeError: If the specified collection module does not have a 'get_operators' function. + """ + + def _list_all_operators(collection_name: str): + try: + module_name = f".{collection_name}" + module = importlib.import_module(module_name, package=__name__) + if hasattr(module, "get_operators"): + return module.get_operators() + else: + raise AttributeError( + f"Module '{module_name}' does not have a 'get_operators' function" + ) + except ModuleNotFoundError: + raise ModuleNotFoundError(f"Module '{module_name}' not found") + + if op_collection == "all": + collection_names = list_operator_collections() + else: + collection_names = op_collection.split(",") + + all_operators = [] + for collection_name in collection_names: + all_operators.extend(_list_all_operators(collection_name)) + return all_operators From 6e6df109440447dd2670777fa784712c8c9ad26c Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 13:46:07 -0500 Subject: [PATCH 09/13] Change org --- tritonbench/operators/op.py | 2 +- tritonbench/operators_collection/op_collection.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tritonbench/operators/op.py b/tritonbench/operators/op.py index 58a7b359..edcc3c45 100644 --- a/tritonbench/operators/op.py +++ b/tritonbench/operators/op.py @@ -67,7 +67,7 @@ def load_opbench_by_name(op_name: str): if not _is_internal_operator(op_name) else f"{INTERNAL_OPBENCH_DIR}.{op_name}" ) - module = importlib.import_module(f".{op_pkg}", package=__name__) + module = importlib.import_module(f"..{op_pkg}", package=__name__) Operator = getattr(module, "Operator", None) if Operator is None: diff --git a/tritonbench/operators_collection/op_collection.py b/tritonbench/operators_collection/op_collection.py index 9a60deb7..65e67750 100644 --- a/tritonbench/operators_collection/op_collection.py +++ b/tritonbench/operators_collection/op_collection.py @@ -49,7 +49,7 @@ def list_operators_by_collection(op_collection: str = "default") -> List[str]: def _list_all_operators(collection_name: str): try: - module_name = f".{collection_name}" + module_name = f"..{collection_name}" module = importlib.import_module(module_name, package=__name__) if hasattr(module, "get_operators"): return module.get_operators() From 164849db07d894a4d457c916abac93fbe86f1bf1 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 13:49:58 -0500 Subject: [PATCH 10/13] Load operators --- tritonbench/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tritonbench/__init__.py b/tritonbench/__init__.py index 3186c98e..617daa4e 100644 --- a/tritonbench/__init__.py +++ b/tritonbench/__init__.py @@ -1,2 +1,2 @@ -from . import operators -from . import operators_collection +from .operators import list_operators, load_opbench_by_name +from .operators_collection import list_operator_collections, list_operators_by_collection From 48eb3a42b7aaeaf23272f96e9ba3ecf6623aa08f Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 13:53:05 -0500 Subject: [PATCH 11/13] Linting --- tritonbench/__init__.py | 5 ++++- tritonbench/operators/__init__.py | 2 +- tritonbench/operators/addmm/hstu.py | 1 + tritonbench/operators/fp8_gemm_rowwise/operator.py | 1 + tritonbench/operators/geglu/operator.py | 1 + tritonbench/operators/gemm/operator.py | 1 + tritonbench/operators/jagged_layer_norm/operator.py | 1 + tritonbench/operators/jagged_mean/operator.py | 1 + tritonbench/operators/jagged_softmax/operator.py | 1 + tritonbench/operators/jagged_sum/operator.py | 1 + tritonbench/operators/layer_norm/operator.py | 1 + tritonbench/operators/op_task.py | 1 + tritonbench/operators/ragged_attention/hstu.py | 1 + tritonbench/operators/ragged_attention/operator.py | 1 + tritonbench/operators/softmax/operator.py | 1 + tritonbench/operators/sum/operator.py | 1 + 16 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tritonbench/__init__.py b/tritonbench/__init__.py index 617daa4e..858fb44e 100644 --- a/tritonbench/__init__.py +++ b/tritonbench/__init__.py @@ -1,2 +1,5 @@ from .operators import list_operators, load_opbench_by_name -from .operators_collection import list_operator_collections, list_operators_by_collection +from .operators_collection import ( + list_operator_collections, + list_operators_by_collection, +) diff --git a/tritonbench/operators/__init__.py b/tritonbench/operators/__init__.py index c2e499f4..38c91993 100644 --- a/tritonbench/operators/__init__.py +++ b/tritonbench/operators/__init__.py @@ -1 +1 @@ -from .op import load_opbench_by_name, list_operators +from .op import list_operators, load_opbench_by_name diff --git a/tritonbench/operators/addmm/hstu.py b/tritonbench/operators/addmm/hstu.py index ddd08073..2aeed9f8 100644 --- a/tritonbench/operators/addmm/hstu.py +++ b/tritonbench/operators/addmm/hstu.py @@ -4,6 +4,7 @@ import torch import triton + from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))): diff --git a/tritonbench/operators/fp8_gemm_rowwise/operator.py b/tritonbench/operators/fp8_gemm_rowwise/operator.py index cae6fd34..53bf30b0 100644 --- a/tritonbench/operators/fp8_gemm_rowwise/operator.py +++ b/tritonbench/operators/fp8_gemm_rowwise/operator.py @@ -6,6 +6,7 @@ import torch import triton + from tritonbench.utils.data_utils import get_production_shapes from tritonbench.utils.triton_op import ( diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index 237f850d..9613bc96 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -5,6 +5,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP + from tritonbench.utils.triton_op import ( BenchmarkOperator, register_benchmark, diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 2c5bf1f1..d02c4b48 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -6,6 +6,7 @@ import torch import torch._inductor.config as inductor_config import triton + from tritonbench.utils.data_utils import get_production_shapes from tritonbench.utils.path_utils import REPO_PATH diff --git a/tritonbench/operators/jagged_layer_norm/operator.py b/tritonbench/operators/jagged_layer_norm/operator.py index 63dc3a6e..d34f4896 100644 --- a/tritonbench/operators/jagged_layer_norm/operator.py +++ b/tritonbench/operators/jagged_layer_norm/operator.py @@ -7,6 +7,7 @@ import torch import triton + from tritonbench.utils.jagged_utils import ( ABSOLUTE_TOLERANCE, EPSILON, diff --git a/tritonbench/operators/jagged_mean/operator.py b/tritonbench/operators/jagged_mean/operator.py index d5a82269..c9c975d2 100644 --- a/tritonbench/operators/jagged_mean/operator.py +++ b/tritonbench/operators/jagged_mean/operator.py @@ -7,6 +7,7 @@ import torch import triton + from tritonbench.utils.jagged_utils import ( ABSOLUTE_TOLERANCE, generate_input_vals, diff --git a/tritonbench/operators/jagged_softmax/operator.py b/tritonbench/operators/jagged_softmax/operator.py index aad5a41c..cc3284c0 100644 --- a/tritonbench/operators/jagged_softmax/operator.py +++ b/tritonbench/operators/jagged_softmax/operator.py @@ -7,6 +7,7 @@ import torch import triton + from tritonbench.utils.jagged_utils import ( ABSOLUTE_TOLERANCE, generate_input_vals, diff --git a/tritonbench/operators/jagged_sum/operator.py b/tritonbench/operators/jagged_sum/operator.py index c531186c..36d519db 100644 --- a/tritonbench/operators/jagged_sum/operator.py +++ b/tritonbench/operators/jagged_sum/operator.py @@ -7,6 +7,7 @@ import torch import triton + from tritonbench.utils.jagged_utils import ( ABSOLUTE_TOLERANCE, generate_input_vals, diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index d2e1df29..6627697c 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F import triton + from tritonbench.utils.triton_op import ( BenchmarkOperator, BenchmarkOperatorMetrics, diff --git a/tritonbench/operators/op_task.py b/tritonbench/operators/op_task.py index fc358ceb..413025b6 100644 --- a/tritonbench/operators/op_task.py +++ b/tritonbench/operators/op_task.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional import torch + from tritonbench.components.tasks import base as base_task from tritonbench.components.workers import subprocess_worker diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index c6dd4010..025bad0d 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -1,5 +1,6 @@ import torch import triton + from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH from tritonbench.utils.triton_op import IS_FBCODE diff --git a/tritonbench/operators/ragged_attention/operator.py b/tritonbench/operators/ragged_attention/operator.py index 1beb351c..52db5482 100644 --- a/tritonbench/operators/ragged_attention/operator.py +++ b/tritonbench/operators/ragged_attention/operator.py @@ -3,6 +3,7 @@ from typing import Any, Callable, List, Optional import torch + from tritonbench.utils.input import input_filter from tritonbench.utils.triton_op import ( diff --git a/tritonbench/operators/softmax/operator.py b/tritonbench/operators/softmax/operator.py index 08840dd1..b079eab6 100644 --- a/tritonbench/operators/softmax/operator.py +++ b/tritonbench/operators/softmax/operator.py @@ -3,6 +3,7 @@ import torch import triton import triton.language as tl + from tritonbench.utils.data_utils import get_production_shapes from tritonbench.utils.triton_op import ( diff --git a/tritonbench/operators/sum/operator.py b/tritonbench/operators/sum/operator.py index ffdc790d..7f60c244 100644 --- a/tritonbench/operators/sum/operator.py +++ b/tritonbench/operators/sum/operator.py @@ -7,6 +7,7 @@ import torch import triton import triton.language as tl + from tritonbench.utils.triton_op import ( BenchmarkOperator, BenchmarkOperatorMetrics, From 0e64ae1fe47fa28f79412312be2931a369ec7ddc Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 14:07:07 -0500 Subject: [PATCH 12/13] Add demo script --- README.md | 16 +++++++++++++++- tritonbench/utils/__init__.py | 3 +++ 2 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 tritonbench/utils/__init__.py diff --git a/README.md b/README.md index 160ec7dd..ecc816f6 100644 --- a/README.md +++ b/README.md @@ -26,12 +26,26 @@ By default, it will install the latest PyTorch nightly release and use the Trito ## Basic Usage -To benchmark an operator, use the following command: +To benchmark an operator, run the following command: ``` $ python run.py --op gemm ``` +## Install as a library + +To install as a library: + +``` +$ pip install -e . +# in your own benchmark script +import tritonbench +from tritonbench.utils import parser +op_args = parser.parse_args() +addmm_bench = tritonbench.load_opbench_by_name("addmm")(op_args) +addmm_bench.run() +``` + ## Submodules We depend on the following projects as a source of customized Triton or CUTLASS kernels: diff --git a/tritonbench/utils/__init__.py b/tritonbench/utils/__init__.py new file mode 100644 index 00000000..672c3f1d --- /dev/null +++ b/tritonbench/utils/__init__.py @@ -0,0 +1,3 @@ +from .parser import get_parser + +parser = get_parser() \ No newline at end of file From 11eb5debe5a6a3c697d3c734fe46de0b2ef44e1b Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 14:08:10 -0500 Subject: [PATCH 13/13] Load parser --- tritonbench/utils/__init__.py | 2 +- tritonbench/utils/triton_op.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tritonbench/utils/__init__.py b/tritonbench/utils/__init__.py index 672c3f1d..3b76e765 100644 --- a/tritonbench/utils/__init__.py +++ b/tritonbench/utils/__init__.py @@ -1,3 +1,3 @@ from .parser import get_parser -parser = get_parser() \ No newline at end of file +parser = get_parser() diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 38c6c037..d5cfb277 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -22,6 +22,7 @@ import tabulate import torch import triton + from tritonbench.components.ncu import analyzer as ncu_analyzer from tritonbench.utils.env_utils import ( apply_precision,