Last active
January 31, 2021 09:10
-
-
Save blondegeek/616af44ed17d76fa6e11392b24cf0dc1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# A quick `e3nn` tutorial\n", | |
"\n", | |
"For more examples see [e3nn.org](https://e3nn.org) and [e3nn_tutorial](https://blondegeek.github.io/e3nn_tutorial)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Create a basic network" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from e3nn import rs, o3\n", | |
"from e3nn.networks import GatedConvParityNetwork\n", | |
"torch.set_default_dtype(torch.float64)\n", | |
"\n", | |
"# Define the datatypes of the inputs and outputs\n", | |
"N_atom_types = 3 # For example H, C, O\n", | |
"Rs_in = [(N_atom_types, 0, 1), (1, 2, 1)] # Input are scalars and time averaged oscillating linearly polarized vector field\n", | |
"Rs_out = [(1, 1, -1)] # Predict vectors\n", | |
"\n", | |
"# Define maximum radius for convolution radial functions\n", | |
"r_max = 1.5\n", | |
"\n", | |
"model_kwargs = {\n", | |
" 'Rs_in': Rs_in, 'Rs_out': Rs_out, 'mul': 4, 'lmax': 2, \n", | |
" 'layers': 3, 'max_radius': r_max, \n", | |
" 'number_of_basis': 10, # Number of basis functions used for radial function\n", | |
"}\n", | |
"\n", | |
"# This network has gated nonlinearities (e3nn.non_linearities.GatedBlockParity)\n", | |
"# and the default (point) convolutions (e3nn.point.operations.Convolution).\n", | |
"# It uses irreps of O(3) rather than just SO(3) so you MUST specify parity.\n", | |
"model = GatedConvParityNetwork(**model_kwargs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Equivariance Test" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Get random rotation angles and get representations for different datatypes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"angles = o3.rand_angles()\n", | |
"# Wigner D matrices\n", | |
"D_in, D_out = rs.rep(Rs_in, *angles), rs.rep(Rs_out, *angles)\n", | |
"# 3x3 Cartesian rotation matrix\n", | |
"rot = o3.rot(*angles)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Show representations of this random rotation for input and output features and geometry" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"fig, ax = plt.subplots(1, 3, figsize=(11, 3))\n", | |
"b = 1\n", | |
"ax[0].imshow(D_in, cmap='plasma', vmin=-b, vmax=b)\n", | |
"ax[0].set_title(\"$D_{in}$\")\n", | |
"ax[0].set_yticks(range(rs.dim(Rs_in)))\n", | |
"ax[0].set_xticks([])\n", | |
"ax[0].set_yticklabels([\"0,0\", \"0,0\", \"0,0\", \"2,-2\", \"2,-1\", \"2,0\", \"2,1\", \"2,2\"])\n", | |
"ax[1].imshow(D_out, cmap='plasma', vmin=-b, vmax=b)\n", | |
"ax[1].set_title(\"$D_{out}$\")\n", | |
"ax[1].set_yticks(range(rs.dim(Rs_out)))\n", | |
"ax[1].set_xticks([])\n", | |
"ax[1].set_yticklabels([\"1,-1 (y)\", \"1,0 (z)\", \"1,1 (x)\"])\n", | |
"ax[2].imshow(rot, cmap='plasma', vmin=-b, vmax=b)\n", | |
"ax[2].set_title(\"Cartesian rotation matrix\")\n", | |
"ax[2].set_yticks(range(rs.dim(Rs_out)))\n", | |
"ax[2].set_xticks([])\n", | |
"ax[2].set_yticklabels([\"x\", \"y\", \"z\"]);" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Check equivariance of randomly initialized model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Random input\n", | |
"N = 5\n", | |
"input = torch.randn(1, N, rs.dim(Rs_in))\n", | |
"geo = torch.randn(1, N, 3)\n", | |
"\n", | |
"# Rotated inputs\n", | |
"rot_input = torch.einsum('ij,zaj->zai', D_in, input)\n", | |
"rot_geo = torch.einsum('ij,zaj->zai', rot, geo)\n", | |
"rot_out = model(rot_input, rot_geo)\n", | |
"\n", | |
"# Rotated outputs\n", | |
"out = model(input, geo)\n", | |
"out_rot = torch.einsum('ij,zaj->zai', D_out, out)\n", | |
"\n", | |
"# Check that both yield the same answer\n", | |
"assert torch.allclose(rot_out, out_rot)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Outputs must have equal or higher symmetry as inputs\n", | |
"## Create network for graph data and visualize with SphericalTensor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import e3nn.point.message_passing as mp \n", | |
"from torch_scatter import scatter_mean\n", | |
"import e3nn.point.data_helpers as dh\n", | |
"from e3nn.tensor import SphericalTensor\n", | |
"from plotly.subplots import make_subplots\n", | |
"\n", | |
"Rs_in = [(1, 0, 1)] \n", | |
"L_max = 6\n", | |
"Rs_out = [(1, L, (-1)**L) for L in range(L_max)] # Predict vectors\n", | |
"\n", | |
"r_max = 1.5\n", | |
"\n", | |
"model_kwargs = {\n", | |
" 'Rs_in': Rs_in, 'Rs_out': Rs_out, 'mul': 4, 'lmax': 3, \n", | |
" 'layers': 3, 'max_radius': r_max, 'number_of_basis': 10,\n", | |
" 'convolution': mp.Convolution # We use a different convolution operation for graph data.\n", | |
"}\n", | |
"\n", | |
"# Create three random models\n", | |
"model1 = GatedConvParityNetwork(**model_kwargs)\n", | |
"model2 = GatedConvParityNetwork(**model_kwargs)\n", | |
"model3 = GatedConvParityNetwork(**model_kwargs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Create tetrahedral geometry with identical scalars at each vertex\n", | |
"pos = torch.tensor([[0., 0., 0.], [1., 1., 0], [1., 0., 1.], [0, 1., 1.]]) # Vertices of a tetrahedron\n", | |
"data = dh.DataNeighbors(torch.ones(pos.shape[0], 1), Rs_in, pos, r_max)\n", | |
"# Input, edge index, and relative distance vectors on each edge\n", | |
"data.x.shape, data.edge_index.shape, data.edge_attr.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Run the 3 random models\n", | |
"outs = (model1(data.x, data.edge_index, data.edge_attr),\n", | |
" model2(data.x, data.edge_index, data.edge_attr),\n", | |
" model3(data.x, data.edge_index, data.edge_attr))\n", | |
"print(output1.shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## All outputs preserve tetrahedral symmetry ($T_d$)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import plotly.graph_objects as go", | |
"\n", | |
"rows, cols = 1, 3\n", | |
"specs = [[{'is_3d': True} for i in range(cols)]\n", | |
" for j in range(rows)]\n", | |
"fig = make_subplots(rows, cols, specs=specs)\n", | |
"\n", | |
"for i, out in enumerate(outs):\n", | |
" trace = SphericalTensor(out.detach().sum(0)).plotly_surface()\n", | |
" trace['showscale'] = False\n", | |
" trace = go.Surface(**trace)\n", | |
" fig.add_trace(trace, row=1, col=i+1)\n", | |
"\n", | |
"fig.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Example for converting Cartesian tensors" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Degrees of freedom and transformation matrix for elasticity tensor to irrep basis" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from e3nn import rs\n", | |
"from e3nn.tensor import CartesianTensor\n", | |
"torch.set_default_dtype(torch.float64)\n", | |
"\n", | |
"\n", | |
"rank4 = torch.zeros(3, 3, 3, 3) # Placeholder\n", | |
"Rs, Q = CartesianTensor(rank4, 'ijkl=jikl=klij').to_irrep_transformation()\n", | |
"print(\"Representations: \", Rs)\n", | |
"print(\"Degrees of freedom: \", rs.dim(Rs))\n", | |
"print('Q transformation [irrep_basis, flattened_cartesian]', Q.shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Visualizing CartesianTensors as SphericalTensors" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import plotly\n", | |
"import plotly.graph_objects as go\n", | |
"\n", | |
"from e3nn.tensor import SphericalTensor\n", | |
"\n", | |
"# Symmetric Matrix\n", | |
"M = torch.randn(3,3)\n", | |
"M = M + M.transpose(0, 1)\n", | |
"\n", | |
"# Plot matrix\n", | |
"plt.imshow(M, cmap='plasma')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"matrix = CartesianTensor(M, formula='ij=ji').to_irrep_tensor()\n", | |
"r, f = SphericalTensor.from_irrep_tensor(matrix).plot()\n", | |
"\n", | |
"# Plot SH signal\n", | |
"surface_plot = lambda r, f: go.Surface(\n", | |
"x=r[..., 0], y=r[..., 1], z=r[..., 2], \n", | |
"surfacecolor=f, showscale=False)\n", | |
"go.Figure([surface_plot(r, f)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment