Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install tritonbench as a library #81

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ __pycache__/
.DS_Store
.ipynb_checkpoints/
.idea
torchbench.egg-info/
*.egg-info/
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,26 @@ By default, it will install the latest PyTorch nightly release and use the Trito

## Basic Usage

To benchmark an operator, use the following command:
To benchmark an operator, run the following command:

```
$ python run.py --op gemm
```

## Install as a library

To install as a library:

```
$ pip install -e .
# in your own benchmark script
import tritonbench
from tritonbench.utils import parser
op_args = parser.parse_args()
addmm_bench = tritonbench.load_opbench_by_name("addmm")(op_args)
addmm_bench.run()
```

## Submodules

We depend on the following projects as a source of customized Triton or CUTLASS kernels:
Expand Down
15 changes: 15 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
[build-system]
requires = ["setuptools>=40.8.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "tritonbench"
version = "0.0.1"
dependencies = [
"torch",
"triton",
]

[tool.setuptools.packages.find]
include = ["tritonbench*"]

[tool.ufmt]
formatter = "ruff-api"
excludes = ["submodules/"]
Expand Down
5 changes: 5 additions & 0 deletions tritonbench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .operators import list_operators, load_opbench_by_name
from .operators_collection import (
list_operator_collections,
list_operators_by_collection,
)
79 changes: 1 addition & 78 deletions tritonbench/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,78 +1 @@
import importlib
import os
import pathlib
from typing import List

OPBENCH_DIR = "operators"
INTERNAL_OPBENCH_DIR = "fb"


def _dir_contains_file(dir, file_name) -> bool:
names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir()))
return file_name in names


def _is_internal_operator(op_name: str) -> bool:
p = (
pathlib.Path(__file__)
.parent.parent.joinpath(OPBENCH_DIR)
.joinpath(INTERNAL_OPBENCH_DIR)
.joinpath(op_name)
)
if p.exists() and p.joinpath("__init__.py").exists():
return True
return False


def _list_opbench_paths() -> List[str]:
p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR)
# Only load the model directories that contain a "__init.py__" file
opbench = sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and _dir_contains_file(child, "__init__.py")
)
p = p.joinpath(INTERNAL_OPBENCH_DIR)
if p.exists():
o = sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and _dir_contains_file(child, "__init__.py")
)
opbench.extend(o)
return opbench


def list_operators() -> List[str]:
operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths()))
if INTERNAL_OPBENCH_DIR in operators:
operators.remove(INTERNAL_OPBENCH_DIR)
return operators


def load_opbench_by_name(op_name: str):
opbench_list = filter(
lambda x: op_name.lower() == x.lower(),
map(lambda y: os.path.basename(y), _list_opbench_paths()),
)
opbench_list = list(opbench_list)
if not opbench_list:
raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.")
assert (
len(opbench_list) == 1
), f"Found more than one operators {opbench_list} matching the required name: {op_name}"
op_name = opbench_list[0]
op_pkg = (
op_name
if not _is_internal_operator(op_name)
else f"{INTERNAL_OPBENCH_DIR}.{op_name}"
)
module = importlib.import_module(f".{op_pkg}", package=__name__)

Operator = getattr(module, "Operator", None)
if Operator is None:
print(f"Warning: {module} does not define attribute Operator, skip it")
return None
if not hasattr(Operator, "name"):
Operator.name = op_name
return Operator
from .op import list_operators, load_opbench_by_name
1 change: 1 addition & 0 deletions tritonbench/operators/addmm/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import triton

from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH

with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))):
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import triton

from tritonbench.utils.data_utils import get_production_shapes

from tritonbench.utils.triton_op import (
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/geglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP

from tritonbench.utils.triton_op import (
BenchmarkOperator,
register_benchmark,
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch._inductor.config as inductor_config
import triton

from tritonbench.utils.data_utils import get_production_shapes

from tritonbench.utils.path_utils import REPO_PATH
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/jagged_layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import triton

from tritonbench.utils.jagged_utils import (
ABSOLUTE_TOLERANCE,
EPSILON,
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/jagged_mean/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import triton

from tritonbench.utils.jagged_utils import (
ABSOLUTE_TOLERANCE,
generate_input_vals,
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/jagged_softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import triton

from tritonbench.utils.jagged_utils import (
ABSOLUTE_TOLERANCE,
generate_input_vals,
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import triton

from tritonbench.utils.jagged_utils import (
ABSOLUTE_TOLERANCE,
generate_input_vals,
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn.functional as F
import triton

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down
78 changes: 78 additions & 0 deletions tritonbench/operators/op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import importlib
import os
import pathlib
from typing import List

OPBENCH_DIR = "operators"
INTERNAL_OPBENCH_DIR = "fb"


def _dir_contains_file(dir, file_name) -> bool:
names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir()))
return file_name in names


def _is_internal_operator(op_name: str) -> bool:
p = (
pathlib.Path(__file__)
.parent.parent.joinpath(OPBENCH_DIR)
.joinpath(INTERNAL_OPBENCH_DIR)
.joinpath(op_name)
)
if p.exists() and p.joinpath("__init__.py").exists():
return True
return False


def _list_opbench_paths() -> List[str]:
p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR)
# Only load the model directories that contain a "__init.py__" file
opbench = sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and _dir_contains_file(child, "__init__.py")
)
p = p.joinpath(INTERNAL_OPBENCH_DIR)
if p.exists():
o = sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and _dir_contains_file(child, "__init__.py")
)
opbench.extend(o)
return opbench


def list_operators() -> List[str]:
operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths()))
if INTERNAL_OPBENCH_DIR in operators:
operators.remove(INTERNAL_OPBENCH_DIR)
return operators


def load_opbench_by_name(op_name: str):
opbench_list = filter(
lambda x: op_name.lower() == x.lower(),
map(lambda y: os.path.basename(y), _list_opbench_paths()),
)
opbench_list = list(opbench_list)
if not opbench_list:
raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.")
assert (
len(opbench_list) == 1
), f"Found more than one operators {opbench_list} matching the required name: {op_name}"
op_name = opbench_list[0]
op_pkg = (
op_name
if not _is_internal_operator(op_name)
else f"{INTERNAL_OPBENCH_DIR}.{op_name}"
)
module = importlib.import_module(f"..{op_pkg}", package=__name__)

Operator = getattr(module, "Operator", None)
if Operator is None:
print(f"Warning: {module} does not define attribute Operator, skip it")
return None
if not hasattr(Operator, "name"):
Operator.name = op_name
return Operator
1 change: 1 addition & 0 deletions tritonbench/operators/op_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict, List, Optional

import torch

from tritonbench.components.tasks import base as base_task
from tritonbench.components.workers import subprocess_worker

Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import triton

from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH
from tritonbench.utils.triton_op import IS_FBCODE

Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, List, Optional

import torch

from tritonbench.utils.input import input_filter

from tritonbench.utils.triton_op import (
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import triton
import triton.language as tl

from tritonbench.utils.data_utils import get_production_shapes

from tritonbench.utils.triton_op import (
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import triton
import triton.language as tl

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down
Loading
Loading