Skip to content

Instantly share code, notes, and snippets.

@pat-coady
Last active April 6, 2024 22:48
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save pat-coady/26fafa10b4d14234bfde0bb58277786d to your computer and use it in GitHub Desktop.
Save pat-coady/26fafa10b4d14234bfde0bb58277786d to your computer and use it in GitHub Desktop.
Sutton and Barto Racetrack: Off-Policy Monte Carlo Control
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sutton and Barto Racetrack: Monte Carlo Control\n",
"Exercise 5.8 from *Reinforcement Learning: An Introduction* by Sutton and Barto.\n",
"\n",
"This is another surprisingly difficult, but very worthwhile exercise from Sutton and Barto. About 50% of the effort was building the racetrack environment, but this same environment will be useful for evaluating learning strategies in coming chapters.\n",
"\n",
"The learning strategy is from the algorithm box titled \"Off-policy every-visit MC control\". Even for a small environment (~14,000 states and 9 actions) learning is very slow. Various hyper-parameters have to be tuned to get reasonable training time: initial Q(s,a) values, off-policy probability (epsilon) and initial policy. It is readily apparent this approach will not scale to real problems. This problem provides excellent understanding of the challenge of exploring and learning in large state-action spaces. What if your self-driving car sees a state that it never encountered in training?\n",
"\n",
"Python Notebook by Patrick Coady: [Learning Artificial Intelligence](https://learningai.io/)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import random\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class RaceTrack(object):\n",
" \"\"\"\n",
" RaceTrack object maintains and updates the race track \n",
" state. Interaction with the class is through\n",
" the take_action() method. The take_action() method returns\n",
" a successor state and reward (i.e. s' and r)\n",
" \n",
" The class constructor is given a race course as a list of \n",
" strings. The constructor loads the course and initializes \n",
" the environment state.\n",
" \"\"\"\n",
" def __init__(self, course):\n",
" \"\"\"\n",
" Load race course, set any min or max limits in the \n",
" environment (e.g. max speed), and set initial state.\n",
" Initial state is random position on start line with \n",
" velocity = (0, 0).\n",
" \n",
" Args:\n",
" course: List of text strings used to construct\n",
" race-track.\n",
" '+': start line\n",
" '-': finish line\n",
" 'o': track\n",
" 'X': wall\n",
" \n",
" Returns:\n",
" self\n",
" \"\"\"\n",
" self.NOISE = 0.0\n",
" self.MAX_VELOCITY = 4\n",
" self.start_positions = []\n",
" self.course = None\n",
" self._load_course(course)\n",
" self._random_start_position()\n",
" self.velocity = np.array([0, 0], dtype=np.int16)\n",
"\n",
"\n",
" def take_action(self, action):\n",
" \"\"\"\n",
" Take action, return state' and reward\n",
" \n",
" Args:\n",
" action: 2-tuple of requested change in velocity in x- and\n",
" y-direction. valid action is -1, 0, +1 in each axis.\n",
" \n",
" Returns:\n",
" reward: integer\n",
" \"\"\"\n",
" if self.is_terminal_state():\n",
" return 0\n",
" self._update_velocity(action)\n",
" self._update_position()\n",
"\n",
" return -1\n",
" \n",
"\n",
" def get_state(self):\n",
" \"\"\"Return 2-tuple: (position, velocity). Each is a 2D numpy array.\"\"\"\n",
" return self.position.copy(), self.velocity.copy()\n",
" \n",
"\n",
" def _update_velocity(self, action):\n",
" \"\"\"\n",
" Update x- and y-velocity. Clip at 0 and self.MAX_VELOCITY\n",
" \n",
" Args:\n",
" action: 2-tuple of requested change in velocity in x- and\n",
" y-direction. valid action is -1, 0, +1 in each axis. \n",
" \"\"\"\n",
" if np.random.rand() > self.NOISE:\n",
" self.velocity += np.array(action, dtype=np.int16)\n",
" self.velocity = np.minimum(self.velocity, self.MAX_VELOCITY)\n",
" self.velocity = np.maximum(self.velocity, 0)\n",
" \n",
" def reset(self):\n",
" self._random_start_position()\n",
" self.velocity = np.array([0, 0], dtype=np.int16)\n",
"\n",
" def _update_position(self):\n",
" \"\"\"\n",
" Update position based on present velocity. Check at fine time \n",
" scale for wall or finish. If wall is hit, set position to random\n",
" position at start line. If finish is reached, set position to \n",
" first crossed point on finish line.\n",
" \"\"\"\n",
" for tstep in range(0, self.MAX_VELOCITY+1):\n",
" t = tstep / self.MAX_VELOCITY\n",
" pos = self.position + np.round(self.velocity * t).astype(np.int16)\n",
" if self._is_wall(pos):\n",
" self._random_start_position()\n",
" self.velocity = np.array([0, 0], dtype=np.int16)\n",
" return\n",
" if self._is_finish(pos):\n",
" self.position = pos\n",
" self.velocity = np.array([0, 0], dtype=np.int16)\n",
" return\n",
" self.position = pos\n",
" \n",
"\n",
" def _random_start_position(self):\n",
" \"\"\"Set car to random position on start line\"\"\"\n",
" self.position = np.array(random.choice(self.start_positions),\n",
" dtype=np.int16)\n",
" \n",
"\n",
" def _load_course(self, course):\n",
" \"\"\"Load course. Internally represented as numpy array\"\"\"\n",
" y_size, x_size = len(course), len(course[0])\n",
" self.course = np.zeros((x_size, y_size), dtype=np.int16)\n",
" for y in range(y_size):\n",
" for x in range(x_size):\n",
" point = course[y][x]\n",
" if point == 'o':\n",
" self.course[x, y] = 1\n",
" elif point == '-':\n",
" self.course[x, y] = 0\n",
" elif point == '+':\n",
" self.course[x, y] = 2\n",
" elif point == 'W':\n",
" self.course[x, y] = -1\n",
" # flip left/right so (0,0) is in bottom-left corner\n",
" self.course = np.fliplr(self.course)\n",
" for y in range(y_size):\n",
" for x in range(x_size):\n",
" if self.course[x, y] == 0:\n",
" self.start_positions.append((x, y))\n",
" \n",
"\n",
" def _is_wall(self, pos):\n",
" \"\"\"Return True is position is wall\"\"\"\n",
" return self.course[pos[0], pos[1]] == -1\n",
" \n",
"\n",
" def _is_finish(self, pos):\n",
" \"\"\"Return True if position is finish line\"\"\"\n",
" return self.course[pos[0], pos[1]] == 2\n",
" \n",
"\n",
" def is_terminal_state(self):\n",
" \"\"\"Return True at episode terminal state\"\"\"\n",
" return (self.course[self.position[0], \n",
" self.position[1]] == 2)\n",
" \n",
" \n",
" def action_to_tuple(self, a):\n",
" \"\"\"Convert integer action to 2-tuple: (ax, ay)\"\"\"\n",
" ax = a // 3 - 1\n",
" ay = a % 3 - 1\n",
" \n",
" return ax, ay\n",
" \n",
" \n",
" def tuple_to_action(self, a):\n",
" \"\"\"Convert 2-tuple to integer action: {0-8}\"\"\"\n",
" return int((a[0] + 1) * 3 + a[1] + 1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Race Track from Sutton and Barto Figure 5.6\n",
"\n",
"big_course = ['WWWWWWWWWWWWWWWWWW',\n",
" 'WWWWooooooooooooo+',\n",
" 'WWWoooooooooooooo+',\n",
" 'WWWoooooooooooooo+',\n",
" 'WWooooooooooooooo+',\n",
" 'Woooooooooooooooo+',\n",
" 'Woooooooooooooooo+',\n",
" 'WooooooooooWWWWWWW',\n",
" 'WoooooooooWWWWWWWW',\n",
" 'WoooooooooWWWWWWWW',\n",
" 'WoooooooooWWWWWWWW',\n",
" 'WoooooooooWWWWWWWW',\n",
" 'WoooooooooWWWWWWWW',\n",
" 'WoooooooooWWWWWWWW',\n",
" 'WoooooooooWWWWWWWW',\n",
" 'WWooooooooWWWWWWWW',\n",
" 'WWooooooooWWWWWWWW',\n",
" 'WWooooooooWWWWWWWW',\n",
" 'WWooooooooWWWWWWWW',\n",
" 'WWooooooooWWWWWWWW',\n",
" 'WWooooooooWWWWWWWW',\n",
" 'WWooooooooWWWWWWWW',\n",
" 'WWooooooooWWWWWWWW',\n",
" 'WWWoooooooWWWWWWWW',\n",
" 'WWWoooooooWWWWWWWW',\n",
" 'WWWoooooooWWWWWWWW',\n",
" 'WWWoooooooWWWWWWWW',\n",
" 'WWWoooooooWWWWWWWW',\n",
" 'WWWoooooooWWWWWWWW',\n",
" 'WWWoooooooWWWWWWWW',\n",
" 'WWWWooooooWWWWWWWW',\n",
" 'WWWWooooooWWWWWWWW',\n",
" 'WWWW------WWWWWWWW']\n",
"\n",
"# Tiny course for debug\n",
"\n",
"tiny_course = ['WWWWWW',\n",
" 'Woooo+',\n",
" 'Woooo+',\n",
" 'WooWWW',\n",
" 'WooWWW',\n",
" 'WooWWW',\n",
" 'WooWWW',\n",
" 'W--WWW',]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode 10000\n",
"Episode 20000\n",
"Episode 30000\n",
"Episode 40000\n",
"Episode 50000\n",
"Episode 60000\n",
"Episode 70000\n",
"Episode 80000\n",
"Episode 90000\n",
"Episode 100000\n",
"Episode 110000\n",
"Episode 120000\n",
"Episode 130000\n",
"Episode 140000\n",
"Episode 150000\n",
"Episode 160000\n",
"Episode 170000\n",
"Episode 180000\n",
"Episode 190000\n",
"Episode 200000\n"
]
}
],
"source": [
"# Problem Initialization\n",
"\n",
"course = big_course\n",
"x_size, y_size = len(course[0]), len(course)\n",
"# initialize Q to large negative - otherwise too much exploration\n",
"Q = np.zeros((x_size, y_size, 5, 5, 3, 3)) - 40 # 5 = num speeds, x & y\n",
"C = np.zeros((x_size, y_size, 5, 5, 3, 3)) # 3 = num actions, x & y\n",
"# initialize policy to all 4s: action = (0, 0)\n",
"pi = np.zeros((x_size, y_size, 5, 5), dtype=np.int16) + 4\n",
"actions = [(x, y) for x in range(3) for y in range(3)]\n",
"eps = 0.1 # epsilon for epsilon-greedy policy\n",
"N = 200000 # num episodes\n",
"gamma = 0.9\n",
"track = RaceTrack(course)\n",
"explore_map = np.zeros((x_size, y_size)) # track explored positions\n",
"\n",
"for e in range(N):\n",
" if (e+1) % 10000 == 0: print('Episode {}'.format(e+1))\n",
" # generate episode\n",
" track.reset()\n",
" episode = []\n",
" while not track.is_terminal_state():\n",
" s = track.get_state()\n",
" s_x, s_y = s[0][0], s[0][1]\n",
" s_vx, s_vy = s[1][0], s[1][1]\n",
" explore_map[s_x, s_y] += 1.0\n",
" if np.random.rand() > eps:\n",
" a = pi[s_x, s_y, s_vx, s_vy]\n",
" else:\n",
" a = random.randrange(9)\n",
" a = track.action_to_tuple(a)\n",
" r_prime = track.take_action(a)\n",
" episode.append((s, a, r_prime))\n",
" \n",
" # update Q and policy\n",
" G = 0\n",
" W = 1\n",
" while len(episode) > 0:\n",
" (s, a, r_prime) = episode.pop()\n",
" s_x, s_y = s[0][0], s[0][1]\n",
" s_vx, s_vy = s[1][0], s[1][1]\n",
" a_x, a_y = a\n",
" s_a = (s_x, s_y, s_vx, s_vy, a_x, a_y)\n",
" G = gamma * G + r_prime\n",
" C[s_a] += W\n",
" Q[s_a] += W / C[s_a] * (G - Q[s_a])\n",
" q_max = -1e6\n",
" for i in range(9): # argmax over actions\n",
" act = track.action_to_tuple(i)\n",
" s_a_max = (s_x, s_y, s_vx, s_vy, act[0], act[1])\n",
" if Q[s_a_max] > q_max:\n",
" a_max = i\n",
" q_max = Q[s_a_max]\n",
" if pi[s_x, s_y, s_vx, s_vy] != track.tuple_to_action(a):\n",
" pi[s_x, s_y, s_vx, s_vy] = a_max\n",
" break\n",
" else:\n",
" W *= 1/(1 - 8 / 9 * eps)\n",
" pi[s_x, s_y, s_vx, s_vy] = a_max"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample trajectory on learned policy:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAANQAAAFfCAYAAAA23uK3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAEP9JREFUeJzt3X/sXXV9x/Hn25aC1DEizFanTrD+mFkkWqWyWWTDjA0T\nnNGIYGKALAaLxvCPxMwMh5mLGhqmpovLFCRqEyYa3QatiopDpE2oTAGRiEUQaPllWlNaa9v3/ji3\nern9/rj3e9/3e8/9fp+P5CS9555zzzvfe179fM75nHNPZCaSajxj3AVIC4mBkgoZKKmQgZIKGSip\nkIGSChkoqZCBkgoZKKnQ0nEXEBEnAGcB9wP7xluNNKVjgBcBmzPziRmXzMyRTMAlwHZgL3Ab8Npp\nljsfSCenCZjOn22/H0kLFRHnAlcC7wa2ApcCmyPipZn5eM/i9wMs4+n9z/2deZPI2sdjVLUf6nw2\nnX11JqPq8l0KfCYzrwWIiIuBNwEXAR/vWXYfNGHqDlQwuQd41j4e81D7rIck5duPiKOA1cBNh+dl\n07f7FnBa9fakNhlFoE8ElgA7e+bvBFaOYHtSa4z9LN9h+2ma7MMOAgdoUYFaFA7Q7HvdcoD1R7G/\nPk5T04qe+SuAHdOt1HtSYpLDtGTcBQxhsde+lCP3u0P0P55T3uXLzN8CtwNnHp4XEdF5fWu/nzOp\nYQJrH5c21D6qGtYD10TE7fz+tPmxwDUj2t5I7Nky7go0b06dvmO3bds2Vq9e3dfHjCRQmXldRJwI\nXEHT1bsDOCszHxvF9qS2GFkrmZkbgA2j+nypjSZ1DE9qJQMlFTJQUiEDJRUyUFIhAyUVMlBSoTZc\nrTHvvAJCo2ILJRUyUFIhAyUVMlBSIQMlFTJQUiEDJRUyUFIhAyUVMlBSIQMlFTJQUiEDJRUyUFIh\nAyUVMlBSIQMlFTJQUqEFdwu8t7e33Aw/yv87W2P2ZVrKFkoqZKCkQgZKKmSgpEIGSipkoKRCBkoq\nZKCkQhMzsOuA7QIxwYO2/ShvoSLi8og41DPdXb0dqY1G1ULdCZwJHP7v6MCItiO1yqgCdSAzHxvR\nZ0utNaqTEi+JiIci4r6I+EJEvGBE25FaZRSBug24ADgLuBg4CfheRCwfwbakVinv8mXm5q6Xd0bE\nVuAXwNuBq6dbbz+/P+A6bMkoCpRmsHHjRjZu3Pi0ebt27ep7/cjs4/6UIXVC9c3M/Icp3ns1cPsx\nzNxcetpcIzXDfVrbtm1j9erVAKszc9tMHzPygd2IeBbwYuCRUW9LGrdRjEN9IiJOj4g/iYg/B75K\nc9p84yyrShNvFIcozwe+BJwAPAbcArwuM58YwbY0jH5uR4cFf3VDpVGclDiv+jOlSeHFsVIhAyUV\nMlBSIQMlFTJQUiEDJRUyUFIhrz1dzPocsF2+ZsR1tML0f4tDA3yKLZRUyEBJhQyUVMhASYUMlFTI\nQEmFDJRUyEBJhRzYXcQWx4Dt/LKFkgoZKKmQgZIKGSipkIGSChkoqZCBkgoZKKmQgZIKeaXEJOrn\nN8n9PfKxsIWSChkoqZCBkgoZKKmQgZIKGSipkIGSChkoqZADu5PIQdvWGriFioi1EfH1iHgoIg5F\nxDlTLHNFRDwcEU9FxDcjYlVNuVK7zaXLtxy4A1gHHHENTERcBrwXeDdwKrAH2BwRy4aoU5oIA3f5\nMnMTsAkgIqbqe7wf+Ehm/ndnmXcBO4G/A66be6lS+5WelIiIk4CVwE2H52XmbmALcFrltqQ2qj7L\nt5KmG7izZ/7OznvSgtaas3z7OfIZcktoUYFaFA4AB3vm9XGzzO9U7687aHKxgqe3UiuAH8604jIc\nFNP4LeXIUBwC9vW5fuk+nJnbaUJ15uF5EXEcsAa4tXJbUhsN3EJFxHJgFb/voZ0cEacAT2bmg8BV\nwIci4mfA/cBHgF8CXyupWGqxuXT5XgN8h6ZrmcCVnfmfBy7KzI9HxLHAZ4Djgf8F/jYz9xfUK7Xa\nXMahbmaWrmJmfhj48NxKkiaX5wGkQgZKKmSgpEIGSipkoKRCBkoqZKCkQgZKKmSgpEIGSipkoKRC\nBkoqZKCkQgZKKmSgpEIGSipkoKRC/kpXm/TzdHfwYQEtZgslFTJQUiEDJRUyUFIhAyUVMlBSIQMl\nFTJQUiEHdtvEAduJZwslFTJQUiEDJRUyUFIhAyUVMlBSIQMlFTJQUqGJGdhdvqa/5fZsGW0d0kwG\nbqEiYm1EfD0iHoqIQxFxTs/7V3fmd0831JUstddcunzLgTuAdcB0P4JwI7ACWNmZzptTddKEGbjL\nl5mbgE0AETHdxWe/yczHhilMmkSjOilxRkTsjIh7ImJDRDx7RNuRWmUUJyVuBK4HtgMvBv4FuCEi\nTsvMPn8nS5pM5YHKzOu6Xt4VET8G7gPOAL4z3Xr7gd7+45JRFCjN4ABwsGfeIK3AyMehMnM78Diw\naqbllgFH90yGSfNtKUfuh8sGWH/kgYqI5wMnAI+MelvSuA3cCETEcprW5nAP7eSIOAV4sjNdTnMM\ntaOz3MeAe4HNFQVLbTaXXtVraI6FsjNd2Zn/eZqxqVcC7wKOBx6mCdI/ZuZvh652ofO3zSfeXMah\nbmbmruLfzL0cabJ5caxUyEBJhQyUVMhASYUMlFTIQEmFDJRUyMvl2sQB24lnCyUVMlBSIQMlFTJQ\nUiEDJRUyUFIhAyUVMlBSIQMlFTJQUiEDJRUyUFIhAyUVMlBSIQMlFTJQUiEDJRUyUFIhAyUVMlBS\nIQMlFTJQUiEDJRUyUFIhAyUVMlBSIQMlFTJQUqGBHhYQER8E3gK8HNgL3Apclpn3di1zNLAeOBc4\nmuYp8Osy89GqohcsnwI/8QZtodYCnwLWAG8EjgK+ERHP7FrmKuBNwFuB04HnAdcPX6rUfgO1UJl5\ndvfriLgAeBRYDdwSEccBFwHvyMybO8tcCPwkIk7NzK0lVUstNewx1PFAAk92Xq+mCelNhxfIzJ8C\nDwCnDbktqfXmHKiICJru3S2ZeXdn9kpgf2bu7ll8Z+c9aUEb5gmGG4BXAK+vKGQ/0HuovQQfsaj5\ndQA42DOvz1NFwBz314j4NHA2sDYzH+56awewLCKO62mlVnTem9YyPIev8VvKkaE4BOzrc/2B9+FO\nmN4M/GVmPtDz9u00IT+za/mXAS8EfjDotqRJM+g41AbgPOAcYE9ErOi8tSsz92Xm7oj4LLA+In4F\n/Br4JPB9z/BpMRi0y3cxTZfyuz3zLwSu7fz7Uppu6JdpBnY3AZfMvcRFxAHbiTfoONSsXcTM/A3w\nvs4kLSqeB5AKGSipkIGSChkoqZCBkgoZKKmQgZIKGSipkIGSChkoqZCBkgoZKKmQgZIKGSipkIGS\nChkoqZCBkgotuF/pWr5m9mX2bBl9HVqcbKGkQgZKKmSgpEIGSipkoKRCBkoqZKCkQgZKKrTgBnYn\nmg+tnni2UFIhAyUVMlBSIQMlFTJQUiEDJRUyUFIhAyUVMlBSoYGulIiIDwJvAV4O7AVuBS7LzHu7\nlvkucHrXagl8JjPXDV3tQucVEBNv0BZqLfApYA3wRuAo4BsR8cyuZRL4d2AFsBJ4LvCB4UuV2m+g\nFiozz+5+HREXAI8Cq4Fbut56KjMfG7o6acIMewx1PE2L9GTP/HdGxGMR8eOI+GhPCyYtWHO+2jwi\nArgKuCUz7+5664vAL4CHgVcCHwdeCrxtiDqliTDM7RsbgFcAf9E9MzP/o+vlXRGxA/hWRJyUmdun\n+7D9QO8h+ZIhC5QGdQA42DOvz5tqgDnurxHxaeBsYG1mPjLL4ltosrIKmDZQy/AcvsZvKUeG4hCw\nb4D1B9IJ05uBN2TmA32s8iqakM8WPGniDToOtQE4DzgH2BMRKzpv7crMfRFxMnA+cAPwBHAKsB64\nOTPvrCtbaqdBW6iLaVqb7/bMvxC4luZQ6I3A+4HlwIPAfwL/PFSV0oQYdBxqxsOczPwlcMYwBUmT\nzPMAUiEDJRUyUFIhAyUVMlBSIQMlFTJQUiGvPW0Tf9t84tlCSYUMlFTIQEmFDJRUyEBJhQyUVMhA\nSYUMlFTIQEmFvFKiTbwCYuLZQkmFDJRUyEBJhQyUVMhASYUMlFTIQEmFDJRUyIHdNvEW+IlnCyUV\nMlBSIQMlFTJQUiEDJRUyUFIhAyUVMlBSIQMlFRroSomIuBh4D/Cizqy7gCsyc1Pn/aOB9cC5wNHA\nZmBdZj5aVXCF5Wv6W27PltHWcQSvgJh4g7ZQDwKXAa8GVgPfBr4WEX/aef8q4E3AW4HTgecB19eU\nKrXfQC1UZv5Pz6wPRcR7gNdFxEPARcA7MvNmgIi4EPhJRJyamVtLKpZabM7HUBHxjIh4B3As8AOa\nFmspcNPhZTLzp8ADwGlD1ilNhIGvNo+IP6MJ0DHAr4G3ZOY9EfEqYH9m7u5ZZSewcuhKpQkwl9s3\n7gFOAf4QeBtwbUScPmwh+4HeQ/IleH+J5tcB4GDPvD5vqgHmsL9m5gHg552XP4yIU4H3A9cByyLi\nuJ5WagWwY7bPXYbn8DV+SzkyFIeAfX2uX7EPP4PmFPntNAE/8/AbEfEy4IU0XURpwRt0HOqjwI00\nJxr+AHgn8AbgrzNzd0R8FlgfEb+iOb76JPB9z/BpsRi0y/cc4PPAc4FdwI9owvTtzvuX0nRBv0zT\nam0CLqkpdRHwFviJN+g41N/P8v5vgPd1JmnR8TyAVMhASYUMlFTIQEmFDJRUyEBJhQyUVMhASYW8\nmHsGj/Z5q3yV52zp7wqI+a5L/bOFkgoZKKlQawN1YNwFDOEr4y5gCJP8d29D7a0NVO9dk5Pkq+Mu\nYAiT/HdvQ+2tDZQ0iQyUVMhASYXaMA51DDQ/hNEtp5g33340x/V2z3HdZ9/T33JPzuGzpzLV37cN\nf/e5GlXtXZ95zGzLRuYgP5JULyLOB7441iKk/rwzM7800wJtCNQJwFnA/fT/a03SfDqG5gEZmzPz\niZkWHHugpIXEkxJSIQMlFTJQUiEDJRUyUFKh1gUqIi6JiO0RsTcibouI1467pn5ExOURcahnunvc\ndU0lItZGxNcj4qFOnedMscwVEfFwRDwVEd+MiFXjqLXXbLVHxNVTfA83zFd9rQpURJwLXAlcDrwK\n+D9gc0ScONbC+ncnzeN7Vnam14+3nGktB+4A1jHF448i4jLgvcC7gVOBPTTfw7L5LHIaM9becSNP\n/x7Om5/SgMxszQTcBvxr1+sAfgl8YNy19VH75cC2cdcxh7oPAef0zHsYuLTr9XHAXuDt4663j9qv\nBr4yrppa00JFxFE0z+ntfkZvAt9icp7R+5JOV+S+iPhCRLxg3AUNKiJOovlfvft72A1sYXK+hzMi\nYmdE3BMRGyLi2fO14dYECjiR5imgO3vmT8ozem8DLqC5jOpi4CTgexGxfJxFzcFKmq7UpH4PNwLv\nAv4K+ADN88tuiIh5eQZQG642XxAyc3PXyzsjYivwC+DtNN0QzYPMvK7r5V0R8WPgPuAM4Duj3n6b\nWqjHae5iXtEzv69n9LZNZu4C7gVacXZsADtojl0XyvewnWbfmpfvoTWByszf0jynt/sZvdF5feu4\n6pqriHgW8GLgkXHXMojODriDp38PxwFrmMzv4fnACczT99C2Lt964JqIuB3YSvOI0WOBa8ZZVD8i\n4hPAf9F08/4Y+CeaH+LZOM66ptI5rltF0xIBnBwRpwBPZuaDwFXAhyLiZzS31XyE5mzr18ZQ7tPM\nVHtnuhy4nuY/hVXAx2h6CpuP/LQRGPepzylOha6j+RL30jw9/jXjrqnPujfS7HR7aR7q/SXgpHHX\nNU2tb6A55XywZ/pc1zIfpjl9/lRnZ1w17rpnq53mvqVNNGHaB/wc+Dfgj+arPu+Hkgq15hhKWggM\nlFTIQEmFDJRUyEBJhQyUVMhASYUMlFTIQEmFDJRUyEBJhf4fEEXXbaR2UNUAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f57caa86358>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Run learned policy on test case\n",
"\n",
"pos_map = np.zeros((x_size, y_size))\n",
"track.reset()\n",
"for e in range(1000):\n",
" s = track.get_state()\n",
" s_x, s_y = s[0][0], s[0][1]\n",
" s_vx, s_vy = s[1][0], s[1][1]\n",
" pos_map[s_x, s_y] += 1 # exploration map\n",
" act = track.action_to_tuple(pi[s_x, s_y, s_vx, s_vy])\n",
" G += track.take_action(act)\n",
" if track.is_terminal_state(): break \n",
"\n",
"print('Sample trajectory on learned policy:')\n",
"pos_map = (pos_map > 0).astype(np.float32)\n",
"pos_map += track.course # overlay track course\n",
"plt.imshow(np.flipud(pos_map.T), cmap='hot', interpolation='nearest')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learned Policy\n",
"\n",
"The \"racecar\" accelerates to maximum speed out of the gate (4 steps per update). Then it slows down before the curve and accelerates to the finish. For some reason it slows down a bit y-position=17. With another 1/2 million training steps it will likely improve. But, better to try a different learning algorithm."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Policy exploration heat map:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAANQAAAFfCAYAAAA23uK3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAEtVJREFUeJzt3X+sZGV9x/H3112WlbXbjaC7WlB+LEpNI9FVkNZFFFJa\nTLDEBgQTAqQxuGgMfwgxMaXF1kYJG6pmjU0tP2K1oVKjbWFXQcUiAsnqtvwQqbjA8mOXBeRHYZdl\nvU//OLNxmHv33mfufOfOmbvvVzIJc+aZc77szOc+Z87znHOilIKkHK8YdQHSfGKgpEQGSkpkoKRE\nBkpKZKCkRAZKSmSgpEQGSkq0cNQFRMSBwMnAA8DO0VYjTWkxcCiwoZTy5LQtSylDeQAXAJuBHcBt\nwDv30u4soPjwMQaPs2b63g+lh4qIM4DLgY8AdwAXAhsi4k2llCd6mj8AsIiX73/u6iwbR9Y+GsOq\nfaKzbjrf1ekMa5fvQuArpZRrACLifOD9wHnA53va7oQmTN2BCsb3B561j8Yc1D7jT5L07UfEfsAq\n4KY9y0qzb3cjcFz29qQ2GUagDwIWANt6lm8DVgxhe1JrjPwo3x67aLrsPX4D7KZFBWqfsJvmu9et\n9PH+YXxfn6CpaXnP8uXA1r29qfegxDiHacGoCxjAvl77QiZ/7yaoH89J3+UrpbwEbARO3LMsIqLz\n/Nba9YxrmMDaR6UNtQ+rhrXAVRGxkd8eNj8AuGpI29unZPUi+1W2yxxtX5y4vdr6a7yUtJ6hBKqU\ncm1EHARcSrOrtwk4uZSyfRjbk9piaL1kKWUdsG5Y65faaFzH8KRWMlBSIgMlJTJQUiIDJSUyUFIi\nAyUlasNsDXXUzoA4oKJNzWyD2u3VzG7onVA6SLvaGRA19dfWlcUeSkpkoKREBkpKZKCkRAZKSmSg\npEQGSkpkoKREDuy2SO0g5AtJ66o97btmoHVJ5bpqBqVr/v+grq6nK9eVxR5KSmSgpEQGSkpkoKRE\nBkpKZKCkRAZKSmSgpEQGSkrkTIl5qub08NpTzWtmXdRe4P+Ra2dus/T0ypVV8BR4aYwZKCmRgZIS\nGSgpkYGSEhkoKZGBkhIZKCmRA7stUnut8ddUtHk+aT0AWyraLKtc16EVg7ZnV67raxVtxn5gNyIu\niYiJnsc92duR2mhYPdRdwIlAdJ7vHtJ2pFYZVqB2l1K2D2ndUmsN66DEkRHxSETcHxFfi4hDhrQd\nqVWGEajbgHOAk4HzgcOAH0VE7aXbpLGVvstXStnQ9fSuiLgDeBA4Hbhyb+/bxW9/cO2xYBgFStPY\nzeQjg6WP9w/9+1pKeSYi7gNWTtduEQ6KafQWMjkUE9Sf7zX073BEvAo4Anhs2NuSRm0Y41CXRcTx\nEfHGiPhD4Fs0Pek3srcltc0wdvkOBr4OHAhsB24B3lVKeXII25pXav+6PZe0rjsrBzYurphS8aW6\nVVW5prLdXM+CqDGMgxJnZq9TGhceB5ASGSgpkYGSEhkoKZGBkhIZKCmRgZISOfe0RSYq2x1e0eZ/\nK9ocUnkO/FN1zdK0ccC2lj2UlMhASYkMlJTIQEmJDJSUyEBJiQyUlMhASYkc2B1DuyraXFTR5tJB\nC9Ek9lBSIgMlJTJQUiIDJSUyUFIiAyUlMlBSIgMlJTJQUiJnSrTIcZXtNlTcAvzQtwxUimbJHkpK\nZKCkRAZKSmSgpEQGSkpkoKREBkpKZKCkRA7stsgvKtv9TcWg7dMDVaLZ6ruHiojVEfGdiHgkIiYi\n4tQp2lwaEY9GxAsR8b2IWJlTrtRus9nlWwJsAtYApffFiLgY+BjwEeAY4HlgQ0QsGqBOaSz0vctX\nSlkPrAeIiJiiySeAz5RS/qPT5mxgG/BnwLWzL1Vqv9SDEhFxGLACuGnPslLKs8Dt1M/9lMZW9lG+\nFTS7gdt6lm/rvCbNa605yrcL6N1/XECLCtQ+YTeT76A46UDBNLK/r1tpcrGcl/dSy4GfTffGRTgo\nptFbyORQTAA7K9+f+h0upWymCdWJe5ZFxFLgWODWzG1JbdR3DxURS4CV/HYP7fCIOBp4qpSyBbgC\n+HRE/BJ4APgM8DDw7ZSKpRabzS7fO4Af0OxaFuDyzvKrgfNKKZ+PiAOArwDLgP8C/rSUUnON+33a\nssp2X65oU3tHeeWKUvr5yTWEAiLeDmxcjL+hjqxst72izXMVbXp/fGtqXb+hVpVSfjpd2339Oyyl\nMlBSIgMlJTJQUiIDJSUyUFIiAyUlcu7pPFXzl9JxqHz2UFIiAyUlMlBSIgMlJTJQUiIDJSUyUFIi\nAyUlMlBSImdKtMim8smqdmfEZTO2uWHQYjQr9lBSIgMlJTJQUiIDJSUyUFIiAyUlMlBSIgMlJXJg\nt0WWVgzYgqe3t5k9lJTIQEmJDJSUyEBJiQyUlMhASYkMlJTIQEmJHNhtkTdWtnuwos2CijYO/ubr\nu4eKiNUR8Z2IeCQiJiLi1J7Xr+ws735cn1ey1F6z2eVbAmwC1gB7u4X8DcByYEXnceasqpPGTN+7\nfKWU9cB6gIiIvTR7sZSyfZDCpHE0rIMSJ0TEtoi4NyLWRcSrh7QdqVWGcVDiBuA6YDNwBPB3wPUR\ncVwpZW+7iNK8kB6oUsq1XU/vjog7gfuBE4Af7O19u4De/ccFwyhQmsZuJh/97KcXGPo4VCllM/AE\nsHK6douA/XsehklzbSGTv4eL+nj/0AMVEQcDBwKPDXtb0qj13QlExBKa3mbPHtrhEXE08FTncQnN\nb6itnXafA+4DNmQULLXZbPaq3kHzW6h0Hpd3ll9NMzb1VuBsYBnwKE2Q/rKU8tLA1Y6x/Sra3FmO\nqVrX++KOGdvcXrUmZZvNONTNTL+r+CezL0cab06OlRIZKCmRgZISGSgpkYGSEhkoKZGBkhI5XW6O\n1JySfnLFgC3AlsFK0RDZQ0mJDJSUyEBJiQyUlMhASYkMlJTIQEmJDJSUyEBJiZwpMUd2VrR5uHJd\n2yraeLOA0bCHkhIZKCmRgZISGSgpkYGSEhkoKZGBkhIZKCmRA7stUvvXreZ2kN6PdTTsoaREBkpK\nZKCkRAZKSmSgpEQGSkpkoKREBkpKZKCkRH3NlIiITwGnAUcBO4BbgYtLKfd1tdkfWAucAexPcxf4\nNaWUx7OKHkfLKtrcWT5eta4j44uDFaOh6beHWg18ETgWOAnYD/huRLyyq80VwPuBDwLHA68Hrhu8\nVKn9+uqhSimndD+PiHOAx4FVwC0RsRQ4D/hQKeXmTptzgZ9HxDGllLr7tUhjatDfUMuAAjzVeb6K\nJqQ37WlQSvkF8BBw3IDbklpv1oGKiKDZvbullHJPZ/EKYFcp5dme5ts6r0nz2iCnb6wD3gK8O6OQ\nXUD0LFuA55dobu1m8vUKSx/vn9X3NSK+BJwCrC6lPNr10lZgUUQs7emllnde26tFeAxfo7eQyaGY\noO5CpTCL73AnTB8A3ltKeajn5Y00IT+xq/2bgTcAP+l3W9K46Xccah1wJnAq8HxELO+89EwpZWcp\n5dmI+CqwNiJ+DTwHfAH4sUf4tC/od5fvfJpdyh/2LD8XuKbz3xfS7IZ+k2Zgdz1wwexLnB+ermjz\nrsoB25cq2nht89Hodxxqxl3EUsqLwMc7D2mf4nEAKZGBkhIZKCmRgZISGSgpkYGSEhkoKZFzT1uk\ndr5YzYCsg7ajYQ8lJTJQUiIDJSUyUFIiAyUlMlBSIgMlJTJQUiIDJSVypsSAFieu68nKdhOJ21Qu\neygpkYGSEhkoKZGBkhIZKCmRgZISGSgpkYGSEjmwO6Da09ZrrjW+pdTdoGRFzHwzyP0q1lNzjXT1\nxx5KSmSgpEQGSkpkoKREBkpKZKCkRAZKSmSgpEQGSkrU10yJiPgUcBpwFLADuBW4uJRyX1ebHwLH\nd72tAF8ppawZuNoWqpkBAXUX7z+qYgZE7TY9TX40+u2hVgNfBI4FTqKZ4fLdiHhlV5sC/AOwHFgB\nvA64aPBSpfbrq4cqpZzS/TwizgEeB1YBt3S99EIpZfvA1UljZtDfUMtoeqSnepZ/OCK2R8SdEfHZ\nnh5MmrdmPds8IgK4ArillHJP10v/DDwIPAq8Ffg88CbgzweoUxoLg5y+sQ54C/BH3QtLKf/Y9fTu\niNgK3BgRh5VSNu9tZbuA6Fm2YMACpX7tZvIBpNLH+2f1fY2ILwGnAKtLKY/N0Px2mqysBPYaqEV4\nDF+jt5DJoZig/ry3vgPVCdMHgPeUUh6qeMvbaEI+U/CksdfvONQ64EzgVOD5iFjeeemZUsrOiDgc\nOAu4nubKwkcDa4GbSyl35ZUttVO/PdT5NL3ND3uWnwtcQ/NT6CTgE8ASYAvwr8DfDlRli9Wcag51\n10CvvU56ze5Hze6zd4rP1+841LSfUynlYeCEQQqSxpnHAaREBkpKZKCkRAZKSmSgpEQGSkpkoKRE\nzj0dUO0cryUVbTaVummYr43eacSTed3y0bCHkhIZKCmRgZISGSgpkYGSEhkoKZGBkhIZKCmRgZIS\nOVNiGjXXEK89bb3mdPMjK2ZA1Kqp3VPg89lDSYkMlJTIQEmJDJSUyEBJiQyUlMhASYkMlJTIgd1p\n1Ax87qpc19OnVTSqvCXdsg/P3KbmmusO7Oazh5ISGSgpkYGSEhkoKZGBkhIZKCmRgZISGSgpkYGS\nEvU1UyIizgc+ChzaWXQ3cGkpZX3n9f2BtcAZwP7ABmBNKeXxrILnUs1p5Isq13Xgt2Zuc0hFm1rO\nghiNfnuoLcDFwNuBVcD3gW9HxO93Xr8CeD/wQeB44PXAdTmlSu3XVw9VSvnPnkWfjoiPAu+KiEeA\n84APlVJuBoiIc4GfR8QxpZQ7UiqWWmzWv6Ei4hUR8SHgAOAnND3WQuCmPW1KKb8AHgKOG7BOaSz0\nPds8Iv6AJkCLgeeA00op90bE24BdpZRne96yDVgxcKXSGJjN6Rv3AkcDv0tzwsE1EXH8oIXsAnqv\nSrcAzy/R3NrN5AM6dfeVbPT9fS2l7AZ+1Xn6s4g4BvgEcC2wKCKW9vRSy4GtM613ER7D1+gtZHIo\nJqi/9WvGd/gVNIfIN9IE/MQ9L0TEm4E30OwiSvNev+NQnwVuoDnQ8DvAh4H3AH9cSnk2Ir4KrI2I\nX9P8vvoC8GOP8Glf0e8u32uBq4HXAc8A/0MTpu93Xr+QZhf0mzS91nrggpxS515N9/34J+vWdcRl\nM7fZdGPdulacNHObmmuub6/bnPrQ7zjUX8zw+ovAxzsPaZ/jcQApkYGSEhkoKZGBkhIZKCmRgZIS\nGSgpkYGSEkUp/cylHUIBEW8HNi6mfemuOW+/9i7wCw6uaPRc5cqWVbT52cxNlry6cnv7uK7JsatK\nKT+drm3bvsPSWDNQUqLWBmr3qAsYwL+MuoABjPO/extqb22gxvkyWOMcqHH+d29D7a0NlDSODJSU\nyEBJidpwUaHF0Bzr71amWDbXNlW0mepSzM8AvYMVC2rubl37I+DFijb/PXOTqf592/DvPlvDqr1r\nnTMOO7YhUIfC1HdTr73SzLC8e4D3HtO7IPPq7v9X0ea9s1/9qP/dBzHk2g8Fbp2uQRtmShwInAw8\nwHh/lpq/FtOEaUMp5cnpGo48UNJ84kEJKZGBkhIZKCmRgZISGSgpUesCFREXRMTmiNgREbdFxDtH\nXVONiLgkIiZ6HveMuq6pRMTqiPhORDzSqfPUKdpcGhGPRsQLEfG9iFg5ilp7zVR7RFw5xedw/VzV\n16pARcQZwOXAJcDbaMb7N0TEQSMtrN5dNLfvWdF5DDI2PExLaCaCrGGK2x9FxMXAx4CP0IxRP0/z\nOdTeo3uYpq294wZe/jmcOTelAaWU1jyA24C/73oewMPARaOuraL2S4CfjrqOWdQ9AZzas+xR4MKu\n50uBHcDpo663ovYrgX8bVU2t6aEiYj+a+/R236O3ADcyPvfoPbKzK3J/RHwtIg4ZdUH9iojDaP6q\nd38OzwK3Mz6fwwkRsS0i7o2IdRExZ1fPaE2ggINo7gK6rWf5uNyj9zbgHJppVOcDhwE/iogloyxq\nFlbQ7EqN6+dwA3A28D7gIpr7l10fEb13nB2KNkyOnRdKKRu6nt4VEXcADwKn0+yGaA6UUq7tenp3\nRNwJ3A+cAPxg2NtvUw/1BM0JDMt7llfdo7dtSinPAPcBrTg61oetNL9d58vnsJnmuzUnn0NrAlVK\neYnmPr3d9+iNzvNpp8y3UUS8CjgCeGzUtfSj8wXcyss/h6XAsYzn53AwcCBz9Dm0bZdvLXBVRGwE\n7qC5xegBwFWjLKpGRFwG/DvNbt7vAX9NcyGeb4yyrql0ftetpOmJAA6PiKOBp0opW4ArgE9HxC9p\nTqv5DM3R1m+PoNyXma72zuMS4DqaPworgc/R7ClsmLy2IRj1oc8pDoWuofkQd9DcPf4do66psu5v\n0HzpdtDc1PvrwGGjrmsvtb6H5pDzb3oe/9TV5q9oDp+/0Pkyrhx13TPVTnPe0nqaMO0EfgV8GXjN\nXNXn+VBSotb8hpLmAwMlJTJQUiIDJSUyUFIiAyUlMlBSIgMlJTJQUiIDJSUyUFKi/wfCpw4iNFrn\nIgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f57caa70320>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print('Policy exploration heat map:')\n",
"plt.imshow(np.flipud(explore_map.T), cmap='hot', interpolation='nearest')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Policy Exploration\n",
"\n",
"With epsilon = 0.1, this learning algorithm spends most of its time on-policy."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment