Skip to content

Commit

Permalink
Improve efficiency of unyt_array.__getitem__ and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
l-johnston committed Aug 22, 2020
1 parent 6fd4de1 commit 827f912
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
10 changes: 6 additions & 4 deletions unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,10 +1624,12 @@ def ua(self):

def __getitem__(self, item):
ret = super(unyt_array, self).__getitem__(item)
if hasattr(ret, "shape") and ret.shape == ():
ret = unyt_quantity(ret, self.units, bypass_validation=True, name=self.name)
elif hasattr(ret, "units"):
ret.units = self.units
if getattr(ret, "shape", None) == ():
ret = unyt_quantity(ret, bypass_validation=True, name=self.name)
try:
setattr(ret, "units", self.units)
except AttributeError:
pass
return ret

#
Expand Down
20 changes: 18 additions & 2 deletions unyt/tests/test_unyt_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,8 +2426,24 @@ def test_ksi():


def test_masked_array():
data = unyt_array([1, 2], "s")
mask = [True, False]
data = unyt_array([1, 2, 3], "s")
mask = [False, False, True]
marr = np.ma.MaskedArray(data, mask)
assert_array_equal(marr.data, data)
assert all(marr.mask == mask)
assert marr.sum() == unyt_quantity(3, "s")
assert np.ma.notmasked_contiguous(marr) == [slice(0, 2, None)]
assert marr.argmax() == 1
assert marr.max() == unyt_quantity(2, "s")
data = unyt_array([1, 2, np.inf], "s")
marr = np.ma.MaskedArray(data)
marr_masked = np.ma.masked_invalid(marr)
assert all(marr_masked.mask == [False, False, True])
marr_masked.set_fill_value(unyt_quantity(3, "s"))
assert_array_equal(marr_masked.filled(), unyt_array([1, 2, 3], "s"))
marr_fixed = np.ma.fix_invalid(marr)
assert_array_equal(marr_fixed.data, unyt_array([1, 2, 1e20], "s"))
assert_array_equal(np.ma.filled(marr, unyt_quantity(3, "s")), data)
assert_array_equal(np.ma.compressed(marr_masked), unyt_array([1, 2], "s"))
# executing the repr should not raise an exception
marr.__repr__()

0 comments on commit 827f912

Please sign in to comment.