diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index 025bad0d..327c283e 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -64,7 +64,7 @@ def __init__( self.all_ts_weights = torch.nn.Parameter( torch.randn( (self.num_buckets + 1,), - dtype=torch.bfloat16, + dtype=torch.float32, ) .requires_grad_(requires_grad) .cuda() @@ -72,7 +72,7 @@ def __init__( self.all_pos_weights = torch.nn.Parameter( torch.randn( (2 * self.max_seq_len - 1,), - dtype=torch.bfloat16, + dtype=torch.float32, ) .requires_grad_(requires_grad) .cuda() @@ -81,7 +81,9 @@ def __init__( def forward( self, - qkv: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, seq_offsets: torch.Tensor, timestamps: torch.Tensor, num_targets: torch.Tensor, @@ -89,9 +91,6 @@ def forward( NUM_BUCKETS = self.num_buckets torch._check(timestamps.size(0) + 1 == seq_offsets.size(0)) - q = qkv[:, :, :128] - k = qkv[:, :, 128:256] - v = qkv[:, :, 256:384] out = torch.zeros_like(v) Z = timestamps.size(0) @@ -134,13 +133,13 @@ def forward( "DeltaSize": None, "num_buckets": NUM_BUCKETS, "max_pos_ind": None, - "time_bucket_incr": 60.0, + "time_bucket_incr": 60, "time_bucket_div": 1.0, "time_delta": 0.0, "INVALID_MASK_TYPE": "lower_triangular", "CAUSAL": True, "BUCKET_FN": "sqrt", - "ATTN_BIAS_TYPE": "fused", + "ATTN_BIAS_TYPE": "ALL", "USE_TIME_BIAS": False, "USE_POS_BIAS": False, "HAS_MAX_POS_IND": False, @@ -150,7 +149,7 @@ def forward( "ALLOW_TF32": True, "BLOCK_D_Q": DimQ, "BLOCK_D_V": DimV, - "MAX_ATTN_LEN": 0, + "MAX_ATTN_LEN": None, "CONTEXTUAL_SEQ_LEN": 0, "HAS_SORT_BY_LENGTH_INDICES": False, "sort_by_length_indices": None, @@ -219,27 +218,42 @@ def generate_sparse_seq_len( ) +try: + from hammer.benchmark.module_factory.hstu_utils import ( + apply_SL, + generate_hstu_timestamps, + ) +except ImportError: + + def apply_SL(lengths: torch.Tensor, alpha: float, max_seq_len: int): + return lengths + + def generate_hstu_timestamps(batch_size, seq_len): + ts = torch.rand(batch_size, seq_len + 1, device="cuda") ** -0.8 + ts = torch.clamp(torch.abs(ts * 86400), max=1e7) + ts, _ = torch.sort(ts, dim=1) + return ts.long() + + def get_test_inputs( batch_size, num_heads, + attn_dim, + hidden_dim, max_seq_len, sparsity, target_size, sort_by_length, requires_grad, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - timestamp_deltas: torch.Tensor = torch.randint( - 86400, - size=(batch_size, max_seq_len + 1), - ).cuda() - timestamps = timestamp_deltas.cumsum(dim=1) - + timestamps = generate_hstu_timestamps(batch_size, max_seq_len) lengths = generate_sparse_seq_len( size=batch_size, max_seq_len=max_seq_len, sparsity=sparsity, device=torch.device("cuda"), ) + lengths = apply_SL(lengths, alpha=2.0, max_seq_len=max_seq_len) # assume has_delta_q is False num_targets = None if target_size != 0: @@ -254,19 +268,21 @@ def get_test_inputs( seq_offsets = torch.zeros( (batch_size + 1,), dtype=torch.int64, - ).cuda() + device="cuda", + ) seq_offsets[1:] = torch.cumsum( lengths, dim=0, ) L = int(seq_offsets[-1].item()) - qkv = ( - torch.randn( - (L, num_heads, 512), - dtype=torch.bfloat16, - ) - .requires_grad_(requires_grad) - .cuda() + qkv = torch.randn( + (L, num_heads, attn_dim * 2 + hidden_dim), + dtype=torch.bfloat16, + device="cuda", ) - return qkv, seq_offsets, timestamps, num_targets + q, k, v = torch.split(qkv, [attn_dim, attn_dim, hidden_dim], dim=-1) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + return q, k, v, seq_offsets, timestamps, num_targets, max_seq_len diff --git a/tritonbench/operators/ragged_attention/operator.py b/tritonbench/operators/ragged_attention/operator.py index 52db5482..2ab21e09 100644 --- a/tritonbench/operators/ragged_attention/operator.py +++ b/tritonbench/operators/ragged_attention/operator.py @@ -19,13 +19,15 @@ def parse_op_args(args: List[str]): parser = argparse.ArgumentParser() - parser.add_argument("--batch-size", type=int, default=8, help="Batch size") + parser.add_argument("--batch-size", type=int, default=256, help="Batch size") parser.add_argument("--heads", type=int, default=4, help="Number of heads") - parser.add_argument("--max-seq-len-log2", type=int, default=9) + parser.add_argument("--attn-dim", type=int, default=128) + parser.add_argument("--hidden-dim", type=int, default=128) + parser.add_argument("--max-seq-len-log2", type=int, default=15) parser.add_argument("--num-buckets", type=int, default=2048) - parser.add_argument("--seq-sparsity", type=float, default=0.8) + parser.add_argument("--seq-sparsity", type=float, default=0.95) parser.add_argument("--target-size", type=int, default=20) - parser.add_argument("--sort-by-length", type=bool, default=False) + parser.add_argument("--sort-by-length", type=bool, default=True) return parser.parse_args(args) @@ -39,21 +41,23 @@ def __init__( args = parse_op_args(self.extra_args) self.batch_size = args.batch_size self.num_heads = args.heads - self.max_seq_len = 2**args.max_seq_len_log2 + self.attn_dim = args.attn_dim + self.hidden_dim = args.hidden_dim + self.max_seq_len_log2 = args.max_seq_len_log2 self.num_buckets = args.num_buckets self.sparsity = args.seq_sparsity self.target_size = args.target_size self.sort_by_length = args.sort_by_length - # set a default number of inputs - self._num_inputs = 10 if self._num_inputs is None else self._num_inputs self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD) @register_benchmark() - def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets): + def hstu_triton_ragged_attention( + self, q, k, v, seq_offsets, timestamps, num_targets, seq_len + ): attn = RaggedHSTUAttn( self.batch_size, self.num_heads, - self.max_seq_len, + seq_len, self.num_buckets, self.sparsity, self.target_size, @@ -61,17 +65,24 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets self.requires_grad, persistent_kernel=False, ) - return lambda: attn(qkv, seq_offsets, timestamps, num_targets) + return lambda: attn(q, k, v, seq_offsets, timestamps, num_targets) # TODO: enable persistent kernels when the OSS backward is ready @register_benchmark(enabled=False) def hstu_triton_ragged_attention_persistent( - self, qkv, seq_offsets, timestamps, num_targets + self, + q, + k, + v, + seq_offsets, + timestamps, + num_targets, + seq_len, ): attn = RaggedHSTUAttn( self.batch_size, self.num_heads, - self.max_seq_len, + seq_len, self.num_buckets, self.sparsity, self.target_size, @@ -79,13 +90,14 @@ def hstu_triton_ragged_attention_persistent( self.requires_grad, persistent_kernel=True, ) - return lambda: attn(qkv, seq_offsets, timestamps, num_targets) + return lambda: attn(q, k, v, seq_offsets, timestamps, num_targets) def get_x_val(self, example_inputs): + seq_len = example_inputs[-1] return ( self.batch_size, self.num_heads, - self.max_seq_len, + seq_len, self.num_buckets, self.sparsity, self.target_size, @@ -93,17 +105,18 @@ def get_x_val(self, example_inputs): ) def get_input_iter(self): - for _input_id in range(self._num_inputs): - inputs = get_test_inputs( + for seq_len in [2**i for i in range(8, self.max_seq_len_log2)]: + yield get_test_inputs( self.batch_size, self.num_heads, - self.max_seq_len, + self.attn_dim, + self.hidden_dim, + seq_len, self.sparsity, self.target_size, self.sort_by_length, self.requires_grad, ) - yield inputs def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]: o = fwd_fn() @@ -123,9 +136,7 @@ def tflops( f1 = 0.0 f2 = 0.0 jagged = True - qkv, seq_offsets, timestamps, num_targets = example_inputs - q = qkv[:, :, :128] - v = qkv[:, :, 256:384] + q, k, v, seq_offsets, timestamps, num_targets = example_inputs _, nheads, attn_dim = q.shape _, _, hidden_dim = v.shape max_seqlen = timestamps.size(1) - 1