Skip to content

Instantly share code, notes, and snippets.

@mariogeiger
Created November 15, 2022 21:09
Show Gist options
  • Save mariogeiger/160848d2d67e6d8292271a243bb6f40e to your computer and use it in GitHub Desktop.
Save mariogeiger/160848d2d67e6d8292271a243bb6f40e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Boilerplate"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax\n",
"import matplotlib.pyplot as plt\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"np.set_printoptions(precision=4, suppress=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Installation\n",
"\n",
"To install the package, run the following command:\n",
"\n",
"```\n",
"pip install git+https://github.com/e3nn/e3nn-jax.git\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import e3nn_jax as e3nn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Functions for rotations\n",
"\n",
"### parameterizations\n",
"- Euler angles\n",
"- Rotation matrices\n",
"- Axis-angle\n",
"- Quaternions\n",
"\n",
"### functions\n",
"- Convertions between parameterizations\n",
"- Compositions of rotations\n",
"- Random rotations\n",
"- Inverse rotation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([0.7821, 0.6875, 1.4887, 5.9424, 1.002 , 5.3127, 0.0833,\n",
" 2.333 , 2.1504, 3.9292], dtype=float64),\n",
" DeviceArray([0.4508, 1.5341, 1.291 , 2.6191, 1.2815, 1.9438, 2.1814,\n",
" 2.0649, 2.0775, 1.1074], dtype=float64),\n",
" DeviceArray([4.4497, 1.0236, 1.8509, 0.0122, 3.4505, 5.3254, 5.8178,\n",
" 3.4122, 2.5171, 5.5721], dtype=float64))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.rand_angles(jax.random.PRNGKey(0), (10,))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([-0.0449, -0.8229, 0.5665], dtype=float64),\n",
" DeviceArray(2.2121, dtype=float64))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.rand_axis_angle(jax.random.PRNGKey(0), ())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([0.4486, 0.8926, 0.045 ], dtype=float64),\n",
" DeviceArray(0.4466, dtype=float64, weak_type=True))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.angles_to_axis_angle(alpha=0.1, beta=0.2, gamma=0.3)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([0.4459, 0.894 , 0.0447], dtype=float64),\n",
" DeviceArray(0.2235, dtype=float64))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.compose_axis_angle(\n",
" axis1=jnp.array([1, 0, 0]),\n",
" angle1=0.1,\n",
" axis2=jnp.array([0, 1, 0]),\n",
" angle2=0.2,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Irreps\n",
"\n",
"Representation of the group $O(3)$\n",
"\n",
"### functions\n",
"- simplify\n",
"- sort\n",
"- filter\n",
"\n",
"$$ D'(g) = A D(g) A^{-1} $$"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10x0e"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.Irreps([(10, (0, 1))])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"64x0e+32x1e+32x1o"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.Irreps(\"64x0e + 32x1e + 32x1o\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.Irreps(\"2x1o\").dim"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.Irreps(\"2x0e + 2x1o\").num_irreps"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4x0e"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.Irreps(\"2x0e + 2x0e\").simplify()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sort(irreps=2x0e+2x0e+1x1o, p=(0, 2, 1), inv=(0, 2, 1))"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.Irreps(\"2x0e + 1o + 2x0e\").sort()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[ 0.8776, 0.4034, 0.259 ],\n",
" [ 0. , 0.5403, -0.8415],\n",
" [-0.4794, 0.7385, 0.4742]], dtype=float64)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.Irreps(\"1e\").D_from_angles(0.5, 1.0, 0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# IrrepsArray\n",
"\n",
"The `IrrepsArray` is a class that represents a tensor `ndarray` with a given representation of rotation `Irreps`."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3x0e+1x1o [300. -10. -5. 1. 0. 0.]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"array = jnp.array([300.0, -10.0, -5.0, 1.0, 0.0, 0.0])\n",
"irreps = e3nn.Irreps(\"3x0e + 1o\")\n",
"\n",
"x = e3nn.IrrepsArray(irreps, array)\n",
"x"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3x0e+1x1o [[300. -10. -5. 1. 0. 0.]]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.reshape((1, 6))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- binary operations\n",
"- indexing\n",
"- e3nn.mean\n",
"- e3nn.norm\n",
"- e3nn.normal\n",
"- axis-to-mul and mul-to-axis\n",
"- slice-by-* functions\n",
"- sorted\n",
"- simplify\n",
"- transform-by-* functions"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x0e [4.]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.IrrepsArray(\"0e\", jnp.array([3.0])) + 1.0 # (assumed as 0e)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3x0e [301. -9. -4.]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x[:3] + 1.0"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3x0e+1x1o [301. -9. -4. 2. 1. 1.]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x + e3nn.IrrepsArray(\"3x0e + 1o\", jnp.ones(6))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.Irreps(\"3x0e + 1o\").dim"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.Irreps(\"3x0e + 1o\").num_irreps"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x1o [1. 0. 0.]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x[3:]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x1o [1. 0. 0.]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# s0 s1 s2 [v00 v01 v02]\n",
"\n",
"x.slice_by_dim[-3:]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x1o [1. 0. 0.]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# s0 s1 s2 [v0]\n",
"\n",
"x.slice_by_mul[-1:]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x1o [1. 0. 0.]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# s [v]\n",
"\n",
"x.slice_by_chunk[-1:]"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2x0e+1x1e+2x0e"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.IrrepsArray(\"2x0e + 1e + 2x0e\", jnp.ones(2 + 3 + 2))\n",
"\n",
"x.irreps"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4x0e+1x1e [1. 1. 1. 1. 1. 1. 1.]"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.sorted().simplify()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2x0e [1. 1.]"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.slice_by_chunk[:1]"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2x1e [ 3. 4. 5. 12. 14. 16.]"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.IrrepsArray(\"2x0e\", jnp.array([1.0, 2.0]))\n",
"y = e3nn.IrrepsArray(\"2x1e\", jnp.array([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))\n",
"\n",
"x * y"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"x = e3nn.IrrepsArray(\"0e + 1e\", jnp.array([1.0, 2.0, 3.0, 4.0]))\n",
"y = e3nn.IrrepsArray(\"1e + 0e\", jnp.array([3.0, 4.0, 5.0, 6.0]))\n",
"\n",
"# x * y ---> ValueError: x * y with both x and y non scalar and ambiguous. Use e3nn.elementwise_tensor_product or e3nn.tensor_product instead."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"x = e3nn.IrrepsArray(\"2x1e\", jnp.array([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))\n",
"y = e3nn.IrrepsArray(\"2x1e\", jnp.array([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))\n",
"\n",
"# x * y ---> ValueError: x * y with both x and y non scalar and ambiguous. Use e3nn.elementwise_tensor_product or e3nn.tensor_product instead."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3, 256)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.normal(\"64x0e + 64x1o\", jax.random.PRNGKey(0), (3,))\n",
"\n",
"x.shape"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3, 64, 4)"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.mul_to_axis().shape"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3, 256)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.reshape((-1, -1)).shape"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x1o [2.1239 1.4495 2.7181]"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.IrrepsArray(\"1o\", jnp.array([1.0, 2.0, 3.0])).transform_by_angles(0.1, 0.2, 0.3)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x0e+1x1o+2x1e [6 0 1 2 3 4 5 7 8 9]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.IrrepsArray(\"1o + 1e + 0e + 1e\", jnp.arange(10)).sorted().simplify()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tensor products\n",
"\n",
"- spherical harmonics\n",
"- tensor product\n",
"- elementwise product\n",
"- tensor square\n",
"- reduced tensor product basis\n",
"- symmetric contraction"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x1e+1x1o [ 3. 4. 5. 12. 18. 24.]"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.IrrepsArray(\"0e + 1e\", jnp.array([1.0, 2.0, 3.0, 4.0]))\n",
"y = e3nn.IrrepsArray(\"1e + 0o\", jnp.array([3.0, 4.0, 5.0, 6.0]))\n",
"\n",
"e3nn.elementwise_tensor_product(x, y)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x0e+1x0o+1x1o+1x1e+1x1e+1x2e\n",
"[21.9393 6. 12. 18. 24. 3. 4. 5. -0.7071\n",
" 1.4142 -0.7071 15.5563 12.0208 -0.8165 21.9203 9.8995]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.tensor_product(x, y)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x0o+1x1o+1x1o+1x2o\n",
"[ 3.4641 0.7071 -1.4142 0.7071 100. 100. 100. 2.8284\n",
" 2.1213 0. 3.5355 1.4142]"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.IrrepsArray(\"1o\", jnp.array([1.0, 1.0, 1.0]))\n",
"y = e3nn.IrrepsArray(\"1e + 0e\", jnp.array([1.0, 2.0, 3.0, 100.0]))\n",
"\n",
"e3nn.tensor_product(x, y)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[ 1., 2., 3., 100.],\n",
" [ 1., 2., 3., 100.],\n",
" [ 1., 2., 3., 100.]], dtype=float64)"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.einsum(\"i,j\", x.array, y.array)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([ 3.4641, 100. , 100. , 100. , 0.7071, -1.4142,\n",
" 0.7071, 2.8284, 2.1213, -0. , 3.5355, 1.4142], dtype=float64)"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"u = e3nn.reduced_tensor_product_basis(\"ij\", i=\"1o\", j=\"1e + 0e\")\n",
"\n",
"jnp.einsum(\"i,j,ijk->k\", x.array, y.array, u.array)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2x1o"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.IrrepsArray(\"2x1o\", jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]))\n",
"x.irreps"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4x0e+4x1e+4x2e\n",
"[ 8.0829 18.4752 18.4752 44.456 -0. 0. 0. -2.1213 4.2426\n",
" -2.1213 2.1213 -4.2426 2.1213 -0. 0. 0. 4.2426 2.8284\n",
" -0.8165 8.4853 5.6569 12.7279 9.1924 -0.8165 19.0919 9.8995 12.7279\n",
" 9.1924 -0.8165 19.0919 9.8995 33.9411 28.2843 -0.8165 42.4264 14.1421]"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.tensor_product(x, x)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x0e+2x0e+1x1e+1x2e+2x2e\n",
"[18.4752 3.6148 19.8813 -2.1213 4.2426 -2.1213 12.7279 9.1924 -0.8165\n",
" 19.0919 9.8995 3. 2. -0.5774 6. 4. 24. 20.\n",
" -0.5774 30. 10. ]"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.tensor_square(x)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x2e [3.873 3.873 0. 3.873 0. ]"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.IrrepsArray(\"1o\", jnp.array([1.0, 1.0, 1.0]))\n",
"\n",
"e3nn.spherical_harmonics(2, x, False, normalization=\"component\")"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x2e [1. 1. 0. 1. 0.]"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.tensor_square(x)[1:]"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x0o+1x1o+1x2o+1x0e\n",
"[ 3.4641 0.7071 -1.4142 0.7071 2.8284 2.1213 0. 3.5355\n",
" 1.4142 200. ]"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.IrrepsArray(\"1o + 0e\", jnp.array([1.0, 1.0, 1.0, 2.0]))\n",
"y = e3nn.IrrepsArray(\"1e + 0e\", jnp.array([1.0, 2.0, 3.0, 100.0]))\n",
"\n",
"e3nn.elementwise_tensor_product(x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\"Normalization\" topic below explains the difference between `e3nn.tensor_product`, `e3nn.tensor_square` and `e3nn.spherical_harmonics`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Linear\n",
"\n",
"Where the weight matrices are..."
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2x0e [46.1297 -5.4155]"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import haiku as hk\n",
"\n",
"@hk.without_apply_rng\n",
"@hk.transform\n",
"def model(x):\n",
" x = e3nn.tensor_square(x)\n",
" x = e3nn.Linear(\"2x0e\")(x)\n",
" return x\n",
"\n",
"x = e3nn.IrrepsArray(\"0e + 1o\", jnp.array([10.0, 1.0, 1.0, 1.0]))\n",
"params = model.init(jax.random.PRNGKey(0), x)\n",
"y = model.apply(params, x)\n",
"\n",
"y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Grad\n",
"\n",
"The `e3nn.grad` function is a wrapper around `jax.grad` that takes into account the representation of rotation of the input and output of the function."
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2x0e+2x1o [ 9.2907 -1.1702 -0.216 -0.216 -0.216 0.2902 0.2902 0.2902]"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"e3nn.grad(model.apply, 1)(params, x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Gate Activation Function\n",
"\n",
"The `e3nn.gate` function is a simple way to apply activation functions to non-scalar quantities."
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x0e+1x1o [15.3485 1.8337 3.6675 5.5012]"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.IrrepsArray(\"2x0e + 1o\", jnp.array([10.0, 5.0, 1.0, 2.0, 3.0]))\n",
"\n",
"e3nn.gate(x)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(15.3485, dtype=float64)"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gelu = e3nn.normalize_function(jax.nn.gelu)\n",
"gelu(jnp.array(10.0))"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([1.8337, 3.6675, 5.5012], dtype=float64)"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sigmoid = e3nn.normalize_function(jax.nn.sigmoid)\n",
"\n",
"sigmoid(jnp.array(5.0)) * jnp.array([1.0, 2.0, 3.0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Radial Basis Functions\n",
"\n",
"- Bessel\n",
"- One hot functions\n",
"- polynomial envelopes\n",
"- smooth envelope"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x142195e10>"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x = jnp.linspace(0.0, 1.1, 100)\n",
"\n",
"plt.plot(x, e3nn.soft_envelope(x), label=\"soft_envelope\")\n",
"plt.plot(x, e3nn.poly_envelope(6, 2)(x), label=\"poly_envelope(6, 2)\")\n",
"\n",
"plt.legend()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Normalization\n",
"\n",
"If the data distribution at **input** satisfies the following conditions:\n",
"\n",
"$$ \\langle x_i^2 \\rangle = 1 $$\n",
"\n",
"Then the function $f$ is **component normalized** if the output distribution satisfies the same condition:\n",
"\n",
"$$ \\langle f(x)_i^2 \\rangle = 1 $$"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x0e+1x0e+1x0e [0.966 2.9739 4.906 ]"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.normal(\"1o\", jax.random.PRNGKey(0), (10_000,))\n",
"y = e3nn.normal(\"1o\", jax.random.PRNGKey(1), (10_000,))\n",
"\n",
"e3nn.mean(e3nn.norm(e3nn.tensor_product(x, y), squared=True), axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x0e+1x0e [0.9995 4.9973]"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.normal(\"1o\", jax.random.PRNGKey(0), (10_000,))\n",
"\n",
"e3nn.mean(e3nn.norm(e3nn.tensor_square(x), squared=True), axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1x0e+1x0e+1x0e [1. 3. 5.]"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = e3nn.normal(\"1o\", jax.random.PRNGKey(0), (10_000,))\n",
"x = x / e3nn.norm(x)\n",
"\n",
"e3nn.mean(e3nn.norm(e3nn.spherical_harmonics([0, 1, 2], x, False, normalization=\"component\"), squared=True), axis=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The `.list` optimization"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[None]"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import haiku as hk\n",
"\n",
"x = e3nn.IrrepsArray(\"1o\", jnp.array([1.0, 0.0, 0.0]))\n",
"\n",
"foo = hk.transform(lambda x: e3nn.Linear(\"1e\")(x))\n",
"\n",
"w = foo.init(jax.random.PRNGKey(0), x)\n",
"foo.apply(w, jax.random.PRNGKey(0), x).list"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.7 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "f26faf9d33dc8b83cd077f62f5d9010e5bc51611e479f12b96223e2da63ba699"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment