Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add saving of ragged lazy vectors #193

Merged
merged 6 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 51 additions & 25 deletions rsciio/_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@

_logger = logging.getLogger(__name__)

# Functions to flatten and unflatten the data to allow for storing
# ragged arrays in hdf5 with dimensionality higher than 1


def flatten_data(x):
new_data = np.empty(shape=x.shape, dtype=object)
shapes = np.empty(shape=x.shape, dtype=object)
for i in np.ndindex(x.shape):
new_data[i] = x[i].ravel()
shapes[i] = np.array(x[i].shape)
return new_data, shapes


def unflatten_data(data, shape):
new_data = np.empty(shape=data.shape, dtype=object)
for i in np.ndindex(new_data.shape):
new_data[i] = np.reshape(data[i], shape[i])
return new_data


# ---------------------------------


def get_signal_chunks(shape, dtype, signal_axes=None, target_size=1e6):
"""
Expand Down Expand Up @@ -243,13 +265,10 @@ def _read_array(group, dataset_key):
# if the data is chunked saved array we must first
# cast to a numpy array to avoid multiple calls to
# _decode_chunk in zarr (or h5py)
ragged_shape = np.array(ragged_shape)
new_data = np.empty(shape=data.shape, dtype=object)
# cast to numpy array to stop multiple calls to _decode_chunk in zarr
data = np.array(data)
for i in np.ndindex(data.shape):
new_data[i] = np.reshape(data[i], ragged_shape[i])
data = new_data
data = da.from_array(data, chunks=data.chunks)
shape = da.from_array(ragged_shape, chunks=ragged_shape.chunks)
shape = shape.rechunk(data.chunks)
data = da.apply_gufunc(unflatten_data, "(),()->()", data, shape)
return data

def group2signaldict(self, group, lazy=False):
Expand Down Expand Up @@ -299,9 +318,12 @@ def group2signaldict(self, group, lazy=False):

data = self._read_array(group, "data")
if lazy:
data = da.from_array(data, chunks=data.chunks)
if not isinstance(data, da.Array):
data = da.from_array(data, chunks=data.chunks)
exp["attributes"]["_lazy"] = True
else:
if isinstance(data, da.Array):
data = data.compute()
data = np.asanyarray(data)
exp["attributes"]["_lazy"] = False
exp["data"] = data
Expand Down Expand Up @@ -724,28 +746,32 @@ def overwrite_dataset(

_logger.info(f"Chunks used for saving: {chunks}")
if data.dtype == np.dtype("O"):
new_data = np.empty(shape=data.shape, dtype=object)
shapes = np.empty(shape=data.shape, dtype=object)
for i in np.ndindex(data.shape):
new_data[i] = data[i].ravel()
shapes[i] = np.array(data[i].shape)
if isinstance(data, da.Array):
new_data, shapes = da.apply_gufunc(
flatten_data,
"()->(),()",
data,
dtype=object,
output_dtypes=[object, object],
allow_rechunk=False,
)
else:
new_data = np.empty(shape=data.shape, dtype=object)
shapes = np.empty(shape=data.shape, dtype=object)
for i in np.ndindex(data.shape):
new_data[i] = data[i].ravel()
shapes[i] = np.array(data[i].shape)

shape_dset = cls._get_object_dset(
group, shapes, f"_ragged_shapes_{key}", shapes.shape, **kwds
)

cls._store_data(
shapes,
shape_dset,
group,
f"_ragged_shapes_{key}",
chunks=shapes.shape,
show_progressbar=show_progressbar,
)
cls._store_data(
new_data,
dset,
(new_data, shapes),
(dset, shape_dset),
group,
key,
chunks,
(key, f"_ragged_shapes_{key}"),
(chunks, shapes.shape),
show_progressbar,
)
else:
Expand Down
45 changes: 37 additions & 8 deletions rsciio/hspy/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,54 @@ def __init__(self, file, signal, expg, **kwds):

@staticmethod
def _store_data(data, dset, group, key, chunks, show_progressbar=True):
if isinstance(data, da.Array):
if data.chunks != dset.chunks:
data = data.rechunk(dset.chunks)
# Tuple of dask arrays can also be passed, in which case the task graphs
# are merged and the data is written in a single `da.store` call.
# This is useful when saving a ragged array, where we need to write
# the data and the shape at the same time as the ragged array must have
# only one dimension.
if isinstance(data, tuple):
data = list(data)
elif not isinstance(data, list):
data = [
data,
]
dset = [
dset,
]
for i, (data_, dset_) in enumerate(zip(data, dset)):
if isinstance(data_, da.Array):
if data_.chunks != dset_.chunks:
data[i] = data_.rechunk(dset_.chunks)
if data_.ndim == 1 and data_.dtype == object:
raise ValueError(
"Saving a 1-D ragged dask array to hspy is not supported yet. "
"Please use the .zspy extension."
)
# for performance reason, we write the data later, with all data
# at the same time in a single `da.store` call
elif data_.flags.c_contiguous:
dset_.write_direct(data_)
else:
dset_[:] = data_
if isinstance(data[0], da.Array):
cm = ProgressBar if show_progressbar else dummy_context_manager
with cm():
# da.store of tuple helps to merge task graphs and avoid computing twice
da.store(data, dset)
elif data.flags.c_contiguous:
dset.write_direct(data)
else:
dset[:] = data

@staticmethod
def _get_object_dset(group, data, key, chunks, **kwds):
"""Creates a h5py dataset object for saving ragged data"""
# For saving ragged array
if chunks is None:
chunks = 1
test_ind = data.ndim * (0,)
if isinstance(data, da.Array):
dtype = data[test_ind].compute().dtype
else:
dtype = data[test_ind].dtype
dset = group.require_dataset(
key, chunks, dtype=h5py.special_dtype(vlen=data.flatten()[0].dtype), **kwds
key, data.shape, dtype=h5py.special_dtype(vlen=dtype), chunks=chunks, **kwds
CSSFrancis marked this conversation as resolved.
Show resolved Hide resolved
)
return dset

Expand Down
24 changes: 17 additions & 7 deletions rsciio/tests/test_hspy.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,8 @@ def test_save_ragged_array(tmp_path, file):

@zspy_marker
@pytest.mark.parametrize("nav_dim", [1, 2, 3])
def test_save_ragged_dim(tmp_path, file, nav_dim):
@pytest.mark.parametrize("lazy", [True, False])
def test_save_ragged_dim(tmp_path, file, nav_dim, lazy):
file = f"nav{nav_dim}_" + file
rng = np.random.default_rng(0)
nav_shape = np.arange(10, 10 * (nav_dim + 1), step=10)
Expand All @@ -867,19 +868,28 @@ def test_save_ragged_dim(tmp_path, file, nav_dim):
data[ind] = rng.random((num, 2)) * 100

s = hs.signals.BaseSignal(data, ragged=True)
if lazy:
s = s.as_lazy()
assert s.axes_manager.navigation_dimension == nav_dim
np.testing.assert_allclose(s.axes_manager.navigation_shape, nav_shape[::-1])
assert s.data.ndim == nav_dim
np.testing.assert_allclose(s.data.shape, nav_shape)

filename = tmp_path / file
s.save(filename)
s2 = hs.load(filename)
assert s.axes_manager.navigation_shape == s2.axes_manager.navigation_shape
assert s.data.shape == s2.data.shape
if ".hspy" in file and nav_dim == 1 and lazy:
with pytest.raises(ValueError):
s.save(filename)

else:
s.save(filename)
s2 = hs.load(filename, lazy=lazy)
assert s.axes_manager.navigation_shape == s2.axes_manager.navigation_shape
assert s.data.shape == s2.data.shape
if lazy:
assert isinstance(s2.data, da.Array)

for indices in np.ndindex(s.data.shape):
np.testing.assert_allclose(s.data[indices], s2.data[indices])
for indices in np.ndindex(s.data.shape):
np.testing.assert_allclose(s.data[indices], s2.data[indices])


def test_load_missing_extension(caplog):
Expand Down
41 changes: 33 additions & 8 deletions rsciio/zspy/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,26 +102,51 @@ def _get_object_dset(group, data, key, chunks, **kwds):
chunks = data.shape
these_kwds = kwds.copy()
these_kwds.update(dict(dtype=object, exact=True, chunks=chunks))
test_ind = data.ndim * (0,)
# Need to know the underlying dtype for the codec
# Note this can't be an object array
if isinstance(data, da.Array):
dtype = data[test_ind].compute().dtype
else:
dtype = data[test_ind].dtype
dset = group.require_dataset(
key,
data.shape,
object_codec=numcodecs.VLenArray(data.flatten()[0].dtype),
object_codec=numcodecs.VLenArray(dtype),
**these_kwds,
)
return dset

@staticmethod
def _store_data(data, dset, group, key, chunks, show_progressbar=True):
"""Write data to zarr format."""
if isinstance(data, da.Array):
if data.chunks != dset.chunks:
data = data.rechunk(dset.chunks)
# Tuple of dask arrays can also be passed, in which case the task graphs
# are merged and the data is written in a single `da.store` call.
# This is useful when saving a ragged array, where we need to write
# the data and the shape at the same time as the ragged array must have
# only one dimension.
if isinstance(data, tuple):
data = list(data)
elif not isinstance(data, list):
data = [
data,
]
dset = [
dset,
]
for i, (data_, dset_) in enumerate(zip(data, dset)):
if isinstance(data_, da.Array):
if data_.chunks != dset_.chunks:
data[i] = data_.rechunk(dset_.chunks)
# for performance reason, we write the data later, with all data
# at the same time in a single `da.store` call
else:
dset_[:] = data_
if isinstance(data[0], da.Array):
cm = ProgressBar if show_progressbar else dummy_context_manager
with cm():
# lock=False is necessary with the distributed scheduler
data.store(dset, lock=False)
else:
dset[:] = data
# da.store of tuple helps to merge task graphs and avoid computing twice
da.store(data, dset, lock=False)


def file_writer(
Expand Down
Loading