Created
June 13, 2024 07:22
-
-
Save radekosmulski/3dd1adf3da443a81aa6880a59db21687 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "802fbb4e", | |
"metadata": {}, | |
"source": [ | |
"# Getting our feet wet with Stable Baselines 3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "6b6e3cff", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# https://github.com/araffin/rl-tutorial-jnrr19\n", | |
"\n", | |
"import gymnasium as gym\n", | |
"import numpy as np\n", | |
"\n", | |
"from stable_baselines3 import PPO\n", | |
"from stable_baselines3.ppo.policies import MlpPolicy\n", | |
"\n", | |
"import imageio" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "3ac98a6e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"env = gym.make(\"CartPole-v1\")\n", | |
"\n", | |
"model = PPO(MlpPolicy, env, verbose=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "782a1ccf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"ActorCriticPolicy(\n", | |
" (features_extractor): FlattenExtractor(\n", | |
" (flatten): Flatten(start_dim=1, end_dim=-1)\n", | |
" )\n", | |
" (pi_features_extractor): FlattenExtractor(\n", | |
" (flatten): Flatten(start_dim=1, end_dim=-1)\n", | |
" )\n", | |
" (vf_features_extractor): FlattenExtractor(\n", | |
" (flatten): Flatten(start_dim=1, end_dim=-1)\n", | |
" )\n", | |
" (mlp_extractor): MlpExtractor(\n", | |
" (policy_net): Sequential(\n", | |
" (0): Linear(in_features=4, out_features=64, bias=True)\n", | |
" (1): Tanh()\n", | |
" (2): Linear(in_features=64, out_features=64, bias=True)\n", | |
" (3): Tanh()\n", | |
" )\n", | |
" (value_net): Sequential(\n", | |
" (0): Linear(in_features=4, out_features=64, bias=True)\n", | |
" (1): Tanh()\n", | |
" (2): Linear(in_features=64, out_features=64, bias=True)\n", | |
" (3): Tanh()\n", | |
" )\n", | |
" )\n", | |
" (action_net): Linear(in_features=64, out_features=2, bias=True)\n", | |
" (value_net): Linear(in_features=64, out_features=1, bias=True)\n", | |
")" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.policy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "41b529e9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mean_reward: 9.81 +/- 0.52\n" | |
] | |
} | |
], | |
"source": [ | |
"# evaluating the model before training\n", | |
"from stable_baselines3.common.evaluation import evaluate_policy\n", | |
"\n", | |
"mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)\n", | |
"print(f\"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "49695eb6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 11.2 s, sys: 162 ms, total: 11.4 s\n", | |
"Wall time: 11.4 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<stable_baselines3.ppo.ppo.PPO at 0x79850070dd20>" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"model.learn(total_timesteps=10_000)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "1976908d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mean_reward: 466.10 +/- 63.32\n", | |
"CPU times: user 19.8 s, sys: 2.89 ms, total: 19.8 s\n", | |
"Wall time: 19.7 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)\n", | |
"print(f\"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "e340c408", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mean_reward: 500.00 +/- 0.00\n" | |
] | |
} | |
], | |
"source": [ | |
"model.learn(total_timesteps=10_000)\n", | |
"\n", | |
"mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)\n", | |
"print(f\"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4a3de575", | |
"metadata": {}, | |
"source": [ | |
"With just 20k steps we maxed out the score :)\n", | |
"\n", | |
"That makes for a really boring recording, but this is how to make one." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "4d7a5b6e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def record_video(env_name='CartPole-v1', ext='mp4'):\n", | |
" env = gym.make(env_name, render_mode=\"rgb_array\")\n", | |
" obs, info = env.reset()\n", | |
"\n", | |
" images = [\n", | |
" env.render()\n", | |
" ]\n", | |
"\n", | |
" ep_len = 0\n", | |
" while True:\n", | |
" a, _states = model.predict(obs, deterministic=True)\n", | |
" obs, _, terminated, truncated, _ = env.step(a)\n", | |
" images.append(env.render())\n", | |
" ep_len += 1\n", | |
" if terminated or truncated:\n", | |
" obs, info = env.reset()\n", | |
" break\n", | |
"\n", | |
" print(f'Episode length: {ep_len}')\n", | |
" fn = f'{env_name}.{ext}'\n", | |
" imageio.mimsave(fn, images)\n", | |
" print(f'Video saved to: {fn}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "6eca4114", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/radek/miniforge3/envs/cleanrl/lib/python3.10/site-packages/pkg_resources/__init__.py:121: DeprecationWarning: pkg_resources is deprecated as an API\n", | |
" warnings.warn(\"pkg_resources is deprecated as an API\", DeprecationWarning)\n", | |
"/home/radek/miniforge3/envs/cleanrl/lib/python3.10/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.\n", | |
"Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages\n", | |
" declare_namespace(pkg)\n", | |
"/home/radek/miniforge3/envs/cleanrl/lib/python3.10/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('mpl_toolkits')`.\n", | |
"Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages\n", | |
" declare_namespace(pkg)\n", | |
"IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (600, 400) to (608, 400) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Episode length: 500\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[swscaler @ 0x704ff00] Warning: data is not aligned! This can lead to a speed loss\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Video saved to: CartPole-v1.mp4\n" | |
] | |
} | |
], | |
"source": [ | |
"record_video()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0a36255d", | |
"metadata": {}, | |
"source": [ | |
"# Training on a more complex problem" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "564e6b02", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Using cuda device\n", | |
"Creating environment from the given name 'Acrobot-v1'\n", | |
"Wrapping the env with a `Monitor` wrapper\n", | |
"Wrapping the env in a DummyVecEnv.\n" | |
] | |
} | |
], | |
"source": [ | |
"env_name = \"Acrobot-v1\"\n", | |
"model = PPO('MlpPolicy', env_name, verbose=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "c73b98b5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (500, 500) to (512, 512) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Episode length: 500\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[swscaler @ 0x67d5d80] Warning: data is not aligned! This can lead to a speed loss\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Video saved to: Acrobot-v1.mp4\n" | |
] | |
} | |
], | |
"source": [ | |
"record_video(env_name)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "49285996", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(-500.0, 0.0)" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"evaluate_policy(model, gym.make(env_name), n_eval_episodes=100, warn=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "bd2d343a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080\"> 100%</span> <span style=\"color: #729c1f; text-decoration-color: #729c1f\">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span style=\"color: #008000; text-decoration-color: #008000\">20,478/20,000 </span> [ <span style=\"color: #808000; text-decoration-color: #808000\">0:00:24</span> < <span style=\"color: #008080; text-decoration-color: #008080\">0:00:00</span> , <span style=\"color: #800000; text-decoration-color: #800000\">817 it/s</span> ]\n", | |
"</pre>\n" | |
], | |
"text/plain": [ | |
"\u001b[35m 100%\u001b[0m \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20,478/20,000 \u001b[0m [ \u001b[33m0:00:24\u001b[0m < \u001b[36m0:00:00\u001b[0m , \u001b[31m817 it/s\u001b[0m ]\n" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |
], | |
"text/plain": [] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |
"</pre>\n" | |
], | |
"text/plain": [ | |
"\n" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 26 s, sys: 126 ms, total: 26.1 s\n", | |
"Wall time: 26 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<stable_baselines3.ppo.ppo.PPO at 0x7984d1913a30>" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"model.learn(20_000, progress_bar=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "6fa32157", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(-80.96, 12.161348609426506)" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"evaluate_policy(model, gym.make(env_name), n_eval_episodes=100, warn=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "41a70522", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# are all these wrappers changing the behavior of the env? Let's try with a function I wrote myself\n", | |
"\n", | |
"def evaluate(n_runs=100):\n", | |
" env = gym.make(env_name, render_mode=\"rgb_array\")\n", | |
" obs, info = env.reset()\n", | |
" ep_lens = []\n", | |
" rewardss = []\n", | |
" for num_runs in range(n_runs):\n", | |
" ep_len = 0\n", | |
" rewards = 0\n", | |
" while True:\n", | |
" a, _states = model.predict(obs, deterministic=True)\n", | |
" obs, reward, terminated, truncated, _ = env.step(a)\n", | |
" ep_len += 1\n", | |
" rewards += reward\n", | |
" if terminated or truncated:\n", | |
" obs, info = env.reset()\n", | |
" break\n", | |
" ep_lens.append(ep_len)\n", | |
" rewardss.append(rewards)\n", | |
" return print(f'Mean episode length: {np.mean(ep_lens)}\\tMean reward: {np.mean(rewardss)}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6c5976e9", | |
"metadata": {}, | |
"source": [ | |
"https://gymnasium.farama.org/environments/classic_control/acrobot/#rewards\n", | |
"\n", | |
"> The goal is to have the free end reach a designated target height in as few steps as possible, and as such all steps that do not reach the goal incur a reward of -1. Achieving the target height results in termination with a reward of 0. The reward threshold is -100." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "2de63951", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Mean episode length: 82.91\tMean reward: -81.91\n" | |
] | |
} | |
], | |
"source": [ | |
"# seems there are no fundamental differences introduced by using these wrappers\n", | |
"\n", | |
"evaluate()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "0383c997", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Episode length: 87\n", | |
"Video saved to: Acrobot-v1.gif\n" | |
] | |
} | |
], | |
"source": [ | |
"record_video(env_name, 'gif')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "e9883919", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080\"> 100%</span> <span style=\"color: #729c1f; text-decoration-color: #729c1f\">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span style=\"color: #008000; text-decoration-color: #008000\">20,374/20,000 </span> [ <span style=\"color: #808000; text-decoration-color: #808000\">0:00:26</span> < <span style=\"color: #008080; text-decoration-color: #008080\">0:00:00</span> , <span style=\"color: #800000; text-decoration-color: #800000\">765 it/s</span> ]\n", | |
"</pre>\n" | |
], | |
"text/plain": [ | |
"\u001b[35m 100%\u001b[0m \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20,374/20,000 \u001b[0m [ \u001b[33m0:00:26\u001b[0m < \u001b[36m0:00:00\u001b[0m , \u001b[31m765 it/s\u001b[0m ]\n" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |
], | |
"text/plain": [] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |
"</pre>\n" | |
], | |
"text/plain": [ | |
"\n" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 27.9 s, sys: 218 ms, total: 28.1 s\n", | |
"Wall time: 27.9 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<stable_baselines3.ppo.ppo.PPO at 0x7984d1913a30>" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"model.learn(20_000, progress_bar=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "61cb0001", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Mean episode length: 86.34\tMean reward: -85.34\n" | |
] | |
} | |
], | |
"source": [ | |
"evaluate()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "17016d79", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Episode length: 78\n", | |
"Video saved to: Acrobot-v1.gif\n" | |
] | |
} | |
], | |
"source": [ | |
"record_video(env_name, 'gif')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment