Skip to content

Instantly share code, notes, and snippets.

@junpenglao
Last active July 22, 2022 20:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save junpenglao/f5b48c34dd8ea5029202fb607806ea0f to your computer and use it in GitHub Desktop.
Save junpenglao/f5b48c34dd8ea5029202fb607806ea0f to your computer and use it in GitHub Desktop.
Sparse cholesky in JAX.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/junpenglao/f5b48c34dd8ea5029202fb607806ea0f/sparse-cholesky-in-jax.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"id": "97b6acba-83df-4220-a98d-f538b299acf0",
"metadata": {
"id": "97b6acba-83df-4220-a98d-f538b299acf0"
},
"source": [
"Implmentation from https://github.com/dpsimpson/blog/blob/master/_posts/2022-03-23-getting-jax-to-love-sparse-matrices/Python/\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5363916",
"metadata": {
"id": "a5363916"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48cd55a9-5202-4f3d-b4ed-25e42eb1430d",
"metadata": {
"id": "48cd55a9-5202-4f3d-b4ed-25e42eb1430d"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"from functools import partial"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf756eac-bdb4-447b-8f3f-90b11c9462aa",
"metadata": {
"id": "bf756eac-bdb4-447b-8f3f-90b11c9462aa"
},
"outputs": [],
"source": [
"def _symbolic_factor_csc(A_indices, A_indptr):\n",
" # Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.\n",
" n = len(A_indptr) - 1\n",
" L_sym = [np.array([], dtype=int) for j in range(n)]\n",
" children = [np.array([], dtype=int) for j in range(n)]\n",
"\n",
" for j in range(n):\n",
" L_sym[j] = A_indices[A_indptr[j] : A_indptr[j + 1]]\n",
" for child in children[j]:\n",
" tmp = L_sym[child][L_sym[child] > j]\n",
" L_sym[j] = np.unique(np.append(L_sym[j], tmp))\n",
"\n",
" if len(L_sym[j]) > 1:\n",
" p = L_sym[j][1]\n",
" children[p] = np.append(children[p], j)\n",
"\n",
" L_indptr = np.zeros(n + 1, dtype=int)\n",
" L_indptr[1:] = np.cumsum([len(x) for x in L_sym])\n",
" L_indices = np.concatenate(L_sym)\n",
"\n",
" return L_indices, L_indptr\n",
"\n",
"\n",
"def _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
" n = len(A_indptr) - 1\n",
" L_x = np.zeros(len(L_indices))\n",
"\n",
" for j in range(0, n):\n",
" copy_idx = np.nonzero(\n",
" np.in1d(\n",
" L_indices[L_indptr[j] : L_indptr[j + 1]],\n",
" A_indices[A_indptr[j] : A_indptr[j + 1]],\n",
" )\n",
" )[0]\n",
" L_x[L_indptr[j] + copy_idx] = A_x[A_indptr[j] : A_indptr[j + 1]]\n",
" return L_x\n",
"\n",
"\n",
"def _sparse_cholesky_csc_impl(L_indices, L_indptr, L_x):\n",
" n = len(L_indptr) - 1\n",
" descendant = [[] for j in range(0, n)]\n",
" for j in range(0, n):\n",
" tmp = L_x[L_indptr[j] : L_indptr[j + 1]]\n",
" for bebe in descendant[j]:\n",
" k = bebe[0]\n",
" Ljk = L_x[bebe[1]]\n",
" pad = np.nonzero(\n",
" L_indices[L_indptr[k] : L_indptr[k + 1]] == L_indices[L_indptr[j]]\n",
" )[0][0]\n",
" update_idx = np.nonzero(\n",
" np.in1d(\n",
" L_indices[L_indptr[j] : L_indptr[j + 1]],\n",
" L_indices[(L_indptr[k] + pad) : L_indptr[k + 1]],\n",
" )\n",
" )[0]\n",
" tmp[update_idx] = (\n",
" tmp[update_idx] - Ljk * L_x[(L_indptr[k] + pad) : L_indptr[k + 1]]\n",
" )\n",
"\n",
" diag = np.sqrt(tmp[0])\n",
" L_x[L_indptr[j]] = diag\n",
" L_x[(L_indptr[j] + 1) : L_indptr[j + 1]] = tmp[1:] / diag\n",
" for idx in range(L_indptr[j] + 1, L_indptr[j + 1]):\n",
" descendant[L_indices[idx]].append((j, idx))\n",
" return L_x\n",
"\n",
"\n",
"def sparse_cholesky_csc(A_indices, A_indptr, A_x):\n",
" L_indices, L_indptr = _symbolic_factor_csc(A_indices, A_indptr)\n",
" L_x = _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
" L_x = _sparse_cholesky_csc_impl(L_indices, L_indptr, L_x)\n",
" return L_indices, L_indptr, L_x\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5185359-8feb-412f-81ae-d8d91d2dec68",
"metadata": {
"id": "f5185359-8feb-412f-81ae-d8d91d2dec68",
"outputId": "008f2fcb-8acd-4864-c4c5-3d79d6f2dfc9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Error in Cholesky is 3.871041263071504e-12\n"
]
}
],
"source": [
"from scipy import sparse\n",
"\n",
"n = 50\n",
"one_d = sparse.diags([[-1.0] * (n - 1), [2.0] * n, [-1.0] * (n - 1)], [-1, 0, 1])\n",
"A = sparse.kronsum(one_d, one_d) + sparse.eye(n * n)\n",
"A_lower = sparse.tril(A, format=\"csc\")\n",
"A_indices = A_lower.indices\n",
"A_indptr = A_lower.indptr\n",
"A_x = A_lower.data\n",
"\n",
"L_indices, L_indptr, L_x = sparse_cholesky_csc(A_indices, A_indptr, A_x)\n",
"L = sparse.csc_array((L_x, L_indices, L_indptr), shape=(n**2, n**2))\n",
"\n",
"err = np.sum(np.abs((A - L @ L.transpose()).todense()))\n",
"print(f\"Error in Cholesky is {err}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "89ee3cad-c94d-4b0c-8bc4-b6c13d246424",
"metadata": {
"id": "89ee3cad-c94d-4b0c-8bc4-b6c13d246424",
"outputId": "41e5b54d-78ee-452c-cf2b-6d383f49b9e1"
},
"outputs": [
{
"data": {
"text/plain": [
"(3, 51)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.max(np.diff(A_indptr)), np.max(np.diff(L_indptr))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e50533d5",
"metadata": {
"id": "e50533d5",
"outputId": "878ac5ff-2f5e-4ec6-840c-fcb1279c8e42"
},
"outputs": [
{
"data": {
"text/plain": [
"(2501, 2501)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(A_indptr), len(L_indptr)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c01f20eb",
"metadata": {
"id": "c01f20eb",
"outputId": "1e578270-6c31-4c1a-dde4-8eb6b24b3711"
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f8d488005e0>]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.diff(A_indptr), label=\"A\")\n",
"plt.plot(np.diff(L_indptr), label=\"L\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0efea52e",
"metadata": {
"id": "0efea52e",
"outputId": "0fdf83b0-4280-4dad-854d-c9cd62d85662"
},
"outputs": [
{
"data": {
"text/plain": [
"(7400, 125049)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(A_indices), len(L_indices)\n"
]
},
{
"cell_type": "markdown",
"id": "f119ad55",
"metadata": {
"id": "f119ad55"
},
"source": [
"Ragged series is difficult to handle in JAX. \n",
"Idea 1: pad to same length\n",
"\n",
"- pro: vmap-able, usually fast\n",
"- con: need more memory\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83d6fee8",
"metadata": {
"id": "83d6fee8"
},
"outputs": [],
"source": [
"@partial(jax.jit, static_argnums=(0,))\n",
"def fix_length_arange(max_length, start, stop):\n",
" output_idx = np.arange(max_length) + start\n",
" output_null = jnp.ones_like(output_idx) * -1\n",
" output = jnp.where(np.arange(max_length) < (stop - start), output_idx, output_null)\n",
" return output\n",
"\n",
"\n",
"# @partial(jax.jit, static_argnums=(0,1,2,3))\n",
"def _deep_copy_csc_jax(m_A, m_L, A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
" A_indptr_mat = jax.vmap(partial(fix_length_arange, m_A))(\n",
" A_indptr[:-1], A_indptr[1:]\n",
" )\n",
" L_indptr_mat = jax.vmap(partial(fix_length_arange, m_L))(\n",
" L_indptr[:-1], L_indptr[1:]\n",
" )\n",
"\n",
" def row_fun(row_idx_A, row_idxptr_A, row_idx_L, row_val_A):\n",
" row_idx_A_ = jnp.where(\n",
" row_idxptr_A != -1, row_idx_A, jnp.ones_like(row_idx_A) * -1\n",
" )\n",
" out = jnp.zeros(m_L)\n",
" copy_idx = jnp.nonzero(jnp.in1d(row_idx_L, row_idx_A_), size=m_A)[0]\n",
" out = out.at[copy_idx].set(row_val_A)\n",
" return out\n",
"\n",
" L_x_mat = jax.vmap(row_fun)(\n",
" A_indices[A_indptr_mat],\n",
" A_indptr_mat,\n",
" L_indices[L_indptr_mat],\n",
" A_x[A_indptr_mat],\n",
" )\n",
" return L_x_mat, L_indptr_mat\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b5b0566-1b87-4ee7-9446-f90c16dc9816",
"metadata": {
"id": "9b5b0566-1b87-4ee7-9446-f90c16dc9816",
"outputId": "799dffc9-4a9a-480d-f3d2-07296abe9786"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.49 ms ± 24.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
"source": [
"L_indices, L_indptr = _symbolic_factor_csc(A_indices, A_indptr)\n",
"L_x0 = _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"\n",
"# size of the memeory, need to be static\n",
"m_A = np.max(np.diff(A_indptr))\n",
"m_L = np.max(np.diff(L_indptr))\n",
"# L_x_mat, L_indptr_mat = _deep_copy_csc_jax(m_A, m_L, A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"partial_deep_copy_csc = jax.jit(partial(_deep_copy_csc_jax, m_A, m_L))\n",
"L_x_mat, L_indptr_mat = partial_deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"%timeit L_x_mat, L_indptr_mat = partial_deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"L_x = L_x_mat[L_indptr_mat != -1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca45e30a",
"metadata": {
"id": "ca45e30a"
},
"outputs": [],
"source": [
"np.testing.assert_array_almost_equal(L_x, L_x0)\n",
"# np.testing.assert_array_almost_equal(L_x_mat, L_x0[L_indptr_mat])\n"
]
},
{
"cell_type": "markdown",
"id": "10cf6278",
"metadata": {
"id": "10cf6278"
},
"source": [
"Idea 2: scan and build index \n",
"There are 2 options here: scan through n (the column/row size of the matrix), which is how it is implemented initially, or scan through the length of `A_x` and build an index vector (since the value need to go somewhere in `L_x`)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de867282",
"metadata": {
"id": "de867282"
},
"outputs": [],
"source": [
"# Option 1 does not work because JAX is absolutely allergic to dynamic slicing\n",
"# There is a way to have static shape inside one_step with a placeholder fix-shape\n",
"# array but then I might as well just use vmap version above\n",
"\n",
"# def _deep_copy_csc_scan(m_A, A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
"# L_indptr_last = L_indptr[0]\n",
"# A_indptr_last = A_indptr[0]\n",
"# L_x_init = jnp.zeros_like(L_indptr, dtype=A_x.dtype)\n",
"\n",
"# def one_step(carry, current):\n",
"# L_indptr_last, A_indptr_last, L_x = carry\n",
"# L_indptr_current, A_indptr_current = current\n",
"# copy_idx = jnp.nonzero(\n",
"# jnp.in1d(\n",
"# # L_indices[L_indptr_last:L_indptr_current],\n",
"# jax.lax.dynamic_slice(L_indices, [L_indptr_last], [L_indptr_current-L_indptr_last]),\n",
"# jax.lax.dynamic_slice(A_indices, [A_indptr_last], [A_indptr_current-A_indptr_last]),\n",
"# ),\n",
"# size=A_indptr_current-A_indptr_last\n",
"# )[0]\n",
"# L_x = L_x.at[L_indptr_last + copy_idx].set(\n",
"# jax.lax.dynamic_slice(A_x, [A_indptr_last], [A_indptr_current-A_indptr_last])\n",
"# )\n",
"# return (L_indptr_current, A_indptr_current, L_x), None\n",
"\n",
"# (*_, L_x), _ = jax.lax.scan(\n",
"# one_step,\n",
"# (L_indptr_last, A_indptr_last, L_x_init),\n",
"# (L_indptr[1:], A_indptr[1:]))\n",
"# return L_x\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6ff7c8f-df2f-40bf-a874-2b975342a2bf",
"metadata": {
"tags": [],
"id": "f6ff7c8f-df2f-40bf-a874-2b975342a2bf"
},
"outputs": [],
"source": [
"def _deep_copy_csc_scan(m_L, A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
" # Pad value otherwise out of bound indexing result is weird.\n",
" L_indices_padded = jnp.pad(L_indices, [0, m_L], constant_values=-1)\n",
"\n",
" def one_step(_, i):\n",
" A_idx = A_indices[i]\n",
" j = jnp.argwhere(i < A_indptr, size=1)[0][0]\n",
" # L_indices_slice = L_indices[L_indptr[j-1]:L_indptr[j]]\n",
" L_indices_slice = jax.lax.dynamic_slice(\n",
" L_indices_padded, [L_indptr[j - 1]], [m_L]\n",
" )\n",
" L_indices_slice = jnp.where(\n",
" jnp.arange(m_L) < L_indptr[j] - L_indptr[j - 1],\n",
" L_indices_slice,\n",
" jnp.ones_like(L_indices_slice) * -1,\n",
" )\n",
" k = jnp.argwhere(A_idx == L_indices_slice, size=1)[0][0]\n",
" to_write_index = k + L_indptr[j - 1]\n",
" return None, to_write_index\n",
"\n",
" _, update_index = jax.lax.scan(one_step, None, jnp.arange(A_indices.shape[-1]))\n",
"\n",
" L_x = jnp.zeros_like(L_indices, dtype=A_x.dtype)\n",
" L_x = L_x.at[update_index].set(A_x)\n",
" return L_x, update_index\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b522bd88",
"metadata": {
"tags": [],
"id": "b522bd88",
"outputId": "a670552f-38ac-4fa6-bf5b-678996855fe9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.53 ms ± 33 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
"source": [
"m_L = np.max(np.diff(L_indptr))\n",
"partial_deep_copy_csc_1 = jax.jit(partial(_deep_copy_csc_scan, m_L))\n",
"L_x, update_index = partial_deep_copy_csc_1(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"%timeit L_x = partial_deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "773b1cf8",
"metadata": {
"id": "773b1cf8"
},
"outputs": [],
"source": [
"np.testing.assert_array_almost_equal(L_x, L_x0)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8754ef41-b0c1-438d-90b1-82212ad0a0cc",
"metadata": {
"id": "8754ef41-b0c1-438d-90b1-82212ad0a0cc",
"outputId": "866c1a0b-8b23-4509-bbe7-e6596beca48b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of non-zeros is 125049 (fill in of 117649)\n"
]
}
],
"source": [
"nnz = len(L_x)\n",
"print(f\"Number of non-zeros is {nnz} (fill in of {len(L_x) - len(A_x)})\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29700fdc-e619-475f-a5a4-d360112e4c5f",
"metadata": {
"id": "29700fdc-e619-475f-a5a4-d360112e4c5f",
"outputId": "9422736a-913c-4aeb-c50f-71674d9ac71c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2499 2498 2449 ... 50 1 0]\n"
]
}
],
"source": [
"perm = sparse.csgraph.reverse_cuthill_mckee(A, symmetric_mode=True)\n",
"print(perm)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d40f647-f9b9-4e6c-b30a-dbaa94ecbb5f",
"metadata": {
"id": "3d40f647-f9b9-4e6c-b30a-dbaa94ecbb5f"
},
"outputs": [],
"source": [
"A_perm = A[perm[:, None], perm]\n",
"A_perm_lower = sparse.tril(A_perm, format=\"csc\")\n",
"A_indices = A_perm_lower.indices\n",
"A_indptr = A_perm_lower.indptr\n",
"A_x = A_perm_lower.data\n",
"\n",
"L_indices, L_indptr, L_x = sparse_cholesky_csc(A_indices, A_indptr, A_x)\n",
"L = sparse.csc_array((L_x, L_indices, L_indptr), shape=(n**2, n**2))\n",
"err = np.sum(np.abs((A_perm - L @ L.transpose()).todense()))\n",
"print(f\"Error in Cholesky is {err}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53052e6f",
"metadata": {
"id": "53052e6f",
"outputId": "26a14ae4-b63d-459e-fa91-3a341a4e82d2"
},
"outputs": [
{
"data": {
"text/plain": [
"(3, 51)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.max(np.diff(A_indptr)), np.max(np.diff(L_indptr))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d9a5a51",
"metadata": {
"id": "5d9a5a51",
"outputId": "6fc321ab-e78c-49ee-9042-657f34ab247a"
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f8cd8969e20>]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.diff(A_indptr), label=\"A\")\n",
"plt.plot(np.diff(L_indptr), label=\"L\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed1cce62-346e-4710-ad6f-64e2246cbdf0",
"metadata": {
"id": "ed1cce62-346e-4710-ad6f-64e2246cbdf0",
"outputId": "94011e5d-1b3e-4cb1-833a-bcc1fe2b0e7e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of non-zeros is 87025 (fill in of 79625), \n",
"which is less than the unpermuted matrix, which had 125049 non-zeros.\n"
]
}
],
"source": [
"nnz_rcm = len(L_x)\n",
"print(\n",
" f\"\"\"Number of non-zeros is {nnz_rcm} (fill in of {len(L_x) - len(A_x)}), \n",
"which is less than the unpermuted matrix, which had {nnz} non-zeros.\"\"\"\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4212f971",
"metadata": {
"id": "4212f971"
},
"outputs": [],
"source": [
"L_indices, L_indptr = _symbolic_factor_csc(A_indices, A_indptr)\n",
"L_x0 = _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"# can reuse the same compiled function as the static shape config does not change by permutation\n",
"L_x_mat, L_indptr_mat = partial_deep_copy_csc(\n",
" A_indices, A_indptr, A_x, L_indices, L_indptr\n",
")\n",
"L_x = L_x_mat[L_indptr_mat != -1]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ae429aa",
"metadata": {
"id": "0ae429aa"
},
"outputs": [],
"source": [
"np.testing.assert_array_almost_equal(L_x, L_x0)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adeb463b",
"metadata": {
"id": "adeb463b"
},
"outputs": [],
"source": [
"L_x_chol0 = _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)\n",
"L_x_chol0 = _sparse_cholesky_csc_impl(L_indices, L_indptr, L_x_chol0)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c905e2a5",
"metadata": {
"id": "c905e2a5"
},
"outputs": [],
"source": [
"n = len(L_indptr) - 1\n",
"descendant = [[] for _ in range(0, n)]\n",
"for j in range(0, n):\n",
" for idx in range(L_indptr[j] + 1, L_indptr[j + 1]):\n",
" descendant[L_indices[idx]].append((j, idx))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61f2a289",
"metadata": {
"id": "61f2a289"
},
"outputs": [],
"source": [
"counter = jnp.zeros(L_indptr.shape[0] - 1, dtype=np.int32)\n",
"m_L = np.max(np.diff(L_indptr))\n",
"# n+1 to store placeholder stuff\n",
"descendant_arr = jnp.ones((n + 1, m_L - 1, 2), dtype=np.int32) * -1\n",
"# L_indices_mat = L_indices[L_indptr_mat]\n",
"\n",
"\n",
"def one_step(carry, x):\n",
" (descendant_arr, counter) = carry\n",
" (j, L_indptr_slice, L_indices_slice) = x\n",
" idx_ptr = L_indptr_slice[1:]\n",
" idx = jnp.where(idx_ptr != -1, L_indices_slice[1:], idx_ptr)\n",
" to_update = jnp.where(\n",
" idx_ptr != -1,\n",
" jnp.vstack([jnp.ones_like(idx_ptr) * j, idx_ptr]),\n",
" jnp.ones((2, m_L - 1), dtype=idx_ptr.dtype) * -1,\n",
" )\n",
" # for i in range(m_L):\n",
" # if idx_ptr[i] != -1:\n",
" # descendant_arr = descendant_arr.at[idx[i], counter[idx[i]]].set(to_update[:, i])\n",
" descendant_arr_next = descendant_arr.at[idx, counter[idx]].set(to_update.T)\n",
" counter_next = counter.at[idx].add(idx_ptr != -1)\n",
" return (descendant_arr_next, counter_next), None\n",
"\n",
"\n",
"(descendant_arr, counter), _ = jax.lax.scan(\n",
" one_step,\n",
" (descendant_arr, counter),\n",
" (jnp.arange(n), L_indptr_mat, L_indices[L_indptr_mat]),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b47f12c",
"metadata": {
"id": "3b47f12c"
},
"outputs": [],
"source": [
"for j in range(1, n):\n",
" np.testing.assert_array_equal(\n",
" jnp.asarray(descendant[j]), descendant_arr[j, : len(descendant[j])]\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6da62499",
"metadata": {
"id": "6da62499"
},
"outputs": [],
"source": [
"def _sparse_cholesky_csc_impl_jax(m_L, L_indices, L_indptr_mat, L_x):\n",
" n = L_indptr_mat.shape[0]\n",
" counter = jnp.zeros(n, dtype=np.int32)\n",
" # n+1 to store placeholder stuff\n",
" descendant_arr = jnp.ones((n + 1, m_L - 1, 2), dtype=np.int32) * -1\n",
" L_indices_mat = L_indices[L_indptr_mat]\n",
" L_x_mat = L_x[L_indptr_mat]\n",
" L_x_chol = jnp.zeros_like(L_x_mat)\n",
"\n",
" def _inner_none_op(carry, _):\n",
" return carry, None\n",
"\n",
" def _inner_one_step(carry, bebe):\n",
" (L_x_chol, L_indices_slice, tmp) = carry\n",
" k = bebe[0]\n",
" L_indices_slice_k = L_indices_mat[k]\n",
" L_x_chol_slice_k = L_x_chol[k]\n",
" Ljk = L_x_chol_slice_k[bebe[1]]\n",
" pad = jnp.argwhere(L_indices_slice_k == L_indices_slice[0], size=1)[0][0]\n",
" L_indices_slice_k_pad = jnp.roll(\n",
" jnp.where(\n",
" (jnp.arange(m_L) < pad) | (L_indptr_mat[k] == -1),\n",
" jnp.ones_like(L_indices_slice_k) * -1,\n",
" L_indices_slice_k,\n",
" ),\n",
" -pad,\n",
" )\n",
" L_x_chol_slice_k_pad = jnp.roll(\n",
" jnp.where(\n",
" jnp.arange(m_L) < pad,\n",
" jnp.zeros_like(L_x_chol_slice_k),\n",
" L_x_chol_slice_k,\n",
" ),\n",
" -pad,\n",
" )\n",
" _update_idx = jnp.nonzero(\n",
" jnp.in1d(L_indices_slice, L_indices_slice_k_pad), size=m_L\n",
" )[0]\n",
" update_idx = jnp.where(\n",
" jnp.concatenate([jnp.ones([1]), jnp.diff(_update_idx)]) > 0,\n",
" _update_idx,\n",
" jnp.ones_like(_update_idx) * -1,\n",
" )\n",
" tmp = tmp.at[update_idx].add(-Ljk * L_x_chol_slice_k_pad)\n",
" return (L_x_chol, L_indices_slice, tmp), None\n",
"\n",
" def inner_one_step(carry, bebe):\n",
" return jax.lax.cond(bebe[0] != -1, _inner_one_step, _inner_none_op, carry, bebe)\n",
"\n",
" def one_step(carry, x):\n",
" (descendant_arr, counter, L_x_chol) = carry\n",
" (j, L_indptr_slice, L_indices_slice, _tmp) = x\n",
" # Append 1 more value to the end to account for the placeholder computation\n",
" tmp = jnp.concatenate([_tmp, jnp.zeros([1])])\n",
" idx_ptr = L_indptr_slice[1:]\n",
" idx = jnp.where(idx_ptr != -1, L_indices_slice[1:], idx_ptr)\n",
" to_update = jnp.where(\n",
" idx_ptr != -1,\n",
" jnp.vstack([jnp.ones_like(idx_ptr) * j, jnp.arange(1, m_L)]),\n",
" jnp.ones((2, m_L - 1), dtype=idx_ptr.dtype) * -1,\n",
" )\n",
" descendant_arr_next = descendant_arr.at[idx, counter[idx]].set(to_update.T)\n",
" counter_next = counter.at[idx].add(idx_ptr != -1)\n",
"\n",
" bebe = descendant_arr_next[j]\n",
" (L_x_chol, L_indices_slice, tmp), _ = jax.lax.scan(\n",
" inner_one_step, (L_x_chol, L_indices_slice, tmp), bebe\n",
" )\n",
"\n",
" diag = jnp.sqrt(tmp[:1])\n",
" L_x_out = jnp.concatenate([diag, tmp[1:-1] / diag])\n",
" L_x_chol = L_x_chol.at[j].set(L_x_out)\n",
" return (descendant_arr_next, counter_next, L_x_chol), None\n",
"\n",
" (descendant_arr, counter, L_x_chol), _ = jax.lax.scan(\n",
" one_step,\n",
" (descendant_arr, counter, L_x_chol),\n",
" (jnp.arange(n), L_indptr_mat, L_indices_mat, L_x_mat),\n",
" )\n",
" return L_x_chol, descendant_arr\n",
"\n",
"\n",
"m_L = np.max(np.diff(L_indptr))\n",
"partial_sparse_cholesky_csc_impl = jax.jit(partial(_sparse_cholesky_csc_impl_jax, m_L))\n",
"L_x_chol2, descendant_arr = partial_sparse_cholesky_csc_impl(\n",
" L_indices, L_indptr_mat, L_x\n",
")\n",
"L_x_chol_ = L_x_chol2[L_indptr_mat != -1]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42c14161-bd5a-4b02-a9df-3434d028cccd",
"metadata": {
"id": "42c14161-bd5a-4b02-a9df-3434d028cccd"
},
"outputs": [],
"source": [
"np.testing.assert_array_almost_equal(L_x_chol0, L_x_chol_)\n"
]
},
{
"cell_type": "markdown",
"id": "fff3412d",
"metadata": {
"id": "fff3412d"
},
"source": [
"Putting everything together, in a Python Class.\n",
"\n",
"Why a Python Class? There are number of things we want to compute them and store them as static parameter. Much easier to do with a class\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00ea1cb1",
"metadata": {
"id": "00ea1cb1"
},
"outputs": [],
"source": [
"def _symbolic_factor_csc(A_indices, A_indptr):\n",
" # Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.\n",
" n = len(A_indptr) - 1\n",
" L_sym = [np.array([], dtype=int) for j in range(n)]\n",
" children = [np.array([], dtype=int) for j in range(n)]\n",
"\n",
" for j in range(n):\n",
" L_sym[j] = A_indices[A_indptr[j] : A_indptr[j + 1]]\n",
" for child in children[j]:\n",
" tmp = L_sym[child][L_sym[child] > j]\n",
" L_sym[j] = np.unique(np.append(L_sym[j], tmp))\n",
"\n",
" if len(L_sym[j]) > 1:\n",
" p = L_sym[j][1]\n",
" children[p] = np.append(children[p], j)\n",
"\n",
" L_indptr = np.zeros(n + 1, dtype=int)\n",
" L_indptr[1:] = np.cumsum([len(x) for x in L_sym])\n",
" L_indices = np.concatenate(L_sym)\n",
"\n",
" return L_indices, L_indptr\n",
"\n",
"\n",
"@partial(jax.jit, static_argnums=(0,))\n",
"def fix_length_arange(max_length, start, stop):\n",
" output_idx = np.arange(max_length) + start\n",
" output_null = jnp.ones_like(output_idx) * -1\n",
" output = jnp.where(np.arange(max_length) < (stop - start), output_idx, output_null)\n",
" return output\n",
"\n",
"\n",
"def _deep_copy_csc_vmap(m_A, m_L, A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
" A_indptr_mat = jax.vmap(partial(fix_length_arange, m_A))(\n",
" A_indptr[:-1], A_indptr[1:]\n",
" )\n",
" L_indptr_mat = jax.vmap(partial(fix_length_arange, m_L))(\n",
" L_indptr[:-1], L_indptr[1:]\n",
" )\n",
"\n",
" def row_fun(row_idx_A, row_idxptr_A, row_idx_L, row_val_A):\n",
" row_idx_A_ = jnp.where(\n",
" row_idxptr_A != -1, row_idx_A, jnp.ones_like(row_idx_A) * -1\n",
" )\n",
" out = jnp.zeros(m_L)\n",
" copy_idx = jnp.nonzero(jnp.in1d(row_idx_L, row_idx_A_), size=m_A)[0]\n",
" out = out.at[copy_idx].set(row_val_A)\n",
" return out\n",
"\n",
" L_x_mat = jax.vmap(row_fun)(\n",
" A_indices[A_indptr_mat],\n",
" A_indptr_mat,\n",
" L_indices[L_indptr_mat],\n",
" A_x[A_indptr_mat],\n",
" )\n",
" return L_x_mat, L_indptr_mat\n",
"\n",
"\n",
"def _deep_copy_csc_scan(m_L, A_indices, A_indptr, A_x, L_indices, L_indptr):\n",
" # Pad value otherwise out of bound indexing result is weird.\n",
" L_indices_padded = jnp.pad(L_indices, [0, m_L], constant_values=-1)\n",
"\n",
" def one_step(_, i):\n",
" A_idx = A_indices[i]\n",
" j = jnp.argwhere(i < A_indptr, size=1)[0][0]\n",
" # L_indices_slice = L_indices[L_indptr[j-1]:L_indptr[j]]\n",
" L_indices_slice = jax.lax.dynamic_slice(\n",
" L_indices_padded, [L_indptr[j - 1]], [m_L]\n",
" )\n",
" L_indices_slice = jnp.where(\n",
" jnp.arange(m_L) < L_indptr[j] - L_indptr[j - 1],\n",
" L_indices_slice,\n",
" jnp.ones_like(L_indices_slice) * -1,\n",
" )\n",
" k = jnp.argwhere(A_idx == L_indices_slice, size=1)[0][0]\n",
" to_write_index = k + L_indptr[j - 1]\n",
" return None, to_write_index\n",
"\n",
" _, update_index = jax.lax.scan(one_step, None, jnp.arange(A_indices.shape[-1]))\n",
"\n",
" L_x = jnp.zeros_like(L_indices, dtype=A_x.dtype)\n",
" L_x = L_x.at[update_index].set(A_x)\n",
" return L_x, update_index\n",
"\n",
"\n",
"def _sparse_cholesky_csc_impl_jax(m_L, L_indices_mat, L_indptr_mat, L_x_mat):\n",
" n = L_indptr_mat.shape[0]\n",
" counter = jnp.zeros(n, dtype=np.int32)\n",
" # n+1 to store placeholder stuff\n",
" descendant_arr = jnp.ones((n + 1, m_L - 1, 2), dtype=np.int32) * -1\n",
" L_x_chol = jnp.zeros_like(L_x_mat)\n",
"\n",
" def _inner_none_op(carry, _):\n",
" return carry, None\n",
"\n",
" def _inner_one_step(carry, bebe):\n",
" (L_x_chol, L_indices_slice, tmp) = carry\n",
" k = bebe[0]\n",
" L_indices_slice_k = L_indices_mat[k]\n",
" L_x_chol_slice_k = L_x_chol[k]\n",
" Ljk = L_x_chol_slice_k[bebe[1]]\n",
" pad = jnp.argwhere(L_indices_slice_k == L_indices_slice[0], size=1)[0][0]\n",
" L_indices_slice_k_pad = jnp.roll(\n",
" jnp.where(\n",
" (jnp.arange(m_L) < pad) | (L_indptr_mat[k] == -1),\n",
" jnp.ones_like(L_indices_slice_k) * -1,\n",
" L_indices_slice_k,\n",
" ),\n",
" -pad,\n",
" )\n",
" L_x_chol_slice_k_pad = jnp.roll(\n",
" jnp.where(\n",
" jnp.arange(m_L) < pad,\n",
" jnp.zeros_like(L_x_chol_slice_k),\n",
" L_x_chol_slice_k,\n",
" ),\n",
" -pad,\n",
" )\n",
" _update_idx = jnp.nonzero(\n",
" jnp.in1d(L_indices_slice, L_indices_slice_k_pad), size=m_L\n",
" )[0]\n",
" update_idx = jnp.where(\n",
" jnp.concatenate([jnp.ones([1]), jnp.diff(_update_idx)]) > 0,\n",
" _update_idx,\n",
" jnp.ones_like(_update_idx) * -1,\n",
" )\n",
" tmp = tmp.at[update_idx].add(-Ljk * L_x_chol_slice_k_pad)\n",
" return (L_x_chol, L_indices_slice, tmp), None\n",
"\n",
" def inner_one_step(carry, bebe):\n",
" return jax.lax.cond(bebe[0] != -1, _inner_one_step, _inner_none_op, carry, bebe)\n",
"\n",
" def one_step(carry, x):\n",
" (descendant_arr, counter, L_x_chol) = carry\n",
" (j, L_indptr_slice, L_indices_slice, _tmp) = x\n",
" # Append 1 more value to the end to account for the placeholder computation\n",
" tmp = jnp.concatenate([_tmp, jnp.zeros([1])])\n",
" idx_ptr = L_indptr_slice[1:]\n",
" idx = jnp.where(idx_ptr != -1, L_indices_slice[1:], idx_ptr)\n",
" to_update = jnp.where(\n",
" idx_ptr != -1,\n",
" jnp.vstack([jnp.ones_like(idx_ptr) * j, jnp.arange(1, m_L)]),\n",
" jnp.ones((2, m_L - 1), dtype=idx_ptr.dtype) * -1,\n",
" )\n",
" descendant_arr_next = descendant_arr.at[idx, counter[idx]].set(to_update.T)\n",
" counter_next = counter.at[idx].add(idx_ptr != -1)\n",
"\n",
" bebe = descendant_arr_next[j]\n",
" (L_x_chol, L_indices_slice, tmp), _ = jax.lax.scan(\n",
" inner_one_step, (L_x_chol, L_indices_slice, tmp), bebe\n",
" )\n",
"\n",
" diag = jnp.sqrt(tmp[:1])\n",
" L_x_out = jnp.concatenate([diag, tmp[1:-1] / diag])\n",
" L_x_chol = L_x_chol.at[j].set(L_x_out)\n",
" return (descendant_arr_next, counter_next, L_x_chol), None\n",
"\n",
" (descendant_arr, counter, L_x_chol), _ = jax.lax.scan(\n",
" one_step,\n",
" (descendant_arr, counter, L_x_chol),\n",
" (jnp.arange(n), L_indptr_mat, L_indices_mat, L_x_mat),\n",
" )\n",
" return L_x_chol, descendant_arr\n",
"\n",
"\n",
"class sparse_csc:\n",
" def __init__(self, A_indices, A_indptr, use_scan=False):\n",
" self.use_scan = use_scan\n",
" L_indices, L_indptr = _symbolic_factor_csc(A_indices, A_indptr)\n",
" # size of the memeory, need to be static\n",
" self.m_A = np.max(np.diff(A_indptr))\n",
" self.m_L = np.max(np.diff(L_indptr))\n",
" \n",
" self.A_indices = jnp.asarray(A_indices)\n",
" self.A_indptr = jnp.asarray(A_indptr)\n",
" self.L_indices = jnp.asarray(L_indices)\n",
" self.L_indptr = jnp.asarray(L_indptr)\n",
" self.L_indptr_mat = jax.vmap(partial(fix_length_arange, m_L))(\n",
" L_indptr[:-1], L_indptr[1:]\n",
" )\n",
" self.mask = self.L_indptr_mat != -1\n",
"\n",
" @partial(jax.jit, static_argnums=(0,))\n",
" def chol(self, A_x):\n",
" if self.use_scan:\n",
" L_x, _ = _deep_copy_csc_scan(\n",
" self.m_L,\n",
" self.A_indices,\n",
" self.A_indptr,\n",
" A_x,\n",
" self.L_indices,\n",
" self.L_indptr,\n",
" )\n",
" L_x_mat = L_x[self.L_indptr_mat]\n",
" else:\n",
" L_x_mat, _ = _deep_copy_csc_vmap(\n",
" self.m_A,\n",
" self.m_L,\n",
" self.A_indices,\n",
" self.A_indptr,\n",
" A_x,\n",
" self.L_indices,\n",
" self.L_indptr,\n",
" )\n",
" L_indices_mat = self.L_indices[self.L_indptr_mat]\n",
" L_x_chol, _ = _sparse_cholesky_csc_impl_jax(\n",
" m_L, L_indices_mat, self.L_indptr_mat, L_x_mat\n",
" )\n",
" return L_x_chol[self.mask]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67c7dbfc",
"metadata": {
"id": "67c7dbfc"
},
"outputs": [],
"source": [
"A_perm = A[perm[:, None], perm]\n",
"A_perm_lower = sparse.tril(A_perm, format=\"csc\")\n",
"A_indices = A_perm_lower.indices\n",
"A_indptr = A_perm_lower.indptr\n",
"A_x = A_perm_lower.data\n",
"\n",
"L_indices, L_indptr, L_x = sparse_cholesky_csc(A_indices, A_indptr, A_x)\n",
"L_x_ = sparse_csc(A_indices, A_indptr, True).chol(A_x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1201e44",
"metadata": {
"id": "d1201e44"
},
"outputs": [],
"source": [
"np.testing.assert_array_almost_equal(L_x, L_x_)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
},
"vscode": {
"interpreter": {
"hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf"
}
},
"colab": {
"name": "Sparse cholesky in JAX.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment