Skip to content

Commit

Permalink
Update hstu
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 14, 2024
1 parent b140dcc commit c3a22e5
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 17 deletions.
26 changes: 12 additions & 14 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
triton_ragged_hstu_attention = importlib.import_module(
"generative-recommenders.ops.triton.triton_ragged_hstu_attention"
)
_ragged_hstu_attn_fwd = triton_ragged_hstu_attention._ragged_hstu_attn_fwd
_ragged_hstu_attn_fwd_persistent = (
triton_ragged_hstu_attention._ragged_hstu_attn_fwd_persistent
)
_RaggedAttentionRelativeBiasFunction = triton_ragged_hstu_attention._RaggedAttentionRelativeBiasFunction

@torch.fx.wrap
def prev_power_of_2(x: int) -> int:
Expand Down Expand Up @@ -141,24 +141,22 @@ def forward(
"HAS_SORT_BY_LENGTH_INDICES": False,
"sort_by_length_indices": None,
}
if not IS_FBCODE:
del kwargs["MAX_ATTN_LEN"]
del kwargs["HAS_CONTEXTUAL_SEQ_LEN"]
del kwargs["contextual_seq_len"]
del kwargs["HAS_SORT_BY_LENGTH_INDICES"]
del kwargs["sort_by_length_indices"]
kwargs["HAS_MAX_ATTN_LEN"] = False
kwargs["max_attn_len"] = 0

if self.persistent_kernel:
grid = (1216,)
_ragged_hstu_attn_fwd_persistent[grid](**kwargs)
else:
grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)
_ragged_hstu_attn_fwd[grid](**kwargs)
kwargs = {
"max_seq_len": kwargs["max_seq_len"],
"alpha": kwargs["alpha"],
"q": kwargs["q"],
"k": kwargs["k"],
"v":kwargs["v"],
"seq_offsets": kwargs["seq_offsets"],
"invalid_attn_mask_type": kwargs["invalid_attn_mask_type"],
"num_targets": kwargs["num_targets"],
}
_RaggedAttentionRelativeBiasFunction.apply(**kwargs)

return out

Expand Down
53 changes: 50 additions & 3 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import argparse
import torch

from typing import List, Optional
from typing import List, Optional, Callable, Any
from tritonbench.utils.input import input_filter

from tritonbench.utils.triton_op import BenchmarkOperator, register_benchmark
from tritonbench.utils.triton_op import BenchmarkOperator, BenchmarkOperatorMetrics, register_benchmark, Mode, register_metric

from .hstu import get_test_inputs, RaggedHSTUAttn

Expand Down Expand Up @@ -42,7 +44,8 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
)
return lambda: attn(qkv, seq_offsets, timestamps)

@register_benchmark()
# TODO: enable persistent kernels when the OSS backward is ready
@register_benchmark(enabled=False)
def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps):
attn = RaggedHSTUAttn(
self.batch_size,
Expand All @@ -60,3 +63,47 @@ def get_input_iter(self):
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len)
yield inputs

def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
o = fwd_fn()
o_tensor = input_filter(
lambda x: isinstance(x, torch.Tensor),
o,
)
do = torch.rand_like(o_tensor)
fn = lambda: o_tensor.backward(do, retain_graph=True)
return fn

@register_metric()
def tflops(
self,
fn_name,
example_inputs,
metrics: BenchmarkOperatorMetrics
) -> float:
ratio = 2.0 # triangular masking
f1 = 0.0
f2 = 0.0
jagged = True
seq_offsets = example_inputs["seq_offsets"]
q = example_inputs["qkv"][:, :, :128]
v = example_inputs["qkv"][:, :, 256:384]
_, nheads, attn_dim = q.shape
_, _, hidden_dim = v.shape
max_seqlen = example_inputs["timestamps"].size(1) - 1

for i in range(self.batch_size):
seq_len = (
int((seq_offsets[i + 1] - seq_offsets[i]).item()) if jagged else max_seqlen
)
# (QK^T), dQ = d(QK^T)K, dK^T = Q^Td(QK^T)
f1 += 2 * self.num_heads * attn_dim * seq_len**2 // ratio
# (QK^T)V, d(QK^T) = dOV^T, dV = (QK^T)^TdO,
f2 += 2 * self.num_heads * hidden_dim * seq_len**2 // ratio
if self.mode == Mode.FWD:
tflops = f1 + f2 # computes (QK^T) and (QK^T)V
elif self.mode == Mode.BWD:
tflops = 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T)
elif self.mode == Mode.FWD_BWD:
tflops = 4 * f1 + 3 * f2
return tflops / metrics.latency * 1e-9

0 comments on commit c3a22e5

Please sign in to comment.