Skip to content

Instantly share code, notes, and snippets.

@rajshah4
Last active March 4, 2018 01:43
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save rajshah4/aa6c67944f4a43a7c9a1204301788e0c to your computer and use it in GitHub Desktop.
Save rajshah4/aa6c67944f4a43a7c9a1204301788e0c to your computer and use it in GitHub Desktop.
RNN_Addition_1stgrade
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Teaching a computer to add (using memorization)**\n",
"The goal here is to take advantage of Recurrent Neural Networks, for more background see my blog post at http://projects.rajivshah.com/blog/2016/04/05/rnn_addition/ This code was partially derived from https://github.com/yankev/tensorflow_example"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"#Import basic libraries\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"#from tensorflow.models.rnn import rnn_cell\n",
"#from tensorflow.models.rnn import rnn\n",
"#from tensorflow.models.rnn import seq2seq\n",
"from numpy import sum\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"#Defining some hyper-params\n",
"num_units = 50 #this is the parameter for input_size in the basic LSTM cell\n",
"input_size = 1 \n",
"batch_size = 50 \n",
"seq_len = 15\n",
"drop_out = 0.6 "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"#Creates our random sequences\n",
"def gen_data(min_length=5, max_length=15, n_batch=50):\n",
"\n",
" X = np.concatenate([np.random.randint(10,size=(n_batch, max_length, 1))],\n",
" axis=-1)\n",
" y = np.zeros((n_batch,))\n",
" # Compute masks and correct values\n",
" for n in range(n_batch):\n",
" # Randomly choose the sequence length\n",
" length = np.random.randint(min_length, max_length)\n",
" X[n, length:, 0] = 0\n",
" # Sum the dimensions of X to get the target value\n",
" y[n] = np.sum(X[n, :, 0]*1)\n",
" return (X,y)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"### Model Construction\n",
"num_layers = 2\n",
"cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)\n",
"cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)\n",
"cell = tf.nn.rnn_cell.DropoutWrapper(cell,output_keep_prob=drop_out)\n",
"\n",
"#create placeholders for X and y\n",
"inputs = [tf.placeholder(tf.float32,shape=[batch_size,1]) for _ in range(seq_len)]\n",
"result = tf.placeholder(tf.float32, shape=[batch_size])\n",
"initial_state = cell.zero_state(batch_size, tf.float32)\n",
"\n",
"outputs, states = tf.nn.seq2seq.rnn_decoder(inputs, initial_state, cell, scope ='rnnln')\n",
"outputs2 = outputs[-1]\n",
"\n",
"W_o = tf.Variable(tf.random_normal([num_units,input_size], stddev=0.01)) \n",
"b_o = tf.Variable(tf.random_normal([input_size], stddev=0.01))\n",
"\n",
"outputs3 = tf.matmul(outputs2, W_o) + b_o\n",
"\n",
"cost = tf.pow(tf.sub(tf.reshape(outputs3, [-1]), result),2)\n",
"\n",
"train_op = tf.train.RMSPropOptimizer(0.005, 0.2).minimize(cost) \n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"### Generate Validation Data\n",
"tempX,y_val = gen_data(5,seq_len,batch_size)\n",
"X_val = []\n",
"for i in range(seq_len):\n",
" X_val.append(tempX[:,i,:])"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"##Run this cell to see what the inputs look like \n",
"print (tempX[1]) \n",
"print (y_val[1])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"##Session\n",
"sess = tf.Session()\n",
"sess.run(tf.initialize_all_variables())\n",
"train_score =[]\n",
"val_score= []\n",
"x_axis=[]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"num_epochs=1000\n",
" \n",
"for k in range(1,num_epochs):\n",
"\n",
" #Generate Data for each epoch\n",
" tempX,y = gen_data(5,seq_len,batch_size)\n",
" X = []\n",
" for i in range(seq_len):\n",
" X.append(tempX[:,i,:])\n",
"\n",
" #Create the dictionary of inputs to feed into sess.run\n",
" temp_dict = {inputs[i]:X[i] for i in range(seq_len)}\n",
" temp_dict.update({result: y})\n",
"\n",
" _,c_train = sess.run([train_op,cost],feed_dict=temp_dict) #perform an update on the parameters\n",
"\n",
" val_dict = {inputs[i]:X_val[i] for i in range(seq_len)} #create validation dictionary\n",
" val_dict.update({result: y_val})\n",
" c_val = sess.run([cost],feed_dict = val_dict ) #compute the cost on the validation set\n",
" if (k%100==0):\n",
" train_score.append(sum(c_train))\n",
" val_score.append(sum(c_val))\n",
" x_axis.append(k)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final Train cost: 3086.54125977, on Epoch 999\n",
"Final Validation cost: 2445.63671875, on Epoch 999\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEACAYAAAC+gnFaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XmYVNW19/HvYhQBRcCgzKioYEwYoqBG08HhYmJABQWN\nwBPQKDjeq+aCyY2Ywau+McbcRIwRo+AQGRQVEUGlHRkMgqCI4AAKAkEUEFHpptf7xz5FF031XN2n\nquv3eZ566vSuc6pWkbhX7eHsbe6OiIjkrnpxByAiIvFSIhARyXFKBCIiOU6JQEQkxykRiIjkOCUC\nEZEcV2YiMLP9zGyhmS01s7fMbHxUPt7M1pnZkuhxZtI148xstZmtNLMzksp7m9ny6LU7k8obm9mj\nUfkCM+tUA99TRERKUWYicPevgR+6ew+gB9DfzPoADvzR3XtGj2cAzKw7MAToDvQH7jIzi95uAjDK\n3bsCXc2sf1Q+CtgSld8B3JrerygiImUpt2vI3XdGh42AhoQkAGApTh8IPOLuBe6+BngP6GNmhwLN\n3X1RdN4k4OzoeADwQHQ8HTi1sl9CRESqrtxEYGb1zGwpsAmYk1SZX2lmb5rZRDNrEZW1BdYlXb4O\naJeifH1UTvT8MYC7FwLbzKxlVb+QiIhUTkVaBEVR11B7wq/7YwjdPF0I3UUbgNtrNEoREakxDSp6\nortvM7N5QH9331Pxm9m9wFPRn+uBDkmXtSe0BNZHxyXLE9d0BD4xswbAge7+WcnPNzMtiiQiUknu\nnqobfy/lzRpqnej2MbMmwOnAO2Z2SNJp5wDLo+MngaFm1sjMugBdgUXuvhHYbmZ9osHjYcATSdeM\niI4HA8+X8YUy+nHjjTfGHoPiVJyKU3EmHhVVXovgUOABM6tPSBqPuvssM5tkZj0IA8cfApdGFfUK\nM5sCrAAKgTFeHM0Y4H6gCTDL3WdH5ROByWa2GtgCDK1w9CIiUm1lJgJ3Xw70SlE+vIxrbgZuTlG+\nGDg2Rfk3wPkVCVZERNJPdxanUV5eXtwhVIjiTC/FmV6Ks/ZZZfqR4mRmni2xiohkAjPDqztYLCIi\ndZ8SgYhIjlMiEBHJcUoEIiI5TolARCTHKRGIiOQ4JQIRkRynRCAikuOUCEREcpwSgYhIjlMiEBHJ\ncUoEIiI5TolARCTHKRGIiOQ4JQIRkRynRCAikuOUCEREcpwSgYhIjsuuRLBjR9wRiIjUOdmVCGbN\nijsCEZE6J7sSwfTpcUcgIlLnlJkIzGw/M1toZkvN7C0zGx+VtzSzuWa2yszmmFmLpGvGmdlqM1tp\nZmcklfc2s+XRa3cmlTc2s0ej8gVm1qnUgJ59Fr76qjrfV0RESigzEbj718AP3b0H0APob2Z9gLHA\nXHc/Eng++hsz6w4MAboD/YG7zMyit5sAjHL3rkBXM+sflY8CtkTldwC3lhpQr14hGYiISNqU2zXk\n7jujw0ZAQ8CBAcADUfkDwNnR8UDgEXcvcPc1wHtAHzM7FGju7oui8yYlXZP8XtOBU0sNZtAgdQ+J\niKRZuYnAzOqZ2VJgEzAnqszbuPum6JRNQJvouC2wLunydUC7FOXro3Ki548B3L0Q2GZmLVMGc845\nMHMmfPNNBb6aiIhURIPyTnD3IqCHmR0IPG5m3y7xupuZ11SAycbfcw80bw4jR5J3ySXk5eXVxseK\niGSF/Px88vPzK31duYkgwd23mdk84D+ATWZ2iLtvjLp9/h2dth7okHRZe0JLYH10XLI8cU1H4BMz\nawAc6O6fpYrhV78aT4MWLWDZMlASEBHZS15e3l4/kG+66aYKXVferKHWiRlBZtYEOB14B3gSGBGd\nNgKYER0/CQw1s0Zm1gXoCixy943AdjPrEw0eDwOeSLom8V6DCYPPKc2cCZx7Ljz5JBQUVOgLiohI\n2cobIzgUeMHM3gQWEcYIZgG3AKeb2SqgX/Q37r4CmAKsAJ4Bxrh7ottoDHAvsBp4z91nR+UTgVZm\nthq4hmgGUioTJgAdO8Jhh8GLL1b6y4qIyL6suJ7ObGbmrVs78+fDEY/dBh98AHffHXdYIiIZy8xw\ndyv3vGxKBNdd55jBbZe+DyeeCJ98AvXrxx2aiEhGqmgiyKolJi69FO6/H75udzi0bQuvvBJ3SCIi\nWS+rEsERR0DPnjBtGrq5TEQkTbIqEQCMHh0NGicSQVFR3CGJiGS1rEsEZ50Fa9fCsoJu0KIFLFwY\nd0giIlkt6xJBgwZwySUlWgUiIlJlWTVrKBHr+vXw7W/DR7PeovkFZ8GHH4KVOzAuIpJT6uSsoYR2\n7aBfP3ho6THQsCG88UbcIYmIZK2sTAQQDRrfbfigwdE0IhERqYqsTQT9+oXNyuZ3HR7GCbKki0tE\nJNNkbSKoVw8uuwwm5B8Nu3bBW2/FHZKISFbKysHihC1b4PDD4b0Lf03rgw0quOSqiEguqNODxQmt\nWsHAgfAPG6lppCIiVZTViQDCoPHfnu1E0WdbYeXKuMMREck6WZ8I+vSB5s2N544bp1aBiEgVZH0i\nMIumkn42RNNIRUSqIKsHixN27ICOHZ1l9XrSftFjYQczEZEclxODxQnNmsGFFxp/7/Q7dQ+JiFRS\nnWgRQLiN4D/yvmZNl340fP21WoxMRCQz5VSLAMIidIcd3YgnVx4JH30UdzgiIlmjziQCgNFj6jGh\n+S/gscfiDkVEJGvUqUQwaBAs/+pwVj24KO5QRESyRp1KBI0bw88urs/f3joRNmyIOxwRkaxQZiIw\nsw5mNs/M3jazt8zsqqh8vJmtM7Ml0ePMpGvGmdlqM1tpZmcklfc2s+XRa3cmlTc2s0ej8gVm1qk6\nX+jSMQ2YxHC+evTJ6ryNiEjOKK9FUAD8p7sfA/QFLjezboADf3T3ntHjGQAz6w4MAboD/YG7zPZs\nHTYBGOXuXYGuZtY/Kh8FbInK7wBurc4X6tIFjjtmJ1Pu2VqdtxERyRllJgJ33+juS6PjHcA7QLvo\n5VRTkgYCj7h7gbuvAd4D+pjZoUBzd0903k8Czo6OBwAPRMfTgVOr+F32GD3uICas6gebN1f3rURE\n6rwKjxGYWWegJ7AgKrrSzN40s4lm1iIqawusS7psHSFxlCxfT3FCaQd8DODuhcA2M2tZua+xtx+d\n05hPGnVhyf+9Up23ERHJCQ0qcpKZNQOmAVe7+w4zmwD8Jnr5t8DthC6eGjV+/Pg9x3l5eeTl5aU8\nr359+PlPNjDhH4255zcpTxERqXPy8/PJz8+v9HXl3llsZg2BmcAz7v6nFK93Bp5y92PNbCyAu98S\nvTYbuBFYC8xz925R+QXAKe4+OjpnvLsvMLMGwAZ3PzjF55R5Z3FJG9//km5H7GLNGjiw00EVvk5E\npK5Iy53F0UDvRGBFchKI+vwTzgGWR8dPAkPNrJGZdQG6AovcfSOw3cz6RO85DHgi6ZoR0fFg4Ply\nv10FHHJ4U05vu4IH/2dVOt5ORKTOKq9r6CTgImCZmS2Jym4ALjCzHoTZQx8ClwK4+wozmwKsAAqB\nMUk/48cA9wNNgFnuPjsqnwhMNrPVwBZgaDq+GMDoETu58i9HMsbDctUiIrKvOrPoXCr++Va6t97E\nPU+35+T+TWsoMhGRzJRzi86lYge14LKj8pnwW00jFREpTZ1OBADDRzdl1usH8+9/xx2JiEhmqvOJ\n4KALz+RcHuO+u3fFHYqISEaq84mAVq0Y3WMBf/trAbt3xx2MiEjmqfuJADhu5LG0KtzEnDlxRyIi\nknlyIhFwzjmM/uoOJvxVTQIRkZJyIxG0acPQXqt49aXd2sVSRKSE3EgEQNMhZ3FRu3zuuSfuSERE\nMkudvqFsL+vXs6L7YE7d/zXWrjUaNUpfbCIimUg3lJXUrh3du8NRB3/OjBlxByMikjlyJxEADB7M\n6NZTmTAh7kBERDJH7nQNAXz4IbuOO4mODdaTn28cfXR6YhMRyUTqGkqlSxcadW7LqFPXcvfdcQcj\nIpIZcisRAAwaxM/rT+TBB2HnzriDERGJX04mgk5z7+WEvs4//xl3MCIi8cu9RHDkkXDwwVx2ygoN\nGouIkIuJAGDQIPqvn8jmzfCvf8UdjIhIvHIzEQweTP3Hp3Hpz12tAhHJebmZCLp3h/33Z1SvJTz2\nGGzdGndAIiLxyc1EYAaDB/OtF/5J//4waVLcAYmIxCc3EwHAoEEwfTqjL3Puvhuy5L46EZG0y91E\n0KMHACc3X4oZvPhizPGIiMQkdxOBGQwahD02ncsuQ4PGIpKzykwEZtbBzOaZ2dtm9paZXRWVtzSz\nuWa2yszmmFmLpGvGmdlqM1tpZmcklfc2s+XRa3cmlTc2s0ej8gVm1qkmvmhKgwbBtGkMH+bMmQMb\nN9baJ4uIZIzyWgQFwH+6+zFAX+ByM+sGjAXmuvuRwPPR35hZd2AI0B3oD9xlZokFjyYAo9y9K9DV\nzPpH5aOALVH5HcCtaft25Tn+eNi5kwPXr2DwYJg4sdY+WUQkY5SZCNx9o7svjY53AO8A7YABwAPR\naQ8AZ0fHA4FH3L3A3dcA7wF9zOxQoLm7L4rOm5R0TfJ7TQdOre6XqjAzOPfcMGg8Gu65B3ZrW2MR\nyTEVHiMws85AT2Ah0MbdN0UvbQLaRMdtgXVJl60jJI6S5eujcqLnjwHcvRDYZmYtK/MlqmXwYJg2\njV694JBD4Jlnau2TRUQyQoOKnGRmzQi/1q929y+Ke3vA3d3MamXy5fjx4/cc5+XlkZeXV/03PfFE\n2LwZVq9m9OiuTJgAZ51V/bcVEalt+fn55OfnV/q6cjemMbOGwEzgGXf/U1S2Eshz941Rt888dz/a\nzMYCuPst0XmzgRuBtdE53aLyC4BT3H10dM54d19gZg2ADe5+cIo4qr8xTWnGjIGOHfnq6rF06ACv\nvw5dutTMR4mI1Ja0bEwTDfROBFYkkkDkSWBEdDwCmJFUPtTMGplZF6ArsMjdNwLbzaxP9J7DgCdS\nvNdgwuBz7Yq6h5o0gWHDwliBiEiuKLNFYGbfB14ClgGJE8cBi4ApQEdgDXC+u2+NrrkBGAkUErqS\nno3KewP3A02AWe6emIraGJhMGH/YAgyNBppLxlJzLYLCQjj0UHj9dd79pjOnnAIffQSNG9fMx4mI\n1IaKtghya8/islxyCRx9NFx7LaeeChdfDBdcUHMfJyJS07RncWVFaw8BjB6tO41FJHeoRZCwa1eY\nP7psGQVt2tOpE8ydC8ccU3MfKSJSk9QiqKxGjeAnP4HHH6dhw9A1dPfdcQclIlLzlAiSJXUPXXIJ\nPPQQ7NgRc0wiIjVMiSDZGWfA0qWwaRMdOsDJJ8Mjj8QdlIhIzVIiSLbffnDmmTAj3BaRGDTOkmEU\nEZEqUSIoKbq5DEIDYds2WLSonGtERLKYZg2V9OWX0LYtfPABtGrFbbfBihVw//01/9EiIumkWUNV\n1bQpnH46PBFWwPjZz0JP0WefxRyXiEgNUSJIJWn20MEHh9VIH3ignGtERLKUuoZS2b4d2reHjz+G\nAw/k1Vdh5EhYuTLsZSMikg3UNVQdBxwAeXnw1FNA2LKgUSN44YV4wxIRqQlKBKVJ6h4y0/pDIlJ3\nqWuoNJ9/Dp06wSefQLNmbN8e/nz77TCpSNJv1y5YsgRefhleeSWs8/T738cdlUj2UtdQdR10UOgT\nmjULCL1FQ4bAvffGHFcdsn07zJkD//M/8MMfQsuWcOmlsGZNuJ3j3nvhrbfijlKk7lOLoCx//zs8\n9xw8+igAb74ZZhB9+CE0qNBuz5Lsk0/CL/3EY9Uq6N0bvv/98DjhBGjRovj8O+8MieLpp+OLWSSb\naWOadNi8GY44AjZuhCZNgNBI+O//hoEDazeUbOMO775b3M3zyiuwdSucdFJxxd+7d9m7wO3aBd26\nhXzcr1/txS5SVygRpEu/fnDVVXD22QBMnhxWJZ09u/ZDyWQl+/dffRWaNSuu9L///VCp16tkZ+SU\nKXDrrfD665W/ViTXKRGky113wWuvwYMPAvD119ChAyxYAIcfXvvhZIrt28O/QaLif/310HhKrvjb\nt6/+57hD375w9dVw4YXVfz+RXKJEkC4bNkD37qF7KOrHuO668Ov0tttqP5y4bNgQKvxExV9e/346\nvfQSDB8ebujbb7+a+QyRukiJIJ1OPhnGjoUf/xiA1atDX/dHH9XNiinRv59c8Ve2fz/dBg6EU06B\na6+tvc8UyXZKBOn0pz/BsmVw3317is44I/xKveiieEJKp0T/fqLiT1f/fjqtXBny8bvvhmmmIlI+\nJYJ0+ugj6NUr9I80bAjA44/D7beHyjPbJPr3ExV/TfXvp9tll4UE9Yc/xB2JSHZIWyIws/uAHwP/\ndvdjo7LxwMXA5ui0G9z9mei1ccBIYDdwlbvPicp7A/cD+wGz3P3qqLwxMAnoBWwBhrj72hRxxJcI\nAPr0gd/9LixRDRQWQufO8MwzcOyx8YVVHvfQn79gAcyfH57fe6/2+vfTaePGcLfx4sXh315EypbO\nRHAysAOYlJQIbgS+cPc/lji3O/AwcBzQDngO6OrubmaLgCvcfZGZzQL+7O6zzWwM8G13H2NmQ4Bz\n3H1oijjiTQS33RY2q7n77j1FN90EmzaFiUWZYvv2sKNaotJfsCD8ij7hhPDo2xd69Kjd/v10uumm\nkNgeeijuSEQyX1q7hsysM/BUiUSww91vL3HeOKDI3W+N/p4NjAfWAi+4e7eofCiQ5+6XRefc6O4L\nzawBsMHdD04RQ7yJ4P33w91kn3wC9esDsH59aA2sXQvNm9d+SEVFoc88UenPnx/ueu7Zs7jS79u3\nbq2NtGMHHHlkWBi2d++4oxHJbBVNBNVZKOFKMxsO/Au41t23Am2BBUnnrCO0DAqi44T1UTnR88cA\n7l5oZtvMrKW7Z9aeYIcfHmrUV16BH/wAgHbtwmrVDz0U+q9r2tatsHBhcaW/cGFYEilR6f/85/Cd\n74Qls+uqZs3gxhvh+uvh+ee1P4RIOlQ1EUwAfhMd/xa4HRiVlojKMH78+D3HeXl55OXl1fRH7i2x\nsX2UCCAsT33ddWGxtHRWSkVFYa/k5L79jz4Kv4L79g2fe//9cMgh6fvMbDFqVJjI9cwz8KMfxR2N\nSObIz88nPz+/0tdVqWuotNfMbCyAu98SvTYbuJHQNTQvqWvoAuAUdx+d6D5y9wUZ3TUEYQ7jqaeG\nncuiuZRFRXDUUWEryxNPrPpbf/ZZ+IWfqPQXLQrbZPbtW/yL/zvf0WJ3CU8+CTfcEBYCjHrqRKSE\nGl2G2swOTfrzHGB5dPwkMNTMGplZF6ArsMjdNwLbzayPmRkwDHgi6ZoR0fFg4PmqxFQrjj46TK9Z\nuHBPUb16oVsoaQy5XLt3h9sS/vY3+NnPwtt27gz/7/+F1668Mty0tnp1WNtozJgwe1VJoNhPfgKt\nWoVWkYhUT0VmDT0C/ABoDWwi/MLPA3oADnwIXOrum6LzbyBMHy0Ernb3Z6PyxPTRJoTpo1dF5Y2B\nyUBPwvTRoe6+JkUc8bcIIHRQ79gRbiKIbNkS5uG/916onEr69NPiGTzz54d5+4ceuvdMnm9/W79s\nK2vRIjj33DBg3rRp3NGIZB7dUFZTli2DAQPC9JykQYERI8IMomuuCZupzJ9f3M2zaRMcf3xxpd+n\nT+qEIZU3ZEjoMvvlL+OORCTzKBHUFPcwKPDww/C97+0pXrBgz71mdOhQ3Ld/wglheQb92q8Z778f\nEuuKFfCtb8UdjUhmUSKoSePGhef//d+9ihNLNRx0UAwx5bBrrgl3ev/lL3FHIpJZlAhq0uLFMHRo\nuMVVE9nTr7AwrD194IEVumvs009Dq+vVV8PNZiISaPP6mtSrV6isli8v/1ypmMLCcIfYZZeFO/Wu\nuw7694d33in30tatw+mJhpqIVI4SQVWYhekq06fHHUl2KyyEuXPDLdFt24bNoLt0CTvCvfFG6Hob\nPBi+/LLct7rqqtA199prtRC3SB2jrqGqmj8fLr4Y3n477kiyS0EBzJsHU6fCjBmh4j/vvFDhd+my\n97nu4UaLwsJwQ0U53XCTJoX7OV59VT12IqAxgppXVBSmBz33XOigltIVFIRun2nTQuV/+OHFlX95\n60nv3BmmYF1+eVjHowy7d4chhV//OjTYRHKdEkFtuOqqMGfxV7+KO5LMs2tXqPynToUnngijuOed\nB4MGQadOlXuvVavCPpmzZ5c7eDxnDlxxRWioRXsIieQsJYLa8OKLYe7ikiVxR5IZdu0KLaSpU8Ni\nQEcdVVz5d+xYvfeeOjWMISxeXO783P/4j3DP3+WXV+8jRbKdEkFt2L07DHK+9lro7shF33wTBnyn\nTg2bBHTrVlz5d+iQ3s+65pqwOdCMGWVuoLx0aZhwtGoVHHBAekMQySZKBLXlssvgsMPgF7+IO5La\n8803oQ9m6lSYOTPsH5mo/Nu1K//6qtq1KywBfs455f57jxgR8tDvfldz4YhkOiWC2jJ3bhgjSFqR\ntE76+mt49tkw4DtzZlhYKVH51+YWaB9/DMcdB48+ute+EKlO69EjLA1Vk7lJJJMpEdSWgoKwlOgb\nb1S/HzzTfP11GKCdOhVmzYLvfjdU/ueeG75zXJ59FkaODOMFZezMM3YsbN4MEyfWYmwiGUSJoDaN\nHBmWwLzmmrgjqb6vvtq78u/Zs7jyz6Tt0MaPD4P1c+eWulHDtm1hstJzz4UGjEiuUSKoTbNmwc03\nh/2Ms9HOnWHfx2nTwnOvXsWVf5s2cUeX2u7dcOaZYQXYm28u9bQ77wzDGU8/XYuxiWQIJYLa9M03\n4dfyihXxdplUxs6dIYFNnRq6Wr73vVD5n3NO9qznvHlzSFoTJsBZZ6U8ZdeuMJHp73+Hfv1qOT6R\nmCkR1LaLLgqbFo8ZUzPvX1QUEs4334S++6+/rthxqtfWrAk/k48/vrjyP3ifbaKzw6uvhpbLggX7\nLlERmTIFbr01rEVUxqxTkTpHiaC2zZgBv/89XHtt9Srp0o537YLGjWG//Yqfq3rcpg38+Mdh2c66\n4I474KGHQlJo3Hifl93D5jVXXw0//WkM8YnERImgtn31Vdh1fseOsivjqlbijRppJbXSuId1i9q0\ngbvuSnnKSy/B8OGwcmX45xTJBUoEklu2bQvjHDfdBBdemPKUgQPhlFNCo00kFygRSO5580047bQw\nrbR7931efuedkAjefRdatowhPpFaph3KJPd897tw222hm2jHjn1e7tYt3AhdxmxTkZykFoHUPaNG\nhTGbhx7aZ1xl48awNNLixeVvhSCS7dLWIjCz+8xsk5ktTypraWZzzWyVmc0xsxZJr40zs9VmttLM\nzkgq721my6PX7kwqb2xmj0blC8yskovVi5Twl7+EDQnuvnuflw45JIzp//KXMcQlkqEq0jX0D6B/\nibKxwFx3PxJ4PvobM+sODAG6R9fcZbbnJ9kEYJS7dwW6mlniPUcBW6LyO4Bbq/F9RKBJk3CX9I03\nhpsHSrjuurBb5uLFMcQmkoHKTQTu/jLweYniAcAD0fEDwNnR8UDgEXcvcPc1wHtAHzM7FGju7oui\n8yYlXZP8XtOBU6vwPUT21rVraBGcfz589tleLzVrFnLE9deHmaciua6qg8Vt3H1TdLwJSCxI0xZY\nl3TeOqBdivL1UTnR88cA7l4IbDMzzemQ6jv33PAYPjzcmZ1k1CjYsCEsrSSS61Iv21gJ7u5mViu/\nq8aPH7/nOC8vj7y8vNr4WMlmt9wCeXlhjYlx4/YUN2gQin7xi7C1Zf368YUoki75+fnk5+dX+roK\nzRoys87AU+5+bPT3SiDP3TdG3T7z3P1oMxsL4O63ROfNBm4E1kbndIvKLwBOcffR0Tnj3X2BmTUA\nNrj7PgvfaNaQVNm6dWEzm4cfhh/+cE+xe9jbZsSI0EIQqWtq+j6CJ4ER0fEIYEZS+VAza2RmXYCu\nwCJ33whsN7M+0eDxMOCJFO81mDD4LJI+7dvDpElhoaENG/YUm8Ef/hDGC778Msb4RGJWbovAzB4B\nfgC0JowH/JpQiU8BOgJrgPPdfWt0/g3ASKAQuNrdn43KewP3A02AWe5+VVTeGJgM9AS2AEOjgeaS\ncahFINXzm9/A88+HR9JmNkOGhI1rfvWrGGMTqQFaYkKkpKIi+NGPwmbGt9yyp/j998PqpCtWZM9W\nDCIVoUQgksqnn0Lv3vB//wcDBuwpvuYaKCwM96KJ1BVKBCKlWbAgLEU6fz4cdhgQ8kO3bmFLgyOP\njDk+kTTRonMipenbN6wxcd55YeMfwh4911231wxTkZyhFoHkJvcwStyy5Z41ib76Co46Ch55BE46\nKeb4RNJALQKRspjBvfeGRYcmTwbCEkW//a2WnpDco0QgueuAA8LidP/1X/DWWwBcdBHs3AmPPx5z\nbCK1SIlActuxx8Ltt4fNbL74gvr1w942Y8dCQUHcwYnUDiUCkeHDwx6WF18M7pxxRti05p574g5M\npHZosFgEwuyhE0+EkSPhiitYuhT694dVq0IPkkg20n0EIpX1/vtwwgkwcyYcfzwjRkCHDvC738Ud\nmEjVKBGIVMWMGeE248WL+XhnK3r0gGXLoF278i8VyTRKBCJVdf31Yc/jmTMZe0M9Nm+GiRPjDkqk\n8pQIRKqqoAD69YP+/dl6+S856ih47rkwwUgkmygRiFTHJ5/A974Hkydz51unMmcOPP103EGJVI7u\nLBapjrZtwx3Hw4YxesB6Vq6EF16IOyiRmqEWgUhZfv97mD2bRy+bx21/bMDrr0M9/XySLKEWgUg6\njBsHzZtz/pJx1K8fFqQTqWvUIhApz5Yt0Ls3L108ieH3nsLKlbDffnEHJVI+tQhE0qVVK5gyhVP+\nPJjvHvElf/1r3AGJpJdaBCIV9de/8s5fnueUT6fz7rtGy5ZxByRSNrUIRNJtzBi69WjMua1f5uab\n4w5GJH3UIhCpjC++YGOvH3HMhrn8a/l+dOkSd0AipVOLQKQmNG/OITPu5sqiP/OrK7bGHY1IWlQr\nEZjZGjNbZmZLzGxRVNbSzOaa2Sozm2NmLZLOH2dmq81spZmdkVTe28yWR6/dWZ2YRGrcMcdw3Z87\nMu/ZXSwR0C9QAAAKxElEQVR+cUfc0YhUW3VbBA7kuXtPdz8+KhsLzHX3I4Hno78xs+7AEKA70B+4\ny8wSTZYJwCh37wp0NbP+1YxLpEY1u3goN544l+vPW4MXqctSsls6uoZK9j8NAB6Ijh8Azo6OBwKP\nuHuBu68B3gP6mNmhQHN3XxSdNynpGpGMNWrWIDZsb8pN332M+XctYeeOorhDEqmSdLQInjOzf5nZ\nJVFZG3ffFB1vAtpEx22BdUnXrgPapShfH5WLZLQGzfbj4Vkt+GT/I7ji2sa0PuAbvn3wJoYP3Mqd\nd8LLL8MXX8QdpUj5GlTz+pPcfYOZHQzMNbOVyS+6u5tZ2trN48eP33Ocl5dHXl5eut5apEp69juI\nexYeBO7sWryct+98jsVPfcIb8/vwyP7fZ/m/D6FDR6NXL+jVC3r3hp49oUWL8t9bpLLy8/PJz8+v\n9HVpmz5qZjcCO4BLCOMGG6Nun3nufrSZjQVw91ui82cDNwJro3O6ReUXAD9w98tKvL+mj0p22L07\nLFX64IMUzHiald85nzeOGcbiet/jjWUNefNNaNOGfZJD69ZxBy51TY3vR2Bm+wP13f0LM2sKzAFu\nAk4Dtrj7rVHl38Ldx0aDxQ8DxxO6fp4DjohaDQuBq4BFwNPAn919donPUyKQ7PPll2H7ywcfhAUL\nYMAAdl84jFXtfsgbb9bnjTdg8WJYsiS0Enr33jtBtGlT/keIlKY2EkEX4PHozwbAQ+7+v2bWEpgC\ndATWAOe7+9bomhuAkUAhcLW7PxuV9wbuB5oAs9z9qhSfp0Qg2W3jRvjnP8M+Bxs3woUXwrBh8J3v\nUFQEH3wQkkIiObzxBjRpUpwUEs9t24KV+5925isshM8/D2v6JR7btkGfPnDkkXFHVzdohzKRTLZi\nRWglPPhgaAoMGxYSQ7vieRLusHZtcVJIJAizvZNDr17QqVO8yWHnzr0r9MTj009Tl2/ZEgbSDzww\nrOnXqlXoGmvaFF56Cb71LTjvvPBQUqg6JQKRbFBUFKYXTZ4Mjz0WavWLLoJBg6B5831Od4f16/du\nNSxeDN98s29yOPzwyieHoqJ9f6VX5AHFFXpFHy1aQP36+8aweze8+ipMnQrTpoXusfPPD0mha9cq\n/BvnMCUCkWzz1Vcwc2ZICi++CD/+cUgKZ5wBDcqe4LdhQxhnSG49bNsWBqETiaFx4/Ir9K1bQ/4p\nWWm3bl12pb7//jXzT7J7N7zySnFSOPTQ4qRwxBE185l1iRKBSDb79FN49NHQdfTBBzB0aOg+6t27\nwj/zN28OySGRGHbvLv9X+kEHlZtzYpNIClOmwPTpYawk0X2kpJCaEoFIXbF6dfF4QqNGoZXw059C\n585xRxab3btDj9rUqcVJIdFSOPzwuKPLHEoEInWNO8yfHxLClCnQvXtoJZx3Xk7foZZIClOmhGGW\ndu2KWwq5nhSUCETqsl27YNaskBTmzoXTTw9J4cwzQ6shR+3eHWYdJVoK7dsXtxQOOyzu6GqfEoFI\nrvj88zCSOnlymJZ6/vmh++iEE+rGDQdVlEgKiZZChw7FLYU6kxTc4euvw1zcFA8bOlSJQCTnfPgh\nPPxwSAqFhcXjCTk+77KwcO+WQseOxS2FWt9lrrBw30p7+/ZSK/NyHw0ahKleKR42daoSgUjOcg9z\nSSdPDnczd+kCRx+99+slz6/o39W5trz3gtCKqV8f6tVL/VzWa+U9169PodfnxQ86MHXJETy2tAud\nWu3g/OPXcl6fj+h8yNcV+1wIy4dUpeIuKIBmzfatuA84oNQKvcxHw4b7/hvu+adU15CIQKh4Xngh\n3GyQrGS3UWX+rs615b1XUVF47N6973Oqsoo+pygrLIQX1x/BlA+P4/GPetG56aec134+57V7jc77\nbSz92qKicBt0VSrwJk1qrctOiUBEpBIKCyE/P3QfPfZYaEQluo86dYo7utK5hzvLd+wIjZTk51NP\nVSIQEamSRFKYMgUefzwMLicGmquaFNzDzeOpKuzEc1mvlXVuw4ahgdKs2d7P+flKBCIi1VZQUNxS\nSCSFAQPCkh2VqbC//DJc06zZvhV2qufKnFPa3eDqGhIRSbOCApg3D2bPDt38lanImzZNvcheTVIi\nEBHJcRVNBNXdvF5ERLKcEoGISI5TIhARyXFKBCIiOU6JQEQkxykRiIjkOCUCEZEclzGJwMz6m9lK\nM1ttZv8ddzwiIrkiIxKBmdUH/gL0B7oDF5hZt3ijqrz8/Py4Q6gQxZleijO9FGfty4hEABwPvOfu\na9y9APgnMDDmmCotW/6PoTjTS3Gml+KsfZmSCNoBHyf9vS4qExGRGpYpiUCLCImIxCQjFp0zs77A\neHfvH/09Dihy91uTzok/UBGRLJM1q4+aWQPgXeBU4BNgEXCBu78Ta2AiIjmglO0Mape7F5rZFcCz\nQH1gopKAiEjtyIgWgYiIxCdTBotLlQ03mpnZfWa2ycyWxx1LWcysg5nNM7O3zewtM7sq7phSMbP9\nzGyhmS2N4hwfd0ylMbP6ZrbEzJ6KO5bSmNkaM1sWxbko7nhKY2YtzGyamb1jZiuiscOMYmZHRf+O\nice2DP7v6D+j/36Wm9nDZta41HMzuUUQ3Wj2LnAasB54nQwcOzCzk4EdwCR3PzbueEpjZocAh7j7\nUjNrBiwGzs60f08AM9vf3XdG40evAFe7+8K44yrJzP4L6A00d/cBcceTipl9CPR298/ijqUsZvYA\n8KK73xf9797U3bfFHVdpzKweoV463t0/Lu/82mRm7YCXgW7u/o2ZPQrMcvcHUp2f6S2CrLjRzN1f\nBj6PO47yuPtGd18aHe8A3gHaxhtVau6+MzpsBDQEimIMJyUzaw/8CLgXKHdmRswyOj4zOxA42d3v\ngzBumMlJIHIa8H6mJYEkDYD9o6S6PyFppZTpiUA3mtUQM+sM9AQy7lc2hF9bZrYU2ATMcffX444p\nhTuA68nAJFWCA8+Z2b/M7JK4gylFF2Czmf3DzN4ws7+b2f5xB1WOocDDcQeRiruvB24HPiLMxNzq\n7s+Vdn6mJ4LM7bfKYlG30DRCd8uOuONJxd2L3L0H0B7oY2bHxB1TMjM7C/i3uy8hw39tAye5e0/g\nTODyqCsz0zQAegF3uXsv4EtgbLwhlc7MGgE/AabGHUsqZnYQMADoTGj1NzOzn5Z2fqYngvVAh6S/\nOxBaBVJFZtYQmA486O4z4o6nPFH3wDzCgoSZ5ERgQNT//gjQz8wmxRxTSu6+IXreDDxO6HLNNOuA\ndUktv2mExJCpzgQWR/+mmeg04EN33+LuhcBjhP/PppTpieBfQFcz6xxl4CHAkzHHlLXMzICJwAp3\n/1Pc8ZTGzFqbWYvouAlwOmE8I2O4+w3u3sHduxC6CF5w9+Fxx1WSme1vZs2j46bAGUDGzW5z943A\nx2Z2ZFR0GvB2jCGV5wLCD4BMtRboa2ZNov/uTwNWlHZyRtxQVppsudHMzB4BfgC0MrOPgV+7+z9i\nDiuVk4CLgGVmtiQqG+fus2OMKZVDgQeiWWP1gEfdfVbMMZUnU7sx2wCPh7qABsBD7j4n3pBKdSXw\nUPSj733gZzHHk1KUUE8DMnW8BXdfZGbTgDeAwuj5ntLOz+jpoyIiUvMyvWtIRERqmBKBiEiOUyIQ\nEclxSgQiIjlOiUBEJMcpEYiI5DglAhGRHKdEICKS4/4/QGQBkLhV900AAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x1102cbe90>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print \"Final Train cost: {}, on Epoch {}\".format(train_score[-1],k)\n",
"print \"Final Validation cost: {}, on Epoch {}\".format(val_score[-1],k)\n",
"plt.plot(train_score, 'r-', val_score, 'b-')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"##This part generates a new validation set to test against\n",
"val_score_v =[]\n",
"num_epochs=1\n",
"\n",
"for k in range(num_epochs):\n",
"\n",
" #Generate Data for each epoch\n",
" tempX,y = gen_data(5,seq_len,batch_size)\n",
" X = []\n",
" for i in range(seq_len):\n",
" X.append(tempX[:,i,:])\n",
"\n",
" val_dict = {inputs[i]:X[i] for i in range(seq_len)}\n",
" val_dict.update({result: y})\n",
" outv, c_val = sess.run([outputs3,cost],feed_dict = val_dict ) \n",
" val_score_v.append([c_val])\n",
"#print \"Validation cost: {}, on Epoch {}\".format(c_val,k)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(array([[8],\n",
" [2],\n",
" [8],\n",
" [8],\n",
" [9],\n",
" [6],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0]]), 41.0)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"##Target\n",
"tempX[3],y[3]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 44.25109482], dtype=float32)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Prediction\n",
"outv[3]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@genho
Copy link

genho commented Apr 21, 2016

Hi Rajiv, thank you so much for your example.
May I know the reason for using seq2seq.rnn_decoder() in your code? I've tried search on the web and many of the examples are related to language translation. And I really cannot find documentation which talks about the TensorFlow seq2seq class.
Thanks a lot.

@rajshah4
Copy link
Author

I found the documentation deep in the tensorflow ops code
It explains how the decoder operates. Does this help?

@genho
Copy link

genho commented May 9, 2016

Got it. Thank you so much 👍

@JackMedley
Copy link

Hi Rajiv,
Can I ask what the purpose of the dropout layer is in a problem such as this? When training for something like addition don't we need to know all of the inputs?
Thanks,
Jack

@rajshah4
Copy link
Author

Hmm, its a good question. This was one of my first RNNs and I just grabbed code from other projects. I am thinking that it would work like dropout generally, it would help against overfitting and get a better sense of how addition works. If you have the time, I would be curious if you played around with the dropout whether it works like that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment