Skip to content

Instantly share code, notes, and snippets.

@mariogeiger
Created September 30, 2022 16:03
Show Gist options
  • Save mariogeiger/1be3d0d482d51ceea44cf43eaf44a648 to your computer and use it in GitHub Desktop.
Save mariogeiger/1be3d0d482d51ceea44cf43eaf44a648 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import e3nn_jax as e3nn\n",
"import haiku as hk\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.scipy\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from e3nn_jax.util import assert_equivariant\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"np.set_printoptions(precision=3, suppress=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def rotation_x_towards_y(\n",
" x: jnp.ndarray, y: jnp.ndarray, irreps: e3nn.Irreps = \"1e\"\n",
") -> jnp.ndarray:\n",
" x = x / jnp.linalg.norm(x)\n",
" y = y / jnp.linalg.norm(y)\n",
" a = jnp.cross(x, y)\n",
" na = jnp.linalg.norm(a)\n",
" an = jnp.where(\n",
" na > 0.0, a / jnp.where(na > 0.0, na, 1.0), jnp.array([1.0, 0.0, 0.0])\n",
" )\n",
" a = jnp.arccos(jnp.dot(x, y)) * an\n",
"\n",
" return jax.scipy.linalg.expm(\n",
" jnp.einsum(\"k,klm->lm\", a, e3nn.Irreps(irreps).generators())\n",
" )\n",
"\n",
"\n",
"# Checks\n",
"x = jax.random.normal(jax.random.PRNGKey(0), (3,))\n",
"y = jax.random.normal(jax.random.PRNGKey(1), (3,))\n",
"np.testing.assert_allclose(\n",
" rotation_x_towards_y(x, y) @ x,\n",
" jnp.linalg.norm(x) * y / jnp.linalg.norm(y),\n",
" atol=1e-8,\n",
")\n",
"\n",
"x = jnp.array([1.0, 0.0, 0.0])\n",
"y = jnp.array([1.0, 0.0, 0.0])\n",
"np.testing.assert_allclose(\n",
" rotation_x_towards_y(x, y),\n",
" jnp.eye(3),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0. 1. 0. 0. 0. 0. 0.]\n",
" [1. 0. 1. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 1. 0. 0.]\n",
" [0. 0. 0. 1. 0. 1. 0.]\n",
" [0. 0. 0. 0. 0. 0. 1.]]\n"
]
}
],
"source": [
"def tie_matrix(irreps: e3nn.Irreps) -> jnp.ndarray:\n",
" irreps = e3nn.Irreps(irreps)\n",
" n = sum([mul * (1 + ir.l) for mul, ir in irreps])\n",
" M = jnp.zeros((n, irreps.dim))\n",
"\n",
" i = 0\n",
" j = 0\n",
" for mul, ir in irreps:\n",
" for _ in range(mul):\n",
" for m in range(1 + ir.l):\n",
" M = M.at[i, j + ir.l - m].set(1.0)\n",
" M = M.at[i, j + ir.l + m].set(1.0)\n",
" i += 1\n",
" j += 2 * ir.l + 1\n",
" return M\n",
"\n",
"# Checks\n",
"print(tie_matrix(\"2x1e + 0e\"))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class TieBinding(hk.Module):\n",
" def __call__(self, x: e3nn.IrrepsArray, v: jnp.ndarray) -> e3nn.IrrepsArray:\n",
" assert v.shape == (3,)\n",
" R = rotation_x_towards_y(v, jnp.array([0.0, 1.0, 0.0]), irreps=x.irreps)\n",
" M = tie_matrix(x.irreps)\n",
" w = hk.get_parameter(\n",
" \"w\", (M.shape[0], M.shape[0]), init=hk.initializers.RandomNormal()\n",
" )\n",
"\n",
" def f(x):\n",
" if x.ndim == 1:\n",
" rx = R @ x\n",
" rx = (M.T @ w @ M @ rx**2) * rx\n",
" return R.T @ rx\n",
" else:\n",
" return jax.vmap(f)(x)\n",
"\n",
" return e3nn.IrrepsArray(x.irreps, f(x.array))\n",
"\n",
"\n",
"# Test equivariance\n",
"@hk.without_apply_rng\n",
"@hk.transform\n",
"def model(x: e3nn.IrrepsArray, v: e3nn.IrrepsArray) -> e3nn.IrrepsArray:\n",
" assert v.irreps == \"1e\"\n",
" return TieBinding()(x, v.array)\n",
"\n",
"\n",
"x = e3nn.normal(\"2e\", jax.random.PRNGKey(0), ())\n",
"v = e3nn.IrrepsArray(\"1e\", jnp.array([0.2, 0.6, 0.8]))\n",
"w = model.init(jax.random.PRNGKey(0), x, v)\n",
"assert_equivariant(\n",
" lambda x, v: model.apply(w, x, v), jax.random.PRNGKey(0), args_in=(x, v)\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# - Tie Binding is SO(3)-equivariant"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.6 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "f26faf9d33dc8b83cd077f62f5d9010e5bc51611e479f12b96223e2da63ba699"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment