Skip to content

Instantly share code, notes, and snippets.

@mcwitt
Last active January 19, 2023 00:39
Show Gist options
  • Save mcwitt/7c31e296a1292536014e156ffc4b4b58 to your computer and use it in GitHub Desktop.
Save mcwitt/7c31e296a1292536014e156ffc4b4b58 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e41bccde-100f-40b1-b748-eafb91231422",
"metadata": {},
"outputs": [],
"source": [
"%config InlineBackend.figure_format = \"retina\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f1cc9c93-07ef-4329-b669-89081a7e884f",
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" (function() {\n",
" jb_set_cell(\"cos2 = (rij.dot(rkj) + epsilon**2) / sp.sqrt(\\n (rij.dot(rij) + epsilon**2) * (rkj.dot(rkj) + epsilon**2)\\n)\\n\\nha2 = k / 2 * (cos2 - sp.cos(theta0)) ** 2\")\n",
" })();\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import jupyter_black\n",
"jupyter_black.load(lab=False)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "01a4688e-993d-46aa-8897-810cfac92d64",
"metadata": {},
"outputs": [],
"source": [
"import jax"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a480630e-d858-4bc6-9f69-b40b39a03bf0",
"metadata": {},
"outputs": [],
"source": [
"jax.config.update(\"jax_enable_x64\", True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "17b272a3",
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import sympy as sp"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3a592011",
"metadata": {},
"outputs": [],
"source": [
"from sympy.vector import CoordSys3D\n",
"\n",
"C = CoordSys3D(\"C\")"
]
},
{
"cell_type": "markdown",
"id": "4f4dfbd0",
"metadata": {},
"source": [
"## Harmonic angle\n",
"\n",
"$\\renewcommand{\\vec}[1]{{\\mathbf{\\boldsymbol{{#1}}}}}$\n",
"\\begin{equation}\n",
"U(\\vec{r}_i, \\vec{r}_j, \\vec{r}_k; k, \\theta_0) = k (\\cos \\theta - \\cos \\theta_0)^2\n",
"\\end{equation}\n",
"where\n",
"\\begin{equation}\n",
"\\cos \\theta = \\frac{\\vec{r}_{ij} \\cdot \\vec{r}_{kj}}{r_{ij} r_{kj} + \\epsilon^2}\n",
"\\end{equation}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "808d8823",
"metadata": {},
"outputs": [],
"source": [
"from sympy.abc import k, epsilon"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7434fcda",
"metadata": {},
"outputs": [],
"source": [
"xi, xj, xk = sp.symbols(\"x_i x_j x_k\")\n",
"yi, yj, yk = sp.symbols(\"y_i y_j y_k\")\n",
"zi, zj, zk = sp.symbols(\"z_i z_j z_k\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "979c26cb",
"metadata": {},
"outputs": [],
"source": [
"ri = xi * C.i + yi * C.j + zi * C.k\n",
"rj = xj * C.i + yj * C.j + zj * C.k\n",
"rk = xk * C.i + yk * C.j + zk * C.k"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d69d8370",
"metadata": {},
"outputs": [],
"source": [
"rij = rj - ri\n",
"rkj = rj - rk"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "11005c24",
"metadata": {},
"outputs": [],
"source": [
"cos1 = rij.dot(rkj) / (rij.magnitude() * rkj.magnitude() + epsilon**2)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0ce55e52",
"metadata": {},
"outputs": [],
"source": [
"theta0 = sp.symbols(\"theta_0\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9914400e",
"metadata": {},
"outputs": [],
"source": [
"ha1 = k / 2 * (cos1 - sp.cos(theta0)) ** 2"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "8afde9ff",
"metadata": {},
"outputs": [],
"source": [
"def ha1_approx(x_i, x_j, x_k, y_i, y_j, y_k, z_i, z_j, z_k, k, theta_0, epsilon):\n",
" ri = jnp.array([x_i, y_i, z_i])\n",
" rj = jnp.array([x_j, y_j, z_j])\n",
" rk = jnp.array([x_k, y_k, z_k])\n",
" rij = rj - ri\n",
" rkj = rj - rk\n",
" cos1 = jnp.dot(rij, rkj) / (\n",
" jnp.linalg.norm(rij) * jnp.linalg.norm(rkj) + epsilon**2\n",
" )\n",
" return k / 2 * (cos1 - jnp.cos(theta_0)) ** 2"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "42d7a094",
"metadata": {},
"outputs": [],
"source": [
"def assert_consistency_on_random_inputs(f, f_approx, n=5):\n",
" for _ in range(n):\n",
" subs = dict(\n",
" x_i=np.random.uniform(),\n",
" x_j=np.random.uniform(),\n",
" x_k=np.random.uniform(),\n",
" y_i=np.random.uniform(),\n",
" y_j=np.random.uniform(),\n",
" y_k=np.random.uniform(),\n",
" z_i=np.random.uniform(),\n",
" z_j=np.random.uniform(),\n",
" z_k=np.random.uniform(),\n",
" k=np.random.uniform(),\n",
" theta_0=np.random.uniform(),\n",
" epsilon=np.random.uniform(),\n",
" )\n",
"\n",
" np.testing.assert_allclose(float(f.evalf(subs=subs)), f_approx(*subs.values()))\n",
"\n",
" # check with r_ij = 0\n",
" subs = dict(\n",
" x_i=0.0,\n",
" x_j=0.0,\n",
" x_k=np.random.uniform(),\n",
" y_i=0.0,\n",
" y_j=0.0,\n",
" y_k=np.random.uniform(),\n",
" z_i=0.0,\n",
" z_j=0.0,\n",
" z_k=np.random.uniform(),\n",
" k=np.random.uniform(),\n",
" theta_0=np.random.uniform(),\n",
" epsilon=np.random.uniform(),\n",
" )\n",
"\n",
" x = float(f.evalf(subs=subs))\n",
" assert not np.isnan(x)\n",
" np.testing.assert_allclose(x, f_approx(*subs.values()))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "bb5bab7d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"assert_consistency_on_random_inputs(ha1, ha1_approx)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "94055e08",
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [17]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43massert_consistency_on_random_inputs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mha1\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiff\u001b[49m\u001b[43m(\u001b[49m\u001b[43mxi\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mha1_approx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margnums\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
"Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36massert_consistency_on_random_inputs\u001b[0;34m(f, f_approx, n)\u001b[0m\n\u001b[1;32m 21\u001b[0m subs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m(\n\u001b[1;32m 22\u001b[0m x_i\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.0\u001b[39m,\n\u001b[1;32m 23\u001b[0m x_j\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.0\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 33\u001b[0m epsilon\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39muniform(),\n\u001b[1;32m 34\u001b[0m )\n\u001b[1;32m 36\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mfloat\u001b[39m(f\u001b[38;5;241m.\u001b[39mevalf(subs\u001b[38;5;241m=\u001b[39msubs))\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m np\u001b[38;5;241m.\u001b[39misnan(x)\n\u001b[1;32m 38\u001b[0m np\u001b[38;5;241m.\u001b[39mtesting\u001b[38;5;241m.\u001b[39massert_allclose(x, f_approx(\u001b[38;5;241m*\u001b[39msubs\u001b[38;5;241m.\u001b[39mvalues()))\n",
"\u001b[0;31mAssertionError\u001b[0m: "
]
}
],
"source": [
"assert_consistency_on_random_inputs(ha1.diff(xi), jax.grad(ha1_approx, argnums=0))"
]
},
{
"cell_type": "markdown",
"id": "43a4fe61",
"metadata": {},
"source": [
"#### Using naive common subexpression elimination, the resulting expressions for the spatial derivatives are singular at $r_{ij}=0$ and $r_{kj}=0$"
]
},
{
"cell_type": "markdown",
"id": "fd0f2a35",
"metadata": {},
"source": [
"## Harmonic angle (alternate)\n",
"\n",
"$\\renewcommand{\\vec}[1]{{\\mathbf{\\boldsymbol{{#1}}}}}$\n",
"\\begin{equation}\n",
"U(\\vec{r}_i, \\vec{r}_j, \\vec{r}_k; k, \\theta_0) = \\frac{1}{2} k (\\cos \\theta - \\cos \\theta_0)^2\n",
"\\end{equation}\n",
"where\n",
"\\begin{equation}\n",
"\\cos \\theta = \\frac{\\vec{r}_{ij} \\cdot \\vec{r}_{kj} + \\epsilon^2}{\\sqrt{(r_{ij}^2 + \\epsilon^2)(r_{kj}^2 + \\epsilon^2)}}\n",
"\\end{equation}"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "795d89ee",
"metadata": {},
"outputs": [],
"source": [
"cos2 = (rij.dot(rkj) + epsilon**2) / sp.sqrt(\n",
" (rij.dot(rij) + epsilon**2) * (rkj.dot(rkj) + epsilon**2)\n",
")\n",
"\n",
"ha2 = k / 2 * (cos2 - sp.cos(theta0)) ** 2"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a77c67f7",
"metadata": {},
"outputs": [],
"source": [
"def ha2_approx(x_i, x_j, x_k, y_i, y_j, y_k, z_i, z_j, z_k, k, theta_0, epsilon):\n",
" ri = jnp.array([x_i, y_i, z_i])\n",
" rj = jnp.array([x_j, y_j, z_j])\n",
" rk = jnp.array([x_k, y_k, z_k])\n",
" rij = rj - ri\n",
" rkj = rj - rk\n",
" cos2 = (jnp.dot(rij, rkj) + epsilon**2) / jnp.sqrt(\n",
" (jnp.dot(rij, rij) + epsilon**2) * (jnp.dot(rkj, rkj) + epsilon**2)\n",
" )\n",
" return k / 2 * (cos2 - jnp.cos(theta_0)) ** 2"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "172ecb84",
"metadata": {},
"outputs": [],
"source": [
"assert_consistency_on_random_inputs(ha2, ha2_approx)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "9380dd68",
"metadata": {},
"outputs": [],
"source": [
"assert_consistency_on_random_inputs(ha2.diff(xi), jax.grad(ha2_approx, argnums=0))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "2aa8eed8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([(x0, epsilon**2),\n",
" (x1, x_i - x_j),\n",
" (x2, -x1),\n",
" (x3, x_j - x_k),\n",
" (x4, -y_i + y_j),\n",
" (x5, y_j - y_k),\n",
" (x6, -z_i + z_j),\n",
" (x7, z_j - z_k),\n",
" (x8, x0 + x2*x3 + x4*x5 + x6*x7),\n",
" (x9, x0 + x2**2 + x4**2 + x6**2),\n",
" (x10, x0 + x3**2 + x5**2 + x7**2),\n",
" (x11, 1/sqrt(x10*x9)),\n",
" (x12, x11*x8 - cos(theta_0)),\n",
" (x13, x12**2/2),\n",
" (x14, 2*x_j),\n",
" (x15, -x14),\n",
" (x16, x15 + 2*x_i),\n",
" (x17, x11*x8),\n",
" (x18, x17/x9),\n",
" (x19, k*x12),\n",
" (x20, x19/2),\n",
" (x21, 2*x11),\n",
" (x22, x14 - 2*x_k),\n",
" (x23, 1/x10),\n",
" (x24, 2*x18*x23)],\n",
" [k*x13,\n",
" x20*(-2*x11*x3 - x16*x18),\n",
" x20*(x21*(-x15 - x_i - x_k) + x24*(x10*x16/2 - x22*x9/2)),\n",
" x20*(x1*x21 + x17*x22*x23),\n",
" x13,\n",
" x19*sin(theta_0),\n",
" x20*(4*epsilon*x11 + x24*(-epsilon*x10 - epsilon*x9))])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sp.cse(\n",
" [\n",
" ha2,\n",
" ha2.diff(xi),\n",
" ha2.diff(xj),\n",
" ha2.diff(xk),\n",
" ha2.diff(k),\n",
" ha2.diff(theta0),\n",
" ha2.diff(epsilon),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "6a65a001",
"metadata": {},
"outputs": [],
"source": [
"def ha2_value_and_grad(\n",
" x_i, x_j, x_k, y_i, y_j, y_k, z_i, z_j, z_k, k, theta_0, epsilon\n",
"):\n",
" eps2 = epsilon**2\n",
"\n",
" xij = x_j - x_i\n",
" xkj = x_j - x_k\n",
" yij = y_j - y_i\n",
" ykj = y_j - y_k\n",
" zij = z_j - z_i\n",
" zkj = z_j - z_k\n",
"\n",
" rij_dot_rkj = eps2 + xij * xkj + yij * ykj + zij * zkj\n",
" rij_dot_rij = eps2 + xij**2 + yij**2 + zij**2\n",
" rkj_dot_rkj = eps2 + xkj**2 + ykj**2 + zkj**2\n",
"\n",
" norm = np.sqrt(rij_dot_rij * rkj_dot_rkj)\n",
" delta = rij_dot_rkj / norm - np.cos(theta_0)\n",
"\n",
" cij = rij_dot_rkj / rij_dot_rij\n",
" ckj = rij_dot_rkj / rkj_dot_rkj\n",
"\n",
" grad_x = (k * delta / norm) * np.array(\n",
" [\n",
" cij * xij - xkj,\n",
" (1 - cij) * xij + (1 - ckj) * xkj,\n",
" -xij + ckj * xkj,\n",
" ]\n",
" )\n",
"\n",
" grad_p = [\n",
" delta**2 / 2,\n",
" k * delta * np.sin(theta_0),\n",
" k * delta * epsilon * (2 - ckj - cij) / norm,\n",
" ]\n",
"\n",
" return k * delta**2 / 2, grad_x, grad_p"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "23f35bd7",
"metadata": {},
"outputs": [],
"source": [
"assert_consistency_on_random_inputs(ha2, lambda *args: ha2_value_and_grad(*args)[0])\n",
"assert_consistency_on_random_inputs(\n",
" ha2.diff(xi), lambda *args: ha2_value_and_grad(*args)[1][0]\n",
")\n",
"assert_consistency_on_random_inputs(\n",
" ha2.diff(xj), lambda *args: ha2_value_and_grad(*args)[1][1]\n",
")\n",
"assert_consistency_on_random_inputs(\n",
" ha2.diff(xk), lambda *args: ha2_value_and_grad(*args)[1][2]\n",
")\n",
"assert_consistency_on_random_inputs(\n",
" ha2.diff(k), lambda *args: ha2_value_and_grad(*args)[2][0]\n",
")\n",
"assert_consistency_on_random_inputs(\n",
" ha2.diff(theta0), lambda *args: ha2_value_and_grad(*args)[2][1]\n",
")\n",
"assert_consistency_on_random_inputs(\n",
" ha2.diff(epsilon), lambda *args: ha2_value_and_grad(*args)[2][2]\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment