Skip to content

Commit

Permalink
Add ufmt linter for pyproject (#47)
Browse files Browse the repository at this point in the history
Summary:
Add linter to make sure PR is consistent with internal `arc lint`.

To format the code, run the following in the repo directory:

```
ufmt format .
```

We also test 25 operators in our H100 CI. Note that many of the flash_attention backends do not work right now and we need to fix them.

Pull Request resolved: #47

Reviewed By: FindHao

Differential Revision: D65709181

Pulled By: xuzhao9

fbshipit-source-id: 5b013906e7b04c8ee41d74db5756de08eec5b5b2
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 11, 2024
1 parent 7b4a0eb commit 359dfb4
Show file tree
Hide file tree
Showing 29 changed files with 206 additions and 106 deletions.
3 changes: 3 additions & 0 deletions .ci/tritonbench/test-gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ fi

. "${SETUP_SCRIPT}"

# install deps
pip install psutil tabulate

python -m unittest test.test_gpu.main
29 changes: 29 additions & 0 deletions .github/workflows/linter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Linter
on:
pull_request:
push:
branches:
- main
workflow_dispatch:

jobs:
pylint:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
with:
path: tritonbench
- name: Install deps
run: |
pip install ruff-api==0.1.0
- name: Check Formatting
uses: omnilib/ufmt@action-v1
with:
path: tritonbench

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
11 changes: 9 additions & 2 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ on:
- .ci/*
- tritonbench/*
- .github/workflows/pr.yaml
push:
branches:
- main

jobs:
h100-pytorch-test:
Expand All @@ -27,6 +30,10 @@ jobs:
sudo nvidia-smi -pm 1
sudo ldconfig
nvidia-smi
- name: Test Tritonbench operators
- name: Test Tritonbench operators on H100 GPU
run: |
bash ./.ci/tritonbench/test-operators.sh
bash ./.ci/tritonbench/test-gpu.sh
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[tool.ufmt]
formatter = "ruff-api"
excludes = ["submodules/"]

[tool.black]
line-length = 88
target-version = ["py312"]
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
packaging
pynvml
psutil
tabulate
transformers==4.46.1
1 change: 0 additions & 1 deletion test/test_cpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class TestTritonbenchCpu(unittest.TestCase):

def _get_test_op(self):
parser = get_parser(["--device", "cpu", "--op", "test_op"])
tb_args, extra_args = parser.parse_known_args(
Expand Down
8 changes: 6 additions & 2 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
fbcode_skip_file_path = "fb/skip_tests_h100_fbcode.yaml"
SKIP_FILE = importlib.resources.files(__package__).joinpath(fbcode_skip_file_path)
else:
SKIP_FILE = "skip_tests_h100_pytorch.yaml"
import os

SKIP_FILE = os.path.abspath(
os.path.join(os.path.dirname(__file__), "skip_tests_h100_pytorch.yaml")
)

with open(SKIP_FILE, "r") as f:
skip_tests = yaml.safe_load(f)
Expand Down Expand Up @@ -55,7 +59,7 @@ def _run_one_operator(
):
if tb_args.op in skip_tests:
# If the op itself is in the skip list, skip all tests
if skip_tests[tb_args.op] is None:
if not skip_tests[tb_args.op]:
return
tb_args.skip = ",".join(skip_tests[tb_args.op])
Operator = load_opbench_by_name(tb_args.op)
Expand Down
44 changes: 35 additions & 9 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
# Tests we skip in OSS CI
# This file is regarding to the Triton version bundled with pytorch
# Use <op-name> to skip an entire operator
# Use <op-name/impl-name> to skip an impl
- test_op
- bf16xint16_gemm/bf16xint16
- fp8_attention/colfax_fmha
- fp8_fused_quant_gemm_rowwise
- fp8_gemm/triton_persistent_fp8_gemm
- fp8_gemm/triton_tma_persistent_fp8_gemm
- fp8_gemm_rowwise
# Use <op-name:> to skip an entire operator
# Use <op-name:\n - impl-name> to skip an impl
bf16xint16_gemm:
- bf16xint16
# TODO: we have many buggy backends for flash_attention
# Need to fix them in the CI
flash_attention:
# - triton_tutorial_flash_v2_tma
# - triton_op_flash_v2
# - xformers_splitk
# - colfax_cutlass
# - tk
# - sdpa
# - cudnn
# - flex_attention
fp8_attention:
- colfax_fmha
fp8_fused_quant_gemm_rowwise:
fp8_gemm:
- triton_persistent_fp8_gemm
- triton_tma_persistent_fp8_gemm
fp8_gemm_rowwise:
gemm:
grouped_gemm:
int4_gemm:
jagged_layer_norm:
jagged_mean:
jagged_softmax:
jagged_sum:
layer_norm:
low_mem_dropout:
rms_norm:
rope:
template_attention:
test_op:
4 changes: 1 addition & 3 deletions tools/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,5 @@ def check_torch_nightly_version(force_date: Optional[str] = None):
if args.install_torch_nightly:
install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ)
if args.check_torch_nightly_version:
assert (
not args.install_torch_nightly
), "Error: Can't run install torch nightly and check version in the same command."
assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command."
check_torch_nightly_version(args.force_date)
4 changes: 3 additions & 1 deletion tritonbench/components/workers/subprocess_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,10 @@ def write(self, msg: bytes) -> None:
def get_writer_pid(self) -> int:
assert (
self._writer_pid is not None
), "Writer pid is not specified. Maybe calling from child process or input pipe.\
), (
"Writer pid is not specified. Maybe calling from child process or input pipe.\
Please report a bug."
)
return self._writer_pid

def set_writer_pid(self, writer_pid: int) -> None:
Expand Down
20 changes: 11 additions & 9 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@


class TmaAutoTuneHelper:

# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
Expand Down Expand Up @@ -734,7 +733,6 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
HEAD_DIM: tl.constexpr, #
STAGE: tl.constexpr, #
):

tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
Expand Down Expand Up @@ -848,7 +846,14 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #

@triton.jit
def _attn_bwd_preprocess(
O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # # # #
O,
DO,
Delta,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr, # # # #
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
Expand Down Expand Up @@ -1179,7 +1184,6 @@ def _attn_bwd(


class _attention_ws(torch.autograd.Function):

@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
Expand Down Expand Up @@ -1232,7 +1236,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
**extra_kern_args
**extra_kern_args,
)

ctx.save_for_backward(q, k, v, o, M)
Expand Down Expand Up @@ -1304,7 +1308,6 @@ def backward(ctx, do):


class _attention(torch.autograd.Function):

@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
Expand Down Expand Up @@ -1355,7 +1358,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
**extra_kern_args
**extra_kern_args,
)

ctx.save_for_backward(q, k, v, o, M)
Expand Down Expand Up @@ -1427,7 +1430,6 @@ def backward(ctx, do):


class _attention_tma(torch.autograd.Function):

@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
Expand Down Expand Up @@ -1587,7 +1589,7 @@ def grid_tma(META):
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
**extra_kern_args
**extra_kern_args,
)

ctx.save_for_backward(q, k, v, o, M)
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/gather_gemv/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@


class Operator(BenchmarkOperator):

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
arg0_1, arg1_1, arg2_1 = example_inputs
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/gather_gemv/triton_gather_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def triton_red_fused_mv_0(
rbase = tl.arange(0, RBLOCK)[None, :].to(tl.int64)
x0 = xindex
# x0 // rnumel should have the same value of either 0 or 1
tmp0 = tl.load(in_ptr0 + ((x0 // rnumel)), None, eviction_policy="evict_last")
tmp0 = tl.load(in_ptr0 + (x0 // rnumel), None, eviction_policy="evict_last")
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
Expand Down
9 changes: 4 additions & 5 deletions tritonbench/operators/jagged_layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def parse_op_args(args: List[str]):


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"

Expand All @@ -48,8 +47,8 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
self.sizes = list(range(2, 12, 4)) + list(
range(12, 23, 3)
self.sizes = (
list(range(2, 12, 4)) + list(range(12, 23, 3))
) # bias towards larger sizes, which are more representative of real-world shapes

args = parse_op_args(self.extra_args)
Expand Down Expand Up @@ -105,8 +104,8 @@ def _inner():
) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm

padded_normalized = (
padded_values - mean
) * padded_mask_values # mask elements outside of the ragged dimension size for correct variance calculation
(padded_values - mean) * padded_mask_values
) # mask elements outside of the ragged dimension size for correct variance calculation

variance = (
torch.sum(
Expand Down
20 changes: 12 additions & 8 deletions tritonbench/operators/jagged_mean/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def triton_jagged_mean_kernel_simple_fused_sum_then_buffer(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down Expand Up @@ -133,8 +134,9 @@ def triton_jagged_mean_kernel_simple_fused_buffer_then_sum(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down Expand Up @@ -212,8 +214,9 @@ def triton_jagged_mean_kernel_variable_length_loop_sum_then_buffer(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down Expand Up @@ -288,8 +291,9 @@ def triton_jagged_mean_kernel_variable_length_loop_buffer_then_sum(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
input_ptr_offsets + (pid_ragged + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_ragged),
tl.load(input_ptr_offsets + (pid_ragged + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down
Loading

0 comments on commit 359dfb4

Please sign in to comment.