diff --git a/pytest.ini b/pytest.ini index b9bb2d26ca..8d456e23ba 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,3 +4,4 @@ testpaths = test/ python_paths = ./ markers = gpu_test: marks cuda tests + npu_test: marks ascend npu tests diff --git a/test/torchtext_unittest/models/npu_tests/models_npu_test.py b/test/torchtext_unittest/models/npu_tests/models_npu_test.py new file mode 100644 index 0000000000..f3ff4919b2 --- /dev/null +++ b/test/torchtext_unittest/models/npu_tests/models_npu_test.py @@ -0,0 +1,33 @@ +import importlib +import unittest + +import pytest +import torch +from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase +from torchtext_unittest.models.roberta_models_test_impl import RobertaBaseTestModels +from torchtext_unittest.models.t5_models_test_impl import T5BaseTestModels + + +def is_npu_available(check_device=False): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch + import torch_npu # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + return hasattr(torch, "npu") and torch.npu.is_available() + + +@pytest.mark.npu_test +@unittest.skipIf(not is_npu_available(), reason="Ascend NPU is not available") +class TestModels32NPU(RobertaBaseTestModels, T5BaseTestModels, TorchtextTestCase): + dtype = torch.float32 + device = torch.device("npu")