Skip to content

Instantly share code, notes, and snippets.

@ivirshup
Last active October 9, 2023 10:42
Show Gist options
  • Save ivirshup/3fbe634b648304978ea77469b5d88961 to your computer and use it in GitHub Desktop.
Save ivirshup/3fbe634b648304978ea77469b5d88961 to your computer and use it in GitHub Desktop.
Using sparse dask chunks inside of anndata
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "734d0a4f-79fd-4672-b028-afe1cf3282de",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.10.1'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%load_ext memory_profiler\n",
"\n",
"import anndata as ad, scanpy as sc, numpy as np, dask.array as da, pandas as pd\n",
"import zarr\n",
"from dask import delayed\n",
"from scipy import sparse\n",
"\n",
"ad.__version__"
]
},
{
"cell_type": "markdown",
"id": "2c0b1594",
"metadata": {},
"source": [
"## Download a couple example datasets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "065c5f51-d95a-4c7f-8941-318b0bdf6d34",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/workspace/mambaforge/envs/anndata-dev/lib/python3.11/site-packages/anndata/_core/anndata.py:1900: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n",
" utils.warn_names_duplicates(\"var\")\n",
"/mnt/workspace/mambaforge/envs/anndata-dev/lib/python3.11/site-packages/anndata/_core/anndata.py:1900: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n",
" utils.warn_names_duplicates(\"var\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 213 ms, sys: 40.2 ms, total: 253 ms\n",
"Wall time: 256 ms\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/workspace/mambaforge/envs/anndata-dev/lib/python3.11/site-packages/anndata/_core/anndata.py:1900: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n",
" utils.warn_names_duplicates(\"var\")\n",
"/mnt/workspace/mambaforge/envs/anndata-dev/lib/python3.11/site-packages/anndata/_core/anndata.py:1900: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n",
" utils.warn_names_duplicates(\"var\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 171 ms, sys: 47.1 ms, total: 218 ms\n",
"Wall time: 218 ms\n"
]
}
],
"source": [
"import pooch\n",
"\n",
"EXAMPLE_DATA = pooch.create(\n",
" path=pooch.os_cache(\"scverse_tutorials\"),\n",
" base_url=\"doi:10.6084/m9.figshare.22716739.v1/\",\n",
")\n",
"EXAMPLE_DATA.load_registry_from_doi()\n",
"\n",
"samples = {\n",
" \"s1d1\": \"s1d1_filtered_feature_bc_matrix.h5\",\n",
" \"s1d3\": \"s1d3_filtered_feature_bc_matrix.h5\",\n",
"}\n",
"adatas = {}\n",
"\n",
"for sample_id, filename in samples.items():\n",
" path = EXAMPLE_DATA.fetch(filename)\n",
" sample_adata = sc.read_10x_h5(path)\n",
" sample_adata.var_names_make_unique()\n",
" %time sample_adata.write_zarr(f\"{sample_id}.zarr\")"
]
},
{
"cell_type": "markdown",
"id": "fcb4c6f4",
"metadata": {},
"source": [
"## Functions for IO"
]
},
{
"cell_type": "markdown",
"id": "79f766de",
"metadata": {},
"source": [
"Unfortunatley, dask doesn't have the best support for sparse arrays. So we need to do a little leg work to get these things working.\n",
"\n",
"Here we define a helper function and class that lets us create dask arrays with the appropriate \"meta array\" from a `dask.delayed` task."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a3574cc4-b557-4ef5-9227-6e86d548e837",
"metadata": {},
"outputs": [],
"source": [
"def csr_callable(shape: tuple[int, int], dtype) -> sparse.csr_matrix:\n",
" if len(shape) == 0:\n",
" shape = (0, 0)\n",
" if len(shape) == 1:\n",
" shape = (shape[0], 0)\n",
" elif len(shape) == 2:\n",
" pass\n",
" else:\n",
" raise ValueError(shape)\n",
"\n",
" return sparse.csr_matrix(shape, dtype=dtype)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "48333808-de8e-4ab7-955c-92948ecb91f4",
"metadata": {},
"outputs": [],
"source": [
"class CSRCallable:\n",
" \"\"\"Dummy class to bypass dask checks\"\"\"\n",
" def __new__(cls, shape, dtype):\n",
" return csr_callable(shape, dtype)"
]
},
{
"cell_type": "markdown",
"id": "94cac178",
"metadata": {},
"source": [
"Here we create dask chunks from a `CSRDataset`:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "79db8e22",
"metadata": {},
"outputs": [],
"source": [
"def make_dask_chunk(\n",
" x: ad.experimental.CSRDataset,\n",
" start: int,\n",
" end: int\n",
") -> da.Array:\n",
" def take_slice(x, idx):\n",
" return x[idx]\n",
"\n",
" return da.from_delayed(\n",
" delayed(take_slice)(x, slice(start, end)),\n",
" dtype=x.dtype,\n",
" shape=(end - start, x.shape[1]),\n",
" meta=CSRCallable,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "60fe5b0e-f563-45e3-83a8-357a8b78c95c",
"metadata": {},
"outputs": [],
"source": [
"def sparse_dataset_as_dask(x, stride: int):\n",
" n_chunks, rem = divmod(x.shape[0], stride)\n",
"\n",
" chunks = []\n",
" cur_pos = 0\n",
" for i in range(n_chunks):\n",
" chunks.append(make_dask_chunk(x, cur_pos, cur_pos + stride))\n",
" cur_pos += stride\n",
" if rem:\n",
" chunks.append(make_dask_chunk(x, cur_pos, x.shape[0]))\n",
"\n",
" return da.concatenate(chunks, axis=0)"
]
},
{
"cell_type": "markdown",
"id": "07e2dd9c-c2dd-4019-a1ac-c5c966a427cc",
"metadata": {},
"source": [
"## Demo"
]
},
{
"cell_type": "markdown",
"id": "b3406ef7",
"metadata": {},
"source": [
"First showing the above works:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "622fd838-0c91-4ac3-a9dd-cae9de5ec5c9",
"metadata": {},
"outputs": [],
"source": [
"z = zarr.open(\"s1d1.zarr\")\n",
"x = ad.experimental.sparse_dataset(z[\"X\"])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "18b92237-0716-40e8-9359-211a9c58de67",
"metadata": {},
"outputs": [],
"source": [
"# showing that they are equal\n",
"assert not (x.to_memory() != sparse_dataset_as_dask(x, 1000).compute()).nnz"
]
},
{
"cell_type": "markdown",
"id": "67ff3ea5",
"metadata": {},
"source": [
"Now a real usecase:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ae06d1d3-0e5a-4da2-a16e-20ddf7ffc3e3",
"metadata": {},
"outputs": [],
"source": [
"def read_w_sparse_dask(path: str, obs_chunk: int = 1000) -> ad.AnnData:\n",
" z = zarr.open(path)\n",
" return ad.AnnData(\n",
" X=sparse_dataset_as_dask(ad.experimental.sparse_dataset(z[\"X\"]), obs_chunk),\n",
" obs=ad.experimental.read_elem(z[\"obs\"]),\n",
" var=ad.experimental.read_elem(z[\"var\"]),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "8659881f-c539-43a2-94c0-38e7eed4b724",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 30.5 ms, sys: 7.18 ms, total: 37.7 ms\n",
"Wall time: 35.7 ms\n"
]
}
],
"source": [
"%%time\n",
"adatas = {s: read_w_sparse_dask(f\"{s}.zarr\", 4000) for s in samples}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "61261b51-2c7d-42ed-978c-bb840272f2b2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 19.7 ms, sys: 3.83 ms, total: 23.6 ms\n",
"Wall time: 22.9 ms\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/workspace/mambaforge/envs/anndata-dev/lib/python3.11/site-packages/anndata/_core/merge.py:1242: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning\n",
" concat_indices = concat_indices.str.cat(label_col.map(str), sep=index_unique)\n"
]
}
],
"source": [
"%%time\n",
"combined = ad.concat(adatas, index_unique=\"-\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "003d068d-0bd4-4696-93dd-6a1e2e50e456",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"combined.X.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a3cc32d9-07b5-4d20-87ad-4c2e1f55164e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 369 ms, sys: 172 ms, total: 541 ms\n",
"Wall time: 540 ms\n"
]
}
],
"source": [
"%%time\n",
"combined.write_zarr(\"combined.zarr\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "7ffedac3-3365-43e0-9f08-5f9e14b570c2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 197 ms, sys: 135 ms, total: 333 ms\n",
"Wall time: 198 ms\n"
]
},
{
"data": {
"text/plain": [
"<17125x36601 sparse matrix of type '<class 'numpy.float32'>'\n",
"\twith 26550469 stored elements in Compressed Sparse Row format>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"result = combined.to_memory()\n",
"result.X"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "419f9575-d8c0-437b-880b-90b8e88f5b78",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment