Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jaroslawjanas/177ec5c648329f0c287351cd79587c7b to your computer and use it in GitHub Desktop.
Save jaroslawjanas/177ec5c648329f0c287351cd79587c7b to your computer and use it in GitHub Desktop.
Reinforcement Learning - FrozenLake Problem.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMnN+bqS6OAnOvxCTxfynU8",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/jaroslawjanas/177ec5c648329f0c287351cd79587c7b/reinforcement-learning-frozenlake-problem.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# The FrozenLake Problem\n",
"\n",
"The *“deterministic”* FrozenLake is a toy problem from the so called *“grid world”* category of problems. In this problem the agent lives in a square grid and can move in 4 directions, “up”, “down”, “left” and “right”. The agent always starts in the top-left position and its goal is to reach the bottom right position on the grid (see image below)."
],
"metadata": {
"id": "rwNu_2ajQI0h"
}
},
{
"cell_type": "markdown",
"source": [
"![frozenlake_legended.webp]()"
],
"metadata": {
"id": "QHwvpd6qOkAS"
}
},
{
"cell_type": "markdown",
"source": [
"Just like the basic *Gridworld* actions are deterministic i.e. a move to the “right” will always move the agent to the cell directly to their right. The FrozenLake however does have holes in the ice and if the agent falls in, it will drown. Any action which causes a move off the grid results in the agent’s state remaining unchanged. "
],
"metadata": {
"id": "0rAX4vzvP8kO"
}
},
{
"cell_type": "markdown",
"source": [
"# Part 1: Create the FrozenLake\n",
"\n",
"1. Using Python, create a 5x5 grid sized FrozenLake, with a start state at the top left corner and a goal state at the bottom right corner.\n",
"2. Place four holes at the following grid positions in the FrozenLake. **(1,0)**, **(1,3)**, **(3,1)**, **(4,2)**\n",
"3. The reward for reaching the goal state is **+10.0**. The reward for falling into a hole is **-5.0** (because you die!) and the rewards for each transition to a non-terminal state is **-1.0**.\n",
"4. The episode ends if the agent falls into a hole or reaches the goal state.\n",
"5. The actions are **“up”**, **“down”**, **“left”** and **“right”**.\n"
],
"metadata": {
"id": "tSjWuCQOQEmP"
}
},
{
"cell_type": "markdown",
"source": [
"![111111111111.png]()"
],
"metadata": {
"id": "-vpFn3pBS0uf"
}
},
{
"cell_type": "code",
"source": [
"# Imports\n",
"import numpy as np\n",
"import random\n",
"from itertools import groupby\n",
"import matplotlib.pyplot as plt\n",
"from scipy.interpolate import make_interp_spline"
],
"metadata": {
"id": "EqkSWCDJXG71"
},
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Board Size\n",
"BOARD_ROWS = 5\n",
"BOARD_COLS = 5"
],
"metadata": {
"id": "ljT5gaMUh8ca"
},
"execution_count": 21,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Why is this solution good?\n",
"\n",
"* Note that you could easily pass in state `actions` and `q_actions` as arguments of *State*. This makes it possible to define per *State* actions.\n",
"\n",
"* It's easier to work with, for any state based application **OOP** is the king."
],
"metadata": {
"id": "z0KwI_JsvX3h"
}
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "SERbkZfqRnlM"
},
"outputs": [],
"source": [
"# State\n",
"class State:\n",
" states = {}\n",
" alpha = 0.5\n",
" gamma = 0.9\n",
" epsilon = 0.1\n",
"\n",
"\n",
" def __init__(self, pos, reward=-1, is_end=False):\n",
" self.pos = pos\n",
" self.reward = reward\n",
" self.is_end = is_end\n",
" self.actions = [\"up\", \"down\", \"left\", \"right\"]\n",
" self.q_actions = {\n",
" \"up\": 0,\n",
" \"down\": 0,\n",
" \"left\": 0,\n",
" \"right\": 0\n",
" }\n",
"\n",
" self.states[self.pos] = self\n",
"\n",
"\n",
" def move(self, action): \n",
" if self.is_end:\n",
" return None\n",
"\n",
" if action == \"up\": \n",
" next_pos = (self.pos[0] - 1, self.pos[1]) \n",
" elif action == \"down\":\n",
" next_pos = (self.pos[0] + 1, self.pos[1])\n",
" elif action == \"left\":\n",
" next_pos = (self.pos[0], self.pos[1] - 1)\n",
" else:\n",
" next_pos = (self.pos[0], self.pos[1] + 1)\n",
" \n",
" if (next_pos[0] >= 0) and (next_pos[0] <= BOARD_COLS - 1):\n",
" if (next_pos[1] >= 0) and (next_pos[1] <= BOARD_ROWS - 1): \n",
" \n",
" if next_pos in self.states:\n",
" return self.states[next_pos] \n",
" return State(next_pos)\n",
"\n",
" return self\n",
"\n",
"\n",
" # https://stackoverflow.com/questions/3844801/\n",
" # check-if-all-elements-in-a-list-are-identical\n",
" @staticmethod\n",
" def all_equal(iterable):\n",
" g = groupby(iterable)\n",
" return next(g, True) and not next(g, False)\n",
"\n",
"\n",
" def best_action(self):\n",
" if self.is_end:\n",
" return (None, 0)\n",
"\n",
" if State.all_equal(self.q_actions):\n",
" return (None, None)\n",
"\n",
" best_value = max(self.q_actions.values())\n",
" best_action = max(self.q_actions, key=self.q_actions.get)\n",
"\n",
" return best_action, best_value\n",
"\n",
"\n",
" def choose_action(self):\n",
" if self.is_end:\n",
" return None\n",
"\n",
" best_action, best_value = self.best_action()\n",
"\n",
" if random.uniform(0, 1) < self.epsilon:\n",
" other_actions = \\\n",
" [action for action in self.actions if action != best_action]\n",
" return random.choice(other_actions)\n",
"\n",
" return best_action \n",
"\n",
"\n",
" def update_q_action(self, action, next_state):\n",
" # print(next_state.reward)\n",
"\n",
" self.q_actions[action] += \\\n",
" self.alpha * (next_state.reward + \\\n",
" (self.gamma * next_state.best_action()[1]) - \\\n",
" self.q_actions[action])"
]
},
{
"cell_type": "code",
"source": [
"# Board Layout\n",
"\n",
"WIN_STATE = State((4, 4), reward=10, is_end=True)\n",
"START_STATE = State((0, 0), reward=-1)\n",
"HOLES_STATES = [\n",
" State((1, 0), reward=-5, is_end=True),\n",
" State((1, 3), reward=-5, is_end=True),\n",
" State((3, 1), reward=-5, is_end=True),\n",
" State((4, 2), reward=-5, is_end=True)\n",
" ]"
],
"metadata": {
"id": "6buN6gMqh3RI"
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Part 2: Implement the Reinforcement Learning algorithm Q-learning\n",
"\n",
"1. Using the algorithmic steps outlined in the notes (Week 9 – Model Free Learning (Temporal Difference Learning)) (or S&B book), implement Q-learning on the FrozenLake problem to learn a policy which can navigate optimally through the lake.\n",
"2. Set the parameters Alpha = **0.5**, Gamma = **0.9**, and Epsilon = **0.10**\n",
"3. Run the frozen lake experiment for **10000** episodes and output the action value estimates at the end of the learning process.\n",
"4. Plot a curve of the reward per episode (similar to what was depicted in the slides for the cliff walking task Q-learning vs SARSA)."
],
"metadata": {
"id": "PY4gH1aKRdMl"
}
},
{
"cell_type": "code",
"source": [
"# Agent \n",
"\n",
"class Agent:\n",
"\n",
" def __init__(self):\n",
" self.start_state = START_STATE\n",
" \n",
"\n",
" def qlearn(self, episodes):\n",
"\n",
" data = {\n",
" \"episodes\": [],\n",
" \"rewards\": []\n",
" }\n",
"\n",
" for episode in range(episodes):\n",
" reward_sum = 0\n",
" data[\"episodes\"].append(episode + 1)\n",
"\n",
" state = self.start_state\n",
"\n",
" while not state.is_end:\n",
"\n",
" action = state.choose_action()\n",
"\n",
" next_state = state.move(action)\n",
" state.update_q_action(action, next_state)\n",
" reward_sum += next_state.reward\n",
"\n",
" state = next_state\n",
"\n",
" data[\"rewards\"].append(reward_sum)\n",
"\n",
" return data \n",
"\n",
"\n",
" def show_values(self):\n",
" for i in range(BOARD_ROWS):\n",
" print('-----------------------------------------')\n",
" out = '| '\n",
" for j in range(BOARD_COLS):\n",
" state = State.states.get((i, j), None)\n",
"\n",
" best_value = 0\n",
" if state:\n",
" if state.is_end:\n",
" out += \"END\".ljust(5) + ' | '\n",
" continue\n",
"\n",
" best_choice, best_value = state.best_action()\n",
"\n",
" out += str(round(best_value, 2)).ljust(5) + ' | '\n",
" print(out)\n",
" print('-----------------------------------------')"
],
"metadata": {
"id": "FZWY5jHPXMaC"
},
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"source": [
"agent = Agent()\n",
"data = agent.qlearn(100000)\n",
"print(agent.show_values())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bzpnTFYzBn5R",
"outputId": "3233b517-2fdb-4dd3-cf11-25cfb599a511"
},
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"-----------------------------------------\n",
"| -0.43 | 0.63 | 1.81 | 3.12 | 4.58 | \n",
"-----------------------------------------\n",
"| END | 1.81 | 3.12 | END | 6.2 | \n",
"-----------------------------------------\n",
"| 1.81 | 3.12 | 4.58 | 6.2 | 8.0 | \n",
"-----------------------------------------\n",
"| 0.63 | END | 6.2 | 8.0 | 10.0 | \n",
"-----------------------------------------\n",
"| -0.53 | -0.5 | END | 10.0 | END | \n",
"-----------------------------------------\n",
"None\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"plt.plot(data[\"episodes\"],data[\"rewards\"])"
],
"metadata": {
"id": "9N_q9WxFBptB",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 447
},
"outputId": "f951580c-b19a-49a1-8f99-683754eb56d0"
},
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fe6f431c070>]"
]
},
"metadata": {},
"execution_count": 26
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"# Smoothed\n",
"# https://www.geeksforgeeks.org/how-to-plot-a-smooth-curve-in-matplotlib/\n",
"spline = make_interp_spline(data[\"episodes\"],data[\"rewards\"])\n",
"\n",
"x = np.linspace(min(data[\"episodes\"]), max(data[\"episodes\"]), 50)\n",
"y = spline(x)\n",
"\n",
"plt.plot(x, y)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 447
},
"id": "J53mi64-z_wT",
"outputId": "de816835-115f-404d-b310-1ce4a924f01b"
},
"execution_count": 27,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fe6f4474430>]"
]
},
"metadata": {},
"execution_count": 27
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "aZd9XTsi1cia"
},
"execution_count": 27,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment