Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sparse arrays inside dask arrays #1114

Merged
merged 19 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,9 +1329,10 @@ def transpose(self) -> "AnnData":
Transpose whole object.

Data matrix is transposed, observations and variables are interchanged.

Ignores `.raw`.
"""
from anndata.compat import _safe_transpose

if not self.isbacked:
X = self.X
else:
Expand All @@ -1342,21 +1343,17 @@ def transpose(self) -> "AnnData":
"which is currently not implemented. Call `.copy()` before transposing."
)

def t_csr(m: sparse.spmatrix) -> sparse.csr_matrix:
return m.T.tocsr() if sparse.isspmatrix_csr(m) else m.T

return AnnData(
X=t_csr(X) if X is not None else None,
X=_safe_transpose(X) if X is not None else None,
layers={k: _safe_transpose(v) for k, v in self.layers.items()},
obs=self.var,
var=self.obs,
# we're taking a private attributes here to be able to modify uns of the original object
uns=self._uns,
obsm=self.varm.flipped(),
varm=self.obsm.flipped(),
obsp=self.varp.copy(),
varp=self.obsp.copy(),
obsm=self._varm,
varm=self._obsm,
obsp=self._varp,
varp=self._obsp,
filename=self.filename,
layers={k: t_csr(v) for k, v in self.layers.items()},
)

T = property(transpose)
Expand Down
12 changes: 9 additions & 3 deletions anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import h5py
import numpy as np
import pandas as pd
from scipy.sparse import spmatrix, issparse
from scipy.sparse import spmatrix, issparse, csc_matrix
from ..compat import AwkArray, DaskArray, Index, Index1D


Expand Down Expand Up @@ -127,8 +127,14 @@ def _subset(a: Union[np.ndarray, pd.DataFrame], subset_idx: Index):
@_subset.register(DaskArray)
def _subset_dask(a: DaskArray, subset_idx: Index):
if all(isinstance(x, cabc.Iterable) for x in subset_idx):
subset_idx = np.ix_(*subset_idx)
return a.vindex[subset_idx]
if isinstance(a._meta, csc_matrix):
return a[:, subset_idx[1]][subset_idx[0], :]
elif isinstance(a._meta, spmatrix):
return a[subset_idx[0], :][:, subset_idx[1]]
else:
# TODO: this may have been working for some cases?
subset_idx = np.ix_(*subset_idx)
return a.vindex[subset_idx]
Comment on lines +135 to +137
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for some sparse cases? I don't understand the problem here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally the vindex case would be more efficient by doing indexing across both dimensions at the same time. Though I haven't actually checked this.

But I believe there were some test cases passing without this change which are now going through the new branches. It could be worth making those cases go through vindex instead of [obs_idx, :][:, var_idx]

return a[subset_idx]


Expand Down
29 changes: 26 additions & 3 deletions anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def equal_dask_array(a, b) -> bool:
if isinstance(b, DaskArray):
if tokenize(a) == tokenize(b):
return True
return da.equal(a, b, where=~(da.isnan(a) == da.isnan(b))).all()
if isinstance(a._meta, spmatrix):
# TODO: Maybe also do this in the other case?
return da.map_blocks(equal, a, b, drop_axis=(0, 1)).all()
Comment on lines +125 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason to not use np.equal? Wouldn't it check type for each block even though the block type is the same for all?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.equal doesn't work on sparse arrays:

In [3]: X = sparse.random(
   ...:     1000,
   ...:     100,
   ...:     format="csr",
   ...:     density=0.01,
   ...:     random_state=np.random.default_rng(),
   ...: )

In [4]: np.equal(X, X)
/usr/local/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3508: SparseEfficiencyWarning: Comparing sparse matrices using == is inefficient, try using != instead.
  exec(code_obj, self.user_global_ns, self.user_ns)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 1
----> 1 np.equal(X, X)

File /usr/local/lib/python3.9/site-packages/scipy/sparse/_base.py:332, in _spbase.__bool__(self)
    330     return self.nnz != 0
    331 else:
--> 332     raise ValueError("The truth value of an array with more than one "
    333                      "element is ambiguous. Use a.any() or a.all().")

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().

Plus we want the nan handling behavior here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, sorry, I guess I made a mistake because it somehow passed for me when I wrote it, but now it has failed. Then it would make sense if that is just the else. However it both fails for me in this case

def test_sparse_vs_dense_dask():

    data = np.random.random((100, 100))
    x = as_sparse_dask_array(data)
    y = as_dense_dask_array(data)
    assert equal(y,x)
    assert equal(x,y)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like that is also true for in-memory dense and sparse arrays:

from anndata._core.merge import equal
from scipy import sparse

s = sparse.random(100, 100, format="csr", density=0.1)

equal(s, s.toarray())  # False
equal(s.toarray(), s)  # True

But I think I would consider this a separate issue. In general, these functions currently assume that the array types being compared are the same (which fits the use case).

A real solution here would need an array type promotion hierarchy:

  • If you are comparing dense to sparse, which method gets used?
  • If you use equal, and some elements are dense, some are sparse, what gets returned?

else:
return da.equal(a, b, where=~(da.isnan(a) == da.isnan(b))).all()


@equal.register(np.ndarray)
Expand Down Expand Up @@ -299,7 +303,7 @@ def check_combinable_cols(cols: list[pd.Index], join: Literal["inner", "outer"])


# TODO: open PR or feature request to cupy
def _cpblock_diag(mats, format=None, dtype=None):
def _cp_block_diag(mats, format=None, dtype=None):
"""
Modified version of scipy.sparse.block_diag for cupy sparse.
"""
Expand Down Expand Up @@ -335,6 +339,23 @@ def _cpblock_diag(mats, format=None, dtype=None):
).asformat(format)


def _dask_block_diag(mats):
from itertools import permutations
import dask.array as da

blocks = np.zeros((len(mats), len(mats)), dtype=object)
for i, j in permutations(range(len(mats)), 2):
blocks[i, j] = da.from_array(
sparse.csr_matrix((mats[i].shape[0], mats[j].shape[1]))
)
for i, x in enumerate(mats):
if not isinstance(x._meta, sparse.csr_matrix):
x = x.map_blocks(sparse.csr_matrix)
blocks[i, i] = x

return da.block(blocks.tolist())


###################
# Per element logic
###################
Expand Down Expand Up @@ -911,7 +932,9 @@ def concat_pairwise_mapping(
for m, s in zip(mappings, shapes)
]
if all(isinstance(el, (CupySparseMatrix, CupyArray)) for el in els):
result[k] = _cpblock_diag(els, format="csr")
result[k] = _cp_block_diag(els, format="csr")
elif all(isinstance(el, DaskArray) for el in els):
result[k] = _dask_block_diag(els)
else:
result[k] = sparse.block_diag(els, format="csr")
return result
Expand Down
48 changes: 48 additions & 0 deletions anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,54 @@ def write_sparse_dataset(f, k, elem, _writer, dataset_kwargs=MappingProxyType({}
f[k].attrs["encoding-version"] = "0.1.0"


@_REGISTRY.register_write(
H5Group, (DaskArray, sparse.csr_matrix), IOSpec("csr_matrix", "0.1.0")
)
@_REGISTRY.register_write(
H5Group, (DaskArray, sparse.csc_matrix), IOSpec("csc_matrix", "0.1.0")
)
@_REGISTRY.register_write(
ZarrGroup, (DaskArray, sparse.csr_matrix), IOSpec("csr_matrix", "0.1.0")
)
@_REGISTRY.register_write(
ZarrGroup, (DaskArray, sparse.csc_matrix), IOSpec("csc_matrix", "0.1.0")
)
def write_dask_sparse(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})):
sparse_format = elem._meta.format
if sparse_format == "csr":
axis = 0
elif sparse_format == "csc":
axis = 1
else:
raise NotImplementedError(
f"Cannot write dask sparse arrays with format {sparse_format}"
)

def chunk_slice(start: int, stop: int) -> tuple[slice | None, slice | None]:
result = [slice(None), slice(None)]
result[axis] = slice(start, stop)
return tuple(result)

axis_chunks = elem.chunks[axis]
chunk_start = 0
chunk_stop = axis_chunks[0]

_writer.write_elem(
f,
k,
elem[chunk_slice(chunk_start, chunk_stop)].compute(),
dataset_kwargs=dataset_kwargs,
)

disk_mtx = sparse_dataset(f[k])

for chunk_size in axis_chunks[1:]:
chunk_start = chunk_stop
chunk_stop += chunk_size

disk_mtx.append(elem[chunk_slice(chunk_start, chunk_stop)].compute())


@_REGISTRY.register_read(H5Group, IOSpec("csc_matrix", "0.1.0"))
@_REGISTRY.register_read(H5Group, IOSpec("csr_matrix", "0.1.0"))
@_REGISTRY.register_read(ZarrGroup, IOSpec("csc_matrix", "0.1.0"))
Expand Down
44 changes: 30 additions & 14 deletions anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,20 @@ def get_spec(
)


def _iter_patterns(elem):
"""Iterates over possible patterns for an element in order of precedence."""
from anndata.compat import DaskArray

t = type(elem)

if isinstance(elem, DaskArray):
yield (t, type(elem._meta), elem.dtype.kind)
yield (t, type(elem._meta))
if hasattr(elem, "dtype"):
yield (t, elem.dtype.kind)
yield t


class Reader:
def __init__(
self, registry: IORegistry, callback: Union[Callable, None] = None
Expand Down Expand Up @@ -257,6 +271,13 @@ def __init__(
self.registry = registry
self.callback = callback

def find_writer(self, dest_type, elem, modifiers):
for pattern in _iter_patterns(elem):
if self.registry.has_writer(dest_type, pattern, modifiers):
return self.registry.get_writer(dest_type, pattern, modifiers)
# Raises IORegistryError
return self.registry.get_writer(dest_type, type(elem), modifiers)

@report_write_key_on_error
def write_elem(
self,
Expand All @@ -269,9 +290,12 @@ def write_elem(
):
from functools import partial
from pathlib import PurePosixPath
import h5py

if isinstance(store, h5py.File):
store = store["/"]

dest_type = type(store)
t = type(elem)

if elem is None:
return lambda *_, **__: None
Expand All @@ -284,19 +308,11 @@ def write_elem(
store.clear()
elif k in store:
del store[k]
if (
hasattr(elem, "dtype")
and (dest_type, (t, elem.dtype.kind), modifiers) in self.registry.write
):
write_func = partial(
self.registry.get_writer(dest_type, (t, elem.dtype.kind), modifiers),
_writer=self,
)
else:
write_func = partial(
self.registry.get_writer(dest_type, t, modifiers),
_writer=self,
)

write_func = partial(
self.find_writer(dest_type, elem, modifiers),
_writer=self,
)

if self.callback is not None:
return self.callback(
Expand Down
29 changes: 28 additions & 1 deletion anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from warnings import warn

import h5py
from scipy.sparse import spmatrix
from scipy.sparse import spmatrix, issparse
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -360,3 +360,30 @@ def inner_f(*args, **kwargs):
return _inner_deprecate_positional_args(func)

return _inner_deprecate_positional_args


def _transpose_by_block(dask_array: DaskArray) -> DaskArray:
import dask.array as da

b = dask_array.blocks
b_raveled = b.ravel()
block_layout = np.zeros(b.shape, dtype=object)

for i in range(block_layout.size):
block_layout.flat[i] = b_raveled[i].map_blocks(
lambda x: x.T, chunks=b_raveled[i].chunks[::-1]
)

return da.block(block_layout.T.tolist())


def _safe_transpose(x):
"""Safely transpose x

This is a workaround for: https://github.com/scipy/scipy/issues/19161
"""

if isinstance(x, DaskArray) and issparse(x._meta):
return _transpose_by_block(x)
else:
return x.T
52 changes: 44 additions & 8 deletions anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,7 @@ def assert_equal_h5py_dataset(a, b, exact=False, elem_name=None):

@assert_equal.register(DaskArray)
def assert_equal_dask_array(a, b, exact=False, elem_name=None):
from dask.array.utils import assert_eq

if exact:
assert_eq(a, b, check_dtype=True, check_type=True, check_graph=False)
else:
# TODO: Why does it fail when check_graph=True
assert_eq(a, b, check_dtype=False, check_type=False, check_graph=False)
assert_equal(b, a.compute(), exact, elem_name)


@assert_equal.register(pd.DataFrame)
Expand Down Expand Up @@ -605,6 +599,33 @@ def _(a):
return as_dense_dask_array(a.toarray())


def _half_chunk_size(a: tuple[int, ...]) -> tuple[int, ...]:
def half_rounded_up(x):
div, mod = divmod(x, 2)
return div + (mod > 0)

return tuple(half_rounded_up(x) for x in a)


@singledispatch
def as_sparse_dask_array(a) -> DaskArray:
import dask.array as da

return da.from_array(sparse.csr_matrix(a), chunks=_half_chunk_size(a.shape))


@as_sparse_dask_array.register(sparse.spmatrix)
def _(a):
import dask.array as da

return da.from_array(a, _half_chunk_size(a.shape))


@as_sparse_dask_array.register(DaskArray)
def _(a):
return a.map_blocks(sparse.csr_matrix)


@contextmanager
def pytest_8_raises(exc_cls, *, match: str | re.Pattern = None):
"""Error handling using pytest 8's support for __notes__.
Expand Down Expand Up @@ -681,14 +702,29 @@ def as_cupy_type(val, typ=None):
)


