-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for sparse arrays inside dask arrays #1114
Changes from all commits
c6d9fe4
170e925
34cbdda
c29282d
3047172
85b6c09
b6ea23d
dccdf11
73bec28
c081bf6
6ad1b21
9e7af96
1a5817d
5bc7815
7c136ab
0c5dfec
0ea87c6
f76fa42
14771bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,7 +122,11 @@ def equal_dask_array(a, b) -> bool: | |
if isinstance(b, DaskArray): | ||
if tokenize(a) == tokenize(b): | ||
return True | ||
return da.equal(a, b, where=~(da.isnan(a) == da.isnan(b))).all() | ||
if isinstance(a._meta, spmatrix): | ||
# TODO: Maybe also do this in the other case? | ||
return da.map_blocks(equal, a, b, drop_axis=(0, 1)).all() | ||
Comment on lines
+125
to
+127
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any specific reason to not use np.equal? Wouldn't it check type for each block even though the block type is the same for all? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In [3]: X = sparse.random(
...: 1000,
...: 100,
...: format="csr",
...: density=0.01,
...: random_state=np.random.default_rng(),
...: )
In [4]: np.equal(X, X)
/usr/local/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3508: SparseEfficiencyWarning: Comparing sparse matrices using == is inefficient, try using != instead.
exec(code_obj, self.user_global_ns, self.user_ns)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[4], line 1
----> 1 np.equal(X, X)
File /usr/local/lib/python3.9/site-packages/scipy/sparse/_base.py:332, in _spbase.__bool__(self)
330 return self.nnz != 0
331 else:
--> 332 raise ValueError("The truth value of an array with more than one "
333 "element is ambiguous. Use a.any() or a.all().")
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all(). Plus we want the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, sorry, I guess I made a mistake because it somehow passed for me when I wrote it, but now it has failed. Then it would make sense if that is just the else. However it both fails for me in this case def test_sparse_vs_dense_dask():
data = np.random.random((100, 100))
x = as_sparse_dask_array(data)
y = as_dense_dask_array(data)
assert equal(y,x)
assert equal(x,y) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like that is also true for in-memory dense and sparse arrays: from anndata._core.merge import equal
from scipy import sparse
s = sparse.random(100, 100, format="csr", density=0.1)
equal(s, s.toarray()) # False
equal(s.toarray(), s) # True But I think I would consider this a separate issue. In general, these functions currently assume that the array types being compared are the same (which fits the use case). A real solution here would need an array type promotion hierarchy:
|
||
else: | ||
return da.equal(a, b, where=~(da.isnan(a) == da.isnan(b))).all() | ||
|
||
|
||
@equal.register(np.ndarray) | ||
|
@@ -299,7 +303,7 @@ def check_combinable_cols(cols: list[pd.Index], join: Literal["inner", "outer"]) | |
|
||
|
||
# TODO: open PR or feature request to cupy | ||
def _cpblock_diag(mats, format=None, dtype=None): | ||
def _cp_block_diag(mats, format=None, dtype=None): | ||
""" | ||
Modified version of scipy.sparse.block_diag for cupy sparse. | ||
""" | ||
|
@@ -335,6 +339,23 @@ def _cpblock_diag(mats, format=None, dtype=None): | |
).asformat(format) | ||
|
||
|
||
def _dask_block_diag(mats): | ||
from itertools import permutations | ||
import dask.array as da | ||
|
||
blocks = np.zeros((len(mats), len(mats)), dtype=object) | ||
for i, j in permutations(range(len(mats)), 2): | ||
blocks[i, j] = da.from_array( | ||
sparse.csr_matrix((mats[i].shape[0], mats[j].shape[1])) | ||
) | ||
for i, x in enumerate(mats): | ||
if not isinstance(x._meta, sparse.csr_matrix): | ||
x = x.map_blocks(sparse.csr_matrix) | ||
blocks[i, i] = x | ||
|
||
return da.block(blocks.tolist()) | ||
|
||
|
||
################### | ||
# Per element logic | ||
################### | ||
|
@@ -911,7 +932,9 @@ def concat_pairwise_mapping( | |
for m, s in zip(mappings, shapes) | ||
] | ||
if all(isinstance(el, (CupySparseMatrix, CupyArray)) for el in els): | ||
result[k] = _cpblock_diag(els, format="csr") | ||
result[k] = _cp_block_diag(els, format="csr") | ||
elif all(isinstance(el, DaskArray) for el in els): | ||
result[k] = _dask_block_diag(els) | ||
else: | ||
result[k] = sparse.block_diag(els, format="csr") | ||
return result | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for some sparse cases? I don't understand the problem here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally the
vindex
case would be more efficient by doing indexing across both dimensions at the same time. Though I haven't actually checked this.But I believe there were some test cases passing without this change which are now going through the new branches. It could be worth making those cases go through
vindex
instead of[obs_idx, :][:, var_idx]