Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sparse arrays inside dask arrays #1114

Merged
merged 19 commits into from
Sep 8, 2023

Conversation

ivirshup
Copy link
Member

@ivirshup ivirshup commented Aug 29, 2023

I got a little excited. Maybe this will be 0.10.0.

  • This is a good POC for dask arrays holding cupy and cupyx arrays
  • This definitely could be more performant. We could do better handling of writing at the very least. I.e. don't build a SparseDataset and append, instead do something like:
    • hold indptr + a chunk sized buffer for indices and data in memory
    • for each chunk generated, write storage chunks until we can't write an entire chunk.
    • Fill memory buffer with remaining data.
    • Generate next matrix chunk start by filling memory buffer, and writing to storage.

TODO:


I've now tried making it faster using the approach above. It's faster, but like 20% for default zarr settings. Here's that implementation:

def write_dask_sparse(f, k, elem, _writer=None, dataset_kwargs=MappingProxyType({})):
    if k in f:  # TODO: Delete in actual implementaion, handled by machinery
        del f[k]  # TODO: Delete in actual implementaion, handled by machinery

    def chunk_slice(start: int, stop: int) -> tuple[slice | None, slice | None]:
        result = [slice(None), slice(None)]
        result[axis] = slice(start, stop)
        return tuple(result)

    def resize_and_write(storage: ZarrArray | H5Array, buffer: np.ndarray):
        prev_shape = storage.shape[0]
        storage.resize((prev_shape + buffer.shape[0],))
        storage[prev_shape:] = buffer

    if isinstance(f, H5Group) and "maxshape" not in dataset_kwargs:
        dataset_kwargs = dict(maxshape=(None,), **dataset_kwargs)

    chunktype = elem._meta # 0 dim instance of spmarix with correct dtype
    sparse_format = chunktype.format

    # Determine axis to concatenate along from sparse format
    if sparse_format == "csr":
        axis = 0
    elif sparse_format == "csc":
        axis = 1
    else:
        raise NotImplementedError(
            f"Cannot write dask sparse arrays with format {sparse_format}"
        )
    axis_chunks: tuple[int, ...] = elem.chunks[axis]
    chunk_start, chunk_stop = 0, 0

    # Start initializing output
    group = f.create_group(k)
    group.attrs["shape"] = elem.shape
    
    group.attrs["encoding-type"] = f"{sparse_format}_matrix"  # TODO: Delete in actual implementaion, handled by machinery
    group.attrs["encoding-version"] = "0.1.0"  # TODO: Delete in actual implementaion, handled by machinery

    # Initialize final storage arrays + intermediate buffers
    data_storage = group.create_dataset("data", shape=(0,), dtype=chunktype.data.dtype, **dataset_kwargs)
    storage_chunksize = data_storage.chunks[0]  # This should probably actually be determined by shape
    data_buffer = np.zeros(storage_chunksize, dtype=chunktype.data.dtype)
    indices_storage = group.create_dataset("indices", shape=(0,), dtype=chunktype.indices.dtype, **dataset_kwargs)
    indices_buffer = np.empty(storage_chunksize, dtype=chunktype.indices.dtype)

    indptr = np.zeros(elem.shape[axis] + 1, dtype=np.int64)
    buffer_offset = 0

    for chunk_size in axis_chunks:
        chunk_start = chunk_stop
        chunk_stop += chunk_size

        matrix_chunk = elem[chunk_slice(chunk_start, chunk_stop)].compute()

        local_offset = 0  # offset into current chunks indices/ data array
        nnz = matrix_chunk.nnz

        indptr[chunk_start + 1: chunk_stop + 1] = indptr[chunk_start] + matrix_chunk.indptr[1:]

        # Fill remaining buffer
        if buffer_offset != 0:
            buffer_remaining = storage_chunksize - buffer_offset
            local_offset = min(buffer_remaining, nnz)

            new_buffer_offset = buffer_offset + local_offset

            indices_buffer[buffer_offset : new_buffer_offset] = matrix_chunk.indices[:local_offset]
            data_buffer[buffer_offset : new_buffer_offset] = matrix_chunk.data[:local_offset]

            if new_buffer_offset == storage_chunksize:
                resize_and_write(indices_storage, indices_buffer)
                resize_and_write(data_storage, data_buffer)
                buffer_offset = 0
            else:
                buffer_offset = new_buffer_offset

        # Write any full sized chunks
        n_full_chunks, n_remaining = divmod(nnz - local_offset, storage_chunksize)
        if n_full_chunks > 0:
            local_slice = slice(local_offset, local_offset + n_full_chunks * storage_chunksize)
            resize_and_write(indices_storage, matrix_chunk.indices[local_slice])
            resize_and_write(data_storage, matrix_chunk.data[local_slice])
            local_offset += n_full_chunks * storage_chunksize

        # Write rest to buffer
        if n_remaining > 0:
            indices_buffer[:n_remaining] = matrix_chunk.indices[local_offset : local_offset + n_remaining]
            data_buffer[:n_remaining] = matrix_chunk.data[local_offset : local_offset + n_remaining]
            buffer_offset = n_remaining

    # Write whatever is left in the buffer
    if buffer_offset:
        indices_storage.append(indices_buffer[:buffer_offset])
        data_storage.append(data_buffer[:buffer_offset])

    group.create_dataset("indptr", data=indptr, **dataset_kwargs)

