Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
RNN_Addition_1stgrade
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Teaching a computer to add (using memorization)**\n",
"The goal here is to take advantage of Recurrent Neural Networks, for more background see my blog post at http://projects.rajivshah.com/blog/2016/04/05/rnn_addition/ This code was partially derived from https://github.com/yankev/tensorflow_example"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"#Import basic libraries\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"#from tensorflow.models.rnn import rnn_cell\n",
"#from tensorflow.models.rnn import rnn\n",
"#from tensorflow.models.rnn import seq2seq\n",
"from numpy import sum\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"#Defining some hyper-params\n",
"num_units = 50 #this is the parameter for input_size in the basic LSTM cell\n",
"input_size = 1 \n",
"batch_size = 50 \n",
"seq_len = 15\n",
"drop_out = 0.6 "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"#Creates our random sequences\n",
"def gen_data(min_length=5, max_length=15, n_batch=50):\n",
"\n",
" X = np.concatenate([np.random.randint(10,size=(n_batch, max_length, 1))],\n",
" axis=-1)\n",
" y = np.zeros((n_batch,))\n",
" # Compute masks and correct values\n",
" for n in range(n_batch):\n",
" # Randomly choose the sequence length\n",
" length = np.random.randint(min_length, max_length)\n",
" X[n, length:, 0] = 0\n",
" # Sum the dimensions of X to get the target value\n",
" y[n] = np.sum(X[n, :, 0]*1)\n",
" return (X,y)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"### Model Construction\n",
"num_layers = 2\n",
"cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)\n",
"cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)\n",
"cell = tf.nn.rnn_cell.DropoutWrapper(cell,output_keep_prob=drop_out)\n",
"\n",
"#create placeholders for X and y\n",
"inputs = [tf.placeholder(tf.float32,shape=[batch_size,1]) for _ in range(seq_len)]\n",
"result = tf.placeholder(tf.float32, shape=[batch_size])\n",
"initial_state = cell.zero_state(batch_size, tf.float32)\n",
"\n",
"outputs, states = tf.nn.seq2seq.rnn_decoder(inputs, initial_state, cell, scope ='rnnln')\n",
"outputs2 = outputs[-1]\n",
"\n",
"W_o = tf.Variable(tf.random_normal([num_units,input_size], stddev=0.01)) \n",
"b_o = tf.Variable(tf.random_normal([input_size], stddev=0.01))\n",
"\n",
"outputs3 = tf.matmul(outputs2, W_o) + b_o\n",
"\n",
"cost = tf.pow(tf.sub(tf.reshape(outputs3, [-1]), result),2)\n",
"\n",
"train_op = tf.train.RMSPropOptimizer(0.005, 0.2).minimize(cost) \n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"### Generate Validation Data\n",
"tempX,y_val = gen_data(5,seq_len,batch_size)\n",
"X_val = []\n",
"for i in range(seq_len):\n",
" X_val.append(tempX[:,i,:])"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"##Run this cell to see what the inputs look like \n",
"print (tempX[1]) \n",
"print (y_val[1])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"##Session\n",
"sess = tf.Session()\n",
"sess.run(tf.initialize_all_variables())\n",
"train_score =[]\n",
"val_score= []\n",
"x_axis=[]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"num_epochs=1000\n",
" \n",
"for k in range(1,num_epochs):\n",
"\n",
" #Generate Data for each epoch\n",
" tempX,y = gen_data(5,seq_len,batch_size)\n",
" X = []\n",
" for i in range(seq_len):\n",
" X.append(tempX[:,i,:])\n",
"\n",
" #Create the dictionary of inputs to feed into sess.run\n",
" temp_dict = {inputs[i]:X[i] for i in range(seq_len)}\n",
" temp_dict.update({result: y})\n",
"\n",
" _,c_train = sess.run([train_op,cost],feed_dict=temp_dict) #perform an update on the parameters\n",
"\n",
" val_dict = {inputs[i]:X_val[i] for i in range(seq_len)} #create validation dictionary\n",
" val_dict.update({result: y_val})\n",
" c_val = sess.run([cost],feed_dict = val_dict ) #compute the cost on the validation set\n",
" if (k%100==0):\n",
" train_score.append(sum(c_train))\n",
" val_score.append(sum(c_val))\n",
" x_axis.append(k)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final Train cost: 3086.54125977, on Epoch 999\n",
"Final Validation cost: 2445.63671875, on Epoch 999\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEACAYAAAC+gnFaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XmYVNW19/HvYhQBRcCgzKioYEwYoqBG08HhYmJABQWN\nwBPQKDjeq+aCyY2Ywau+McbcRIwRo+AQGRQVEUGlHRkMgqCI4AAKAkEUEFHpptf7xz5FF031XN2n\nquv3eZ566vSuc6pWkbhX7eHsbe6OiIjkrnpxByAiIvFSIhARyXFKBCIiOU6JQEQkxykRiIjkOCUC\nEZEcV2YiMLP9zGyhmS01s7fMbHxUPt7M1pnZkuhxZtI148xstZmtNLMzksp7m9ny6LU7k8obm9mj\nUfkCM+tUA99TRERKUWYicPevgR+6ew+gB9DfzPoADvzR3XtGj2cAzKw7MAToDvQH7jIzi95uAjDK\n3bsCXc2sf1Q+CtgSld8B3JrerygiImUpt2vI3XdGh42AhoQkAGApTh8IPOLuBe6+BngP6GNmhwLN\n3X1RdN4k4OzoeADwQHQ8HTi1sl9CRESqrtxEYGb1zGwpsAmYk1SZX2lmb5rZRDNrEZW1BdYlXb4O\naJeifH1UTvT8MYC7FwLbzKxlVb+QiIhUTkVaBEVR11B7wq/7YwjdPF0I3UUbgNtrNEoREakxDSp6\nortvM7N5QH9331Pxm9m9wFPRn+uBDkmXtSe0BNZHxyXLE9d0BD4xswbAge7+WcnPNzMtiiQiUknu\nnqobfy/lzRpqnej2MbMmwOnAO2Z2SNJp5wDLo+MngaFm1sjMugBdgUXuvhHYbmZ9osHjYcATSdeM\niI4HA8+X8YUy+nHjjTfGHoPiVJyKU3EmHhVVXovgUOABM6tPSBqPuvssM5tkZj0IA8cfApdGFfUK\nM5sCrAAKgTFeHM0Y4H6gCTDL3WdH5ROByWa2GtgCDK1w9CIiUm1lJgJ3Xw70SlE+vIxrbgZuTlG+\nGDg2Rfk3wPkVCVZERNJPdxanUV5eXtwhVIjiTC/FmV6Ks/ZZZfqR4mRmni2xiohkAjPDqztYLCIi\ndZ8SgYhIjlMiEBHJcUoEIiI5TolARCTHKRGIiOQ4JQIRkRynRCAikuOUCEREcpwSgYhIjlMiEBHJ\ncUoEIiI5TolARCTHKRGIiOQ4JQIRkRynRCAikuOUCEREcpwSgYhIjsuuRLBjR9wRiIjUOdmVCGbN\nijsCEZE6J7sSwfTpcUcgIlLnlJkIzGw/M1toZkvN7C0zGx+VtzSzuWa2yszmmFmLpGvGmdlqM1tp\nZmcklfc2s+XRa3cmlTc2s0ej8gVm1qnUgJ59Fr76qjrfV0RESigzEbj718AP3b0H0APob2Z9gLHA\nXHc/Eng++hsz6w4MAboD/YG7zMyit5sAjHL3rkBXM+sflY8CtkTldwC3lhpQr14hGYiISNqU2zXk\n7jujw0ZAQ8CBAcADUfkDwNnR8UDgEXcvcPc1wHtAHzM7FGju7oui8yYlXZP8XtOBU0sNZtAgdQ+J\niKRZuYnAzOqZ2VJgEzAnqszbuPum6JRNQJvouC2wLunydUC7FOXro3Ki548B3L0Q2GZmLVMGc845\nMHMmfPNNBb6aiIhURIPyTnD3IqCHmR0IPG5m3y7xupuZ11SAycbfcw80bw4jR5J3ySXk5eXVxseK\niGSF/Px88vPzK31duYkgwd23mdk84D+ATWZ2iLtvjLp9/h2dth7okHRZe0JLYH10XLI8cU1H4BMz\nawAc6O6fpYrhV78aT4MWLWDZMlASEBHZS15e3l4/kG+66aYKXVferKHWiRlBZtYEOB14B3gSGBGd\nNgKYER0/CQw1s0Zm1gXoCixy943AdjPrEw0eDwOeSLom8V6DCYPPKc2cCZx7Ljz5JBQUVOgLiohI\n2cobIzgUeMHM3gQWEcYIZgG3AKeb2SqgX/Q37r4CmAKsAJ4Bxrh7ottoDHAvsBp4z91nR+UTgVZm\nthq4hmgGUioTJgAdO8Jhh8GLL1b6y4qIyL6suJ7ObGbmrVs78+fDEY/dBh98AHffHXdYIiIZy8xw\ndyv3vGxKBNdd55jBbZe+DyeeCJ98AvXrxx2aiEhGqmgiyKolJi69FO6/H75udzi0bQuvvBJ3SCIi\nWS+rEsERR0DPnjBtGrq5TEQkTbIqEQCMHh0NGicSQVFR3CGJiGS1rEsEZ50Fa9fCsoJu0KIFLFwY\nd0giIlkt6xJBgwZwySUlWgUiIlJlWTVrKBHr+vXw7W/DR7PeovkFZ8GHH4KVOzAuIpJT6uSsoYR2\n7aBfP3ho6THQsCG88UbcIYmIZK2sTAQQDRrfbfigwdE0IhERqYqsTQT9+oXNyuZ3HR7GCbKki0tE\nJNNkbSKoVw8uuwwm5B8Nu3bBW2/FHZKISFbKysHihC1b4PDD4b0Lf03rgw0quOSqiEguqNODxQmt\nWsHAgfAPG6lppCIiVZTViQDCoPHfnu1E0WdbYeXKuMMREck6WZ8I+vSB5s2N544bp1aBiEgVZH0i\nMIumkn42RNNIRUSqIKsHixN27ICOHZ1l9XrSftFjYQczEZEclxODxQnNmsGFFxp/7/Q7dQ+JiFRS\nnWgRQLiN4D/yvmZNl340fP21WoxMRCQz5VSLAMIidIcd3YgnVx4JH30UdzgiIlmjziQCgNFj6jGh\n+S/gscfiDkVEJGvUqUQwaBAs/+pwVj24KO5QRESyRp1KBI0bw88urs/f3joRNmyIOxwRkaxQZiIw\nsw5mNs/M3jazt8zsqqh8vJmtM7Ml0ePMpGvGmdlqM1tpZmcklfc2s+XRa3cmlTc2s0ej8gVm1qk6\nX+jSMQ2YxHC+evTJ6ryNiEjOKK9FUAD8p7sfA/QFLjezboADf3T3ntHjGQAz6w4MAboD/YG7zPZs\nHTYBGOXuXYGuZtY/Kh8FbInK7wBurc4X6tIFjjtmJ1Pu2VqdtxERyRllJgJ33+juS6PjHcA7QLvo\n5VRTkgYCj7h7gbuvAd4D+pjZoUBzd0903k8Czo6OBwAPRMfTgVOr+F32GD3uICas6gebN1f3rURE\n6rwKjxGYWWegJ7AgKrrSzN40s4lm1iIqawusS7psHSFxlCxfT3FCaQd8DODuhcA2M2tZua+xtx+d\n05hPGnVhyf+9Up23ERHJCQ0qcpKZNQOmAVe7+w4zmwD8Jnr5t8DthC6eGjV+/Pg9x3l5eeTl5aU8\nr359+PlPNjDhH4255zcpTxERqXPy8/PJz8+v9HXl3llsZg2BmcAz7v6nFK93Bp5y92PNbCyAu98S\nvTYbuBFYC8xz925R+QXAKe4+OjpnvLsvMLMGwAZ3PzjF55R5Z3FJG9//km5H7GLNGjiw00EVvk5E\npK5Iy53F0UDvRGBFchKI+vwTzgGWR8dPAkPNrJGZdQG6AovcfSOw3cz6RO85DHgi6ZoR0fFg4Ply\nv10FHHJ4U05vu4IH/2dVOt5ORKTOKq9r6CTgImCZmS2Jym4ALjCzHoTZQx8ClwK4+wozmwKsAAqB\nMUk/48cA9wNNgFnuPjsqnwhMNrPVwBZgaDq+GMDoETu58i9HMsbDctUiIrKvOrPoXCr++Va6t97E\nPU+35+T+TWsoMhGRzJRzi86lYge14LKj8pnwW00jFREpTZ1OBADDRzdl1usH8+9/xx2JiEhmqvOJ\n4KALz+RcHuO+u3fFHYqISEaq84mAVq0Y3WMBf/trAbt3xx2MiEjmqfuJADhu5LG0KtzEnDlxRyIi\nknlyIhFwzjmM/uoOJvxVTQIRkZJyIxG0acPQXqt49aXd2sVSRKSE3EgEQNMhZ3FRu3zuuSfuSERE\nMkudvqFsL+vXs6L7YE7d/zXWrjUaNUpfbCIimUg3lJXUrh3du8NRB3/OjBlxByMikjlyJxEADB7M\n6NZTmTAh7kBERDJH7nQNAXz4IbuOO4mODdaTn28cfXR6YhMRyUTqGkqlSxcadW7LqFPXcvfdcQcj\nIpIZcisRAAwaxM/rT+TBB2HnzriDERGJX04mgk5z7+WEvs4//xl3MCIi8cu9RHDkkXDwwVx2ygoN\nGouIkIuJAGDQIPqvn8jmzfCvf8UdjIhIvHIzEQweTP3Hp3Hpz12tAhHJebmZCLp3h/33Z1SvJTz2\nGGzdGndAIiLxyc1EYAaDB/OtF/5J//4waVLcAYmIxCc3EwHAoEEwfTqjL3Puvhuy5L46EZG0y91E\n0KMHACc3X4oZvPhizPGIiMQkdxOBGQwahD02ncsuQ4PGIpKzykwEZtbBzOaZ2dtm9paZXRWVtzSz\nuWa2yszmmFmLpGvGmdlqM1tpZmcklfc2s+XRa3cmlTc2s0ej8gVm1qkmvmhKgwbBtGkMH+bMmQMb\nN9baJ4uIZIzyWgQFwH+6+zFAX+ByM+sGjAXmuvuRwPPR35hZd2AI0B3oD9xlZokFjyYAo9y9K9DV\nzPpH5aOALVH5HcCtaft25Tn+eNi5kwPXr2DwYJg4sdY+WUQkY5SZCNx9o7svjY53AO8A7YABwAPR\naQ8AZ0fHA4FH3L3A3dcA7wF9zOxQoLm7L4rOm5R0TfJ7TQdOre6XqjAzOPfcMGg8Gu65B3ZrW2MR\nyTEVHiMws85AT2Ah0MbdN0UvbQLaRMdtgXVJl60jJI6S5eujcqLnjwHcvRDYZmYtK/MlqmXwYJg2\njV694JBD4Jlnau2TRUQyQoOKnGRmzQi/1q929y+Ke3vA3d3MamXy5fjx4/cc5+XlkZeXV/03PfFE\n2LwZVq9m9OiuTJgAZ51V/bcVEalt+fn55OfnV/q6cjemMbOGwEzgGXf/U1S2Eshz941Rt888dz/a\nzMYCuPst0XmzgRuBtdE53aLyC4BT3H10dM54d19gZg2ADe5+cIo4qr8xTWnGjIGOHfnq6rF06ACv\nvw5dutTMR4mI1Ja0bEwTDfROBFYkkkDkSWBEdDwCmJFUPtTMGplZF6ArsMjdNwLbzaxP9J7DgCdS\nvNdgwuBz7Yq6h5o0gWHDwliBiEiuKLNFYGbfB14ClgGJE8cBi4ApQEdgDXC+u2+NrrkBGAkUErqS\nno3KewP3A02AWe6emIraGJhMGH/YAgyNBppLxlJzLYLCQjj0UHj9dd79pjOnnAIffQSNG9fMx4mI\n1IaKtghya8/islxyCRx9NFx7LaeeChdfDBdcUHMfJyJS07RncWVFaw8BjB6tO41FJHeoRZCwa1eY\nP7psGQVt2tOpE8ydC8ccU3MfKSJSk9QiqKxGjeAnP4HHH6dhw9A1dPfdcQclIlLzlAiSJXUPXXIJ\nPPQQ7NgRc0wiIjVMiSDZGWfA0qWwaRMdOsDJJ8Mjj8QdlIhIzVIiSLbffnDmmTAj3BaRGDTOkmEU\nEZEqUSIoKbq5DEIDYds2WLSonGtERLKYZg2V9OWX0LYtfPABtGrFbbfBihVw//01/9EiIumkWUNV\n1bQpnH46PBFWwPjZz0JP0WefxRyXiEgNUSJIJWn20MEHh9VIH3ignGtERLKUuoZS2b4d2reHjz+G\nAw/k1Vdh5EhYuTLsZSMikg3UNVQdBxwAeXnw1FNA2LKgUSN44YV4wxIRqQlKBKVJ6h4y0/pDIlJ3\nqWuoNJ9/Dp06wSefQLNmbN8e/nz77TCpSNJv1y5YsgRefhleeSWs8/T738cdlUj2UtdQdR10UOgT\nmjULCL1FQ4bAvffGHFcdsn07zJkD//M/8MMfQsuWcOmlsGZNuJ3j3nvhrbfijlKk7lOLoCx//zs8\n9xw8+igAb74ZZhB9+CE0qNBuz5Lsk0/CL/3EY9Uq6N0bvv/98DjhBGjRovj8O+8MieLpp+OLWSSb\naWOadNi8GY44AjZuhCZNgNBI+O//hoEDazeUbOMO775b3M3zyiuwdSucdFJxxd+7d9m7wO3aBd26\nhXzcr1/txS5SVygRpEu/fnDVVXD22QBMnhxWJZ09u/ZDyWQl+/dffRWaNSuu9L///VCp16tkZ+SU\nKXDrrfD665W/ViTXKRGky113wWuvwYMPAvD119ChAyxYAIcfXvvhZIrt28O/QaLif/310HhKrvjb\nt6/+57hD375w9dVw4YXVfz+RXKJEkC4bNkD37qF7KOrHuO668Ov0tttqP5y4bNgQKvxExV9e/346\nvfQSDB8ebujbb7+a+QyRukiJIJ1OPhnGjoUf/xiA1atDX/dHH9XNiinRv59c8Ve2fz/dBg6EU06B\na6+tvc8UyXZKBOn0pz/BsmVw3317is44I/xKveiieEJKp0T/fqLiT1f/fjqtXBny8bvvhmmmIlI+\nJYJ0+ugj6NUr9I80bAjA44/D7beHyjPbJPr3ExV/TfXvp9tll4UE9Yc/xB2JSHZIWyIws/uAHwP/\ndvdjo7LxwMXA5ui0G9z9mei1ccBIYDdwlbvPicp7A/cD+wGz3P3qqLwxMAnoBWwBhrj72hRxxJcI\nAPr0gd/9LixRDRQWQufO8MwzcOyx8YVVHvfQn79gAcyfH57fe6/2+vfTaePGcLfx4sXh315EypbO\nRHAysAOYlJQIbgS+cPc/lji3O/AwcBzQDngO6OrubmaLgCvcfZGZzQL+7O6zzWwM8G13H2NmQ4Bz\n3H1oijjiTQS33RY2q7n77j1FN90EmzaFiUWZYvv2sKNaotJfsCD8ij7hhPDo2xd69Kjd/v10uumm\nkNgeeijuSEQyX1q7hsysM/BUiUSww91vL3HeOKDI3W+N/p4NjAfWAi+4e7eofCiQ5+6XRefc6O4L\nzawBsMHdD04RQ7yJ4P33w91kn3wC9esDsH59aA2sXQvNm9d+SEVFoc88UenPnx/ueu7Zs7jS79u3\nbq2NtGMHHHlkWBi2d++4oxHJbBVNBNVZKOFKMxsO/Au41t23Am2BBUnnrCO0DAqi44T1UTnR88cA\n7l5oZtvMrKW7Z9aeYIcfHmrUV16BH/wAgHbtwmrVDz0U+q9r2tatsHBhcaW/cGFYEilR6f/85/Cd\n74Qls+uqZs3gxhvh+uvh+ee1P4RIOlQ1EUwAfhMd/xa4HRiVlojKMH78+D3HeXl55OXl1fRH7i2x\nsX2UCCAsT33ddWGxtHRWSkVFYa/k5L79jz4Kv4L79g2fe//9cMgh6fvMbDFqVJjI9cwz8KMfxR2N\nSObIz88nPz+/0tdVqWuotNfMbCyAu98SvTYbuJHQNTQvqWvoAuAUdx+d6D5y9wUZ3TUEYQ7jqaeG\nncuiuZRFRXDUUWEryxNPrPpbf/ZZ+IWfqPQXLQrbZPbtW/yL/zvf0WJ3CU8+CTfcEBYCjHrqRKSE\nGl2G2swOTfrzHGB5dPwkMNTMGplZF6ArsMjdNwLbzayPmRkwDHgi6ZoR0fFg4PmqxFQrjj46TK9Z\nuHBPUb16oVsoaQy5XLt3h9sS/vY3+NnPwtt27gz/7/+F1668Mty0tnp1WNtozJgwe1VJoNhPfgKt\nWoVWkYhUT0VmDT0C/ABoDWwi/MLPA3oADnwIXOrum6LzbyBMHy0Ernb3Z6PyxPTRJoTpo1dF5Y2B\nyUBPwvTRoe6+JkUc8bcIIHRQ79gRbiKIbNkS5uG/916onEr69NPiGTzz54d5+4ceuvdMnm9/W79s\nK2vRIjj33DBg3rRp3NGIZB7dUFZTli2DAQPC9JykQYERI8IMomuuCZupzJ9f3M2zaRMcf3xxpd+n\nT+qEIZU3ZEjoMvvlL+OORCTzKBHUFPcwKPDww/C97+0pXrBgz71mdOhQ3Ld/wglheQb92q8Z778f\nEuuKFfCtb8UdjUhmUSKoSePGhef//d+9ihNLNRx0UAwx5bBrrgl3ev/lL3FHIpJZlAhq0uLFMHRo\nuMVVE9nTr7AwrD194IEVumvs009Dq+vVV8PNZiISaPP6mtSrV6isli8v/1ypmMLCcIfYZZeFO/Wu\nuw7694d33in30tatw+mJhpqIVI4SQVWYhekq06fHHUl2KyyEuXPDLdFt24bNoLt0CTvCvfFG6Hob\nPBi+/LLct7rqqtA199prtRC3SB2jrqGqmj8fLr4Y3n477kiyS0EBzJsHU6fCjBmh4j/vvFDhd+my\n97nu4UaLwsJwQ0U53XCTJoX7OV59VT12IqAxgppXVBSmBz33XOigltIVFIRun2nTQuV/+OHFlX95\n60nv3BmmYF1+eVjHowy7d4chhV//OjTYRHKdEkFtuOqqMGfxV7+KO5LMs2tXqPynToUnngijuOed\nB4MGQadOlXuvVavCPpmzZ5c7eDxnDlxxRWioRXsIieQsJYLa8OKLYe7ikiVxR5IZdu0KLaSpU8Ni\nQEcdVVz5d+xYvfeeOjWMISxeXO783P/4j3DP3+WXV+8jRbKdEkFt2L07DHK+9lro7shF33wTBnyn\nTg2bBHTrVlz5d+iQ3s+65pqwOdCMGWVuoLx0aZhwtGoVHHBAekMQySZKBLXlssvgsMPgF7+IO5La\n8803oQ9m6lSYOTPsH5mo/Nu1K//6qtq1KywBfs455f57jxgR8tDvfldz4YhkOiWC2jJ3bhgjSFqR\ntE76+mt49tkw4DtzZlhYKVH51+YWaB9/DMcdB48+ute+EKlO69EjLA1Vk7lJJJMpEdSWgoKwlOgb\nb1S/HzzTfP11GKCdOhVmzYLvfjdU/ueeG75zXJ59FkaODOMFZezMM3YsbN4MEyfWYmwiGUSJoDaN\nHBmWwLzmmrgjqb6vvtq78u/Zs7jyz6Tt0MaPD4P1c+eWulHDtm1hstJzz4UGjEiuUSKoTbNmwc03\nh/2Ms9HOnWHfx2nTwnOvXsWVf5s2cUeX2u7dcOaZYQXYm28u9bQ77wzDGU8/XYuxiWQIJYLa9M03\n4dfyihXxdplUxs6dIYFNnRq6Wr73vVD5n3NO9qznvHlzSFoTJsBZZ6U8ZdeuMJHp73+Hfv1qOT6R\nmCkR1LaLLgqbFo8ZUzPvX1QUEs4334S++6+/rthxqtfWrAk/k48/vrjyP3ifbaKzw6uvhpbLggX7\nLlERmTIFbr01rEVUxqxTkTpHiaC2zZgBv/89XHtt9Srp0o537YLGjWG//Yqfq3rcpg38+Mdh2c66\n4I474KGHQlJo3Hifl93D5jVXXw0//WkM8YnERImgtn31Vdh1fseOsivjqlbijRppJbXSuId1i9q0\ngbvuSnnKSy/B8OGwcmX45xTJBUoEklu2bQvjHDfdBBdemPKUgQPhlFNCo00kFygRSO5580047bQw\nrbR7931efuedkAjefRdatowhPpFaph3KJPd897tw222hm2jHjn1e7tYt3AhdxmxTkZykFoHUPaNG\nhTGbhx7aZ1xl48awNNLixeVvhSCS7dLWIjCz+8xsk5ktTypraWZzzWyVmc0xsxZJr40zs9VmttLM\nzkgq721my6PX7kwqb2xmj0blC8yskovVi5Twl7+EDQnuvnuflw45JIzp//KXMcQlkqEq0jX0D6B/\nibKxwFx3PxJ4PvobM+sODAG6R9fcZbbnJ9kEYJS7dwW6mlniPUcBW6LyO4Bbq/F9RKBJk3CX9I03\nhpsHSrjuurBb5uLFMcQmkoHKTQTu/jLweYniAcAD0fEDwNnR8UDgEXcvcPc1wHtAHzM7FGju7oui\n8yYlXZP8XtOBU6vwPUT21rVraBGcfz589tleLzVrFnLE9deHmaciua6qg8Vt3H1TdLwJSCxI0xZY\nl3TeOqBdivL1UTnR88cA7l4IbDMzzemQ6jv33PAYPjzcmZ1k1CjYsCEsrSSS61Iv21gJ7u5mViu/\nq8aPH7/nOC8vj7y8vNr4WMlmt9wCeXlhjYlx4/YUN2gQin7xi7C1Zf368YUoki75+fnk5+dX+roK\nzRoys87AU+5+bPT3SiDP3TdG3T7z3P1oMxsL4O63ROfNBm4E1kbndIvKLwBOcffR0Tnj3X2BmTUA\nNrj7PgvfaNaQVNm6dWEzm4cfhh/+cE+xe9jbZsSI0EIQqWtq+j6CJ4ER0fEIYEZS+VAza2RmXYCu\nwCJ33whsN7M+0eDxMOCJFO81mDD4LJI+7dvDpElhoaENG/YUm8Ef/hDGC778Msb4RGJWbovAzB4B\nfgC0JowH/JpQiU8BOgJrgPPdfWt0/g3ASKAQuNrdn43KewP3A02AWe5+VVTeGJgM9AS2AEOjgeaS\ncahFINXzm9/A88+HR9JmNkOGhI1rfvWrGGMTqQFaYkKkpKIi+NGPwmbGt9yyp/j998PqpCtWZM9W\nDCIVoUQgksqnn0Lv3vB//wcDBuwpvuYaKCwM96KJ1BVKBCKlWbAgLEU6fz4cdhgQ8kO3bmFLgyOP\njDk+kTTRonMipenbN6wxcd55YeMfwh4911231wxTkZyhFoHkJvcwStyy5Z41ib76Co46Ch55BE46\nKeb4RNJALQKRspjBvfeGRYcmTwbCEkW//a2WnpDco0QgueuAA8LidP/1X/DWWwBcdBHs3AmPPx5z\nbCK1SIlActuxx8Ltt4fNbL74gvr1w942Y8dCQUHcwYnUDiUCkeHDwx6WF18M7pxxRti05p574g5M\npHZosFgEwuyhE0+EkSPhiitYuhT694dVq0IPkkg20n0EIpX1/vtwwgkwcyYcfzwjRkCHDvC738Ud\nmEjVKBGIVMWMGeE248WL+XhnK3r0gGXLoF278i8VyTRKBCJVdf31Yc/jmTMZe0M9Nm+GiRPjDkqk\n8pQIRKqqoAD69YP+/dl6+S856ih47rkwwUgkmygRiFTHJ5/A974Hkydz51unMmcOPP103EGJVI7u\nLBapjrZtwx3Hw4YxesB6Vq6EF16IOyiRmqEWgUhZfv97mD2bRy+bx21/bMDrr0M9/XySLKEWgUg6\njBsHzZtz/pJx1K8fFqQTqWvUIhApz5Yt0Ls3L108ieH3nsLKlbDffnEHJVI+tQhE0qVVK5gyhVP+\nPJjvHvElf/1r3AGJpJdaBCIV9de/8s5fnueUT6fz7rtGy5ZxByRSNrUIRNJtzBi69WjMua1f5uab\n4w5GJH3UIhCpjC++YGOvH3HMhrn8a/l+dOkSd0AipVOLQKQmNG/OITPu5sqiP/OrK7bGHY1IWlQr\nEZjZGjNbZmZLzGxRVNbSzOaa2Sozm2NmLZLOH2dmq81spZmdkVTe28yWR6/dWZ2YRGrcMcdw3Z87\nMu/ZXSwR0C9QAAAKxElEQVR+cUfc0YhUW3VbBA7kuXtPdz8+KhsLzHX3I4Hno78xs+7AEKA70B+4\ny8wSTZYJwCh37wp0NbP+1YxLpEY1u3goN544l+vPW4MXqctSsls6uoZK9j8NAB6Ijh8Azo6OBwKP\nuHuBu68B3gP6mNmhQHN3XxSdNynpGpGMNWrWIDZsb8pN332M+XctYeeOorhDEqmSdLQInjOzf5nZ\nJVFZG3ffFB1vAtpEx22BdUnXrgPapShfH5WLZLQGzfbj4Vkt+GT/I7ji2sa0PuAbvn3wJoYP3Mqd\nd8LLL8MXX8QdpUj5GlTz+pPcfYOZHQzMNbOVyS+6u5tZ2trN48eP33Ocl5dHXl5eut5apEp69juI\nexYeBO7sWryct+98jsVPfcIb8/vwyP7fZ/m/D6FDR6NXL+jVC3r3hp49oUWL8t9bpLLy8/PJz8+v\n9HVpmz5qZjcCO4BLCOMGG6Nun3nufrSZjQVw91ui82cDNwJro3O6ReUXAD9w98tKvL+mj0p22L07\nLFX64IMUzHiald85nzeOGcbiet/jjWUNefNNaNOGfZJD69ZxBy51TY3vR2Bm+wP13f0LM2sKzAFu\nAk4Dtrj7rVHl38Ldx0aDxQ8DxxO6fp4DjohaDQuBq4BFwNPAn919donPUyKQ7PPll2H7ywcfhAUL\nYMAAdl84jFXtfsgbb9bnjTdg8WJYsiS0Enr33jtBtGlT/keIlKY2EkEX4PHozwbAQ+7+v2bWEpgC\ndATWAOe7+9bomhuAkUAhcLW7PxuV9wbuB5oAs9z9qhSfp0Qg2W3jRvjnP8M+Bxs3woUXwrBh8J3v\nUFQEH3wQkkIiObzxBjRpUpwUEs9t24KV+5925isshM8/D2v6JR7btkGfPnDkkXFHVzdohzKRTLZi\nRWglPPhgaAoMGxYSQ7vieRLusHZtcVJIJAizvZNDr17QqVO8yWHnzr0r9MTj009Tl2/ZEgbSDzww\nrOnXqlXoGmvaFF56Cb71LTjvvPBQUqg6JQKRbFBUFKYXTZ4Mjz0WavWLLoJBg6B5831Od4f16/du\nNSxeDN98s29yOPzwyieHoqJ9f6VX5AHFFXpFHy1aQP36+8aweze8+ipMnQrTpoXusfPPD0mha9cq\n/BvnMCUCkWzz1Vcwc2ZICi++CD/+cUgKZ5wBDcqe4LdhQxhnSG49bNsWBqETiaFx4/Ir9K1bQ/4p\nWWm3bl12pb7//jXzT7J7N7zySnFSOPTQ4qRwxBE185l1iRKBSDb79FN49NHQdfTBBzB0aOg+6t27\nwj/zN28OySGRGHbvLv9X+kEHlZtzYpNIClOmwPTpYawk0X2kpJCaEoFIXbF6dfF4QqNGoZXw059C\n585xRxab3btDj9rUqcVJIdFSOPzwuKPLHEoEInWNO8yfHxLClCnQvXtoJZx3Xk7foZZIClOmhGGW\ndu2KWwq5nhSUCETqsl27YNaskBTmzoXTTw9J4cwzQ6shR+3eHWYdJVoK7dsXtxQOOyzu6GqfEoFI\nrvj88zCSOnlymJZ6/vmh++iEE+rGDQdVlEgKiZZChw7FLYU6kxTc4euvw1zcFA8bOlSJQCTnfPgh\nPPxwSAqFhcXjCTk+77KwcO+WQseOxS2FWt9lrrBw30p7+/ZSK/NyHw0ahKleKR42daoSgUjOcg9z\nSSdPDnczd+kCRx+99+slz6/o39W5trz3gtCKqV8f6tVL/VzWa+U9169PodfnxQ86MHXJETy2tAud\nWu3g/OPXcl6fj+h8yNcV+1wIy4dUpeIuKIBmzfatuA84oNQKvcxHw4b7/hvu+adU15CIQKh4Xngh\n3GyQrGS3UWX+rs615b1XUVF47N6973Oqsoo+pygrLIQX1x/BlA+P4/GPetG56aec134+57V7jc77\nbSz92qKicBt0VSrwJk1qrctOiUBEpBIKCyE/P3QfPfZYaEQluo86dYo7utK5hzvLd+wIjZTk51NP\nVSIQEamSRFKYMgUefzwMLicGmquaFNzDzeOpKuzEc1mvlXVuw4ahgdKs2d7P+flKBCIi1VZQUNxS\nSCSFAQPCkh2VqbC//DJc06zZvhV2qufKnFPa3eDqGhIRSbOCApg3D2bPDt38lanImzZNvcheTVIi\nEBHJcRVNBNXdvF5ERLKcEoGISI5TIhARyXFKBCIiOU6JQEQkxykRiIjkOCUCEZEclzGJwMz6m9lK\nM1ttZv8ddzwiIrkiIxKBmdUH/gL0B7oDF5hZt3ijqrz8/Py4Q6gQxZleijO9FGfty4hEABwPvOfu\na9y9APgnMDDmmCotW/6PoTjTS3Gml+KsfZmSCNoBHyf9vS4qExGRGpYpiUCLCImIxCQjFp0zs77A\neHfvH/09Dihy91uTzok/UBGRLJM1q4+aWQPgXeBU4BNgEXCBu78Ta2AiIjmglO0Mape7F5rZFcCz\nQH1gopKAiEjtyIgWgYiIxCdTBotLlQ03mpnZfWa2ycyWxx1LWcysg5nNM7O3zewtM7sq7phSMbP9\nzGyhmS2N4hwfd0ylMbP6ZrbEzJ6KO5bSmNkaM1sWxbko7nhKY2YtzGyamb1jZiuiscOMYmZHRf+O\nice2DP7v6D+j/36Wm9nDZta41HMzuUUQ3Wj2LnAasB54nQwcOzCzk4EdwCR3PzbueEpjZocAh7j7\nUjNrBiwGzs60f08AM9vf3XdG40evAFe7+8K44yrJzP4L6A00d/cBcceTipl9CPR298/ijqUsZvYA\n8KK73xf9797U3bfFHVdpzKweoV463t0/Lu/82mRm7YCXgW7u/o2ZPQrMcvcHUp2f6S2CrLjRzN1f\nBj6PO47yuPtGd18aHe8A3gHaxhtVau6+MzpsBDQEimIMJyUzaw/8CLgXKHdmRswyOj4zOxA42d3v\ngzBumMlJIHIa8H6mJYEkDYD9o6S6PyFppZTpiUA3mtUQM+sM9AQy7lc2hF9bZrYU2ATMcffX444p\nhTuA68nAJFWCA8+Z2b/M7JK4gylFF2Czmf3DzN4ws7+b2f5xB1WOocDDcQeRiruvB24HPiLMxNzq\n7s+Vdn6mJ4LM7bfKYlG30DRCd8uOuONJxd2L3L0H0B7oY2bHxB1TMjM7C/i3uy8hw39tAye5e0/g\nTODyqCsz0zQAegF3uXsv4EtgbLwhlc7MGgE/AabGHUsqZnYQMADoTGj1NzOzn5Z2fqYngvVAh6S/\nOxBaBVJFZtYQmA486O4z4o6nPFH3wDzCgoSZ5ERgQNT//gjQz8wmxRxTSu6+IXreDDxO6HLNNOuA\ndUktv2mExJCpzgQWR/+mmeg04EN33+LuhcBjhP/PppTpieBfQFcz6xxl4CHAkzHHlLXMzICJwAp3\n/1Pc8ZTGzFqbWYvouAlwOmE8I2O4+w3u3sHduxC6CF5w9+Fxx1WSme1vZs2j46bAGUDGzW5z943A\nx2Z2ZFR0GvB2jCGV5wLCD4BMtRboa2ZNov/uTwNWlHZyRtxQVppsudHMzB4BfgC0MrOPgV+7+z9i\nDiuVk4CLgGVmtiQqG+fus2OMKZVDgQeiWWP1gEfdfVbMMZUnU7sx2wCPh7qABsBD7j4n3pBKdSXw\nUPSj733gZzHHk1KUUE8DMnW8BXdfZGbTgDeAwuj5ntLOz+jpoyIiUvMyvWtIRERqmBKBiEiOUyIQ\nEclxSgQiIjlOiUBEJMcpEYiI5DglAhGRHKdEICKS4/4/QGQBkLhV900AAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x1102cbe90>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print \"Final Train cost: {}, on Epoch {}\".format(train_score[-1],k)\n",
"print \"Final Validation cost: {}, on Epoch {}\".format(val_score[-1],k)\n",
"plt.plot(train_score, 'r-', val_score, 'b-')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"##This part generates a new validation set to test against\n",
"val_score_v =[]\n",
"num_epochs=1\n",
"\n",
"for k in range(num_epochs):\n",
"\n",
" #Generate Data for each epoch\n",
" tempX,y = gen_data(5,seq_len,batch_size)\n",
" X = []\n",
" for i in range(seq_len):\n",
" X.append(tempX[:,i,:])\n",
"\n",
" val_dict = {inputs[i]:X[i] for i in range(seq_len)}\n",
" val_dict.update({result: y})\n",
" outv, c_val = sess.run([outputs3,cost],feed_dict = val_dict ) \n",
" val_score_v.append([c_val])\n",
"#print \"Validation cost: {}, on Epoch {}\".format(c_val,k)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(array([[8],\n",
" [2],\n",
" [8],\n",
" [8],\n",
" [9],\n",
" [6],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0],\n",
" [0]]), 41.0)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"##Target\n",
"tempX[3],y[3]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 44.25109482], dtype=float32)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Prediction\n",
"outv[3]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@genho

