Skip to content

Instantly share code, notes, and snippets.

@hanskyy
Created April 12, 2022 06:58
Show Gist options
  • Save hanskyy/efb85840a64b8cf01716a270966acd66 to your computer and use it in GitHub Desktop.
Save hanskyy/efb85840a64b8cf01716a270966acd66 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### RL Problem \n",
"In this problem, you will try to to use PyTorch to implement DQN in OpenAI gym environment 'Pendulum-v0'."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Make sure you install the gym environment. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pip install gym"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Q1 Define a RL_brain for DQN."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import numpy as np\n",
"\n",
"class DQN:\n",
" def __init__(\n",
" self,\n",
" n_actions,\n",
" n_features,\n",
" learning_rate=0.005,\n",
" reward_decay=0.9,\n",
" e_greedy=0.9,\n",
" replace_target_iter=200,\n",
" memory_size=3000,\n",
" batch_size=32,\n",
" ):\n",
" self.n_actions = n_actions\n",
" self.n_features = n_features\n",
" self.lr = learning_rate\n",
" self.gamma = reward_decay\n",
" self.epsilon = e_greedy\n",
" self.replace_target_iter = replace_target_iter\n",
" self.memory_size = memory_size\n",
" self.batch_size = batch_size\n",
" \n",
" #-- complete the code here #initial memory buffer and define your network\n",
" \n",
" def forward(self, observation):\n",
" #-- complete the code here \n",
" \n",
" def store_transition(self, s, a, r, s_):\n",
" #-- complete the code here\n",
" \n",
" def choose_action(self, observation):\n",
" #-- complete the code here\n",
" \n",
" def learn(self):\n",
" #-- complete the code here\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Q2 setting the main loop"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The DQN algorithm is use for discrete action space. But Pendulum env is continuous action space. Try to find a way to solve the problem. Hint: discretizing it!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"\n",
"env = gym.make('Pendulum-v0')\n",
"# env = gym.make()\n",
"env = env.unwrapped\n",
"env.seed(1)\n",
"MEMORY_SIZE = 3000\n",
"# define the n_actions and n_features\n",
"\n",
"#-- complete the code here\n",
"\n",
"# Hint: make sure you check the Pendulum-v0 env in openai gym github code\n",
"dqn = DQN(n_actions=ACTION_SPACE, n_features=N_FEATURES, memory_size=MEMORY_SIZE)\n",
"\n",
"\n",
"def train(RL):\n",
" total_steps = 0\n",
" observation = env.reset()\n",
" while True:\n",
" if total_steps - MEMORY_SIZE > 8000: env.render()\n",
" \n",
" action = RL.choose_action(observation)\n",
"\n",
" #-- complete the code here\n",
" #hint: map the action to the range of pendulum action space\n",
" \n",
" observation_, reward, done, info = env.step(np.array([action]))\n",
"\n",
"\n",
" RL.store_transition(observation, action, reward, observation_)\n",
"\n",
" if total_steps > MEMORY_SIZE: # learning\n",
" RL.learn()\n",
"\n",
" if total_steps - MEMORY_SIZE > 20000: # stop game\n",
" break\n",
"\n",
" observation = observation_\n",
" total_steps += 1\n",
"\n",
"train(dqn)\n",
"env.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Q3 Plot the reward along the timesteps"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#-- complete the code here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For further reference, please check the DQN code present in class. [Here](https://github.com/hanskyy/RL)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment