diff --git a/rsciio/_hierarchical.py b/rsciio/_hierarchical.py index 149545bc5..c5e9bbc06 100644 --- a/rsciio/_hierarchical.py +++ b/rsciio/_hierarchical.py @@ -38,6 +38,28 @@ _logger = logging.getLogger(__name__) +# Functions to flatten and unflatten the data to allow for storing +# ragged arrays in hdf5 with dimensionality higher than 1 + + +def flatten_data(x): + new_data = np.empty(shape=x.shape, dtype=object) + shapes = np.empty(shape=x.shape, dtype=object) + for i in np.ndindex(x.shape): + new_data[i] = x[i].ravel() + shapes[i] = np.array(x[i].shape) + return new_data, shapes + + +def unflatten_data(data, shape): + new_data = np.empty(shape=data.shape, dtype=object) + for i in np.ndindex(new_data.shape): + new_data[i] = np.reshape(data[i], shape[i]) + return new_data + + +# --------------------------------- + def get_signal_chunks(shape, dtype, signal_axes=None, target_size=1e6): """ @@ -243,13 +265,10 @@ def _read_array(group, dataset_key): # if the data is chunked saved array we must first # cast to a numpy array to avoid multiple calls to # _decode_chunk in zarr (or h5py) - ragged_shape = np.array(ragged_shape) - new_data = np.empty(shape=data.shape, dtype=object) - # cast to numpy array to stop multiple calls to _decode_chunk in zarr - data = np.array(data) - for i in np.ndindex(data.shape): - new_data[i] = np.reshape(data[i], ragged_shape[i]) - data = new_data + data = da.from_array(data, chunks=data.chunks) + shape = da.from_array(ragged_shape, chunks=ragged_shape.chunks) + shape = shape.rechunk(data.chunks) + data = da.apply_gufunc(unflatten_data, "(),()->()", data, shape) return data def group2signaldict(self, group, lazy=False): @@ -299,9 +318,12 @@ def group2signaldict(self, group, lazy=False): data = self._read_array(group, "data") if lazy: - data = da.from_array(data, chunks=data.chunks) + if not isinstance(data, da.Array): + data = da.from_array(data, chunks=data.chunks) exp["attributes"]["_lazy"] = True else: + if isinstance(data, da.Array): + data = data.compute() data = np.asanyarray(data) exp["attributes"]["_lazy"] = False exp["data"] = data @@ -724,28 +746,32 @@ def overwrite_dataset( _logger.info(f"Chunks used for saving: {chunks}") if data.dtype == np.dtype("O"): - new_data = np.empty(shape=data.shape, dtype=object) - shapes = np.empty(shape=data.shape, dtype=object) - for i in np.ndindex(data.shape): - new_data[i] = data[i].ravel() - shapes[i] = np.array(data[i].shape) + if isinstance(data, da.Array): + new_data, shapes = da.apply_gufunc( + flatten_data, + "()->(),()", + data, + dtype=object, + output_dtypes=[object, object], + allow_rechunk=False, + ) + else: + new_data = np.empty(shape=data.shape, dtype=object) + shapes = np.empty(shape=data.shape, dtype=object) + for i in np.ndindex(data.shape): + new_data[i] = data[i].ravel() + shapes[i] = np.array(data[i].shape) + shape_dset = cls._get_object_dset( group, shapes, f"_ragged_shapes_{key}", shapes.shape, **kwds ) + cls._store_data( - shapes, - shape_dset, - group, - f"_ragged_shapes_{key}", - chunks=shapes.shape, - show_progressbar=show_progressbar, - ) - cls._store_data( - new_data, - dset, + (new_data, shapes), + (dset, shape_dset), group, - key, - chunks, + (key, f"_ragged_shapes_{key}"), + (chunks, shapes.shape), show_progressbar, ) else: diff --git a/rsciio/hspy/_api.py b/rsciio/hspy/_api.py index ae4c2f5e2..c51970572 100644 --- a/rsciio/hspy/_api.py +++ b/rsciio/hspy/_api.py @@ -73,16 +73,40 @@ def __init__(self, file, signal, expg, **kwds): @staticmethod def _store_data(data, dset, group, key, chunks, show_progressbar=True): - if isinstance(data, da.Array): - if data.chunks != dset.chunks: - data = data.rechunk(dset.chunks) + # Tuple of dask arrays can also be passed, in which case the task graphs + # are merged and the data is written in a single `da.store` call. + # This is useful when saving a ragged array, where we need to write + # the data and the shape at the same time as the ragged array must have + # only one dimension. + if isinstance(data, tuple): + data = list(data) + elif not isinstance(data, list): + data = [ + data, + ] + dset = [ + dset, + ] + for i, (data_, dset_) in enumerate(zip(data, dset)): + if isinstance(data_, da.Array): + if data_.chunks != dset_.chunks: + data[i] = data_.rechunk(dset_.chunks) + if data_.ndim == 1 and data_.dtype == object: + raise ValueError( + "Saving a 1-D ragged dask array to hspy is not supported yet. " + "Please use the .zspy extension." + ) + # for performance reason, we write the data later, with all data + # at the same time in a single `da.store` call + elif data_.flags.c_contiguous: + dset_.write_direct(data_) + else: + dset_[:] = data_ + if isinstance(data[0], da.Array): cm = ProgressBar if show_progressbar else dummy_context_manager with cm(): + # da.store of tuple helps to merge task graphs and avoid computing twice da.store(data, dset) - elif data.flags.c_contiguous: - dset.write_direct(data) - else: - dset[:] = data @staticmethod def _get_object_dset(group, data, key, chunks, **kwds): @@ -90,8 +114,13 @@ def _get_object_dset(group, data, key, chunks, **kwds): # For saving ragged array if chunks is None: chunks = 1 + test_ind = data.ndim * (0,) + if isinstance(data, da.Array): + dtype = data[test_ind].compute().dtype + else: + dtype = data[test_ind].dtype dset = group.require_dataset( - key, chunks, dtype=h5py.special_dtype(vlen=data.flatten()[0].dtype), **kwds + key, data.shape, dtype=h5py.special_dtype(vlen=dtype), chunks=chunks, **kwds ) return dset diff --git a/rsciio/tests/test_hspy.py b/rsciio/tests/test_hspy.py index 57e6dbe66..6d2534da2 100644 --- a/rsciio/tests/test_hspy.py +++ b/rsciio/tests/test_hspy.py @@ -857,7 +857,8 @@ def test_save_ragged_array(tmp_path, file): @zspy_marker @pytest.mark.parametrize("nav_dim", [1, 2, 3]) -def test_save_ragged_dim(tmp_path, file, nav_dim): +@pytest.mark.parametrize("lazy", [True, False]) +def test_save_ragged_dim(tmp_path, file, nav_dim, lazy): file = f"nav{nav_dim}_" + file rng = np.random.default_rng(0) nav_shape = np.arange(10, 10 * (nav_dim + 1), step=10) @@ -867,19 +868,28 @@ def test_save_ragged_dim(tmp_path, file, nav_dim): data[ind] = rng.random((num, 2)) * 100 s = hs.signals.BaseSignal(data, ragged=True) + if lazy: + s = s.as_lazy() assert s.axes_manager.navigation_dimension == nav_dim np.testing.assert_allclose(s.axes_manager.navigation_shape, nav_shape[::-1]) assert s.data.ndim == nav_dim np.testing.assert_allclose(s.data.shape, nav_shape) filename = tmp_path / file - s.save(filename) - s2 = hs.load(filename) - assert s.axes_manager.navigation_shape == s2.axes_manager.navigation_shape - assert s.data.shape == s2.data.shape + if ".hspy" in file and nav_dim == 1 and lazy: + with pytest.raises(ValueError): + s.save(filename) + + else: + s.save(filename) + s2 = hs.load(filename, lazy=lazy) + assert s.axes_manager.navigation_shape == s2.axes_manager.navigation_shape + assert s.data.shape == s2.data.shape + if lazy: + assert isinstance(s2.data, da.Array) - for indices in np.ndindex(s.data.shape): - np.testing.assert_allclose(s.data[indices], s2.data[indices]) + for indices in np.ndindex(s.data.shape): + np.testing.assert_allclose(s.data[indices], s2.data[indices]) def test_load_missing_extension(caplog): diff --git a/rsciio/zspy/_api.py b/rsciio/zspy/_api.py index 5c831d660..39bade250 100644 --- a/rsciio/zspy/_api.py +++ b/rsciio/zspy/_api.py @@ -102,26 +102,51 @@ def _get_object_dset(group, data, key, chunks, **kwds): chunks = data.shape these_kwds = kwds.copy() these_kwds.update(dict(dtype=object, exact=True, chunks=chunks)) + test_ind = data.ndim * (0,) + # Need to know the underlying dtype for the codec + # Note this can't be an object array + if isinstance(data, da.Array): + dtype = data[test_ind].compute().dtype + else: + dtype = data[test_ind].dtype dset = group.require_dataset( key, data.shape, - object_codec=numcodecs.VLenArray(data.flatten()[0].dtype), + object_codec=numcodecs.VLenArray(dtype), **these_kwds, ) return dset @staticmethod def _store_data(data, dset, group, key, chunks, show_progressbar=True): - """Write data to zarr format.""" - if isinstance(data, da.Array): - if data.chunks != dset.chunks: - data = data.rechunk(dset.chunks) + # Tuple of dask arrays can also be passed, in which case the task graphs + # are merged and the data is written in a single `da.store` call. + # This is useful when saving a ragged array, where we need to write + # the data and the shape at the same time as the ragged array must have + # only one dimension. + if isinstance(data, tuple): + data = list(data) + elif not isinstance(data, list): + data = [ + data, + ] + dset = [ + dset, + ] + for i, (data_, dset_) in enumerate(zip(data, dset)): + if isinstance(data_, da.Array): + if data_.chunks != dset_.chunks: + data[i] = data_.rechunk(dset_.chunks) + # for performance reason, we write the data later, with all data + # at the same time in a single `da.store` call + else: + dset_[:] = data_ + if isinstance(data[0], da.Array): cm = ProgressBar if show_progressbar else dummy_context_manager with cm(): # lock=False is necessary with the distributed scheduler - data.store(dset, lock=False) - else: - dset[:] = data + # da.store of tuple helps to merge task graphs and avoid computing twice + da.store(data, dset, lock=False) def file_writer(