Some timing results:

Setup
X = sparse.random(
    100_000,
    10_000,
    format="csr",
    density=0.01,
    random_state=np.random.default_rng(),
)
X_dask = da.from_array(X, chunks=(1000, 5000))

z = zarr.open("test.zarr", "w")

%timeit write_elem(z, "from_memory", X, dataset_kwargs={"chunks": 10_000})
%timeit write_elem(z, "original_implementation", X_dask, dataset_kwargs={"chunks": 10_000})
%timeit write_dask_sparse(z, "new_implementation", X_dask, dataset_kwargs={"chunks": 10_000})
In memory:  1.43 s ± 60.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Naive Dask: 2.36 s ± 88.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
New impl:   2.08 s ± 88.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

From profiling, it seems like a lot of the difference is from:

  • Every time the array is resized, the metadata is updated
  • Time for dask compute

I think a lot of this could be made up with more write buffering.

But also this is way more complex.

Transpose

Tranpose is an issue since scipy doesn't let you pass axes to transpose and dask always passes that argument.

I've opened a PR to scipy to fix, but we'll need to work around this until then. Either with a fix or disallowing this.

I think the fix should look like:

def transpose_by_block(dask_array: da.Array) -> da.Array:
    b = dask_array.blocks
    b_raveled = b.ravel()
    block_layout = np.zeros(b.shape, dtype=object)

    for i in range(block_layout.size):
        block_layout.flat[i] = b_raveled[i].map_blocks(lambda x: x.T, chunks=x.shape[::-1])

    return da.block(block_layout.T.tolist())
  • Transpose each block of the dask array
  • Transpose the block structure of the dask array

Indexing

dask.Array.vindex is broken when the meta array is a sparse matrix. Also maybe when it's a dense np.matrix class, but that is less important.

@codecov
Copy link

codecov bot commented Aug 29, 2023

Codecov Report

Merging #1114 (14771bd) into main (8c00a29) will decrease coverage by 32.55%.
The diff coverage is 46.29%.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1114       +/-   ##
===========================================
- Coverage   84.97%   52.43%   -32.55%     
===========================================
  Files          36       36               
  Lines        5197     5287       +90     
