Skip to content

Instantly share code, notes, and snippets.

@ivirshup
Created January 30, 2024 15:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ivirshup/ecf131b6ba8b094c04bd4a275ad7a804 to your computer and use it in GitHub Desktop.
Save ivirshup/ecf131b6ba8b094c04bd4a275ad7a804 to your computer and use it in GitHub Desktop.
Benchmark highly variable genes w/ dask
git checkout dask-hvg
python run_hvg.py
# 544.9354245662689
git checkout dask-hvg-w-pandas
python run_hvg.py
# 130.07193517684937
import time
import scanpy as sc, anndata as ad, pandas as pd
from anndata.experimental import read_elem, sparse_dataset, write_elem
from scipy import sparse
import h5py
import dask
import dask.array as da
from dask import delayed
import zarr
def csr_callable(shape: tuple[int, int], dtype) -> sparse.csr_matrix:
if len(shape) == 0:
shape = (0, 0)
if len(shape) == 1:
shape = (shape[0], 0)
elif len(shape) == 2:
pass
else:
raise ValueError(shape)
return sparse.csr_matrix(shape, dtype=dtype)
class CSRCallable:
"""Dummy class to bypass dask checks"""
def __new__(cls, shape, dtype):
return csr_callable(shape, dtype)
def make_dask_chunk(x: "SparseDataset", start: int, end: int) -> da.Array:
def take_slice(x, idx):
return x[idx]
return da.from_delayed(
delayed(take_slice)(x, slice(start, end)),
dtype=x.dtype,
shape=(end - start, x.shape[1]),
meta=CSRCallable,
)
def sparse_dataset_as_dask(x, stride: int):
n_chunks, rem = divmod(x.shape[0], stride)
chunks = []
cur_pos = 0
for i in range(n_chunks):
chunks.append(make_dask_chunk(x, cur_pos, cur_pos + stride))
cur_pos += stride
if rem:
chunks.append(make_dask_chunk(x, cur_pos, x.shape[0]))
return da.concatenate(chunks, axis=0)
def read_w_sparse_dask(group: h5py.Group | zarr.Group, obs_chunk: int = 1000) -> ad.AnnData:
return ad.AnnData(
X=sparse_dataset_as_dask(sparse_dataset(group["X"]), obs_chunk),
**{
k: read_elem(group[k]) if k in group else {}
for k in ["layers", "obs", "var", "obsm", "varm", "uns", "obsp", "varp"]
}
)
f = h5py.File("/mnt/workspace/data/cd19-carT-atlas-417k.h5ad")
adata = read_w_sparse_dask(f, obs_chunk=5_000)
del adata.raw
adata.layers["dense"] = adata.X.map_blocks(lambda x: x.toarray(), dtype=adata.X.dtype)
subset = adata[:1000].copy()
# warmup for imports + cache
_ = dask.compute(sc.pp.highly_variable_genes(subset, layer="dense", inplace=False))
t0 = time.time()
result = dask.compute(sc.pp.highly_variable_genes(adata, layer="dense", inplace=False))[0]
t1 = time.time()
print(t1 - t0)
pd.to_pickle(result, f"hvg_{sc.__version__}.pkl")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment