Skip to content

Instantly share code, notes, and snippets.

@yingzwang
Last active July 8, 2022 08:29
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save yingzwang/2c5b455907942c7bdf3c0fece640095b to your computer and use it in GitHub Desktop.
Save yingzwang/2c5b455907942c7bdf3c0fece640095b to your computer and use it in GitHub Desktop.
Deep-Q learning implementation in Tensorflow and Keras (solving CartPole-v0)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deep-Q learning implementation in Tensorflow and Keras\n",
"with an example application to solving `CartPole-v0` environment.\n",
"![dqn](https://user-images.githubusercontent.com/38169187/46908388-63807200-cf22-11e8-99f3-b471405495b3.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# import"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import tensorflow as tf\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# replay buffer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# import numpy as np\n",
"from collections import deque\n",
"import random\n",
"\n",
"class ReplayBuffer:\n",
" \"\"\"Fixed-size buffer to store experience tuples.\"\"\"\n",
"\n",
" def __init__(self, buffer_size=int(1e5), random_seed=1234):\n",
" \"\"\"Initialize a ReplayBuffer object.\n",
" Params\n",
" ======\n",
" buffer_size: maximum size of buffer\n",
" The right side of the deque contains the most recent experiences. \n",
" \"\"\"\n",
" self.buffer_size = buffer_size\n",
" self.buffer = deque(maxlen=buffer_size)\n",
" random.seed(random_seed)\n",
"\n",
" def __len__(self):\n",
" \"\"\"Return the current size of internal memory.\"\"\"\n",
" return len(self.buffer)\n",
" \n",
" def add(self, s, a, r, done, s2):\n",
" \"\"\"Add a new experience to buffer.\n",
" Params\n",
" ======\n",
" s: one state sample, numpy array shape (s_dim,)\n",
" a: one action sample, scalar (for DQN)\n",
" r: one reward sample, scalar\n",
" done: True/False, scalar\n",
" s2: one state sample, numpy array shape (s_dim,)\n",
" \"\"\"\n",
" e = (s, a, r, done, s2)\n",
" self.buffer.append(e)\n",
" \n",
" def sample_batch(self, batch_size):\n",
" \"\"\"Randomly sample a batch of experiences from buffer.\"\"\"\n",
" \n",
" # ensure the buffer is large enough for sampleling \n",
" assert (len(self.buffer) >= batch_size)\n",
" \n",
" # sample a batch\n",
" batch = random.sample(self.buffer, batch_size)\n",
" \n",
" # Convert experience tuples to separate arrays for each element (states, actions, rewards, etc.)\n",
" states, actions, rewards, dones, next_states = zip(*batch)\n",
" states = np.asarray(states).reshape(batch_size, -1) # shape (batch_size, s_dim)\n",
" next_states = np.asarray(next_states).reshape(batch_size, -1) # shape (batch_size, s_dim)\n",
" actions = np.asarray(actions) # shape (batch_size,), for DQN, action is an int\n",
" rewards = np.asarray(rewards) # shape (batch_size,)\n",
" dones = np.asarray(dones, dtype=np.uint8) # shape (batch_size,)\n",
" return states, actions, rewards, dones, next_states"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DQN tf summary"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def build_summaries():\n",
" \"\"\"\n",
" tensorboard summary for monitoring training process\n",
" \"\"\"\n",
" \n",
" # performance per episode\n",
" ph_reward = tf.placeholder(tf.float32) \n",
" tf.summary.scalar(\"Reward_ep\", ph_reward)\n",
" ph_Qmax = tf.placeholder(tf.float32)\n",
" tf.summary.scalar(\"Qmax_ep\", ph_Qmax)\n",
" \n",
" # merge all summary op (must be done at the last step)\n",
" summary_op = tf.summary.merge_all()\n",
" \n",
" return summary_op, ph_reward, ph_Qmax\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DQN neural network model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import time\n",
"from keras import layers, initializers, regularizers\n",
"from functools import partial\n",
"\n",
"def build_net(model_name, state, a_dim, args, trainable):\n",
" \"\"\"\n",
" neural network model\n",
" model input: state\n",
" model output: Qhat\n",
" \"\"\"\n",
" h1 = int(args['h1'])\n",
" h2 = int(args['h2'])\n",
" \n",
" my_dense = partial(layers.Dense, trainable=trainable)\n",
" with tf.variable_scope(model_name):\n",
" net = my_dense(h1, name=\"l1-dense-{}\".format(h1))(state) \n",
" net = layers.Activation('relu', name=\"relu1\")(net) \n",
" net = my_dense(h2, name=\"l2-dense-{}\".format(h2))(net)\n",
" net = layers.Activation('relu', name=\"relu2\")(net)\n",
" net = my_dense(a_dim, name=\"l3-dense-{}\".format(a_dim))(net)\n",
" Qhat = layers.Activation('linear', name=\"Qhat\")(net)\n",
" nn_params = tf.trainable_variables(scope=model_name)\n",
" return Qhat, nn_params"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DQN agent"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class DeepQNetwork:\n",
" def __init__(self, sess, a_dim, s_dim, args):\n",
" self.a_dim = a_dim\n",
" self.s_dim = s_dim\n",
" self.h1 = args[\"h1\"]\n",
" self.h2 = args[\"h2\"]\n",
" self.lr = args[\"learning_rate\"]\n",
" self.gamma = args[\"gamma\"]\n",
" self.epsilon_start = args[\"epsilon_start\"]\n",
" self.epsilon_stop = args[\"epsilon_stop\"]\n",
" self.epsilon_decay = args[\"epsilon_decay\"]\n",
" self.epsilon = self.epsilon_start # current exploration probability\n",
" self.update_target_C = args[\"update_target_C\"]\n",
" self.update_target_tau = args['update_target_tau']\n",
" self.learn_step_counter = 0\n",
" \n",
" # initialize replay buffer\n",
" self.replay_buffer = ReplayBuffer(int(args['buffer_size']), int(args['random_seed']))\n",
" self.minibatch_size = int(args['minibatch_size'])\n",
"\n",
" self.s = tf.placeholder(tf.float32, [None, self.s_dim], name='state') # input State\n",
" self.s_ = tf.placeholder(tf.float32, [None, self.s_dim], name='state_next') # input Next State\n",
" self.r = tf.placeholder(tf.float32, [None,], name='reward') # input Reward\n",
" self.a = tf.placeholder(tf.int32, [None,], name='action') # input Action\n",
" self.done = tf.placeholder(tf.float32, [None,], name='done')\n",
" \n",
" # initialize NN, self.q shape (batch_size, a_dim)\n",
" self.q, self.nn_params = build_net(\"DQN\", self.s, a_dim, args, trainable=True)\n",
" self.q_, self.nn_params_ = build_net(\"target_DQN\", self.s_, a_dim, args, trainable=False)\n",
" for var in self.nn_params:\n",
" vname = var.name.replace(\"kernel:0\", \"W\").replace(\"bias:0\", \"b\")\n",
" tf.summary.histogram(vname, var)\n",
"\n",
" with tf.variable_scope(\"Qmax\"):\n",
" self.Qmax = tf.reduce_max(self.q_, axis=1) # shape (batch_size,)\n",
"\n",
" with tf.variable_scope(\"yi\"):\n",
" self.yi = self.r + self.gamma * self.Qmax * (1 - self.done) # shape (batch_size,)\n",
" \n",
" with tf.variable_scope(\"Qa_all\"):\n",
" Qa = tf.Variable(tf.zeros([self.minibatch_size, self.a_dim]))\n",
" for aval in np.arange(self.a_dim):\n",
" tf.summary.histogram(\"Qa{}\".format(aval), Qa[:, aval])\n",
" self.Qa_op = Qa.assign(self.q)\n",
" \n",
" with tf.variable_scope(\"Q_at_a\"):\n",
" # select the Q value corresponding to the action\n",
" one_hot_actions = tf.one_hot(self.a, self.a_dim) # shape (batch_size, a_dim)\n",
" q_all = tf.multiply(self.q, one_hot_actions) # shape (batch_size, a_dim)\n",
" self.q_at_a = tf.reduce_sum(q_all, axis=1) # shape (batch_size,)\n",
" \n",
" with tf.variable_scope(\"loss_MSE\"):\n",
" self.loss = tf.losses.mean_squared_error(labels=self.yi, predictions=self.q_at_a)\n",
" \n",
" with tf.variable_scope(\"train_DQN\"):\n",
" self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss=self.loss, var_list=self.nn_params)\n",
" \n",
" with tf.variable_scope(\"soft_update\"):\n",
" TAU = self.update_target_tau \n",
" self.update_op = [tf.assign(t, (1 - TAU)*t + TAU*e) for t, e in zip(self.nn_params_, self.nn_params)]\n",
" \n",
" \n",
" def choose_action(self, sess, observation):\n",
" # Explore or Exploit\n",
" explore_p = self.epsilon # exploration probability\n",
" \n",
" if np.random.uniform() <= explore_p:\n",
" # Explore: make a random action\n",
" action = np.random.randint(0, self.a_dim)\n",
" else:\n",
" # Exploit: Get action from Q-network\n",
" observation = np.reshape(observation, (1, self.s_dim))\n",
" Qs = sess.run(self.q, feed_dict={self.s: observation}) # shape (1, a_dim)\n",
" action = np.argmax(Qs[0])\n",
" return action\n",
"\n",
" \n",
" def learn_a_batch(self, sess):\n",
" # update target every C learning steps\n",
" if self.learn_step_counter % self.update_target_C == 0:\n",
" sess.run(self.update_op)\n",
" \n",
" # Sample a batch\n",
" s_batch, a_batch, r_batch, done_batch, s2_batch = self.replay_buffer.sample_batch(self.minibatch_size)\n",
" \n",
" # Train\n",
" _, _, Qhat, loss = sess.run([self.train_op, self.Qa_op, self.q_at_a, self.loss], feed_dict={\n",
" self.s: s_batch, self.a: a_batch, self.r: r_batch, self.done: done_batch, self.s_: s2_batch})\n",
" \n",
" # count learning steps\n",
" self.learn_step_counter += 1\n",
" \n",
" # decay exploration probability after each learning step\n",
" if self.epsilon > self.epsilon_stop:\n",
" self.epsilon *= self.epsilon_decay\n",
" \n",
" return np.max(Qhat)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# args `CartPole-v0`"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"code_folding": []
},
"outputs": [],
"source": [
"args = {\"env\": 'CartPole-v0',\n",
" \"random_seed\": 1234,\n",
" \"max_episodes\": 150, # number of episodes\n",
" \"max_episode_len\": 200, # time steps per episode, 200 for CartPole-v0\n",
" ## NN params\n",
" \"h1\": 32, # 32 \n",
" \"h2\": 64, # 64\n",
" \"learning_rate\": 0.001, # 1e-3\n",
" \"gamma\": 0.9, # 0.9 (32), 0.95 (34) better than 0.99\n",
" \"update_target_C\": 1, # update every C learning steps (C=1 if soft update, C=100 if hard update)\n",
" \"update_target_tau\": 8e-2, # soft update (tau=8e-2), hard update (tau=1)\n",
" ## exploration prob\n",
" \"epsilon_start\": 1.0, \n",
" \"epsilon_stop\": 0.01, # 0.01\n",
" \"epsilon_decay\": 0.999, # 0.999\n",
" ## replay buffer\n",
" \"buffer_size\": 1e5, \n",
" \"minibatch_size\": 32, # 32\n",
" ## tensorboard logs\n",
" \"summary_dir\": './results/dqn', \n",
" }\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# main training"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ying/gym/gym/__init__.py:22: UserWarning: DEPRECATION WARNING: to improve load times, gym no longer automatically loads gym.spaces. Please run \"import gym.spaces\" to load gym.spaces on your own. This warning will turn into an error in a future version of gym.\n",
" warnings.warn('DEPRECATION WARNING: to improve load times, gym no longer automatically loads gym.spaces. Please run \"import gym.spaces\" to load gym.spaces on your own. This warning will turn into an error in a future version of gym.')\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
"states: Box(4,)\n",
"actions: Discrete(2)\n",
"episode: 0/150, steps: 23, explore_prob: 1.00, total reward: 23.0\n",
"episode: 10/150, steps: 13, explore_prob: 0.89, total reward: 13.0\n",
"episode: 20/150, steps: 21, explore_prob: 0.69, total reward: 21.0\n",
"episode: 30/150, steps: 17, explore_prob: 0.49, total reward: 17.0\n",
"episode: 40/150, steps: 200, explore_prob: 0.11, total reward: 200.0\n",
"episode: 50/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 60/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 70/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 80/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 90/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 100/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 110/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 120/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 130/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 140/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n"
]
}
],
"source": [
"sess = tf.InteractiveSession()\n",
"tf.set_random_seed(int(args['random_seed']))\n",
"\n",
"# initialize numpy seed\n",
"np.random.seed(int(args['random_seed']))\n",
"\n",
"# initialize gym env\n",
"env = gym.make(args['env'])\n",
"env.seed(int(args['random_seed']))\n",
"state_size = env.observation_space.shape[0]\n",
"action_size = env.action_space.n\n",
"print(\"states:\", env.observation_space)\n",
"print(\"actions:\", env.action_space)\n",
"\n",
"# initialize DQN agent\n",
"agent = DeepQNetwork(sess, action_size, state_size, args)\n",
"\n",
"# initialize summary (for visualization in tensorboard)\n",
"summary_op, ph_reward, ph_Qmax = build_summaries()\n",
"subdir = time.strftime(\"%Y%m%d-%H%M%S\", time.localtime()) # a sub folder, e.g., yyyymmdd-HHMMSS\n",
"logdir = args['summary_dir'] + '/' + subdir\n",
"writer = tf.summary.FileWriter(logdir, sess.graph) # must be done after graph is constructed\n",
"\n",
"# initialize variables existed in the graph\n",
"sess.run(tf.global_variables_initializer())\n",
"\n",
"# training DQN agent\n",
"rewards_list = []\n",
"loss = -999\n",
"num_ep = args['max_episodes']\n",
"max_t = args['max_episode_len']\n",
"for ep in range(num_ep):\n",
" state= env.reset() # shape (s_dim,)\n",
" ep_reward = 0 # total reward per episode\n",
" ep_qmax = 0\n",
" t_step = 0\n",
" done = False\n",
" while (t_step < max_t) and (not done):\n",
" \n",
" # choose an action\n",
" action = agent.choose_action(sess, state)\n",
" \n",
" # interact with the env\n",
" next_state, reward, done, _ = env.step(action)\n",
" \n",
" # add the experience to replay buffer\n",
" agent.replay_buffer.add(state, action, reward, done, next_state)\n",
" \n",
" # learn from a batch of experiences\n",
" if len(agent.replay_buffer) > 3 * agent.minibatch_size:\n",
" qmax = agent.learn_a_batch(sess)\n",
" ep_qmax = max(ep_qmax, qmax)\n",
" \n",
" # next time step\n",
" t_step += 1\n",
" ep_reward += reward\n",
" state= next_state\n",
" \n",
" # end of an episode\n",
" rewards_list.append((ep, ep_reward))\n",
"\n",
" # write to tensorboard summary\n",
" summary_str = sess.run(summary_op, feed_dict={ph_reward: ep_reward, ph_Qmax: ep_qmax})\n",
" writer.add_summary(summary_str, ep)\n",
" writer.flush()\n",
"\n",
" if ep % 10 == 0:\n",
" print(\"episode: {}/{}, steps: {}, explore_prob: {:.2f}, total reward: {}\".\\\n",
" format(ep, num_ep, t_step, agent.epsilon, ep_reward))\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Solved Requirements `CartPole-v0`**\n",
"\n",
"https://github.com/openai/gym/wiki/CartPole-v0\n",
"\n",
"Considered solved when the average reward is greater than or equal to **195.0** over 100 consecutive trials."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"episodes before solving: 32\n"
]
}
],
"source": [
"def my_sma(x, N):\n",
" \"\"\"simple moving average over a window of N samples\"\"\"\n",
" filt = np.ones(N) / N\n",
" xm = np.convolve(x, filt)\n",
" xm = xm[:-(N-1)] # remove the last (N-1) elements\n",
" return xm\n",
"\n",
"eps, rewards = np.array(rewards_list).T\n",
"\n",
"# plot reward v.s. episode\n",
"plt.plot(eps, rewards)\n",
"plt.xlabel('episode')\n",
"plt.ylabel('reward')\n",
"plt.show()\n",
"\n",
"# check solved requirements\n",
"N = 100\n",
"thr = 195.0\n",
"ep_solve = np.argwhere(my_sma(rewards, N) >= thr).ravel()[0] - N # find where sma > thr \n",
"print(\"episodes before solving: {}\".format(ep_solve))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "288px"
},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@yingzwang
Copy link
Author

dqn
DQN graph (generated by tensorboard)

@DylanHaiyangChen
Copy link

Hi Ying,
This is a really great approach to solve CartPole problem. I wonder if you would like to support more information about the DQN architecture. Such like report or references.
I am thinking about why your implement is of high efficiency.

@TomeASilva
Copy link

Hi Ying,
This is a really great approach to solve CartPole problem. I wonder if you would like to support more information about the DQN architecture. Such like report or references.
I am thinking about why your implement is of high efficiency.

Hi dylan, HaiyangChen

I'm not associated with yingzwang, but i can give some information, this an implementation of DQN algorithm ( https://deepmind.com/research/dqn/. ) So, the architecture of the algorithm is essentially the same as the one presented in the paper. The difference is a soft-update to the weights of the target network by using exponential moving averages parameterized by tau. She also uses a decreasing exploration strategy, which clearly helps in this problem. The rest is just good hyper parameter tunning.

The code is also very good, good code practices all around.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment