Skip to content

Commit

Permalink
Install tritonbench as a library (#81)
Browse files Browse the repository at this point in the history
Summary:
Now users can import Tritonbench as a library:

```
$ pip install -e .
$ python -c "import tritonbench; op = tritonbench.load_opbench_by_name('addmm');"
```

Clean up the init file so that dependencies like `os` and `importlib` will not pollute the namespace.

Pull Request resolved: #81

Reviewed By: FindHao

Differential Revision: D66544679

Pulled By: xuzhao9

fbshipit-source-id: 0339adef58d28f2c7f207a2869c21d4b55575386
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 27, 2024
1 parent c666f87 commit bec6d9f
Show file tree
Hide file tree
Showing 24 changed files with 205 additions and 151 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ __pycache__/
.DS_Store
.ipynb_checkpoints/
.idea
torchbench.egg-info/
*.egg-info/
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,26 @@ By default, it will install the latest PyTorch nightly release and use the Trito

## Basic Usage

To benchmark an operator, use the following command:
To benchmark an operator, run the following command:

```
$ python run.py --op gemm
```

## Install as a library

To install as a library:

```
$ pip install -e .
# in your own benchmark script
import tritonbench
from tritonbench.utils import parser
op_args = parser.parse_args()
addmm_bench = tritonbench.load_opbench_by_name("addmm")(op_args)
addmm_bench.run()
```

## Submodules

We depend on the following projects as a source of customized Triton or CUTLASS kernels:
Expand Down
15 changes: 15 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
[build-system]
requires = ["setuptools>=40.8.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "tritonbench"
version = "0.0.1"
dependencies = [
"torch",
"triton",
]

[tool.setuptools.packages.find]
include = ["tritonbench*"]

[tool.ufmt]
formatter = "ruff-api"
excludes = ["submodules/"]
Expand Down
5 changes: 5 additions & 0 deletions tritonbench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .operators import list_operators, load_opbench_by_name
from .operators_collection import (
list_operator_collections,
list_operators_by_collection,
)
79 changes: 1 addition & 78 deletions tritonbench/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,78 +1 @@
import importlib
import os
import pathlib
from typing import List

OPBENCH_DIR = "operators"
INTERNAL_OPBENCH_DIR = "fb"


def _dir_contains_file(dir, file_name) -> bool:
names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir()))
return file_name in names


def _is_internal_operator(op_name: str) -> bool:
p = (
pathlib.Path(__file__)
.parent.parent.joinpath(OPBENCH_DIR)
.joinpath(INTERNAL_OPBENCH_DIR)
.joinpath(op_name)
)
if p.exists() and p.joinpath("__init__.py").exists():
return True
return False


def _list_opbench_paths() -> List[str]:
p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR)
# Only load the model directories that contain a "__init.py__" file
opbench = sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and _dir_contains_file(child, "__init__.py")
)
p = p.joinpath(INTERNAL_OPBENCH_DIR)
if p.exists():
o = sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and _dir_contains_file(child, "__init__.py")
)
opbench.extend(o)
return opbench


def list_operators() -> List[str]:
operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths()))
if INTERNAL_OPBENCH_DIR in operators:
operators.remove(INTERNAL_OPBENCH_DIR)
return operators


def load_opbench_by_name(op_name: str):
opbench_list = filter(
lambda x: op_name.lower() == x.lower(),
map(lambda y: os.path.basename(y), _list_opbench_paths()),
)
opbench_list = list(opbench_list)
if not opbench_list:
raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.")
assert (
len(opbench_list) == 1
), f"Found more than one operators {opbench_list} matching the required name: {op_name}"
op_name = opbench_list[0]
op_pkg = (
op_name
if not _is_internal_operator(op_name)
else f"{INTERNAL_OPBENCH_DIR}.{op_name}"
)
module = importlib.import_module(f".{op_pkg}", package=__name__)

Operator = getattr(module, "Operator", None)
if Operator is None:
print(f"Warning: {module} does not define attribute Operator, skip it")
return None
if not hasattr(Operator, "name"):
Operator.name = op_name
return Operator
from .op import list_operators, load_opbench_by_name
1 change: 1 addition & 0 deletions tritonbench/operators/addmm/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import triton

from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH

with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))):
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import triton

from tritonbench.utils.data_utils import get_production_shapes

from tritonbench.utils.triton_op import (
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/geglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP

from tritonbench.utils.triton_op import (
BenchmarkOperator,
register_benchmark,
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch._inductor.config as inductor_config
import triton

from tritonbench.utils.data_utils import get_production_shapes

from tritonbench.utils.path_utils import REPO_PATH
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/jagged_layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import triton

from tritonbench.utils.jagged_utils import (
ABSOLUTE_TOLERANCE,
EPSILON,
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/jagged_mean/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import triton

from tritonbench.utils.jagged_utils import (
ABSOLUTE_TOLERANCE,
generate_input_vals,
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/jagged_softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import triton

from tritonbench.utils.jagged_utils import (
ABSOLUTE_TOLERANCE,
generate_input_vals,
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import triton

from tritonbench.utils.jagged_utils import (
ABSOLUTE_TOLERANCE,
generate_input_vals,
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn.functional as F
import triton

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down
78 changes: 78 additions & 0 deletions tritonbench/operators/op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import importlib
import os
import pathlib
from typing import List

OPBENCH_DIR = "operators"
INTERNAL_OPBENCH_DIR = "fb"


def _dir_contains_file(dir, file_name) -> bool:
names = map(lambda x: x.name, filter(lambda x: x.is_file(), dir.iterdir()))
return file_name in names


def _is_internal_operator(op_name: str) -> bool:
p = (
pathlib.Path(__file__)
.parent.parent.joinpath(OPBENCH_DIR)
.joinpath(INTERNAL_OPBENCH_DIR)
.joinpath(op_name)
)
if p.exists() and p.joinpath("__init__.py").exists():
return True
return False


def _list_opbench_paths() -> List[str]:
p = pathlib.Path(__file__).parent.parent.joinpath(OPBENCH_DIR)
# Only load the model directories that contain a "__init.py__" file
opbench = sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and _dir_contains_file(child, "__init__.py")
)
p = p.joinpath(INTERNAL_OPBENCH_DIR)
if p.exists():
o = sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and _dir_contains_file(child, "__init__.py")
)
opbench.extend(o)
return opbench


def list_operators() -> List[str]:
operators = list(map(lambda y: os.path.basename(y), _list_opbench_paths()))
if INTERNAL_OPBENCH_DIR in operators:
operators.remove(INTERNAL_OPBENCH_DIR)
return operators


def load_opbench_by_name(op_name: str):
opbench_list = filter(
lambda x: op_name.lower() == x.lower(),
map(lambda y: os.path.basename(y), _list_opbench_paths()),
)
opbench_list = list(opbench_list)
if not opbench_list:
raise RuntimeError(f"{op_name} is not found in the Tritonbench operator list.")
assert (
len(opbench_list) == 1
), f"Found more than one operators {opbench_list} matching the required name: {op_name}"
op_name = opbench_list[0]
op_pkg = (
op_name
if not _is_internal_operator(op_name)
else f"{INTERNAL_OPBENCH_DIR}.{op_name}"
)
module = importlib.import_module(f"..{op_pkg}", package=__name__)

Operator = getattr(module, "Operator", None)
if Operator is None:
print(f"Warning: {module} does not define attribute Operator, skip it")
return None
if not hasattr(Operator, "name"):
Operator.name = op_name
return Operator
1 change: 1 addition & 0 deletions tritonbench/operators/op_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict, List, Optional

import torch

from tritonbench.components.tasks import base as base_task
from tritonbench.components.workers import subprocess_worker

Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import triton

from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH
from tritonbench.utils.triton_op import IS_FBCODE

Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, List, Optional

import torch

from tritonbench.utils.input import input_filter

from tritonbench.utils.triton_op import (
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import triton
import triton.language as tl

from tritonbench.utils.data_utils import get_production_shapes

from tritonbench.utils.triton_op import (
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import triton
import triton.language as tl

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down
Loading

0 comments on commit bec6d9f

Please sign in to comment.