Last active
November 3, 2018 12:54
-
-
Save yoheitaonishi/10703310beb8bc87d7774c4356c3a8d9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"deletable": true, | |
"editable": true | |
}, | |
"source": [ | |
"# Simple Reinforcement Learning in Tensorflow Part 2-b: \n", | |
"## Vanilla Policy Gradient Agent\n", | |
"This tutorial contains a simple example of how to build a policy-gradient based agent that can solve the CartPole problem. For more information, see this [Medium post](https://medium.com/@awjuliani/super-simple-reinforcement-learning-tutorial-part-2-ded33892c724#.mtwpvfi8b). This implementation is generalizable to more than two actions.\n", | |
"\n", | |
"For more Reinforcement Learning algorithms, including DQN and Model-based learning in Tensorflow, see my Github repo, [DeepRL-Agents](https://github.com/awjuliani/DeepRL-Agents). " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import tensorflow.contrib.slim as slim\n", | |
"import numpy as np\n", | |
"import gym\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"\n", | |
"try:\n", | |
" xrange = xrange\n", | |
"except:\n", | |
" xrange = range" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[2017-03-09 18:45:39,894] Making new env: CartPole-v0\n" | |
] | |
} | |
], | |
"source": [ | |
"env = gym.make('CartPole-v0')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"deletable": true, | |
"editable": true | |
}, | |
"source": [ | |
"### The Policy-Based Agent" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"gamma = 0.99\n", | |
"\n", | |
"def discount_rewards(r):\n", | |
" \"\"\" take 1D float array of rewards and compute discounted reward \"\"\"\n", | |
" discounted_r = np.zeros_like(r)\n", | |
" running_add = 0\n", | |
" for t in reversed(xrange(0, r.size)):\n", | |
" running_add = running_add * gamma + r[t]\n", | |
" discounted_r[t] = running_add\n", | |
" return discounted_r" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class agent():\n", | |
" def __init__(self, lr, s_size,a_size,h_size):\n", | |
" #These lines established the feed-forward part of the network. The agent takes a state and produces an action.\n", | |
" self.state_in= tf.placeholder(shape=[None,s_size],dtype=tf.float32)\n", | |
" hidden = slim.fully_connected(self.state_in,h_size,biases_initializer=None,activation_fn=tf.nn.relu)\n", | |
" self.output = slim.fully_connected(hidden,a_size,activation_fn=tf.nn.softmax,biases_initializer=None)\n", | |
" self.chosen_action = tf.argmax(self.output,1)\n", | |
"\n", | |
" #The next six lines establish the training proceedure. We feed the reward and chosen action into the network\n", | |
" #to compute the loss, and use it to update the network.\n", | |
" self.reward_holder = tf.placeholder(shape=[None],dtype=tf.float32)\n", | |
" self.action_holder = tf.placeholder(shape=[None],dtype=tf.int32)\n", | |
" \n", | |
" self.indexes = tf.range(0, tf.shape(self.output)[0]) * tf.shape(self.output)[1] + self.action_holder\n", | |
" self.responsible_outputs = tf.gather(tf.reshape(self.output, [-1]), self.indexes)\n", | |
"\n", | |
" self.loss = -tf.reduce_mean(tf.log(self.responsible_outputs)*self.reward_holder)\n", | |
" \n", | |
" tvars = tf.trainable_variables()\n", | |
" self.gradient_holders = []\n", | |
" for idx,var in enumerate(tvars):\n", | |
" placeholder = tf.placeholder(tf.float32,name=str(idx)+'_holder')\n", | |
" self.gradient_holders.append(placeholder)\n", | |
" \n", | |
" self.gradients = tf.gradients(self.loss,tvars)\n", | |
" \n", | |
" optimizer = tf.train.AdamOptimizer(learning_rate=lr)\n", | |
" self.update_batch = optimizer.apply_gradients(zip(self.gradient_holders,tvars))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"deletable": true, | |
"editable": true | |
}, | |
"source": [ | |
"### Training the Agent" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"16.0\n", | |
"21.47\n", | |
"25.57\n", | |
"38.03\n", | |
"43.59\n", | |
"53.05\n", | |
"67.38\n", | |
"90.44\n", | |
"120.19\n", | |
"131.75\n", | |
"162.65\n", | |
"156.48\n", | |
"168.18\n", | |
"181.43\n" | |
] | |
} | |
], | |
"source": [ | |
"tf.reset_default_graph() #Clear the Tensorflow graph.\n", | |
"\n", | |
"myAgent = agent(lr=1e-2,s_size=4,a_size=2,h_size=8) #Load the agent.\n", | |
"\n", | |
"total_episodes = 5000 #Set total number of episodes to train agent on.\n", | |
"max_ep = 999\n", | |
"update_frequency = 5\n", | |
"\n", | |
"init = tf.global_variables_initializer()\n", | |
"\n", | |
"# Launch the tensorflow graph\n", | |
"with tf.Session() as sess:\n", | |
" sess.run(init)\n", | |
" i = 0\n", | |
" total_reward = []\n", | |
" total_lenght = []\n", | |
" \n", | |
" gradBuffer = sess.run(tf.trainable_variables())\n", | |
" for ix,grad in enumerate(gradBuffer):\n", | |
" gradBuffer[ix] = grad * 0\n", | |
" \n", | |
" while i < total_episodes:\n", | |
" s = env.reset()\n", | |
" running_reward = 0\n", | |
" ep_history = []\n", | |
" for j in range(max_ep):\n", | |
" #Probabilistically pick an action given our network outputs.\n", | |
" a_dist = sess.run(myAgent.output,feed_dict={myAgent.state_in:[s]})\n", | |
" a = np.random.choice(a_dist[0],p=a_dist[0])\n", | |
" a = np.argmax(a_dist == a)\n", | |
"\n", | |
" s1,r,d,_ = env.step(a) #Get our reward for taking an action given a bandit.\n", | |
" ep_history.append([s,a,r,s1])\n", | |
" s = s1\n", | |
" running_reward += r\n", | |
" if d == True:\n", | |
" #Update the network.\n", | |
" ep_history = np.array(ep_history)\n", | |
" ep_history[:,2] = discount_rewards(ep_history[:,2])\n", | |
" feed_dict={myAgent.reward_holder:ep_history[:,2],\n", | |
" myAgent.action_holder:ep_history[:,1],myAgent.state_in:np.vstack(ep_history[:,0])}\n", | |
" grads = sess.run(myAgent.gradients, feed_dict=feed_dict)\n", | |
" for idx,grad in enumerate(grads):\n", | |
" gradBuffer[idx] += grad\n", | |
"\n", | |
" if i % update_frequency == 0 and i != 0:\n", | |
" feed_dict= dictionary = dict(zip(myAgent.gradient_holders, gradBuffer))\n", | |
" _ = sess.run(myAgent.update_batch, feed_dict=feed_dict)\n", | |
" for ix,grad in enumerate(gradBuffer):\n", | |
" gradBuffer[ix] = grad * 0\n", | |
" \n", | |
" total_reward.append(running_reward)\n", | |
" total_lenght.append(j)\n", | |
" break\n", | |
"\n", | |
" \n", | |
" #Update our running tally of scores.\n", | |
" if i % 100 == 0:\n", | |
" print(np.mean(total_reward[-100:]))\n", | |
" i += 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [conda env:py2]", | |
"language": "python", | |
"name": "conda-env-py2-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment