Last active
January 19, 2023 00:39
-
-
Save mcwitt/7c31e296a1292536014e156ffc4b4b58 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "e41bccde-100f-40b1-b748-eafb91231422", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%config InlineBackend.figure_format = \"retina\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "f1cc9c93-07ef-4329-b669-89081a7e884f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/javascript": [ | |
"\n", | |
" (function() {\n", | |
" jb_set_cell(\"cos2 = (rij.dot(rkj) + epsilon**2) / sp.sqrt(\\n (rij.dot(rij) + epsilon**2) * (rkj.dot(rkj) + epsilon**2)\\n)\\n\\nha2 = k / 2 * (cos2 - sp.cos(theta0)) ** 2\")\n", | |
" })();\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.Javascript object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"import jupyter_black\n", | |
"jupyter_black.load(lab=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "01a4688e-993d-46aa-8897-810cfac92d64", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import jax" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "a480630e-d858-4bc6-9f69-b40b39a03bf0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"jax.config.update(\"jax_enable_x64\", True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "17b272a3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import jax.numpy as jnp\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"import sympy as sp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "3a592011", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sympy.vector import CoordSys3D\n", | |
"\n", | |
"C = CoordSys3D(\"C\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4f4dfbd0", | |
"metadata": {}, | |
"source": [ | |
"## Harmonic angle\n", | |
"\n", | |
"$\\renewcommand{\\vec}[1]{{\\mathbf{\\boldsymbol{{#1}}}}}$\n", | |
"\\begin{equation}\n", | |
"U(\\vec{r}_i, \\vec{r}_j, \\vec{r}_k; k, \\theta_0) = k (\\cos \\theta - \\cos \\theta_0)^2\n", | |
"\\end{equation}\n", | |
"where\n", | |
"\\begin{equation}\n", | |
"\\cos \\theta = \\frac{\\vec{r}_{ij} \\cdot \\vec{r}_{kj}}{r_{ij} r_{kj} + \\epsilon^2}\n", | |
"\\end{equation}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "808d8823", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sympy.abc import k, epsilon" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "7434fcda", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"xi, xj, xk = sp.symbols(\"x_i x_j x_k\")\n", | |
"yi, yj, yk = sp.symbols(\"y_i y_j y_k\")\n", | |
"zi, zj, zk = sp.symbols(\"z_i z_j z_k\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "979c26cb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ri = xi * C.i + yi * C.j + zi * C.k\n", | |
"rj = xj * C.i + yj * C.j + zj * C.k\n", | |
"rk = xk * C.i + yk * C.j + zk * C.k" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "d69d8370", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rij = rj - ri\n", | |
"rkj = rj - rk" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "11005c24", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cos1 = rij.dot(rkj) / (rij.magnitude() * rkj.magnitude() + epsilon**2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "0ce55e52", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"theta0 = sp.symbols(\"theta_0\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "9914400e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ha1 = k / 2 * (cos1 - sp.cos(theta0)) ** 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "8afde9ff", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def ha1_approx(x_i, x_j, x_k, y_i, y_j, y_k, z_i, z_j, z_k, k, theta_0, epsilon):\n", | |
" ri = jnp.array([x_i, y_i, z_i])\n", | |
" rj = jnp.array([x_j, y_j, z_j])\n", | |
" rk = jnp.array([x_k, y_k, z_k])\n", | |
" rij = rj - ri\n", | |
" rkj = rj - rk\n", | |
" cos1 = jnp.dot(rij, rkj) / (\n", | |
" jnp.linalg.norm(rij) * jnp.linalg.norm(rkj) + epsilon**2\n", | |
" )\n", | |
" return k / 2 * (cos1 - jnp.cos(theta_0)) ** 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "42d7a094", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def assert_consistency_on_random_inputs(f, f_approx, n=5):\n", | |
" for _ in range(n):\n", | |
" subs = dict(\n", | |
" x_i=np.random.uniform(),\n", | |
" x_j=np.random.uniform(),\n", | |
" x_k=np.random.uniform(),\n", | |
" y_i=np.random.uniform(),\n", | |
" y_j=np.random.uniform(),\n", | |
" y_k=np.random.uniform(),\n", | |
" z_i=np.random.uniform(),\n", | |
" z_j=np.random.uniform(),\n", | |
" z_k=np.random.uniform(),\n", | |
" k=np.random.uniform(),\n", | |
" theta_0=np.random.uniform(),\n", | |
" epsilon=np.random.uniform(),\n", | |
" )\n", | |
"\n", | |
" np.testing.assert_allclose(float(f.evalf(subs=subs)), f_approx(*subs.values()))\n", | |
"\n", | |
" # check with r_ij = 0\n", | |
" subs = dict(\n", | |
" x_i=0.0,\n", | |
" x_j=0.0,\n", | |
" x_k=np.random.uniform(),\n", | |
" y_i=0.0,\n", | |
" y_j=0.0,\n", | |
" y_k=np.random.uniform(),\n", | |
" z_i=0.0,\n", | |
" z_j=0.0,\n", | |
" z_k=np.random.uniform(),\n", | |
" k=np.random.uniform(),\n", | |
" theta_0=np.random.uniform(),\n", | |
" epsilon=np.random.uniform(),\n", | |
" )\n", | |
"\n", | |
" x = float(f.evalf(subs=subs))\n", | |
" assert not np.isnan(x)\n", | |
" np.testing.assert_allclose(x, f_approx(*subs.values()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "bb5bab7d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | |
] | |
} | |
], | |
"source": [ | |
"assert_consistency_on_random_inputs(ha1, ha1_approx)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "94055e08", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "AssertionError", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", | |
"Input \u001b[0;32mIn [17]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43massert_consistency_on_random_inputs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mha1\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiff\u001b[49m\u001b[43m(\u001b[49m\u001b[43mxi\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mha1_approx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margnums\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", | |
"Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36massert_consistency_on_random_inputs\u001b[0;34m(f, f_approx, n)\u001b[0m\n\u001b[1;32m 21\u001b[0m subs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m(\n\u001b[1;32m 22\u001b[0m x_i\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.0\u001b[39m,\n\u001b[1;32m 23\u001b[0m x_j\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.0\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 33\u001b[0m epsilon\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39muniform(),\n\u001b[1;32m 34\u001b[0m )\n\u001b[1;32m 36\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mfloat\u001b[39m(f\u001b[38;5;241m.\u001b[39mevalf(subs\u001b[38;5;241m=\u001b[39msubs))\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m np\u001b[38;5;241m.\u001b[39misnan(x)\n\u001b[1;32m 38\u001b[0m np\u001b[38;5;241m.\u001b[39mtesting\u001b[38;5;241m.\u001b[39massert_allclose(x, f_approx(\u001b[38;5;241m*\u001b[39msubs\u001b[38;5;241m.\u001b[39mvalues()))\n", | |
"\u001b[0;31mAssertionError\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"assert_consistency_on_random_inputs(ha1.diff(xi), jax.grad(ha1_approx, argnums=0))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "43a4fe61", | |
"metadata": {}, | |
"source": [ | |
"#### Using naive common subexpression elimination, the resulting expressions for the spatial derivatives are singular at $r_{ij}=0$ and $r_{kj}=0$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fd0f2a35", | |
"metadata": {}, | |
"source": [ | |
"## Harmonic angle (alternate)\n", | |
"\n", | |
"$\\renewcommand{\\vec}[1]{{\\mathbf{\\boldsymbol{{#1}}}}}$\n", | |
"\\begin{equation}\n", | |
"U(\\vec{r}_i, \\vec{r}_j, \\vec{r}_k; k, \\theta_0) = \\frac{1}{2} k (\\cos \\theta - \\cos \\theta_0)^2\n", | |
"\\end{equation}\n", | |
"where\n", | |
"\\begin{equation}\n", | |
"\\cos \\theta = \\frac{\\vec{r}_{ij} \\cdot \\vec{r}_{kj} + \\epsilon^2}{\\sqrt{(r_{ij}^2 + \\epsilon^2)(r_{kj}^2 + \\epsilon^2)}}\n", | |
"\\end{equation}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "795d89ee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cos2 = (rij.dot(rkj) + epsilon**2) / sp.sqrt(\n", | |
" (rij.dot(rij) + epsilon**2) * (rkj.dot(rkj) + epsilon**2)\n", | |
")\n", | |
"\n", | |
"ha2 = k / 2 * (cos2 - sp.cos(theta0)) ** 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "a77c67f7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def ha2_approx(x_i, x_j, x_k, y_i, y_j, y_k, z_i, z_j, z_k, k, theta_0, epsilon):\n", | |
" ri = jnp.array([x_i, y_i, z_i])\n", | |
" rj = jnp.array([x_j, y_j, z_j])\n", | |
" rk = jnp.array([x_k, y_k, z_k])\n", | |
" rij = rj - ri\n", | |
" rkj = rj - rk\n", | |
" cos2 = (jnp.dot(rij, rkj) + epsilon**2) / jnp.sqrt(\n", | |
" (jnp.dot(rij, rij) + epsilon**2) * (jnp.dot(rkj, rkj) + epsilon**2)\n", | |
" )\n", | |
" return k / 2 * (cos2 - jnp.cos(theta_0)) ** 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "172ecb84", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"assert_consistency_on_random_inputs(ha2, ha2_approx)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "9380dd68", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"assert_consistency_on_random_inputs(ha2.diff(xi), jax.grad(ha2_approx, argnums=0))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "2aa8eed8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([(x0, epsilon**2),\n", | |
" (x1, x_i - x_j),\n", | |
" (x2, -x1),\n", | |
" (x3, x_j - x_k),\n", | |
" (x4, -y_i + y_j),\n", | |
" (x5, y_j - y_k),\n", | |
" (x6, -z_i + z_j),\n", | |
" (x7, z_j - z_k),\n", | |
" (x8, x0 + x2*x3 + x4*x5 + x6*x7),\n", | |
" (x9, x0 + x2**2 + x4**2 + x6**2),\n", | |
" (x10, x0 + x3**2 + x5**2 + x7**2),\n", | |
" (x11, 1/sqrt(x10*x9)),\n", | |
" (x12, x11*x8 - cos(theta_0)),\n", | |
" (x13, x12**2/2),\n", | |
" (x14, 2*x_j),\n", | |
" (x15, -x14),\n", | |
" (x16, x15 + 2*x_i),\n", | |
" (x17, x11*x8),\n", | |
" (x18, x17/x9),\n", | |
" (x19, k*x12),\n", | |
" (x20, x19/2),\n", | |
" (x21, 2*x11),\n", | |
" (x22, x14 - 2*x_k),\n", | |
" (x23, 1/x10),\n", | |
" (x24, 2*x18*x23)],\n", | |
" [k*x13,\n", | |
" x20*(-2*x11*x3 - x16*x18),\n", | |
" x20*(x21*(-x15 - x_i - x_k) + x24*(x10*x16/2 - x22*x9/2)),\n", | |
" x20*(x1*x21 + x17*x22*x23),\n", | |
" x13,\n", | |
" x19*sin(theta_0),\n", | |
" x20*(4*epsilon*x11 + x24*(-epsilon*x10 - epsilon*x9))])" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sp.cse(\n", | |
" [\n", | |
" ha2,\n", | |
" ha2.diff(xi),\n", | |
" ha2.diff(xj),\n", | |
" ha2.diff(xk),\n", | |
" ha2.diff(k),\n", | |
" ha2.diff(theta0),\n", | |
" ha2.diff(epsilon),\n", | |
" ]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "6a65a001", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def ha2_value_and_grad(\n", | |
" x_i, x_j, x_k, y_i, y_j, y_k, z_i, z_j, z_k, k, theta_0, epsilon\n", | |
"):\n", | |
" eps2 = epsilon**2\n", | |
"\n", | |
" xij = x_j - x_i\n", | |
" xkj = x_j - x_k\n", | |
" yij = y_j - y_i\n", | |
" ykj = y_j - y_k\n", | |
" zij = z_j - z_i\n", | |
" zkj = z_j - z_k\n", | |
"\n", | |
" rij_dot_rkj = eps2 + xij * xkj + yij * ykj + zij * zkj\n", | |
" rij_dot_rij = eps2 + xij**2 + yij**2 + zij**2\n", | |
" rkj_dot_rkj = eps2 + xkj**2 + ykj**2 + zkj**2\n", | |
"\n", | |
" norm = np.sqrt(rij_dot_rij * rkj_dot_rkj)\n", | |
" delta = rij_dot_rkj / norm - np.cos(theta_0)\n", | |
"\n", | |
" cij = rij_dot_rkj / rij_dot_rij\n", | |
" ckj = rij_dot_rkj / rkj_dot_rkj\n", | |
"\n", | |
" grad_x = (k * delta / norm) * np.array(\n", | |
" [\n", | |
" cij * xij - xkj,\n", | |
" (1 - cij) * xij + (1 - ckj) * xkj,\n", | |
" -xij + ckj * xkj,\n", | |
" ]\n", | |
" )\n", | |
"\n", | |
" grad_p = [\n", | |
" delta**2 / 2,\n", | |
" k * delta * np.sin(theta_0),\n", | |
" k * delta * epsilon * (2 - ckj - cij) / norm,\n", | |
" ]\n", | |
"\n", | |
" return k * delta**2 / 2, grad_x, grad_p" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "23f35bd7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"assert_consistency_on_random_inputs(ha2, lambda *args: ha2_value_and_grad(*args)[0])\n", | |
"assert_consistency_on_random_inputs(\n", | |
" ha2.diff(xi), lambda *args: ha2_value_and_grad(*args)[1][0]\n", | |
")\n", | |
"assert_consistency_on_random_inputs(\n", | |
" ha2.diff(xj), lambda *args: ha2_value_and_grad(*args)[1][1]\n", | |
")\n", | |
"assert_consistency_on_random_inputs(\n", | |
" ha2.diff(xk), lambda *args: ha2_value_and_grad(*args)[1][2]\n", | |
")\n", | |
"assert_consistency_on_random_inputs(\n", | |
" ha2.diff(k), lambda *args: ha2_value_and_grad(*args)[2][0]\n", | |
")\n", | |
"assert_consistency_on_random_inputs(\n", | |
" ha2.diff(theta0), lambda *args: ha2_value_and_grad(*args)[2][1]\n", | |
")\n", | |
"assert_consistency_on_random_inputs(\n", | |
" ha2.diff(epsilon), lambda *args: ha2_value_and_grad(*args)[2][2]\n", | |
")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment