Skip to content

Commit

Permalink
Fix gemv
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 18, 2024
1 parent 59bb8fe commit 597facf
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 0 additions & 2 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@

TEST_OPERATORS = set(list_operators_by_collection(op_collection="default")) - SKIP_OPS

print(f"Testing operators: {TEST_OPERATORS}")


def check_ci_output(op):
from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/addmm/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))):
from generative_recommenders.ops.triton.triton_addmm import _addmm_fwd


class _AddMmFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
Expand Down
3 changes: 3 additions & 0 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,11 +521,13 @@ def __call__(cls, *args, **kwargs):
obj.__post__init__()
return obj


def _translate_mode(tb_args):
def _has_and_true(attr):
if hasattr(tb_args, attr) and getattr(tb_args, attr):
return True
return False

if _has_and_true("fwd"):
tb_args.mode = "fwd"
if _has_and_true("bwd"):
Expand All @@ -535,6 +537,7 @@ def _has_and_true(attr):
if _has_and_true("fwd_no_grad"):
tb_args.mode = "fwd_no_grad"


class BenchmarkOperator(metaclass=PostInitProcessor):
mode: Mode = Mode.FWD
test: str = "eval"
Expand Down

0 comments on commit 597facf

Please sign in to comment.