This comment has been minimized.

Copy link

@genho genho commented Apr 21, 2016

Hi Rajiv, thank you so much for your example.
May I know the reason for using seq2seq.rnn_decoder() in your code? I've tried search on the web and many of the examples are related to language translation. And I really cannot find documentation which talks about the TensorFlow seq2seq class.
Thanks a lot.

@rajshah4

This comment has been minimized.

Copy link
Owner Author

@rajshah4 rajshah4 commented Apr 30, 2016

I found the documentation deep in the tensorflow ops code
It explains how the decoder operates. Does this help?

@genho

This comment has been minimized.

Copy link

@genho genho commented May 9, 2016

Got it. Thank you so much 👍

@JackMedley

This comment has been minimized.

Copy link

@JackMedley JackMedley commented Aug 22, 2016

Hi Rajiv,
Can I ask what the purpose of the dropout layer is in a problem such as this? When training for something like addition don't we need to know all of the inputs?
Thanks,
Jack

@rajshah4

This comment has been minimized.

Copy link
Owner Author

@rajshah4 rajshah4 commented Aug 27, 2016

Hmm, its a good question. This was one of my first RNNs and I just grabbed code from other projects. I am thinking that it would work like dropout generally, it would help against overfitting and get a better sense of how addition works. If you have the time, I would be curious if you played around with the dropout whether it works like that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment