Skip to content

Instantly share code, notes, and snippets.

@tims457
Created May 24, 2020 20:42
Show Gist options
  • Save tims457/60d1b1688dfab3e02cbd9bbd83b49a66 to your computer and use it in GitHub Desktop.
Save tims457/60d1b1688dfab3e02cbd9bbd83b49a66 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"import torch.nn.functional as F\n",
"import matplotlib.pyplot as plt\n",
"from torch.utils.data import DataLoader, Dataset\n"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"class RNNCell(nn.Module):\n",
" \n",
" def __init__(self, inputSize, hiddenSize, outputSize):\n",
" super(RNNCell, self).__init__()\n",
" self.Wx = torch.randn(hiddenSize, inputSize) # input weights\n",
" self.Wh = torch.randn(hiddenSize, hiddenSize) # hidden weights\n",
" self.Wy = torch.randn(outputSize,recurhiddenSizerentSize) # output weights\n",
" self.h = torch.zeros(hiddenSize,1) # initial hidden state\n",
" self.bh = torch.zeros(hiddenSize,1) # hidden state bias\n",
" self.by = torch.zeros(outputSize,1) # output bias\n",
"\n",
" def forward(self, x):\n",
" self.h = torch.tanh(self.bh + torch.matmul(self.Wx, x) + torch.matmul(self.Wh,self.h))\n",
" output = nn.Softmax(self.by + torch.matmul(self.Wy,self.h))\n",
" \n",
" return output, self.h"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'x')"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X = torch.sin(torch.linspace(0,100,100000))\n",
"plt.plot(X)\n",
"plt.ylabel('Sin x')\n",
"plt.xlabel('x')"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class RNNData(Dataset):\n",
" def __init__(self, X, sequenceLength):\n",
" 'Initialization'\n",
" self.X = X\n",
" self.sequenceLength = sequenceLength\n",
"\n",
" def __len__(self):\n",
" 'Denotes the total number of samples'\n",
" return int(torch.floor(torch.tensor(len(self.X)/self.sequenceLength)))\n",
" \n",
" def __getitem__(self, index):\n",
" sequence = self.X[index:index+self.sequenceLength]\n",
" y = self.X[index+self.sequenceLength+1]\n",
" return sequence, y"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"#hyperparameters\n",
"batchSize = 100 \n",
"sequenceLength = 50\n",
"numLayers = 1\n",
"hiddenSize = 4\n",
"learningRate = 0.01\n",
"epochs = 100"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0.1977, 0.1987, 0.1997, ..., 0.2435, 0.2445, 0.2455],\n",
" [0.4059, 0.4069, 0.4078, ..., 0.4484, 0.4493, 0.4502],\n",
" [0.8912, 0.8917, 0.8921, ..., 0.9115, 0.9119, 0.9124],\n",
" ...,\n",
" [0.8117, 0.8123, 0.8128, ..., 0.8382, 0.8388, 0.8393],\n",
" [0.9963, 0.9964, 0.9965, ..., 0.9992, 0.9993, 0.9993],\n",
" [0.3977, 0.3986, 0.3995, ..., 0.4404, 0.4413, 0.4422]])\n",
"tensor([0.2474, 0.4520, 0.9132, 0.9972, 0.4132, 0.9986, 0.2580, 0.6272, 0.7695,\n",
" 0.9936, 0.9777, 0.9785, 0.9365, 0.8152, 0.9614, 0.7420, 0.9471, 0.9474,\n",
" 0.9379, 0.9659, 0.5778, 0.9697, 0.8966, 0.9992, 0.9080, 0.7901, 0.8930,\n",
" 0.3485, 0.8537, 0.3012, 0.9874, 0.8338, 0.8894, 0.5819, 0.6020, 0.9029,\n",
" 0.9625, 0.8928, 0.9428, 0.9291, 0.9977, 0.8644, 0.3728, 0.6233, 0.1018,\n",
" 0.5346, 0.9990, 0.5505, 0.9744, 0.9530, 0.9150, 0.7413, 0.9313, 0.5932,\n",
" 0.9987, 0.9499, 0.0580, 0.9286, 0.1425, 0.9862, 0.9980, 0.9687, 0.2026,\n",
" 0.9423, 0.7004, 0.6846, 0.9855, 0.8238, 0.7033, 0.2812, 0.2974, 0.8624,\n",
" 0.8140, 0.8669, 0.6713, 0.9928, 0.9338, 0.9842, 0.5996, 0.9275, 0.9999,\n",
" 0.9997, 0.9934, 0.9819, 0.1237, 0.9384, 0.5972, 0.9973, 0.1455, 0.2888,\n",
" 0.9461, 0.9804, 0.9983, 0.9956, 0.8076, 0.1207, 0.4368, 0.8404, 0.9994,\n",
" 0.4440])\n"
]
}
],
"source": [
"data = RNNData(X,sequenceLength)\n",
"dataLoader = DataLoader(data, batch_size=batchSize, shuffle=True)\n",
"for x,y in dataLoader:\n",
" print(x)\n",
" print(y)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"# create our RNN based network with an RNN followed by a linear layer\n",
"class RNN(nn.Module):\n",
" def __init__(self, inputSize, hiddenSize, numLayers):\n",
" super().__init__()\n",
" self.RNN = nn.RNN(input_size=inputSize, \n",
" hidden_size=hiddenSize, \n",
" num_layers=numLayers, \n",
" nonlinearity='tanh', \n",
" batch_first=True) #inputs and outputs are (batch, seq, feature)\n",
" self.linear = nn.Linear(hiddenSize,1)\n",
" \n",
" def forward(self,x,hState):\n",
" x, h = self.RNN(x,hState)\n",
" out = self.linear(x[:,-1,:]) # gets last output\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"# create our network instance, pick loss function and optimizer\n",
"model = RNN(1,hiddenSize,numLayers)\n",
"lossFn = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=learningRate)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([100, 1])"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# check output to see if everything is setup correctly\n",
"ytest = model(torch.randn(batchSize,sequenceLength,1),torch.zeros([numLayers, batchSize, hiddenSize]))\n",
"ytest.shape"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11.319549560546875\n",
"1.7751704454421997\n",
"1.4014699459075928\n",
"1.1168690919876099\n",
"0.5675828456878662\n",
"0.05716589093208313\n",
"0.010280625894665718\n",
"0.004087743349373341\n",
"0.0029274208936840296\n",
"0.0025021100882440805\n",
"0.002253422513604164\n",
"0.002016317332163453\n",
"0.0018425789894536138\n",
"0.0017580073326826096\n",
"0.0015311933821067214\n",
"0.0013432155828922987\n",
"0.001226308522745967\n",
"0.0011239354498684406\n",
"0.0010009920224547386\n",
"0.0009049549116753042\n",
"0.0008012366015464067\n",
"0.0007353670080192387\n",
"0.0006402007420547307\n",
"0.0005788293201476336\n",
"0.0005184737383387983\n",
"0.0004636308876797557\n",
"0.00041253535891883075\n",
"0.00036798190558329225\n",
"0.00034769228659570217\n",
"0.0002972030488308519\n",
"0.0002639765734784305\n",
"0.00022841084864921868\n",
"0.00021119693701621145\n",
"0.00018246164836455137\n",
"0.00017494821804575622\n",
"0.0001459053164580837\n",
"0.00012809678446501493\n",
"0.00011152001388836652\n",
"0.00010181981633650139\n",
"9.421376307727769e-05\n",
"8.516746311215684e-05\n",
"9.344099089503288e-05\n",
"7.571556488983333e-05\n",
"6.939681770745665e-05\n",
"6.255778134800494e-05\n",
"6.651831790804863e-05\n",
"6.726705760229379e-05\n",
"5.616387352347374e-05\n",
"5.658273585140705e-05\n",
"6.0519552789628506e-05\n",
"5.55058904865291e-05\n",
"5.864330523763783e-05\n",
"5.548787521547638e-05\n",
"5.9752503148047253e-05\n",
"4.80966227769386e-05\n",
"5.113997031003237e-05\n",
"7.44025019230321e-05\n",
"5.991280704620294e-05\n",
"4.653463838621974e-05\n",
"5.885814243811183e-05\n",
"4.6151482820278034e-05\n",
"4.450901906238869e-05\n",
"4.508296478888951e-05\n",
"4.469474151846953e-05\n",
"4.367278961581178e-05\n",
"4.332353637437336e-05\n",
"4.1468832932878286e-05\n",
"4.0118800825439394e-05\n",
"4.7187353629851714e-05\n",
"4.421977064339444e-05\n",
"4.5634435082320124e-05\n",
"6.0528047470143065e-05\n",
"5.609228537650779e-05\n",
"4.7935765906004235e-05\n",
"4.159270611125976e-05\n",
"4.1005550883710384e-05\n",
"4.603252455126494e-05\n",
"4.5191103708930314e-05\n",
"4.270176577847451e-05\n",
"4.48863283963874e-05\n",
"4.341293970355764e-05\n",
"4.052995427628048e-05\n",
"3.95571296394337e-05\n",
"4.4394877477316186e-05\n",
"4.054577220813371e-05\n",
"3.578514952096157e-05\n",
"3.6663273931480944e-05\n",
"5.189451258047484e-05\n",
"3.78264558094088e-05\n",
"3.903014658135362e-05\n",
"4.4174295908305794e-05\n",
"3.90354725823272e-05\n",
"4.6599816414527595e-05\n",
"4.038702536490746e-05\n",
"4.039769555674866e-05\n",
"3.597330942284316e-05\n",
"5.0886817916762084e-05\n",
"3.824286250164732e-05\n",
"4.622007327270694e-05\n",
"3.545684376149438e-05\n"
]
},
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Loss')"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEWCAYAAABhffzLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAWXElEQVR4nO3de7BlZX3m8e+zL1zlTqNAAw3xFkW8TKuoGSuFmqBRMTEVcdRBxwkzqZnROI4ZramUc82lykmMFcupjqBmdHRG1IiWZSR4LxO0QSRAq6By04Y+KAhegNOnf/PHXuf0Pqf7NMfu3nvTa30/VbvO3muvs9939dr99Nu/9a61UlVIkrqjN+sOSJKmy+CXpI4x+CWpYwx+SeoYg1+SOsbgl6SOMfglqWMMfmlMkpuSPHfW/ZAmyeCXpI4x+KU1SPK7SW5M8qMklyY5qVmeJH+eZFuSHye5JsmZzXsvSHJ9knuTfD/Jf5jtVkgjBr/0IJKcA/wx8DvAicDNwIeat38NeDbwaOBo4GXAD5v3LgL+VVUdAZwJfHaK3ZZWNZh1B6QDwCuAi6vqKoAkbwHuSrIBmAeOAB4LfLWqtoz93jzwuCTfqKq7gLum2mtpFY74pQd3EqNRPgBV9RNGo/qTq+qzwF8C7wTuSLIpyZHNqi8FXgDcnOQLSZ4x5X5Lu2XwSw/uB8Bpiy+SHA4cB3wfoKreUVX/BHg8o5LPm5rlX6uq84ATgL8B/t+U+y3tlsEv7WqY5JDFB6PAfk2SJyU5GPgj4IqquinJU5M8PckQ+ClwH7CQ5KAkr0hyVFXNA/cACzPbImmMwS/t6lPAz8ce/xT4Q+AjwFbgl4Dzm3WPBP6KUf3+ZkYloLc1770KuCnJPcC/Bl45pf5LexRvxCJJ3eKIX5I6xuCXpI4x+CWpYwx+SeqYA+LM3eOPP742bNgw625I0gHlyiuvvLOq1q1cfkAE/4YNG9i8efOsuyFJB5QkN+9uuaUeSeoYg1+SOsbgl6SOMfglqWMMfknqGINfkjrG4Jekjml18H/s67fxgSt2O41Vkjqr1cH/iW9s5UNfvXXW3ZCkh5RWB/+gF7bv8H4DkjSu3cHfD9sXdsy6G5L0kNLu4O/1HPFL0grtDv5+mHfEL0nLtDv4e2HBEb8kLdPu4O/3mF8w+CVpXKuDf9gL23dY6pGkca0O/kG/x4Ijfklapt3B3wvzjvglaZl2B38/bHfEL0nLtDv4m3n8VYa/JC1qefAHwCmdkjSm3cHfH22eZ+9K0k6tDv5hfzTi9+xdSdqp1cG/WOrxAK8k7dTq4O9b6pGkXbQ6+IeLI37n8kvSklYH/9LBXUs9krSk3cG/NOI3+CVp0cSCP8nFSbYluXZs2bFJLktyQ/PzmEm1D6MzdwHvwiVJYyY54n8vcO6KZW8GLq+qRwGXN68nZtAbbZ6XZpaknSYW/FX1ReBHKxafB7yvef4+4CWTah92zuP34K4k7TTtGv/Dq2orQPPzhNVWTHJhks1JNs/Nze1VY31r/JK0i4fswd2q2lRVG6tq47p16/bqM4bO6pGkXUw7+O9IciJA83PbJBvbeeaupR5JWjTt4L8UuKB5fgHw8Uk25kXaJGlXk5zO+UHg74HHJLktyWuBPwGel+QG4HnN64kZeOauJO1iMKkPrqqXr/LWcybV5kqDpatzOuKXpEUP2YO7+4MHdyVpV60O/r6lHknaRauDf9hzxC9JK7U6+AeeuStJu+hE8HtwV5J2anfwN6WeBefxS9KSdge/N1uXpF20OviXDu464pekJa0O/sXpnJZ6JGmnVgf/0FKPJO2i1cGfhH4vzuOXpDGtDn4YXaht3nn8krSkE8G/4Ihfkpa0P/j7PWf1SNKY1gf/sB8P7krSmNYHf78Xp3NK0pjWB/+g1/NaPZI0pvXBP+zHq3NK0pjWB/+g33MevySNaX/w9xzxS9K49gd/3zN3JWlc+4O/12PeWT2StKQDwR+2O49fkpa0P/j78cxdSRrT+uAf9nuO+CVpTOuDfzSrxxG/JC2aSfAneUOS65Jcm+SDSQ6ZVFv9nvP4JWnc1IM/ycnA64CNVXUm0AfOn1R7nrkrScvNqtQzAA5NMgAOA34wsYY8c1eSlpl68FfV94G3AbcAW4EfV9VnVq6X5MIkm5Nsnpub2+v2ht6BS5KWmUWp5xjgPOB04CTg8CSvXLleVW2qqo1VtXHdunV73V7fO3BJ0jKzKPU8F/heVc1V1TzwUeCZk2ps0PfMXUkaN4vgvwU4O8lhSQI8B9gyqcaGfc/claRxs6jxXwFcAlwF/GPTh02Taq/f8yJtkjRuMItGq+qtwFun0dbQm61L0jIdOXPXUo8kLWp/8PdH99ytctQvSdCF4O8FAKs9kjTS/uDvj4J/3pk9kgR0IPiHvdEmeoBXkkZaH/z9ptTjXH5JGml98A+bUo8jfkkaaX3wD/pNqceTuCQJ6ELw9zy4K0nj2h/8TalnwVKPJAFdCP6lWT2O+CUJOhD8w6V5/I74JQk6EPz9ngd3JWlc64N/sDSd01KPJEEHgt8zdyVpudYHv9fqkaTl2h/8S5dscMQvSdCF4G/O3HUevySNtD/4PXNXkpZpf/B7kTZJWqb9we+sHklapvXBv3RZZks9kgR0IPi9LLMkLdf+4F88uOuZu5IEdCj4nc4pSSPtD/6m1OPVOSVpZCbBn+ToJJck+WaSLUmeMam2Bt5sXZKWWVPwJ/mlJAc3z381yeuSHL0P7f4F8OmqeizwRGDLPnzWHjmPX5KWW+uI/yPAQpJHAhcBpwP/Z28aTHIk8Ozmc6iqB6rq7r35rLUYej1+SVpmrcG/o6q2A78JvL2q3gCcuJdtngHMAe9J8vUk705y+MqVklyYZHOSzXNzc3vZFPR6oRevxy9Ji9Ya/PNJXg5cAHyyWTbcyzYHwFOAd1XVk4GfAm9euVJVbaqqjVW1cd26dXvZVNNgr+fBXUlqrDX4XwM8A/gfVfW9JKcD79/LNm8DbquqK5rXlzD6h2BiBv2w4IhfkoDR6PtBVdX1wOsAkhwDHFFVf7I3DVbV7UluTfKYqvoW8Bzg+r35rLUa9OKIX5Iaawr+JJ8HXtysfzUwl+QLVfXv97Ldfwd8IMlBwHcZ/Y9iYob9njV+SWqsKfiBo6rqniT/EnhPVb01yTV722hVXQ1s3Nvf/0X1e3FWjyQ11lrjHyQ5Efgddh7cPWCMRvwGvyTB2oP/vwJ/C3ynqr6W5Azghsl1a/8a9OOZu5LUWOvB3Q8DHx57/V3gpZPq1P7W74V5R/ySBKz9kg3rk3wsybYkdyT5SJL1k+7c/jLs9Viwxi9JwNpLPe8BLgVOAk4GPtEsOyAM+nFWjyQ11hr866rqPVW1vXm8F9i302mnaND3zF1JWrTW4L8zySuT9JvHK4EfTrJj+9Og54hfkhatNfj/BaOpnLcDW4HfZsInXe1PA+fxS9KSNQV/Vd1SVS+uqnVVdUJVvQT4rQn3bb9xHr8k7bQvd+Da28s1TN3ozF1LPZIE+xb82W+9mLBhP474JamxL8F/wCTpoNezxi9JjT2euZvkXnYf8AEOnUiPJmDQD/PO6pEk4EGCv6qOmFZHJslZPZK0076Ueg4Yg36PBWv8kgR0JPiH/TDvrB5JAjoS/P2es3okaVEngn/Q6znil6RGJ4J/2I81fklqdCL4B33n8UvSom4Ef895/JK0qCPB36MKdljukaSOBH9/dFkhR/2S1JXg742C3zq/JHUl+PujzTT4JakjwT9sSj3eflGSZhj8zb17v57kk5Nua9BrRvwe3JWkmY74Xw9smUZDizV+z96VpBkFf5L1wG8A755Ge4uzeqzxS9LsRvxvB/4AWHUInuTCJJuTbJ6bm9unxpYO7lrqkaTpB3+SFwLbqurKPa1XVZuqamNVbVy3bt0+tTnseXBXkhbNYsT/LODFSW4CPgSck+T9k2yw7zx+SVoy9eCvqrdU1fqq2gCcD3y2ql45yTaHlnokaUkn5vHvPLhrqUeS9niz9Umrqs8Dn590O/2l6ZyO+CWpEyP+naUeR/yS1IngX7pImzV+SepG8A+9SJskLelE8O+czmmpR5I6Efw7r87piF+SOhH8O6/O6YhfkjoR/E7nlKSdOhH8HtyVpJ06EfyLZ+4uWOqRpG4E/7Cp8VvqkaSOBH/fe+5K0pJOBP/Ag7uStKQTwb94cHfBefyS1I3gbwb8nrkrSXQk+JMw7Id5R/yS1I3gh9HZu5Z6JKlLwd8P85Z6JKlDwd+LZ+5KEl0K/n7PefySRIeCf+iIX5KADgV/vx+vxy9JdCj4h72eB3cliQ4F/6Afp3NKEl0K/l7Pa/VIEl0K/n6c1SNJdCn4ndUjScAMgj/JKUk+l2RLkuuSvH4a7TqPX5JGBjNoczvwxqq6KskRwJVJLquq6yfZ6LAf7p83+CVp6iP+qtpaVVc1z+8FtgAnT7rdfq/n1TkliRnX+JNsAJ4MXLGb9y5MsjnJ5rm5uX1ua3TmriN+SZpZ8Cd5GPAR4Per6p6V71fVpqraWFUb161bt8/tOY9fkkZmEvxJhoxC/wNV9dFptDnwzF1JAmYzqyfARcCWqvqzabU76If7PLgrSTMZ8T8LeBVwTpKrm8cLJt3oWeuP5vt3/5wrb75r0k1J0kPaLGb1fLmqUlVnVdWTmsenJt3u+U89haMPG/Kuz39n0k1J0kNaZ87cPfzgARc8YwN/t+UOvn3HvbPujiTNTGeCH+DVz9zAocM+/8tRv6QO61TwH3P4Qbz8aafy8W/8gFt/9LNZd0eSZqJTwQ/wu88+nV7g3V/67qy7Ikkz0bngP/GoQ/mtJ6/ng1+7lZvu/OmsuyNJU9e54Ad44689moP7Pf7w49dS5dm8krqlk8F/wpGH8KZzH8OXbriTT1yzddbdkaSp6mTwA7zi6adx1vqj+G+fvJ4f/3x+1t2RpKnpbPD3e+GPfvMJ/PAn9/M/P/OtWXdHkqams8EPcObJR/Gqs0/j/f9wM9vuvW/W3ZGkqeh08AP8s6efxo6Cv7t+26y7IklT0fngf/TDH8bpxx/Op6+7fdZdkaSp6HzwJ+HXH/8IvnLjnR7kldQJnQ9+gHPPfATbdxSf/eYds+6KJE2cwQ+cdfJRnHjUIXz6Wss9ktrP4Ad6vVG55wvfnuNnD2yfdXckaaIM/savP/4R3De/gy9+e27WXZGkiTL4G0/dcAzHHn6Q5R5JrWfwNwb9Hs/75Ydz+ZZt3L99YdbdkaSJMfjHnPuER3Dv/dv50rfvnHVXJGliDP4xv/LI4znmsCGfuOYHs+6KJE2MwT9m2O/x/CecyGXX38HPH7DcI6mdDP4VXnTWSfzsgQUu92QuSS1l8K/wtNOP5YQjDuYT37DcI6mdDP4V+r3wwrNO4nPfmuOe+7x2j6T2Mfh340VPPJEHtu/gM9dZ7pHUPjMJ/iTnJvlWkhuTvHkWfdiTJ51yNKcce6jlHkmtNPXgT9IH3gk8H3gc8PIkj5t2P/YkCS866yS+fOOd/OdLr+Oy6++w7COpNQYzaPNpwI1V9V2AJB8CzgOun0FfVvXqZ23gm7ffy4e+dgvv/cpNAAx64ZBhn4MHPXq90Av0EtL8TpJlnzH+ctlzlq+36u+sus7u31n9U3+xldb0OXv6/VX6p+UW/5RqP33OtK21334b9s1FFzyVU487bL9+5iyC/2Tg1rHXtwFPX7lSkguBCwFOPfXU6fRszAlHHMLFr34q929f4Ou33M2VN9/FT+/fzn3zO7h/+wI7CqqKhR2jr38BNfY3oVj2YndPd1FjH7DaerXKG2v5S1ir/fIv+DmT/YBuqBV/UHsaDCyuv7t1Vn7OtK2l39o3Bw32f2FmFsG/u2/KLt+OqtoEbALYuHHjzL49Bw/6nH3GcZx9xnGz6oIk7VezOLh7G3DK2Ov1gEdRJWlKZhH8XwMeleT0JAcB5wOXzqAfktRJUy/1VNX2JP8W+FugD1xcVddNux+S1FWzqPFTVZ8CPjWLtiWp6zxzV5I6xuCXpI4x+CWpYwx+SeqYrOVszllLMgfcvJe/fjzQxZvodnG7u7jN0M3tdpvX5rSqWrdy4QER/Psiyeaq2jjrfkxbF7e7i9sM3dxut3nfWOqRpI4x+CWpY7oQ/Jtm3YEZ6eJ2d3GboZvb7Tbvg9bX+CVJy3VhxC9JGmPwS1LHtDr4H+o3dd8fkpyS5HNJtiS5Lsnrm+XHJrksyQ3Nz2Nm3df9LUk/ydeTfLJ5fXqSK5pt/r/NZb9bJcnRSS5J8s1mnz+j7fs6yRua7/a1ST6Y5JA27uskFyfZluTasWW73bcZeUeTbdckecov0lZrg/9AuKn7frIdeGNV/TJwNvBvmu18M3B5VT0KuLx53TavB7aMvf5T4M+bbb4LeO1MejVZfwF8uqoeCzyR0fa3dl8nORl4HbCxqs5kdCn382nnvn4vcO6KZavt2+cDj2oeFwLv+kUaam3wM3ZT96p6AFi8qXurVNXWqrqqeX4voyA4mdG2vq9Z7X3AS2bTw8lIsh74DeDdzesA5wCXNKu0cZuPBJ4NXARQVQ9U1d20fF8zunz8oUkGwGHAVlq4r6vqi8CPVixebd+eB/x1jfwDcHSSE9faVpuDf3c3dT95Rn2ZiiQbgCcDVwAPr6qtMPrHAThhdj2biLcDfwDsaF4fB9xdVdub123c32cAc8B7mhLXu5McTov3dVV9H3gbcAujwP8xcCXt39eLVtu3+5RvbQ7+Nd3UvS2SPAz4CPD7VXXPrPszSUleCGyrqivHF+9m1bbt7wHwFOBdVfVk4Ke0qKyzO01N+zzgdOAk4HBGZY6V2ravH8w+fd/bHPydual7kiGj0P9AVX20WXzH4n/9mp/bZtW/CXgW8OIkNzEq4Z3D6H8ARzflAGjn/r4NuK2qrmheX8LoH4I27+vnAt+rqrmqmgc+CjyT9u/rRavt233KtzYHfydu6t7Uti8CtlTVn429dSlwQfP8AuDj0+7bpFTVW6pqfVVtYLRfP1tVrwA+B/x2s1qrthmgqm4Hbk3ymGbRc4DrafG+ZlTiOTvJYc13fXGbW72vx6y2by8F/nkzu+ds4MeLJaE1qarWPoAXAN8GvgP8p1n3Z0Lb+CuM/ot3DXB183gBo5r35cANzc9jZ93XCW3/rwKfbJ6fAXwVuBH4MHDwrPs3ge19ErC52d9/AxzT9n0N/Bfgm8C1wP8GDm7jvgY+yOg4xjyjEf1rV9u3jEo972yy7R8ZzXpac1teskGSOqbNpR5J0m4Y/JLUMQa/JHWMwS9JHWPwS1LHGPwSkGQhydVjj/12RmySDeNXXJRmbfDgq0id8POqetKsOyFNgyN+aQ+S3JTkT5N8tXk8sll+WpLLm2uhX57k1Gb5w5N8LMk3msczm4/qJ/mr5rryn0ly6Mw2Sp1n8Esjh64o9bxs7L17quppwF8yuiYQzfO/rqqzgA8A72iWvwP4QlU9kdF1dK5rlj8KeGdVPR64G3jphLdHWpVn7kpAkp9U1cN2s/wm4Jyq+m5zMbzbq+q4JHcCJ1bVfLN8a1Udn2QOWF9V9499xgbgshrdTIMk/xEYVtV/n/yWSbtyxC89uFrl+Wrr7M79Y88X8PiaZsjglx7cy8Z+/n3z/CuMrgwK8Argy83zy4Hfg6V7Ah85rU5Ka+WoQxo5NMnVY68/XVWLUzoPTnIFo4HSy5tlrwMuTvImRnfFek2z/PXApiSvZTSy/z1GV1yUHjKs8Ut70NT4N1bVnbPui7S/WOqRpI5xxC9JHeOIX5I6xuCXpI4x+CWpYwx+SeoYg1+SOub/A6yD2AB/d07aAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# train the model!\n",
"model.train()\n",
"lossHistory = []\n",
"for epoch in range(epochs):\n",
" lossTotal = 0\n",
" for x,y in dataLoader:\n",
" hState = torch.zeros([numLayers, batchSize, hiddenSize])\n",
" yhat= model(x.reshape([batchSize,sequenceLength, 1]),hState)\n",
" \n",
" loss = lossFn(yhat.view(-1),y)\n",
" \n",
" model.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" lossTotal +=loss\n",
" lossHistory.append(lossTotal)\n",
" print(lossTotal.item())\n",
" \n",
"plt.plot(lossHistory)\n",
"plt.title('Loss')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Loss')"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0.0000, 0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0060, 0.0070, 0.0080,\n",
" 0.0090, 0.0100, 0.0110, 0.0120, 0.0130, 0.0140, 0.0150, 0.0160, 0.0170,\n",
" 0.0180, 0.0190, 0.0200, 0.0210, 0.0220, 0.0230, 0.0240, 0.0250, 0.0260,\n",
" 0.0270, 0.0280, 0.0290, 0.0300, 0.0310, 0.0320, 0.0330, 0.0340, 0.0350,\n",
" 0.0360, 0.0370, 0.0380, 0.0390, 0.0400, 0.0410, 0.0420, 0.0430, 0.0440,\n",
" 0.0450, 0.0460, 0.0470, 0.0480, 0.0490])\n",
"tensor(0.0510)\n"
]
}
],
"source": [
"print(X[:sequenceLength])\n",
"print(X[sequenceLength+1])"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0509]], grad_fn=<AddmmBackward>)"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.eval()\n",
"model(X[:sequenceLength].reshape(1,sequenceLength,1),torch.zeros([numLayers, 1, hiddenSize]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment