diff --git a/pya/Asig.py b/pya/Asig.py index 694aafe1..94e0ffa3 100644 --- a/pya/Asig.py +++ b/pya/Asig.py @@ -361,13 +361,9 @@ def __setitem__(self, index, value): # parse row index rindex into ridx # sr = self.sr # unused default case for conversion if not changed by special case # e.g. a[[4,5,7,8,9]], or a[[True, False, True...]] - if isinstance(rindex, list): - ridx = rindex - elif isinstance(rindex, int): # picking a single row - ridx = rindex - elif isinstance(rindex, slice): - # _, _, step = rindex.indices(len(self.sig)) - # sr = int(self.sr / abs(step)) # This is unused. + # if isinstance(rindex, (list, int, slice)): + # ridx = rindex + if isinstance(rindex, (slice, int)): ridx = rindex elif isinstance(rindex, dict): # time slicing for key, val in rindex.items(): @@ -380,27 +376,29 @@ def __setitem__(self, index, value): except TypeError: stop = None ridx = slice(start, stop, 1) - else: # Dont think there is a usecase. - ridx = rindex + elif hasattr(rindex, '__iter__'): + ridx = list(rindex) + else: + return # we cannot determine a row index; return without changes # now parse cindex - if isinstance(cindex, list): + if hasattr(cindex, '__iter__'): if isinstance(cindex[0], str): cidx = [self.col_name.get(s) for s in cindex] cidx = cidx[0] if len(cidx) == 1 else cidx # hotfix for now. - elif isinstance(cindex[0], bool): - cidx = cindex - elif isinstance(cindex[0], int): - cidx = cindex + else: + try: + cidx = list(cindex) + except TypeError: + cidx = slice(None) # int, slice are the same. - elif isinstance(cindex, int) or isinstance(cindex, slice): + elif isinstance(cindex, (int, slice)): cidx = cindex # if only a single channel name is given. elif isinstance(cindex, str): cidx = self.col_name.get(cindex) else: cidx = slice(None) - # cidx = None _LOGGER.debug("self.sig.ndim == %d", self.sig.ndim) if self.sig.ndim == 1: diff --git a/tests/test_setitem.py b/tests/test_setitem.py index 29a365cb..bff3d87b 100644 --- a/tests/test_setitem.py +++ b/tests/test_setitem.py @@ -104,8 +104,12 @@ def test_numpy_index(self): def test_byte_index(self): self.azeros[bytes([0, 1, 2])] = np.ones(3) - self.assertTrue(np.array_equal(self.azeros[0, 1, 2].sig, self.aones[0, 1, 2].sig)) + self.assertTrue(np.array_equal(self.azeros[[0, 1, 2]].sig, self.aones[[0, 1, 2]].sig)) + + def test_asig_index(self): + self.azeros[self.aones.sig.astype(np.bool)] = self.aones.sig + self.assertTrue(np.array_equal(np.ones(self.sr), self.azeros.sig)) def test_invalid_slicing_type(self): - self.azeros[self.aones] = self.aones - self.assertEqual(np.zeros(self.sr), self.azeros) \ No newline at end of file + self.azeros[self.aones] = self.aones.sig + self.assertTrue(np.array_equal(np.zeros(self.sr), self.azeros.sig))