Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Nov 21, 2024
1 parent 6114f86 commit eb55eb4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
14 changes: 12 additions & 2 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def __init__(
self.persistent_kernel = persistent_kernel

def forward(
self, qkv: torch.Tensor, seq_offsets: torch.Tensor, timestamps: torch.Tensor, num_targets: torch.Tensor
self,
qkv: torch.Tensor,
seq_offsets: torch.Tensor,
timestamps: torch.Tensor,
num_targets: torch.Tensor,
) -> torch.Tensor:
NUM_BUCKETS = self.num_buckets
torch._check(timestamps.size(0) + 1 == seq_offsets.size(0))
Expand Down Expand Up @@ -215,7 +219,13 @@ def generate_sparse_seq_len(


def get_test_inputs(
batch_size, num_heads, max_seq_len, sparsity, target_size, sort_by_length, requires_grad
batch_size,
num_heads,
max_seq_len,
sparsity,
target_size,
sort_by_length,
requires_grad,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
timestamp_deltas: torch.Tensor = torch.randint(
86400,
Expand Down
22 changes: 19 additions & 3 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets

# TODO: enable persistent kernels when the OSS backward is ready
@register_benchmark(enabled=False)
def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps, num_targets):
def hstu_triton_ragged_attention_persistent(
self, qkv, seq_offsets, timestamps, num_targets
):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
Expand All @@ -79,12 +81,26 @@ def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps,
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)

def get_x_val(self, example_inputs):
return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets, self.sparsity, self.target_size, self.sort_by_length)
return (
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
)

def get_input_iter(self):
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(
self.batch_size, self.num_heads, self.max_seq_len, self.sparsity, self.target_size, self.sort_by_length, self.requires_grad
self.batch_size,
self.num_heads,
self.max_seq_len,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
)
yield inputs

Expand Down

0 comments on commit eb55eb4

Please sign in to comment.