From 827f912fbed05a69734ef05f4ccf78eefa9f367e Mon Sep 17 00:00:00 2001 From: Lee Johnston Date: Sat, 22 Aug 2020 16:30:18 -0500 Subject: [PATCH] Improve efficiency of unyt_array.__getitem__ and add more tests --- unyt/array.py | 10 ++++++---- unyt/tests/test_unyt_array.py | 20 ++++++++++++++++++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/unyt/array.py b/unyt/array.py index 605ae1b4..4b5ecc38 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -1624,10 +1624,12 @@ def ua(self): def __getitem__(self, item): ret = super(unyt_array, self).__getitem__(item) - if hasattr(ret, "shape") and ret.shape == (): - ret = unyt_quantity(ret, self.units, bypass_validation=True, name=self.name) - elif hasattr(ret, "units"): - ret.units = self.units + if getattr(ret, "shape", None) == (): + ret = unyt_quantity(ret, bypass_validation=True, name=self.name) + try: + setattr(ret, "units", self.units) + except AttributeError: + pass return ret # diff --git a/unyt/tests/test_unyt_array.py b/unyt/tests/test_unyt_array.py index 7f915f07..971f8be8 100644 --- a/unyt/tests/test_unyt_array.py +++ b/unyt/tests/test_unyt_array.py @@ -2426,8 +2426,24 @@ def test_ksi(): def test_masked_array(): - data = unyt_array([1, 2], "s") - mask = [True, False] + data = unyt_array([1, 2, 3], "s") + mask = [False, False, True] marr = np.ma.MaskedArray(data, mask) + assert_array_equal(marr.data, data) + assert all(marr.mask == mask) + assert marr.sum() == unyt_quantity(3, "s") + assert np.ma.notmasked_contiguous(marr) == [slice(0, 2, None)] + assert marr.argmax() == 1 + assert marr.max() == unyt_quantity(2, "s") + data = unyt_array([1, 2, np.inf], "s") + marr = np.ma.MaskedArray(data) + marr_masked = np.ma.masked_invalid(marr) + assert all(marr_masked.mask == [False, False, True]) + marr_masked.set_fill_value(unyt_quantity(3, "s")) + assert_array_equal(marr_masked.filled(), unyt_array([1, 2, 3], "s")) + marr_fixed = np.ma.fix_invalid(marr) + assert_array_equal(marr_fixed.data, unyt_array([1, 2, 1e20], "s")) + assert_array_equal(np.ma.filled(marr, unyt_quantity(3, "s")), data) + assert_array_equal(np.ma.compressed(marr_masked), unyt_array([1, 2], "s")) # executing the repr should not raise an exception marr.__repr__()