Skip to content

Instantly share code, notes, and snippets.

@tonghuikang
Last active April 19, 2020 19:26
Show Gist options
  • Save tonghuikang/ba457d6b9aed79c1935b6ae98e552d66 to your computer and use it in GitHub Desktop.
Save tonghuikang/ba457d6b9aed79c1935b6ae98e552d66 to your computer and use it in GitHub Desktop.
SML HW4 template
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reset -sf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Q3\n",
"[student-id]-gridworld.py"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import sys\n",
"from gym.envs.toy_text import discrete\n",
"\n",
"UP = 0\n",
"RIGHT = 1\n",
"DOWN = 2\n",
"LEFT = 3\n",
"\n",
"GOAL = 4 # upper-rightcorner\n",
"START = 20 # lower-leftcorner\n",
"SNAKE1 = 7\n",
"SNAKE2 = 17\n",
"\n",
"eps = 0.25\n",
"\n",
"\n",
"class Robot_vs_snakes_world(discrete.DiscreteEnv):\n",
" def __init__(self):\n",
" self.shape = [5, 5]\n",
" \n",
" # total number of states\n",
" nS = np.prod(self.shape)\n",
" \n",
" # total number of actions per state\n",
" nA = 4\n",
"\n",
" MAXY = self.shape[0]\n",
" MAXX = self.shape[1]\n",
"\n",
" P = {}\n",
" grid = np.arange(nS).reshape(self.shape)\n",
" it = np.nditer(grid, flags=[\"multi_index\"])\n",
"\n",
" while not it.finished:\n",
" s = it.iterindex\n",
" y, x = it.multi_index\n",
"\n",
" P[s] = {a:[] for a in range(nA)}\n",
"\n",
" is_done = lambda s: s == GOAL\n",
"\n",
" if is_done(s):\n",
" reward = 0.0\n",
" elif s == SNAKE1 or s == SNAKE2:\n",
" reward = -15.0\n",
" else:\n",
" reward = -1.0\n",
"\n",
" if is_done(s):\n",
" P[s][UP] = [(1.0, s, reward, True)]\n",
" P[s][RIGHT] = [(1.0, s, reward, True)]\n",
" P[s][DOWN] = [(1.0, s, reward, True)]\n",
" P[s][LEFT] = [(1.0, s, reward, True)]\n",
"\n",
" else:\n",
" ns_up = s if y == 0 else s - MAXX\n",
" ns_right = s if x == (MAXX - 1) else s + 1\n",
" ns_down = s if y == (MAXY - 1) else s + MAXX\n",
" ns_left = s if x == 0 else s - 1\n",
" P[s][UP] = [\n",
" (1 - (2 * eps), ns_up, reward, is_done(ns_up)),\n",
" (eps, ns_right, reward, is_done(ns_right)),\n",
" (eps, ns_left, reward, is_done(ns_left)),\n",
" ]\n",
" P[s][RIGHT] = [\n",
" (1 - (2 * eps), ns_right, reward, is_done(ns_right)),\n",
" (eps, ns_up, reward, is_done(ns_up)),\n",
" (eps, ns_down, reward, is_done(ns_down)),\n",
" ]\n",
" P[s][DOWN] = [\n",
" (1 - (2 * eps), ns_down, reward, is_done(ns_down)),\n",
" (eps, ns_right, reward, is_done(ns_right)),\n",
" (eps, ns_left, reward, is_done(ns_left)),\n",
" ]\n",
" P[s][LEFT] = [\n",
" (1 - (2 * eps), ns_left, reward, is_done(ns_left)),\n",
" (eps, ns_up, reward, is_done(ns_up)),\n",
" (eps, ns_down, reward, is_done(ns_down)),\n",
" ]\n",
" it.iternext()\n",
"\n",
" isd = np.zeros(nS)\n",
" isd[START] = 1.0\n",
" self.P = P\n",
"\n",
" super(Robot_vs_snakes_world, self).__init__(nS, nA, P, isd)\n",
"\n",
" def _render(self):\n",
" grid = np.arange(self.nS).reshape(self.shape)\n",
" it = np.nditer(grid, flags=[\"multi_index\"])\n",
"\n",
" while not it.finished:\n",
" s = it.iterindex\n",
" y, x = it.multi_index\n",
"\n",
" if self.s == s:\n",
" output = \"R\"\n",
" elif s == GOAL:\n",
" output = \"G\"\n",
" elif s == SNAKE1 or s == SNAKE2:\n",
" output = \"S\"\n",
"\n",
" else:\n",
" output = \"o\"\n",
" if x == 0:\n",
" output = output.lstrip()\n",
" if x == self.shape[1] - 1:\n",
" output = output.rstrip()\n",
"\n",
" sys.stdout.write(output)\n",
"\n",
" if x == self.shape[1] - 1:\n",
" sys.stdout.write(\"\\n\")\n",
"\n",
" it.iternext()\n",
"\n",
" sys.stdout.write(\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ooooG\n",
"ooSoo\n",
"ooooo\n",
"ooSoo\n",
"Roooo\n",
"\n"
]
}
],
"source": [
"env = Robot_vs_snakes_world()\n",
"env._render()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# env.s\n",
"# env.step(DIR)\n",
"# env.p[state][action]\n",
"\n",
"def value_iteration(env):\n",
" policy = np.zeros([env.nS, env.nA])\n",
" V = np.zeros(env.nS)\n",
" return policy, V"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[NbConvertApp] Converting notebook hw4-gridworld.ipynb to html\n",
"[NbConvertApp] Writing 296037 bytes to hw4-gridworld.html\n",
"[NbConvertApp] Converting notebook hw4-gridworld.ipynb to script\n",
"[NbConvertApp] Writing 4138 bytes to hw4-gridworld.py\n"
]
}
],
"source": [
"%%bash\n",
"export THIS_NB=\"hw4-gridworld\"\n",
"jupyter nbconvert --to html $THIS_NB.ipynb --output=$THIS_NB\n",
"jupyter nbconvert --to script $THIS_NB.ipynb --output=$THIS_NB\n",
"python -c 'import sys;print(\"\".join(sys.stdin.readlines()[8:-19])),' < $THIS_NB.py > temp.txt\n",
"mv temp.txt $THIS_NB.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment