Skip to content

Commit

Permalink
(fix): cache arrays in BaseCompressedSparseDataset (#1744)
Browse files Browse the repository at this point in the history
* (fix): lazy chunking respects -1

* (fix): cache arrays in `BaseCompressedSparseDataset`

* (fix): clean up typing

* (fix): doctest double >>>

* (chore): add tests

* (fix): more typing updates

* (chore): add tests

* (fix): remove extra >>>

* (fix): spelling

* (chore): release note

* (chore): release note

* (fix): support `None` and `-1`

* (chore): typing

* (chore): add cache bust test

* (chore): type

* (chore): types

* (chore): better name

* (Fix): overload type

* (chore): bring back test comment

* Update 1744.bugfix.md

* (fix): revert erroneous change

* (fix): dont generate coo matrices

---------

Co-authored-by: Philipp A. <[email protected]>
  • Loading branch information
ilan-gold and flying-sheep authored Nov 25, 2024
1 parent 997fbd7 commit 41369da
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1744.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache accesses to the `data` and `indices` arrays in {class}`~anndata.abc.CSRDataset` and {class}`~anndata.abc.CSCDataset` {user}`ilan-gold`
33 changes: 25 additions & 8 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from scipy.sparse._compressed import _cs_matrix

from .._types import GroupStorageType
from ..compat import H5Array
from .index import Index
else:
from scipy.sparse import spmatrix as _cs_matrix
Expand Down Expand Up @@ -380,7 +381,7 @@ def backend(self) -> Literal["zarr", "hdf5"]:
@property
def dtype(self) -> np.dtype:
"""The :class:`numpy.dtype` of the `data` attribute of the sparse matrix."""
return self.group["data"].dtype
return self._data.dtype

@classmethod
def _check_group_format(cls, group):
Expand Down Expand Up @@ -545,16 +546,18 @@ def append(self, sparse_matrix: ss.csr_matrix | ss.csc_matrix | SpArray) -> None
indptr[orig_data_size:] = (
sparse_matrix.indptr[1:].astype(np.int64) + indptr_offset
)
# Clear cached property
if hasattr(self, "indptr"):
del self._indptr

# indices
indices = self.group["indices"]
orig_data_size = indices.shape[0]
indices.resize((orig_data_size + sparse_matrix.indices.shape[0],))
indices[orig_data_size:] = sparse_matrix.indices

# Clear cached property
for attr in ["_indptr", "_indices", "_data"]:
if hasattr(self, attr):
delattr(self, attr)

@cached_property
def _indptr(self) -> np.ndarray:
"""\
Expand All @@ -565,11 +568,25 @@ def _indptr(self) -> np.ndarray:
arr = self.group["indptr"][...]
return arr

@cached_property
def _indices(self) -> H5Array | ZarrArray:
"""\
Cache access to the indices to prevent unnecessary reads of the zarray
"""
return self.group["indices"]

@cached_property
def _data(self) -> H5Array | ZarrArray:
"""\
Cache access to the data to prevent unnecessary reads of the zarray
"""
return self.group["data"]

def _to_backed(self) -> BackedSparseMatrix:
format_class = get_backed_class(self.format)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"]
mtx.indices = self.group["indices"]
mtx.data = self._data
mtx.indices = self._indices
mtx.indptr = self._indptr
return mtx

Expand All @@ -578,8 +595,8 @@ def to_memory(self) -> ss.csr_matrix | ss.csc_matrix | SpArray:
self.format, use_sparray_in_io=settings.use_sparse_array_on_read
)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"][...]
mtx.indices = self.group["indices"][...]
mtx.data = self._data[...]
mtx.indices = self._indices[...]
mtx.indptr = self._indptr
return mtx

Expand Down
54 changes: 32 additions & 22 deletions src/anndata/_io/specs/lazy_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, overload

import h5py
import numpy as np
from scipy import sparse

import anndata as ad
from anndata.abc import CSCDataset, CSRDataset

from ..._core.file_backing import filename, get_elem_name
from ...compat import H5Array, H5Group, ZarrArray, ZarrGroup
from .registry import _LAZY_REGISTRY, IOSpec

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Mapping, Sequence
from collections.abc import Generator, Mapping, Sequence
from typing import Literal, ParamSpec, TypeVar

from ..._core.sparse_dataset import _CSCDataset, _CSRDataset
from ..._types import ArrayStorageType, StorageType
from ...compat import DaskArray
from ...compat import DaskArray, H5File, SpArray
from .registry import DaskReader

BlockInfo = Mapping[
Expand All @@ -31,16 +30,25 @@

P = ParamSpec("P")
R = TypeVar("R")
D = TypeVar("D")


@overload
@contextmanager
def maybe_open_h5(
path_or_group: Path | ZarrGroup, elem_name: str
) -> Generator[StorageType, None, None]:
if not isinstance(path_or_group, Path):
yield path_or_group
path_or_other: Path, elem_name: str
) -> Generator[H5File, None, None]: ...
@overload
@contextmanager
def maybe_open_h5(path_or_other: D, elem_name: str) -> Generator[D, None, None]: ...
@contextmanager
def maybe_open_h5(
path_or_other: H5File | D, elem_name: str
) -> Generator[H5File | D, None, None]:
if not isinstance(path_or_other, Path):
yield path_or_other
return
file = h5py.File(path_or_group, "r")
file = h5py.File(path_or_other, "r")
try:
yield file[elem_name]
finally:
Expand All @@ -61,20 +69,17 @@ def compute_chunk_layout_for_axis_shape(


def make_dask_chunk(
path_or_group: Path | ZarrGroup,
path_or_sparse_dataset: Path | D,
elem_name: str,
block_info: BlockInfo | None = None,
*,
wrap: Callable[[ArrayStorageType], ArrayStorageType]
| Callable[[H5Group | ZarrGroup], _CSRDataset | _CSCDataset] = lambda g: g,
):
) -> sparse.csr_matrix | sparse.csc_matrix | SpArray:
if block_info is None:
msg = "Block info is required"
raise ValueError(msg)
# We need to open the file in each task since `dask` cannot share h5py objects when using `dask.distributed`
# https://github.com/scverse/anndata/issues/1105
with maybe_open_h5(path_or_group, elem_name) as f:
mtx = wrap(f)
with maybe_open_h5(path_or_sparse_dataset, elem_name) as f:
mtx = ad.io.sparse_dataset(f) if isinstance(f, H5Group) else f
idx = tuple(
slice(start, stop) for start, stop in block_info[None]["array-location"]
)
Expand All @@ -94,10 +99,17 @@ def read_sparse_as_dask(
) -> DaskArray:
import dask.array as da

path_or_group = Path(filename(elem)) if isinstance(elem, H5Group) else elem
path_or_sparse_dataset = (
Path(filename(elem))
if isinstance(elem, H5Group)
else ad.io.sparse_dataset(elem)
)
elem_name = get_elem_name(elem)
shape: tuple[int, int] = tuple(elem.attrs["shape"])
dtype = elem["data"].dtype
if isinstance(path_or_sparse_dataset, CSRDataset | CSCDataset):
dtype = path_or_sparse_dataset.dtype
else:
dtype = elem["data"].dtype
is_csc: bool = elem.attrs["encoding-type"] == "csc_matrix"

stride: int = _DEFAULT_STRIDE
Expand All @@ -123,9 +135,7 @@ def read_sparse_as_dask(
(chunks_minor, chunks_major) if is_csc else (chunks_major, chunks_minor)
)
memory_format = sparse.csc_matrix if is_csc else sparse.csr_matrix
make_chunk = partial(
make_dask_chunk, path_or_group, elem_name, wrap=ad.io.sparse_dataset
)
make_chunk = partial(make_dask_chunk, path_or_sparse_dataset, elem_name)
da_mtx = da.map_blocks(
make_chunk,
dtype=dtype,
Expand Down
72 changes: 58 additions & 14 deletions tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import anndata as ad
from anndata._core.anndata import AnnData
from anndata._core.sparse_dataset import sparse_dataset
from anndata.compat import CAN_USE_SPARSE_ARRAY, SpArray
from anndata._io.specs.registry import read_elem_as_dask
from anndata.compat import CAN_USE_SPARSE_ARRAY, DaskArray, SpArray
from anndata.experimental import read_dispatched
from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func

Expand All @@ -26,6 +27,9 @@
from numpy.typing import ArrayLike, NDArray
from pytest_mock import MockerFixture

from anndata.abc import CSCDataset, CSRDataset
from anndata.compat import ZarrGroup

Idx = slice | int | NDArray[np.integer] | NDArray[np.bool_]


Expand Down Expand Up @@ -281,6 +285,25 @@ def test_dataset_append_memory(
assert_equal(fromdisk, frommem)


def test_append_array_cache_bust(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]):
path = tmp_path / f"test.{diskfmt.replace('ad', '')}"
a = sparse.random(100, 100, format="csr")
if diskfmt == "zarr":
f = zarr.open_group(path, "a")
else:
f = h5py.File(path, "a")
ad.io.write_elem(f, "mtx", a)
ad.io.write_elem(f, "mtx_2", a)
diskmtx = sparse_dataset(f["mtx"])
old_array_shapes = {}
array_names = ["indptr", "indices", "data"]
for name in array_names:
old_array_shapes[name] = getattr(diskmtx, f"_{name}").shape
diskmtx.append(sparse_dataset(f["mtx_2"]))
for name in array_names:
assert old_array_shapes[name] != getattr(diskmtx, f"_{name}").shape


@pytest.mark.parametrize("sparse_format", [sparse.csr_matrix, sparse.csc_matrix])
@pytest.mark.parametrize(
("subset_func", "subset_func2"),
Expand Down Expand Up @@ -354,16 +377,18 @@ def test_dataset_append_disk(


@pytest.mark.parametrize("sparse_format", [sparse.csr_matrix, sparse.csc_matrix])
def test_indptr_cache(
def test_lazy_array_cache(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
):
elems = {"indptr", "indices", "data"}
path = tmp_path / "test.zarr"
a = sparse_format(sparse.random(10, 10))
f = zarr.open_group(path, "a")
ad.io.write_elem(f, "X", a)
store = AccessTrackingStore(path)
store.initialize_key_trackers(["X/indptr"])
for elem in elems:
store.initialize_key_trackers([f"X/{elem}"])
f = zarr.open_group(store, "a")
a_disk = sparse_dataset(f["X"])
a_disk[:1]
Expand All @@ -372,6 +397,14 @@ def test_indptr_cache(
a_disk[8:9]
# one each for .zarray and actual access
assert store.get_access_count("X/indptr") == 2
for elem_not_indptr in elems - {"indptr"}:
assert (
sum(
".zarray" in key_accessed
for key_accessed in store.get_accessed_keys(f"X/{elem_not_indptr}")
)
== 1
)


Kind = Literal["slice", "int", "array", "mask"]
Expand Down Expand Up @@ -421,27 +454,38 @@ def width_idx_kinds(
(
[0],
slice(None, None),
["X/data/.zarray", "X/data/.zarray", "X/data/0"],
["X/data/.zarray", "X/data/0"],
),
(
[0],
slice(None, 3),
["X/data/.zarray", "X/data/.zarray", "X/data/0"],
["X/data/.zarray", "X/data/0"],
),
(
[3, 4, 5],
slice(None, None),
["X/data/.zarray", "X/data/.zarray", "X/data/3", "X/data/4", "X/data/5"],
["X/data/.zarray", "X/data/3", "X/data/4", "X/data/5"],
),
l=10,
),
)
@pytest.mark.parametrize(
"open_func",
[
sparse_dataset,
lambda x: read_elem_as_dask(
x, chunks=(1, -1) if x.attrs["encoding-type"] == "csr_matrix" else (-1, 1)
),
],
ids=["sparse_dataset", "read_elem_as_dask"],
)
def test_data_access(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
idx_maj: Idx,
idx_min: Idx,
exp: Sequence[str],
open_func: Callable[[ZarrGroup], CSRDataset | CSCDataset | DaskArray],
):
path = tmp_path / "test.zarr"
a = sparse_format(np.eye(10, 10))
Expand All @@ -454,19 +498,19 @@ def test_data_access(
store = AccessTrackingStore(path)
store.initialize_key_trackers(["X/data"])
f = zarr.open_group(store)
a_disk = sparse_dataset(f["X"])

# Do the slicing with idx
store.reset_key_trackers()
if a_disk.format == "csr":
a_disk[idx_maj, idx_min]
a_disk = AnnData(X=open_func(f["X"]))
if a.format == "csr":
subset = a_disk[idx_maj, idx_min]
else:
a_disk[idx_min, idx_maj]
subset = a_disk[idx_min, idx_maj]
if isinstance(subset.X, DaskArray):
subset.X.compute(scheduler="single-threaded")

assert store.get_access_count("X/data") == len(exp), store.get_accessed_keys(
"X/data"
)
assert store.get_accessed_keys("X/data") == exp
# dask access order is not guaranteed so need to sort
assert sorted(store.get_accessed_keys("X/data")) == sorted(exp)


@pytest.mark.parametrize(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,9 @@ def gen_list(n):


def gen_sparse(n):
return sparse.random(np.random.randint(1, 100), np.random.randint(1, 100))
return sparse.random(
np.random.randint(1, 100), np.random.randint(1, 100), format="csr"
)


def gen_something(n):
Expand Down

0 comments on commit 41369da

Please sign in to comment.