Skip to content

Instantly share code, notes, and snippets.

@cat-state
Created September 8, 2022 19:26
Show Gist options
  • Save cat-state/6308e46f323b909825d5146afa2945a0 to your computer and use it in GitHub Desktop.
Save cat-state/6308e46f323b909825d5146afa2945a0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from math import pi\n",
"from typing import Tuple"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Routing"
]
},
{
"cell_type": "code",
"execution_count": 212,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[1.6784, 2.4763, 1.1539]], grad_fn=<RepeatBackward>),\n",
" tensor([[0.3223, 2.3602, 1.2148]], grad_fn=<NormBackward3>))"
]
},
"execution_count": 212,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Input = Tuple[torch.Tensor, torch.Tensor]\n",
"\n",
"class SLMRouter(nn.Module):\n",
" def __init__(self, n_filters: int):\n",
" super().__init__()\n",
" self._angles = nn.Parameter(torch.zeros(n_filters).uniform_(0, 3.14159))\n",
" #self._angles = nn.Parameter(torch.tensor([0.0, pi / 2, pi / 2]))\n",
" \n",
" def forward(self, x: Input):\n",
" angle, intensity = x\n",
"\n",
" self_vecs = torch.stack([self._angles.cos(), self._angles.sin()]).unsqueeze(0)\n",
" in_vecs = torch.stack([angle.cos(), angle.sin()], dim=1)\n",
" \n",
" infall = self_vecs[:, :, :, None] @ in_vecs[:, :, None, :]\n",
" transmitted = (infall * intensity[:, None, None, :]).sum(dim=3).norm(dim=1)\n",
"\n",
" return (self._angles.repeat(intensity.shape[0], 1), transmitted)\n",
" \n",
"s = SLMRouter(3)\n",
"\n",
"x = (torch.tensor([[0.0, 0.0, 0.0]]), torch.ones(3)[None, :])\n",
"s(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Non-linearity"
]
},
{
"cell_type": "code",
"execution_count": 213,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x1e5622f1400>]"
]
},
"execution_count": 213,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAH5JJREFUeJzt3Xl81dWd//HXJ/u+h7CEXZBNUAibjrbVuqCtOqNtXQouFepvqu2vHa3a2tqOXbSddrppKQqtrVUq1YdLQS0dNxwFCQoBwhbWhC0rWclyc8/8kbSNiOYCN/ne5f18PO4jdzm593MIvHM43/M9X3POISIikSXG6wJERCT4FO4iIhFI4S4iEoEU7iIiEUjhLiISgRTuIiIRSOEuIhKBFO4iIhFI4S4iEoHivPrgvLw8N2LECK8+XkQkLK1bt67aOZffWzvPwn3EiBEUFxd79fEiImHJzPYG0k7TMiIiEUjhLiISgRTuIiIRSOEuIhKBFO4iIhFI4S4iEoEU7iIiEcizde4iIpGko9NPU6uP5nYfzW2dNLf7aPn71+7nWtp9tLR3Mm14NueO6fU8pFOicBeRqOfr9NPQ6qP+aAf1RztoONpBQ2sHDUd9NLZ23W9s9XXfuu43tXXfWn00tvlo9/kD/rxbPzZa4S4iciJaOzqpa2mnpqmd2uZ26lq6vza3U9fSQV1LO0daOjhytOtrfUsHjW2+j3zP2BgjPSmO9KQ40hLjSU+MY2BGEmlJcaQmxpGeGEdaYtf9tMQ4UhJj/3k/IZbUhK7nUhLiSI6PJTbG+vzPQeEuIiGvzddJVWMblY1tVPW8NbVR3dhGdVMbNc3t1Da1f2hQm0FmcjzZKQlkpcSTn5bI2AHpZKbEk5n8wVtGcjwZSfGkJ3UFtFnfB3IwKdxFxFONrR0cONLKgfqjHK5v5WB9K4fqWznU0Mrh7ltdS8cHvs8MclISyE1LIC8tkSmFWeSkJpCbmkBuWmLX/bQEslMSyElNIDM5vl9GzKFC4S4ifcY5x5GWDsrrWiivPUpFXQsVdUfZf+QoB450fW1sff9I2wzy0hIZmJFEYXYK04ZnU5CRxID0RAZkJDIgPYn89ERyUxOIi9WCvw+jcBeRU+L3Ow43trK7upm9NS3sqWlmb3ULe2tbqKht+cA0SWZyPEOykinMTmHmyBwGZyUzKCuZwZlJDMpKZkB6IvEK7VOmcBeRgBxt72RnVVPXrbKJndXN7KxsYk9NM60d/1wpkhAbQ2FOMsNzUpgxIpuhOSldt+wUCnOSyUiK97AX0UPhLiLv0+brpKyyie2HG9l2qIkdhxvZXtlIRd1RnOtqE2MwNCeFUXmpnD06j5H5qYzKS2V4bgqDMpOjam47VCncRaJYTVMbmw80UHqwgS0HGyg90MCu6mY6/V0pHh9rjMpLY0phFldPHcqYgjROG5DG8NwUEuNiPa5ePorCXSRK1DS1UbK/no0V9ZRU1LP5QD0H61v/8frgzCTGD8rg4okDOX1gOuMGpjMiL1Xz32FK4S4Sgdp8nWza38B7++pYX36E9eVHqKg7CnStRhmZl8qMkTlMGpzJxMEZTBicQVZKgsdVSzAp3EUiQG1zO+v21lG8p5a1e2rZtL+B9s6ug5xDspKZMjSTebOHM7kwi4mDM0jXQc2Ip3AXCUOVja2s2VXLmt01rNlVy47KJqBrpcoZhZnceM4Ipg7LYuqwbAZkJHlcrXhB4S4SBhpaO1i9s4a3dtbwv2XV/wjz1IRYikbkcOVZQ5gxMoczhmSSFK8DnaJwFwlJnX5HScURXt9exRvbq9hQUU+n35EcH8v0kTlcNa2Q2aNymTg4Q2dpynEp3EVCxJGWdl7fXsX/bKlk1Y4q6lo6MIPJhVn8v4+N5l/G5DF1WDYJcQpz6V1A4W5mlwA/B2KBR51zDxzzeibwODCs+z3/yzn32yDXKhJx9lQ3s7L0MCu3HKZ4Ty1+B7mpCXzi9AF87PR8zhuTT3aqVrHIies13M0sFngIuBCoANaa2fPOudIezb4ElDrnPm1m+cA2M/ujc669T6oWCVPOObYcbOSlTQd5afMhth/umjsfPyiDL33iNM4fN4AphVnE6AxPOUWBjNxnAGXOuV0AZrYUuALoGe4OSLeuDY/TgFrgo3e/F4kSzjm2HmrkLyUHWF5ykD01LcQYTB+Rw7c/NYELJxQwNCfF6zIlwgQS7kOA8h6PK4CZx7T5FfA8cABIBz7nnAv8mlMiEWhvTTPPrT/Ac+v3s7OqmRiDs0fnseC80Vw0sYC8tESvS5QIFki4H+//h+6YxxcD64HzgdHASjNb5ZxreN8bmS0AFgAMGzbsxKsVCXFHWtp5oeQgz7xbwXv7jgAwY2QON54zkjmTBirQpd8EEu4VwNAejwvpGqH3dBPwgHPOAWVmthsYB7zTs5FzbhGwCKCoqOjYXxAiYanT73hjexXL1pXzt9JK2jv9nF6Qzt1zxnH5lMEMzkr2ukSJQoGE+1pgjJmNBPYD1wDXHdNmH3ABsMrMCoDTgV3BLFQk1JTXtvBUcTnLiis41NBKTmoC188axlVTC5k4OCPsrrkpkaXXcHfO+czsNuBlupZCLnHObTazW7tfXwjcD/zOzDbSNY1zl3Ouug/rFvFEp9/x6tZK/rhmL69tr8KA88bmc9+nJ3DB+AKtQZeQEdA6d+fcCmDFMc8t7HH/AHBRcEsTCR11ze0sXVvO46v3sv/IUQakJ3L7J07jczOGMUTTLhKCdIaqyEfYdqiRJW/u5tn1+2nz+Zk9Kpd7LxvPJycUaJ9zCWkKd5FjOOd4Y0c1j67axaod1STFx3DVtEJuPHsEYwvSvS5PJCAKd5FuHZ1+lpccZOHrO9l6qJEB6YncefHpXDdjmLYAkLCjcJeo19rRybJ1FSx8bSf7jxzltAFp/PjqyVxx5hAdIJWwpXCXqNXa0ckf1+zjN6/vpLKxjTOHZvGdyydywbgB2ttFwp7CXaLO30N94es7qWpsY9aoHH72uTOZPTpXa9MlYijcJWq0+/w8VVzOL1/ZweGGNmaPyuVX157FzFG5XpcmEnQKd4l4fr/j+Q0H+OnK7eyrbWH6iGx+fs1ZzFKoSwRTuEtEW7Wjih+u2ErpwQYmDMrgtzdN5+Nj8zX9IhFP4S4RaduhRr63vJRVO6opzE7m59ecyacnD9aBUokaCneJKDVNbfx05XaefGcf6Unx3HvZeObOHk5iXKzXpYn0K4W7RARfp58/rN7LT1dup6W9k3mzR/CVC8bo5COJWgp3CXurd9Vw33Ob2Xa4kXPH5PHtT01gjLYJkCincJewVd3Uxg+Wb+GZ9/YzJCuZhZ+fxsUTC3SwVASFu4Qhv9+xdG05D7y4haMdndx+/mn8+8dPIzlB8+oif6dwl7BSVtnEPc+UsHZPHbNG5fC9K8/gtAFpXpclEnIU7hIWOjr9LHxtJ798pYzkhFh+dPVkPjOtUFMwIh9C4S4hr/RAA3cs20DpwQY+NXkQ9316IvnpiV6XJRLSFO4Ssjo6/Tz86k5++coOslIS+M3caVw8caDXZYmEBYW7hKSyyia++qf1bNxfz+VTBvPdyydqzbrICVC4S0jx+x2/f3sPP3xxKykJsfz6+qnMOWOQ12WJhB2Fu4SMqsY27li2gde3V/Hx0/P50VWTGZCR5HVZImFJ4S4h4bVtldyxbAONrT7uv2Iin581XCthRE6Bwl081e7z8+OXt/LIqt2MG5jOE/NnMVZbB4icMoW7eKairoXbnniP9eVHmDd7ON+4dDxJ8TrLVCQYFO7iib+VHuZrT63HOXj4+qlcqoOmIkGlcJd+1el3/OSv23j4tZ1MGpLBQ9dNZXhuqtdliUQchbv0m5qmNr689D3+t6yGa2cM5b5PT9Q0jEgfUbhLv9hYUc8X/1BMdXM7P7p6Mp8tGup1SSIRTeEufe7Z9/Zz19Ml5KYm8PStZ3NGYabXJYlEPIW79JlOv+PBl7ay6I1dzByZw0PXTyUvTRt+ifQHhbv0icbWDr6ydD2vbK1k3uzhfOtTE4iPjfG6LJGooXCXoCuvbeELj61lZ1Uz9185ibmzhntdkkjUUbhLUL27r475jxXT0ennsZtm8C9j8rwuSSQqKdwlaF7ceJD//6f1FGQk8dubpjM6X5e/E/GKwl1OmXOOxW/u5vsrtnDW0CwemVdErg6cingqoCNcZnaJmW0zszIzu/tD2nzczNab2WYzez24ZUqo8vsd9/9lC99bvoVLJw3iifmzFOwiIaDXkbuZxQIPARcCFcBaM3veOVfao00W8DBwiXNun5kN6KuCJXS0+Tr52lMbWF5ykJvPGcm9l40nJkbb9IqEgkCmZWYAZc65XQBmthS4Aijt0eY64Bnn3D4A51xlsAuV0NLY2sGC36/j7V01fOPSccw/d5T2XxcJIYFMywwByns8ruh+rqexQLaZvWZm68xs3vHeyMwWmFmxmRVXVVWdXMXiuZqmNq57ZA1r99Ty35+bwoLzRivYRUJMICP34/2rdcd5n2nABUAy8LaZrXbObX/fNzm3CFgEUFRUdOx7SBjYf+QocxevYX/dURbNm8b54wq8LklEjiOQcK8Aeu7yVAgcOE6baudcM9BsZm8AU4DtSMTYVdXE5x9dQ2Obj8dvmcn0ETlelyQiHyKQaZm1wBgzG2lmCcA1wPPHtHkOONfM4swsBZgJbAluqeKlbYca+exvVtPm87N0wSwFu0iI63Xk7pzzmdltwMtALLDEObfZzG7tfn2hc26Lmb0ElAB+4FHn3Ka+LFz6z8aKeuYuWUNiXAx/vGU2pw3QyUkioc6c82bqu6ioyBUXF3vy2RK4d/fVccPid8hMieeJW2YxLDfF65JEopqZrXPOFfXWTmeoyodat7eWG5asJS8tgSfmz2JwVrLXJYlIgLQHqxzX2j21zFv8DgPSE1m6YLaCXSTMaOQuH7B2Ty03LHmHgZlJLJ0/iwEZSV6XJCInSCN3eZ91e+u4UcEuEvYU7vIP68uPcOOSd8hPT+RJBbtIWFO4CwCbD9Qzb/EaslMTeHLBLAoU7CJhTeEulFU2MnfxO6QnxfPE/JkMytTBU5Fwp3CPcntrmrnukTXExhh/vGUmhdlaxy4SCRTuUexg/VGuf3QN7Z1+Hv/CTEbkpXpdkogEicI9StU1tzN38Tscaeng9zfP4PSB6V6XJCJBpHXuUai5zceNv1vLvtoWHrtpBpMLs7wuSUSCTCP3KNPm6+TWx9exseIIv7r2LGaPzvW6JBHpAxq5RxG/3/EfT21g1Y5qfnz1ZC6aONDrkkSkj2jkHiWcc3xv+Rb+UnKQu+eM4zNFQ3v/JhEJWwr3KPHIql0s+d/d3HTOCL543iivyxGRPqZwjwLPrd/PD1Zs5bLJg/jWZRN0MWuRKKBwj3Bv76zhjmUbmDkyh59+dgoxMQp2kWigcI9gZZWNfPEPxQzPTWXR3CIS42K9LklE+onCPUJVNrZyw5K1JMTF8tsbp5OZEu91SSLSjxTuEehoeyfzHyumtrmdJTcWMTRH+8WIRButc48wfr/jP5atp2R/PQs/P01nn4pEKY3cI8xPVm5jxcZD3DNnHBfrJCWRqKVwjyBPr6vgoVd3cs30ocw/V2vZRaKZwj1CrNtbyz3PbGT2qFzuv3KS1rKLRDmFewTYf+QoX/zDOgZlJfHrz08lPlY/VpFopwOqYa6l3cf8x4pp6/CzdEERWSkJXpckIiFA4R7GnHPcuayELYcaWHLDdE4boAtuiEgX/f89jD382k6WbzzIXZeM4xPjBnhdjoiEEIV7mHp1ayX/9ddtXD5lsHZ5FJEPULiHod3VzXx56XuMH5jBg1dN1soYEfkAhXuYaW7zseD3xcTHxrBo3jSSE7QZmIh8kMI9jDjn+PqfS9hZ1cSvrj2LwmztGSMix6dwDyOPrNrF8o0H+fol4zj7tDyvyxGREKZwDxNvlVXzwItbmTNpoA6gikivFO5h4FB9K7c/+R4j81L58Wem6ACqiPQqoHA3s0vMbJuZlZnZ3R/RbrqZdZrZ1cErMbp1dPq57Yl3OdrRyW/mTiMtUeediUjveg13M4sFHgLmABOAa81swoe0exB4OdhFRrMHX9xK8d46Hrhqss5AFZGABTJynwGUOed2OefagaXAFcdpdzvwNFAZxPqi2kubDvLom7uZN3s4l08Z7HU5IhJGAgn3IUB5j8cV3c/9g5kNAf4VWBi80qLb3ppm7lxWwpShWXzzsvFelyMiYSaQcD/e0Tt3zOOfAXc55zo/8o3MFphZsZkVV1VVBVpj1GnzdfKlJ97FDB667iwS43SikoicmECOzlUAQ3s8LgQOHNOmCFjavYojD7jUzHzOuWd7NnLOLQIWARQVFR37C0K6fX/5Fjbtb+CReUU6UUlETkog4b4WGGNmI4H9wDXAdT0bOOdG/v2+mf0O+MuxwS6BWV5ykN+/vZf5547kwgkFXpcjImGq13B3zvnM7Da6VsHEAkucc5vN7Nbu1zXPHiT7alq4++kSzhqWxdcvGed1OSISxgJaNO2cWwGsOOa544a6c+7GUy8r+rT7/Nz+ZNc8+y+uOUuXyhORU6IzYkLEf/11Gxsq6vn19VMZmqN5dhE5NRoehoBXt1Wy6I1dfH7WMOacMcjrckQkAijcPVbZ0ModT21g3MB07r3sAyf+ioicFE3LeMjvd/zHsg00t/tYeu0skuK1nl1EgkMjdw8tfnM3q3ZU861PTWBMgfaNEZHgUbh7ZNP+en708lYumlDAdTOGeV2OiEQYhbsHWtp9fPnJ98hNTdQFrkWkT2jO3QP3/2ULu2ua+eMtM8lOTfC6HBGJQBq597OVpYd58p19LDh3FGeP1nVQRaRvKNz7UWVjK3c9XcKEQRl87aKxXpcjIhFM4d5PnHPcuayE5jYfP7/mTG3jKyJ9SuHeTx5fvZfXt1fxjUvHa9mjiPQ5hXs/2FXVxPdXbOG8sfnMmz3c63JEJAoo3PuYr9PPV5/aQGJcLD++WsseRaR/aClkH3vo1Z1sKD/CL689i4KMJK/LEZEooZF7HyqpOMIvXtnBFWcO5tNTBntdjohEEYV7H2nt6ORrT20gPy2R/7x8ktfliEiU0bRMH/nJX7dRVtnEYzfPIDMl3utyRCTKaOTeB97ZXcujb+7mupnD+NjYfK/LEZEopHAPsuY2H3cs20BhdjLfvHS81+WISJTStEyQ/fDFLZTXtbB0/ixSE/XHKyLe0Mg9iN4qq+bx1fu4+ZyRzByV63U5IhLFFO5B0tTm484/lzAqL5U7Ljrd63JEJMpp3iBIfrBiCwfqj/LnW2eTnKBNwUTEWxq5B8GbO6p5Ys0+5p87imnDc7wuR0RE4X6qmtp83PV0CaPyU/nahdqjXURCg6ZlTtEDL/5zOiYpXtMxIhIaNHI/BW/t/OfqGE3HiEgoUbifpJZ2H3c/vZERuSlaHSMiIUfTMifpxy9vY19tC39aMEurY0Qk5GjkfhLW7a3ld2/tYd7s4TpZSURCksL9BLV2dPL1P5cwODOZr18yzutyRESOS9MyJ+iXr+xgZ1Uzj908gzTtHSMiIUoj9xOw+UA9C1/fxVVTC7WVr4iENIV7gHydfu56uoTslAS+9Slt5SsioU3zCgFa/OZuNu1v4OHrp5KVkuB1OSIiHymgkbuZXWJm28yszMzuPs7r15tZSfftLTObEvxSvbOnupmfrtzORRMKmDNpoNfliIj0qtdwN7NY4CFgDjABuNbMJhzTbDfwMefcZOB+YFGwC/WKc457ntlIQmwM9185CTPzuiQRkV4FMnKfAZQ553Y559qBpcAVPRs4595yztV1P1wNFAa3TO88VVzO27tquOfS8RRkJHldjohIQAIJ9yFAeY/HFd3PfZgvAC8e7wUzW2BmxWZWXFVVFXiVHqlsbOX7y7cwc2QO10wf6nU5IiIBCyTcjzcP4Y7b0OwTdIX7Xcd73Tm3yDlX5Jwrys8P/aWE332+lFafnx/+2xnExGg6RkTCRyDhXgH0HLYWAgeObWRmk4FHgSucczXBKc87fys9zPKNB/ny+acxKj/N63JERE5IIOG+FhhjZiPNLAG4Bni+ZwMzGwY8A8x1zm0Pfpn9q7G1g289t4nTC9JZcN5or8sRETlhva5zd875zOw24GUgFljinNtsZrd2v74Q+DaQCzzcvZrE55wr6ruy+9ZP/rqdQw2tPHT9VBLidJ6XiISfgE5ics6tAFYc89zCHvdvAW4JbmneeG9fHY+9vYe5s4YzdVi21+WIiJwUDUt76Oj0c88zGylIT+LOi3UBDhEJX9p+oIfFb+5m66FGfjN3GulJ8V6XIyJy0jRy77avpoWf/a1ri4GLJ2qLAREJbwp3urYYuPe5TcTFxPDdKyZ6XY6IyClTuAMvlBzkje1V3HHRWAZlJntdjojIKYv6cK9v6eA/XyhlcmEmc2eP8LocEZGgiPoDqg++vJXa5jZ+d9N0YrXFgIhEiKgeua/bW8sTa/Zx8zkjmTQk0+tyRESCJmrDvaPTzzee2cTgzCS+euFYr8sREQmqqJ2WWfzmbrYdbmTR3GmkJkbtH4OIRKioHLmX13atab9wQgEXaU27iESgqAt35xz3Pb+ZGDO+c7nWtItIZIq6cH958yFe2VrJVz85liFZWtMuIpEpqsK9qc3Hd54vZfygDG46Z4TX5YiI9JmoCvefrdzO4cZWvnflJOJio6rrIhJloibhSg808Nu39nDN9GFMG6592kUkskVFuPv9jm8+u5Gs5HjuukT7tItI5IuKcF+6tpz39h3hm5eNJyslwetyRET6XMSHe3VTGw++tJVZo3L417OGeF2OiEi/iPhw/8GKLbS0+/jelZPovni3iEjEi+hwX72rhmfe3c/8c0dx2oB0r8sREek3ERvu7T4/9z67icLsZG4/f4zX5YiI9KuI3THr0Td3UVbZxOIbikhOiPW6HBGRfhWRI/eKuhZ+8T87uGhCAReML/C6HBGRfheR4f7dF0oxjPu0MZiIRKmIC/e/lR5mZelhvvLJMdoYTESiVkSF+9H2Tr7zwmbGDEjj5nNGel2OiIhnIuqA6q9e3UFF3VH+tGAWCXER9XtLROSEREwCllU2seiNXfzb1CHMHJXrdTkiIp6KiHB3zvHt5zaRHB/LPXPGe12OiIjnIiLcXyg5yFs7a7jzknHkpyd6XY6IiOfCPtwbWju4/y+lTC7M5LoZw7wuR0QkJIT9AdX/Xrmd6qY2Ft9QRGyMNgYTEYEwH7mXHmjgsbf2cP3MYUwuzPK6HBGRkBG24e73O+59diPZKQncedE4r8sREQkpAYW7mV1iZtvMrMzM7j7O62Zmv+h+vcTMpga/1Pdbtq6cd/cd4Z5Lx5OZEt/XHyciElZ6DXcziwUeAuYAE4BrzWzCMc3mAGO6bwuAXwe5zvepa27ngRe3MmNEDldN1dWVRESOFcjIfQZQ5pzb5ZxrB5YCVxzT5grg967LaiDLzAYFudZ/+NHLW2lo9XG/rq4kInJcgYT7EKC8x+OK7udOtE1QvLuvjiffKefmc0Zw+kBdXUlE5HgCCffjDY3dSbTBzBaYWbGZFVdVVQVS3wfEmHHumDy+8smxJ/X9IiLRIJBwrwCG9nhcCBw4iTY45xY554qcc0X5+fknWisAZw7N4g9fmElaYtgv0RcR6TOBhPtaYIyZjTSzBOAa4Plj2jwPzOteNTMLqHfOHQxyrSIiEqBeh7/OOZ+Z3Qa8DMQCS5xzm83s1u7XFwIrgEuBMqAFuKnvShYRkd4ENLfhnFtBV4D3fG5hj/sO+FJwSxMRkZMVtmeoiojIh1O4i4hEIIW7iEgEUriLiEQghbuISASyroUuHnywWRWw9yS/PQ+oDmI54UB9jg7qc3Q4lT4Pd871ehaoZ+F+Ksys2DlX5HUd/Ul9jg7qc3Tojz5rWkZEJAIp3EVEIlC4hvsirwvwgPocHdTn6NDnfQ7LOXcREflo4TpyFxGRjxDS4R6KF+buawH0+fruvpaY2VtmNsWLOoOptz73aDfdzDrN7Or+rK8vBNJnM/u4ma03s81m9np/1xhsAfzdzjSzF8xsQ3efw3p3WTNbYmaVZrbpQ17v2/xyzoXkja7thXcCo4AEYAMw4Zg2lwIv0nUlqFnAGq/r7oc+nw1kd9+fEw197tHuFbp2J73a67r74eecBZQCw7ofD/C67n7o8zeAB7vv5wO1QILXtZ9Cn88DpgKbPuT1Ps2vUB65h9yFuftBr312zr3lnKvrfriarqtehbNAfs4AtwNPA5X9WVwfCaTP1wHPOOf2ATjnwr3fgfTZAenWddX7NLrC3de/ZQaPc+4NuvrwYfo0v0I53EPqwtz95ET78wW6fvOHs177bGZDgH8FFhIZAvk5jwWyzew1M1tnZvP6rbq+EUiffwWMp+sSnRuBrzjn/P1Tnif6NL9C+UKkQbswdxgJuD9m9gm6wv1f+rSivhdIn38G3OWc6+wa1IW9QPocB0wDLgCSgbfNbLVzbntfF9dHAunzxcB64HxgNLDSzFY55xr6ujiP9Gl+hXK4B+3C3GEkoP6Y2WTgUWCOc66mn2rrK4H0uQhY2h3secClZuZzzj3bPyUGXaB/t6udc81As5m9AUwBwjXcA+nzTcADrmtCuszMdgPjgHf6p8R+16f5FcrTMtF4Ye5e+2xmw4BngLlhPIrrqdc+O+dGOudGOOdGAH8G/j2Mgx0C+7v9HHCumcWZWQowE9jSz3UGUyB93kfX/1QwswLgdGBXv1bZv/o0v0J25O6i8MLcAfb520Au8HD3SNbnwnjTpQD7HFEC6bNzbouZvQSUAH7gUefccZfUhYMAf873A78zs410TVnc5ZwL290izexJ4ONAnplVAPcB8dA/+aUzVEVEIlAoT8uIiMhJUriLiEQghbuISARSuIuIRCCFu4hIBFK4i4hEIIW7iEgEUriLiESg/wNhrl3YwrAXzwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def optical_limiting(x):\n",
" return ((x * 5).sigmoid() - 0.5) * 1.8\n",
"\n",
"plt.plot(torch.linspace(0, 1).numpy(), optical_limiting(torch.linspace(0, 1)).numpy())"
]
},
{
"cell_type": "code",
"execution_count": 214,
"metadata": {},
"outputs": [],
"source": [
"class Act(nn.Module):\n",
" def __init__(self, activation):\n",
" super().__init__()\n",
" self._activation = activation\n",
" \n",
" def forward(self, x):\n",
" angle, intensity = x\n",
" return (angle, self._activation(intensity))"
]
},
{
"cell_type": "code",
"execution_count": 225,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.6629, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.1859],\n",
" [1.2430],\n",
" [1.1676]], grad_fn=<NormBackward3>)\n",
"tensor(0.4733, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.2728],\n",
" [1.2726],\n",
" [0.0262]], grad_fn=<NormBackward3>)\n",
"tensor(0.4276, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000e+00],\n",
" [1.7377e+00],\n",
" [1.7377e+00],\n",
" [3.2093e-05]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000e+00],\n",
" [1.7579e+00],\n",
" [1.7579e+00],\n",
" [6.3294e-06]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000e+00],\n",
" [1.7580e+00],\n",
" [1.7580e+00],\n",
" [1.0729e-07]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.7580],\n",
" [1.7580],\n",
" [0.0000]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.7580],\n",
" [1.7580],\n",
" [0.0000]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.7580],\n",
" [1.7580],\n",
" [0.0000]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.7580],\n",
" [1.7580],\n",
" [0.0000]], grad_fn=<NormBackward3>)\n",
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n",
"tensor([0, 1, 1, 0])\n",
"tensor([[0.0000],\n",
" [1.7580],\n",
" [1.7580],\n",
" [0.0000]], grad_fn=<NormBackward3>)\n"
]
}
],
"source": [
"xor = {\n",
" (0, 0): 0,\n",
" (0, 1): 1,\n",
" (1, 0): 1,\n",
" (1, 1): 0\n",
"}\n",
"\n",
"batch_x = torch.tensor([[0, 0],\n",
" [0, 1],\n",
" [1, 0],\n",
" [1, 1]]).float()\n",
"\n",
"batch_y = torch.tensor([0, 1, 1, 0])\n",
"\n",
"net = nn.Sequential(\n",
" SLMRouter(2),\n",
" Act(optical_limiting),\n",
" SLMRouter(2),\n",
" Act(optical_limiting),\n",
" SLMRouter(1)\n",
")\n",
"\n",
"start_angles = nn.Parameter(torch.zeros(1, 2).uniform_(0, 3.14159).expand(4, 2))\n",
"\n",
"optim = torch.optim.Adam([start_angles, *net.parameters()], lr=0.01)\n",
"\n",
"for i in range(1000):\n",
" optim.zero_grad()\n",
" \n",
" x = (start_angles, batch_x)\n",
" out_angle, out_intensity = net(x)\n",
" loss = F.binary_cross_entropy_with_logits(out_intensity.squeeze(), batch_y.float())\n",
" loss.backward()\n",
" \n",
" optim.step()\n",
" if i % 100 == 0:\n",
" print(loss)\n",
" print(batch_y)\n",
" print(out_intensity)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment