Skip to content

Instantly share code, notes, and snippets.

@jaberg
Created August 21, 2016 16:45
Show Gist options
  • Save jaberg/9437e508c6fdde908200156360221c76 to your computer and use it in GitHub Desktop.
Save jaberg/9437e508c6fdde908200156360221c76 to your computer and use it in GitHub Desktop.
Shared vars and small functions for fast, flexible way of doing RNNs with Theano
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Three ways to write RNNs with Theano\n",
"\n",
"Following up on some discussion with people at MILA I thought it might be interesting to compare three ways of writing the core computations of a simple RNN, and introduce a `shared_update` or *loop* style of Theano programming that can combine Python control flow with Theano speed and differentiation.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10x5: 1000 scan evals averaged 0.000187505960464 seconds\n",
"10x5: 1000 loop evals averaged 0.000110383033752 seconds\n",
"10x5: 1000 full evals averaged 8.73539447784e-05 seconds\n",
"50x5: 100 scan evals averaged 0.000439031124115 seconds\n",
"50x5: 100 loop evals averaged 0.000379998683929 seconds\n",
"50x5: 100 full evals averaged 0.000342991352081 seconds\n",
"20x500: 50 scan evals averaged 0.00244964122772 seconds\n",
"20x500: 50 loop evals averaged 0.00248847961426 seconds\n",
"20x500: 50 full evals averaged 0.0127433204651 seconds\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEKCAYAAAAb7IIBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XmYHWWZ/vHvHRZRCLukTYd0IAa3kQFRjAxIFB1WCSoI\nQWAMOjIqDqCj7NAK6Ag/fw6rCAIKyiYii7IJpGF0HIJAWAQMQQghSyNLQiARYnjmj3pPUzmcPl1J\nd53T3ef+XNe5uvZ6Tr196ql631oUEZiZWesZ0ewAzMysOZwAzMxalBOAmVmLcgIwM2tRTgBmZi3K\nCcDMrEU5AQxCki6S9O1mxzEYSdpR0pxVmG8tSddLWijpijJiG8wkbSLpTkmLJJ1WY/yQ/Z+TdLSk\n8+qMf0LSR+uMnybp4HKiG9ycAPpB0pOSlkh6UdJzaQfT3uy4WsCq3LyyN/BWYIOI2Le/AaREtDyV\n/SJJj0j6XNU0r0m6v2rYSZIuTN0daZpfV01ziaQT+htjlS8Cz0TEehHxjQFedk1px7w4baOlkv6e\nuhdLenCg1hMR342ILxaM6URJF6/qulK5vybprKrh/y3poFVdbrM4AfRPALtHxLrA24BngDObG9Lr\nJK3W7BgGkQ5gZqzCnY91tuPciFg3ItYDvgacL2lC1TSjJe3Xxyo+KGniysa1kjqAh0texwrSjnlk\n+n38G/A/aXuNjIj3NjKWAfYycKCksc0OpL+cAPpPABHxKnAV8O6eEdJuku5NR4izJZ24wozS9pJ+\nL+mFNP4NRxCSRkq6XdJ/pf6LJP1Q0i3paGpa/h8xHZ18WdJMYGYatp2k6Wk9d0n6UG76aZK+k4Yv\nkvQrSevX/KLSRuks54V0xnNHbtwTko6S9Kc07gJJa+bG7yHpvjTv7yS9NzfubZKukvSMpMclfTU3\nbi1JP5H0vKSHgA9UxXSkpKfTtnhE0kdqxN0JnADsl6abqsxx6SxuQVrHumn6ypH5wZJmA7fV2h55\nEXEj8DywZdWoU4FvS6r3WzsV+E6tEfW2eY1pa5azpIuAfwGOTN+/1+qQ3LLqldcTkr4u6f40/rJ8\nWQ+UVDZbp+7PpjJ5V+o/WNLVqftESZfk5jswzftXSceQzhgl7QwcA+yr7CzkvtzqxqXv+aKkmyRt\nWCe0hcBPgM4B/LpN4QQwQCS9BdgX+ENu8EvAgekIcXfg3yTtmabvAG4ATgc2BrYCZlQtc0PgVuC/\nI+Lw3Kj9gW8BGwH3Az+vCmcy2Y7y3ZI2AH4N/Fea/gfAb9LwigOBzwFtwHJ6P4v5OjAnLWcTsh9T\n3v7Ax4HxwDuA49L32Bq4APhXYEPgR8B1ktaQJOB64D6ys6idgMMkfTwtsxPYLH12JtuRVbbPFsBX\ngG3SUebOwJPVQUdEJ9kO9vJ0BHoRMBU4CNgR2BwYCZxVNeuHgXem5fYqJZM903aZlV81cDWwiGz7\n1hLAOcAWveyY+9rmlRh6LeeImEr2P/K99P1v7+P79Fpeucn2Af6ZrFz+sc73648uYFLq/jDwePoL\nWbnlk2FlJ/9usu35WWA02bYYAxARN5P9H1yRzkK2zs0/hex/663Am4D/qBNXAKcAn9Ybz/iGFCeA\n/rtG0vNkRwUfA/5fZURE3BkRf0rdDwGXk/3jQvYP99uIuDIilkfECxHxQG657WT/4FdExApnDsBv\nIuL3EbEMOBb4kFZse/hORCyKiFfIEs/MiLg0Il6LiMuBR4FP5Ka/JCIeiYilwPHAPmnHXG0Z2U56\nsxTz76vGnxkR8yJiIdkPZEoa/q/AuRHxx8hcArwCTCRLVBtHxClpmU8CPwYq1Sb7ACen7zMXOCO3\nvuXAmsA/SFo9Ip6KiCdqxF3L/sD/j4jZEbEEOJrsDKHymwjgxIhYmrZjLe2p7JcCvwS+FhH5On+l\n5ZwAHC9p9V6Ws5Rse51cY1xf27yiSDkXVa+8Kk6PiO5U1teTHcAMtDt5/feyA/DdXH91Aqj4NHB9\n7vdxPMXajC6KiMdTWV9JH98nIp4BzgWGZMN5hRNA/02OiA3Jjhq+CtwpaRMASR9UVn3zjKSFwCFk\nR/sAm5Id0fRmd2AtsqOvaj1XwUTEy2RVD6Nz45/OdY8GZlfNP5sswbxheWncmrk4805NMd8iaZak\nI6vG59c7OxdTB/D1VI3zvKQXyI7KRqdx7VXjjiY72q3EX71cACLiceBwsrOEbkmXSnpbjbhrqd4u\ns4HVgVG9fJ9a5qayH0mWmGpWraTqoafJ6sF782NglKQ9qoafRv1tXlGknIuqV14V3bnuJcA6q7Ce\nvtwB7CCpjWxfdSWwfTp7XjciZtSYZzQr/j6WAM8VWNeCXHfR7/M9YGdJ1dV+Q4YTQP9V2gAiIn5F\ndlS6fRr3c+AaoD0i1ifbmVeOrOcAb6+z3POAm4AbU/VS3qY9K5fWITtNn5sbnz/imQeMq5p/bNX0\nm+a6O4BXgWerA4qIlyPiPyJiPLAn8LWqOvfq5cxL3XOAUyJiw/TZICLWiYgr0ri/VI1bLyIqR67z\naiw3H9PlEbFDbvh/Vsfdi3lVy+ogO9rO79gKNRinI82jgC0rVXw1HEdWfVNdlvllfAs4qWr4S31s\n84oi5VxUvfJqmJTgl5IOrCLiJbId9ReB3/Uy23xW/H28hawaqGexAxjf82RVbicN5HIbyQlgAEma\nDKzP61dbrAO8EBHLJG1LVu1Q8XNgJ0l7S1pN0oaS/jG/vIj4KvBn4HpJa+VG7ZYa/NYk++f7Q0TM\no7YbgAmS9kvr2Rd4F1l9ccUBkt6ZfizfAn5R62oZSbtLGp96FwN/J0t4FV+R1J7aLo4hq/ICOJ+s\n/WPbtJy1lTWQrw1MBxZL+qayBt/VJL1H0vvTvL8Ajpa0vqQxwKG5eLaQ9JG0HV4l21m81st2qHYZ\ncISkcSmJnkLWRlCZv1YVWK/SDvz7QHV1XWX8HcBD5NowaqznZ2Rnfbv2jKy9zWt9xyLlXFS98hpQ\nyi5CqHfJ6x1kZV6p7umq6q92FbBH+n2sQVZFk9/G3WQNvitVvnX8ANiObFsPOU4A/Xd9unJgEdnO\n+KCIeDSN+zJwUhp3HNBzBBURc4DdyBqbnidrBK11KvlFsuqDa/T6lRaXklV7PAdsDRyQm36FHXc6\nStkjrefZ9Hf3NLziEuCnZEeRawKH9fJdJwC3SloM/B44OyLuzI2/FLiFrCH0MbKdKhFxD1m98lmp\nznwmaUeYdrh7kNW5PkF2Ke35wLppmd8CnkrjbgLy13C/ieyI/68p9reSVR8VcWH63neSVbEsAf49\nN35VjuguBDaVtHsvyzgO2KBqeE932hYnVE1Ta5u/YedXoJz7+j75OHotr4LLWhmb0vvRPGQ7+nXI\nyqlW/woi4mGyCwMuI/ufeI4Vq/J+QZYQnpP0x8psqxp8RCwmqxqtd9XQoKUaB3oDt/DsiO1isnrV\n14DzI+KMqml2BK4F/pIGXR0RtRrDjJ5L+uZExIDcKCRpGlkj8IX9XM4TwOejjytMzCrShQtXRMT2\nfU5spejtqoSB8neyKyNmpNPseyTdkjtCrrgzInqrOzWzYShd1eWdfxOVWgUUEQsqLfWpAecRal+V\nMFD1ca1goE/ZBmp5Q7IRzKyVlVoFtMKKpHFkDTj/kJJBZfiOZNdQP012xcI3Uj2emZmVqOwqIKDn\nUsWrgMPyO//kHmBsRCyRtCvZZZNbNCIuM7NWVvoZQLr78dfAjRFxeoHpnyC7tf/5quGuYjAzWwUR\nUbOavRGXgV4IPNzbzl/SqFz3tmRJ6fla00bEoPyceOKJTY/BH5dfK35cdn1/6im1CkjSP5E9lOlB\nZU/eC7IbhDqy/XmcB+wt6Utkd2EuJXugmpmZlazUBBDZg6vqPpM+Is4Gzi4zDjMzeyPfCTwAJk2a\n1OwQrB9cfkOXy65/GnYZaH9JiqESq5nZYCGJ6KURuCGXgZqZDaRx48Yxe3b1069bW0dHB08++eRK\nzeMzADMbctJRbbPDGFR62yb1zgDcBmBm1qKcAMzMWpQTgJlZi3ICMDMbYJttthm33z74X43hBGBm\nw0Jb2zgklfZpaxvX7K844HwZqJkNC93dsynztRTd3cPvtSU+AzAzK8mrr77K4YcfTnt7O2PGjOGI\nI45g2bJlPePPP/98JkyYwMYbb8xee+3F/Pnze8aNGDGCM888k/Hjx7PJJpvwzW9+c8DjcwIwMyvJ\nySefzPTp03nggQe4//77mT59OiefnL3y/Pbbb+eYY47hqquuYv78+YwdO5b99ttvhfmvueYa7r33\nXu69916uvfZaLrywX6/ufqNmP6p0JR5pGmZmERG19gdAQJT4Kb4PGjduXNx2220xfvz4uOmmm3qG\n33zzzbHZZptFRMTnP//5OPLII3vGvfTSS7HGGmvE7NmzIyJCUtxyyy09488555z42Mc+tlLbJDe8\n5n7VZwBmZgOsclfuvHnzGDt2bM/wjo4O5s2bB8C8efPo6OjoGbf22muz0UYbMXfu3J5hY8aMqTnv\nQHECMDMrgSTa29tXeGbR7NmzGT16NACjR49eYdzLL7/Mc889t8JOf86cOT3dTz31VM+8A8UJwMxs\ngEV6Js9+++3HySefzLPPPsuzzz7LSSedxIEHHgjAlClTuOiii3jggQd45ZVXOOaYY5g4cSKbbrpp\nz3JOO+00Fi5cyJw5czj99NPf0EbQX74M1MyGhVGjOkq9VHPUqI6+J0qkLI7jjz+eF198kS233BJJ\nfOYzn+HYY48FYKedduKkk07iU5/6FAsXLmS77bbj8ssvX2E5kydPZptttuHFF19k6tSpHHzwwQP3\nhfDTQM1sCGqFp4GOGDGCWbNmsfnmmxea3k8DNTOzwpwAzMwGoUo1UpncBmBmNggtX7689HX4DMDM\nrEU5AZiZtSgnADOzFuUEYGbWopwAzMxalBOAmdkAmzlzJltvvTXrrbceZ511Vt1pR4wYwV/+8hcA\npk6dygknnNCIEAFfBlqqtjFtdM/tbmoMo9pHseDpBU2NwawRyv69rcxv6dRTT+WjH/0o9913X5/T\nNuJ6/944AZSoe243dDY5hs7mJiCzRin797Yyv6XZs2czZcqUQtM285EWrgIyMxtAO+20E9OmTePQ\nQw9l5MiRtLe3r/Amr5/+9KfssMMOTYzwdU4AZmYD6LbbbmOHHXbg7LPPZvHixWyxxRZvmKaZ1T55\nTgBmZiUYCk8rdQIwM2tRTgBmZiVae+21WbJkSU//ggWD56o8JwAzsxJttdVWXH311SxdupRZs2Zx\nwQUXNDukHr4M1MyGhVHto0q97HlU+6jC0+YbeY844gjuvvtu2tra2HLLLTnggAO49dZba07baKW+\nElLSGOBiYBTwGnB+RJxRY7ozgF2Bl4HPRcSMGtMMuVdCSmr6fQB0Do3GKLOV0QqvhFxZq/JKyLLP\nAP4OfC0iZkhaB7hH0i0R8WguuF2B8RExQdIHgXOBiSXHZWbW8kptA4iIBZWj+Yh4CXgEaK+abDLZ\nWQIRcRewnqTi51pmZrZKGtYILGkcsBVwV9WodmBOrn8ub0wSZmY2wBrSCJyqf64CDktnAquks7Oz\np3vSpElMmjSp37GZmQ0nXV1ddHV1FZq21EZgAEmrA78GboyI02uMPxeYFhFXpP5HgR0jortqOjcC\nr4pONwLb8ONG4DdalUbgRlQBXQg8XGvnn1wHHAQgaSKwsHrnb2ZmA6/UKiBJ/wR8FnhQ0n1AAMcA\nHUBExHkRcYOk3STNIrsMdGqZMZnZ0NfR0TFoHqg2WHR0dKz0PKUmgIj4PbBagekOLTMOMxtennzy\nyWaHMCz4URBmZi3KCcDMrEUVqgKStB0wLj99RFxcUkxmZtYAfSYASZcA44EZwPI0OEh375qZ2dBU\n5Azg/cC7h9xF+GZmVleRNoCHgLayAzEzs8YqcgawMfCwpOnAK5WBEbFnaVGZmVnpiiSAzrKDMDOz\nxuszAUTEHY0IxMzMGqvPNgBJEyXdLeklSa9KWi7pxUYEZ2Zm5SnSCHwWMAV4DHgz8AXg7DKDMjOz\n8hW6EzgiZgGrRcTyiLgI2KXcsMzMrGxFGoGXSFoTmCHpVGA+foSEmdmQV2RHfmCa7lCyxzVvCny6\nzKDMzKx8Rc4AngVejYi/Ad+StBrwpnLDMjOzshU5A7gNeEuu/83AreWEY2ZmjVIkAayVf5F76n5L\nnenNzGwIKJIAXpb0vkqPpG2ApeWFZGZmjVCkDeBw4BeS5gEiezDcvqVGZWZmpSvyKIi7Jb0TeEca\n9OeIWFZuWGZmVrYij4LYh6wd4CFgL+CKfJWQmZkNTUXaAI6PiMWStgd2Ai4AflhuWGZmVrYiCaDy\nGsjdgfMj4jfAmuWF1DtJTf+0tY1rxlc3MxtwRRqB50r6EfBx4HuS3kTTHgXR/LdSdner2SGYmQ2I\nIjvyzwA3AztHxEJgQ+AbpUZlZmalK3IV0BLg6lz/fLIHwpmZ2RDmp3qambUoJwAzsxZVpBEYSaOA\nD6Te6RHxTHkhmZlZIxS5EewzwHRgH7IG4bsk7V12YGZmVq4iZwDHAh+oHPVLeivZ46CvKjMwMzMr\nV5E2gBFVVT7PFZzPzMwGsSJnADdJuhm4LPXvC9xYXkhmZtYIRe4D+IakTwHbp0HnRcSvyg3LzMzK\n1mcCkPS9iDiS3M1guWFmZjZEFanL/3iNYbsWWbikCyR1S3qgl/E7Sloo6d70Oa7Ics3MrP96PQOQ\n9CXgy8DmVTvwkcDvCy7/IuBM4OI609wZEXsWXJ6ZmQ2QelVAl5I19n4XOCo3fHFEPF9k4RHxO0kd\nfUzmx2uamTVBrwkgIhYBi4ApJcfwIUkzgLnANyLi4ZLXZ2ZmFHwURInuAcZGxBJJuwLXAFs0OSYz\ns5bQ1AQQES/lum+UdI6kDXuvYurMdU9KHzMzq+jq6qKrq6vQtIro+y1bqR5/QkTcKunNwOoRsbjQ\nCqRxwPUR8d4a40ZFRHfq3ha4MiLG9bKcGAxvBANRZJtB9grLFXJWM3RSOF4zG34kERE121qL3Afw\nr8AXyd4ENh4YA5xL9oL4vua9lOwwfSNJTwEnkr1POCLiPGDvdLXRMmAp2V3GZmbWAH2eAaQG2m2B\nuyJi6zTswVpH9GXyGcAq6vQZgFkrq3cGUORGsFci4tXcwlZncOyJzcysH4okgDskHQO8WdLHgV8A\n15cblpmZla1IAjgK+CvwIHAIcAPgRzaYmQ1xRS4D3Qu4OCLOLzsYMzNrnCJnAJ8AZkq6RNIeqQ3A\nzMyGuD4TQERMBd5OVvc/BXhc0o/LDszMzMpV6Gg+IpZJupHs6p83k1ULfaHMwMzMrFx9ngFI2lXS\nT4DHgE8DPwbaSo7LzMxKVuQM4CDgCuCQiHil5HjMzKxBirwTuOzHQZuZWRPUeyPY7yJie0mLWfHO\nX5E9y2fd0qMzM7PS1HshzPbp78jGhWNmZo1SpBH4kiLDzMxsaClyI9h78j3pRrBtygnHzMwapdcE\nIOnoVP+/paQX02cx0A1c27AIzcysFL0mgIj4bqr/Py0i1k2fkRGxUUQc3cAYzcysBEUuAz1a0gbA\nBGCt3PA7ywzMrJnaxrTRPbe72WEwqn0UC55e0OwwbJgq8krILwCHkb0KcgYwEfgD8NFyQzNrnu65\n3c1/mxvQ3dn8JGTDV5FG4MOADwCzI+IjwNbAwlKjMjOz0hVJAH+LiL8BSHpTRDwKvKPcsMzMrGxF\nngX0tKT1gWuA30p6AZhdblhmZla2Io3An0ydnZKmAesBN5UalZmZla7es4A2rDH4wfR3HeD5UiIy\nM7OGqHcGcA/ZQ+CUG1bpD2DzEuMyM7OS1XsY3GaNDMTMzBqryMPgJOkAScen/rGSti0/NDMzK1OR\ny0DPAT4E7J/6FwNnlxaRmZk1RJEE8MGI+ArwN4CIeAFYs9SobNhpaxuHpKZ+2trGNXszmA0qRe4D\nWCZpNdJbwSS9FXit1Khs2Onuns2KL5ZrRgzqeyKzFlLkDOAM4FfAJpJOAX4HfKfUqMzMrHRFbgT7\nuaR7gJ3ILgHdKyIeKT0yMzMrVd0EkKp+/hQR7wQebUxIZmbWCHWrgCJiOfBnSWMbFI+ZmTVIkUbg\nDYA/SZoOvFwZGBF7lhaVmZmVrkgCOL70KMzMrOGKNALf0YhAzMyssYpcBrrKJF0gqVvSA3WmOUPS\nY5JmSNqqzHjMzOx1pSYA4CJg595GStoVGB8RE4BDgHNLjsfMzJJCCUDSmpK2lPReSYUfAxERvwNe\nqDPJZODiNO1dwHqSRhVdvpmZrboiTwPdHXic7I7gs4BZ6ch9ILQDc3L9c9MwMzMrWZGrgL4PfCQi\nZgFIGg/8BrixzMDMzKxcRRLA4srOP/kL2SOhB8JcYNNc/5g0rBedue5J6WNmZhVdXV10dXUVmlYR\n9Z/QKOmHQAdwJdnjHPcBngJuBYiIq/uYfxxwfUS8t8a43YCvRMTukiYC/xURE3tZTjT7aZIZ0dc2\n65lSWjFnNUMnheMtk1R5k2hToxhaZQeDovza2salp7k216hRHSxY8GSzwxhyJBERNR+FW+QMYC2g\nG9gx9f8VeDPwCbJfdK8JQNKlZIfpG0l6CjiR7F0CERHnRcQNknaTNIvsLuOpxb6SmTXKYHiUdxaH\nH+c90IrcCLbKO+WI2L/ANIeu6vLNzGzV9ZkAJF1EjfQfEQeXEpGZmTVEkSqgX+e61wI+CcwrJxwz\nM2uUIlVAv8z3S7qM7K1gZmY2hK3KoyAmAJsMdCBmZtZYRdoAFpO1AVSu41sAHFlyXGZmVrIiVUAj\nGxGImZk1Vq8JQNL76s0YEfcOfDhmZtYo9c4Avp/+rgW8H7ifrBpoS+CPwIfKDc3MzMrUayNwRHwk\nIj4CzAfeFxHvj4htgK2p+7weMzMbCopcBfSOiHiw0hMRDwHvKi8kMzNrhCI3gj0g6cfAz1L/Z4Fe\nX/FoZmZDQ5EEMBX4EnBY6r8T+GFpEZmZWUMUuQz0b5LOBW6IiD83ICYzM2uAIq+E3BOYAdyU+reS\ndF3ZgZmZWbmKNAKfCGwLLASIiBnAZmUGZWZm5SuSAJZFxKKqYc1/O4SZmfVLkUbgP0naH1hN0gTg\n34H/KTcsMzMrW5EzgK8C7wFeAS4FFgGHlxmUmZmVr8hVQEuAYyWdkrrNzGwYKHIV0HaSHgYeTf3/\nKOmc0iMzM7NSFakC+gGwM/AcQETcD3y4zKDMzKx8hd4IFhFzqgYtLyEWMzNroCJXAc2RtB0QktYg\neyTEI+WGZWZmZStyBvBvwFeAdmAesFXqNzOzIazIVUDPkj0B1MzMhpEiVwFtLul6SX+V9IykayVt\n3ojgzMysPEWqgC4FrgTeBowGfgFcVmZQZmZWviIJ4C0RcUlE/D19fkb2nmAzMxvCilwFdKOko4DL\nyR4Cty9wg6QNASLi+RLjMzOzkhRJAJ9Jfw+pGr4fWUJwe4CZ2RBU5CogP/vfzGwY6rUNQNIHJLXl\n+g9KVwCdUan+MTOzoateI/CPgFcBJH0Y+E/gYrLHQZ9XfmhmZlamelVAq+UaePcFzouIXwK/lDSj\n/NDMzKxM9c4AVpNUSRA7AbfnxhVpPDYzs0Gs3o78MuAOSc8CS4H/BpD0drJqIDMzG8J6PQOIiFOA\nrwM/AbaPiMjN89WiK5C0i6RHJc2UdGSN8TtKWijp3vQ5buW+gpmZrYq6VTkR8b81hs0sunBJI4Cz\nyKqQ5gF3S7o2Ih6tmvTOiNiz6HLNzKz/Cr0Qph+2BR6LiNkRsYzsbuLJNaZTyXGYmVmVshNAO5B/\nm9jTaVi1D0maIek3kt5dckxmZsbguJrnHmBsRCyRtCtwDbBF7Uk7c92T0sfMzCq6urro6uoqNK1e\nb9sdeJImAp0RsUvqPwqIiPhenXmeALapfsicpMgePdRsoug2k7RizmqGTgrHWyZJNL/8hljZwaAo\nv8FRdrAy5Wevk0RE1KxmL7sK6G7g7ZI6JK1J9gC566qCG5Xr3pYsKfkJo2ZmJSu1Cigilks6FLiF\nLNlcEBGPSDokGx3nAXtL+hKwjOx+g33LjMnMzDKltwFExE3AO6qG/SjXfTZwdtlxmJnZisquAjIz\ns0HKCcDMrEU5AZiZtSgnADOzFuUEYGbWopwAzMxalBOAmVmLcgIwM2tRTgBmZi3KCcDMrEU5AZiZ\ntSgnADOzFuUEYGbWopwAzMxalBOAmVmLcgIwM2tRTgBmZi3KCcDMrEU5AZiZtSgnADOzFuUEYGbW\nopwAzMxalBOAmVmLcgIws2GnbUwbkpr+aRvT1uxNUdfqzQ7AzGygdc/ths5mRwHdnd3NDqEunwGY\nmbUoJwAzsxblBGBm1qKcAMzMWpQTgJlZi3ICMDNrUU4AZmYtygnAzKxFOQGYmbUoJwAzsxZVegKQ\ntIukRyXNlHRkL9OcIekxSTMkbVV2TGZmVnICkDQCOAvYGXgPMEXSO6um2RUYHxETgEOAc8uMqQxd\nXV3NDsH6weU3dLns+qfsM4BtgcciYnZELAMuByZXTTMZuBggIu4C1pM0quS4BpT/CYc2l9/Q5bLr\nn7ITQDswJ9f/dBpWb5q5NaYxM7MB5kZgM7MWpYgob+HSRKAzInZJ/UcBERHfy01zLjAtIq5I/Y8C\nO0ZEd9WyygvUzGwYiwjVGl72C2HuBt4uqQOYD+wHTKma5jrgK8AVKWEsrN75Q+9fwMzMVk2pCSAi\nlks6FLiFrLrpgoh4RNIh2eg4LyJukLSbpFnAy8DUMmMyM7NMqVVAZmY2eA3rRmBJG0q6T9K9kuZL\nejrXX/fsR9Lekh6StFzS+6rGHZ1uXHtE0j/3Mn+XpLtz/dtImrYSsU9LN9BV4t246LzDRT/L78Q0\n/b3ps0tunMtvgPSzjE5NZTBD0i8lrZsb17QykrSmpMvT+v8gaWxunn9JN7X+WdJBRdc1aEVES3yA\nE4CvrcT07wAmALcD78sNfxdwH1n12ThgFulMqmr+acCTwM6pfxvg9pVY/zRg62Zvt8HyWYXyO7HW\n9C6/QVVGHwNGpO7/BL6but/dzDICvgSck7r3BS5P3RsAjwPrAetXupu93fvzGdZnAFVWqhE5Iv4c\nEY/VmG+akdXeAAACnElEQVQy2T/E3yPiSeAxshveajkNOG5lA81ppfLpy6pcBFBrHpdfeVb2N3Zr\nRLyWev8XGJO696S5ZTQZ+Gnqvgr4aOreGbglIhZFxEKyts1dasw/ZJR9FdCgJelOYJ0ao/4jIm6v\nM2s78Idcf283rkWa7pOSdgReyq17C+CKNE21SRHxYur+iaRlwNURcXKdmFpOwfI7VNKBwB+Br0fE\nIlx+DbOSv7GDgctSd7PLqOfm1MguZFkkaUOG4U2rLZsAIuLDJa+icjR0CnA80PMgvIiYCWzdx/z7\nR8R8SWsDV0s6ICJ+Vk6oQ0+B8jsH+HZEhKSTge8DX1iJVbj8+qnob0zSscCyiLisz4mrZk1/yy6j\nYXsJessmgHR0MrJqcND3GcBcYNNc/5g0rKaImCbpJGBibt35o5P8P1eQjk4iYn6a/2VJl5KdArfU\nDqSevsovIv6aG34+cH3qdvk1SJHfmKTPAbvxejULNL+MKuufJ2k1YN2IeF7SXGBSVVyFG50Ho5ZN\nACt5BpD/B7oO+LmkH5Cd/r0dmN7H/KeQPeX08bTuukcn6Z9u/Yh4TtIawB7Ab1ci3mGvr/KT1BYR\nC1Lvp4CHUrfLr0EKlNEuwDeAD0fEK7lRzS6j64B/Ae4C9iG7EATgZuAUSeuRtR18HDiqj7gGtZZN\nAH2RtBdwJrAx8GtJMyJi14h4WNKVwMPAMuDLkS4RqNIzLCJulPQMtesja3kTcHO6jG414Fayo1gr\n7lRl75Z4jexKkUMAXH6DypnAmsBvJQH8b0R8eRCU0QXAJZIeA54je4IBEfFCOtP4Y1rPt1Jj8JDl\nG8HMzFpUq12mZmZmiROAmVmLcgIwM2tRTgBmZi3KCcDMrEU5AZiZtSgnADOzFuUEYGbWov4PTdxe\nxmh/hbIAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fe742fa7510>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"main()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What I find interesting about the *loop* implementation based on the `shared_update` technique: \n",
"\n",
"* as fast as scan in all cases, but simpler to debug and quicker to compile\n",
"* almost as fast as full-graph compilation for small RNNs, and scalable to large RNNs\n",
"* strictly more powerful - the tape-based approach to AD lets us handle models with arbitrary data-dependent control flow\n",
"\n",
"There are also some drawbacks: namely the programmer must take on some of the bookkeeping that scan() provides automatically. If you look at `elman_loop` below you'll see that it's longer and a bit less math-y than the first two pure-Theano expressions of the program. I wonder if there might be syntactic constructs like the \"Layers\" in other nnet libs that might help here.\n",
"\n",
"N.B. That to get the efficient memory usage and BLAS-based RNN update required programming outside individual fprop and bprop functions based on my knowledge of how these fprops were going to be used. Maybe with some conventions or protocols for how some *kinds* of Layer interact with shared variables, index variables, and parameters (e.g. linear ops that can be combined into GEMM), there might be a useful abstraction layer that expresses a tight subset of what's possible and correct.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import time\n",
"import numpy as np\n",
"import theano\n",
"import theano.tensor as TT\n",
"\n",
"\n",
"# We want to iterate this computation\n",
"def step_math(H_i, W):\n",
" return TT.tanh(TT.dot(H_i, W))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# -- Scan-style definition of iterated computation with Theano\n",
"\n",
"def elman_scan(rng, seqlen=10, nhid=5):\n",
"\n",
" X_python = rng.randint(2, size=seqlen)\n",
" X_shared = theano.shared(X_python)\n",
" W_rec = theano.shared(rng.randn(nhid, nhid))\n",
" W_readout = theano.shared(rng.randn(nhid))\n",
" H_0 = rng.randn(nhid)\n",
" del rng\n",
"\n",
" result, updates = theano.scan(\n",
" fn=lambda x_i, prior_result, W_rec_: step_math(prior_result, W_rec_),\n",
" outputs_info=dict(\n",
" initial=theano.shared(H_0)),\n",
" sequences=X_shared,\n",
" non_sequences=W_rec)\n",
" assert not updates\n",
" del updates\n",
"\n",
" sqerr = TT.mean((TT.nnet.sigmoid(TT.dot(result, W_readout)) - X_shared) ** 2)\n",
"\n",
" g_rec, g_readout, g_result = TT.grad(sqerr, [W_rec, W_readout, result])\n",
"\n",
" fn = theano.function(\n",
" inputs=[],\n",
" outputs=[sqerr],\n",
" updates=[\n",
" (W_rec, W_rec - 0.1 * g_rec),\n",
" (W_readout, W_readout - 0.1 * g_readout),\n",
" ])\n",
" return fn"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# -- Full fixed-graph computation with Theano\n",
"\n",
"def elman_full(rng, seqlen=10, nhid=5):\n",
"\n",
" X_python = rng.randint(2, size=seqlen)\n",
" X_shared = theano.shared(X_python)\n",
" W_rec = theano.shared(rng.randn(nhid, nhid))\n",
" W_readout = theano.shared(rng.randn(nhid))\n",
" H_0 = theano.shared(rng.randn(nhid))\n",
" del rng\n",
"\n",
" H_ii = H_0\n",
" Hs = []\n",
" for ii in xrange(seqlen):\n",
" H_ii = step_math(H_ii, W_rec)\n",
" Hs.append(H_ii)\n",
"\n",
" result = TT.stack(Hs)\n",
" sqerr = TT.mean((TT.nnet.sigmoid( TT.dot(result, W_readout)) - X_shared) ** 2)\n",
" g_rec, g_readout, g_result = TT.grad(sqerr, [W_rec, W_readout, result])\n",
"\n",
" fn = theano.function(\n",
" inputs=[],\n",
" outputs=[sqerr],\n",
" updates=[\n",
" (W_rec, W_rec - 0.1 * g_rec),\n",
" (W_readout, W_readout - 0.1 * g_readout),\n",
" ])\n",
" def _fn():\n",
" return fn()\n",
" #print W_readout.get_value()\n",
" #print W_rec.get_value()\n",
" return _fn\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# -- Update-style computation with shared variables:\n",
"\n",
"def shared_update(*args):\n",
" return theano.function(inputs=[], outputs=[], updates=args).fn\n",
"\n",
"\n",
"def elman_loop(rng, seqlen=10, nhid=5):\n",
"\n",
" X_python = rng.randint(2, size=seqlen)\n",
" X_shared = theano.shared(X_python)\n",
" W_rec = theano.shared(rng.randn(nhid, nhid))\n",
" W_readout = theano.shared(rng.randn(nhid))\n",
" H_0 = rng.randn(nhid)\n",
" del rng\n",
"\n",
" # storage for manually computed hidden state and gradient\n",
" # N.B. if sequences have different lengths, then this should be the max length\n",
" H_init = np.zeros((seqlen + 1, nhid))\n",
" H_init[0] = H_0\n",
" H = theano.shared(H_init)\n",
" gH = theano.shared(H.get_value())\n",
" gH_in = theano.shared(H.get_value())\n",
"\n",
" # pointer into H where to store next recurrent value\n",
" H_idx = theano.shared(np.asarray(0))\n",
"\n",
" sqerr = TT.mean((TT.nnet.sigmoid(TT.dot(H[1:], W_readout)) - X_shared) ** 2)\n",
" gW_readout, grad_H = TT.grad(sqerr, [W_readout, H])\n",
"\n",
" cost = theano.shared(np.asarray(0.0))\n",
"\n",
" cost_fn = shared_update(\n",
" (cost, sqerr),\n",
" (W_readout, W_readout - 0.1 * gW_readout),\n",
" (gH, grad_H))\n",
"\n",
" # -- FPROP (computing iteration H_idx + 1 from H_idx)\n",
" step_fprop = shared_update(\n",
" (H_idx, H_idx + 1),\n",
" (H, TT.set_subtensor(H[H_idx + 1], step_math(H[H_idx], W_rec))))\n",
"\n",
" # -- BPROP (computing iteration H_idx - 1 from H_idx)\n",
" H_im1 = H[H_idx - 1]\n",
" gH_im1, gW_rec = TT.grad(\n",
" (gH[H_idx] * step_math(H_im1, W_rec)).sum(),\n",
" wrt=[H_im1, W_rec])\n",
"\n",
" # Knowing that W_rec is used via a dot operation (as right argument), we\n",
" # we stash gradient coming back to W_rec\n",
" # so we can compute gradient on W_rec later with a single matrix dot()\n",
" gH_i_in = gW_rec.owner.inputs[1].owner.inputs[0]\n",
" # Otherwise, in the general case we would have to increment into a gW_rec\n",
" # shared variable with each call to step_bprop, which would be slower\n",
"\n",
" step_bprop = shared_update(\n",
" (H_idx, H_idx - 1),\n",
" (gH, TT.inc_subtensor(gH[H_idx - 1], gH_im1)),\n",
" (gH_in, TT.set_subtensor(gH_in[H_idx], gH_i_in)))\n",
"\n",
" sgd_update = shared_update(\n",
" (W_rec, W_rec - 0.1 * TT.dot(H[:-1].T, gH_in[1:])))\n",
"\n",
" # as our Python control flow unfolds within fn()\n",
" # keep a stack of the bprops we're going to need to call\n",
" trace = []\n",
" def fn():\n",
" del trace[:]\n",
" for Xi in X_python:\n",
" step_fprop()\n",
" trace.append(step_bprop)\n",
" cost_fn()\n",
" for bprop in reversed(trace):\n",
" bprop()\n",
" sgd_update()\n",
" return cost.get_value(),\n",
" return fn\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def main():\n",
"\n",
" times = {}\n",
"\n",
" cases = [\n",
" (1000, 10, 5), # Tiny: 10 steps for 5 hidden units\n",
" (100, 50, 5), # Small: 50 steps for 5 hidden units\n",
" (50, 20, 500)] # Med: 20 steps for 500 hidden units\n",
"\n",
" for n_epochs, seqlen, nhid in cases:\n",
" fn_scan = elman_scan(np.random.RandomState(123),\n",
" seqlen=seqlen, nhid=nhid)\n",
" fn_loop = elman_loop(np.random.RandomState(123),\n",
" seqlen=seqlen, nhid=nhid)\n",
" fn_full = elman_full(np.random.RandomState(123),\n",
" seqlen=seqlen, nhid=nhid)\n",
"\n",
" for fn, label in ((fn_scan, 'scan'),\n",
" (fn_loop, 'loop'),\n",
" (fn_full, 'full')):\n",
" for epoch in xrange(n_epochs):\n",
" if epoch == 1:\n",
" t0 = time.time()\n",
" fn()\n",
" t1 = time.time()\n",
" times[n_epochs, seqlen, nhid, label] = t1 - t0\n",
" print('{}x{}: {} {} evals averaged {} seconds'.format(\n",
" seqlen, nhid,\n",
" n_epochs, label, (t1 - t0) / n_epochs))\n",
"\n",
" loop_heights = []\n",
" full_heights = []\n",
" for n_epochs, seqlen, nhid in cases:\n",
" rel_loop_speed = (\n",
" times[n_epochs, seqlen, nhid, 'scan']\n",
" / times[n_epochs, seqlen, nhid, 'loop'])\n",
" rel_full_speed = (\n",
" times[n_epochs, seqlen, nhid, 'scan']\n",
" / times[n_epochs, seqlen, nhid, 'full'])\n",
" loop_heights.append(rel_loop_speed)\n",
" full_heights.append(rel_full_speed)\n",
" import matplotlib.pyplot as plt\n",
" plt.bar(np.arange(len(cases)), loop_heights, width=.3, color='b', label='loop')\n",
" plt.bar(np.arange(len(cases)) + .4, full_heights, width=.3, color='g', label='full')\n",
" plt.title('Backprop speeds for RNNs of len T, width N')\n",
" plt.ylabel('Speedup relative to scan')\n",
" plt.legend(loc='upper right')\n",
" plt.xticks(\n",
" np.arange(len(cases)) + .35,\n",
" ['T={} N={}'.format(seqlen, nhid) for n_epochs, seqlen, nhid in cases])\n",
" plt.show()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment