Skip to content

Instantly share code, notes, and snippets.

@rajarsheem
Last active March 20, 2017 07:34
Show Gist options
  • Save rajarsheem/de132ba50f739f84c01ee4a0ff97e364 to your computer and use it in GitHub Desktop.
Save rajarsheem/de132ba50f739f84c01ee4a0ff97e364 to your computer and use it in GitHub Desktop.
tiny demos for pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pytorch Demos for recurrent nets\n",
"## Contents:\n",
"1. RNN to predict sine curve.\n",
"2. Vanilla Seq2Seq to reverse strings"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**A simple pytorch code for RNN to predict sine curve**"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"from torch.autograd import Variable\n",
"import torch\n",
"import numpy as np\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"import matplotlib.pyplot as plt\n",
"import random, string\n",
"from itertools import chain\n",
"from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([5, 26]) torch.Size([5, 26])\n"
]
}
],
"source": [
"def batch_data(size=5):\n",
" x = torch.FloatTensor(size, 26)\n",
" y = torch.FloatTensor(size, 26)\n",
" for i in range(size):\n",
" start = int(np.random.choice(10, 1)[0])\n",
" end = start + 5\n",
" x[i] = torch.sin(torch.range(start+0.2, end+0.3, 0.2))\n",
" y[i] = torch.sin(torch.range(start, end, 0.2))\n",
" \n",
" return x, y\n",
" \n",
"a, b = batch_data(5)\n",
"print(a.size(), b.size())"
]
},
{
"cell_type": "code",
"execution_count": 177,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class IntegerPredictor(nn.Module):\n",
" \n",
" def __init__(self, ntoken, dims, nlayers):\n",
" super(IntegerPredictor, self).__init__()\n",
" self.rnn = nn.RNN(ntoken, dims, nlayers, nonlinearity='tanh', bias=False)\n",
" self.predictor = nn.Linear(dims, ntoken)\n",
" self.tanh = nn.Tanh()\n",
" \n",
" def forward(self, inputs, hidden):\n",
" output, hidden = self.rnn(inputs, hidden)\n",
" preds = self.tanh(self.predictor(output.view(-1, 20)).view(5, 26))\n",
" return preds\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 183,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.6445996761322021\n",
"0.1592656970024109\n",
"0.03313932940363884\n",
"0.03268695995211601\n",
"0.01342559140175581\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAhQAAAFkCAYAAAB4sKK5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3X2UZHdd5/H3N50AEpbJmOgMCEKY7h7icYmZDmBQBJIO\n3TO6YVezandYIigPgiunPRxYPaw8eA5ZeUhE3QjKLsgC5UHZFRcz3dAQwBUSDtMBCSTUdJOAGiYk\n6WFQkgDp/PaPW81U9fRDVd26devh/TqnT6Z+dW/1r2+q+37q9xgpJSRJkvI4rewKSJKk/megkCRJ\nuRkoJElSbgYKSZKUm4FCkiTlZqCQJEm5GSgkSVJuBgpJkpSbgUKSJOVmoJAkSbkVGigi4ukR8TcR\n8c8R8WBEXLbD8c+oHVf/tRYRP1xkPSVJUj5Ft1CcCXwOeCnQ7KYhCRgD9ta+HpVS+kYx1ZMkSZ1w\nepEvnlKaB+YBIiJaOPWulNK3iqmVJEnqtF4cQxHA5yLijoj4cEQ8rewKSZKk7RXaQtGGrwMvBj4L\nPBR4IfDxiHhKSulzm50QEWcDU8DtwP1dqqckSYPgYcDjgYWU0j15XqinAkVKqQpU64puiIh9wBxw\n5RanTQHvLbpukiQNsCuA9+V5gZ4KFFv4DPBT2zx/O8B73vMezjvvvK5UqJfNzc1xzTXXlF2N0nkd\nMl6Hk7wWGa/DSV4LuOWWW3juc58LtXtpHv0QKH6CrCtkK/cDnHfeeRw4cKA7Nephu3bt8jrgdVjn\ndTjJa5HxOpzktWiQe8hAoYEiIs4ERskGWgI8ISLOB1ZTSv8YEVcBj04pXVk7/uXAbcAXyfp1Xgg8\nC7i0yHpKkqR8im6huBC4nmxtiQS8pVb+58ALyNaZeGzd8Q+pHfNo4F7gH4BLUkqfLLiekiQph6LX\nofgE20xNTSk9f8PjNwFvKrJOkiSp83pxHQrlMDMzU3YVeoLXIeN1OMlrkfE6nOS16KxIqdkVsXtT\nRBwAjhw5csTBNZIktWBpaYmJiQmAiZTSUp7XsoVCkiTlZqCQJEm5GSgkSVJuBgpJkpSbgUKSJOVm\noJAkSbkZKCRJUm4GCkmSlJuBQpIk5WagkCRJuRkoJElSbgYKSZKUm4FCkiTlZqCQJEm5GSgkSVJu\nBgpJkpSbgUKSJOVmoJAkSbkZKCRJUm4GCkmSlJuBQpIk5WagkCRJuRkoJElSbgYKSZKUm4FCkiTl\nZqCQJEm5GSgkSVJuBgpJkpSbgUKSJOVmoJAkSbkZKCRJUm4GCkmSlJuBQpIk5WagkCRJuRkoJElS\nbgYKSZKUm4FCkiTlZqCQJEm5GSgkSVJuhQaKiHh6RPxNRPxzRDwYEZc1cc4zI+JIRNwfEdWIuLLI\nOkqSpPyKbqE4E/gc8FIg7XRwRDwe+BDwUeB84K3AOyLi0uKqKEnqV9UqHD4MR4+WXROdXuSLp5Tm\ngXmAiIgmTvl14CsppVfWHn85In4amAM+UkwtJUn9ZnUVZmdhYeFk2dQUVCqwe/fO51ersLICo6Mw\nNlZcPYdJr42h+ElgcUPZAnBRCXWRJPWo2VlY3HC3WFyEmZntz1tdhelp2L8fDh2C8fHs8fHjxdV1\nWPRaoNgL3Lmh7E7gkRHx0BLqI0nqkma7L6rVrGViba2xfG0tK9/u/HaDiHZWaJdHN83NzbFr166G\nspmZGWZ8l0hST2u1+2JlZfvXW17evBtjPYhsVB9EBrn7o1KpUKlUGspOnDjRsdfvtUBxDNizoWwP\n8K2U0ne2O/Gaa67hwIEDhVVMklSM7VoN5udPPX7fvu1fb3R08/J2g8ig2OxD9tLSEhMTEx15/V7r\n8vg0cMmGsmfXyiVJA6ad7ovx8awFY2SksXxkJCvfKhS0G0TUnKLXoTgzIs6PiJ+oFT2h9vixteev\niog/rzvlbbVjfj8i9kfES4HLgauLrKckqRzNtBpsplKBycnGssnJrHwr7QYRNafoFooLgZuAI2Tr\nULwFWAJeV3t+L/DY9YNTSrcDPwtMkq1fMQf8akpp48wPSdIAaLfVYPfurDukWoXrrsv+Oz+/85TR\ndoKImlP0OhSfYJvQklJ6/iZlnwQ606EjSepp660Gi4uN3R4jI9mNfqdWg7Gx1loW1oPI0aNZ64fr\nUHROr42hkCQNmTJaDcbG4OBBw0Qn9dosD0nSAGhlJUpbDQaDgUKS1DF5lsRutftCvcUuD0lSx7gS\n5fAyUEiSOiLPktjqfwYKSVJHtLumRC7uX94zDBSSpI7o6kqUebcNNYh0nIFCktQRuVeibOUm7/7l\nPcdAIUnqmLbWlGj1Ju/+5T3JQCFJ6pi2lsRu9Sbf7mANR40WykAhSeq4sVTlIIcZY4ebdDs3+SL3\nL1fbDBSSpG21NH6x1e6Ldm7y7l/ekwwUkqRNtTV+sdXui3Zv8u5f3nMMFJKkTbU8frGd7ot2b/Lu\nX95z3MtDknSK9WywUX02OOVe30z3xWYBoVLJUkr9N2z2Ju/+5T3DQCFJOkV9Nhijyj5WWGaUZbKb\n76bZoN3uizJu8u5E1nF2eUiSTrFvH+xmlcNMU2U/hznEUcY5zDRncXzzbJB3jMLYGBw86I2+Txko\nJEmnGB+HD589yySNgygmWeQjZ89sfc93jMLQsstDknSqapUL7zl1EMXprGXlmw6iwDEKQ8xAIUk6\nVbsDLNc5RmHo2OUhSTqVi0CpRQYKSdKpXARKLTJQSJI25wBLtcAxFJKkzTnAUi0wUEjSsKhWs8GW\nrQYDB1iqCXZ5SNKga2uXL6k1BgpJGnQt7/Iltc5AIUmDrJ0dQKU2GCgkaZA1s0CV1AEGCkkaYLed\ntv0CVbef7gJV6gwDhSQNsFsfHGeeKR6gcYGqBxhhniluecDZG+oMA4UkDbB9+2CGCos0LlC1yCQz\nVFxBWx3jOhSSNMDGx+GpU7v5ucV5zl07yijLLDPKbSNjTE66vIQ6xxYKSRpw6ytoLzPGPAdZZswV\ntNVxtlBI0oBzBW11g4FCkoaEK2irSHZ5SJKk3AwUkiQpN7s8JKkftbtzqFQQWygkqZ+4c6h6VFcC\nRUS8LCJui4j7IuKGiHjyNsdeGREPRsRa7b8PRsS93ainJPU8dw5Vjyo8UETELwFvAV4DXAB8HliI\niHO2Oe0EsLfu63FF11OSep47h6qHdaOFYg54e0rp3SmlW4GXAPcCL9jmnJRSuiul9I3a111dqKck\n9TZ3DlUPKzRQRMQZwATw0fWylFICFoGLtjn1ERFxe0R8LSL+OiJ+rMh6SlJf2Lf9zqFuzKEyFd1C\ncQ4wAty5ofxOsq6MzXyZrPXiMuAKsjp+KiIeXVQlJakvjI/z2bM33zn0s2dPOdtDpeq5WR4ppRtS\nSu9JKf1DSunvgJ8H7gJeXHLVJKlU1Spces/mO4deek/FIRQqVdHrUNwNrAF7NpTvAY418wIppQci\n4iZg27a8ubk5du3a1VA2MzPDjCOfJQ2IlRX4Jrs5yDyjnNw5dJmsZWJ52UYKba1SqVDZsCPciRMn\nOvb6kQ1pKE5E3ADcmFJ6ee1xAF8D/jCl9KYmzj8N+CLwtymlV2zy/AHgyJEjRzhw4EBnKy9JPaRa\nzZaf2O55A4VasbS0xMTEBMBESmkpz2t1o8vjauCFEfG8iHgi8Dbg4cC7ACLi3RHxhvWDI+K/RsSl\nEXFuRFwAvBf4UeAdXairJPWs8XGYmoKRxiEUjIxk5YYJlanwQJFSej/wCuD1wE3Ak4Cpuqmgj6Fx\ngOZu4E+BLwF/CzwCuKg25VSShlqlApONQyiYnMzKpTJ1ZS+PlNK1wLVbPHfxhse/BfxWN+olSf1m\n926Yn8/WsFpedisP9Q43B5OkPjQ2ZpBQb+m5aaOSJKn/GCgkSVJuBgpJkpSbgUKSJOVmoJAkSbk5\ny0OSylatZutqOwdUfcwWCkkqy+oqTE9n62kfOpQthTk9DcePl10zqWUGCkkqy+wsLC42li0ugpsa\nqg8ZKCSpDNUqLCzA2lpj+dpaVu5e5OozBgpJKsPKyvbPLy93px5ShxgoJKkM+/Zt//zoaHfqIXWI\ngUKSyjA+zmfPnuIBGvcif4ARPnu2e5Gr/xgoJKkE1Spcek+FRRr3Il9kkkvvqTiEQn3HdSgkqQQr\nK/BNdnOQeUY5yijLLDPKMlnLxPKyjRTqLwYKSSpB/RCKZca+HyTWOYRC/cYuD0kqwfg4TE3BSOMQ\nCkZGsnJbJ9RvDBSSVJJKBSYbh1AwOZmVS/3GLg9JKsnu3TA/n61htbzsVh7qbwYKSSrZ2JhBQv3P\nLg9JkpSbgUKSJOVmoJAkSbkZKCRJUm4GCkmSlJuBQpIk5WagkCRJuRkoJElSbi5sNUCq1WwHQ1fb\nkyR1my0UA2B1FaanYf9+OHQo23RoehqOHy+7ZpKkYWGgGACzs7C42Fi2uAgzM+XURxpW1SocPpzt\nzSENGwNFn6tWYWEB1tYay9fWsnL/sEnFW28l/Ln9Vf7w0GEOjR+1lVBDx0DR51ZWtn9+ebk79ZCG\n2YsuX2VuYZoq+znMIY4yztzCNC+83ESh4WGg6HP79m3//Ohod+ohDatqFX7t+lkuobHf8RIW+bWP\nzdhKqKFhoOhz4+MwNQUjI43lIyNZubM9pGJ9/RNVplngdBr7HU9njWkWuOMTJgoNBwPFAKhUYHKy\nsWxyMiuXVKx9bN/vOIr9jhoOrkMxAHbvhvn5bADm8rLrUEjd9JhnbN/v+CPPsN9Rw8FAMUDGxgwS\nUteNj/O9i6c47fpFRtLJbo+1GOHBZ01yhr+UGhJ2eUhSTmf8VYWRZzf2O448e5Iz/sp+Rw0PWygk\nKS/7HSUDhSR1jP2OGmJd6fKIiJdFxG0RcV9E3BART97h+P8YEbfUjv98RBzsRj0lSVJ7Cg8UEfFL\nwFuA1wAXAJ8HFiLinC2OfxrwPuDPgJ8APgj8dUT8WNF1lSRJ7elGC8Uc8PaU0rtTSrcCLwHuBV6w\nxfG/CRxOKV2dUvpySul3gSXgN7pQV0mS1IZCA0VEnAFMAB9dL0spJWARuGiL0y6qPV9vYZvjJUlS\nyYpuoTgHGAHu3FB+J7B3i3P2tni8JEkq2cDM8pibm2PXrl0NZTMzM8zMzJRUI0mSekelUqGyYU+G\nEydOdOz1iw4UdwNrwJ4N5XuAY1ucc6zF4wG45pprOHDgQDt1lCRp4G32IXtpaYmJiYmOvH6hXR4p\npe8BR4BL1ssiImqPP7XFaZ+uP77m0lq5OqxahcOHcYtlSVIu3ZjlcTXwwoh4XkQ8EXgb8HDgXQAR\n8e6IeEPd8W8FpiPityJif0S8lmxg5x93oa5DY3UVpqdh/344dCjbBn16Go4fL7tmkqR+VHigSCm9\nH3gF8HrgJuBJwFRK6a7aIY+hbsBlSunTwCzwIuBzwM8Dz0kpfanoug6T2VlY3DCXZnERHHIiSWpH\nVwZlppSuBa7d4rmLNyn7APCBous1cKpVWFnZcR+BahUWFk4tX1vLyo8edfVgqclfJ0k17jY6CFrs\nv1hZ2f7llpcLqKPUJ+wOlNpjoBgELfZf7Nu3/cuNjnaoXlIfsjtQao+Bot+t91+srTWW1/dfbDA+\nDlNTMDLSWD4ykpXbvKth1cavk6QaA0W/a7P/olKBycnGssnJrFwaVnYHSu0bmJUyh1ab/Re7d8P8\nfPaJa3nZgWcS2B0o5WELRb/L2X8xNgYHDxomJGj8dRqjyjSHGeWo3YFSEwwUg8D+C6lj/uLaVW44\na5oq+znMIY4yzg1nTfMXf+I0D2k7dnkMAvsvpI4566WzXPjNxmkeF35zEX59Jvs9k7QpA8UgGRsz\nSEh5uOqb1Da7PCRpndM8pLYZKCRpndM8pLYZKCRpnau+SW0zUEhSPWdNSW1xUKYk1XPWlNQWA4Uk\nbcZZU1JL7PKQJEm5GSgkSVJudnmoZdVqNl3frmVJ0jpbKNS01VWYnob9++HQoWyG3fQ0HHeLA0ka\negYKNW12FhYbtzhgcRFmZsqpj9SsahUOH84mbkgqhoFCTVnf4mBtrbG8fosDqdfYqiZ1j4FCTXGL\nA/UjW9Wk7jFQqClucaB+Y6ua1F0GCjXFLQ7Ub2xVk7rLQDHsWhit5hYH6ie2qknd5ToUw2p1Netg\nXlg4WTY1laWD3bs3PcUtDtRP1lvVFhcbuz1GRrIg7HtX6ixbKIZVjtFqY2Nw8KB/kNX7bFWTuscW\nimG0Plpto/rRaqYFDQBb1aTuMVAMo2ZGq/lXVwPEjUOl4hkohpGj1TRs3IBGKpxjKIaRc0A1LFwq\nU+oaA8WwcrSahoFLZUpdY5fHsHK0mgadg4+lrjJQDDtHq2lQOfhY6iq7PCQNJgcfS11loJA0mBx8\nLHWVgUJd0cKWIVLnOPhY6hrHUKhQbWwZInWOg4+lrjFQqFDbzdqbny+nTupvba1R5eBjqXB2eagw\n67P26nd6hMZZe1KzXKNK6m0GChWmmVl7UrNco0rqbYUGiojYHRHvjYgTEXE8It4REWfucM7HI+LB\nuq+1iLi2yHqqGM7aU6fY2iX1vqJbKN4HnAdcAvws8DPA23c4JwF/CuwB9gKPAl5ZYB1VEGftqVNs\n7ZJ6X2GBIiKeCEwBv5pS+mxK6VPAfwZ+OSL27nD6vSmlu1JK36h9/WtR9VSxnLWnTrC1S+p9RbZQ\nXAQcTyndVFe2SNYC8dQdzr0iIu6KiC9ExBsi4gcKq6Va18KiEuuz9qpVuO667L/z804ZVWts7ZJ6\nX5HTRvcC36gvSCmtRcRq7bmtvBf4KnAH8CTgjcA4cHlB9VSzciwq4aw95VWpZAMw699+tnZJvaPl\nQBERVwGv2uaQRDZuoi0ppXfUPfxiRBwDFiPi3JTSbVudNzc3x65duxrKZmZmmHEIeOe4qIRK5BpV\nUj6VSoXKhgR+4sSJjr1+pJRaOyHibODsHQ77CvCfgDenlL5/bESMAPcDl6eUPtjk93s48K/AVErp\nI5s8fwA4cuTIEQ4cONDkT6GWVavZAgDbPe9fd0nqK0tLS0xMTABMpJSW8rxWyy0UKaV7gHt2Oi4i\nPg2cFREX1I2juAQI4MYWvuUFZK0eX2+1ruogt4KWJG2jsEGZKaVbgQXgzyLiyRHxU8AfAZWU0jGA\niHh0RNwSERfWHj8hIl4dEQci4nERcRnw58AnUko3F1VXNcFh9uoV7jQn9aSi16GYBW4lm93xIeCT\nwIvrnj+DbMDlw2uPvwtMkgWRW4A3AX8JXFZwPbUTh9mrbK69LfW0QjcHSyl9E3juNs9/FRipe/xP\nwDOLrJNycJi9yuSgYKmnuduomucwe5Vlfe3tjerX3va9KJXKQKHWuaiEus1BwVLPM1Cop1Wr2b3E\nxpAh56Bgqee5fbl6kuPv1MBBwVLPM1CoJ203/k6DoeXZn+40J/U0uzzUcxx/N9ja3hLGQcFST7OF\nQt3RwsfRZsbfqX/lbn0aG4ODBw0TUo8xUKhYbQyGcPzd4FpvfVpbayyvb32S1J8MFCpWGx9HHX83\nuGx9kgaXgULFyfFx1PF3g8nWJ2lwGShUnBwfR9fH31WrcN112X/n53cYtKeeZ+uTNLic5aHidODj\nqItyDp71LWG+slBlHyssM8q+yTFbn6Q+Z6BQcdY/ji4uNnZ7jIxk/RcmhaG0O60yzyzZpsLrpoAK\nYBOU1K/s8lCxHAyhjVy1TBpItlCoWC5GpHquWiYNLAOFuqPdwRDuDjZY3DVUGlh2eag3uTvYYHLe\nqDSwDBTqTTn62VvedErd47xRaWAZKNR72lwQy0aN8rQU4hyoKw0kA4V6T5sLYjl5oPvaCnGuWiYN\nJAOFek8b/exuOlWO9RA3RpVpDjPK0eZDnLuGSgPFQKHe00Y/u5tOdV+1Cp9ZWOVDa9NU2c9hDnGU\ncT60Ns2NC8cNcdKQMVCoN7XYz17fqFH/aXmdkwc6b2UF3scskzT2M02ySIUZQ5w0ZFyHQr2pxQWx\nxsfhF561yguvn2WqbknnBab4s4srjI3ZP99pTzytyrmcukjV6awxzQK3n34UsDtDGhYGCvW2FhbE\nqsQsp8UipJNlk7HIxcwA88XUb4id++D2/UyPf2AZA4U0POzy0GCoVjnjYwuMpMZRmSNpjTM+5qjM\nljQ7B9RFqiTVMVBoMOQdlelqWK3PAXWRKkl1DBQaDO1+WnY1rJPaWcjDRaok1RgoNBja/bQ8yKth\ntdLq0u5CHi5SJanGQKHB0eqn5U6shtWtrpJWvk87rS55u4xcpEoaegYKDY5WPy3nuYnm6SopOhy0\n0+riAEtJORkoNHia/bSc5ybazk27G+Gg3VaX8XG+e/EUa9HYZbQWI3zvYgdYStqZgUJDq8o480zx\nAI030QcYYZ4pjm61hkK7N+1uhIMcrS6zVFhMjV1Gi2mSGRxgKWlnBgoNrZUVmKHCIhtuomQ30S3v\nve3ctLsVDtpsdalW4QMf280084xR5SDX1ZYwn+cDH9s91LNpJTXHlTI1tPbtg2+ym4PMM8pRRllm\nmVGWay0TW/Z4tHPTbiYcbOxWaOf7rM92WVxsDC8jI9kA1S26Luqrt8zY96/BdtWTpHq2UGho1c80\nXWaMeQ6yzNjO6zK1M0U1TzhodSpsG2tDOCZTUl4GCg21ttdlavXELoaDdtaGcNFLSXlFSmnno3pY\nRBwAjhw5coQDBw6UXR31qSY3NT3FbR8+yt03LPNDF43y+Et3OPH48WwA5kLdDp1TU1k42GkhqHYr\n2II81ZPUn5aWlpiYmACYSCkt5Xktx1BItLSpKZDNAJ2dhYWFMdZ31Nzx5tviluy5KkjWMLGy0vy3\nyVM9STJQSG3Ybgbo/E47pbcRDlpxMuycLGulpaHg6kkaUIWNoYiI34mIv4+Ib0fEagvnvT4i7oiI\neyPiIxHhcDD1lE6s2F2kQd6eRFLvKnJQ5hnA+4E/afaEiHgV8BvAi4CnAN8GFiLiIYXUUGpDGTul\nN3tOr4cdSYOrsECRUnpdSumtwBdaOO3lwO+llD6UUroZeB7waODfF1FHqR3d3Cm91XPyhh1JalfP\nTBuNiHOBvcBH18tSSt8CbgQuKqte0kbd3Cm91XNcT0JSWXomUJCFiQTcuaH8ztpzUs/oxk7p7Zzj\nehKSytJSoIiIqyLiwW2+1iJivKjKSr2iGzult9t90fZiXZKUQ6vTRt8MvHOHY77SZl2OAQHsobGV\nYg9w004nz83NsWvXroaymZkZZhzargI1O8Wyna6IdrsvXE9C0mYqlQqVDZ8sTpw40bHXL3ylzIi4\nErgmpfSDTRx7B/CmlNI1tcePJAsXz0sp/eUW57hSpvrC9PTWe3ZttXZFO+dIUrM6uVJmketQPDYi\nzgceB4xExPm1rzPrjrk1Ip5Td9ofAK+OiH8XEf8WeDfwT8AHi6qn1C3tdEXYfSGpXxS5UubryaZ9\nrltPPs8CPln79xjw/X6KlNIbI+LhwNuBs4C/Aw6mlL5bYD2lrminK8LuC0n9orBAkVJ6PvD8HY4Z\n2aTstcBri6mVVL52lrZ2OWxJva6Xpo1KkqQ+ZaCQJEm5GSgkSVJuBgpJkpSbgUKSJOVmoJAkSbkZ\nKCRJUm4GCkmSlJuBQpIk5WagkCRJuRkoJElSbgYKSZKUm4FCkiTlZqCQJEm5GSgkSVJuBgpJkpSb\ngUKSJOVmoJAkSbkZKCRJUm4GCkmSlJuBQpIk5WagkCRJuRkoJElSbgYKSZKUm4FCkiTlZqCQJEm5\nGSgkSVJuBgpJkpSbgUKSJOVmoJAkSbkZKCRJUm4GCkmSlJuBQpIk5WagkCRJuRkoJElSbgYKSZKU\nm4FCkiTlZqCQJEm5GSgkSVJuBgpJkpSbgWLAVCqVsqvQE7wOGa/DSV6LjNfhJK9FZxUWKCLidyLi\n7yPi2xGx2uQ574yIBzd8XVdUHQeRvyAZr0PG63CS1yLjdTjJa9FZpxf42mcA7wc+DbyghfMOA78C\nRO3xdzpbLUmS1GmFBYqU0usAIuLKFk/9TkrprgKqJEmSCtKLYyieGRF3RsStEXFtRPxg2RWSJEnb\nK7LLox2HgQ8AtwH7gKuA6yLiopRS2uKchwHccsst3alhjztx4gRLS0tlV6N0XoeM1+Ekr0XG63CS\n16Lh3vmwvK8VW9+nNzk44irgVdsckoDzUkrVunOuBK5JKbXc0hAR5wIrwCUppeu3OGYWeG+rry1J\nkr7vipTS+/K8QKstFG8G3rnDMV9psy6nSCndFhF3A6PApoECWACuAG4H7u/U95YkaQg8DHg82b00\nl5YCRUrpHuCevN+0WRHxGOBs4Os71ClXqpIkaYh9qhMvUuQ6FI+NiPOBxwEjEXF+7evMumNujYjn\n1P59ZkS8MSKeGhGPi4hLgL8GqnQgOUmSpOIUOSjz9cDz6h6vj3x5FvDJ2r/HgF21f68BT6qdcxZw\nB1mQ+N2U0vcKrKckScqppUGZkiRJm+nFdSgkSVKfMVBIkqTcBipQRMTtGzYWW4uIV5Zdr26IiJdF\nxG0RcV9E3BARTy67Tt0WEa/ZZHO5L5Vdr6JFxNMj4m8i4p9rP/Nlmxzz+oi4IyLujYiPRMRoGXUt\n2k7XYlg2IIyI346Iz0TEt2orD/+fiBjfcMxDI+K/R8TdEfEvEfFXEfHDZdW5CE1eh49vct+4tqw6\nFyEiXhIRn4+IE7WvT0XEdN3zHXkvDFSgIFtY69XAHmAv8Cjgj0qtURdExC8BbwFeA1wAfB5YiIhz\nSq1YOW7m5P//vcBPl1udrjgT+BzwUrLfgQYR8SrgN4AXAU8Bvk32/nhINyvZJdtei5rDNL5HZrpT\nta56OtnfvqcCk2SbNX44In6g7pg/AH4W+AXgZ4BHk61UPEiauQ4J+FMa7xuD9kH0H8kWpTwATAAf\nAz4YEefVnu/MeyGlNDBfZEt2/2bZ9Sjh574BeGvd4wD+CXhl2XXr8nV4DbBUdj1KvgYPApdtKLsD\nmKt7/EjgPuAXy65vCdfincD/LrtuJVyLc2rX46fr3gPfAf5D3TH7a8c8pez6dus61MquB64uu24l\nXIt7gOd38r0waC0UAP+l1myzFBGviIiRsitUpIg4gyxxfnS9LGXviEXgorLqVaKxWnP3SkS8JyIe\nW3aFylRTndg+AAAEBUlEQVRbvn4vje+PbwE3MpzvDxjODQjPIvskvlp7PEG2bED9++LLwNcY7PfF\nxuuw7oqIuCsivhARb9jQgjFQIuK0iPhl4OHAp+nge6HXNgfL661k612sAk8D/hvZH9NXlFmpgp0D\njAB3bii/kyxlDpMbgF8BvkzWbPla4JMR8eMppW+XWK8y7SX7A7rZ+2Nv96tTunY2IOxrERFkTdr/\nL6W0PqZoL/DdWrisN7Dviy2uA2R7QX2VrCXvScAbgXHg8q5XskAR8eNkAeJhwL+QtUjcGhEX0KH3\nQs8HilY2JEsp/UFd+c0R8T3gbRHx28nFsQZeSql+RdWbI+IzZH8ofpGd96DREEgpvb/u4Rcj4gtk\nGxA+k633C+p31wI/xnCMJ9rO+nX4qfrClNI76h5+MSKOAYsRcW5K6bZuVrBgtwLnky0meTnw7oj4\nmU5+g54PFOTbkOxGsp/x8cDRDtapl9xNtsrong3le4Bj3a9O70gpnYiIKtnmcsPqGNmYmj00tlLs\nAW4qpUY9JDW3AWHfiog/Bg4BT08p3VH31DHgIRHxyA2fTAfy78aG67Dl3lA1N5L9zoyStWQNhJTS\nA5y8V94UEU8BXg68nw69F3p+DEVK6Z5a68N2Xw9scfoFZANLvtHFKndVreXlCHDJelmtae8SOrTh\nS7+KiEeQNWvv9AdkYNU+YR2j8f3xSLJR70P9/oDmNiDsV7Wb6HOAZ6WUvrbh6SPAAzS+L/YDP0rW\nLD4wdrgOm7mArOV74N4TG5wGPJQOvhf6oYWiKRHxk2R/JK8n6x96GnA18L9SSifKrFsXXA28KyKO\nAJ8B5sgG3LyrzEp1W0S8Cfi/ZN0cPwK8juwXpVJmvYoW2YZ7o2SfqgCeENnGfKsppX8k6zd+dUQs\nA7cDv0c2C+iDJVS3UNtdi9rXa8jGUByrHff7DOAGhLV1FGaAy4BvR8R6C+aJlNL9KaVvRcT/AK6O\niONkfzP/EPj7lNJnyql15+10HSLiCcAscB3ZrIfzyf6efiKldHMZdS5CRLyBbPzQ14B/A1wBPAN4\ndkffC2VPXengFJgLyNLUKtk8+5vJ5hKfUXbduvTzv5TsZnFf7TpcWHadSrgGFbIb5X21X5z3AeeW\nXa8u/NzPIGuJW9vw9T/rjnkt2aCze8lunqNl17vb14JsMNo8WZi4n6z590+AHyq73gVch82uwRrw\nvLpjHkq2RsPdtZvIXwI/XHbdu3kdgMcAHwfuqv1ufJlsoO4jyq57h6/DO2rv9/tq7/8PAxd3+r3g\n5mCSJCm3nh9DIUmSep+BQpIk5WagkCRJuRkoJElSbgYKSZKUm4FCkiTlZqCQJEm5GSgkSVJuBgpJ\nkpSbgUKSJOVmoJAkSbn9f7rDhhUjNFAgAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f947bd82c18>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"m = IntegerPredictor(1, 20, 1)\n",
"loss_fn = nn.MSELoss()\n",
"optimizer = optim.Adam(m.parameters(), lr = 0.01)\n",
"for i in range(50):\n",
" data = batch_data(5)\n",
" x = Variable(data[0].view(26, 5, 1))\n",
" hidden = Variable(torch.zeros(1, 5, 20))\n",
" optimizer.zero_grad()\n",
" p = m(x, hidden)\n",
" t = Variable(data[1])\n",
" l = loss_fn(p, t)\n",
" if i % 10 == 0:\n",
" print(l.data[0]) \n",
" l.backward()\n",
" optimizer.step()\n",
" \n",
"m.eval()\n",
"a, b = batch_data(5)\n",
"h = Variable(torch.zeros(1, 5, 20))\n",
"t = Variable(b)\n",
"p = m(Variable(a.view(26, 5, 1)), h)\n",
"plt.scatter(np.arange(0, 26), b[0].numpy(), color='b')\n",
"plt.scatter(np.arange(0, 26), p[0].data.numpy(), color='r')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"** Seq2Seq model to reverse string, handles minibatch with variable-length sequences**\n",
"\n",
"Eg.: A sequence to sequence model that reverse sequences."
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" 13 9 7 16 26\n",
" 21 22 24 26 26\n",
" 3 25 19 26 26\n",
" 6 1 21 26 26\n",
" 9 4 26 26 26\n",
"[torch.LongTensor of size 5x5]\n",
" \n",
" 16 7 9 26\n",
" 24 22 26 26\n",
" 19 25 26 26\n",
" 21 1 26 26\n",
" 4 26 26 26\n",
"[torch.LongTensor of size 5x4]\n",
" \n",
" 7 9 13 26\n",
" 22 21 26 26\n",
" 25 3 26 26\n",
" 1 6 26 26\n",
" 9 26 26 26\n",
"[torch.LongTensor of size 5x4]\n",
"\n"
]
}
],
"source": [
"def batch_data2(size, max_len, pad_id):\n",
" lengths = []\n",
" x = torch.LongTensor(size, max_len)\n",
" y_i = torch.LongTensor(size, max_len-1)\n",
" y_t = torch.LongTensor(size, max_len-1)\n",
" for i in range(size):\n",
" l = np.random.choice(np.arange(2, max_len + 1), 1)[0]\n",
" a = ''.join(random.choice(string.ascii_lowercase) for _ in range(l))\n",
" b = a[::-1]\n",
" b_i = b[:-1]\n",
" b_t = b[1:]\n",
" x[i] = torch.from_numpy(np.array([ord(i)-97 for i in a] + [pad_id] * (max_len - l)))\n",
" y_i[i] = torch.from_numpy(np.array([ord(i)-97 for i in b_i] + [pad_id] * (max_len - l)))\n",
" y_t[i] = torch.from_numpy(np.array([ord(i)-97 for i in b_t] + [pad_id] * (max_len - l)))\n",
" lengths.append(l)\n",
" sort_idx = np.argsort(-np.array(lengths))\n",
" return x[torch.LongTensor(sort_idx)], y_i[torch.LongTensor(sort_idx)], y_t[torch.LongTensor(sort_idx)], np.array(lengths)[sort_idx]\n",
"\n",
"x, y, z, l = batch_data2(5, 5, 26)\n",
"print(x, y, z)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Encoder inputs, Decoder inputs, Decoder targets"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" \n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
" self.enc_embed = nn.Embedding(26 + 1, 100, 26)\n",
" self.dec_embed = nn.Embedding(26 + 1, 100, 26)\n",
" self.encoder = nn.LSTM(100, 100, 1, batch_first=True)\n",
" self.decoder = nn.LSTM(100, 100, 1, batch_first=True)\n",
" self.dec_out = nn.Linear(100, 26)\n",
" self.init_weights()\n",
" \n",
" def init_weights(self):\n",
" initrange = 0.1\n",
"# self.enc_embed.weight.data.uniform_(-initrange, initrange)\n",
"# self.dec_embed.weight.data.uniform_(-initrange, initrange)\n",
" self.dec_out.bias.data.fill_(0)\n",
" \n",
" def forward(self, enc_inputs, dec_inputs, lengths):\n",
" enc_emb = self.enc_embed(enc_inputs)\n",
" enc_emb = pack_padded_sequence(enc_emb, lengths, batch_first=True)\n",
" dec_emb = self.dec_embed(dec_inputs)\n",
" dec_emb = pack_padded_sequence(dec_emb, lengths - 1, batch_first=True)\n",
" \n",
" h0, c0 = Variable(torch.zeros(1, 5, 100)), Variable(torch.zeros(1, 5, 100))\n",
" oe, (ht, ct) = self.encoder(enc_emb, (h0, c0))\n",
" ht, ct = ht[-1], ct[-1]\n",
" od, (_, _) = self.decoder(dec_emb, (ht.view(1, ht.size(0), ht.size(1)), ct.view(1, ct.size(0), ct.size(1))))\n",
" od = pad_packed_sequence(od)[0].transpose(0, 1)\n",
" od = od.contiguous().view(od.size(0)*od.size(1), od.size(2))\n",
" logits = self.dec_out(od)\n",
"# logits = logits.view(5, -1, 26)\n",
" logits = logits.view(-1, 26)\n",
" probs = F.softmax(logits).view(5, -1, 26)\n",
" return pack_padded_sequence(probs, lengths - 1, batch_first=True)"
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.277252197265625\n",
"0.9902730584144592\n",
"0.9568517208099365\n",
"0.6452716588973999\n",
"0.2598739266395569\n",
"0.6236557960510254\n",
"0.2031468003988266\n",
"0.17586927115917206\n",
"[19 17 9 22 19 19 25 6 18 13 19 4 24 14 1 13 2 18 6 19 16 9 25 22 13\n",
" 12 25 15 18]\n",
"[ 5 17 9 22 19 19 25 6 18 13 19 4 24 14 1 13 2 18 6 19 16 9 25 22 13\n",
" 12 14 15 18]\n"
]
}
],
"source": [
"m = Model()\n",
"criterion = nn.NLLLoss()\n",
"optimizer = optim.Adam(m.parameters(), lr=1e-3)\n",
"for i in range(4000):\n",
" data = batch_data2(5, 10, 26)\n",
" lengths = data[3]\n",
" x = Variable(data[0])\n",
" dec_in = Variable(data[1])\n",
" dec_tar = Variable(data[2]).view(5, -1, 1)\n",
" optimizer.zero_grad()\n",
" p = m(x, dec_in, lengths)\n",
" p_log = torch.log(p.data)\n",
" dec_tar = pack_padded_sequence(dec_tar, lengths - 1, True).data.view(-1)\n",
" loss = criterion(p_log, dec_tar)\n",
" if i % 500 == 0:\n",
" print(loss[0].data[0])\n",
" loss.backward()\n",
" optimizer.step()\n",
" if i == 4000 - 1:\n",
" print(p_log.data.numpy().argmax(-1))\n",
" print(dec_tar.data.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Descending loss, Predicted outputs, Decoder targets"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment