Created
February 21, 2024 09:23
-
-
Save Intron7/bbf5058794be7b81d3953ae39c17d8b8 to your computer and use it in GitHub Desktop.
Dask Singlecell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "37588552-b9d5-4113-9f53-ecc63d3da815", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import scanpy as sc\n", | |
"import anndata\n", | |
"\n", | |
"import dask\n", | |
"import time\n", | |
"\n", | |
"import os, wget\n", | |
"\n", | |
"from dask_cuda import initialize, LocalCUDACluster\n", | |
"from dask.distributed import Client, default_client\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "2f159c97-e435-40c5-ac7b-aa37a6812ced", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import rmm\n", | |
"import cupy as cp\n", | |
"\n", | |
"from rmm.allocators.cupy import rmm_cupy_allocator\n", | |
"\n", | |
"def set_mem():\n", | |
" rmm.reinitialize(managed_memory=True)\n", | |
" cp.cuda.set_allocator(rmm_cupy_allocator)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "49a27e3c-5473-42eb-a0e3-73a22a1a9e16", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"preprocessing_gpus=\"0, 1, 2, 3, 4, 5, 6, 7\"\n", | |
"#preprocessing_gpus=\"0\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "41eb36fe-1a49-4f16-b8a5-e01cdf39a7c2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 934 ms, sys: 875 ms, total: 1.81 s\n", | |
"Wall time: 33 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #e1e1e1; border: 3px solid #9D9D9D; border-radius: 5px; position: absolute;\"> </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <h3 style=\"margin-bottom: 0px;\">Client</h3>\n", | |
" <p style=\"color: #9D9D9D; margin-bottom: 0px;\">Client-0a46410e-d004-11ee-8919-043f72ce7f82</p>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
"\n", | |
" <tr>\n", | |
" \n", | |
" <td style=\"text-align: left;\"><strong>Connection method:</strong> Cluster object</td>\n", | |
" <td style=\"text-align: left;\"><strong>Cluster type:</strong> dask_cuda.LocalCUDACluster</td>\n", | |
" \n", | |
" </tr>\n", | |
"\n", | |
" \n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard: </strong> <a href=\"http://127.0.0.1:8787/status\" target=\"_blank\">http://127.0.0.1:8787/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\"></td>\n", | |
" </tr>\n", | |
" \n", | |
"\n", | |
" </table>\n", | |
"\n", | |
" \n", | |
" <button style=\"margin-bottom: 12px;\" data-commandlinker-command=\"dask:populate-and-launch-layout\" data-commandlinker-args='{\"url\": \"http://127.0.0.1:8787/status\" }'>\n", | |
" Launch dashboard in JupyterLab\n", | |
" </button>\n", | |
" \n", | |
"\n", | |
" \n", | |
" <details>\n", | |
" <summary style=\"margin-bottom: 20px;\"><h3 style=\"display: inline;\">Cluster Info</h3></summary>\n", | |
" <div class=\"jp-RenderedHTMLCommon jp-RenderedHTML jp-mod-trusted jp-OutputArea-output\">\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #e1e1e1; border: 3px solid #9D9D9D; border-radius: 5px; position: absolute;\">\n", | |
" </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <h3 style=\"margin-bottom: 0px; margin-top: 0px;\">LocalCUDACluster</h3>\n", | |
" <p style=\"color: #9D9D9D; margin-bottom: 0px;\">27a5eec8</p>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard:</strong> <a href=\"http://127.0.0.1:8787/status\" target=\"_blank\">http://127.0.0.1:8787/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Workers:</strong> 8\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total threads:</strong> 8\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total memory:</strong> 1.86 TiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\"><strong>Status:</strong> running</td>\n", | |
" <td style=\"text-align: left;\"><strong>Using processes:</strong> True</td>\n", | |
"</tr>\n", | |
"\n", | |
" \n", | |
" </table>\n", | |
"\n", | |
" <details>\n", | |
" <summary style=\"margin-bottom: 20px;\">\n", | |
" <h3 style=\"display: inline;\">Scheduler Info</h3>\n", | |
" </summary>\n", | |
"\n", | |
" <div style=\"\">\n", | |
" <div>\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #FFF7E5; border: 3px solid #FF6132; border-radius: 5px; position: absolute;\"> </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <h3 style=\"margin-bottom: 0px;\">Scheduler</h3>\n", | |
" <p style=\"color: #9D9D9D; margin-bottom: 0px;\">Scheduler-f5323a92-56fa-4fc7-93ba-8c32d627e723</p>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Comm:</strong> tcp://127.0.0.1:42393\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Workers:</strong> 8\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard:</strong> <a href=\"http://127.0.0.1:8787/status\" target=\"_blank\">http://127.0.0.1:8787/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total threads:</strong> 8\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Started:</strong> Just now\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total memory:</strong> 1.86 TiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" </table>\n", | |
" </div>\n", | |
" </div>\n", | |
"\n", | |
" <details style=\"margin-left: 48px;\">\n", | |
" <summary style=\"margin-bottom: 20px;\">\n", | |
" <h3 style=\"display: inline;\">Workers</h3>\n", | |
" </summary>\n", | |
"\n", | |
" \n", | |
" <div style=\"margin-bottom: 20px;\">\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #DBF5FF; border: 3px solid #4CC9FF; border-radius: 5px; position: absolute;\"> </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <details>\n", | |
" <summary>\n", | |
" <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 0</h4>\n", | |
" </summary>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Comm: </strong> tcp://127.0.0.1:41181\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total threads: </strong> 1\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard: </strong> <a href=\"http://127.0.0.1:33463/status\" target=\"_blank\">http://127.0.0.1:33463/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Memory: </strong> 237.50 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Nanny: </strong> tcp://127.0.0.1:34937\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\"></td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td colspan=\"2\" style=\"text-align: left;\">\n", | |
" <strong>Local directory: </strong> /tmp/dask-scratch-space/worker-x_85b63u\n", | |
" </td>\n", | |
" </tr>\n", | |
"\n", | |
" \n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU: </strong>NVIDIA A100-SXM4-80GB\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU memory: </strong> 80.00 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" </table>\n", | |
" </details>\n", | |
" </div>\n", | |
" </div>\n", | |
" \n", | |
" <div style=\"margin-bottom: 20px;\">\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #DBF5FF; border: 3px solid #4CC9FF; border-radius: 5px; position: absolute;\"> </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <details>\n", | |
" <summary>\n", | |
" <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 1</h4>\n", | |
" </summary>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Comm: </strong> tcp://127.0.0.1:39227\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total threads: </strong> 1\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard: </strong> <a href=\"http://127.0.0.1:41193/status\" target=\"_blank\">http://127.0.0.1:41193/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Memory: </strong> 237.50 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Nanny: </strong> tcp://127.0.0.1:35679\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\"></td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td colspan=\"2\" style=\"text-align: left;\">\n", | |
" <strong>Local directory: </strong> /tmp/dask-scratch-space/worker-1w4snx39\n", | |
" </td>\n", | |
" </tr>\n", | |
"\n", | |
" \n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU: </strong>NVIDIA A100-SXM4-80GB\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU memory: </strong> 80.00 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" </table>\n", | |
" </details>\n", | |
" </div>\n", | |
" </div>\n", | |
" \n", | |
" <div style=\"margin-bottom: 20px;\">\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #DBF5FF; border: 3px solid #4CC9FF; border-radius: 5px; position: absolute;\"> </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <details>\n", | |
" <summary>\n", | |
" <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 2</h4>\n", | |
" </summary>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Comm: </strong> tcp://127.0.0.1:42895\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total threads: </strong> 1\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard: </strong> <a href=\"http://127.0.0.1:38303/status\" target=\"_blank\">http://127.0.0.1:38303/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Memory: </strong> 237.50 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Nanny: </strong> tcp://127.0.0.1:39305\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\"></td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td colspan=\"2\" style=\"text-align: left;\">\n", | |
" <strong>Local directory: </strong> /tmp/dask-scratch-space/worker-frdj5cke\n", | |
" </td>\n", | |
" </tr>\n", | |
"\n", | |
" \n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU: </strong>NVIDIA A100-SXM4-80GB\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU memory: </strong> 80.00 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" </table>\n", | |
" </details>\n", | |
" </div>\n", | |
" </div>\n", | |
" \n", | |
" <div style=\"margin-bottom: 20px;\">\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #DBF5FF; border: 3px solid #4CC9FF; border-radius: 5px; position: absolute;\"> </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <details>\n", | |
" <summary>\n", | |
" <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 3</h4>\n", | |
" </summary>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Comm: </strong> tcp://127.0.0.1:46437\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total threads: </strong> 1\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard: </strong> <a href=\"http://127.0.0.1:34027/status\" target=\"_blank\">http://127.0.0.1:34027/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Memory: </strong> 237.50 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Nanny: </strong> tcp://127.0.0.1:36439\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\"></td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td colspan=\"2\" style=\"text-align: left;\">\n", | |
" <strong>Local directory: </strong> /tmp/dask-scratch-space/worker-v01hcrk_\n", | |
" </td>\n", | |
" </tr>\n", | |
"\n", | |
" \n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU: </strong>NVIDIA A100-SXM4-80GB\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU memory: </strong> 80.00 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" </table>\n", | |
" </details>\n", | |
" </div>\n", | |
" </div>\n", | |
" \n", | |
" <div style=\"margin-bottom: 20px;\">\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #DBF5FF; border: 3px solid #4CC9FF; border-radius: 5px; position: absolute;\"> </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <details>\n", | |
" <summary>\n", | |
" <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 4</h4>\n", | |
" </summary>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Comm: </strong> tcp://127.0.0.1:42063\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total threads: </strong> 1\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard: </strong> <a href=\"http://127.0.0.1:41045/status\" target=\"_blank\">http://127.0.0.1:41045/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Memory: </strong> 237.50 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Nanny: </strong> tcp://127.0.0.1:45531\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\"></td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td colspan=\"2\" style=\"text-align: left;\">\n", | |
" <strong>Local directory: </strong> /tmp/dask-scratch-space/worker-lp1fcy5q\n", | |
" </td>\n", | |
" </tr>\n", | |
"\n", | |
" \n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU: </strong>NVIDIA A100-SXM4-80GB\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU memory: </strong> 80.00 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" </table>\n", | |
" </details>\n", | |
" </div>\n", | |
" </div>\n", | |
" \n", | |
" <div style=\"margin-bottom: 20px;\">\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #DBF5FF; border: 3px solid #4CC9FF; border-radius: 5px; position: absolute;\"> </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <details>\n", | |
" <summary>\n", | |
" <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 5</h4>\n", | |
" </summary>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Comm: </strong> tcp://127.0.0.1:37545\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total threads: </strong> 1\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard: </strong> <a href=\"http://127.0.0.1:46027/status\" target=\"_blank\">http://127.0.0.1:46027/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Memory: </strong> 237.50 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Nanny: </strong> tcp://127.0.0.1:43059\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\"></td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td colspan=\"2\" style=\"text-align: left;\">\n", | |
" <strong>Local directory: </strong> /tmp/dask-scratch-space/worker-c2mqyb7l\n", | |
" </td>\n", | |
" </tr>\n", | |
"\n", | |
" \n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU: </strong>NVIDIA A100-SXM4-80GB\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU memory: </strong> 80.00 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" </table>\n", | |
" </details>\n", | |
" </div>\n", | |
" </div>\n", | |
" \n", | |
" <div style=\"margin-bottom: 20px;\">\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #DBF5FF; border: 3px solid #4CC9FF; border-radius: 5px; position: absolute;\"> </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <details>\n", | |
" <summary>\n", | |
" <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 6</h4>\n", | |
" </summary>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Comm: </strong> tcp://127.0.0.1:43897\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total threads: </strong> 1\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard: </strong> <a href=\"http://127.0.0.1:41089/status\" target=\"_blank\">http://127.0.0.1:41089/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Memory: </strong> 237.50 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Nanny: </strong> tcp://127.0.0.1:39879\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\"></td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td colspan=\"2\" style=\"text-align: left;\">\n", | |
" <strong>Local directory: </strong> /tmp/dask-scratch-space/worker-_0f_tcs6\n", | |
" </td>\n", | |
" </tr>\n", | |
"\n", | |
" \n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU: </strong>NVIDIA A100-SXM4-80GB\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU memory: </strong> 80.00 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" </table>\n", | |
" </details>\n", | |
" </div>\n", | |
" </div>\n", | |
" \n", | |
" <div style=\"margin-bottom: 20px;\">\n", | |
" <div style=\"width: 24px; height: 24px; background-color: #DBF5FF; border: 3px solid #4CC9FF; border-radius: 5px; position: absolute;\"> </div>\n", | |
" <div style=\"margin-left: 48px;\">\n", | |
" <details>\n", | |
" <summary>\n", | |
" <h4 style=\"margin-bottom: 0px; display: inline;\">Worker: 7</h4>\n", | |
" </summary>\n", | |
" <table style=\"width: 100%; text-align: left;\">\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Comm: </strong> tcp://127.0.0.1:39193\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Total threads: </strong> 1\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Dashboard: </strong> <a href=\"http://127.0.0.1:37869/status\" target=\"_blank\">http://127.0.0.1:37869/status</a>\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Memory: </strong> 237.50 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>Nanny: </strong> tcp://127.0.0.1:44741\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\"></td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td colspan=\"2\" style=\"text-align: left;\">\n", | |
" <strong>Local directory: </strong> /tmp/dask-scratch-space/worker-j7bszc0g\n", | |
" </td>\n", | |
" </tr>\n", | |
"\n", | |
" \n", | |
" <tr>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU: </strong>NVIDIA A100-SXM4-80GB\n", | |
" </td>\n", | |
" <td style=\"text-align: left;\">\n", | |
" <strong>GPU memory: </strong> 80.00 GiB\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" </table>\n", | |
" </details>\n", | |
" </div>\n", | |
" </div>\n", | |
" \n", | |
"\n", | |
" </details>\n", | |
"</div>\n", | |
"\n", | |
" </details>\n", | |
" </div>\n", | |
"</div>\n", | |
" </details>\n", | |
" \n", | |
"\n", | |
" </div>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
"<Client: 'tcp://127.0.0.1:42393' processes=8 threads=8, memory=1.86 TiB>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=preprocessing_gpus)\n", | |
"client = Client(cluster) \n", | |
"\n", | |
"set_mem()\n", | |
"client.run(set_mem)\n", | |
"\n", | |
"client" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "b4ac3268-f54b-40c6-9f56-6f0d942c6753", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import cudf\n", | |
"import cuml\n", | |
"import cupy as cp\n", | |
"from cuml.dask.common.part_utils import _extract_partitions\n", | |
"import math\n", | |
"from cuml.internals.memory_utils import with_cupy_rmm\n", | |
"import h5py\n", | |
"import rapids_singlecell as rsc" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "c8541cf5-9a40-4b1a-baea-309bfdc8d15a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from cupyx.scipy import sparse" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "87cb9613-bed9-4b7c-accc-e88b3d4bc415", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "3fc9a948-80ea-457f-8ded-151505a23418", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<KeysViewHDF5 ['Biotype', 'Chromosome', 'End', 'Gene', 'Start', 'ensembl_ids', 'feature_biotype', 'feature_is_filtered', 'feature_length', 'feature_name', 'feature_reference']>\n", | |
"CPU times: user 6.52 s, sys: 3.24 s, total: 9.76 s\n", | |
"Wall time: 1min 28s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
" <tr>\n", | |
" <td>\n", | |
" <table style=\"border-collapse: collapse;\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <td> </td>\n", | |
" <th> Array </th>\n", | |
" <th> Chunk </th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Bytes </th>\n", | |
" <td> 548.59 GiB </td>\n", | |
" <td> 11.06 GiB </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Shape </th>\n", | |
" <td> (2480956, 59357) </td>\n", | |
" <td> (50000, 59357) </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Dask graph </th>\n", | |
" <td colspan=\"2\"> 50 chunks in 1 graph layer </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Data type </th>\n", | |
" <td colspan=\"2\"> float32 numpy.ndarray </td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
" </table>\n", | |
" </td>\n", | |
" <td>\n", | |
" <svg width=\"80\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"30\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"4\" x2=\"30\" y2=\"4\" />\n", | |
" <line x1=\"0\" y1=\"12\" x2=\"30\" y2=\"12\" />\n", | |
" <line x1=\"0\" y1=\"16\" x2=\"30\" y2=\"16\" />\n", | |
" <line x1=\"0\" y1=\"24\" x2=\"30\" y2=\"24\" />\n", | |
" <line x1=\"0\" y1=\"31\" x2=\"30\" y2=\"31\" />\n", | |
" <line x1=\"0\" y1=\"36\" x2=\"30\" y2=\"36\" />\n", | |
" <line x1=\"0\" y1=\"43\" x2=\"30\" y2=\"43\" />\n", | |
" <line x1=\"0\" y1=\"50\" x2=\"30\" y2=\"50\" />\n", | |
" <line x1=\"0\" y1=\"55\" x2=\"30\" y2=\"55\" />\n", | |
" <line x1=\"0\" y1=\"62\" x2=\"30\" y2=\"62\" />\n", | |
" <line x1=\"0\" y1=\"67\" x2=\"30\" y2=\"67\" />\n", | |
" <line x1=\"0\" y1=\"74\" x2=\"30\" y2=\"74\" />\n", | |
" <line x1=\"0\" y1=\"82\" x2=\"30\" y2=\"82\" />\n", | |
" <line x1=\"0\" y1=\"87\" x2=\"30\" y2=\"87\" />\n", | |
" <line x1=\"0\" y1=\"94\" x2=\"30\" y2=\"94\" />\n", | |
" <line x1=\"0\" y1=\"101\" x2=\"30\" y2=\"101\" />\n", | |
" <line x1=\"0\" y1=\"106\" x2=\"30\" y2=\"106\" />\n", | |
" <line x1=\"0\" y1=\"113\" x2=\"30\" y2=\"113\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"30\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"30\" y1=\"0\" x2=\"30\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.0,0.0 30.05538364053684,0.0 30.05538364053684,120.0 0.0,120.0\" style=\"fill:#8B4903A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"15.027692\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >59357</text>\n", | |
" <text x=\"50.055384\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(-90,50.055384,60.000000)\">2480956</text>\n", | |
"</svg>\n", | |
" </td>\n", | |
" </tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<concatenate, shape=(2480956, 59357), dtype=float32, chunksize=(50000, 59357), chunktype=numpy.ndarray>" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"def read_with_filter(client,\n", | |
" sample_file, batch_size = 50000):\n", | |
" \"\"\"\n", | |
" Reads an h5ad file and applies cell and geans count filter. Dask Array is\n", | |
" used allow partitioning the input file. This function supports multi-GPUs.\n", | |
" \"\"\"\n", | |
"\n", | |
" # Path in h5 file\n", | |
" _data = '/X/data'\n", | |
" _index = '/X/indices'\n", | |
" _indprt = '/X/indptr'\n", | |
" _genes = '/var/ensembl_ids'\n", | |
" #_genes = '/var/ensembl_id'\n", | |
" #_genes = '/var/_index'\n", | |
" #_genes = '/var/feature_id'\n", | |
" _barcodes = '/obs/_index'\n", | |
"\n", | |
" @dask.delayed\n", | |
" def _read_partition_to_sparse_matrix(sample_file,\n", | |
" total_cols, batch_start, batch_end,\n", | |
" ):\n", | |
" with h5py.File(sample_file, 'r') as h5f:\n", | |
" indptrs = h5f[_indprt]\n", | |
" start_ptr = indptrs[batch_start]\n", | |
" end_ptr = indptrs[batch_end]\n", | |
"\n", | |
" # Read all things data and index\n", | |
" sub_data = cp.array(h5f[_data][start_ptr:end_ptr])\n", | |
" sub_indices = cp.array(h5f[_index][start_ptr:end_ptr])\n", | |
"\n", | |
" # recompute the row pointer for the partial dataset\n", | |
" sub_indptrs = cp.array(indptrs[batch_start:(batch_end + 1)])\n", | |
" sub_indptrs = sub_indptrs - sub_indptrs[0]\n", | |
"\n", | |
" # Reconstruct partial sparse array\n", | |
" partial_sparse_array = cp.sparse.csr_matrix(\n", | |
" (sub_data, sub_indices, sub_indptrs),\n", | |
" shape=(batch_end - batch_start, total_cols))\n", | |
" \n", | |
" return partial_sparse_array\n", | |
"\n", | |
"\n", | |
" with h5py.File(sample_file, 'r') as h5f:\n", | |
" # Compute the number of cells to read\n", | |
" indptr = h5f[_indprt]\n", | |
" vars= h5f[\"/var/\"]\n", | |
" print(vars.keys())\n", | |
" genes = cudf.Series(h5f[_genes], dtype=cp.dtype('object'))\n", | |
"\n", | |
" total_cols = genes.shape[0]\n", | |
" max_cells = indptr.shape[0] - 1\n", | |
"\n", | |
" dls = []\n", | |
" for batch_start in range(0, max_cells, batch_size):\n", | |
" actual_batch_size = min(batch_size, max_cells - batch_start)\n", | |
" dls.append(dask.array.from_delayed(\n", | |
" (_read_partition_to_sparse_matrix)\n", | |
" (sample_file,\n", | |
" total_cols,\n", | |
" batch_start,\n", | |
" batch_start + actual_batch_size),\n", | |
" dtype=cp.float32,\n", | |
" shape=(actual_batch_size, total_cols)))\n", | |
"\n", | |
" dask_sparse_arr = dask.array.concatenate(dls)\n", | |
" dask_sparse_arr = dask_sparse_arr.persist()\n", | |
" return dask_sparse_arr\n", | |
"\n", | |
"dask_sparse_arr = read_with_filter(client, \"h5/human_brain.h5ad\", batch_size=50000)\n", | |
"dask_sparse_arr = dask_sparse_arr.persist()\n", | |
"\n", | |
"dask_sparse_arr.compute_chunk_sizes()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "a7976fd6-24d0-4dc9-a2d9-ab2bf0fb05c8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@with_cupy_rmm\n", | |
"def calc_qc_dask(client, csr_matrix, axis=0):\n", | |
" '''\n", | |
" Implements sum operation for dask array when the backend is cupy sparse csr matrix\n", | |
" '''\n", | |
" from rapids_singlecell.preprocessing._kernels._qc_kernels import _sparse_qc_csr\n", | |
" sparse_qc_csr = _sparse_qc_csr(csr_matrix.dtype)\n", | |
" sparse_qc_csr.compile()\n", | |
"\n", | |
" def __qc_calc(X):\n", | |
" sums_cells = cp.zeros(X.shape[0], dtype=X.dtype)\n", | |
" sums_genes = cp.zeros((X.shape[1],1), dtype=X.dtype)\n", | |
" cell_ex = cp.zeros(X.shape[0], dtype=cp.int32)\n", | |
" gene_ex = cp.zeros((X.shape[1],1), dtype=cp.int32)\n", | |
" block = (32,)\n", | |
" grid = (int(math.ceil(X.shape[0] / block[0])),)\n", | |
" sparse_qc_csr(\n", | |
" grid,\n", | |
" block,\n", | |
" (\n", | |
" X.indptr,\n", | |
" X.indices,\n", | |
" X.data,\n", | |
" sums_cells,\n", | |
" sums_genes,\n", | |
" cell_ex,\n", | |
" gene_ex,\n", | |
" X.shape[0],\n", | |
" ),\n", | |
" )\n", | |
" return sums_cells,sums_genes,cell_ex,gene_ex\n", | |
" parts = client.sync(_extract_partitions, csr_matrix)\n", | |
" futures = [client.submit(__qc_calc, part, workers=[w]) for w, part in parts]\n", | |
" # Gather results from futures\n", | |
" results = client.gather(futures)\n", | |
"\n", | |
" # Initialize lists to hold the Dask arrays\n", | |
" sums_cells_objs = []\n", | |
" sums_genes_objs = []\n", | |
" cell_ex_objs = []\n", | |
" gene_ex_objs = []\n", | |
"\n", | |
" # Process each result\n", | |
" for sums_cells, sums_genes, cell_ex, gene_ex in results:\n", | |
" # Append the arrays to their respective lists as Dask arrays\n", | |
" sums_cells_objs.append(dask.array.from_array(sums_cells, chunks=sums_cells.shape))\n", | |
" sums_genes_objs.append(dask.array.from_array(sums_genes, chunks=sums_genes.shape))\n", | |
" cell_ex_objs.append(dask.array.from_array(cell_ex, chunks=cell_ex.shape))\n", | |
" gene_ex_objs.append(dask.array.from_array(gene_ex, chunks=gene_ex.shape))\n", | |
" sums_cells = dask.array.concatenate(sums_cells_objs).compute().ravel()\n", | |
" sums_genes = dask.array.concatenate(sums_genes_objs,axis=1).compute().sum(axis=1).ravel()\n", | |
" cell_ex = dask.array.concatenate(cell_ex_objs).compute().ravel()\n", | |
" gene_ex = dask.array.concatenate(gene_ex_objs,axis=1).compute().sum(axis=1).ravel()\n", | |
" return sums_cells, sums_genes, cell_ex, gene_ex\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "2b9a4f2c-7d79-4316-b3b5-3b447322b1af", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/sedi10/conda/envs/rapids-23.12/lib/python3.10/site-packages/distributed/client.py:3163: UserWarning: Sending large graph of size 11.34 MiB.\n", | |
"This may cause some slowdown.\n", | |
"Consider scattering data ahead of time and using futures.\n", | |
" warnings.warn(\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 825 ms, sys: 214 ms, total: 1.04 s\n", | |
"Wall time: 1.29 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"sums_cells, sums_genes, cell_ex, gene_ex = calc_qc_dask(client, dask_sparse_arr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "3250452d-fa10-47bf-9f00-9ef6e52fdf56", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 4.36 ms, sys: 2.5 ms, total: 6.86 ms\n", | |
"Wall time: 18.4 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"qc_cells = (cell_ex <= 10000) & (200 <= cell_ex)\n", | |
"qc_genes = (10 <= gene_ex)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "7772b50c-2e75-42e0-be7b-39db3c5bd8c6", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<timed exec>:1: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", | |
"chunk and silence this warning, set the option\n", | |
" >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", | |
" ... array[indexer]\n", | |
"\n", | |
"To avoid creating the large chunks, set the option\n", | |
" >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", | |
" ... array[indexer]\n", | |
"/home/sedi10/conda/envs/rapids-23.12/lib/python3.10/site-packages/distributed/client.py:3163: UserWarning: Sending large graph of size 18.56 MiB.\n", | |
"This may cause some slowdown.\n", | |
"Consider scattering data ahead of time and using futures.\n", | |
" warnings.warn(\n", | |
"<timed exec>:2: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", | |
"chunk and silence this warning, set the option\n", | |
" >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", | |
" ... array[indexer]\n", | |
"\n", | |
"To avoid creating the large chunks, set the option\n", | |
" >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", | |
" ... array[indexer]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 824 ms, sys: 626 ms, total: 1.45 s\n", | |
"Wall time: 15.5 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
" <tr>\n", | |
" <td>\n", | |
" <table style=\"border-collapse: collapse;\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <td> </td>\n", | |
" <th> Array </th>\n", | |
" <th> Chunk </th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Bytes </th>\n", | |
" <td> 426.35 GiB </td>\n", | |
" <td> 8.77 GiB </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Shape </th>\n", | |
" <td> (2432024, 47059) </td>\n", | |
" <td> (50000, 47059) </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Dask graph </th>\n", | |
" <td colspan=\"2\"> 50 chunks in 1 graph layer </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Data type </th>\n", | |
" <td colspan=\"2\"> float32 numpy.ndarray </td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
" </table>\n", | |
" </td>\n", | |
" <td>\n", | |
" <svg width=\"78\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"28\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"4\" x2=\"28\" y2=\"4\" />\n", | |
" <line x1=\"0\" y1=\"12\" x2=\"28\" y2=\"12\" />\n", | |
" <line x1=\"0\" y1=\"17\" x2=\"28\" y2=\"17\" />\n", | |
" <line x1=\"0\" y1=\"24\" x2=\"28\" y2=\"24\" />\n", | |
" <line x1=\"0\" y1=\"31\" x2=\"28\" y2=\"31\" />\n", | |
" <line x1=\"0\" y1=\"36\" x2=\"28\" y2=\"36\" />\n", | |
" <line x1=\"0\" y1=\"43\" x2=\"28\" y2=\"43\" />\n", | |
" <line x1=\"0\" y1=\"51\" x2=\"28\" y2=\"51\" />\n", | |
" <line x1=\"0\" y1=\"55\" x2=\"28\" y2=\"55\" />\n", | |
" <line x1=\"0\" y1=\"63\" x2=\"28\" y2=\"63\" />\n", | |
" <line x1=\"0\" y1=\"68\" x2=\"28\" y2=\"68\" />\n", | |
" <line x1=\"0\" y1=\"75\" x2=\"28\" y2=\"75\" />\n", | |
" <line x1=\"0\" y1=\"82\" x2=\"28\" y2=\"82\" />\n", | |
" <line x1=\"0\" y1=\"87\" x2=\"28\" y2=\"87\" />\n", | |
" <line x1=\"0\" y1=\"94\" x2=\"28\" y2=\"94\" />\n", | |
" <line x1=\"0\" y1=\"101\" x2=\"28\" y2=\"101\" />\n", | |
" <line x1=\"0\" y1=\"106\" x2=\"28\" y2=\"106\" />\n", | |
" <line x1=\"0\" y1=\"113\" x2=\"28\" y2=\"113\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"28\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"28\" y1=\"0\" x2=\"28\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.0,0.0 28.845202251346922,0.0 28.845202251346922,120.0 0.0,120.0\" style=\"fill:#8B4903A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"14.422601\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >47059</text>\n", | |
" <text x=\"48.845202\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(-90,48.845202,60.000000)\">2432024</text>\n", | |
"</svg>\n", | |
" </td>\n", | |
" </tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<getitem, shape=(2432024, 47059), dtype=float32, chunksize=(50000, 47059), chunktype=numpy.ndarray>" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"dask_sparse_arr = dask_sparse_arr[qc_cells,:].persist()\n", | |
"dask_sparse_arr = dask_sparse_arr[:,qc_genes].persist()\n", | |
"dask_sparse_arr.compute_chunk_sizes()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "fcaa4d7c-743c-449a-b895-63af2cc5c29b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 69 ms, sys: 4.14 ms, total: 73.2 ms\n", | |
"Wall time: 84.6 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
" <tr>\n", | |
" <td>\n", | |
" <table style=\"border-collapse: collapse;\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <td> </td>\n", | |
" <th> Array </th>\n", | |
" <th> Chunk </th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Bytes </th>\n", | |
" <td> 426.35 GiB </td>\n", | |
" <td> 8.77 GiB </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Shape </th>\n", | |
" <td> (2432024, 47059) </td>\n", | |
" <td> (50000, 47059) </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Dask graph </th>\n", | |
" <td colspan=\"2\"> 50 chunks in 1 graph layer </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Data type </th>\n", | |
" <td colspan=\"2\"> float32 cupy.ndarray </td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
" </table>\n", | |
" </td>\n", | |
" <td>\n", | |
" <svg width=\"78\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"28\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"4\" x2=\"28\" y2=\"4\" />\n", | |
" <line x1=\"0\" y1=\"12\" x2=\"28\" y2=\"12\" />\n", | |
" <line x1=\"0\" y1=\"17\" x2=\"28\" y2=\"17\" />\n", | |
" <line x1=\"0\" y1=\"24\" x2=\"28\" y2=\"24\" />\n", | |
" <line x1=\"0\" y1=\"31\" x2=\"28\" y2=\"31\" />\n", | |
" <line x1=\"0\" y1=\"36\" x2=\"28\" y2=\"36\" />\n", | |
" <line x1=\"0\" y1=\"43\" x2=\"28\" y2=\"43\" />\n", | |
" <line x1=\"0\" y1=\"51\" x2=\"28\" y2=\"51\" />\n", | |
" <line x1=\"0\" y1=\"55\" x2=\"28\" y2=\"55\" />\n", | |
" <line x1=\"0\" y1=\"63\" x2=\"28\" y2=\"63\" />\n", | |
" <line x1=\"0\" y1=\"68\" x2=\"28\" y2=\"68\" />\n", | |
" <line x1=\"0\" y1=\"75\" x2=\"28\" y2=\"75\" />\n", | |
" <line x1=\"0\" y1=\"82\" x2=\"28\" y2=\"82\" />\n", | |
" <line x1=\"0\" y1=\"87\" x2=\"28\" y2=\"87\" />\n", | |
" <line x1=\"0\" y1=\"94\" x2=\"28\" y2=\"94\" />\n", | |
" <line x1=\"0\" y1=\"101\" x2=\"28\" y2=\"101\" />\n", | |
" <line x1=\"0\" y1=\"106\" x2=\"28\" y2=\"106\" />\n", | |
" <line x1=\"0\" y1=\"113\" x2=\"28\" y2=\"113\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"28\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"28\" y1=\"0\" x2=\"28\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.0,0.0 28.845202251346922,0.0 28.845202251346922,120.0 0.0,120.0\" style=\"fill:#8B4903A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"14.422601\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >47059</text>\n", | |
" <text x=\"48.845202\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(-90,48.845202,60.000000)\">2432024</text>\n", | |
"</svg>\n", | |
" </td>\n", | |
" </tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<lambda, shape=(2432024, 47059), dtype=float32, chunksize=(50000, 47059), chunktype=cupy.ndarray>" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"from rapids_singlecell.preprocessing._kernels._norm_kernel import _mul_csr\n", | |
"\n", | |
"mul_kernel = _mul_csr(dask_sparse_arr.dtype)\n", | |
"mul_kernel.compile()\n", | |
"def norm(X, target_sum = 10000):\n", | |
" mul_kernel(\n", | |
" (math.ceil(X.shape[0] / 128),),\n", | |
" (128,),\n", | |
" (X.indptr, X.data, X.shape[0], int(target_sum)),\n", | |
" )\n", | |
" return X\n", | |
"dask_sparse_arr = dask_sparse_arr.map_blocks(lambda X: norm(X),dtype=cp.float32,meta=cp.array((0,),dtype=dask_sparse_arr.dtype))\n", | |
"dask_sparse_arr = dask_sparse_arr.persist()\n", | |
"dask_sparse_arr.compute_chunk_sizes()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "28107a4a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 74.7 ms, sys: 10.7 ms, total: 85.4 ms\n", | |
"Wall time: 194 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
" <tr>\n", | |
" <td>\n", | |
" <table style=\"border-collapse: collapse;\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <td> </td>\n", | |
" <th> Array </th>\n", | |
" <th> Chunk </th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Bytes </th>\n", | |
" <td> 426.35 GiB </td>\n", | |
" <td> 8.77 GiB </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Shape </th>\n", | |
" <td> (2432024, 47059) </td>\n", | |
" <td> (50000, 47059) </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Dask graph </th>\n", | |
" <td colspan=\"2\"> 50 chunks in 1 graph layer </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Data type </th>\n", | |
" <td colspan=\"2\"> float32 cupy.ndarray </td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
" </table>\n", | |
" </td>\n", | |
" <td>\n", | |
" <svg width=\"78\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"28\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"4\" x2=\"28\" y2=\"4\" />\n", | |
" <line x1=\"0\" y1=\"12\" x2=\"28\" y2=\"12\" />\n", | |
" <line x1=\"0\" y1=\"17\" x2=\"28\" y2=\"17\" />\n", | |
" <line x1=\"0\" y1=\"24\" x2=\"28\" y2=\"24\" />\n", | |
" <line x1=\"0\" y1=\"31\" x2=\"28\" y2=\"31\" />\n", | |
" <line x1=\"0\" y1=\"36\" x2=\"28\" y2=\"36\" />\n", | |
" <line x1=\"0\" y1=\"43\" x2=\"28\" y2=\"43\" />\n", | |
" <line x1=\"0\" y1=\"51\" x2=\"28\" y2=\"51\" />\n", | |
" <line x1=\"0\" y1=\"55\" x2=\"28\" y2=\"55\" />\n", | |
" <line x1=\"0\" y1=\"63\" x2=\"28\" y2=\"63\" />\n", | |
" <line x1=\"0\" y1=\"68\" x2=\"28\" y2=\"68\" />\n", | |
" <line x1=\"0\" y1=\"75\" x2=\"28\" y2=\"75\" />\n", | |
" <line x1=\"0\" y1=\"82\" x2=\"28\" y2=\"82\" />\n", | |
" <line x1=\"0\" y1=\"87\" x2=\"28\" y2=\"87\" />\n", | |
" <line x1=\"0\" y1=\"94\" x2=\"28\" y2=\"94\" />\n", | |
" <line x1=\"0\" y1=\"101\" x2=\"28\" y2=\"101\" />\n", | |
" <line x1=\"0\" y1=\"106\" x2=\"28\" y2=\"106\" />\n", | |
" <line x1=\"0\" y1=\"113\" x2=\"28\" y2=\"113\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"28\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"28\" y1=\"0\" x2=\"28\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.0,0.0 28.845202251346922,0.0 28.845202251346922,120.0 0.0,120.0\" style=\"fill:#8B4903A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"14.422601\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >47059</text>\n", | |
" <text x=\"48.845202\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(-90,48.845202,60.000000)\">2432024</text>\n", | |
"</svg>\n", | |
" </td>\n", | |
" </tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<lambda, shape=(2432024, 47059), dtype=float32, chunksize=(50000, 47059), chunktype=cupy.ndarray>" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"dask_sparse_arr = dask_sparse_arr.map_blocks(lambda X: X.log1p(),dtype=cp.float32, meta=cp.array((0,),dtype=dask_sparse_arr.dtype))\n", | |
"dask_sparse_arr = dask_sparse_arr.persist()\n", | |
"dask_sparse_arr.compute_chunk_sizes()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "cf816ba6-55d2-45ae-aad9-5f3b904aae01", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"import warnings" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "39316898-d771-45a1-a675-e60279a96ab5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"@with_cupy_rmm\n", | |
"def get_mean_var_dask(client, csr_matrix):\n", | |
" '''\n", | |
" Implements sum operation for dask array when the backend is cupy sparse csr matrix\n", | |
" '''\n", | |
" from rapids_singlecell.preprocessing._kernels._mean_var_kernel import _get_mean_var_minor\n", | |
" get_mean_var_minor = _get_mean_var_minor(csr_matrix.dtype)\n", | |
" get_mean_var_minor.compile()\n", | |
"\n", | |
" def __mean_var(X, minor, major):\n", | |
" mean = cp.zeros((minor,1), dtype=cp.float64)\n", | |
" var = cp.zeros((minor,1), dtype=cp.float64)\n", | |
" block = (32,)\n", | |
" grid = (int(math.ceil(X.nnz / block[0])),)\n", | |
" get_mean_var_minor(grid, block, (X.indices, X.data, mean, var, major, X.nnz))\n", | |
" return mean,var\n", | |
" major = csr_matrix.shape[0]\n", | |
" minor = csr_matrix.shape[1]\n", | |
" parts = client.sync(_extract_partitions, csr_matrix)\n", | |
" futures = [client.submit(__mean_var, part,minor, major, workers=[w]) for w, part in parts]\n", | |
" # Gather results from futures\n", | |
" results = client.gather(futures)\n", | |
"\n", | |
" # Initialize lists to hold the Dask arrays\n", | |
" means_objs = []\n", | |
" var_objs = []\n", | |
"\n", | |
" # Process each result\n", | |
" for means, vars in results:\n", | |
" # Append the arrays to their respective lists as Dask arrays\n", | |
" means_objs.append(dask.array.from_array(means, chunks=means.shape))\n", | |
" var_objs.append(dask.array.from_array(vars, chunks=vars.shape))\n", | |
" mean = dask.array.concatenate(means_objs,axis=1).compute().sum(axis=1).ravel()\n", | |
" var = dask.array.concatenate(var_objs,axis=1).compute().sum(axis=1).ravel()\n", | |
"\n", | |
" var = (var - mean**2) * (major / (major - 1))\n", | |
" return mean, var\n", | |
"\n", | |
"def highly_variable_genes_filter(client,\n", | |
" data_mat,\n", | |
" n_top_genes=None):\n", | |
"\n", | |
"\n", | |
" mean, variance = get_mean_var_dask(client, data_mat)\n", | |
" dispersion = variance / mean\n", | |
"\n", | |
" df = pd.DataFrame()\n", | |
" df['genes'] = np.arange(data_mat.shape[1])\n", | |
" df['means'] = mean.tolist()\n", | |
" df['dispersions'] = dispersion.tolist()\n", | |
" df['mean_bin'] = pd.cut(\n", | |
" df['means'],\n", | |
" np.r_[-np.inf, np.percentile(df['means'], np.arange(10, 105, 5)), np.inf],\n", | |
" )\n", | |
"\n", | |
" disp_grouped = df.groupby('mean_bin')['dispersions']\n", | |
" disp_median_bin = disp_grouped.median()\n", | |
"\n", | |
" with warnings.catch_warnings():\n", | |
" from statsmodels import robust\n", | |
" warnings.simplefilter('ignore')\n", | |
" disp_mad_bin = disp_grouped.apply(robust.mad)\n", | |
" df['dispersions_norm'] = (\n", | |
" df['dispersions'].values - disp_median_bin[df['mean_bin'].values].values\n", | |
" ) / disp_mad_bin[df['mean_bin'].values].values\n", | |
"\n", | |
" dispersion_norm = df['dispersions_norm'].values\n", | |
"\n", | |
" dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)]\n", | |
" dispersion_norm[::-1].sort()\n", | |
"\n", | |
" if n_top_genes > df.shape[0]:\n", | |
" n_top_genes = df.shape[0]\n", | |
"\n", | |
" disp_cut_off = dispersion_norm[n_top_genes - 1]\n", | |
" vaiable_genes = np.nan_to_num(df['dispersions_norm'].values) >= disp_cut_off\n", | |
"\n", | |
" return vaiable_genes\n", | |
" \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "bb84f792-2b06-4bfa-8cca-3b4eaec042ea", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/sedi10/conda/envs/rapids-23.12/lib/python3.10/site-packages/distributed/client.py:3163: UserWarning: Sending large graph of size 17.97 MiB.\n", | |
"This may cause some slowdown.\n", | |
"Consider scattering data ahead of time and using futures.\n", | |
" warnings.warn(\n", | |
"<timed exec>:2: PerformanceWarning: Slicing is producing a large chunk. To accept the large\n", | |
"chunk and silence this warning, set the option\n", | |
" >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):\n", | |
" ... array[indexer]\n", | |
"\n", | |
"To avoid creating the large chunks, set the option\n", | |
" >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):\n", | |
" ... array[indexer]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.6 s, sys: 503 ms, total: 2.11 s\n", | |
"Wall time: 10.3 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
" <tr>\n", | |
" <td>\n", | |
" <table style=\"border-collapse: collapse;\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <td> </td>\n", | |
" <th> Array </th>\n", | |
" <th> Chunk </th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Bytes </th>\n", | |
" <td> 45.30 GiB </td>\n", | |
" <td> 0.93 GiB </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Shape </th>\n", | |
" <td> (2432024, 5000) </td>\n", | |
" <td> (50000, 5000) </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Dask graph </th>\n", | |
" <td colspan=\"2\"> 50 chunks in 1 graph layer </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Data type </th>\n", | |
" <td colspan=\"2\"> float32 cupy.ndarray </td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
" </table>\n", | |
" </td>\n", | |
" <td>\n", | |
" <svg width=\"75\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"25\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"4\" x2=\"25\" y2=\"4\" />\n", | |
" <line x1=\"0\" y1=\"12\" x2=\"25\" y2=\"12\" />\n", | |
" <line x1=\"0\" y1=\"17\" x2=\"25\" y2=\"17\" />\n", | |
" <line x1=\"0\" y1=\"24\" x2=\"25\" y2=\"24\" />\n", | |
" <line x1=\"0\" y1=\"31\" x2=\"25\" y2=\"31\" />\n", | |
" <line x1=\"0\" y1=\"36\" x2=\"25\" y2=\"36\" />\n", | |
" <line x1=\"0\" y1=\"43\" x2=\"25\" y2=\"43\" />\n", | |
" <line x1=\"0\" y1=\"51\" x2=\"25\" y2=\"51\" />\n", | |
" <line x1=\"0\" y1=\"55\" x2=\"25\" y2=\"55\" />\n", | |
" <line x1=\"0\" y1=\"63\" x2=\"25\" y2=\"63\" />\n", | |
" <line x1=\"0\" y1=\"68\" x2=\"25\" y2=\"68\" />\n", | |
" <line x1=\"0\" y1=\"75\" x2=\"25\" y2=\"75\" />\n", | |
" <line x1=\"0\" y1=\"82\" x2=\"25\" y2=\"82\" />\n", | |
" <line x1=\"0\" y1=\"87\" x2=\"25\" y2=\"87\" />\n", | |
" <line x1=\"0\" y1=\"94\" x2=\"25\" y2=\"94\" />\n", | |
" <line x1=\"0\" y1=\"101\" x2=\"25\" y2=\"101\" />\n", | |
" <line x1=\"0\" y1=\"106\" x2=\"25\" y2=\"106\" />\n", | |
" <line x1=\"0\" y1=\"113\" x2=\"25\" y2=\"113\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"25\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"25\" y1=\"0\" x2=\"25\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.0,0.0 25.412616514582485,0.0 25.412616514582485,120.0 0.0,120.0\" style=\"fill:#8B4903A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"12.706308\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >5000</text>\n", | |
" <text x=\"45.412617\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(-90,45.412617,60.000000)\">2432024</text>\n", | |
"</svg>\n", | |
" </td>\n", | |
" </tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<getitem, shape=(2432024, 5000), dtype=float32, chunksize=(50000, 5000), chunktype=cupy.ndarray>" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"genes = highly_variable_genes_filter(client, dask_sparse_arr, n_top_genes=5000)\n", | |
"dask_sparse_arr = dask_sparse_arr[:,genes].persist()\n", | |
"dask_sparse_arr.compute_chunk_sizes()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "75761f4c-c178-49b8-a716-aefc5f31b276", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.18 s, sys: 1.07 s, total: 2.26 s\n", | |
"Wall time: 28.2 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
" <tr>\n", | |
" <td>\n", | |
" <table style=\"border-collapse: collapse;\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <td> </td>\n", | |
" <th> Array </th>\n", | |
" <th> Chunk </th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Bytes </th>\n", | |
" <td> 90.60 GiB </td>\n", | |
" <td> 11.32 GiB </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Shape </th>\n", | |
" <td> (2432024, 5000) </td>\n", | |
" <td> (304003, 5000) </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Dask graph </th>\n", | |
" <td colspan=\"2\"> 8 chunks in 1 graph layer </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Data type </th>\n", | |
" <td colspan=\"2\"> float64 cupy.ndarray </td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
" </table>\n", | |
" </td>\n", | |
" <td>\n", | |
" <svg width=\"75\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"25\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"15\" x2=\"25\" y2=\"15\" />\n", | |
" <line x1=\"0\" y1=\"30\" x2=\"25\" y2=\"30\" />\n", | |
" <line x1=\"0\" y1=\"45\" x2=\"25\" y2=\"45\" />\n", | |
" <line x1=\"0\" y1=\"60\" x2=\"25\" y2=\"60\" />\n", | |
" <line x1=\"0\" y1=\"75\" x2=\"25\" y2=\"75\" />\n", | |
" <line x1=\"0\" y1=\"90\" x2=\"25\" y2=\"90\" />\n", | |
" <line x1=\"0\" y1=\"105\" x2=\"25\" y2=\"105\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"25\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"25\" y1=\"0\" x2=\"25\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.0,0.0 25.412616514582485,0.0 25.412616514582485,120.0 0.0,120.0\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"12.706308\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >5000</text>\n", | |
" <text x=\"45.412617\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(-90,45.412617,60.000000)\">2432024</text>\n", | |
"</svg>\n", | |
" </td>\n", | |
" </tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<rechunk-merge, shape=(2432024, 5000), dtype=float64, chunksize=(304003, 5000), chunktype=cupy.ndarray>" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"dense_array = dask_sparse_arr.map_blocks(lambda x: x.todense().astype(cp.float64), dtype=cp.float64,meta=cp.array((0,),dtype=cp.float64))\n", | |
"\n", | |
"n_rows = dense_array.shape[0]\n", | |
"n_cols = dense_array.shape[1]\n", | |
"cols_per_worker = int(n_rows / 8)\n", | |
"dense_array = dense_array.rechunk((cols_per_worker, n_cols)).persist()\n", | |
"\n", | |
"dense_array.compute_chunk_sizes()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "ae91f9cd-bbaf-42f9-a75b-d2a363f457a2", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.17 s, sys: 1.12 s, total: 2.29 s\n", | |
"Wall time: 36.4 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
" <tr>\n", | |
" <td>\n", | |
" <table style=\"border-collapse: collapse;\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <td> </td>\n", | |
" <th> Array </th>\n", | |
" <th> Chunk </th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Bytes </th>\n", | |
" <td> 1.81 GiB </td>\n", | |
" <td> 231.94 MiB </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Shape </th>\n", | |
" <td> (2432024, 100) </td>\n", | |
" <td> (304003, 100) </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Dask graph </th>\n", | |
" <td colspan=\"2\"> 8 chunks in 19 graph layers </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Data type </th>\n", | |
" <td colspan=\"2\"> float64 cupy.ndarray </td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
" </table>\n", | |
" </td>\n", | |
" <td>\n", | |
" <svg width=\"75\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"25\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"15\" x2=\"25\" y2=\"15\" />\n", | |
" <line x1=\"0\" y1=\"30\" x2=\"25\" y2=\"30\" />\n", | |
" <line x1=\"0\" y1=\"45\" x2=\"25\" y2=\"45\" />\n", | |
" <line x1=\"0\" y1=\"60\" x2=\"25\" y2=\"60\" />\n", | |
" <line x1=\"0\" y1=\"75\" x2=\"25\" y2=\"75\" />\n", | |
" <line x1=\"0\" y1=\"90\" x2=\"25\" y2=\"90\" />\n", | |
" <line x1=\"0\" y1=\"105\" x2=\"25\" y2=\"105\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"25\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"25\" y1=\"0\" x2=\"25\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.0,0.0 25.41261651458249,0.0 25.41261651458249,120.0 0.0,120.0\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"12.706308\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >100</text>\n", | |
" <text x=\"45.412617\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(-90,45.412617,60.000000)\">2432024</text>\n", | |
"</svg>\n", | |
" </td>\n", | |
" </tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<concatenate, shape=(2432024, 100), dtype=float64, chunksize=(304003, 100), chunktype=cupy.ndarray>" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"from cuml.dask.decomposition import PCA\n", | |
"pca_func = PCA(n_components=100)\n", | |
"pca_data_d = pca_func.fit_transform(dense_array)\n", | |
"pca_data_d.compute_chunk_sizes()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "b8ff7564-0b43-40e8-b920-d22ef9943729", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import time" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9f0ece51-4022-4ffc-9b38-ff5ba2480f2a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "776d53f2-8a77-4f0e-addb-444068718821", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from cupyx import cusparse\n", | |
"class PCA_sparse_dask:\n", | |
" def __init__(self, n_components, client, whiten = False) -> None:\n", | |
" self.n_components = n_components\n", | |
" self.client = client\n", | |
" self.whiten = whiten\n", | |
" \n", | |
" def fit(self, x):\n", | |
" if self.n_components is None:\n", | |
" n_rows = x.shape[0]\n", | |
" n_cols = x.shape[1]\n", | |
" self.n_components_ = min(n_rows, n_cols)\n", | |
" else:\n", | |
" self.n_components_ = self.n_components\n", | |
"\n", | |
" self.n_samples_ = x.shape[0]\n", | |
" self.n_features_in_ = x.shape[1] if x.ndim == 2 else 1\n", | |
" self.dtype = x.dtype\n", | |
" covariance, self.mean_, _ = _cov_sparse_dask(self.client, x=x, return_mean=True)\n", | |
" self.explained_variance_, self.components_ = cp.linalg.eigh(\n", | |
" covariance, UPLO=\"U\"\n", | |
" )\n", | |
" # NOTE: We reverse the eigen vector and eigen values here\n", | |
" # because cupy provides them in ascending order. Make a copy otherwise\n", | |
" # it is not C_CONTIGUOUS anymore and would error when converting to\n", | |
" # CumlArray\n", | |
" self.explained_variance_ = self.explained_variance_[::-1]\n", | |
"\n", | |
" self.components_ = cp.flip(self.components_, axis=1)\n", | |
"\n", | |
" self.components_ = self.components_.T[: self.n_components_, :]\n", | |
"\n", | |
" self.explained_variance_ratio_ = self.explained_variance_ / cp.sum(\n", | |
" self.explained_variance_\n", | |
" )\n", | |
" if self.n_components_ < min(self.n_samples_, self.n_features_in_):\n", | |
" self.noise_variance_ = \\\n", | |
" self.explained_variance_[self.n_components_:].mean()\n", | |
" else:\n", | |
" self.noise_variance_ = cp.array([0.0])\n", | |
" self.explained_variance_ = self.explained_variance_[: self.n_components_]\n", | |
"\n", | |
" self.explained_variance_ratio_ = self.explained_variance_ratio_[\n", | |
" : self.n_components_\n", | |
" ]\n", | |
" # Truncating negative explained variance values to 0\n", | |
" self.singular_values_ = \\\n", | |
" cp.where(self.explained_variance_ < 0, 0,\n", | |
" self.explained_variance_)\n", | |
" self.singular_values_ = \\\n", | |
" cp.sqrt(self.singular_values_ * (self.n_samples_ - 1))\n", | |
" return self\n", | |
"\n", | |
" def transform(self, X):\n", | |
"\n", | |
" if self.whiten:\n", | |
" self.components_ *= cp.sqrt(self.n_samples_ - 1)\n", | |
" self.components_ /= self.singular_values_.reshape((-1, 1))\n", | |
" \n", | |
" def _transform(X_part, mean_, components_):\n", | |
" dense = cusparse.csr2dense(X_part)\n", | |
" dense = dense - mean_\n", | |
" X_pca = dense.dot(components_.T)\n", | |
" return X_pca\n", | |
"\n", | |
" X_pca = X.map_blocks(_transform, \n", | |
" mean_=self.mean_, \n", | |
" components_=self.components_, \n", | |
" dtype=X.dtype, \n", | |
" meta=cp.array((0,),dtype=X.dtype))\n", | |
"\n", | |
" if self.whiten:\n", | |
" self.components_ *= self.singular_values_.reshape((-1, 1))\n", | |
" self.components_ *= (1 / cp.sqrt(self.n_samples_ - 1))\n", | |
"\n", | |
" #self.components_ = self.components_.get()\n", | |
" #self.explained_variance_ = self.explained_variance_.get()\n", | |
" #self.explained_variance_ratio_ = self.explained_variance_ratio_.get()\n", | |
" return X_pca.persist()\n", | |
"\n", | |
" def fit_transform(self, X, y=None):\n", | |
" return self.fit(X).transform(X)\n", | |
"\n", | |
" def inverse_transform(self, X, return_sparse=False,\n", | |
" sparse_tol=1e-10):\n", | |
"\n", | |
" # NOTE: All intermediate calculations are done using cupy.ndarray and\n", | |
" # then converted to CumlArray at the end to minimize conversions\n", | |
" # between types\n", | |
"\n", | |
" if self.whiten:\n", | |
" cp.multiply(self.components_,\n", | |
" (1 / cp.sqrt(self.n_samples_ - 1)),\n", | |
" out=self.components_)\n", | |
" cp.multiply(self.components_,\n", | |
" self.singular_values_.reshape((-1, 1)),\n", | |
" out=self.components_)\n", | |
"\n", | |
" def _inv_transform(X_part, mean_, components_):\n", | |
" X_inv = cp.dot(X_part, self.components_)\n", | |
" cp.add(X_inv, self.mean_, out=X_inv)\n", | |
" return X_inv\n", | |
"\n", | |
" X_inv = X.map_blocks(_inv_transform, \n", | |
" mean_=self.mean_, \n", | |
" components_=self.components_, \n", | |
" dtype=X.dtype, \n", | |
" meta=cp.array((0,),dtype=X.dtype))\n", | |
" X_inv = X_inv.persist()\n", | |
"\n", | |
" if self.whiten:\n", | |
" self.components_ /= self.singular_values_.reshape((-1, 1))\n", | |
" self.components_ *= cp.sqrt(self.n_samples_ - 1)\n", | |
"\n", | |
" if return_sparse:\n", | |
" def _ret_sparse(X_part, sparse_tol):\n", | |
" X_part = cp.where(X_part < sparse_tol, 0, X_inv)\n", | |
" X_part = cupyx.scipy.sparse.csr_matrix(X_inv)\n", | |
" return X_part\n", | |
"\n", | |
" X_inv = X_inv.map_blocks(_ret_sparse, sparse_tol=sparse_tol, dtype=X_inv.dtype)\n", | |
" return X_inv.persist()\n", | |
"\n", | |
" return X_inv.persist()\n", | |
"\n", | |
"\n", | |
"@with_cupy_rmm\n", | |
"def _cov_sparse_dask(client, x, return_gram=False, return_mean=False):\n", | |
" \"\"\"\n", | |
" Computes the mean and the covariance of matrix X of\n", | |
" the form Cov(X, X) = E(XX) - E(X)E(X)\n", | |
"\n", | |
" This is a temporary fix for\n", | |
" cuml issue #5475 and cupy issue #7699,\n", | |
" where the operation `x.T.dot(x)` did not work for\n", | |
" larger sparse matrices.\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
"\n", | |
" x : cupyx.scipy.sparse of size (m, n)\n", | |
" return_gram : boolean (default = False)\n", | |
" If True, gram matrix of the form (1 / n) * X.T.dot(X)\n", | |
" will be returned.\n", | |
" When True, a copy will be created\n", | |
" to store the results of the covariance.\n", | |
" When False, the local gram matrix result\n", | |
" will be overwritten\n", | |
" return_mean: boolean (default = False)\n", | |
" If True, the Maximum Likelihood Estimate used to\n", | |
" calculate the mean of X and X will be returned,\n", | |
" of the form (1 / n) * mean(X) and (1 / n) * mean(X)\n", | |
"\n", | |
" Returns\n", | |
" -------\n", | |
"\n", | |
" result : cov(X, X) when return_gram and return_mean are False\n", | |
" cov(X, X), gram(X, X) when return_gram is True,\n", | |
" return_mean is False\n", | |
" cov(X, X), mean(X), mean(X) when return_gram is False,\n", | |
" return_mean is True\n", | |
" cov(X, X), gram(X, X), mean(X), mean(X)\n", | |
" when return_gram is True and return_mean is True\n", | |
" \"\"\"\n", | |
"\n", | |
" from rapids_singlecell.preprocessing._kernels._pca_sparse_kernel import (\n", | |
" _copy_kernel,\n", | |
" _cov_kernel,\n", | |
" _gramm_kernel_csr,\n", | |
" )\n", | |
"\n", | |
" compute_mean_cov = _gramm_kernel_csr(x.dtype)\n", | |
" compute_mean_cov.compile()\n", | |
"\n", | |
" def __gram_block(x_part, n_cols):\n", | |
" gram_matrix = cp.zeros((n_cols, n_cols), dtype=x.dtype)\n", | |
" \n", | |
" block = (128,)\n", | |
" grid = (x_part.shape[0],)\n", | |
" compute_mean_cov(\n", | |
" grid,\n", | |
" block,\n", | |
" (\n", | |
" x_part.indptr,\n", | |
" x_part.indices,\n", | |
" x_part.data,\n", | |
" x_part.shape[0],\n", | |
" n_cols,\n", | |
" gram_matrix,\n", | |
" ),\n", | |
" )\n", | |
" return gram_matrix\n", | |
"\n", | |
" parts = client.sync(_extract_partitions, x)\n", | |
" futures = [client.submit(__gram_block, part,x.shape[1], workers=[w]) for w, part in parts]\n", | |
" # Gather results from futures\n", | |
" objs = []\n", | |
" for i in range(len(futures)):\n", | |
" obj = dask.array.from_delayed(futures[i],\n", | |
" shape=(x.shape[1],x.shape[1]),\n", | |
" dtype=x.dtype)\n", | |
" objs.append(obj)\n", | |
" gram_matrix = dask.array.stack(objs).sum(axis=0).compute() \n", | |
" mean_x, _ = get_mean_var_dask(client, x)\n", | |
" mean_x = mean_x.astype(x.dtype)\n", | |
" copy_gram = _copy_kernel(x.dtype)\n", | |
" block = (32, 32)\n", | |
" grid = (math.ceil(x.shape[1] / block[0]), math.ceil(x.shape[1] / block[1]))\n", | |
" copy_gram(\n", | |
" grid,\n", | |
" block,\n", | |
" (gram_matrix, x.shape[1]),\n", | |
" )\n", | |
"\n", | |
" gram_matrix *= 1 / x.shape[0]\n", | |
"\n", | |
" if return_gram:\n", | |
" cov_result = cp.zeros(\n", | |
" (gram_matrix.shape[0], gram_matrix.shape[0]),\n", | |
" dtype=gram_matrix.dtype,\n", | |
" )\n", | |
" else:\n", | |
" cov_result = gram_matrix\n", | |
"\n", | |
" compute_cov = _cov_kernel(gram_matrix.dtype)\n", | |
"\n", | |
" block_size = (32, 32)\n", | |
" grid_size = (math.ceil(gram_matrix.shape[0] / 8),) * 2\n", | |
" compute_cov(\n", | |
" grid_size,\n", | |
" block_size,\n", | |
" (cov_result, gram_matrix, mean_x, mean_x, gram_matrix.shape[0]),\n", | |
" )\n", | |
"\n", | |
" if not return_gram and not return_mean:\n", | |
" return cov_result\n", | |
" elif return_gram and not return_mean:\n", | |
" return cov_result, gram_matrix\n", | |
" elif not return_gram and return_mean:\n", | |
" return cov_result, mean_x, mean_x\n", | |
" elif return_gram and return_mean:\n", | |
" return cov_result, gram_matrix, mean_x, mean_x\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "c5065da6-356d-406d-ba6f-991b42f8dca5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
" <tr>\n", | |
" <td>\n", | |
" <table style=\"border-collapse: collapse;\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <td> </td>\n", | |
" <th> Array </th>\n", | |
" <th> Chunk </th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Bytes </th>\n", | |
" <td> 90.60 GiB </td>\n", | |
" <td> 5.66 GiB </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Shape </th>\n", | |
" <td> (2432024, 5000) </td>\n", | |
" <td> (152001, 5000) </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Dask graph </th>\n", | |
" <td colspan=\"2\"> 17 chunks in 1 graph layer </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Data type </th>\n", | |
" <td colspan=\"2\"> float64 cupy.ndarray </td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
" </table>\n", | |
" </td>\n", | |
" <td>\n", | |
" <svg width=\"75\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"25\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"7\" x2=\"25\" y2=\"7\" />\n", | |
" <line x1=\"0\" y1=\"14\" x2=\"25\" y2=\"14\" />\n", | |
" <line x1=\"0\" y1=\"22\" x2=\"25\" y2=\"22\" />\n", | |
" <line x1=\"0\" y1=\"29\" x2=\"25\" y2=\"29\" />\n", | |
" <line x1=\"0\" y1=\"37\" x2=\"25\" y2=\"37\" />\n", | |
" <line x1=\"0\" y1=\"44\" x2=\"25\" y2=\"44\" />\n", | |
" <line x1=\"0\" y1=\"52\" x2=\"25\" y2=\"52\" />\n", | |
" <line x1=\"0\" y1=\"59\" x2=\"25\" y2=\"59\" />\n", | |
" <line x1=\"0\" y1=\"67\" x2=\"25\" y2=\"67\" />\n", | |
" <line x1=\"0\" y1=\"74\" x2=\"25\" y2=\"74\" />\n", | |
" <line x1=\"0\" y1=\"82\" x2=\"25\" y2=\"82\" />\n", | |
" <line x1=\"0\" y1=\"89\" x2=\"25\" y2=\"89\" />\n", | |
" <line x1=\"0\" y1=\"97\" x2=\"25\" y2=\"97\" />\n", | |
" <line x1=\"0\" y1=\"104\" x2=\"25\" y2=\"104\" />\n", | |
" <line x1=\"0\" y1=\"112\" x2=\"25\" y2=\"112\" />\n", | |
" <line x1=\"0\" y1=\"119\" x2=\"25\" y2=\"119\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"25\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"25\" y1=\"0\" x2=\"25\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.0,0.0 25.412616514582485,0.0 25.412616514582485,120.0 0.0,120.0\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"12.706308\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >5000</text>\n", | |
" <text x=\"45.412617\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(-90,45.412617,60.000000)\">2432024</text>\n", | |
"</svg>\n", | |
" </td>\n", | |
" </tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<rechunk-merge, shape=(2432024, 5000), dtype=float64, chunksize=(152001, 5000), chunktype=cupy.ndarray>" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dask_sparse_arr = dask_sparse_arr.astype(cp.float64)\n", | |
"n_rows = dask_sparse_arr.shape[0]\n", | |
"n_cols = dask_sparse_arr.shape[1]\n", | |
"cols_per_worker = int(n_rows / 16)\n", | |
"dask_sparse_arr = dask_sparse_arr.rechunk((cols_per_worker, n_cols)).persist()\n", | |
"\n", | |
"dask_sparse_arr.compute_chunk_sizes()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "b525a19c-6ed6-4e2c-a249-c4fc536862bb", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.11 s, sys: 608 ms, total: 1.72 s\n", | |
"Wall time: 10.7 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
" <tr>\n", | |
" <td>\n", | |
" <table style=\"border-collapse: collapse;\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <td> </td>\n", | |
" <th> Array </th>\n", | |
" <th> Chunk </th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Bytes </th>\n", | |
" <td> 1.81 GiB </td>\n", | |
" <td> 115.97 MiB </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <th> Shape </th>\n", | |
" <td> (2432024, 100) </td>\n", | |
" <td> (152001, 100) </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Dask graph </th>\n", | |
" <td colspan=\"2\"> 17 chunks in 1 graph layer </td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th> Data type </th>\n", | |
" <td colspan=\"2\"> float64 cupy.ndarray </td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
" </table>\n", | |
" </td>\n", | |
" <td>\n", | |
" <svg width=\"75\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"25\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"7\" x2=\"25\" y2=\"7\" />\n", | |
" <line x1=\"0\" y1=\"14\" x2=\"25\" y2=\"14\" />\n", | |
" <line x1=\"0\" y1=\"22\" x2=\"25\" y2=\"22\" />\n", | |
" <line x1=\"0\" y1=\"29\" x2=\"25\" y2=\"29\" />\n", | |
" <line x1=\"0\" y1=\"37\" x2=\"25\" y2=\"37\" />\n", | |
" <line x1=\"0\" y1=\"44\" x2=\"25\" y2=\"44\" />\n", | |
" <line x1=\"0\" y1=\"52\" x2=\"25\" y2=\"52\" />\n", | |
" <line x1=\"0\" y1=\"59\" x2=\"25\" y2=\"59\" />\n", | |
" <line x1=\"0\" y1=\"67\" x2=\"25\" y2=\"67\" />\n", | |
" <line x1=\"0\" y1=\"74\" x2=\"25\" y2=\"74\" />\n", | |
" <line x1=\"0\" y1=\"82\" x2=\"25\" y2=\"82\" />\n", | |
" <line x1=\"0\" y1=\"89\" x2=\"25\" y2=\"89\" />\n", | |
" <line x1=\"0\" y1=\"97\" x2=\"25\" y2=\"97\" />\n", | |
" <line x1=\"0\" y1=\"104\" x2=\"25\" y2=\"104\" />\n", | |
" <line x1=\"0\" y1=\"112\" x2=\"25\" y2=\"112\" />\n", | |
" <line x1=\"0\" y1=\"119\" x2=\"25\" y2=\"119\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"25\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"25\" y1=\"0\" x2=\"25\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.0,0.0 25.41261651458249,0.0 25.41261651458249,120.0 0.0,120.0\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"12.706308\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >100</text>\n", | |
" <text x=\"45.412617\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(-90,45.412617,60.000000)\">2432024</text>\n", | |
"</svg>\n", | |
" </td>\n", | |
" </tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<_transform, shape=(2432024, 100), dtype=float64, chunksize=(152001, 100), chunktype=cupy.ndarray>" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"dask_pca =PCA_sparse_dask(n_components=100, client=client, whiten=False)\n", | |
"dask_pca.fit(dask_sparse_arr)\n", | |
"dask_pca_data = dask_pca.transform(dask_sparse_arr)\n", | |
"dask_pca_data.compute_chunk_sizes()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "2e9c53db-f805-4d9a-8db9-94d8e188059a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.11 s, sys: 1.47 s, total: 2.58 s\n", | |
"Wall time: 2.75 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"pca = dask_pca_data.compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "4c67de58-c3ac-4138-9c9a-5218acfd6a76", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dask_sparse_arr_np = dask_sparse_arr.map_blocks(lambda x: x.get(), dtype=cp.float64,meta = np.array((0,)))\n", | |
"sparse_scipy_matrix = dask_sparse_arr_np.compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "6f7b0471-5fd7-4d08-b372-182d6bc35b72", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<2432024x5000 sparse matrix of type '<class 'numpy.float64'>'\n", | |
"\twith 1608641402 stored elements in Compressed Sparse Row format>" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sparse_scipy_matrix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d09450b7-8773-401a-ba36-782de9d320c1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"adata = anndata.AnnData(sparse_scipy_matrix)\n", | |
"sc.pp.pca(adata, n_comps=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "daf960e9-0cc2-475f-8a6b-8c12bfc852f6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cp.testing.assert_allclose(cp.abs(pca),np.abs(adata.obsm[\"X_pca\"]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e3a9c399-e712-4c1e-ab3b-728c4679ab4b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "rapids-23.12", | |
"language": "python", | |
"name": "rapids-23.12" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment