Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable bwd for flash_attention #74

Closed
wants to merge 11 commits into from
11 changes: 7 additions & 4 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@


def check_ci_output(op):
from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS
from tritonbench.utils.triton_op import (
find_enabled_benchmarks,
REGISTERED_BENCHMARKS,
)

output = op.output
output_impls = output.result[0][1].keys()
ci_enabled_impls = [
x for x in REGISTERED_BENCHMARKS[output.op_name].keys() if x not in op._skip
]
ci_enabled_impls = find_enabled_benchmarks(
op.mode, REGISTERED_BENCHMARKS[op.name], op._skip
)
# Make sure that all the ci_enabled impls are in the output
logger.info(f"output impls: {output_impls}, ci_enabled impls: {ci_enabled_impls}")
assert set(output_impls) == set(
Expand Down
2 changes: 0 additions & 2 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,3 @@ jagged_sum:
ragged_attention:
- hstu_triton_ragged_attention_persistent
test_op:
fwd_only_ops:
- flash_attention
2 changes: 0 additions & 2 deletions test/test_gpu/skip_tests_h100_triton_main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,3 @@ jagged_sum:
# FIXME: ragged attention will Abort (Core Dump) on Triton Main
ragged_attention:
test_op:
fwd_only_ops:
- flash_attention
2 changes: 1 addition & 1 deletion tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,7 +1954,7 @@ def backward(ctx, do):
num_stages=NUM_STAGES, #
)

return dq, dk, dv, None, None
return dq, dk, dv, None, None, None


attention_opt = _attention_opt.apply
12 changes: 9 additions & 3 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,20 +338,26 @@ def xformers(
k: torch.Tensor,
v: torch.Tensor,
) -> Callable:
need_gradient = not (self.mode == BenchmarkMode.FWD_NO_GRAD)
fhma_input = self.xformers_preprocess(q, k, v)
xformers_cutlass_fhma = xformers.ops.fmha.cutlass.FwOp
return lambda: xformers_cutlass_fhma().apply(fhma_input, needs_gradient=False)
return lambda: xformers_cutlass_fhma().apply(
fhma_input, needs_gradient=need_gradient
)

@register_benchmark(enabled=HAS_XFORMERS)
@register_benchmark(enabled=HAS_XFORMERS, fwd_only=True)
def xformers_splitk(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
):
need_gradient = not (self.mode == BenchmarkMode.FWD_NO_GRAD)
fhma_input = self.xformers_preprocess(q, k, v)
xformers_splitk_fhma = xformers_fmha.triton_splitk.FwOp
return lambda: xformers_splitk_fhma().apply(fhma_input, needs_gradient=False)
return lambda: xformers_splitk_fhma().apply(
fhma_input, needs_gradient=need_gradient
)

def colfax_cutlass_preprocess(self, q, k, v):
return (
Expand Down
11 changes: 7 additions & 4 deletions tritonbench/operators/op_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,16 @@ def get_attribute(
@staticmethod
def check_output() -> None:
op = globals()["op"]
from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS
from tritonbench.utils.triton_op import (
find_enabled_benchmarks,
REGISTERED_BENCHMARKS,
)

output = op.output
output_impls = output.result[0][1].keys()
ci_enabled_impls = [
x for x in REGISTERED_BENCHMARKS[output.op_name].keys() if x not in op._skip
]
ci_enabled_impls = find_enabled_benchmarks(
op.mode, REGISTERED_BENCHMARKS[output.op_name], op._skip
)
# Make sure that all the ci_enabled impls are in the output
assert set(output_impls) == set(
ci_enabled_impls
Expand Down
59 changes: 38 additions & 21 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class BenchmarkOperatorBackend:
baseline: bool = False
# enabled
enabled: bool = True
# fwd_only
# if an operator supports backward, but one of the kernels do not
# set fwd_only = True
fwd_only: bool = False
# need to be tested in ci
# ci = False implies enabled = False
ci: bool = True
Expand All @@ -59,7 +63,6 @@ class BenchmarkOperatorBackend:
DEFAULT_RUN_ITERS = 100
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {}
ENABLED_BENCHMARKS: Dict[str, List[str]] = {}
REGISTERED_METRICS: Dict[str, List[str]] = {}
REGISTERED_X_VALS: Dict[str, str] = {}
BASELINE_BENCHMARKS: Dict[str, str] = {}
Expand Down Expand Up @@ -406,6 +409,26 @@ def __str__(self):
return table


def find_enabled_benchmarks(mode, benchmark_backends, skip_benchmarks):
"""Condition: enabled, not skipped and"""
runnable = lambda m, backend: (not (m == Mode.BWD or m == Mode.FWD_BWD)) or (
not backend.fwd_only
)
if skip_benchmarks:
benchmarks = [
bm
for bm in benchmark_backends.keys()
if not bm in skip_benchmarks and runnable(mode, benchmark_backends[bm])
]
else:
benchmarks = [
bm
for bm in benchmark_backends.keys()
if benchmark_backends[bm].enabled and runnable(mode, benchmark_backends[bm])
]
return benchmarks


def register_x_val(label: str = "x_val"):
def decorator(function):
operator_name = _find_op_name_from_module_path(function.__module__)
Expand All @@ -422,6 +445,7 @@ def _inner(self, *args, **kwargs):
def register_benchmark(
baseline: bool = False,
enabled: bool = True,
fwd_only: bool = False,
label: Optional[str] = None,
):
def decorator(function):
Expand All @@ -431,16 +455,13 @@ def decorator(function):
label=label if label else function.__name__,
baseline=baseline,
enabled=enabled,
fwd_only=fwd_only,
)
if not operator_name in REGISTERED_BENCHMARKS:
REGISTERED_BENCHMARKS[operator_name] = OrderedDict()
REGISTERED_BENCHMARKS[operator_name][function.__name__] = backend_config
if backend_config.baseline:
BASELINE_BENCHMARKS[operator_name] = function.__name__
if backend_config.enabled:
if not operator_name in ENABLED_BENCHMARKS:
ENABLED_BENCHMARKS[operator_name] = []
ENABLED_BENCHMARKS[operator_name].append(function.__name__)

def _inner(self, *args, **kwargs):
return function(self, *args, **kwargs)
Expand Down Expand Up @@ -468,8 +489,8 @@ def register_benchmark_mannually(
enabled (bool, optional): If True, this benchmark function is enabled. Defaults to True.
label (Optional[str], optional): An optional label for the benchmark function. Defaults to None.

This function updates the global dictionaries REGISTERED_BENCHMARKS, BASELINE_BENCHMARKS,
and ENABLED_BENCHMARKS to include the new benchmark function. If the operator or function
This function updates the global dictionaries REGISTERED_BENCHMARKS and BASELINE_BENCHMARKS,
to include the new benchmark function. If the operator or function
is already registered, it updates the existing entries.

We need this manually register function because decorator doesn't work for
Expand All @@ -485,10 +506,6 @@ def register_benchmark_mannually(
)
if baseline:
BASELINE_BENCHMARKS[operator_name] = func_name
if enabled:
if not operator_name in ENABLED_BENCHMARKS:
ENABLED_BENCHMARKS[operator_name] = []
ENABLED_BENCHMARKS[operator_name].append(func_name)


def register_metric(
Expand Down Expand Up @@ -639,14 +656,21 @@ def _get_bm_func(self, bm_func_name: str):
bm_func_name,
)

backend = REGISTERED_BENCHMARKS[self.name][bm_func_name]
if self.mode == Mode.FWD:
setattr(fwd_fn, "_name", bm_func_name)
return fwd_fn
elif self.mode == Mode.BWD:
assert (
not backend.fwd_only
), f"Backend {bm_func_name} does not support backward pass."
bwd_fn = self.get_bwd_fn(fwd_fn)
setattr(bwd_fn, "_name", bm_func_name)
return bwd_fn
elif self.mode == Mode.FWD_BWD:
assert (
not backend.fwd_only
), f"Backend {bm_func_name} does not support backward pass."
bwd_fn = self.get_bwd_fn(fwd_fn)
fwd_bwd_fn = lambda: (fwd_fn(), bwd_fn())
setattr(fwd_bwd_fn, "_name", bm_func_name)
Expand Down Expand Up @@ -696,18 +720,11 @@ def run(
x_val = self.get_x_val(self.example_inputs)
if self._only:
benchmarks = self._only
elif self._skip:
benchmarks = [
bm
for bm in REGISTERED_BENCHMARKS[self.name].keys()
if bm not in self._skip
]
else:
benchmarks = (
[bm for bm in ENABLED_BENCHMARKS[self.name]]
if self.name in ENABLED_BENCHMARKS
else []
benchmarks = find_enabled_benchmarks(
self.mode, REGISTERED_BENCHMARKS[self.name], self._skip
)

# Run the baseline first, if baseline exists
baseline_name = (
BASELINE_BENCHMARKS[self.name]
Expand Down
Loading