diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index 69b9bbb1..875ffc36 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -149,6 +149,13 @@ np.take_along_axis, # works out of the box (tested) } + +if NUMPY_VERSION >= Version("2.0.0dev0"): + NOOP_FUNCTIONS |= { + np.linalg.diagonal, # works out of the box (tested) + np.linalg.trace, # works out of the box (tested) + } + # Functions for which behaviour is intentionally left to default IGNORED_FUNCTIONS = { np.i0, @@ -441,6 +448,26 @@ def test_invalid_matrix_stack_linalg_pinv(): np.linalg.pinv(stack) +@pytest.mark.skipif( + NUMPY_VERSION < Version("2.0.0dev0"), reason="linalg.diagonal is new in numpy 2.0" +) +def test_linalg_diagonal(): + a = np.eye(3) * cm + b = np.linalg.diagonal(a) + assert type(b) is unyt_array + assert b.units == a.units + + +@pytest.mark.skipif( + NUMPY_VERSION < Version("2.0.0dev0"), reason="linalg.trace is new in numpy 2.0" +) +def test_linalg_trace(): + a = np.eye(3) * cm + b = np.linalg.trace(a) + assert type(b) is unyt_quantity + assert b.units == a.units + + def test_histogram(): arr = np.random.normal(size=1000) * cm counts, bins = np.histogram(arr, bins=10, range=(arr.min(), arr.max()))