Skip to content

Instantly share code, notes, and snippets.

@blondegeek
Last active January 31, 2021 09:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save blondegeek/616af44ed17d76fa6e11392b24cf0dc1 to your computer and use it in GitHub Desktop.
Save blondegeek/616af44ed17d76fa6e11392b24cf0dc1 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A quick `e3nn` tutorial\n",
"\n",
"For more examples see [e3nn.org](https://e3nn.org) and [e3nn_tutorial](https://blondegeek.github.io/e3nn_tutorial)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create a basic network"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from e3nn import rs, o3\n",
"from e3nn.networks import GatedConvParityNetwork\n",
"torch.set_default_dtype(torch.float64)\n",
"\n",
"# Define the datatypes of the inputs and outputs\n",
"N_atom_types = 3 # For example H, C, O\n",
"Rs_in = [(N_atom_types, 0, 1), (1, 2, 1)] # Input are scalars and time averaged oscillating linearly polarized vector field\n",
"Rs_out = [(1, 1, -1)] # Predict vectors\n",
"\n",
"# Define maximum radius for convolution radial functions\n",
"r_max = 1.5\n",
"\n",
"model_kwargs = {\n",
" 'Rs_in': Rs_in, 'Rs_out': Rs_out, 'mul': 4, 'lmax': 2, \n",
" 'layers': 3, 'max_radius': r_max, \n",
" 'number_of_basis': 10, # Number of basis functions used for radial function\n",
"}\n",
"\n",
"# This network has gated nonlinearities (e3nn.non_linearities.GatedBlockParity)\n",
"# and the default (point) convolutions (e3nn.point.operations.Convolution).\n",
"# It uses irreps of O(3) rather than just SO(3) so you MUST specify parity.\n",
"model = GatedConvParityNetwork(**model_kwargs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Equivariance Test"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get random rotation angles and get representations for different datatypes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"angles = o3.rand_angles()\n",
"# Wigner D matrices\n",
"D_in, D_out = rs.rep(Rs_in, *angles), rs.rep(Rs_out, *angles)\n",
"# 3x3 Cartesian rotation matrix\n",
"rot = o3.rot(*angles)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Show representations of this random rotation for input and output features and geometry"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"fig, ax = plt.subplots(1, 3, figsize=(11, 3))\n",
"b = 1\n",
"ax[0].imshow(D_in, cmap='plasma', vmin=-b, vmax=b)\n",
"ax[0].set_title(\"$D_{in}$\")\n",
"ax[0].set_yticks(range(rs.dim(Rs_in)))\n",
"ax[0].set_xticks([])\n",
"ax[0].set_yticklabels([\"0,0\", \"0,0\", \"0,0\", \"2,-2\", \"2,-1\", \"2,0\", \"2,1\", \"2,2\"])\n",
"ax[1].imshow(D_out, cmap='plasma', vmin=-b, vmax=b)\n",
"ax[1].set_title(\"$D_{out}$\")\n",
"ax[1].set_yticks(range(rs.dim(Rs_out)))\n",
"ax[1].set_xticks([])\n",
"ax[1].set_yticklabels([\"1,-1 (y)\", \"1,0 (z)\", \"1,1 (x)\"])\n",
"ax[2].imshow(rot, cmap='plasma', vmin=-b, vmax=b)\n",
"ax[2].set_title(\"Cartesian rotation matrix\")\n",
"ax[2].set_yticks(range(rs.dim(Rs_out)))\n",
"ax[2].set_xticks([])\n",
"ax[2].set_yticklabels([\"x\", \"y\", \"z\"]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check equivariance of randomly initialized model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Random input\n",
"N = 5\n",
"input = torch.randn(1, N, rs.dim(Rs_in))\n",
"geo = torch.randn(1, N, 3)\n",
"\n",
"# Rotated inputs\n",
"rot_input = torch.einsum('ij,zaj->zai', D_in, input)\n",
"rot_geo = torch.einsum('ij,zaj->zai', rot, geo)\n",
"rot_out = model(rot_input, rot_geo)\n",
"\n",
"# Rotated outputs\n",
"out = model(input, geo)\n",
"out_rot = torch.einsum('ij,zaj->zai', D_out, out)\n",
"\n",
"# Check that both yield the same answer\n",
"assert torch.allclose(rot_out, out_rot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Outputs must have equal or higher symmetry as inputs\n",
"## Create network for graph data and visualize with SphericalTensor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import e3nn.point.message_passing as mp \n",
"from torch_scatter import scatter_mean\n",
"import e3nn.point.data_helpers as dh\n",
"from e3nn.tensor import SphericalTensor\n",
"from plotly.subplots import make_subplots\n",
"\n",
"Rs_in = [(1, 0, 1)] \n",
"L_max = 6\n",
"Rs_out = [(1, L, (-1)**L) for L in range(L_max)] # Predict vectors\n",
"\n",
"r_max = 1.5\n",
"\n",
"model_kwargs = {\n",
" 'Rs_in': Rs_in, 'Rs_out': Rs_out, 'mul': 4, 'lmax': 3, \n",
" 'layers': 3, 'max_radius': r_max, 'number_of_basis': 10,\n",
" 'convolution': mp.Convolution # We use a different convolution operation for graph data.\n",
"}\n",
"\n",
"# Create three random models\n",
"model1 = GatedConvParityNetwork(**model_kwargs)\n",
"model2 = GatedConvParityNetwork(**model_kwargs)\n",
"model3 = GatedConvParityNetwork(**model_kwargs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create tetrahedral geometry with identical scalars at each vertex\n",
"pos = torch.tensor([[0., 0., 0.], [1., 1., 0], [1., 0., 1.], [0, 1., 1.]]) # Vertices of a tetrahedron\n",
"data = dh.DataNeighbors(torch.ones(pos.shape[0], 1), Rs_in, pos, r_max)\n",
"# Input, edge index, and relative distance vectors on each edge\n",
"data.x.shape, data.edge_index.shape, data.edge_attr.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run the 3 random models\n",
"outs = (model1(data.x, data.edge_index, data.edge_attr),\n",
" model2(data.x, data.edge_index, data.edge_attr),\n",
" model3(data.x, data.edge_index, data.edge_attr))\n",
"print(output1.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## All outputs preserve tetrahedral symmetry ($T_d$)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import plotly.graph_objects as go",
"\n",
"rows, cols = 1, 3\n",
"specs = [[{'is_3d': True} for i in range(cols)]\n",
" for j in range(rows)]\n",
"fig = make_subplots(rows, cols, specs=specs)\n",
"\n",
"for i, out in enumerate(outs):\n",
" trace = SphericalTensor(out.detach().sum(0)).plotly_surface()\n",
" trace['showscale'] = False\n",
" trace = go.Surface(**trace)\n",
" fig.add_trace(trace, row=1, col=i+1)\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example for converting Cartesian tensors"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Degrees of freedom and transformation matrix for elasticity tensor to irrep basis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from e3nn import rs\n",
"from e3nn.tensor import CartesianTensor\n",
"torch.set_default_dtype(torch.float64)\n",
"\n",
"\n",
"rank4 = torch.zeros(3, 3, 3, 3) # Placeholder\n",
"Rs, Q = CartesianTensor(rank4, 'ijkl=jikl=klij').to_irrep_transformation()\n",
"print(\"Representations: \", Rs)\n",
"print(\"Degrees of freedom: \", rs.dim(Rs))\n",
"print('Q transformation [irrep_basis, flattened_cartesian]', Q.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing CartesianTensors as SphericalTensors"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import plotly\n",
"import plotly.graph_objects as go\n",
"\n",
"from e3nn.tensor import SphericalTensor\n",
"\n",
"# Symmetric Matrix\n",
"M = torch.randn(3,3)\n",
"M = M + M.transpose(0, 1)\n",
"\n",
"# Plot matrix\n",
"plt.imshow(M, cmap='plasma')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"matrix = CartesianTensor(M, formula='ij=ji').to_irrep_tensor()\n",
"r, f = SphericalTensor.from_irrep_tensor(matrix).plot()\n",
"\n",
"# Plot SH signal\n",
"surface_plot = lambda r, f: go.Surface(\n",
"x=r[..., 0], y=r[..., 1], z=r[..., 2], \n",
"surfacecolor=f, showscale=False)\n",
"go.Figure([surface_plot(r, f)])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment