Skip to content

Instantly share code, notes, and snippets.

@junpenglao
Last active July 22, 2022 20:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save junpenglao/f5b48c34dd8ea5029202fb607806ea0f to your computer and use it in GitHub Desktop.
Save junpenglao/f5b48c34dd8ea5029202fb607806ea0f to your computer and use it in GitHub Desktop.
Sparse cholesky in JAX.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/junpenglao/f5b48c34dd8ea5029202fb607806ea0f/sparse-cholesky-in-jax.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"id": "97b6acba-83df-4220-a98d-f538b299acf0",
"metadata": {
"id": "97b6acba-83df-4220-a98d-f538b299acf0"
},
"source": [
"Implmentation from https://github.com/dpsimpson/blog/blob/master/_posts/2022-03-23-getting-jax-to-love-sparse-matrices/Python/\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5363916",
"metadata": {
"id": "a5363916"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48cd55a9-5202-4f3d-b4ed-25e42eb1430d",
"metadata": {
"id": "48cd55a9-5202-4f3d-b4ed-25e42eb1430d"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"from functools import partial"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf756eac-bdb4-447b-8f3f-90b11c9462aa",
"metadata": {
"id": "bf756eac-bdb4-447b-8f3f-90b11c9462aa"
},
"outputs": [],
"source": [
"def _symbolic_factor_csc(A_indices, A_indptr):\n",
" # Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.\n",
" n = len(A_indptr) - 1\n",
" L_sym = [np.array([], dtype=int) for j in range(n)]\n",
" children = [np.array([], dtype=int) for j in range(n)]\n",
"\n",
" for j in range(n):\n",
" L_sym[j] = A_indices[A_indptr[j] : A_indptr[j + 1]]\n",
" for child in children[j]:\n",
" tmp = L_sym[child][L_sym[child] > j]\n",
" L_sym[j] = np.unique(np.append(L_sym[j], tmp))\n",
"\n",
" if len(L_sym[j]) > 1:\n",
" p = L_sym[j][1]\n",
" children[p] = np.append(children[p], j)\n",
"\n",
" L_indptr = np.zeros(n + 1, dtype=int)\n",
" L_indptr[1:] = np.cumsum([len(x) for x in L_sym])\n",
" L_indices = np.concatenate(L_sym)\n",
"\n",
" return L_indices, L_indptr\n",
"\n",
"\n",
"def _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
" n = len(A_indptr) - 1\n",
" L_x = np.zeros(len(L_indices))\n",
"\n",
" for j in range(0, n):\n",
" copy_idx = np.nonzero(\n",
" np.in1d(\n",
" L_indices[L_indptr[j] : L_indptr[j + 1]],\n",
" A_indices[A_indptr[j] : A_indptr[j + 1]],\n",
" )\n",
" )[0]\n",
" L_x[L_indptr[j] + copy_idx] = A_x[A_indptr[j] : A_indptr[j + 1]]\n",
" return L_x\n",
"\n",
"\n",
"def _sparse_cholesky_csc_impl(L_indices, L_indptr, L_x):\n",
" n = len(L_indptr) - 1\n",
" descendant = [[] for j in range(0, n)]\n",
" for j in range(0, n):\n",
" tmp = L_x[L_indptr[j] : L_indptr[j + 1]]\n",
" for bebe in descendant[j]:\n",
" k = bebe[0]\n",
" Ljk = L_x[bebe[1]]\n",
" pad = np.nonzero(\n",
" L_indices[L_indptr[k] : L_indptr[k + 1]] == L_indices[L_indptr[j]]\n",
" )[0][0]\n",
" update_idx = np.nonzero(\n",
" np.in1d(\n",
" L_indices[L_indptr[j] : L_indptr[j + 1]],\n",
" L_indices[(L_indptr[k] + pad) : L_indptr[k + 1]],\n",
" )\n",
" )[0]\n",
" tmp[update_idx] = (\n",
" tmp[update_idx] - Ljk * L_x[(L_indptr[k] + pad) : L_indptr[k + 1]]\n",
" )\n",
"\n",
" diag = np.sqrt(tmp[0])\n",
" L_x[L_indptr[j]] = diag\n",
" L_x[(L_indptr[j] + 1) : L_indptr[j + 1]] = tmp[1:] / diag\n",
" for idx in range(L_indptr[j] + 1, L_indptr[j + 1]):\n",
" descendant[L_indices[idx]].append((j, idx))\n",
" return L_x\n",
"\n",
"\n",
"def sparse_cholesky_csc(A_indices, A_indptr, A_x):\n",
" L_indices, L_indptr = _symbolic_factor_csc(A_indices, A_indptr)\n",
" L_x = _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
" L_x = _sparse_cholesky_csc_impl(L_indices, L_indptr, L_x)\n",
" return L_indices, L_indptr, L_x\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5185359-8feb-412f-81ae-d8d91d2dec68",
"metadata": {
"id": "f5185359-8feb-412f-81ae-d8d91d2dec68",
"outputId": "008f2fcb-8acd-4864-c4c5-3d79d6f2dfc9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Error in Cholesky is 3.871041263071504e-12\n"
]
}
],
"source": [
"from scipy import sparse\n",
"\n",
"n = 50\n",
"one_d = sparse.diags([[-1.0] * (n - 1), [2.0] * n, [-1.0] * (n - 1)], [-1, 0, 1])\n",
"A = sparse.kronsum(one_d, one_d) + sparse.eye(n * n)\n",
"A_lower = sparse.tril(A, format=\"csc\")\n",
"A_indices = A_lower.indices\n",
"A_indptr = A_lower.indptr\n",
"A_x = A_lower.data\n",
"\n",
"L_indices, L_indptr, L_x = sparse_cholesky_csc(A_indices, A_indptr, A_x)\n",
"L = sparse.csc_array((L_x, L_indices, L_indptr), shape=(n**2, n**2))\n",
"\n",
"err = np.sum(np.abs((A - L @ L.transpose()).todense()))\n",
"print(f\"Error in Cholesky is {err}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "89ee3cad-c94d-4b0c-8bc4-b6c13d246424",
"metadata": {
"id": "89ee3cad-c94d-4b0c-8bc4-b6c13d246424",
"outputId": "41e5b54d-78ee-452c-cf2b-6d383f49b9e1"
},
"outputs": [
{
"data": {
"text/plain": [
"(3, 51)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.max(np.diff(A_indptr)), np.max(np.diff(L_indptr))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e50533d5",
"metadata": {
"id": "e50533d5",
"outputId": "878ac5ff-2f5e-4ec6-840c-fcb1279c8e42"
},
"outputs": [
{
"data": {
"text/plain": [
"(2501, 2501)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(A_indptr), len(L_indptr)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c01f20eb",
"metadata": {
"id": "c01f20eb",
"outputId": "1e578270-6c31-4c1a-dde4-8eb6b24b3711"
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f8d488005e0>]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAaXklEQVR4nO3de5Cb13nf8d+zALkri7rRWrK0pC4lh81EiSPZ3mqUunXGUZ0oThspTdw6M6n5h6ZKM8mMM23ala0mitM4VePYrd3cylhuaDu24owupNeyJYaSLDuWSC0lkiJFUryJd3KXt+VtbwCe/oEXy33xArsgd4E9Z/n9zOwAeBbAe855gQcPDl7gmLsLABCfttluAADg8pDAASBSJHAAiBQJHAAiRQIHgEjlW7mxG2+80ZcuXdrKTQJA9DZs2HDc3Tur4y1N4EuXLlVfX18rNwkA0TOzfbXiTKEAQKRI4AAQKRI4AESKBA4AkSKBA0CkSOAAECkSOABEqqXHgc+4A+ulnWtmuxUAUNZxnXT3b0htuZZsLu4E/vx/l/a+JMlmuyUArnjJ2gq3/gtpyR0t2WJDCdzM3pZ0VlJRUsHdu81soaS/lbRU0tuS/q27n2pOM+sYG5Ju+5D08adbulkAyNi1VvravynnpRa5lDnwD7n7ne7enVx+SNJad18maW1yubUKw1K+o+WbBYCMSi4qDLdsk9P5EPM+SSuT8ysl3T/t1lyqwqiUb2/5ZgEgYzyBj7Zsk40mcJf0nJltMLMHk9hidz8iScnpolo3NLMHzazPzPoGBgam3+KJqMABhKJSTLawAm/0Q8wPuPthM1skaY2ZbW90A+6+QtIKSeru7p7ZFZQLI1TgAMIwXoGPtGyTDVXg7n44Oe2X9JSkuyQdM7MlkpSc9jerkXVRgQMIxSxU4FMmcDO72syuqZyX9LOStkhaLWl5crXlklY1q5F1UYEDCMUsfIjZyBTKYklPmVnl+l939++a2auSvmlmD0jaL+mjzWtmDe5U4ADCMV6Bt24KZcoE7u57JGWOSnf3E5LuaUajGlIck+RU4ADCENlhhLOrMkgkcAAhyM0rn4b2IWaQKoPEFAqAEJiV8xEVeAOKlQROBQ4gEPl2qRjeF3nCQwUOIDRU4A1iDhxAaPLtzIE3ZDyBU4EDCAQVeIMKzIEDCAwVeIOowAGEhgq8QVTgAEKT76ACb0jlVS5HAgcQiNx8KvCGUIEDCA0VeIM4DhxAaPgQs0EkcAChoQJvEF/kARCafDtz4A2hAgcQGirwBhWGJctJuUaX9QSAJqMCbxCr8QAITb6j/EupPrPrt9cTcQJnPUwAgcnPL5+2aBol4gQ+TAIHEJYWL6sWcQKnAgcQmBYvbBxvAi+OMAcOICyVnFQkgU+OChxAaManUEjgk+MoFAChGZ9CYQ58clTgAEJDBd4gKnAAoaECbxAVOIDQcBhhg6jAAYQmxxd5GlMYYTUeAGGhAm8Q38QEEJrxOfDRlmwu4gQ+yhQKgLBQgTeIChxAaEL9Kr2Z5czsdTPrTS4vNLM1ZrYzOb2hec2sUipKpTEqcABhCbgC/4SkbRMuPyRprbsvk7Q2udwarEgPIEQhVuBmdrOkX5D0pQnh+yStTM6vlHT/jLZsMuPrYVKBAwhIW05qmxdcBf6/Jf1XSaUJscXufkSSktNFtW5oZg+aWZ+Z9Q0MDEynrRdRgQMIVQvXxZwygZvZv5LU7+4bLmcD7r7C3bvdvbuzs/Ny7iKLChxAqPLzW1aBN7Ii8Ack/aKZfURSh6Rrzexrko6Z2RJ3P2JmSyT1N7OhKeMV+PyWbRIAGhJSBe7un3T3m919qaSPSXre3X9N0mpJy5OrLZe0qmmtrEYFDiBU+fYoFnR4VNKHzWynpA8nl1ujmHzLiTlwAKHJdwQ1hTLO3V+U9GJy/oSke2a+SQ2gAgcQqnx7OFMoQRqfAyeBAwhMCyvwSBN4pQJnCgVAYKjAp0AFDiBUVOBToAIHECoq8ClUEjgLOgAITa6dCnxSfJUeQKiowKfAYYQAQhXSNzGDVOCLPAACRQU+hcJweZ7JbLZbAgBpHIUyhcII0ycAwpTvkLwoFQtN31SkCZz1MAEEanxVnuZX4ZEmcCpwAIEaXxez+fPgkSZwKnAAgaqsU0AFXkdhhAQOIEwtXJk+0gROBQ4gUC1cmT7SBM4cOIBAVXJTC1bliTOBF5lCARAoKvApFIapwAGEiTnwKfAhJoBQUYFPgQocQKiowKdABQ4gVHyRZwpU4ABCleOLPJMrjFwcJAAICRX4JNypwAGEix+zmkSRxRwABIwKfBLj62FSgQMIUG6eJCOB18SCxgBCZtayVXkiTOAsaAwgcC1aFzPCBM4UCoDAUYHXMV6BM4UCIFChVOBm1mFm681sk5ltNbNPJ/GFZrbGzHYmpzc0vbUSFTiA8OXbg6nARyT9jLvfIelOSfea2d2SHpK01t2XSVqbXG6+8QqcL/IACFQoFbiXnUsuzkv+XNJ9klYm8ZWS7m9GAzP4EBNA6EKaAzeznJltlNQvaY27r5O02N2PSFJyuqjObR80sz4z6xsYGJh+izmMEEDo8h0Xv3TYRA0lcHcvuvudkm6WdJeZ/USjG3D3Fe7e7e7dnZ2dl9nMCYrMgQMIXEBz4OPc/bSkFyXdK+mYmS2RpOS0f6YbVxMVOIDQ5TvCmAM3s04zuz45f5Wkfylpu6TVkpYnV1suaVWT2pjGHDiA0LWoAs83cJ0lklaaWU7lhP9Nd+81s5clfdPMHpC0X9JHm9jOiziMEEDoWlSBT5nA3X2zpPfWiJ+QdE8zGjUpvsgDIHQhzoEHofKqliOBAwhULpDjwINTGJYsJ+Uamf0BgFlABV5HYYT5bwBhqxwH7t7UzUSYwIeZ/wYQtvFl1Zo7jRJhAqcCBxC48WXVmjuNEmkCpwIHEDAq8DpYkR5A6KjA66ACBxA6KvA6qMABhG48gVOBp1GBAwjd+BQKFXgahxECCB0VeB1U4ABCV6nAi1TgacyBAwgdH2LWURylAgcQNg4jrIMKHEDoqMDr4Kv0AEJHBV4HR6EACB0VeA3FglQqUIEDCFuOwwiziqxIDyACVOA1sJwagBi05aS2eVTgKSxoDCAW+Q6pMNrUTUSawJkDBxC4FqyLGVkCT17NqMABhC7fwRx4ChU4gFhQgVcpcBQKgEjkO0jgKVTgAGKRb2cKJWW8AieBAwgcUyhVOIwQQCyowKswBw4gFsyBV6ECBxCLfHt5/YImijSBMwcOIHAhVOBmdouZvWBm28xsq5l9IokvNLM1ZrYzOb2hqS2VLr6aUYEDCF0gc+AFSf/Z3X9M0t2SftPMbpf0kKS17r5M0trkcnNRgQOIRQgVuLsfcffXkvNnJW2TdJOk+yStTK62UtL9TWrjReO/Rji/6ZsCgGkJpAIfZ2ZLJb1X0jpJi939iFRO8pIW1bnNg2bWZ2Z9AwMD02ttZT1Ms+ndDwA0W6UCd2/aJhpO4Ga2QNITkn7b3c80ejt3X+Hu3e7e3dnZeTltvKgwwvw3gDjk2iUvlVcRa5KGEriZzVM5ef+Nuz+ZhI+Z2ZLk/0sk9TeniROwIj2AWOSbv6xaI0ehmKTHJG1z989P+NdqScuT88slrZr55lUpjLAaD4A4jK9M37x58HwD1/mApH8v6Q0z25jEPiXpUUnfNLMHJO2X9NGmtHAiVqQHEIsWrIs5ZQJ39x9Iqvep4T0z25wpFEaYQgEQh/EKfBanUILCh5gAYtGCCjzCBE4FDiACVOBVmAMHEAsq8CpU4ABiEcJhhEGhAgcQCyrwKlTgAGLBHHiVwrCU54esAESgBV/kiSyBU4EDiERlCqVIAi9jDhxALKjAJ3Avv5JRgQOIAUehTMByagBiwoeYE7CcGoCYtOUla2MKRdLFQaACBxADs/LPX1OBiwocQHyavC5mRAm8UoGTwAFEoskr00eUwJNBYEV6ALGgAk9QgQOITb6DBC5pwhw4H2ICiAQVeIIKHEBsmANPcBghgNhQgSc4jBBAbKjAE1TgAGJDBZ6gAgcQmzzfxCyjAgcQGw4jTHAYIYDY5NtZ0EHSxVexHAkcQCSowBOF4fLPM+bys90SAGgMc+CJ4igfYAKIS76jnLtKpabcfTwJnPUwAcSmyQsbR5bAqcABRKTJy6pNmcDN7Mtm1m9mWybEFprZGjPbmZze0JTWTVQYoQIHEJfKz1836YPMRirwv5Z0b1XsIUlr3X2ZpLXJ5eaiAgcQm9muwN39JUknq8L3SVqZnF8p6f6ZbVYNhREWcwAQl8qswSxW4LUsdvcjkpScLqp3RTN70Mz6zKxvYGDgMjcnKnAA8RmvwMNK4A1z9xXu3u3u3Z2dnZd/R8yBA4hNoAn8mJktkaTktH/mmlQHFTiA2IxPoczSHHgdqyUtT84vl7RqZpozicIoFTiAuMx2BW5m35D0sqQfNbODZvaApEclfdjMdkr6cHK5uajAAcSmyRX4lD8s4u6/Wudf98xwWybHHDiA2Mz2YYTBoAIHEJv87H+RJwxU4ABiQwWe4MesAMQm0C/ytFaxIHmRKRQAcankrCv61whZTg1AjHJU4BMWNKYCBxCRtrbybzhd0XPgRVakBxCpJq6LGUcCH59CoQIHEJkmrosZSQKnAgcQqVw7FbgkKnAA8aECpwIHECnmwJNXrxwJHEBkqMA5jBBApKjA+SIPgEjlr/gPManAAUQq38EUiiQqcADxoQLnMEIAkaICpwIHEKn8fCpwSVTgAOJDBT4iyaTcvNluCQBcGubAk9V4zGa7JQBwaSoVuPvM3/WM32MTbD0woK5iXv9hxSup+OsHTml4rKSfuu2dqfjO/nM6fm4kEz96Zlh7j5/PxM8Mj2nr4TO6+7aFMl18kRgrltS375T+6dIblG+7+Frncr2y56Tec9N1WtCeHsKX95zQbZ1Xa/E1HZn4omva9e7OBZn4O+bndMfN12fikjJtrRffsP+URgu1xuKsjp8bzcQPDw5p34kLmfjg0JjePJIdi9FiSRv2ndJdSxcq13YxXnLXur0n9ZM3X6er52fH4t2dV2tRjbH4R9d26NYbr87EF7Tn9Z6brpvWWPTtO6mxomfiO46d1cnz2bE4ePqCDpwcysRPXRjV9qNnM2MxXCjq9f2n647FHbdcr3fMy2XaumzRAt24oD0Tf9d1Hep6Z3YsrunI6yfelR0LM+nuWxsbi/Vvn1SxlB2LN4+c0eDQWCa+/+QFHTqdHYuT50e149jZTHxorKiNB07rrlsXKjehwCq6a/3ek7rzlut1VY2x+NHF12jh1fMz8Zuuv0r/eOE7MvHrrpqn25dcm4nn2kx3LV3Y0Fi8sveE3LPxrYcHdWa4kInvO3FehweHM/Hj50a0s/9cJi5J77r+Kn32V35SbRMeF+XP7lwqFWZ8FiGKCvxsxxJtn//jKpY89Tc8VpKkTPz4ufLbleFCMRXfe/y8JOncSCEV33r4jCTp0OmhVLxv3ylJ0uv7T6fi+09ckCS9cWgwFT87MiZJ2jNwPhUfGitKkvrPjmTaKkkXRos147X6Vi8+Wqg3FqOSpJGqsdiX9KF6LN48Uh6LY2fSbd2QjMXmQ+mx2JOM6eaD6bE4M1wei911xuLomeFUvFAq1WzP5YzFWNFrxk+eL4/FaLGUih84OSRJOj+a3vb2o2clScfPjabir+8/LamcBCfGd/afkyRtOpAeo8Gh8ljs7D+Xil9IxuLw4HBV+8tjcXa49li4Nz4Wlf9VxyptGqsai0Onh5LHZHrbO46Vx+LUhfRYbDxQHosdR8+mxy55HG2sGotTF8r7YMex9PUvjBYkZZ+Do8lYDA6NNdSvycaiUgBXx88MF2qOxeHB8rz1hbH0c6eyn6vbdHhwSE+8dlAnkz6Oa+bCxu7esr/3v//9PpO6enq9q6e3bvzs8Fgq/s/+x1rv6un1PQPnUvEH/nq9d/X0+rNbjqTin/3udu/q6fUv/P1bqfi3Nh3yrp5e/42v9aXiO4+d8a6eXv/QZ19IxU+fH52yrc2OXxgppOLdf7jGu3p6ff+J86n4r33pFe/q6fUXth9LxT/z7Te9q6fX/+LFXan4ExsOeFdPr3/iG6+l4lsPDXpXT6//3P/6Xip+4txIzbaWSqWWjcXIWDEVv+PTz3pXT68fOnUhFf93//eH3tXT6/+wayAVf2TVFu/q6fXHvr8nFX98/T7v6un1//J3G1PxjftPeVdPr//r//P9VPzYmSHv6un1ZZ96JhUvFls3FmOF9Fj82O9+x7t6ev3o4FAq/kt/9gPv6un1V/eeSMU/+eRm7+rp9a++/HYq/pUf7vWunl5/+KnNqfj6vSe8q6fXf/nP/yEVP3K6PBY//nvfTcXHCsWWjUWpVErF3/3Jb3tXT6/3nxlOxX/hiy95V0+vbzpwKhWv7P/qx5GvW+H+yLXu59KPo0shqc9r5NQoKvDL1Z5Pd6/yDi/fVj2XXntuvd6Uu9W9fu14+7zZH+bqsaiYXydery/1PoWod/1G29Ho7WfCvFztbdUbi2bpqJpaqGjLPD6bJ5+reo4kp9nnzuSPh8y/k0D1c6Xe9TuS50j13VS3r5nq9bGjzvO3um/t+fL+HEneDY8bX5l+5ivwKObAL1d1op5f58FQeUJX78DK/Gau6n4qF3NttR/8+Vz1jp39BF6dFCpjUf2QnVcnXhmDtgbHqHK1eVVjXhmLevuiFar3c6Ut1X27OBbpeH58LNL3W7l9vbGo93is94IyG+bn23R+tJjpw8XnSPr6+XqPi7pjYcnt0vu/kqjnBfBcqZifb1NhtJh5Ya9+TFdUHtv/8asb1DH/4ovzTw/v13+SNDx0QR3X1bzp5atVljfrb6anUH6wc8D/3w/2ZOKbDpzyL1ZNe7i77+4/63/0zJuZt0rHBof8955+w0er3k6eHR7zTz252c9VTcWMjBX9d59+I/PWqlQq+We+/WZmisbd/Qt//5ZvPnA6E3/s+3syb9Hd3f/21f2ZKR1392c2H/YnNhzIxF96q99X/nBvJv7avpP+p8/vzMR3Hjvjj35nW2Ysjpwe8kdWbcm8tR4cGvVPPbnZz4+kx2J4rOD/7ak3/PjZ9FgUiyX/w96tvu94eorG3f3zz+3wNw5mx+KvXtrtL+8+nok/vn6fr9l6NBPv3XTYn3rtYCb+4o5+/0rVW3p39763T/qfv7ArE3/r6Bn/4+9mx+Lw6Qv+yKotXiim46cvlMeielpqaLTgDz+12U+eG0nFi8WS/8G3tmamq0qlkn/uuR3+5uHBTJtWfG+3r9tzIhP/+rp9vnZbdixWbzzkT7+eHYvntx/zr72SHYtX957wv3wxOxbbjgz6nzy7PTMWB09d8E+v3poZi1PnR/zhpzb70GjtsTh1Pj0WhWLJP716qx+smmYolUr+J89u9+1HzmTa9Jcv7spM3bi7f/Xlt/35qqk+d/enXz/oqzceysTXbjvqX1+3LxNft+eEr/je7kx866FB/9xzOzLx/SfO+x98a6sXi9k88utf6fOPP7Yu9feZz/6R+yPX+p4tr2Tuq1GqM4Vi3oRDW+rp7u72vr6+lm0PAGbbprWP647v/7re+sXV+ifv++nLug8z2+Du3dXxcN6vAMAc1DbvKklSYXRo5u97xu8RADAu314+jLBIAgeAuCx457v02oIP6qrrFs34fU/rKBQzu1fSFyTlJH3J3R+dkVYBwBxxy4+8R7f8zreact+XXYGbWU7Sn0n6eUm3S/pVM7t9phoGAJjcdKZQ7pK0y933uPuopMcl3TczzQIATGU6CfwmSQcmXD6YxFLM7EEz6zOzvoGBgWlsDgAw0XQSeK2vj2UOKnf3Fe7e7e7dnZ2d09gcAGCi6STwg5JumXD5ZkmHp9ccAECjppPAX5W0zMxuNbP5kj4mafXMNAsAMJXLPozQ3Qtm9luSnlX5MMIvu/vWGWsZAGBS0zoO3N2fkfTMDLUFAHAJWvpjVmY2IGnfZd78RknHZ7A5MaDPVwb6fGWYTp+73D1zFEhLE/h0mFlfrV/jmsvo85WBPl8ZmtFnfgsFACJFAgeASMWUwFfMdgNmAX2+MtDnK8OM9zmaOXAAQFpMFTgAYAISOABEKooEbmb3mtkOM9tlZg/Ndntmipm9bWZvmNlGM+tLYgvNbI2Z7UxOb5hw/U8mY7DDzH5u9lreODP7spn1m9mWCbFL7qOZvT8Zq11m9kUzq/VjakGo0+ffN7NDyb7eaGYfmfC/udDnW8zsBTPbZmZbzewTSXzO7utJ+ty6fV1rqfqQ/lT+mv5uSbdJmi9pk6TbZ7tdM9S3tyXdWBX7Y0kPJecfkvQ/k/O3J31vl3RrMia52e5DA338oKT3SdoynT5KWi/pp1T+FczvSPr52e7bJfb59yX9To3rzpU+L5H0vuT8NZLeSvo2Z/f1JH1u2b6OoQK/0haOuE/SyuT8Skn3T4g/7u4j7r5X0i6VxyZo7v6SpJNV4Uvqo5ktkXStu7/s5Uf7VybcJjh1+lzPXOnzEXd/LTl/VtI2ldcHmLP7epI+1zPjfY4hgTe0cESkXNJzZrbBzB5MYovd/YhUfoBIqqyEOpfG4VL7eFNyvjoem98ys83JFEtlKmHO9dnMlkp6r6R1ukL2dVWfpRbt6xgSeEMLR0TqA+7+PpXXFf1NM/vgJNedy+NQUa+Pc6HvfyHp3ZLulHRE0ueS+Jzqs5ktkPSEpN929zOTXbVGLMp+1+hzy/Z1DAl8zi4c4e6Hk9N+SU+pPCVyLHlLpeS0P7n6XBqHS+3jweR8dTwa7n7M3YvuXpL0V7o4/TVn+mxm81ROZH/j7k8m4Tm9r2v1uZX7OoYEPicXjjCzq83smsp5ST8raYvKfVueXG25pFXJ+dWSPmZm7WZ2q6RlKn/wEaNL6mPy1vusmd2dfDr/8Qm3iUIliSV+SeV9Lc2RPidtfEzSNnf//IR/zdl9Xa/PLd3Xs/1JboOf9n5E5U94d0t6eLbbM0N9uk3lT6Q3Sdpa6Zekd0paK2lncrpwwm0eTsZghwL9ZL5GP7+h8tvIMZUrjQcup4+SupMnwm5Jf6rkW8Qh/tXp81clvSFpc/JEXjLH+vzPVX7bv1nSxuTvI3N5X0/S55bta75KDwCRimEKBQBQAwkcACJFAgeASJHAASBSJHAAiBQJHAAiRQIHgEj9f5Y2lXFozNFhAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.diff(A_indptr), label=\"A\")\n",
"plt.plot(np.diff(L_indptr), label=\"L\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0efea52e",
"metadata": {
"id": "0efea52e",
"outputId": "0fdf83b0-4280-4dad-854d-c9cd62d85662"
},
"outputs": [
{
"data": {
"text/plain": [
"(7400, 125049)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(A_indices), len(L_indices)\n"
]
},
{
"cell_type": "markdown",
"id": "f119ad55",
"metadata": {
"id": "f119ad55"
},
"source": [
"Ragged series is difficult to handle in JAX. \n",
"Idea 1: pad to same length\n",
"\n",
"- pro: vmap-able, usually fast\n",
"- con: need more memory\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83d6fee8",
"metadata": {
"id": "83d6fee8"
},
"outputs": [],
"source": [
"@partial(jax.jit, static_argnums=(0,))\n",
"def fix_length_arange(max_length, start, stop):\n",
" output_idx = np.arange(max_length) + start\n",
" output_null = jnp.ones_like(output_idx) * -1\n",
" output = jnp.where(np.arange(max_length) < (stop - start), output_idx, output_null)\n",
" return output\n",
"\n",
"\n",
"# @partial(jax.jit, static_argnums=(0,1,2,3))\n",
"def _deep_copy_csc_jax(m_A, m_L, A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
" A_indptr_mat = jax.vmap(partial(fix_length_arange, m_A))(\n",
" A_indptr[:-1], A_indptr[1:]\n",
" )\n",
" L_indptr_mat = jax.vmap(partial(fix_length_arange, m_L))(\n",
" L_indptr[:-1], L_indptr[1:]\n",
" )\n",
"\n",
" def row_fun(row_idx_A, row_idxptr_A, row_idx_L, row_val_A):\n",
" row_idx_A_ = jnp.where(\n",
" row_idxptr_A != -1, row_idx_A, jnp.ones_like(row_idx_A) * -1\n",
" )\n",
" out = jnp.zeros(m_L)\n",
" copy_idx = jnp.nonzero(jnp.in1d(row_idx_L, row_idx_A_), size=m_A)[0]\n",
" out = out.at[copy_idx].set(row_val_A)\n",
" return out\n",
"\n",
" L_x_mat = jax.vmap(row_fun)(\n",
" A_indices[A_indptr_mat],\n",
" A_indptr_mat,\n",
" L_indices[L_indptr_mat],\n",
" A_x[A_indptr_mat],\n",
" )\n",
" return L_x_mat, L_indptr_mat\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b5b0566-1b87-4ee7-9446-f90c16dc9816",
"metadata": {
"id": "9b5b0566-1b87-4ee7-9446-f90c16dc9816",
"outputId": "799dffc9-4a9a-480d-f3d2-07296abe9786"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.49 ms ± 24.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
"source": [
"L_indices, L_indptr = _symbolic_factor_csc(A_indices, A_indptr)\n",
"L_x0 = _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"\n",
"# size of the memeory, need to be static\n",
"m_A = np.max(np.diff(A_indptr))\n",
"m_L = np.max(np.diff(L_indptr))\n",
"# L_x_mat, L_indptr_mat = _deep_copy_csc_jax(m_A, m_L, A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"partial_deep_copy_csc = jax.jit(partial(_deep_copy_csc_jax, m_A, m_L))\n",
"L_x_mat, L_indptr_mat = partial_deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"%timeit L_x_mat, L_indptr_mat = partial_deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"L_x = L_x_mat[L_indptr_mat != -1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca45e30a",
"metadata": {
"id": "ca45e30a"
},
"outputs": [],
"source": [
"np.testing.assert_array_almost_equal(L_x, L_x0)\n",
"# np.testing.assert_array_almost_equal(L_x_mat, L_x0[L_indptr_mat])\n"
]
},
{
"cell_type": "markdown",
"id": "10cf6278",
"metadata": {
"id": "10cf6278"
},
"source": [
"Idea 2: scan and build index \n",
"There are 2 options here: scan through n (the column/row size of the matrix), which is how it is implemented initially, or scan through the length of `A_x` and build an index vector (since the value need to go somewhere in `L_x`)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de867282",
"metadata": {
"id": "de867282"
},
"outputs": [],
"source": [
"# Option 1 does not work because JAX is absolutely allergic to dynamic slicing\n",
"# There is a way to have static shape inside one_step with a placeholder fix-shape\n",
"# array but then I might as well just use vmap version above\n",
"\n",
"# def _deep_copy_csc_scan(m_A, A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
"# L_indptr_last = L_indptr[0]\n",
"# A_indptr_last = A_indptr[0]\n",
"# L_x_init = jnp.zeros_like(L_indptr, dtype=A_x.dtype)\n",
"\n",
"# def one_step(carry, current):\n",
"# L_indptr_last, A_indptr_last, L_x = carry\n",
"# L_indptr_current, A_indptr_current = current\n",
"# copy_idx = jnp.nonzero(\n",
"# jnp.in1d(\n",
"# # L_indices[L_indptr_last:L_indptr_current],\n",
"# jax.lax.dynamic_slice(L_indices, [L_indptr_last], [L_indptr_current-L_indptr_last]),\n",
"# jax.lax.dynamic_slice(A_indices, [A_indptr_last], [A_indptr_current-A_indptr_last]),\n",
"# ),\n",
"# size=A_indptr_current-A_indptr_last\n",
"# )[0]\n",
"# L_x = L_x.at[L_indptr_last + copy_idx].set(\n",
"# jax.lax.dynamic_slice(A_x, [A_indptr_last], [A_indptr_current-A_indptr_last])\n",
"# )\n",
"# return (L_indptr_current, A_indptr_current, L_x), None\n",
"\n",
"# (*_, L_x), _ = jax.lax.scan(\n",
"# one_step,\n",
"# (L_indptr_last, A_indptr_last, L_x_init),\n",
"# (L_indptr[1:], A_indptr[1:]))\n",
"# return L_x\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6ff7c8f-df2f-40bf-a874-2b975342a2bf",
"metadata": {
"tags": [],
"id": "f6ff7c8f-df2f-40bf-a874-2b975342a2bf"
},
"outputs": [],
"source": [
"def _deep_copy_csc_scan(m_L, A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
" # Pad value otherwise out of bound indexing result is weird.\n",
" L_indices_padded = jnp.pad(L_indices, [0, m_L], constant_values=-1)\n",
"\n",
" def one_step(_, i):\n",
" A_idx = A_indices[i]\n",
" j = jnp.argwhere(i < A_indptr, size=1)[0][0]\n",
" # L_indices_slice = L_indices[L_indptr[j-1]:L_indptr[j]]\n",
" L_indices_slice = jax.lax.dynamic_slice(\n",
" L_indices_padded, [L_indptr[j - 1]], [m_L]\n",
" )\n",
" L_indices_slice = jnp.where(\n",
" jnp.arange(m_L) < L_indptr[j] - L_indptr[j - 1],\n",
" L_indices_slice,\n",
" jnp.ones_like(L_indices_slice) * -1,\n",
" )\n",
" k = jnp.argwhere(A_idx == L_indices_slice, size=1)[0][0]\n",
" to_write_index = k + L_indptr[j - 1]\n",
" return None, to_write_index\n",
"\n",
" _, update_index = jax.lax.scan(one_step, None, jnp.arange(A_indices.shape[-1]))\n",
"\n",
" L_x = jnp.zeros_like(L_indices, dtype=A_x.dtype)\n",
" L_x = L_x.at[update_index].set(A_x)\n",
" return L_x, update_index\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b522bd88",
"metadata": {
"tags": [],
"id": "b522bd88",
"outputId": "a670552f-38ac-4fa6-bf5b-678996855fe9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.53 ms ± 33 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
"source": [
"m_L = np.max(np.diff(L_indptr))\n",
"partial_deep_copy_csc_1 = jax.jit(partial(_deep_copy_csc_scan, m_L))\n",
"L_x, update_index = partial_deep_copy_csc_1(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"%timeit L_x = partial_deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "773b1cf8",
"metadata": {
"id": "773b1cf8"
},
"outputs": [],
"source": [
"np.testing.assert_array_almost_equal(L_x, L_x0)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8754ef41-b0c1-438d-90b1-82212ad0a0cc",
"metadata": {
"id": "8754ef41-b0c1-438d-90b1-82212ad0a0cc",
"outputId": "866c1a0b-8b23-4509-bbe7-e6596beca48b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of non-zeros is 125049 (fill in of 117649)\n"
]
}
],
"source": [
"nnz = len(L_x)\n",
"print(f\"Number of non-zeros is {nnz} (fill in of {len(L_x) - len(A_x)})\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29700fdc-e619-475f-a5a4-d360112e4c5f",
"metadata": {
"id": "29700fdc-e619-475f-a5a4-d360112e4c5f",
"outputId": "9422736a-913c-4aeb-c50f-71674d9ac71c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2499 2498 2449 ... 50 1 0]\n"
]
}
],
"source": [
"perm = sparse.csgraph.reverse_cuthill_mckee(A, symmetric_mode=True)\n",
"print(perm)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d40f647-f9b9-4e6c-b30a-dbaa94ecbb5f",
"metadata": {
"id": "3d40f647-f9b9-4e6c-b30a-dbaa94ecbb5f"
},
"outputs": [],
"source": [
"A_perm = A[perm[:, None], perm]\n",
"A_perm_lower = sparse.tril(A_perm, format=\"csc\")\n",
"A_indices = A_perm_lower.indices\n",
"A_indptr = A_perm_lower.indptr\n",
"A_x = A_perm_lower.data\n",
"\n",
"L_indices, L_indptr, L_x = sparse_cholesky_csc(A_indices, A_indptr, A_x)\n",
"L = sparse.csc_array((L_x, L_indices, L_indptr), shape=(n**2, n**2))\n",
"err = np.sum(np.abs((A_perm - L @ L.transpose()).todense()))\n",
"print(f\"Error in Cholesky is {err}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53052e6f",
"metadata": {
"id": "53052e6f",
"outputId": "26a14ae4-b63d-459e-fa91-3a341a4e82d2"
},
"outputs": [
{
"data": {
"text/plain": [
"(3, 51)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.max(np.diff(A_indptr)), np.max(np.diff(L_indptr))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d9a5a51",
"metadata": {
"id": "5d9a5a51",
"outputId": "6fc321ab-e78c-49ee-9042-657f34ab247a"
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f8cd8969e20>]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAnFElEQVR4nO3deXxU9b3/8dc3mawECCHsWwARQZTFgOCCVK3ivrVWr1tbftLNW29vbWuvvdXe28XaqvX21ra0eoutbbVVr9prXUrFpcWFRUWlLAKyQwiBACHLJN/fH99zMhkIWWfmzJm8n4/HPM4538zM+ZyZ5JMzn/Od79dYaxERkfDJCjoAERHpGiVwEZGQUgIXEQkpJXARkZBSAhcRCalIKndWWlpqy8rKUrlLEZHQW7Zs2W5r7YDD21OawMvKyli6dGkqdykiEnrGmA9ba1cJRUQkpJTARURCSglcRCSklMBFREJKCVxEJKSUwEVEQkoJXEQkpFLaD1wk5Q7thTd/CdG6jj9m4ASYdHnSQhJJFCVwyWxrn4e//qe3YTrwAAuRAiVwCYUOJXBjzEZgP9AIRK215caYEuARoAzYCFxpra1KTpgiXdRQ45Zfeh/6Dmv//i9+D166E6wF05GELxKcztTAP2KtnWKtLfe2bwUWWWvHAYu8bZH0Eq13y0h+x+4fyfMe14mSi0hAulNCuQSY460vBBYDX+tmPCJd09gAr9wDtfvi27e/5ZZ+Ym6Pn+ifvw2yD3vMkMkw+RPdClMkkTqawC3wvDHGAj+31i4ABllrtwNYa7cbYwa29kBjzHxgPsDIkSMTELJIK3a8A4u/6+rXWYf9Wg8+EXIKO/Y8gydBQT94+5H49mgt5BYqgUta6WgCP9Vau81L0i8YY/7R0R14yX4BQHl5uWZQluRoqHXLf3oExpzR9ecZPRu+tvHI9hduh9d+2vXnFUmCDtXArbXbvOUu4AlgBrDTGDMEwFvuSlaQIu2Kegm8o6WSzorkQWOdu7gpkibaPQM3xvQCsqy1+731c4D/AJ4CbgDu9JZPJjNQkWbLfgXbVsS37dvilslM4ABP/TNkZcf/bNw5cNwFydmvSBs6UkIZBDxhXJeqCPBba+2zxpg3gUeNMfOATcDHkxemSAsvfNNdtMzrHd9eOh6KRyVnn8NOgj7DXL/ylg5VwY6VSuASiHYTuLV2PTC5lfZK4KxkBCXSpmgdnPwZ+Oh/pG6fY+bAv75/ZPvvr4GqjamLQ6QFjYUi4WKtq3d3tF93skXyYvV3kRTTV+klfW1/G5b8BGxTrM1fT1atu7Mi+VC9DR77f/HtJhtO/SIMOj6YuKRHUAKX9LXyj/DOI1AyJr699FgYcXIwMR1u9Bmw6TXYuiy+fc96KB6hBC5JpQQu6StaB/nF8MUV7d41MJM/0fqXe749WKUVSTrVwCV9pVOtu7MiebFxWESSRGfgkh5e+CbsXhvftv3t9Kl1d1YkH1b/GfZtjm8vGgQX3H1kX3KRLlACl+A1NsDf7oOiwVA0INZeWALHnB1cXN1xwsdgw0vxCbymCqq3wOlfdvVxkW5SApfg+bXiU26CU/452FgS5dzvHNn2zqPw+I3QqNKKJIZq4BI8f+ztsNa7O6p5rHFd3JTE0Bm4pFblB/Dnr7mBoXzNky6EtN7dUf4/qCe/cNgwAMb1GQ9ruUgCozNwSa0P/w7rXoC6A6723djgpi4bfQaMnBV0dMk1ZAqMPcslcv/YGxvca7Lq6aCjkxDSGbikVrTFuN1Frc4Bkrl6D4LrHj+y/d4TNIWbdInOwCW1muvdGV4u6QyNpyJdpDNwSZ5VT8OrP8LNyOfZv8MtM/2CZWfk5MMHf4VfnBnfPnQaXPDDYGKSUNAZuCTP6j/DznfdHJP+beAEOPlzkJ0bdHTp46RPwfDp8a/Twd3w1sNBRyZpTmfgkjzRWjcJwrWPBR1Jeps+z91a+uu34ZW73fC5bjIVkSPoDFySJ1qnUklXRfLc0LlN0aAjkTSmM3BJjCX3w/KF8W37tkDpuGDiCTv/H99PTwHT4jwrvxiueRTy+wYSlqQXJXBJjDXPwoFdMPr0WNuA8TBec0V2ybFzYdtb0NQQa9u/Eza/Bns2wNApQUUmaUQJXBIjWgeDT4ArHwo6ksxQOg4+9kB827pF8JvL1WdcmqkGLokRPaR6d7L5r6/6jItHZ+DSea/eC28edna4fzsUjwomnp7CT+B//DTkFMTacwrg6t9D/7HBxCWBUQKXzvvgRXcWOO6c+PbJVwUTT08xeJLrQ19XHWs7VAWrn4Fdq5TAeyAlcOm8xnr3hZxL7w86kp4lkgfn3RnfVrHGJXCVVXok1cCl88I8V2WmaR5jXBc2eyKdgUvblj8EL9xO3HgmtfvgOHUPTAt+LfyZW+D521q0F8L1T6offoZTApe2bXnTnd1NvSa+fdIVwcQj8XoNgLPvgOptsbaaSnj3Mdi9Rgk8wymBS9uidW6i4fN/EHQk0hpj4LQvxbdVrHYJXHXxjKcauLQtWgvZGrs7VPyRHlUXz3g6A5eYbStg4cXxZ26NDTDkxOBiks7LKXTLJ2+Cp2+OtRf0g8+/BoUlwcQlCacELjG717o+xid9CgqKY+2jZwcWknRB0UA47y735Spf5To3wUb1NiXwDNLhBG6MyQaWAluttRcaY0qAR4AyYCNwpbW2KhlBSor4Z96zb4G+w4ONRbrOGDj5M/Fta553CVxllYzSmRr4zcCqFtu3AousteOARd62hElTU/zs6PU1rl19vDOP31+8/kD8e25t24+TtNahM3BjzHDgAuA7wL96zZcAc7z1hcBi4GuJDU+S6uenuynPDqcEnnlye7nlQxfHt4+YCfOeS308khAdLaH8CPgq0LtF2yBr7XYAa+12Y8zA1h5ojJkPzAcYOXJk1yOVxLIWdr0PZafDmDNi7cWjIK8ouLgkOYZMcXXxluOorHkeKlYd9SGS/tpN4MaYC4Fd1tplxpg5nd2BtXYBsACgvLxcn9fSRVPUTdk1+gyY/ZWgo5Fky44cWRevrYYd7wQTjyRER87ATwUuNsacD+QDfYwxvwF2GmOGeGffQ4BdyQxUEqC2GmyjW2+ud6uPd48VyXcXNWv2xCZONlmari1E2k3g1tqvA18H8M7Ab7HWXmuM+QFwA3Cnt3wyeWFKt73zKDx+45Htfm1Uep7cXoCFu0bHt59315Fn65KWutMP/E7gUWPMPGAT8PHEhCRJsWeDW577vdjZVlZEY5r0ZFOvc0m8KRpr+8u3Yr8rkvY6lcCttYtxvU2w1lYCZyU+JEmKaC1k5cCszwcdiaSLXv1hxmGfyl65R2OohIi+iZmponVwsCK2XbNb3QOlfZE8OLQH9m2JteX3hbzeR3+MBEYJPFM9dAlsWhLf1ntIMLFIeOT1hvefdDdffl+4ZR1EcoOLS1qlBJ6p9m1xX9JoOY73gAnBxSPhcNnPYPvbse31L8G7f3Tf4IxoDJV0owSeqaK1MGgiTLs+6EgkTIZMdjdfU9QlcI2hkpaUwDPFgV1waG9su+GQat7SfRFvyraKVVC3361nR6Df6FhvJgmMEngmqNkD90yEpob49rw+wcQjmSPf+x369WXx7ZctgMmfSH08EkcJPBMcqnLJe/qNMHKmazNZMPbMYOOS8Dvmo3DVb90nOoDGevjfz8GBHcHGJYASeGbw++2OPh0mXhJsLJJZIrlw3AWx7caoS+CqiacFJfAwstbNnlN/wG1XrHZL1bwl2bIjYLJh74ewdXmsfcB4DcsQACXwMNr+FiyYc2R7fnGKA5EeqaAfrPiNu/lO/ARcviC4mHooJfAwOrjbLc/5DvQ/xq3n9oLh04OLSXqOTz0TP17K89+I/U5KSimBh5Ff8x5zBgw+IdhYpOcZMN7dfH//sWriAVECDwNrYdvy2Bje271B+FXzlnQQyXPf/N3wSqytdBz0HhxcTD2EEngYbH4dHjz3sEajmrekh8L+8MEiWHhhrG3oNJj/YnAx9RBK4GFQU+mWF90HJWPdemF/KBoQXEwivvN/ED9kwyt3Q5XGFE8FJfAw8GveI2fF1x5F0kFBsfsOgm/lo1Dxj8DC6UmUwNPVptdiY5tsWeaWmr9SwiCSD3UHYPWzsbYB46Fk9NEfI12iBJ6OKj84suZtsjXZrIRDrwFQvx9+12KslEGT4HN/Cy6mDKUEno5q97rlud+LjW1S2N99gUIk3Z32JRh3Dtgmt/3SXbBjZbAxZSgl8HTk96kdOAGGTQs2FpHOys6BoVNi232GwpY3AwsnkymBp4ttK1zpBGDXKrdUP2/JBJF8N27Pyj/G2oaXQ7+ywELKFErg6eLXl7vJZFsqGhhMLCKJ1Huw60n12LxY29gz4bongospQyiBp4vafXDSp2Dm5912Xm/oo0mIJQPMugnGnx+riT/9RaitDjamDKEEng4ao2AbXa1wwLFBRyOSWFlZUHpMbLugBPZuCi6eDKIEHpTqbbDuL26ck8Z616Z+3tITRPLg4C5YttBtmywYfx70Kg02rhBSAg/Kyz+EpQ/Et/UdHkwsIqnUdzgc2OlKKb7ZX4UzbwsuppBSAg9K3X7oOwI+/Zzbzs7RRUvpGc7+Fpz82dj2T2bEZpeSTlECD0pjnZuEoe+woCMRSa2srPjf+5yC2Hg/0ilK4KliLaz8Q2zmkoo1qnmLgOsnvv1tWHK/2y4ohslXgzGBhhUGSuCpsncTPH5jfNvxlwUTi0g6KRkDG16CrctibUOmwKCJgYUUFu0mcGNMPvAykOfd/4/W2tuNMSXAI0AZsBG40lpblbxQQ67+oFte+lPXJxYgr09w8Yiki+uecNeEANYvhj/cAA01gYYUFlkduE8dcKa1djIwBZhrjJkJ3AosstaOAxZ523I0fo2voMR9RCwodrVAkZ4uKzv2N1FY4tpUE++Qds/ArbUW8C8R53g3C1wCzPHaFwKLga8lPMIwq/rQ1b2thX2bXVskN9iYRNJZtnddaMXD8OEStz7ubBg6NbiY0liHauDGmGxgGXAM8BNr7evGmEHW2u0A1trtxphW+8AZY+YD8wFGjhyZmKjD4s1fuBm7fTmFUDwquHhE0l3xCMgtgrd/G2vb8gZc84fgYkpjHUrg1tpGYIoxphh4whgzqaM7sNYuABYAlJeX264EGVr1Na5kcstat22M+7goIq3rMxRu3RwbN2XhRe7vSFrVqSKstXYvrlQyF9hpjBkC4C13JTq40IvWuT6u2RF3U/IWaV9WVuxvJidf9fA2dKQXygCgwVq71xhTAJwNfB94CrgBuNNbPpnMQENjzXPw4d/d+tal6ust0h2RfNi7El643W0XDXQjdqqPONCxEsoQYKFXB88CHrXW/skYswR41BgzD9gEfDyJcYbH8/8Ou9dAtnexcsJFwcYjEmZDp8K6RfDaT92InU1ROO5C6KdrSdCxXijvAEdcArbWVgJnJSOoUGs4BJOvgst+FnQkIuF3xlfdDdyMPo/NU0mlBXVETrTGOpVNRJLBn2JQCbyZvkqfCK8vcGUTgEN7NZelSDL4f1cv/xCKBrkRPE/5Z9dzpYdSAu+upkb481cgUuB6nOT1hmHlQUclknlKx0HfkbDxVdfNsHavG0dlxo3tPjRTKYF3l/9xbs6tcNq/BBqKSEbrNwq+tNKt1+2H7w3v8eUU1cC7K1rnliqbiKSO6uGAzsC7bvH3oWojRA+5bV24FEmdrAhg4L3/hcr1rl/49Hkw7KSgI0spJfCuqN0Hi78L+cVuSNj+x8DQKUFHJdJzGOOGZd6x0tXEq7e4kyglcGmXXzY58xs9+gKKSKCubjHg1T3HQ7Q+uFgCohp4V/h1N9W9RdJDJK9H1sN1Bt4ZL93lPrL5s+sogYukh0g+fPg3eOQ6tz19HoyZE2hIqaAE3hkvfR/y+0KvgW6MBtW9RdLDhAvh/adg91qoXOe+k6EELs0ao24gnZM/GxubQUTSw0f+zd0A/ntG7DpVhlMNvKMa/f7e6i4oktYiuT0mgesMvD3LH4L3n4TGBreturdIeovkw+bX4TdXgMmC074Eo04JOqqk0Bl4e5YthM1vQv0BGDkrY38RRDLGpCugZDQcqoJ1f4FVTwcdUdLoDLw90VooOy2+z6mIpK+Zn3M3gLvGZnT3Qp2Btydaq7q3SFhF8jO6Hq4z8NZsW+GmRmuKwt7NMHxG0BGJSFdE8mDNs/DgXLd+4b1uCNoMoTPw1qxfDBtfcQPmjJwJx18adEQi0hUnfRIGHe/G7V+/GDa/EXRECaUz8Nb4H7mufwqy9D9OJLRO/aK77dsK907MuHq4slNrorVuVnklb5HM0Dx+eGYNeKUzcJ+1bsbr3Wugejtk68KlSMbwOyK8eg+seMjVwT++0A1LG2JK4L7GBnj3MSgdDyNmwNBpQUckIomS2wtmfgGqNsCe9e7LeQ2HILcw6Mi6RQnc59fGpl0Pp9wUbCwikljGwNzvuvXXfgbPfs39zYc8gavI64tqrBORHsH/G8+A/uE6A//w7/D0zZqkQaSn8P/GHzgHsnPcKIYnfCzYmLpIZ+Cb33AXLoeVw5RrYeyZQUckIsk0Zg5MvRZGngzVW2HDy0FH1GU6A/c/Rl3xS8jKDjYWEUm+PkPgkp+49R+dEOpSis7Ao7WQlaPkLdITRfJD/eWennsG/pdvwVsPQ91+1b1FeqpIPqx+Bn54LJSMhU/+X6i+wNdzE7g/1smJV8KQyUFHIyJBmP0V+GAR7HgXNv09dF0L2/1XY4wZYYx50RizyhjznjHmZq+9xBjzgjFmrbfsl/xwEyha6xL3RfdB+aeDjkZEgjDxYpcD/F4oISundOSzQhT4srV2AjAT+IIxZiJwK7DIWjsOWORth0e0Tn2+RcTxc0FjuMZKaTeBW2u3W2uXe+v7gVXAMOASYKF3t4XApUmKMbEWXgTfHea6DkYKgo5GRNKBnwv+ayrcNQa2vx1sPB3UqWq9MaYMmAq8Dgyy1m4Hl+SBgUd5zHxjzFJjzNKKiopuhpsAm16H0mNh1k0w87NBRyMi6WDcOXD6l2HS5VBTCRVrgo6oQzqcwI0xRcBjwL9Ya6s7+jhr7QJrbbm1tnzAgAFdiTFxrIXGOvdmnfsdXbwUEadXfzjrm3CGVwkOSS28QwncGJODS94PW2sf95p3GmOGeD8fAuxKTogJpPFORKQtzeOGZ0gCN8YY4AFglbX2nhY/egq4wVu/AXgy8eElSFMT/Pgk+M4gt50Tnm5CIpJCOV4t/Jlb4Fv9YMXDwcbTjo70Az8VuA5YaYx5y2v7N+BO4FFjzDxgE/DxpESYCNFDULnOjYEw6jSYdEXQEYlIOsorcl+z37sJXrkHdr0fdERtajeBW2tfBY42bcVZiQ0nSfzSybHn6cKliLRt6rVu+caCtB8nJTzfGe2qpkY4VOXWc/SVeRHpoEgB1B+A+pqgIzmqzE/g98+CH3vTo+X0CjYWEQmP3EJ4+3fw3SHw6r1BR9OqzB4LxVr3hZ0xc2D8+TB+btARiUhYXHQfbF0Gr9wNu9cFHU2rMvsMvLEBsFB2Opz8GcjrHXREIhIWZafBqTdDYWnadivM3ATeGIV9m926hosVka6K5EPtPji4232qTyOZm8AfvS5W+85V7VtEuiivCNa9AD8YC6/dH3Q0cTI3gVd9CINPgIt/rH7fItJ1F9ztbjmFrn94GsncBB6thdLxMO16yO8TdDQiElaDT4Dp/89dQ2s4FHQ0cTIvgVvrzr7rD6r2LSKJE8lzIxXu2xp0JM0yL4G/8Qu470Q4sEO9TkQkcfL6wj/+BPdOhDXPBR0NkIkJvHqrm+vysgVufF8RkUS4fAGcd5dbr94WbCyezEvgjfXuYsPkT0BRwOOPi0jmGDQRTvDG7EuTqdcyK4FXb3M3jfctIsng55aqjVC9PdBQIJMSeN1+uG8yvP+/kN836GhEJBNF8iE7z/UHv28y1B0INJzMSeC1+9zHmuk3wlW/CzoaEclEWdkw73kon+emZ6zr8OySyQkn0L0nkj9u74gZMODYYGMRkcw1dAoMn+7WAx4jJTMSeLTOjRoGkJ0bbCwikvn8WviWpYGWUTIjgb/6I3j8Rrde0C/QUESkBygsccvHb4Rnbw0sjMxI4DW7Ibc3fPo5N3SsiEgylc12+abfaKjZE1gYmTGhQ7TWjRg2cmbQkYhIT5CV5fJNYf9A6+DhPwOv2ujGPlHfbxFJtUg+HNgJ298OZPfhT+C/uQI2vOT+E4qIpFKv/rDzXfj5bDiwK+W7D38CP1QFx10I//Ro0JGISE9z8X/DnK+79dp9Kd99+BN4tB6KR0Gv0qAjEZGeJr8PDJzo1v3voqRQuBP4ptcgekj1bxEJjj/vwPrFbh6CFApvAj+4Gx48F5qiUDQw6GhEpKfy88/zt8HyX6d01+FN4P4YBGffATM+E2goItKDDZ0CX1zh1lNcBw9vAvfrTf3KXJ9MEZGglIyBrJyU9wkPZ+azFt57wq1nq/4tImkgkgc73knpzPXhTOC73oeXvu/Wew8ONhYREYA+Q2HdX+CFb6Zsl+0mcGPMg8aYXcaYd1u0lRhjXjDGrPWWqR1Bqtarf1+2AIZNS+muRURadeOLMGiSm1wmRTpyBv4rYO5hbbcCi6y144BF3nbq+HWm4pEp3a2IyFHlFUFen5T2B283gVtrXwYOH27rEmCht74QuDSxYbUZELz1sFtX/28RSSeRPNj4CuxYmZLddbUGPshaux3AWx61I7YxZr4xZqkxZmlFRUUXd9fC7jWw8g9uXfVvEUkn/Ua55TNfScnukn4R01q7wFpbbq0tHzBgQPef0J/94mMPuosGIiLp4oJ7YfiMlM3S09UEvtMYMwTAW6ZuGC6//q3RB0Uk3WRlQd9hbsLjVOyui497CrjBW78BeDIx4XTASm/UQX/8ARGRdBLJd6XercuTvquOdCP8HbAEGG+M2WKMmQfcCXzUGLMW+Ki3nXzROlj2K7feZ1hKdiki0iml49wyBf3B251SzVp79VF+dFaCY2lfwyG3PPtbUDwi5bsXEWnX6V+GdX+Fhpqk7ypc38T0+1fm9Q42DhGRthQUp6Q/eLgS+PrFbqn+3yKSzrJz3VRrFauTuptwJfA3f+GWpccGG4eISFsGn+CWS36S1N2EK4E3NsC4c2DEjKAjERE5utP/1Q117V+3S5KQJfB6dR8UkXCI5Ce9P3h4EvjBSjeMrBK4iIRBdi68/yQcqkraLsKTwNc+55bqPigiYeCPlvrBi0nbRXgSuF9LmjE/2DhERDri3O+4ZRKnWQtPAvf7VKoLoYiEgV/uVQIHVj3llqqBi0gY+Cebyx9K2i7Ck8D3bXVLJXARCYO8Pm65492279cN4UngTQ0w7XowJuhIRETal5XtxkXBJm8XSXvmRIvW6uxbRMIlkg9NUWiMJuXpw5HA62tcX0pdwBSRMPFzVsWqpDx9OBL4jnfcMrco2DhERDpj4PFuuf6lpDx9OBK43w1n9Oxg4xAR6Yyy09wySV+pD0kCVx9wEQkhP2claWzwcCRwf0xdXcQUkTAxBkwWrH0+KU8fjgTeWO+WfYYGG4eISGfZJtcTJQnCkcCjdYCB/OKgIxER6ZzjL4NofVKeOiQJvNbVkvQlHhEJm+w8iCZnYod2Z6VPB+9trmBUY4QbF7wWdCiSIfYcrGf1zv3MGtM/rv1QQyNvbd57RPvB+ijvbNl3RPumPTVs3XvoiPYl6yvJz8li6oh+zW0Wy2vr9zCjrITsrNjJyGsbKrGWuOeoqY/ydiv7W7K+kikjiinIyY5rGzugFwN757fbNnVkMfkR99jq2gbe21Ydt48l6yvJzjLMKCsBoNFa3tiwh5ljSjCY5vv0K8zhuMHuq+KrdlSzt6ah+Xl27q9lfcXB5u2KA3Ws23WgeXv97gPsrK5r3l6yvpKcbEP5qJIjjnHNzv1UHqxn1pj+za/f9LJ+RLKyWLK+ktKiXMYN7M2S9ZWAew2XbaqiPtrErDH9WbK+kiwDJ49268P7FTC0bwFvbHTH9PqGPVgL08v68ebGKmaOKeG19XsYWVLIpj015GQbmiw0NrlvU2YZaLIwuE8+O6prKczNpqa+kbZ8N7KTf4psAmsTfhIaigS+P38I/8g9vvlFFOmu1Tv3A1BVU0+f/Jzm9rc27wVg94E6+hXmNre/s2Uf4BJ/34LY/bfudWdWNfVR8iKxpApQ29AU9zu7sfIgAMs3VTFtZIvE7t2l5X3f9vZXebCO4oLc5lj9GP0Ee7De1VY/qDhI/155R23b4z12xabYY9/bVt3qa9DYZJtj8V+PdbsOMKa0iEYv2Kqahub77K1pAKChsYksY1hf4Y5zf10DhTkR1u06ALh/GL1yI+ysdj0yDjU0kpud5T3W7fPwY6w86Lbroo3Nj3tzYxXTRhYDsPtAPWNKY69bY5OlPtoU93o22dhrsqXqEFuq3Hv2YWVN82v/5kY36cLanS7WTXtqmuNqyX+LdlS7rs3tJW+ACC4e6vZDfp92798ZoUjgM6+5HYBHA45DMsfl9/+N5Zv28u1LJ1HuJTSAGx58g5fWVPBv50/gI8cNbG6/7P6/sWLTXu64+HhmjY2dsZbd+n8A/OKG8rizXb/90c/Oam578NUN/Mef3ufamaO44+Lj27zvlT9fwhsb9nD7Rcdz6jGlACxevYtP/s+bnD6ulF/POxmAtTv389F7X6YwN7v58Wt27uece1+md16kue3Zd3fw2d8s45yJg1hwfTkAc3/0Mv/YsT/uNTg8lm//6X1++eoG5s8ew/zZY6mLNjL+G8/G3cd/zO/nzySSncWxt/2Z+sYm7rtqKmMHFHHiHc9RXRvlniunMH5w7+b7/+pT0ykuzI3b54urd/GpFsfo/+y3N87kiRVb+frjK7lq+gj+/cKJHH/7c82Pa/kcra3/7NqTOOXOvwJwzckjefj1TXx+zlj+/cn3ALh6xgh+98ZmPjF9BPcv/oBEWmlHcyUvxTpjJFA4auAiCWaO8lG2ufkon3RzI63/yRx+9t3mc3eAf9eWpRY/5pax5+ccud88P8ZW9tdaDK09x9Hu758xtyaSHb9fPw4/3rzDXrvW9tv88h+245aPNebI52pPy3219T609Vp0VR3ep5skjAuuBC49Uk62nwzj23t7pYTD/8aLvbLJ4fcvynMfYjuSUPwEl5fT/n1blml8fi7vnR/74Ozvd3Cf/BZtLgkNatHmP7ZXXuyxJb1yvZ8dPaP5CS2SFZ+M2zKkr9uv/89uUJ+8uG0/ltb+GfixtDxGf7/+P7P8nOzYP4sO8l+n4sKc5usH2Vmx5yjIcfsrzHU/Ky1K3JcG66x7L2sP1STsOX2hKKGIJNrdV07hgVc2MKXFRUaAm886hpwsw/QWZRWAr849jpJeeUwcEl/DvP+aaby4etcRCfxrc49jYO/4JHD+pMEs/7CK62eVxbV/+9JJNDQ2xbXdcu54ivIjTB5e3Nx20qh+XD5tGJ+fM7a5rbQoj6tnjOSKacOa2wb0zuPqGSP42EnDm9tOOaaUy6cO4wsfOaa57bYLJvA/f9vI2IG9mtu+ccEECnNjaeHqGSPZUnWIiybHvoNx00eOYdKw2Otw/zXT+Mf26ubtH358Mk+s2EqpV3//3uUn8sdlm5v/oTxww3R3cdFLyLedP4E+BW6f5WX9uHzqMD7/EXeMP7v2JFZu3QvARycM4rKpw/j0qaMB+Nycsc3XEv7z0kk0eq/hXR87kSqvdv7NCyeSl5NFYW42nzyljLMnDGJU/0J2H6hn7qTBNFlLQ2MTZ08YRFVNPVdMG86mPTXMPX4w722rpk9BhMYmV683uH++6ysOctzg3vz9g0omDOnNzuo69hysJzvLsL+2gd0H6tlzsJ7dB+rYXxulDvePcntlFaMHk1DG2tRdGCwvL7dLly5N2f5ERIL21qLfM+WVz7Dm4qc4dtoZXXoOY8wya2354e0qoYiIJFF2TgEA0frE9wVXAhcRSaJInisdNSqBi4iES1H/oSwvmk1B34Ht37mTunUR0xgzF7gPyAZ+aa29MyFRiYhkiBHHnMCIW55OynN3+QzcGJMN/AQ4D5gIXG2MmZiowEREpG3dKaHMANZZa9dba+uB3wOXJCYsERFpT3cS+DBgc4vtLV5bHGPMfGPMUmPM0oqKim7sTkREWupOAm/tK1lHdCq31i6w1pZba8sHDBjQjd2JiEhL3UngW4ARLbaHA9u6F46IiHRUdxL4m8A4Y8xoY0wucBXwVGLCEhGR9nS5G6G1NmqMuQl4DteN8EFr7XsJi0xERNrUrX7g1tpngGcSFIuIiHRCSgezMsZUAB928eGlwO4EhhMGOuaeQcfcM3TnmEdZa4/oBZLSBN4dxpilrY3Glcl0zD2DjrlnSMYxaywUEZGQUgIXEQmpMCXwBUEHEAAdc8+gY+4ZEn7MoamBi4hIvDCdgYuISAtK4CIiIRWKBG6MmWuMWW2MWWeMuTXoeBLFGLPRGLPSGPOWMWap11ZijHnBGLPWW/Zrcf+ve6/BamPMucFF3nHGmAeNMbuMMe+2aOv0MRpjTvJeq3XGmP8yxrQ2mFpaOMox32GM2eq9128ZY85v8bNMOOYRxpgXjTGrjDHvGWNu9toz9r1u45hT915ba9P6hvua/gfAGCAXeBuYGHRcCTq2jUDpYW13Abd667cC3/fWJ3rHngeM9l6T7KCPoQPHOBuYBrzbnWME3gBm4UbB/DNwXtDH1sljvgO4pZX7ZsoxDwGmeeu9gTXesWXse93GMafsvQ7DGXhPmzjiEmCht74QuLRF+++ttXXW2g3AOtxrk9astS8Dew5r7tQxGmOGAH2stUus+21/qMVj0s5RjvloMuWYt1trl3vr+4FVuPkBMva9buOYjybhxxyGBN6hiSNCygLPG2OWGWPme22DrLXbwf2CAP5MqJn0OnT2GId564e3h81Nxph3vBKLX0rIuGM2xpQBU4HX6SHv9WHHDCl6r8OQwDs0cURInWqtnYabV/QLxpjZbdw3k18H39GOMROO/afAWGAKsB2422vPqGM2xhQBjwH/Yq2tbuuurbSF8rhbOeaUvddhSOAZO3GEtXabt9wFPIEriez0PlLhLXd5d8+k16Gzx7jFWz+8PTSstTuttY3W2ibgF8TKXxlzzMaYHFwie9ha+7jXnNHvdWvHnMr3OgwJPCMnjjDG9DLG9PbXgXOAd3HHdoN3txuAJ731p4CrjDF5xpjRwDjchY8w6tQxeh+99xtjZnpX569v8ZhQ8JOY5zLcew0ZcsxejA8Aq6y197T4Uca+10c75pS+10Ffye3g1d7zcVd4PwBuCzqeBB3TGNwV6beB9/zjAvoDi4C13rKkxWNu816D1aTplflWjvN3uI+RDbgzjXldOUag3PtD+AD4b7xvEafj7SjH/GtgJfCO94c8JMOO+TTcx/53gLe82/mZ/F63ccwpe6/1VXoRkZAKQwlFRERaoQQuIhJSSuAiIiGlBC4iElJK4CIiIaUELiISUkrgIiIh9f8BtM+GU/DGE/MAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.diff(A_indptr), label=\"A\")\n",
"plt.plot(np.diff(L_indptr), label=\"L\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed1cce62-346e-4710-ad6f-64e2246cbdf0",
"metadata": {
"id": "ed1cce62-346e-4710-ad6f-64e2246cbdf0",
"outputId": "94011e5d-1b3e-4cb1-833a-bcc1fe2b0e7e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of non-zeros is 87025 (fill in of 79625), \n",
"which is less than the unpermuted matrix, which had 125049 non-zeros.\n"
]
}
],
"source": [
"nnz_rcm = len(L_x)\n",
"print(\n",
" f\"\"\"Number of non-zeros is {nnz_rcm} (fill in of {len(L_x) - len(A_x)}), \n",
"which is less than the unpermuted matrix, which had {nnz} non-zeros.\"\"\"\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4212f971",
"metadata": {
"id": "4212f971"
},
"outputs": [],
"source": [
"L_indices, L_indptr = _symbolic_factor_csc(A_indices, A_indptr)\n",
"L_x0 = _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"# can reuse the same compiled function as the static shape config does not change by permutation\n",
"L_x_mat, L_indptr_mat = partial_deep_copy_csc(\n",
" A_indices, A_indptr, A_x, L_indices, L_indptr\n",
")\n",
"L_x = L_x_mat[L_indptr_mat != -1]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ae429aa",
"metadata": {
"id": "0ae429aa"
},
"outputs": [],
"source": [
"np.testing.assert_array_almost_equal(L_x, L_x0)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adeb463b",
"metadata": {
"id": "adeb463b"
},
"outputs": [],
"source": [
"L_x_chol0 = _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"L_x_chol0 = _sparse_cholesky_csc_impl(L_indices, L_indptr, L_x_chol0)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c905e2a5",
"metadata": {
"id": "c905e2a5"
},
"outputs": [],
"source": [
"n = len(L_indptr) - 1\n",
"descendant = [[] for _ in range(0, n)]\n",
"for j in range(0, n):\n",
" for idx in range(L_indptr[j] + 1, L_indptr[j + 1]):\n",
" descendant[L_indices[idx]].append((j, idx))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61f2a289",
"metadata": {
"id": "61f2a289"
},
"outputs": [],
"source": [
"counter = jnp.zeros(L_indptr.shape[0] - 1, dtype=np.int32)\n",
"m_L = np.max(np.diff(L_indptr))\n",
"# n+1 to store placeholder stuff\n",
"descendant_arr = jnp.ones((n + 1, m_L - 1, 2), dtype=np.int32) * -1\n",
"# L_indices_mat = L_indices[L_indptr_mat]\n",
"\n",
"\n",
"def one_step(carry, x):\n",
" (descendant_arr, counter) = carry\n",
" (j, L_indptr_slice, L_indices_slice) = x\n",
" idx_ptr = L_indptr_slice[1:]\n",
" idx = jnp.where(idx_ptr != -1, L_indices_slice[1:], idx_ptr)\n",
" to_update = jnp.where(\n",
" idx_ptr != -1,\n",
" jnp.vstack([jnp.ones_like(idx_ptr) * j, idx_ptr]),\n",
" jnp.ones((2, m_L - 1), dtype=idx_ptr.dtype) * -1,\n",
" )\n",
" # for i in range(m_L):\n",
" # if idx_ptr[i] != -1:\n",
" # descendant_arr = descendant_arr.at[idx[i], counter[idx[i]]].set(to_update[:, i])\n",
" descendant_arr_next = descendant_arr.at[idx, counter[idx]].set(to_update.T)\n",
" counter_next = counter.at[idx].add(idx_ptr != -1)\n",
" return (descendant_arr_next, counter_next), None\n",
"\n",
"\n",
"(descendant_arr, counter), _ = jax.lax.scan(\n",
" one_step,\n",
" (descendant_arr, counter),\n",
" (jnp.arange(n), L_indptr_mat, L_indices[L_indptr_mat]),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b47f12c",
"metadata": {
"id": "3b47f12c"
},
"outputs": [],
"source": [
"for j in range(1, n):\n",
" np.testing.assert_array_equal(\n",
" jnp.asarray(descendant[j]), descendant_arr[j, : len(descendant[j])]\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6da62499",
"metadata": {
"id": "6da62499"
},
"outputs": [],
"source": [
"def _sparse_cholesky_csc_impl_jax(m_L, L_indices, L_indptr_mat, L_x):\n",
" n = L_indptr_mat.shape[0]\n",
" counter = jnp.zeros(n, dtype=np.int32)\n",
" # n+1 to store placeholder stuff\n",
" descendant_arr = jnp.ones((n + 1, m_L - 1, 2), dtype=np.int32) * -1\n",
" L_indices_mat = L_indices[L_indptr_mat]\n",
" L_x_mat = L_x[L_indptr_mat]\n",
" L_x_chol = jnp.zeros_like(L_x_mat)\n",
"\n",
" def _inner_none_op(carry, _):\n",
" return carry, None\n",
"\n",
" def _inner_one_step(carry, bebe):\n",
" (L_x_chol, L_indices_slice, tmp) = carry\n",
" k = bebe[0]\n",
" L_indices_slice_k = L_indices_mat[k]\n",
" L_x_chol_slice_k = L_x_chol[k]\n",
" Ljk = L_x_chol_slice_k[bebe[1]]\n",
" pad = jnp.argwhere(L_indices_slice_k == L_indices_slice[0], size=1)[0][0]\n",
" L_indices_slice_k_pad = jnp.roll(\n",
" jnp.where(\n",
" (jnp.arange(m_L) < pad) | (L_indptr_mat[k] == -1),\n",
" jnp.ones_like(L_indices_slice_k) * -1,\n",
" L_indices_slice_k,\n",
" ),\n",
" -pad,\n",
" )\n",
" L_x_chol_slice_k_pad = jnp.roll(\n",
" jnp.where(\n",
" jnp.arange(m_L) < pad,\n",
" jnp.zeros_like(L_x_chol_slice_k),\n",
" L_x_chol_slice_k,\n",
" ),\n",
" -pad,\n",
" )\n",
" _update_idx = jnp.nonzero(\n",
" jnp.in1d(L_indices_slice, L_indices_slice_k_pad), size=m_L\n",
" )[0]\n",
" update_idx = jnp.where(\n",
" jnp.concatenate([jnp.ones([1]), jnp.diff(_update_idx)]) > 0,\n",
" _update_idx,\n",
" jnp.ones_like(_update_idx) * -1,\n",
" )\n",
" tmp = tmp.at[update_idx].add(-Ljk * L_x_chol_slice_k_pad)\n",
" return (L_x_chol, L_indices_slice, tmp), None\n",
"\n",
" def inner_one_step(carry, bebe):\n",
" return jax.lax.cond(bebe[0] != -1, _inner_one_step, _inner_none_op, carry, bebe)\n",
"\n",
" def one_step(carry, x):\n",
" (descendant_arr, counter, L_x_chol) = carry\n",
" (j, L_indptr_slice, L_indices_slice, _tmp) = x\n",
" # Append 1 more value to the end to account for the placeholder computation\n",
" tmp = jnp.concatenate([_tmp, jnp.zeros([1])])\n",
" idx_ptr = L_indptr_slice[1:]\n",
" idx = jnp.where(idx_ptr != -1, L_indices_slice[1:], idx_ptr)\n",
" to_update = jnp.where(\n",
" idx_ptr != -1,\n",
" jnp.vstack([jnp.ones_like(idx_ptr) * j, jnp.arange(1, m_L)]),\n",
" jnp.ones((2, m_L - 1), dtype=idx_ptr.dtype) * -1,\n",
" )\n",
" descendant_arr_next = descendant_arr.at[idx, counter[idx]].set(to_update.T)\n",
" counter_next = counter.at[idx].add(idx_ptr != -1)\n",
"\n",
" bebe = descendant_arr_next[j]\n",
" (L_x_chol, L_indices_slice, tmp), _ = jax.lax.scan(\n",
" inner_one_step, (L_x_chol, L_indices_slice, tmp), bebe\n",
" )\n",
"\n",
" diag = jnp.sqrt(tmp[:1])\n",
" L_x_out = jnp.concatenate([diag, tmp[1:-1] / diag])\n",
" L_x_chol = L_x_chol.at[j].set(L_x_out)\n",
" return (descendant_arr_next, counter_next, L_x_chol), None\n",
"\n",
" (descendant_arr, counter, L_x_chol), _ = jax.lax.scan(\n",
" one_step,\n",
" (descendant_arr, counter, L_x_chol),\n",
" (jnp.arange(n), L_indptr_mat, L_indices_mat, L_x_mat),\n",
" )\n",
" return L_x_chol, descendant_arr\n",
"\n",
"\n",
"m_L = np.max(np.diff(L_indptr))\n",
"partial_sparse_cholesky_csc_impl = jax.jit(partial(_sparse_cholesky_csc_impl_jax, m_L))\n",
"L_x_chol2, descendant_arr = partial_sparse_cholesky_csc_impl(\n",
" L_indices, L_indptr_mat, L_x\n",
")\n",
"L_x_chol_ = L_x_chol2[L_indptr_mat != -1]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42c14161-bd5a-4b02-a9df-3434d028cccd",
"metadata": {
"id": "42c14161-bd5a-4b02-a9df-3434d028cccd"
},
"outputs": [],
"source": [
"np.testing.assert_array_almost_equal(L_x_chol0, L_x_chol_)\n"
]
},
{
"cell_type": "markdown",
"id": "fff3412d",
"metadata": {
"id": "fff3412d"
},
"source": [
"Putting everything together, in a Python Class.\n",
"\n",
"Why a Python Class? There are number of things we want to compute them and store them as static parameter. Much easier to do with a class\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00ea1cb1",
"metadata": {
"id": "00ea1cb1"
},
"outputs": [],
"source": [
"def _symbolic_factor_csc(A_indices, A_indptr):\n",
" # Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.\n",
" n = len(A_indptr) - 1\n",
" L_sym = [np.array([], dtype=int) for j in range(n)]\n",
" children = [np.array([], dtype=int) for j in range(n)]\n",
"\n",
" for j in range(n):\n",
" L_sym[j] = A_indices[A_indptr[j] : A_indptr[j + 1]]\n",
" for child in children[j]:\n",
" tmp = L_sym[child][L_sym[child] > j]\n",
" L_sym[j] = np.unique(np.append(L_sym[j], tmp))\n",
"\n",
" if len(L_sym[j]) > 1:\n",
" p = L_sym[j][1]\n",
" children[p] = np.append(children[p], j)\n",
"\n",
" L_indptr = np.zeros(n + 1, dtype=int)\n",
" L_indptr[1:] = np.cumsum([len(x) for x in L_sym])\n",
" L_indices = np.concatenate(L_sym)\n",
"\n",
" return L_indices, L_indptr\n",
"\n",
"\n",
"@partial(jax.jit, static_argnums=(0,))\n",
"def fix_length_arange(max_length, start, stop):\n",
" output_idx = np.arange(max_length) + start\n",
" output_null = jnp.ones_like(output_idx) * -1\n",
" output = jnp.where(np.arange(max_length) < (stop - start), output_idx, output_null)\n",
" return output\n",
"\n",
"\n",
"def _deep_copy_csc_vmap(m_A, m_L, A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
" A_indptr_mat = jax.vmap(partial(fix_length_arange, m_A))(\n",
" A_indptr[:-1], A_indptr[1:]\n",
" )\n",
" L_indptr_mat = jax.vmap(partial(fix_length_arange, m_L))(\n",
" L_indptr[:-1], L_indptr[1:]\n",
" )\n",
"\n",
" def row_fun(row_idx_A, row_idxptr_A, row_idx_L, row_val_A):\n",
" row_idx_A_ = jnp.where(\n",
" row_idxptr_A != -1, row_idx_A, jnp.ones_like(row_idx_A) * -1\n",
" )\n",
" out = jnp.zeros(m_L)\n",
" copy_idx = jnp.nonzero(jnp.in1d(row_idx_L, row_idx_A_), size=m_A)[0]\n",
" out = out.at[copy_idx].set(row_val_A)\n",
" return out\n",
"\n",
" L_x_mat = jax.vmap(row_fun)(\n",
" A_indices[A_indptr_mat],\n",
" A_indptr_mat,\n",
" L_indices[L_indptr_mat],\n",
" A_x[A_indptr_mat],\n",
" )\n",
" return L_x_mat, L_indptr_mat\n",
"\n",
"\n",
"def _deep_copy_csc_scan(m_L, A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
" # Pad value otherwise out of bound indexing result is weird.\n",
" L_indices_padded = jnp.pad(L_indices, [0, m_L], constant_values=-1)\n",
"\n",
" def one_step(_, i):\n",
" A_idx = A_indices[i]\n",
" j = jnp.argwhere(i < A_indptr, size=1)[0][0]\n",
" # L_indices_slice = L_indices[L_indptr[j-1]:L_indptr[j]]\n",
" L_indices_slice = jax.lax.dynamic_slice(\n",
" L_indices_padded, [L_indptr[j - 1]], [m_L]\n",
" )\n",
" L_indices_slice = jnp.where(\n",
" jnp.arange(m_L) < L_indptr[j] - L_indptr[j - 1],\n",
" L_indices_slice,\n",
" jnp.ones_like(L_indices_slice) * -1,\n",
" )\n",
" k = jnp.argwhere(A_idx == L_indices_slice, size=1)[0][0]\n",
" to_write_index = k + L_indptr[j - 1]\n",
" return None, to_write_index\n",
"\n",
" _, update_index = jax.lax.scan(one_step, None, jnp.arange(A_indices.shape[-1]))\n",
"\n",
" L_x = jnp.zeros_like(L_indices, dtype=A_x.dtype)\n",
" L_x = L_x.at[update_index].set(A_x)\n",
" return L_x, update_index\n",
"\n",
"\n",
"def _sparse_cholesky_csc_impl_jax(m_L, L_indices_mat, L_indptr_mat, L_x_mat):\n",
" n = L_indptr_mat.shape[0]\n",
" counter = jnp.zeros(n, dtype=np.int32)\n",
" # n+1 to store placeholder stuff\n",
" descendant_arr = jnp.ones((n + 1, m_L - 1, 2), dtype=np.int32) * -1\n",
" L_x_chol = jnp.zeros_like(L_x_mat)\n",
"\n",
" def _inner_none_op(carry, _):\n",
" return carry, None\n",
"\n",
" def _inner_one_step(carry, bebe):\n",
" (L_x_chol, L_indices_slice, tmp) = carry\n",
" k = bebe[0]\n",
" L_indices_slice_k = L_indices_mat[k]\n",
" L_x_chol_slice_k = L_x_chol[k]\n",
" Ljk = L_x_chol_slice_k[bebe[1]]\n",
" pad = jnp.argwhere(L_indices_slice_k == L_indices_slice[0], size=1)[0][0]\n",
" L_indices_slice_k_pad = jnp.roll(\n",
" jnp.where(\n",
" (jnp.arange(m_L) < pad) | (L_indptr_mat[k] == -1),\n",
" jnp.ones_like(L_indices_slice_k) * -1,\n",
" L_indices_slice_k,\n",
" ),\n",
" -pad,\n",
" )\n",
" L_x_chol_slice_k_pad = jnp.roll(\n",
" jnp.where(\n",
" jnp.arange(m_L) < pad,\n",
" jnp.zeros_like(L_x_chol_slice_k),\n",
" L_x_chol_slice_k,\n",
" ),\n",
" -pad,\n",
" )\n",
" _update_idx = jnp.nonzero(\n",
" jnp.in1d(L_indices_slice, L_indices_slice_k_pad), size=m_L\n",
" )[0]\n",
" update_idx = jnp.where(\n",
" jnp.concatenate([jnp.ones([1]), jnp.diff(_update_idx)]) > 0,\n",
" _update_idx,\n",
" jnp.ones_like(_update_idx) * -1,\n",
" )\n",
" tmp = tmp.at[update_idx].add(-Ljk * L_x_chol_slice_k_pad)\n",
" return (L_x_chol, L_indices_slice, tmp), None\n",
"\n",
" def inner_one_step(carry, bebe):\n",
" return jax.lax.cond(bebe[0] != -1, _inner_one_step, _inner_none_op, carry, bebe)\n",
"\n",
" def one_step(carry, x):\n",
" (descendant_arr, counter, L_x_chol) = carry\n",
" (j, L_indptr_slice, L_indices_slice, _tmp) = x\n",
" # Append 1 more value to the end to account for the placeholder computation\n",
" tmp = jnp.concatenate([_tmp, jnp.zeros([1])])\n",
" idx_ptr = L_indptr_slice[1:]\n",
" idx = jnp.where(idx_ptr != -1, L_indices_slice[1:], idx_ptr)\n",
" to_update = jnp.where(\n",
" idx_ptr != -1,\n",
" jnp.vstack([jnp.ones_like(idx_ptr) * j, jnp.arange(1, m_L)]),\n",
" jnp.ones((2, m_L - 1), dtype=idx_ptr.dtype) * -1,\n",
" )\n",
" descendant_arr_next = descendant_arr.at[idx, counter[idx]].set(to_update.T)\n",
" counter_next = counter.at[idx].add(idx_ptr != -1)\n",
"\n",
" bebe = descendant_arr_next[j]\n",
" (L_x_chol, L_indices_slice, tmp), _ = jax.lax.scan(\n",
" inner_one_step, (L_x_chol, L_indices_slice, tmp), bebe\n",
" )\n",
"\n",
" diag = jnp.sqrt(tmp[:1])\n",
" L_x_out = jnp.concatenate([diag, tmp[1:-1] / diag])\n",
" L_x_chol = L_x_chol.at[j].set(L_x_out)\n",
" return (descendant_arr_next, counter_next, L_x_chol), None\n",
"\n",
" (descendant_arr, counter, L_x_chol), _ = jax.lax.scan(\n",
" one_step,\n",
" (descendant_arr, counter, L_x_chol),\n",
" (jnp.arange(n), L_indptr_mat, L_indices_mat, L_x_mat),\n",
" )\n",
" return L_x_chol, descendant_arr\n",
"\n",
"\n",
"class sparse_csc:\n",
" def __init__(self, A_indices, A_indptr, use_scan=False):\n",
" self.use_scan = use_scan\n",
" L_indices, L_indptr = _symbolic_factor_csc(A_indices, A_indptr)\n",
" # size of the memeory, need to be static\n",
" self.m_A = np.max(np.diff(A_indptr))\n",
" self.m_L = np.max(np.diff(L_indptr))\n",
" \n",
" self.A_indices = jnp.asarray(A_indices)\n",
" self.A_indptr = jnp.asarray(A_indptr)\n",
" self.L_indices = jnp.asarray(L_indices)\n",
" self.L_indptr = jnp.asarray(L_indptr)\n",
" self.L_indptr_mat = jax.vmap(partial(fix_length_arange, m_L))(\n",
" L_indptr[:-1], L_indptr[1:]\n",
" )\n",
" self.mask = self.L_indptr_mat != -1\n",
"\n",
" @partial(jax.jit, static_argnums=(0,))\n",
" def chol(self, A_x):\n",
" if self.use_scan:\n",
" L_x, _ = _deep_copy_csc_scan(\n",
" self.m_L,\n",
" self.A_indices,\n",
" self.A_indptr,\n",
" A_x,\n",
" self.L_indices,\n",
" self.L_indptr,\n",
" )\n",
" L_x_mat = L_x[self.L_indptr_mat]\n",
" else:\n",
" L_x_mat, _ = _deep_copy_csc_vmap(\n",
" self.m_A,\n",
" self.m_L,\n",
" self.A_indices,\n",
" self.A_indptr,\n",
" A_x,\n",
" self.L_indices,\n",
" self.L_indptr,\n",
" )\n",
" L_indices_mat = self.L_indices[self.L_indptr_mat]\n",
" L_x_chol, _ = _sparse_cholesky_csc_impl_jax(\n",
" m_L, L_indices_mat, self.L_indptr_mat, L_x_mat\n",
" )\n",
" return L_x_chol[self.mask]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67c7dbfc",
"metadata": {
"id": "67c7dbfc"
},
"outputs": [],
"source": [
"A_perm = A[perm[:, None], perm]\n",
"A_perm_lower = sparse.tril(A_perm, format=\"csc\")\n",
"A_indices = A_perm_lower.indices\n",
"A_indptr = A_perm_lower.indptr\n",
"A_x = A_perm_lower.data\n",
"\n",
"L_indices, L_indptr, L_x = sparse_cholesky_csc(A_indices, A_indptr, A_x)\n",
"L_x_ = sparse_csc(A_indices, A_indptr, True).chol(A_x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1201e44",
"metadata": {
"id": "d1201e44"
},
"outputs": [],
"source": [
"np.testing.assert_array_almost_equal(L_x, L_x_)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
},
"vscode": {
"interpreter": {
"hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf"
}
},
"colab": {
"name": "Sparse cholesky in JAX.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment