Skip to content

Commit

Permalink
fix tests; adjust Asig.__setitem__
Browse files Browse the repository at this point in the history
  • Loading branch information
aleneum committed Feb 28, 2020
1 parent c39cbaa commit ab2824b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
30 changes: 14 additions & 16 deletions pya/Asig.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,9 @@ def __setitem__(self, index, value):
# parse row index rindex into ridx
# sr = self.sr # unused default case for conversion if not changed by special case
# e.g. a[[4,5,7,8,9]], or a[[True, False, True...]]
if isinstance(rindex, list):
ridx = rindex
elif isinstance(rindex, int): # picking a single row
ridx = rindex
elif isinstance(rindex, slice):
# _, _, step = rindex.indices(len(self.sig))
# sr = int(self.sr / abs(step)) # This is unused.
# if isinstance(rindex, (list, int, slice)):
# ridx = rindex
if isinstance(rindex, (slice, int)):
ridx = rindex
elif isinstance(rindex, dict): # time slicing
for key, val in rindex.items():
Expand All @@ -380,27 +376,29 @@ def __setitem__(self, index, value):
except TypeError:
stop = None
ridx = slice(start, stop, 1)
else: # Dont think there is a usecase.
ridx = rindex
elif hasattr(rindex, '__iter__'):
ridx = list(rindex)
else:
return # we cannot determine a row index; return without changes

# now parse cindex
if isinstance(cindex, list):
if hasattr(cindex, '__iter__'):
if isinstance(cindex[0], str):
cidx = [self.col_name.get(s) for s in cindex]
cidx = cidx[0] if len(cidx) == 1 else cidx # hotfix for now.
elif isinstance(cindex[0], bool):
cidx = cindex
elif isinstance(cindex[0], int):
cidx = cindex
else:
try:
cidx = list(cindex)
except TypeError:
cidx = slice(None)
# int, slice are the same.
elif isinstance(cindex, int) or isinstance(cindex, slice):
elif isinstance(cindex, (int, slice)):
cidx = cindex
# if only a single channel name is given.
elif isinstance(cindex, str):
cidx = self.col_name.get(cindex)
else:
cidx = slice(None)
# cidx = None

_LOGGER.debug("self.sig.ndim == %d", self.sig.ndim)
if self.sig.ndim == 1:
Expand Down
10 changes: 7 additions & 3 deletions tests/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,12 @@ def test_numpy_index(self):

def test_byte_index(self):
self.azeros[bytes([0, 1, 2])] = np.ones(3)
self.assertTrue(np.array_equal(self.azeros[0, 1, 2].sig, self.aones[0, 1, 2].sig))
self.assertTrue(np.array_equal(self.azeros[[0, 1, 2]].sig, self.aones[[0, 1, 2]].sig))

def test_asig_index(self):
self.azeros[self.aones.sig.astype(np.bool)] = self.aones.sig
self.assertTrue(np.array_equal(np.ones(self.sr), self.azeros.sig))

def test_invalid_slicing_type(self):
self.azeros[self.aones] = self.aones
self.assertEqual(np.zeros(self.sr), self.azeros)
self.azeros[self.aones] = self.aones.sig
self.assertTrue(np.array_equal(np.zeros(self.sr), self.azeros.sig))

0 comments on commit ab2824b

Please sign in to comment.