-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Install tritonbench as a library (#81)
Summary: Now users can import Tritonbench as a library: ``` $ pip install -e . $ python -c "import tritonbench; op = tritonbench.load_opbench_by_name('addmm');" ``` Clean up the init file so that dependencies like `os` and `importlib` will not pollute the namespace. Pull Request resolved: #81 Reviewed By: FindHao Differential Revision: D66544679 Pulled By: xuzhao9 fbshipit-source-id: 0339adef58d28f2c7f207a2869c21d4b55575386
- Loading branch information
1 parent
c666f87
commit bec6d9f
Showing
24 changed files
with
205 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,4 @@ __pycache__/ | |
.DS_Store | ||
.ipynb_checkpoints/ | ||
.idea | ||
torchbench.egg-info/ | ||
*.egg-info/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .operators import list_operators, load_opbench_by_name | ||
from .operators_collection import ( | ||
list_operator_collections, | ||
list_operators_by_collection, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +1 @@ | ||
import importlib | ||
import os | ||
import pathlib | ||
from typing import List | ||
|
||
OPBENCH_DIR = "operators" | ||
INTERNAL_OPBENCH_DIR = "fb" | ||
|
||
|
||
def _dir_contains_file(dir, file_name) -> bool: | ||
names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir())) | ||
return file_name in names | ||
|
||
|
||
def _is_internal_operator(op_name: str) -> bool: | ||
p = ( | ||
pathlib.Path(__file__) | ||
.parent.parent.joinpath(OPBENCH_DIR) | ||
.joinpath(INTERNAL_OPBENCH_DIR) | ||
.joinpath(op_name) | ||
) | ||
if p.exists() and p.joinpath("__init__.py").exists(): | ||
return True | ||
return False | ||
|
||
|
||
def _list_opbench_paths() -> List[str]: | ||
p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR) | ||
# Only load the model directories that contain a "__init.py__" file | ||
opbench = sorted( | ||
str(child.absolute()) | ||
for child in p.iterdir() | ||
if child.is_dir() and _dir_contains_file(child, "__init__.py") | ||
) | ||
p = p.joinpath(INTERNAL_OPBENCH_DIR) | ||
if p.exists(): | ||
o = sorted( | ||
str(child.absolute()) | ||
for child in p.iterdir() | ||
if child.is_dir() and _dir_contains_file(child, "__init__.py") | ||
) | ||
opbench.extend(o) | ||
return opbench | ||
|
||
|
||
def list_operators() -> List[str]: | ||
operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths())) | ||
if INTERNAL_OPBENCH_DIR in operators: | ||
operators.remove(INTERNAL_OPBENCH_DIR) | ||
return operators | ||
|
||
|
||
def load_opbench_by_name(op_name: str): | ||
opbench_list = filter( | ||
lambda x: op_name.lower() == x.lower(), | ||
map(lambda y: os.path.basename(y), _list_opbench_paths()), | ||
) | ||
opbench_list = list(opbench_list) | ||
if not opbench_list: | ||
raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.") | ||
assert ( | ||
len(opbench_list) == 1 | ||
), f"Found more than one operators {opbench_list} matching the required name: {op_name}" | ||
op_name = opbench_list[0] | ||
op_pkg = ( | ||
op_name | ||
if not _is_internal_operator(op_name) | ||
else f"{INTERNAL_OPBENCH_DIR}.{op_name}" | ||
) | ||
module = importlib.import_module(f".{op_pkg}", package=__name__) | ||
|
||
Operator = getattr(module, "Operator", None) | ||
if Operator is None: | ||
print(f"Warning: {module} does not define attribute Operator, skip it") | ||
return None | ||
if not hasattr(Operator, "name"): | ||
Operator.name = op_name | ||
return Operator | ||
from .op import list_operators, load_opbench_by_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
|
||
import torch | ||
import triton | ||
|
||
from tritonbench.utils.jagged_utils import ( | ||
ABSOLUTE_TOLERANCE, | ||
EPSILON, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import importlib | ||
import os | ||
import pathlib | ||
from typing import List | ||
|
||
OPBENCH_DIR = "operators" | ||
INTERNAL_OPBENCH_DIR = "fb" | ||
|
||
|
||
def _dir_contains_file(dir, file_name) -> bool: | ||
names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir())) | ||
return file_name in names | ||
|
||
|
||
def _is_internal_operator(op_name: str) -> bool: | ||
p = ( | ||
pathlib.Path(__file__) | ||
.parent.parent.joinpath(OPBENCH_DIR) | ||
.joinpath(INTERNAL_OPBENCH_DIR) | ||
.joinpath(op_name) | ||
) | ||
if p.exists() and p.joinpath("__init__.py").exists(): | ||
return True | ||
return False | ||
|
||
|
||
def _list_opbench_paths() -> List[str]: | ||
p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR) | ||
# Only load the model directories that contain a "__init.py__" file | ||
opbench = sorted( | ||
str(child.absolute()) | ||
for child in p.iterdir() | ||
if child.is_dir() and _dir_contains_file(child, "__init__.py") | ||
) | ||
p = p.joinpath(INTERNAL_OPBENCH_DIR) | ||
if p.exists(): | ||
o = sorted( | ||
str(child.absolute()) | ||
for child in p.iterdir() | ||
if child.is_dir() and _dir_contains_file(child, "__init__.py") | ||
) | ||
opbench.extend(o) | ||
return opbench | ||
|
||
|
||
def list_operators() -> List[str]: | ||
operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths())) | ||
if INTERNAL_OPBENCH_DIR in operators: | ||
operators.remove(INTERNAL_OPBENCH_DIR) | ||
return operators | ||
|
||
|
||
def load_opbench_by_name(op_name: str): | ||
opbench_list = filter( | ||
lambda x: op_name.lower() == x.lower(), | ||
map(lambda y: os.path.basename(y), _list_opbench_paths()), | ||
) | ||
opbench_list = list(opbench_list) | ||
if not opbench_list: | ||
raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.") | ||
assert ( | ||
len(opbench_list) == 1 | ||
), f"Found more than one operators {opbench_list} matching the required name: {op_name}" | ||
op_name = opbench_list[0] | ||
op_pkg = ( | ||
op_name | ||
if not _is_internal_operator(op_name) | ||
else f"{INTERNAL_OPBENCH_DIR}.{op_name}" | ||
) | ||
module = importlib.import_module(f"..{op_pkg}", package=__name__) | ||
|
||
Operator = getattr(module, "Operator", None) | ||
if Operator is None: | ||
print(f"Warning: {module} does not define attribute Operator, skip it") | ||
return None | ||
if not hasattr(Operator, "name"): | ||
Operator.name = op_name | ||
return Operator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.