Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UT] Port and run operator tests #246

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@ jobs:
python3 assert_helper.py device_assert
python3 print_helper.py device_print float 1> /dev/null

- name: Clear cache
run: |
rm -rf ~/.triton

- name: Run interpreter tests
env:
# TRITON_INTERPRET: "1"
CUA_VISIBLE_DEVICES: ""
run: |
cd python/test/unit
python3 -m pytest -vs operators/test_flash_attention.py

- name: Run partial operators tests
if: ${{ env.BACKEND == 'XPU'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 --verbose operators

- name: Run XPU python tests
if: ${{ env.BACKEND == 'XPU'}}
run: |
Expand Down
17 changes: 17 additions & 0 deletions .github/workflows/build_and_test_2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,23 @@ jobs:
python3 assert_helper.py device_assert
python3 print_helper.py device_print float 1> /dev/null

- name: Clear cache
run: |
rm -rf ~/.triton

- name: Run interpreter tests
env:
# TRITON_INTERPRET: "1"
CUA_VISIBLE_DEVICES: ""
run: |
cd python/test/unit
python3 -m pytest -vs operators/test_flash_attention.py

- name: Run partial operators tests
run: |
cd python/test/unit
python3 -m pytest -n 8 --verbose operators

- name: Run XPU python tests
run: |
cd python/test/backend/third_party_backends
Expand Down
14 changes: 10 additions & 4 deletions python/test/unit/operators/test_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import triton
import triton.ops

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
Expand All @@ -12,7 +15,7 @@ def sparsify_tensor(x, mask, block):
return ret


def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
def make_pair(shape, device="xpu", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
if data is None:
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
ref_ret = data
Expand All @@ -38,6 +41,7 @@ def mask_tensor(x, mask, block, value=0):
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
@pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
pytest.skip("RuntimeError: Triton Error [ZE]: 2013265944")
seed = 0
torch.manual_seed(seed)
is_sdd = MODE == "sdd"
Expand Down Expand Up @@ -79,7 +83,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
b_tri = do_sparsify(b_tri) if is_dds else b_tri
a_tri.retain_grad()
b_tri.retain_grad()
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="xpu")
c_tri = op(a_tri, b_tri)
c_tri.backward(dc_tri)
da_tri = a_tri.grad
Expand All @@ -101,6 +105,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
@pytest.mark.parametrize("is_dense", [False, True])
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
pytest.skip("RuntimeError: Triton Error [ZE]: 2013265944")
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 3, WIDTH, WIDTH
Expand All @@ -119,7 +124,7 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
# compute [torch]
a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref.retain_grad()
at_mask = torch.ones((M, N), device="cuda")
at_mask = torch.ones((M, N), device="xpu")
if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
Expand All @@ -132,7 +137,7 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
a_tri = sparsify_tensor(a_tri, layout, BLOCK)
a_tri.retain_grad()
dout_tri = sparsify_tensor(dout_tri, layout, BLOCK)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="xpu", is_dense=is_dense)
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
out_tri.backward(dout_tri)
da_tri = a_tri.grad
Expand All @@ -152,6 +157,7 @@ def test_attention_fwd_bwd(
batch_size=2,
n_heads=2,
):
pytest.skip("FIXME: Port get_device_capability to XPU")
ESI-SYD marked this conversation as resolved.
Show resolved Hide resolved
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
Expand Down
4 changes: 4 additions & 0 deletions python/test/unit/operators/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import triton
import triton.ops

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


@pytest.mark.parametrize("M, N, dtype, mode", [ #
(M, N, dtype, mode)
Expand All @@ -13,6 +16,7 @@
for mode in ['forward', 'backward']
])
def test_op(M, N, dtype, mode):
pytest.skip("FIXME: Port get_device_capability to XPU")
ESI-SYD marked this conversation as resolved.
Show resolved Hide resolved
capability = torch.cuda.get_device_capability()
if capability[0] < 8 and dtype == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
Expand Down
13 changes: 8 additions & 5 deletions python/test/unit/operators/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import triton
import triton.ops

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ #
(2, 4, 512, 16),
Expand All @@ -20,7 +23,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
if enable_tma in ["on", "true", "1"]:
if dtype == torch.bfloat16:
pytest.skip('bfloat16 tma not support currently')

pytest.skip("FIXME: Port get_device_capability to XPU")
ESI-SYD marked this conversation as resolved.
Show resolved Hide resolved
capability = torch.cuda.get_device_capability()
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
if not interpreter and capability[0] < 8:
Expand Down Expand Up @@ -87,14 +90,14 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):


@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"):
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="xpu"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
sm_scale = 1.3
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="xpu", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="xpu", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="xpu", requires_grad=True)
if provider == "triton":
fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par)
if mode == 'bwd':
Expand Down
26 changes: 15 additions & 11 deletions python/test/unit/operators/test_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import triton
import triton.language as tl

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


def test_normalization_with_remat():

Expand Down Expand Up @@ -47,12 +50,12 @@ def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel

torch.manual_seed(123)

buf14 = torch.rand(8, 64, 64, 64, device="cuda")
buf16 = torch.rand(8, 1, 64, device="cuda")
arg114_1 = torch.rand(64, device="cuda")
arg115_1 = torch.rand(64, device="cuda")
arg8_1 = torch.rand(64, device="cuda")
arg9_1 = torch.rand(64, device="cuda")
buf14 = torch.rand(8, 64, 64, 64, device="xpu")
buf16 = torch.rand(8, 1, 64, device="xpu")
arg114_1 = torch.rand(64, device="xpu")
arg115_1 = torch.rand(64, device="xpu")
arg8_1 = torch.rand(64, device="xpu")
arg9_1 = torch.rand(64, device="xpu")
triton_[(512, )](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)

Expand Down Expand Up @@ -146,7 +149,7 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr):
tmp76 = tl.where(tmp74, tmp75, tmp71)
tl.store(out_ptr0 + (x5 + tl.zeros([XBLOCK], tl.int32)), tmp76, None)

inp = torch.ones(8, 2048, 8, 8, device="cuda", dtype=torch.half)
inp = torch.ones(8, 2048, 8, 8, device="xpu", dtype=torch.half)
out = torch.ones_like(inp) * 3
numel = inp.numel()
triton_[(numel // 1024, )](inp, out, 1024)
Expand All @@ -160,6 +163,7 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr):
@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128])
@pytest.mark.parametrize("num_warps", [1, 4])
def test_scan2d_broadcast(RBLOCK, num_warps):
pytest.skip("FIXME: worker crashed cases")

@triton.jit(debug=True)
def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
Expand All @@ -172,8 +176,8 @@ def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
tl.store(out_ptr + xindex * RBLOCK + rindex, scan)

XBLOCK = 4
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='cuda')
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='cuda')
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='xpu')
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='xpu')
fn[(1, )](input, output, XBLOCK, RBLOCK, num_warps=num_warps)
ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK))
torch.testing.assert_close(output, ref)
Expand All @@ -192,7 +196,7 @@ def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr):
tl.store(out_ptr0 + rindex, tmp6, rmask)

RBLOCK = 8
out0 = torch.empty(RBLOCK, device="cuda", dtype=torch.int64)
out0 = torch.empty(RBLOCK, device="xpu", dtype=torch.int64)
fn[(1, )](out0, RBLOCK, RBLOCK)
ref = torch.arange(RBLOCK, device="cuda", dtype=torch.int64) + 1
ref = torch.arange(RBLOCK, device="xpu", dtype=torch.int64) + 1
torch.testing.assert_close(out0, ref)
12 changes: 8 additions & 4 deletions python/test/unit/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import triton.language as tl
import triton.ops

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE",
Expand Down Expand Up @@ -102,6 +105,7 @@
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32,
F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE):
pytest.skip("FIXME: Port get_device_capability to XPU")
ESI-SYD marked this conversation as resolved.
Show resolved Hide resolved
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
Expand Down Expand Up @@ -152,15 +156,15 @@ def upcast_if_fp8(x, dtype):
def init_input(m, n, dtype, acc_dtype):
if 'float8' in dtype:
ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype]
sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128
val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth
sign = torch.randint(2, size=(m, n), device="xpu", dtype=torch.int8) * 128
val = torch.randint(2**3 - 1, size=(m, n), device="xpu", dtype=torch.int8) << 7 - ewidth
return sign | val
if dtype == "int8":
return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8)
return torch.randint(-128, 127, (m, n), device="xpu", dtype=torch.int8)
# Use small range of values to prevent numerical issues.
min_exp = -4 if acc_dtype == "float16" else -10
exponents = torch.randint(min_exp, 0, size=(m, n))
ret = (2.**exponents).to(getattr(torch, dtype)).to("cuda")
ret = (2.**exponents).to(getattr(torch, dtype)).to("xpu")
return ret

# allocate/transpose inputs
Expand Down
6 changes: 6 additions & 0 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ function run_core_tests {
echo "FAILED: return code $?" ; exit $?
fi

cd $CORE_TEST_DIR/operators
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
TRITON_DISABLE_LINE_INFO=1 python3 -m pytest -n 8 --verbose
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
if [ $? -ne 0 ]; then
echo "FAILED: return code $?" ; exit $?
fi

cd $CORE_TEST_DIR/runtime
TRITON_DISABLE_LINE_INFO=1 python3 -m pytest --verbose
if [ $? -ne 0 ]; then
Expand Down
Loading