===========================================
- Hits         4416     2772     -1644     
- Misses        781     2515     +1734     
Flag Coverage Δ
gpu-tests 52.43% <46.29%> (-0.34%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
anndata/_io/specs/methods.py 48.52% <20.83%> (-39.25%) ⬇️
anndata/_core/merge.py 82.97% <27.77%> (-11.14%) ⬇️
anndata/compat/__init__.py 40.57% <38.46%> (-36.86%) ⬇️
anndata/tests/helpers.py 79.06% <52.17%> (-16.96%) ⬇️
anndata/_core/index.py 66.41% <71.42%> (-28.20%) ⬇️
anndata/utils.py 40.96% <75.00%> (-43.70%) ⬇️
anndata/_io/specs/registry.py 82.60% <77.77%> (-13.31%) ⬇️
anndata/_core/anndata.py 57.90% <100.00%> (-25.70%) ⬇️

... and 16 files with indirect coverage changes

📢 Have feedback on the report? [Share it here](https://about.codecov.io/codecov-pr-comment-feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=scverse).

Comment on lines 1305 to 1306
X=_safe_transpose(X) if X is not None else None,
# X=t_csr(X) if X is not None else None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@flying-sheep do you remember why we would convert to csr here? It seems like an odd choice considering transpose is "free" if we don't insist on the result type.

Copy link
Member

@flying-sheep flying-sheep Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think older versions of scanpy or anndata assumed csr, but I don’t really know.

I guess we should run scanpy’s test suite on this branch before merging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems fine... But part of me feels like we should do a deprecation.

Tbh I think transposing is pretty rare. Maybe we change it with a release candidate and if anyone complains we do a deprecation?

@ivirshup
Copy link
Member Author

@selmanozleyen pinging you for a look at this. It's pretty close to done I think, just 3 failing tests left. WDYT?

Comment on lines 445 to 454
assert_equal(b, a.compute(), exact, elem_name)
# TODO: Figure out why we did this
# from dask.array.utils import assert_eq
#
# I believe the above fails for sparse matrices due to some coercion to np.matrix
# if exact:
# assert_eq(a, b, check_dtype=True, check_type=True, check_graph=False)
# else:
# # TODO: Why does it fail when check_graph=True
# assert_eq(a, b, check_dtype=False, check_type=False, check_graph=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember not wanting to use compute for the equality check to see if the lazy version works fine with the assertion. However, it would also check if the computation graph was also the same, which turned out to fail. Because lazy equality check should also be dasks feature I think we can maybe write a assert_eq function for two dask arrays and compute if other isn't a dask array.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dask's assert_eq always computes the values. To me, this makes sense for a test assertion, since it's okay if it's a bit slow.

Just graph based comparisons are probably a better fit for the anndata._core.merge.equals.

Copy link
Member

@selmanozleyen selmanozleyen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very exciting!

I am not really familiar with the transpose code so I glossed over it. I wrote some comments on parts of code that I think would be good to clarify (e.g. indexing and equal functions).

Some new warnings are present in the tests and capturing or possibly fixing them would be important imo.

anndata/_io/specs/registry.py Outdated Show resolved Hide resolved
Comment on lines +135 to +137
# TODO: this may have been working for some cases?
subset_idx = np.ix_(*subset_idx)
return a.vindex[subset_idx]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for some sparse cases? I don't understand the problem here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally the vindex case would be more efficient by doing indexing across both dimensions at the same time. Though I haven't actually checked this.

But I believe there were some test cases passing without this change which are now going through the new branches. It could be worth making those cases go through vindex instead of [obs_idx, :][:, var_idx]

Comment on lines +125 to +127
if isinstance(a._meta, spmatrix):
# TODO: Maybe also do this in the other case?
return da.map_blocks(equal, a, b, drop_axis=(0, 1)).all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason to not use np.equal? Wouldn't it check type for each block even though the block type is the same for all?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.equal doesn't work on sparse arrays:

In [3]: X = sparse.random(
   ...:     1000,
   ...:     100,
   ...:     format="csr",
   ...:     density=0.01,
   ...:     random_state=np.random.default_rng(),
   ...: )

In [4]: np.equal(X, X)
/usr/local/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3508: SparseEfficiencyWarning: Comparing sparse matrices using == is inefficient, try using != instead.
  exec(code_obj, self.user_global_ns, self.user_ns)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 1
----> 1 np.equal(X, X)

File /usr/local/lib/python3.9/site-packages/scipy/sparse/_base.py:332, in _spbase.__bool__(self)
    330     return self.nnz != 0
    331 else:
--> 332     raise ValueError("The truth value of an array with more than one "
    333                      "element is ambiguous. Use a.any() or a.all().")

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().

Plus we want the nan handling behavior here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, sorry, I guess I made a mistake because it somehow passed for me when I wrote it, but now it has failed. Then it would make sense if that is just the else. However it both fails for me in this case

def test_sparse_vs_dense_dask():

    data = np.random.random((100, 100))
    x = as_sparse_dask_array(data)
    y = as_dense_dask_array(data)
    assert equal(y,x)
    assert equal(x,y)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like that is also true for in-memory dense and sparse arrays:

from anndata._core.merge import equal
from scipy import sparse

s = sparse.random(100, 100, format="csr", density=0.1)

equal(s, s.toarray())  # False
equal(s.toarray(), s)  # True

But I think I would consider this a separate issue. In general, these functions currently assume that the array types being compared are the same (which fits the use case).

A real solution here would need an array type promotion hierarchy:

  • If you are comparing dense to sparse, which method gets used?
  • If you use equal, and some elements are dense, some are sparse, what gets returned?

anndata/tests/test_concatenate.py Outdated Show resolved Hide resolved
@ivirshup ivirshup marked this pull request as ready for review September 5, 2023 18:33
@ivirshup ivirshup added this to the 0.10.0 milestone Sep 5, 2023
@ivirshup ivirshup enabled auto-merge (squash) September 8, 2023 14:32
@ivirshup ivirshup merged commit 36cd7b9 into scverse:main Sep 8, 2023
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

dask.array support with sparse chunks
3 participants