Created
January 30, 2024 15:52
-
-
Save ivirshup/ecf131b6ba8b094c04bd4a275ad7a804 to your computer and use it in GitHub Desktop.
Benchmark highly variable genes w/ dask
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
git checkout dask-hvg | |
python run_hvg.py | |
# 544.9354245662689 | |
git checkout dask-hvg-w-pandas | |
python run_hvg.py | |
# 130.07193517684937 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import scanpy as sc, anndata as ad, pandas as pd | |
from anndata.experimental import read_elem, sparse_dataset, write_elem | |
from scipy import sparse | |
import h5py | |
import dask | |
import dask.array as da | |
from dask import delayed | |
import zarr | |
def csr_callable(shape: tuple[int, int], dtype) -> sparse.csr_matrix: | |
if len(shape) == 0: | |
shape = (0, 0) | |
if len(shape) == 1: | |
shape = (shape[0], 0) | |
elif len(shape) == 2: | |
pass | |
else: | |
raise ValueError(shape) | |
return sparse.csr_matrix(shape, dtype=dtype) | |
class CSRCallable: | |
"""Dummy class to bypass dask checks""" | |
def __new__(cls, shape, dtype): | |
return csr_callable(shape, dtype) | |
def make_dask_chunk(x: "SparseDataset", start: int, end: int) -> da.Array: | |
def take_slice(x, idx): | |
return x[idx] | |
return da.from_delayed( | |
delayed(take_slice)(x, slice(start, end)), | |
dtype=x.dtype, | |
shape=(end - start, x.shape[1]), | |
meta=CSRCallable, | |
) | |
def sparse_dataset_as_dask(x, stride: int): | |
n_chunks, rem = divmod(x.shape[0], stride) | |
chunks = [] | |
cur_pos = 0 | |
for i in range(n_chunks): | |
chunks.append(make_dask_chunk(x, cur_pos, cur_pos + stride)) | |
cur_pos += stride | |
if rem: | |
chunks.append(make_dask_chunk(x, cur_pos, x.shape[0])) | |
return da.concatenate(chunks, axis=0) | |
def read_w_sparse_dask(group: h5py.Group | zarr.Group, obs_chunk: int = 1000) -> ad.AnnData: | |
return ad.AnnData( | |
X=sparse_dataset_as_dask(sparse_dataset(group["X"]), obs_chunk), | |
**{ | |
k: read_elem(group[k]) if k in group else {} | |
for k in ["layers", "obs", "var", "obsm", "varm", "uns", "obsp", "varp"] | |
} | |
) | |
f = h5py.File("/mnt/workspace/data/cd19-carT-atlas-417k.h5ad") | |
adata = read_w_sparse_dask(f, obs_chunk=5_000) | |
del adata.raw | |
adata.layers["dense"] = adata.X.map_blocks(lambda x: x.toarray(), dtype=adata.X.dtype) | |
subset = adata[:1000].copy() | |
# warmup for imports + cache | |
_ = dask.compute(sc.pp.highly_variable_genes(subset, layer="dense", inplace=False)) | |
t0 = time.time() | |
result = dask.compute(sc.pp.highly_variable_genes(adata, layer="dense", inplace=False))[0] | |
t1 = time.time() | |
print(t1 - t0) | |
pd.to_pickle(result, f"hvg_{sc.__version__}.pkl") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment