Skip to content

Commit

Permalink
Add op collection and op to library
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 26, 2024
1 parent 38d470f commit 3bbc161
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 74 deletions.
3 changes: 1 addition & 2 deletions tritonbench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from . import operator_loader
from . import operators
from . import operators_collection
from . import operators_collection
2 changes: 1 addition & 1 deletion tritonbench/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .op import load_opbench_by_name, list_operators
from .op import load_opbench_by_name, list_operators
72 changes: 1 addition & 71 deletions tritonbench/operators_collection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,71 +1 @@
import importlib
import pathlib
from typing import List

OP_COLLECTION_PATH = "operators_collection"


def list_operator_collections() -> List[str]:
"""
List the available operator collections.
This function retrieves the list of available operator collections by scanning the directories
in the current path that contain an "__init__.py" file.
Returns:
List[str]: A list of names of the available operator collections.
"""
p = pathlib.Path(__file__).parent
# only load the directories that contain a "__init__.py" file
collection_paths = sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and child.joinpath("__init__.py").exists()
)
filtered_collections = [pathlib.Path(path).name for path in collection_paths]
return filtered_collections


def list_operators_by_collection(op_collection: str = "default") -> List[str]:
"""
List the operators from the specified operator collections.
This function retrieves the list of operators from the specified operator collections.
If the collection name is "all", it retrieves operators from all available collections.
If the collection name is not specified, it defaults to the "default" collection.
Args:
op_collection (str): Names of the operator collections to list operators from.
It can be a single collection name or a comma-separated list of names.
Special value "all" retrieves operators from all collections.
Returns:
List[str]: A list of operator names from the specified collection(s).
Raises:
ModuleNotFoundError: If the specified collection module is not found.
AttributeError: If the specified collection module does not have a 'get_operators' function.
"""

def _list_all_operators(collection_name: str):
try:
module_name = f".{collection_name}"
module = importlib.import_module(module_name, package=__name__)
if hasattr(module, "get_operators"):
return module.get_operators()
else:
raise AttributeError(
f"Module '{module_name}' does not have a 'get_operators' function"
)
except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module '{module_name}' not found")

if op_collection == "all":
collection_names = list_operator_collections()
else:
collection_names = op_collection.split(",")

all_operators = []
for collection_name in collection_names:
all_operators.extend(_list_all_operators(collection_name))
return all_operators
from .op_collection import list_operator_collections, list_operators_by_collection
71 changes: 71 additions & 0 deletions tritonbench/operators_collection/op_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import importlib
import pathlib
from typing import List

OP_COLLECTION_PATH = "operators_collection"


def list_operator_collections() -> List[str]:
"""
List the available operator collections.
This function retrieves the list of available operator collections by scanning the directories
in the current path that contain an "__init__.py" file.
Returns:
List[str]: A list of names of the available operator collections.
"""
p = pathlib.Path(__file__).parent
# only load the directories that contain a "__init__.py" file
collection_paths = sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and child.joinpath("__init__.py").exists()
)
filtered_collections = [pathlib.Path(path).name for path in collection_paths]
return filtered_collections


def list_operators_by_collection(op_collection: str = "default") -> List[str]:
"""
List the operators from the specified operator collections.
This function retrieves the list of operators from the specified operator collections.
If the collection name is "all", it retrieves operators from all available collections.
If the collection name is not specified, it defaults to the "default" collection.
Args:
op_collection (str): Names of the operator collections to list operators from.
It can be a single collection name or a comma-separated list of names.
Special value "all" retrieves operators from all collections.
Returns:
List[str]: A list of operator names from the specified collection(s).
Raises:
ModuleNotFoundError: If the specified collection module is not found.
AttributeError: If the specified collection module does not have a 'get_operators' function.
"""

def _list_all_operators(collection_name: str):
try:
module_name = f".{collection_name}"
module = importlib.import_module(module_name, package=__name__)
if hasattr(module, "get_operators"):
return module.get_operators()
else:
raise AttributeError(
f"Module '{module_name}' does not have a 'get_operators' function"
)
except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module '{module_name}' not found")

if op_collection == "all":
collection_names = list_operator_collections()
else:
collection_names = op_collection.split(",")

all_operators = []
for collection_name in collection_names:
all_operators.extend(_list_all_operators(collection_name))
return all_operators

0 comments on commit 3bbc161

Please sign in to comment.