Skip to content

Instantly share code, notes, and snippets.

@radekosmulski
Created June 13, 2024 07:22
Show Gist options
  • Save radekosmulski/3dd1adf3da443a81aa6880a59db21687 to your computer and use it in GitHub Desktop.
Save radekosmulski/3dd1adf3da443a81aa6880a59db21687 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "802fbb4e",
"metadata": {},
"source": [
"# Getting our feet wet with Stable Baselines 3"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6b6e3cff",
"metadata": {},
"outputs": [],
"source": [
"# https://github.com/araffin/rl-tutorial-jnrr19\n",
"\n",
"import gymnasium as gym\n",
"import numpy as np\n",
"\n",
"from stable_baselines3 import PPO\n",
"from stable_baselines3.ppo.policies import MlpPolicy\n",
"\n",
"import imageio"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3ac98a6e",
"metadata": {},
"outputs": [],
"source": [
"env = gym.make(\"CartPole-v1\")\n",
"\n",
"model = PPO(MlpPolicy, env, verbose=0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "782a1ccf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ActorCriticPolicy(\n",
" (features_extractor): FlattenExtractor(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" )\n",
" (pi_features_extractor): FlattenExtractor(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" )\n",
" (vf_features_extractor): FlattenExtractor(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" )\n",
" (mlp_extractor): MlpExtractor(\n",
" (policy_net): Sequential(\n",
" (0): Linear(in_features=4, out_features=64, bias=True)\n",
" (1): Tanh()\n",
" (2): Linear(in_features=64, out_features=64, bias=True)\n",
" (3): Tanh()\n",
" )\n",
" (value_net): Sequential(\n",
" (0): Linear(in_features=4, out_features=64, bias=True)\n",
" (1): Tanh()\n",
" (2): Linear(in_features=64, out_features=64, bias=True)\n",
" (3): Tanh()\n",
" )\n",
" )\n",
" (action_net): Linear(in_features=64, out_features=2, bias=True)\n",
" (value_net): Linear(in_features=64, out_features=1, bias=True)\n",
")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.policy"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "41b529e9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean_reward: 9.81 +/- 0.52\n"
]
}
],
"source": [
"# evaluating the model before training\n",
"from stable_baselines3.common.evaluation import evaluate_policy\n",
"\n",
"mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)\n",
"print(f\"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "49695eb6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 11.2 s, sys: 162 ms, total: 11.4 s\n",
"Wall time: 11.4 s\n"
]
},
{
"data": {
"text/plain": [
"<stable_baselines3.ppo.ppo.PPO at 0x79850070dd20>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"model.learn(total_timesteps=10_000)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1976908d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean_reward: 466.10 +/- 63.32\n",
"CPU times: user 19.8 s, sys: 2.89 ms, total: 19.8 s\n",
"Wall time: 19.7 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)\n",
"print(f\"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e340c408",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean_reward: 500.00 +/- 0.00\n"
]
}
],
"source": [
"model.learn(total_timesteps=10_000)\n",
"\n",
"mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)\n",
"print(f\"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}\")"
]
},
{
"cell_type": "markdown",
"id": "4a3de575",
"metadata": {},
"source": [
"With just 20k steps we maxed out the score :)\n",
"\n",
"That makes for a really boring recording, but this is how to make one."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4d7a5b6e",
"metadata": {},
"outputs": [],
"source": [
"def record_video(env_name='CartPole-v1', ext='mp4'):\n",
" env = gym.make(env_name, render_mode=\"rgb_array\")\n",
" obs, info = env.reset()\n",
"\n",
" images = [\n",
" env.render()\n",
" ]\n",
"\n",
" ep_len = 0\n",
" while True:\n",
" a, _states = model.predict(obs, deterministic=True)\n",
" obs, _, terminated, truncated, _ = env.step(a)\n",
" images.append(env.render())\n",
" ep_len += 1\n",
" if terminated or truncated:\n",
" obs, info = env.reset()\n",
" break\n",
"\n",
" print(f'Episode length: {ep_len}')\n",
" fn = f'{env_name}.{ext}'\n",
" imageio.mimsave(fn, images)\n",
" print(f'Video saved to: {fn}')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "6eca4114",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/radek/miniforge3/envs/cleanrl/lib/python3.10/site-packages/pkg_resources/__init__.py:121: DeprecationWarning: pkg_resources is deprecated as an API\n",
" warnings.warn(\"pkg_resources is deprecated as an API\", DeprecationWarning)\n",
"/home/radek/miniforge3/envs/cleanrl/lib/python3.10/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.\n",
"Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages\n",
" declare_namespace(pkg)\n",
"/home/radek/miniforge3/envs/cleanrl/lib/python3.10/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('mpl_toolkits')`.\n",
"Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages\n",
" declare_namespace(pkg)\n",
"IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (600, 400) to (608, 400) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode length: 500\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[swscaler @ 0x704ff00] Warning: data is not aligned! This can lead to a speed loss\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Video saved to: CartPole-v1.mp4\n"
]
}
],
"source": [
"record_video()"
]
},
{
"cell_type": "markdown",
"id": "0a36255d",
"metadata": {},
"source": [
"# Training on a more complex problem"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "564e6b02",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda device\n",
"Creating environment from the given name 'Acrobot-v1'\n",
"Wrapping the env with a `Monitor` wrapper\n",
"Wrapping the env in a DummyVecEnv.\n"
]
}
],
"source": [
"env_name = \"Acrobot-v1\"\n",
"model = PPO('MlpPolicy', env_name, verbose=1)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c73b98b5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (500, 500) to (512, 512) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode length: 500\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[swscaler @ 0x67d5d80] Warning: data is not aligned! This can lead to a speed loss\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Video saved to: Acrobot-v1.mp4\n"
]
}
],
"source": [
"record_video(env_name)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "49285996",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(-500.0, 0.0)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evaluate_policy(model, gym.make(env_name), n_eval_episodes=100, warn=False)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "bd2d343a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080\"> 100%</span> <span style=\"color: #729c1f; text-decoration-color: #729c1f\">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span style=\"color: #008000; text-decoration-color: #008000\">20,478/20,000 </span> [ <span style=\"color: #808000; text-decoration-color: #808000\">0:00:24</span> &lt; <span style=\"color: #008080; text-decoration-color: #008080\">0:00:00</span> , <span style=\"color: #800000; text-decoration-color: #800000\">817 it/s</span> ]\n",
"</pre>\n"
],
"text/plain": [
"\u001b[35m 100%\u001b[0m \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20,478/20,000 \u001b[0m [ \u001b[33m0:00:24\u001b[0m < \u001b[36m0:00:00\u001b[0m , \u001b[31m817 it/s\u001b[0m ]\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 26 s, sys: 126 ms, total: 26.1 s\n",
"Wall time: 26 s\n"
]
},
{
"data": {
"text/plain": [
"<stable_baselines3.ppo.ppo.PPO at 0x7984d1913a30>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"model.learn(20_000, progress_bar=True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6fa32157",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(-80.96, 12.161348609426506)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evaluate_policy(model, gym.make(env_name), n_eval_episodes=100, warn=False)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "41a70522",
"metadata": {},
"outputs": [],
"source": [
"# are all these wrappers changing the behavior of the env? Let's try with a function I wrote myself\n",
"\n",
"def evaluate(n_runs=100):\n",
" env = gym.make(env_name, render_mode=\"rgb_array\")\n",
" obs, info = env.reset()\n",
" ep_lens = []\n",
" rewardss = []\n",
" for num_runs in range(n_runs):\n",
" ep_len = 0\n",
" rewards = 0\n",
" while True:\n",
" a, _states = model.predict(obs, deterministic=True)\n",
" obs, reward, terminated, truncated, _ = env.step(a)\n",
" ep_len += 1\n",
" rewards += reward\n",
" if terminated or truncated:\n",
" obs, info = env.reset()\n",
" break\n",
" ep_lens.append(ep_len)\n",
" rewardss.append(rewards)\n",
" return print(f'Mean episode length: {np.mean(ep_lens)}\\tMean reward: {np.mean(rewardss)}')"
]
},
{
"cell_type": "markdown",
"id": "6c5976e9",
"metadata": {},
"source": [
"https://gymnasium.farama.org/environments/classic_control/acrobot/#rewards\n",
"\n",
"> The goal is to have the free end reach a designated target height in as few steps as possible, and as such all steps that do not reach the goal incur a reward of -1. Achieving the target height results in termination with a reward of 0. The reward threshold is -100."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "2de63951",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean episode length: 82.91\tMean reward: -81.91\n"
]
}
],
"source": [
"# seems there are no fundamental differences introduced by using these wrappers\n",
"\n",
"evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "0383c997",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode length: 87\n",
"Video saved to: Acrobot-v1.gif\n"
]
}
],
"source": [
"record_video(env_name, 'gif')"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "e9883919",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080\"> 100%</span> <span style=\"color: #729c1f; text-decoration-color: #729c1f\">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span style=\"color: #008000; text-decoration-color: #008000\">20,374/20,000 </span> [ <span style=\"color: #808000; text-decoration-color: #808000\">0:00:26</span> &lt; <span style=\"color: #008080; text-decoration-color: #008080\">0:00:00</span> , <span style=\"color: #800000; text-decoration-color: #800000\">765 it/s</span> ]\n",
"</pre>\n"
],
"text/plain": [
"\u001b[35m 100%\u001b[0m \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20,374/20,000 \u001b[0m [ \u001b[33m0:00:26\u001b[0m < \u001b[36m0:00:00\u001b[0m , \u001b[31m765 it/s\u001b[0m ]\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 27.9 s, sys: 218 ms, total: 28.1 s\n",
"Wall time: 27.9 s\n"
]
},
{
"data": {
"text/plain": [
"<stable_baselines3.ppo.ppo.PPO at 0x7984d1913a30>"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"model.learn(20_000, progress_bar=True)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "61cb0001",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean episode length: 86.34\tMean reward: -85.34\n"
]
}
],
"source": [
"evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "17016d79",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode length: 78\n",
"Video saved to: Acrobot-v1.gif\n"
]
}
],
"source": [
"record_video(env_name, 'gif')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment