Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save denny0323/80c0a9692af6416756cedcb61cafb485 to your computer and use it in GitHub Desktop.
Save denny0323/80c0a9692af6416756cedcb61cafb485 to your computer and use it in GitHub Desktop.
6_Partial Observability and Deep Recurrent Q-Networks
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Deep Recurrent Q-Network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook provides an example implementation of a Deep Recurrent Q-Network which can solve Partially Observable Markov Decision Processes."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import random\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import scipy.misc\n",
"import os\n",
"import csv\n",
"import itertools\n",
"import tensorflow.contrib.slim as slim\n",
"%matplotlib inline\n",
"\n",
"from helper import * # helper is the package including updateing tool for target network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the game environment"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from gridworld import gameEnv"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Feel free to adjust the size of the gridworld. Making it smaller (adjusting size) provides an easier task for our DRQN agent, while making the world larger increases the challenge.\n",
"\n",
"Initializing the Gridworld with __True__ limits the field of view, resulting in a partially observable MDP. \n",
"Initializing it with __False__ provides the agent with the entire environment, resulting in a fully MDP."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#MDP\n",
"env = gameEnv(partial=False,size=9)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAADMtJREFUeJzt3W+sZHV9x/H3p3tBBGuW5V+2LO1lE4KQpoDZKJQ+sCAtpQZ8oC1EG9PQ8oSm0Jro0j4pTZpo0vjnQWNCQEsayx9XWgkx2s2KaZo0K4tgCywI6BaurOwSQNQHpKvfPpiz9JbOMufunZl7D7/3K7mZOeeeued35uQz55yZud9vqgpJbfmFtR6ApPkz+FKDDL7UIIMvNcjgSw0y+FKDDL7UoFUFP8nlSZ5I8lSS7dMalKTZytF+gSfJBuC7wGXAEvAAcE1VPTa94UmahYVVPPZdwFNV9T2AJHcCVwFHDP7JJ59ci4uLq1ilpDeyb98+XnjhhUxabjXBPx14dtn0EvDuN3rA4uIie/bsWcUqJb2Rbdu29VpuNdf4415V/t91Q5LrkuxJsufgwYOrWJ2kaVlN8JeAM5ZNbwGee/1CVXVLVW2rqm2nnHLKKlYnaVpWE/wHgLOSnJnkWOBq4N7pDEvSLB31NX5VHUryJ8DXgQ3A56vq0amNTNLMrObNParqq8BXpzQWSXPiN/ekBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBk0MfpLPJzmQ5JFl8zYl2Znkye72xNkOU9I09Tni/z1w+evmbQd2VdVZwK5uWtJATAx+Vf0r8OLrZl8F3N7dvx14/5THJWmGjvYa/7Sq2g/Q3Z46vSFJmrWZv7lnJx1p/Tna4D+fZDNAd3vgSAvaSUdaf442+PcCH+nufwT4ynSGI2ke+nycdwfw78DZSZaSXAt8ArgsyZPAZd20pIGY2Emnqq45wq8unfJYJM2J39yTGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGmTwpQYZfKlBBl9qkMGXGtSn9NYZSe5PsjfJo0lu6ObbTUcaqD5H/EPAR6vqHOBC4Pok52I3HWmw+nTS2V9V3+7u/xjYC5yO3XSkwVrRNX6SReACYDc9u+nYUENaf3oHP8nbgC8DN1bVK30fZ0MNaf3pFfwkxzAK/Rer6p5udu9uOpLWlz7v6ge4DdhbVZ9a9iu76UgDNbGhBnAx8AfAfyZ5uJv3F4y659zdddZ5BvjgbIYoadr6dNL5NyBH+LXddKQB8pt7UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDTL4UoMMvtQggy81yOBLDUpVzW9lyfxWNng+VeMdqSaMDquqiU9Sn5p7xyX5VpLvdJ10bu7mn5lkd9dJ564kx05j0JJmr8+p/qvAJVV1HnA+cHmSC4FPAp/uOum8BFw7u2FKmqY+nXSqqn7STR7T/RRwCbCjm28nHWlA+tbV39BV2D0A7ASeBl6uqkPdIkuM2mqNe+xrnXSmMWBJq9cr+FX1s6o6H9gCvAs4Z9xiR3jsa510jn6YkqZpRR/nVdXLwDcZdc3dmORwee4twHPTHZqkWenzrv4pSTZ2998KvJdRx9z7gQ90i9lJRxqQiZ/jJ/k1Rm/ebWD0QnF3Vf11kq3AncAm4CHgw1X16oS/5YfTvflUjefn+JP0+RzfL/CsWz5V4xn8SabyBR5Jbz4GX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUG9g9+V2H4oyX3dtJ10pIFayRH/BkZFNg+zk440UH0bamwBfhe4tZsOdtKRBqvvEf8zwMeAn3fTJ2EnHWmw+tTVfx9woKoeXD57zKJ20pEGYmHyIlwMXJnkCuA44O2MzgA2Jlnojvp20pEGpE+33JuqaktVLQJXA9+oqg9hJx1psFbzOf7HgT9P8hSja/7bpjMkSbNmJ511y6dqPDvpTGInHUljGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxpk8KUGGXypQQZfapDBlxrUp+YeSfYBPwZ+Bhyqqm1JNgF3AYvAPuD3quql2QxT0jSt5Ij/m1V1/rJquduBXV1DjV3dtKQBWM2p/lWMGmmADTWkQekb/AL+JcmDSa7r5p1WVfsButtTZzFASdPX6xofuLiqnktyKrAzyeN9V9C9UFw3cUFJc7PiKrtJ/gr4CfDHwHuqan+SzcA3q+rsCY+1dGxvPlXjWWV3kqlU2U1yQpJfPHwf+C3gEeBeRo00wIYa0qBMPOIn2Qr8Uze5APxjVf1NkpOAu4FfBp4BPlhVL074Wx7GevOpGs8j/iR9jvg21Fi3fKrGM/iT2FBD0lgGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUF9/ztPc+c31DQ7HvGlBhl8qUEGX2qQwZcaZPClBhl8qUEGX2pQr+An2ZhkR5LHk+xNclGSTUl2Jnmyuz1x1oOVNB19j/ifBb5WVe8AzgP2YicdabD6FNt8O/AdYGstWzjJE1heW1p3plVzbytwEPhCkoeS3NqV2baTjjRQfYK/ALwT+FxVXQD8lBWc1ie5LsmeJHuOcoySpqxP8JeApara3U3vYPRC8Hx3ik93e2Dcg6vqlqratqzLrqQ1NjH4VfVD4Nkkh6/fLwUew0460mD1aqiR5HzgVuBY4HvAHzJ60bCTjrTO2ElHapCddCSNZfClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaZPClBhl8qUEGX2qQwZcaNDH4Sc5O8vCyn1eS3GgnHWm4VlR6K8kG4AfAu4HrgRer6hNJtgMnVtXHJzze0lvSjM2i9NalwNNV9V/AVcDt3fzbgfev8G9JWiMrDf7VwB3dfTvpSAPVO/hJjgWuBL60khXYSUdaf1ZyxP8d4NtV9Xw3bScdaaBWEvxr+N/TfLCTjjRYfTvpHA88y6hV9o+6eSdhJx1p3bGTjtQgO+lIGsvgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNahX8JP8WZJHkzyS5I4kxyU5M8nurpPOXV0VXkkD0KeF1unAnwLbqupXgQ2M6ut/Evh0VZ0FvARcO8uBSpqevqf6C8BbkywAxwP7gUuAHd3v7aQjDcjE4FfVD4C/ZVRJdz/wI+BB4OWqOtQttgScPqtBSpquPqf6JzLqk3cm8EvACYyaa7ze2Aq6dtKR1p+FHsu8F/h+VR0ESHIP8OvAxiQL3VF/C/DcuAdX1S3ALd1jLa8trQN9rvGfAS5McnySMOqY+xhwP/CBbhk76UgD0reTzs3A7wOHgIeAP2J0TX8nsKmb9+GqenXC3/GIL82YnXSkBtlJR9JYBl9qkMGXGmTwpQb1+Rx/ml4AftrdvlmcjNuzXr2ZtgX6bc+v9PlDc31XHyDJnqraNteVzpDbs369mbYFprs9nupLDTL4UoPWIvi3rME6Z8ntWb/eTNsCU9yeuV/jS1p7nupLDZpr8JNcnuSJJE8l2T7Pda9WkjOS3J9kb1d/8IZu/qYkO7vagzu7+gWDkWRDkoeS3NdND7aWYpKNSXYkebzbTxcNef/Mstbl3IKfZAPwd4yKeJwLXJPk3HmtfwoOAR+tqnOAC4Hru/FvB3Z1tQd3ddNDcgOwd9n0kGspfhb4WlW9AziP0XYNcv/MvNZlVc3lB7gI+Pqy6ZuAm+a1/hlsz1eAy4AngM3dvM3AE2s9thVswxZGYbgEuA8Ioy+ILIzbZ+v5B3g78H26962WzR/k/mH0b+/PMvq394Vu//z2tPbPPE/1D2/IYYOt05dkEbgA2A2cVlX7AbrbU9duZCv2GeBjwM+76ZMYbi3FrcBB4AvdpcutSU5goPunZlzrcp7BH/c/woP7SCHJ24AvAzdW1StrPZ6jleR9wIGqenD57DGLDmUfLQDvBD5XVRcw+mr4IE7rx1ltrctJ5hn8JeCMZdNHrNO3XiU5hlHov1hV93Szn0+yufv9ZuDAWo1vhS4Grkyyj1ElpUsYnQFs7Mqow7D20RKwVFW7u+kdjF4Ihrp/Xqt1WVX/DfyfWpfdMke9f+YZ/AeAs7p3JY9l9EbFvXNc/6p09QZvA/ZW1aeW/epeRjUHYUC1B6vqpqraUlWLjPbFN6rqQwy0lmJV/RB4NsnZ3azDtSEHuX+Yda3LOb9hcQXwXeBp4C/X+g2UFY79NxidVv0H8HD3cwWj6+JdwJPd7aa1HutRbNt7gPu6+1uBbwFPAV8C3rLW41vBdpwP7On20T8DJw55/wA3A48DjwD/ALxlWvvHb+5JDfKbe1KDDL7UIIMvNcjgSw0y+FKDDL7UIIMvNcjgSw36H0A1AqEut0z3AAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x28617d58e48>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# POMDP\n",
"env = gameEnv(partial=True,size=9)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Above are examples of a starting environment in our simple game. The agent controls the blue square, and can move __up__, __down__, __left__, or __right__. \n",
"The goal is to move to __the green squares__ _(for +1 reward)_ and avoid the __red squares__ _(for -1 reward)_. \n",
"\n",
"When the agent moves through a green or red square, it is randomly moved to a new place in the environment."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Implementing the network itself"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class Qnetwork():\n",
" def __init__(self, h_size, rnn_cell, myScope):\n",
" # The network recieves a frame from the game, flattened into an array.\n",
" # It then resizes it and processes it through four convolutional layers.\n",
" self.scalarInput = tf.placeholder(shape=[None,21168],dtype=tf.float32)\n",
" self.imageIn = tf.reshape(self.scalarInput,shape=[-1,84,84,3])\n",
" \n",
" self.conv1 = slim.convolution2d(inputs=self.imageIn, num_outputs=32,kernel_size=[8,8], stride=[4,4], padding='VALID',\n",
" biases_initializer=None, scope=myScope+'_conv1')\n",
" \n",
" self.conv2 = slim.convolution2d(inputs=self.conv1, num_outputs=64, kernel_size=[4,4], stride=[2,2],padding='VALID',\n",
" biases_initializer=None, scope=myScope+'_conv2')\n",
" \n",
" self.conv3 = slim.convolution2d(inputs=self.conv2, num_outputs=64, kernel_size=[3,3],stride=[1,1],padding='VALID',\n",
" biases_initializer=None, scope=myScope+'_conv3')\n",
" \n",
" self.conv4 = slim.convolution2d(inputs=self.conv3, num_outputs=h_size, kernel_size=[7,7],stride=[1,1],padding='VALID',\n",
" biases_initializer=None, scope=myScope+'_conv4')\n",
" \n",
" # define the number of walking steps\n",
" self.trainLength = tf.placeholder(dtype=tf.int32) \n",
" \n",
" # We take the output from the final convolutional layer and send it to a recurrent layer.\n",
" # The input must be reshaped into [batch x trace x units] for rnn processing, \n",
" # and then returned to [batch x units] when sent through the upper levles.ss\n",
" self.batch_size = tf.placeholder(dtype=tf.int32, shape=[])\n",
" self.convFlat = tf.reshape(slim.flatten(self.conv4), [self.batch_size, self.trainLength, h_size])\n",
" \n",
" # initialize rnn hidden node as zeros\n",
" self.state_in = rnn_cell.zero_state(self.batch_size, tf.float32)\n",
" self.rnn, self.rnn_state = tf.nn.dynamic_rnn(\n",
" inputs=self.convFlat, cell=rnn_cell, dtype=tf.float32,initial_state=self.state_in,scope=myScope+'_rnn')\n",
" self.rnn = tf.reshape(self.rnn,shape=[-1, h_size])\n",
" \n",
" # The output from the recurrent player is then split into separate Value and Advantage streams\n",
" self.streamA, self.streamV = tf.split(self.rnn, 2, 1)\n",
" self.AW = tf.Variable(tf.random_normal([h_size//2, 4]))\n",
" self.VW = tf.Variable(tf.random_normal([h_size//2, 1]))\n",
" self.Advantage = tf.matmul(self.streamA, self.AW)\n",
" self.Value = tf.matmul(self.streamV, self.VW)\n",
" \n",
" self.salience = tf.gradients(self.Advantage,self.imageIn)\n",
" \n",
" # Then combine them together to get our final Q-values.\n",
" self.Qout = self.Value + tf.subtract(self.Advantage, tf.reduce_mean(self.Advantage, axis=1, keep_dims=True))\n",
" self.predict = tf.argmax(self.Qout,1)\n",
" \n",
" # Below we obtain the loss by taking the sum of squares difference between the target and prediction Q values.\n",
" self.targetQ = tf.placeholder(shape=[None], dtype=tf.float32)\n",
" self.actions = tf.placeholder(shape=[None], dtype=tf.int32)\n",
" self.actions_onehot = tf.one_hot(self.actions, 4, dtype=tf.float32)\n",
" \n",
" self.Q = tf.reduce_sum(tf.multiply(self.Qout, self.actions_onehot), axis=1)\n",
" \n",
" self.td_error = tf.square(self.targetQ - self.Q)\n",
" \n",
" # In order to only propogate accurate gradients through the network, we will mask the first\n",
" # half of the losses for each trace as per Lample & Chatlot 2016\n",
" self.maskA = tf.zeros([self.batch_size, self.trainLength//2])\n",
" self.maskB = tf.ones([self.batch_size, self.trainLength//2])\n",
" self.mask = tf.concat([self.maskA, self.maskB],1)\n",
" self.mask = tf.reshape(self.mask,[-1])\n",
" self.loss = tf.reduce_mean(self.td_error * self.mask)\n",
" \n",
" self.trainer = tf.train.AdamOptimizer(learning_rate=0.0001)\n",
" self.updateModel = self.trainer.minimize(self.loss)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Experience Replay"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# To store experies and sample then randomly to train the network.\n",
"class experience_buffer():\n",
" def __init__(self, buffer_size = 1000):\n",
" self.buffer = []\n",
" self.buffer_size = buffer_size\n",
" \n",
" # if over buffer size, erase from first index(FIFO)\n",
" def add(self, experience):\n",
" if len(self.buffer) + 1 >= self.buffer_size:\n",
" self.buffer[0:(1+len(self.buffer))-self.buffer_size] = []\n",
" self.buffer.append(experience)\n",
" \n",
" def sample(self, batch_size, trace_length):\n",
" sampled_episodes = random.sample(self.buffer, batch_size)\n",
" sampledTraces = []\n",
" for episode in sampled_episodes:\n",
" point = np.random.randint(0, len(episode)+1-trace_length)\n",
" sampledTraces.append(episode[point:point+trace_length])\n",
" sampledTraces = np.array(sampledTraces)\n",
" return np.reshape(sampledTraces, [batch_size*trace_length, 5])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training the network"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# setting training parameters\n",
"batch_size = 4 # How many experiences to use for each training step.\n",
"trace_length = 8 # How long each experience trace will be when training\n",
"update_freq = 5 # How often to perform a training step.\n",
"y = .99 # Discount factor on the target Q-values\n",
"\n",
"startE = 1 # Starting chance of random action\n",
"endE = 0.1 # Final chance of random action\n",
"\n",
"annealing_steps = 10000 # How many steps of training to reduce startE to endE.\n",
"num_episodes = 10000 # How many episodes of game environment to train network with.\n",
"pre_train_steps = 10000 # How many steps of random actions before training begins.\n",
"\n",
"max_epLength = 50 # The max allowed length of our episode.\n",
"load_model = False # Whether to load a saved model.\n",
"path = \"./drqn\" # The path to save our model to.\n",
"h_size = 512 # The size of the final convolutional layer before splitting it into Advantage and Value streams.\n",
"time_per_step = 1 # Length of each step used in gif creation\n",
"summaryLength = 100 # Number of epidoes to periodically save for analysis\n",
"tau = .001"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5000 0.63 1\n",
"10000 0.51 1\n",
"The Target network is updated.\n",
"15000 1.03 0.5499999999998275\n",
"The Target network is updated.\n",
"20000 2.21 0.09999999999985551\n",
"The Target network is updated.\n",
"25000 2.7 0.09999999999985551\n",
"The Target network is updated.\n",
"30000 2.64 0.09999999999985551\n",
"The Target network is updated.\n",
"35000 2.46 0.09999999999985551\n",
"The Target network is updated.\n",
"40000 3.19 0.09999999999985551\n",
"The Target network is updated.\n",
"45000 3.12 0.09999999999985551\n",
"The Target network is updated.\n",
"50000 3.02 0.09999999999985551\n",
"Saved Model\n",
"The Target network is updated.\n",
"55000 2.57 0.09999999999985551\n",
"The Target network is updated.\n",
"60000 3.02 0.09999999999985551\n",
"The Target network is updated.\n",
"65000 3.31 0.09999999999985551\n",
"The Target network is updated.\n",
"70000 2.99 0.09999999999985551\n",
"The Target network is updated.\n",
"75000 2.99 0.09999999999985551\n",
"The Target network is updated.\n",
"80000 3.36 0.09999999999985551\n",
"The Target network is updated.\n",
"85000 3.94 0.09999999999985551\n",
"The Target network is updated.\n",
"90000 4.26 0.09999999999985551\n",
"The Target network is updated.\n",
"95000 4.48 0.09999999999985551\n",
"The Target network is updated.\n",
"100000 5.05 0.09999999999985551\n",
"Saved Model\n",
"The Target network is updated.\n",
"105000 5.51 0.09999999999985551\n",
"The Target network is updated.\n",
"110000 5.1 0.09999999999985551\n",
"The Target network is updated.\n",
"115000 5.36 0.09999999999985551\n",
"The Target network is updated.\n",
"120000 5.24 0.09999999999985551\n",
"The Target network is updated.\n",
"125000 4.76 0.09999999999985551\n",
"The Target network is updated.\n",
"130000 4.59 0.09999999999985551\n",
"The Target network is updated.\n",
"135000 5.29 0.09999999999985551\n",
"The Target network is updated.\n",
"140000 5.32 0.09999999999985551\n",
"The Target network is updated.\n",
"145000 5.86 0.09999999999985551\n",
"The Target network is updated.\n",
"150000 4.81 0.09999999999985551\n",
"Saved Model\n",
"The Target network is updated.\n",
"155000 5.26 0.09999999999985551\n",
"The Target network is updated.\n",
"160000 5.89 0.09999999999985551\n",
"The Target network is updated.\n",
"165000 5.51 0.09999999999985551\n",
"The Target network is updated.\n",
"170000 5.04 0.09999999999985551\n",
"The Target network is updated.\n",
"175000 4.95 0.09999999999985551\n",
"The Target network is updated.\n",
"180000 5.17 0.09999999999985551\n",
"The Target network is updated.\n",
"185000 6.1 0.09999999999985551\n",
"The Target network is updated.\n",
"190000 5.46 0.09999999999985551\n",
"The Target network is updated.\n",
"195000 5.81 0.09999999999985551\n",
"The Target network is updated.\n",
"200000 5.51 0.09999999999985551\n",
"Saved Model\n",
"The Target network is updated.\n",
"205000 4.67 0.09999999999985551\n",
"The Target network is updated.\n",
"210000 5.25 0.09999999999985551\n",
"The Target network is updated.\n",
"215000 5.56 0.09999999999985551\n",
"The Target network is updated.\n",
"220000 5.46 0.09999999999985551\n",
"The Target network is updated.\n",
"225000 5.48 0.09999999999985551\n",
"The Target network is updated.\n",
"230000 5.85 0.09999999999985551\n",
"The Target network is updated.\n",
"235000 6.18 0.09999999999985551\n",
"The Target network is updated.\n",
"240000 5.22 0.09999999999985551\n",
"The Target network is updated.\n",
"245000 5.74 0.09999999999985551\n",
"The Target network is updated.\n",
"250000 5.38 0.09999999999985551\n",
"Saved Model\n",
"The Target network is updated.\n",
"255000 5.58 0.09999999999985551\n",
"The Target network is updated.\n",
"260000 5.07 0.09999999999985551\n",
"The Target network is updated.\n",
"265000 5.76 0.09999999999985551\n",
"The Target network is updated.\n",
"270000 5.79 0.09999999999985551\n",
"The Target network is updated.\n",
"275000 5.91 0.09999999999985551\n",
"The Target network is updated.\n",
"280000 6.05 0.09999999999985551\n",
"The Target network is updated.\n",
"285000 5.88 0.09999999999985551\n",
"The Target network is updated.\n",
"290000 5.21 0.09999999999985551\n",
"The Target network is updated.\n",
"295000 5.23 0.09999999999985551\n",
"The Target network is updated.\n",
"300000 5.61 0.09999999999985551\n",
"Saved Model\n",
"The Target network is updated.\n",
"305000 5.24 0.09999999999985551\n",
"The Target network is updated.\n",
"310000 4.79 0.09999999999985551\n",
"The Target network is updated.\n",
"315000 5.31 0.09999999999985551\n",
"The Target network is updated.\n",
"320000 5.42 0.09999999999985551\n",
"The Target network is updated.\n",
"325000 5.04 0.09999999999985551\n",
"The Target network is updated.\n",
"330000 5.69 0.09999999999985551\n",
"The Target network is updated.\n",
"335000 5.39 0.09999999999985551\n",
"The Target network is updated.\n",
"340000 5.67 0.09999999999985551\n",
"The Target network is updated.\n",
"345000 5.28 0.09999999999985551\n",
"The Target network is updated.\n",
"350000 5.4 0.09999999999985551\n",
"Saved Model\n",
"The Target network is updated.\n",
"355000 5.62 0.09999999999985551\n",
"The Target network is updated.\n",
"360000 5.7 0.09999999999985551\n",
"The Target network is updated.\n",
"365000 5.65 0.09999999999985551\n",
"The Target network is updated.\n",
"370000 5.52 0.09999999999985551\n",
"The Target network is updated.\n",
"375000 5.8 0.09999999999985551\n",
"The Target network is updated.\n",
"380000 5.17 0.09999999999985551\n",
"The Target network is updated.\n",
"385000 5.51 0.09999999999985551\n",
"The Target network is updated.\n",
"390000 5.61 0.09999999999985551\n",
"The Target network is updated.\n",
"395000 5.84 0.09999999999985551\n",
"The Target network is updated.\n",
"400000 5.8 0.09999999999985551\n",
"Saved Model\n",
"The Target network is updated.\n",
"405000 5.41 0.09999999999985551\n",
"The Target network is updated.\n",
"410000 5.97 0.09999999999985551\n",
"The Target network is updated.\n",
"415000 5.93 0.09999999999985551\n",
"The Target network is updated.\n",
"420000 6.41 0.09999999999985551\n",
"The Target network is updated.\n",
"425000 5.96 0.09999999999985551\n",
"The Target network is updated.\n",
"430000 6.16 0.09999999999985551\n",
"The Target network is updated.\n",
"435000 5.98 0.09999999999985551\n",
"The Target network is updated.\n",
"440000 5.83 0.09999999999985551\n",
"The Target network is updated.\n",
"445000 6.19 0.09999999999985551\n",
"The Target network is updated.\n",
"450000 5.54 0.09999999999985551\n",
"Saved Model\n",
"The Target network is updated.\n",
"455000 5.44 0.09999999999985551\n",
"The Target network is updated.\n",
"460000 5.52 0.09999999999985551\n",
"The Target network is updated.\n",
"465000 5.56 0.09999999999985551\n",
"The Target network is updated.\n",
"470000 5.7 0.09999999999985551\n",
"The Target network is updated.\n",
"475000 5.28 0.09999999999985551\n",
"The Target network is updated.\n",
"480000 5.97 0.09999999999985551\n",
"The Target network is updated.\n",
"485000 6.09 0.09999999999985551\n",
"The Target network is updated.\n",
"490000 5.77 0.09999999999985551\n",
"The Target network is updated.\n",
"495000 6.1 0.09999999999985551\n",
"The Target network is updated.\n",
"500000 5.37 0.09999999999985551\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"\n",
"# We define the cells for the primary and target q-networks\n",
"cell = tf.contrib.rnn.BasicLSTMCell(num_units=h_size, state_is_tuple=True)\n",
"cellT = tf.contrib.rnn.BasicLSTMCell(num_units=h_size, state_is_tuple=True)\n",
"mainQN = Qnetwork(h_size, cell, 'main')\n",
"targetQN = Qnetwork(h_size, cellT,'target')\n",
"\n",
"init = tf.global_variables_initializer()\n",
"\n",
"saver = tf.train.Saver(max_to_keep = 5)\n",
"trainables = tf.trainable_variables()\n",
"targetOps = updateTargetGraph(trainables, tau)\n",
"myBuffer = experience_buffer()\n",
"\n",
"#Set the rate of random action decrease. \n",
"e = startE\n",
"stepDrop = (startE - endE)/annealing_steps\n",
"\n",
"#create lists to contain total rewards and steps per episode\n",
"jList = []\n",
"rList = []\n",
"total_steps = 0\n",
"\n",
"# Make a path for our model to be saved in.\n",
"if not os.path.exists(path):\n",
" os.makedirs(path)\n",
"\n",
"with tf.Session() as sess:\n",
" if load_model == True:\n",
" print ('Loading Model...')\n",
" # load the model\n",
" ckpt = tf.train.get_checkpoint_state(path)\n",
" saver.restore(sess, ckpt.model_checkpoint_path)\n",
" sess.run(init)\n",
" \n",
" # Set the target network to be equal to the primary network.\n",
" updateTarget(targetOps, sess) \n",
" \n",
" for i in range(num_episodes):\n",
" episodeBuffer = []\n",
" \n",
" # Reset environment and get first new observation\n",
" sP = env.reset()\n",
" s = processState(sP)\n",
" d = False\n",
" rAll = 0\n",
" j = 0 \n",
" state = (np.zeros([1, h_size]),np.zeros([1, h_size])) #Reset the recurrent layer's hidden state\n",
" \n",
" # The Q-Network\n",
" while j < max_epLength: \n",
" j+=1\n",
" \n",
" # Choose an action by greedily (with e chance of random action) from the Q-network\n",
" if np.random.rand(1) < e or total_steps < pre_train_steps:\n",
" state1 = sess.run(mainQN.rnn_state,\n",
" feed_dict={mainQN.scalarInput:[s/255.0], mainQN.trainLength:1, mainQN.state_in:state, mainQN.batch_size:1})\n",
" a = np.random.randint(0,4)\n",
" else:\n",
" a, state1 = sess.run([mainQN.predict,mainQN.rnn_state],\n",
" feed_dict={mainQN.scalarInput:[s/255.0], mainQN.trainLength:1, mainQN.state_in:state, mainQN.batch_size:1})\n",
" a = a[0]\n",
" s1P, r, d = env.step(a)\n",
" s1 = processState(s1P)\n",
" total_steps += 1\n",
" episodeBuffer.append(np.reshape(np.array([s,a,r,s1,d]), [1,5]))\n",
" \n",
" if total_steps > pre_train_steps:\n",
" if e > endE:\n",
" e -= stepDrop\n",
" \n",
" if total_steps % (update_freq*1000) == 0:\n",
" print('The Target network is updated.')\n",
" \n",
" if total_steps % (update_freq) == 0:\n",
" updateTarget(targetOps,sess)\n",
" # Reset the recurrent layer's hidden state\n",
" state_train = (np.zeros([batch_size,h_size]), np.zeros([batch_size,h_size])) \n",
" trainBatch = myBuffer.sample(batch_size, trace_length) # Get a random batch of experiences.\n",
" \n",
" # Below we perform the Double-DQN update to the target Q-values\n",
" Q1 = sess.run(mainQN.predict,feed_dict={\n",
" mainQN.scalarInput:np.vstack(trainBatch[:,3]/255.0),\n",
" mainQN.trainLength:trace_length, mainQN.state_in:state_train, mainQN.batch_size:batch_size})\n",
" \n",
" Q2 = sess.run(targetQN.Qout,feed_dict={\n",
" targetQN.scalarInput:np.vstack(trainBatch[:,3]/255.0),\n",
" targetQN.trainLength:trace_length, targetQN.state_in:state_train, targetQN.batch_size:batch_size})\n",
" \n",
" end_multiplier = -(trainBatch[:,4] - 1)\n",
" doubleQ = Q2[range(batch_size*trace_length),Q1]\n",
" targetQ = trainBatch[:,2] + (y*doubleQ * end_multiplier)\n",
" \n",
" # Update the network with our target values.\n",
" sess.run(mainQN.updateModel, \n",
" feed_dict={mainQN.scalarInput:np.vstack(trainBatch[:,0]/255.0), mainQN.targetQ:targetQ,\n",
" mainQN.actions:trainBatch[:,1], mainQN.trainLength:trace_length,\n",
" mainQN.state_in:state_train, mainQN.batch_size:batch_size})\n",
" rAll += r\n",
" s = s1\n",
" sP = s1P\n",
" state = state1\n",
" \n",
" if d == True:\n",
" break\n",
"\n",
" # Add the episode to the experience buffer\n",
" bufferArray = np.array(episodeBuffer)\n",
" episodeBuffer = list(zip(bufferArray))\n",
" myBuffer.add(episodeBuffer)\n",
" jList.append(j)\n",
" rList.append(rAll)\n",
"\n",
" # Periodically save the model. \n",
" if i % 1000 == 0 and i != 0:\n",
" saver.save(sess,path+'/model-'+str(i)+'.ckpt')\n",
" print(\"Saved Model\")\n",
" \n",
" if len(rList) % summaryLength == 0 and len(rList) != 0:\n",
" print(total_steps, np.mean(rList[-summaryLength:]), e)\n",
" \n",
" # save model\n",
" saver.save(sess,path+'/model-'+str(i)+'.ckpt')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Testing the network"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"e = 0.01 # The chance of chosing a random action\n",
"num_episodes = 10000 # How many episodes of game environment to train network with.\n",
"load_model = True # Whether to load a saved model.\n",
"path = \"./drqn\" # The path to save/load our model to/from.\n",
"h_size = 512 # The size of the final convolutional layer before splitting it into Advantage and Value streams.\n",
"max_epLength = 50 # The max allowed length of our episode.\n",
"time_per_step = 1 # Length of each step used in gif creation\n",
"summaryLength = 100 # Number of epidoes to periodically save for analysis"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading Model...\n",
"INFO:tensorflow:Restoring parameters from ./drqn\\model-9999.ckpt\n",
"5000 6.18 0.01\n",
"10000 5.66 0.01\n",
"15000 5.64 0.01\n",
"20000 5.79 0.01\n",
"25000 5.75 0.01\n",
"30000 6.56 0.01\n",
"35000 6.18 0.01\n",
"40000 5.64 0.01\n",
"45000 6.01 0.01\n",
"50000 6.75 0.01\n",
"55000 6.07 0.01\n",
"60000 5.62 0.01\n",
"65000 5.68 0.01\n",
"70000 6.37 0.01\n",
"75000 6.13 0.01\n",
"80000 6.21 0.01\n",
"85000 5.66 0.01\n",
"90000 6.26 0.01\n",
"95000 6.09 0.01\n",
"100000 6.04 0.01\n",
"105000 6.11 0.01\n",
"110000 6.15 0.01\n",
"115000 6.18 0.01\n",
"120000 6.46 0.01\n",
"125000 5.9 0.01\n",
"130000 5.61 0.01\n",
"135000 6.03 0.01\n",
"140000 6.39 0.01\n",
"145000 6.57 0.01\n",
"150000 6.07 0.01\n",
"155000 5.8 0.01\n",
"160000 5.62 0.01\n",
"165000 5.84 0.01\n",
"170000 6.83 0.01\n",
"175000 5.5 0.01\n",
"180000 6.05 0.01\n",
"185000 6.23 0.01\n",
"190000 5.71 0.01\n",
"195000 6.16 0.01\n",
"200000 6.09 0.01\n",
"205000 5.56 0.01\n",
"210000 5.98 0.01\n",
"215000 5.73 0.01\n",
"220000 5.96 0.01\n",
"225000 6.02 0.01\n",
"230000 6.53 0.01\n",
"235000 6.1 0.01\n",
"240000 6.29 0.01\n",
"245000 6.12 0.01\n",
"250000 5.76 0.01\n",
"255000 6.0 0.01\n",
"260000 5.86 0.01\n",
"265000 6.65 0.01\n",
"270000 5.72 0.01\n",
"275000 5.86 0.01\n",
"280000 6.04 0.01\n",
"285000 6.2 0.01\n",
"290000 6.65 0.01\n",
"295000 6.43 0.01\n",
"300000 6.11 0.01\n",
"305000 6.05 0.01\n",
"310000 5.63 0.01\n",
"315000 6.16 0.01\n",
"320000 5.5 0.01\n",
"325000 6.37 0.01\n",
"330000 5.68 0.01\n",
"335000 6.82 0.01\n",
"340000 5.8 0.01\n",
"345000 6.2 0.01\n",
"350000 6.22 0.01\n",
"355000 6.08 0.01\n",
"360000 6.32 0.01\n",
"365000 6.27 0.01\n",
"370000 5.59 0.01\n",
"375000 6.12 0.01\n",
"380000 6.28 0.01\n",
"385000 6.5 0.01\n",
"390000 5.83 0.01\n",
"395000 5.9 0.01\n",
"400000 6.24 0.01\n",
"405000 6.21 0.01\n",
"410000 5.59 0.01\n",
"415000 6.15 0.01\n",
"420000 5.96 0.01\n",
"425000 6.28 0.01\n",
"430000 5.82 0.01\n",
"435000 5.79 0.01\n",
"440000 5.6 0.01\n",
"445000 6.04 0.01\n",
"450000 6.01 0.01\n",
"455000 6.02 0.01\n",
"460000 5.89 0.01\n",
"465000 5.74 0.01\n",
"470000 6.37 0.01\n",
"475000 6.04 0.01\n",
"480000 6.08 0.01\n",
"485000 5.53 0.01\n",
"490000 6.4 0.01\n",
"495000 6.15 0.01\n",
"500000 5.85 0.01\n",
"Percent of succesful episodes: 6.0424%\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"cell = tf.contrib.rnn.BasicLSTMCell(num_units=h_size,state_is_tuple=True)\n",
"cellT = tf.contrib.rnn.BasicLSTMCell(num_units=h_size,state_is_tuple=True)\n",
"mainQN = Qnetwork(h_size,cell,'main')\n",
"targetQN = Qnetwork(h_size,cellT,'target')\n",
"\n",
"init = tf.global_variables_initializer()\n",
"\n",
"saver = tf.train.Saver(max_to_keep = 2)\n",
"\n",
"# create lists to contain total rewards and steps per episode\n",
"jList = []\n",
"rList = []\n",
"total_steps = 0\n",
"\n",
"# Make a path for our model to be saved in.\n",
"if not os.path.exists(path):\n",
" os.makedirs(path)\n",
" \n",
"\n",
"with tf.Session() as sess:\n",
" if load_model == True:\n",
" print ('Loading Model...')\n",
" ckpt = tf.train.get_checkpoint_state(path)\n",
" saver.restore(sess,ckpt.model_checkpoint_path)\n",
" else:\n",
" sess.run(init)\n",
"\n",
" \n",
" for i in range(num_episodes):\n",
" episodeBuffer = []\n",
" \n",
" # Reset environment and get first new observation\n",
" sP = env.reset()\n",
" s = processState(sP)\n",
" d = False\n",
" rAll = 0\n",
" j = 0\n",
" \n",
" state = (np.zeros([1,h_size]),np.zeros([1,h_size]))\n",
" \n",
" # The Q-Network\n",
" while j < max_epLength: # If the agent takes longer than 200 moves to reach either of the blocks, end the trial.\n",
" j+=1\n",
" \n",
" # Choose an action by greedily (with e chance of random action) from the Q-network\n",
" if np.random.rand(1) < e:\n",
" state1 = sess.run(mainQN.rnn_state,\n",
" feed_dict={mainQN.scalarInput:[s/255.0],mainQN.trainLength:1,mainQN.state_in:state,mainQN.batch_size:1})\n",
" a = np.random.randint(0,4)\n",
" \n",
" else:\n",
" a, state1 = sess.run([mainQN.predict,mainQN.rnn_state],\n",
" feed_dict={mainQN.scalarInput:[s/255.0],mainQN.trainLength:1,\n",
" mainQN.state_in:state,mainQN.batch_size:1})\n",
" a = a[0]\n",
" \n",
" s1P,r,d = env.step(a)\n",
" s1 = processState(s1P)\n",
" total_steps += 1 \n",
" episodeBuffer.append(np.reshape(np.array([s,a,r,s1,d]), [1,5])) #Save the experience to our episode buffer.\n",
" rAll += r\n",
" s = s1\n",
" sP = s1P\n",
" state = state1\n",
" \n",
" if d == True:\n",
" break\n",
"\n",
" bufferArray = np.array(episodeBuffer)\n",
" jList.append(j)\n",
" rList.append(rAll)\n",
"\n",
" # Periodically save the model. \n",
" if len(rList) % summaryLength == 0 and len(rList) != 0:\n",
" print (total_steps,np.mean(rList[-summaryLength:]), e)\n",
" \n",
"print (\"Percent of succesful episodes: \" + str(sum(rList)/num_episodes) + \"%\")"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment