Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference UTs check for trition support from accelerator #6782

Merged
merged 10 commits into from
Dec 10, 2024
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def ref_torch_attention(q, k, v, mask, sm_scale):
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("use_flash", [True, False])
def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float16):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
loadams marked this conversation as resolved.
Show resolved Hide resolved
pytest.skip("triton is not supported on this system")

minus_inf = -65504.0
dev = deepspeed.accelerator.get_accelerator().device_name()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def test_gelu(batch, sequence, channels, dtype, use_triton_ops):
activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device)
activations_ref = activations_ds.clone().detach()

if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")
ds_out = run_gelu_ds(activations_ds, use_triton_ops)
ref_out = run_gelu_reference(activations_ref)
assert (allclose(ds_out, ref_out))
12 changes: 6 additions & 6 deletions tests/unit/ops/transformer/inference/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def ds_triton_implementation(vals, gamma, beta, epsilon):
@pytest.mark.parametrize("dtype", get_dtypes())
@pytest.mark.parametrize("use_triton_ops", [False, True])
def test_layer_norm(batch, seq_len, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name())
Expand Down Expand Up @@ -93,8 +93,8 @@ def residual_ds_triton_implementation(vals, bias, res, gamma, beta, epsilon):
@pytest.mark.parametrize("dtype", get_dtypes())
@pytest.mark.parametrize("use_triton_ops", [False, True])
def test_layer_norm_residual(batch, seq_len, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
residual = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
Expand Down Expand Up @@ -163,8 +163,8 @@ def test_layer_norm_residual_store_pre_ln_res(batch, seq_len, channels, dtype):
@pytest.mark.parametrize("residual", [True, False])
@pytest.mark.parametrize("input_bias", [True, False])
def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='cuda'):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")
dev = get_accelerator().device_name()
torch.manual_seed(0)
# create data
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def run_matmul_ds(a, b, use_triton_ops=False):
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("use_triton_ops", [True])
def test_matmul_4d(B, H, M, K, N, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

# skip autotune in testing
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_residual_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f
@pytest.mark.parametrize("use_triton_ops", [True, False])
def test_residual_add(batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size, pre_attn_norm,
use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")
ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def run_softmax_ds(input, use_triton_ops=False):
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_triton_ops", [True])
def test_softmax(batch, sequence, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

device = deepspeed.accelerator.get_accelerator().device_name()
input_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device)
Expand Down