Skip to content

Instantly share code, notes, and snippets.

@nambrot
Created April 23, 2018 17:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nambrot/9f7b071c70f25b967fe7d3a4b508d58f to your computer and use it in GitHub Desktop.
Save nambrot/9f7b071c70f25b967fe7d3a4b508d58f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Approximate q-learning\n",
"\n",
"In this notebook you will teach a __tensorflow__ neural network to do Q-learning."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__Frameworks__ - we'll accept this homework in any deep learning framework. This particular notebook was designed for tensorflow, but you will find it easy to adapt it to almost any python-based deep learning framework."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting virtual X frame buffer: Xvfb../xvfb: line 8: start-stop-daemon: command not found\n",
".\n",
"env: DISPLAY=:1\n"
]
}
],
"source": [
"#XVFB will be launched if you run on a server\n",
"import os\n",
"if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\"))==0:\n",
" !bash ../xvfb start\n",
" %env DISPLAY=:1"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x11711a2b0>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD8CAYAAAB9y7/cAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEjdJREFUeJzt3XGMnVd95vHvs3ZIKLB1QmYt13bWaetdlK4WJ50NiUBVmog2yVZ1KrUoaVUiFGlYKUigom6TVmpB2kit1JIW7W6E26SYFUvIBmisKC1NTaSKP0iYgDF2TMoARrblxAMkARY1XYff/jHH4XYy9tyZO9fjOf1+pKv7vuc9772/k1w9886Z9/imqpAk9edfrXYBkqTxMOAlqVMGvCR1yoCXpE4Z8JLUKQNekjo1toBPcn2Sp5PMJLljXO8jSVpYxnEffJJ1wD8AbwWOAp8Hbqmqp1b8zSRJCxrXFfyVwExVfb2q/gm4H9g5pveSJC1g/ZhedzNwZGD/KPCm03W++OKLa9u2bWMqRZLWnsOHD/Otb30ro7zGuAJ+UUmmgCmASy65hOnp6dUqRZLOOZOTkyO/xrimaI4BWwf2t7S2l1XVrqqarKrJiYmJMZUhSf9yjSvgPw9sT3JpklcBNwN7xvRekqQFjGWKpqpOJnkX8GlgHXBfVR0cx3tJkhY2tjn4qnoEeGRcry9JOjNXskpSpwx4SeqUAS9JnTLgJalTBrwkdcqAl6ROGfCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0y4CWpUwa8JHXKgJekThnwktQpA16SOmXAS1KnDHhJ6tRIX9mX5DDwPeAl4GRVTSa5CPg4sA04DLytqp4brUxJ0lKtxBX8z1fVjqqabPt3AHurajuwt+1Lks6ycUzR7AR2t+3dwE1jeA9J0iJGDfgC/jbJk0mmWtvGqjretp8BNo74HpKkZRhpDh54S1UdS/JvgEeTfGXwYFVVklroxPYDYQrgkksuGbEMSdJ8I13BV9Wx9nwC+BRwJfBskk0A7fnEac7dVVWTVTU5MTExShmSpAUsO+CTvCbJ605tA78AHAD2ALe2brcCD41apCRp6UaZotkIfCrJqdf531X1N0k+DzyQ5Dbgm8DbRi9TkrRUyw74qvo68MYF2r8NXDdKUZKk0bmSVZI6ZcBLUqcMeEnqlAEvSZ0y4CWpUwa8JHXKgJekThnwktQpA16SOmXAS1KnDHhJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SerUogGf5L4kJ5IcGGi7KMmjSb7ani9s7UnywSQzSfYnuWKcxUuSTm+YK/gPA9fPa7sD2FtV24G9bR/gBmB7e0wB96xMmZKkpVo04Kvq74HvzGveCexu27uBmwbaP1JzPgdsSLJppYqVJA1vuXPwG6vqeNt+BtjYtjcDRwb6HW1tr5BkKsl0kunZ2dllliFJOp2R/8haVQXUMs7bVVWTVTU5MTExahmSpHmWG/DPnpp6ac8nWvsxYOtAvy2tTZJ0li034PcAt7btW4GHBtrf3u6muQp4YWAqR5J0Fq1frEOSjwHXABcnOQr8AfCHwANJbgO+CbytdX8EuBGYAX4AvGMMNUuShrBowFfVLac5dN0CfQu4fdSiJEmjcyWrJHXKgJekThnwktQpA16SOmXAS1KnDHhJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SeqUAS9JnTLgJalTBrwkdcqAl6ROLRrwSe5LciLJgYG29yU5lmRfe9w4cOzOJDNJnk7yi+MqXJJ0ZsNcwX8YuH6B9rurakd7PAKQ5DLgZuBn2jn/M8m6lSpWkjS8RQO+qv4e+M6Qr7cTuL+qXqyqbwAzwJUj1CdJWqZR5uDflWR/m8K5sLVtBo4M9Dna2l4hyVSS6STTs7OzI5QhSVrIcgP+HuCngB3AceBPlvoCVbWrqiaranJiYmKZZUiSTmdZAV9Vz1bVS1X1Q+DP+dE0zDFg60DXLa1NknSWLSvgk2wa2P0V4NQdNnuAm5Ocn+RSYDvwxGglSpKWY/1iHZJ8DLgGuDjJUeAPgGuS7AAKOAy8E6CqDiZ5AHgKOAncXlUvjad0SdKZLBrwVXXLAs33nqH/XcBdoxQlSRqdK1klqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpxa9TVLq2ZO73vmKtp+d+tAqVCKtPK/gJalTBrwkdcqAl6ROGfCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0y4CWpUwa8NM9Cq1ultciAl6ROLRrwSbYmeSzJU0kOJnl3a78oyaNJvtqeL2ztSfLBJDNJ9ie5YtyDkCS90jBX8CeB91bVZcBVwO1JLgPuAPZW1XZgb9sHuAHY3h5TwD0rXrUkaVGLBnxVHa+qL7Tt7wGHgM3ATmB367YbuKlt7wQ+UnM+B2xIsmnFK5ckndGS5uCTbAMuBx4HNlbV8XboGWBj294MHBk47Whrm/9aU0mmk0zPzs4usWxJ0mKGDvgkrwU+Abynqr47eKyqCqilvHFV7aqqyaqanJiYWMqpkqQhDBXwSc5jLtw/WlWfbM3Pnpp6ac8nWvsxYOvA6VtamyTpLBrmLpoA9wKHquoDA4f2ALe27VuBhwba397uprkKeGFgKkeSdJYM85V9bwZ+E/hykn2t7XeBPwQeSHIb8E3gbe3YI8CNwAzwA+AdK1qxJGkoiwZ8VX0WyGkOX7dA/wJuH7EuSdKIXMkqSZ0y4CWpUwa8JHXKgJekThnwktQpA16SOmXAS1KnDHhJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SeqUAS9JnTLgJalTw3zp9tYkjyV5KsnBJO9u7e9LcizJvva4ceCcO5PMJHk6yS+OcwCSpIUN86XbJ4H3VtUXkrwOeDLJo+3Y3VX1x4Odk1wG3Az8DPATwN8l+XdV9dJKFi5JOrNFr+Cr6nhVfaFtfw84BGw+wyk7gfur6sWq+gYwA1y5EsVKkoa3pDn4JNuAy4HHW9O7kuxPcl+SC1vbZuDIwGlHOfMPBGnV/OzUh1a7BGlshg74JK8FPgG8p6q+C9wD/BSwAzgO/MlS3jjJVJLpJNOzs7NLOVWSNIShAj7JecyF+0er6pMAVfVsVb1UVT8E/pwfTcMcA7YOnL6ltf0zVbWrqiaranJiYmKUMUiSFjDMXTQB7gUOVdUHBto3DXT7FeBA294D3Jzk/CSXAtuBJ1auZEnSMIa5i+bNwG8CX06yr7X9LnBLkh1AAYeBdwJU1cEkDwBPMXcHzu3eQSNJZ9+iAV9VnwWywKFHznDOXcBdI9QlSRqRK1klqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SeqUAS9JnTLgJalTBrwkdcqAl6ROGfCS1CkDXpI6ZcCrS0mGfozjfOlcYMBLUqeG+cIPqXsPH596efuXNu1axUqkleMVvP7FGwz3hfaltcqAl6RODfOl2xckeSLJl5IcTPL+1n5pkseTzCT5eJJXtfbz2/5MO75tvEOQJC1kmCv4F4Frq+qNwA7g+iRXAX8E3F1VPw08B9zW+t8GPNfa7279pHPW/Dl35+DVi2G+dLuA77fd89qjgGuBX2/tu4H3AfcAO9s2wIPAf0+S9jrSOWfynbuAH4X6+1atEmllDTUHn2Rdkn3ACeBR4GvA81V1snU5Cmxu25uBIwDt+AvA61eyaEnS4oYK+Kp6qap2AFuAK4E3jPrGSaaSTCeZnp2dHfXlJEnzLOkumqp6HngMuBrYkOTUFM8W4FjbPgZsBWjHfxz49gKvtauqJqtqcmJiYpnlS5JOZ5i7aCaSbGjbrwbeChxiLuh/tXW7FXiobe9p+7Tjn3H+XZLOvmFWsm4CdidZx9wPhAeq6uEkTwH3J/lvwBeBe1v/e4H/lWQG+A5w8xjqliQtYpi7aPYDly/Q/nXm5uPnt/8j8GsrUp0kadlcySpJnTLgJalTBrwkdcp/Llhd8sYtySt4SeqWAS9JnTLgJalTBrwkdcqAl6ROGfCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0y4CWpUwa8JHXKgJekThnwktSpYb50+4IkTyT5UpKDSd7f2j+c5BtJ9rXHjtaeJB9MMpNkf5Irxj0ISdIrDfPvwb8IXFtV309yHvDZJH/djv12VT04r/8NwPb2eBNwT3uWJJ1Fi17B15zvt93z2uNM36awE/hIO+9zwIYkm0YvVZK0FEPNwSdZl2QfcAJ4tKoeb4fuatMwdyc5v7VtBo4MnH60tUmSzqKhAr6qXqqqHcAW4Mok/wG4E3gD8J+Ai4DfWcobJ5lKMp1kenZ2dollS5IWs6S7aKrqeeAx4PqqOt6mYV4E/hK4snU7BmwdOG1La5v/WruqarKqJicmJpZXvSTptIa5i2YiyYa2/WrgrcBXTs2rJwlwE3CgnbIHeHu7m+Yq4IWqOj6W6iVJpzXMXTSbgN1J1jH3A+GBqno4yWeSTAAB9gH/pfV/BLgRmAF+ALxj5cuWJC1m0YCvqv3A5Qu0X3ua/gXcPnppkqRRuJJVkjplwEtSpwx4SeqUAS9JnTLgJalTBrwkdcqAl6ROGfCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0y4CWpUwa8JHXKgJekThnwktQpA16SOmXAS1KnDHhJ6tTQAZ9kXZIvJnm47V+a5PEkM0k+nuRVrf38tj/Tjm8bT+mSpDNZyhX8u4FDA/t/BNxdVT8NPAfc1tpvA55r7Xe3fpKks2yogE+yBfjPwF+0/QDXAg+2LruBm9r2zrZPO35d6y9JOovWD9nvT4H/Cryu7b8eeL6qTrb9o8Dmtr0ZOAJQVSeTvND6f2vwBZNMAVNt98UkB5Y1gnPfxcwbeyd6HRf0OzbHtbb82yRTVbVruS+waMAn+SXgRFU9meSa5b7RfK3oXe09pqtqcqVe+1zS69h6HRf0OzbHtfYkmabl5HIMcwX/ZuCXk9wIXAD8a+DPgA1J1rer+C3Asdb/GLAVOJpkPfDjwLeXW6AkaXkWnYOvqjuraktVbQNuBj5TVb8BPAb8aut2K/BQ297T9mnHP1NVtaJVS5IWNcp98L8D/FaSGebm2O9t7fcCr2/tvwXcMcRrLftXkDWg17H1Oi7od2yOa+0ZaWzx4lqS+uRKVknq1KoHfJLrkzzdVr4OM51zTklyX5ITg7d5JrkoyaNJvtqeL2ztSfLBNtb9Sa5YvcrPLMnWJI8leSrJwSTvbu1remxJLkjyRJIvtXG9v7V3sTK71xXnSQ4n+XKSfe3OkjX/WQRIsiHJg0m+kuRQkqtXclyrGvBJ1gH/A7gBuAy4Jcllq1nTMnwYuH5e2x3A3qraDuzlR3+HuAHY3h5TwD1nqcblOAm8t6ouA64Cbm//b9b62F4Erq2qNwI7gOuTXEU/K7N7XnH+81W1Y+CWyLX+WYS5OxL/pqreALyRuf93Kzeuqlq1B3A18OmB/TuBO1ezpmWOYxtwYGD/aWBT294EPN22PwTcslC/c/3B3F1Sb+1pbMCPAV8A3sTcQpn1rf3lzyXwaeDqtr2+9ctq136a8WxpgXAt8DCQHsbVajwMXDyvbU1/Fpm7hfwb8/+7r+S4VnuK5uVVr83giti1bGNVHW/bzwAb2/aaHG/79f1y4HE6GFubxtgHnAAeBb7GkCuzgVMrs89Fp1ac/7DtD73inHN7XAAF/G2SJ9sqeFj7n8VLgVngL9u02l8keQ0rOK7VDvju1dyP2jV7q1KS1wKfAN5TVd8dPLZWx1ZVL1XVDuaueK8E3rDKJY0sAyvOV7uWMXlLVV3B3DTF7Ul+bvDgGv0srgeuAO6pqsuB/8u828pHHddqB/ypVa+nDK6IXcueTbIJoD2faO1rarxJzmMu3D9aVZ9szV2MDaCqnmduwd7VtJXZ7dBCK7M5x1dmn1pxfhi4n7lpmpdXnLc+a3FcAFTVsfZ8AvgUcz+Y1/pn8ShwtKoeb/sPMhf4Kzau1Q74zwPb21/6X8XcStk9q1zTShhczTt/le/b21/DrwJeGPhV7JySJMwtWjtUVR8YOLSmx5ZkIsmGtv1q5v6ucIg1vjK7Ol5xnuQ1SV53ahv4BeAAa/yzWFXPAEeS/PvWdB3wFCs5rnPgDw03Av/A3Dzo7612Pcuo/2PAceD/MfcT+Tbm5jL3Al8F/g64qPUNc3cNfQ34MjC52vWfYVxvYe5Xw/3Avva4ca2PDfiPwBfbuA4Av9/afxJ4ApgB/g9wfmu/oO3PtOM/udpjGGKM1wAP9zKuNoYvtcfBUzmx1j+LrdYdwHT7PP4VcOFKjsuVrJLUqdWeopEkjYkBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSp/4/9WF6SBFlMBUAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"env = gym.make(\"CartPole-v0\").env\n",
"env.reset()\n",
"n_actions = env.action_space.n\n",
"state_dim = env.observation_space.shape\n",
"\n",
"plt.imshow(env.render(\"rgb_array\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Approximate (deep) Q-learning: building the network\n",
"\n",
"To train a neural network policy one must have a neural network policy. Let's build it.\n",
"\n",
"\n",
"Since we're working with a pre-extracted features (cart positions, angles and velocities), we don't need a complicated network yet. In fact, let's build something like this for starters:\n",
"\n",
"![img](https://s14.postimg.org/uzay2q5rl/qlearning_scheme.png)\n",
"\n",
"For your first run, please only use linear layers (L.Dense) and activations. Stuff like batch normalization or dropout may ruin everything if used haphazardly. \n",
"\n",
"Also please avoid using nonlinearities like sigmoid & tanh: agent's observations are not normalized so sigmoids may become saturated from init.\n",
"\n",
"Ideally you should start small with maybe 1-2 hidden layers with < 200 neurons and then increase network size if agent doesn't beat the target score."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import keras\n",
"import keras.layers as L\n",
"tf.reset_default_graph()\n",
"sess = tf.InteractiveSession()\n",
"keras.backend.set_session(sess)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"network = keras.models.Sequential()\n",
"network.add(L.InputLayer(state_dim))\n",
"\n",
"network.add(L.Dense(units=200))\n",
"network.add(L.Dense(units=50))\n",
"network.add(L.Dense(units=n_actions))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from random import random \n",
"from math import floor\n",
"\n",
"def get_action(state, epsilon=0):\n",
" \"\"\"\n",
" sample actions with epsilon-greedy policy\n",
" recap: with p = epsilon pick random action, else pick action with highest Q(s,a)\n",
" \"\"\"\n",
" \n",
" q_values = network.predict(state[None])[0]\n",
" \n",
" if random() < epsilon:\n",
" return env.action_space.sample()\n",
" else:\n",
" return np.argmax(q_values)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"[ 0 10000]\n",
"e=0.0 tests passed\n",
"1\n",
"[ 540 9460]\n",
"e=0.1 tests passed\n",
"1\n",
"[2452 7548]\n",
"e=0.5 tests passed\n",
"1\n",
"[4949 5051]\n",
"e=1.0 tests passed\n"
]
}
],
"source": [
"assert network.output_shape == (None, n_actions), \"please make sure your model maps state s -> [Q(s,a0), ..., Q(s, a_last)]\"\n",
"assert network.layers[-1].activation == keras.activations.linear, \"please make sure you predict q-values without nonlinearity\"\n",
"\n",
"# test epsilon-greedy exploration\n",
"s = env.reset()\n",
"assert np.shape(get_action(s)) == (), \"please return just one action (integer)\"\n",
"for eps in [0., 0.1, 0.5, 1.0]:\n",
" state_frequencies = np.bincount([get_action(s, epsilon=eps) for i in range(10000)], minlength=n_actions)\n",
" best_action = state_frequencies.argmax()\n",
" print(best_action)\n",
" print(state_frequencies)\n",
" assert abs(state_frequencies[best_action] - 10000 * (1 - eps + eps / n_actions)) < 200\n",
" for other_action in range(n_actions):\n",
" if other_action != best_action:\n",
" assert abs(state_frequencies[other_action] - 10000 * (eps / n_actions)) < 200\n",
" print('e=%.1f tests passed'%eps)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Q-learning via gradient descent\n",
"\n",
"We shall now train our agent's Q-function by minimizing the TD loss:\n",
"$$ L = { 1 \\over N} \\sum_i (Q_{\\theta}(s,a) - [r(s,a) + \\gamma \\cdot max_{a'} Q_{-}(s', a')]) ^2 $$\n",
"\n",
"\n",
"Where\n",
"* $s, a, r, s'$ are current state, action, reward and next state respectively\n",
"* $\\gamma$ is a discount factor defined two cells above.\n",
"\n",
"The tricky part is with $Q_{-}(s',a')$. From an engineering standpoint, it's the same as $Q_{\\theta}$ - the output of your neural network policy. However, when doing gradient descent, __we won't propagate gradients through it__ to make training more stable (see lectures).\n",
"\n",
"To do so, we shall use `tf.stop_gradient` function which basically says \"consider this thing constant when doingbackprop\"."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Create placeholders for the <s, a, r, s'> tuple and a special indicator for game end (is_done = True)\n",
"states_ph = tf.placeholder('float32', shape=(None,) + state_dim)\n",
"actions_ph = tf.placeholder('int32', shape=[None])\n",
"rewards_ph = tf.placeholder('float32', shape=[None])\n",
"next_states_ph = tf.placeholder('float32', shape=(None,) + state_dim)\n",
"is_done_ph = tf.placeholder('bool', shape=[None])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"#get q-values for all actions in current states\n",
"predicted_qvalues = network(states_ph)\n",
"\n",
"#select q-values for chosen actions\n",
"predicted_qvalues_for_actions = tf.reduce_sum(predicted_qvalues * tf.one_hot(actions_ph, n_actions), axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"gamma = 0.99\n",
"\n",
"# compute q-values for all actions in next states\n",
"predicted_next_qvalues = network(next_states_ph)\n",
"\n",
"# compute V*(next_states) using predicted next q-values\n",
"next_state_values = tf.reduce_max(predicted_next_qvalues, axis=1)\n",
"\n",
"# compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n",
"target_qvalues_for_actions = rewards_ph + gamma * next_state_values\n",
"\n",
"# at the last state we shall use simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n",
"target_qvalues_for_actions = tf.where(is_done_ph, rewards_ph, target_qvalues_for_actions)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"#mean squared error loss to minimize\n",
"loss = (predicted_qvalues_for_actions - tf.stop_gradient(target_qvalues_for_actions)) ** 2\n",
"loss = tf.reduce_mean(loss)\n",
"\n",
"# training function that resembles agent.update(state, action, reward, next_state) from tabular agent\n",
"train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"assert tf.gradients(loss, [predicted_qvalues_for_actions])[0] is not None, \"make sure you update q-values for chosen actions and not just all actions\"\n",
"assert tf.gradients(loss, [predicted_next_qvalues])[0] is None, \"make sure you don't propagate gradient w.r.t. Q_(s',a')\"\n",
"assert predicted_next_qvalues.shape.ndims == 2, \"make sure you predicted q-values for all actions in next state\"\n",
"assert next_state_values.shape.ndims == 1, \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n",
"assert target_qvalues_for_actions.shape.ndims == 1, \"there's something wrong with target q-values, they must be a vector\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Playing the game"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def generate_session(t_max=1000, epsilon=0, train=False):\n",
" \"\"\"play env with approximate q-learning agent and train it at the same time\"\"\"\n",
" total_reward = 0\n",
" s = env.reset()\n",
" \n",
" for t in range(t_max):\n",
" a = get_action(s, epsilon=epsilon) \n",
" next_s, r, done, _ = env.step(a)\n",
" \n",
" if train:\n",
" sess.run(train_step,{\n",
" states_ph: [s], actions_ph: [a], rewards_ph: [r], \n",
" next_states_ph: [next_s], is_done_ph: [done]\n",
" })\n",
"\n",
" total_reward += r\n",
" s = next_s\n",
" if done: break\n",
" \n",
" return total_reward"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"epsilon = 0.5"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch #0\tmean reward = 19.400\tepsilon = 0.500\n",
"epoch #1\tmean reward = 17.850\tepsilon = 0.495\n",
"epoch #2\tmean reward = 15.570\tepsilon = 0.490\n",
"epoch #3\tmean reward = 15.600\tepsilon = 0.485\n",
"epoch #4\tmean reward = 18.730\tepsilon = 0.480\n",
"epoch #5\tmean reward = 17.230\tepsilon = 0.475\n",
"epoch #6\tmean reward = 20.270\tepsilon = 0.471\n",
"epoch #7\tmean reward = 15.570\tepsilon = 0.466\n",
"epoch #8\tmean reward = 17.250\tepsilon = 0.461\n",
"epoch #9\tmean reward = 14.080\tepsilon = 0.457\n",
"epoch #10\tmean reward = 15.880\tepsilon = 0.452\n",
"epoch #11\tmean reward = 14.740\tepsilon = 0.448\n",
"epoch #12\tmean reward = 13.540\tepsilon = 0.443\n",
"epoch #13\tmean reward = 16.490\tepsilon = 0.439\n",
"epoch #14\tmean reward = 25.200\tepsilon = 0.434\n",
"epoch #15\tmean reward = 22.610\tepsilon = 0.430\n",
"epoch #16\tmean reward = 17.880\tepsilon = 0.426\n",
"epoch #17\tmean reward = 18.130\tepsilon = 0.421\n",
"epoch #18\tmean reward = 15.090\tepsilon = 0.417\n",
"epoch #19\tmean reward = 15.500\tepsilon = 0.413\n",
"epoch #20\tmean reward = 14.570\tepsilon = 0.409\n",
"epoch #21\tmean reward = 15.200\tepsilon = 0.405\n",
"epoch #22\tmean reward = 16.510\tepsilon = 0.401\n",
"epoch #23\tmean reward = 16.970\tepsilon = 0.397\n",
"epoch #24\tmean reward = 14.810\tepsilon = 0.393\n",
"epoch #25\tmean reward = 16.540\tepsilon = 0.389\n",
"epoch #26\tmean reward = 15.520\tepsilon = 0.385\n",
"epoch #27\tmean reward = 18.500\tepsilon = 0.381\n",
"epoch #28\tmean reward = 16.710\tepsilon = 0.377\n",
"epoch #29\tmean reward = 15.740\tepsilon = 0.374\n",
"epoch #30\tmean reward = 17.580\tepsilon = 0.370\n",
"epoch #31\tmean reward = 16.180\tepsilon = 0.366\n",
"epoch #32\tmean reward = 17.940\tepsilon = 0.362\n",
"epoch #33\tmean reward = 15.220\tepsilon = 0.359\n",
"epoch #34\tmean reward = 16.930\tepsilon = 0.355\n",
"epoch #35\tmean reward = 14.930\tepsilon = 0.352\n",
"epoch #36\tmean reward = 17.310\tepsilon = 0.348\n",
"epoch #37\tmean reward = 18.890\tepsilon = 0.345\n",
"epoch #38\tmean reward = 17.160\tepsilon = 0.341\n",
"epoch #39\tmean reward = 14.360\tepsilon = 0.338\n",
"epoch #40\tmean reward = 14.760\tepsilon = 0.334\n",
"epoch #41\tmean reward = 16.880\tepsilon = 0.331\n",
"epoch #42\tmean reward = 14.970\tepsilon = 0.328\n",
"epoch #43\tmean reward = 12.830\tepsilon = 0.325\n",
"epoch #44\tmean reward = 16.990\tepsilon = 0.321\n",
"epoch #45\tmean reward = 14.050\tepsilon = 0.318\n",
"epoch #46\tmean reward = 18.680\tepsilon = 0.315\n",
"epoch #47\tmean reward = 13.310\tepsilon = 0.312\n",
"epoch #48\tmean reward = 15.050\tepsilon = 0.309\n",
"epoch #49\tmean reward = 14.120\tepsilon = 0.306\n",
"epoch #50\tmean reward = 14.220\tepsilon = 0.303\n",
"epoch #51\tmean reward = 14.390\tepsilon = 0.299\n",
"epoch #52\tmean reward = 17.340\tepsilon = 0.296\n",
"epoch #53\tmean reward = 16.310\tepsilon = 0.294\n",
"epoch #54\tmean reward = 16.310\tepsilon = 0.291\n",
"epoch #55\tmean reward = 18.000\tepsilon = 0.288\n",
"epoch #56\tmean reward = 17.920\tepsilon = 0.285\n",
"epoch #57\tmean reward = 16.730\tepsilon = 0.282\n",
"epoch #58\tmean reward = 17.690\tepsilon = 0.279\n",
"epoch #59\tmean reward = 14.950\tepsilon = 0.276\n",
"epoch #60\tmean reward = 13.480\tepsilon = 0.274\n",
"epoch #61\tmean reward = 14.140\tepsilon = 0.271\n",
"epoch #62\tmean reward = 16.810\tepsilon = 0.268\n",
"epoch #63\tmean reward = 13.670\tepsilon = 0.265\n",
"epoch #64\tmean reward = 17.080\tepsilon = 0.263\n",
"epoch #65\tmean reward = 18.200\tepsilon = 0.260\n",
"epoch #66\tmean reward = 17.600\tepsilon = 0.258\n",
"epoch #67\tmean reward = 15.120\tepsilon = 0.255\n",
"epoch #68\tmean reward = 18.810\tepsilon = 0.252\n",
"epoch #69\tmean reward = 18.800\tepsilon = 0.250\n",
"epoch #70\tmean reward = 18.910\tepsilon = 0.247\n",
"epoch #71\tmean reward = 27.050\tepsilon = 0.245\n",
"epoch #72\tmean reward = 16.130\tepsilon = 0.242\n",
"epoch #73\tmean reward = 15.180\tepsilon = 0.240\n",
"epoch #74\tmean reward = 14.980\tepsilon = 0.238\n",
"epoch #75\tmean reward = 15.770\tepsilon = 0.235\n",
"epoch #76\tmean reward = 16.520\tepsilon = 0.233\n",
"epoch #77\tmean reward = 16.950\tepsilon = 0.231\n",
"epoch #78\tmean reward = 15.550\tepsilon = 0.228\n",
"epoch #79\tmean reward = 15.960\tepsilon = 0.226\n",
"epoch #80\tmean reward = 20.220\tepsilon = 0.224\n",
"epoch #81\tmean reward = 17.350\tepsilon = 0.222\n",
"epoch #82\tmean reward = 13.570\tepsilon = 0.219\n",
"epoch #83\tmean reward = 14.020\tepsilon = 0.217\n",
"epoch #84\tmean reward = 17.210\tepsilon = 0.215\n",
"epoch #85\tmean reward = 14.720\tepsilon = 0.213\n",
"epoch #86\tmean reward = 12.890\tepsilon = 0.211\n",
"epoch #87\tmean reward = 12.950\tepsilon = 0.209\n",
"epoch #88\tmean reward = 17.040\tepsilon = 0.206\n",
"epoch #89\tmean reward = 14.390\tepsilon = 0.204\n",
"epoch #90\tmean reward = 16.700\tepsilon = 0.202\n",
"epoch #91\tmean reward = 13.330\tepsilon = 0.200\n",
"epoch #92\tmean reward = 14.260\tepsilon = 0.198\n",
"epoch #93\tmean reward = 18.560\tepsilon = 0.196\n",
"epoch #94\tmean reward = 17.270\tepsilon = 0.194\n",
"epoch #95\tmean reward = 14.180\tepsilon = 0.192\n",
"epoch #96\tmean reward = 16.940\tepsilon = 0.191\n",
"epoch #97\tmean reward = 13.840\tepsilon = 0.189\n",
"epoch #98\tmean reward = 15.340\tepsilon = 0.187\n",
"epoch #99\tmean reward = 16.350\tepsilon = 0.185\n"
]
}
],
"source": [
"for i in range(100):\n",
" session_rewards = [generate_session(epsilon=epsilon, train=True) for _ in range(100)]\n",
" print(\"epoch #{}\\tmean reward = {:.3f}\\tepsilon = {:.3f}\".format(i, np.mean(session_rewards), epsilon))\n",
" \n",
" epsilon *= 0.99\n",
" assert epsilon >= 1e-4, \"Make sure epsilon is always nonzero during training\"\n",
" \n",
" if np.mean(session_rewards) > 300:\n",
" print (\"You Win!\")\n",
" break\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How to interpret results\n",
"\n",
"\n",
"Welcome to the f.. world of deep f...n reinforcement learning. Don't expect agent's reward to smoothly go up. Hope for it to go increase eventually. If it deems you worthy.\n",
"\n",
"Seriously though,\n",
"* __ mean reward__ is the average reward per game. For a correct implementation it may stay low for some 10 epochs, then start growing while oscilating insanely and converges by ~50-100 steps depending on the network architecture. \n",
"* If it never reaches target score by the end of for loop, try increasing the number of hidden neurons or look at the epsilon.\n",
"* __ epsilon__ - agent's willingness to explore. If you see that agent's already at < 0.01 epsilon before it's is at least 200, just reset it back to 0.1 - 0.5."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Record videos\n",
"\n",
"As usual, we now use `gym.wrappers.Monitor` to record a video of our agent playing the game. Unlike our previous attempts with state binarization, this time we expect our agent to act ~~(or fail)~~ more smoothly since there's no more binarization error at play.\n",
"\n",
"As you already did with tabular q-learning, we set epsilon=0 for final evaluation to prevent agent from exploring himself to death."
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"#record sessions\n",
"import gym.wrappers\n",
"env = gym.wrappers.Monitor(gym.make(\"CartPole-v0\"),directory=\"videos\",force=True)\n",
"sessions = [generate_session(epsilon=0, train=False) for _ in range(100)]\n",
"env.close()\n"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<video width=\"640\" height=\"480\" controls>\n",
" <source src=\"./videos/openaigym.video.0.51645.video000000.mp4\" type=\"video/mp4\">\n",
"</video>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#show video\n",
"from IPython.display import HTML\n",
"import os\n",
"\n",
"video_names = list(filter(lambda s:s.endswith(\".mp4\"),os.listdir(\"./videos/\")))\n",
"\n",
"HTML(\"\"\"\n",
"<video width=\"640\" height=\"480\" controls>\n",
" <source src=\"{}\" type=\"video/mp4\">\n",
"</video>\n",
"\"\"\".format(\"./videos/\"+video_names[-2])) #this may or may not be _last_ video. Try other indices"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment