Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference UTs check for trition support from accelerator #6782

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def ref_torch_attention(q, k, v, mask, sm_scale):
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("use_flash", [True, False])
def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float16):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @raza-sikander - do you think we can extrapolate this to replace all instances of HAS_TRITON?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@loadams Yes.
This replacing would help cover the case where triton is installed on system but its not supported by device, test would still run as the triton has been installed and fail.
So the ideal case would be to check if it is supported.

pytest.skip("triton is not supported on this system")

minus_inf = -65504.0
dev = deepspeed.accelerator.get_accelerator().device_name()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def test_gelu(batch, sequence, channels, dtype, use_triton_ops):
activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device)
activations_ref = activations_ds.clone().detach()

if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")
ds_out = run_gelu_ds(activations_ds, use_triton_ops)
ref_out = run_gelu_reference(activations_ref)
assert (allclose(ds_out, ref_out))
12 changes: 6 additions & 6 deletions tests/unit/ops/transformer/inference/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def ds_triton_implementation(vals, gamma, beta, epsilon):
@pytest.mark.parametrize("dtype", get_dtypes())
@pytest.mark.parametrize("use_triton_ops", [False, True])
def test_layer_norm(batch, seq_len, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name())
Expand Down Expand Up @@ -93,8 +93,8 @@ def residual_ds_triton_implementation(vals, bias, res, gamma, beta, epsilon):
@pytest.mark.parametrize("dtype", get_dtypes())
@pytest.mark.parametrize("use_triton_ops", [False, True])
def test_layer_norm_residual(batch, seq_len, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
residual = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
Expand Down Expand Up @@ -163,8 +163,8 @@ def test_layer_norm_residual_store_pre_ln_res(batch, seq_len, channels, dtype):
@pytest.mark.parametrize("residual", [True, False])
@pytest.mark.parametrize("input_bias", [True, False])
def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='cuda'):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")
dev = get_accelerator().device_name()
torch.manual_seed(0)
# create data
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def run_matmul_ds(a, b, use_triton_ops=False):
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("use_triton_ops", [True])
def test_matmul_4d(B, H, M, K, N, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

# skip autotune in testing
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_residual_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f
@pytest.mark.parametrize("use_triton_ops", [True, False])
def test_residual_add(batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size, pre_attn_norm,
use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")
ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def run_softmax_ds(input, use_triton_ops=False):
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_triton_ops", [True])
def test_softmax(batch, sequence, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

device = deepspeed.accelerator.get_accelerator().device_name()
input_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device)
Expand Down