Skip to content

Instantly share code, notes, and snippets.

@tomtung
Created August 31, 2021 04:42
Show Gist options
  • Save tomtung/c2fab9d0e22501b6e40ab7e5d6339ec7 to your computer and use it in GitHub Desktop.
Save tomtung/c2fab9d0e22501b6e40ab7e5d6339ec7 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "52eb889f",
"metadata": {},
"outputs": [],
"source": [
"import gym"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7672fa88",
"metadata": {},
"outputs": [],
"source": [
"from enum import IntEnum\n",
"\n",
"class CellState(IntEnum):\n",
" \"\"\"State of a cell.\n",
" \n",
" A cell can be undug.\n",
" \n",
" If it's dug, it could be a rupee:\n",
" - Green Rupee (+1): 0 nearby bombs / rupoors\n",
" - Blue Rupee (+5): 1 or 2 nearby bombs / rupoors\n",
" - Red Rupee (+20): 3 or 4 nearby bombs / rupoors\n",
" - Silver Rupee (+100): 5 or 6 nearby bombs / rupoors\n",
" - Gold Rupee (+300): 7 or 8 nearby bombs / rupoors\n",
" \n",
" It could also be a rupoor, which reduces the total reward by 10\n",
" (without going into negative).\n",
" \n",
" Finally, it could also be a bomb, which doesn't reduce the reward\n",
" but terminates the episode immediately.\n",
"\n",
" \"\"\"\n",
" UNDUG = 0\n",
" GREEN = 1\n",
" BLUE = 2\n",
" RED = 3\n",
" SILVER = 4\n",
" GOLD = 5\n",
" RUPOOR = 6\n",
" BOMB = 7\n",
" \n",
" @classmethod\n",
" def from_adj_bad_count(cls, count):\n",
" count_to_val = {\n",
" 0: cls.GREEN,\n",
" 1: cls.BLUE,\n",
" 2: cls.BLUE,\n",
" 3: cls.RED,\n",
" 4: cls.RED,\n",
" 5: cls.SILVER,\n",
" 6: cls.SILVER,\n",
" 7: cls.GOLD,\n",
" 8: cls.GOLD,\n",
" }\n",
"\n",
" return count_to_val[count]\n",
" \n",
" def to_reward(self):\n",
" val_to_reward = {\n",
" self.GREEN: 1,\n",
" self.BLUE: 5,\n",
" self.RED: 20,\n",
" self.SILVER: 100,\n",
" self.GOLD: 300,\n",
" self.RUPOOR: -10,\n",
" self.BOMB: 0,\n",
" }\n",
" return val_to_reward[self]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cd82b1b0",
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import random\n",
"import numpy as np\n",
"\n",
"\n",
"class ThrillDiggerEnv(gym.Env):\n",
" metadata = {'render.modes': ['human']}\n",
" reward_range = (-10.0, 500.0)\n",
" \n",
" N_ROWS, N_COLS = 5, 8\n",
" N_CELLS = N_ROWS * N_COLS\n",
"\n",
" action_space = gym.spaces.Discrete(N_CELLS)\n",
" observation_space = gym.spaces.MultiDiscrete([len(CellState)] * N_CELLS)\n",
" \n",
" inner_states = [CellState.GREEN] * N_CELLS\n",
" is_dug = [False] * N_CELLS\n",
" total_reward = 0\n",
" is_done = False\n",
" \n",
" @property\n",
" def observation(self):\n",
" return [\n",
" cell_state if cell_is_dug else CellState.UNDUG\n",
" for cell_is_dug, cell_state in zip(self.is_dug, self.inner_states)\n",
" ]\n",
" \n",
" def reset(self):\n",
" grid_state = [\n",
" [None] * self.N_COLS\n",
" for _ in range(self.N_ROWS)\n",
" ]\n",
"\n",
" def fill_bombs_and_rupoor():\n",
" positions = [\n",
" (r, l)\n",
" for r in range(self.N_ROWS)\n",
" for l in range(self.N_COLS)\n",
" ]\n",
" random.shuffle(positions)\n",
" \n",
" for i in range(8):\n",
" r, l = positions[i]\n",
" grid_state[r][l] = CellState.RUPOOR\n",
" \n",
" for i in range(8, 16):\n",
" r, l = positions[i]\n",
" grid_state[r][l] = CellState.BOMB\n",
" \n",
" def is_bad(r, l):\n",
" return 0 <= l < self.N_COLS and \\\n",
" 0 <= r < self.N_ROWS and \\\n",
" grid_state[r][l] in (CellState.RUPOOR, CellState.BOMB)\n",
" \n",
" def set_rupee(r, l):\n",
" if grid_state[r][l] is not None:\n",
" return\n",
" \n",
" bad_count = sum([\n",
" int(is_bad(r + dr, l + dl))\n",
" for dr in [-1, 0, 1]\n",
" for dl in [-1, 0, 1]\n",
" ])\n",
" grid_state[r][l] = CellState.from_adj_bad_count(bad_count)\n",
" \n",
" fill_bombs_and_rupoor()\n",
" for r in range(self.N_ROWS):\n",
" for l in range(self.N_COLS):\n",
" set_rupee(r, l)\n",
" \n",
" self.inner_states = [\n",
" item\n",
" for row in grid_state\n",
" for item in row\n",
" ]\n",
" self.is_dug = [False] * self.N_CELLS\n",
" self.total_reward = 0\n",
" self.is_done = False\n",
" return np.array(self.observation, dtype=np.int64)\n",
"\n",
" def render(self, mode='human'):\n",
" for i in range(self.N_CELLS):\n",
" if i > 0 and i % self.N_COLS == 0:\n",
" print(\"\")\n",
"\n",
" name = self.inner_states[i].name.title()[:3]\n",
" if not self.is_dug[i]:\n",
" name = f\"({name})\"\n",
" \n",
" print(name, end=\"\\t\")\n",
" \n",
" print(\"\")\n",
"\n",
" def step(self, action):\n",
" # NB: agent should make sure to not dig cells that are already dug\n",
" reward = 0\n",
" if not self.is_done and not self.is_dug[action]:\n",
" self.is_dug[action] = True\n",
" self.is_done = self.is_done or self.inner_states[action] == CellState.BOMB\n",
"\n",
" # NB: Make sure that the total reward is always non-negative \n",
" new_total_reward = max(0, self.total_reward + self.inner_states[action].to_reward())\n",
" reward = new_total_reward - self.total_reward\n",
" self.total_reward = new_total_reward\n",
"\n",
" return (\n",
" np.array(self.observation, dtype=np.int64),\n",
" reward,\n",
" self.is_done,\n",
" {\"cell_state\": self.inner_states[action]}\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3a1a13c8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.77 ms ± 60.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"from stable_baselines3.common.env_checker import check_env\n",
"\n",
"env = ThrillDiggerEnv()\n",
"check_env(env)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "26bc3dae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(Blu)\t(Blu)\t(Bom)\t(Blu)\t(Blu)\t(Blu)\t(Red)\t(Rup)\t\n",
"(Bom)\t(Sil)\t(Red)\t(Red)\t(Red)\t(Rup)\t(Red)\t(Bom)\t\n",
"(Rup)\t(Bom)\t(Rup)\t(Bom)\tBom\t(Rup)\t(Red)\t(Blu)\t\n",
"(Blu)\t(Red)\t(Red)\t(Red)\t(Red)\t(Red)\t(Sil)\t(Rup)\t\n",
"(Gre)\t(Blu)\t(Bom)\t(Blu)\t(Blu)\t(Bom)\t(Rup)\t(Rup)\t\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env = ThrillDiggerEnv()\n",
"env.reset()\n",
"env.step(20)\n",
"env.render()\n",
"env.is_done"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f03e7628",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n",
" 0,\n",
" True,\n",
" {'cell_state': <CellState.RED: 3>})"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env.step(10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1bc9513",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment