Skip to content

Commit

Permalink
Lint the code
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 9, 2024
1 parent f607751 commit 0185a8d
Show file tree
Hide file tree
Showing 22 changed files with 116 additions and 90 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/linter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ jobs:
uses: actions/checkout@v3
with:
path: tritonbench
- name: Install deps
run: |
pip install ruff-api
- name: Check Formatting
uses: omnilib/ufmt@action-v1
with:
Expand Down
1 change: 0 additions & 1 deletion test/test_cpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class TestTritonbenchCpu(unittest.TestCase):

def _get_test_op(self):
parser = get_parser(["--device", "cpu", "--op", "test_op"])
tb_args, extra_args = parser.parse_known_args(
Expand Down
4 changes: 1 addition & 3 deletions tools/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,5 @@ def check_torch_nightly_version(force_date: Optional[str] = None):
if args.install_torch_nightly:
install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ)
if args.check_torch_nightly_version:
assert (
not args.install_torch_nightly
), "Error: Can't run install torch nightly and check version in the same command."
assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command."
check_torch_nightly_version(args.force_date)
4 changes: 3 additions & 1 deletion tritonbench/components/workers/subprocess_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,10 @@ def write(self, msg: bytes) -> None:
def get_writer_pid(self) -> int:
assert (
self._writer_pid is not None
), "Writer pid is not specified. Maybe calling from child process or input pipe.\
), (
"Writer pid is not specified. Maybe calling from child process or input pipe.\
Please report a bug."
)
return self._writer_pid

def set_writer_pid(self, writer_pid: int) -> None:
Expand Down
17 changes: 10 additions & 7 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@


class TmaAutoTuneHelper:

# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
Expand Down Expand Up @@ -457,7 +456,6 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
HEAD_DIM: tl.constexpr, #
STAGE: tl.constexpr, #
):

tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
Expand Down Expand Up @@ -569,7 +567,14 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #

@triton.jit
def _attn_bwd_preprocess(
O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # # # #
O,
DO,
Delta,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr, # # # #
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
Expand Down Expand Up @@ -900,7 +905,6 @@ def _attn_bwd(


class _attention(torch.autograd.Function):

@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
Expand Down Expand Up @@ -949,7 +953,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
**extra_kern_args
**extra_kern_args,
)

ctx.save_for_backward(q, k, v, o, M)
Expand Down Expand Up @@ -1021,7 +1025,6 @@ def backward(ctx, do):


class _attention_tma(torch.autograd.Function):

@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
Expand Down Expand Up @@ -1175,7 +1178,7 @@ def grid_tma(META):
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
**extra_kern_args
**extra_kern_args,
)

ctx.save_for_backward(q, k, v, o, M)
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/gather_gemv/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@


class Operator(BenchmarkOperator):

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
arg0_1, arg1_1, arg2_1 = example_inputs
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/gather_gemv/triton_gather_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def triton_red_fused_mv_0(
rbase = tl.arange(0, RBLOCK)[None, :].to(tl.int64)
x0 = xindex
# x0 // rnumel should have the same value of either 0 or 1
tmp0 = tl.load(in_ptr0 + ((x0 // rnumel)), None, eviction_policy="evict_last")
tmp0 = tl.load(in_ptr0 + (x0 // rnumel), None, eviction_policy="evict_last")
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
Expand Down
9 changes: 4 additions & 5 deletions tritonbench/operators/jagged_layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def parse_op_args(args: List[str]):


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"

Expand All @@ -48,8 +47,8 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
self.sizes = list(range(2, 12, 4)) + list(
range(12, 23, 3)
self.sizes = (
list(range(2, 12, 4)) + list(range(12, 23, 3))
) # bias towards larger sizes, which are more representative of real-world shapes

args = parse_op_args(self.extra_args)
Expand Down Expand Up @@ -105,8 +104,8 @@ def _inner():
) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm

padded_normalized = (
padded_values - mean
) * padded_mask_values # mask elements outside of the ragged dimension size for correct variance calculation
(padded_values - mean) * padded_mask_values
) # mask elements outside of the ragged dimension size for correct variance calculation

variance = (
torch.sum(
Expand Down
20 changes: 12 additions & 8 deletions tritonbench/operators/jagged_mean/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def triton_jagged_mean_kernel_simple_fused_sum_then_buffer(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down Expand Up @@ -133,8 +134,9 @@ def triton_jagged_mean_kernel_simple_fused_buffer_then_sum(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down Expand Up @@ -212,8 +214,9 @@ def triton_jagged_mean_kernel_variable_length_loop_sum_then_buffer(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down Expand Up @@ -288,8 +291,9 @@ def triton_jagged_mean_kernel_variable_length_loop_buffer_then_sum(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
input_ptr_offsets + (pid_ragged + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_ragged),
tl.load(input_ptr_offsets + (pid_ragged + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down
50 changes: 30 additions & 20 deletions tritonbench/operators/jagged_mean/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer):


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"

Expand All @@ -104,8 +103,8 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
self.sizes = list(range(2, 12, 4)) + list(
range(12, 23, 3)
self.sizes = (
list(range(2, 12, 4)) + list(range(12, 23, 3))
) # bias towards larger sizes, which are more representative of real-world shapes

args = parse_op_args(self.extra_args)
Expand All @@ -130,28 +129,37 @@ def torch_jagged_mean_unbind_torch_mean(
def torch_jagged_mean_torch_nanmean(
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
):
return lambda: torch.nanmean(
torch.ops.aten._jagged_to_padded_dense_forward(
x.values(),
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
max_lengths=[seqlen], # max length of ragged dimension
padding_value=float("nan"),
),
dim=1,
return (
lambda: torch.nanmean(
torch.ops.aten._jagged_to_padded_dense_forward(
x.values(),
[
x.offsets()
], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
max_lengths=[seqlen], # max length of ragged dimension
padding_value=float("nan"),
),
dim=1,
)
)

@register_benchmark()
def torch_jagged_mean_torch_sum(
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
):
return lambda: torch.sum(
torch.ops.aten._jagged_to_padded_dense_forward(
x.values(),
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
max_lengths=[seqlen], # max length of ragged dimension
),
dim=1,
) / x.offsets().diff().unsqueeze(1)
return (
lambda: torch.sum(
torch.ops.aten._jagged_to_padded_dense_forward(
x.values(),
[
x.offsets()
], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
max_lengths=[seqlen], # max length of ragged dimension
),
dim=1,
)
/ x.offsets().diff().unsqueeze(1)
)

@register_benchmark()
def triton_jagged_mean_simple_fused(
Expand All @@ -176,7 +184,9 @@ def torch_compile_nested_tensor_integration(
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
):
def _inner(x: torch.Tensor): # mean along ragged dimension (dim == 1)
return torch.mean(x, dim=x._ragged_idx, keepdim=True) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.
return torch.mean(
x, dim=x._ragged_idx, keepdim=True
) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.

torch_compile_func = torch.compile(_inner)
return lambda: torch_compile_func(x)
Expand Down
10 changes: 6 additions & 4 deletions tritonbench/operators/jagged_softmax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def triton_jagged_softmax_kernel_simple_fused_buffer_then_sum(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]

buffer_max_all = tl.full(
Expand Down Expand Up @@ -163,8 +164,9 @@ def triton_jagged_softmax_kernel_variable_length_loop_buffer_then_sum(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]

buffer_max_all = tl.full(
Expand Down
13 changes: 8 additions & 5 deletions tritonbench/operators/jagged_softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def parse_op_args(args: List[str]):


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy", "best_config"]
DEFAULT_PRECISION = "fp32"

Expand All @@ -83,8 +82,8 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
self.sizes = list(range(2, 12, 4)) + list(
range(12, 23, 3)
self.sizes = (
list(range(2, 12, 4)) + list(range(12, 23, 3))
) # bias towards larger sizes, which are more representative of real-world shapes

args = parse_op_args(self.extra_args)
Expand Down Expand Up @@ -114,7 +113,9 @@ def torch_jagged_softmax_torch_sum(
def _inner():
padded = torch.ops.aten._jagged_to_padded_dense_forward(
x.values(),
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
[
x.offsets()
], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
max_lengths=[seqlen], # max length of ragged dimension
padding_value=float("-inf"), # e^-inf = 0
)
Expand Down Expand Up @@ -153,7 +154,9 @@ def torch_compile_nested_tensor_integration(
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
):
def _inner(x: torch.Tensor): # softmax along ragged dimension
return torch.softmax(x, dim=x._ragged_idx) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.
return torch.softmax(
x, dim=x._ragged_idx
) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.

torch_compile_func = torch.compile(_inner)
return lambda: torch_compile_func(
Expand Down
20 changes: 12 additions & 8 deletions tritonbench/operators/jagged_sum/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def triton_jagged_sum_kernel_simple_fused_sum_then_buffer(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
input_ptr_offsets + (pid_ragged + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_ragged),
tl.load(input_ptr_offsets + (pid_ragged + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]

for block_pos in range(
Expand Down Expand Up @@ -127,8 +128,9 @@ def triton_jagged_sum_kernel_simple_fused_buffer_then_sum(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
input_ptr_offsets + (pid_ragged + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_ragged),
tl.load(input_ptr_offsets + (pid_ragged + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]

for block_pos in range(
Expand Down Expand Up @@ -201,8 +203,9 @@ def triton_jagged_sum_kernel_variable_length_loop_sum_then_buffer(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]

for block_start_ragged in range(
Expand Down Expand Up @@ -272,8 +275,9 @@ def triton_jagged_sum_kernel_variable_length_loop_buffer_then_sum(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
input_ptr_offsets + (pid_ragged + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_ragged),
tl.load(input_ptr_offsets + (pid_ragged + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]

for block_start_ragged in range(
Expand Down
Loading

0 comments on commit 0185a8d

Please sign in to comment.