From 3bbc161e097dc1e4f008a33b90f62ed662b5583b Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 26 Nov 2024 13:44:19 -0500 Subject: [PATCH] Add op collection and op to library --- tritonbench/__init__.py | 3 +- tritonbench/operators/__init__.py | 2 +- tritonbench/operators_collection/__init__.py | 72 +------------------ .../operators_collection/op_collection.py | 71 ++++++++++++++++++ 4 files changed, 74 insertions(+), 74 deletions(-) create mode 100644 tritonbench/operators_collection/op_collection.py diff --git a/tritonbench/__init__.py b/tritonbench/__init__.py index a7ce9df4..3186c98e 100644 --- a/tritonbench/__init__.py +++ b/tritonbench/__init__.py @@ -1,3 +1,2 @@ -from . import operator_loader from . import operators -from . import operators_collection \ No newline at end of file +from . import operators_collection diff --git a/tritonbench/operators/__init__.py b/tritonbench/operators/__init__.py index 497b8ba0..c2e499f4 100644 --- a/tritonbench/operators/__init__.py +++ b/tritonbench/operators/__init__.py @@ -1 +1 @@ -from .op import load_opbench_by_name, list_operators \ No newline at end of file +from .op import load_opbench_by_name, list_operators diff --git a/tritonbench/operators_collection/__init__.py b/tritonbench/operators_collection/__init__.py index 9a60deb7..5e6cc3cd 100644 --- a/tritonbench/operators_collection/__init__.py +++ b/tritonbench/operators_collection/__init__.py @@ -1,71 +1 @@ -import importlib -import pathlib -from typing import List - -OP_COLLECTION_PATH = "operators_collection" - - -def list_operator_collections() -> List[str]: - """ - List the available operator collections. - - This function retrieves the list of available operator collections by scanning the directories - in the current path that contain an "__init__.py" file. - - Returns: - List[str]: A list of names of the available operator collections. - """ - p = pathlib.Path(__file__).parent - # only load the directories that contain a "__init__.py" file - collection_paths = sorted( - str(child.absolute()) - for child in p.iterdir() - if child.is_dir() and child.joinpath("__init__.py").exists() - ) - filtered_collections = [pathlib.Path(path).name for path in collection_paths] - return filtered_collections - - -def list_operators_by_collection(op_collection: str = "default") -> List[str]: - """ - List the operators from the specified operator collections. - - This function retrieves the list of operators from the specified operator collections. - If the collection name is "all", it retrieves operators from all available collections. - If the collection name is not specified, it defaults to the "default" collection. - - Args: - op_collection (str): Names of the operator collections to list operators from. - It can be a single collection name or a comma-separated list of names. - Special value "all" retrieves operators from all collections. - - Returns: - List[str]: A list of operator names from the specified collection(s). - - Raises: - ModuleNotFoundError: If the specified collection module is not found. - AttributeError: If the specified collection module does not have a 'get_operators' function. - """ - - def _list_all_operators(collection_name: str): - try: - module_name = f".{collection_name}" - module = importlib.import_module(module_name, package=__name__) - if hasattr(module, "get_operators"): - return module.get_operators() - else: - raise AttributeError( - f"Module '{module_name}' does not have a 'get_operators' function" - ) - except ModuleNotFoundError: - raise ModuleNotFoundError(f"Module '{module_name}' not found") - - if op_collection == "all": - collection_names = list_operator_collections() - else: - collection_names = op_collection.split(",") - - all_operators = [] - for collection_name in collection_names: - all_operators.extend(_list_all_operators(collection_name)) - return all_operators +from .op_collection import list_operator_collections, list_operators_by_collection diff --git a/tritonbench/operators_collection/op_collection.py b/tritonbench/operators_collection/op_collection.py new file mode 100644 index 00000000..9a60deb7 --- /dev/null +++ b/tritonbench/operators_collection/op_collection.py @@ -0,0 +1,71 @@ +import importlib +import pathlib +from typing import List + +OP_COLLECTION_PATH = "operators_collection" + + +def list_operator_collections() -> List[str]: + """ + List the available operator collections. + + This function retrieves the list of available operator collections by scanning the directories + in the current path that contain an "__init__.py" file. + + Returns: + List[str]: A list of names of the available operator collections. + """ + p = pathlib.Path(__file__).parent + # only load the directories that contain a "__init__.py" file + collection_paths = sorted( + str(child.absolute()) + for child in p.iterdir() + if child.is_dir() and child.joinpath("__init__.py").exists() + ) + filtered_collections = [pathlib.Path(path).name for path in collection_paths] + return filtered_collections + + +def list_operators_by_collection(op_collection: str = "default") -> List[str]: + """ + List the operators from the specified operator collections. + + This function retrieves the list of operators from the specified operator collections. + If the collection name is "all", it retrieves operators from all available collections. + If the collection name is not specified, it defaults to the "default" collection. + + Args: + op_collection (str): Names of the operator collections to list operators from. + It can be a single collection name or a comma-separated list of names. + Special value "all" retrieves operators from all collections. + + Returns: + List[str]: A list of operator names from the specified collection(s). + + Raises: + ModuleNotFoundError: If the specified collection module is not found. + AttributeError: If the specified collection module does not have a 'get_operators' function. + """ + + def _list_all_operators(collection_name: str): + try: + module_name = f".{collection_name}" + module = importlib.import_module(module_name, package=__name__) + if hasattr(module, "get_operators"): + return module.get_operators() + else: + raise AttributeError( + f"Module '{module_name}' does not have a 'get_operators' function" + ) + except ModuleNotFoundError: + raise ModuleNotFoundError(f"Module '{module_name}' not found") + + if op_collection == "all": + collection_names = list_operator_collections() + else: + collection_names = op_collection.split(",") + + all_operators = [] + for collection_name in collection_names: + all_operators.extend(_list_all_operators(collection_name)) + return all_operators