Skip to content

Commit

Permalink
Merge pull request #454 from neutrinoceros/test_new_nep18_funcs
Browse files Browse the repository at this point in the history
ENH: (NEP 18) add tests for new functions in numpy.linalg
  • Loading branch information
jzuhone authored Oct 13, 2023
2 parents 6523a7e + a6f635f commit 02824de
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@
np.take_along_axis, # works out of the box (tested)
}


if NUMPY_VERSION >= Version("2.0.0dev0"):
NOOP_FUNCTIONS |= {
np.linalg.diagonal, # works out of the box (tested)
np.linalg.trace, # works out of the box (tested)
}

# Functions for which behaviour is intentionally left to default
IGNORED_FUNCTIONS = {
np.i0,
Expand Down Expand Up @@ -441,6 +448,26 @@ def test_invalid_matrix_stack_linalg_pinv():
np.linalg.pinv(stack)


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="linalg.diagonal is new in numpy 2.0"
)
def test_linalg_diagonal():
a = np.eye(3) * cm
b = np.linalg.diagonal(a)
assert type(b) is unyt_array
assert b.units == a.units


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="linalg.trace is new in numpy 2.0"
)
def test_linalg_trace():
a = np.eye(3) * cm
b = np.linalg.trace(a)
assert type(b) is unyt_quantity
assert b.units == a.units


def test_histogram():
arr = np.random.normal(size=1000) * cm
counts, bins = np.histogram(arr, bins=10, range=(arr.min(), arr.max()))
Expand Down

0 comments on commit 02824de

Please sign in to comment.