Created
September 30, 2022 16:03
-
-
Save mariogeiger/1be3d0d482d51ceea44cf43eaf44a648 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import e3nn_jax as e3nn\n", | |
"import haiku as hk\n", | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"import jax.scipy\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"from e3nn_jax.util import assert_equivariant\n", | |
"\n", | |
"jax.config.update(\"jax_enable_x64\", True)\n", | |
"np.set_printoptions(precision=3, suppress=True)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def rotation_x_towards_y(\n", | |
" x: jnp.ndarray, y: jnp.ndarray, irreps: e3nn.Irreps = \"1e\"\n", | |
") -> jnp.ndarray:\n", | |
" x = x / jnp.linalg.norm(x)\n", | |
" y = y / jnp.linalg.norm(y)\n", | |
" a = jnp.cross(x, y)\n", | |
" na = jnp.linalg.norm(a)\n", | |
" an = jnp.where(\n", | |
" na > 0.0, a / jnp.where(na > 0.0, na, 1.0), jnp.array([1.0, 0.0, 0.0])\n", | |
" )\n", | |
" a = jnp.arccos(jnp.dot(x, y)) * an\n", | |
"\n", | |
" return jax.scipy.linalg.expm(\n", | |
" jnp.einsum(\"k,klm->lm\", a, e3nn.Irreps(irreps).generators())\n", | |
" )\n", | |
"\n", | |
"\n", | |
"# Checks\n", | |
"x = jax.random.normal(jax.random.PRNGKey(0), (3,))\n", | |
"y = jax.random.normal(jax.random.PRNGKey(1), (3,))\n", | |
"np.testing.assert_allclose(\n", | |
" rotation_x_towards_y(x, y) @ x,\n", | |
" jnp.linalg.norm(x) * y / jnp.linalg.norm(y),\n", | |
" atol=1e-8,\n", | |
")\n", | |
"\n", | |
"x = jnp.array([1.0, 0.0, 0.0])\n", | |
"y = jnp.array([1.0, 0.0, 0.0])\n", | |
"np.testing.assert_allclose(\n", | |
" rotation_x_towards_y(x, y),\n", | |
" jnp.eye(3),\n", | |
")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[0. 1. 0. 0. 0. 0. 0.]\n", | |
" [1. 0. 1. 0. 0. 0. 0.]\n", | |
" [0. 0. 0. 0. 1. 0. 0.]\n", | |
" [0. 0. 0. 1. 0. 1. 0.]\n", | |
" [0. 0. 0. 0. 0. 0. 1.]]\n" | |
] | |
} | |
], | |
"source": [ | |
"def tie_matrix(irreps: e3nn.Irreps) -> jnp.ndarray:\n", | |
" irreps = e3nn.Irreps(irreps)\n", | |
" n = sum([mul * (1 + ir.l) for mul, ir in irreps])\n", | |
" M = jnp.zeros((n, irreps.dim))\n", | |
"\n", | |
" i = 0\n", | |
" j = 0\n", | |
" for mul, ir in irreps:\n", | |
" for _ in range(mul):\n", | |
" for m in range(1 + ir.l):\n", | |
" M = M.at[i, j + ir.l - m].set(1.0)\n", | |
" M = M.at[i, j + ir.l + m].set(1.0)\n", | |
" i += 1\n", | |
" j += 2 * ir.l + 1\n", | |
" return M\n", | |
"\n", | |
"# Checks\n", | |
"print(tie_matrix(\"2x1e + 0e\"))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class TieBinding(hk.Module):\n", | |
" def __call__(self, x: e3nn.IrrepsArray, v: jnp.ndarray) -> e3nn.IrrepsArray:\n", | |
" assert v.shape == (3,)\n", | |
" R = rotation_x_towards_y(v, jnp.array([0.0, 1.0, 0.0]), irreps=x.irreps)\n", | |
" M = tie_matrix(x.irreps)\n", | |
" w = hk.get_parameter(\n", | |
" \"w\", (M.shape[0], M.shape[0]), init=hk.initializers.RandomNormal()\n", | |
" )\n", | |
"\n", | |
" def f(x):\n", | |
" if x.ndim == 1:\n", | |
" rx = R @ x\n", | |
" rx = (M.T @ w @ M @ rx**2) * rx\n", | |
" return R.T @ rx\n", | |
" else:\n", | |
" return jax.vmap(f)(x)\n", | |
"\n", | |
" return e3nn.IrrepsArray(x.irreps, f(x.array))\n", | |
"\n", | |
"\n", | |
"# Test equivariance\n", | |
"@hk.without_apply_rng\n", | |
"@hk.transform\n", | |
"def model(x: e3nn.IrrepsArray, v: e3nn.IrrepsArray) -> e3nn.IrrepsArray:\n", | |
" assert v.irreps == \"1e\"\n", | |
" return TieBinding()(x, v.array)\n", | |
"\n", | |
"\n", | |
"x = e3nn.normal(\"2e\", jax.random.PRNGKey(0), ())\n", | |
"v = e3nn.IrrepsArray(\"1e\", jnp.array([0.2, 0.6, 0.8]))\n", | |
"w = model.init(jax.random.PRNGKey(0), x, v)\n", | |
"assert_equivariant(\n", | |
" lambda x, v: model.apply(w, x, v), jax.random.PRNGKey(0), args_in=(x, v)\n", | |
")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# - Tie Binding is SO(3)-equivariant" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3.10.6 64-bit", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.6" | |
}, | |
"orig_nbformat": 4, | |
"vscode": { | |
"interpreter": { | |
"hash": "f26faf9d33dc8b83cd077f62f5d9010e5bc51611e479f12b96223e2da63ba699" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment