Skip to content

Instantly share code, notes, and snippets.

@cat-state
Created September 8, 2022 19:26
Show Gist options
  • Save cat-state/6308e46f323b909825d5146afa2945a0 to your computer and use it in GitHub Desktop.
Save cat-state/6308e46f323b909825d5146afa2945a0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from math import pi\n",
"from typing import Tuple"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Routing"
]
},
{
"cell_type": "code",
"execution_count": 212,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[1.6784, 2.4763, 1.1539]], grad_fn=<RepeatBackward>),\n",
" tensor([[0.3223, 2.3602, 1.2148]], grad_fn=<NormBackward3>))"
]
},
"execution_count": 212,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Input = Tuple[torch.Tensor, torch.Tensor]\n",
"\n",
"class SLMRouter(nn.Module):\n",
" def __init__(self, n_filters: int):\n",
" super().__init__()\n",
" self._angles = nn.Parameter(torch.zeros(n_filters).uniform_(0, 3.14159))\n",
" #self._angles = nn.Parameter(torch.tensor([0.0, pi / 2, pi / 2]))\n",
" \n",
" def forward(self, x: Input):\n",
" angle, intensity = x\n",
"\n",
" self_vecs = torch.stack([self._angles.cos(), self._angles.sin()]).unsqueeze(0)\n",
" in_vecs = torch.stack([angle.cos(), angle.sin()], dim=1)\n",
" \n",
" infall = self_vecs[:, :, :, None] @ in_vecs[:, :, None, :]\n",
" transmitted = (infall * intensity[:, None, None, :]).sum(dim=3).norm(dim=1)\n",
"\n",
" return (self._angles.repeat(intensity.shape[0], 1), transmitted)\n",
" \n",
"s = SLMRouter(3)\n",
"\n",
"x = (torch.tensor([[0.0, 0.0, 0.0]]), torch.ones(3)[None, :])\n",
"s(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Non-linearity"
]
},
{
"cell_type": "code",
"execution_count": 213,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x1e5622f1400>]"
]
},
"execution_count": 213,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def optical_limiting(x):\n",
" return ((x * 5).sigmoid() - 0.5) * 1.8\n",
"\n",
"plt.plot(torch.linspace(0, 1).numpy(), optical_limiting(torch.linspace(0, 1)).numpy())"
]
},
{
"cell_type": "code",
"execution_count": 214,
"metadata": {},
"outputs": [],
"source": [
"class Act(nn.Module):\n",
" def __init__(self, activation):\n",
" super().__init__()\n",
" self._activation = activation\n",
" \n",
" def forward(self, x):\n",
" angle, intensity = x\n",
" return (angle, self._activation(intensity))"
]
},
{
"cell_type": "code",
"execution_count": 225,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.6629, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.1859],\n",
" [1.2430],\n",
" [1.1676]], grad_fn=<NormBackward3>)\n",
"tensor(0.4733, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.2728],\n",
" [1.2726],\n",
" [0.0262]], grad_fn=<NormBackward3>)\n",
"tensor(0.4276, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000e+00],\n",
" [1.7377e+00],\n",
" [1.7377e+00],\n",
" [3.2093e-05]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000e+00],\n",
" [1.7579e+00],\n",
" [1.7579e+00],\n",
" [6.3294e-06]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000e+00],\n",
" [1.7580e+00],\n",
" [1.7580e+00],\n",
" [1.0729e-07]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.7580],\n",
" [1.7580],\n",
" [0.0000]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.7580],\n",
" [1.7580],\n",
" [0.0000]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.7580],\n",
" [1.7580],\n",
" [0.0000]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.7580],\n",
" [1.7580],\n",
" [0.0000]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.7580],\n",
" [1.7580],\n",
" [0.0000]], grad_fn=<NormBackward3>)\n"
]
}
],
"source": [
"xor = {\n",
" (0, 0): 0,\n",
" (0, 1): 1,\n",
" (1, 0): 1,\n",
" (1, 1): 0\n",
"}\n",
"\n",
"batch_x = torch.tensor([[0, 0],\n",
" [0, 1],\n",
" [1, 0],\n",
" [1, 1]]).float()\n",
"\n",
"batch_y = torch.tensor([0, 1, 1, 0])\n",
"\n",
"net = nn.Sequential(\n",
" SLMRouter(2),\n",
" Act(optical_limiting),\n",
" SLMRouter(2),\n",
" Act(optical_limiting),\n",
" SLMRouter(1)\n",
")\n",
"\n",
"start_angles = nn.Parameter(torch.zeros(1, 2).uniform_(0, 3.14159).expand(4, 2))\n",
"\n",
"optim = torch.optim.Adam([start_angles, *net.parameters()], lr=0.01)\n",
"\n",
"for i in range(1000):\n",
" optim.zero_grad()\n",
" \n",
" x = (start_angles, batch_x)\n",
" out_angle, out_intensity = net(x)\n",
" loss = F.binary_cross_entropy_with_logits(out_intensity.squeeze(), batch_y.float())\n",
" loss.backward()\n",
" \n",
" optim.step()\n",
" if i % 100 == 0:\n",
" print(loss)\n",
" print(batch_y)\n",
" print(out_intensity)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment