From eb55eb43a126de7074cc51167a2b552e4df6ac75 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Wed, 20 Nov 2024 17:53:33 -0800 Subject: [PATCH] lint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .../operators/ragged_attention/hstu.py | 14 ++++++++++-- .../operators/ragged_attention/operator.py | 22 ++++++++++++++++--- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index 4e92a28f..c6dd4010 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -79,7 +79,11 @@ def __init__( self.persistent_kernel = persistent_kernel def forward( - self, qkv: torch.Tensor, seq_offsets: torch.Tensor, timestamps: torch.Tensor, num_targets: torch.Tensor + self, + qkv: torch.Tensor, + seq_offsets: torch.Tensor, + timestamps: torch.Tensor, + num_targets: torch.Tensor, ) -> torch.Tensor: NUM_BUCKETS = self.num_buckets torch._check(timestamps.size(0) + 1 == seq_offsets.size(0)) @@ -215,7 +219,13 @@ def generate_sparse_seq_len( def get_test_inputs( - batch_size, num_heads, max_seq_len, sparsity, target_size, sort_by_length, requires_grad + batch_size, + num_heads, + max_seq_len, + sparsity, + target_size, + sort_by_length, + requires_grad, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: timestamp_deltas: torch.Tensor = torch.randint( 86400, diff --git a/tritonbench/operators/ragged_attention/operator.py b/tritonbench/operators/ragged_attention/operator.py index c157a3af..c4834ac7 100644 --- a/tritonbench/operators/ragged_attention/operator.py +++ b/tritonbench/operators/ragged_attention/operator.py @@ -64,7 +64,9 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets # TODO: enable persistent kernels when the OSS backward is ready @register_benchmark(enabled=False) - def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps, num_targets): + def hstu_triton_ragged_attention_persistent( + self, qkv, seq_offsets, timestamps, num_targets + ): attn = RaggedHSTUAttn( self.batch_size, self.num_heads, @@ -79,12 +81,26 @@ def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps, return lambda: attn(qkv, seq_offsets, timestamps, num_targets) def get_x_val(self, example_inputs): - return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets, self.sparsity, self.target_size, self.sort_by_length) + return ( + self.batch_size, + self.num_heads, + self.max_seq_len, + self.num_buckets, + self.sparsity, + self.target_size, + self.sort_by_length, + ) def get_input_iter(self): for _input_id in range(self._num_inputs): inputs = get_test_inputs( - self.batch_size, self.num_heads, self.max_seq_len, self.sparsity, self.target_size, self.sort_by_length, self.requires_grad + self.batch_size, + self.num_heads, + self.max_seq_len, + self.sparsity, + self.target_size, + self.sort_by_length, + self.requires_grad, ) yield inputs