Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix): use dask array for missing element in dask concatenation #1780

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

ilan-gold
Copy link
Contributor

@ilan-gold ilan-gold commented Nov 27, 2024

To test:

from pathlib import Path
import time
import anndata as ad
import numpy as np
import zarr
import h5py

def read_as_dask(store: str) -> ad.AnnData:
    """\
    Read from a hierarchical Zarr array store.

    Parameters
    ----------
    store
        The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class.
    """
    if not isinstance(store, str):
       raise ValueError("Only string paths are supported")

    if store.endswith(".h5ad"):
        f = h5py.File(store, "r")
    elif store.endswith(".zarr"):
        f = zarr.open(store, mode="r")
    else:
        raise ValueError("Unknown file format")

    # Read with handling for backwards compat
    def callback(func, elem_name: str, elem, iospec):
        if iospec.encoding_type == "anndata" or elem_name.endswith("/"):
            return ad.AnnData(
                **{
                    k: ad.experimental.read_dispatched(v, callback)
                    for k, v in dict(elem).items()
                    if not k.startswith("raw.")
                }
            )
        elif elem_name.startswith("/raw"): # remove or add what you need but beware missing elements so proceed with caution
            return None
        elif iospec.encoding_type in {
            "csr_matrix",
            "csc_matrix",
            "array",
        }:
            return ad.experimental.read_elem_as_dask(elem)
        elif iospec.encoding_type == "dict":
            return {k: ad.experimental.read_dispatched(v, callback=callback) for k, v in elem.items()}
        return ad.io.read_elem(elem)

    adata = ad.experimental.read_dispatched(f, callback=callback)

    return adata

shape = (100_000, 10_000)
n_datasets = 2
layer_key = "foo"
def gen_path(i: int):
    return f"data/test_{i}.zarr"
arr = None
for i in range(n_datasets):
    file_path = Path(gen_path(i))
    if not file_path.exists():
        if arr is None:
            arr = np.random.random(shape)
        adata = ad.AnnData(X=arr)
        if i == 0:
            adata.layers[layer_key] = arr
        adata.write_zarr(file_path)
adatas = [read_as_dask(gen_path(i)) for i in range(n_datasets)]
assert sum(layer_key in a.layers for a in adatas) == 1

t = time.time()
concatenated = ad.concat(adatas, join="outer")
print('Concatenation took: ', time.time() - t)

On main this takes about 30 seconds, now .3.

With the python profiler you can see where the performance hit was coming from - by sending in numpy arrays as missing values previously instead of dask, we were triggering dasks tokenization mechanism for an in-memory data structure:

       56    0.000    0.000   34.738    0.620 tokenize.py:47(tokenize)
       56    0.001    0.000   34.737    0.620 tokenize.py:33(_tokenize)
   124/56    0.001    0.000   34.734    0.620 tokenize.py:141(_normalize_seq_func)
  338/138    0.000    0.000   34.734    0.252 tokenize.py:142(_inner_normalize_token)
   139/77    0.001    0.000   34.733    0.451 utils.py:767(__call__)
       10    0.000    0.000   34.725    3.472 core.py:4694(asarray)
        4    0.738    0.184   34.699    8.675 tokenize.py:401(normalize_array)
      208    0.000    0.000   27.179    0.131 hashing.py:94(hash_buffer_hex)
      208    0.000    0.000   27.178    0.131 hashing.py:73(hash_buffer)
      208    0.000    0.000   27.178    0.131 hashing.py:63(_hash_sha1)
      209   27.176    0.130   27.176    0.130 {built-in method _hashlib.openssl_sha1}

Copy link

codecov bot commented Nov 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.53%. Comparing base (7d9fba8) to head (3979acf).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1780      +/-   ##
==========================================
- Coverage   87.01%   84.53%   -2.48%     
==========================================
  Files          40       40              
  Lines        6075     6080       +5     
==========================================
- Hits         5286     5140     -146     
- Misses        789      940     +151     
Files with missing lines Coverage Δ
src/anndata/_core/merge.py 84.04% <100.00%> (-10.94%) ⬇️

... and 7 files with indirect coverage changes

@ilan-gold ilan-gold added this to the 0.11.2 milestone Nov 28, 2024
@ilan-gold ilan-gold changed the title (fix): use dask array for missing element in concatenation (fix): use dask array for missing element in dask concatenation Nov 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dask Concatenation Should Impute With Dask Array, not Numpy
1 participant