@singledispatch
def shares_memory(x, y) -> bool:
return np.shares_memory(x, y)


@shares_memory.register(sparse.spmatrix)
def shares_memory_sparse(x, y):
return (
np.shares_memory(x.data, y.data)
and np.shares_memory(x.indices, y.indices)
and np.shares_memory(x.indptr, y.indptr)
)


BASE_MATRIX_PARAMS = [
pytest.param(asarray, id="np_array"),
pytest.param(sparse.csr_matrix, id="scipy_csr"),
pytest.param(sparse.csc_matrix, id="scipy_csc"),
]

DASK_MATRIX_PARAMS = [
pytest.param(as_dense_dask_array, id="dask_array"),
pytest.param(as_dense_dask_array, id="dense_dask_array"),
pytest.param(as_sparse_dask_array, id="sparse_dask_array"),
]

CUPY_MATRIX_PARAMS = [
Expand Down
3 changes: 3 additions & 0 deletions anndata/tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,9 @@ def gen_dim_array(m):
orig_arr = getattr(adatas[k], dim_attr)["arr"]
full_arr = getattr(w_pairwise, dim_attr)["arr"]

if isinstance(full_arr, DaskArray):
full_arr = full_arr.compute()

# Check original values are intact
assert_equal(orig_arr, _subset(full_arr, (inds, inds)))
# Check that entries are filled with zeroes
Expand Down
Loading
Loading